def mobius_gru_cell(
    input: torch.Tensor,
    hx: torch.Tensor,
    weight_ih: torch.Tensor,
    weight_hh: torch.Tensor,
    bias: torch.Tensor,
    c: torch.Tensor,
    nonlin=None,
):
    W_ir, W_ih, W_iz = weight_ih.chunk(3)
    b_r, b_h, b_z = bias
    W_hr, W_hh, W_hz = weight_hh.chunk(3)
    # print ('Inside GRU Cell: ')
    # print ('W_hz: ', W_hz.shape)
    # print ('hx: ', hx.shape)
    # print ('W_iz: ', W_iz.shape)
    # print ('input: ', input.shape)
    z_t = pmath.logmap0(one_rnn_transform(W_hz, hx, W_iz, input, b_z, c), c=c).sigmoid()
    r_t = pmath.logmap0(one_rnn_transform(W_hr, hx, W_ir, input, b_r, c), c=c).sigmoid()

    rh_t = pmath.mobius_pointwise_mul(r_t, hx, c=c)
    h_tilde = one_rnn_transform(W_hh, rh_t, W_ih, input, b_h, c)

    if nonlin is not None:
        h_tilde = pmath.mobius_fn_apply(nonlin, h_tilde, c=c)
    delta_h = pmath.mobius_add(-hx, h_tilde, c=c)
    h_out = pmath.mobius_add(hx, pmath.mobius_pointwise_mul(z_t, delta_h, c=c), c=c)
    return h_out
Beispiel #2
0
 def forward(self, x_h, x_e):
     dist = pmath.dist(x_h, pmath.project(pmath.expmap0(x_e, c=self.c))) * self.att
     x_e = pmath.project(pmath.mobius_scalar_mul(dist.view([-1, 1]), pmath.project(pmath.expmap0(x_e, c=self.c))),
                         c=self.c)
     x_e = F.dropout(x_e, p=self.drop, training=self.training)
     x_h = pmath.project(pmath.mobius_add(x_h, x_e))
     return x_h
Beispiel #3
0
def exp_map_x_polar(x: Tensor, radius: Tensor, v: Tensor, c: Tensor) -> Tensor:
    v = v + eps  # Perturb v to avoid dealing with v = 0
    norm_v = torch.norm(v, p=2, dim=-1, keepdim=True)
    assert len(radius.shape) == len(v.shape)
    assert radius.shape[:-1] == v.shape[:-1]

    second_term = (torch.tanh(c.sqrt() * radius / 2) / (c.sqrt() * norm_v)) * v
    assert x.shape == second_term.shape
    return pm.mobius_add(x, second_term, c=c)
Beispiel #4
0
def poincare_distance_c(x: Tensor,
                        y: Tensor,
                        c: Tensor,
                        keepdim: bool = True,
                        **kwargs: Any) -> Tensor:
    # res = pm.dist(x, y, c=c, keepdim=keepdim, **kwargs)

    sqrt_c = sqrt(c)
    mob = pm.mobius_add(-x, y, c=c, dim=-1).norm(dim=-1, p=2, keepdim=keepdim)
    arg = sqrt_c * mob
    dist_c = atanh(arg)
    res = dist_c * 2 / sqrt_c
    assert torch.isfinite(res).all()
    return res
def mobius_linear(
    input,
    weight,
    bias=None,
    hyperbolic_input=True,
    hyperbolic_bias=True,
    nonlin=None,
    c=1.0,
):
    if hyperbolic_input:
        output = pmath.mobius_matvec(weight, input, c=c)
    else:
        output = torch.nn.functional.linear(input, weight)
        output = pmath.expmap0(output, c=c)
    if bias is not None:
        if not hyperbolic_bias:
            bias = pmath.expmap0(bias, c=c)
        output = pmath.mobius_add(output, bias, c=c)
    if nonlin is not None:
        output = pmath.mobius_fn_apply(nonlin, output, c=c)
    output = pmath.project(output, c=c)
    return output
def one_rnn_transform(W, h, U, x, b, c):
    W_otimes_h = pmath.mobius_matvec(W, h, c=c)
    U_otimes_x = pmath.mobius_matvec(U, x, c=c)
    Wh_plus_Ux = pmath.mobius_add(W_otimes_h, U_otimes_x, c=c)
    return pmath.mobius_add(Wh_plus_Ux, b, c=c)
Beispiel #7
0
def mobius_add(x, y, c):
    return pmath.mobius_add(x, y, c=c)
Beispiel #8
0
    def forward(self, input):
        source_input = input[0]
        target_input = input[1]
        alignment = input[2]
        batch_size = alignment.shape[0]

        source_input_data = self.embedding(source_input.data)
        target_input_data = self.embedding(target_input.data)

        zero_hidden = torch.zeros(self.num_layers,
                                  batch_size,
                                  self.hidden_dim,
                                  device=self.device or source_input.device,
                                  dtype=source_input_data.dtype)

        if self.embedding_type == "eucl" and "hyp" in self.cell_type:
            source_input_data = pmath.expmap0(source_input_data, c=self.c)
            target_input_data = pmath.expmap0(target_input_data, c=self.c)
        elif self.embedding_type == "hyp" and "eucl" in self.cell_type:
            source_input_data = pmath.logmap0(source_input_data, c=self.c)
            target_input_data = pmath.logmap0(target_input_data, c=self.c)
        # ht: (num_layers * num_directions, batch, hidden_size)

        source_input = torch.nn.utils.rnn.PackedSequence(
            source_input_data, source_input.batch_sizes)
        target_input = torch.nn.utils.rnn.PackedSequence(
            target_input_data, target_input.batch_sizes)

        _, source_hidden = self.cell_source(source_input, zero_hidden)
        _, target_hidden = self.cell_target(target_input, zero_hidden)

        # take hiddens from the last layer
        source_hidden = source_hidden[-1]
        target_hidden = target_hidden[-1][alignment]

        if self.decision_type == "hyp":
            if "eucl" in self.cell_type:
                source_hidden = pmath.expmap0(source_hidden, c=self.c)
                target_hidden = pmath.expmap0(target_hidden, c=self.c)
            source_projected = self.projector_source(source_hidden)
            target_projected = self.projector_target(target_hidden)
            projected = pmath.mobius_add(source_projected,
                                         target_projected,
                                         c=self.ball.c)
            if self.use_distance_as_feature:
                dist = (pmath.dist(source_hidden,
                                   target_hidden,
                                   dim=-1,
                                   keepdim=True,
                                   c=self.ball.c)**2)
                bias = pmath.mobius_scalar_mul(dist,
                                               self.dist_bias,
                                               c=self.ball.c)
                projected = pmath.mobius_add(projected, bias, c=self.ball.c)
        else:
            if "hyp" in self.cell_type:
                source_hidden = pmath.logmap0(source_hidden, c=self.c)
                target_hidden = pmath.logmap0(target_hidden, c=self.c)
            projected = self.projector(
                torch.cat((source_hidden, target_hidden), dim=-1))
            if self.use_distance_as_feature:
                dist = torch.sum((source_hidden - target_hidden).pow(2),
                                 dim=-1,
                                 keepdim=True)
                bias = self.dist_bias * dist
                projected = projected + bias

        logits = self.logits(projected)
        # CrossEntropy accepts logits
        return logits
Beispiel #9
0
    def forward(self, inputs, timestamps, hidden_states, reverse=False):
        b, seq, embed = inputs.size()
        h = hidden_states[0]
        _c = hidden_states[1]
        if self.cuda_flag:
            h = h.cuda()
            _c = _c.cuda()
        outputs = []
        hidden_state_h = []
        hidden_state_c = []

        for s in range(seq):
            c_s1 = pmath_geo.expmap0(
                torch.tanh(
                    pmath_geo.logmap0(pmath_geo.mobius_matvec(self.W_d,
                                                              _c,
                                                              c=self.c),
                                      c=self.c)))  # short term mem
            c_s2 = pmath_geo.mobius_pointwise_mul(
                c_s1, timestamps[:, s:s + 1].expand_as(c_s1),
                c=self.c)  # discounted short term mem
            c_l = pmath_geo.mobius_add(-c_s1, _c, c=self.c)  # long term mem
            c_adj = pmath_geo.mobius_add(c_l, c_s2, c=self.c)

            W_f, W_i, W_o, W_c_tmp = self.W_all.chunk(4, dim=1)
            U_f, U_i, U_o, U_c_tmp = self.U_all.chunk(4, dim=0)
            # print ('WF: ', W_f.shape)
            # print ('H: ', h.shape)
            # print ('UF: ', U_f.shape)
            # print ('X: ', inputs[:, s].shape)
            f = pmath_geo.logmap0(one_rnn_transform(W_f, h, U_f, inputs[:, s],
                                                    self.c),
                                  c=self.c).sigmoid()
            i = pmath_geo.logmap0(one_rnn_transform(W_i, h, U_i, inputs[:, s],
                                                    self.c),
                                  c=self.c).sigmoid()
            o = pmath_geo.logmap0(one_rnn_transform(W_o, h, U_o, inputs[:, s],
                                                    self.c),
                                  c=self.c).sigmoid()
            c_tmp = pmath_geo.logmap0(one_rnn_transform(
                W_c_tmp, h, U_c_tmp, inputs[:, s], self.c),
                                      c=self.c).sigmoid()

            f_dot_c_adj = pmath_geo.mobius_pointwise_mul(f, c_adj, c=self.c)
            i_dot_c_tmp = pmath_geo.mobius_pointwise_mul(i, c_tmp, c=self.c)
            _c = pmath_geo.mobius_add(i_dot_c_tmp, f_dot_c_adj, c=self.c)

            h = pmath_geo.mobius_pointwise_mul(o,
                                               pmath_geo.expmap0(
                                                   torch.tanh(_c), c=self.c),
                                               c=self.c)
            outputs.append(o)
            hidden_state_c.append(_c)
            hidden_state_h.append(h)

        if reverse:
            outputs.reverse()
            hidden_state_c.reverse()
            hidden_state_h.reverse()
        outputs = torch.stack(outputs, 1)
        hidden_state_c = torch.stack(hidden_state_c, 1)
        hidden_state_h = torch.stack(hidden_state_h, 1)

        return outputs, (h, _c)