def pdist(self, x, squared=False): assert x.ndim == 3 n = x.shape[0] l_inv, _ = self.invchol(x) m = torch.triu_indices(n, n, 1, device=x.device) lylt = tb.axat(l_inv[m[0]], x[m[1]]) return self._norm_log(lylt, squared=squared)
def exp(self, x, u): # TODO(ccruceru): Replace this with :math:`X \exp(X^{-1} U)` once # general matrix exponential is supported. lult, l = self._lult(x, u, ret_chol=True) exp_lult = tb.symexpm(lult) expx_u = tb.axat(l, exp_lult) return expx_u
def randvec(self, x, norm=1): # The tangent vector is :math:`X^{1/2} U X^{1/2}`, where :math:`U` is a # symmetric matrix uniformly sampled from the unit sphere (in the # corresponding vector space of symmetric matrices). # This corresponds to the parallel transport of ``U`` from the identity # matrix to ``X``. shape = x.shape[:-2] + (self.dim, ) u = torch.randn(shape, out=x.new(shape)) u.div_(u.norm(dim=-1, keepdim=True)).mul_(norm) u = self.from_vec(u) x_sqrt = tb.spdsqrtm(x, wmin=self.wmin, wmax=self.wmax) return tb.axat(x_sqrt, u)
def log(self, x, y): lylt, l = self._lult(x, y, ret_chol=True) log_lult = tb.spdlogm(lylt) logx_y = tb.axat(l, log_lult) return logx_y
def egrad2rgrad(self, x, u): return tb.axat(x, self.proju(x, u))
def _lult(self, x, u, ret_chol=False): l_inv, l = self.invchol(x, ret_chol=ret_chol) lult = tb.axat(l_inv, u) return lult, l