def inner(self, z: torch.Tensor, u: torch.Tensor, v=None, *, keepdim=False) -> torch.Tensor: """ Inner product for tangent vectors at point :math:`z`. For the upper half space model, the inner product at point z = x + iy of the vectors u, v it is (z, u, v are complex symmetric matrices): g_{z}(u, v) = tr[ y^-1 u y^-1 \ov{v} ] :param z: torch.Tensor point on the manifold: b x 2 x n x n :param u: torch.Tensor tangent vector at point :math:`z`: b x 2 x n x n :param v: Optional[torch.Tensor] tangent vector at point :math:`z`: b x 2 x n x n :param keepdim: bool keep the last dim? :return: torch.Tensor inner product (broadcastable): b x 2 x 1 x 1 """ if v is None: v = u inv_imag_z = torch.inverse(sm.imag(z)) inv_imag_z = sm.stick(inv_imag_z, torch.zeros_like(inv_imag_z)) res = sm.bmm3(inv_imag_z, u, inv_imag_z) res = sm.bmm(res, sm.conjugate(v)) real_part = sm.trace(sm.real(res), keepdim=True) # b x 1 real_part = torch.unsqueeze(real_part, -1) # b x 1 x 1 return sm.stick(real_part, real_part) # b x 2 x 1 x 1
def egrad2rgrad(self, z: torch.Tensor, u: torch.Tensor) -> torch.Tensor: """ Transform gradient computed using autodiff to the correct Riemannian gradient for the point :math:`x`. If you have a function f(z) on Mn, then the Riemannian gradient is grad_R(f(z)) = (Id + ẑz) * grad_E(f(z)) * (Id + zẑ) :param z: point on the manifold. Shape: (b, 2, n, n) :param u: gradient to be projected: Shape: same than z :return grad vector in the Riemannian manifold. Shape: same than z """ id = sm.identity_like(z) conjz = sm.conjugate(z) id_plus_conjz_z = id + sm.bmm(conjz, z) id_plus_z_conjz = id + sm.bmm(z, conjz) riem_grad = sm.bmm3(id_plus_conjz_z, u, id_plus_z_conjz) return riem_grad
def get_id_minus_conjugate_z_times_z(z: torch.Tensor): """ :param z: b x 2 x n x n :return: Id - \overline(z)z """ identity = sm.identity_like(z) conj_z_z = sm.bmm(sm.conjugate(z), z) return sm.subtract(identity, conj_z_z)
def dist(self, w: torch.Tensor, x: torch.Tensor, *, keepdim=False) -> torch.Tensor: """ Given W, X in the compact dual: 1 - TakagiFact(W) -> W = ÛPU* 2 - U unitary, P diagonal, Û: U conjugate, U*: U conjugate transpose 3 - Define A = (Id + P^2)^(-1/2) 4 - Define M = [(A -AP), (AP A)] * [(U^t 0), (0 U)] 5 - MW = 0 by construction. Y = MX implies Y = [(A -AP), (AP A)] * [(U^t 0), (0 U)] * X Y = [(A -AP), (AP A)] * U^tXU Lets call Q = U^tXU Y = (AQ - AP) (APQ + A)^-1 6 - TakagiFact(Y) = ŜDS* 7 - Distance = sqrt[ sum ( arctan(d_k)^2 ) ] with d_k the diagonal entries of D :param w, x: b x 2 x n x n: elements in the Compact Dual :param keepdim: :return: distance between w and x in the compact dual """ p, u = self.takagi_factorization.factorize(w) # p: b x n, u: b x 2 x n x n # Define A: since (Id + P^2) is diagonal, taking the matrix sqrt is just the sqrt of the entries # Moreover, then taking the inverse of that is taking the inverse of the entries. a = 1 + p**2 a = 1 / torch.sqrt(a) a = sm.diag_embed(a) p = sm.diag_embed(p) q = sm.bmm3(sm.transpose(u), x, u) ap = sm.bmm(a, p) aq_minus_ap = sm.subtract(sm.bmm(a, q), ap) apq_plus_a_inv = sm.add(sm.bmm(ap, q), a) apq_plus_a_inv = sm.inverse(apq_plus_a_inv) y = sm.bmm(aq_minus_ap, apq_plus_a_inv) # b x 2 x n x n d, s = self.takagi_factorization.factorize(y) # d = b x n d = torch.atan(d) dist = self.metric.compute_metric(d) return dist
def inverse_cayley_transform(z: torch.Tensor) -> torch.Tensor: """ T^-1(Z): Bounded Domain Model -> Upper Half Space model T^-1(Z) = i (Id + Z)(Id - Z)^-1 :param z: b x 2 x n x n: PRE: z \in Bounded Domain Manifold :return: y \in Upper Half Space Manifold """ identity = sm.identity_like(z) i_z_plus_id = sm.multiply_by_i(sm.add(identity, z)) inv_z_minus_id = sm.inverse(sm.subtract(identity, z)) return sm.bmm(i_z_plus_id, inv_z_minus_id)
def inner(self, z: torch.Tensor, u: torch.Tensor, v=None, *, keepdim=False) -> torch.Tensor: """ Inner product for tangent vectors at point :math:`z`. For the bounded domain model, the inner product at point z of the vectors u, v it is (z, u, v are complex symmetric matrices): g_{z}(u, v) = tr[ (Id - ẑz)^-1 u (Id - zẑ)^-1 \ov{v} ] :param z: torch.Tensor point on the manifold: b x 2 x n x n :param u: torch.Tensor tangent vector at point :math:`z`: b x 2 x n x n :param v: Optional[torch.Tensor] tangent vector at point :math:`z`: b x 2 x n x n :param keepdim: bool keep the last dim? :return: torch.Tensor inner product (broadcastable): b x 2 x 1 x 1 """ if v is None: v = u identity = sm.identity_like(z) conj_z = sm.conjugate(z) conj_z_z = sm.bmm(conj_z, z) z_conj_z = sm.bmm(z, conj_z) inv_id_minus_conj_z_z = sm.subtract(identity, conj_z_z) inv_id_minus_z_conj_z = sm.subtract(identity, z_conj_z) inv_id_minus_conj_z_z = sm.inverse(inv_id_minus_conj_z_z) inv_id_minus_z_conj_z = sm.inverse(inv_id_minus_z_conj_z) res = sm.bmm3(inv_id_minus_conj_z_z, u, inv_id_minus_z_conj_z) res = sm.bmm(res, sm.conjugate(v)) real_part = sm.trace(sm.real(res), keepdim=True) real_part = torch.unsqueeze(real_part, -1) # b x 1 x 1 return sm.stick(real_part, real_part) # # b x 2 x 1 x 1
def cayley_transform(z: torch.Tensor) -> torch.Tensor: """ T(Z): Upper Half Space model -> Bounded Domain Model T(Z) = (Z - i Id)(Z + i Id)^-1 :param z: b x 2 x n x n: PRE: z \in Upper Half Space Manifold :return: y \in Bounded Domain Manifold """ identity = sm.identity_like(z) i_identity = sm.stick(sm.imag(identity), sm.real(identity)) z_minus_id = sm.subtract(z, i_identity) inv_z_plus_id = sm.inverse(sm.add(z, i_identity)) return sm.bmm(z_minus_id, inv_z_plus_id)
def test_bmm(self): x_real = torch.Tensor([[[1, -3], [5, -7]]]) x_imag = torch.Tensor([[[9, -11], [-14, 15]]]) x = sm.stick(x_real, x_imag) y_real = torch.Tensor([[[9, -11], [-14, 15]]]) y_imag = torch.Tensor([[[1, -3], [5, -7]]]) y = sm.stick(y_real, y_imag) expected_real = torch.Tensor([[[97, -106], [82, -97]]]) expected_imag = torch.Tensor([[[221, -246], [-366, 413]]]) expected = sm.stick(expected_real, expected_imag) result = sm.bmm(x, y) self.assertAllClose(expected, result)