Ejemplo n.º 1
0
    def test_curvature_derivative(self):
        group = self.matrix_so3
        metric = InvariantMetric(group=group)
        x, y, z = metric.normal_basis(group.lie_algebra.basis)
        result = metric.curvature_derivative(x, y, z, x)
        expected = gs.zeros_like(x)
        self.assertAllClose(result, expected)

        point = group.random_uniform()
        translation_map = group.tangent_translation_map(point)
        tan_a = translation_map(x)
        tan_b = translation_map(y)
        tan_c = translation_map(z)
        result = metric.curvature_derivative(tan_a, tan_b, tan_c, tan_a, point)
        expected = gs.zeros_like(x)
        self.assertAllClose(result, expected)
Ejemplo n.º 2
0
    def test_curvature_derivative(self):
        group = self.matrix_so3
        lie_algebra = SkewSymmetricMatrices(3)
        metric = InvariantMetric(group=group, algebra=lie_algebra)
        x, y, z = lie_algebra.orthonormal_basis(metric.metric_mat_at_identity)
        result = metric.curvature_derivative(x, y, z, x)
        expected = gs.zeros_like(x)
        self.assertAllClose(result, expected)

        point = group.random_uniform()
        translation_map = group.tangent_translation_map(point)
        tan_a = translation_map(x)
        tan_b = translation_map(y)
        tan_c = translation_map(z)
        result = metric.curvature_derivative(tan_a, tan_b, tan_c, tan_a, point)
        expected = gs.zeros_like(x)
        self.assertAllClose(result, expected)
Ejemplo n.º 3
0
 def test_curvature_derivative_tangent_translation_map(
     self,
     group,
     tangent_vec_a,
     tangent_vec_b,
     tangent_vec_c,
     tangent_vec_d,
     base_point,
     expected,
 ):
     metric = InvariantMetric(group=group)
     translation_map = group.tangent_translation_map(base_point)
     tan_a = translation_map(tangent_vec_a)
     tan_b = translation_map(tangent_vec_b)
     tan_c = translation_map(tangent_vec_c)
     tan_d = translation_map(tangent_vec_d)
     result = metric.curvature_derivative(tan_a, tan_b, tan_c, tan_d,
                                          base_point)
     self.assertAllClose(result, expected)