def test_gae(): model = GAE(encoder=lambda x: x) model.reset_parameters() x = torch.Tensor([[1, -1], [1, 2], [2, 1]]) z = model.encode(x) assert z.tolist() == x.tolist() adj = model.decode(z) assert adj.tolist() == torch.sigmoid( torch.Tensor([[+2, -1, +1], [-1, +5, +4], [+1, +4, +5]])).tolist() edge_index = torch.tensor([[0, 1], [1, 2]]) value = model.decode_indices(z, edge_index) assert value.tolist() == torch.sigmoid(torch.Tensor([-1, 4])).tolist() edge_index = torch.tensor([[0, 1, 2, 3, 4, 5, 6, 7, 8, 9], [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]]) data = Data(edge_index=edge_index) data = model.split_edges(data, val_ratio=0.2, test_ratio=0.3) assert data.val_pos_edge_index.size() == (2, 2) assert data.val_neg_edge_index.size() == (2, 2) assert data.test_pos_edge_index.size() == (2, 3) assert data.test_neg_edge_index.size() == (2, 3) assert data.train_pos_edge_index.size() == (2, 5) assert data.train_neg_adj_mask.size() == (11, 11) assert data.train_neg_adj_mask.sum().item() == (11**2 - 11) / 2 - 4 - 6 - 5 z = torch.randn(11, 16) loss = model.recon_loss(z, data.train_pos_edge_index) assert loss.item() > 0 auc, ap = model.test(z, data.val_pos_edge_index, data.val_neg_edge_index) assert auc >= 0 and auc <= 1 and ap >= 0 and ap <= 1
def test_gae(): model = GAE(encoder=lambda x: x) model.reset_parameters() x = torch.Tensor([[1, -1], [1, 2], [2, 1]]) z = model.encode(x) assert z.tolist() == x.tolist() adj = model.decoder.forward_all(z) assert adj.tolist() == torch.sigmoid( torch.Tensor([[+2, -1, +1], [-1, +5, +4], [+1, +4, +5]])).tolist() edge_index = torch.tensor([[0, 1], [1, 2]]) value = model.decode(z, edge_index) assert value.tolist() == torch.sigmoid(torch.Tensor([-1, 4])).tolist() edge_index = torch.tensor([[0, 1, 2, 3, 4, 5, 6, 7, 8, 9], [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]]) data = Data(edge_index=edge_index) data.num_nodes = edge_index.max().item() + 1 data = train_test_split_edges(data, val_ratio=0.2, test_ratio=0.3) z = torch.randn(11, 16) loss = model.recon_loss(z, data.train_pos_edge_index) assert loss.item() > 0 auc, ap = model.test(z, data.val_pos_edge_index, data.val_neg_edge_index) assert auc >= 0 and auc <= 1 and ap >= 0 and ap <= 1
def test_gae(): model = GAE(encoder=lambda x: x) model.reset_parameters() x = torch.Tensor([[1, -1], [1, 2], [2, 1]]) z = model.encode(x) assert z.tolist() == x.tolist() adj = model.decoder.forward_all(z) assert adj.tolist() == torch.sigmoid( torch.Tensor([[+2, -1, +1], [-1, +5, +4], [+1, +4, +5]])).tolist() edge_index = torch.tensor([[0, 1], [1, 2]]) value = model.decode(z, edge_index) assert value.tolist() == torch.sigmoid(torch.Tensor([-1, 4])).tolist() edge_index = torch.tensor([[0, 1, 2, 3, 4, 5, 6, 7, 8, 9], [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]]) data = Data(edge_index=edge_index, num_nodes=11) transform = RandomLinkSplit(split_labels=True, add_negative_train_samples=False) train_data, val_data, test_data = transform(data) z = torch.randn(11, 16) loss = model.recon_loss(z, train_data.pos_edge_label_index) assert loss.item() > 0 auc, ap = model.test(z, val_data.pos_edge_label_index, val_data.neg_edge_label_index) assert auc >= 0 and auc <= 1 and ap >= 0 and ap <= 1
def run_GAE(input_data, output_dir, epochs=1000, lr=0.01, weight_decay=0.0005): device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') print('Device: '.ljust(32), device) print('Model Name: '.ljust(32), 'GAE') print('Model params:{:19} lr: {} weight_decay: {}'.format( '', lr, weight_decay)) print('Total number of epochs to run: '.ljust(32), epochs) print('*' * 70) data = input_data.clone().to(device) in_channels = data.num_features out_channels = data.num_classes.item() model = GAE(GAEncoder(in_channels, out_channels)).to(device) data = input_data.clone().to(device) split_data = model.split_edges(data) x, train_pos_edge_index, edge_attr = split_data.x.to( device), split_data.train_pos_edge_index.to(device), data.edge_attr.to( device) split_data.train_idx = split_data.test_idx = data.y = None optimizer = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=weight_decay) train_losses, test_losses = [], [] aucs = [] aps = [] model.train() for epoch in range(1, epochs + 1): train_loss = 0 test_loss = 0 optimizer.zero_grad() z = model.encode(x, train_pos_edge_index) train_loss = model.recon_loss(z, train_pos_edge_index) train_losses.append(train_loss) train_loss.backward() optimizer.step() model.eval() with torch.no_grad(): z = model.encode(x, train_pos_edge_index) auc, ap = model.test(z, split_data.test_pos_edge_index, split_data.test_neg_edge_index) test_loss = model.recon_loss(z, data.test_pos_edge_index) test_losses.append(test_loss.item()) aucs.append(auc) aps.append(ap) figname = os.path.join( output_dir, "_".join((GAE.__name__, str(lr), str(weight_decay)))) makepath(output_dir) if (epoch % int(epochs / 10) == 0): print( 'Epoch: {} Train loss: {} Test loss: {} AUC: {} AP: {}' .format(epoch, train_loss, test_loss, auc, ap)) if (epoch == epochs): print( '-' * 65, '\nFinal epoch: {} Train loss: {} Test loss: {} AUC: {} AP: {}' .format(epoch, train_loss, test_loss, auc, ap)) log = 'Final epoch: {} Train loss: {} Test loss: {} AUC: {} AP: {}'.format( epoch, train_loss, test_loss, auc, ap) write_log(log, figname) print('-' * 65) plot_linkpred(train_losses, test_losses, aucs, aps, output_dir, epochs, figname) return
def perturb_edges(data, name, remove_pct, add_pct, hidden_channels=16, epochs=400): if remove_pct == 0 and add_pct == 0: return try: cached = pickle.load( open(f'{ROOT}/cache/edge/{name}_{remove_pct}_{add_pct}.pt', 'rb')) print(f'Use cached edge augmentation for dataset {name}') if data.setting == 'inductive': data.train_edge_index = cached else: data.edge_index = cached return except FileNotFoundError: try: A_pred, adj_orig = pickle.load( open(f'{ROOT}/cache/edge/{name}.pt', 'rb')) A = sample_graph_det(adj_orig, A_pred, remove_pct, add_pct) data.edge_index, _ = from_scipy_sparse_matrix(A) pickle.dump( data.edge_index, open(f'{ROOT}/cache/edge/{name}_{remove_pct}_{add_pct}.pt', 'wb')) return except FileNotFoundError: print( f'cache/edge/{name}_{remove_pct}_{add_pct}.pt not found! Regenerating it now' ) device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') if data.setting == 'inductive': train_data = Data(x=data.train_x, ori_x=data.ori_x, edge_index=data.train_edge_index, y=data.train_y) else: train_data = deepcopy(data) edge_index = deepcopy(train_data.edge_index) train_data = train_test_split_edges(train_data, val_ratio=0.1, test_ratio=0) num_features = train_data.ori_x.shape[1] model = GAE(GCNEncoder(num_features, hidden_channels)) model = model.to(device) x = train_data.ori_x.to(device) train_pos_edge_index = train_data.train_pos_edge_index.to(device) optimizer = torch.optim.Adam(model.parameters(), lr=0.01) best_val_auc = 0 best_z = None for epoch in range(1, epochs + 1): model.train() optimizer.zero_grad() z = model.encode(x, train_pos_edge_index) loss = model.recon_loss(z, train_pos_edge_index) loss.backward() optimizer.step() model.eval() with torch.no_grad(): z = model.encode(x, train_pos_edge_index) auc, ap = model.test(z, train_data.val_pos_edge_index, train_data.val_neg_edge_index) print('Val | Epoch: {:03d}, AUC: {:.4f}, AP: {:.4f}'.format( epoch, auc, ap)) if auc > best_val_auc: best_val_auc = auc best_z = deepcopy(z) A_pred = torch.sigmoid(torch.mm(z, z.T)).cpu().numpy() adj_orig = to_scipy_sparse_matrix(edge_index).asformat('csr') adj_pred = sample_graph_det(adj_orig, A_pred, remove_pct, add_pct) if data.setting == 'inductive': data.train_edge_index, _ = from_scipy_sparse_matrix(adj_pred) else: data.edge_index, _ = from_scipy_sparse_matrix(adj_pred) pickle.dump((A_pred, adj_orig), open(f'{ROOT}/cache/edge/{name}.pt', 'wb')) if data.setting == 'inductive': pickle.dump( data.train_edge_index, open(f'{ROOT}/cache/edge/{name}_{remove_pct}_{add_pct}.pt', 'wb')) else: pickle.dump( data.edge_index, open(f'{ROOT}/cache/edge/{name}_{remove_pct}_{add_pct}.pt', 'wb'))