def mobius_gru_cell(
    input: torch.Tensor,
    hx: torch.Tensor,
    weight_ih: torch.Tensor,
    weight_hh: torch.Tensor,
    bias: torch.Tensor,
    c: torch.Tensor,
    nonlin=None,
):
    W_ir, W_ih, W_iz = weight_ih.chunk(3)
    b_r, b_h, b_z = bias
    W_hr, W_hh, W_hz = weight_hh.chunk(3)
    # print ('Inside GRU Cell: ')
    # print ('W_hz: ', W_hz.shape)
    # print ('hx: ', hx.shape)
    # print ('W_iz: ', W_iz.shape)
    # print ('input: ', input.shape)
    z_t = pmath.logmap0(one_rnn_transform(W_hz, hx, W_iz, input, b_z, c), c=c).sigmoid()
    r_t = pmath.logmap0(one_rnn_transform(W_hr, hx, W_ir, input, b_r, c), c=c).sigmoid()

    rh_t = pmath.mobius_pointwise_mul(r_t, hx, c=c)
    h_tilde = one_rnn_transform(W_hh, rh_t, W_ih, input, b_h, c)

    if nonlin is not None:
        h_tilde = pmath.mobius_fn_apply(nonlin, h_tilde, c=c)
    delta_h = pmath.mobius_add(-hx, h_tilde, c=c)
    h_out = pmath.mobius_add(hx, pmath.mobius_pointwise_mul(z_t, delta_h, c=c), c=c)
    return h_out
Example #2
0
 def forward(self, input):
     x, adj = input
     h = pmath.logmap0(x)
     h, _ = self.conv((h, adj))
     h = F.dropout(h, p=self.p, training=self.training)
     h = pmath.project(pmath.expmap0(h))
     h = F.relu(h)
     output = h, adj
     return output
Example #3
0
    def forward(self, sentence_feats, len_tweets, time_feats):
        """
        sentence_feat: sentence features (B*5*30*N),
        len_tweets: (B*5)
        time_feats: (B*5*30)
        """
        sentence_feats = pmath_geo.expmap0(sentence_feats, c=self.c)
        # time_feats = pmath_geo.expmap0(time_feats, c=self.c)
        sentence_feats = sentence_feats.permute(1, 0, 2, 3)
        len_days, self.bs, _, _ = sentence_feats.size()
        h_init, c_init = self.init_hidden()

        len_tweets = len_tweets.permute(1, 0)
        time_feats = time_feats.permute(1, 0, 2)

        lstm1_out = torch.zeros(len_days, self.bs,
                                self.lstm1_outshape).to(self.device)

        for i in range(len_days):
            temp_lstmout, (_, _) = self.lstm1(sentence_feats[i], time_feats[i],
                                              (h_init, c_init))
            last_idx = len_tweets[i]
            last_idx = last_idx.type(torch.int).tolist()
            temp_hn = torch.zeros(self.bs, 1, self.lstm1_outshape)
            for j in range(self.bs):
                temp_hn[j] = temp_lstmout[j, last_idx[j] - 1, :]
            lstm1_out[i] = temp_hn.squeeze(1)
        lstm1_out = lstm1_out.permute(1, 0, 2)
        batch_size = lstm1_out.shape[0]
        num_of_timesteps = lstm1_out.shape[1]
        '''
        Hyberpolic exp
        '''
        all_outputs, cell_output = self.cell_source(lstm1_out.permute(1, 0, 2))

        cell_output = cell_output[-1]
        x = pmath_geo.logmap0(cell_output, c=self.c)
        x = self.drop(self.relu(self.linear3(x)))
        cse_output = self.linear4(x)
        margin_output = self.linear5(x)
        return cse_output, margin_output
Example #4
0
    def forward(self, input):
        x, adj = input
        x = self.hy_linear.forward(x)

        edge_index = adj._indices()
        edge_index, _ = remove_self_loops(edge_index)
        edge_index = add_self_loops(edge_index, num_nodes=x.size(0))

        log_x = pmath.logmap0(x, c=1.0)  # Get log(x) as input to GCN
        log_x = log_x.view(-1, self.heads, self.out_channels)
        out = self.propagate(edge_index,
                             x=log_x,
                             num_nodes=x.size(0),
                             original_x=x)
        out = self.manifold.proj_tan0(out, c=self.c)

        out = self.act(out)
        out = self.manifold.proj_tan0(out, c=self.c)

        return self.manifold.proj(self.manifold.expmap0(out, c=self.c),
                                  c=self.c)
Example #5
0
def train(args):
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    if int(args.double_precision):
        torch.set_default_dtype(torch.float64)
    if int(args.cuda) >= 0:
        torch.cuda.manual_seed(args.seed)
    args.device = 'cuda:' + str(args.cuda) if int(args.cuda) >= 0 else 'cpu'
    args.patience = args.epochs if not args.patience else int(args.patience)
    logging.getLogger().setLevel(logging.INFO)
    if args.save:
        if not args.save_dir:
            dt = datetime.datetime.now()
            date = f"{dt.year}_{dt.month}_{dt.day}"
            models_dir = os.path.join(os.environ['LOG_DIR'], args.task, date)
            save_dir = get_dir_name(models_dir)
        else:
            save_dir = args.save_dir
        logging.basicConfig(level=logging.INFO,
                            handlers=[
                                logging.FileHandler(
                                    os.path.join(save_dir, 'log.txt')),
                                logging.StreamHandler()
                            ])

    logging.info(f'Using: {args.device}')
    logging.info("Using seed {}.".format(args.seed))

    # Load data
    data = load_data(args, os.path.join(os.environ['DATAPATH'], args.dataset))
    args.n_nodes, args.feat_dim = data['features'].shape

    if args.task == 'nc':
        Model = NCModel
        args.n_classes = int(data['labels'].max() + 1)
        logging.info(f'Num classes: {args.n_classes}')
    else:
        args.nb_false_edges = len(data['train_edges_false'])
        args.nb_edges = len(data['train_edges'])
        if args.task == 'lp':
            Model = LPModel

    if not args.lr_reduce_freq:
        args.lr_reduce_freq = args.epochs

    # Model and optimizer
    model = Model(args)
    logging.info(str(model))
    optimizer = getattr(optimizers,
                        args.optimizer)(params=model.parameters(),
                                        lr=args.lr,
                                        weight_decay=args.weight_decay)
    lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer,
                                                   step_size=int(
                                                       args.lr_reduce_freq),
                                                   gamma=float(args.gamma))
    tot_params = sum([np.prod(p.size()) for p in model.parameters()])
    logging.info(f"Total number of parameters: {tot_params}")
    if args.cuda is not None and int(args.cuda) >= 0:
        os.environ['CUDA_VISIBLE_DEVICES'] = str(args.cuda)
        model = model.to(args.device)
        for x, val in data.items():
            if torch.is_tensor(data[x]):
                data[x] = data[x].to(args.device)

    # Train model
    t_total = time.time()
    counter = 0
    best_val_metrics = model.init_metric_dict()
    best_test_metrics = None
    best_emb = None
    for epoch in range(args.epochs):
        t = time.time()
        model.train()
        optimizer.zero_grad()
        embeddings = model.encode(data['features'], data['adj_train_norm'])
        train_metrics = model.compute_metrics(embeddings, data, 'train', args)
        train_metrics['loss'].backward()

        if args.grad_clip is not None:
            max_norm = float(args.grad_clip)
            all_params = list(model.parameters())
            for param in all_params:
                torch.nn.utils.clip_grad_norm_(param, max_norm)
        optimizer.step()
        lr_scheduler.step()
        if (epoch + 1) % args.log_freq == 0:
            logging.info(" ".join([
                'Epoch: {:04d}'.format(epoch + 1),
                'lr: {}'.format(lr_scheduler.get_lr()[0]),
                format_metrics(train_metrics, 'train'),
                'time: {:.4f}s'.format(time.time() - t)
            ]))
        if (epoch + 1) % args.eval_freq == 0:
            model.eval()
            embeddings = model.encode(data['features'], data['adj_train_norm'])
            val_metrics = model.compute_metrics(embeddings, data, 'val', args)
            if (epoch + 1) % args.log_freq == 0:
                logging.info(" ".join([
                    'Epoch: {:04d}'.format(epoch + 1),
                    format_metrics(val_metrics, 'val')
                ]))
            if model.has_improved(best_val_metrics, val_metrics):
                best_test_metrics = model.compute_metrics(
                    embeddings, data, 'test', args)
                if isinstance(embeddings, tuple):
                    best_emb = torch.cat(
                        (pmath.logmap0(embeddings[0], c=1.0), embeddings[1]),
                        dim=1).cpu()
                else:
                    best_emb = embeddings.cpu()
                if args.save:
                    np.save(os.path.join(save_dir, 'embeddings.npy'),
                            best_emb.detach().numpy())

                best_val_metrics = val_metrics
                counter = 0
            else:
                counter += 1
                if counter == args.patience and epoch > args.min_epochs:
                    logging.info("Early stopping")
                    break

    logging.info("Optimization Finished!")
    logging.info("Total time elapsed: {:.4f}s".format(time.time() - t_total))
    if not best_test_metrics:
        model.eval()
        best_emb = model.encode(data['features'], data['adj_train_norm'])
        best_test_metrics = model.compute_metrics(best_emb, data, 'test', args)
    logging.info(" ".join(
        ["Val set results:",
         format_metrics(best_val_metrics, 'val')]))
    logging.info(" ".join(
        ["Test set results:",
         format_metrics(best_test_metrics, 'test')]))

    if args.save:
        if isinstance(best_emb, tuple):
            best_emb = torch.cat(
                (pmath.logmap0(best_emb[0], c=1.0), best_emb[1]), dim=1).cpu()
        else:
            best_emb = best_emb.cpu()
        np.save(os.path.join(save_dir, 'embeddings.npy'),
                best_emb.detach().numpy())
        if hasattr(model.encoder, 'att_adj'):
            filename = os.path.join(save_dir, args.dataset + '_att_adj.p')
            pickle.dump(model.encoder.att_adj.cpu().to_dense(),
                        open(filename, 'wb'))
            print('Dumped attention adj: ' + filename)

        json.dump(vars(args), open(os.path.join(save_dir, 'config.json'), 'w'))
        torch.save(model.state_dict(), os.path.join(save_dir, 'model.pth'))
        logging.info(f"Saved model in {save_dir}")
Example #6
0
    def forward(self, input):
        source_input = input[0]
        target_input = input[1]
        alignment = input[2]
        batch_size = alignment.shape[0]

        source_input_data = self.embedding(source_input.data)
        target_input_data = self.embedding(target_input.data)

        zero_hidden = torch.zeros(self.num_layers,
                                  batch_size,
                                  self.hidden_dim,
                                  device=self.device or source_input.device,
                                  dtype=source_input_data.dtype)

        if self.embedding_type == "eucl" and "hyp" in self.cell_type:
            source_input_data = pmath.expmap0(source_input_data, c=self.c)
            target_input_data = pmath.expmap0(target_input_data, c=self.c)
        elif self.embedding_type == "hyp" and "eucl" in self.cell_type:
            source_input_data = pmath.logmap0(source_input_data, c=self.c)
            target_input_data = pmath.logmap0(target_input_data, c=self.c)
        # ht: (num_layers * num_directions, batch, hidden_size)

        source_input = torch.nn.utils.rnn.PackedSequence(
            source_input_data, source_input.batch_sizes)
        target_input = torch.nn.utils.rnn.PackedSequence(
            target_input_data, target_input.batch_sizes)

        _, source_hidden = self.cell_source(source_input, zero_hidden)
        _, target_hidden = self.cell_target(target_input, zero_hidden)

        # take hiddens from the last layer
        source_hidden = source_hidden[-1]
        target_hidden = target_hidden[-1][alignment]

        if self.decision_type == "hyp":
            if "eucl" in self.cell_type:
                source_hidden = pmath.expmap0(source_hidden, c=self.c)
                target_hidden = pmath.expmap0(target_hidden, c=self.c)
            source_projected = self.projector_source(source_hidden)
            target_projected = self.projector_target(target_hidden)
            projected = pmath.mobius_add(source_projected,
                                         target_projected,
                                         c=self.ball.c)
            if self.use_distance_as_feature:
                dist = (pmath.dist(source_hidden,
                                   target_hidden,
                                   dim=-1,
                                   keepdim=True,
                                   c=self.ball.c)**2)
                bias = pmath.mobius_scalar_mul(dist,
                                               self.dist_bias,
                                               c=self.ball.c)
                projected = pmath.mobius_add(projected, bias, c=self.ball.c)
        else:
            if "hyp" in self.cell_type:
                source_hidden = pmath.logmap0(source_hidden, c=self.c)
                target_hidden = pmath.logmap0(target_hidden, c=self.c)
            projected = self.projector(
                torch.cat((source_hidden, target_hidden), dim=-1))
            if self.use_distance_as_feature:
                dist = torch.sum((source_hidden - target_hidden).pow(2),
                                 dim=-1,
                                 keepdim=True)
                bias = self.dist_bias * dist
                projected = projected + bias

        logits = self.logits(projected)
        # CrossEntropy accepts logits
        return logits
Example #7
0
 def forward(self, x_h, x_e):
     dist = (pmath.logmap0(x_h, c=self.c) - x_e).pow(2).sum(dim=-1) * self.att
     x_h = dist.view([-1, 1]) * pmath.logmap0(x_h, c=self.c)
     x_h = F.dropout(x_h, p=self.drop, training=self.training)
     x_e = x_e + x_h
     return x_e
Example #8
0
    def forward(self, inputs, timestamps, hidden_states, reverse=False):
        b, seq, embed = inputs.size()
        h = hidden_states[0]
        _c = hidden_states[1]
        if self.cuda_flag:
            h = h.cuda()
            _c = _c.cuda()
        outputs = []
        hidden_state_h = []
        hidden_state_c = []

        for s in range(seq):
            c_s1 = pmath_geo.expmap0(
                torch.tanh(
                    pmath_geo.logmap0(pmath_geo.mobius_matvec(self.W_d,
                                                              _c,
                                                              c=self.c),
                                      c=self.c)))  # short term mem
            c_s2 = pmath_geo.mobius_pointwise_mul(
                c_s1, timestamps[:, s:s + 1].expand_as(c_s1),
                c=self.c)  # discounted short term mem
            c_l = pmath_geo.mobius_add(-c_s1, _c, c=self.c)  # long term mem
            c_adj = pmath_geo.mobius_add(c_l, c_s2, c=self.c)

            W_f, W_i, W_o, W_c_tmp = self.W_all.chunk(4, dim=1)
            U_f, U_i, U_o, U_c_tmp = self.U_all.chunk(4, dim=0)
            # print ('WF: ', W_f.shape)
            # print ('H: ', h.shape)
            # print ('UF: ', U_f.shape)
            # print ('X: ', inputs[:, s].shape)
            f = pmath_geo.logmap0(one_rnn_transform(W_f, h, U_f, inputs[:, s],
                                                    self.c),
                                  c=self.c).sigmoid()
            i = pmath_geo.logmap0(one_rnn_transform(W_i, h, U_i, inputs[:, s],
                                                    self.c),
                                  c=self.c).sigmoid()
            o = pmath_geo.logmap0(one_rnn_transform(W_o, h, U_o, inputs[:, s],
                                                    self.c),
                                  c=self.c).sigmoid()
            c_tmp = pmath_geo.logmap0(one_rnn_transform(
                W_c_tmp, h, U_c_tmp, inputs[:, s], self.c),
                                      c=self.c).sigmoid()

            f_dot_c_adj = pmath_geo.mobius_pointwise_mul(f, c_adj, c=self.c)
            i_dot_c_tmp = pmath_geo.mobius_pointwise_mul(i, c_tmp, c=self.c)
            _c = pmath_geo.mobius_add(i_dot_c_tmp, f_dot_c_adj, c=self.c)

            h = pmath_geo.mobius_pointwise_mul(o,
                                               pmath_geo.expmap0(
                                                   torch.tanh(_c), c=self.c),
                                               c=self.c)
            outputs.append(o)
            hidden_state_c.append(_c)
            hidden_state_h.append(h)

        if reverse:
            outputs.reverse()
            hidden_state_c.reverse()
            hidden_state_h.reverse()
        outputs = torch.stack(outputs, 1)
        hidden_state_c = torch.stack(hidden_state_c, 1)
        hidden_state_h = torch.stack(hidden_state_h, 1)

        return outputs, (h, _c)
Example #9
0
def inverse_exp_map_mu0(x: Tensor, radius: Tensor) -> Tensor:
    return pm.logmap0(x, c=_c(radius))