Beispiel #1
0
 def sqdist(self, x, y, c):
     K = 1. / c
     prod = self.minkowski_dot(x, y)
     eps = {torch.float32: 1e-7, torch.float64: 1e-15}
     theta = torch.clamp(-prod / K, min=1.0 + eps[x.dtype])
     sqdist = K * arcosh(theta)**2
     return torch.clamp(sqdist, max=50.0)
Beispiel #2
0
 def sqdist(self, x, y, c):
     K = 1. / c
     prod = self.minkowski_dot(x, y)
     theta = torch.clamp(-prod / K, min=1.0 + self.eps[x.dtype])
     sqdist = K * arcosh(theta)**2
     # clamp distance to avoid nans in Fermi-Dirac decoder
     return torch.clamp(sqdist, max=50.0)
Beispiel #3
0
 def logmap0(self, x, c):
     K = 1. / c
     sqrtK = K**0.5
     d = x.size(-1) - 1
     y = x.narrow(-1, 1, d).view(-1, d)
     y_norm = torch.norm(y, p=2, dim=1, keepdim=True)
     y_norm = torch.clamp(y_norm, min=self.min_norm)
     res = torch.zeros_like(x)
     theta = torch.clamp(x[:, 0:1] / sqrtK, min=1.0 + self.eps[x.dtype])
     res[:, 1:] = sqrtK * arcosh(theta) * y / y_norm
     return res
Beispiel #4
0
 def induced_distance(self, x, y, c):
     xy_inner = self.l_inner(x, y)
     sqrt_c = c**0.5
     return sqrt_c * arcosh(-xy_inner / c + self.eps[x.dtype])