def setUp(self): """ Tangent vectors constructed following: http://noodle.med.yale.edu/hdtag/notes/steifel_notes.pdf """ warnings.filterwarnings('ignore') gs.random.seed(1234) self.p = 3 self.n = 4 self.space = Stiefel(self.n, self.p) self.n_samples = 10 self.dimension = int( self.p * self.n - (self.p * (self.p + 1) / 2)) self.point_a = gs.array([ [1., 0., 0.], [0., 1., 0.], [0., 0., 1.], [0., 0., 0.]]) self.point_b = gs.array([ [1. / gs.sqrt(2.), 0., 0.], [0., 1., 0.], [0., 0., 1.], [1. / gs.sqrt(2.), 0., 0.]]) point_perp = gs.array([ [0.], [0.], [0.], [1.]]) matrix_a_1 = gs.array([ [0., 2., -5.], [-2., 0., -1.], [5., 1., 0.]]) matrix_b_1 = gs.array([ [-2., 1., 4.]]) matrix_a_2 = gs.array([ [0., 2., -5.], [-2., 0., -1.], [5., 1., 0.]]) matrix_b_2 = gs.array([ [-2., 1., 4.]]) self.tangent_vector_1 = ( gs.matmul(self.point_a, matrix_a_1) + gs.matmul(point_perp, matrix_b_1)) self.tangent_vector_2 = ( gs.matmul(self.point_a, matrix_a_2) + gs.matmul(point_perp, matrix_b_2)) self.metric = self.space.canonical_metric
def test_stiefel_n_samples(self): space = Stiefel(3, 2) metric = space.metric point = space.random_point(2) mean = FrechetMean(metric, lr=0.5, verbose=True, method="default") mean.fit(point) result = space.belongs(mean.estimate_) self.assertTrue(result)
def test_stiefel_two_samples(self): space = Stiefel(3, 2) metric = space.metric point = space.random_point(2) mean = FrechetMean(metric) mean.fit(point) result = mean.estimate_ expected = metric.exp(metric.log(point[0], point[1]) / 2, point[1]) self.assertAllClose(expected, result)
def test_to_grassmanniann_vectorized(self): inf_rots = gs.array([gs.pi * r_z / n for n in [2, 3, 4]]) rots = GeneralLinear.exp(inf_rots) points = Matrices.mul(rots, point1) result = Stiefel.to_grassmannian(points) expected = gs.array([p_xy, p_xy, p_xy]) self.assertAllClose(result, expected)
def log_two_sheets_error_test_data(self): stiefel = Stiefel(n=3, p=3) base_point = stiefel.random_point() det_base = gs.linalg.det(base_point) point = stiefel.random_point() det_point = gs.linalg.det(point) if gs.all(det_base * det_point > 0.0): point *= -1.0 random_data = [ dict( n=3, p=3, point=point, base_point=base_point, expected=pytest.raises(ValueError), ) ] return self.generate_tests([], random_data)
class TestStiefelMethods(geomstats.tests.TestCase): def setUp(self): """ Tangent vectors constructed following: http://noodle.med.yale.edu/hdtag/notes/steifel_notes.pdf """ warnings.filterwarnings('ignore') gs.random.seed(1234) self.p = 3 self.n = 4 self.space = Stiefel(self.n, self.p) self.n_samples = 10 self.dimension = int( self.p * self.n - (self.p * (self.p + 1) / 2)) self.point_a = gs.array([ [1., 0., 0.], [0., 1., 0.], [0., 0., 1.], [0., 0., 0.]]) self.point_b = gs.array([ [1. / gs.sqrt(2.), 0., 0.], [0., 1., 0.], [0., 0., 1.], [1. / gs.sqrt(2.), 0., 0.]]) point_perp = gs.array([ [0.], [0.], [0.], [1.]]) matrix_a_1 = gs.array([ [0., 2., -5.], [-2., 0., -1.], [5., 1., 0.]]) matrix_b_1 = gs.array([ [-2., 1., 4.]]) matrix_a_2 = gs.array([ [0., 2., -5.], [-2., 0., -1.], [5., 1., 0.]]) matrix_b_2 = gs.array([ [-2., 1., 4.]]) self.tangent_vector_1 = ( gs.matmul(self.point_a, matrix_a_1) + gs.matmul(point_perp, matrix_b_1)) self.tangent_vector_2 = ( gs.matmul(self.point_a, matrix_a_2) + gs.matmul(point_perp, matrix_b_2)) self.metric = self.space.canonical_metric @geomstats.tests.np_and_tf_only def test_belongs_shape(self): point = self.space.random_uniform() belongs = self.space.belongs(point) self.assertAllClose(gs.shape(belongs), ()) @geomstats.tests.np_and_tf_only def test_random_uniform_and_belongs(self): point = self.space.random_uniform() result = self.space.belongs(point, tolerance=1e-4) expected = True self.assertAllClose(result, expected) @geomstats.tests.np_and_tf_only def test_random_uniform_shape(self): result = self.space.random_uniform() self.assertAllClose(gs.shape(result), (self.n, self.p)) @geomstats.tests.np_only def test_log_and_exp(self): """ Test that the Riemannian exponential and the Riemannian logarithm are inverse. Expect their composition to give the identity function. """ # Riemannian Log then Riemannian Exp # General case base_point = self.point_a point = self.point_b log = self.metric.log(point=point, base_point=base_point) result = self.metric.exp(tangent_vec=log, base_point=base_point) expected = point self.assertAllClose(result, expected, atol=ATOL) @geomstats.tests.np_and_tf_only def test_exp_and_belongs(self): base_point = self.point_a tangent_vec = self.tangent_vector_1 exp = self.metric.exp( tangent_vec=tangent_vec, base_point=base_point) result = self.space.belongs(exp) expected = True self.assertAllClose(result, expected) @geomstats.tests.np_and_tf_only def test_exp_vectorization_shape(self): n_samples = self.n_samples n = self.n p = self.p one_base_point = self.point_a one_tangent_vec = self.tangent_vector_1 n_base_points = gs.tile( gs.to_ndarray(self.point_a, to_ndim=3), (n_samples, 1, 1)) n_tangent_vecs = gs.tile( gs.to_ndarray(self.tangent_vector_2, to_ndim=3), (n_samples, 1, 1)) # With single tangent vec and base point result = self.metric.exp(one_tangent_vec, one_base_point) self.assertAllClose(gs.shape(result), (n, p)) # With n_samples tangent vecs and base points result = self.metric.exp(n_tangent_vecs, one_base_point) self.assertAllClose(gs.shape(result), (n_samples, n, p)) result = self.metric.exp(one_tangent_vec, n_base_points) self.assertAllClose(gs.shape(result), (n_samples, n, p)) @geomstats.tests.np_and_tf_only def test_log_vectorization_shape(self): n_samples = self.n_samples n = self.n p = self.p one_point = self.space.random_uniform() one_base_point = self.space.random_uniform() n_points = self.space.random_uniform(n_samples=n_samples) n_base_points = self.space.random_uniform(n_samples=n_samples) # With single point and base point result = self.metric.log(one_point, one_base_point) self.assertAllClose(gs.shape(result), (n, p)) # With multiple points and base points result = self.metric.log(n_points, one_base_point) self.assertAllClose(gs.shape(result), (n_samples, n, p)) result = self.metric.log(one_point, n_base_points) self.assertAllClose(gs.shape(result), (n_samples, n, p)) result = self.metric.log(n_points, n_base_points) self.assertAllClose(gs.shape(result), (n_samples, n, p)) @geomstats.tests.np_only def test_retractation_and_lifting(self): """ Test that the Riemannian exponential and the Riemannian logarithm are inverse. Expect their composition to give the identity function. """ # Riemannian Log then Riemannian Exp # General case base_point = self.point_a point = self.point_b tangent_vec = self.tangent_vector_1 lifted = self.metric.lifting(point=point, base_point=base_point) result = self.metric.retraction( tangent_vec=lifted, base_point=base_point) expected = point self.assertAllClose(result, expected, atol=ATOL) retract = self.metric.retraction( tangent_vec=tangent_vec, base_point=base_point) result = self.metric.lifting(point=retract, base_point=base_point) expected = tangent_vec self.assertAllClose(result, expected, atol=ATOL) @geomstats.tests.np_only def test_lifting_vectorization_shape(self): n_samples = self.n_samples n = self.n p = self.p one_point = self.point_a one_base_point = self.point_b n_points = gs.tile( gs.to_ndarray(self.point_a, to_ndim=3), (n_samples, 1, 1)) n_base_points = gs.tile( gs.to_ndarray(self.point_b, to_ndim=3), (n_samples, 1, 1)) result = self.metric.lifting(one_point, one_base_point) self.assertAllClose(gs.shape(result), (n, p)) result = self.metric.lifting(n_points, one_base_point) self.assertAllClose(gs.shape(result), (n_samples, n, p)) result = self.metric.lifting(one_point, n_base_points) self.assertAllClose(gs.shape(result), (n_samples, n, p)) result = self.metric.lifting(n_points, n_base_points) self.assertAllClose(gs.shape(result), (n_samples, n, p)) @geomstats.tests.np_and_tf_only def test_retraction_vectorization_shape(self): n_samples = self.n_samples n = self.n p = self.p one_point = self.point_a n_points = gs.tile( gs.to_ndarray(one_point, to_ndim=3), (n_samples, 1, 1)) one_tangent_vec = self.tangent_vector_1 n_tangent_vecs = gs.tile( gs.to_ndarray(self.tangent_vector_2, to_ndim=3), (n_samples, 1, 1)) result = self.metric.retraction(one_tangent_vec, one_point) self.assertAllClose(gs.shape(result), (n, p)) result = self.metric.retraction(n_tangent_vecs, one_point) self.assertAllClose(gs.shape(result), (n_samples, n, p)) result = self.metric.retraction(one_tangent_vec, n_points) self.assertAllClose(gs.shape(result), (n_samples, n, p)) result = self.metric.retraction(n_tangent_vecs, n_points) self.assertAllClose(gs.shape(result), (n_samples, n, p)) def test_inner_product(self): base_point = self.point_a tangent_vector_1 = self.tangent_vector_1 tangent_vector_2 = self.tangent_vector_2 result = self.metric.inner_product( tangent_vector_1, tangent_vector_2, base_point=base_point) self.assertAllClose(gs.shape(result), ()) @geomstats.tests.np_and_pytorch_only def test_to_grassmannian(self): point2 = gs.array([[1., -1.], [1., 1.], [0., 0.]]) / gs.sqrt(2) result = self.space.to_grassmannian(point2) expected = p_xy self.assertAllClose(result, expected) @geomstats.tests.np_only def test_to_grassmanniann_vectorized(self): inf_rots = gs.array([gs.pi * r_z / n for n in [2, 3, 4]]) rots = GeneralLinear.exp(inf_rots) points = Matrices.mul(rots, point1) result = Stiefel.to_grassmannian(points) expected = gs.array([p_xy, p_xy, p_xy]) self.assertAllClose(result, expected)
class StiefelCanonicalMetricTestData(_RiemannianMetricTestData): n_list = random.sample(range(3, 5), 2) p_list = [random.sample(range(2, n), 1)[0] for n in n_list] metric_args_list = list(zip(n_list, p_list)) shape_list = metric_args_list space_list = [Stiefel(n, p) for n, p in metric_args_list] n_points_list = random.sample(range(1, 5), 2) n_points_a_list = random.sample(range(1, 5), 2) n_points_b_list = [1] n_tangent_vecs_list = random.sample(range(1, 5), 2) alpha_list = [1] * 2 n_rungs_list = [1] * 2 scheme_list = ["pole"] * 2 def log_two_sheets_error_test_data(self): stiefel = Stiefel(n=3, p=3) base_point = stiefel.random_point() det_base = gs.linalg.det(base_point) point = stiefel.random_point() det_point = gs.linalg.det(point) if gs.all(det_base * det_point > 0.0): point *= -1.0 random_data = [ dict( n=3, p=3, point=point, base_point=base_point, expected=pytest.raises(ValueError), ) ] return self.generate_tests([], random_data) def exp_shape_test_data(self): return self._exp_shape_test_data( self.metric_args_list, self.space_list, self.shape_list, ) def log_shape_test_data(self): return self._log_shape_test_data( self.metric_args_list, self.space_list, ) def squared_dist_is_symmetric_test_data(self): return self._squared_dist_is_symmetric_test_data( self.metric_args_list, self.space_list, self.n_points_a_list, self.n_points_b_list, atol=gs.atol * 1000, ) def exp_belongs_test_data(self): return self._exp_belongs_test_data( self.metric_args_list, self.space_list, self.shape_list, self.n_tangent_vecs_list, belongs_atol=gs.atol * 10000, ) def log_is_tangent_test_data(self): return self._log_is_tangent_test_data( self.metric_args_list, self.space_list, self.n_points_list, is_tangent_atol=gs.atol * 1000, ) def geodesic_ivp_belongs_test_data(self): return self._geodesic_ivp_belongs_test_data( self.metric_args_list, self.space_list, self.shape_list, self.n_points_list, belongs_atol=gs.atol * 1000, ) def geodesic_bvp_belongs_test_data(self): return self._geodesic_bvp_belongs_test_data( self.metric_args_list, self.space_list, self.n_points_list, belongs_atol=gs.atol * 1000, ) def exp_after_log_test_data(self): return self._exp_after_log_test_data( self.metric_args_list, self.space_list, self.n_points_list, rtol=gs.rtol * 100, atol=gs.atol * 10000, ) def log_after_exp_test_data(self): return self._log_after_exp_test_data( self.metric_args_list, self.space_list, self.shape_list, self.n_tangent_vecs_list, rtol=gs.rtol * 100, atol=gs.atol * 10000, ) def exp_ladder_parallel_transport_test_data(self): return self._exp_ladder_parallel_transport_test_data( self.metric_args_list, self.space_list, self.shape_list, self.n_tangent_vecs_list, self.n_rungs_list, self.alpha_list, self.scheme_list, atol=1e-1, ) def exp_geodesic_ivp_test_data(self): return self._exp_geodesic_ivp_test_data( self.metric_args_list, self.space_list, self.shape_list, self.n_tangent_vecs_list, self.n_points_list, rtol=gs.rtol * 1000, atol=gs.atol * 1000, ) def dist_is_symmetric_test_data(self): return self._dist_is_symmetric_test_data( self.metric_args_list, self.space_list, self.n_points_a_list, self.n_points_b_list, atol=gs.atol * 1000, ) def dist_is_positive_test_data(self): return self._dist_is_positive_test_data( self.metric_args_list, self.space_list, self.n_points_a_list, self.n_points_b_list, ) def squared_dist_is_positive_test_data(self): return self._squared_dist_is_positive_test_data( self.metric_args_list, self.space_list, self.n_points_a_list, self.n_points_b_list, ) def dist_is_norm_of_log_test_data(self): return self._dist_is_norm_of_log_test_data( self.metric_args_list, self.space_list, self.n_points_a_list, self.n_points_b_list, ) def dist_point_to_itself_is_zero_test_data(self): return self._dist_point_to_itself_is_zero_test_data( self.metric_args_list, self.space_list, self.n_points_list, atol=gs.atol * 1000, ) def inner_product_is_symmetric_test_data(self): return self._inner_product_is_symmetric_test_data( self.metric_args_list, self.space_list, self.shape_list, self.n_tangent_vecs_list, ) def triangle_inequality_of_dist_test_data(self): return self._triangle_inequality_of_dist_test_data( self.metric_args_list, self.space_list, self.n_points_list, atol=1e-3, ) def retraction_lifting_test_data(self): return self._log_after_exp_test_data( self.metric_args_list, self.space_list, self.shape_list, self.n_tangent_vecs_list, rtol=gs.rtol * 100, atol=gs.atol * 10000, ) def lifting_retraction_test_data(self): return self._exp_after_log_test_data( self.metric_args_list, self.space_list, self.n_points_list, rtol=gs.rtol * 100, atol=gs.atol * 10000, ) def retraction_shape_test_data(self): return self.exp_shape_test_data() def lifting_shape_test_data(self): return self.log_shape_test_data()