def logmap(self, p1, p2, c): sub = self.mobius_add(-p1, p2, c) sub_norm = sub.norm(dim=-1, p=2, keepdim=True).clamp_min( self.min_norm).clamp_max(self.max_norm) lam = self._lambda_x(p1, c) sqrt_c = c**0.5 return 2 / sqrt_c / lam * artanh(sqrt_c * sub_norm) * sub / sub_norm
def logmap0(self, p, c): sqrt_c = c ** 0.5 p_norm = p.norm(dim=-1, p=2, keepdim=True).clamp_min(self.min_norm).clamp_max(self.max_norm) scale = 1. / sqrt_c * artanh(sqrt_c * p_norm) / p_norm if torch.isinf(scale * p).any(): print("check here") return scale * p
def sqdist(self, p1, p2, c): sqrt_c = c**0.5 dist_c = artanh(sqrt_c * self.mobius_add(-p1, p2, c, dim=-1).norm( dim=-1, p=2, keepdim=False)) dist = dist_c * 2 / sqrt_c if torch.isinf(dist).any(): print("check here") return dist**2
def mobius_matvec(self, m, x, c): sqrt_c = c ** 0.5 x_norm = x.norm(dim=-1, keepdim=True, p=2).clamp_min(self.min_norm) mx = x @ m.transpose(-1, -2) mx_norm = mx.norm(dim=-1, keepdim=True, p=2).clamp_min(self.min_norm) res_c = tanh(mx_norm / x_norm * artanh(sqrt_c * x_norm)) * mx / (mx_norm * sqrt_c) cond = (mx == 0).prod(-1, keepdim=True, dtype=torch.uint8) res_0 = torch.zeros(1, dtype=res_c.dtype, device=res_c.device) res = torch.where(cond, res_0, res_c) return res