def test_multiply_by_i(self): x = get_random_symmetric_matrices(10, 4) result = sm.multiply_by_i(x) expected_real = sm.real(result) expected_imag = sm.imag(result) self.assertAllEqual(expected_real, -sm.imag(x)) self.assertAllEqual(expected_imag, sm.real(x))
def projx(self, z: torch.Tensor) -> torch.Tensor: """ Project point :math:`z` on the manifold. In this space, we need to ensure that Y = Im(X) is positive definite. Since the matrix Y is symmetric, it is possible to diagonalize it. For a diagonal matrix the condition is just that all diagonal entries are positive, so we clamp the values that are <=0 in the diagonal to an epsilon, and then restore the matrix back into non-diagonal form using the base change matrix that was obtained from the diagonalization. Steps to project: Y = Im(z) 1) Y = SDS^-1 2) D_tilde = clamp(D, min=epsilon) 3) Y_tilde = SD_tildeS^-1 :param z: points to be projected: (b, 2, n, n) """ z = super().projx(z) y = sm.imag(z) y_tilde, batchwise_mask = sm.positive_conjugate_projection(y) self.projected_points += len(z) - sum(batchwise_mask).item() return sm.stick(sm.real(z), y_tilde)
def inner(self, z: torch.Tensor, u: torch.Tensor, v=None, *, keepdim=False) -> torch.Tensor: """ Inner product for tangent vectors at point :math:`z`. For the upper half space model, the inner product at point z = x + iy of the vectors u, v it is (z, u, v are complex symmetric matrices): g_{z}(u, v) = tr[ y^-1 u y^-1 \ov{v} ] :param z: torch.Tensor point on the manifold: b x 2 x n x n :param u: torch.Tensor tangent vector at point :math:`z`: b x 2 x n x n :param v: Optional[torch.Tensor] tangent vector at point :math:`z`: b x 2 x n x n :param keepdim: bool keep the last dim? :return: torch.Tensor inner product (broadcastable): b x 2 x 1 x 1 """ if v is None: v = u inv_imag_z = torch.inverse(sm.imag(z)) inv_imag_z = sm.stick(inv_imag_z, torch.zeros_like(inv_imag_z)) res = sm.bmm3(inv_imag_z, u, inv_imag_z) res = sm.bmm(res, sm.conjugate(v)) real_part = sm.trace(sm.real(res), keepdim=True) # b x 1 real_part = torch.unsqueeze(real_part, -1) # b x 1 x 1 return sm.stick(real_part, real_part) # b x 2 x 1 x 1
def test_stick(self): x_real = torch.rand(10, 2, 4, 4) x_imag = torch.rand(10, 2, 4, 4) x = sm.stick(x_real, x_imag) self.assertAllEqual(x_real, sm.real(x)) self.assertAllEqual(x_imag, sm.imag(x))
def test_distance_is_symmetric_real_neg_imag_pos(self): x = self.manifold.random(10) x = sm.stick(sm.real(x) * -1, sm.imag(x)) y = self.manifold.random(10) dist_xy = self.manifold.dist(x, y) dist_yx = self.manifold.dist(y, x) self.assertAllClose(dist_xy, dist_yx)
def test_to_symmetric(self): x = get_random_symmetric_matrices(10, 4) x_real = sm.real(x) x_imag = sm.imag(x) x_real_transpose = x_real.transpose(-1, -2) x_imag_transpose = x_imag.transpose(-1, -2) self.assertAllEqual(x_real, x_real_transpose) self.assertAllEqual(x_imag, x_imag_transpose)
def test_distance_is_symmetric_only_imaginary_matrices(self): x = self.manifold.random(10) y = self.manifold.random(10) zeros = torch.zeros_like(sm.real(x)) x = sm.stick(zeros, sm.imag(x)) y = sm.stick(zeros, sm.imag(y)) dist_xy = self.manifold.dist(x, y) dist_yx = self.manifold.dist(y, x) self.assertAllClose(dist_xy, dist_yx)
def test_takagi_factorization_real_neg_imag_neg(self): a = get_random_symmetric_matrices(3, 3) a = sm.stick(sm.real(a) * -1, sm.imag(a) * -1) eigenvalues, s = TakagiFactorization(3).factorize(a) diagonal = torch.diag_embed(eigenvalues) diagonal = sm.stick(diagonal, torch.zeros_like(diagonal)) self.assertAllClose( a, sm.bmm3(sm.conjugate(s), diagonal, sm.conj_trans(s)))
def test_proj_x_real_neg_imag_neg(self): x = get_random_symmetric_matrices(10, self.dims) x = sm.stick(sm.real(x) * -1, sm.imag(x) * -1) proj_x = self.manifold.projx(x) # assert symmetry self.assertAllClose(proj_x, sm.transpose(proj_x)) # assert all points belong to the manifold for point in proj_x: self.assertTrue(self.manifold.check_point_on_manifold(point))
def test_conj_transpose(self): x = get_random_symmetric_matrices(10, 4) x_real = sm.real(x) x_imag = sm.imag(x) x_real_transpose = x_real.transpose(-1, -2) x_imag_transpose = x_imag.transpose(-1, -2) x_expected_conj_transpose = sm.stick(x_real_transpose, x_imag_transpose * -1) x_result_conj_transpose = sm.conj_trans(x) self.assertAllEqual(x_expected_conj_transpose, x_result_conj_transpose)
def test_transpose(self): x = get_random_symmetric_matrices(10, 4) x_real = sm.real(x) x_imag = sm.imag(x) x_real_transpose = x_real.transpose(-1, -2) x_imag_transpose = x_imag.transpose(-1, -2) x_expected_transpose = sm.stick(x_real_transpose, x_imag_transpose) x_result_transpose = sm.transpose(x) self.assertAllEqual(x_expected_transpose, x_result_transpose) self.assertAllEqual( x, x_result_transpose) # because they are symmetric matrices
def test_cayley_transform_real_neg_imag_pos(self): x = self.upper_half_manifold.random(10) x = sm.stick(sm.real(x) * -1, sm.imag(x)) tran_x = cayley_transform(x) result = inverse_cayley_transform(tran_x) self.assertAllClose(x, result) # the intermediate result belongs to the Bounded domain manifold for point in tran_x: self.assertTrue(self.bounded_manifold.check_point_on_manifold(point)) # the final result belongs to the Upper Half Space manifold for point in result: self.assertTrue(self.upper_half_manifold.check_point_on_manifold(point))
def cayley_transform(z: torch.Tensor) -> torch.Tensor: """ T(Z): Upper Half Space model -> Bounded Domain Model T(Z) = (Z - i Id)(Z + i Id)^-1 :param z: b x 2 x n x n: PRE: z \in Upper Half Space Manifold :return: y \in Bounded Domain Manifold """ identity = sm.identity_like(z) i_identity = sm.stick(sm.imag(identity), sm.real(identity)) z_minus_id = sm.subtract(z, i_identity) inv_z_plus_id = sm.inverse(sm.add(z, i_identity)) return sm.bmm(z_minus_id, inv_z_plus_id)
def test_takagi_factorization_real_diagonal(self): a = get_random_symmetric_matrices(3, 3) * 10 a = torch.where(sm.identity_like(a) == 1, a, torch.zeros_like(a)) eigenvalues, s = TakagiFactorization(3).factorize(a) diagonal = torch.diag_embed(eigenvalues) diagonal = sm.stick(diagonal, torch.zeros_like(diagonal)) self.assertAllClose( a, sm.bmm3(sm.conjugate(s), diagonal, sm.conj_trans(s))) # real part of eigenvectors is made of vectors with one 1 and all zeros real_part = torch.sum(torch.abs(sm.real(s)), dim=-1) self.assertAllClose(torch.ones_like(real_part), real_part) # imaginary part of eigenvectors is all zeros self.assertAllClose(torch.zeros(1), torch.sum(sm.imag(s)))
def egrad2rgrad(self, z: torch.Tensor, u: torch.Tensor) -> torch.Tensor: """ Transform gradient computed using autodiff to the correct Riemannian gradient for the point :math:`x`. If you have a function f(z) on Hn, then the gradient is the y * grad_eucl(f(z)) * y, where y is the imaginary part of z, and multiplication is just matrix multiplication. :param z: point on the manifold. Shape: (b, 2, n, n) :param u: gradient to be projected: Shape: same than z :return grad vector in the Riemannian manifold. Shape: same than z """ real_grad, imag_grad = sm.real(u), sm.imag(u) y = sm.imag(z) real_grad = y.bmm(real_grad).bmm(y) imag_grad = y.bmm(imag_grad).bmm(y) return sm.stick(real_grad, imag_grad)
def test_to_hermitian(self): m = torch.rand(4, 2, 3, 3) h = sm.to_hermitian(m) h_real, h_imag = sm.real(h), sm.imag(h) # 1 - real part is symmetric self.assertAllEqual(h_real, h_real.transpose(-1, -2)) # 2 - Imaginary diagonal must be 0 imag_diag = torch.diagonal(h_imag, dim1=-2, dim2=-1) self.assertAllEqual(imag_diag, torch.zeros_like(imag_diag)) # 3 - imaginary elements in the upper triangular part of the matrix must be of opposite sign than the # elements in the lower triangular part imag_triu = torch.triu(h_imag, diagonal=1) imag_tril = torch.tril(h_imag, diagonal=-1) self.assertAllEqual(imag_triu, imag_tril.transpose(-1, -2) * -1)
def inner(self, z: torch.Tensor, u: torch.Tensor, v=None, *, keepdim=False) -> torch.Tensor: """ Inner product for tangent vectors at point :math:`z`. For the bounded domain model, the inner product at point z of the vectors u, v it is (z, u, v are complex symmetric matrices): g_{z}(u, v) = tr[ (Id - ẑz)^-1 u (Id - zẑ)^-1 \ov{v} ] :param z: torch.Tensor point on the manifold: b x 2 x n x n :param u: torch.Tensor tangent vector at point :math:`z`: b x 2 x n x n :param v: Optional[torch.Tensor] tangent vector at point :math:`z`: b x 2 x n x n :param keepdim: bool keep the last dim? :return: torch.Tensor inner product (broadcastable): b x 2 x 1 x 1 """ if v is None: v = u identity = sm.identity_like(z) conj_z = sm.conjugate(z) conj_z_z = sm.bmm(conj_z, z) z_conj_z = sm.bmm(z, conj_z) inv_id_minus_conj_z_z = sm.subtract(identity, conj_z_z) inv_id_minus_z_conj_z = sm.subtract(identity, z_conj_z) inv_id_minus_conj_z_z = sm.inverse(inv_id_minus_conj_z_z) inv_id_minus_z_conj_z = sm.inverse(inv_id_minus_z_conj_z) res = sm.bmm3(inv_id_minus_conj_z_z, u, inv_id_minus_z_conj_z) res = sm.bmm(res, sm.conjugate(v)) real_part = sm.trace(sm.real(res), keepdim=True) real_part = torch.unsqueeze(real_part, -1) # b x 1 x 1 return sm.stick(real_part, real_part) # # b x 2 x 1 x 1