def mobius_gru_cell( input: torch.Tensor, hx: torch.Tensor, weight_ih: torch.Tensor, weight_hh: torch.Tensor, bias: torch.Tensor, c: torch.Tensor, nonlin=None, ): W_ir, W_ih, W_iz = weight_ih.chunk(3) b_r, b_h, b_z = bias W_hr, W_hh, W_hz = weight_hh.chunk(3) # print ('Inside GRU Cell: ') # print ('W_hz: ', W_hz.shape) # print ('hx: ', hx.shape) # print ('W_iz: ', W_iz.shape) # print ('input: ', input.shape) z_t = pmath.logmap0(one_rnn_transform(W_hz, hx, W_iz, input, b_z, c), c=c).sigmoid() r_t = pmath.logmap0(one_rnn_transform(W_hr, hx, W_ir, input, b_r, c), c=c).sigmoid() rh_t = pmath.mobius_pointwise_mul(r_t, hx, c=c) h_tilde = one_rnn_transform(W_hh, rh_t, W_ih, input, b_h, c) if nonlin is not None: h_tilde = pmath.mobius_fn_apply(nonlin, h_tilde, c=c) delta_h = pmath.mobius_add(-hx, h_tilde, c=c) h_out = pmath.mobius_add(hx, pmath.mobius_pointwise_mul(z_t, delta_h, c=c), c=c) return h_out
def forward(self, x_h, x_e): dist = pmath.dist(x_h, pmath.project(pmath.expmap0(x_e, c=self.c))) * self.att x_e = pmath.project(pmath.mobius_scalar_mul(dist.view([-1, 1]), pmath.project(pmath.expmap0(x_e, c=self.c))), c=self.c) x_e = F.dropout(x_e, p=self.drop, training=self.training) x_h = pmath.project(pmath.mobius_add(x_h, x_e)) return x_h
def exp_map_x_polar(x: Tensor, radius: Tensor, v: Tensor, c: Tensor) -> Tensor: v = v + eps # Perturb v to avoid dealing with v = 0 norm_v = torch.norm(v, p=2, dim=-1, keepdim=True) assert len(radius.shape) == len(v.shape) assert radius.shape[:-1] == v.shape[:-1] second_term = (torch.tanh(c.sqrt() * radius / 2) / (c.sqrt() * norm_v)) * v assert x.shape == second_term.shape return pm.mobius_add(x, second_term, c=c)
def poincare_distance_c(x: Tensor, y: Tensor, c: Tensor, keepdim: bool = True, **kwargs: Any) -> Tensor: # res = pm.dist(x, y, c=c, keepdim=keepdim, **kwargs) sqrt_c = sqrt(c) mob = pm.mobius_add(-x, y, c=c, dim=-1).norm(dim=-1, p=2, keepdim=keepdim) arg = sqrt_c * mob dist_c = atanh(arg) res = dist_c * 2 / sqrt_c assert torch.isfinite(res).all() return res
def mobius_linear( input, weight, bias=None, hyperbolic_input=True, hyperbolic_bias=True, nonlin=None, c=1.0, ): if hyperbolic_input: output = pmath.mobius_matvec(weight, input, c=c) else: output = torch.nn.functional.linear(input, weight) output = pmath.expmap0(output, c=c) if bias is not None: if not hyperbolic_bias: bias = pmath.expmap0(bias, c=c) output = pmath.mobius_add(output, bias, c=c) if nonlin is not None: output = pmath.mobius_fn_apply(nonlin, output, c=c) output = pmath.project(output, c=c) return output
def one_rnn_transform(W, h, U, x, b, c): W_otimes_h = pmath.mobius_matvec(W, h, c=c) U_otimes_x = pmath.mobius_matvec(U, x, c=c) Wh_plus_Ux = pmath.mobius_add(W_otimes_h, U_otimes_x, c=c) return pmath.mobius_add(Wh_plus_Ux, b, c=c)
def mobius_add(x, y, c): return pmath.mobius_add(x, y, c=c)
def forward(self, input): source_input = input[0] target_input = input[1] alignment = input[2] batch_size = alignment.shape[0] source_input_data = self.embedding(source_input.data) target_input_data = self.embedding(target_input.data) zero_hidden = torch.zeros(self.num_layers, batch_size, self.hidden_dim, device=self.device or source_input.device, dtype=source_input_data.dtype) if self.embedding_type == "eucl" and "hyp" in self.cell_type: source_input_data = pmath.expmap0(source_input_data, c=self.c) target_input_data = pmath.expmap0(target_input_data, c=self.c) elif self.embedding_type == "hyp" and "eucl" in self.cell_type: source_input_data = pmath.logmap0(source_input_data, c=self.c) target_input_data = pmath.logmap0(target_input_data, c=self.c) # ht: (num_layers * num_directions, batch, hidden_size) source_input = torch.nn.utils.rnn.PackedSequence( source_input_data, source_input.batch_sizes) target_input = torch.nn.utils.rnn.PackedSequence( target_input_data, target_input.batch_sizes) _, source_hidden = self.cell_source(source_input, zero_hidden) _, target_hidden = self.cell_target(target_input, zero_hidden) # take hiddens from the last layer source_hidden = source_hidden[-1] target_hidden = target_hidden[-1][alignment] if self.decision_type == "hyp": if "eucl" in self.cell_type: source_hidden = pmath.expmap0(source_hidden, c=self.c) target_hidden = pmath.expmap0(target_hidden, c=self.c) source_projected = self.projector_source(source_hidden) target_projected = self.projector_target(target_hidden) projected = pmath.mobius_add(source_projected, target_projected, c=self.ball.c) if self.use_distance_as_feature: dist = (pmath.dist(source_hidden, target_hidden, dim=-1, keepdim=True, c=self.ball.c)**2) bias = pmath.mobius_scalar_mul(dist, self.dist_bias, c=self.ball.c) projected = pmath.mobius_add(projected, bias, c=self.ball.c) else: if "hyp" in self.cell_type: source_hidden = pmath.logmap0(source_hidden, c=self.c) target_hidden = pmath.logmap0(target_hidden, c=self.c) projected = self.projector( torch.cat((source_hidden, target_hidden), dim=-1)) if self.use_distance_as_feature: dist = torch.sum((source_hidden - target_hidden).pow(2), dim=-1, keepdim=True) bias = self.dist_bias * dist projected = projected + bias logits = self.logits(projected) # CrossEntropy accepts logits return logits
def forward(self, inputs, timestamps, hidden_states, reverse=False): b, seq, embed = inputs.size() h = hidden_states[0] _c = hidden_states[1] if self.cuda_flag: h = h.cuda() _c = _c.cuda() outputs = [] hidden_state_h = [] hidden_state_c = [] for s in range(seq): c_s1 = pmath_geo.expmap0( torch.tanh( pmath_geo.logmap0(pmath_geo.mobius_matvec(self.W_d, _c, c=self.c), c=self.c))) # short term mem c_s2 = pmath_geo.mobius_pointwise_mul( c_s1, timestamps[:, s:s + 1].expand_as(c_s1), c=self.c) # discounted short term mem c_l = pmath_geo.mobius_add(-c_s1, _c, c=self.c) # long term mem c_adj = pmath_geo.mobius_add(c_l, c_s2, c=self.c) W_f, W_i, W_o, W_c_tmp = self.W_all.chunk(4, dim=1) U_f, U_i, U_o, U_c_tmp = self.U_all.chunk(4, dim=0) # print ('WF: ', W_f.shape) # print ('H: ', h.shape) # print ('UF: ', U_f.shape) # print ('X: ', inputs[:, s].shape) f = pmath_geo.logmap0(one_rnn_transform(W_f, h, U_f, inputs[:, s], self.c), c=self.c).sigmoid() i = pmath_geo.logmap0(one_rnn_transform(W_i, h, U_i, inputs[:, s], self.c), c=self.c).sigmoid() o = pmath_geo.logmap0(one_rnn_transform(W_o, h, U_o, inputs[:, s], self.c), c=self.c).sigmoid() c_tmp = pmath_geo.logmap0(one_rnn_transform( W_c_tmp, h, U_c_tmp, inputs[:, s], self.c), c=self.c).sigmoid() f_dot_c_adj = pmath_geo.mobius_pointwise_mul(f, c_adj, c=self.c) i_dot_c_tmp = pmath_geo.mobius_pointwise_mul(i, c_tmp, c=self.c) _c = pmath_geo.mobius_add(i_dot_c_tmp, f_dot_c_adj, c=self.c) h = pmath_geo.mobius_pointwise_mul(o, pmath_geo.expmap0( torch.tanh(_c), c=self.c), c=self.c) outputs.append(o) hidden_state_c.append(_c) hidden_state_h.append(h) if reverse: outputs.reverse() hidden_state_c.reverse() hidden_state_h.reverse() outputs = torch.stack(outputs, 1) hidden_state_c = torch.stack(hidden_state_c, 1) hidden_state_h = torch.stack(hidden_state_h, 1) return outputs, (h, _c)