def main(_): ds = LinkPropPredDataset(FLAGS.dataset) split_edge = ds.get_edge_split() train_edges = split_edge['train']['edge'] train_edges = np.concatenate([train_edges, train_edges[:, ::-1]], axis=0) spa = scipy.sparse.csr_matrix( (np.ones([len(train_edges)]), (train_edges[:, 0], train_edges[:, 1]))) mult_f = tf_fsvd.WYSDeepWalkPF(spa, window=FLAGS.wys_window, mult_degrees=False, neg_sample_coef=FLAGS.wys_neg_coef) tt = tqdm.tqdm(range(FLAGS.num_runs)) test_metrics = [] val_metrics = [] for run in tt: u, s, v = tf_fsvd.fsvd(mult_f, FLAGS.k, n_iter=FLAGS.svd_iters, n_redundancy=FLAGS.k * 3) dataset = LinkPropPredDataset(FLAGS.dataset) evaluator = Evaluator(name=FLAGS.dataset) evaluator.K = FLAGS.hits split_edge = dataset.get_edge_split() metrics = [] for split in ('test', 'valid'): pos_edges = split_edge[split]['edge'] neg_edges = split_edge[split]['edge_neg'] pos_scores = tf.reduce_sum(tf.gather(u * s, pos_edges[:, 0]) * tf.gather(v, pos_edges[:, 1]), axis=1).numpy() neg_scores = tf.reduce_sum(tf.gather(u * s, neg_edges[:, 0]) * tf.gather(v, neg_edges[:, 1]), axis=1).numpy() metric = evaluator.eval({ 'y_pred_pos': pos_scores, 'y_pred_neg': neg_scores }) metrics.append(metric['hits@%i' % FLAGS.hits]) test_metrics.append(metrics[0]) val_metrics.append(metrics[1]) tt.set_description( 'HITS@%i: validate=%g; test=%g' % (FLAGS.hits, np.mean(val_metrics), np.mean(test_metrics))) print('\n\n *** Trained for %i times and average metrics are:') print('HITS@20 test: mean=%g; std=%g' % (np.mean(test_metrics), np.std(test_metrics))) print('HITS@20 validate: mean=%g; std=%g' % (np.mean(val_metrics), np.std(val_metrics)))
def evaluate_hits(name, pos_pred, neg_pred, K): """ Compute hits Args: name(str): name of dataset pos_pred(Tensor): predict value of positive edges neg_pred(Tensor): predict value of negative edges K(int): num of hits Returns: hits(float): score of hits """ evaluator = Evaluator(name) evaluator.K = K hits = evaluator.eval({ 'y_pred_pos': pos_pred, 'y_pred_neg': neg_pred, })[f'hits@{K}'] return hits
def train_model( train_graph: pyg.torch_geometric.data.Data, valid_graph: pyg.torch_geometric.data.Data, train_dl: data.DataLoader, dev_dl: data.DataLoader, evaluator: Evaluator, model: nn.Module, optimizer: optim.Optimizer, lr_scheduler: optim.lr_scheduler._LRScheduler, args: argparse.Namespace, ) -> nn.Module: device = model_utils.get_device() loss_fn = nn.functional.binary_cross_entropy val_loss_fn = nn.functional.binary_cross_entropy best_val_loss = torch.tensor(float('inf')) best_val_hits = torch.tensor(0.0) saved_checkpoints = [] writer = SummaryWriter(log_dir=f'{args.log_dir}/{args.experiment}') for e in range(1, args.train_epochs + 1): print(f'Training epoch {e}...') # Training portion torch.cuda.empty_cache() torch.set_grad_enabled(True) with tqdm(total=args.train_batch_size * len(train_dl)) as progress_bar: model.train() # Load graph into GPU adj_t = train_graph.adj_t.to(device) edge_index = train_graph.edge_index.to(device) x = train_graph.x.to(device) pos_pred = [] neg_pred = [] for i, (y_pos_edges,) in enumerate(train_dl): y_pos_edges = y_pos_edges.to(device).T y_neg_edges = negative_sampling( edge_index, num_nodes=train_graph.num_nodes, num_neg_samples=y_pos_edges.shape[1] ).to(device) y_batch = torch.cat([torch.ones(y_pos_edges.shape[1]), torch.zeros( y_neg_edges.shape[1])], dim=0).to(device) # Ground truth edge labels (1 or 0) # Forward pass on model optimizer.zero_grad() y_pred = model(adj_t, torch.cat( [y_pos_edges, y_neg_edges], dim=1)) loss = loss_fn(y_pred, y_batch) # Backward pass and optimization loss.backward() optimizer.step() if args.use_scheduler: lr_scheduler.step(loss) batch_acc = torch.mean( 1 - torch.abs(y_batch.detach() - torch.round(y_pred.detach()))).item() pos_pred += [y_pred[y_batch == 1].detach()] neg_pred += [y_pred[y_batch == 0].detach()] progress_bar.update(y_pos_edges.shape[1]) progress_bar.set_postfix(loss=loss.item(), acc=batch_acc) writer.add_scalar( "train/Loss", loss, ((e - 1) * len(train_dl) + i) * args.train_batch_size) writer.add_scalar("train/Accuracy", batch_acc, ((e - 1) * len(train_dl) + i) * args.train_batch_size) del y_pos_edges del y_neg_edges del y_pred del loss del adj_t del edge_index del x # Training set evaluation Hits@K Metrics pos_pred = torch.cat(pos_pred, dim=0) neg_pred = torch.cat(neg_pred, dim=0) results = {} for K in [10, 20, 30]: evaluator.K = K hits = evaluator.eval({ 'y_pred_pos': pos_pred, 'y_pred_neg': neg_pred, })[f'hits@{K}'] results[f'Hits@{K}'] = hits print() print(f'Train Statistics') print('*' * 30) for k, v in results.items(): print(f'{k}: {v}') writer.add_scalar( f"train/{k}", v, (pos_pred.shape[0] + neg_pred.shape[0]) * e) print('*' * 30) del pos_pred del neg_pred # Validation portion torch.cuda.empty_cache() torch.set_grad_enabled(False) with tqdm(total=args.val_batch_size * len(dev_dl)) as progress_bar: model.eval() adj_t = valid_graph.adj_t.to(device) edge_index = valid_graph.edge_index.to(device) x = valid_graph.x.to(device) val_loss = 0.0 accuracy = 0 num_samples_processed = 0 pos_pred = [] neg_pred = [] for i, (edges_batch, y_batch) in enumerate(dev_dl): edges_batch = edges_batch.T.to(device) y_batch = y_batch.to(device) # Forward pass on model in validation environment y_pred = model(adj_t, edges_batch) loss = val_loss_fn(y_pred, y_batch) num_samples_processed += edges_batch.shape[1] batch_acc = torch.mean( 1 - torch.abs(y_batch - torch.round(y_pred))).item() accuracy += batch_acc * edges_batch.shape[1] val_loss += loss.item() * edges_batch.shape[1] pos_pred += [y_pred[y_batch == 1].detach()] neg_pred += [y_pred[y_batch == 0].detach()] progress_bar.update(edges_batch.shape[1]) progress_bar.set_postfix( val_loss=val_loss / num_samples_processed, acc=accuracy/num_samples_processed) writer.add_scalar( "Val/Loss", loss, ((e - 1) * len(dev_dl) + i) * args.val_batch_size) writer.add_scalar( "Val/Accuracy", batch_acc, ((e - 1) * len(dev_dl) + i) * args.val_batch_size) del edges_batch del y_batch del y_pred del loss del adj_t del edge_index del x # Validation evaluation Hits@K Metrics pos_pred = torch.cat(pos_pred, dim=0) neg_pred = torch.cat(neg_pred, dim=0) results = {} for K in [10, 20, 30]: evaluator.K = K hits = evaluator.eval({ 'y_pred_pos': pos_pred, 'y_pred_neg': neg_pred, })[f'hits@{K}'] results[f'Hits@{K}'] = hits print() print(f'Validation Statistics') print('*' * 30) for k, v in results.items(): print(f'{k}: {v}') writer.add_scalar( f"Val/{k}", v, (pos_pred.shape[0] + neg_pred.shape[0]) * e) print('*' * 30) del pos_pred del neg_pred # Save model if it's the best one yet. if results['Hits@20'] > best_val_hits: best_val_hits = results['Hits@20'] filename = f'{args.save_path}/{args.experiment}/{model.__class__.__name__}_best_val.checkpoint' model_utils.save_model(model, filename) print(f'Model saved!') print(f'Best validation Hits@20 yet: {best_val_hits}') # Save model on checkpoints. if e % args.checkpoint_freq == 0: filename = f'{args.save_path}/{args.experiment}/{model.__class__.__name__}_epoch_{e}.checkpoint' model_utils.save_model(model, filename) print(f'Model checkpoint reached!') saved_checkpoints.append(filename) # Delete checkpoints if there are too many while len(saved_checkpoints) > args.num_checkpoints: os.remove(saved_checkpoints.pop(0)) return model
def main(): parser = argparse.ArgumentParser(description='OGBL-DDI (MADGraph)') parser.add_argument('--lr', type=float, default=0.005) parser.add_argument('--epochs', type=int, default=200) parser.add_argument('--eval_steps', type=int, default=5) parser.add_argument('--runs', type=int, default=10) parser.add_argument('--batch_size', type=int, default=4 * 1024) parser.add_argument('--dim', type=int, default=12) parser.add_argument('--heads', type=int, default=12) parser.add_argument('--samples', type=int, default=8) parser.add_argument('--nearest', type=int, default=8) parser.add_argument('--seed', type=int, default=0) parser.add_argument('--sentinels', type=int, default=8) parser.add_argument('--memory', type=str, default='all') parser.add_argument('--softmin', type=bool, default=True) parser.add_argument('--output_csv', type=str, default='') args = parser.parse_args() print(args) DNAME = 'ogbl-ddi' dataset = LinkPropPredDataset(name=DNAME) graph = dataset[0] n_nodes = graph['num_nodes'] data = dataset.get_edge_split() for group in 'train valid test'.split(): if group in data: sets = data[group] for key in ('edge', 'edge_neg'): if key in sets: sets[key] = gpu(torch.from_numpy(sets[key])) data['eval_train'] = { 'edge': data['train']['edge'][torch.randperm( data['train']['edge'].shape[0])[:data['valid']['edge'].shape[0]]] } model = MADGraph( n_nodes=n_nodes, node_feats=args.dim, src=data['train']['edge'][:, 0], dst=data['train']['edge'][:, 1], n_samples=args.samples, n_heads=args.heads, n_sentinels=args.sentinels, memory=['none', 'stat', 'all'].index(args.memory), softmin=args.softmin, n_nearest=args.nearest, ) params = [p for net in [model] for p in net.parameters()] print('params:', sum(p.numel() for p in params)) evaluator = Evaluator(name=DNAME) loggers = { 'Hits@10': Logger(args.runs, args), 'Hits@20': Logger(args.runs, args), 'Hits@30': Logger(args.runs, args), } for run in range(args.runs): torch.manual_seed(args.seed + run) opt = optim.Adam(params, lr=args.lr) torch.nn.init.xavier_uniform_(model.pos.data) torch.nn.init.xavier_uniform_(model.field.data) model.uncertainty.data = model.uncertainty.data * 0 + 1 for epoch in range(1, args.epochs + 1): model.train() for chunk in sample(data['train']['edge'], args.batch_size): opt.zero_grad() p_edge = torch.sigmoid(model(chunk)) edge_neg_chunk = gpu(torch.randint(0, n_nodes, chunk.shape)) p_edge_neg = torch.sigmoid(model(edge_neg_chunk)) loss = (-torch.log(1e-5 + 1 - p_edge_neg).mean() - torch.log(1e-5 + p_edge).mean()) loss.backward() opt.step() if epoch % args.eval_steps: continue with torch.no_grad(): model.eval() p_train = torch.cat([ model(chunk) for chunk in sample( data['eval_train']['edge'], args.batch_size) ]) n_train = torch.cat([ model(chunk) for chunk in sample(data['valid']['edge_neg'], args.batch_size) ]) p_valid = torch.cat([ model(chunk) for chunk in sample(data['valid']['edge'], args.batch_size) ]) n_valid = n_train p_test = torch.cat([ model(chunk) for chunk in sample(data['test']['edge'], args.batch_size) ]) n_test = torch.cat([ model(chunk) for chunk in sample(data['test']['edge_neg'], args.batch_size) ]) for K in [10, 20, 30]: evaluator.K = K key = f'Hits@{K}' h_train = evaluator.eval({ 'y_pred_pos': p_train, 'y_pred_neg': n_train, })[f'hits@{K}'] h_valid = evaluator.eval({ 'y_pred_pos': p_valid, 'y_pred_neg': n_valid, })[f'hits@{K}'] h_test = evaluator.eval({ 'y_pred_pos': p_test, 'y_pred_neg': n_test, })[f'hits@{K}'] loggers[key].add_result(run, (h_train, h_valid, h_test)) print(key) print(f'Run: {run + 1:02d}, ' f'Epoch: {epoch:02d}, ' f'Loss: {loss:.4f}, ' f'Train: {100 * h_train:.2f}%, ' f'Valid: {100 * h_valid:.2f}%, ' f'Test: {100 * h_test:.2f}%') print('---') for key in loggers.keys(): print(key) loggers[key].print_statistics(run) for key in loggers.keys(): print(key) loggers[key].print_statistics()