Exemplo n.º 1
0
def mobius_gru_loop(
    input: torch.Tensor,
    h0: torch.Tensor,
    weight_ih: torch.Tensor,
    weight_hh: torch.Tensor,
    bias: torch.Tensor,
    c: torch.Tensor,
    batch_sizes=None,
    hyperbolic_input: bool = False,
    hyperbolic_hidden_state0: bool = False,
    nonlin=None,
):
    if not hyperbolic_hidden_state0:
        hx = pmath.expmap0(h0, c=c)
    else:
        hx = h0
    if not hyperbolic_input:
        input = pmath.expmap0(input, c=c)
    outs = []
    if batch_sizes is None:
        input_unbinded = input.unbind(0)
        for t in range(input.size(0)):
            # print ('Inside GRU loop T: ', t)
            hx = mobius_gru_cell(
                input=input_unbinded[t],
                hx=hx,
                weight_ih=weight_ih,
                weight_hh=weight_hh,
                bias=bias,
                nonlin=nonlin,
                c=c,
            )
            outs.append(hx)
        outs = torch.stack(outs)
        h_last = hx
    else:
        h_last = []
        T = len(batch_sizes) - 1
        for i, t in enumerate(range(batch_sizes.size(0))):
            ix, input = input[: batch_sizes[t]], input[batch_sizes[t] :]
            hx = mobius_gru_cell(
                input=ix,
                hx=hx,
                weight_ih=weight_ih,
                weight_hh=weight_hh,
                bias=bias,
                nonlin=nonlin,
                c=c,
            )
            outs.append(hx)
            if t < T:
                hx, ht = hx[: batch_sizes[t+1]], hx[batch_sizes[t+1]:]
                h_last.append(ht)
            else:
                h_last.append(hx)
        h_last.reverse()
        h_last = torch.cat(h_last)
        outs = torch.cat(outs)
    return outs, h_last
Exemplo n.º 2
0
 def forward(self, x_h, x_e):
     dist = pmath.dist(x_h, pmath.project(pmath.expmap0(x_e, c=self.c))) * self.att
     x_e = pmath.project(pmath.mobius_scalar_mul(dist.view([-1, 1]), pmath.project(pmath.expmap0(x_e, c=self.c))),
                         c=self.c)
     x_e = F.dropout(x_e, p=self.drop, training=self.training)
     x_h = pmath.project(pmath.mobius_add(x_h, x_e))
     return x_h
Exemplo n.º 3
0
 def forward(self, input):
     x, adj = input
     h = pmath.logmap0(x)
     h, _ = self.conv((h, adj))
     h = F.dropout(h, p=self.p, training=self.training)
     h = pmath.project(pmath.expmap0(h))
     h = F.relu(h)
     output = h, adj
     return output
Exemplo n.º 4
0
    def __init__(self, c, args):
        super(DualDecoder, self).__init__(c)
        self.manifold = getattr(manifolds, args.manifold)()
        self.in_features = args.dim
        act = getattr(F, args.act)
        if args.dataset == 'pubmed':
            self.cls_e = GATConv(self.in_features, args.n_classes, 8, False,
                                 args.alpha, args.dropout, args.bias,
                                 lambda x: x)
            self.cls_h = HGATConv(self.manifold,
                                  self.in_features,
                                  args.dim,
                                  8,
                                  False,
                                  args.alpha,
                                  args.dropout,
                                  args.bias,
                                  act,
                                  atten=args.atten,
                                  dist=args.dist)
        else:
            self.cls_e = GATConv(self.in_features, args.n_classes, 1,
                                 args.concat, args.alpha, args.dropout,
                                 args.bias, lambda x: x)
            self.cls_h = HGATConv(self.manifold,
                                  self.in_features,
                                  args.dim,
                                  1,
                                  args.concat,
                                  args.alpha,
                                  args.dropout,
                                  args.bias,
                                  act,
                                  atten=args.atten,
                                  dist=args.dist)

        self.in_features = args.dim
        self.out_features = args.n_classes
        self.c = c
        self.ball = ball = geoopt.PoincareBall(c=c)
        self.sphere = sphere = geoopt.manifolds.Sphere()
        self.scale = nn.Parameter(torch.zeros(self.out_features))
        point = torch.randn(self.out_features, self.in_features) / 4
        point = pmath.expmap0(point.to(args.device), c=c)
        tangent = torch.randn(self.out_features, self.in_features)
        self.point = geoopt.ManifoldParameter(point, manifold=ball)
        with torch.no_grad():
            self.tangent = geoopt.ManifoldParameter(tangent,
                                                    manifold=sphere).proj_()
        self.decoder_name = 'DualDecoder'
        '''prob weight'''
        self.w_e = nn.Linear(args.n_classes, 1, bias=False)
        self.w_h = nn.Linear(args.dim, 1, bias=False)
        self.drop_e = args.drop_e
        self.drop_h = args.drop_h
        self.reset_param()
Exemplo n.º 5
0
def mobius_linear(
    input,
    weight,
    bias=None,
    hyperbolic_input=True,
    hyperbolic_bias=True,
    nonlin=None,
    c=1.0,
):
    if hyperbolic_input:
        output = pmath.mobius_matvec(weight, input, c=c)
    else:
        output = torch.nn.functional.linear(input, weight)
        output = pmath.expmap0(output, c=c)
    if bias is not None:
        if not hyperbolic_bias:
            bias = pmath.expmap0(bias, c=c)
        output = pmath.mobius_add(output, bias, c=c)
    if nonlin is not None:
        output = pmath.mobius_fn_apply(nonlin, output, c=c)
    output = pmath.project(output, c=c)
    return output
Exemplo n.º 6
0
 def __init__(self, in_features, out_features, c=1.0):
     super().__init__()
     self.in_features = in_features
     self.out_features = out_features
     self.ball = ball = geoopt.PoincareBall(c=c)
     self.sphere = sphere = geoopt.manifolds.Sphere()
     self.scale = torch.nn.Parameter(torch.zeros(out_features))
     point = torch.randn(out_features, in_features) / 4
     point = pmath.expmap0(point, c=c)
     tangent = torch.randn(out_features, in_features)
     self.point = geoopt.ManifoldParameter(point, manifold=ball)
     with torch.no_grad():
         self.tangent = geoopt.ManifoldParameter(tangent, manifold=sphere).proj_()
Exemplo n.º 7
0
 def __init__(
     self,
     input_size,
     hidden_size,
     num_layers=1,
     bias=True,
     nonlin=None,
     hyperbolic_input=True,
     hyperbolic_hidden_state0=True,
     c=1.0,
 ):
     super().__init__()
     self.ball = geoopt.PoincareBall(c=c)
     self.input_size = input_size
     self.hidden_size = hidden_size
     self.num_layers = num_layers
     self.bias = bias
     self.weight_ih = torch.nn.ParameterList(
         [
             torch.nn.Parameter(
                 torch.Tensor(3 * hidden_size, input_size if i == 0 else hidden_size)
             )
             for i in range(num_layers)
         ]
     )
     self.weight_hh = torch.nn.ParameterList(
         [
             torch.nn.Parameter(torch.Tensor(3 * hidden_size, hidden_size))
             for _ in range(num_layers)
         ]
     )
     if bias:
         biases = []
         for i in range(num_layers):
             bias = torch.randn(3, hidden_size) * 1e-5
             bias = geoopt.ManifoldParameter(
                 pmath.expmap0(bias, c=self.ball.c), manifold=self.ball
             )
             biases.append(bias)
         self.bias = torch.nn.ParameterList(biases)
     else:
         self.register_buffer("bias", None)
     self.nonlin = nonlin
     self.hyperbolic_input = hyperbolic_input
     self.hyperbolic_hidden_state0 = hyperbolic_hidden_state0
     self.reset_parameters()
Exemplo n.º 8
0
 def __init__(self,
              *args,
              hyperbolic_input=True,
              hyperbolic_bias=True,
              nonlin=None,
              c=1.0,
              **kwargs):
     super().__init__(*args, **kwargs)
     if self.bias is not None:
         if hyperbolic_bias:
             self.ball = manifold = geoopt.PoincareBall(c=c)
             self.bias = geoopt.ManifoldParameter(self.bias,
                                                  manifold=manifold)
             with torch.no_grad():
                 self.bias.set_(pmath.expmap0(self.bias.normal_() / 4, c=c))
     with torch.no_grad():
         self.weight.normal_(std=1e-2)
     self.hyperbolic_bias = hyperbolic_bias
     self.hyperbolic_input = hyperbolic_input
     self.nonlin = nonlin
Exemplo n.º 9
0
    def forward(self, sentence_feats, len_tweets, time_feats):
        """
        sentence_feat: sentence features (B*5*30*N),
        len_tweets: (B*5)
        time_feats: (B*5*30)
        """
        sentence_feats = pmath_geo.expmap0(sentence_feats, c=self.c)
        # time_feats = pmath_geo.expmap0(time_feats, c=self.c)
        sentence_feats = sentence_feats.permute(1, 0, 2, 3)
        len_days, self.bs, _, _ = sentence_feats.size()
        h_init, c_init = self.init_hidden()

        len_tweets = len_tweets.permute(1, 0)
        time_feats = time_feats.permute(1, 0, 2)

        lstm1_out = torch.zeros(len_days, self.bs,
                                self.lstm1_outshape).to(self.device)

        for i in range(len_days):
            temp_lstmout, (_, _) = self.lstm1(sentence_feats[i], time_feats[i],
                                              (h_init, c_init))
            last_idx = len_tweets[i]
            last_idx = last_idx.type(torch.int).tolist()
            temp_hn = torch.zeros(self.bs, 1, self.lstm1_outshape)
            for j in range(self.bs):
                temp_hn[j] = temp_lstmout[j, last_idx[j] - 1, :]
            lstm1_out[i] = temp_hn.squeeze(1)
        lstm1_out = lstm1_out.permute(1, 0, 2)
        batch_size = lstm1_out.shape[0]
        num_of_timesteps = lstm1_out.shape[1]
        '''
        Hyberpolic exp
        '''
        all_outputs, cell_output = self.cell_source(lstm1_out.permute(1, 0, 2))

        cell_output = cell_output[-1]
        x = pmath_geo.logmap0(cell_output, c=self.c)
        x = self.drop(self.relu(self.linear3(x)))
        cse_output = self.linear4(x)
        margin_output = self.linear5(x)
        return cse_output, margin_output
Exemplo n.º 10
0
def exp_0_map(v, c, dim=-1):
    a = pmath.expmap0(v, c=c, dim=dim)
    return project_hyp_vecs(a, c)
Exemplo n.º 11
0
    def forward(self, input):
        source_input = input[0]
        target_input = input[1]
        alignment = input[2]
        batch_size = alignment.shape[0]

        source_input_data = self.embedding(source_input.data)
        target_input_data = self.embedding(target_input.data)

        zero_hidden = torch.zeros(self.num_layers,
                                  batch_size,
                                  self.hidden_dim,
                                  device=self.device or source_input.device,
                                  dtype=source_input_data.dtype)

        if self.embedding_type == "eucl" and "hyp" in self.cell_type:
            source_input_data = pmath.expmap0(source_input_data, c=self.c)
            target_input_data = pmath.expmap0(target_input_data, c=self.c)
        elif self.embedding_type == "hyp" and "eucl" in self.cell_type:
            source_input_data = pmath.logmap0(source_input_data, c=self.c)
            target_input_data = pmath.logmap0(target_input_data, c=self.c)
        # ht: (num_layers * num_directions, batch, hidden_size)

        source_input = torch.nn.utils.rnn.PackedSequence(
            source_input_data, source_input.batch_sizes)
        target_input = torch.nn.utils.rnn.PackedSequence(
            target_input_data, target_input.batch_sizes)

        _, source_hidden = self.cell_source(source_input, zero_hidden)
        _, target_hidden = self.cell_target(target_input, zero_hidden)

        # take hiddens from the last layer
        source_hidden = source_hidden[-1]
        target_hidden = target_hidden[-1][alignment]

        if self.decision_type == "hyp":
            if "eucl" in self.cell_type:
                source_hidden = pmath.expmap0(source_hidden, c=self.c)
                target_hidden = pmath.expmap0(target_hidden, c=self.c)
            source_projected = self.projector_source(source_hidden)
            target_projected = self.projector_target(target_hidden)
            projected = pmath.mobius_add(source_projected,
                                         target_projected,
                                         c=self.ball.c)
            if self.use_distance_as_feature:
                dist = (pmath.dist(source_hidden,
                                   target_hidden,
                                   dim=-1,
                                   keepdim=True,
                                   c=self.ball.c)**2)
                bias = pmath.mobius_scalar_mul(dist,
                                               self.dist_bias,
                                               c=self.ball.c)
                projected = pmath.mobius_add(projected, bias, c=self.ball.c)
        else:
            if "hyp" in self.cell_type:
                source_hidden = pmath.logmap0(source_hidden, c=self.c)
                target_hidden = pmath.logmap0(target_hidden, c=self.c)
            projected = self.projector(
                torch.cat((source_hidden, target_hidden), dim=-1))
            if self.use_distance_as_feature:
                dist = torch.sum((source_hidden - target_hidden).pow(2),
                                 dim=-1,
                                 keepdim=True)
                bias = self.dist_bias * dist
                projected = projected + bias

        logits = self.logits(projected)
        # CrossEntropy accepts logits
        return logits
Exemplo n.º 12
0
    def __init__(
        self,
        vocab_size,
        embedding_dim,
        hidden_dim,
        project_dim,
        cell_type="eucl_rnn",
        embedding_type="eucl",
        decision_type="eucl",
        use_distance_as_feature=True,
        device=None,
        num_layers=1,
        num_classes=1,
        c=1.0,
    ):
        super(RNNBase, self).__init__()
        (cell_type, embedding_type,
         decision_type) = map(str.lower,
                              [cell_type, embedding_type, decision_type])
        if embedding_type == "eucl":
            self.embedding = hyrnn.LookupEmbedding(vocab_size,
                                                   embedding_dim,
                                                   manifold=geoopt.Euclidean())
            with torch.no_grad():
                self.embedding.weight.normal_()
        elif embedding_type == "hyp":
            self.embedding = hyrnn.LookupEmbedding(
                vocab_size,
                embedding_dim,
                manifold=geoopt.PoincareBall(c=c),
            )
            with torch.no_grad():
                self.embedding.weight.set_(
                    pmath.expmap0(self.embedding.weight.normal_() / 10, c=c))
        else:
            raise NotImplementedError(
                "Unsuported embedding type: {0}".format(embedding_type))
        self.embedding_type = embedding_type
        if decision_type == "eucl":
            self.projector = nn.Linear(hidden_dim * 2, project_dim)
            self.logits = nn.Linear(project_dim, num_classes)
        elif decision_type == "hyp":
            self.projector_source = hyrnn.MobiusLinear(hidden_dim,
                                                       project_dim,
                                                       c=c)
            self.projector_target = hyrnn.MobiusLinear(hidden_dim,
                                                       project_dim,
                                                       c=c)
            self.logits = hyrnn.MobiusDist2Hyperplane(project_dim, num_classes)
        else:
            raise NotImplementedError(
                "Unsuported decision type: {0}".format(decision_type))
        self.ball = geoopt.PoincareBall(c)
        if use_distance_as_feature:
            if decision_type == "eucl":
                self.dist_bias = nn.Parameter(torch.zeros(project_dim))
            else:
                self.dist_bias = geoopt.ManifoldParameter(
                    torch.zeros(project_dim), manifold=self.ball)
        else:
            self.register_buffer("dist_bias", None)
        self.decision_type = decision_type
        self.use_distance_as_feature = use_distance_as_feature
        self.device = device  # declaring device here due to fact we are using catalyst
        self.num_layers = num_layers
        self.hidden_dim = hidden_dim
        self.c = c

        if cell_type == "eucl_rnn":
            self.cell = nn.RNN
        elif cell_type == "eucl_gru":
            self.cell = nn.GRU
        elif cell_type == "hyp_gru":
            self.cell = functools.partial(hyrnn.MobiusGRU, c=c)
        else:
            raise NotImplementedError(
                "Unsuported cell type: {0}".format(cell_type))
        self.cell_type = cell_type

        self.cell_source = self.cell(embedding_dim, self.hidden_dim,
                                     self.num_layers)
        self.cell_target = self.cell(embedding_dim, self.hidden_dim,
                                     self.num_layers)
Exemplo n.º 13
0
    def forward(self, inputs, timestamps, hidden_states, reverse=False):
        b, seq, embed = inputs.size()
        h = hidden_states[0]
        _c = hidden_states[1]
        if self.cuda_flag:
            h = h.cuda()
            _c = _c.cuda()
        outputs = []
        hidden_state_h = []
        hidden_state_c = []

        for s in range(seq):
            c_s1 = pmath_geo.expmap0(
                torch.tanh(
                    pmath_geo.logmap0(pmath_geo.mobius_matvec(self.W_d,
                                                              _c,
                                                              c=self.c),
                                      c=self.c)))  # short term mem
            c_s2 = pmath_geo.mobius_pointwise_mul(
                c_s1, timestamps[:, s:s + 1].expand_as(c_s1),
                c=self.c)  # discounted short term mem
            c_l = pmath_geo.mobius_add(-c_s1, _c, c=self.c)  # long term mem
            c_adj = pmath_geo.mobius_add(c_l, c_s2, c=self.c)

            W_f, W_i, W_o, W_c_tmp = self.W_all.chunk(4, dim=1)
            U_f, U_i, U_o, U_c_tmp = self.U_all.chunk(4, dim=0)
            # print ('WF: ', W_f.shape)
            # print ('H: ', h.shape)
            # print ('UF: ', U_f.shape)
            # print ('X: ', inputs[:, s].shape)
            f = pmath_geo.logmap0(one_rnn_transform(W_f, h, U_f, inputs[:, s],
                                                    self.c),
                                  c=self.c).sigmoid()
            i = pmath_geo.logmap0(one_rnn_transform(W_i, h, U_i, inputs[:, s],
                                                    self.c),
                                  c=self.c).sigmoid()
            o = pmath_geo.logmap0(one_rnn_transform(W_o, h, U_o, inputs[:, s],
                                                    self.c),
                                  c=self.c).sigmoid()
            c_tmp = pmath_geo.logmap0(one_rnn_transform(
                W_c_tmp, h, U_c_tmp, inputs[:, s], self.c),
                                      c=self.c).sigmoid()

            f_dot_c_adj = pmath_geo.mobius_pointwise_mul(f, c_adj, c=self.c)
            i_dot_c_tmp = pmath_geo.mobius_pointwise_mul(i, c_tmp, c=self.c)
            _c = pmath_geo.mobius_add(i_dot_c_tmp, f_dot_c_adj, c=self.c)

            h = pmath_geo.mobius_pointwise_mul(o,
                                               pmath_geo.expmap0(
                                                   torch.tanh(_c), c=self.c),
                                               c=self.c)
            outputs.append(o)
            hidden_state_c.append(_c)
            hidden_state_h.append(h)

        if reverse:
            outputs.reverse()
            hidden_state_c.reverse()
            hidden_state_h.reverse()
        outputs = torch.stack(outputs, 1)
        hidden_state_c = torch.stack(hidden_state_c, 1)
        hidden_state_h = torch.stack(hidden_state_h, 1)

        return outputs, (h, _c)
Exemplo n.º 14
0
def exp_map_mu0_c(x: Tensor, c: Tensor) -> Tensor:
    return pm.expmap0(x, c=c)