def sqdist(self, p1, p2, c): sqrt_c = c ** 0.5 dist_c = artanh( sqrt_c * self.mobius_add(-p1, p2, c, dim=-1).norm(dim=-1, p=2, keepdim=False) ) dist = dist_c * 2 / sqrt_c return dist ** 2
def distance(self, p1, p2): keepdim = False dim = -1 sqrt_c = self.c**0.5 dist_c = artanh(sqrt_c * self.mobius_add(-p1, p2, dim=dim).norm( dim=dim, p=2, keepdim=keepdim)) dist = dist_c * 2 / sqrt_c return dist
def mid_point_poincare(self, x, y, c: Curvature, manifold): sqrt_c = c.c**0.5 r = 0.5 x_y = manifold.mobius_add(-x, y, c=c) norm = torch.clamp_min(x_y.norm(dim=-1, keepdim=True, p=2), 1e-15) x_y_r = tanh(r * artanh(sqrt_c * norm)) * (x_y / norm) / sqrt_c mid_point = manifold.mobius_add(x, x_y_r, c=c) return mid_point
def mobius_matvec(self, m, x, c): sqrt_c = c ** 0.5 x_norm = x.norm(dim=-1, keepdim=True, p=2).clamp_min(self.min_norm) mx = x @ m.transpose(-1, -2) mx_norm = mx.norm(dim=-1, keepdim=True, p=2).clamp_min(self.min_norm) res_c = tanh(mx_norm / x_norm * artanh(sqrt_c * x_norm)) * mx / (mx_norm * sqrt_c) cond = (mx == 0).prod(-1, keepdim=True, dtype=torch.uint8) res_0 = torch.zeros(1, dtype=res_c.dtype, device=res_c.device) res = torch.where(cond, res_0, res_c) return res
def sqdist(self, p1, p2, c): sqrt_c = c**0.5 test = sqrt_c * self.mobius_add( self.proj(-p1, c), self.proj(p2, c), c, dim=-1).norm( dim=-1, p=2, keepdim=False) assert torch.max(test) <= 1 assert torch.min(test) >= -1 dist_c = artanh(sqrt_c * self.mobius_add(-p1, p2, c, dim=-1).norm( dim=-1, p=2, keepdim=False)) dist = dist_c * 2 / sqrt_c return dist**2
def logmap0(self, p, c): sqrt_c = c**0.5 p_norm = p.norm(dim=-1, p=2, keepdim=True).clamp_min(self.min_norm) scale = 1. / sqrt_c * artanh(sqrt_c * p_norm) / p_norm return scale * p
def logmap(self, p1, p2, c): sub = self.mobius_add(-p1, p2, c) sub_norm = sub.norm(dim=-1, p=2, keepdim=True).clamp_min(self.min_norm) lam = self._lambda_x(p1, c) sqrt_c = c**0.5 return 2 / sqrt_c / lam * artanh(sqrt_c * sub_norm) * sub / sub_norm