Ejemplo n.º 1
0
 def sqdist(self, p1, p2, c):
     sqrt_c = c ** 0.5
     dist_c = artanh(
         sqrt_c * self.mobius_add(-p1, p2, c, dim=-1).norm(dim=-1, p=2, keepdim=False)
     )
     dist = dist_c * 2 / sqrt_c
     return dist ** 2
Ejemplo n.º 2
0
 def distance(self, p1, p2):
     keepdim = False
     dim = -1
     sqrt_c = self.c**0.5
     dist_c = artanh(sqrt_c * self.mobius_add(-p1, p2, dim=dim).norm(
         dim=dim, p=2, keepdim=keepdim))
     dist = dist_c * 2 / sqrt_c
     return dist
Ejemplo n.º 3
0
 def mid_point_poincare(self, x, y, c: Curvature, manifold):
     sqrt_c = c.c**0.5
     r = 0.5
     x_y = manifold.mobius_add(-x, y, c=c)
     norm = torch.clamp_min(x_y.norm(dim=-1, keepdim=True, p=2), 1e-15)
     x_y_r = tanh(r * artanh(sqrt_c * norm)) * (x_y / norm) / sqrt_c
     mid_point = manifold.mobius_add(x, x_y_r, c=c)
     return mid_point
Ejemplo n.º 4
0
 def mobius_matvec(self, m, x, c):
     sqrt_c = c ** 0.5
     x_norm = x.norm(dim=-1, keepdim=True, p=2).clamp_min(self.min_norm)
     mx = x @ m.transpose(-1, -2)
     mx_norm = mx.norm(dim=-1, keepdim=True, p=2).clamp_min(self.min_norm)
     res_c = tanh(mx_norm / x_norm * artanh(sqrt_c * x_norm)) * mx / (mx_norm * sqrt_c)
     cond = (mx == 0).prod(-1, keepdim=True, dtype=torch.uint8)
     res_0 = torch.zeros(1, dtype=res_c.dtype, device=res_c.device)
     res = torch.where(cond, res_0, res_c)
     return res
Ejemplo n.º 5
0
 def sqdist(self, p1, p2, c):
     sqrt_c = c**0.5
     test = sqrt_c * self.mobius_add(
         self.proj(-p1, c), self.proj(p2, c), c, dim=-1).norm(
             dim=-1, p=2, keepdim=False)
     assert torch.max(test) <= 1
     assert torch.min(test) >= -1
     dist_c = artanh(sqrt_c * self.mobius_add(-p1, p2, c, dim=-1).norm(
         dim=-1, p=2, keepdim=False))
     dist = dist_c * 2 / sqrt_c
     return dist**2
Ejemplo n.º 6
0
 def logmap0(self, p, c):
     sqrt_c = c**0.5
     p_norm = p.norm(dim=-1, p=2, keepdim=True).clamp_min(self.min_norm)
     scale = 1. / sqrt_c * artanh(sqrt_c * p_norm) / p_norm
     return scale * p
Ejemplo n.º 7
0
 def logmap(self, p1, p2, c):
     sub = self.mobius_add(-p1, p2, c)
     sub_norm = sub.norm(dim=-1, p=2, keepdim=True).clamp_min(self.min_norm)
     lam = self._lambda_x(p1, c)
     sqrt_c = c**0.5
     return 2 / sqrt_c / lam * artanh(sqrt_c * sub_norm) * sub / sub_norm