def fit(self, X, y): """Compute Frechet mean of each class. Parameters ---------- X : array-like, shape=[n_samples, dim] if point_type='vector' shape=[n_samples, n, n] if point_type='matrix' Training data, where n_samples is the number of samples and n_features is the number of features. y : array-like, shape=[n_samples,] Training labels. """ self.classes_ = gs.unique(y) mean_estimator = FrechetMean(metric=self.riemannian_metric, point_type=self.point_type) frechet_means = [] for c in self.classes_: X_c = X[gs.where(y == c, True, False)] frechet_means.append(mean_estimator.fit(X_c).estimate_) self.mean_estimates_ = gs.array(frechet_means)
def test_unique(self): vec = gs.array([-1, 0, 1, 1, 0, -1]) result = gs.unique(vec) expected = gs.array([-1, 0, 1]) self.assertAllClose(result, expected)