def mobius_gru_cell( input: torch.Tensor, hx: torch.Tensor, weight_ih: torch.Tensor, weight_hh: torch.Tensor, bias: torch.Tensor, c: torch.Tensor, nonlin=None, ): W_ir, W_ih, W_iz = weight_ih.chunk(3) b_r, b_h, b_z = bias W_hr, W_hh, W_hz = weight_hh.chunk(3) # print ('Inside GRU Cell: ') # print ('W_hz: ', W_hz.shape) # print ('hx: ', hx.shape) # print ('W_iz: ', W_iz.shape) # print ('input: ', input.shape) z_t = pmath.logmap0(one_rnn_transform(W_hz, hx, W_iz, input, b_z, c), c=c).sigmoid() r_t = pmath.logmap0(one_rnn_transform(W_hr, hx, W_ir, input, b_r, c), c=c).sigmoid() rh_t = pmath.mobius_pointwise_mul(r_t, hx, c=c) h_tilde = one_rnn_transform(W_hh, rh_t, W_ih, input, b_h, c) if nonlin is not None: h_tilde = pmath.mobius_fn_apply(nonlin, h_tilde, c=c) delta_h = pmath.mobius_add(-hx, h_tilde, c=c) h_out = pmath.mobius_add(hx, pmath.mobius_pointwise_mul(z_t, delta_h, c=c), c=c) return h_out
def forward(self, input): x, adj = input h = pmath.logmap0(x) h, _ = self.conv((h, adj)) h = F.dropout(h, p=self.p, training=self.training) h = pmath.project(pmath.expmap0(h)) h = F.relu(h) output = h, adj return output
def forward(self, sentence_feats, len_tweets, time_feats): """ sentence_feat: sentence features (B*5*30*N), len_tweets: (B*5) time_feats: (B*5*30) """ sentence_feats = pmath_geo.expmap0(sentence_feats, c=self.c) # time_feats = pmath_geo.expmap0(time_feats, c=self.c) sentence_feats = sentence_feats.permute(1, 0, 2, 3) len_days, self.bs, _, _ = sentence_feats.size() h_init, c_init = self.init_hidden() len_tweets = len_tweets.permute(1, 0) time_feats = time_feats.permute(1, 0, 2) lstm1_out = torch.zeros(len_days, self.bs, self.lstm1_outshape).to(self.device) for i in range(len_days): temp_lstmout, (_, _) = self.lstm1(sentence_feats[i], time_feats[i], (h_init, c_init)) last_idx = len_tweets[i] last_idx = last_idx.type(torch.int).tolist() temp_hn = torch.zeros(self.bs, 1, self.lstm1_outshape) for j in range(self.bs): temp_hn[j] = temp_lstmout[j, last_idx[j] - 1, :] lstm1_out[i] = temp_hn.squeeze(1) lstm1_out = lstm1_out.permute(1, 0, 2) batch_size = lstm1_out.shape[0] num_of_timesteps = lstm1_out.shape[1] ''' Hyberpolic exp ''' all_outputs, cell_output = self.cell_source(lstm1_out.permute(1, 0, 2)) cell_output = cell_output[-1] x = pmath_geo.logmap0(cell_output, c=self.c) x = self.drop(self.relu(self.linear3(x))) cse_output = self.linear4(x) margin_output = self.linear5(x) return cse_output, margin_output
def forward(self, input): x, adj = input x = self.hy_linear.forward(x) edge_index = adj._indices() edge_index, _ = remove_self_loops(edge_index) edge_index = add_self_loops(edge_index, num_nodes=x.size(0)) log_x = pmath.logmap0(x, c=1.0) # Get log(x) as input to GCN log_x = log_x.view(-1, self.heads, self.out_channels) out = self.propagate(edge_index, x=log_x, num_nodes=x.size(0), original_x=x) out = self.manifold.proj_tan0(out, c=self.c) out = self.act(out) out = self.manifold.proj_tan0(out, c=self.c) return self.manifold.proj(self.manifold.expmap0(out, c=self.c), c=self.c)
def train(args): np.random.seed(args.seed) torch.manual_seed(args.seed) if int(args.double_precision): torch.set_default_dtype(torch.float64) if int(args.cuda) >= 0: torch.cuda.manual_seed(args.seed) args.device = 'cuda:' + str(args.cuda) if int(args.cuda) >= 0 else 'cpu' args.patience = args.epochs if not args.patience else int(args.patience) logging.getLogger().setLevel(logging.INFO) if args.save: if not args.save_dir: dt = datetime.datetime.now() date = f"{dt.year}_{dt.month}_{dt.day}" models_dir = os.path.join(os.environ['LOG_DIR'], args.task, date) save_dir = get_dir_name(models_dir) else: save_dir = args.save_dir logging.basicConfig(level=logging.INFO, handlers=[ logging.FileHandler( os.path.join(save_dir, 'log.txt')), logging.StreamHandler() ]) logging.info(f'Using: {args.device}') logging.info("Using seed {}.".format(args.seed)) # Load data data = load_data(args, os.path.join(os.environ['DATAPATH'], args.dataset)) args.n_nodes, args.feat_dim = data['features'].shape if args.task == 'nc': Model = NCModel args.n_classes = int(data['labels'].max() + 1) logging.info(f'Num classes: {args.n_classes}') else: args.nb_false_edges = len(data['train_edges_false']) args.nb_edges = len(data['train_edges']) if args.task == 'lp': Model = LPModel if not args.lr_reduce_freq: args.lr_reduce_freq = args.epochs # Model and optimizer model = Model(args) logging.info(str(model)) optimizer = getattr(optimizers, args.optimizer)(params=model.parameters(), lr=args.lr, weight_decay=args.weight_decay) lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=int( args.lr_reduce_freq), gamma=float(args.gamma)) tot_params = sum([np.prod(p.size()) for p in model.parameters()]) logging.info(f"Total number of parameters: {tot_params}") if args.cuda is not None and int(args.cuda) >= 0: os.environ['CUDA_VISIBLE_DEVICES'] = str(args.cuda) model = model.to(args.device) for x, val in data.items(): if torch.is_tensor(data[x]): data[x] = data[x].to(args.device) # Train model t_total = time.time() counter = 0 best_val_metrics = model.init_metric_dict() best_test_metrics = None best_emb = None for epoch in range(args.epochs): t = time.time() model.train() optimizer.zero_grad() embeddings = model.encode(data['features'], data['adj_train_norm']) train_metrics = model.compute_metrics(embeddings, data, 'train', args) train_metrics['loss'].backward() if args.grad_clip is not None: max_norm = float(args.grad_clip) all_params = list(model.parameters()) for param in all_params: torch.nn.utils.clip_grad_norm_(param, max_norm) optimizer.step() lr_scheduler.step() if (epoch + 1) % args.log_freq == 0: logging.info(" ".join([ 'Epoch: {:04d}'.format(epoch + 1), 'lr: {}'.format(lr_scheduler.get_lr()[0]), format_metrics(train_metrics, 'train'), 'time: {:.4f}s'.format(time.time() - t) ])) if (epoch + 1) % args.eval_freq == 0: model.eval() embeddings = model.encode(data['features'], data['adj_train_norm']) val_metrics = model.compute_metrics(embeddings, data, 'val', args) if (epoch + 1) % args.log_freq == 0: logging.info(" ".join([ 'Epoch: {:04d}'.format(epoch + 1), format_metrics(val_metrics, 'val') ])) if model.has_improved(best_val_metrics, val_metrics): best_test_metrics = model.compute_metrics( embeddings, data, 'test', args) if isinstance(embeddings, tuple): best_emb = torch.cat( (pmath.logmap0(embeddings[0], c=1.0), embeddings[1]), dim=1).cpu() else: best_emb = embeddings.cpu() if args.save: np.save(os.path.join(save_dir, 'embeddings.npy'), best_emb.detach().numpy()) best_val_metrics = val_metrics counter = 0 else: counter += 1 if counter == args.patience and epoch > args.min_epochs: logging.info("Early stopping") break logging.info("Optimization Finished!") logging.info("Total time elapsed: {:.4f}s".format(time.time() - t_total)) if not best_test_metrics: model.eval() best_emb = model.encode(data['features'], data['adj_train_norm']) best_test_metrics = model.compute_metrics(best_emb, data, 'test', args) logging.info(" ".join( ["Val set results:", format_metrics(best_val_metrics, 'val')])) logging.info(" ".join( ["Test set results:", format_metrics(best_test_metrics, 'test')])) if args.save: if isinstance(best_emb, tuple): best_emb = torch.cat( (pmath.logmap0(best_emb[0], c=1.0), best_emb[1]), dim=1).cpu() else: best_emb = best_emb.cpu() np.save(os.path.join(save_dir, 'embeddings.npy'), best_emb.detach().numpy()) if hasattr(model.encoder, 'att_adj'): filename = os.path.join(save_dir, args.dataset + '_att_adj.p') pickle.dump(model.encoder.att_adj.cpu().to_dense(), open(filename, 'wb')) print('Dumped attention adj: ' + filename) json.dump(vars(args), open(os.path.join(save_dir, 'config.json'), 'w')) torch.save(model.state_dict(), os.path.join(save_dir, 'model.pth')) logging.info(f"Saved model in {save_dir}")
def forward(self, input): source_input = input[0] target_input = input[1] alignment = input[2] batch_size = alignment.shape[0] source_input_data = self.embedding(source_input.data) target_input_data = self.embedding(target_input.data) zero_hidden = torch.zeros(self.num_layers, batch_size, self.hidden_dim, device=self.device or source_input.device, dtype=source_input_data.dtype) if self.embedding_type == "eucl" and "hyp" in self.cell_type: source_input_data = pmath.expmap0(source_input_data, c=self.c) target_input_data = pmath.expmap0(target_input_data, c=self.c) elif self.embedding_type == "hyp" and "eucl" in self.cell_type: source_input_data = pmath.logmap0(source_input_data, c=self.c) target_input_data = pmath.logmap0(target_input_data, c=self.c) # ht: (num_layers * num_directions, batch, hidden_size) source_input = torch.nn.utils.rnn.PackedSequence( source_input_data, source_input.batch_sizes) target_input = torch.nn.utils.rnn.PackedSequence( target_input_data, target_input.batch_sizes) _, source_hidden = self.cell_source(source_input, zero_hidden) _, target_hidden = self.cell_target(target_input, zero_hidden) # take hiddens from the last layer source_hidden = source_hidden[-1] target_hidden = target_hidden[-1][alignment] if self.decision_type == "hyp": if "eucl" in self.cell_type: source_hidden = pmath.expmap0(source_hidden, c=self.c) target_hidden = pmath.expmap0(target_hidden, c=self.c) source_projected = self.projector_source(source_hidden) target_projected = self.projector_target(target_hidden) projected = pmath.mobius_add(source_projected, target_projected, c=self.ball.c) if self.use_distance_as_feature: dist = (pmath.dist(source_hidden, target_hidden, dim=-1, keepdim=True, c=self.ball.c)**2) bias = pmath.mobius_scalar_mul(dist, self.dist_bias, c=self.ball.c) projected = pmath.mobius_add(projected, bias, c=self.ball.c) else: if "hyp" in self.cell_type: source_hidden = pmath.logmap0(source_hidden, c=self.c) target_hidden = pmath.logmap0(target_hidden, c=self.c) projected = self.projector( torch.cat((source_hidden, target_hidden), dim=-1)) if self.use_distance_as_feature: dist = torch.sum((source_hidden - target_hidden).pow(2), dim=-1, keepdim=True) bias = self.dist_bias * dist projected = projected + bias logits = self.logits(projected) # CrossEntropy accepts logits return logits
def forward(self, x_h, x_e): dist = (pmath.logmap0(x_h, c=self.c) - x_e).pow(2).sum(dim=-1) * self.att x_h = dist.view([-1, 1]) * pmath.logmap0(x_h, c=self.c) x_h = F.dropout(x_h, p=self.drop, training=self.training) x_e = x_e + x_h return x_e
def forward(self, inputs, timestamps, hidden_states, reverse=False): b, seq, embed = inputs.size() h = hidden_states[0] _c = hidden_states[1] if self.cuda_flag: h = h.cuda() _c = _c.cuda() outputs = [] hidden_state_h = [] hidden_state_c = [] for s in range(seq): c_s1 = pmath_geo.expmap0( torch.tanh( pmath_geo.logmap0(pmath_geo.mobius_matvec(self.W_d, _c, c=self.c), c=self.c))) # short term mem c_s2 = pmath_geo.mobius_pointwise_mul( c_s1, timestamps[:, s:s + 1].expand_as(c_s1), c=self.c) # discounted short term mem c_l = pmath_geo.mobius_add(-c_s1, _c, c=self.c) # long term mem c_adj = pmath_geo.mobius_add(c_l, c_s2, c=self.c) W_f, W_i, W_o, W_c_tmp = self.W_all.chunk(4, dim=1) U_f, U_i, U_o, U_c_tmp = self.U_all.chunk(4, dim=0) # print ('WF: ', W_f.shape) # print ('H: ', h.shape) # print ('UF: ', U_f.shape) # print ('X: ', inputs[:, s].shape) f = pmath_geo.logmap0(one_rnn_transform(W_f, h, U_f, inputs[:, s], self.c), c=self.c).sigmoid() i = pmath_geo.logmap0(one_rnn_transform(W_i, h, U_i, inputs[:, s], self.c), c=self.c).sigmoid() o = pmath_geo.logmap0(one_rnn_transform(W_o, h, U_o, inputs[:, s], self.c), c=self.c).sigmoid() c_tmp = pmath_geo.logmap0(one_rnn_transform( W_c_tmp, h, U_c_tmp, inputs[:, s], self.c), c=self.c).sigmoid() f_dot_c_adj = pmath_geo.mobius_pointwise_mul(f, c_adj, c=self.c) i_dot_c_tmp = pmath_geo.mobius_pointwise_mul(i, c_tmp, c=self.c) _c = pmath_geo.mobius_add(i_dot_c_tmp, f_dot_c_adj, c=self.c) h = pmath_geo.mobius_pointwise_mul(o, pmath_geo.expmap0( torch.tanh(_c), c=self.c), c=self.c) outputs.append(o) hidden_state_c.append(_c) hidden_state_h.append(h) if reverse: outputs.reverse() hidden_state_c.reverse() hidden_state_h.reverse() outputs = torch.stack(outputs, 1) hidden_state_c = torch.stack(hidden_state_c, 1) hidden_state_h = torch.stack(hidden_state_h, 1) return outputs, (h, _c)
def inverse_exp_map_mu0(x: Tensor, radius: Tensor) -> Tensor: return pm.logmap0(x, c=_c(radius))