def train_model_and_save_embeddings(dataset, data, epochs, learning_rate, device): # Define Model encoder = EmbeddingEncoder(emb_dim=200, out_channels=64, n_nodes=dataset.num_nodes).to(device) decoder = CosineSimDecoder().to(device) model = VGAE(encoder=encoder, decoder=decoder).to(device) node_features, train_pos_edge_index = data.x.to( device), data.edge_index.to(device) optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate) # data.edge_index = data.edge_index.long() assert data.edge_index.max().item() < dataset.num_nodes data_loader = NeighborSampler(data, size=[25, 10], num_hops=2, batch_size=10000, shuffle=False, add_self_loops=False) model.train() for epoch in tqdm(range(epochs)): epoch_loss = 0.0 for data_flow in tqdm(data_loader()): optimizer.zero_grad() data_flow = data_flow.to(device) block = data_flow[0] embeddings = model.encode( node_features[block.n_id], block.edge_index ) # TODO Avoid computation of all node features! loss = model.recon_loss(embeddings, block.edge_index) loss = loss + (1 / len(block.n_id)) * model.kl_loss() epoch_loss += loss.item() # Compute gradients loss.backward() # Perform optimization step optimizer.step() z = model.encode(node_features, train_pos_edge_index) torch.save(z.cpu(), "large_emb.pt") print(f"Loss after epoch {epoch} / {epochs}: {epoch_loss}") return model
else: data = load_wiki.load_data() data.edge_index = gutils.to_undirected(data.edge_index) data = GAE.split_edges(GAE, data) num_features = data.x.shape[1] aucs = [] aps = [] for run in range(args.runs): model = VGAE(VGAE_Encoder(num_features)) optimizer = torch.optim.Adam(model.parameters(), lr=0.01) # Training loop for epoch in range(args.epochs): model.train() optimizer.zero_grad() z = model.encode(data.x, data.train_pos_edge_index) loss = model.recon_loss( z, data.train_pos_edge_index) #0.01*model.kl_loss() loss.backward() optimizer.step() # Log validation metrics if epoch % args.val_freq == 0: model.eval() with torch.no_grad(): z = model.encode(data.x, data.train_pos_edge_index) auc, ap = model.test(z, data.val_pos_edge_index, data.val_neg_edge_index) print('Train loss: {:.4f}, Validation AUC-ROC: {:.4f}, '
def main(): model_name = 'VGAE' disease_gene_files = [ 'data/OMIM/3-fold-1.txt', 'data/OMIM/3-fold-2.txt', 'data/OMIM/3-fold-3.txt' ] disease_disease_file = 'data/MimMiner/MimMiner.txt' gene_gene_file = 'data/HumanNetV2/HumanNet_V2.txt' prediction_files = [ f'data/prediction/{model_name}/prediction-3-fold-1.txt', f'data/prediction/{model_name}/prediction-3-fold-2.txt', f'data/prediction/{model_name}/prediction-3-fold-3.txt' ] for counter in [3]: g_nx = nx.Graph() with open(disease_gene_files[counter], 'r') as f: for line in f: node1, node2, tag = line.strip().split('\t') if tag == 'train': g_nx.add_node(node1) g_nx.add_node(node2) g_nx.add_edge(node1, node2, weight=1) with open(gene_gene_file, 'r') as f: for line in f: node1, node2 = line.strip().split('\t') g_nx.add_node(node1) g_nx.add_node(node2) g_nx.add_edge(node1, node2, weight=1) with open(disease_disease_file, 'r') as f: for line in f: node1, node2, weight = line.strip().split('\t') g_nx.add_node(node1) g_nx.add_node(node2) g_nx.add_edge(node1, node2, weight=1) print('read data success') name_id = dict(zip(g_nx.nodes(), range(g_nx.number_of_nodes()))) g_nx = nx.relabel_nodes(g_nx, name_id) # transform from networkx to pyg data g_nx = g_nx.to_directed() if not nx.is_directed(g_nx) else g_nx edge_index = torch.tensor(list(g_nx.edges)).t().contiguous() data = {} data['edge_index'] = edge_index.view(2, -1) data = torch_geometric.data.Data.from_dict(data) data.num_nodes = g_nx.number_of_nodes() data.x = torch.from_numpy(np.eye(data.num_nodes)).float() data.train_mask = data.val_mask = data.test_mask = data.y = None print( f'Graph information:\nNode:{data.num_nodes}\nEdge:{data.num_edges}\nFeature:{data.num_node_features}' ) channels = 128 dev = torch.device('cuda' if torch.cuda.is_available() else 'cpu') model = VGAE(Encoder(data.num_node_features, channels)).to(dev) x, train_pos_edge_index = data.x.to(dev), data.edge_index.to(dev) optimizer = torch.optim.Adam(model.parameters(), lr=0.01) for epoch in range(4000): model.train() optimizer.zero_grad() z = model.encode(x, train_pos_edge_index) loss = model.recon_loss( z, train_pos_edge_index) + (1 / data.num_nodes) * model.kl_loss() loss.backward() optimizer.step() nowTime = datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S') print(f'{nowTime}\tepoch:{epoch}\tloss:{loss}') z = model.encode(x, train_pos_edge_index) pred = model.decoder.forward_all(z).cpu().detach().numpy().tolist() id_name = {} diseases = set() genes = set() for key in name_id: id_name[name_id[key]] = key if key.startswith('g_'): genes.add(key) elif key.startswith('d_'): diseases.add(key) test_diseases = set() with open(disease_gene_files[counter], 'r') as f: for line in f: disease, gene, tag = line.strip().split('\t') if tag == 'test': test_diseases.add(disease) with open(prediction_files[counter], 'w') as f: for disease in test_diseases: sims = {} if disease not in diseases: for gene in genes: sims[gene] = 0 else: for gene in genes: sim = pred[name_id[disease]][name_id[gene]] sims[gene] = sim sorted_sims = sorted(sims.items(), key=lambda item: item[1], reverse=True) c = 0 for gene, sim in sorted_sims: f.write(disease + '\t' + gene + '\t' + str(sim) + '\n') c += 1 if c >= 150: break
def run_VGAE(input_data, output_dir, epochs=1000, lr=0.01, weight_decay=0.0005): device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') print('Device: '.ljust(32), device) print('Model Name: '.ljust(32), 'VGAE') print('Model params:{:19} lr: {} weight_decay: {}'.format( '', lr, weight_decay)) print('Total number of epochs to run: '.ljust(32), epochs) print('*' * 70) data = input_data.clone().to(device) model = VGAE(VGAEncoder(data.num_features, data.num_classes.item())).to(device) data = model.split_edges(data) x, train_pos_edge_index, edge_attr = data.x.to( device), data.train_pos_edge_index.to(device), data.edge_attr.to( device) data.train_idx = data.test_idx = data.y = None optimizer = torch.optim.Adam(model.parameters(), lr=0.01) train_losses = [] test_losses = [] aucs = [] aps = [] model.train() for epoch in range(1, epochs + 1): train_loss, test_loss = 0, 0 optimizer.zero_grad() z = model.encode(x, train_pos_edge_index) train_loss = model.recon_loss( z, train_pos_edge_index) + (1 / data.num_nodes) * model.kl_loss() train_losses.append(train_loss.item()) train_loss.backward() optimizer.step() model.eval() with torch.no_grad(): z = model.encode(x, train_pos_edge_index) auc, ap = model.test(z, data.test_pos_edge_index, data.test_neg_edge_index) test_loss = model.recon_loss( z, data.test_pos_edge_index) + (1 / data.num_nodes) * model.kl_loss() test_losses.append(test_loss.item()) aucs.append(auc) aps.append(ap) makepath(output_dir) figname = os.path.join( output_dir, "_".join( (VGAE.__name__, str(lr), str(weight_decay), str(epochs)))) # print('AUC: {:.4f}, AP: {:.4f}'.format(auc, ap)) if (epoch % int(epochs / 10) == 0): print( 'Epoch: {} Train loss: {} Test loss: {} AUC: {} AP: {:.4f}' .format(epoch, train_loss, test_loss, auc, ap)) if (epoch == epochs): print( '-' * 65, '\nFinal epoch: {} Train loss: {} Test loss: {} AUC: {} AP: {}' .format(epoch, train_loss, test_loss, auc, ap)) log = 'Final epoch: {} Train loss: {} Test loss: {} AUC: {} AP: {}'.format( epoch, train_loss, test_loss, auc, ap) write_log(log, figname) print('-' * 65) plot_linkpred(train_losses, test_losses, aucs, aps, output_dir, epochs, figname) return