def test_gae():
    model = GAE(encoder=lambda x: x)

    x = torch.Tensor([[1, -1], [1, 2], [2, 1]])
    z = model.encode(x)
    assert z.tolist() == x.tolist()

    adj = model.decode_all(z, sigmoid=False)
    assert adj.tolist() == [[+2, -1, +1], [-1, +5, +4], [+1, +4, +5]]

    edge_index = torch.tensor([[0, 1], [1, 2]])
    value = model.decode_indices(z, edge_index, sigmoid=False)
    assert value.tolist() == [-1, 4]

    edge_index = torch.tensor([[0, 1, 2, 3, 4, 5, 6, 7, 8, 9],
                               [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]])
    data = Data(edge_index=edge_index)
    data = model.split_edges(data, val_ratio=0.2, test_ratio=0.3)

    assert data.val_pos_edge_index.size() == (2, 2)
    assert data.val_neg_edge_index.size() == (2, 2)
    assert data.test_pos_edge_index.size() == (2, 3)
    assert data.test_neg_edge_index.size() == (2, 3)
    assert data.train_pos_edge_index.size() == (2, 5)
    assert data.train_neg_adj_mask.size() == (11, 11)
    assert data.train_neg_adj_mask.sum().item() == (11**2 - 11) / 2 - 4 - 6 - 5

    z = torch.randn(11, 16)
    loss = model.loss(z, data.train_pos_edge_index, data.train_neg_adj_mask)
    assert loss.item() > 0

    auc, ap = model.evaluate(z, data.val_pos_edge_index,
                             data.val_neg_edge_index)
    assert auc >= 0 and auc <= 1 and ap >= 0 and ap <= 1
Пример #2
0
def train(dataset, args, writer = None):
    task = args.task
    test_loader = loader = DataLoader(dataset, batch_size=args.batch_size, shuffle=True)

    if task == 'link':
        model = GAE(models.GNNStack(dataset.num_node_features, args.hidden_dim, int(dataset.num_classes), 
                            args))
    elif task == 'node':
        model = models.GNNStack(dataset.num_node_features, args.hidden_dim, int(dataset.num_classes), 
                            args)
    else:
        raise RuntimeError("Unknown task.")
    metrics_for_labels = True if args.metrics_for_labels == 'True' else False
    scheduler, opt = build_optimizer(args, model.parameters())
    print("Training \nModel: {}, Data representation: {}. Dataset: {}, Task type: {}". format(args.model_name, args.graph_type, args.dataset, args.task))
    metric_text = 'test accuracy' if task == 'node' else 'test precision'
    for epoch in range(args.epochs):
        total_loss = 0
        model.train()
        for batch in loader:
            opt.zero_grad()
            if task == 'node':
                pred = model(batch)
                label = batch.y
                pred = pred[batch.train_mask]
                label = label[batch.train_mask]
                loss = model.loss(pred, label)
            else:
                train_pos_edge_index = batch.train_pos_edge_index
                z = model.encode(batch)
                loss = model.recon_loss(z, train_pos_edge_index)
            loss.backward()
            opt.step()
            total_loss += loss.item() * batch.num_graphs
        total_loss /= len(loader.dataset)
        if writer == None:
            print(total_loss)
        else:
            writer.add_scalar("loss", total_loss, epoch)

        if epoch % 10 == 0:
            test_metric, _ = test(loader, model, task = task)
            if writer == None:
                print(test_metric, metric_text)
            else:
                writer.add_scalar(metric_text, test_metric, epoch)
        if metrics_for_labels == True and epoch == args.epochs -1:
            _, labels_metrics = test(loader, model, task = task, metrics_for_labels=metrics_for_labels)
            print('{} for labels:\n {}'.format(metric_text, labels_metrics))