Ejemplo n.º 1
0
    def test_aux_differential_srv_transform_vectorization(
        self, dim, n_sampling_points, curve_a, curve_b
    ):
        """Test differential of square root velocity transform.
        Check vectorization.
        """
        dim = 3
        curves = gs.stack((curve_a, curve_b))
        tangent_vecs = gs.random.rand(2, n_sampling_points, dim)
        srv_metric_r3 = SRVMetric(r3)
        result = srv_metric_r3.aux_differential_srv_transform(tangent_vecs, curves)

        res_a = srv_metric_r3.aux_differential_srv_transform(tangent_vecs[0], curve_a)
        res_b = srv_metric_r3.aux_differential_srv_transform(tangent_vecs[1], curve_b)
        expected = gs.stack([res_a, res_b])
        self.assertAllClose(result, expected)
Ejemplo n.º 2
0
 def test_aux_differential_srv_transform_inverse(
     self, dim, n_sampling_points, curve_a
 ):
     """Test inverse of differential of square root velocity transform.
     Check that it is the inverse of aux_differential_srv_transform.
     """
     tangent_vec = gs.transpose(
         gs.tile(gs.linspace(0.0, 1.0, n_sampling_points), (dim, 1))
     )
     srv_metric_r3 = SRVMetric(r3)
     d_srv = srv_metric_r3.aux_differential_srv_transform(tangent_vec, curve_a)
     result = srv_metric_r3.aux_differential_srv_transform_inverse(d_srv, curve_a)
     expected = tangent_vec
     self.assertAllClose(result, expected, atol=1e-3, rtol=1e-3)
Ejemplo n.º 3
0
    def test_aux_differential_srv_transform(self, dim, n_sampling_points,
                                            n_curves, curve_fun_a):
        """Test differential of square root velocity transform.
        Check that its value at (curve, tangent_vec) coincides
        with the derivative at zero of the square root velocity
        transform of a path of curves starting at curve with
        initial derivative tangent_vec.
        """
        srv_metric_r3 = SRVMetric(r3)
        sampling_times = gs.linspace(0.0, 1.0, n_sampling_points)
        curve_a = curve_fun_a(sampling_times)
        tangent_vec = gs.transpose(
            gs.tile(gs.linspace(1.0, 2.0, n_sampling_points), (dim, 1)))
        result = srv_metric_r3.aux_differential_srv_transform(
            tangent_vec, curve_a)

        times = gs.linspace(0.0, 1.0, n_curves)
        path_of_curves = curve_a + gs.einsum("i,jk->ijk", times, tangent_vec)
        srv_path = srv_metric_r3.srv_transform(path_of_curves)
        expected = n_curves * (srv_path[1] - srv_path[0])
        self.assertAllClose(result, expected, atol=1e-3, rtol=1e-3)