Example #1
0
    def test_exp_and_belongs(self):
        H2 = Hyperbolic(dimension=2)
        METRIC = H2.metric

        base_point = gs.array([1., 0., 0.])
        with self.session():
            self.assertTrue(gs.eval(H2.belongs(base_point)))

        tangent_vec = H2.projection_to_tangent_space(vector=gs.array(
            [1., 2., 1.]),
                                                     base_point=base_point)
        exp = METRIC.exp(tangent_vec=tangent_vec, base_point=base_point)
        with self.session():
            self.assertTrue(gs.eval(H2.belongs(exp)))
Example #2
0
class TestFrechetMean(geomstats.tests.TestCase):
    _multiprocess_can_split_ = True

    def setUp(self):
        self.sphere = Hypersphere(dimension=4)
        self.hyperbolic = Hyperbolic(dimension=3)
        self.euclidean = Euclidean(dimension=2)
        self.minkowski = Minkowski(dimension=2)

    @geomstats.tests.np_only
    def test_adaptive_gradient_descent_sphere(self):
        n_tests = 100
        result = gs.zeros(n_tests)
        expected = gs.zeros(n_tests)

        for i in range(n_tests):
            # take 2 random points, compute their mean, and verify that
            # log of each at the mean is opposite
            points = self.sphere.random_uniform(n_samples=2)
            mean = _adaptive_gradient_descent(points=points,
                                              metric=self.sphere.metric)

            logs = self.sphere.metric.log(point=points, base_point=mean)
            result[i] = gs.linalg.norm(logs[1, :] + logs[0, :])

        self.assertAllClose(expected, result, rtol=1e-10, atol=1e-10)

    @geomstats.tests.np_and_pytorch_only
    def test_estimate_and_belongs_sphere(self):
        point_a = gs.array([1., 0., 0., 0., 0.])
        point_b = gs.array([0., 1., 0., 0., 0.])
        points = gs.zeros((2, point_a.shape[0]))
        points[0, :] = point_a
        points[1, :] = point_b

        mean = FrechetMean(metric=self.sphere.metric)
        mean.fit(points)

        result = self.sphere.belongs(mean.estimate_)
        expected = gs.array([[True]])
        self.assertAllClose(result, expected)

    @geomstats.tests.np_and_pytorch_only
    def test_variance_sphere(self):
        point = gs.array([0., 0., 0., 0., 1.])
        points = gs.zeros((2, point.shape[0]))
        points[0, :] = point
        points[1, :] = point

        result = variance(points, base_point=point, metric=self.sphere.metric)
        expected = helper.to_scalar(0.)

        self.assertAllClose(expected, result)

    @geomstats.tests.np_and_pytorch_only
    def test_estimate_sphere(self):
        point = gs.array([0., 0., 0., 0., 1.])
        points = gs.zeros((2, point.shape[0]))
        points[0, :] = point
        points[1, :] = point

        mean = FrechetMean(metric=self.sphere.metric)
        mean.fit(X=points)

        result = mean.estimate_
        expected = helper.to_vector(point)

        self.assertAllClose(expected, result)

    @geomstats.tests.np_and_tf_only
    def test_variance_hyperbolic(self):
        point = gs.array([2., 1., 1., 1.])
        points = gs.array([point, point])
        result = variance(points,
                          base_point=point,
                          metric=self.hyperbolic.metric)
        expected = helper.to_scalar(0.)

        self.assertAllClose(result, expected)

    @geomstats.tests.np_and_tf_only
    def test_estimate_hyperbolic(self):
        point = gs.array([2., 1., 1., 1.])
        points = gs.array([point, point])

        mean = FrechetMean(metric=self.hyperbolic.metric)
        mean.fit(X=points)

        result = mean.estimate_
        expected = helper.to_vector(point)

        self.assertAllClose(result, expected)

    @geomstats.tests.np_and_tf_only
    def test_estimate_and_belongs_hyperbolic(self):
        point_a = self.hyperbolic.random_uniform()
        point_b = self.hyperbolic.random_uniform()
        point_c = self.hyperbolic.random_uniform()
        points = gs.concatenate([point_a, point_b, point_c], axis=0)

        mean = FrechetMean(metric=self.hyperbolic.metric)
        mean.fit(X=points)

        result = self.hyperbolic.belongs(mean.estimate_)
        expected = gs.array([[True]])

        self.assertAllClose(result, expected)

    def test_mean_euclidean(self):
        point = gs.array([[1., 4.]])

        mean = FrechetMean(metric=self.euclidean.metric)
        points = [point, point, point]
        mean.fit(points)

        result = mean.estimate_
        expected = point
        expected = helper.to_vector(expected)

        self.assertAllClose(result, expected)

        points = gs.array([[1., 2.], [2., 3.], [3., 4.], [4., 5.]])
        weights = gs.array([1., 2., 1., 2.])

        mean = FrechetMean(metric=self.euclidean.metric)
        mean.fit(points, weights=weights)

        result = mean.estimate_
        expected = gs.array([16. / 6., 22. / 6.])
        expected = helper.to_vector(expected)

        self.assertAllClose(result, expected)

    def test_variance_euclidean(self):
        points = gs.array([[1., 2.], [2., 3.], [3., 4.], [4., 5.]])
        weights = gs.array([1., 2., 1., 2.])
        base_point = gs.zeros(2)
        result = variance(points,
                          weights=weights,
                          base_point=base_point,
                          metric=self.euclidean.metric)
        # we expect the average of the points' sq norms.
        expected = (1 * 5. + 2 * 13. + 1 * 25. + 2 * 41.) / 6.
        expected = helper.to_scalar(expected)

        self.assertAllClose(result, expected)

    def test_mean_minkowski(self):
        point = gs.array([[2., -math.sqrt(3)]])
        points = [point, point, point]

        mean = FrechetMean(metric=self.minkowski.metric)
        mean.fit(points)
        result = mean.estimate_

        expected = point
        expected = helper.to_vector(expected)

        self.assertAllClose(result, expected)

        points = gs.array([[1., 0.], [2., math.sqrt(3)], [3., math.sqrt(8)],
                           [4., math.sqrt(24)]])
        weights = gs.array([1., 2., 1., 2.])

        mean = FrechetMean(metric=self.minkowski.metric)
        mean.fit(points, weights=weights)
        result = mean.estimate_
        result = self.minkowski.belongs(result)
        expected = gs.array([[True]])

        self.assertAllClose(result, expected)

    def test_variance_minkowski(self):
        points = gs.array([[1., 0.], [2., math.sqrt(3)], [3., math.sqrt(8)],
                           [4., math.sqrt(24)]])
        weights = gs.array([1., 2., 1., 2.])
        base_point = gs.array([-1., 0.])
        var = variance(points,
                       weights=weights,
                       base_point=base_point,
                       metric=self.minkowski.metric)
        result = helper.to_scalar(var != 0)
        # we expect the average of the points' Minkowski sq norms.
        expected = helper.to_scalar(gs.array([True]))
        self.assertAllClose(result, expected)
class TestHyperbolicMethods(geomstats.tests.TestCase):
    def setUp(self):
        gs.random.seed(1234)
        self.dimension = 2

        self.extrinsic_manifold = Hyperbolic(dimension=self.dimension)
        self.ball_manifold = Hyperbolic(dimension=self.dimension,
                                        point_type='ball')
        self.intrinsic_manifold = Hyperbolic(dimension=self.dimension,
                                             point_type='intrinsic')
        self.half_plane_manifold = Hyperbolic(dimension=self.dimension,
                                              point_type='half-plane')
        self.ball_metric = HyperbolicMetric(dimension=self.dimension,
                                            point_type='ball')
        self.extrinsic_metric = HyperbolicMetric(dimension=self.dimension,
                                                 point_type='extrinsic')
        self.n_samples = 10

    @geomstats.tests.np_and_pytorch_only
    def test_extrinsic_ball_extrinsic(self):
        x_in = gs.array([[0.5, 7]])
        x = self.intrinsic_manifold.to_coordinates(x_in,
                                                   to_point_type='extrinsic')
        x_b = self.extrinsic_manifold.to_coordinates(x, to_point_type='ball')
        x2 = self.ball_manifold.to_coordinates(x_b, to_point_type='extrinsic')
        self.assertAllClose(x, x2, atol=1e-8)

    @geomstats.tests.np_and_pytorch_only
    def test_extrinsic_half_plane_extrinsic(self):
        x_in = gs.array([[0.5, 7]])
        x = self.intrinsic_manifold.to_coordinates(x_in,
                                                   to_point_type='extrinsic')
        x_up = self.extrinsic_manifold.to_coordinates(
            x, to_point_type='half-plane')

        x2 = self.half_plane_manifold.to_coordinates(x_up,
                                                     to_point_type='extrinsic')
        self.assertAllClose(x, x2, atol=1e-8)

    @geomstats.tests.np_and_pytorch_only
    def test_intrinsic_extrinsic_intrinsic(self):
        x_intr = gs.array([[0.5, 7]])
        x_extr = self.intrinsic_manifold.to_coordinates(
            x_intr, to_point_type='extrinsic')
        x_intr2 = self.extrinsic_manifold.to_coordinates(
            x_extr, to_point_type='intrinsic')
        self.assertAllClose(x_intr, x_intr2, atol=1e-8)

    @geomstats.tests.np_and_pytorch_only
    def test_ball_extrinsic_ball(self):
        x = gs.array([[0.5, 0.2]])
        x_e = self.ball_manifold.to_coordinates(x, to_point_type='extrinsic')
        x2 = self.extrinsic_manifold.to_coordinates(x_e, to_point_type='ball')
        self.assertAllClose(x, x2, atol=1e-10)

    @geomstats.tests.np_and_pytorch_only
    def test_belongs_ball(self):
        x = gs.array([[0.5, 0.2]])
        belong = self.ball_manifold.belongs(x)
        assert (belong[0])

    @geomstats.tests.np_and_pytorch_only
    def test_distance_ball_extrinsic_from_ball(self):
        x_ball = gs.array([[0.7, 0.2]])
        y_ball = gs.array([[0.2, 0.2]])
        x_extr = self.ball_manifold.to_coordinates(x_ball,
                                                   to_point_type='extrinsic')
        y_extr = self.ball_manifold.to_coordinates(y_ball,
                                                   to_point_type='extrinsic')
        dst_ball = self.ball_metric.dist(x_ball, y_ball)
        dst_extr = self.extrinsic_metric.dist(x_extr, y_extr)
        self.assertAllClose(dst_ball, dst_extr)

    @geomstats.tests.np_and_pytorch_only
    def test_distance_ball_extrinsic_from_extr(self):
        x_int = gs.array([[10, 0.2]])
        y_int = gs.array([[1, 6.]])
        x_extr = self.intrinsic_manifold.to_coordinates(
            x_int, to_point_type='extrinsic')
        y_extr = self.intrinsic_manifold.to_coordinates(
            y_int, to_point_type='extrinsic')
        x_ball = self.extrinsic_manifold.to_coordinates(x_extr,
                                                        to_point_type='ball')
        y_ball = self.extrinsic_manifold.to_coordinates(y_extr,
                                                        to_point_type='ball')
        dst_ball = self.ball_metric.dist(x_ball, y_ball)
        dst_extr = self.extrinsic_metric.dist(x_extr, y_extr)
        self.assertAllClose(dst_ball, dst_extr)

    @geomstats.tests.np_and_pytorch_only
    def test_distance_ball_extrinsic_from_extr_5_dim(self):
        x_int = gs.array([[10, 0.2, 3, 4]])
        y_int = gs.array([[1, 6, 2., 1]])
        extrinsic_manifold = Hyperbolic(4, point_type='extrinsic')
        ball_metric = HyperbolicMetric(4, point_type='ball')
        extrinsic_metric = HyperbolicMetric(4, point_type='extrinsic')
        x_extr = extrinsic_manifold.from_coordinates(
            x_int, from_point_type='intrinsic')
        y_extr = extrinsic_manifold.from_coordinates(
            y_int, from_point_type='intrinsic')
        x_ball = extrinsic_manifold.to_coordinates(x_extr,
                                                   to_point_type='ball')
        y_ball = extrinsic_manifold.to_coordinates(y_extr,
                                                   to_point_type='ball')
        dst_ball = ball_metric.dist(x_ball, y_ball)
        dst_extr = extrinsic_metric.dist(x_extr, y_extr)
        self.assertAllClose(dst_ball, dst_extr)

    @geomstats.tests.np_and_pytorch_only
    def test_log_exp_ball_extrinsic_from_extr(self):
        x_int = gs.array([[4., 0.2]])
        y_int = gs.array([[3., 3]])
        x_extr = self.intrinsic_manifold.to_coordinates(
            x_int, to_point_type='extrinsic')
        y_extr = self.intrinsic_manifold.to_coordinates(
            y_int, to_point_type='extrinsic')
        x_ball = self.extrinsic_manifold.to_coordinates(x_extr,
                                                        to_point_type='ball')
        y_ball = self.extrinsic_manifold.to_coordinates(y_extr,
                                                        to_point_type='ball')

        x_ball_log_exp = self.ball_metric.exp(
            self.ball_metric.log(y_ball, x_ball), x_ball)

        x_extr_a = self.extrinsic_metric.exp(
            self.extrinsic_metric.log(y_extr, x_extr), x_extr)
        x_extr_b = self.extrinsic_manifold.from_coordinates(
            x_ball_log_exp, from_point_type='ball')
        self.assertAllClose(x_extr_a, x_extr_b, atol=1e-4)

    @geomstats.tests.np_and_pytorch_only
    def test_log_exp_ball(self):
        x = gs.array([[0.1, 0.2]])
        y = gs.array([[0.2, 0.5]])

        log = self.ball_metric.log(y, x)
        exp = self.ball_metric.exp(log, x)
        self.assertAllClose(exp, y)

    @geomstats.tests.np_and_pytorch_only
    def test_log_exp_ball_batch(self):
        x = gs.array([[0.1, 0.2]])
        y = gs.array([[0.2, 0.5], [0.1, 0.7]])

        log = self.ball_metric.log(y, x)
        exp = self.ball_metric.exp(log, x)
        self.assertAllClose(exp, y)
class TestHyperbolicMethods(geomstats.tests.TestCase):
    def setUp(self):
        gs.random.seed(1234)
        self.dimension = 3
        self.space = Hyperbolic(dimension=self.dimension)
        self.metric = self.space.metric
        self.n_samples = 10

    def test_random_uniform_and_belongs(self):
        point = self.space.random_uniform()
        result = self.space.belongs(point)
        expected = gs.array([[True]])

        self.assertAllClose(result, expected)

    def test_random_uniform(self):
        result = self.space.random_uniform()

        self.assertAllClose(gs.shape(result), (1, self.dimension + 1))

    def test_intrinsic_and_extrinsic_coords(self):
        """
        Test that the composition of
        intrinsic_to_extrinsic_coords and
        extrinsic_to_intrinsic_coords
        gives the identity.
        """
        point_int = gs.ones(self.dimension)
        point_ext = self.space.intrinsic_to_extrinsic_coords(point_int)
        result = self.space.extrinsic_to_intrinsic_coords(point_ext)
        expected = point_int
        expected = helper.to_vector(expected)
        self.assertAllClose(result, expected)

        point_ext = gs.array([2.0, 1.0, 1.0, 1.0])
        point_int = self.space.extrinsic_to_intrinsic_coords(point_ext)
        result = self.space.intrinsic_to_extrinsic_coords(point_int)
        expected = point_ext
        expected = helper.to_vector(expected)

        self.assertAllClose(result, expected)

    def test_intrinsic_and_extrinsic_coords_vectorization(self):
        """
        Test that the composition of
        intrinsic_to_extrinsic_coords and
        extrinsic_to_intrinsic_coords
        gives the identity.
        """
        point_int = gs.array([[.1, 0., 0., .1, 0., 0.],
                              [.1, .1, .1, .4, .1, 0.],
                              [.1, .3, 0., .1, 0., 0.],
                              [-0.1, .1, -.4, .1, -.01, 0.],
                              [0., 0., .1, .1, -0.08, -0.1],
                              [.1, .1, .1, .1, 0., -0.5]])
        point_ext = self.space.intrinsic_to_extrinsic_coords(point_int)
        result = self.space.extrinsic_to_intrinsic_coords(point_ext)
        expected = point_int
        expected = helper.to_vector(expected)

        self.assertAllClose(result, expected)

        point_ext = gs.array([[2., 1., 1., 1.],
                              [4., 1., 3., math.sqrt(5.)],
                              [3., 2., 0., 2.]])
        point_int = self.space.extrinsic_to_intrinsic_coords(point_ext)
        result = self.space.intrinsic_to_extrinsic_coords(point_int)
        expected = point_ext
        expected = helper.to_vector(expected)

        self.assertAllClose(result, expected)

    def test_log_and_exp_general_case(self):
        """
        Test that the riemannian exponential
        and the riemannian logarithm are inverse.

        Expect their composition to give the identity function.
        """
        # Riemannian Log then Riemannian Exp
        # General case
        base_point = gs.array([4.0, 1., 3.0, math.sqrt(5.)])
        point = gs.array([2.0, 1.0, 1.0, 1.0])

        log = self.metric.log(point=point, base_point=base_point)

        result = self.metric.exp(tangent_vec=log, base_point=base_point)
        expected = helper.to_vector(point)
        self.assertAllClose(result, expected)

    def test_exp_and_belongs(self):
        H2 = Hyperbolic(dimension=2)
        METRIC = H2.metric

        base_point = gs.array([1., 0., 0.])
        with self.session():
            self.assertTrue(gs.eval(H2.belongs(base_point)))

        tangent_vec = H2.projection_to_tangent_space(
            vector=gs.array([1., 2., 1.]),
            base_point=base_point)
        exp = METRIC.exp(tangent_vec=tangent_vec,
                         base_point=base_point)
        with self.session():
            self.assertTrue(gs.eval(H2.belongs(exp)))

    @geomstats.tests.np_and_pytorch_only
    def test_exp_vectorization(self):
        n_samples = 3
        dim = self.dimension + 1

        one_vec = gs.array([2.0, 1.0, 1.0, 1.0])
        one_base_point = gs.array([4.0, 3., 1.0, math.sqrt(5)])
        n_vecs = gs.array([[2., 1., 1., 1.],
                           [4., 1., 3., math.sqrt(5.)],
                           [3., 2., 0., 2.]])
        n_base_points = gs.array([
            [2.0, 0.0, 1.0, math.sqrt(2)],
            [5.0, math.sqrt(8), math.sqrt(8), math.sqrt(8)],
            [1.0, 0.0, 0.0, 0.0]])

        one_tangent_vec = self.space.projection_to_tangent_space(
            one_vec, base_point=one_base_point)
        result = self.metric.exp(one_tangent_vec, one_base_point)
        self.assertAllClose(gs.shape(result), (1, dim))

        n_tangent_vecs = self.space.projection_to_tangent_space(
            n_vecs, base_point=one_base_point)
        result = self.metric.exp(n_tangent_vecs, one_base_point)
        self.assertAllClose(gs.shape(result), (n_samples, dim))

        expected = gs.zeros((n_samples, dim))

        with self.session():
            for i in range(n_samples):
                expected[i] = gs.eval(
                    self.metric.exp(n_tangent_vecs[i], one_base_point))
            expected = helper.to_vector(gs.array(expected))
            self.assertAllClose(result, expected)

        one_tangent_vec = self.space.projection_to_tangent_space(
            one_vec, base_point=n_base_points)
        result = self.metric.exp(one_tangent_vec, n_base_points)
        self.assertAllClose(gs.shape(result), (n_samples, dim))

        expected = gs.zeros((n_samples, dim))
        with self.session():
            for i in range(n_samples):
                expected[i] = gs.eval(self.metric.exp(one_tangent_vec[i],
                                      n_base_points[i]))
            expected = helper.to_vector(gs.array(expected))
            self.assertAllClose(result, expected)

        n_tangent_vecs = self.space.projection_to_tangent_space(
            n_vecs, base_point=n_base_points)
        result = self.metric.exp(n_tangent_vecs, n_base_points)
        self.assertAllClose(gs.shape(result), (n_samples, dim))

        expected = gs.zeros((n_samples, dim))
        with self.session():
            for i in range(n_samples):
                expected[i] = gs.eval(self.metric.exp(n_tangent_vecs[i],
                                      n_base_points[i]))
            expected = helper.to_vector(gs.array(expected))
            self.assertAllClose(result, expected)

    def test_log_vectorization(self):
        n_samples = 3
        dim = self.dimension + 1

        one_point = gs.array([2.0, 1.0, 1.0, 1.0])
        one_base_point = gs.array([4.0, 3., 1.0, math.sqrt(5)])
        n_points = gs.array([[2.0, 1.0, 1.0, 1.0],
                             [4.0, 1., 3.0, math.sqrt(5)],
                             [3.0, 2.0, 0.0, 2.0]])
        n_base_points = gs.array([
            [2.0, 0.0, 1.0, math.sqrt(2)],
            [5.0, math.sqrt(8), math.sqrt(8), math.sqrt(8)],
            [1.0, 0.0, 0.0, 0.0]])

        result = self.metric.log(one_point, one_base_point)
        self.assertAllClose(gs.shape(result), (1, dim))

        result = self.metric.log(n_points, one_base_point)
        self.assertAllClose(gs.shape(result), (n_samples, dim))

        result = self.metric.log(one_point, n_base_points)
        self.assertAllClose(gs.shape(result), (n_samples, dim))

        result = self.metric.log(n_points, n_base_points)
        self.assertAllClose(gs.shape(result), (n_samples, dim))

    def test_inner_product(self):
        """
        Test that the inner product between two tangent vectors
        is the Minkowski inner product.
        """
        minkowski_space = Minkowski(self.dimension + 1)
        base_point = gs.array(
            [1.16563816, 0.36381045, -0.47000603, 0.07381469])

        tangent_vec_a = self.space.projection_to_tangent_space(
            vector=gs.array([10., 200., 1., 1.]),
            base_point=base_point)

        tangent_vec_b = self.space.projection_to_tangent_space(
            vector=gs.array([11., 20., -21., 0.]),
            base_point=base_point)

        result = self.metric.inner_product(
            tangent_vec_a, tangent_vec_b, base_point)

        expected = minkowski_space.metric.inner_product(
            tangent_vec_a, tangent_vec_b, base_point)

        with self.session():
            self.assertAllClose(result, expected)

    def test_squared_norm_and_squared_dist(self):
        """
        Test that the squared distance between two points is
        the squared norm of their logarithm.
        """
        point_a = gs.array([2.0, 1.0, 1.0, 1.0])
        point_b = gs.array([4.0, 1., 3.0, math.sqrt(5)])
        log = self.metric.log(point=point_a, base_point=point_b)
        result = self.metric.squared_norm(vector=log)
        expected = self.metric.squared_dist(point_a, point_b)

        with self.session():
            self.assertAllClose(result, expected)

    def test_norm_and_dist(self):
        """
        Test that the distance between two points is
        the norm of their logarithm.
        """
        point_a = gs.array([2.0, 1.0, 1.0, 1.0])
        point_b = gs.array([4.0, 1., 3.0, math.sqrt(5)])
        log = self.metric.log(point=point_a, base_point=point_b)
        result = self.metric.norm(vector=log)
        expected = self.metric.dist(point_a, point_b)

        with self.session():
            self.assertAllClose(result, expected)

    def test_log_and_exp_edge_case(self):
        """
        Test that the riemannian exponential
        and the riemannian logarithm are inverse.

        Expect their composition to give the identity function.
        """
        # Riemannian Log then Riemannian Exp
        # Edge case: two very close points, base_point_2 and point_2,
        # form an angle < epsilon
        base_point_intrinsic = gs.array([1., 2., 3.])
        base_point = self.space.intrinsic_to_extrinsic_coords(
            base_point_intrinsic)
        point_intrinsic = (base_point_intrinsic +
                           1e-12 * gs.array([-1., -2., 1.]))
        point = self.space.intrinsic_to_extrinsic_coords(
            point_intrinsic)

        log = self.metric.log(point=point, base_point=base_point)
        result = self.metric.exp(tangent_vec=log, base_point=base_point)
        expected = point

        with self.session():
            self.assertAllClose(result, expected)

    @geomstats.tests.np_and_tf_only
    def test_exp_and_log_and_projection_to_tangent_space_general_case(self):
        """
        Test that the riemannian exponential
        and the riemannian logarithm are inverse.

        Expect their composition to give the identity function.
        """
        # Riemannian Exp then Riemannian Log
        # General case
        base_point = gs.array([4.0, 1., 3.0, math.sqrt(5)])
        vector = gs.array([2.0, 1.0, 1.0, 1.0])
        vector = self.space.projection_to_tangent_space(
            vector=vector,
            base_point=base_point)
        exp = self.metric.exp(tangent_vec=vector, base_point=base_point)
        result = self.metric.log(point=exp, base_point=base_point)

        expected = vector
        with self.session():
            self.assertAllClose(result, expected)

    def test_dist(self):
        # Distance between a point and itself is 0.
        point_a = gs.array([4.0, 1., 3.0, math.sqrt(5)])
        point_b = point_a
        result = self.metric.dist(point_a, point_b)
        expected = gs.array([[0]])

        with self.session():
            self.assertAllClose(result, expected)

    @geomstats.tests.np_and_pytorch_only
    def test_dist_poincare(self):

        point_a = gs.array([0.5, 0.5])
        point_b = gs.array([0.5, -0.5])

        self.space.metric.point_type = 'ball'

        dist_a_b = self.metric.dist(point_a, point_b)
        self.space.metric.point_type = 'extrinsic'

        result = dist_a_b
        expected = gs.array([[2.887270927429199]])

        with self.session():
            self.assertAllClose(result, expected)

    def test_exp_poincare(self):

        self.space.metric.point_type = 'ball'
        result = 0
        expected = 0
        self.space.metric.point_type = 'extrinsic'
        with self.session():
            self.assertAllClose(result, expected)

    @geomstats.tests.np_only
    def test_log_poincare(self):

        point = gs.array([0.3, 0.5])
        base_point = gs.array([0.3, 0.3])

        self.space.metric.point_type = 'ball'
        result = self.space.metric.log(point, base_point)
        expected = gs.array([-0.01733576, 0.21958634])

        self.space.metric.point_type = 'extrinsic'
        with self.session():
            self.assertAllClose(result, expected)

    def test_exp_and_dist_and_projection_to_tangent_space(self):
        base_point = gs.array([4.0, 1., 3.0, math.sqrt(5)])
        vector = gs.array([0.001, 0., -.00001, -.00003])
        tangent_vec = self.space.projection_to_tangent_space(
            vector=vector,
            base_point=base_point)
        exp = self.metric.exp(tangent_vec=tangent_vec,
                              base_point=base_point)

        result = self.metric.dist(base_point, exp)
        sq_norm = self.metric.embedding_metric.squared_norm(
            tangent_vec)
        expected = sq_norm
        with self.session():
            self.assertAllClose(result, expected, atol=1e-2)

    def test_geodesic_and_belongs(self):
        # TODO(nina): Fix this tests, as it fails when geodesic goes "too far"
        initial_point = gs.array([4.0, 1., 3.0, math.sqrt(5)])
        n_geodesic_points = 100
        vector = gs.array([1., 0., 0., 0.])

        initial_tangent_vec = self.space.projection_to_tangent_space(
            vector=vector,
            base_point=initial_point)
        geodesic = self.metric.geodesic(
            initial_point=initial_point,
            initial_tangent_vec=initial_tangent_vec)

        t = gs.linspace(start=0., stop=1., num=n_geodesic_points)
        points = geodesic(t)

        result = self.space.belongs(points)
        expected = gs.array(n_geodesic_points * [[True]])

        with self.session():
            self.assertAllClose(expected, result)

    def test_exp_and_log_and_projection_to_tangent_space_edge_case(self):
        """
        Test that the riemannian exponential and
        the riemannian logarithm are inverse.

        Expect their composition to give the identity function.
        """
        # Riemannian Exp then Riemannian Log
        # Edge case: tangent vector has norm < epsilon
        base_point = gs.array([2., 1., 1., 1.])
        vector = 1e-10 * gs.array([.06, -51., 6., 5.])

        exp = self.metric.exp(tangent_vec=vector, base_point=base_point)
        result = self.metric.log(point=exp, base_point=base_point)
        expected = self.space.projection_to_tangent_space(
            vector=vector,
            base_point=base_point)

        self.assertAllClose(result, expected, atol=1e-8)

    @geomstats.tests.np_and_tf_only
    def test_variance(self):
        point = gs.array([2., 1., 1., 1.])
        points = gs.array([point, point])
        result = self.metric.variance(points)
        expected = helper.to_scalar(0.)

        self.assertAllClose(result, expected)

    @geomstats.tests.np_and_tf_only
    def test_mean(self):
        point = gs.array([2., 1., 1., 1.])
        points = gs.array([point, point])
        result = self.metric.mean(points)
        expected = helper.to_vector(point)

        self.assertAllClose(result, expected)

    @geomstats.tests.np_and_tf_only
    def test_mean_and_belongs(self):
        point_a = self.space.random_uniform()
        point_b = self.space.random_uniform()
        point_c = self.space.random_uniform()
        points = gs.concatenate([point_a, point_b, point_c], axis=0)

        mean = self.metric.mean(points)
        result = self.space.belongs(mean)
        expected = gs.array([[True]])

        self.assertAllClose(result, expected)

    @geomstats.tests.np_only
    def test_scaled_inner_product(self):
        base_point_intrinsic = gs.array([1, 1, 1])
        base_point = self.space.intrinsic_to_extrinsic_coords(
            base_point_intrinsic)
        tangent_vec_a = gs.array([1, 2, 3, 4])
        tangent_vec_b = gs.array([5, 6, 7, 8])
        tangent_vec_a = self.space.projection_to_tangent_space(
            tangent_vec_a,
            base_point)
        tangent_vec_b = self.space.projection_to_tangent_space(
            tangent_vec_b,
            base_point)
        scale = 2
        default_space = Hyperbolic(dimension=self.dimension)
        scaled_space = Hyperbolic(dimension=self.dimension, scale=2)
        inner_product_default_metric = \
            default_space.metric.inner_product(
                tangent_vec_a,
                tangent_vec_b,
                base_point)
        inner_product_scaled_metric = \
            scaled_space.metric.inner_product(
                tangent_vec_a,
                tangent_vec_b,
                base_point)
        result = inner_product_scaled_metric
        expected = scale ** 2 * inner_product_default_metric
        self.assertAllClose(result, expected)

    @geomstats.tests.np_only
    def test_scaled_squared_norm(self):
        base_point_intrinsic = gs.array([1, 1, 1])
        base_point = self.space.intrinsic_to_extrinsic_coords(
            base_point_intrinsic)
        tangent_vec = gs.array([1, 2, 3, 4])
        tangent_vec = self.space.projection_to_tangent_space(
            tangent_vec, base_point)
        scale = 2
        default_space = Hyperbolic(dimension=self.dimension)
        scaled_space = Hyperbolic(dimension=self.dimension, scale=2)
        squared_norm_default_metric = default_space.metric.squared_norm(
            tangent_vec, base_point)
        squared_norm_scaled_metric = scaled_space.metric.squared_norm(
            tangent_vec, base_point)
        result = squared_norm_scaled_metric
        expected = scale ** 2 * squared_norm_default_metric
        self.assertAllClose(result, expected)

    @geomstats.tests.np_only
    def test_scaled_distance(self):
        point_a_intrinsic = gs.array([1, 2, 3])
        point_b_intrinsic = gs.array([4, 5, 6])
        point_a = self.space.intrinsic_to_extrinsic_coords(point_a_intrinsic)
        point_b = self.space.intrinsic_to_extrinsic_coords(point_b_intrinsic)
        scale = 2
        default_space = Hyperbolic(dimension=self.dimension)
        scaled_space = Hyperbolic(dimension=self.dimension, scale=2)
        distance_default_metric = default_space.metric.dist(point_a, point_b)
        distance_scaled_metric = scaled_space.metric.dist(point_a, point_b)
        result = distance_scaled_metric
        expected = scale * distance_default_metric
        self.assertAllClose(result, expected)
class TestHyperbolicMethods(geomstats.tests.TestCase):
    def setUp(self):
        gs.random.seed(1234)
        self.dimension = 2

        self.extrinsic_manifold = Hyperbolic(
            dimension=self.dimension)
        self.ball_manifold = Hyperbolic(
            dimension=self.dimension, coords_type='ball')
        self.intrinsic_manifold = Hyperbolic(
            dimension=self.dimension, coords_type='intrinsic')
        self.half_plane_manifold = Hyperbolic(
            dimension=self.dimension, coords_type='half-plane')
        self.ball_metric = HyperbolicMetric(
            dimension=self.dimension, coords_type='ball')
        self.extrinsic_metric = HyperbolicMetric(
            dimension=self.dimension, coords_type='extrinsic')
        self.n_samples = 10

    @geomstats.tests.np_and_pytorch_only
    def test_extrinsic_ball_extrinsic(self):
        x_in = gs.array([[0.5, 7]])
        x = self.intrinsic_manifold.to_coordinates(
            x_in, to_coords_type='extrinsic')
        x_b = self.extrinsic_manifold.to_coordinates(x, to_coords_type='ball')
        x2 = self.ball_manifold.to_coordinates(x_b, to_coords_type='extrinsic')
        self.assertAllClose(x, x2, atol=1e-8)

    @geomstats.tests.np_and_pytorch_only
    def test_belongs_intrinsic(self):
        x_in = gs.array([[0.5, 7]])
        is_in = self.intrinsic_manifold.belongs(x_in)
        self.assertTrue(is_in)

    @geomstats.tests.np_and_pytorch_only
    def test_belongs_extrinsic(self):
        x_true = self.intrinsic_manifold.to_coordinates(gs.array([[0.5, 7]]),
                                                        'extrinsic')
        x_false = gs.array([[0.5, 7, 3.]])
        is_in = self.extrinsic_manifold.belongs(x_true)
        self.assertTrue(is_in)
        is_out = self.extrinsic_manifold.belongs(x_false)
        self.assertFalse(is_out)

    @geomstats.tests.np_and_pytorch_only
    def test_belongs_ball(self):
        x_true = gs.array([[0.5, 0.5]])
        x_false = gs.array([[0.8, 0.8]])
        is_in = self.ball_manifold.belongs(x_true)
        self.assertTrue(is_in)
        is_out = self.ball_manifold.belongs(x_false)
        self.assertFalse(is_out)

    @geomstats.tests.np_and_pytorch_only
    def test_belongs_half_plane(self):
        x_true = gs.array([[0.5, 0.5]])
        x_false = gs.array([[0.8, -0.8]])
        is_in = self.half_plane_manifold.belongs(x_true)
        self.assertTrue(is_in)
        is_out = self.half_plane_manifold.belongs(x_false)
        self.assertFalse(is_out)

    @geomstats.tests.np_and_pytorch_only
    def test_extrinsic_half_plane_extrinsic(self):
        x_in = gs.array([[0.5, 7]])
        x = self.intrinsic_manifold.to_coordinates(
            x_in, to_coords_type='extrinsic')
        x_up = self.extrinsic_manifold.to_coordinates(
            x, to_coords_type='half-plane')

        x2 = self.half_plane_manifold.to_coordinates(
            x_up, to_coords_type='extrinsic')
        self.assertAllClose(x, x2, atol=1e-8)

    @geomstats.tests.np_and_pytorch_only
    def test_intrinsic_extrinsic_intrinsic(self):
        x_intr = gs.array([[0.5, 7]])
        x_extr = self.intrinsic_manifold.to_coordinates(
            x_intr, to_coords_type='extrinsic')
        x_intr2 = self.extrinsic_manifold.to_coordinates(
            x_extr, to_coords_type='intrinsic')
        self.assertAllClose(x_intr, x_intr2, atol=1e-8)

    @geomstats.tests.np_and_pytorch_only
    def test_ball_extrinsic_ball(self):
        x = gs.array([[0.5, 0.2]])
        x_e = self.ball_manifold.to_coordinates(x, to_coords_type='extrinsic')
        x2 = self.extrinsic_manifold.to_coordinates(x_e, to_coords_type='ball')
        self.assertAllClose(x, x2, atol=1e-10)

    @geomstats.tests.np_and_pytorch_only
    def test_distance_ball_extrinsic_from_ball(self):
        x_ball = gs.array([[0.7, 0.2]])
        y_ball = gs.array([[0.2, 0.2]])
        x_extr = self.ball_manifold.to_coordinates(
            x_ball, to_coords_type='extrinsic')
        y_extr = self.ball_manifold.to_coordinates(
            y_ball, to_coords_type='extrinsic')
        dst_ball = self.ball_metric.dist(x_ball, y_ball)
        dst_extr = self.extrinsic_metric.dist(x_extr, y_extr)
        self.assertAllClose(dst_ball, dst_extr)

    @geomstats.tests.np_and_pytorch_only
    def test_distance_ball_extrinsic_from_extr(self):
        x_int = gs.array([[10, 0.2]])
        y_int = gs.array([[1, 6.]])
        x_extr = self.intrinsic_manifold.to_coordinates(
            x_int, to_coords_type='extrinsic')
        y_extr = self.intrinsic_manifold.to_coordinates(
            y_int, to_coords_type='extrinsic')
        x_ball = self.extrinsic_manifold.to_coordinates(
            x_extr, to_coords_type='ball')
        y_ball = self.extrinsic_manifold.to_coordinates(
            y_extr, to_coords_type='ball')
        dst_ball = self.ball_metric.dist(x_ball, y_ball)
        dst_extr = self.extrinsic_metric.dist(x_extr, y_extr)
        self.assertAllClose(dst_ball, dst_extr)

    @geomstats.tests.np_and_pytorch_only
    def test_distance_ball_extrinsic_from_extr_5_dim(self):
        x_int = gs.array([[10, 0.2, 3, 4]])
        y_int = gs.array([[1, 6, 2., 1]])
        extrinsic_manifold = Hyperbolic(4, coords_type='extrinsic')
        ball_metric = HyperbolicMetric(4, coords_type='ball')
        extrinsic_metric = HyperbolicMetric(4, coords_type='extrinsic')
        x_extr = extrinsic_manifold.from_coordinates(
            x_int, from_coords_type='intrinsic')
        y_extr = extrinsic_manifold.from_coordinates(
            y_int, from_coords_type='intrinsic')
        x_ball = extrinsic_manifold.to_coordinates(
            x_extr, to_coords_type='ball')
        y_ball = extrinsic_manifold.to_coordinates(
            y_extr, to_coords_type='ball')
        dst_ball = ball_metric.dist(x_ball, y_ball)
        dst_extr = extrinsic_metric.dist(x_extr, y_extr)
        self.assertAllClose(dst_ball, dst_extr)

    @geomstats.tests.np_and_pytorch_only
    def test_log_exp_ball_extrinsic_from_extr(self):
        """Compare log exp in different parameterizations."""
        # TODO(Hazaatiti): Fix this test
        # x_int = gs.array([[4., 0.2]])
        # y_int = gs.array([[3., 3]])
        # x_extr = self.intrinsic_manifold.to_coordinates(
        #     x_int, to_point_type='extrinsic')
        # y_extr = self.intrinsic_manifold.to_coordinates(
        #     y_int, to_point_type='extrinsic')
        # x_ball = self.extrinsic_manifold.to_coordinates(
        #     x_extr, to_point_type='ball')
        # y_ball = self.extrinsic_manifold.to_coordinates(
        #     y_extr, to_point_type='ball')

        # x_ball_log_exp = self.ball_metric.exp(
        #     self.ball_metric.log(y_ball, x_ball), x_ball)

        # x_extr_a = self.extrinsic_metric.exp(
        #     self.extrinsic_metric.log(y_extr, x_extr), x_extr)
        # x_extr_b = self.extrinsic_manifold.from_coordinates(
        #     x_ball_log_exp, from_point_type='ball')
        # self.assertAllClose(x_extr_a, x_extr_b, atol=1e-4)

    @geomstats.tests.np_only
    def test_log_exp_ball(self):
        x = gs.array([[0.1, 0.2]])
        y = gs.array([[0.2, 0.5]])

        log = self.ball_metric.log(point=y, base_point=x)
        exp = self.ball_metric.exp(tangent_vec=log, base_point=x)
        self.assertAllClose(exp, y, atol=1e-1)

    @geomstats.tests.np_only
    def test_log_exp_ball_vectorization(self):
        x = gs.array([[0.1, 0.2]])
        y = gs.array([[0.2, 0.5], [0.1, 0.7]])

        log = self.ball_metric.log(y, x)
        exp = self.ball_metric.exp(log, x)
        self.assertAllClose(exp, y, atol=1e-1)

    @geomstats.tests.np_only
    def test_log_exp_ball_null_tangent(self):
        x = gs.array([[0.1, 0.2], [0.1, 0.2]])
        tangent_vec = gs.array([[0.0, 0.0], [0.0, 0.0]])
        exp = self.ball_metric.exp(tangent_vec, x)
        self.assertAllClose(exp, x, atol=1e-10)