示例#1
0
 def pdist(self, x, squared=False):
     assert x.ndim == 3
     n = x.shape[0]
     l_inv, _ = self.invchol(x)
     m = torch.triu_indices(n, n, 1, device=x.device)
     lylt = tb.axat(l_inv[m[0]], x[m[1]])
     return self._norm_log(lylt, squared=squared)
示例#2
0
    def exp(self, x, u):
        # TODO(ccruceru): Replace this with :math:`X \exp(X^{-1} U)` once
        # general matrix exponential is supported.
        lult, l = self._lult(x, u, ret_chol=True)
        exp_lult = tb.symexpm(lult)
        expx_u = tb.axat(l, exp_lult)

        return expx_u
示例#3
0
 def randvec(self, x, norm=1):
     # The tangent vector is :math:`X^{1/2} U X^{1/2}`, where :math:`U` is a
     # symmetric matrix uniformly sampled from the unit sphere (in the
     # corresponding vector space of symmetric matrices).
     # This corresponds to the parallel transport of ``U`` from the identity
     # matrix to ``X``.
     shape = x.shape[:-2] + (self.dim, )
     u = torch.randn(shape, out=x.new(shape))
     u.div_(u.norm(dim=-1, keepdim=True)).mul_(norm)
     u = self.from_vec(u)
     x_sqrt = tb.spdsqrtm(x, wmin=self.wmin, wmax=self.wmax)
     return tb.axat(x_sqrt, u)
示例#4
0
    def log(self, x, y):
        lylt, l = self._lult(x, y, ret_chol=True)
        log_lult = tb.spdlogm(lylt)
        logx_y = tb.axat(l, log_lult)

        return logx_y
示例#5
0
 def egrad2rgrad(self, x, u):
     return tb.axat(x, self.proju(x, u))
示例#6
0
 def _lult(self, x, u, ret_chol=False):
     l_inv, l = self.invchol(x, ret_chol=ret_chol)
     lult = tb.axat(l_inv, u)
     return lult, l