Exemplo n.º 1
0
    def projx(self, z: torch.Tensor) -> torch.Tensor:
        """
        Project point :math:`z` on the manifold.

        In this space, we need to ensure that Y = Id - \overline(Z)Z is positive definite.

        Steps to project: Z complex symmetric matrix
        1) Z = SDS^-1
        2) D_tilde = clamp(D, max=1 - epsilon)
        3) Z_tilde = Ŝ D_tilde S^*

        :param z: points to be projected: (b, 2, n, n)
        """
        z = super().projx(z)

        eigenvalues, s = self.takagi_factorization.factorize(z)
        eigenvalues_tilde = torch.clamp(eigenvalues, max=1 - sm.EPS[z.dtype])

        diag_tilde = sm.diag_embed(eigenvalues_tilde)

        z_tilde = sm.bmm3(sm.conjugate(s), diag_tilde, sm.conj_trans(s))

        # we do this so no operation is applied on the matrices that already belong to the space.
        # This prevents modifying values due to numerical instabilities/floating point ops
        batch_wise_mask = torch.all(eigenvalues < 1 - sm.EPS[z.dtype],
                                    dim=-1,
                                    keepdim=True)
        already_in_space_mask = batch_wise_mask.unsqueeze(-1).unsqueeze(
            -1).expand_as(z)

        self.projected_points += len(z) - sum(batch_wise_mask).item()

        return torch.where(already_in_space_mask, z, z_tilde)
Exemplo n.º 2
0
    def inner(self,
              z: torch.Tensor,
              u: torch.Tensor,
              v=None,
              *,
              keepdim=False) -> torch.Tensor:
        """
        Inner product for tangent vectors at point :math:`z`.
        For the upper half space model, the inner product at point z = x + iy of the vectors u, v
        it is (z, u, v are complex symmetric matrices):

        g_{z}(u, v) = tr[ y^-1 u y^-1 \ov{v} ]

        :param z: torch.Tensor point on the manifold: b x 2 x n x n
        :param u: torch.Tensor tangent vector at point :math:`z`: b x 2 x n x n
        :param v: Optional[torch.Tensor] tangent vector at point :math:`z`: b x 2 x n x n
        :param keepdim: bool keep the last dim?
        :return: torch.Tensor inner product (broadcastable): b x 2 x 1 x 1
        """
        if v is None:
            v = u
        inv_imag_z = torch.inverse(sm.imag(z))
        inv_imag_z = sm.stick(inv_imag_z, torch.zeros_like(inv_imag_z))

        res = sm.bmm3(inv_imag_z, u, inv_imag_z)
        res = sm.bmm(res, sm.conjugate(v))
        real_part = sm.trace(sm.real(res), keepdim=True)  # b x 1
        real_part = torch.unsqueeze(real_part, -1)  # b x 1 x 1
        return sm.stick(real_part, real_part)  # b x 2 x 1 x 1
Exemplo n.º 3
0
    def test_takagi_factorization_very_large_values(self):
        a = get_random_symmetric_matrices(3, 3) * 1000

        eigenvalues, s = TakagiFactorization(3).factorize(a)

        diagonal = torch.diag_embed(eigenvalues)
        diagonal = sm.stick(diagonal, torch.zeros_like(diagonal))

        self.assertAllClose(
            a, sm.bmm3(sm.conjugate(s), diagonal, sm.conj_trans(s)))
Exemplo n.º 4
0
    def test_takagi_factorization_real_neg_imag_neg(self):
        a = get_random_symmetric_matrices(3, 3)
        a = sm.stick(sm.real(a) * -1, sm.imag(a) * -1)

        eigenvalues, s = TakagiFactorization(3).factorize(a)

        diagonal = torch.diag_embed(eigenvalues)
        diagonal = sm.stick(diagonal, torch.zeros_like(diagonal))

        self.assertAllClose(
            a, sm.bmm3(sm.conjugate(s), diagonal, sm.conj_trans(s)))
Exemplo n.º 5
0
    def test_takagi_factorization_real_identity(self):
        a = sm.identity_like(get_random_symmetric_matrices(3, 3))

        eigenvalues, s = TakagiFactorization(3).factorize(a)

        diagonal = torch.diag_embed(eigenvalues)
        diagonal = sm.stick(diagonal, torch.zeros_like(diagonal))

        self.assertAllClose(
            a, sm.bmm3(sm.conjugate(s), diagonal, sm.conj_trans(s)))
        self.assertAllClose(a, s)
        self.assertAllClose(torch.ones_like(eigenvalues), eigenvalues)
Exemplo n.º 6
0
    def egrad2rgrad(self, z: torch.Tensor, u: torch.Tensor) -> torch.Tensor:
        """
        Transform gradient computed using autodiff to the correct Riemannian gradient for the point :math:`x`.

        If you have a function f(z) on Hn, then the gradient is the  A * grad_eucl(f(z)) * A,
        where A = (Id - \overline{Z}Z)
        :param z: point on the manifold. Shape: (b, 2, n, n)
        :param u: gradient to be projected: Shape: same than z
        :return grad vector in the Riemannian manifold. Shape: same than z
        """
        a = get_id_minus_conjugate_z_times_z(z)
        a_times_grad_times_a = sm.bmm3(a, u, a)
        return a_times_grad_times_a
Exemplo n.º 7
0
    def test_takagi_factorization_real_diagonal(self):
        a = get_random_symmetric_matrices(3, 3) * 10
        a = torch.where(sm.identity_like(a) == 1, a, torch.zeros_like(a))

        eigenvalues, s = TakagiFactorization(3).factorize(a)

        diagonal = torch.diag_embed(eigenvalues)
        diagonal = sm.stick(diagonal, torch.zeros_like(diagonal))

        self.assertAllClose(
            a, sm.bmm3(sm.conjugate(s), diagonal, sm.conj_trans(s)))
        # real part of eigenvectors is made of vectors with one 1 and all zeros
        real_part = torch.sum(torch.abs(sm.real(s)), dim=-1)
        self.assertAllClose(torch.ones_like(real_part), real_part)
        # imaginary part of eigenvectors is all zeros
        self.assertAllClose(torch.zeros(1), torch.sum(sm.imag(s)))
Exemplo n.º 8
0
    def egrad2rgrad(self, z: torch.Tensor, u: torch.Tensor) -> torch.Tensor:
        """
        Transform gradient computed using autodiff to the correct Riemannian gradient for the point :math:`x`.

        If you have a function f(z) on Mn, then the Riemannian gradient is
            grad_R(f(z)) = (Id + ẑz) * grad_E(f(z)) * (Id + zẑ)

        :param z: point on the manifold. Shape: (b, 2, n, n)
        :param u: gradient to be projected: Shape: same than z
        :return grad vector in the Riemannian manifold. Shape: same than z
        """
        id = sm.identity_like(z)
        conjz = sm.conjugate(z)
        id_plus_conjz_z = id + sm.bmm(conjz, z)
        id_plus_z_conjz = id + sm.bmm(z, conjz)
        riem_grad = sm.bmm3(id_plus_conjz_z, u, id_plus_z_conjz)
        return riem_grad
Exemplo n.º 9
0
    def test_bmm3(self):
        x_real = torch.Tensor([[[1, -3], [5, -7]]])
        x_imag = torch.Tensor([[[9, -11], [-14, 15]]])
        x = sm.stick(x_real, x_imag)

        y_real = torch.Tensor([[[9, -11], [-14, 15]]])
        y_imag = torch.Tensor([[[1, -3], [5, -7]]])
        y = sm.stick(y_real, y_imag)

        z_real = torch.Tensor([[[-3, -1], [-2, 5]]])
        z_imag = torch.Tensor([[[-1, 3], [0, -2]]])
        z = sm.stick(z_real, z_imag)

        expected_real = torch.Tensor([[[142, -1782], [-418, 1357]]])
        expected_imag = torch.Tensor([[[-268, -948], [190, 2871]]])
        expected = sm.stick(expected_real, expected_imag)

        result = sm.bmm3(x, y, z)

        self.assertAllClose(expected, result)
Exemplo n.º 10
0
    def dist(self, w: torch.Tensor, x: torch.Tensor, *, keepdim=False) -> torch.Tensor:
        """
        Given  W, X in the compact dual:
            1 - TakagiFact(W) -> W = ÛPU*
            2 - U unitary, P diagonal, Û: U conjugate, U*: U conjugate transpose
            3 - Define A = (Id + P^2)^(-1/2)
            4 - Define M = [(A  -AP), (AP  A)] * [(U^t  0), (0  U)]
            5 - MW = 0 by construction. Y = MX implies
                Y = [(A  -AP), (AP  A)] * [(U^t  0), (0  U)] * X
                Y = [(A  -AP), (AP  A)] * U^tXU        Lets call Q = U^tXU
                Y = (AQ - AP) (APQ + A)^-1
            6 - TakagiFact(Y) = ŜDS*
            7 - Distance = sqrt[ sum ( arctan(d_k)^2 ) ]  with d_k the diagonal entries of D

        :param w, x: b x 2 x n x n: elements in the Compact Dual
        :param keepdim:
        :return: distance between w and x in the compact dual
        """
        p, u = self.takagi_factorization.factorize(w)  # p: b x n, u: b x 2 x n x n

        # Define A: since (Id + P^2) is diagonal, taking the matrix sqrt is just the sqrt of the entries
        # Moreover, then taking the inverse of that is taking the inverse of the entries.
        a = 1 + p**2
        a = 1 / torch.sqrt(a)
        a = sm.diag_embed(a)
        p = sm.diag_embed(p)

        q = sm.bmm3(sm.transpose(u), x, u)
        ap = sm.bmm(a, p)
        aq_minus_ap = sm.subtract(sm.bmm(a, q), ap)
        apq_plus_a_inv = sm.add(sm.bmm(ap, q), a)
        apq_plus_a_inv = sm.inverse(apq_plus_a_inv)
        y = sm.bmm(aq_minus_ap, apq_plus_a_inv)                     # b x 2 x n x n

        d, s = self.takagi_factorization.factorize(y)   # d = b x n

        d = torch.atan(d)
        dist = self.metric.compute_metric(d)
        return dist
Exemplo n.º 11
0
    def inner(self,
              z: torch.Tensor,
              u: torch.Tensor,
              v=None,
              *,
              keepdim=False) -> torch.Tensor:
        """
        Inner product for tangent vectors at point :math:`z`.
        For the bounded domain model, the inner product at point z of the vectors u, v
        it is (z, u, v are complex symmetric matrices):

        g_{z}(u, v) = tr[ (Id -  ẑz)^-1 u (Id - zẑ)^-1 \ov{v} ]

        :param z: torch.Tensor point on the manifold: b x 2 x n x n
        :param u: torch.Tensor tangent vector at point :math:`z`: b x 2 x n x n
        :param v: Optional[torch.Tensor] tangent vector at point :math:`z`: b x 2 x n x n
        :param keepdim: bool keep the last dim?
        :return: torch.Tensor inner product (broadcastable): b x 2 x 1 x 1
        """
        if v is None:
            v = u
        identity = sm.identity_like(z)
        conj_z = sm.conjugate(z)
        conj_z_z = sm.bmm(conj_z, z)
        z_conj_z = sm.bmm(z, conj_z)

        inv_id_minus_conj_z_z = sm.subtract(identity, conj_z_z)
        inv_id_minus_z_conj_z = sm.subtract(identity, z_conj_z)

        inv_id_minus_conj_z_z = sm.inverse(inv_id_minus_conj_z_z)
        inv_id_minus_z_conj_z = sm.inverse(inv_id_minus_z_conj_z)

        res = sm.bmm3(inv_id_minus_conj_z_z, u, inv_id_minus_z_conj_z)
        res = sm.bmm(res, sm.conjugate(v))
        real_part = sm.trace(sm.real(res), keepdim=True)
        real_part = torch.unsqueeze(real_part, -1)  # b x 1 x 1
        return sm.stick(real_part, real_part)  # # b x 2 x 1 x 1