def logm(self, x, y, dim=-1, keepdim=True): x = x + self.eps #sqrt_c = np.sqrt(self.c) diff = self.add(-x, y) norm_diff_scaled = diff.norm(2, dim=dim, keepdim=keepdim) / self.s norm_diff_scaled.data.clamp_(min=self.min_norm) lam = self.lambd(x) return 2. * atanh(norm_diff_scaled, self.eps) * diff.div(norm_diff_scaled * lam)
def scalar_mul(self, r, x, dim=-1, keepdim=True): x = x + self.eps s = self.s norm_x = x.norm(2, dim=dim, keepdim=keepdim) norm_x.data.clamp_(min=self.min_norm) t = r * atanh((norm_x / s), self.eps) t.data.clamp_(min=self.eps) return self.proj2ball(s * tanh(t) * x.div(norm_x))
def mat_mul(self, M, x, dim=-1, keepdim=True): x = x + self.eps #Mx = M.mm(x.t()).t() if dim != -1 or M.dim() == 2: Mx = torch.tensordot(x, M, dims=([dim], [1])) else: Mx = torch.matmul(M, x.unsqueeze(-1)).squeeze(-1) norm_Mx = Mx.norm(2, dim=dim, keepdim=keepdim) / self.s norm_x = x.norm(2, dim=dim, keepdim=keepdim) / self.s norm_Mx.data.clamp_(min=self.min_norm) norm_x.data.clamp_(min=self.min_norm) result = tanh(norm_Mx.div(norm_x) * atanh(norm_x, self.eps)) * Mx.div(norm_Mx) return self.proj2ball(result)
def logm_zero(self, y, dim=-1, keepdim=True): y = y + self.eps norm_diff_scaled = y.norm(2, dim=dim, keepdim=keepdim) / self.s norm_diff_scaled.data.clamp_(min=self.min_norm) return atanh(norm_diff_scaled, self.eps) * y.div(norm_diff_scaled)
def distance(self, x, y): return 2 * self.s * atanh(self.gyrodistance(x, y) / self.s, self.eps)
def _distance(self, x, y): return atanh(self.gyrodistance(x, y) / self.s)
def poincare_matrix(u, v, manifold, s=1.): x_col = u.unsqueeze(1) y_lin = v.unsqueeze(0) return 2 * s * atanh(manifold.add(-x_col, y_lin).norm(dim=-1) / s)