def inner(self, z: torch.Tensor, u: torch.Tensor, v=None, *, keepdim=False) -> torch.Tensor: """ Inner product for tangent vectors at point :math:`z`. For the upper half space model, the inner product at point z = x + iy of the vectors u, v it is (z, u, v are complex symmetric matrices): g_{z}(u, v) = tr[ y^-1 u y^-1 \ov{v} ] :param z: torch.Tensor point on the manifold: b x 2 x n x n :param u: torch.Tensor tangent vector at point :math:`z`: b x 2 x n x n :param v: Optional[torch.Tensor] tangent vector at point :math:`z`: b x 2 x n x n :param keepdim: bool keep the last dim? :return: torch.Tensor inner product (broadcastable): b x 2 x 1 x 1 """ if v is None: v = u inv_imag_z = torch.inverse(sm.imag(z)) inv_imag_z = sm.stick(inv_imag_z, torch.zeros_like(inv_imag_z)) res = sm.bmm3(inv_imag_z, u, inv_imag_z) res = sm.bmm(res, sm.conjugate(v)) real_part = sm.trace(sm.real(res), keepdim=True) # b x 1 real_part = torch.unsqueeze(real_part, -1) # b x 1 x 1 return sm.stick(real_part, real_part) # b x 2 x 1 x 1
def projx(self, z: torch.Tensor) -> torch.Tensor: """ Project point :math:`z` on the manifold. In this space, we need to ensure that Y = Id - \overline(Z)Z is positive definite. Steps to project: Z complex symmetric matrix 1) Z = SDS^-1 2) D_tilde = clamp(D, max=1 - epsilon) 3) Z_tilde = Ŝ D_tilde S^* :param z: points to be projected: (b, 2, n, n) """ z = super().projx(z) eigenvalues, s = self.takagi_factorization.factorize(z) eigenvalues_tilde = torch.clamp(eigenvalues, max=1 - sm.EPS[z.dtype]) diag_tilde = sm.diag_embed(eigenvalues_tilde) z_tilde = sm.bmm3(sm.conjugate(s), diag_tilde, sm.conj_trans(s)) # we do this so no operation is applied on the matrices that already belong to the space. # This prevents modifying values due to numerical instabilities/floating point ops batch_wise_mask = torch.all(eigenvalues < 1 - sm.EPS[z.dtype], dim=-1, keepdim=True) already_in_space_mask = batch_wise_mask.unsqueeze(-1).unsqueeze( -1).expand_as(z) self.projected_points += len(z) - sum(batch_wise_mask).item() return torch.where(already_in_space_mask, z, z_tilde)
def test_conjugate(self): x = get_random_symmetric_matrices(10, 4) x_imag = sm.imag(x) conj_x = sm.conjugate(x) self.assertAllEqual(-x_imag, sm.imag(conj_x)) self.assertAllEqual(x_imag, sm.imag(x))
def get_id_minus_conjugate_z_times_z(z: torch.Tensor): """ :param z: b x 2 x n x n :return: Id - \overline(z)z """ identity = sm.identity_like(z) conj_z_z = sm.bmm(sm.conjugate(z), z) return sm.subtract(identity, conj_z_z)
def test_takagi_factorization_very_large_values(self): a = get_random_symmetric_matrices(3, 3) * 1000 eigenvalues, s = TakagiFactorization(3).factorize(a) diagonal = torch.diag_embed(eigenvalues) diagonal = sm.stick(diagonal, torch.zeros_like(diagonal)) self.assertAllClose( a, sm.bmm3(sm.conjugate(s), diagonal, sm.conj_trans(s)))
def test_takagi_factorization_real_neg_imag_neg(self): a = get_random_symmetric_matrices(3, 3) a = sm.stick(sm.real(a) * -1, sm.imag(a) * -1) eigenvalues, s = TakagiFactorization(3).factorize(a) diagonal = torch.diag_embed(eigenvalues) diagonal = sm.stick(diagonal, torch.zeros_like(diagonal)) self.assertAllClose( a, sm.bmm3(sm.conjugate(s), diagonal, sm.conj_trans(s)))
def test_takagi_factorization_real_identity(self): a = sm.identity_like(get_random_symmetric_matrices(3, 3)) eigenvalues, s = TakagiFactorization(3).factorize(a) diagonal = torch.diag_embed(eigenvalues) diagonal = sm.stick(diagonal, torch.zeros_like(diagonal)) self.assertAllClose( a, sm.bmm3(sm.conjugate(s), diagonal, sm.conj_trans(s))) self.assertAllClose(a, s) self.assertAllClose(torch.ones_like(eigenvalues), eigenvalues)
def inner(self, z: torch.Tensor, u: torch.Tensor, v=None, *, keepdim=False) -> torch.Tensor: """ Inner product for tangent vectors at point :math:`z`. For the bounded domain model, the inner product at point z of the vectors u, v it is (z, u, v are complex symmetric matrices): g_{z}(u, v) = tr[ (Id - ẑz)^-1 u (Id - zẑ)^-1 \ov{v} ] :param z: torch.Tensor point on the manifold: b x 2 x n x n :param u: torch.Tensor tangent vector at point :math:`z`: b x 2 x n x n :param v: Optional[torch.Tensor] tangent vector at point :math:`z`: b x 2 x n x n :param keepdim: bool keep the last dim? :return: torch.Tensor inner product (broadcastable): b x 2 x 1 x 1 """ if v is None: v = u identity = sm.identity_like(z) conj_z = sm.conjugate(z) conj_z_z = sm.bmm(conj_z, z) z_conj_z = sm.bmm(z, conj_z) inv_id_minus_conj_z_z = sm.subtract(identity, conj_z_z) inv_id_minus_z_conj_z = sm.subtract(identity, z_conj_z) inv_id_minus_conj_z_z = sm.inverse(inv_id_minus_conj_z_z) inv_id_minus_z_conj_z = sm.inverse(inv_id_minus_z_conj_z) res = sm.bmm3(inv_id_minus_conj_z_z, u, inv_id_minus_z_conj_z) res = sm.bmm(res, sm.conjugate(v)) real_part = sm.trace(sm.real(res), keepdim=True) real_part = torch.unsqueeze(real_part, -1) # b x 1 x 1 return sm.stick(real_part, real_part) # # b x 2 x 1 x 1
def test_takagi_factorization_real_diagonal(self): a = get_random_symmetric_matrices(3, 3) * 10 a = torch.where(sm.identity_like(a) == 1, a, torch.zeros_like(a)) eigenvalues, s = TakagiFactorization(3).factorize(a) diagonal = torch.diag_embed(eigenvalues) diagonal = sm.stick(diagonal, torch.zeros_like(diagonal)) self.assertAllClose( a, sm.bmm3(sm.conjugate(s), diagonal, sm.conj_trans(s))) # real part of eigenvectors is made of vectors with one 1 and all zeros real_part = torch.sum(torch.abs(sm.real(s)), dim=-1) self.assertAllClose(torch.ones_like(real_part), real_part) # imaginary part of eigenvectors is all zeros self.assertAllClose(torch.zeros(1), torch.sum(sm.imag(s)))
def egrad2rgrad(self, z: torch.Tensor, u: torch.Tensor) -> torch.Tensor: """ Transform gradient computed using autodiff to the correct Riemannian gradient for the point :math:`x`. If you have a function f(z) on Mn, then the Riemannian gradient is grad_R(f(z)) = (Id + ẑz) * grad_E(f(z)) * (Id + zẑ) :param z: point on the manifold. Shape: (b, 2, n, n) :param u: gradient to be projected: Shape: same than z :return grad vector in the Riemannian manifold. Shape: same than z """ id = sm.identity_like(z) conjz = sm.conjugate(z) id_plus_conjz_z = id + sm.bmm(conjz, z) id_plus_z_conjz = id + sm.bmm(z, conjz) riem_grad = sm.bmm3(id_plus_conjz_z, u, id_plus_z_conjz) return riem_grad