コード例 #1
0
    def fit(self, X, y):
        """Compute Frechet mean of each class.

        Parameters
        ----------
        X : array-like, shape=[n_samples, dim]
                              if point_type='vector'
                              shape=[n_samples, n, n]
                              if point_type='matrix'
            Training data, where n_samples is the number of samples
            and n_features is the number of features.
        y : array-like, shape=[n_samples,]
            Training labels.
        """
        self.classes_ = gs.unique(y)
        mean_estimator = FrechetMean(metric=self.riemannian_metric,
                                     point_type=self.point_type)
        frechet_means = []
        for c in self.classes_:
            X_c = X[gs.where(y == c, True, False)]
            frechet_means.append(mean_estimator.fit(X_c).estimate_)
        self.mean_estimates_ = gs.array(frechet_means)
コード例 #2
0
 def test_unique(self):
     vec = gs.array([-1, 0, 1, 1, 0, -1])
     result = gs.unique(vec)
     expected = gs.array([-1, 0, 1])
     self.assertAllClose(result, expected)