def closest_rotation_matrix(mat):
    """
    Compute the closest rotation matrix of a given matrix mat,
    in terms of the Frobenius norm.
    """
    mat = gs.to_ndarray(mat, to_ndim=3)

    n_mats, mat_dim_1, mat_dim_2 = mat.shape
    assert mat_dim_1 == mat_dim_2

    if mat_dim_1 == 3:
        mat_unitary_u, diag_s, mat_unitary_v = gs.linalg.svd(mat)
        rot_mat = gs.matmul(mat_unitary_u, mat_unitary_v)

        mask = gs.where(gs.linalg.det(rot_mat) < 0)
        new_mat_diag_s = gs.tile(gs.diag([1, 1, -1]), len(mask))

        rot_mat[mask] = gs.matmul(
            gs.matmul(mat_unitary_u[mask], new_mat_diag_s),
            mat_unitary_v[mask])
    else:
        aux_mat = gs.matmul(gs.transpose(mat, axes=(0, 2, 1)), mat)

        inv_sqrt_mat = gs.zeros_like(mat)
        for i in range(n_mats):
            sym_mat = aux_mat[i]

            assert spd_matrices_space.is_symmetric(sym_mat)
            inv_sqrt_mat[i] = gs.linalg.inv(spd_matrices_space.sqrtm(sym_mat))
        rot_mat = gs.matmul(mat, inv_sqrt_mat)

    assert rot_mat.ndim == 3
    return rot_mat
Пример #2
0
    def projection(self, mat):
        """
        Project a matrix on SO(n), using the Frobenius norm.
        """
        # TODO(nina): projection when the point_type is not 'matrix'?
        mat = gs.to_ndarray(mat, to_ndim=3)

        n_mats, mat_dim_1, mat_dim_2 = mat.shape
        assert mat_dim_1 == mat_dim_2 == self.n

        if self.n == 3:
            mat_unitary_u, diag_s, mat_unitary_v = gs.linalg.svd(mat)
            rot_mat = gs.matmul(mat_unitary_u, mat_unitary_v)
            mask = gs.nonzero(gs.linalg.det(rot_mat) < 0)
            diag = gs.array([1, 1, -1])
            new_mat_diag_s = gs.tile(gs.diag(diag), len(mask))

            rot_mat[mask] = gs.matmul(
                gs.matmul(mat_unitary_u[mask], new_mat_diag_s),
                mat_unitary_v[mask])
        else:
            aux_mat = gs.matmul(gs.transpose(mat, axes=(0, 2, 1)), mat)

            inv_sqrt_mat = gs.zeros_like(mat)
            for i in range(n_mats):
                sym_mat = aux_mat[i]

                assert spd_matrices_space.is_symmetric(sym_mat)
                inv_sqrt_mat[i] = gs.linalg.inv(
                    spd_matrices_space.sqrtm(sym_mat))
            rot_mat = gs.matmul(mat, inv_sqrt_mat)

        assert gs.ndim(rot_mat) == 3
        return rot_mat
Пример #3
0
    def test_make_symmetric_and_is_symmetric_vectorization(self):
        n_samples = self.n_samples
        mats = gs.random.rand(n_samples, 5, 5)

        results = spd_matrices_space.make_symmetric(mats)
        self.assertTrue(gs.all(spd_matrices_space.is_symmetric(results)))
Пример #4
0
 def test_is_symmetric_vectorization(self):
     n_samples = self.n_samples
     points = self.space.random_uniform(n_samples=n_samples)
     self.assertTrue(gs.all(spd_matrices_space.is_symmetric(points)))
Пример #5
0
    def test_is_symmetric(self):
        sym_mat = gs.array([[1, 2], [2, 1]])
        self.assertTrue(spd_matrices_space.is_symmetric(sym_mat))

        not_a_sym_mat = gs.array([[1., 0.6, -3.], [6., -7., 0.], [0., 7., 8.]])
        self.assertFalse(spd_matrices_space.is_symmetric(not_a_sym_mat))