def one_rnn_transform(W, h, U, x, b, c): W_otimes_h = pmath.mobius_matvec(W, h, k=to_k(c, W)) U_otimes_x = pmath.mobius_matvec(U, x, k=to_k(c, U)) Wh_plus_Ux = pmath.mobius_add(W_otimes_h, U_otimes_x, k=to_k(c, W_otimes_h)) return pmath.mobius_add(Wh_plus_Ux, b, k=to_k(c, Wh_plus_Ux))
def transition(self, x, h): """ :param x: batch x input :param h: hidden x hidden :return: batch x hidden """ W_otimes_h = pmath.mobius_matvec(self.w, h, k=self.ball.k) U_otimes_x = pmath.mobius_matvec(self.u, x, k=self.ball.k) Wh_plus_Ux = pmath.mobius_add(W_otimes_h, U_otimes_x, k=self.ball.k) return pmath.mobius_add(Wh_plus_Ux, self.b, k=self.ball.k)
def mobius_linear( input, weight, bias=None, hyperbolic_input=True, hyperbolic_bias=True, nonlin=None, c=1.0, ): if hyperbolic_input: output = pmath.mobius_matvec(weight, input, k=to_k(c, weight)) else: output = torch.nn.functional.linear(input, weight) output = pmath.expmap0(output, k=to_k(c, output)) if bias is not None: if not hyperbolic_bias: bias = pmath.expmap0(bias, k=to_k(c, bias)) output = pmath.mobius_add(output, bias, k=to_k(c, output)) if nonlin is not None: output = pmath.mobius_fn_apply(nonlin, output, k=to_k(c, output)) output = pmath.project(output, k=to_k(c, output)) return output
def one_rnn_transform(W, h, U, x, b, k): W_otimes_h = gmath.mobius_matvec(W, h, k=k) U_otimes_x = gmath.mobius_matvec(U, x, k=k) Wh_plus_Ux = gmath.mobius_add(W_otimes_h, U_otimes_x, k=k) return gmath.mobius_add(Wh_plus_Ux, b, k=k)