示例#1
0
def test_vgae():
    model = VGAE(encoder=lambda x: (x, x))

    x = torch.Tensor([[1, -1], [1, 2], [2, 1]])
    model.encode(x)
    assert model.kl_loss().item() > 0

    model.eval()
    model.encode(x)
示例#2
0
        model = VGAE(VGAE_Encoder(num_features))
        optimizer = torch.optim.Adam(model.parameters(), lr=0.01)

        # Training loop
        for epoch in range(args.epochs):
            model.train()
            optimizer.zero_grad()
            z = model.encode(data.x, data.train_pos_edge_index)
            loss = model.recon_loss(
                z, data.train_pos_edge_index)  #0.01*model.kl_loss()
            loss.backward()
            optimizer.step()

            # Log validation metrics
            if epoch % args.val_freq == 0:
                model.eval()
                with torch.no_grad():
                    z = model.encode(data.x, data.train_pos_edge_index)
                    auc, ap = model.test(z, data.val_pos_edge_index,
                                         data.val_neg_edge_index)
                print('Train loss: {:.4f}, Validation AUC-ROC: {:.4f}, '
                      'AP: {:.4f} at epoch {:03d}'.format(
                          loss, auc, ap, epoch))

        # Final evaluation
        model.eval()
        with torch.no_grad():
            if args.test:
                z = model.encode(data.x, data.train_pos_edge_index)
                auc, ap = model.test(z, data.test_pos_edge_index,
                                     data.test_neg_edge_index)
示例#3
0
def run_VGAE(input_data,
             output_dir,
             epochs=1000,
             lr=0.01,
             weight_decay=0.0005):
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print('Device: '.ljust(32), device)
    print('Model Name: '.ljust(32), 'VGAE')
    print('Model params:{:19} lr: {}     weight_decay: {}'.format(
        '', lr, weight_decay))
    print('Total number of epochs to run: '.ljust(32), epochs)
    print('*' * 70)

    data = input_data.clone().to(device)
    model = VGAE(VGAEncoder(data.num_features,
                            data.num_classes.item())).to(device)
    data = model.split_edges(data)
    x, train_pos_edge_index, edge_attr = data.x.to(
        device), data.train_pos_edge_index.to(device), data.edge_attr.to(
            device)
    data.train_idx = data.test_idx = data.y = None
    optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
    train_losses = []
    test_losses = []
    aucs = []
    aps = []
    model.train()
    for epoch in range(1, epochs + 1):
        train_loss, test_loss = 0, 0
        optimizer.zero_grad()
        z = model.encode(x, train_pos_edge_index)
        train_loss = model.recon_loss(
            z, train_pos_edge_index) + (1 / data.num_nodes) * model.kl_loss()
        train_losses.append(train_loss.item())
        train_loss.backward()
        optimizer.step()

        model.eval()
        with torch.no_grad():
            z = model.encode(x, train_pos_edge_index)
        auc, ap = model.test(z, data.test_pos_edge_index,
                             data.test_neg_edge_index)
        test_loss = model.recon_loss(
            z,
            data.test_pos_edge_index) + (1 / data.num_nodes) * model.kl_loss()
        test_losses.append(test_loss.item())
        aucs.append(auc)
        aps.append(ap)
        makepath(output_dir)
        figname = os.path.join(
            output_dir, "_".join(
                (VGAE.__name__, str(lr), str(weight_decay), str(epochs))))
        # print('AUC: {:.4f}, AP: {:.4f}'.format(auc, ap))
        if (epoch % int(epochs / 10) == 0):
            print(
                'Epoch: {}        Train loss: {}    Test loss: {}    AUC: {}    AP: {:.4f}'
                .format(epoch, train_loss, test_loss, auc, ap))
        if (epoch == epochs):
            print(
                '-' * 65,
                '\nFinal epoch: {}  Train loss: {}    Test loss: {}    AUC: {}    AP: {}'
                .format(epoch, train_loss, test_loss, auc, ap))
        log = 'Final epoch: {}    Train loss: {}    Test loss: {}    AUC: {}    AP: {}'.format(
            epoch, train_loss, test_loss, auc, ap)
        write_log(log, figname)
    print('-' * 65)

    plot_linkpred(train_losses, test_losses, aucs, aps, output_dir, epochs,
                  figname)
    return