Example #1
0
def mobius_gru_cell(
    input: torch.Tensor,
    hx: torch.Tensor,
    weight_ih: torch.Tensor,
    weight_hh: torch.Tensor,
    bias: torch.Tensor,
    k: torch.Tensor,
    nonlin=None,
):
    W_ir, W_ih, W_iz = weight_ih.chunk(3)
    b_r, b_h, b_z = bias
    W_hr, W_hh, W_hz = weight_hh.chunk(3)

    z_t = gmath.logmap0(one_rnn_transform(W_hz, hx, W_iz, input, b_z, k),
                        k=k).sigmoid()
    r_t = gmath.logmap0(one_rnn_transform(W_hr, hx, W_ir, input, b_r, k),
                        k=k).sigmoid()

    rh_t = gmath.mobius_pointwise_mul(r_t, hx, k=k)
    h_tilde = one_rnn_transform(W_hh, rh_t, W_ih, input, b_h, k)

    if nonlin is not None:
        h_tilde = gmath.mobius_fn_apply(nonlin, h_tilde, k=k)
    delta_h = gmath.mobius_add(-hx, h_tilde, k=k)
    h_out = gmath.mobius_add(hx,
                             gmath.mobius_pointwise_mul(z_t, delta_h, k=k),
                             k=k)
    return h_out
Example #2
0
File: nets.py Project: nsuke/hyrnn
def one_rnn_transform(W, h, U, x, b, c):
    W_otimes_h = pmath.mobius_matvec(W, h, k=to_k(c, W))
    U_otimes_x = pmath.mobius_matvec(U, x, k=to_k(c, U))
    Wh_plus_Ux = pmath.mobius_add(W_otimes_h,
                                  U_otimes_x,
                                  k=to_k(c, W_otimes_h))
    return pmath.mobius_add(Wh_plus_Ux, b, k=to_k(c, Wh_plus_Ux))
Example #3
0
 def transition(self, x, h):
     """
     :param x: batch x input
     :param h: hidden x hidden
     :return: batch x hidden
     """
     W_otimes_h = pmath.mobius_matvec(self.w, h, k=self.ball.k)
     U_otimes_x = pmath.mobius_matvec(self.u, x, k=self.ball.k)
     Wh_plus_Ux = pmath.mobius_add(W_otimes_h, U_otimes_x, k=self.ball.k)
     
     return pmath.mobius_add(Wh_plus_Ux, self.b, k=self.ball.k)
Example #4
0
def mobius_linear(
    input,
    weight,
    bias=None,
    hyperbolic_input=True,
    hyperbolic_bias=True,
    nonlin=None,
    k=-1.0,
):
    k = torch.tensor(k)
    if hyperbolic_input:
        output = mobius_matvec(weight, input, k=k)
    else:
        output = torch.nn.functional.linear(input, weight)
        output = gmath.expmap0(output, k=k)
    if bias is not None:
        if not hyperbolic_bias:
            bias = gmath.expmap0(bias, k=k)
        output = gmath.mobius_add(output,
                                  bias.unsqueeze(0).expand_as(output),
                                  k=k)
    if nonlin is not None:
        output = gmath.mobius_fn_apply(nonlin, output, k=k)
    output = gmath.project(output, k=k)
    return output
def exp_map_x_polar(x: Tensor, radius: Tensor, v: Tensor, c: Tensor) -> Tensor:
    v = v + eps  # Perturb v to avoid dealing with v = 0
    norm_v = torch.norm(v, p=2, dim=-1, keepdim=True)
    assert len(radius.shape) == len(v.shape)
    assert radius.shape[:-1] == v.shape[:-1]

    second_term = (torch.tanh(c.sqrt() * radius / 2) / (c.sqrt() * norm_v)) * v
    assert x.shape == second_term.shape
    return pm.mobius_add(x, second_term, c=c)
def poincare_distance_c(x: Tensor, y: Tensor, c: Tensor, keepdim: bool = True, **kwargs: Any) -> Tensor:
    # res = pm.dist(x, y, c=c, keepdim=keepdim, **kwargs)

    sqrt_c = sqrt(c)
    mob = pm.mobius_add(-x, y, c=c, dim=-1).norm(dim=-1, p=2, keepdim=keepdim)
    arg = sqrt_c * mob
    dist_c = atanh(arg)
    res = dist_c * 2 / sqrt_c
    assert torch.isfinite(res).all()
    return res
Example #7
0
 def _dist2plane(self, x, a, p, c, k, keepdim: bool = False, signed: bool = False, dim: int = -1):
     """
     Taken from geoopt and corrected so it returns a_norm and this value does not have to be calculated twice
     """
     sqrt_c = c ** 0.5
     minus_p_plus_x = pmath.mobius_add(-p, x, k=k, dim=dim)
     mpx_sqnorm = minus_p_plus_x.pow(2).sum(dim=dim, keepdim=keepdim).clamp_min(MIN_NORM)
     mpx_dot_a = (minus_p_plus_x * a).sum(dim=dim, keepdim=keepdim)
     if not signed:
         mpx_dot_a = mpx_dot_a.abs()
     a_norm = a.norm(dim=dim, keepdim=keepdim, p=2).clamp_min(MIN_NORM)
     num = 2 * sqrt_c * mpx_dot_a
     denom = (1 - c * mpx_sqnorm) * a_norm
     return pmath.arsinh(num / denom.clamp_min(MIN_NORM)) / sqrt_c, a_norm
Example #8
0
    def __init__(self, args, pos_embeds_rows, input_dims):
        super(DistanceAttention, self).__init__()

        # pos embeds
        self.manifold = gt.PoincareBall(
        ) if args.attn_metric == cs.HY else gt.Euclidean()
        with torch.no_grad():
            pos_embeds = init_embeddings(pos_embeds_rows, input_dims)
            if args.attn_metric == cs.HY:
                pos_embeds = pmath.expmap0(pos_embeds, k=self.manifold.k)
            beta = torch.Tensor(1).uniform_(-0.01, 0.01)
        self.position_embeds = gt.ManifoldParameter(pos_embeds,
                                                    manifold=self.manifold)

        if args.attn_metric == cs.HY:
            self.key_dense = hnn.MobiusLinear(input_dims,
                                              input_dims,
                                              hyperbolic_input=True,
                                              hyperbolic_bias=True)
            self.query_dense = hnn.MobiusLinear(input_dims,
                                                input_dims,
                                                hyperbolic_input=True,
                                                hyperbolic_bias=True)
            self.addition = lambda x, y: pmath.mobius_add(
                x, y, k=self.manifold.k)
            self.distance_function = lambda x, y: pmath.dist(
                x, y, k=self.manifold.k)
            self.midpoint = hnn.weighted_mobius_midpoint
        else:
            self.key_dense = nn.Linear(input_dims, input_dims)
            self.query_dense = nn.Linear(input_dims, input_dims)
            self.addition = torch.add
            self.distance_function = utils.euclidean_distance
            self.midpoint = hnn.weighted_euclidean_midpoint

        self.encoder_to_attn_map = define_mapping(args.encoder_metric,
                                                  args.attn_metric, args.c)
        self.attention_function = nn.Softmax(
            dim=1) if args.attn == "softmax" else nn.Sigmoid()
        self.beta = torch.nn.Parameter(beta, requires_grad=True)
Example #9
0
File: nets.py Project: nsuke/hyrnn
def mobius_linear(
    input,
    weight,
    bias=None,
    hyperbolic_input=True,
    hyperbolic_bias=True,
    nonlin=None,
    c=1.0,
):
    if hyperbolic_input:
        output = pmath.mobius_matvec(weight, input, k=to_k(c, weight))
    else:
        output = torch.nn.functional.linear(input, weight)
        output = pmath.expmap0(output, k=to_k(c, output))
    if bias is not None:
        if not hyperbolic_bias:
            bias = pmath.expmap0(bias, k=to_k(c, bias))
        output = pmath.mobius_add(output, bias, k=to_k(c, output))
    if nonlin is not None:
        output = pmath.mobius_fn_apply(nonlin, output, k=to_k(c, output))
    output = pmath.project(output, k=to_k(c, output))
    return output
Example #10
0
def one_rnn_transform(W, h, U, x, b, k):
    W_otimes_h = gmath.mobius_matvec(W, h, k=k)
    U_otimes_x = gmath.mobius_matvec(U, x, k=k)
    Wh_plus_Ux = gmath.mobius_add(W_otimes_h, U_otimes_x, k=k)
    return gmath.mobius_add(Wh_plus_Ux, b, k=k)
Example #11
0
File: model.py Project: nsuke/hyrnn
    def forward(self, input):
        source_input = input[0]
        target_input = input[1]
        alignment = input[2]
        batch_size = alignment.shape[0]

        source_input_data = self.embedding(source_input.data)
        target_input_data = self.embedding(target_input.data)

        zero_hidden = torch.zeros(self.num_layers,
                                  batch_size,
                                  self.hidden_dim,
                                  device=self.device or source_input.device,
                                  dtype=source_input_data.dtype)

        if self.embedding_type == "eucl" and "hyp" in self.cell_type:
            source_input_data = pmath.expmap0(source_input_data, c=self.c)
            target_input_data = pmath.expmap0(target_input_data, c=self.c)
        elif self.embedding_type == "hyp" and "eucl" in self.cell_type:
            source_input_data = pmath.logmap0(source_input_data, c=self.c)
            target_input_data = pmath.logmap0(target_input_data, c=self.c)
        # ht: (num_layers * num_directions, batch, hidden_size)

        source_input = torch.nn.utils.rnn.PackedSequence(
            source_input_data, source_input.batch_sizes)
        target_input = torch.nn.utils.rnn.PackedSequence(
            target_input_data, target_input.batch_sizes)

        _, source_hidden = self.cell_source(source_input, zero_hidden)
        _, target_hidden = self.cell_target(target_input, zero_hidden)

        # take hiddens from the last layer
        source_hidden = source_hidden[-1]
        target_hidden = target_hidden[-1][alignment]

        if self.decision_type == "hyp":
            if "eucl" in self.cell_type:
                source_hidden = pmath.expmap0(source_hidden, c=self.c)
                target_hidden = pmath.expmap0(target_hidden, c=self.c)
            source_projected = self.projector_source(source_hidden)
            target_projected = self.projector_target(target_hidden)
            projected = pmath.mobius_add(source_projected,
                                         target_projected,
                                         c=self.ball.c)
            if self.use_distance_as_feature:
                dist = (pmath.dist(source_hidden,
                                   target_hidden,
                                   dim=-1,
                                   keepdim=True,
                                   c=self.ball.c)**2)
                bias = pmath.mobius_scalar_mul(dist,
                                               self.dist_bias,
                                               c=self.ball.c)
                projected = pmath.mobius_add(projected, bias, c=self.ball.c)
        else:
            if "hyp" in self.cell_type:
                source_hidden = pmath.logmap0(source_hidden, c=self.c)
                target_hidden = pmath.logmap0(target_hidden, c=self.c)
            projected = self.projector(
                torch.cat((source_hidden, target_hidden), dim=-1))
            if self.use_distance_as_feature:
                dist = torch.sum((source_hidden - target_hidden).pow(2),
                                 dim=-1,
                                 keepdim=True)
                bias = self.dist_bias * dist
                projected = projected + bias

        logits = self.logits(projected)
        # CrossEntropy accepts logits
        return logits
Example #12
0
 def add(self, a, b):
     out = pmath.mobius_add(a, b, k=self.ball.k)
     return pmath.project(out, k=self.ball.k)