def mobius_linear( input, weight, bias=None, hyperbolic_input=True, hyperbolic_bias=True, nonlin=None, k=-1.0, ): k = torch.tensor(k) if hyperbolic_input: output = mobius_matvec(weight, input, k=k) else: output = torch.nn.functional.linear(input, weight) output = gmath.expmap0(output, k=k) if bias is not None: if not hyperbolic_bias: bias = gmath.expmap0(bias, k=k) output = gmath.mobius_add(output, bias.unsqueeze(0).expand_as(output), k=k) if nonlin is not None: output = gmath.mobius_fn_apply(nonlin, output, k=k) output = gmath.project(output, k=k) return output
def project_embeds(self): """Projects embeddings back into the hyperbolic ball, for numerical stability""" with torch.no_grad(): if self.train_word_embeds and self.args.embedding_metric == cs.HY: self.word_lut.data = pmath.project( self.word_lut, k=self.word_embed_manifold.k) if self.args.attn_metric == cs.HY: k = self.ctx_attn.manifold.k self.ctx_attn.position_embeds.data = pmath.project( self.ctx_attn.position_embeds, k=k) self.mention_encoder.mention_attn.position_embeds.data = pmath.project( self.mention_encoder.mention_attn.position_embeds, k=k) self.mention_encoder.char_lut.data = pmath.project( self.mention_encoder.char_lut, k=self.mention_encoder.manifold.k)
def mobius_linear( input, weight, bias=None, hyperbolic_input=True, hyperbolic_bias=True, nonlin=None, c=1.0, ): if hyperbolic_input: output = pmath.mobius_matvec(weight, input, k=to_k(c, weight)) else: output = torch.nn.functional.linear(input, weight) output = pmath.expmap0(output, k=to_k(c, output)) if bias is not None: if not hyperbolic_bias: bias = pmath.expmap0(bias, k=to_k(c, bias)) output = pmath.mobius_add(output, bias, k=to_k(c, output)) if nonlin is not None: output = pmath.mobius_fn_apply(nonlin, output, k=to_k(c, output)) output = pmath.project(output, k=to_k(c, output)) return output
def add(self, a, b): out = pmath.mobius_add(a, b, k=self.ball.k) return pmath.project(out, k=self.ball.k)