def test_gae(): model = GAE(encoder=lambda x: x) x = torch.Tensor([[1, -1], [1, 2], [2, 1]]) z = model.encode(x) assert z.tolist() == x.tolist() adj = model.decode_all(z, sigmoid=False) assert adj.tolist() == [[+2, -1, +1], [-1, +5, +4], [+1, +4, +5]] edge_index = torch.tensor([[0, 1], [1, 2]]) value = model.decode_indices(z, edge_index, sigmoid=False) assert value.tolist() == [-1, 4] edge_index = torch.tensor([[0, 1, 2, 3, 4, 5, 6, 7, 8, 9], [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]]) data = Data(edge_index=edge_index) data = model.split_edges(data, val_ratio=0.2, test_ratio=0.3) assert data.val_pos_edge_index.size() == (2, 2) assert data.val_neg_edge_index.size() == (2, 2) assert data.test_pos_edge_index.size() == (2, 3) assert data.test_neg_edge_index.size() == (2, 3) assert data.train_pos_edge_index.size() == (2, 5) assert data.train_neg_adj_mask.size() == (11, 11) assert data.train_neg_adj_mask.sum().item() == (11**2 - 11) / 2 - 4 - 6 - 5 z = torch.randn(11, 16) loss = model.loss(z, data.train_pos_edge_index, data.train_neg_adj_mask) assert loss.item() > 0 auc, ap = model.evaluate(z, data.val_pos_edge_index, data.val_neg_edge_index) assert auc >= 0 and auc <= 1 and ap >= 0 and ap <= 1
def train(dataset, args, writer = None): task = args.task test_loader = loader = DataLoader(dataset, batch_size=args.batch_size, shuffle=True) if task == 'link': model = GAE(models.GNNStack(dataset.num_node_features, args.hidden_dim, int(dataset.num_classes), args)) elif task == 'node': model = models.GNNStack(dataset.num_node_features, args.hidden_dim, int(dataset.num_classes), args) else: raise RuntimeError("Unknown task.") metrics_for_labels = True if args.metrics_for_labels == 'True' else False scheduler, opt = build_optimizer(args, model.parameters()) print("Training \nModel: {}, Data representation: {}. Dataset: {}, Task type: {}". format(args.model_name, args.graph_type, args.dataset, args.task)) metric_text = 'test accuracy' if task == 'node' else 'test precision' for epoch in range(args.epochs): total_loss = 0 model.train() for batch in loader: opt.zero_grad() if task == 'node': pred = model(batch) label = batch.y pred = pred[batch.train_mask] label = label[batch.train_mask] loss = model.loss(pred, label) else: train_pos_edge_index = batch.train_pos_edge_index z = model.encode(batch) loss = model.recon_loss(z, train_pos_edge_index) loss.backward() opt.step() total_loss += loss.item() * batch.num_graphs total_loss /= len(loader.dataset) if writer == None: print(total_loss) else: writer.add_scalar("loss", total_loss, epoch) if epoch % 10 == 0: test_metric, _ = test(loader, model, task = task) if writer == None: print(test_metric, metric_text) else: writer.add_scalar(metric_text, test_metric, epoch) if metrics_for_labels == True and epoch == args.epochs -1: _, labels_metrics = test(loader, model, task = task, metrics_for_labels=metrics_for_labels) print('{} for labels:\n {}'.format(metric_text, labels_metrics))