def test_aux_differential_srv_transform_vectorization( self, dim, n_sampling_points, curve_a, curve_b ): """Test differential of square root velocity transform. Check vectorization. """ dim = 3 curves = gs.stack((curve_a, curve_b)) tangent_vecs = gs.random.rand(2, n_sampling_points, dim) srv_metric_r3 = SRVMetric(r3) result = srv_metric_r3.aux_differential_srv_transform(tangent_vecs, curves) res_a = srv_metric_r3.aux_differential_srv_transform(tangent_vecs[0], curve_a) res_b = srv_metric_r3.aux_differential_srv_transform(tangent_vecs[1], curve_b) expected = gs.stack([res_a, res_b]) self.assertAllClose(result, expected)
def test_aux_differential_srv_transform_inverse( self, dim, n_sampling_points, curve_a ): """Test inverse of differential of square root velocity transform. Check that it is the inverse of aux_differential_srv_transform. """ tangent_vec = gs.transpose( gs.tile(gs.linspace(0.0, 1.0, n_sampling_points), (dim, 1)) ) srv_metric_r3 = SRVMetric(r3) d_srv = srv_metric_r3.aux_differential_srv_transform(tangent_vec, curve_a) result = srv_metric_r3.aux_differential_srv_transform_inverse(d_srv, curve_a) expected = tangent_vec self.assertAllClose(result, expected, atol=1e-3, rtol=1e-3)
def test_aux_differential_srv_transform(self, dim, n_sampling_points, n_curves, curve_fun_a): """Test differential of square root velocity transform. Check that its value at (curve, tangent_vec) coincides with the derivative at zero of the square root velocity transform of a path of curves starting at curve with initial derivative tangent_vec. """ srv_metric_r3 = SRVMetric(r3) sampling_times = gs.linspace(0.0, 1.0, n_sampling_points) curve_a = curve_fun_a(sampling_times) tangent_vec = gs.transpose( gs.tile(gs.linspace(1.0, 2.0, n_sampling_points), (dim, 1))) result = srv_metric_r3.aux_differential_srv_transform( tangent_vec, curve_a) times = gs.linspace(0.0, 1.0, n_curves) path_of_curves = curve_a + gs.einsum("i,jk->ijk", times, tangent_vec) srv_path = srv_metric_r3.srv_transform(path_of_curves) expected = n_curves * (srv_path[1] - srv_path[0]) self.assertAllClose(result, expected, atol=1e-3, rtol=1e-3)