Exemplo n.º 1
0
def train_model_and_save_embeddings(dataset, data, epochs, learning_rate,
                                    device):
    # Define Model
    encoder = EmbeddingEncoder(emb_dim=200,
                               out_channels=64,
                               n_nodes=dataset.num_nodes).to(device)

    decoder = CosineSimDecoder().to(device)

    model = VGAE(encoder=encoder, decoder=decoder).to(device)

    node_features, train_pos_edge_index = data.x.to(
        device), data.edge_index.to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

    # data.edge_index = data.edge_index.long()

    assert data.edge_index.max().item() < dataset.num_nodes

    data_loader = NeighborSampler(data,
                                  size=[25, 10],
                                  num_hops=2,
                                  batch_size=10000,
                                  shuffle=False,
                                  add_self_loops=False)

    model.train()

    for epoch in tqdm(range(epochs)):
        epoch_loss = 0.0
        for data_flow in tqdm(data_loader()):
            optimizer.zero_grad()

            data_flow = data_flow.to(device)
            block = data_flow[0]
            embeddings = model.encode(
                node_features[block.n_id], block.edge_index
            )  # TODO Avoid computation of all node features!

            loss = model.recon_loss(embeddings, block.edge_index)
            loss = loss + (1 / len(block.n_id)) * model.kl_loss()

            epoch_loss += loss.item()

            # Compute gradients
            loss.backward()
            # Perform optimization step
            optimizer.step()

        z = model.encode(node_features, train_pos_edge_index)

        torch.save(z.cpu(), "large_emb.pt")

        print(f"Loss after epoch {epoch} / {epochs}: {epoch_loss}")

    return model
Exemplo n.º 2
0
    else:
        data = load_wiki.load_data()

    data.edge_index = gutils.to_undirected(data.edge_index)
    data = GAE.split_edges(GAE, data)

    num_features = data.x.shape[1]
    aucs = []
    aps = []
    for run in range(args.runs):
        model = VGAE(VGAE_Encoder(num_features))
        optimizer = torch.optim.Adam(model.parameters(), lr=0.01)

        # Training loop
        for epoch in range(args.epochs):
            model.train()
            optimizer.zero_grad()
            z = model.encode(data.x, data.train_pos_edge_index)
            loss = model.recon_loss(
                z, data.train_pos_edge_index)  #0.01*model.kl_loss()
            loss.backward()
            optimizer.step()

            # Log validation metrics
            if epoch % args.val_freq == 0:
                model.eval()
                with torch.no_grad():
                    z = model.encode(data.x, data.train_pos_edge_index)
                    auc, ap = model.test(z, data.val_pos_edge_index,
                                         data.val_neg_edge_index)
                print('Train loss: {:.4f}, Validation AUC-ROC: {:.4f}, '
Exemplo n.º 3
0
def main():
    model_name = 'VGAE'
    disease_gene_files = [
        'data/OMIM/3-fold-1.txt', 'data/OMIM/3-fold-2.txt',
        'data/OMIM/3-fold-3.txt'
    ]
    disease_disease_file = 'data/MimMiner/MimMiner.txt'
    gene_gene_file = 'data/HumanNetV2/HumanNet_V2.txt'
    prediction_files = [
        f'data/prediction/{model_name}/prediction-3-fold-1.txt',
        f'data/prediction/{model_name}/prediction-3-fold-2.txt',
        f'data/prediction/{model_name}/prediction-3-fold-3.txt'
    ]

    for counter in [3]:
        g_nx = nx.Graph()
        with open(disease_gene_files[counter], 'r') as f:
            for line in f:
                node1, node2, tag = line.strip().split('\t')
                if tag == 'train':
                    g_nx.add_node(node1)
                    g_nx.add_node(node2)
                    g_nx.add_edge(node1, node2, weight=1)
        with open(gene_gene_file, 'r') as f:
            for line in f:
                node1, node2 = line.strip().split('\t')
                g_nx.add_node(node1)
                g_nx.add_node(node2)
                g_nx.add_edge(node1, node2, weight=1)
        with open(disease_disease_file, 'r') as f:
            for line in f:
                node1, node2, weight = line.strip().split('\t')
                g_nx.add_node(node1)
                g_nx.add_node(node2)
                g_nx.add_edge(node1, node2, weight=1)
        print('read data success')

        name_id = dict(zip(g_nx.nodes(), range(g_nx.number_of_nodes())))
        g_nx = nx.relabel_nodes(g_nx, name_id)

        # transform from networkx to pyg data
        g_nx = g_nx.to_directed() if not nx.is_directed(g_nx) else g_nx
        edge_index = torch.tensor(list(g_nx.edges)).t().contiguous()
        data = {}
        data['edge_index'] = edge_index.view(2, -1)
        data = torch_geometric.data.Data.from_dict(data)
        data.num_nodes = g_nx.number_of_nodes()
        data.x = torch.from_numpy(np.eye(data.num_nodes)).float()
        data.train_mask = data.val_mask = data.test_mask = data.y = None
        print(
            f'Graph information:\nNode:{data.num_nodes}\nEdge:{data.num_edges}\nFeature:{data.num_node_features}'
        )

        channels = 128
        dev = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        model = VGAE(Encoder(data.num_node_features, channels)).to(dev)
        x, train_pos_edge_index = data.x.to(dev), data.edge_index.to(dev)
        optimizer = torch.optim.Adam(model.parameters(), lr=0.01)

        for epoch in range(4000):
            model.train()
            optimizer.zero_grad()
            z = model.encode(x, train_pos_edge_index)
            loss = model.recon_loss(
                z,
                train_pos_edge_index) + (1 / data.num_nodes) * model.kl_loss()
            loss.backward()
            optimizer.step()
            nowTime = datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')
            print(f'{nowTime}\tepoch:{epoch}\tloss:{loss}')

        z = model.encode(x, train_pos_edge_index)
        pred = model.decoder.forward_all(z).cpu().detach().numpy().tolist()

        id_name = {}
        diseases = set()
        genes = set()
        for key in name_id:
            id_name[name_id[key]] = key
            if key.startswith('g_'):
                genes.add(key)
            elif key.startswith('d_'):
                diseases.add(key)

        test_diseases = set()
        with open(disease_gene_files[counter], 'r') as f:
            for line in f:
                disease, gene, tag = line.strip().split('\t')
                if tag == 'test':
                    test_diseases.add(disease)

        with open(prediction_files[counter], 'w') as f:
            for disease in test_diseases:
                sims = {}
                if disease not in diseases:
                    for gene in genes:
                        sims[gene] = 0
                else:
                    for gene in genes:
                        sim = pred[name_id[disease]][name_id[gene]]
                        sims[gene] = sim
                sorted_sims = sorted(sims.items(),
                                     key=lambda item: item[1],
                                     reverse=True)
                c = 0
                for gene, sim in sorted_sims:
                    f.write(disease + '\t' + gene + '\t' + str(sim) + '\n')
                    c += 1
                    if c >= 150:
                        break
Exemplo n.º 4
0
def run_VGAE(input_data,
             output_dir,
             epochs=1000,
             lr=0.01,
             weight_decay=0.0005):
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print('Device: '.ljust(32), device)
    print('Model Name: '.ljust(32), 'VGAE')
    print('Model params:{:19} lr: {}     weight_decay: {}'.format(
        '', lr, weight_decay))
    print('Total number of epochs to run: '.ljust(32), epochs)
    print('*' * 70)

    data = input_data.clone().to(device)
    model = VGAE(VGAEncoder(data.num_features,
                            data.num_classes.item())).to(device)
    data = model.split_edges(data)
    x, train_pos_edge_index, edge_attr = data.x.to(
        device), data.train_pos_edge_index.to(device), data.edge_attr.to(
            device)
    data.train_idx = data.test_idx = data.y = None
    optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
    train_losses = []
    test_losses = []
    aucs = []
    aps = []
    model.train()
    for epoch in range(1, epochs + 1):
        train_loss, test_loss = 0, 0
        optimizer.zero_grad()
        z = model.encode(x, train_pos_edge_index)
        train_loss = model.recon_loss(
            z, train_pos_edge_index) + (1 / data.num_nodes) * model.kl_loss()
        train_losses.append(train_loss.item())
        train_loss.backward()
        optimizer.step()

        model.eval()
        with torch.no_grad():
            z = model.encode(x, train_pos_edge_index)
        auc, ap = model.test(z, data.test_pos_edge_index,
                             data.test_neg_edge_index)
        test_loss = model.recon_loss(
            z,
            data.test_pos_edge_index) + (1 / data.num_nodes) * model.kl_loss()
        test_losses.append(test_loss.item())
        aucs.append(auc)
        aps.append(ap)
        makepath(output_dir)
        figname = os.path.join(
            output_dir, "_".join(
                (VGAE.__name__, str(lr), str(weight_decay), str(epochs))))
        # print('AUC: {:.4f}, AP: {:.4f}'.format(auc, ap))
        if (epoch % int(epochs / 10) == 0):
            print(
                'Epoch: {}        Train loss: {}    Test loss: {}    AUC: {}    AP: {:.4f}'
                .format(epoch, train_loss, test_loss, auc, ap))
        if (epoch == epochs):
            print(
                '-' * 65,
                '\nFinal epoch: {}  Train loss: {}    Test loss: {}    AUC: {}    AP: {}'
                .format(epoch, train_loss, test_loss, auc, ap))
        log = 'Final epoch: {}    Train loss: {}    Test loss: {}    AUC: {}    AP: {}'.format(
            epoch, train_loss, test_loss, auc, ap)
        write_log(log, figname)
    print('-' * 65)

    plot_linkpred(train_losses, test_losses, aucs, aps, output_dir, epochs,
                  figname)
    return