def closest_rotation_matrix(mat): """ Compute the closest rotation matrix of a given matrix mat, in terms of the Frobenius norm. """ mat = gs.to_ndarray(mat, to_ndim=3) n_mats, mat_dim_1, mat_dim_2 = mat.shape assert mat_dim_1 == mat_dim_2 if mat_dim_1 == 3: mat_unitary_u, diag_s, mat_unitary_v = gs.linalg.svd(mat) rot_mat = gs.matmul(mat_unitary_u, mat_unitary_v) mask = gs.where(gs.linalg.det(rot_mat) < 0) new_mat_diag_s = gs.tile(gs.diag([1, 1, -1]), len(mask)) rot_mat[mask] = gs.matmul( gs.matmul(mat_unitary_u[mask], new_mat_diag_s), mat_unitary_v[mask]) else: aux_mat = gs.matmul(gs.transpose(mat, axes=(0, 2, 1)), mat) inv_sqrt_mat = gs.zeros_like(mat) for i in range(n_mats): sym_mat = aux_mat[i] assert spd_matrices_space.is_symmetric(sym_mat) inv_sqrt_mat[i] = gs.linalg.inv(spd_matrices_space.sqrtm(sym_mat)) rot_mat = gs.matmul(mat, inv_sqrt_mat) assert rot_mat.ndim == 3 return rot_mat
def projection(self, mat): """ Project a matrix on SO(n), using the Frobenius norm. """ # TODO(nina): projection when the point_type is not 'matrix'? mat = gs.to_ndarray(mat, to_ndim=3) n_mats, mat_dim_1, mat_dim_2 = mat.shape assert mat_dim_1 == mat_dim_2 == self.n if self.n == 3: mat_unitary_u, diag_s, mat_unitary_v = gs.linalg.svd(mat) rot_mat = gs.matmul(mat_unitary_u, mat_unitary_v) mask = gs.nonzero(gs.linalg.det(rot_mat) < 0) diag = gs.array([1, 1, -1]) new_mat_diag_s = gs.tile(gs.diag(diag), len(mask)) rot_mat[mask] = gs.matmul( gs.matmul(mat_unitary_u[mask], new_mat_diag_s), mat_unitary_v[mask]) else: aux_mat = gs.matmul(gs.transpose(mat, axes=(0, 2, 1)), mat) inv_sqrt_mat = gs.zeros_like(mat) for i in range(n_mats): sym_mat = aux_mat[i] assert spd_matrices_space.is_symmetric(sym_mat) inv_sqrt_mat[i] = gs.linalg.inv( spd_matrices_space.sqrtm(sym_mat)) rot_mat = gs.matmul(mat, inv_sqrt_mat) assert gs.ndim(rot_mat) == 3 return rot_mat
def test_make_symmetric_and_is_symmetric_vectorization(self): n_samples = self.n_samples mats = gs.random.rand(n_samples, 5, 5) results = spd_matrices_space.make_symmetric(mats) self.assertTrue(gs.all(spd_matrices_space.is_symmetric(results)))
def test_is_symmetric_vectorization(self): n_samples = self.n_samples points = self.space.random_uniform(n_samples=n_samples) self.assertTrue(gs.all(spd_matrices_space.is_symmetric(points)))
def test_is_symmetric(self): sym_mat = gs.array([[1, 2], [2, 1]]) self.assertTrue(spd_matrices_space.is_symmetric(sym_mat)) not_a_sym_mat = gs.array([[1., 0.6, -3.], [6., -7., 0.], [0., 7., 8.]]) self.assertFalse(spd_matrices_space.is_symmetric(not_a_sym_mat))