def _log_det_jacobian(self, x, y): r = F.sqrt(functions.clamp(functions.lorentzian_product(x, x), eps)) d = x / r[..., None] dim = d.shape[-1] logdet = (dim - 2) * F.log(F.sinh(r) / r) return logdet
def pseudo_polar_projection(x): r = F.sqrt(F.sum(F.square(x), axis=-1, keepdims=True)) d = x / F.broadcast_to(clamp(r, eps), x.shape) r_proj = F.cosh(r) d_proj = F.broadcast_to(F.sinh(r), d.shape) * d x_proj = F.concat((r_proj, d_proj), axis=-1) return x_proj
def forward(self, x): y1 = F.sinh(x) return y1
def exponential_map(x, v): vnorm = F.sqrt(clamp(lorentzian_product(v, keepdims=True), eps)) return F.cosh(vnorm) * x + F.sinh(vnorm) * v / vnorm
def sinh(self, x): return F.sinh(x)