# Training loop for epoch in range(args.epochs): model.train() optimizer.zero_grad() z = model.encode(data.x, data.train_pos_edge_index) loss = model.recon_loss( z, data.train_pos_edge_index) #0.01*model.kl_loss() loss.backward() optimizer.step() # Log validation metrics if epoch % args.val_freq == 0: model.eval() with torch.no_grad(): z = model.encode(data.x, data.train_pos_edge_index) auc, ap = model.test(z, data.val_pos_edge_index, data.val_neg_edge_index) print('Train loss: {:.4f}, Validation AUC-ROC: {:.4f}, ' 'AP: {:.4f} at epoch {:03d}'.format( loss, auc, ap, epoch)) # Final evaluation model.eval() with torch.no_grad(): if args.test: z = model.encode(data.x, data.train_pos_edge_index) auc, ap = model.test(z, data.test_pos_edge_index, data.test_neg_edge_index) else: z = model.encode(data.x, data.train_pos_edge_index) auc, ap = model.test(z, data.val_pos_edge_index, data.val_neg_edge_index)
def run_VGAE(input_data, output_dir, epochs=1000, lr=0.01, weight_decay=0.0005): device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') print('Device: '.ljust(32), device) print('Model Name: '.ljust(32), 'VGAE') print('Model params:{:19} lr: {} weight_decay: {}'.format( '', lr, weight_decay)) print('Total number of epochs to run: '.ljust(32), epochs) print('*' * 70) data = input_data.clone().to(device) model = VGAE(VGAEncoder(data.num_features, data.num_classes.item())).to(device) data = model.split_edges(data) x, train_pos_edge_index, edge_attr = data.x.to( device), data.train_pos_edge_index.to(device), data.edge_attr.to( device) data.train_idx = data.test_idx = data.y = None optimizer = torch.optim.Adam(model.parameters(), lr=0.01) train_losses = [] test_losses = [] aucs = [] aps = [] model.train() for epoch in range(1, epochs + 1): train_loss, test_loss = 0, 0 optimizer.zero_grad() z = model.encode(x, train_pos_edge_index) train_loss = model.recon_loss( z, train_pos_edge_index) + (1 / data.num_nodes) * model.kl_loss() train_losses.append(train_loss.item()) train_loss.backward() optimizer.step() model.eval() with torch.no_grad(): z = model.encode(x, train_pos_edge_index) auc, ap = model.test(z, data.test_pos_edge_index, data.test_neg_edge_index) test_loss = model.recon_loss( z, data.test_pos_edge_index) + (1 / data.num_nodes) * model.kl_loss() test_losses.append(test_loss.item()) aucs.append(auc) aps.append(ap) makepath(output_dir) figname = os.path.join( output_dir, "_".join( (VGAE.__name__, str(lr), str(weight_decay), str(epochs)))) # print('AUC: {:.4f}, AP: {:.4f}'.format(auc, ap)) if (epoch % int(epochs / 10) == 0): print( 'Epoch: {} Train loss: {} Test loss: {} AUC: {} AP: {:.4f}' .format(epoch, train_loss, test_loss, auc, ap)) if (epoch == epochs): print( '-' * 65, '\nFinal epoch: {} Train loss: {} Test loss: {} AUC: {} AP: {}' .format(epoch, train_loss, test_loss, auc, ap)) log = 'Final epoch: {} Train loss: {} Test loss: {} AUC: {} AP: {}'.format( epoch, train_loss, test_loss, auc, ap) write_log(log, figname) print('-' * 65) plot_linkpred(train_losses, test_losses, aucs, aps, output_dir, epochs, figname) return