Example #1
0
def test_vgae():
    model = VGAE(encoder=lambda x: (x, x))

    x = torch.Tensor([[1, -1], [1, 2], [2, 1]])
    model.encode(x)
    assert model.kl_loss().item() > 0

    model.eval()
    model.encode(x)
Example #2
0
def train_model_and_save_embeddings(dataset, data, epochs, learning_rate,
                                    device):
    # Define Model
    encoder = EmbeddingEncoder(emb_dim=200,
                               out_channels=64,
                               n_nodes=dataset.num_nodes).to(device)

    decoder = CosineSimDecoder().to(device)

    model = VGAE(encoder=encoder, decoder=decoder).to(device)

    node_features, train_pos_edge_index = data.x.to(
        device), data.edge_index.to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

    # data.edge_index = data.edge_index.long()

    assert data.edge_index.max().item() < dataset.num_nodes

    data_loader = NeighborSampler(data,
                                  size=[25, 10],
                                  num_hops=2,
                                  batch_size=10000,
                                  shuffle=False,
                                  add_self_loops=False)

    model.train()

    for epoch in tqdm(range(epochs)):
        epoch_loss = 0.0
        for data_flow in tqdm(data_loader()):
            optimizer.zero_grad()

            data_flow = data_flow.to(device)
            block = data_flow[0]
            embeddings = model.encode(
                node_features[block.n_id], block.edge_index
            )  # TODO Avoid computation of all node features!

            loss = model.recon_loss(embeddings, block.edge_index)
            loss = loss + (1 / len(block.n_id)) * model.kl_loss()

            epoch_loss += loss.item()

            # Compute gradients
            loss.backward()
            # Perform optimization step
            optimizer.step()

        z = model.encode(node_features, train_pos_edge_index)

        torch.save(z.cpu(), "large_emb.pt")

        print(f"Loss after epoch {epoch} / {epochs}: {epoch_loss}")

    return model
Example #3
0
def test_init():
    encoder = torch.nn.Linear(16, 32)
    decoder = torch.nn.Linear(32, 16)
    discriminator = torch.nn.Linear(32, 1)

    GAE(encoder, decoder)
    VGAE(encoder, decoder)
    ARGA(encoder, discriminator, decoder)
    ARGVA(encoder, discriminator, decoder)
        super().__init__()
        self.conv_mu = GCNConv(in_channels, out_channels)
        self.conv_logstd = GCNConv(in_channels, out_channels)

    def forward(self, x, edge_index):
        return self.conv_mu(x, edge_index), self.conv_logstd(x, edge_index)


in_channels, out_channels = dataset.num_features, 16

if not args.variational and not args.linear:
    model = GAE(GCNEncoder(in_channels, out_channels))
elif not args.variational and args.linear:
    model = GAE(LinearEncoder(in_channels, out_channels))
elif args.variational and not args.linear:
    model = VGAE(VariationalGCNEncoder(in_channels, out_channels))
elif args.variational and args.linear:
    model = VGAE(VariationalLinearEncoder(in_channels, out_channels))

model = model.to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)


def train():
    model.train()
    optimizer.zero_grad()
    z = model.encode(train_data.x, train_data.edge_index)
    loss = model.recon_loss(z, train_data.pos_edge_label_index)
    if args.variational:
        loss = loss + (1 / train_data.num_nodes) * model.kl_loss()
    loss.backward()
Example #5
0
    if args.dataset in ['cora', 'citeseer', 'pubmed']:
        path = os.path.join(os.path.dirname(os.path.realpath(__file__)), '.',
                            'data', args.dataset)
        data = Planetoid(path, args.dataset)[0]
    else:
        data = load_wiki.load_data()

    data.edge_index = gutils.to_undirected(data.edge_index)
    data = GAE.split_edges(GAE, data)

    num_features = data.x.shape[1]
    aucs = []
    aps = []
    for run in range(args.runs):
        model = VGAE(VGAE_Encoder(num_features))
        optimizer = torch.optim.Adam(model.parameters(), lr=0.01)

        # Training loop
        for epoch in range(args.epochs):
            model.train()
            optimizer.zero_grad()
            z = model.encode(data.x, data.train_pos_edge_index)
            loss = model.recon_loss(
                z, data.train_pos_edge_index)  #0.01*model.kl_loss()
            loss.backward()
            optimizer.step()

            # Log validation metrics
            if epoch % args.val_freq == 0:
                model.eval()
Example #6
0
def main():
    model_name = 'VGAE'
    disease_gene_files = [
        'data/OMIM/3-fold-1.txt', 'data/OMIM/3-fold-2.txt',
        'data/OMIM/3-fold-3.txt'
    ]
    disease_disease_file = 'data/MimMiner/MimMiner.txt'
    gene_gene_file = 'data/HumanNetV2/HumanNet_V2.txt'
    prediction_files = [
        f'data/prediction/{model_name}/prediction-3-fold-1.txt',
        f'data/prediction/{model_name}/prediction-3-fold-2.txt',
        f'data/prediction/{model_name}/prediction-3-fold-3.txt'
    ]

    for counter in [3]:
        g_nx = nx.Graph()
        with open(disease_gene_files[counter], 'r') as f:
            for line in f:
                node1, node2, tag = line.strip().split('\t')
                if tag == 'train':
                    g_nx.add_node(node1)
                    g_nx.add_node(node2)
                    g_nx.add_edge(node1, node2, weight=1)
        with open(gene_gene_file, 'r') as f:
            for line in f:
                node1, node2 = line.strip().split('\t')
                g_nx.add_node(node1)
                g_nx.add_node(node2)
                g_nx.add_edge(node1, node2, weight=1)
        with open(disease_disease_file, 'r') as f:
            for line in f:
                node1, node2, weight = line.strip().split('\t')
                g_nx.add_node(node1)
                g_nx.add_node(node2)
                g_nx.add_edge(node1, node2, weight=1)
        print('read data success')

        name_id = dict(zip(g_nx.nodes(), range(g_nx.number_of_nodes())))
        g_nx = nx.relabel_nodes(g_nx, name_id)

        # transform from networkx to pyg data
        g_nx = g_nx.to_directed() if not nx.is_directed(g_nx) else g_nx
        edge_index = torch.tensor(list(g_nx.edges)).t().contiguous()
        data = {}
        data['edge_index'] = edge_index.view(2, -1)
        data = torch_geometric.data.Data.from_dict(data)
        data.num_nodes = g_nx.number_of_nodes()
        data.x = torch.from_numpy(np.eye(data.num_nodes)).float()
        data.train_mask = data.val_mask = data.test_mask = data.y = None
        print(
            f'Graph information:\nNode:{data.num_nodes}\nEdge:{data.num_edges}\nFeature:{data.num_node_features}'
        )

        channels = 128
        dev = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        model = VGAE(Encoder(data.num_node_features, channels)).to(dev)
        x, train_pos_edge_index = data.x.to(dev), data.edge_index.to(dev)
        optimizer = torch.optim.Adam(model.parameters(), lr=0.01)

        for epoch in range(4000):
            model.train()
            optimizer.zero_grad()
            z = model.encode(x, train_pos_edge_index)
            loss = model.recon_loss(
                z,
                train_pos_edge_index) + (1 / data.num_nodes) * model.kl_loss()
            loss.backward()
            optimizer.step()
            nowTime = datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')
            print(f'{nowTime}\tepoch:{epoch}\tloss:{loss}')

        z = model.encode(x, train_pos_edge_index)
        pred = model.decoder.forward_all(z).cpu().detach().numpy().tolist()

        id_name = {}
        diseases = set()
        genes = set()
        for key in name_id:
            id_name[name_id[key]] = key
            if key.startswith('g_'):
                genes.add(key)
            elif key.startswith('d_'):
                diseases.add(key)

        test_diseases = set()
        with open(disease_gene_files[counter], 'r') as f:
            for line in f:
                disease, gene, tag = line.strip().split('\t')
                if tag == 'test':
                    test_diseases.add(disease)

        with open(prediction_files[counter], 'w') as f:
            for disease in test_diseases:
                sims = {}
                if disease not in diseases:
                    for gene in genes:
                        sims[gene] = 0
                else:
                    for gene in genes:
                        sim = pred[name_id[disease]][name_id[gene]]
                        sims[gene] = sim
                sorted_sims = sorted(sims.items(),
                                     key=lambda item: item[1],
                                     reverse=True)
                c = 0
                for gene, sim in sorted_sims:
                    f.write(disease + '\t' + gene + '\t' + str(sim) + '\n')
                    c += 1
                    if c >= 150:
                        break
Example #7
0
def run_experiment(args):
    """
    Performing experiment for the given arguments
    """
    dataset, data = load_data(args.dataset)
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    # Define Model
    encoder = create_encoder(args.model, dataset.num_features,
                             args.latent_dim).to(device)
    decoder = create_decoder(args.decoder).to(device)

    if args.model == 'GAE':
        model = GAE(encoder=encoder, decoder=decoder).to(device)
    else:
        model = VGAE(encoder=encoder, decoder=decoder).to(device)

    # Split edges of a torch_geometric.data.Data object into pos negative train/val/test edges
    # default ratios of positive edges: val_ratio=0.05, test_ratio=0.1
    print("Data.edge_index.size", data.edge_index.size(1))
    data = model.split_edges(data)
    node_features, train_pos_edge_index = data.x.to(
        device), data.train_pos_edge_index.to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=0.01)

    def train_epoch():
        """
        Performing training over a single epoch and optimize over loss
        :return: log - loss of training loss
        """
        # Todo: Add logging of results

        model.train()
        optimizer.zero_grad()
        # Compute latent embedding Z
        latent_embeddings = model.encode(node_features, train_pos_edge_index)

        # Calculate loss and
        loss = model.recon_loss(latent_embeddings, train_pos_edge_index)
        if args.model in ['VGAE']:
            loss = loss + (1 / data.num_nodes) * model.kl_loss()

        # Compute gradients
        loss.backward()
        # Perform optimization step
        optimizer.step()

        # print("Train-Epoch: {} Loss: {}".format(epoch, loss))

        # ToDo: Add logging via Tensorboard
        log = {'loss': loss}

        return log

    def test(pos_edge_index, neg_edge_index):
        model.eval()
        with torch.no_grad():
            # compute latent var
            z = model.encode(node_features, train_pos_edge_index)

        # model.test return - AUC, AP
        return model.test(z, pos_edge_index, neg_edge_index)

    def test_naive_graph(z, sample_size=1000):

        if args.sample_dense_evaluation:
            graph_type = "sampled"
            z_sample, index_mapping = sample_graph(z, sample_size)
            t = time.time()
            adjacency = model.decoder.forward_all(
                z_sample, sigmoid=(args.decoder == 'dot'))
        else:
            graph_type = "full"
            t = time.time()
            adjacency = model.decoder.forward_all(
                z, sigmoid=(args.decoder == 'dot'))

        print(f"Computing {graph_type} graph took {time.time() - t} seconds.")
        print(
            f"Adjacency matrix takes {adjacency.element_size() * adjacency.nelement() / 10 ** 6} MB of memory."
        )

        if args.min_sim_absolute_value is None:
            args.min_sim_absolute_value, _ = sample_percentile(
                args.min_sim,
                adjacency,
                dist_measure=args.decoder,
                sample_size=sample_size)

        if args.sample_dense_evaluation:
            precision, recall = sampled_dense_precision_recall(
                data, adjacency, index_mapping, args.min_sim_absolute_value)
        else:
            precision, recall = dense_precision_recall(
                data, adjacency, args.min_sim_absolute_value)

        print("Predicted {} adjacency matrix has precision {} and recall {}!".
              format(graph_type, precision, recall))

        return precision, recall

    def sample_graph(z, sample_size):
        N, D = z.shape

        sample_size = min(sample_size, N)
        sample_ix = np.random.choice(np.arange(N),
                                     size=sample_size,
                                     replace=False)

        # Returns the sampled embeddings, and a mapping from their indices to the originals
        return z[sample_ix], {i: sample_ix[i] for i in np.arange(sample_size)}

    def test_compare_lsh_naive_graphs(z, assure_correctness=True):
        """

        :param z:
        :param assure_correctness:
        :return:
        """
        # Naive Adjacency-Matrix (Non-LSH-Version)
        t = time.time()
        # Don't use sigmoid in order to directly compare thresholds with LSH
        naive_adjacency = model.decoder.forward_all(
            z, sigmoid=(args.decoder == 'dot'))
        naive_time = time.time() - t
        naive_size = naive_adjacency.element_size() * naive_adjacency.nelement(
        ) / 10**6

        if args.min_sim_absolute_value is None:
            args.min_sim_absolute_value, _ = sample_percentile(
                args.min_sim, z, dist_measure=args.decoder)

        print(
            "______________________________Naive Graph Computation KPI____________________________________________"
        )
        print(f"Computing naive graph took {naive_time} seconds.")
        print(f"Naive adjacency matrix takes {naive_size} MB of memory.")

        # LSH-Adjacency-Matrix:
        t = time.time()
        lsh_adjacency = LSHDecoder(bands=args.lsh_bands,
                                   rows=args.lsh_rows,
                                   verbose=True,
                                   assure_correctness=assure_correctness,
                                   sim_thresh=args.min_sim_absolute_value)(z)
        lsh_time = time.time() - t
        lsh_size = lsh_adjacency.element_size() * lsh_adjacency._nnz() / 10**6

        print(
            "__________________________________LSH Graph Computation KPI__________________________________________"
        )
        print(f"Computing LSH graph took {lsh_time} seconds.")
        print(f"Sparse adjacency matrix takes {lsh_size} MB of memory.")

        print(
            "________________________________________Precision-Recall_____________________________________________"
        )
        # 1) Evaluation: Both Adjacency matrices against ground truth graph
        naive_precision, naive_recall = dense_precision_recall(
            data, naive_adjacency, args.min_sim_absolute_value)

        lsh_precision, lsh_recall = sparse_precision_recall(
            data, lsh_adjacency)

        print(
            f"Naive-Precision {naive_precision}; Naive-Recall {naive_recall}")
        print(f"LSH-Precision {lsh_precision}; LSH-Recall {lsh_recall}")

        print(
            "_____________________________Comparison Sparse vs Dense______________________________________________"
        )
        # 2) Evation: Compare both adjacency matrices against each other
        compare_precision, compare_recall = sparse_v_dense_precision_recall(
            naive_adjacency, lsh_adjacency, args.min_sim_absolute_value)
        print(
            f"LSH sparse matrix has {compare_precision} precision and {compare_recall} recall w.r.t. the naively generated dense matrix!"
        )

        return naive_precision, naive_recall, naive_time, naive_size, lsh_precision, lsh_recall, lsh_time, lsh_size, compare_precision, compare_recall

    # Training routine
    early_stopping = EarlyStopping(args.use_early_stopping,
                                   patience=args.early_stopping_patience,
                                   verbose=True)

    logs = []

    if args.load_model and os.path.isfile("checkpoint.pt"):
        print("Loading model from savefile...")
        model.load_state_dict(torch.load("checkpoint.pt"))

    if not (args.load_model and args.early_stopping_patience == 0):
        for epoch in range(1, args.epochs):
            log = train_epoch()
            logs.append(log)

            # Validation metrics
            val_auc, val_ap = test(data.val_pos_edge_index,
                                   data.val_neg_edge_index)
            print('Validation-Epoch: {:03d}, AUC: {:.4f}, AP: {:.4f}'.format(
                epoch, val_auc, val_ap))

            # Stop training if validation scores have not improved
            early_stopping(val_ap, model)
            if early_stopping.early_stop:
                print("Applying early-stopping")
                break
    else:
        epoch = 0

    # Load best encoder
    print("Load best model for evaluation.")
    model.load_state_dict(torch.load('checkpoint.pt'))
    print(
        "__________________________________________________________________________"
    )
    # Training is finished, calculate test metrics
    test_auc, test_ap = test(data.test_pos_edge_index,
                             data.test_neg_edge_index)
    print('Test Results: {:03d}, AUC: {:.4f}, AP: {:.4f}'.format(
        epoch, test_auc, test_ap))

    # Check if early stopping was applied or not - if not: model might not be done with training
    if args.epochs == epoch + 1:
        print("Model might need more epochs - Increase number of Epochs!")

    # Evaluate full graph
    latent_embeddings = model.encode(node_features, train_pos_edge_index)

    # Save embeddings to embeddings folder if flag is set
    if args.save_embeddings:
        embeddings_folder = osp.join(osp.dirname(osp.abspath(__file__)),
                                     'embeddings')
        if not osp.isdir(embeddings_folder):
            os.makedirs(embeddings_folder)

        torch.save(
            latent_embeddings,
            osp.join(embeddings_folder,
                     args.dataset + "_" + args.decoder + ".pt"))

    if not args.lsh:
        # Compute precision recall w.r.t the ground truth graph
        graph_precision, graph_recall = test_naive_graph(latent_embeddings)
        del model
        del encoder
        del decoder
        torch.cuda.empty_cache()
    else:
        # Precision w.r.t. the generated graph
        naive_precision, naive_recall, naive_time, naive_size, lsh_precision, \
        lsh_recall, lsh_time, lsh_size, \
        compare_precision, compare_recall = test_compare_lsh_naive_graphs(
            latent_embeddings)

        del model
        del encoder
        del decoder
        torch.cuda.empty_cache()

        return {
            'args': args,
            'test_auc': test_auc,
            'test_ap': test_ap,
            'naive_precision': naive_precision,
            'naive_recall': naive_recall,
            'naive_time': naive_time,
            'naive_size': naive_size,
            'lsh_precision': lsh_precision,
            'lsh_recall': lsh_recall,
            'lsh_time': lsh_time,
            'lsh_size': lsh_size,
            'compare_precision': compare_precision,
            'compare_recall': compare_recall
        }
Example #8
0
    print("use dataset: CiteSeer")
elif args.dataset.lower() == 'PubMed'.lower():
    dataset = Planetoid(root='tmp', name='PubMed')
    print("use dataset: PubMed")
data = dataset[0]

enhanced_data = train_test_split_edges(data.clone(),
                                       val_ratio=0.1,
                                       test_ratio=0.2)

train_data = Data(x=enhanced_data.x,
                  edge_index=enhanced_data['train_pos_edge_index']).to(DEVICE)
target_data = data.to(DEVICE)

if args.model is 'VGAE':
    model = VGAE(encoder=VEncoder(data['x'].shape[1])).to(DEVICE)
else:
    model = GAE(encoder=Encoder(data['x'].shape[1])).to(DEVICE)

optimizer = torch.optim.Adam(model.parameters(),
                             lr=args.learning_rate,
                             weight_decay=5e-4)


def model_train():
    print("========Start training========")
    for epoch in range(args.num_epoch):
        model.train()
        optimizer.zero_grad()
        z = model.encode(train_data)
        recon_loss = model.recon_loss(z, target_data['edge_index'])
                    help='Initial learning rate.')
parser.add_argument('--res',
                    type=str,
                    default="True",
                    help='Residual connection')
args = parser.parse_args()

#download datasets
path = os.join(os.dirname(os.realpath(__file__)), '..', 'data', args.dataset)
dataset = Planetoid(path, args.dataset)

dev = torch.device(args.dev)

if args.model == VGAE:
    model = VGAE(
        Encoder_VGAE(dataset.num_features, args.hidden1, args.hidden2,
                     args.depth, args.res)).to(dev)
else:
    model = GAE(
        Encoder_GAE(dataset.num_features, args.hidden1, args.hidden2,
                    args.depth, args.res)).to(dev)

auc_score_list = []
ap_score_list = []

print("Dataset: ", args.dataset, " Model: ", args.model, ", Residual :",
      args.res, ", Layer depth:", args.depth, " ")

for i in range(1, args.runs + 1):
    data = dataset[0]
    data.train_mask = data.val_mask = data.test_mask = data.y = None
Example #10
0
def run_VGAE(input_data,
             output_dir,
             epochs=1000,
             lr=0.01,
             weight_decay=0.0005):
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print('Device: '.ljust(32), device)
    print('Model Name: '.ljust(32), 'VGAE')
    print('Model params:{:19} lr: {}     weight_decay: {}'.format(
        '', lr, weight_decay))
    print('Total number of epochs to run: '.ljust(32), epochs)
    print('*' * 70)

    data = input_data.clone().to(device)
    model = VGAE(VGAEncoder(data.num_features,
                            data.num_classes.item())).to(device)
    data = model.split_edges(data)
    x, train_pos_edge_index, edge_attr = data.x.to(
        device), data.train_pos_edge_index.to(device), data.edge_attr.to(
            device)
    data.train_idx = data.test_idx = data.y = None
    optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
    train_losses = []
    test_losses = []
    aucs = []
    aps = []
    model.train()
    for epoch in range(1, epochs + 1):
        train_loss, test_loss = 0, 0
        optimizer.zero_grad()
        z = model.encode(x, train_pos_edge_index)
        train_loss = model.recon_loss(
            z, train_pos_edge_index) + (1 / data.num_nodes) * model.kl_loss()
        train_losses.append(train_loss.item())
        train_loss.backward()
        optimizer.step()

        model.eval()
        with torch.no_grad():
            z = model.encode(x, train_pos_edge_index)
        auc, ap = model.test(z, data.test_pos_edge_index,
                             data.test_neg_edge_index)
        test_loss = model.recon_loss(
            z,
            data.test_pos_edge_index) + (1 / data.num_nodes) * model.kl_loss()
        test_losses.append(test_loss.item())
        aucs.append(auc)
        aps.append(ap)
        makepath(output_dir)
        figname = os.path.join(
            output_dir, "_".join(
                (VGAE.__name__, str(lr), str(weight_decay), str(epochs))))
        # print('AUC: {:.4f}, AP: {:.4f}'.format(auc, ap))
        if (epoch % int(epochs / 10) == 0):
            print(
                'Epoch: {}        Train loss: {}    Test loss: {}    AUC: {}    AP: {:.4f}'
                .format(epoch, train_loss, test_loss, auc, ap))
        if (epoch == epochs):
            print(
                '-' * 65,
                '\nFinal epoch: {}  Train loss: {}    Test loss: {}    AUC: {}    AP: {}'
                .format(epoch, train_loss, test_loss, auc, ap))
        log = 'Final epoch: {}    Train loss: {}    Test loss: {}    AUC: {}    AP: {}'.format(
            epoch, train_loss, test_loss, auc, ap)
        write_log(log, figname)
    print('-' * 65)

    plot_linkpred(train_losses, test_losses, aucs, aps, output_dir, epochs,
                  figname)
    return
Example #11
0
    def forward(self, x, edge_index):
        return self.conv_mu(x, edge_index), self.conv_logstd(x, edge_index)


out_channels = 16
num_features = dataset.num_features

if not args.variational:
    if not args.linear:
        model = GAE(GCNEncoder(num_features, out_channels))
    else:
        model = GAE(LinearEncoder(num_features, out_channels))
else:
    if args.linear:
        model = VGAE(VariationalLinearEncoder(num_features, out_channels))
    else:
        model = VGAE(VariationalGCNEncoder(num_features, out_channels))

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = model.to(device)
x = data.x.to(device)
train_pos_edge_index = data.train_pos_edge_index.to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)


def train():
    model.train()
    optimizer.zero_grad()
    z = model.encode(x, train_pos_edge_index)
    loss = model.recon_loss(z, train_pos_edge_index)
Example #12
0
def run_model(dataset, conf):
    # ## 1) Build Table graph
    # ### Tables tokenization
    tokenized_tables, vocabulary, cell_dict, reversed_dictionary = corpus_tuple = create_corpus(
        dataset, include_attr=conf["add_attr"])
    if conf["shuffle_vocab"] == True:
        shuffled_vocab = shuffle_vocabulary(vocabulary)
    else:
        shuffled_vocab = None

    nodes = build_node_features(vocabulary)
    row_edges_index, row_edges_weights = build_graph_edges(
        tokenized_tables,
        s_vocab=shuffled_vocab,
        sample_frac=conf["row_edges_sample"],
        columns=False)
    col_edges_index, col_edges_weights = build_graph_edges(
        tokenized_tables,
        s_vocab=shuffled_vocab,
        sample_frac=conf["column_edges_sample"],
        columns=True)

    edges = torch.cat((row_edges_index, col_edges_index), dim=1)
    weights = torch.cat((row_edges_weights, col_edges_weights), dim=0)
    graph_data = Data(x=nodes, edge_index=edges, edge_attr=weights)

    # ## 2 ) Run Table Auto-Encoder Model:
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    loader = DataLoader(torch.arange(graph_data.num_nodes),
                        batch_size=128,
                        shuffle=True)
    graph_data = graph_data.to(device)

    x, train_pos_edge_index = nodes, edges

    class Encoder(torch.nn.Module):
        def __init__(self, in_channels, out_channels):
            super(Encoder, self).__init__()
            self.conv1 = GCNConv(in_channels, 2 * out_channels, cached=True)
            self.conv_mu = GCNConv(2 * out_channels, out_channels, cached=True)
            self.conv_logvar = GCNConv(2 * out_channels,
                                       out_channels,
                                       cached=True)

        def forward(self, x, edge_index):
            x = F.relu(self.conv1(x, edge_index))
            return self.conv_mu(x, edge_index), self.conv_logvar(x, edge_index)

    channels = conf["vector_size"]
    enc = Encoder(graph_data.num_features, channels)
    model = VGAE(enc)
    model = model.to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=0.01)

    def train(model, optimizer, x, train_pos_edge_index):
        model.train()
        optimizer.zero_grad()
        z = model.encode(x, train_pos_edge_index)
        rl = model.recon_loss(z, train_pos_edge_index)
        kl = model.kl_loss()

        loss = rl + kl

        loss.backward()
        optimizer.step()
        return (rl, kl, loss)

    losses = []
    for epoch in range(conf["epoch_num"]):
        loss = train(model, optimizer, x, train_pos_edge_index)
        losses.append(loss)
        print(epoch, loss)
        losses.append(loss)
    # ### 3) Extract the latent cell vectors, generate table vectors:
    def get_cell_vectors(model, x, train_pos_edge_index):
        model.eval()
        with torch.no_grad():
            z = model.encode(x, train_pos_edge_index)
            cell_vectors = z.numpy()
        return z, cell_vectors

    z, cell_vectors = get_cell_vectors(model, x, train_pos_edge_index)

    vec_list = generate_table_vectors(cell_vectors,
                                      tokenized_tables,
                                      s_vocab=shuffled_vocab)

    # ## 3) Evaluate the model
    result_score = evaluate_model(dataset, vec_list, k=5)
    return cell_vectors, vec_list, losses, result_score
node_recon_loss_weight = 10.0
use_gin = True

path = Path(__file__).parent / "../../test/data/BBA-subset-100.h5"
node_feature_path = (Path(__file__).parent /
                     "../../test/data/onehot_bba_amino_acid_labels.npy")
dataset = ContactMapDataset(path,
                            "contact_map", ["rmsd"],
                            node_feature_path=node_feature_path)
loader = DataLoader(dataset, batch_size=1, shuffle=True)

# Select node AE
if args.linear:
    node_ae = GAE(GCNEncoder(num_features, node_out_channels))
else:
    node_ae = VGAE(VariationalGCNEncoder(num_features, node_out_channels))

# Select graph AE
encoder = VariationalGraphEncoder(
    node_out_channels,
    hidden_channels,
    graph_out_channels,
    depth,
    pool_ratios,
    act,
    variational,
    use_gin,
)
decoder = VariationalGraphDecoder(
    graph_out_channels,
    hidden_channels,
if __name__ == "__main__":
    filePath = '../wholeYear/' if len(sys.argv) > 1 else sys.argv[2]
    dataset = WholeYearDataset(filePath)
    d = dataset[0]

    train_test_split_edges(d)

    #parameters
    out_channels = 2
    num_features = d.num_features

    model_gae1 = GAE(GCNEncoder(num_features, out_channels))
    areasUnderCurve_gae_weekday, precisions_gae_weekday, losses_gae_weekday = runAutoencoder(
        model_gae1, d, 1000, torch.optim.Adam, 0.001)
    plotAUC_AP_Loss(areasUnderCurve_gae_weekday, precisions_gae_weekday,
                    losses_gae_weekday, 1000, "GAE 1: 2 Convolutions")

    model2 = GAE(GCNEncoder2(num_features, out_channels))
    areasUnderCurve_gae_weekday_model2, precisions_gae_weekday_model2, losses_gae_weekday_model2 = runAutoencoder(
        model2, d, 1000, torch.optim.Adam, 0.001)
    plotAUC_AP_Loss(areasUnderCurve_gae_weekday_model2,
                    precisions_gae_weekday_model2, losses_gae_weekday_model2,
                    1000, "GAE 2: 2 Convolutions 1 Linear")

    modelVgae = VGAE(VariationalGCNEncoder(num_features, out_channels))
    runVariational1 = runVariational(modelVgae, d, 1000, torch.optim.Adam,
                                     0.001)
    plotAUC_AP_Loss(runVariational1[0], runVariational1[1], runVariational1[2],
                    1000, "GV|AE 1: 2 Convolutions")