Exemplo n.º 1
0
 def expmap(self, u, p, c):
     sqrt_c = c**0.5
     u_norm = u.norm(dim=-1, p=2, keepdim=True).clamp_min(self.min_norm)
     second_term = (tanh(sqrt_c / 2 * self._lambda_x(p, c) * u_norm) * u /
                    (sqrt_c * u_norm))
     gamma_1 = self.mobius_add(p, second_term, c)
     return gamma_1
Exemplo n.º 2
0
 def mid_point_poincare(self, x, y, c: Curvature, manifold):
     sqrt_c = c.c**0.5
     r = 0.5
     x_y = manifold.mobius_add(-x, y, c=c)
     norm = torch.clamp_min(x_y.norm(dim=-1, keepdim=True, p=2), 1e-15)
     x_y_r = tanh(r * artanh(sqrt_c * norm)) * (x_y / norm) / sqrt_c
     mid_point = manifold.mobius_add(x, x_y_r, c=c)
     return mid_point
 def expmap(self, u, p):
     c = torch.as_tensor(self.c).type_as(u)
     sqrt_c = c**0.5
     u_norm = u.norm(dim=-1, p=2, keepdim=True).clamp_min(self.min_norm)
     second_term = (tanh(sqrt_c / 2 * self._lambda_x(p) * u_norm) * u /
                    (sqrt_c * u_norm))
     gamma_1 = self.mobius_add(p, second_term)
     return gamma_1
Exemplo n.º 4
0
 def mobius_matvec(self, m, x, c):
     sqrt_c = c ** 0.5
     x_norm = x.norm(dim=-1, keepdim=True, p=2).clamp_min(self.min_norm)
     mx = x @ m.transpose(-1, -2)
     mx_norm = mx.norm(dim=-1, keepdim=True, p=2).clamp_min(self.min_norm)
     res_c = tanh(mx_norm / x_norm * artanh(sqrt_c * x_norm)) * mx / (mx_norm * sqrt_c)
     cond = (mx == 0).prod(-1, keepdim=True, dtype=torch.uint8)
     res_0 = torch.zeros(1, dtype=res_c.dtype, device=res_c.device)
     res = torch.where(cond, res_0, res_c)
     return res
Exemplo n.º 5
0
 def expmap0(self, u, c):
     sqrt_c = c**0.5
     u_norm = torch.clamp_min(u.norm(dim=-1, p=2, keepdim=True),
                              self.min_norm)
     gamma_1 = tanh(sqrt_c * u_norm) * u / (sqrt_c * u_norm)
     return gamma_1