class TestGrassmannianCanonicalMetric(RiemannianMetricTestCase, metaclass=Parametrizer): metric = connection = GrassmannianCanonicalMetric skip_test_log_after_exp = True skip_test_exp_geodesic_ivp = True skip_test_log_is_tangent = not np_backend() testing_data = GrassmannianCanonicalMetricTestData() def test_exp(self, n, k, tangent_vec, base_point, expected): self.assertAllClose( self.metric(n, k).exp(gs.array(tangent_vec), gs.array(base_point)), gs.array(expected), )
class TestMinkowskiMetric(RiemannianMetricTestCase, metaclass=Parametrizer): connection = metric = MinkowskiMetric skip_test_parallel_transport_ivp_is_isometry = True skip_test_parallel_transport_bvp_is_isometry = True skip_test_exp_geodesic_ivp = True skip_test_dist_is_positive = True skip_test_squared_dist_is_positive = True skip_test_dist_is_norm_of_log = not np_backend() skip_test_dist_is_symmetric = not np_backend() skip_test_triangle_inequality_of_dist = True testing_data = MinkowskiMetricTestData() def test_metric_matrix(self, dim, expected): metric = self.metric(dim) self.assertAllClose(metric.metric_matrix(), gs.array(expected)) def test_inner_product(self, dim, point_a, point_b, expected): metric = self.metric(dim) self.assertAllClose( metric.inner_product(gs.array(point_a), gs.array(point_b)), gs.array(expected), ) def test_squared_norm(self, dim, point, expected): metric = self.metric(dim) self.assertAllClose(metric.squared_norm(gs.array(point)), gs.array(expected)) def test_exp(self, dim, tangent_vec, base_point, expected): result = self.metric(dim).exp(gs.array(tangent_vec), gs.array(base_point)) self.assertAllClose(result, gs.array(expected)) def test_log(self, dim, point, base_point, expected): result = self.metric(dim).log(gs.array(point), gs.array(base_point)) self.assertAllClose(result, gs.array(expected)) def test_squared_dist(self, dim, point_a, point_b, expected): result = self.metric(dim).squared_dist(gs.array(point_a), gs.array(point_b)) self.assertAllClose(result, gs.array(expected))
class TestMinkowskiMetric(RiemannianMetricTestCase, metaclass=Parametrizer): connection = metric = MinkowskiMetric skip_test_parallel_transport_ivp_is_isometry = True skip_test_parallel_transport_bvp_is_isometry = True skip_test_exp_geodesic_ivp = True skip_test_dist_is_positive = True skip_test_squared_dist_is_positive = True skip_test_dist_is_norm_of_log = not np_backend() skip_test_dist_is_symmetric = not np_backend() class MinkowskiMetricTestData(_RiemannianMetricTestData): n_list = random.sample(range(2, 4), 2) metric_args_list = [(n, ) for n in n_list] shape_list = metric_args_list space_list = [Minkowski(n) for n in n_list] n_points_list = random.sample(range(1, 3), 2) n_tangent_vecs_list = random.sample(range(1, 3), 2) n_points_a_list = random.sample(range(1, 3), 2) n_points_b_list = [1] alpha_list = [1] * 2 n_rungs_list = [1] * 2 scheme_list = ["pole"] * 2 def metric_matrix_test_data(self): smoke_data = [dict(dim=2, expected=[[-1.0, 0.0], [0.0, 1.0]])] return self.generate_tests(smoke_data) def inner_product_test_data(self): smoke_data = [ dict(dim=2, point_a=[0.0, 1.0], point_b=[2.0, 10.0], expected=10.0), dict( dim=2, point_a=[[-1.0, 0.0], [1.0, 0.0], [2.0, math.sqrt(3)]], point_b=[ [2.0, -math.sqrt(3)], [4.0, math.sqrt(15)], [-4.0, math.sqrt(15)], ], expected=[2.0, -4.0, 14.70820393], ), ] return self.generate_tests(smoke_data) def squared_norm_test_data(self): smoke_data = [dict(dim=2, vector=[-2.0, 4.0], expected=12.0)] return self.generate_tests(smoke_data) def squared_dist_test_data(self): smoke_data = [ dict( dim=2, point_a=[2.0, -math.sqrt(3)], point_b=[4.0, math.sqrt(15)], expected=27.416407, ) ] return self.generate_tests(smoke_data) def exp_test_data(self): smoke_data = [ dict( dim=2, tangent_vec=[2.0, math.sqrt(3)], base_point=[1.0, 0.0], expected=[3.0, math.sqrt(3)], ) ] return self.generate_tests(smoke_data) def log_test_data(self): smoke_data = [ dict( dim=2, point=[2.0, math.sqrt(3)], base_point=[-1.0, 0.0], expected=[3.0, math.sqrt(3)], ) ] return self.generate_tests(smoke_data) def exp_shape_test_data(self): return self._exp_shape_test_data(self.metric_args_list, self.space_list, self.shape_list) def log_shape_test_data(self): return self._log_shape_test_data( self.metric_args_list, self.space_list, ) def squared_dist_is_symmetric_test_data(self): return self._squared_dist_is_symmetric_test_data( self.metric_args_list, self.space_list, self.n_points_a_list, self.n_points_b_list, atol=gs.atol * 1000, ) def exp_belongs_test_data(self): return self._exp_belongs_test_data( self.metric_args_list, self.space_list, self.shape_list, self.n_tangent_vecs_list, belongs_atol=gs.atol * 1000, ) def log_is_tangent_test_data(self): return self._log_is_tangent_test_data( self.metric_args_list, self.space_list, self.n_points_list, is_tangent_atol=gs.atol * 1000, ) def geodesic_ivp_belongs_test_data(self): return self._geodesic_ivp_belongs_test_data( self.metric_args_list, self.space_list, self.shape_list, self.n_points_list, belongs_atol=gs.atol * 1000, ) def geodesic_bvp_belongs_test_data(self): return self._geodesic_bvp_belongs_test_data( self.metric_args_list, self.space_list, self.n_points_list, belongs_atol=gs.atol * 1000, ) def log_then_exp_test_data(self): return self._log_then_exp_test_data( self.metric_args_list, self.space_list, self.n_points_list, rtol=gs.rtol * 100, atol=gs.atol * 10000, ) def exp_then_log_test_data(self): return self._exp_then_log_test_data( self.metric_args_list, self.space_list, self.shape_list, self.n_tangent_vecs_list, rtol=gs.rtol * 100, atol=gs.atol * 10000, ) def exp_ladder_parallel_transport_test_data(self): return self._exp_ladder_parallel_transport_test_data( self.metric_args_list, self.space_list, self.shape_list, self.n_tangent_vecs_list, self.n_rungs_list, self.alpha_list, self.scheme_list, ) def exp_geodesic_ivp_test_data(self): return self._exp_geodesic_ivp_test_data( self.metric_args_list, self.space_list, self.shape_list, self.n_tangent_vecs_list, self.n_points_list, rtol=gs.rtol * 1000, atol=gs.atol * 1000, ) def parallel_transport_ivp_is_isometry_test_data(self): return self._parallel_transport_ivp_is_isometry_test_data( self.metric_args_list, self.space_list, self.shape_list, self.n_tangent_vecs_list, is_tangent_atol=gs.atol * 1000, atol=gs.atol * 1000, ) def parallel_transport_bvp_is_isometry_test_data(self): return self._parallel_transport_bvp_is_isometry_test_data( self.metric_args_list, self.space_list, self.shape_list, self.n_tangent_vecs_list, is_tangent_atol=gs.atol * 1000, atol=gs.atol * 1000, ) def dist_is_symmetric_test_data(self): return self._dist_is_symmetric_test_data( self.metric_args_list, self.space_list, self.n_points_a_list, self.n_points_b_list, ) def dist_is_norm_of_log_test_data(self): return self._dist_is_norm_of_log_test_data( self.metric_args_list, self.space_list, self.n_points_a_list, self.n_points_b_list, ) def dist_point_to_itself_is_zero_test_data(self): return self._dist_point_to_itself_is_zero_test_data( self.metric_args_list, self.space_list, self.n_points_list) def inner_product_is_symmetric_test_data(self): return self._inner_product_is_symmetric_test_data( self.metric_args_list, self.space_list, self.shape_list, self.n_tangent_vecs_list, ) testing_data = MinkowskiMetricTestData() def test_metric_matrix(self, dim, expected): metric = self.metric(dim) self.assertAllClose(metric.metric_matrix(), gs.array(expected)) def test_inner_product(self, dim, point_a, point_b, expected): metric = self.metric(dim) self.assertAllClose( metric.inner_product(gs.array(point_a), gs.array(point_b)), gs.array(expected), ) def test_squared_norm(self, dim, point, expected): metric = self.metric(dim) self.assertAllClose(metric.squared_norm(gs.array(point)), gs.array(expected)) def test_exp(self, dim, tangent_vec, base_point, expected): result = self.metric(dim).exp(gs.array(tangent_vec), gs.array(base_point)) self.assertAllClose(result, gs.array(expected)) def test_log(self, dim, point, base_point, expected): result = self.metric(dim).log(gs.array(point), gs.array(base_point)) self.assertAllClose(result, gs.array(expected)) def test_squared_dist(self, dim, point_a, point_b, expected): result = self.metric(dim).squared_dist(gs.array(point_a), gs.array(point_b)) self.assertAllClose(result, gs.array(expected))
class TestGrassmannianCanonicalMetric(RiemannianMetricTestCase, metaclass=Parametrizer): metric = connection = GrassmannianCanonicalMetric skip_test_exp_then_log = True skip_test_exp_geodesic_ivp = True skip_test_log_is_tangent = not np_backend() class GrassmannianCanonicalMetricTestData(_RiemannianMetricTestData): n_list = random.sample(range(3, 5), 2) k_list = [random.sample(range(2, n), 1)[0] for n in n_list] metric_args_list = list(zip(n_list, k_list)) shape_list = [(n, n) for n in n_list] space_list = [Grassmannian(n, p) for n, p in metric_args_list] n_points_list = random.sample(range(1, 5), 2) n_points_a_list = random.sample(range(1, 5), 2) n_points_b_list = [1] n_tangent_vecs_list = random.sample(range(1, 5), 2) alpha_list = [1] * 2 n_rungs_list = [1] * 2 scheme_list = ["pole"] * 2 def exp_test_data(self): smoke_data = [ dict( n=3, k=2, tangent_vec=Matrices.bracket(pi_2 * r_y, gs.array([p_xy, p_yz])), base_point=gs.array([p_xy, p_yz]), expected=gs.array([p_yz, p_xy]), ), dict( n=3, k=2, tangent_vec=Matrices.bracket(pi_2 * gs.array([r_y, r_z]), gs.array([p_xy, p_yz])), base_point=gs.array([p_xy, p_yz]), expected=gs.array([p_yz, p_xz]), ), ] return self.generate_tests(smoke_data) def exp_shape_test_data(self): return self._exp_shape_test_data(self.metric_args_list, self.space_list, self.shape_list) def log_shape_test_data(self): return self._log_shape_test_data( self.metric_args_list, self.space_list, ) def squared_dist_is_symmetric_test_data(self): return self._squared_dist_is_symmetric_test_data( self.metric_args_list, self.space_list, self.n_points_a_list, self.n_points_b_list, atol=gs.atol * 1000, ) def exp_belongs_test_data(self): return self._exp_belongs_test_data( self.metric_args_list, self.space_list, self.shape_list, self.n_tangent_vecs_list, belongs_atol=gs.atol * 1000, ) def log_is_tangent_test_data(self): return self._log_is_tangent_test_data( self.metric_args_list, self.space_list, self.n_points_list, is_tangent_atol=gs.atol * 1000, ) def geodesic_ivp_belongs_test_data(self): return self._geodesic_ivp_belongs_test_data( self.metric_args_list, self.space_list, self.shape_list, self.n_points_list, belongs_atol=gs.atol * 10000, ) def geodesic_bvp_belongs_test_data(self): return self._geodesic_bvp_belongs_test_data( self.metric_args_list, self.space_list, self.n_points_list, belongs_atol=gs.atol * 10000, ) def log_then_exp_test_data(self): return self._log_then_exp_test_data( self.metric_args_list, self.space_list, self.n_points_list, rtol=gs.rtol * 100, atol=gs.atol * 10000, ) def exp_then_log_test_data(self): return self._exp_then_log_test_data( self.metric_args_list, self.space_list, self.shape_list, self.n_tangent_vecs_list, rtol=gs.rtol * 100, atol=gs.atol * 10000, ) def exp_ladder_parallel_transport_test_data(self): return self._exp_ladder_parallel_transport_test_data( self.metric_args_list, self.space_list, self.shape_list, self.n_tangent_vecs_list, self.n_rungs_list, self.alpha_list, self.scheme_list, ) def exp_geodesic_ivp_test_data(self): return self._exp_geodesic_ivp_test_data( self.metric_args_list, self.space_list, self.shape_list, self.n_tangent_vecs_list, self.n_points_list, rtol=gs.rtol * 10000, atol=gs.atol * 10000, ) def parallel_transport_ivp_is_isometry_test_data(self): return self._parallel_transport_ivp_is_isometry_test_data( self.metric_args_list, self.space_list, self.shape_list, self.n_tangent_vecs_list, is_tangent_atol=gs.atol * 1000, atol=gs.atol * 1000, ) def parallel_transport_bvp_is_isometry_test_data(self): return self._parallel_transport_bvp_is_isometry_test_data( self.metric_args_list, self.space_list, self.shape_list, self.n_tangent_vecs_list, is_tangent_atol=gs.atol * 1000, atol=gs.atol * 1000, ) def dist_is_symmetric_test_data(self): return self._dist_is_symmetric_test_data( self.metric_args_list, self.space_list, self.n_points_a_list, self.n_points_b_list, atol=gs.atol * 1000, ) def dist_is_positive_test_data(self): return self._dist_is_positive_test_data( self.metric_args_list, self.space_list, self.n_points_a_list, self.n_points_b_list, ) def squared_dist_is_positive_test_data(self): return self._squared_dist_is_positive_test_data( self.metric_args_list, self.space_list, self.n_points_a_list, self.n_points_b_list, ) def dist_is_norm_of_log_test_data(self): return self._dist_is_norm_of_log_test_data( self.metric_args_list, self.space_list, self.n_points_a_list, self.n_points_b_list, atol=gs.atol * 1000, ) def dist_point_to_itself_is_zero_test_data(self): return self._dist_point_to_itself_is_zero_test_data( self.metric_args_list, self.space_list, self.n_points_list) def inner_product_is_symmetric_test_data(self): return self._inner_product_is_symmetric_test_data( self.metric_args_list, self.space_list, self.shape_list, self.n_tangent_vecs_list, ) testing_data = GrassmannianCanonicalMetricTestData() def test_exp(self, n, k, tangent_vec, base_point, expected): self.assertAllClose( self.metric(n, k).exp(gs.array(tangent_vec), gs.array(base_point)), gs.array(expected), )