Пример #1
0
 def setUp(self):
     gs.random.seed(123)
     self.sphere = Hypersphere(dim=4)
     self.hyperbolic = Hyperboloid(dim=3)
     self.euclidean = Euclidean(dim=2)
     self.minkowski = Minkowski(dim=2)
     self.so3 = SpecialOrthogonal(n=3, point_type='vector')
     self.so_matrix = SpecialOrthogonal(n=3)
Пример #2
0
    def setUp(self):
        gs.random.seed(1234)

        self.time_like_dim = 0
        self.dimension = 2
        self.space = Minkowski(self.dimension)
        self.metric = self.space.metric
        self.n_samples = 10
Пример #3
0
 def setup_method(self):
     gs.random.seed(123)
     self.sphere = Hypersphere(dim=4)
     self.hyperbolic = Hyperboloid(dim=3)
     self.euclidean = Euclidean(dim=2)
     self.minkowski = Minkowski(dim=2)
     self.so3 = SpecialOrthogonal(n=3, point_type="vector")
     self.so_matrix = SpecialOrthogonal(n=3)
     self.curves_2d = DiscreteCurves(R2)
     self.elastic_metric = ElasticMetric(a=1, b=1, ambient_manifold=R2)
Пример #4
0
    def test_inner_product(self):
        """
        Test that the inner product between two tangent vectors
        is the Minkowski inner product.
        """
        minkowski_space = Minkowski(self.dimension + 1)
        base_point = gs.array(
            [1.16563816, 0.36381045, -0.47000603, 0.07381469])

        tangent_vec_a = self.space.projection_to_tangent_space(
            vector=gs.array([10., 200., 1., 1.]),
            base_point=base_point)

        tangent_vec_b = self.space.projection_to_tangent_space(
            vector=gs.array([11., 20., -21., 0.]),
            base_point=base_point)

        result = self.metric.inner_product(
            tangent_vec_a, tangent_vec_b, base_point)

        expected = minkowski_space.metric.inner_product(
            tangent_vec_a, tangent_vec_b, base_point)

        with self.session():
            self.assertAllClose(result, expected)
Пример #5
0
    def __init__(self, dimension, point_type='extrinsic', scale=1):
        assert isinstance(dimension, int) and dimension > 0
        super(Hyperbolic,
              self).__init__(dimension=dimension,
                             embedding_manifold=Minkowski(dimension + 1))
        self.embedding_metric = self.embedding_manifold.metric
        self.point_type = point_type
        self.scale = scale
        self.metric = HyperbolicMetric(self.dimension, point_type, self.scale)

        self.transform_to = {
            'ball-extrinsic': Hyperbolic._ball_to_extrinsic_coordinates,
            'extrinsic-ball': Hyperbolic._extrinsic_to_ball_coordinates,
            'intrinsic-extrinsic':
            Hyperbolic._intrinsic_to_extrinsic_coordinates,
            'extrinsic-intrinsic':
            Hyperbolic._extrinsic_to_intrinsic_coordinates,
            'extrinsic-half-plane':
            Hyperbolic._extrinsic_to_half_plane_coordinates,
            'half-plane-extrinsic':
            Hyperbolic._half_plane_to_extrinsic_coordinates,
            'extrinsic-extrinsic':
            Hyperbolic._extrinsic_to_extrinsic_coordinates
        }
        self.belongs_to = {'ball': Hyperbolic._belongs_ball}
Пример #6
0
 def __init__(self, dim, coords_type='extrinsic', scale=1):
     super(Hyperboloid, self).__init__(
         dim=dim, scale=scale, embedding_manifold=Minkowski(dim + 1))
     self.coords_type = coords_type
     self.point_type = Hyperboloid.default_point_type
     self.embedding_metric = self.embedding_manifold.metric
     self.metric =\
         HyperboloidMetric(self.dim, self.coords_type, self.scale)
Пример #7
0
 def test_inner_product_matrix_vector(self, default_point_type):
     euclidean = Euclidean(3)
     minkowski = Minkowski(3)
     space = ProductManifold(manifolds=[euclidean, minkowski])
     point = space.random_point(1)
     expected = gs.eye(6)
     expected[3, 3] = -1
     result = space.metric.metric_matrix(point)
     self.assertAllClose(result, expected)
Пример #8
0
 def test_inner_product_is_minkowski_inner_product(
     self, dim, tangent_vec_a, tangent_vec_b, base_point
 ):
     metric = self.metric(dim)
     minkowki_space = Minkowski(dim + 1)
     result = metric.inner_product(tangent_vec_a, tangent_vec_b, base_point)
     expected = minkowki_space.metric.inner_product(
         tangent_vec_a, tangent_vec_b, base_point
     )
     self.assertAllClose(result, expected)
Пример #9
0
 def test_inner_product_matrix_matrix(self):
     euclidean = Euclidean(3)
     minkowski = Minkowski(3)
     space = ProductManifold(manifolds=[euclidean, minkowski],
                             default_point_type='matrix')
     point = space.random_uniform(1)
     result = space.metric.inner_product_matrix(point)
     expected = gs.eye(6)
     expected[3, 3] = -1
     self.assertAllClose(result, expected)
Пример #10
0
 def __init__(self, dim, coords_type='extrinsic', scale=1):
     minkowski = Minkowski(dim + 1)
     super(Hyperboloid, self).__init__(
         dim=dim, embedding_space=minkowski,
         submersion=minkowski.metric.squared_norm, value=- 1.,
         tangent_submersion=minkowski.metric.inner_product, scale=scale)
     self.coords_type = coords_type
     self.point_type = Hyperboloid.default_point_type
     self.metric =\
         HyperboloidMetric(self.dim, self.coords_type, self.scale)
Пример #11
0
 def __init__(self, dim, coords_type='extrinsic', scale=1):
     # TODO (ninamiolane): Call __init__ from parent classes
     # and remove ignore rule of corresponding DeepSource issue
     self.scale = scale
     self.dim = dim
     self.coords_type = coords_type
     self.point_type = Hyperboloid.default_point_type
     self.embedding_manifold = Minkowski(dim + 1)
     self.embedding_metric = self.embedding_manifold.metric
     self.metric =\
         HyperboloidMetric(self.dim, self.coords_type, self.scale)
Пример #12
0
 def __init__(self, dim, coords_type="extrinsic", scale=1, **kwargs):
     minkowski = Minkowski(dim + 1)
     kwargs.setdefault("metric", HyperboloidMetric(dim, coords_type, scale))
     super(Hyperboloid,
           self).__init__(dim=dim,
                          embedding_space=minkowski,
                          submersion=minkowski.metric.squared_norm,
                          value=-1.0,
                          tangent_submersion=minkowski.metric.inner_product,
                          scale=scale,
                          **kwargs)
     self.coords_type = coords_type
     self.point_type = Hyperboloid.default_point_type
Пример #13
0
class TestFrechetMean(geomstats.tests.TestCase):
    _multiprocess_can_split_ = True

    def setUp(self):
        gs.random.seed(123)
        self.sphere = Hypersphere(dim=4)
        self.hyperbolic = Hyperboloid(dim=3)
        self.euclidean = Euclidean(dim=2)
        self.minkowski = Minkowski(dim=2)
        self.so3 = SpecialOrthogonal(n=3, point_type='vector')
        self.so_matrix = SpecialOrthogonal(n=3)

    def test_logs_at_mean_default_gradient_descent_sphere(self):
        n_tests = 10
        estimator = FrechetMean(
            metric=self.sphere.metric, method='default', lr=1.)

        result = []
        for _ in range(n_tests):
            # take 2 random points, compute their mean, and verify that
            # log of each at the mean is opposite
            points = self.sphere.random_uniform(n_samples=2)
            estimator.fit(points)
            mean = estimator.estimate_

            logs = self.sphere.metric.log(point=points, base_point=mean)
            result.append(gs.linalg.norm(logs[1, :] + logs[0, :]))
        result = gs.stack(result)
        expected = gs.zeros(n_tests)
        self.assertAllClose(expected, result)

    def test_logs_at_mean_adaptive_gradient_descent_sphere(self):
        n_tests = 10
        estimator = FrechetMean(metric=self.sphere.metric, method='adaptive')

        result = []
        for _ in range(n_tests):
            # take 2 random points, compute their mean, and verify that
            # log of each at the mean is opposite
            points = self.sphere.random_uniform(n_samples=2)
            estimator.fit(points)
            mean = estimator.estimate_

            logs = self.sphere.metric.log(point=points, base_point=mean)
            result.append(gs.linalg.norm(logs[1, :] + logs[0, :]))
        result = gs.stack(result)

        expected = gs.zeros(n_tests)
        self.assertAllClose(expected, result)

    def test_estimate_shape_default_gradient_descent_sphere(self):
        dim = 5
        point_a = gs.array([1., 0., 0., 0., 0.])
        point_b = gs.array([0., 1., 0., 0., 0.])
        points = gs.array([point_a, point_b])

        mean = FrechetMean(
            metric=self.sphere.metric, method='default', verbose=True)
        mean.fit(points)
        result = mean.estimate_

        self.assertAllClose(gs.shape(result), (dim,))

    def test_estimate_shape_adaptive_gradient_descent_sphere(self):
        dim = 5
        point_a = gs.array([1., 0., 0., 0., 0.])
        point_b = gs.array([0., 1., 0., 0., 0.])
        points = gs.array([point_a, point_b])

        mean = FrechetMean(metric=self.sphere.metric, method='adaptive')
        mean.fit(points)
        result = mean.estimate_

        self.assertAllClose(gs.shape(result), (dim,))

    def test_estimate_and_belongs_default_gradient_descent_sphere(self):
        point_a = gs.array([1., 0., 0., 0., 0.])
        point_b = gs.array([0., 1., 0., 0., 0.])
        points = gs.array([point_a, point_b])

        mean = FrechetMean(metric=self.sphere.metric, method='default')
        mean.fit(points)

        result = self.sphere.belongs(mean.estimate_)
        expected = True
        self.assertAllClose(result, expected)

    def test_estimate_default_gradient_descent_so3(self):
        points = self.so3.random_uniform(2)

        mean_vec = FrechetMean(
            metric=self.so3.bi_invariant_metric, method='default', lr=1.)
        mean_vec.fit(points)

        logs = self.so3.bi_invariant_metric.log(points, mean_vec.estimate_)
        result = gs.sum(logs, axis=0)
        expected = gs.zeros_like(points[0])
        self.assertAllClose(result, expected)

    def test_estimate_and_belongs_default_gradient_descent_so3(self):
        point = self.so3.random_uniform(10)

        mean_vec = FrechetMean(
            metric=self.so3.bi_invariant_metric, method='default')
        mean_vec.fit(point)

        result = self.so3.belongs(mean_vec.estimate_)
        expected = True
        self.assertAllClose(result, expected)

    @geomstats.tests.np_and_tf_only
    def test_estimate_default_gradient_descent_so_matrix(self):
        points = self.so_matrix.random_uniform(2)
        mean_vec = FrechetMean(
            metric=self.so_matrix.bi_invariant_metric, method='default',
            lr=1.)
        mean_vec.fit(points)
        logs = self.so_matrix.bi_invariant_metric.log(
            points, mean_vec.estimate_)
        result = gs.sum(logs, axis=0)
        expected = gs.zeros_like(points[0])

        self.assertAllClose(result, expected, atol=1e-5)

    @geomstats.tests.np_and_tf_only
    def test_estimate_and_belongs_default_gradient_descent_so_matrix(self):
        point = self.so_matrix.random_uniform(10)

        mean = FrechetMean(
            metric=self.so_matrix.bi_invariant_metric, method='default')
        mean.fit(point)

        result = self.so_matrix.belongs(mean.estimate_)
        expected = True
        self.assertAllClose(result, expected)

    @geomstats.tests.np_and_tf_only
    def test_estimate_and_belongs_adaptive_gradient_descent_so_matrix(self):
        point = self.so_matrix.random_uniform(10)

        mean = FrechetMean(
            metric=self.so_matrix.bi_invariant_metric, method='adaptive',
            verbose=True, lr=.5)
        mean.fit(point)

        result = self.so_matrix.belongs(mean.estimate_)
        self.assertTrue(result)

    @geomstats.tests.np_and_tf_only
    def test_estimate_and_coincide_default_so_vec_and_mat(self):
        point = self.so_matrix.random_uniform(3)

        mean = FrechetMean(
            metric=self.so_matrix.bi_invariant_metric, method='default')
        mean.fit(point)
        expected = mean.estimate_

        mean_vec = FrechetMean(
            metric=self.so3.bi_invariant_metric, method='default')
        point_vec = self.so3.rotation_vector_from_matrix(point)
        mean_vec.fit(point_vec)
        result_vec = mean_vec.estimate_
        result = self.so3.matrix_from_rotation_vector(result_vec)

        self.assertAllClose(result, expected)

    def test_estimate_and_belongs_adaptive_gradient_descent_sphere(self):
        point_a = gs.array([1., 0., 0., 0., 0.])
        point_b = gs.array([0., 1., 0., 0., 0.])
        points = gs.array([point_a, point_b])

        mean = FrechetMean(metric=self.sphere.metric, method='adaptive')
        mean.fit(points)

        result = self.sphere.belongs(mean.estimate_)
        expected = True
        self.assertAllClose(result, expected)

    def test_variance_sphere(self):
        point = gs.array([0., 0., 0., 0., 1.])
        points = gs.array([point, point])

        result = variance(
            points, base_point=point, metric=self.sphere.metric)
        expected = gs.array(0.)

        self.assertAllClose(expected, result)

    def test_estimate_default_gradient_descent_sphere(self):
        point = gs.array([0., 0., 0., 0., 1.])
        points = gs.array([point, point])

        mean = FrechetMean(metric=self.sphere.metric, method='default')
        mean.fit(X=points)

        result = mean.estimate_
        expected = point

        self.assertAllClose(expected, result)

    def test_estimate_adaptive_gradient_descent_sphere(self):
        point = gs.array([0., 0., 0., 0., 1.])
        points = gs.array([point, point])

        mean = FrechetMean(metric=self.sphere.metric, method='adaptive')
        mean.fit(X=points)

        result = mean.estimate_
        expected = point

        self.assertAllClose(expected, result)

    def test_estimate_spd(self):
        point = SPDMatrices(3).random_point()
        points = gs.array([point, point])
        mean = FrechetMean(metric=SPDMetricAffine(3), point_type='matrix')
        mean.fit(X=points)
        result = mean.estimate_
        expected = point
        self.assertAllClose(expected, result)

    def test_variance_hyperbolic(self):
        point = gs.array([2., 1., 1., 1.])
        points = gs.array([point, point])
        result = variance(
            points, base_point=point, metric=self.hyperbolic.metric)
        expected = gs.array(0.)

        self.assertAllClose(result, expected)

    def test_estimate_hyperbolic(self):
        point = gs.array([2., 1., 1., 1.])
        points = gs.array([point, point])

        mean = FrechetMean(metric=self.hyperbolic.metric)
        mean.fit(X=points)
        expected = point

        result = mean.estimate_

        self.assertAllClose(result, expected)

    def test_estimate_and_belongs_hyperbolic(self):
        point_a = self.hyperbolic.random_point()
        point_b = self.hyperbolic.random_point()
        point_c = self.hyperbolic.random_point()
        points = gs.stack([point_a, point_b, point_c], axis=0)

        mean = FrechetMean(metric=self.hyperbolic.metric)
        mean.fit(X=points)

        result = self.hyperbolic.belongs(mean.estimate_)
        expected = True

        self.assertAllClose(result, expected)

    def test_mean_euclidean_shape(self):
        dim = 2
        point = gs.array([1., 4.])

        mean = FrechetMean(metric=self.euclidean.metric)
        points = [point, point, point]
        mean.fit(points)

        result = mean.estimate_

        self.assertAllClose(gs.shape(result), (dim,))

    def test_mean_euclidean(self):
        point = gs.array([1., 4.])

        mean = FrechetMean(metric=self.euclidean.metric)
        points = [point, point, point]
        mean.fit(points)

        result = mean.estimate_
        expected = point

        self.assertAllClose(result, expected)

        points = gs.array([
            [1., 2.],
            [2., 3.],
            [3., 4.],
            [4., 5.]])
        weights = [1., 2., 1., 2.]

        mean = FrechetMean(metric=self.euclidean.metric)
        mean.fit(points, weights=weights)

        result = mean.estimate_
        expected = gs.array([16. / 6., 22. / 6.])

        self.assertAllClose(result, expected)

    def test_variance_euclidean(self):
        points = gs.array([
            [1., 2.],
            [2., 3.],
            [3., 4.],
            [4., 5.]])
        weights = gs.array([1., 2., 1., 2.])
        base_point = gs.zeros(2)
        result = variance(
            points, weights=weights, base_point=base_point,
            metric=self.euclidean.metric)
        # we expect the average of the points' sq norms.
        expected = gs.array((1 * 5. + 2 * 13. + 1 * 25. + 2 * 41.) / 6.)

        self.assertAllClose(result, expected)

    def test_mean_matrices_shape(self):
        m, n = (2, 2)
        point = gs.array([
            [1., 4.],
            [2., 3.]])

        metric = MatricesMetric(m, n)
        mean = FrechetMean(metric=metric, point_type='matrix')
        points = [point, point, point]
        mean.fit(points)

        result = mean.estimate_

        self.assertAllClose(gs.shape(result), (m, n))

    def test_mean_matrices(self):
        m, n = (2, 2)
        point = gs.array([
            [1., 4.],
            [2., 3.]])

        metric = MatricesMetric(m, n)
        mean = FrechetMean(metric=metric, point_type='matrix')
        points = [point, point, point]
        mean.fit(points)

        result = mean.estimate_
        expected = point

        self.assertAllClose(result, expected)

    def test_mean_minkowski_shape(self):
        dim = 2
        point = gs.array([2., -math.sqrt(3)])
        points = [point, point, point]

        mean = FrechetMean(metric=self.minkowski.metric)
        mean.fit(points)
        result = mean.estimate_

        self.assertAllClose(gs.shape(result), (dim,))

    def test_mean_minkowski(self):
        point = gs.array([2., -math.sqrt(3)])
        points = [point, point, point]

        mean = FrechetMean(metric=self.minkowski.metric)
        mean.fit(points)
        result = mean.estimate_

        expected = point

        self.assertAllClose(result, expected)

        points = gs.array([
            [1., 0.],
            [2., math.sqrt(3)],
            [3., math.sqrt(8)],
            [4., math.sqrt(24)]])
        weights = gs.array([1., 2., 1., 2.])

        mean = FrechetMean(metric=self.minkowski.metric)
        mean.fit(points, weights=weights)
        result = mean.estimate_
        result = self.minkowski.belongs(result)
        expected = gs.array(True)

        self.assertAllClose(result, expected)

    def test_variance_minkowski(self):
        points = gs.array([
            [1., 0.],
            [2., math.sqrt(3)],
            [3., math.sqrt(8)],
            [4., math.sqrt(24)]])
        weights = gs.array([1., 2., 1., 2.])
        base_point = gs.array([-1., 0.])
        var = variance(
            points, weights=weights, base_point=base_point,
            metric=self.minkowski.metric)
        result = var != 0
        # we expect the average of the points' Minkowski sq norms.
        expected = True
        self.assertAllClose(result, expected)

    def test_one_point(self):
        point = gs.array([0., 0., 0., 0., 1.])

        mean = FrechetMean(metric=self.sphere.metric, method='default')
        mean.fit(X=point)

        result = mean.estimate_
        expected = point
        self.assertAllClose(expected, result)

        mean = FrechetMean(
            metric=self.sphere.metric, method='frechet-poincare-ball')
        mean.fit(X=point)

        result = mean.estimate_
        expected = point
        self.assertAllClose(expected, result)
Пример #14
0
 def setUp(self):
     gs.random.seed(123)
     self.sphere = Hypersphere(dim=4)
     self.hyperbolic = Hyperboloid(dim=3)
     self.euclidean = Euclidean(dim=2)
     self.minkowski = Minkowski(dim=2)
Пример #15
0
 def setUp(self):
     self.sphere = Hypersphere(dimension=4)
     self.hyperbolic = Hyperbolic(dimension=3)
     self.euclidean = Euclidean(dimension=2)
     self.minkowski = Minkowski(dimension=2)
Пример #16
0
class TestFrechetMean(geomstats.tests.TestCase):
    _multiprocess_can_split_ = True

    def setUp(self):
        self.sphere = Hypersphere(dimension=4)
        self.hyperbolic = Hyperbolic(dimension=3)
        self.euclidean = Euclidean(dimension=2)
        self.minkowski = Minkowski(dimension=2)

    @geomstats.tests.np_only
    def test_adaptive_gradient_descent_sphere(self):
        n_tests = 100
        result = gs.zeros(n_tests)
        expected = gs.zeros(n_tests)

        for i in range(n_tests):
            # take 2 random points, compute their mean, and verify that
            # log of each at the mean is opposite
            points = self.sphere.random_uniform(n_samples=2)
            mean = _adaptive_gradient_descent(points=points,
                                              metric=self.sphere.metric)

            logs = self.sphere.metric.log(point=points, base_point=mean)
            result[i] = gs.linalg.norm(logs[1, :] + logs[0, :])

        self.assertAllClose(expected, result, rtol=1e-10, atol=1e-10)

    @geomstats.tests.np_and_pytorch_only
    def test_estimate_and_belongs_sphere(self):
        point_a = gs.array([1., 0., 0., 0., 0.])
        point_b = gs.array([0., 1., 0., 0., 0.])
        points = gs.zeros((2, point_a.shape[0]))
        points[0, :] = point_a
        points[1, :] = point_b

        mean = FrechetMean(metric=self.sphere.metric)
        mean.fit(points)

        result = self.sphere.belongs(mean.estimate_)
        expected = gs.array([[True]])
        self.assertAllClose(result, expected)

    @geomstats.tests.np_and_pytorch_only
    def test_variance_sphere(self):
        point = gs.array([0., 0., 0., 0., 1.])
        points = gs.zeros((2, point.shape[0]))
        points[0, :] = point
        points[1, :] = point

        result = variance(points, base_point=point, metric=self.sphere.metric)
        expected = helper.to_scalar(0.)

        self.assertAllClose(expected, result)

    @geomstats.tests.np_and_pytorch_only
    def test_estimate_sphere(self):
        point = gs.array([0., 0., 0., 0., 1.])
        points = gs.zeros((2, point.shape[0]))
        points[0, :] = point
        points[1, :] = point

        mean = FrechetMean(metric=self.sphere.metric)
        mean.fit(X=points)

        result = mean.estimate_
        expected = helper.to_vector(point)

        self.assertAllClose(expected, result)

    @geomstats.tests.np_and_tf_only
    def test_variance_hyperbolic(self):
        point = gs.array([2., 1., 1., 1.])
        points = gs.array([point, point])
        result = variance(points,
                          base_point=point,
                          metric=self.hyperbolic.metric)
        expected = helper.to_scalar(0.)

        self.assertAllClose(result, expected)

    @geomstats.tests.np_and_tf_only
    def test_estimate_hyperbolic(self):
        point = gs.array([2., 1., 1., 1.])
        points = gs.array([point, point])

        mean = FrechetMean(metric=self.hyperbolic.metric)
        mean.fit(X=points)

        result = mean.estimate_
        expected = helper.to_vector(point)

        self.assertAllClose(result, expected)

    @geomstats.tests.np_and_tf_only
    def test_estimate_and_belongs_hyperbolic(self):
        point_a = self.hyperbolic.random_uniform()
        point_b = self.hyperbolic.random_uniform()
        point_c = self.hyperbolic.random_uniform()
        points = gs.concatenate([point_a, point_b, point_c], axis=0)

        mean = FrechetMean(metric=self.hyperbolic.metric)
        mean.fit(X=points)

        result = self.hyperbolic.belongs(mean.estimate_)
        expected = gs.array([[True]])

        self.assertAllClose(result, expected)

    def test_mean_euclidean(self):
        point = gs.array([[1., 4.]])

        mean = FrechetMean(metric=self.euclidean.metric)
        points = [point, point, point]
        mean.fit(points)

        result = mean.estimate_
        expected = point
        expected = helper.to_vector(expected)

        self.assertAllClose(result, expected)

        points = gs.array([[1., 2.], [2., 3.], [3., 4.], [4., 5.]])
        weights = gs.array([1., 2., 1., 2.])

        mean = FrechetMean(metric=self.euclidean.metric)
        mean.fit(points, weights=weights)

        result = mean.estimate_
        expected = gs.array([16. / 6., 22. / 6.])
        expected = helper.to_vector(expected)

        self.assertAllClose(result, expected)

    def test_variance_euclidean(self):
        points = gs.array([[1., 2.], [2., 3.], [3., 4.], [4., 5.]])
        weights = gs.array([1., 2., 1., 2.])
        base_point = gs.zeros(2)
        result = variance(points,
                          weights=weights,
                          base_point=base_point,
                          metric=self.euclidean.metric)
        # we expect the average of the points' sq norms.
        expected = (1 * 5. + 2 * 13. + 1 * 25. + 2 * 41.) / 6.
        expected = helper.to_scalar(expected)

        self.assertAllClose(result, expected)

    def test_mean_minkowski(self):
        point = gs.array([[2., -math.sqrt(3)]])
        points = [point, point, point]

        mean = FrechetMean(metric=self.minkowski.metric)
        mean.fit(points)
        result = mean.estimate_

        expected = point
        expected = helper.to_vector(expected)

        self.assertAllClose(result, expected)

        points = gs.array([[1., 0.], [2., math.sqrt(3)], [3., math.sqrt(8)],
                           [4., math.sqrt(24)]])
        weights = gs.array([1., 2., 1., 2.])

        mean = FrechetMean(metric=self.minkowski.metric)
        mean.fit(points, weights=weights)
        result = mean.estimate_
        result = self.minkowski.belongs(result)
        expected = gs.array([[True]])

        self.assertAllClose(result, expected)

    def test_variance_minkowski(self):
        points = gs.array([[1., 0.], [2., math.sqrt(3)], [3., math.sqrt(8)],
                           [4., math.sqrt(24)]])
        weights = gs.array([1., 2., 1., 2.])
        base_point = gs.array([-1., 0.])
        var = variance(points,
                       weights=weights,
                       base_point=base_point,
                       metric=self.minkowski.metric)
        result = helper.to_scalar(var != 0)
        # we expect the average of the points' Minkowski sq norms.
        expected = helper.to_scalar(gs.array([True]))
        self.assertAllClose(result, expected)
Пример #17
0
from geomstats.geometry.minkowski import Minkowski
from geomstats.geometry.product_manifold import (
    NFoldManifold,
    NFoldMetric,
    ProductManifold,
)
from geomstats.geometry.product_riemannian_metric import ProductRiemannianMetric
from geomstats.geometry.special_orthogonal import SpecialOrthogonal
from tests.conftest import Parametrizer
from tests.data_generation import _ManifoldTestData, _RiemannianMetricTestData
from tests.geometry_test_cases import ManifoldTestCase, RiemannianMetricTestCase

smoke_manifolds_1 = [Hypersphere(dim=2), Hyperboloid(dim=2)]
smoke_metrics_1 = [Hypersphere(dim=2).metric, Hyperboloid(dim=2).metric]

smoke_manifolds_2 = [Euclidean(3), Minkowski(3)]
smoke_metrics_2 = [Euclidean(3).metric, Minkowski(3).metric]


class TestProductManifold(ManifoldTestCase, metaclass=Parametrizer):
    space = ProductManifold
    skip_test_random_tangent_vec_is_tangent = True
    skip_test_projection_belongs = True

    class ProductManifoldTestData(_ManifoldTestData):

        n_list = random.sample(range(2, 4), 2)
        default_point_list = ["vector", "matrix"]
        manifolds_list = [[Hypersphere(dim=n),
                           Hyperboloid(dim=n)] for n in n_list]
        space_args_list = [(manifold, None, default_point)
Пример #18
0
class TestFrechetMean(geomstats.tests.TestCase):
    _multiprocess_can_split_ = True

    def setUp(self):
        gs.random.seed(123)
        self.sphere = Hypersphere(dim=4)
        self.hyperbolic = Hyperboloid(dim=3)
        self.euclidean = Euclidean(dim=2)
        self.minkowski = Minkowski(dim=2)

    @geomstats.tests.np_only
    def test_logs_at_mean_default_gradient_descent_sphere(self):
        n_tests = 100
        estimator = FrechetMean(metric=self.sphere.metric, method='default')

        result = gs.zeros(n_tests)
        for i in range(n_tests):
            # take 2 random points, compute their mean, and verify that
            # log of each at the mean is opposite
            points = self.sphere.random_uniform(n_samples=2)
            estimator.fit(points)
            mean = estimator.estimate_

            logs = self.sphere.metric.log(point=points, base_point=mean)
            result[i] = gs.linalg.norm(logs[1, :] + logs[0, :])

        expected = gs.zeros(n_tests)
        self.assertAllClose(expected, result, rtol=1e-10, atol=1e-10)

    @geomstats.tests.np_only
    def test_logs_at_mean_adaptive_gradient_descent_sphere(self):
        n_tests = 100
        estimator = FrechetMean(metric=self.sphere.metric, method='adaptive')

        result = gs.zeros(n_tests)
        for i in range(n_tests):
            # take 2 random points, compute their mean, and verify that
            # log of each at the mean is opposite
            points = self.sphere.random_uniform(n_samples=2)
            estimator.fit(points)
            mean = estimator.estimate_

            logs = self.sphere.metric.log(point=points, base_point=mean)
            result[i] = gs.linalg.norm(logs[1, :] + logs[0, :])

        expected = gs.zeros(n_tests)
        self.assertAllClose(expected, result, rtol=1e-10, atol=1e-10)

    @geomstats.tests.np_and_pytorch_only
    def test_estimate_shape_default_gradient_descent_sphere(self):
        dim = 5
        point_a = gs.array([1., 0., 0., 0., 0.])
        point_b = gs.array([0., 1., 0., 0., 0.])
        points = gs.array([point_a, point_b])

        mean = FrechetMean(metric=self.sphere.metric, method='default')
        mean.fit(points)
        result = mean.estimate_

        self.assertAllClose(gs.shape(result), (dim, ))

    @geomstats.tests.np_and_pytorch_only
    def test_estimate_shape_adaptive_gradient_descent_sphere(self):
        dim = 5
        point_a = gs.array([1., 0., 0., 0., 0.])
        point_b = gs.array([0., 1., 0., 0., 0.])
        points = gs.array([point_a, point_b])

        mean = FrechetMean(metric=self.sphere.metric, method='adaptive')
        mean.fit(points)
        result = mean.estimate_

        self.assertAllClose(gs.shape(result), (dim, ))

    @geomstats.tests.np_and_pytorch_only
    def test_estimate_and_belongs_default_gradient_descent_sphere(self):
        point_a = gs.array([1., 0., 0., 0., 0.])
        point_b = gs.array([0., 1., 0., 0., 0.])
        points = gs.array([point_a, point_b])

        mean = FrechetMean(metric=self.sphere.metric, method='default')
        mean.fit(points)

        result = self.sphere.belongs(mean.estimate_)
        expected = True
        self.assertAllClose(result, expected)

    @geomstats.tests.np_and_pytorch_only
    def test_estimate_and_belongs_adaptive_gradient_descent_sphere(self):
        point_a = gs.array([1., 0., 0., 0., 0.])
        point_b = gs.array([0., 1., 0., 0., 0.])
        points = gs.array([point_a, point_b])

        mean = FrechetMean(metric=self.sphere.metric, method='adaptive')
        mean.fit(points)

        result = self.sphere.belongs(mean.estimate_)
        expected = True
        self.assertAllClose(result, expected)

    @geomstats.tests.np_and_pytorch_only
    def test_variance_sphere(self):
        point = gs.array([0., 0., 0., 0., 1.])
        points = gs.array([point, point])

        result = variance(points, base_point=point, metric=self.sphere.metric)
        expected = gs.array(0.)

        self.assertAllClose(expected, result)

    @geomstats.tests.np_and_pytorch_only
    def test_estimate_default_gradient_descent_sphere(self):
        point = gs.array([0., 0., 0., 0., 1.])
        points = gs.array([point, point])

        mean = FrechetMean(metric=self.sphere.metric, method='default')
        mean.fit(X=points)

        result = mean.estimate_
        expected = point

        self.assertAllClose(expected, result)

    @geomstats.tests.np_and_pytorch_only
    def test_estimate_adaptive_gradient_descent_sphere(self):
        point = gs.array([0., 0., 0., 0., 1.])
        points = gs.array([point, point])

        mean = FrechetMean(metric=self.sphere.metric, method='adaptive')
        mean.fit(X=points)

        result = mean.estimate_
        expected = point

        self.assertAllClose(expected, result)

    @geomstats.tests.np_and_pytorch_only
    def test_estimate_transform_sphere(self):
        point = gs.array([0., 0., 0., 0., 1.])
        points = gs.array([point, point])

        mean = FrechetMean(metric=self.sphere.metric)
        mean.fit(X=points)
        result = mean.transform(points)
        expected = gs.zeros_like(points)
        self.assertAllClose(expected, result)

    @geomstats.tests.np_and_pytorch_only
    def test_estimate_transform_spd(self):
        point = SPDMatrices(3).random_uniform()
        points = gs.array([point, point])
        mean = FrechetMean(metric=SPDMetricAffine(3), point_type='matrix')
        mean.fit(X=points)
        result = mean.transform(points)
        expected = gs.zeros((2, 6))
        self.assertAllClose(expected, result)

    def test_fit_transform_hyperbolic(self):
        point = gs.array([2., 1., 1., 1.])
        points = gs.array([point, point])
        mean = FrechetMean(metric=self.hyperbolic.metric)
        result = mean.fit_transform(X=points)
        expected = gs.zeros_like(points)
        self.assertAllClose(expected, result)

    def test_inverse_transform_hyperbolic(self):
        points = self.hyperbolic.random_uniform(10)
        mean = FrechetMean(metric=self.hyperbolic.metric)
        X = mean.fit_transform(X=points)
        result = mean.inverse_transform(X)
        expected = points
        self.assertAllClose(expected, result)

    @geomstats.tests.np_and_pytorch_only
    def test_inverse_transform_spd(self):
        point = SPDMatrices(3).random_uniform(10)
        mean = FrechetMean(metric=SPDMetricAffine(3), point_type='matrix')
        X = mean.fit_transform(X=point)
        result = mean.inverse_transform(X)
        expected = point
        self.assertAllClose(expected, result)

    @geomstats.tests.np_and_tf_only
    def test_variance_hyperbolic(self):
        point = gs.array([2., 1., 1., 1.])
        points = gs.array([point, point])
        result = variance(points,
                          base_point=point,
                          metric=self.hyperbolic.metric)
        expected = gs.array(0.)

        self.assertAllClose(result, expected)

    @geomstats.tests.np_and_tf_only
    def test_estimate_hyperbolic(self):
        point = gs.array([2., 1., 1., 1.])
        points = gs.array([point, point])

        mean = FrechetMean(metric=self.hyperbolic.metric)
        mean.fit(X=points)
        expected = point

        result = mean.estimate_

        self.assertAllClose(result, expected)

    @geomstats.tests.np_and_tf_only
    def test_estimate_and_belongs_hyperbolic(self):
        point_a = self.hyperbolic.random_uniform()
        point_b = self.hyperbolic.random_uniform()
        point_c = self.hyperbolic.random_uniform()
        points = gs.stack([point_a, point_b, point_c], axis=0)

        mean = FrechetMean(metric=self.hyperbolic.metric)
        mean.fit(X=points)

        result = self.hyperbolic.belongs(mean.estimate_)
        expected = True

        self.assertAllClose(result, expected)

    def test_mean_euclidean_shape(self):
        dim = 2
        point = gs.array([1., 4.])

        mean = FrechetMean(metric=self.euclidean.metric)
        points = [point, point, point]
        mean.fit(points)

        result = mean.estimate_

        self.assertAllClose(gs.shape(result), (dim, ))

    def test_mean_euclidean(self):
        point = gs.array([1., 4.])

        mean = FrechetMean(metric=self.euclidean.metric)
        points = [point, point, point]
        mean.fit(points)

        result = mean.estimate_
        expected = point

        self.assertAllClose(result, expected)

        points = gs.array([[1., 2.], [2., 3.], [3., 4.], [4., 5.]])
        weights = gs.array([1., 2., 1., 2.])

        mean = FrechetMean(metric=self.euclidean.metric)
        mean.fit(points, weights=weights)

        result = mean.estimate_
        expected = gs.array([16. / 6., 22. / 6.])

        self.assertAllClose(result, expected)

    def test_variance_euclidean(self):
        points = gs.array([[1., 2.], [2., 3.], [3., 4.], [4., 5.]])
        weights = gs.array([1., 2., 1., 2.])
        base_point = gs.zeros(2)
        result = variance(points,
                          weights=weights,
                          base_point=base_point,
                          metric=self.euclidean.metric)
        # we expect the average of the points' sq norms.
        expected = gs.array((1 * 5. + 2 * 13. + 1 * 25. + 2 * 41.) / 6.)

        self.assertAllClose(result, expected)

    def test_mean_matrices_shape(self):
        m, n = (2, 2)
        point = gs.array([[1., 4.], [2., 3.]])

        metric = MatricesMetric(m, n)
        mean = FrechetMean(metric=metric, point_type='matrix')
        points = [point, point, point]
        mean.fit(points)

        result = mean.estimate_

        self.assertAllClose(gs.shape(result), (m, n))

    def test_mean_matrices(self):
        m, n = (2, 2)
        point = gs.array([[1., 4.], [2., 3.]])

        metric = MatricesMetric(m, n)
        mean = FrechetMean(metric=metric, point_type='matrix')
        points = [point, point, point]
        mean.fit(points)

        result = mean.estimate_
        expected = point

        self.assertAllClose(result, expected)

    def test_mean_minkowski_shape(self):
        dim = 2
        point = gs.array([2., -math.sqrt(3)])
        points = [point, point, point]

        mean = FrechetMean(metric=self.minkowski.metric)
        mean.fit(points)
        result = mean.estimate_

        self.assertAllClose(gs.shape(result), (dim, ))

    def test_mean_minkowski(self):
        point = gs.array([2., -math.sqrt(3)])
        points = [point, point, point]

        mean = FrechetMean(metric=self.minkowski.metric)
        mean.fit(points)
        result = mean.estimate_

        expected = point

        self.assertAllClose(result, expected)

        points = gs.array([[1., 0.], [2., math.sqrt(3)], [3., math.sqrt(8)],
                           [4., math.sqrt(24)]])
        weights = gs.array([1., 2., 1., 2.])

        mean = FrechetMean(metric=self.minkowski.metric)
        mean.fit(points, weights=weights)
        result = mean.estimate_
        result = self.minkowski.belongs(result)
        expected = gs.array(True)

        self.assertAllClose(result, expected)

    def test_variance_minkowski(self):
        points = gs.array([[1., 0.], [2., math.sqrt(3)], [3., math.sqrt(8)],
                           [4., math.sqrt(24)]])
        weights = gs.array([1., 2., 1., 2.])
        base_point = gs.array([-1., 0.])
        var = variance(points,
                       weights=weights,
                       base_point=base_point,
                       metric=self.minkowski.metric)
        result = var != 0
        # we expect the average of the points' Minkowski sq norms.
        expected = True
        self.assertAllClose(result, expected)
Пример #19
0
class TestMinkowskiMethods(geomstats.tests.TestCase):
    def setUp(self):
        gs.random.seed(1234)

        self.time_like_dim = 0
        self.dimension = 2
        self.space = Minkowski(self.dimension)
        self.metric = self.space.metric
        self.n_samples = 10

    def test_belongs(self):
        point = gs.array([-1., 3.])
        result = self.space.belongs(point)
        expected = gs.array([[True]])

        self.assertAllClose(result, expected)

    def test_random_uniform(self):
        point = self.space.random_uniform()
        self.assertAllClose(gs.shape(point), (1, self.dimension))

    def test_random_uniform_and_belongs(self):
        point = self.space.random_uniform()
        result = self.space.belongs(point)
        expected = gs.array([[True]])
        self.assertAllClose(result, expected)

    def test_inner_product_matrix(self):
        result = self.metric.inner_product_matrix()

        expected = gs.array([[-1., 0.], [0., 1.]])
        self.assertAllClose(result, expected)

    def test_inner_product(self):
        point_a = gs.array([0., 1.])
        point_b = gs.array([2., 10.])

        result = self.metric.inner_product(point_a, point_b)
        expected = helper.to_scalar(gs.dot(point_a, point_b))
        expected -= (2 * point_a[self.time_like_dim]
                     * point_b[self.time_like_dim])

        self.assertAllClose(result, expected)

    def test_inner_product_vectorization(self):
        n_samples = 3
        one_point_a = gs.array([[-1., 0.]])
        one_point_b = gs.array([[1.0, 0.]])

        n_points_a = gs.array([
            [-1., 0.],
            [1., 0.],
            [2., math.sqrt(3)]])
        n_points_b = gs.array([
            [2., -math.sqrt(3)],
            [4.0, math.sqrt(15)],
            [-4.0, math.sqrt(15)]])

        result = self.metric.inner_product(one_point_a, one_point_b)
        expected = gs.dot(one_point_a, gs.transpose(one_point_b))
        expected -= (2 * one_point_a[:, self.time_like_dim]
                     * one_point_b[:, self.time_like_dim])
        expected = helper.to_scalar(expected)

        result_no = self.metric.inner_product(n_points_a,
                                              one_point_b)
        result_on = self.metric.inner_product(one_point_a, n_points_b)

        result_nn = self.metric.inner_product(n_points_a, n_points_b)

        self.assertAllClose(result, expected)
        self.assertAllClose(gs.shape(result_no), (n_samples, 1))
        self.assertAllClose(gs.shape(result_on), (n_samples, 1))
        self.assertAllClose(gs.shape(result_nn), (n_samples, 1))

        with self.session():
            expected = np.zeros(n_samples)
            for i in range(n_samples):
                expected[i] = gs.eval(gs.dot(n_points_a[i],
                                             n_points_b[i]))
                expected[i] -= (2 * gs.eval(n_points_a[i, self.time_like_dim])
                                * gs.eval(n_points_b[i, self.time_like_dim]))
            expected = helper.to_scalar(gs.array(expected))

            self.assertAllClose(result_nn, expected)

    def test_squared_norm(self):
        point = gs.array([-2., 4.])

        result = self.metric.squared_norm(point)
        expected = gs.array([[12.]])
        self.assertAllClose(result, expected)

    def test_squared_norm_vectorization(self):
        n_samples = 3
        n_points = gs.array([
            [-1., 0.],
            [1., 0.],
            [2., math.sqrt(3)]])

        result = self.metric.squared_norm(n_points)
        self.assertAllClose(gs.shape(result), (n_samples, 1))

    def test_exp(self):
        base_point = gs.array([1.0, 0.])
        vector = gs.array([2., math.sqrt(3)])

        result = self.metric.exp(tangent_vec=vector,
                                 base_point=base_point)
        expected = base_point + vector
        expected = helper.to_vector(expected)
        self.assertAllClose(result, expected)

    def test_exp_vectorization(self):
        dim = self.dimension
        n_samples = 3
        one_tangent_vec = gs.array([[-1., 0.]])
        one_base_point = gs.array([[1.0, 0.]])

        n_tangent_vecs = gs.array([
            [-1., 0.],
            [1., 0.],
            [2., math.sqrt(3)]])
        n_base_points = gs.array([
            [2., -math.sqrt(3)],
            [4.0, math.sqrt(15)],
            [-4.0, math.sqrt(15)]])

        result = self.metric.exp(one_tangent_vec, one_base_point)
        expected = one_tangent_vec + one_base_point
        expected = helper.to_vector(expected)
        self.assertAllClose(result, expected)

        result = self.metric.exp(n_tangent_vecs, one_base_point)
        self.assertAllClose(gs.shape(result), (n_samples, dim))

        result = self.metric.exp(one_tangent_vec, n_base_points)
        self.assertAllClose(gs.shape(result), (n_samples, dim))

        result = self.metric.exp(n_tangent_vecs, n_base_points)
        self.assertAllClose(gs.shape(result), (n_samples, dim))

    def test_log(self):
        base_point = gs.array([-1., 0.])
        point = gs.array([2., math.sqrt(3)])

        result = self.metric.log(point=point, base_point=base_point)
        expected = point - base_point
        expected = helper.to_vector(expected)
        self.assertAllClose(result, expected)

    def test_log_vectorization(self):
        dim = self.dimension
        n_samples = 3
        one_point = gs.array([[-1., 0.]])
        one_base_point = gs.array([[1.0, 0.]])

        n_points = gs.array([
            [-1., 0.],
            [1., 0.],
            [2., math.sqrt(3)]])
        n_base_points = gs.array([
            [2., -math.sqrt(3)],
            [4.0, math.sqrt(15)],
            [-4.0, math.sqrt(15)]])

        result = self.metric.log(one_point, one_base_point)
        expected = one_point - one_base_point
        expected = helper.to_vector(expected)
        self.assertAllClose(result, expected)

        result = self.metric.log(n_points, one_base_point)
        self.assertAllClose(gs.shape(result), (n_samples, dim))

        result = self.metric.log(one_point, n_base_points)
        self.assertAllClose(gs.shape(result), (n_samples, dim))

        result = self.metric.log(n_points, n_base_points)
        self.assertAllClose(gs.shape(result), (n_samples, dim))

    def test_squared_dist(self):
        point_a = gs.array([2., -math.sqrt(3)])
        point_b = gs.array([4.0, math.sqrt(15)])

        result = self.metric.squared_dist(point_a, point_b)
        vec = point_b - point_a
        expected = gs.dot(vec, vec)
        expected -= 2 * vec[self.time_like_dim] * vec[self.time_like_dim]
        expected = helper.to_scalar(expected)
        self.assertAllClose(result, expected)

    def test_geodesic_and_belongs(self):
        n_geodesic_points = 100
        initial_point = gs.array([2., -math.sqrt(3)])
        initial_tangent_vec = gs.array([2., 0.])

        geodesic = self.metric.geodesic(
            initial_point=initial_point,
            initial_tangent_vec=initial_tangent_vec)

        t = gs.linspace(start=0., stop=1., num=n_geodesic_points)
        points = geodesic(t)

        result = self.space.belongs(points)
        expected = gs.array(n_geodesic_points * [[True]])

        self.assertAllClose(result, expected)

    def test_mean(self):
        point = gs.array([[2., -math.sqrt(3)]])
        result = self.metric.mean(points=[point, point, point])
        expected = point
        expected = helper.to_vector(expected)

        self.assertAllClose(result, expected)

        points = gs.array([
            [1., 0.],
            [2., math.sqrt(3)],
            [3., math.sqrt(8)],
            [4., math.sqrt(24)]])
        weights = gs.array([1., 2., 1., 2.])
        result = self.metric.mean(points, weights)
        result = self.space.belongs(result)
        expected = gs.array([[True]])

        self.assertAllClose(result, expected)

    def test_variance(self):
        points = gs.array([
            [1., 0.],
            [2., math.sqrt(3)],
            [3., math.sqrt(8)],
            [4., math.sqrt(24)]])
        weights = gs.array([1., 2., 1., 2.])
        base_point = gs.array([-1., 0.])
        variance = self.metric.variance(points, weights, base_point)
        result = helper.to_scalar(variance != 0)
        # we expect the average of the points' Minkowski sq norms.
        expected = helper.to_scalar(gs.array([True]))
        self.assertAllClose(result, expected)
Пример #20
0
class TestMinkowski(geomstats.tests.TestCase):
    def setUp(self):
        gs.random.seed(1234)

        self.time_like_dim = 0
        self.dimension = 2
        self.space = Minkowski(self.dimension)
        self.metric = self.space.metric
        self.n_samples = 10

    def test_belongs(self):
        point = gs.array([-1., 3.])
        result = self.space.belongs(point)
        expected = True

        self.assertAllClose(result, expected)

    def test_random_uniform(self):
        point = self.space.random_uniform()
        self.assertAllClose(gs.shape(point), (self.dimension, ))

    def test_random_uniform_and_belongs(self):
        point = self.space.random_uniform()
        result = self.space.belongs(point)
        expected = True
        self.assertAllClose(result, expected)

    def test_inner_product_matrix(self):
        result = self.metric.inner_product_matrix()

        expected = gs.array([[-1., 0.], [0., 1.]])
        self.assertAllClose(result, expected)

    def test_inner_product(self):
        point_a = gs.array([0., 1.])
        point_b = gs.array([2., 10.])

        result = self.metric.inner_product(point_a, point_b)
        expected = gs.dot(point_a, point_b)
        expected -= (2 * point_a[self.time_like_dim] *
                     point_b[self.time_like_dim])

        self.assertAllClose(result, expected)

    def test_inner_product_vectorization(self):
        n_samples = 3
        one_point_a = gs.array([-1., 0.])
        one_point_b = gs.array([1.0, 0.])

        n_points_a = gs.array([[-1., 0.], [1., 0.], [2., math.sqrt(3)]])
        n_points_b = gs.array([[2., -math.sqrt(3)], [4.0, math.sqrt(15)],
                               [-4.0, math.sqrt(15)]])

        result = self.metric.inner_product(one_point_a, one_point_b)
        expected = gs.dot(one_point_a, gs.transpose(one_point_b))
        expected -= (2 * one_point_a[self.time_like_dim] *
                     one_point_b[self.time_like_dim])

        result_no = self.metric.inner_product(n_points_a, one_point_b)
        result_on = self.metric.inner_product(one_point_a, n_points_b)

        result_nn = self.metric.inner_product(n_points_a, n_points_b)

        self.assertAllClose(result, expected)
        self.assertAllClose(gs.shape(result_no), (n_samples, ))
        self.assertAllClose(gs.shape(result_on), (n_samples, ))
        self.assertAllClose(gs.shape(result_nn), (n_samples, ))

        expected = np.zeros(n_samples)
        for i in range(n_samples):
            expected[i] = gs.dot(n_points_a[i], n_points_b[i])
            expected[i] -= (2 * n_points_a[i, self.time_like_dim] *
                            n_points_b[i, self.time_like_dim])

        self.assertAllClose(result_nn, expected)

    def test_squared_norm(self):
        point = gs.array([-2., 4.])

        result = self.metric.squared_norm(point)
        expected = 12.
        self.assertAllClose(result, expected)

    def test_squared_norm_vectorization(self):
        n_samples = 3
        n_points = gs.array([[-1., 0.], [1., 0.], [2., math.sqrt(3)]])

        result = self.metric.squared_norm(n_points)
        self.assertAllClose(gs.shape(result), (n_samples, ))

    def test_exp(self):
        base_point = gs.array([1.0, 0.])
        vector = gs.array([2., math.sqrt(3)])

        result = self.metric.exp(tangent_vec=vector, base_point=base_point)
        expected = base_point + vector
        self.assertAllClose(result, expected)

    def test_exp_vectorization(self):
        dim = self.dimension
        n_samples = 3
        one_tangent_vec = gs.array([-1., 0.])
        one_base_point = gs.array([1.0, 0.])

        n_tangent_vecs = gs.array([[-1., 0.], [1., 0.], [2., math.sqrt(3)]])
        n_base_points = gs.array([[2., -math.sqrt(3)], [4.0,
                                                        math.sqrt(15)],
                                  [-4.0, math.sqrt(15)]])

        result = self.metric.exp(one_tangent_vec, one_base_point)
        expected = one_tangent_vec + one_base_point
        self.assertAllClose(result, expected)

        result = self.metric.exp(n_tangent_vecs, one_base_point)
        self.assertAllClose(gs.shape(result), (n_samples, dim))

        result = self.metric.exp(one_tangent_vec, n_base_points)
        self.assertAllClose(gs.shape(result), (n_samples, dim))

        result = self.metric.exp(n_tangent_vecs, n_base_points)
        self.assertAllClose(gs.shape(result), (n_samples, dim))

    def test_log(self):
        base_point = gs.array([-1., 0.])
        point = gs.array([2., math.sqrt(3)])

        result = self.metric.log(point=point, base_point=base_point)
        expected = point - base_point
        self.assertAllClose(result, expected)

    def test_log_vectorization(self):
        dim = self.dimension
        n_samples = 3
        one_point = gs.array([-1., 0.])
        one_base_point = gs.array([1.0, 0.])

        n_points = gs.array([[-1., 0.], [1., 0.], [2., math.sqrt(3)]])
        n_base_points = gs.array([[2., -math.sqrt(3)], [4.0,
                                                        math.sqrt(15)],
                                  [-4.0, math.sqrt(15)]])

        result = self.metric.log(one_point, one_base_point)
        expected = one_point - one_base_point
        self.assertAllClose(result, expected)

        result = self.metric.log(n_points, one_base_point)
        self.assertAllClose(gs.shape(result), (n_samples, dim))

        result = self.metric.log(one_point, n_base_points)
        self.assertAllClose(gs.shape(result), (n_samples, dim))

        result = self.metric.log(n_points, n_base_points)
        self.assertAllClose(gs.shape(result), (n_samples, dim))

    def test_squared_dist(self):
        point_a = gs.array([2., -math.sqrt(3)])
        point_b = gs.array([4.0, math.sqrt(15)])

        result = self.metric.squared_dist(point_a, point_b)
        vec = point_b - point_a
        expected = gs.dot(vec, vec)
        expected -= 2 * vec[self.time_like_dim] * vec[self.time_like_dim]
        self.assertAllClose(result, expected)

    def test_geodesic_and_belongs(self):
        n_geodesic_points = 100
        initial_point = gs.array([2., -math.sqrt(3)])
        initial_tangent_vec = gs.array([2., 0.])

        geodesic = self.metric.geodesic(
            initial_point=initial_point,
            initial_tangent_vec=initial_tangent_vec)

        t = gs.linspace(start=0., stop=1., num=n_geodesic_points)
        points = geodesic(t)

        result = self.space.belongs(points)
        expected = gs.array(n_geodesic_points * [True])

        self.assertAllClose(result, expected)
Пример #21
0
    class MinkowskiMetricTestData(_RiemannianMetricTestData):
        n_list = random.sample(range(2, 4), 2)
        metric_args_list = [(n, ) for n in n_list]
        shape_list = metric_args_list
        space_list = [Minkowski(n) for n in n_list]
        n_points_list = random.sample(range(1, 3), 2)
        n_tangent_vecs_list = random.sample(range(1, 3), 2)
        n_points_a_list = random.sample(range(1, 3), 2)
        n_points_b_list = [1]
        alpha_list = [1] * 2
        n_rungs_list = [1] * 2
        scheme_list = ["pole"] * 2

        def metric_matrix_test_data(self):
            smoke_data = [dict(dim=2, expected=[[-1.0, 0.0], [0.0, 1.0]])]
            return self.generate_tests(smoke_data)

        def inner_product_test_data(self):
            smoke_data = [
                dict(dim=2,
                     point_a=[0.0, 1.0],
                     point_b=[2.0, 10.0],
                     expected=10.0),
                dict(
                    dim=2,
                    point_a=[[-1.0, 0.0], [1.0, 0.0], [2.0, math.sqrt(3)]],
                    point_b=[
                        [2.0, -math.sqrt(3)],
                        [4.0, math.sqrt(15)],
                        [-4.0, math.sqrt(15)],
                    ],
                    expected=[2.0, -4.0, 14.70820393],
                ),
            ]
            return self.generate_tests(smoke_data)

        def squared_norm_test_data(self):
            smoke_data = [dict(dim=2, vector=[-2.0, 4.0], expected=12.0)]
            return self.generate_tests(smoke_data)

        def squared_dist_test_data(self):
            smoke_data = [
                dict(
                    dim=2,
                    point_a=[2.0, -math.sqrt(3)],
                    point_b=[4.0, math.sqrt(15)],
                    expected=27.416407,
                )
            ]
            return self.generate_tests(smoke_data)

        def exp_test_data(self):
            smoke_data = [
                dict(
                    dim=2,
                    tangent_vec=[2.0, math.sqrt(3)],
                    base_point=[1.0, 0.0],
                    expected=[3.0, math.sqrt(3)],
                )
            ]
            return self.generate_tests(smoke_data)

        def log_test_data(self):
            smoke_data = [
                dict(
                    dim=2,
                    point=[2.0, math.sqrt(3)],
                    base_point=[-1.0, 0.0],
                    expected=[3.0, math.sqrt(3)],
                )
            ]
            return self.generate_tests(smoke_data)

        def exp_shape_test_data(self):
            return self._exp_shape_test_data(self.metric_args_list,
                                             self.space_list, self.shape_list)

        def log_shape_test_data(self):
            return self._log_shape_test_data(
                self.metric_args_list,
                self.space_list,
            )

        def squared_dist_is_symmetric_test_data(self):
            return self._squared_dist_is_symmetric_test_data(
                self.metric_args_list,
                self.space_list,
                self.n_points_a_list,
                self.n_points_b_list,
                atol=gs.atol * 1000,
            )

        def exp_belongs_test_data(self):
            return self._exp_belongs_test_data(
                self.metric_args_list,
                self.space_list,
                self.shape_list,
                self.n_tangent_vecs_list,
                belongs_atol=gs.atol * 1000,
            )

        def log_is_tangent_test_data(self):
            return self._log_is_tangent_test_data(
                self.metric_args_list,
                self.space_list,
                self.n_points_list,
                is_tangent_atol=gs.atol * 1000,
            )

        def geodesic_ivp_belongs_test_data(self):
            return self._geodesic_ivp_belongs_test_data(
                self.metric_args_list,
                self.space_list,
                self.shape_list,
                self.n_points_list,
                belongs_atol=gs.atol * 1000,
            )

        def geodesic_bvp_belongs_test_data(self):
            return self._geodesic_bvp_belongs_test_data(
                self.metric_args_list,
                self.space_list,
                self.n_points_list,
                belongs_atol=gs.atol * 1000,
            )

        def log_then_exp_test_data(self):
            return self._log_then_exp_test_data(
                self.metric_args_list,
                self.space_list,
                self.n_points_list,
                rtol=gs.rtol * 100,
                atol=gs.atol * 10000,
            )

        def exp_then_log_test_data(self):
            return self._exp_then_log_test_data(
                self.metric_args_list,
                self.space_list,
                self.shape_list,
                self.n_tangent_vecs_list,
                rtol=gs.rtol * 100,
                atol=gs.atol * 10000,
            )

        def exp_ladder_parallel_transport_test_data(self):
            return self._exp_ladder_parallel_transport_test_data(
                self.metric_args_list,
                self.space_list,
                self.shape_list,
                self.n_tangent_vecs_list,
                self.n_rungs_list,
                self.alpha_list,
                self.scheme_list,
            )

        def exp_geodesic_ivp_test_data(self):
            return self._exp_geodesic_ivp_test_data(
                self.metric_args_list,
                self.space_list,
                self.shape_list,
                self.n_tangent_vecs_list,
                self.n_points_list,
                rtol=gs.rtol * 1000,
                atol=gs.atol * 1000,
            )

        def parallel_transport_ivp_is_isometry_test_data(self):
            return self._parallel_transport_ivp_is_isometry_test_data(
                self.metric_args_list,
                self.space_list,
                self.shape_list,
                self.n_tangent_vecs_list,
                is_tangent_atol=gs.atol * 1000,
                atol=gs.atol * 1000,
            )

        def parallel_transport_bvp_is_isometry_test_data(self):
            return self._parallel_transport_bvp_is_isometry_test_data(
                self.metric_args_list,
                self.space_list,
                self.shape_list,
                self.n_tangent_vecs_list,
                is_tangent_atol=gs.atol * 1000,
                atol=gs.atol * 1000,
            )

        def dist_is_symmetric_test_data(self):
            return self._dist_is_symmetric_test_data(
                self.metric_args_list,
                self.space_list,
                self.n_points_a_list,
                self.n_points_b_list,
            )

        def dist_is_norm_of_log_test_data(self):
            return self._dist_is_norm_of_log_test_data(
                self.metric_args_list,
                self.space_list,
                self.n_points_a_list,
                self.n_points_b_list,
            )

        def dist_point_to_itself_is_zero_test_data(self):
            return self._dist_point_to_itself_is_zero_test_data(
                self.metric_args_list, self.space_list, self.n_points_list)

        def inner_product_is_symmetric_test_data(self):
            return self._inner_product_is_symmetric_test_data(
                self.metric_args_list,
                self.space_list,
                self.shape_list,
                self.n_tangent_vecs_list,
            )
Пример #22
0
class TestFrechetMean(geomstats.tests.TestCase):
    _multiprocess_can_split_ = True

    def setup_method(self):
        gs.random.seed(123)
        self.sphere = Hypersphere(dim=4)
        self.hyperbolic = Hyperboloid(dim=3)
        self.euclidean = Euclidean(dim=2)
        self.minkowski = Minkowski(dim=2)
        self.so3 = SpecialOrthogonal(n=3, point_type="vector")
        self.so_matrix = SpecialOrthogonal(n=3)
        self.curves_2d = DiscreteCurves(R2)
        self.elastic_metric = ElasticMetric(a=1, b=1, ambient_manifold=R2)

    def test_logs_at_mean_curves_2d(self):
        n_tests = 10
        metric = self.curves_2d.srv_metric
        estimator = FrechetMean(metric=metric, init_step_size=1.0)

        result = []
        for _ in range(n_tests):
            # take 2 random points, compute their mean, and verify that
            # log of each at the mean is opposite
            points = self.curves_2d.random_point(n_samples=2)
            estimator.fit(points)
            mean = estimator.estimate_

            # Expand and tile are added because of GitHub issue 1575
            mean = gs.expand_dims(mean, axis=0)
            mean = gs.tile(mean, (2, 1, 1))

            logs = metric.log(point=points, base_point=mean)
            logs_srv = metric.aux_differential_srv_transform(logs, curve=mean)
            # Note that the logs are NOT inverse, only the logs_srv are.
            result.append(gs.linalg.norm(logs_srv[1, :] + logs_srv[0, :]))
        result = gs.stack(result)
        expected = gs.zeros(n_tests)
        self.assertAllClose(expected, result, atol=1e-5)

    def test_logs_at_mean_default_gradient_descent_sphere(self):
        n_tests = 10
        estimator = FrechetMean(metric=self.sphere.metric,
                                method="default",
                                init_step_size=1.0)

        result = []
        for _ in range(n_tests):
            # take 2 random points, compute their mean, and verify that
            # log of each at the mean is opposite
            points = self.sphere.random_uniform(n_samples=2)
            estimator.fit(points)
            mean = estimator.estimate_

            logs = self.sphere.metric.log(point=points, base_point=mean)
            result.append(gs.linalg.norm(logs[1, :] + logs[0, :]))
        result = gs.stack(result)
        expected = gs.zeros(n_tests)
        self.assertAllClose(expected, result)

    def test_logs_at_mean_adaptive_gradient_descent_sphere(self):
        n_tests = 10
        estimator = FrechetMean(metric=self.sphere.metric, method="adaptive")

        result = []
        for _ in range(n_tests):
            # take 2 random points, compute their mean, and verify that
            # log of each at the mean is opposite
            points = self.sphere.random_uniform(n_samples=2)
            estimator.fit(points)
            mean = estimator.estimate_

            logs = self.sphere.metric.log(point=points, base_point=mean)
            result.append(gs.linalg.norm(logs[1, :] + logs[0, :]))
        result = gs.stack(result)

        expected = gs.zeros(n_tests)
        self.assertAllClose(expected, result)

    def test_estimate_shape_default_gradient_descent_sphere(self):
        dim = 5
        point_a = gs.array([1.0, 0.0, 0.0, 0.0, 0.0])
        point_b = gs.array([0.0, 1.0, 0.0, 0.0, 0.0])
        points = gs.array([point_a, point_b])

        mean = FrechetMean(metric=self.sphere.metric,
                           method="default",
                           verbose=True)
        mean.fit(points)
        result = mean.estimate_

        self.assertAllClose(gs.shape(result), (dim, ))

    def test_estimate_shape_adaptive_gradient_descent_sphere(self):
        dim = 5
        point_a = gs.array([1.0, 0.0, 0.0, 0.0, 0.0])
        point_b = gs.array([0.0, 1.0, 0.0, 0.0, 0.0])
        points = gs.array([point_a, point_b])

        mean = FrechetMean(metric=self.sphere.metric, method="adaptive")
        mean.fit(points)
        result = mean.estimate_

        self.assertAllClose(gs.shape(result), (dim, ))

    def test_estimate_shape_elastic_metric(self):
        points = self.curves_2d.random_point(n_samples=2)

        mean = FrechetMean(metric=self.elastic_metric)
        mean.fit(points)
        result = mean.estimate_

        self.assertAllClose(gs.shape(result), (points.shape[1:]))

    def test_estimate_shape_metric(self):
        points = self.curves_2d.random_point(n_samples=2)

        mean = FrechetMean(metric=self.curves_2d.srv_metric)
        mean.fit(points)
        result = mean.estimate_

        self.assertAllClose(gs.shape(result), (points.shape[1:]))

    def test_estimate_and_belongs_default_gradient_descent_sphere(self):
        point_a = gs.array([1.0, 0.0, 0.0, 0.0, 0.0])
        point_b = gs.array([0.0, 1.0, 0.0, 0.0, 0.0])
        points = gs.array([point_a, point_b])

        mean = FrechetMean(metric=self.sphere.metric, method="default")
        mean.fit(points)

        result = self.sphere.belongs(mean.estimate_)
        expected = True
        self.assertAllClose(result, expected)

    def test_estimate_and_belongs_curves_2d(self):
        points = self.curves_2d.random_point(n_samples=2)

        mean = FrechetMean(metric=self.curves_2d.srv_metric)
        mean.fit(points)

        result = self.curves_2d.belongs(mean.estimate_)
        expected = True
        self.assertAllClose(result, expected)

    def test_estimate_default_gradient_descent_so3(self):
        points = self.so3.random_uniform(2)

        mean_vec = FrechetMean(metric=self.so3.bi_invariant_metric,
                               method="default",
                               init_step_size=1.0)
        mean_vec.fit(points)

        logs = self.so3.bi_invariant_metric.log(points, mean_vec.estimate_)
        result = gs.sum(logs, axis=0)
        expected = gs.zeros_like(points[0])
        self.assertAllClose(result, expected)

    def test_estimate_and_belongs_default_gradient_descent_so3(self):
        point = self.so3.random_uniform(10)

        mean_vec = FrechetMean(metric=self.so3.bi_invariant_metric,
                               method="default")
        mean_vec.fit(point)

        result = self.so3.belongs(mean_vec.estimate_)
        expected = True
        self.assertAllClose(result, expected)

    @geomstats.tests.np_autograd_and_tf_only
    def test_estimate_default_gradient_descent_so_matrix(self):
        points = self.so_matrix.random_uniform(2)
        mean_vec = FrechetMean(
            metric=self.so_matrix.bi_invariant_metric,
            method="default",
            init_step_size=1.0,
        )
        mean_vec.fit(points)
        logs = self.so_matrix.bi_invariant_metric.log(points,
                                                      mean_vec.estimate_)
        result = gs.sum(logs, axis=0)
        expected = gs.zeros_like(points[0])

        self.assertAllClose(result, expected, atol=1e-5)

    @geomstats.tests.np_autograd_and_tf_only
    def test_estimate_and_belongs_default_gradient_descent_so_matrix(self):
        point = self.so_matrix.random_uniform(10)

        mean = FrechetMean(metric=self.so_matrix.bi_invariant_metric,
                           method="default")
        mean.fit(point)

        result = self.so_matrix.belongs(mean.estimate_)
        expected = True
        self.assertAllClose(result, expected)

    @geomstats.tests.np_autograd_and_tf_only
    def test_estimate_and_belongs_adaptive_gradient_descent_so_matrix(self):
        point = self.so_matrix.random_uniform(10)

        mean = FrechetMean(
            metric=self.so_matrix.bi_invariant_metric,
            method="adaptive",
            init_step_size=0.5,
            verbose=True,
        )
        mean.fit(point)

        result = self.so_matrix.belongs(mean.estimate_)
        self.assertTrue(result)

    @geomstats.tests.np_autograd_and_tf_only
    def test_estimate_and_coincide_default_so_vec_and_mat(self):
        point = self.so_matrix.random_uniform(3)

        mean = FrechetMean(metric=self.so_matrix.bi_invariant_metric,
                           method="default")
        mean.fit(point)
        expected = mean.estimate_

        mean_vec = FrechetMean(metric=self.so3.bi_invariant_metric,
                               method="default")
        point_vec = self.so3.rotation_vector_from_matrix(point)
        mean_vec.fit(point_vec)
        result_vec = mean_vec.estimate_
        result = self.so3.matrix_from_rotation_vector(result_vec)

        self.assertAllClose(result, expected)

    def test_estimate_and_belongs_adaptive_gradient_descent_sphere(self):
        point_a = gs.array([1.0, 0.0, 0.0, 0.0, 0.0])
        point_b = gs.array([0.0, 1.0, 0.0, 0.0, 0.0])
        points = gs.array([point_a, point_b])

        mean = FrechetMean(metric=self.sphere.metric, method="adaptive")
        mean.fit(points)

        result = self.sphere.belongs(mean.estimate_)
        expected = True
        self.assertAllClose(result, expected)

    def test_variance_sphere(self):
        point = gs.array([0.0, 0.0, 0.0, 0.0, 1.0])
        points = gs.array([point, point])

        result = variance(points, base_point=point, metric=self.sphere.metric)
        expected = gs.array(0.0)

        self.assertAllClose(expected, result)

    def test_estimate_default_gradient_descent_sphere(self):
        point = gs.array([0.0, 0.0, 0.0, 0.0, 1.0])
        points = gs.array([point, point])

        mean = FrechetMean(metric=self.sphere.metric, method="default")
        mean.fit(X=points)

        result = mean.estimate_
        expected = point

        self.assertAllClose(expected, result)

    def test_estimate_elastic_metric(self):
        point = self.curves_2d.random_point(n_samples=1)
        points = gs.array([point, point])

        mean = FrechetMean(metric=self.elastic_metric)
        mean.fit(X=points)

        result = mean.estimate_
        expected = point

        self.assertAllClose(expected, result)

    def test_estimate_curves_2d(self):
        point = self.curves_2d.random_point(n_samples=1)
        points = gs.array([point, point])

        mean = FrechetMean(metric=self.curves_2d.srv_metric)
        mean.fit(X=points)

        result = mean.estimate_
        expected = point

        self.assertAllClose(expected, result)

    def test_estimate_adaptive_gradient_descent_sphere(self):
        point = gs.array([0.0, 0.0, 0.0, 0.0, 1.0])
        points = gs.array([point, point])

        mean = FrechetMean(metric=self.sphere.metric, method="adaptive")
        mean.fit(X=points)

        result = mean.estimate_
        expected = point

        self.assertAllClose(expected, result)

    def test_estimate_spd(self):
        point = SPDMatrices(3).random_point()
        points = gs.array([point, point])
        mean = FrechetMean(metric=SPDMetricAffine(3), point_type="matrix")
        mean.fit(X=points)
        result = mean.estimate_
        expected = point
        self.assertAllClose(expected, result)

    def test_estimate_spd_two_samples(self):
        space = SPDMatrices(3)
        metric = SPDMetricAffine(3)
        point = space.random_point(2)
        mean = FrechetMean(metric)
        mean.fit(point)
        result = mean.estimate_
        expected = metric.exp(metric.log(point[0], point[1]) / 2, point[1])
        self.assertAllClose(expected, result)

    def test_variance_hyperbolic(self):
        point = gs.array([2.0, 1.0, 1.0, 1.0])
        points = gs.array([point, point])
        result = variance(points,
                          base_point=point,
                          metric=self.hyperbolic.metric)
        expected = gs.array(0.0)

        self.assertAllClose(result, expected)

    def test_estimate_hyperbolic(self):
        point = gs.array([2.0, 1.0, 1.0, 1.0])
        points = gs.array([point, point])

        mean = FrechetMean(metric=self.hyperbolic.metric)
        mean.fit(X=points)
        expected = point

        result = mean.estimate_

        self.assertAllClose(result, expected)

    def test_estimate_and_belongs_hyperbolic(self):
        point_a = self.hyperbolic.random_point()
        point_b = self.hyperbolic.random_point()
        point_c = self.hyperbolic.random_point()
        points = gs.stack([point_a, point_b, point_c], axis=0)

        mean = FrechetMean(metric=self.hyperbolic.metric)
        mean.fit(X=points)

        result = self.hyperbolic.belongs(mean.estimate_)
        expected = True

        self.assertAllClose(result, expected)

    def test_mean_euclidean_shape(self):
        dim = 2
        point = gs.array([1.0, 4.0])

        mean = FrechetMean(metric=self.euclidean.metric)
        points = [point, point, point]
        mean.fit(points)

        result = mean.estimate_

        self.assertAllClose(gs.shape(result), (dim, ))

    def test_mean_euclidean(self):
        point = gs.array([1.0, 4.0])

        mean = FrechetMean(metric=self.euclidean.metric)
        points = [point, point, point]
        mean.fit(points)

        result = mean.estimate_
        expected = point

        self.assertAllClose(result, expected)

        points = gs.array([[1.0, 2.0], [2.0, 3.0], [3.0, 4.0], [4.0, 5.0]])
        weights = [1.0, 2.0, 1.0, 2.0]

        mean = FrechetMean(metric=self.euclidean.metric)
        mean.fit(points, weights=weights)

        result = mean.estimate_
        expected = gs.array([16.0 / 6.0, 22.0 / 6.0])

        self.assertAllClose(result, expected)

    def test_variance_euclidean(self):
        points = gs.array([[1.0, 2.0], [2.0, 3.0], [3.0, 4.0], [4.0, 5.0]])
        weights = gs.array([1.0, 2.0, 1.0, 2.0])
        base_point = gs.zeros(2)
        result = variance(points,
                          weights=weights,
                          base_point=base_point,
                          metric=self.euclidean.metric)
        # we expect the average of the points' sq norms.
        expected = gs.array((1 * 5.0 + 2 * 13.0 + 1 * 25.0 + 2 * 41.0) / 6.0)

        self.assertAllClose(result, expected)

    def test_mean_matrices_shape(self):
        m, n = (2, 2)
        point = gs.array([[1.0, 4.0], [2.0, 3.0]])

        metric = MatricesMetric(m, n)
        mean = FrechetMean(metric=metric, point_type="matrix")
        points = [point, point, point]
        mean.fit(points)

        result = mean.estimate_

        self.assertAllClose(gs.shape(result), (m, n))

    def test_mean_matrices(self):
        m, n = (2, 2)
        point = gs.array([[1.0, 4.0], [2.0, 3.0]])

        metric = MatricesMetric(m, n)
        mean = FrechetMean(metric=metric, point_type="matrix")
        points = [point, point, point]
        mean.fit(points)

        result = mean.estimate_
        expected = point

        self.assertAllClose(result, expected)

    def test_mean_minkowski_shape(self):
        dim = 2
        point = gs.array([2.0, -math.sqrt(3)])
        points = [point, point, point]

        mean = FrechetMean(metric=self.minkowski.metric)
        mean.fit(points)
        result = mean.estimate_

        self.assertAllClose(gs.shape(result), (dim, ))

    def test_mean_minkowski(self):
        point = gs.array([2.0, -math.sqrt(3)])
        points = [point, point, point]

        mean = FrechetMean(metric=self.minkowski.metric)
        mean.fit(points)
        result = mean.estimate_

        expected = point

        self.assertAllClose(result, expected)

        points = gs.array([[1.0, 0.0], [2.0, math.sqrt(3)],
                           [3.0, math.sqrt(8)], [4.0, math.sqrt(24)]])
        weights = gs.array([1.0, 2.0, 1.0, 2.0])

        mean = FrechetMean(metric=self.minkowski.metric)
        mean.fit(points, weights=weights)
        result = self.minkowski.belongs(mean.estimate_)

        self.assertTrue(result)

    def test_variance_minkowski(self):
        points = gs.array([[1.0, 0.0], [2.0, math.sqrt(3)],
                           [3.0, math.sqrt(8)], [4.0, math.sqrt(24)]])
        weights = gs.array([1.0, 2.0, 1.0, 2.0])
        base_point = gs.array([-1.0, 0.0])
        var = variance(points,
                       weights=weights,
                       base_point=base_point,
                       metric=self.minkowski.metric)
        result = var != 0
        # we expect the average of the points' Minkowski sq norms.
        expected = True
        self.assertAllClose(result, expected)

    def test_one_point(self):
        point = gs.array([0.0, 0.0, 0.0, 0.0, 1.0])

        mean = FrechetMean(metric=self.sphere.metric, method="default")
        mean.fit(X=point)

        result = mean.estimate_
        expected = point
        self.assertAllClose(expected, result)

        mean = FrechetMean(metric=self.sphere.metric, method="default")
        mean.fit(X=point)

        result = mean.estimate_
        expected = point
        self.assertAllClose(expected, result)

    def test_batched(self):
        space = SPDMatrices(3)
        metric = SPDMetricAffine(3)
        point = space.random_point(4)
        mean_batch = FrechetMean(metric, method="batch", verbose=True)
        data = gs.stack([point[:2], point[2:]], axis=1)
        mean_batch.fit(data)
        result = mean_batch.estimate_

        mean = FrechetMean(metric)
        mean.fit(data[:, 0])
        expected_1 = mean.estimate_
        mean.fit(data[:, 1])
        expected_2 = mean.estimate_
        expected = gs.stack([expected_1, expected_2])
        self.assertAllClose(expected, result)

    @geomstats.tests.np_and_autograd_only
    def test_stiefel_two_samples(self):
        space = Stiefel(3, 2)
        metric = space.metric
        point = space.random_point(2)
        mean = FrechetMean(metric)
        mean.fit(point)
        result = mean.estimate_
        expected = metric.exp(metric.log(point[0], point[1]) / 2, point[1])
        self.assertAllClose(expected, result)

    @geomstats.tests.np_and_autograd_only
    def test_stiefel_n_samples(self):
        space = Stiefel(3, 2)
        metric = space.metric
        point = space.random_point(2)
        mean = FrechetMean(metric,
                           method="default",
                           init_step_size=0.5,
                           verbose=True)
        mean.fit(point)
        result = space.belongs(mean.estimate_)
        self.assertTrue(result)

    def test_circle_mean(self):
        space = Hypersphere(1)
        points = space.random_uniform(10)
        mean_circle = FrechetMean(space.metric)
        mean_circle.fit(points)
        estimate_circle = mean_circle.estimate_

        # set a wrong dimension so that the extrinsic coordinates are used
        metric = space.metric
        metric.dim = 2
        mean_extrinsic = FrechetMean(metric)
        mean_extrinsic.fit(points)
        estimate_extrinsic = mean_extrinsic.estimate_
        self.assertAllClose(estimate_circle, estimate_extrinsic)