コード例 #1
0
ファイル: test_hyperbolic.py プロジェクト: xpennec/geomstats
    def test_log_and_exp_general_case_general_dim(self):
        """
        Test that the Riemannian exponential
        and the Riemannian logarithm are inverse.

        Expect their composition to give the identity function.
        """
        # Riemannian Log then Riemannian Exp
        dim = 5
        n_samples = self.n_samples

        h5 = Hyperboloid(dim=dim)
        h5_metric = h5.metric

        base_point = h5.random_point()
        point = h5.random_point()
        point = gs.cast(point, gs.float64)
        base_point = gs.cast(base_point, gs.float64)
        one_log = h5_metric.log(point=point, base_point=base_point)

        result = h5_metric.exp(tangent_vec=one_log, base_point=base_point)
        expected = point
        self.assertAllClose(result, expected)

        # Test vectorization of log
        base_point = gs.stack([base_point] * n_samples, axis=0)
        point = gs.stack([point] * n_samples, axis=0)
        expected = gs.stack([one_log] * n_samples, axis=0)

        log = h5_metric.log(point=point, base_point=base_point)
        result = log

        self.assertAllClose(gs.shape(result), (n_samples, dim + 1))
        self.assertAllClose(result, expected)

        result = h5_metric.exp(tangent_vec=log, base_point=base_point)
        expected = point

        self.assertAllClose(gs.shape(result), (n_samples, dim + 1))
        self.assertAllClose(result, expected)

        # Test vectorization of exp
        tangent_vec = gs.stack([one_log] * n_samples, axis=0)
        exp = h5_metric.exp(tangent_vec=tangent_vec, base_point=base_point)
        result = exp

        expected = point
        self.assertAllClose(gs.shape(result), (n_samples, dim + 1))
        self.assertAllClose(result, expected)
コード例 #2
0
class TestToTangentSpace(geomstats.tests.TestCase):
    _multiprocess_can_split_ = True

    def setup_method(self):
        gs.random.seed(123)
        self.sphere = Hypersphere(dim=4)
        self.hyperbolic = Hyperboloid(dim=3)
        self.euclidean = Euclidean(dim=2)
        self.minkowski = Minkowski(dim=2)
        self.so3 = SpecialOrthogonal(n=3, point_type="vector")
        self.so_matrix = SpecialOrthogonal(n=3, point_type="matrix")

    def test_estimate_transform_sphere(self):
        point = gs.array([0.0, 0.0, 0.0, 0.0, 1.0])
        points = gs.array([point, point])
        transformer = ToTangentSpace(geometry=self.sphere)
        transformer.fit(X=points)
        result = transformer.transform(points)
        expected = gs.zeros_like(points)
        self.assertAllClose(expected, result)

    def test_inverse_transform_no_fit_sphere(self):
        point = self.sphere.random_uniform(3)
        base_point = point[0]
        point = point[1:]
        transformer = ToTangentSpace(geometry=self.sphere)
        X = transformer.transform(point, base_point=base_point)
        result = transformer.inverse_transform(X, base_point=base_point)
        expected = point
        self.assertAllClose(expected, result)

    @geomstats.tests.np_autograd_and_tf_only
    def test_estimate_transform_so_group(self):
        point = self.so_matrix.random_uniform()
        points = gs.array([point, point])

        transformer = ToTangentSpace(geometry=self.so_matrix)
        transformer.fit(X=points)
        result = transformer.transform(points)
        expected = gs.zeros((2, 6))
        self.assertAllClose(expected, result)

    def test_estimate_transform_spd(self):
        point = spd.SPDMatrices(3).random_point()
        points = gs.stack([point, point])
        transformer = ToTangentSpace(geometry=spd.SPDMetricAffine(3))
        transformer.fit(X=points)
        result = transformer.transform(points)
        expected = gs.zeros((2, 6))
        self.assertAllClose(expected, result, atol=1e-5)

    def test_fit_transform_hyperbolic(self):
        point = gs.array([2.0, 1.0, 1.0, 1.0])
        points = gs.array([point, point])
        transformer = ToTangentSpace(geometry=self.hyperbolic.metric)
        result = transformer.fit_transform(X=points)
        expected = gs.zeros_like(points)
        self.assertAllClose(expected, result)

    def test_inverse_transform_hyperbolic(self):
        points = self.hyperbolic.random_point(10)
        transformer = ToTangentSpace(geometry=self.hyperbolic.metric)
        X = transformer.fit_transform(X=points)
        result = transformer.inverse_transform(X)
        expected = points
        self.assertAllClose(expected, result)

    def test_inverse_transform_spd(self):
        point = spd.SPDMatrices(3).random_point(10)
        transformer = ToTangentSpace(geometry=spd.SPDMetricLogEuclidean(3))
        X = transformer.fit_transform(X=point)
        result = transformer.inverse_transform(X)
        expected = point
        self.assertAllClose(expected, result, atol=1e-4)

        transformer = ToTangentSpace(geometry=spd.SPDMetricAffine(3))
        X = transformer.fit_transform(X=point)
        result = transformer.inverse_transform(X)
        expected = point
        self.assertAllClose(expected, result, atol=1e-4)

    @geomstats.tests.np_autograd_and_tf_only
    def test_inverse_transform_so(self):
        point = self.so_matrix.random_uniform(10)
        transformer = ToTangentSpace(
            geometry=self.so_matrix.bi_invariant_metric)
        X = transformer.transform(X=point, base_point=self.so_matrix.identity)
        result = transformer.inverse_transform(
            X, base_point=self.so_matrix.identity)
        expected = point
        self.assertAllClose(expected, result)
コード例 #3
0
def main():
    """Perform tangent PCA at the mean on H2."""
    fig = plt.figure(figsize=(15, 5))

    hyperbolic_plane = Hyperboloid(dim=2)

    data = hyperbolic_plane.random_point(n_samples=140)

    mean = FrechetMean(metric=hyperbolic_plane.metric)
    mean.fit(data)

    mean_estimate = mean.estimate_

    tpca = TangentPCA(metric=hyperbolic_plane.metric, n_components=2)
    tpca = tpca.fit(data, base_point=mean_estimate)
    tangent_projected_data = tpca.transform(data)

    geodesic_0 = hyperbolic_plane.metric.geodesic(
        initial_point=mean_estimate, initial_tangent_vec=tpca.components_[0])
    geodesic_1 = hyperbolic_plane.metric.geodesic(
        initial_point=mean_estimate, initial_tangent_vec=tpca.components_[1])

    n_steps = 100
    t = np.linspace(-1, 1, n_steps)
    geodesic_points_0 = geodesic_0(t)
    geodesic_points_1 = geodesic_1(t)

    logging.info(
        'Coordinates of the Log of the first 5 data points at the mean, '
        'projected on the principal components:')
    logging.info('\n{}'.format(tangent_projected_data[:5]))

    ax_var = fig.add_subplot(121)
    xticks = np.arange(1, 2 + 1, 1)
    ax_var.xaxis.set_ticks(xticks)
    ax_var.set_title('Explained variance')
    ax_var.set_xlabel('Number of Principal Components')
    ax_var.set_ylim((0, 1))
    ax_var.plot(xticks, tpca.explained_variance_ratio_)

    ax = fig.add_subplot(122)

    visualization.plot(mean_estimate,
                       ax,
                       space='H2_poincare_disk',
                       color='darkgreen',
                       s=10)
    visualization.plot(geodesic_points_0,
                       ax,
                       space='H2_poincare_disk',
                       linewidth=2)
    visualization.plot(geodesic_points_1,
                       ax,
                       space='H2_poincare_disk',
                       linewidth=2)
    visualization.plot(data,
                       ax,
                       space='H2_poincare_disk',
                       color='black',
                       alpha=0.7)

    plt.show()
コード例 #4
0
class TestFrechetMean(geomstats.tests.TestCase):
    _multiprocess_can_split_ = True

    def setUp(self):
        gs.random.seed(123)
        self.sphere = Hypersphere(dim=4)
        self.hyperbolic = Hyperboloid(dim=3)
        self.euclidean = Euclidean(dim=2)
        self.minkowski = Minkowski(dim=2)
        self.so3 = SpecialOrthogonal(n=3, point_type='vector')
        self.so_matrix = SpecialOrthogonal(n=3)

    def test_logs_at_mean_default_gradient_descent_sphere(self):
        n_tests = 10
        estimator = FrechetMean(
            metric=self.sphere.metric, method='default', lr=1.)

        result = []
        for _ in range(n_tests):
            # take 2 random points, compute their mean, and verify that
            # log of each at the mean is opposite
            points = self.sphere.random_uniform(n_samples=2)
            estimator.fit(points)
            mean = estimator.estimate_

            logs = self.sphere.metric.log(point=points, base_point=mean)
            result.append(gs.linalg.norm(logs[1, :] + logs[0, :]))
        result = gs.stack(result)
        expected = gs.zeros(n_tests)
        self.assertAllClose(expected, result)

    def test_logs_at_mean_adaptive_gradient_descent_sphere(self):
        n_tests = 10
        estimator = FrechetMean(metric=self.sphere.metric, method='adaptive')

        result = []
        for _ in range(n_tests):
            # take 2 random points, compute their mean, and verify that
            # log of each at the mean is opposite
            points = self.sphere.random_uniform(n_samples=2)
            estimator.fit(points)
            mean = estimator.estimate_

            logs = self.sphere.metric.log(point=points, base_point=mean)
            result.append(gs.linalg.norm(logs[1, :] + logs[0, :]))
        result = gs.stack(result)

        expected = gs.zeros(n_tests)
        self.assertAllClose(expected, result)

    def test_estimate_shape_default_gradient_descent_sphere(self):
        dim = 5
        point_a = gs.array([1., 0., 0., 0., 0.])
        point_b = gs.array([0., 1., 0., 0., 0.])
        points = gs.array([point_a, point_b])

        mean = FrechetMean(
            metric=self.sphere.metric, method='default', verbose=True)
        mean.fit(points)
        result = mean.estimate_

        self.assertAllClose(gs.shape(result), (dim,))

    def test_estimate_shape_adaptive_gradient_descent_sphere(self):
        dim = 5
        point_a = gs.array([1., 0., 0., 0., 0.])
        point_b = gs.array([0., 1., 0., 0., 0.])
        points = gs.array([point_a, point_b])

        mean = FrechetMean(metric=self.sphere.metric, method='adaptive')
        mean.fit(points)
        result = mean.estimate_

        self.assertAllClose(gs.shape(result), (dim,))

    def test_estimate_and_belongs_default_gradient_descent_sphere(self):
        point_a = gs.array([1., 0., 0., 0., 0.])
        point_b = gs.array([0., 1., 0., 0., 0.])
        points = gs.array([point_a, point_b])

        mean = FrechetMean(metric=self.sphere.metric, method='default')
        mean.fit(points)

        result = self.sphere.belongs(mean.estimate_)
        expected = True
        self.assertAllClose(result, expected)

    def test_estimate_default_gradient_descent_so3(self):
        points = self.so3.random_uniform(2)

        mean_vec = FrechetMean(
            metric=self.so3.bi_invariant_metric, method='default', lr=1.)
        mean_vec.fit(points)

        logs = self.so3.bi_invariant_metric.log(points, mean_vec.estimate_)
        result = gs.sum(logs, axis=0)
        expected = gs.zeros_like(points[0])
        self.assertAllClose(result, expected)

    def test_estimate_and_belongs_default_gradient_descent_so3(self):
        point = self.so3.random_uniform(10)

        mean_vec = FrechetMean(
            metric=self.so3.bi_invariant_metric, method='default')
        mean_vec.fit(point)

        result = self.so3.belongs(mean_vec.estimate_)
        expected = True
        self.assertAllClose(result, expected)

    @geomstats.tests.np_and_tf_only
    def test_estimate_default_gradient_descent_so_matrix(self):
        points = self.so_matrix.random_uniform(2)
        mean_vec = FrechetMean(
            metric=self.so_matrix.bi_invariant_metric, method='default',
            lr=1.)
        mean_vec.fit(points)
        logs = self.so_matrix.bi_invariant_metric.log(
            points, mean_vec.estimate_)
        result = gs.sum(logs, axis=0)
        expected = gs.zeros_like(points[0])

        self.assertAllClose(result, expected, atol=1e-5)

    @geomstats.tests.np_and_tf_only
    def test_estimate_and_belongs_default_gradient_descent_so_matrix(self):
        point = self.so_matrix.random_uniform(10)

        mean = FrechetMean(
            metric=self.so_matrix.bi_invariant_metric, method='default')
        mean.fit(point)

        result = self.so_matrix.belongs(mean.estimate_)
        expected = True
        self.assertAllClose(result, expected)

    @geomstats.tests.np_and_tf_only
    def test_estimate_and_belongs_adaptive_gradient_descent_so_matrix(self):
        point = self.so_matrix.random_uniform(10)

        mean = FrechetMean(
            metric=self.so_matrix.bi_invariant_metric, method='adaptive',
            verbose=True, lr=.5)
        mean.fit(point)

        result = self.so_matrix.belongs(mean.estimate_)
        self.assertTrue(result)

    @geomstats.tests.np_and_tf_only
    def test_estimate_and_coincide_default_so_vec_and_mat(self):
        point = self.so_matrix.random_uniform(3)

        mean = FrechetMean(
            metric=self.so_matrix.bi_invariant_metric, method='default')
        mean.fit(point)
        expected = mean.estimate_

        mean_vec = FrechetMean(
            metric=self.so3.bi_invariant_metric, method='default')
        point_vec = self.so3.rotation_vector_from_matrix(point)
        mean_vec.fit(point_vec)
        result_vec = mean_vec.estimate_
        result = self.so3.matrix_from_rotation_vector(result_vec)

        self.assertAllClose(result, expected)

    def test_estimate_and_belongs_adaptive_gradient_descent_sphere(self):
        point_a = gs.array([1., 0., 0., 0., 0.])
        point_b = gs.array([0., 1., 0., 0., 0.])
        points = gs.array([point_a, point_b])

        mean = FrechetMean(metric=self.sphere.metric, method='adaptive')
        mean.fit(points)

        result = self.sphere.belongs(mean.estimate_)
        expected = True
        self.assertAllClose(result, expected)

    def test_variance_sphere(self):
        point = gs.array([0., 0., 0., 0., 1.])
        points = gs.array([point, point])

        result = variance(
            points, base_point=point, metric=self.sphere.metric)
        expected = gs.array(0.)

        self.assertAllClose(expected, result)

    def test_estimate_default_gradient_descent_sphere(self):
        point = gs.array([0., 0., 0., 0., 1.])
        points = gs.array([point, point])

        mean = FrechetMean(metric=self.sphere.metric, method='default')
        mean.fit(X=points)

        result = mean.estimate_
        expected = point

        self.assertAllClose(expected, result)

    def test_estimate_adaptive_gradient_descent_sphere(self):
        point = gs.array([0., 0., 0., 0., 1.])
        points = gs.array([point, point])

        mean = FrechetMean(metric=self.sphere.metric, method='adaptive')
        mean.fit(X=points)

        result = mean.estimate_
        expected = point

        self.assertAllClose(expected, result)

    def test_estimate_spd(self):
        point = SPDMatrices(3).random_point()
        points = gs.array([point, point])
        mean = FrechetMean(metric=SPDMetricAffine(3), point_type='matrix')
        mean.fit(X=points)
        result = mean.estimate_
        expected = point
        self.assertAllClose(expected, result)

    def test_variance_hyperbolic(self):
        point = gs.array([2., 1., 1., 1.])
        points = gs.array([point, point])
        result = variance(
            points, base_point=point, metric=self.hyperbolic.metric)
        expected = gs.array(0.)

        self.assertAllClose(result, expected)

    def test_estimate_hyperbolic(self):
        point = gs.array([2., 1., 1., 1.])
        points = gs.array([point, point])

        mean = FrechetMean(metric=self.hyperbolic.metric)
        mean.fit(X=points)
        expected = point

        result = mean.estimate_

        self.assertAllClose(result, expected)

    def test_estimate_and_belongs_hyperbolic(self):
        point_a = self.hyperbolic.random_point()
        point_b = self.hyperbolic.random_point()
        point_c = self.hyperbolic.random_point()
        points = gs.stack([point_a, point_b, point_c], axis=0)

        mean = FrechetMean(metric=self.hyperbolic.metric)
        mean.fit(X=points)

        result = self.hyperbolic.belongs(mean.estimate_)
        expected = True

        self.assertAllClose(result, expected)

    def test_mean_euclidean_shape(self):
        dim = 2
        point = gs.array([1., 4.])

        mean = FrechetMean(metric=self.euclidean.metric)
        points = [point, point, point]
        mean.fit(points)

        result = mean.estimate_

        self.assertAllClose(gs.shape(result), (dim,))

    def test_mean_euclidean(self):
        point = gs.array([1., 4.])

        mean = FrechetMean(metric=self.euclidean.metric)
        points = [point, point, point]
        mean.fit(points)

        result = mean.estimate_
        expected = point

        self.assertAllClose(result, expected)

        points = gs.array([
            [1., 2.],
            [2., 3.],
            [3., 4.],
            [4., 5.]])
        weights = [1., 2., 1., 2.]

        mean = FrechetMean(metric=self.euclidean.metric)
        mean.fit(points, weights=weights)

        result = mean.estimate_
        expected = gs.array([16. / 6., 22. / 6.])

        self.assertAllClose(result, expected)

    def test_variance_euclidean(self):
        points = gs.array([
            [1., 2.],
            [2., 3.],
            [3., 4.],
            [4., 5.]])
        weights = gs.array([1., 2., 1., 2.])
        base_point = gs.zeros(2)
        result = variance(
            points, weights=weights, base_point=base_point,
            metric=self.euclidean.metric)
        # we expect the average of the points' sq norms.
        expected = gs.array((1 * 5. + 2 * 13. + 1 * 25. + 2 * 41.) / 6.)

        self.assertAllClose(result, expected)

    def test_mean_matrices_shape(self):
        m, n = (2, 2)
        point = gs.array([
            [1., 4.],
            [2., 3.]])

        metric = MatricesMetric(m, n)
        mean = FrechetMean(metric=metric, point_type='matrix')
        points = [point, point, point]
        mean.fit(points)

        result = mean.estimate_

        self.assertAllClose(gs.shape(result), (m, n))

    def test_mean_matrices(self):
        m, n = (2, 2)
        point = gs.array([
            [1., 4.],
            [2., 3.]])

        metric = MatricesMetric(m, n)
        mean = FrechetMean(metric=metric, point_type='matrix')
        points = [point, point, point]
        mean.fit(points)

        result = mean.estimate_
        expected = point

        self.assertAllClose(result, expected)

    def test_mean_minkowski_shape(self):
        dim = 2
        point = gs.array([2., -math.sqrt(3)])
        points = [point, point, point]

        mean = FrechetMean(metric=self.minkowski.metric)
        mean.fit(points)
        result = mean.estimate_

        self.assertAllClose(gs.shape(result), (dim,))

    def test_mean_minkowski(self):
        point = gs.array([2., -math.sqrt(3)])
        points = [point, point, point]

        mean = FrechetMean(metric=self.minkowski.metric)
        mean.fit(points)
        result = mean.estimate_

        expected = point

        self.assertAllClose(result, expected)

        points = gs.array([
            [1., 0.],
            [2., math.sqrt(3)],
            [3., math.sqrt(8)],
            [4., math.sqrt(24)]])
        weights = gs.array([1., 2., 1., 2.])

        mean = FrechetMean(metric=self.minkowski.metric)
        mean.fit(points, weights=weights)
        result = mean.estimate_
        result = self.minkowski.belongs(result)
        expected = gs.array(True)

        self.assertAllClose(result, expected)

    def test_variance_minkowski(self):
        points = gs.array([
            [1., 0.],
            [2., math.sqrt(3)],
            [3., math.sqrt(8)],
            [4., math.sqrt(24)]])
        weights = gs.array([1., 2., 1., 2.])
        base_point = gs.array([-1., 0.])
        var = variance(
            points, weights=weights, base_point=base_point,
            metric=self.minkowski.metric)
        result = var != 0
        # we expect the average of the points' Minkowski sq norms.
        expected = True
        self.assertAllClose(result, expected)

    def test_one_point(self):
        point = gs.array([0., 0., 0., 0., 1.])

        mean = FrechetMean(metric=self.sphere.metric, method='default')
        mean.fit(X=point)

        result = mean.estimate_
        expected = point
        self.assertAllClose(expected, result)

        mean = FrechetMean(
            metric=self.sphere.metric, method='frechet-poincare-ball')
        mean.fit(X=point)

        result = mean.estimate_
        expected = point
        self.assertAllClose(expected, result)
コード例 #5
0
ファイル: test_hyperbolic.py プロジェクト: xpennec/geomstats
class TestHyperbolic(geomstats.tests.TestCase):
    def setup_method(self):
        gs.random.seed(1234)
        self.dimension = 3
        self.space = Hyperboloid(dim=self.dimension)
        self.metric = self.space.metric
        self.ball_manifold = PoincareBall(dim=2)
        self.n_samples = 10

    def test_belongs_intrinsic(self):
        self.space.coords_type = "intrinsic"
        point = gs.random.rand(self.n_samples, self.dimension)
        result = self.space.belongs(point)
        self.assertTrue(gs.all(result))

    def test_regularize_intrinsic(self):
        self.space.coords_type = "intrinsic"
        point = gs.random.rand(self.n_samples, self.dimension)
        regularized = self.space.regularize(point)
        self.space.coords_type = "extrinsic"
        result = self.space.belongs(regularized)
        self.assertTrue(gs.all(result))

    def test_regularize_zero_norm(self):
        point = gs.array([-1.0, 1.0, 0.0, 0.0])
        with pytest.raises(ValueError):
            self.space.regularize(point)

        with pytest.raises(NameError):
            self.space.extrinsic_to_intrinsic_coords(point)

    def test_random_uniform_and_belongs(self):
        point = self.space.random_point()
        result = self.space.belongs(point)
        expected = True
        self.assertAllClose(result, expected)

    def test_random_uniform(self):
        result = self.space.random_point()

        self.assertAllClose(gs.shape(result), (self.dimension + 1,))

    def test_projection_to_tangent_space(self):
        base_point = gs.array([1.0, 0.0, 0.0, 0.0])
        belongs = self.space.belongs(base_point)
        self.assertTrue(belongs)

        tangent_vec = self.space.to_tangent(
            vector=gs.array([1.0, 2.0, 1.0, 3.0]), base_point=base_point
        )
        result = self.metric.inner_product(tangent_vec, base_point)
        expected = 0.0

        self.assertAllClose(result, expected)

        result = self.space.to_tangent(
            vector=gs.array([1.0, 2.0, 1.0, 3.0]), base_point=base_point
        )
        expected = tangent_vec

        self.assertAllClose(result, expected)

    def test_intrinsic_and_extrinsic_coords(self):
        """
        Test that the composition of
        intrinsic_to_extrinsic_coords and
        extrinsic_to_intrinsic_coords
        gives the identity.
        """
        point_int = gs.ones(self.dimension)
        point_ext = self.space.intrinsic_to_extrinsic_coords(point_int)
        result = self.space.extrinsic_to_intrinsic_coords(point_ext)
        expected = point_int
        self.assertAllClose(result, expected)

        point_ext = gs.array([2.0, 1.0, 1.0, 1.0])
        point_int = self.space.to_coordinates(point_ext, "intrinsic")
        result = self.space.from_coordinates(point_int, "intrinsic")
        expected = point_ext

        self.assertAllClose(result, expected)

    def test_intrinsic_and_extrinsic_coords_vectorization(self):
        """
        Test that the composition of
        intrinsic_to_extrinsic_coords and
        extrinsic_to_intrinsic_coords
        gives the identity.
        """
        point_int = gs.array(
            [
                [0.1, 0.0, 0.0, 0.1, 0.0, 0.0],
                [0.1, 0.1, 0.1, 0.4, 0.1, 0.0],
                [0.1, 0.3, 0.0, 0.1, 0.0, 0.0],
                [-0.1, 0.1, -0.4, 0.1, -0.01, 0.0],
                [0.0, 0.0, 0.1, 0.1, -0.08, -0.1],
                [0.1, 0.1, 0.1, 0.1, 0.0, -0.5],
            ]
        )
        point_ext = self.space.from_coordinates(point_int, "intrinsic")
        result = self.space.to_coordinates(point_ext, "intrinsic")
        expected = point_int
        expected = helper.to_vector(expected)

        self.assertAllClose(result, expected)

        point_ext = gs.array(
            [
                [2.0, 1.0, 1.0, 1.0],
                [4.0, 1.0, 3.0, math.sqrt(5.0)],
                [3.0, 2.0, 0.0, 2.0],
            ]
        )
        point_int = self.space.to_coordinates(point_ext, "intrinsic")
        result = self.space.from_coordinates(point_int, "intrinsic")
        expected = point_ext
        expected = helper.to_vector(expected)

        self.assertAllClose(result, expected)

    def test_log_and_exp_general_case(self):
        """
        Test that the Riemannian exponential
        and the Riemannian logarithm are inverse.

        Expect their composition to give the identity function.
        """
        # Riemannian Log then Riemannian Exp
        # General case
        base_point = gs.array([4.0, 1.0, 3.0, math.sqrt(5.0)])
        point = gs.array([2.0, 1.0, 1.0, 1.0])

        log = self.metric.log(point=point, base_point=base_point)

        result = self.metric.exp(tangent_vec=log, base_point=base_point)
        expected = point
        self.assertAllClose(result, expected)

    def test_log_and_exp_general_case_general_dim(self):
        """
        Test that the Riemannian exponential
        and the Riemannian logarithm are inverse.

        Expect their composition to give the identity function.
        """
        # Riemannian Log then Riemannian Exp
        dim = 5
        n_samples = self.n_samples

        h5 = Hyperboloid(dim=dim)
        h5_metric = h5.metric

        base_point = h5.random_point()
        point = h5.random_point()
        point = gs.cast(point, gs.float64)
        base_point = gs.cast(base_point, gs.float64)
        one_log = h5_metric.log(point=point, base_point=base_point)

        result = h5_metric.exp(tangent_vec=one_log, base_point=base_point)
        expected = point
        self.assertAllClose(result, expected)

        # Test vectorization of log
        base_point = gs.stack([base_point] * n_samples, axis=0)
        point = gs.stack([point] * n_samples, axis=0)
        expected = gs.stack([one_log] * n_samples, axis=0)

        log = h5_metric.log(point=point, base_point=base_point)
        result = log

        self.assertAllClose(gs.shape(result), (n_samples, dim + 1))
        self.assertAllClose(result, expected)

        result = h5_metric.exp(tangent_vec=log, base_point=base_point)
        expected = point

        self.assertAllClose(gs.shape(result), (n_samples, dim + 1))
        self.assertAllClose(result, expected)

        # Test vectorization of exp
        tangent_vec = gs.stack([one_log] * n_samples, axis=0)
        exp = h5_metric.exp(tangent_vec=tangent_vec, base_point=base_point)
        result = exp

        expected = point
        self.assertAllClose(gs.shape(result), (n_samples, dim + 1))
        self.assertAllClose(result, expected)

    def test_exp_and_belongs(self):
        H2 = Hyperboloid(dim=2)
        METRIC = H2.metric

        base_point = gs.array([1.0, 0.0, 0.0])
        self.assertTrue(H2.belongs(base_point))

        tangent_vec = H2.to_tangent(
            vector=gs.array([1.0, 2.0, 1.0]), base_point=base_point
        )
        exp = METRIC.exp(tangent_vec=tangent_vec, base_point=base_point)
        self.assertTrue(H2.belongs(exp))

    def test_exp_small_vec(self):
        H2 = Hyperboloid(dim=2)
        METRIC = H2.metric

        base_point = H2.regularize(gs.array([1.0, 0.0, 0.0]))
        self.assertTrue(H2.belongs(base_point))

        tangent_vec = 1e-9 * H2.to_tangent(
            vector=gs.array([1.0, 2.0, 1.0]), base_point=base_point
        )
        exp = METRIC.exp(tangent_vec=tangent_vec, base_point=base_point)
        self.assertTrue(H2.belongs(exp))

    def test_exp_vectorization(self):
        n_samples = 3
        dim = self.dimension + 1

        one_vec = gs.array([2.0, 1.0, 1.0, 1.0])
        one_base_point = gs.array([4.0, 3.0, 1.0, math.sqrt(5)])
        n_vecs = gs.array(
            [
                [2.0, 1.0, 1.0, 1.0],
                [4.0, 1.0, 3.0, math.sqrt(5.0)],
                [3.0, 2.0, 0.0, 2.0],
            ]
        )
        n_base_points = gs.array(
            [
                [2.0, 0.0, 1.0, math.sqrt(2)],
                [5.0, math.sqrt(8), math.sqrt(8), math.sqrt(8)],
                [1.0, 0.0, 0.0, 0.0],
            ]
        )

        one_tangent_vec = self.space.to_tangent(one_vec, base_point=one_base_point)
        result = self.metric.exp(one_tangent_vec, one_base_point)
        self.assertAllClose(gs.shape(result), (dim,))

        n_tangent_vecs = self.space.to_tangent(n_vecs, base_point=one_base_point)
        result = self.metric.exp(n_tangent_vecs, one_base_point)
        self.assertAllClose(gs.shape(result), (n_samples, dim))

        expected = []

        for i in range(n_samples):
            expected.append(self.metric.exp(n_tangent_vecs[i], one_base_point))
        expected = gs.stack(expected, axis=0)
        expected = helper.to_vector(gs.array(expected))
        self.assertAllClose(result, expected, atol=1e-2)

        one_tangent_vec = self.space.to_tangent(one_vec, base_point=n_base_points)
        result = self.metric.exp(one_tangent_vec, n_base_points)
        self.assertAllClose(gs.shape(result), (n_samples, dim))

        expected = []
        for i in range(n_samples):
            expected.append(self.metric.exp(one_tangent_vec[i], n_base_points[i]))
        expected = gs.stack(expected, axis=0)
        expected = helper.to_vector(gs.array(expected))
        self.assertAllClose(result, expected)

        n_tangent_vecs = self.space.to_tangent(n_vecs, base_point=n_base_points)
        result = self.metric.exp(n_tangent_vecs, n_base_points)
        self.assertAllClose(gs.shape(result), (n_samples, dim))

        expected = []
        for i in range(n_samples):
            expected.append(self.metric.exp(n_tangent_vecs[i], n_base_points[i]))
        expected = gs.stack(expected, axis=0)
        expected = helper.to_vector(gs.array(expected))
        self.assertAllClose(result, expected)

    def test_log_vectorization(self):
        n_samples = 3
        dim = self.dimension + 1

        one_point = gs.array([2.0, 1.0, 1.0, 1.0])
        one_base_point = gs.array([4.0, 3.0, 1.0, math.sqrt(5)])
        n_points = gs.array(
            [[2.0, 1.0, 1.0, 1.0], [4.0, 1.0, 3.0, math.sqrt(5)], [3.0, 2.0, 0.0, 2.0]]
        )
        n_base_points = gs.array(
            [
                [2.0, 0.0, 1.0, math.sqrt(2)],
                [5.0, math.sqrt(8), math.sqrt(8), math.sqrt(8)],
                [1.0, 0.0, 0.0, 0.0],
            ]
        )

        result = self.metric.log(one_point, one_base_point)
        self.assertAllClose(gs.shape(result), (dim,))

        result = self.metric.log(n_points, one_base_point)
        self.assertAllClose(gs.shape(result), (n_samples, dim))

        result = self.metric.log(one_point, n_base_points)
        self.assertAllClose(gs.shape(result), (n_samples, dim))

        result = self.metric.log(n_points, n_base_points)
        self.assertAllClose(gs.shape(result), (n_samples, dim))

    def test_inner_product(self):
        """
        Test that the inner product between two tangent vectors
        is the Minkowski inner product.
        """
        minkowski_space = Minkowski(self.dimension + 1)
        base_point = gs.array([1.16563816, 0.36381045, -0.47000603, 0.07381469])

        tangent_vec_a = self.space.to_tangent(
            vector=gs.array([10.0, 200.0, 1.0, 1.0]), base_point=base_point
        )

        tangent_vec_b = self.space.to_tangent(
            vector=gs.array([11.0, 20.0, -21.0, 0.0]), base_point=base_point
        )

        result = self.metric.inner_product(tangent_vec_a, tangent_vec_b, base_point)

        expected = minkowski_space.metric.inner_product(
            tangent_vec_a, tangent_vec_b, base_point
        )

        self.assertAllClose(result, expected)

    def test_squared_norm_and_squared_dist(self):
        """
        Test that the squared distance between two points is
        the squared norm of their logarithm.
        """
        point_a = gs.array([2.0, 1.0, 1.0, 1.0])
        point_b = gs.array([4.0, 1.0, 3.0, math.sqrt(5)])
        log = self.metric.log(point=point_a, base_point=point_b)
        result = self.metric.squared_norm(vector=log)
        expected = self.metric.squared_dist(point_a, point_b)

        self.assertAllClose(result, expected)

    def test_norm_and_dist(self):
        """
        Test that the distance between two points is
        the norm of their logarithm.
        """
        point_a = gs.array([2.0, 1.0, 1.0, 1.0])
        point_b = gs.array([4.0, 1.0, 3.0, math.sqrt(5)])
        log = self.metric.log(point=point_a, base_point=point_b)
        result = self.metric.norm(vector=log)
        expected = self.metric.dist(point_a, point_b)

        self.assertAllClose(result, expected)

    def test_log_and_exp_edge_case(self):
        """
        Test that the Riemannian exponential
        and the Riemannian logarithm are inverse.

        Expect their composition to give the identity function.
        """
        # Riemannian Log then Riemannian Exp
        # Edge case: two very close points, base_point_2 and point_2,
        # form an angle < epsilon
        base_point_intrinsic = gs.array([1.0, 2.0, 3.0])
        base_point = self.space.from_coordinates(base_point_intrinsic, "intrinsic")
        point_intrinsic = base_point_intrinsic + 1e-12 * gs.array([-1.0, -2.0, 1.0])
        point = self.space.from_coordinates(point_intrinsic, "intrinsic")

        log = self.metric.log(point=point, base_point=base_point)
        result = self.metric.exp(tangent_vec=log, base_point=base_point)
        expected = point

        self.assertAllClose(result, expected)

    def test_exp_and_log_and_projection_to_tangent_space_general_case(self):
        """
        Test that the Riemannian exponential
        and the Riemannian logarithm are inverse.

        Expect their composition to give the identity function.
        """
        # Riemannian Exp then Riemannian Log
        # General case
        base_point = gs.array([4.0, 1.0, 3.0, math.sqrt(5)])
        vector = gs.array([2.0, 1.0, 1.0, 1.0])
        vector = self.space.to_tangent(vector=vector, base_point=base_point)
        exp = self.metric.exp(tangent_vec=vector, base_point=base_point)
        result = self.metric.log(point=exp, base_point=base_point)

        expected = vector
        self.assertAllClose(result, expected)

    def test_dist(self):
        # Distance between a point and itself is 0.
        point_a = gs.array([4.0, 1.0, 3.0, math.sqrt(5)])
        point_b = point_a
        result = self.metric.dist(point_a, point_b)
        expected = 0
        self.assertAllClose(result, expected)

    def test_exp_and_dist_and_projection_to_tangent_space(self):
        base_point = gs.array([4.0, 1.0, 3.0, math.sqrt(5)])
        vector = gs.array([0.001, 0.0, -0.00001, -0.00003])
        tangent_vec = self.space.to_tangent(vector=vector, base_point=base_point)
        exp = self.metric.exp(tangent_vec=tangent_vec, base_point=base_point)

        result = self.metric.dist(base_point, exp)
        sq_norm = self.metric.embedding_metric.squared_norm(tangent_vec)
        expected = sq_norm
        self.assertAllClose(result, expected, atol=1e-2)

    def test_geodesic_and_belongs(self):
        initial_point = gs.array([4.0, 1.0, 3.0, math.sqrt(5)])
        n_geodesic_points = 100
        vector = gs.array([1.0, 0.0, 0.0, 0.0])

        initial_tangent_vec = self.space.to_tangent(
            vector=vector, base_point=initial_point
        )
        geodesic = self.metric.geodesic(
            initial_point=initial_point, initial_tangent_vec=initial_tangent_vec
        )

        t = gs.linspace(start=0.0, stop=1.0, num=n_geodesic_points)
        points = geodesic(t)
        result = gs.all(self.space.belongs(points))
        self.assertTrue(result)

    def test_geodesic_and_belongs_large_initial_velocity(self):
        initial_point = gs.array([4.0, 1.0, 3.0, math.sqrt(5)])
        n_geodesic_points = 100
        vector = gs.array([2.0, 0.0, 0.0, 0.0])

        initial_tangent_vec = self.space.to_tangent(
            vector=vector, base_point=initial_point
        )
        geodesic = self.metric.geodesic(
            initial_point=initial_point, initial_tangent_vec=initial_tangent_vec
        )

        t = gs.linspace(start=0.0, stop=1.0, num=n_geodesic_points)
        points = geodesic(t)
        result = gs.all(self.space.belongs(points, atol=gs.atol * 1e4))
        self.assertTrue(result)

    def test_exp_and_log_and_projection_to_tangent_space_edge_case(self):
        """
        Test that the Riemannian exponential and
        the Riemannian logarithm are inverse.

        Expect their composition to give the identity function.
        """
        # Riemannian Exp then Riemannian Log
        # Edge case: tangent vector has norm < epsilon
        base_point = gs.array([2.0, 1.0, 1.0, 1.0])
        vector = 1e-10 * gs.array([0.06, -51.0, 6.0, 5.0])

        exp = self.metric.exp(tangent_vec=vector, base_point=base_point)
        result = self.metric.log(point=exp, base_point=base_point)
        expected = self.space.to_tangent(vector=vector, base_point=base_point)

        self.assertAllClose(result, expected)

    def test_scaled_inner_product(self):
        base_point_intrinsic = gs.array([1.0, 1.0, 1.0])
        base_point = self.space.from_coordinates(base_point_intrinsic, "intrinsic")
        tangent_vec_a = gs.array([1.0, 2.0, 3.0, 4.0])
        tangent_vec_b = gs.array([5.0, 6.0, 7.0, 8.0])
        tangent_vec_a = self.space.to_tangent(tangent_vec_a, base_point)
        tangent_vec_b = self.space.to_tangent(tangent_vec_b, base_point)
        scale = 2
        default_space = Hyperboloid(dim=self.dimension)
        scaled_space = Hyperboloid(dim=self.dimension, scale=2)
        inner_product_default_metric = default_space.metric.inner_product(
            tangent_vec_a, tangent_vec_b, base_point
        )
        inner_product_scaled_metric = scaled_space.metric.inner_product(
            tangent_vec_a, tangent_vec_b, base_point
        )
        result = inner_product_scaled_metric
        expected = scale**2 * inner_product_default_metric
        self.assertAllClose(result, expected)

    def test_scaled_squared_norm(self):
        base_point_intrinsic = gs.array([1.0, 1.0, 1.0])
        base_point = self.space.from_coordinates(base_point_intrinsic, "intrinsic")
        tangent_vec = gs.array([1.0, 2.0, 3.0, 4.0])
        tangent_vec = self.space.to_tangent(tangent_vec, base_point)
        scale = 2
        default_space = Hyperboloid(dim=self.dimension)
        scaled_space = Hyperboloid(dim=self.dimension, scale=2)
        squared_norm_default_metric = default_space.metric.squared_norm(
            tangent_vec, base_point
        )
        squared_norm_scaled_metric = scaled_space.metric.squared_norm(
            tangent_vec, base_point
        )
        result = squared_norm_scaled_metric
        expected = scale**2 * squared_norm_default_metric
        self.assertAllClose(result, expected)

    def test_scaled_distance(self):
        point_a_intrinsic = gs.array([1.0, 2.0, 3.0])
        point_b_intrinsic = gs.array([4.0, 5.0, 6.0])
        point_a = self.space.from_coordinates(point_a_intrinsic, "intrinsic")
        point_b = self.space.from_coordinates(point_b_intrinsic, "intrinsic")
        scale = 2
        scaled_space = Hyperboloid(dim=self.dimension, scale=2)
        distance_default_metric = self.space.metric.dist(point_a, point_b)
        distance_scaled_metric = scaled_space.metric.dist(point_a, point_b)
        result = distance_scaled_metric
        expected = scale * distance_default_metric
        self.assertAllClose(result, expected)

    def test_is_tangent(self):
        base_point = gs.array([4.0, 1.0, 3.0, math.sqrt(5.0)])
        point = gs.array([2.0, 1.0, 1.0, 1.0])

        log = self.metric.log(point=point, base_point=base_point)
        result = self.space.is_tangent(log, base_point)
        self.assertTrue(result)

    @geomstats.tests.np_autograd_and_tf_only
    def test_parallel_transport_vectorization(self):
        space = self.space
        shape = (4, space.dim + 1)
        metric = space.metric

        results = helper.test_parallel_transport(space, metric, shape)
        for res in results:
            self.assertTrue(res)

    def test_projection_and_belongs(self):
        shape = (self.n_samples, self.dimension + 1)
        result = helper.test_projection_and_belongs(
            self.space, shape, atol=gs.atol * 100
        )
        for res in result:
            self.assertTrue(res)

        point = gs.array([0.0, 1.0, 0.0, 0.0])
        projected = self.space.projection(point)
        result = self.space.belongs(projected)
        self.assertTrue(result)
コード例 #6
0
class TestFrechetMean(geomstats.tests.TestCase):
    _multiprocess_can_split_ = True

    def setup_method(self):
        gs.random.seed(123)
        self.sphere = Hypersphere(dim=4)
        self.hyperbolic = Hyperboloid(dim=3)
        self.euclidean = Euclidean(dim=2)
        self.minkowski = Minkowski(dim=2)
        self.so3 = SpecialOrthogonal(n=3, point_type="vector")
        self.so_matrix = SpecialOrthogonal(n=3)
        self.curves_2d = DiscreteCurves(R2)
        self.elastic_metric = ElasticMetric(a=1, b=1, ambient_manifold=R2)

    def test_logs_at_mean_curves_2d(self):
        n_tests = 10
        metric = self.curves_2d.srv_metric
        estimator = FrechetMean(metric=metric, init_step_size=1.0)

        result = []
        for _ in range(n_tests):
            # take 2 random points, compute their mean, and verify that
            # log of each at the mean is opposite
            points = self.curves_2d.random_point(n_samples=2)
            estimator.fit(points)
            mean = estimator.estimate_

            # Expand and tile are added because of GitHub issue 1575
            mean = gs.expand_dims(mean, axis=0)
            mean = gs.tile(mean, (2, 1, 1))

            logs = metric.log(point=points, base_point=mean)
            logs_srv = metric.aux_differential_srv_transform(logs, curve=mean)
            # Note that the logs are NOT inverse, only the logs_srv are.
            result.append(gs.linalg.norm(logs_srv[1, :] + logs_srv[0, :]))
        result = gs.stack(result)
        expected = gs.zeros(n_tests)
        self.assertAllClose(expected, result, atol=1e-5)

    def test_logs_at_mean_default_gradient_descent_sphere(self):
        n_tests = 10
        estimator = FrechetMean(metric=self.sphere.metric,
                                method="default",
                                init_step_size=1.0)

        result = []
        for _ in range(n_tests):
            # take 2 random points, compute their mean, and verify that
            # log of each at the mean is opposite
            points = self.sphere.random_uniform(n_samples=2)
            estimator.fit(points)
            mean = estimator.estimate_

            logs = self.sphere.metric.log(point=points, base_point=mean)
            result.append(gs.linalg.norm(logs[1, :] + logs[0, :]))
        result = gs.stack(result)
        expected = gs.zeros(n_tests)
        self.assertAllClose(expected, result)

    def test_logs_at_mean_adaptive_gradient_descent_sphere(self):
        n_tests = 10
        estimator = FrechetMean(metric=self.sphere.metric, method="adaptive")

        result = []
        for _ in range(n_tests):
            # take 2 random points, compute their mean, and verify that
            # log of each at the mean is opposite
            points = self.sphere.random_uniform(n_samples=2)
            estimator.fit(points)
            mean = estimator.estimate_

            logs = self.sphere.metric.log(point=points, base_point=mean)
            result.append(gs.linalg.norm(logs[1, :] + logs[0, :]))
        result = gs.stack(result)

        expected = gs.zeros(n_tests)
        self.assertAllClose(expected, result)

    def test_estimate_shape_default_gradient_descent_sphere(self):
        dim = 5
        point_a = gs.array([1.0, 0.0, 0.0, 0.0, 0.0])
        point_b = gs.array([0.0, 1.0, 0.0, 0.0, 0.0])
        points = gs.array([point_a, point_b])

        mean = FrechetMean(metric=self.sphere.metric,
                           method="default",
                           verbose=True)
        mean.fit(points)
        result = mean.estimate_

        self.assertAllClose(gs.shape(result), (dim, ))

    def test_estimate_shape_adaptive_gradient_descent_sphere(self):
        dim = 5
        point_a = gs.array([1.0, 0.0, 0.0, 0.0, 0.0])
        point_b = gs.array([0.0, 1.0, 0.0, 0.0, 0.0])
        points = gs.array([point_a, point_b])

        mean = FrechetMean(metric=self.sphere.metric, method="adaptive")
        mean.fit(points)
        result = mean.estimate_

        self.assertAllClose(gs.shape(result), (dim, ))

    def test_estimate_shape_elastic_metric(self):
        points = self.curves_2d.random_point(n_samples=2)

        mean = FrechetMean(metric=self.elastic_metric)
        mean.fit(points)
        result = mean.estimate_

        self.assertAllClose(gs.shape(result), (points.shape[1:]))

    def test_estimate_shape_metric(self):
        points = self.curves_2d.random_point(n_samples=2)

        mean = FrechetMean(metric=self.curves_2d.srv_metric)
        mean.fit(points)
        result = mean.estimate_

        self.assertAllClose(gs.shape(result), (points.shape[1:]))

    def test_estimate_and_belongs_default_gradient_descent_sphere(self):
        point_a = gs.array([1.0, 0.0, 0.0, 0.0, 0.0])
        point_b = gs.array([0.0, 1.0, 0.0, 0.0, 0.0])
        points = gs.array([point_a, point_b])

        mean = FrechetMean(metric=self.sphere.metric, method="default")
        mean.fit(points)

        result = self.sphere.belongs(mean.estimate_)
        expected = True
        self.assertAllClose(result, expected)

    def test_estimate_and_belongs_curves_2d(self):
        points = self.curves_2d.random_point(n_samples=2)

        mean = FrechetMean(metric=self.curves_2d.srv_metric)
        mean.fit(points)

        result = self.curves_2d.belongs(mean.estimate_)
        expected = True
        self.assertAllClose(result, expected)

    def test_estimate_default_gradient_descent_so3(self):
        points = self.so3.random_uniform(2)

        mean_vec = FrechetMean(metric=self.so3.bi_invariant_metric,
                               method="default",
                               init_step_size=1.0)
        mean_vec.fit(points)

        logs = self.so3.bi_invariant_metric.log(points, mean_vec.estimate_)
        result = gs.sum(logs, axis=0)
        expected = gs.zeros_like(points[0])
        self.assertAllClose(result, expected)

    def test_estimate_and_belongs_default_gradient_descent_so3(self):
        point = self.so3.random_uniform(10)

        mean_vec = FrechetMean(metric=self.so3.bi_invariant_metric,
                               method="default")
        mean_vec.fit(point)

        result = self.so3.belongs(mean_vec.estimate_)
        expected = True
        self.assertAllClose(result, expected)

    @geomstats.tests.np_autograd_and_tf_only
    def test_estimate_default_gradient_descent_so_matrix(self):
        points = self.so_matrix.random_uniform(2)
        mean_vec = FrechetMean(
            metric=self.so_matrix.bi_invariant_metric,
            method="default",
            init_step_size=1.0,
        )
        mean_vec.fit(points)
        logs = self.so_matrix.bi_invariant_metric.log(points,
                                                      mean_vec.estimate_)
        result = gs.sum(logs, axis=0)
        expected = gs.zeros_like(points[0])

        self.assertAllClose(result, expected, atol=1e-5)

    @geomstats.tests.np_autograd_and_tf_only
    def test_estimate_and_belongs_default_gradient_descent_so_matrix(self):
        point = self.so_matrix.random_uniform(10)

        mean = FrechetMean(metric=self.so_matrix.bi_invariant_metric,
                           method="default")
        mean.fit(point)

        result = self.so_matrix.belongs(mean.estimate_)
        expected = True
        self.assertAllClose(result, expected)

    @geomstats.tests.np_autograd_and_tf_only
    def test_estimate_and_belongs_adaptive_gradient_descent_so_matrix(self):
        point = self.so_matrix.random_uniform(10)

        mean = FrechetMean(
            metric=self.so_matrix.bi_invariant_metric,
            method="adaptive",
            init_step_size=0.5,
            verbose=True,
        )
        mean.fit(point)

        result = self.so_matrix.belongs(mean.estimate_)
        self.assertTrue(result)

    @geomstats.tests.np_autograd_and_tf_only
    def test_estimate_and_coincide_default_so_vec_and_mat(self):
        point = self.so_matrix.random_uniform(3)

        mean = FrechetMean(metric=self.so_matrix.bi_invariant_metric,
                           method="default")
        mean.fit(point)
        expected = mean.estimate_

        mean_vec = FrechetMean(metric=self.so3.bi_invariant_metric,
                               method="default")
        point_vec = self.so3.rotation_vector_from_matrix(point)
        mean_vec.fit(point_vec)
        result_vec = mean_vec.estimate_
        result = self.so3.matrix_from_rotation_vector(result_vec)

        self.assertAllClose(result, expected)

    def test_estimate_and_belongs_adaptive_gradient_descent_sphere(self):
        point_a = gs.array([1.0, 0.0, 0.0, 0.0, 0.0])
        point_b = gs.array([0.0, 1.0, 0.0, 0.0, 0.0])
        points = gs.array([point_a, point_b])

        mean = FrechetMean(metric=self.sphere.metric, method="adaptive")
        mean.fit(points)

        result = self.sphere.belongs(mean.estimate_)
        expected = True
        self.assertAllClose(result, expected)

    def test_variance_sphere(self):
        point = gs.array([0.0, 0.0, 0.0, 0.0, 1.0])
        points = gs.array([point, point])

        result = variance(points, base_point=point, metric=self.sphere.metric)
        expected = gs.array(0.0)

        self.assertAllClose(expected, result)

    def test_estimate_default_gradient_descent_sphere(self):
        point = gs.array([0.0, 0.0, 0.0, 0.0, 1.0])
        points = gs.array([point, point])

        mean = FrechetMean(metric=self.sphere.metric, method="default")
        mean.fit(X=points)

        result = mean.estimate_
        expected = point

        self.assertAllClose(expected, result)

    def test_estimate_elastic_metric(self):
        point = self.curves_2d.random_point(n_samples=1)
        points = gs.array([point, point])

        mean = FrechetMean(metric=self.elastic_metric)
        mean.fit(X=points)

        result = mean.estimate_
        expected = point

        self.assertAllClose(expected, result)

    def test_estimate_curves_2d(self):
        point = self.curves_2d.random_point(n_samples=1)
        points = gs.array([point, point])

        mean = FrechetMean(metric=self.curves_2d.srv_metric)
        mean.fit(X=points)

        result = mean.estimate_
        expected = point

        self.assertAllClose(expected, result)

    def test_estimate_adaptive_gradient_descent_sphere(self):
        point = gs.array([0.0, 0.0, 0.0, 0.0, 1.0])
        points = gs.array([point, point])

        mean = FrechetMean(metric=self.sphere.metric, method="adaptive")
        mean.fit(X=points)

        result = mean.estimate_
        expected = point

        self.assertAllClose(expected, result)

    def test_estimate_spd(self):
        point = SPDMatrices(3).random_point()
        points = gs.array([point, point])
        mean = FrechetMean(metric=SPDMetricAffine(3), point_type="matrix")
        mean.fit(X=points)
        result = mean.estimate_
        expected = point
        self.assertAllClose(expected, result)

    def test_estimate_spd_two_samples(self):
        space = SPDMatrices(3)
        metric = SPDMetricAffine(3)
        point = space.random_point(2)
        mean = FrechetMean(metric)
        mean.fit(point)
        result = mean.estimate_
        expected = metric.exp(metric.log(point[0], point[1]) / 2, point[1])
        self.assertAllClose(expected, result)

    def test_variance_hyperbolic(self):
        point = gs.array([2.0, 1.0, 1.0, 1.0])
        points = gs.array([point, point])
        result = variance(points,
                          base_point=point,
                          metric=self.hyperbolic.metric)
        expected = gs.array(0.0)

        self.assertAllClose(result, expected)

    def test_estimate_hyperbolic(self):
        point = gs.array([2.0, 1.0, 1.0, 1.0])
        points = gs.array([point, point])

        mean = FrechetMean(metric=self.hyperbolic.metric)
        mean.fit(X=points)
        expected = point

        result = mean.estimate_

        self.assertAllClose(result, expected)

    def test_estimate_and_belongs_hyperbolic(self):
        point_a = self.hyperbolic.random_point()
        point_b = self.hyperbolic.random_point()
        point_c = self.hyperbolic.random_point()
        points = gs.stack([point_a, point_b, point_c], axis=0)

        mean = FrechetMean(metric=self.hyperbolic.metric)
        mean.fit(X=points)

        result = self.hyperbolic.belongs(mean.estimate_)
        expected = True

        self.assertAllClose(result, expected)

    def test_mean_euclidean_shape(self):
        dim = 2
        point = gs.array([1.0, 4.0])

        mean = FrechetMean(metric=self.euclidean.metric)
        points = [point, point, point]
        mean.fit(points)

        result = mean.estimate_

        self.assertAllClose(gs.shape(result), (dim, ))

    def test_mean_euclidean(self):
        point = gs.array([1.0, 4.0])

        mean = FrechetMean(metric=self.euclidean.metric)
        points = [point, point, point]
        mean.fit(points)

        result = mean.estimate_
        expected = point

        self.assertAllClose(result, expected)

        points = gs.array([[1.0, 2.0], [2.0, 3.0], [3.0, 4.0], [4.0, 5.0]])
        weights = [1.0, 2.0, 1.0, 2.0]

        mean = FrechetMean(metric=self.euclidean.metric)
        mean.fit(points, weights=weights)

        result = mean.estimate_
        expected = gs.array([16.0 / 6.0, 22.0 / 6.0])

        self.assertAllClose(result, expected)

    def test_variance_euclidean(self):
        points = gs.array([[1.0, 2.0], [2.0, 3.0], [3.0, 4.0], [4.0, 5.0]])
        weights = gs.array([1.0, 2.0, 1.0, 2.0])
        base_point = gs.zeros(2)
        result = variance(points,
                          weights=weights,
                          base_point=base_point,
                          metric=self.euclidean.metric)
        # we expect the average of the points' sq norms.
        expected = gs.array((1 * 5.0 + 2 * 13.0 + 1 * 25.0 + 2 * 41.0) / 6.0)

        self.assertAllClose(result, expected)

    def test_mean_matrices_shape(self):
        m, n = (2, 2)
        point = gs.array([[1.0, 4.0], [2.0, 3.0]])

        metric = MatricesMetric(m, n)
        mean = FrechetMean(metric=metric, point_type="matrix")
        points = [point, point, point]
        mean.fit(points)

        result = mean.estimate_

        self.assertAllClose(gs.shape(result), (m, n))

    def test_mean_matrices(self):
        m, n = (2, 2)
        point = gs.array([[1.0, 4.0], [2.0, 3.0]])

        metric = MatricesMetric(m, n)
        mean = FrechetMean(metric=metric, point_type="matrix")
        points = [point, point, point]
        mean.fit(points)

        result = mean.estimate_
        expected = point

        self.assertAllClose(result, expected)

    def test_mean_minkowski_shape(self):
        dim = 2
        point = gs.array([2.0, -math.sqrt(3)])
        points = [point, point, point]

        mean = FrechetMean(metric=self.minkowski.metric)
        mean.fit(points)
        result = mean.estimate_

        self.assertAllClose(gs.shape(result), (dim, ))

    def test_mean_minkowski(self):
        point = gs.array([2.0, -math.sqrt(3)])
        points = [point, point, point]

        mean = FrechetMean(metric=self.minkowski.metric)
        mean.fit(points)
        result = mean.estimate_

        expected = point

        self.assertAllClose(result, expected)

        points = gs.array([[1.0, 0.0], [2.0, math.sqrt(3)],
                           [3.0, math.sqrt(8)], [4.0, math.sqrt(24)]])
        weights = gs.array([1.0, 2.0, 1.0, 2.0])

        mean = FrechetMean(metric=self.minkowski.metric)
        mean.fit(points, weights=weights)
        result = self.minkowski.belongs(mean.estimate_)

        self.assertTrue(result)

    def test_variance_minkowski(self):
        points = gs.array([[1.0, 0.0], [2.0, math.sqrt(3)],
                           [3.0, math.sqrt(8)], [4.0, math.sqrt(24)]])
        weights = gs.array([1.0, 2.0, 1.0, 2.0])
        base_point = gs.array([-1.0, 0.0])
        var = variance(points,
                       weights=weights,
                       base_point=base_point,
                       metric=self.minkowski.metric)
        result = var != 0
        # we expect the average of the points' Minkowski sq norms.
        expected = True
        self.assertAllClose(result, expected)

    def test_one_point(self):
        point = gs.array([0.0, 0.0, 0.0, 0.0, 1.0])

        mean = FrechetMean(metric=self.sphere.metric, method="default")
        mean.fit(X=point)

        result = mean.estimate_
        expected = point
        self.assertAllClose(expected, result)

        mean = FrechetMean(metric=self.sphere.metric, method="default")
        mean.fit(X=point)

        result = mean.estimate_
        expected = point
        self.assertAllClose(expected, result)

    def test_batched(self):
        space = SPDMatrices(3)
        metric = SPDMetricAffine(3)
        point = space.random_point(4)
        mean_batch = FrechetMean(metric, method="batch", verbose=True)
        data = gs.stack([point[:2], point[2:]], axis=1)
        mean_batch.fit(data)
        result = mean_batch.estimate_

        mean = FrechetMean(metric)
        mean.fit(data[:, 0])
        expected_1 = mean.estimate_
        mean.fit(data[:, 1])
        expected_2 = mean.estimate_
        expected = gs.stack([expected_1, expected_2])
        self.assertAllClose(expected, result)

    @geomstats.tests.np_and_autograd_only
    def test_stiefel_two_samples(self):
        space = Stiefel(3, 2)
        metric = space.metric
        point = space.random_point(2)
        mean = FrechetMean(metric)
        mean.fit(point)
        result = mean.estimate_
        expected = metric.exp(metric.log(point[0], point[1]) / 2, point[1])
        self.assertAllClose(expected, result)

    @geomstats.tests.np_and_autograd_only
    def test_stiefel_n_samples(self):
        space = Stiefel(3, 2)
        metric = space.metric
        point = space.random_point(2)
        mean = FrechetMean(metric,
                           method="default",
                           init_step_size=0.5,
                           verbose=True)
        mean.fit(point)
        result = space.belongs(mean.estimate_)
        self.assertTrue(result)

    def test_circle_mean(self):
        space = Hypersphere(1)
        points = space.random_uniform(10)
        mean_circle = FrechetMean(space.metric)
        mean_circle.fit(points)
        estimate_circle = mean_circle.estimate_

        # set a wrong dimension so that the extrinsic coordinates are used
        metric = space.metric
        metric.dim = 2
        mean_extrinsic = FrechetMean(metric)
        mean_extrinsic.fit(points)
        estimate_extrinsic = mean_extrinsic.estimate_
        self.assertAllClose(estimate_circle, estimate_extrinsic)