Exemplo n.º 1
0
    def test_split_horizontal_vertical(
        self, times, n_discretized_curves, curve_a, curve_b
    ):
        """Test split horizontal vertical.
        Check that horizontal and vertical parts of any tangent
        vector are othogonal with respect to the SRVMetric inner
        product, and check vectorization.
        """
        srv_metric_r3 = SRVMetric(r3)
        quotient_srv_metric_r3 = DiscreteCurves(ambient_manifold=r3).quotient_srv_metric
        geod = srv_metric_r3.geodesic(initial_curve=curve_a, end_curve=curve_b)
        geod = geod(times)
        tangent_vec = n_discretized_curves * (geod[1, :, :] - geod[0, :, :])
        (
            tangent_vec_hor,
            tangent_vec_ver,
            _,
        ) = quotient_srv_metric_r3.split_horizontal_vertical(tangent_vec, curve_a)
        result = srv_metric_r3.inner_product(tangent_vec_hor, tangent_vec_ver, curve_a)
        expected = 0.0
        self.assertAllClose(result, expected, atol=1e-4)

        tangent_vecs = n_discretized_curves * (geod[1:] - geod[:-1])
        _, _, result = quotient_srv_metric_r3.split_horizontal_vertical(
            tangent_vecs, geod[:-1]
        )
        expected = []
        for i in range(n_discretized_curves - 1):
            _, _, res = quotient_srv_metric_r3.split_horizontal_vertical(
                tangent_vecs[i], geod[i]
            )
            expected.append(res)
        expected = gs.stack(expected)
        self.assertAllClose(result, expected)
Exemplo n.º 2
0
 def test_horizontal_geodesic(self, n_sampling_points, curve_a, n_times):
     """Test horizontal geodesic.
     Check that the time derivative of the geodesic is
     horizontal at all time.
     """
     curve_b = gs.transpose(
         gs.stack(
             (
                 gs.zeros(n_sampling_points),
                 gs.zeros(n_sampling_points),
                 gs.linspace(1.0, 0.5, n_sampling_points),
             )
         )
     )
     quotient_srv_metric_r3 = DiscreteCurves(ambient_manifold=r3).quotient_srv_metric
     horizontal_geod_fun = quotient_srv_metric_r3.horizontal_geodesic(
         curve_a, curve_b
     )
     times = gs.linspace(0.0, 1.0, n_times)
     horizontal_geod = horizontal_geod_fun(times)
     velocity_vec = n_times * (horizontal_geod[1:] - horizontal_geod[:-1])
     _, _, vertical_norms = quotient_srv_metric_r3.split_horizontal_vertical(
         velocity_vec, horizontal_geod[:-1]
     )
     result = gs.sum(vertical_norms**2, axis=1) ** (1 / 2)
     expected = gs.zeros(n_times - 1)
     self.assertAllClose(result, expected, atol=1e-3)
Exemplo n.º 3
0
    def test_cartesian_to_polar_and_polar_to_cartesian(self, a, b, rtol, atol):
        """Test conversion to polar coordinate"""
        curves_space = DiscreteCurves(ambient_manifold=r2)
        el_metric = ElasticMetric(a=a, b=b)
        curve = curves_space.random_point()
        polar_curve = el_metric.cartesian_to_polar(curve)
        result = el_metric.polar_to_cartesian(polar_curve)

        self.assertAllClose(result, curve, rtol, atol)
Exemplo n.º 4
0
 def setup_method(self):
     gs.random.seed(123)
     self.sphere = Hypersphere(dim=4)
     self.hyperbolic = Hyperboloid(dim=3)
     self.euclidean = Euclidean(dim=2)
     self.minkowski = Minkowski(dim=2)
     self.so3 = SpecialOrthogonal(n=3, point_type="vector")
     self.so_matrix = SpecialOrthogonal(n=3)
     self.curves_2d = DiscreteCurves(R2)
     self.elastic_metric = ElasticMetric(a=1, b=1, ambient_manifold=R2)
    def setUp(self):
        s2 = Hypersphere(dim=2)
        r2 = Euclidean(dim=2)
        r3 = s2.embedding_space

        initial_point = [0.0, 0.0, 1.0]
        initial_tangent_vec_a = [1.0, 0.0, 0.0]
        initial_tangent_vec_b = [0.0, 1.0, 0.0]
        initial_tangent_vec_c = [-1.0, 0.0, 0.0]

        curve_fun_a = s2.metric.geodesic(
            initial_point=initial_point, initial_tangent_vec=initial_tangent_vec_a
        )
        curve_fun_b = s2.metric.geodesic(
            initial_point=initial_point, initial_tangent_vec=initial_tangent_vec_b
        )
        curve_fun_c = s2.metric.geodesic(
            initial_point=initial_point, initial_tangent_vec=initial_tangent_vec_c
        )
        self.curve_fun_a = curve_fun_a

        self.n_sampling_points = 10
        self.sampling_times = gs.linspace(0.0, 1.0, self.n_sampling_points)
        self.curve_a = curve_fun_a(self.sampling_times)
        self.curve_b = curve_fun_b(self.sampling_times)
        self.curve_c = curve_fun_c(self.sampling_times)

        self.space_closed_curves_in_euclidean_2d = ClosedDiscreteCurves(
            ambient_manifold=r2
        )

        self.a = 1
        self.b = 1
        self.space_elastic_curves = ElasticCurves(self.a, self.b)
        self.elastic_metric = self.space_elastic_curves.elastic_metric

        self.n_discretized_curves = 5
        self.times = gs.linspace(0.0, 1.0, self.n_discretized_curves)
        gs.random.seed(1234)
        self.space_curves_in_euclidean_3d = DiscreteCurves(ambient_manifold=r3)
        self.space_curves_in_sphere_2d = DiscreteCurves(ambient_manifold=s2)
        self.l2_metric_s2 = self.space_curves_in_sphere_2d.l2_metric(
            self.n_sampling_points
        )
        self.l2_metric_r3 = self.space_curves_in_euclidean_3d.l2_metric(
            self.n_sampling_points
        )
        self.srv_metric_r3 = (
            self.space_curves_in_euclidean_3d.square_root_velocity_metric
        )
        self.quotient_srv_metric_r3 = (
            self.space_curves_in_euclidean_3d.quotient_square_root_velocity_metric
        )
Exemplo n.º 6
0
    def test_f_transform_and_srv_transform_vectorization(self, rtol, atol):
        """Test that the f transform coincides with the SRVF.

        This is valid for a f_transform with a=1, b=1/2.
        """
        curves_space = DiscreteCurves(ambient_manifold=r2)
        el_metric = ElasticMetric(a=1, b=0.5)

        curves = curves_space.random_point(n_samples=2)

        result = el_metric.f_transform(curves)
        expected = curves_space.srv_metric.srv_transform(curves)
        self.assertAllClose(result, expected, rtol, atol)
Exemplo n.º 7
0
    def test_f_transform_and_inverse(self, a, b, rtol, atol):
        """Test that the inverse is right."""
        curves_space = DiscreteCurves(ambient_manifold=r2)
        el_metric = ElasticMetric(a=a, b=b)
        curve = curves_space.random_point()

        f = el_metric.f_transform(curve)
        f_inverse = el_metric.f_transform_inverse(f, curve[0])

        result = f.shape
        expected = (curve.shape[0] - 1, 2)
        self.assertAllClose(result, expected)

        result = f_inverse
        expected = curve
        self.assertAllClose(result, expected, rtol, atol)
Exemplo n.º 8
0
 def test_quotient_dist(self, sampling_times, curve_fun_a, curve_a,
                        n_sampling_points):
     """Test quotient distance.
     Check that the quotient distance is the same as the distance
     between the end points of the horizontal geodesic.
     """
     curve_a_resampled = curve_fun_a(sampling_times**2)
     curve_b = gs.transpose(
         gs.stack((
             gs.zeros(n_sampling_points),
             gs.zeros(n_sampling_points),
             gs.linspace(1.0, 0.5, n_sampling_points),
         )))
     quotient_srv_metric_r3 = DiscreteCurves(
         ambient_manifold=r3).quotient_square_root_velocity_metric
     result = quotient_srv_metric_r3.dist(curve_a_resampled, curve_b)
     expected = quotient_srv_metric_r3.dist(curve_a, curve_b)
     self.assertAllClose(result, expected, atol=1e-3, rtol=1e-3)
Exemplo n.º 9
0
    def test_f_transform_and_srv_transform(self, curve, rtol, atol):
        """Test that the f transform coincides with the SRVF

        This is valid for a f transform with a=1, b=1/2.
        """
        curves_space = DiscreteCurves(ambient_manifold=r2)
        el_metric = ElasticMetric(a=1, b=0.5)

        result = el_metric.f_transform(curve)
        expected = curves_space.srv_metric.srv_transform(curve)
        self.assertAllClose(result, expected, rtol, atol)
Exemplo n.º 10
0
    def setUp(self):
        s2 = Hypersphere(dim=2)
        r3 = s2.embedding_manifold

        initial_point = [0., 0., 1.]
        initial_tangent_vec_a = [1., 0., 0.]
        initial_tangent_vec_b = [0., 1., 0.]
        initial_tangent_vec_c = [-1., 0., 0.]

        curve_a = s2.metric.geodesic(initial_point=initial_point,
                                     initial_tangent_vec=initial_tangent_vec_a)
        curve_b = s2.metric.geodesic(initial_point=initial_point,
                                     initial_tangent_vec=initial_tangent_vec_b)
        curve_c = s2.metric.geodesic(initial_point=initial_point,
                                     initial_tangent_vec=initial_tangent_vec_c)

        self.n_sampling_points = 10
        sampling_times = gs.linspace(0., 1., self.n_sampling_points)
        discretized_curve_a = curve_a(sampling_times)
        discretized_curve_b = curve_b(sampling_times)
        discretized_curve_c = curve_c(sampling_times)

        self.n_discretized_curves = 5
        self.times = gs.linspace(0., 1., self.n_discretized_curves)
        self.atol = 1e-6
        gs.random.seed(1234)
        self.space_curves_in_euclidean_3d = DiscreteCurves(ambient_manifold=r3)
        self.space_curves_in_sphere_2d = DiscreteCurves(ambient_manifold=s2)
        self.l2_metric_s2 = self.space_curves_in_sphere_2d.l2_metric(
            self.n_sampling_points)
        self.l2_metric_r3 = self.space_curves_in_euclidean_3d.l2_metric(
            self.n_sampling_points)
        self.srv_metric_r3 = self.space_curves_in_euclidean_3d.\
            square_root_velocity_metric
        self.curve_a = discretized_curve_a
        self.curve_b = discretized_curve_b
        self.curve_c = discretized_curve_c
Exemplo n.º 11
0
    def test_cells(self):
        """Test that cells belong to space of planar curves."""
        cells, cell_lines, treatments = data_utils.load_cells()
        expected = 650
        result = len(cells)
        self.assertAllClose(result, expected)
        result = len(cell_lines)
        self.assertAllClose(result, expected)
        result = len(treatments)
        self.assertAllClose(result, expected)

        planar_curves_space = DiscreteCurves(R2)

        result = planar_curves_space.belongs(cells)
        self.assertTrue(gs.all(result))

        result = [line in ["dlm8", "dunn"] for line in cell_lines]
        self.assertTrue(gs.all(result))

        result = [
            treatment in ["control", "cytd", "jasp"]
            for treatment in treatments
        ]
        self.assertTrue(gs.all(result))
Exemplo n.º 12
0
    def test_f_transform(self):
        """Test that the f transform coincides with the SRVF.

        With the parameters: a=1, b=1/2.
        """
        r2 = Euclidean(dim=2)
        elastic_metric = ElasticMetric(a=1.0, b=0.5)
        curves_r2 = DiscreteCurves(ambient_manifold=r2)
        curve_a_projected = gs.stack((self.curve_a[:, 0], self.curve_a[:, 2]), axis=-1)

        result = elastic_metric.f_transform(curve_a_projected)
        expected = gs.squeeze(
            curves_r2.square_root_velocity_metric.srv_transform(curve_a_projected)
        )
        self.assertAllClose(result, expected)
Exemplo n.º 13
0
    def test_srv_transform_and_srv_transform_inverse(self, rtol, atol):
        """Test that srv and its inverse are inverse."""
        metric = SRVMetric(ambient_manifold=r3)
        curve = DiscreteCurves(r3).random_point(n_samples=2)

        srv = metric.srv_transform(curve)
        srv_inverse = metric.srv_transform_inverse(srv, curve[:, 0])

        result = srv.shape
        expected = (curve.shape[0], curve.shape[1] - 1, 3)
        self.assertAllClose(result, expected)

        result = srv_inverse
        expected = curve
        self.assertAllClose(result, expected, rtol, atol)
Exemplo n.º 14
0
    def test_elastic_and_srv_dist(self):
        """Test that SRV dist and elastic dist coincide.

        For a=1 and b=1/2.
        """
        r2 = Euclidean(dim=2)
        elastic_metric = ElasticMetric(a=1.0, b=0.5)
        curves_r2 = DiscreteCurves(ambient_manifold=r2)
        curve_a_projected = gs.stack((self.curve_a[:, 0], self.curve_a[:, 2]), axis=-1)
        curve_b_projected = gs.stack((self.curve_b[:, 0], self.curve_b[:, 2]), axis=-1)
        result = elastic_metric.dist(curve_a_projected, curve_b_projected)
        expected = curves_r2.square_root_velocity_metric.dist(
            curve_a_projected, curve_b_projected
        )
        print(result / expected)
        self.assertAllClose(result, expected)
Exemplo n.º 15
0
    def test_f_transform_inverse_and_srv_transform_inverse(self, curve, rtol, atol):
        """Test that the f transform coincides with the SRVF

        This is valid for a f transform with a=1, b=1/2.
        """
        curves_space = DiscreteCurves(ambient_manifold=r2)

        el_metric = ElasticMetric(a=1, b=0.5)
        starting_point = curve[0]
        fake_transformed_curve = curve[1:, :]

        result = el_metric.f_transform_inverse(fake_transformed_curve, starting_point)
        expected = curves_space.srv_metric.srv_transform_inverse(
            fake_transformed_curve, starting_point
        )
        self.assertAllClose(result, expected, rtol, atol)
Exemplo n.º 16
0
class L2CurvesMetricTestData(_RiemannianMetricTestData):

    ambient_manifolds_list = [r2, r3]
    metric_args_list = [(ambient_manifolds, )
                        for ambient_manifolds in ambient_manifolds_list]
    shape_list = [(10, 2), (10, 3)]
    space_list = [
        DiscreteCurves(ambient_manifolds)
        for ambient_manifolds in ambient_manifolds_list
    ]
    n_points_list = random.sample(range(2, 5), 2)
    n_tangent_vecs_list = random.sample(range(2, 5), 2)
    n_points_a_list = random.sample(range(2, 5), 2)
    n_points_b_list = [1]
    batch_size_list = random.sample(range(2, 5), 2)
    alpha_list = [1] * 2
    n_rungs_list = [1] * 2
    scheme_list = ["pole"] * 2

    def exp_shape_test_data(self):
        return self._exp_shape_test_data(self.metric_args_list,
                                         self.space_list, self.shape_list)

    def log_shape_test_data(self):
        return self._log_shape_test_data(self.metric_args_list,
                                         self.space_list)

    def squared_dist_is_symmetric_test_data(self):
        return self._squared_dist_is_symmetric_test_data(
            self.metric_args_list,
            self.space_list,
            self.n_points_a_list,
            self.n_points_b_list,
            atol=gs.atol * 1000,
        )

    def exp_belongs_test_data(self):
        return self._exp_belongs_test_data(
            self.metric_args_list,
            self.space_list,
            self.shape_list,
            self.n_tangent_vecs_list,
            belongs_atol=gs.atol * 1000,
        )

    def log_is_tangent_test_data(self):
        return self._log_is_tangent_test_data(
            self.metric_args_list,
            self.space_list,
            self.n_points_list,
            is_tangent_atol=gs.atol * 1000,
        )

    def geodesic_ivp_belongs_test_data(self):
        return self._geodesic_ivp_belongs_test_data(
            self.metric_args_list,
            self.space_list,
            self.shape_list,
            self.n_points_list,
            belongs_atol=gs.atol * 1000,
        )

    def geodesic_bvp_belongs_test_data(self):
        return self._geodesic_bvp_belongs_test_data(
            self.metric_args_list,
            self.space_list,
            self.n_points_list,
            belongs_atol=gs.atol * 1000,
        )

    def exp_after_log_test_data(self):
        return self._exp_after_log_test_data(
            self.metric_args_list,
            self.space_list,
            self.n_points_list,
            rtol=gs.rtol * 100,
            atol=gs.atol * 10000,
        )

    def log_after_exp_test_data(self):
        return self._log_after_exp_test_data(
            self.metric_args_list,
            self.space_list,
            self.shape_list,
            self.n_tangent_vecs_list,
            rtol=gs.rtol * 100,
            atol=gs.atol * 10000,
        )

    def exp_ladder_parallel_transport_test_data(self):
        return self._exp_ladder_parallel_transport_test_data(
            self.metric_args_list,
            self.space_list,
            self.shape_list,
            self.n_tangent_vecs_list,
            self.n_rungs_list,
            self.alpha_list,
            self.scheme_list,
        )

    def exp_geodesic_ivp_test_data(self):
        return self._exp_geodesic_ivp_test_data(
            self.metric_args_list,
            self.space_list,
            self.shape_list,
            self.n_tangent_vecs_list,
            self.n_points_list,
            rtol=gs.rtol * 100000,
            atol=gs.atol * 100000,
        )

    def parallel_transport_ivp_is_isometry_test_data(self):
        return self._parallel_transport_ivp_is_isometry_test_data(
            self.metric_args_list,
            self.space_list,
            self.shape_list,
            self.n_tangent_vecs_list,
            is_tangent_atol=gs.atol * 1000,
            atol=gs.atol * 1000,
        )

    def parallel_transport_bvp_is_isometry_test_data(self):
        return self._parallel_transport_bvp_is_isometry_test_data(
            self.metric_args_list,
            self.space_list,
            self.shape_list,
            self.n_tangent_vecs_list,
            is_tangent_atol=gs.atol * 1000,
            atol=gs.atol * 1000,
        )

    def dist_is_symmetric_test_data(self):
        return self._dist_is_symmetric_test_data(
            self.metric_args_list,
            self.space_list,
            self.n_points_a_list,
            self.n_points_b_list,
        )

    def dist_is_positive_test_data(self):
        return self._dist_is_positive_test_data(
            self.metric_args_list,
            self.space_list,
            self.n_points_a_list,
            self.n_points_b_list,
        )

    def squared_dist_is_positive_test_data(self):
        return self._squared_dist_is_positive_test_data(
            self.metric_args_list,
            self.space_list,
            self.n_points_a_list,
            self.n_points_b_list,
        )

    def dist_is_norm_of_log_test_data(self):
        return self._dist_is_norm_of_log_test_data(
            self.metric_args_list,
            self.space_list,
            self.n_points_a_list,
            self.n_points_b_list,
        )

    def dist_point_to_itself_is_zero_test_data(self):
        return self._dist_point_to_itself_is_zero_test_data(
            self.metric_args_list, self.space_list, self.n_points_list)

    def inner_product_is_symmetric_test_data(self):
        return self._inner_product_is_symmetric_test_data(
            self.metric_args_list,
            self.space_list,
            self.shape_list,
            self.n_tangent_vecs_list,
        )

    def triangle_inequality_of_dist_test_data(self):
        return self._triangle_inequality_of_dist_test_data(
            self.metric_args_list, self.space_list, self.n_points_list)

    def l2_metric_geodesic_test_data(self):
        smoke_data = [
            dict(
                ambient_manfold=s2,
                curve_a=curve_a,
                curve_b=curve_b,
                times=times,
                n_sampling_points=n_sampling_points,
            )
        ]
        return self.generate_tests(smoke_data)
Exemplo n.º 17
0
class TestDiscreteCurves(geomstats.tests.TestCase):
    def setUp(self):
        s2 = Hypersphere(dim=2)
        r3 = s2.embedding_space

        initial_point = [0., 0., 1.]
        initial_tangent_vec_a = [1., 0., 0.]
        initial_tangent_vec_b = [0., 1., 0.]
        initial_tangent_vec_c = [-1., 0., 0.]

        curve_a = s2.metric.geodesic(initial_point=initial_point,
                                     initial_tangent_vec=initial_tangent_vec_a)
        curve_b = s2.metric.geodesic(initial_point=initial_point,
                                     initial_tangent_vec=initial_tangent_vec_b)
        curve_c = s2.metric.geodesic(initial_point=initial_point,
                                     initial_tangent_vec=initial_tangent_vec_c)

        self.n_sampling_points = 10
        sampling_times = gs.linspace(0., 1., self.n_sampling_points)
        discretized_curve_a = curve_a(sampling_times)
        discretized_curve_b = curve_b(sampling_times)
        discretized_curve_c = curve_c(sampling_times)

        self.n_discretized_curves = 5
        self.times = gs.linspace(0., 1., self.n_discretized_curves)
        gs.random.seed(1234)
        self.space_curves_in_euclidean_3d = DiscreteCurves(ambient_manifold=r3)
        self.space_curves_in_sphere_2d = DiscreteCurves(ambient_manifold=s2)
        self.l2_metric_s2 = self.space_curves_in_sphere_2d.l2_metric(
            self.n_sampling_points)
        self.l2_metric_r3 = self.space_curves_in_euclidean_3d.l2_metric(
            self.n_sampling_points)
        self.srv_metric_r3 = self.space_curves_in_euclidean_3d.\
            square_root_velocity_metric
        self.curve_a = discretized_curve_a
        self.curve_b = discretized_curve_b
        self.curve_c = discretized_curve_c

    def test_belongs(self):
        result = self.space_curves_in_sphere_2d.belongs(self.curve_a)
        self.assertTrue(result)

        curve_ab = [self.curve_a[:-1], self.curve_b]
        result = self.space_curves_in_sphere_2d.belongs(curve_ab)
        self.assertTrue(gs.all(result))

        curve_ab = gs.array([self.curve_a, self.curve_b])
        result = self.space_curves_in_sphere_2d.belongs(curve_ab)
        self.assertTrue(gs.all(result))

    def test_l2_metric_log_and_squared_norm_and_dist(self):
        """Test that squared norm of logarithm is squared dist."""
        tangent_vec = self.l2_metric_s2.log(point=self.curve_b,
                                            base_point=self.curve_a)
        log_ab = tangent_vec
        result = self.l2_metric_s2.squared_norm(vector=log_ab,
                                                base_point=self.curve_a)
        expected = self.l2_metric_s2.dist(self.curve_a, self.curve_b)**2

        self.assertAllClose(result, expected)

    def test_l2_metric_log_and_exp(self):
        """Test that exp and log are inverse maps."""
        tangent_vec = self.l2_metric_s2.log(point=self.curve_b,
                                            base_point=self.curve_a)
        result = self.l2_metric_s2.exp(tangent_vec=tangent_vec,
                                       base_point=self.curve_a)
        expected = self.curve_b

        self.assertAllClose(result, expected)

    def test_l2_metric_inner_product_vectorization(self):
        """Test the vectorization inner_product."""
        n_samples = self.n_discretized_curves
        curves_ab = self.l2_metric_s2.geodesic(self.curve_a, self.curve_b)
        curves_bc = self.l2_metric_s2.geodesic(self.curve_b, self.curve_c)
        curves_ab = curves_ab(self.times)
        curves_bc = curves_bc(self.times)

        tangent_vecs = self.l2_metric_s2.log(point=curves_bc,
                                             base_point=curves_ab)

        result = self.l2_metric_s2.inner_product(tangent_vecs, tangent_vecs,
                                                 curves_ab)

        self.assertAllClose(gs.shape(result), (n_samples, ))

    def test_l2_metric_dist_vectorization(self):
        """Test the vectorization of dist."""
        n_samples = self.n_discretized_curves
        curves_ab = self.l2_metric_s2.geodesic(self.curve_a, self.curve_b)
        curves_bc = self.l2_metric_s2.geodesic(self.curve_b, self.curve_c)
        curves_ab = curves_ab(self.times)
        curves_bc = curves_bc(self.times)

        result = self.l2_metric_s2.dist(curves_ab, curves_bc)
        self.assertAllClose(gs.shape(result), (n_samples, ))

    def test_l2_metric_exp_vectorization(self):
        """Test the vectorization of exp."""
        curves_ab = self.l2_metric_s2.geodesic(self.curve_a, self.curve_b)
        curves_bc = self.l2_metric_s2.geodesic(self.curve_b, self.curve_c)
        curves_ab = curves_ab(self.times)
        curves_bc = curves_bc(self.times)

        tangent_vecs = self.l2_metric_s2.log(point=curves_bc,
                                             base_point=curves_ab)

        result = self.l2_metric_s2.exp(tangent_vec=tangent_vecs,
                                       base_point=curves_ab)
        self.assertAllClose(gs.shape(result), gs.shape(curves_ab))

    def test_l2_metric_log_vectorization(self):
        """Test the vectorization of log."""
        curves_ab = self.l2_metric_s2.geodesic(self.curve_a, self.curve_b)
        curves_bc = self.l2_metric_s2.geodesic(self.curve_b, self.curve_c)
        curves_ab = curves_ab(self.times)
        curves_bc = curves_bc(self.times)

        tangent_vecs = self.l2_metric_s2.log(point=curves_bc,
                                             base_point=curves_ab)

        result = tangent_vecs
        self.assertAllClose(gs.shape(result), gs.shape(curves_ab))

    def test_l2_metric_geodesic(self):
        """Test the geodesic method of L2Metric."""
        curves_ab = self.l2_metric_s2.geodesic(self.curve_a, self.curve_b)
        curves_ab = curves_ab(self.times)

        result = curves_ab
        expected = []
        for k in range(self.n_sampling_points):
            geod = self.l2_metric_s2.ambient_metric.geodesic(
                initial_point=self.curve_a[k, :], end_point=self.curve_b[k, :])
            expected.append(geod(self.times))
        expected = gs.stack(expected, axis=1)
        self.assertAllClose(result, expected)

    def test_srv_metric_pointwise_inner_product(self):
        curves_ab = self.l2_metric_s2.geodesic(self.curve_a, self.curve_b)
        curves_bc = self.l2_metric_s2.geodesic(self.curve_b, self.curve_c)
        curves_ab = curves_ab(self.times)
        curves_bc = curves_bc(self.times)

        tangent_vecs = self.l2_metric_s2.log(point=curves_bc,
                                             base_point=curves_ab)
        result = self.srv_metric_r3.pointwise_inner_product(
            tangent_vec_a=tangent_vecs,
            tangent_vec_b=tangent_vecs,
            base_curve=curves_ab)
        expected_shape = (self.n_discretized_curves, self.n_sampling_points)
        self.assertAllClose(gs.shape(result), expected_shape)

        result = self.srv_metric_r3.pointwise_inner_product(
            tangent_vec_a=tangent_vecs[0],
            tangent_vec_b=tangent_vecs[0],
            base_curve=curves_ab[0])
        expected_shape = (self.n_sampling_points, )
        self.assertAllClose(gs.shape(result), expected_shape)

    def test_square_root_velocity_and_inverse(self):
        """Test of square_root_velocity and its inverse.

        N.B: Here curves_ab are seen as curves in R3 and not S2.
        """
        curves_ab = self.l2_metric_s2.geodesic(self.curve_a, self.curve_b)
        curves_ab = curves_ab(self.times)

        curves = curves_ab
        srv_curves = self.srv_metric_r3.square_root_velocity(curves)
        starting_points = curves[:, 0, :]
        result = self.srv_metric_r3.square_root_velocity_inverse(
            srv_curves, starting_points)
        expected = curves

        self.assertAllClose(result, expected)

    def test_srv_metric_exp_and_log(self):
        """Test that exp and log are inverse maps and vectorized.

        N.B: Here curves_ab and curves_bc are seen as curves in R3 and not S2.
        """
        curves_ab = self.l2_metric_s2.geodesic(self.curve_a, self.curve_b)
        curves_bc = self.l2_metric_s2.geodesic(self.curve_b, self.curve_c)
        curves_ab = curves_ab(self.times)
        curves_bc = curves_bc(self.times)

        log = self.srv_metric_r3.log(point=curves_bc, base_point=curves_ab)
        result = self.srv_metric_r3.exp(tangent_vec=log, base_point=curves_ab)
        expected = curves_bc

        self.assertAllClose(gs.squeeze(result), expected)

    def test_srv_metric_geodesic(self):
        """Test that the geodesic between two curves in a Euclidean space.

        for the srv metric is the L2 geodesic betweeen the curves srvs.
        N.B: Here curve_a and curve_b are seen as curves in R3 and not S2.
        """
        geod = self.srv_metric_r3.geodesic(initial_curve=self.curve_a,
                                           end_curve=self.curve_b)
        result = geod(self.times)

        srv_a = self.srv_metric_r3.square_root_velocity(self.curve_a)
        srv_b = self.srv_metric_r3.square_root_velocity(self.curve_b)
        l2_metric = self.space_curves_in_euclidean_3d.l2_metric(
            self.n_sampling_points - 1)
        geod_srv = l2_metric.geodesic(initial_point=srv_a, end_point=srv_b)
        geod_srv = geod_srv(self.times)

        starting_points = self.srv_metric_r3.ambient_metric.geodesic(
            initial_point=self.curve_a[0, :], end_point=self.curve_b[0, :])
        starting_points = starting_points(self.times)

        expected = self.srv_metric_r3.square_root_velocity_inverse(
            geod_srv, starting_points)

        self.assertAllClose(result, expected)

    def test_srv_metric_dist_and_geod(self):
        """Test that the length of the geodesic gives the distance.

        N.B: Here curve_a and curve_b are seen as curves in R3 and not S2.
        """
        geod = self.srv_metric_r3.geodesic(initial_curve=self.curve_a,
                                           end_curve=self.curve_b)
        geod = geod(self.times)

        srv = self.srv_metric_r3.square_root_velocity(geod)

        srv_derivative = self.n_discretized_curves * (srv[1:, :] - srv[:-1, :])
        l2_metric = self.space_curves_in_euclidean_3d.l2_metric(
            self.n_sampling_points - 1)
        norms = l2_metric.norm(srv_derivative, geod[:-1, :-1, :])
        result = gs.sum(norms, 0) / self.n_discretized_curves

        expected = self.srv_metric_r3.dist(self.curve_a, self.curve_b)[0]
        self.assertAllClose(result, expected)

    def test_random_and_belongs(self):
        random = self.space_curves_in_sphere_2d.random_point()
        result = self.space_curves_in_sphere_2d.belongs(random)
        self.assertTrue(result)
        self.assertAllClose(random.shape, (10, 3))

        random = self.space_curves_in_sphere_2d.random_point(2)
        result = self.space_curves_in_sphere_2d.belongs(random)
        self.assertTrue(gs.all(result))

    def test_is_tangent_to_tangent(self):
        point = self.space_curves_in_sphere_2d.random_point()
        vector = self.space_curves_in_sphere_2d.random_point()
        tangent_vec = self.space_curves_in_sphere_2d.to_tangent(vector, point)
        result = self.space_curves_in_sphere_2d.is_tangent(tangent_vec, point)
        self.assertTrue(result)

        point = self.space_curves_in_sphere_2d.random_point(2)
        vector = self.space_curves_in_sphere_2d.random_point(2)
        tangent_vec = self.space_curves_in_sphere_2d.to_tangent(vector, point)
        result = self.space_curves_in_sphere_2d.is_tangent(tangent_vec, point)
        self.assertTrue(gs.all(result))
Exemplo n.º 18
0
class TestDiscreteCurves(geomstats.tests.TestCase):
    def setup_method(self):
        s2 = Hypersphere(dim=2)
        r2 = Euclidean(dim=2)
        r3 = s2.embedding_space

        initial_point = [0.0, 0.0, 1.0]
        initial_tangent_vec_a = [1.0, 0.0, 0.0]
        initial_tangent_vec_b = [0.0, 1.0, 0.0]
        initial_tangent_vec_c = [-1.0, 0.0, 0.0]

        curve_fun_a = s2.metric.geodesic(
            initial_point=initial_point, initial_tangent_vec=initial_tangent_vec_a
        )
        curve_fun_b = s2.metric.geodesic(
            initial_point=initial_point, initial_tangent_vec=initial_tangent_vec_b
        )
        curve_fun_c = s2.metric.geodesic(
            initial_point=initial_point, initial_tangent_vec=initial_tangent_vec_c
        )
        self.curve_fun_a = curve_fun_a

        self.n_sampling_points = 10
        self.sampling_times = gs.linspace(0.0, 1.0, self.n_sampling_points)
        self.curve_a = curve_fun_a(self.sampling_times)
        self.curve_b = curve_fun_b(self.sampling_times)
        self.curve_c = curve_fun_c(self.sampling_times)

        self.space_curves_in_euclidean_3d = DiscreteCurves(ambient_manifold=r3)
        self.space_curves_in_sphere_2d = DiscreteCurves(ambient_manifold=s2)
        self.space_closed_curves_in_euclidean_2d = ClosedDiscreteCurves(
            ambient_manifold=r2
        )

        self.l2_metric_s2 = L2CurvesMetric(ambient_manifold=s2)
        self.l2_metric_r3 = L2CurvesMetric(ambient_manifold=r3)
        self.srv_metric_r3 = (
            self.space_curves_in_euclidean_3d.square_root_velocity_metric
        )
        self.quotient_srv_metric_r3 = (
            self.space_curves_in_euclidean_3d.quotient_square_root_velocity_metric
        )
        self.a = 1
        self.b = 1
        self.elastic_metric = ElasticMetric(self.a, self.b)

        self.n_discretized_curves = 5
        self.times = gs.linspace(0.0, 1.0, self.n_discretized_curves)
        gs.random.seed(1234)

    def test_belongs(self):
        result = self.space_curves_in_sphere_2d.belongs(self.curve_a)
        self.assertTrue(result)

        curve_ab = [self.curve_a[:-1], self.curve_b]
        result = self.space_curves_in_sphere_2d.belongs(curve_ab)
        self.assertTrue(gs.all(result))

        curve_ab = gs.array([self.curve_a, self.curve_b])
        result = self.space_curves_in_sphere_2d.belongs(curve_ab)
        self.assertTrue(gs.all(result))

    @geomstats.tests.np_autograd_and_torch_only
    def test_l2_metric_log_and_squared_norm_and_dist(self):
        """Test that squared norm of logarithm is squared dist."""
        tangent_vec = self.l2_metric_s2.log(point=self.curve_b, base_point=self.curve_a)
        log_ab = tangent_vec
        result = self.l2_metric_s2.squared_norm(vector=log_ab, base_point=self.curve_a)
        expected = self.l2_metric_s2.dist(self.curve_a, self.curve_b) ** 2

        self.assertAllClose(result, expected)

    def test_l2_metric_log_and_exp(self):
        """Test that exp and log are inverse maps."""
        tangent_vec = self.l2_metric_s2.log(point=self.curve_b, base_point=self.curve_a)
        result = self.l2_metric_s2.exp(tangent_vec=tangent_vec, base_point=self.curve_a)
        expected = self.curve_b

        self.assertAllClose(result, expected)

    def test_l2_metric_inner_product_vectorization(self):
        """Test the vectorization inner_product."""
        n_samples = self.n_discretized_curves
        curves_ab = self.l2_metric_s2.geodesic(self.curve_a, self.curve_b)
        curves_bc = self.l2_metric_s2.geodesic(self.curve_b, self.curve_c)
        curves_ab = curves_ab(self.times)
        curves_bc = curves_bc(self.times)

        tangent_vecs = self.l2_metric_s2.log(point=curves_bc, base_point=curves_ab)

        result = self.l2_metric_s2.inner_product(tangent_vecs, tangent_vecs, curves_ab)

        self.assertAllClose(gs.shape(result), (n_samples,))

    def test_l2_metric_dist_vectorization(self):
        """Test the vectorization of dist."""
        n_samples = self.n_discretized_curves
        curves_ab = self.l2_metric_s2.geodesic(self.curve_a, self.curve_b)
        curves_bc = self.l2_metric_s2.geodesic(self.curve_b, self.curve_c)
        curves_ab = curves_ab(self.times)
        curves_bc = curves_bc(self.times)

        result = self.l2_metric_s2.dist(curves_ab, curves_bc)
        self.assertAllClose(gs.shape(result), (n_samples,))

    def test_l2_metric_exp_vectorization(self):
        """Test the vectorization of exp."""
        curves_ab = self.l2_metric_s2.geodesic(self.curve_a, self.curve_b)
        curves_bc = self.l2_metric_s2.geodesic(self.curve_b, self.curve_c)
        curves_ab = curves_ab(self.times)
        curves_bc = curves_bc(self.times)

        tangent_vecs = self.l2_metric_s2.log(point=curves_bc, base_point=curves_ab)

        result = self.l2_metric_s2.exp(tangent_vec=tangent_vecs, base_point=curves_ab)
        self.assertAllClose(gs.shape(result), gs.shape(curves_ab))

    def test_l2_metric_log_vectorization(self):
        """Test the vectorization of log."""
        curves_ab = self.l2_metric_s2.geodesic(self.curve_a, self.curve_b)
        curves_bc = self.l2_metric_s2.geodesic(self.curve_b, self.curve_c)
        curves_ab = curves_ab(self.times)
        curves_bc = curves_bc(self.times)

        tangent_vecs = self.l2_metric_s2.log(point=curves_bc, base_point=curves_ab)

        result = tangent_vecs
        self.assertAllClose(gs.shape(result), gs.shape(curves_ab))

    def test_l2_metric_geodesic(self):
        """Test the geodesic method of L2Metric."""
        curves_ab = self.l2_metric_s2.geodesic(self.curve_a, self.curve_b)
        curves_ab = curves_ab(self.times)

        result = curves_ab
        expected = []
        for k in range(self.n_sampling_points):
            geod = self.l2_metric_s2.ambient_metric.geodesic(
                initial_point=self.curve_a[k, :], end_point=self.curve_b[k, :]
            )
            expected.append(geod(self.times))
        expected = gs.stack(expected, axis=1)
        self.assertAllClose(result, expected)

    def test_srv_metric_pointwise_inner_products(self):
        curves_ab = self.l2_metric_s2.geodesic(self.curve_a, self.curve_b)
        curves_bc = self.l2_metric_s2.geodesic(self.curve_b, self.curve_c)
        curves_ab = curves_ab(self.times)
        curves_bc = curves_bc(self.times)

        tangent_vecs = self.l2_metric_s2.log(point=curves_bc, base_point=curves_ab)
        result = self.srv_metric_r3.l2_metric.pointwise_inner_products(
            tangent_vec_a=tangent_vecs, tangent_vec_b=tangent_vecs, base_curve=curves_ab
        )
        expected_shape = (self.n_discretized_curves, self.n_sampling_points)
        self.assertAllClose(gs.shape(result), expected_shape)

        result = self.srv_metric_r3.l2_metric.pointwise_inner_products(
            tangent_vec_a=tangent_vecs[0],
            tangent_vec_b=tangent_vecs[0],
            base_curve=curves_ab[0],
        )
        expected_shape = (self.n_sampling_points,)
        self.assertAllClose(gs.shape(result), expected_shape)

    def test_srv_transform_and_inverse(self):
        """Test of SRVT and its inverse.

        N.B: Here curves_ab are seen as curves in R3 and not S2.
        """
        curves_ab = self.l2_metric_s2.geodesic(self.curve_a, self.curve_b)
        curves_ab = curves_ab(self.times)

        curves = curves_ab
        srv_curves = self.srv_metric_r3.srv_transform(curves)
        starting_points = curves[:, 0, :]
        result = self.srv_metric_r3.srv_transform_inverse(srv_curves, starting_points)
        expected = curves

        self.assertAllClose(result, expected)

    def test_srv_metric_exp_and_log(self):
        """Test that exp and log are inverse maps and vectorized.

        N.B: Here curves_ab and curves_bc are seen as curves in R3 and not S2.
        """
        curves_ab = self.l2_metric_s2.geodesic(self.curve_a, self.curve_b)
        curves_bc = self.l2_metric_s2.geodesic(self.curve_b, self.curve_c)
        curves_ab = curves_ab(self.times)
        curves_bc = curves_bc(self.times)

        log = self.srv_metric_r3.log(point=curves_bc, base_point=curves_ab)
        result = self.srv_metric_r3.exp(tangent_vec=log, base_point=curves_ab)
        expected = curves_bc

        self.assertAllClose(gs.squeeze(result), expected)

    def test_srv_metric_geodesic(self):
        """Test that the geodesic between two curves in a Euclidean space.

        for the srv metric is the L2 geodesic betweeen the curves srvs.
        N.B: Here curve_a and curve_b are seen as curves in R3 and not S2.
        """
        geod = self.srv_metric_r3.geodesic(
            initial_curve=self.curve_a, end_curve=self.curve_b
        )
        result = geod(self.times)

        srv_a = self.srv_metric_r3.srv_transform(self.curve_a)
        srv_b = self.srv_metric_r3.srv_transform(self.curve_b)
        geod_srv = self.l2_metric_r3.geodesic(initial_point=srv_a, end_point=srv_b)
        geod_srv = geod_srv(self.times)

        starting_points = self.srv_metric_r3.ambient_metric.geodesic(
            initial_point=self.curve_a[0, :], end_point=self.curve_b[0, :]
        )
        starting_points = starting_points(self.times)

        expected = self.srv_metric_r3.srv_transform_inverse(geod_srv, starting_points)

        self.assertAllClose(result, expected)

    def test_srv_metric_dist_and_geod(self):
        """Test that the length of the geodesic gives the distance.

        N.B: Here curve_a and curve_b are seen as curves in R3 and not S2.
        """
        geod = self.srv_metric_r3.geodesic(
            initial_curve=self.curve_a, end_curve=self.curve_b
        )
        geod = geod(self.times)
        srv = self.srv_metric_r3.srv_transform(geod)
        srv_derivative = self.n_discretized_curves * (srv[1:, :] - srv[:-1, :])
        norms = self.srv_metric_r3.l2_metric.norm(srv_derivative)
        result = gs.sum(norms, 0) / self.n_discretized_curves

        expected = self.srv_metric_r3.dist(self.curve_a, self.curve_b)
        self.assertAllClose(result, expected)

    def test_random_and_belongs(self):
        random = self.space_curves_in_sphere_2d.random_point()
        result = self.space_curves_in_sphere_2d.belongs(random)
        self.assertTrue(result)
        self.assertAllClose(random.shape, (10, 3))

        random = self.space_curves_in_sphere_2d.random_point(2)
        result = self.space_curves_in_sphere_2d.belongs(random)
        self.assertTrue(gs.all(result))

    def test_is_tangent_to_tangent(self):
        point = self.space_curves_in_sphere_2d.random_point()
        vector = self.space_curves_in_sphere_2d.random_point()
        tangent_vec = self.space_curves_in_sphere_2d.to_tangent(vector, point)
        result = self.space_curves_in_sphere_2d.is_tangent(tangent_vec, point)
        self.assertTrue(result)

        point = self.space_curves_in_sphere_2d.random_point(2)
        vector = self.space_curves_in_sphere_2d.random_point(2)
        tangent_vec = self.space_curves_in_sphere_2d.to_tangent(vector, point)
        result = self.space_curves_in_sphere_2d.is_tangent(tangent_vec, point)
        self.assertTrue(gs.all(result))

    @geomstats.tests.np_and_autograd_only
    def test_projection_closed_curves(self):
        """Test that projecting the projection returns projection.

        Also test that the projection is a closed curve.
        """
        planar_closed_curves = self.space_closed_curves_in_euclidean_2d

        cells, _, _ = data_utils.load_cells()
        curves = [cell[:-10] for cell in cells[:5]]

        for curve in curves:
            proj = planar_closed_curves.project(curve)
            expected = proj
            result = planar_closed_curves.project(proj)
            self.assertAllClose(result, expected)

            result = proj[-1, :]
            expected = proj[0, :]
            self.assertAllClose(result, expected, rtol=10 * gs.rtol)

    def test_srv_inner_product(self):
        """Test that srv_inner_product works as expected.

        Also test that the resulting shape is right.
        """
        curves_ab = self.l2_metric_s2.geodesic(self.curve_a, self.curve_b)
        curves_bc = self.l2_metric_s2.geodesic(self.curve_b, self.curve_c)
        curves_ab = curves_ab(self.times)
        curves_bc = curves_bc(self.times)
        srvs_ab = self.srv_metric_r3.srv_transform(curves_ab)
        srvs_bc = self.srv_metric_r3.srv_transform(curves_bc)

        result = self.srv_metric_r3.l2_metric.inner_product(srvs_ab, srvs_bc)
        products = srvs_ab * srvs_bc
        expected = [gs.sum(product) for product in products]
        expected = gs.array(expected) / (srvs_ab.shape[-2] + 1)
        self.assertAllClose(result, expected)

        result = result.shape
        expected = [srvs_ab.shape[0]]
        self.assertAllClose(result, expected)

    def test_srv_norm(self):
        """Test that srv_norm works as expected.

        Also test that the resulting shape is right.
        """
        curves_ab = self.l2_metric_s2.geodesic(self.curve_a, self.curve_b)
        curves_ab = curves_ab(self.times)
        srvs_ab = self.srv_metric_r3.srv_transform(curves_ab)

        result = self.srv_metric_r3.l2_metric.norm(srvs_ab)
        products = srvs_ab * srvs_ab
        sums = [gs.sum(product) for product in products]
        squared_norm = gs.array(sums) / (srvs_ab.shape[-2] + 1)
        expected = gs.sqrt(squared_norm)
        self.assertAllClose(result, expected)

        result = result.shape
        expected = [srvs_ab.shape[0]]
        self.assertAllClose(result, expected)

    def test_f_transform(self):
        """Test that the f transform coincides with the SRVF.

        With the parameters: a=1, b=1/2.
        """
        r2 = Euclidean(dim=2)
        elastic_metric = ElasticMetric(a=1.0, b=0.5)
        curves_r2 = DiscreteCurves(ambient_manifold=r2)
        curve_a_projected = gs.stack((self.curve_a[:, 0], self.curve_a[:, 2]), axis=-1)

        result = elastic_metric.f_transform(curve_a_projected)
        expected = gs.squeeze(
            curves_r2.square_root_velocity_metric.srv_transform(curve_a_projected)
        )
        self.assertAllClose(result, expected)

    @geomstats.tests.np_autograd_and_tf_only
    def test_f_transform_and_inverse(self):
        """Test that the inverse is right."""
        cells, _, _ = data_utils.load_cells()
        curve = cells[0]
        metric = self.elastic_metric
        f = metric.f_transform(curve)
        f_inverse = metric.f_transform_inverse(f, curve[0])

        result = f.shape
        expected = (curve.shape[0] - 1, 2)
        self.assertAllClose(result, expected)

        result = f_inverse
        expected = curve
        self.assertAllClose(result, expected)

    @geomstats.tests.np_autograd_and_torch_only
    def test_elastic_dist(self):
        """Test shape and positivity."""
        cells, _, _ = data_utils.load_cells()
        curve_1, curve_2 = cells[0][:10], cells[1][:10]
        metric = self.elastic_metric
        dist = metric.dist(curve_1, curve_2)

        result = dist.shape
        expected = ()
        self.assertAllClose(result, expected)

        result = dist > 0
        self.assertTrue(result)

    @geomstats.tests.np_autograd_and_torch_only
    def test_elastic_and_srv_dist(self):
        """Test that SRV dist and elastic dist coincide.

        For a=1 and b=1/2.
        """
        r2 = Euclidean(dim=2)
        elastic_metric = ElasticMetric(a=1.0, b=0.5)
        curves_r2 = DiscreteCurves(ambient_manifold=r2)
        curve_a_projected = gs.stack((self.curve_a[:, 0], self.curve_a[:, 2]), axis=-1)
        curve_b_projected = gs.stack((self.curve_b[:, 0], self.curve_b[:, 2]), axis=-1)
        result = elastic_metric.dist(curve_a_projected, curve_b_projected)
        expected = curves_r2.square_root_velocity_metric.dist(
            curve_a_projected, curve_b_projected
        )
        print(result / expected)
        self.assertAllClose(result, expected)

    def test_cartesian_to_polar_and_inverse(self):
        """Test that going back to cartesian works."""
        cells, _, _ = data_utils.load_cells()
        curve = cells[0]

        metric = self.elastic_metric
        norms, args = metric.cartesian_to_polar(curve)

        result = metric.polar_to_cartesian(norms, args)
        expected = curve
        self.assertAllClose(result, expected, rtol=10000 * gs.rtol)

    @geomstats.tests.np_and_autograd_only
    def test_aux_differential_srv_transform(self):
        """Test differential of square root velocity transform.

        Check that its value at (curve, tangent_vec) coincides
        with the derivative at zero of the square root velocity
        transform of a path of curves starting at curve with
        initial derivative tangent_vec.
        """
        dim = 3
        n_sampling_points = 2000
        sampling_times = gs.linspace(0.0, 1.0, n_sampling_points)
        curve_a = self.curve_fun_a(sampling_times)
        tangent_vec = gs.transpose(
            gs.tile(gs.linspace(1.0, 2.0, n_sampling_points), (dim, 1))
        )
        result = self.srv_metric_r3.aux_differential_srv_transform(tangent_vec, curve_a)

        n_curves = 2000
        times = gs.linspace(0.0, 1.0, n_curves)
        path_of_curves = curve_a + gs.einsum("i,jk->ijk", times, tangent_vec)
        srv_path = self.srv_metric_r3.srv_transform(path_of_curves)
        expected = n_curves * (srv_path[1] - srv_path[0])
        self.assertAllClose(result, expected, atol=1e-3, rtol=1e-3)

    @geomstats.tests.np_and_autograd_only
    def test_aux_differential_srv_transform_inverse(self):
        """Test inverse of differential of square root velocity transform.

        Check that it is the inverse of aux_differential_srv_transform.
        """
        dim = 3
        tangent_vec = gs.transpose(
            gs.tile(gs.linspace(0.0, 1.0, self.n_sampling_points), (dim, 1))
        )
        d_srv = self.srv_metric_r3.aux_differential_srv_transform(
            tangent_vec, self.curve_a
        )
        result = self.srv_metric_r3.aux_differential_srv_transform_inverse(
            d_srv, self.curve_a
        )
        expected = tangent_vec
        self.assertAllClose(result, expected, atol=1e-3, rtol=1e-3)

    def test_aux_differential_srv_transform_vectorization(self):
        """Test differential of square root velocity transform.

        Check vectorization.
        """
        dim = 3
        curves = gs.stack((self.curve_a, self.curve_b))
        tangent_vecs = gs.random.rand(2, self.n_sampling_points, dim)
        result = self.srv_metric_r3.aux_differential_srv_transform(tangent_vecs, curves)

        res_a = self.srv_metric_r3.aux_differential_srv_transform(
            tangent_vecs[0], self.curve_a
        )
        res_b = self.srv_metric_r3.aux_differential_srv_transform(
            tangent_vecs[1], self.curve_b
        )
        expected = gs.stack([res_a, res_b])
        self.assertAllClose(result, expected)

    def test_srv_inner_product_elastic(self):
        """Test inner product of SRVMetric.

        Check that the pullback metric gives an elastic metric
        with parameters a=1, b=1/2.
        """
        tangent_vec_a = gs.random.rand(self.n_sampling_points, 3)
        tangent_vec_b = gs.random.rand(self.n_sampling_points, 3)
        result = self.srv_metric_r3.inner_product(
            tangent_vec_a, tangent_vec_b, self.curve_a
        )

        r3 = Euclidean(3)
        d_vec_a = (self.n_sampling_points - 1) * (
            tangent_vec_a[1:, :] - tangent_vec_a[:-1, :]
        )
        d_vec_b = (self.n_sampling_points - 1) * (
            tangent_vec_b[1:, :] - tangent_vec_b[:-1, :]
        )
        velocity_vec = (self.n_sampling_points - 1) * (
            self.curve_a[1:, :] - self.curve_a[:-1, :]
        )
        velocity_norm = r3.metric.norm(velocity_vec)
        unit_velocity_vec = gs.einsum("ij,i->ij", velocity_vec, 1 / velocity_norm)
        a_param = 1
        b_param = 1 / 2
        integrand = (
            a_param**2 * gs.sum(d_vec_a * d_vec_b, axis=1)
            - (a_param**2 - b_param**2)
            * gs.sum(d_vec_a * unit_velocity_vec, axis=1)
            * gs.sum(d_vec_b * unit_velocity_vec, axis=1)
        ) / velocity_norm
        expected = gs.sum(integrand) / self.n_sampling_points
        self.assertAllClose(result, expected)

    def test_srv_inner_product_and_dist(self):
        """Test that norm of log and dist coincide

        for curves with same / different starting points, and for
        the translation invariant / non invariant SRV metric.
        """
        r3 = Euclidean(dim=3)
        curve_b_transl = self.curve_b + gs.array([1.0, 0.0, 0.0])
        curve_b = [self.curve_b, curve_b_transl]
        translation_invariant = [True, False]
        for curve in curve_b:
            for param in translation_invariant:
                srv_metric = SRVMetric(ambient_manifold=r3, translation_invariant=param)
                log = srv_metric.log(point=curve, base_point=self.curve_a)
                result = srv_metric.norm(vector=log, base_point=self.curve_a)
                expected = srv_metric.dist(self.curve_a, curve)
                self.assertAllClose(result, expected)

    def test_srv_inner_product_vectorization(self):
        """Test inner product of SRVMetric.

        Check vectorization.
        """
        dim = 3
        curves = gs.stack((self.curve_a, self.curve_b))
        tangent_vecs_1 = gs.random.rand(2, self.n_sampling_points, dim)
        tangent_vecs_2 = gs.random.rand(2, self.n_sampling_points, dim)
        result = self.srv_metric_r3.inner_product(
            tangent_vecs_1, tangent_vecs_2, curves
        )

        res_a = self.srv_metric_r3.inner_product(
            tangent_vecs_1[0], tangent_vecs_2[0], self.curve_a
        )
        res_b = self.srv_metric_r3.inner_product(
            tangent_vecs_1[1], tangent_vecs_2[1], self.curve_b
        )
        expected = gs.stack((res_a, res_b))
        self.assertAllClose(result, expected)

    @geomstats.tests.np_autograd_and_torch_only
    def test_split_horizontal_vertical(self):
        """Test split horizontal vertical.

        Check that horizontal and vertical parts of any tangent
        vector are othogonal with respect to the SRVMetric inner
        product, and check vectorization.
        """
        geod = self.srv_metric_r3.geodesic(
            initial_curve=self.curve_a, end_curve=self.curve_b
        )
        geod = geod(self.times)
        tangent_vec = self.n_discretized_curves * (geod[1, :, :] - geod[0, :, :])
        (
            tangent_vec_hor,
            tangent_vec_ver,
            _,
        ) = self.quotient_srv_metric_r3.split_horizontal_vertical(
            tangent_vec, self.curve_a
        )
        result = self.srv_metric_r3.inner_product(
            tangent_vec_hor, tangent_vec_ver, self.curve_a
        )
        expected = 0.0
        self.assertAllClose(result, expected, atol=1e-4)

        tangent_vecs = self.n_discretized_curves * (geod[1:] - geod[:-1])
        _, _, result = self.quotient_srv_metric_r3.split_horizontal_vertical(
            tangent_vecs, geod[:-1]
        )
        expected = []
        for i in range(self.n_discretized_curves - 1):
            _, _, res = self.quotient_srv_metric_r3.split_horizontal_vertical(
                tangent_vecs[i], geod[i]
            )
            expected.append(res)
        expected = gs.stack(expected)
        self.assertAllClose(result, expected)

    def test_space_derivative(self):
        """Test space derivative.

        Check result on an example and vectorization.
        """
        n_points = 3
        dim = 3
        curve = gs.random.rand(n_points, dim)
        result = self.srv_metric_r3.space_derivative(curve)
        delta = 1 / n_points
        d_curve_1 = (curve[1] - curve[0]) / delta
        d_curve_2 = (curve[2] - curve[0]) / (2 * delta)
        d_curve_3 = (curve[2] - curve[1]) / delta
        expected = gs.squeeze(
            gs.vstack(
                (
                    gs.to_ndarray(d_curve_1, 2),
                    gs.to_ndarray(d_curve_2, 2),
                    gs.to_ndarray(d_curve_3, 2),
                )
            )
        )
        self.assertAllClose(result, expected)

        path_of_curves = gs.random.rand(
            self.n_discretized_curves, self.n_sampling_points, dim
        )
        result = self.srv_metric_r3.space_derivative(path_of_curves)
        expected = []
        for i in range(self.n_discretized_curves):
            expected.append(self.srv_metric_r3.space_derivative(path_of_curves[i]))
        expected = gs.stack(expected)
        self.assertAllClose(result, expected)

    @geomstats.tests.np_autograd_and_torch_only
    def test_horizontal_geodesic(self):
        """Test horizontal geodesic.

        Check that the time derivative of the geodesic is
        horizontal at all time.
        """
        curve_b = gs.transpose(
            gs.stack(
                (
                    gs.zeros(self.n_sampling_points),
                    gs.zeros(self.n_sampling_points),
                    gs.linspace(1.0, 0.5, self.n_sampling_points),
                )
            )
        )
        horizontal_geod_fun = self.quotient_srv_metric_r3.horizontal_geodesic(
            self.curve_a, curve_b
        )
        n_times = 20
        times = gs.linspace(0.0, 1.0, n_times)
        horizontal_geod = horizontal_geod_fun(times)
        velocity_vec = n_times * (horizontal_geod[1:] - horizontal_geod[:-1])
        _, _, vertical_norms = self.quotient_srv_metric_r3.split_horizontal_vertical(
            velocity_vec, horizontal_geod[:-1]
        )
        result = gs.sum(vertical_norms**2, axis=1) ** (1 / 2)
        expected = gs.zeros(n_times - 1)
        self.assertAllClose(result, expected, atol=1e-3)

    @geomstats.tests.np_autograd_and_torch_only
    def test_quotient_dist(self):
        """Test quotient distance.

        Check that the quotient distance is the same as the distance
        between the end points of the horizontal geodesic.
        """
        curve_a_resampled = self.curve_fun_a(self.sampling_times**2)
        curve_b = gs.transpose(
            gs.stack(
                (
                    gs.zeros(self.n_sampling_points),
                    gs.zeros(self.n_sampling_points),
                    gs.linspace(1.0, 0.5, self.n_sampling_points),
                )
            )
        )
        result = self.quotient_srv_metric_r3.dist(curve_a_resampled, curve_b)
        expected = self.quotient_srv_metric_r3.dist(self.curve_a, curve_b)
        self.assertAllClose(result, expected, atol=1e-3, rtol=1e-3)
            horizontal_path[j, :, :] = spline_j(phi_inverse(t_space))

        phi_inverse = CubicSpline(phi[-1, :], t_space)
        horizontal_path[-1, :, :] = spline_b(phi_inverse(t_space))

        new_end_curve = horizontal_path[-1, :, :]
        gap = (np.sum(np.linalg.norm(new_end_curve - end_curve,
                                     axis=-1)**2))**(1 / 2)
        end_curve = new_end_curve.copy()
        print(gap)

    return horizontal_path


R3 = Euclidean(dim=3)
curves3D = DiscreteCurves(ambient_manifold=R3)

n_files = 10
n_points_F = np.zeros(n_files)
for i in range(n_files):
    F = np.loadtxt('data/femme{}.txt'.format(i + 1))
    n_points_F[i] = F.shape[0]

ns = np.min(n_points_F) - 1
F = np.zeros((n_files, int(ns), 3))
for i in range(n_files):
    F1 = np.loadtxt('data/femme{}.txt'.format(i + 1))
    F[i, :, :] = F1[np.floor(np.linspace(0, F1.shape[0] -
                                         1, int(ns))).astype('int'), :] - F1[0]

c1 = F[1, 199:399, :]
def horizontal_geodesic(curve_a, curve_b, n_times=10, threshold=1e-3):
    """Compute the horizontal geodesic between curve_a and curve_b in the fiber bundle
    induced by the action of reparameterizations.
    """
    dim = curve_a.shape[1]
    Rdim = Euclidean(dim)
    curves = DiscreteCurves(ambient_manifold=Rdim)

    n_points = curve_a.shape[0] - 1
    t_space = np.linspace(0., 1., n_points + 1)
    t_time = np.linspace(0., 1., n_times + 1)

    spline_a = CubicSpline(t_space, curve_a, axis=0)
    spline_b = CubicSpline(t_space, curve_b, axis=0)

    initial_curve = curve_a.copy()
    end_curve = curve_b.copy()
    gap = 1.

    while (gap > threshold):

        # Compute geodesic path of curves
        srv_geod_fun = curves.square_root_velocity_metric.geodesic(
            initial_curve=initial_curve, end_curve=end_curve)
        geod = srv_geod_fun(t_time)
        M, K, cs_ver, cs_hor = hvsplit(geod)

        # Compute path of reparameterizations
        phi = np.zeros((n_times + 1, n_points + 1))
        phi_t = np.zeros((n_times + 1, n_points))
        phi_s = np.zeros((n_times, n_points))
        test_phi = np.zeros(n_times)
        phi[0, :] = np.linspace(0., 1., n_points + 1)
        phi[:, -1] = np.ones(n_times + 1)
        for j in range(n_times):
            phi_t[j, 0] = n_points * (phi[j, 1] - phi[j, 0])
            phi_s[j, 0] = phi_t[j, 0] * M[j, 0] / (n_points * K[j, 0])
            phi[j + 1, 0] = phi[j, 0] + 1 / n_times * phi_s[j, 0]
            for k in range(1, n_points):  # Matlab k = 2 : n
                if M[j, k] > 0:
                    phi_t[j, k] = n_points * (phi[j, k + 1] - phi[j, k])
                else:
                    phi_t[j, k] = n_points * (phi[j, k] - phi[j, k - 1])

                phi_s[j, k] = phi_t[j, k] * M[j, k] / (n_points * K[j, k])
                phi[j + 1, k] = phi[j, k] + 1 / n_times * phi_s[j, k]

            test_phi[j] = np.sum(phi[j + 1, 2:] - phi[j + 1, 1:-1] < 0)
            if np.any(test_phi):
                print(test_phi)
                print(
                    'Warning: phi(s) is non increasing for at least one time s.'
                )

        # Compute horizontal path of curves
        horizontal_path = np.zeros(geod.shape)
        horizontal_path[0, :, :] = curve_a
        for j in range(1, n_times):
            spline_j = CubicSpline(t_space, geod[j, :, :], axis=0)
            phi_inverse = CubicSpline(phi[j, :], t_space)
            horizontal_path[j, :, :] = spline_j(phi_inverse(t_space))

        phi_inverse = CubicSpline(phi[-1, :], t_space)
        horizontal_path[-1, :, :] = spline_b(phi_inverse(t_space))

        new_end_curve = horizontal_path[-1, :, :]
        gap = (np.sum(np.linalg.norm(new_end_curve - end_curve,
                                     axis=-1)**2))**(1 / 2)
        end_curve = new_end_curve.copy()
        print(gap)

    return horizontal_path
Exemplo n.º 21
0
class TestGeodesicRegression(geomstats.tests.TestCase):
    _multiprocess_can_split_ = True

    def setup_method(self):
        gs.random.seed(1234)
        self.n_samples = 20

        # Set up for euclidean
        self.dim_eucl = 3
        self.shape_eucl = (self.dim_eucl, )
        self.eucl = Euclidean(dim=self.dim_eucl)
        X = gs.random.rand(self.n_samples)
        self.X_eucl = X - gs.mean(X)
        self.intercept_eucl_true = self.eucl.random_point()
        self.coef_eucl_true = self.eucl.random_point()

        self.y_eucl = (self.intercept_eucl_true +
                       self.X_eucl[:, None] * self.coef_eucl_true)
        self.param_eucl_true = gs.vstack(
            [self.intercept_eucl_true, self.coef_eucl_true])
        self.param_eucl_guess = gs.vstack([
            self.y_eucl[0],
            self.y_eucl[0] + gs.random.normal(size=self.shape_eucl)
        ])

        # Set up for hypersphere
        self.dim_sphere = 4
        self.shape_sphere = (self.dim_sphere + 1, )
        self.sphere = Hypersphere(dim=self.dim_sphere)
        X = gs.random.rand(self.n_samples)
        self.X_sphere = X - gs.mean(X)
        self.intercept_sphere_true = self.sphere.random_point()
        self.coef_sphere_true = self.sphere.projection(
            gs.random.rand(self.dim_sphere + 1))

        self.y_sphere = self.sphere.metric.exp(
            self.X_sphere[:, None] * self.coef_sphere_true,
            base_point=self.intercept_sphere_true,
        )

        self.param_sphere_true = gs.vstack(
            [self.intercept_sphere_true, self.coef_sphere_true])
        self.param_sphere_guess = gs.vstack([
            self.y_sphere[0],
            self.sphere.to_tangent(gs.random.normal(size=self.shape_sphere),
                                   self.y_sphere[0]),
        ])

        # Set up for special euclidean
        self.se2 = SpecialEuclidean(n=2)
        self.metric_se2 = self.se2.left_canonical_metric
        self.metric_se2.default_point_type = "matrix"

        self.shape_se2 = (3, 3)
        X = gs.random.rand(self.n_samples)
        self.X_se2 = X - gs.mean(X)

        self.intercept_se2_true = self.se2.random_point()
        self.coef_se2_true = self.se2.to_tangent(
            5.0 * gs.random.rand(*self.shape_se2), self.intercept_se2_true)

        self.y_se2 = self.metric_se2.exp(
            self.X_se2[:, None, None] * self.coef_se2_true[None],
            self.intercept_se2_true,
        )

        self.param_se2_true = gs.vstack([
            gs.flatten(self.intercept_se2_true),
            gs.flatten(self.coef_se2_true),
        ])
        self.param_se2_guess = gs.vstack([
            gs.flatten(self.y_se2[0]),
            gs.flatten(
                self.se2.to_tangent(gs.random.normal(size=self.shape_se2),
                                    self.y_se2[0])),
        ])

        # Set up for discrete curves
        n_sampling_points = 8
        self.curves_2d = DiscreteCurves(R2)
        self.metric_curves_2d = self.curves_2d.srv_metric
        self.metric_curves_2d.default_point_type = "matrix"

        self.shape_curves_2d = (n_sampling_points, 2)
        X = gs.random.rand(self.n_samples)
        self.X_curves_2d = X - gs.mean(X)

        self.intercept_curves_2d_true = self.curves_2d.random_point(
            n_sampling_points=n_sampling_points)
        self.coef_curves_2d_true = self.curves_2d.to_tangent(
            5.0 * gs.random.rand(*self.shape_curves_2d),
            self.intercept_curves_2d_true)

        # Added because of GitHub issue #1575
        intercept_curves_2d_true_repeated = gs.tile(
            gs.expand_dims(self.intercept_curves_2d_true, axis=0),
            (self.n_samples, 1, 1),
        )
        self.y_curves_2d = self.metric_curves_2d.exp(
            self.X_curves_2d[:, None, None] * self.coef_curves_2d_true[None],
            intercept_curves_2d_true_repeated,
        )

        self.param_curves_2d_true = gs.vstack([
            gs.flatten(self.intercept_curves_2d_true),
            gs.flatten(self.coef_curves_2d_true),
        ])
        self.param_curves_2d_guess = gs.vstack([
            gs.flatten(self.y_curves_2d[0]),
            gs.flatten(
                self.curves_2d.to_tangent(
                    gs.random.normal(size=self.shape_curves_2d),
                    self.y_curves_2d[0])),
        ])

    def test_loss_euclidean(self):
        """Test that the loss is 0 at the true parameters."""
        gr = GeodesicRegression(
            self.eucl,
            metric=self.eucl.metric,
            center_X=False,
            method="extrinsic",
            max_iter=50,
            init_step_size=0.1,
            verbose=True,
        )
        loss = gr._loss(
            self.X_eucl,
            self.y_eucl,
            self.param_eucl_true,
            self.shape_eucl,
        )
        self.assertAllClose(loss.shape, ())
        self.assertTrue(gs.isclose(loss, 0.0))

    def test_loss_hypersphere(self):
        """Test that the loss is 0 at the true parameters."""
        gr = GeodesicRegression(
            self.sphere,
            metric=self.sphere.metric,
            center_X=False,
            method="extrinsic",
            max_iter=50,
            init_step_size=0.1,
            verbose=True,
        )
        loss = gr._loss(
            self.X_sphere,
            self.y_sphere,
            self.param_sphere_true,
            self.shape_sphere,
        )
        self.assertAllClose(loss.shape, ())
        self.assertTrue(gs.isclose(loss, 0.0))

    @geomstats.tests.autograd_and_tf_only
    def test_loss_se2(self):
        """Test that the loss is 0 at the true parameters."""
        gr = GeodesicRegression(
            self.se2,
            metric=self.metric_se2,
            center_X=False,
            method="extrinsic",
            max_iter=50,
            init_step_size=0.1,
            verbose=True,
        )
        loss = gr._loss(self.X_se2, self.y_se2, self.param_se2_true,
                        self.shape_se2)
        self.assertAllClose(loss.shape, ())
        self.assertTrue(gs.isclose(loss, 0.0))

    @geomstats.tests.autograd_only
    def test_loss_curves_2d(self):
        """Test that the loss is 0 at the true parameters."""
        gr = GeodesicRegression(
            self.curves_2d,
            metric=self.metric_curves_2d,
            center_X=False,
            method="extrinsic",
            max_iter=50,
            init_step_size=0.1,
            verbose=True,
        )
        loss = gr._loss(
            self.X_curves_2d,
            self.y_curves_2d,
            self.param_curves_2d_true,
            self.shape_curves_2d,
        )
        self.assertAllClose(loss.shape, ())
        self.assertTrue(gs.isclose(loss, 0.0))

    @geomstats.tests.autograd_tf_and_torch_only
    def test_value_and_grad_loss_euclidean(self):
        gr = GeodesicRegression(
            self.eucl,
            metric=self.eucl.metric,
            center_X=False,
            method="extrinsic",
            max_iter=50,
            init_step_size=0.1,
            verbose=True,
            regularization=0,
        )

        def loss_of_param(param):
            return gr._loss(self.X_eucl, self.y_eucl, param, self.shape_eucl)

        # Without numpy conversion
        objective_with_grad = gs.autodiff.value_and_grad(loss_of_param)
        loss_value, loss_grad = objective_with_grad(self.param_eucl_guess)

        expected_grad_shape = (2, self.dim_eucl)
        self.assertAllClose(loss_value.shape, ())
        self.assertAllClose(loss_grad.shape, expected_grad_shape)

        self.assertFalse(gs.isclose(loss_value, 0.0))
        self.assertFalse(gs.isnan(loss_value))
        self.assertFalse(
            gs.all(gs.isclose(loss_grad, gs.zeros(expected_grad_shape))))
        self.assertTrue(gs.all(~gs.isnan(loss_grad)))

        # With numpy conversion
        objective_with_grad = gs.autodiff.value_and_grad(loss_of_param,
                                                         to_numpy=True)
        loss_value, loss_grad = objective_with_grad(self.param_eucl_guess)
        # Convert back to arrays/tensors
        loss_value = gs.array(loss_value)
        loss_grad = gs.array(loss_grad)

        expected_grad_shape = (2, self.dim_eucl)
        self.assertAllClose(loss_value.shape, ())
        self.assertAllClose(loss_grad.shape, expected_grad_shape)

        self.assertFalse(gs.isclose(loss_value, 0.0))
        self.assertFalse(gs.isnan(loss_value))
        self.assertFalse(
            gs.all(gs.isclose(loss_grad, gs.zeros(expected_grad_shape))))
        self.assertTrue(gs.all(~gs.isnan(loss_grad)))

    @geomstats.tests.autograd_tf_and_torch_only
    def test_value_and_grad_loss_hypersphere(self):
        gr = GeodesicRegression(
            self.sphere,
            metric=self.sphere.metric,
            center_X=False,
            method="extrinsic",
            max_iter=50,
            init_step_size=0.1,
            verbose=True,
            regularization=0,
        )

        def loss_of_param(param):
            return gr._loss(self.X_sphere, self.y_sphere, param,
                            self.shape_sphere)

        # Without numpy conversion
        objective_with_grad = gs.autodiff.value_and_grad(loss_of_param)
        loss_value, loss_grad = objective_with_grad(self.param_sphere_guess)

        expected_grad_shape = (2, self.dim_sphere + 1)
        self.assertAllClose(loss_value.shape, ())
        self.assertAllClose(loss_grad.shape, expected_grad_shape)

        self.assertFalse(gs.isclose(loss_value, 0.0))
        self.assertFalse(gs.isnan(loss_value))
        self.assertFalse(
            gs.all(gs.isclose(loss_grad, gs.zeros(expected_grad_shape))))
        self.assertTrue(gs.all(~gs.isnan(loss_grad)))

        # With numpy conversion
        objective_with_grad = gs.autodiff.value_and_grad(loss_of_param,
                                                         to_numpy=True)
        loss_value, loss_grad = objective_with_grad(self.param_sphere_guess)
        # Convert back to arrays/tensors
        loss_value = gs.array(loss_value)
        loss_grad = gs.array(loss_grad)

        expected_grad_shape = (2, self.dim_sphere + 1)
        self.assertAllClose(loss_value.shape, ())
        self.assertAllClose(loss_grad.shape, expected_grad_shape)

        self.assertFalse(gs.isclose(loss_value, 0.0))
        self.assertFalse(gs.isnan(loss_value))
        self.assertFalse(
            gs.all(gs.isclose(loss_grad, gs.zeros(expected_grad_shape))))
        self.assertTrue(gs.all(~gs.isnan(loss_grad)))

    @geomstats.tests.autograd_and_tf_only
    def test_value_and_grad_loss_se2(self):

        gr = GeodesicRegression(
            self.se2,
            metric=self.metric_se2,
            center_X=False,
            method="extrinsic",
            max_iter=50,
            init_step_size=0.1,
            verbose=True,
        )

        def loss_of_param(param):
            return gr._loss(self.X_se2, self.y_se2, param, self.shape_se2)

        objective_with_grad = gs.autodiff.value_and_grad(loss_of_param)
        loss_value, loss_grad = objective_with_grad(self.param_se2_true)
        expected_grad_shape = (
            2,
            self.shape_se2[0] * self.shape_se2[1],
        )

        self.assertTrue(gs.isclose(loss_value, 0.0))

        loss_value, loss_grad = objective_with_grad(self.param_se2_guess)

        self.assertAllClose(loss_value.shape, ())
        self.assertAllClose(loss_grad.shape, expected_grad_shape)

        self.assertFalse(gs.isclose(loss_value, 0.0))
        self.assertFalse(
            gs.all(gs.isclose(loss_grad, gs.zeros(expected_grad_shape))))
        self.assertTrue(gs.all(~gs.isnan(loss_grad)))

        objective_with_grad = gs.autodiff.value_and_grad(loss_of_param,
                                                         to_numpy=True)
        loss_value, loss_grad = objective_with_grad(self.param_se2_guess)
        expected_grad_shape = (
            2,
            self.shape_se2[0] * self.shape_se2[1],
        )
        self.assertAllClose(loss_value.shape, ())
        self.assertAllClose(loss_grad.shape, expected_grad_shape)

        self.assertFalse(gs.isclose(loss_value, 0.0))
        self.assertFalse(gs.isnan(loss_value))
        self.assertFalse(
            gs.all(gs.isclose(loss_grad, gs.zeros(expected_grad_shape))))
        self.assertTrue(gs.all(~gs.isnan(loss_grad)))

    @geomstats.tests.autograd_tf_and_torch_only
    def test_loss_minimization_extrinsic_euclidean(self):
        """Minimize loss from noiseless data."""
        gr = GeodesicRegression(self.eucl, regularization=0)

        def loss_of_param(param):
            return gr._loss(self.X_eucl, self.y_eucl, param, self.shape_eucl)

        objective_with_grad = gs.autodiff.value_and_grad(loss_of_param,
                                                         to_numpy=True)
        initial_guess = gs.flatten(self.param_eucl_guess)
        res = minimize(
            objective_with_grad,
            initial_guess,
            method="CG",
            jac=True,
            tol=10 * gs.atol,
            options={
                "disp": True,
                "maxiter": 50
            },
        )
        self.assertAllClose(gs.array(res.x).shape, (self.dim_eucl * 2, ))
        self.assertAllClose(res.fun, 0.0, atol=1000 * gs.atol)

        # Cast required because minimization happens in scipy in float64
        param_hat = gs.cast(gs.array(res.x), self.param_eucl_true.dtype)

        intercept_hat, coef_hat = gs.split(param_hat, 2)
        coef_hat = self.eucl.to_tangent(coef_hat, intercept_hat)
        self.assertAllClose(intercept_hat, self.intercept_eucl_true)

        tangent_vec_of_transport = self.eucl.metric.log(
            self.intercept_eucl_true, base_point=intercept_hat)

        transported_coef_hat = self.eucl.metric.parallel_transport(
            tangent_vec=coef_hat,
            base_point=intercept_hat,
            direction=tangent_vec_of_transport,
        )

        self.assertAllClose(transported_coef_hat,
                            self.coef_eucl_true,
                            atol=10 * gs.atol)

    @geomstats.tests.autograd_tf_and_torch_only
    def test_loss_minimization_extrinsic_hypersphere(self):
        """Minimize loss from noiseless data."""
        gr = GeodesicRegression(self.sphere, regularization=0)

        def loss_of_param(param):
            return gr._loss(self.X_sphere, self.y_sphere, param,
                            self.shape_sphere)

        objective_with_grad = gs.autodiff.value_and_grad(loss_of_param,
                                                         to_numpy=True)
        initial_guess = gs.flatten(self.param_sphere_guess)
        res = minimize(
            objective_with_grad,
            initial_guess,
            method="CG",
            jac=True,
            tol=10 * gs.atol,
            options={
                "disp": True,
                "maxiter": 50
            },
        )
        self.assertAllClose(
            gs.array(res.x).shape, ((self.dim_sphere + 1) * 2, ))
        self.assertAllClose(res.fun, 0.0, atol=5e-3)

        # Cast required because minimization happens in scipy in float64
        param_hat = gs.cast(gs.array(res.x), self.param_sphere_true.dtype)

        intercept_hat, coef_hat = gs.split(param_hat, 2)
        intercept_hat = self.sphere.projection(intercept_hat)
        coef_hat = self.sphere.to_tangent(coef_hat, intercept_hat)
        self.assertAllClose(intercept_hat,
                            self.intercept_sphere_true,
                            atol=5e-2)

        tangent_vec_of_transport = self.sphere.metric.log(
            self.intercept_sphere_true, base_point=intercept_hat)

        transported_coef_hat = self.sphere.metric.parallel_transport(
            tangent_vec=coef_hat,
            base_point=intercept_hat,
            direction=tangent_vec_of_transport,
        )

        self.assertAllClose(transported_coef_hat,
                            self.coef_sphere_true,
                            atol=0.6)

    @geomstats.tests.autograd_and_tf_only
    def test_loss_minimization_extrinsic_se2(self):
        gr = GeodesicRegression(
            self.se2,
            metric=self.metric_se2,
            center_X=False,
            method="extrinsic",
            max_iter=50,
            init_step_size=0.1,
            verbose=True,
        )

        def loss_of_param(param):
            return gr._loss(self.X_se2, self.y_se2, param, self.shape_se2)

        objective_with_grad = gs.autodiff.value_and_grad(loss_of_param,
                                                         to_numpy=True)

        res = minimize(
            objective_with_grad,
            gs.flatten(self.param_se2_guess),
            method="CG",
            jac=True,
            options={
                "disp": True,
                "maxiter": 50
            },
        )
        self.assertAllClose(gs.array(res.x).shape, (18, ))

        self.assertAllClose(res.fun, 0.0, atol=1e-6)

        # Cast required because minimization happens in scipy in float64
        param_hat = gs.cast(gs.array(res.x), self.param_se2_true.dtype)

        intercept_hat, coef_hat = gs.split(param_hat, 2)
        intercept_hat = gs.reshape(intercept_hat, self.shape_se2)
        coef_hat = gs.reshape(coef_hat, self.shape_se2)

        intercept_hat = self.se2.projection(intercept_hat)
        coef_hat = self.se2.to_tangent(coef_hat, intercept_hat)
        self.assertAllClose(intercept_hat, self.intercept_se2_true, atol=1e-4)

        tangent_vec_of_transport = self.se2.metric.log(
            self.intercept_se2_true, base_point=intercept_hat)

        transported_coef_hat = self.se2.metric.parallel_transport(
            tangent_vec=coef_hat,
            base_point=intercept_hat,
            direction=tangent_vec_of_transport,
        )

        self.assertAllClose(transported_coef_hat, self.coef_se2_true, atol=0.6)

    @geomstats.tests.autograd_tf_and_torch_only
    def test_fit_extrinsic_euclidean(self):
        gr = GeodesicRegression(
            self.eucl,
            metric=self.eucl.metric,
            center_X=False,
            method="extrinsic",
            max_iter=50,
            init_step_size=0.1,
            verbose=True,
            initialization="random",
            regularization=0.9,
        )

        gr.fit(self.X_eucl, self.y_eucl, compute_training_score=True)

        training_score = gr.training_score_
        intercept_hat, coef_hat = gr.intercept_, gr.coef_
        self.assertAllClose(intercept_hat.shape, self.shape_eucl)
        self.assertAllClose(coef_hat.shape, self.shape_eucl)
        self.assertAllClose(training_score, 1.0, atol=500 * gs.atol)
        self.assertAllClose(intercept_hat, self.intercept_eucl_true)

        tangent_vec_of_transport = self.eucl.metric.log(
            self.intercept_eucl_true, base_point=intercept_hat)

        transported_coef_hat = self.eucl.metric.parallel_transport(
            tangent_vec=coef_hat,
            base_point=intercept_hat,
            direction=tangent_vec_of_transport,
        )

        self.assertAllClose(transported_coef_hat, self.coef_eucl_true)

    @geomstats.tests.autograd_tf_and_torch_only
    def test_fit_extrinsic_hypersphere(self):
        gr = GeodesicRegression(
            self.sphere,
            metric=self.sphere.metric,
            center_X=False,
            method="extrinsic",
            max_iter=50,
            init_step_size=0.1,
            verbose=True,
            initialization="random",
            regularization=0.9,
        )

        gr.fit(self.X_sphere, self.y_sphere, compute_training_score=True)

        training_score = gr.training_score_
        intercept_hat, coef_hat = gr.intercept_, gr.coef_
        self.assertAllClose(intercept_hat.shape, self.shape_sphere)
        self.assertAllClose(coef_hat.shape, self.shape_sphere)
        self.assertAllClose(training_score, 1.0, atol=500 * gs.atol)
        self.assertAllClose(intercept_hat,
                            self.intercept_sphere_true,
                            atol=5e-3)

        tangent_vec_of_transport = self.sphere.metric.log(
            self.intercept_sphere_true, base_point=intercept_hat)

        transported_coef_hat = self.sphere.metric.parallel_transport(
            tangent_vec=coef_hat,
            base_point=intercept_hat,
            direction=tangent_vec_of_transport,
        )

        self.assertAllClose(transported_coef_hat,
                            self.coef_sphere_true,
                            atol=0.6)

    @geomstats.tests.autograd_and_tf_only
    def test_fit_extrinsic_se2(self):
        gr = GeodesicRegression(
            self.se2,
            metric=self.metric_se2,
            center_X=False,
            method="extrinsic",
            max_iter=50,
            init_step_size=0.1,
            verbose=True,
            initialization="warm_start",
        )

        gr.fit(self.X_se2, self.y_se2, compute_training_score=True)
        intercept_hat, coef_hat = gr.intercept_, gr.coef_
        training_score = gr.training_score_

        self.assertAllClose(intercept_hat.shape, self.shape_se2)
        self.assertAllClose(coef_hat.shape, self.shape_se2)
        self.assertTrue(gs.isclose(training_score, 1.0))
        self.assertAllClose(intercept_hat, self.intercept_se2_true, atol=1e-4)

        tangent_vec_of_transport = self.se2.metric.log(
            self.intercept_se2_true, base_point=intercept_hat)

        transported_coef_hat = self.se2.metric.parallel_transport(
            tangent_vec=coef_hat,
            base_point=intercept_hat,
            direction=tangent_vec_of_transport,
        )

        self.assertAllClose(transported_coef_hat, self.coef_se2_true, atol=0.6)

    @geomstats.tests.autograd_tf_and_torch_only
    def test_fit_riemannian_euclidean(self):
        gr = GeodesicRegression(
            self.eucl,
            metric=self.eucl.metric,
            center_X=False,
            method="riemannian",
            max_iter=50,
            init_step_size=0.1,
            verbose=True,
        )

        gr.fit(self.X_eucl, self.y_eucl, compute_training_score=True)
        intercept_hat, coef_hat = gr.intercept_, gr.coef_
        training_score = gr.training_score_

        self.assertAllClose(intercept_hat.shape, self.shape_eucl)
        self.assertAllClose(coef_hat.shape, self.shape_eucl)

        self.assertAllClose(training_score, 1.0, atol=0.1)
        self.assertAllClose(intercept_hat, self.intercept_eucl_true)

        tangent_vec_of_transport = self.eucl.metric.log(
            self.intercept_eucl_true, base_point=intercept_hat)

        transported_coef_hat = self.eucl.metric.parallel_transport(
            tangent_vec=coef_hat,
            base_point=intercept_hat,
            direction=tangent_vec_of_transport,
        )

        self.assertAllClose(transported_coef_hat,
                            self.coef_eucl_true,
                            atol=1e-2)

    @geomstats.tests.autograd_tf_and_torch_only
    def test_fit_riemannian_hypersphere(self):
        gr = GeodesicRegression(
            self.sphere,
            metric=self.sphere.metric,
            center_X=False,
            method="riemannian",
            max_iter=50,
            init_step_size=0.1,
            verbose=True,
        )

        gr.fit(self.X_sphere, self.y_sphere, compute_training_score=True)
        intercept_hat, coef_hat = gr.intercept_, gr.coef_
        training_score = gr.training_score_

        self.assertAllClose(intercept_hat.shape, self.shape_sphere)
        self.assertAllClose(coef_hat.shape, self.shape_sphere)

        self.assertAllClose(training_score, 1.0, atol=0.1)
        self.assertAllClose(intercept_hat,
                            self.intercept_sphere_true,
                            atol=1e-2)

        tangent_vec_of_transport = self.sphere.metric.log(
            self.intercept_sphere_true, base_point=intercept_hat)

        transported_coef_hat = self.sphere.metric.parallel_transport(
            tangent_vec=coef_hat,
            base_point=intercept_hat,
            direction=tangent_vec_of_transport,
        )

        self.assertAllClose(transported_coef_hat,
                            self.coef_sphere_true,
                            atol=0.6)

    @geomstats.tests.autograd_and_tf_only
    def test_fit_riemannian_se2(self):
        init = (self.y_se2[0], gs.zeros_like(self.y_se2[0]))
        gr = GeodesicRegression(
            self.se2,
            metric=self.metric_se2,
            center_X=False,
            method="riemannian",
            max_iter=50,
            init_step_size=0.1,
            verbose=True,
            initialization=init,
        )

        gr.fit(self.X_se2, self.y_se2, compute_training_score=True)
        intercept_hat, coef_hat = gr.intercept_, gr.coef_
        training_score = gr.training_score_

        self.assertAllClose(intercept_hat.shape, self.shape_se2)
        self.assertAllClose(coef_hat.shape, self.shape_se2)
        self.assertAllClose(training_score, 1.0, atol=1e-4)
        self.assertAllClose(intercept_hat, self.intercept_se2_true, atol=1e-4)

        tangent_vec_of_transport = self.se2.metric.log(
            self.intercept_se2_true, base_point=intercept_hat)

        transported_coef_hat = self.se2.metric.parallel_transport(
            tangent_vec=coef_hat,
            base_point=intercept_hat,
            direction=tangent_vec_of_transport,
        )

        self.assertAllClose(transported_coef_hat, self.coef_se2_true, atol=0.6)
Exemplo n.º 22
0
    def setup_method(self):
        gs.random.seed(1234)
        self.n_samples = 20

        # Set up for euclidean
        self.dim_eucl = 3
        self.shape_eucl = (self.dim_eucl, )
        self.eucl = Euclidean(dim=self.dim_eucl)
        X = gs.random.rand(self.n_samples)
        self.X_eucl = X - gs.mean(X)
        self.intercept_eucl_true = self.eucl.random_point()
        self.coef_eucl_true = self.eucl.random_point()

        self.y_eucl = (self.intercept_eucl_true +
                       self.X_eucl[:, None] * self.coef_eucl_true)
        self.param_eucl_true = gs.vstack(
            [self.intercept_eucl_true, self.coef_eucl_true])
        self.param_eucl_guess = gs.vstack([
            self.y_eucl[0],
            self.y_eucl[0] + gs.random.normal(size=self.shape_eucl)
        ])

        # Set up for hypersphere
        self.dim_sphere = 4
        self.shape_sphere = (self.dim_sphere + 1, )
        self.sphere = Hypersphere(dim=self.dim_sphere)
        X = gs.random.rand(self.n_samples)
        self.X_sphere = X - gs.mean(X)
        self.intercept_sphere_true = self.sphere.random_point()
        self.coef_sphere_true = self.sphere.projection(
            gs.random.rand(self.dim_sphere + 1))

        self.y_sphere = self.sphere.metric.exp(
            self.X_sphere[:, None] * self.coef_sphere_true,
            base_point=self.intercept_sphere_true,
        )

        self.param_sphere_true = gs.vstack(
            [self.intercept_sphere_true, self.coef_sphere_true])
        self.param_sphere_guess = gs.vstack([
            self.y_sphere[0],
            self.sphere.to_tangent(gs.random.normal(size=self.shape_sphere),
                                   self.y_sphere[0]),
        ])

        # Set up for special euclidean
        self.se2 = SpecialEuclidean(n=2)
        self.metric_se2 = self.se2.left_canonical_metric
        self.metric_se2.default_point_type = "matrix"

        self.shape_se2 = (3, 3)
        X = gs.random.rand(self.n_samples)
        self.X_se2 = X - gs.mean(X)

        self.intercept_se2_true = self.se2.random_point()
        self.coef_se2_true = self.se2.to_tangent(
            5.0 * gs.random.rand(*self.shape_se2), self.intercept_se2_true)

        self.y_se2 = self.metric_se2.exp(
            self.X_se2[:, None, None] * self.coef_se2_true[None],
            self.intercept_se2_true,
        )

        self.param_se2_true = gs.vstack([
            gs.flatten(self.intercept_se2_true),
            gs.flatten(self.coef_se2_true),
        ])
        self.param_se2_guess = gs.vstack([
            gs.flatten(self.y_se2[0]),
            gs.flatten(
                self.se2.to_tangent(gs.random.normal(size=self.shape_se2),
                                    self.y_se2[0])),
        ])

        # Set up for discrete curves
        n_sampling_points = 8
        self.curves_2d = DiscreteCurves(R2)
        self.metric_curves_2d = self.curves_2d.srv_metric
        self.metric_curves_2d.default_point_type = "matrix"

        self.shape_curves_2d = (n_sampling_points, 2)
        X = gs.random.rand(self.n_samples)
        self.X_curves_2d = X - gs.mean(X)

        self.intercept_curves_2d_true = self.curves_2d.random_point(
            n_sampling_points=n_sampling_points)
        self.coef_curves_2d_true = self.curves_2d.to_tangent(
            5.0 * gs.random.rand(*self.shape_curves_2d),
            self.intercept_curves_2d_true)

        # Added because of GitHub issue #1575
        intercept_curves_2d_true_repeated = gs.tile(
            gs.expand_dims(self.intercept_curves_2d_true, axis=0),
            (self.n_samples, 1, 1),
        )
        self.y_curves_2d = self.metric_curves_2d.exp(
            self.X_curves_2d[:, None, None] * self.coef_curves_2d_true[None],
            intercept_curves_2d_true_repeated,
        )

        self.param_curves_2d_true = gs.vstack([
            gs.flatten(self.intercept_curves_2d_true),
            gs.flatten(self.coef_curves_2d_true),
        ])
        self.param_curves_2d_guess = gs.vstack([
            gs.flatten(self.y_curves_2d[0]),
            gs.flatten(
                self.curves_2d.to_tangent(
                    gs.random.normal(size=self.shape_curves_2d),
                    self.y_curves_2d[0])),
        ])
Exemplo n.º 23
0
class TestFrechetMean(geomstats.tests.TestCase):
    _multiprocess_can_split_ = True

    def setup_method(self):
        gs.random.seed(123)
        self.sphere = Hypersphere(dim=4)
        self.hyperbolic = Hyperboloid(dim=3)
        self.euclidean = Euclidean(dim=2)
        self.minkowski = Minkowski(dim=2)
        self.so3 = SpecialOrthogonal(n=3, point_type="vector")
        self.so_matrix = SpecialOrthogonal(n=3)
        self.curves_2d = DiscreteCurves(R2)
        self.elastic_metric = ElasticMetric(a=1, b=1, ambient_manifold=R2)

    def test_logs_at_mean_curves_2d(self):
        n_tests = 10
        metric = self.curves_2d.srv_metric
        estimator = FrechetMean(metric=metric, init_step_size=1.0)

        result = []
        for _ in range(n_tests):
            # take 2 random points, compute their mean, and verify that
            # log of each at the mean is opposite
            points = self.curves_2d.random_point(n_samples=2)
            estimator.fit(points)
            mean = estimator.estimate_

            # Expand and tile are added because of GitHub issue 1575
            mean = gs.expand_dims(mean, axis=0)
            mean = gs.tile(mean, (2, 1, 1))

            logs = metric.log(point=points, base_point=mean)
            logs_srv = metric.aux_differential_srv_transform(logs, curve=mean)
            # Note that the logs are NOT inverse, only the logs_srv are.
            result.append(gs.linalg.norm(logs_srv[1, :] + logs_srv[0, :]))
        result = gs.stack(result)
        expected = gs.zeros(n_tests)
        self.assertAllClose(expected, result, atol=1e-5)

    def test_logs_at_mean_default_gradient_descent_sphere(self):
        n_tests = 10
        estimator = FrechetMean(metric=self.sphere.metric,
                                method="default",
                                init_step_size=1.0)

        result = []
        for _ in range(n_tests):
            # take 2 random points, compute their mean, and verify that
            # log of each at the mean is opposite
            points = self.sphere.random_uniform(n_samples=2)
            estimator.fit(points)
            mean = estimator.estimate_

            logs = self.sphere.metric.log(point=points, base_point=mean)
            result.append(gs.linalg.norm(logs[1, :] + logs[0, :]))
        result = gs.stack(result)
        expected = gs.zeros(n_tests)
        self.assertAllClose(expected, result)

    def test_logs_at_mean_adaptive_gradient_descent_sphere(self):
        n_tests = 10
        estimator = FrechetMean(metric=self.sphere.metric, method="adaptive")

        result = []
        for _ in range(n_tests):
            # take 2 random points, compute their mean, and verify that
            # log of each at the mean is opposite
            points = self.sphere.random_uniform(n_samples=2)
            estimator.fit(points)
            mean = estimator.estimate_

            logs = self.sphere.metric.log(point=points, base_point=mean)
            result.append(gs.linalg.norm(logs[1, :] + logs[0, :]))
        result = gs.stack(result)

        expected = gs.zeros(n_tests)
        self.assertAllClose(expected, result)

    def test_estimate_shape_default_gradient_descent_sphere(self):
        dim = 5
        point_a = gs.array([1.0, 0.0, 0.0, 0.0, 0.0])
        point_b = gs.array([0.0, 1.0, 0.0, 0.0, 0.0])
        points = gs.array([point_a, point_b])

        mean = FrechetMean(metric=self.sphere.metric,
                           method="default",
                           verbose=True)
        mean.fit(points)
        result = mean.estimate_

        self.assertAllClose(gs.shape(result), (dim, ))

    def test_estimate_shape_adaptive_gradient_descent_sphere(self):
        dim = 5
        point_a = gs.array([1.0, 0.0, 0.0, 0.0, 0.0])
        point_b = gs.array([0.0, 1.0, 0.0, 0.0, 0.0])
        points = gs.array([point_a, point_b])

        mean = FrechetMean(metric=self.sphere.metric, method="adaptive")
        mean.fit(points)
        result = mean.estimate_

        self.assertAllClose(gs.shape(result), (dim, ))

    def test_estimate_shape_elastic_metric(self):
        points = self.curves_2d.random_point(n_samples=2)

        mean = FrechetMean(metric=self.elastic_metric)
        mean.fit(points)
        result = mean.estimate_

        self.assertAllClose(gs.shape(result), (points.shape[1:]))

    def test_estimate_shape_metric(self):
        points = self.curves_2d.random_point(n_samples=2)

        mean = FrechetMean(metric=self.curves_2d.srv_metric)
        mean.fit(points)
        result = mean.estimate_

        self.assertAllClose(gs.shape(result), (points.shape[1:]))

    def test_estimate_and_belongs_default_gradient_descent_sphere(self):
        point_a = gs.array([1.0, 0.0, 0.0, 0.0, 0.0])
        point_b = gs.array([0.0, 1.0, 0.0, 0.0, 0.0])
        points = gs.array([point_a, point_b])

        mean = FrechetMean(metric=self.sphere.metric, method="default")
        mean.fit(points)

        result = self.sphere.belongs(mean.estimate_)
        expected = True
        self.assertAllClose(result, expected)

    def test_estimate_and_belongs_curves_2d(self):
        points = self.curves_2d.random_point(n_samples=2)

        mean = FrechetMean(metric=self.curves_2d.srv_metric)
        mean.fit(points)

        result = self.curves_2d.belongs(mean.estimate_)
        expected = True
        self.assertAllClose(result, expected)

    def test_estimate_default_gradient_descent_so3(self):
        points = self.so3.random_uniform(2)

        mean_vec = FrechetMean(metric=self.so3.bi_invariant_metric,
                               method="default",
                               init_step_size=1.0)
        mean_vec.fit(points)

        logs = self.so3.bi_invariant_metric.log(points, mean_vec.estimate_)
        result = gs.sum(logs, axis=0)
        expected = gs.zeros_like(points[0])
        self.assertAllClose(result, expected)

    def test_estimate_and_belongs_default_gradient_descent_so3(self):
        point = self.so3.random_uniform(10)

        mean_vec = FrechetMean(metric=self.so3.bi_invariant_metric,
                               method="default")
        mean_vec.fit(point)

        result = self.so3.belongs(mean_vec.estimate_)
        expected = True
        self.assertAllClose(result, expected)

    @geomstats.tests.np_autograd_and_tf_only
    def test_estimate_default_gradient_descent_so_matrix(self):
        points = self.so_matrix.random_uniform(2)
        mean_vec = FrechetMean(
            metric=self.so_matrix.bi_invariant_metric,
            method="default",
            init_step_size=1.0,
        )
        mean_vec.fit(points)
        logs = self.so_matrix.bi_invariant_metric.log(points,
                                                      mean_vec.estimate_)
        result = gs.sum(logs, axis=0)
        expected = gs.zeros_like(points[0])

        self.assertAllClose(result, expected, atol=1e-5)

    @geomstats.tests.np_autograd_and_tf_only
    def test_estimate_and_belongs_default_gradient_descent_so_matrix(self):
        point = self.so_matrix.random_uniform(10)

        mean = FrechetMean(metric=self.so_matrix.bi_invariant_metric,
                           method="default")
        mean.fit(point)

        result = self.so_matrix.belongs(mean.estimate_)
        expected = True
        self.assertAllClose(result, expected)

    @geomstats.tests.np_autograd_and_tf_only
    def test_estimate_and_belongs_adaptive_gradient_descent_so_matrix(self):
        point = self.so_matrix.random_uniform(10)

        mean = FrechetMean(
            metric=self.so_matrix.bi_invariant_metric,
            method="adaptive",
            init_step_size=0.5,
            verbose=True,
        )
        mean.fit(point)

        result = self.so_matrix.belongs(mean.estimate_)
        self.assertTrue(result)

    @geomstats.tests.np_autograd_and_tf_only
    def test_estimate_and_coincide_default_so_vec_and_mat(self):
        point = self.so_matrix.random_uniform(3)

        mean = FrechetMean(metric=self.so_matrix.bi_invariant_metric,
                           method="default")
        mean.fit(point)
        expected = mean.estimate_

        mean_vec = FrechetMean(metric=self.so3.bi_invariant_metric,
                               method="default")
        point_vec = self.so3.rotation_vector_from_matrix(point)
        mean_vec.fit(point_vec)
        result_vec = mean_vec.estimate_
        result = self.so3.matrix_from_rotation_vector(result_vec)

        self.assertAllClose(result, expected)

    def test_estimate_and_belongs_adaptive_gradient_descent_sphere(self):
        point_a = gs.array([1.0, 0.0, 0.0, 0.0, 0.0])
        point_b = gs.array([0.0, 1.0, 0.0, 0.0, 0.0])
        points = gs.array([point_a, point_b])

        mean = FrechetMean(metric=self.sphere.metric, method="adaptive")
        mean.fit(points)

        result = self.sphere.belongs(mean.estimate_)
        expected = True
        self.assertAllClose(result, expected)

    def test_variance_sphere(self):
        point = gs.array([0.0, 0.0, 0.0, 0.0, 1.0])
        points = gs.array([point, point])

        result = variance(points, base_point=point, metric=self.sphere.metric)
        expected = gs.array(0.0)

        self.assertAllClose(expected, result)

    def test_estimate_default_gradient_descent_sphere(self):
        point = gs.array([0.0, 0.0, 0.0, 0.0, 1.0])
        points = gs.array([point, point])

        mean = FrechetMean(metric=self.sphere.metric, method="default")
        mean.fit(X=points)

        result = mean.estimate_
        expected = point

        self.assertAllClose(expected, result)

    def test_estimate_elastic_metric(self):
        point = self.curves_2d.random_point(n_samples=1)
        points = gs.array([point, point])

        mean = FrechetMean(metric=self.elastic_metric)
        mean.fit(X=points)

        result = mean.estimate_
        expected = point

        self.assertAllClose(expected, result)

    def test_estimate_curves_2d(self):
        point = self.curves_2d.random_point(n_samples=1)
        points = gs.array([point, point])

        mean = FrechetMean(metric=self.curves_2d.srv_metric)
        mean.fit(X=points)

        result = mean.estimate_
        expected = point

        self.assertAllClose(expected, result)

    def test_estimate_adaptive_gradient_descent_sphere(self):
        point = gs.array([0.0, 0.0, 0.0, 0.0, 1.0])
        points = gs.array([point, point])

        mean = FrechetMean(metric=self.sphere.metric, method="adaptive")
        mean.fit(X=points)

        result = mean.estimate_
        expected = point

        self.assertAllClose(expected, result)

    def test_estimate_spd(self):
        point = SPDMatrices(3).random_point()
        points = gs.array([point, point])
        mean = FrechetMean(metric=SPDMetricAffine(3), point_type="matrix")
        mean.fit(X=points)
        result = mean.estimate_
        expected = point
        self.assertAllClose(expected, result)

    def test_estimate_spd_two_samples(self):
        space = SPDMatrices(3)
        metric = SPDMetricAffine(3)
        point = space.random_point(2)
        mean = FrechetMean(metric)
        mean.fit(point)
        result = mean.estimate_
        expected = metric.exp(metric.log(point[0], point[1]) / 2, point[1])
        self.assertAllClose(expected, result)

    def test_variance_hyperbolic(self):
        point = gs.array([2.0, 1.0, 1.0, 1.0])
        points = gs.array([point, point])
        result = variance(points,
                          base_point=point,
                          metric=self.hyperbolic.metric)
        expected = gs.array(0.0)

        self.assertAllClose(result, expected)

    def test_estimate_hyperbolic(self):
        point = gs.array([2.0, 1.0, 1.0, 1.0])
        points = gs.array([point, point])

        mean = FrechetMean(metric=self.hyperbolic.metric)
        mean.fit(X=points)
        expected = point

        result = mean.estimate_

        self.assertAllClose(result, expected)

    def test_estimate_and_belongs_hyperbolic(self):
        point_a = self.hyperbolic.random_point()
        point_b = self.hyperbolic.random_point()
        point_c = self.hyperbolic.random_point()
        points = gs.stack([point_a, point_b, point_c], axis=0)

        mean = FrechetMean(metric=self.hyperbolic.metric)
        mean.fit(X=points)

        result = self.hyperbolic.belongs(mean.estimate_)
        expected = True

        self.assertAllClose(result, expected)

    def test_mean_euclidean_shape(self):
        dim = 2
        point = gs.array([1.0, 4.0])

        mean = FrechetMean(metric=self.euclidean.metric)
        points = [point, point, point]
        mean.fit(points)

        result = mean.estimate_

        self.assertAllClose(gs.shape(result), (dim, ))

    def test_mean_euclidean(self):
        point = gs.array([1.0, 4.0])

        mean = FrechetMean(metric=self.euclidean.metric)
        points = [point, point, point]
        mean.fit(points)

        result = mean.estimate_
        expected = point

        self.assertAllClose(result, expected)

        points = gs.array([[1.0, 2.0], [2.0, 3.0], [3.0, 4.0], [4.0, 5.0]])
        weights = [1.0, 2.0, 1.0, 2.0]

        mean = FrechetMean(metric=self.euclidean.metric)
        mean.fit(points, weights=weights)

        result = mean.estimate_
        expected = gs.array([16.0 / 6.0, 22.0 / 6.0])

        self.assertAllClose(result, expected)

    def test_variance_euclidean(self):
        points = gs.array([[1.0, 2.0], [2.0, 3.0], [3.0, 4.0], [4.0, 5.0]])
        weights = gs.array([1.0, 2.0, 1.0, 2.0])
        base_point = gs.zeros(2)
        result = variance(points,
                          weights=weights,
                          base_point=base_point,
                          metric=self.euclidean.metric)
        # we expect the average of the points' sq norms.
        expected = gs.array((1 * 5.0 + 2 * 13.0 + 1 * 25.0 + 2 * 41.0) / 6.0)

        self.assertAllClose(result, expected)

    def test_mean_matrices_shape(self):
        m, n = (2, 2)
        point = gs.array([[1.0, 4.0], [2.0, 3.0]])

        metric = MatricesMetric(m, n)
        mean = FrechetMean(metric=metric, point_type="matrix")
        points = [point, point, point]
        mean.fit(points)

        result = mean.estimate_

        self.assertAllClose(gs.shape(result), (m, n))

    def test_mean_matrices(self):
        m, n = (2, 2)
        point = gs.array([[1.0, 4.0], [2.0, 3.0]])

        metric = MatricesMetric(m, n)
        mean = FrechetMean(metric=metric, point_type="matrix")
        points = [point, point, point]
        mean.fit(points)

        result = mean.estimate_
        expected = point

        self.assertAllClose(result, expected)

    def test_mean_minkowski_shape(self):
        dim = 2
        point = gs.array([2.0, -math.sqrt(3)])
        points = [point, point, point]

        mean = FrechetMean(metric=self.minkowski.metric)
        mean.fit(points)
        result = mean.estimate_

        self.assertAllClose(gs.shape(result), (dim, ))

    def test_mean_minkowski(self):
        point = gs.array([2.0, -math.sqrt(3)])
        points = [point, point, point]

        mean = FrechetMean(metric=self.minkowski.metric)
        mean.fit(points)
        result = mean.estimate_

        expected = point

        self.assertAllClose(result, expected)

        points = gs.array([[1.0, 0.0], [2.0, math.sqrt(3)],
                           [3.0, math.sqrt(8)], [4.0, math.sqrt(24)]])
        weights = gs.array([1.0, 2.0, 1.0, 2.0])

        mean = FrechetMean(metric=self.minkowski.metric)
        mean.fit(points, weights=weights)
        result = self.minkowski.belongs(mean.estimate_)

        self.assertTrue(result)

    def test_variance_minkowski(self):
        points = gs.array([[1.0, 0.0], [2.0, math.sqrt(3)],
                           [3.0, math.sqrt(8)], [4.0, math.sqrt(24)]])
        weights = gs.array([1.0, 2.0, 1.0, 2.0])
        base_point = gs.array([-1.0, 0.0])
        var = variance(points,
                       weights=weights,
                       base_point=base_point,
                       metric=self.minkowski.metric)
        result = var != 0
        # we expect the average of the points' Minkowski sq norms.
        expected = True
        self.assertAllClose(result, expected)

    def test_one_point(self):
        point = gs.array([0.0, 0.0, 0.0, 0.0, 1.0])

        mean = FrechetMean(metric=self.sphere.metric, method="default")
        mean.fit(X=point)

        result = mean.estimate_
        expected = point
        self.assertAllClose(expected, result)

        mean = FrechetMean(metric=self.sphere.metric, method="default")
        mean.fit(X=point)

        result = mean.estimate_
        expected = point
        self.assertAllClose(expected, result)

    def test_batched(self):
        space = SPDMatrices(3)
        metric = SPDMetricAffine(3)
        point = space.random_point(4)
        mean_batch = FrechetMean(metric, method="batch", verbose=True)
        data = gs.stack([point[:2], point[2:]], axis=1)
        mean_batch.fit(data)
        result = mean_batch.estimate_

        mean = FrechetMean(metric)
        mean.fit(data[:, 0])
        expected_1 = mean.estimate_
        mean.fit(data[:, 1])
        expected_2 = mean.estimate_
        expected = gs.stack([expected_1, expected_2])
        self.assertAllClose(expected, result)

    @geomstats.tests.np_and_autograd_only
    def test_stiefel_two_samples(self):
        space = Stiefel(3, 2)
        metric = space.metric
        point = space.random_point(2)
        mean = FrechetMean(metric)
        mean.fit(point)
        result = mean.estimate_
        expected = metric.exp(metric.log(point[0], point[1]) / 2, point[1])
        self.assertAllClose(expected, result)

    @geomstats.tests.np_and_autograd_only
    def test_stiefel_n_samples(self):
        space = Stiefel(3, 2)
        metric = space.metric
        point = space.random_point(2)
        mean = FrechetMean(metric,
                           method="default",
                           init_step_size=0.5,
                           verbose=True)
        mean.fit(point)
        result = space.belongs(mean.estimate_)
        self.assertTrue(result)

    def test_circle_mean(self):
        space = Hypersphere(1)
        points = space.random_uniform(10)
        mean_circle = FrechetMean(space.metric)
        mean_circle.fit(points)
        estimate_circle = mean_circle.estimate_

        # set a wrong dimension so that the extrinsic coordinates are used
        metric = space.metric
        metric.dim = 2
        mean_extrinsic = FrechetMean(metric)
        mean_extrinsic.fit(points)
        estimate_extrinsic = mean_extrinsic.estimate_
        self.assertAllClose(estimate_circle, estimate_extrinsic)
Exemplo n.º 24
0
class SRVMetricTestData(_RiemannianMetricTestData):
    ambient_manifolds_list = [r2, r3]
    metric_args_list = [(ambient_manifolds, )
                        for ambient_manifolds in ambient_manifolds_list]
    shape_list = [(10, 2), (10, 3)]
    space_list = [
        DiscreteCurves(ambient_manifolds)
        for ambient_manifolds in ambient_manifolds_list
    ]
    n_points_list = random.sample(range(2, 5), 2)
    n_tangent_vecs_list = random.sample(range(2, 5), 2)
    n_points_a_list = [1, 2]
    n_points_b_list = [1, 2]
    batch_size_list = random.sample(range(2, 5), 2)
    alpha_list = [1] * 2
    n_rungs_list = [1] * 2
    scheme_list = ["pole"] * 2

    def exp_shape_test_data(self):
        return self._exp_shape_test_data(self.metric_args_list,
                                         self.space_list, self.shape_list)

    def log_shape_test_data(self):
        return self._log_shape_test_data(self.metric_args_list,
                                         self.space_list)

    def squared_dist_is_symmetric_test_data(self):
        return self._squared_dist_is_symmetric_test_data(
            self.metric_args_list,
            self.space_list,
            self.n_points_a_list,
            self.n_points_b_list,
            atol=gs.atol * 1000,
        )

    def exp_belongs_test_data(self):
        return self._exp_belongs_test_data(
            self.metric_args_list,
            self.space_list,
            self.shape_list,
            self.n_tangent_vecs_list,
            belongs_atol=gs.atol * 1000,
        )

    def log_is_tangent_test_data(self):
        return self._log_is_tangent_test_data(
            self.metric_args_list,
            self.space_list,
            self.n_points_list,
            is_tangent_atol=gs.atol * 1000,
        )

    def geodesic_ivp_belongs_test_data(self):
        return self._geodesic_ivp_belongs_test_data(
            self.metric_args_list,
            self.space_list,
            self.shape_list,
            self.n_points_list,
            belongs_atol=gs.atol * 1000,
        )

    def geodesic_bvp_belongs_test_data(self):
        return self._geodesic_bvp_belongs_test_data(
            self.metric_args_list,
            self.space_list,
            self.n_points_list,
            belongs_atol=gs.atol * 1000,
        )

    def exp_after_log_test_data(self):
        return self._exp_after_log_test_data(
            self.metric_args_list,
            self.space_list,
            self.n_points_list,
            rtol=gs.rtol * 100,
            atol=gs.atol * 10000,
        )

    def log_after_exp_test_data(self):
        return self._log_after_exp_test_data(
            self.metric_args_list,
            self.space_list,
            self.shape_list,
            self.n_tangent_vecs_list,
            rtol=gs.rtol * 100,
            atol=gs.atol * 10000,
        )

    def exp_ladder_parallel_transport_test_data(self):
        return self._exp_ladder_parallel_transport_test_data(
            self.metric_args_list,
            self.space_list,
            self.shape_list,
            self.n_tangent_vecs_list,
            self.n_rungs_list,
            self.alpha_list,
            self.scheme_list,
        )

    def exp_geodesic_ivp_test_data(self):
        return self._exp_geodesic_ivp_test_data(
            self.metric_args_list,
            self.space_list,
            self.shape_list,
            self.n_tangent_vecs_list,
            self.n_points_list,
            rtol=gs.rtol * 100000,
            atol=gs.atol * 100000,
        )

    def parallel_transport_ivp_is_isometry_test_data(self):
        return self._parallel_transport_ivp_is_isometry_test_data(
            self.metric_args_list,
            self.space_list,
            self.shape_list,
            self.n_tangent_vecs_list,
            is_tangent_atol=gs.atol * 1000,
            atol=gs.atol * 1000,
        )

    def parallel_transport_bvp_is_isometry_test_data(self):
        return self._parallel_transport_bvp_is_isometry_test_data(
            self.metric_args_list,
            self.space_list,
            self.shape_list,
            self.n_tangent_vecs_list,
            is_tangent_atol=gs.atol * 1000,
            atol=gs.atol * 1000,
        )

    def dist_is_symmetric_test_data(self):
        return self._dist_is_symmetric_test_data(
            self.metric_args_list,
            self.space_list,
            self.n_points_a_list,
            self.n_points_b_list,
        )

    def dist_is_positive_test_data(self):
        return self._dist_is_positive_test_data(
            self.metric_args_list,
            self.space_list,
            self.n_points_a_list,
            self.n_points_b_list,
        )

    def squared_dist_is_positive_test_data(self):
        return self._squared_dist_is_positive_test_data(
            self.metric_args_list,
            self.space_list,
            self.n_points_a_list,
            self.n_points_b_list,
        )

    def dist_is_norm_of_log_test_data(self):
        return self._dist_is_norm_of_log_test_data(
            self.metric_args_list,
            self.space_list,
            self.n_points_a_list,
            self.n_points_b_list,
        )

    def dist_point_to_itself_is_zero_test_data(self):
        return self._dist_point_to_itself_is_zero_test_data(
            self.metric_args_list, self.space_list, self.n_points_list)

    def inner_product_is_symmetric_test_data(self):
        return self._inner_product_is_symmetric_test_data(
            self.metric_args_list,
            self.space_list,
            self.shape_list,
            self.n_tangent_vecs_list,
        )

    def triangle_inequality_of_dist_test_data(self):
        return self._triangle_inequality_of_dist_test_data(
            self.metric_args_list, self.space_list, self.n_points_list)

    def aux_differential_srv_transform_test_data(self):
        smoke_data = [
            dict(
                dim=3,
                n_sampling_points=2000,
                n_curves=2000,
                curve_fun_a=curve_fun_a,
            )
        ]
        return self.generate_tests(smoke_data)

    def aux_differential_srv_transform_inverse_test_data(self):
        smoke_data = [
            dict(dim=3, n_sampling_points=n_sampling_points, curve_a=curve_a)
        ]
        return self.generate_tests(smoke_data)

    def aux_differential_srv_transform_vectorization_test_data(self):
        smoke_data = [
            dict(
                dim=3,
                n_sampling_points=n_sampling_points,
                curve_a=curve_a,
                curve_b=curve_b,
            )
        ]
        return self.generate_tests(smoke_data)

    def srv_inner_product_elastic_test_data(self):
        smoke_data = [
            dict(dim=3, n_sampling_points=n_sampling_points, curve_a=curve_a)
        ]
        return self.generate_tests(smoke_data)

    def srv_inner_product_and_dist_test_data(self):
        smoke_data = [dict(dim=3, curve_a=curve_a, curve_b=curve_b)]
        return self.generate_tests(smoke_data)

    def srv_inner_product_vectorization_test_data(self):
        smoke_data = [
            dict(
                dim=3,
                n_sampling_points=n_sampling_points,
                curve_a=curve_a,
                curve_b=curve_b,
            )
        ]
        return self.generate_tests(smoke_data)

    def split_horizontal_vertical_test_data(self):
        smoke_data = [
            dict(
                times=times,
                n_discretized_curves=n_discretized_curves,
                curve_a=curve_a,
                curve_b=curve_b,
            )
        ]
        return self.generate_tests(smoke_data)

    def space_derivative_test_data(self):
        smoke_data = [
            dict(
                dim=3,
                n_points=3,
                n_discretized_curves=n_discretized_curves,
                n_sampling_points=n_sampling_points,
            )
        ]
        return self.generate_tests(smoke_data)

    def srv_inner_product_test_data(self):
        smoke_data = [
            dict(curve_a=curve_a,
                 curve_b=curve_b,
                 curve_c=curve_c,
                 times=times)
        ]
        return self.generate_tests(smoke_data)

    def srv_norm_test_data(self):
        smoke_data = [dict(curve_a=curve_a, curve_b=curve_b, times=times)]
        return self.generate_tests(smoke_data)

    def srv_metric_pointwise_inner_products_test_data(self):
        smoke_data = [
            dict(
                times=times,
                curve_a=curve_a,
                curve_b=curve_b,
                curve_c=curve_c,
                n_discretized_curves=n_discretized_curves,
                n_sampling_points=n_sampling_points,
            )
        ]
        return self.generate_tests(smoke_data)

    def srv_transform_and_inverse_test_data(self):
        smoke_data = [dict(times=times, curve_a=curve_a, curve_b=curve_b)]
        return self.generate_tests(smoke_data)