def test_curvature(self): group = self.matrix_so3 metric = InvariantMetric(group=group) x, y, z = metric.normal_basis(group.lie_algebra.basis) result = metric.curvature_at_identity(x, y, x) expected = 1.0 / 8 * y self.assertAllClose(result, expected) tan_a = gs.stack([x, x]) tan_b = gs.stack([y] * 2) result = metric.curvature(tan_a, tan_b, tan_a) self.assertAllClose(result, gs.array([expected] * 2)) point = group.random_uniform() translation_map = group.tangent_translation_map(point) tan_a = translation_map(x) tan_b = translation_map(y) result = metric.curvature(tan_a, tan_b, tan_a, point) expected = translation_map(expected) self.assertAllClose(result, expected) result = metric.curvature(y, y, z) expected = gs.zeros_like(z) self.assertAllClose(result, expected)
def test_curvature(self, group, tangent_vec_a, tangent_vec_b, tangent_vec_c, expected): metric = InvariantMetric(group) result = metric.curvature(tangent_vec_a, tangent_vec_b, tangent_vec_c, base_point=None) self.assertAllClose(result, expected)
def test_curvature_translation_point(self, group, tangent_vec_a, tangent_vec_b, tangent_vec_c, point, expected): metric = InvariantMetric(group) translation_map = group.tangent_translation_map(point) tan_a = translation_map(tangent_vec_a) tan_b = translation_map(tangent_vec_b) tan_c = translation_map(tangent_vec_c) result = metric.curvature(tan_a, tan_b, tan_c, point) expected = translation_map(expected) self.assertAllClose(result, expected)