Ejemplo n.º 1
0
def plot_distribution(envs, imitator_model, num_trajs, seed, metric):
    """
    Compare the distribution of reward/time-step of the expert and the imitator.
    :param env:
    :param imitator_model:
    :param num_trajs:
    :param seed:
    :param metric:
    :return:
    """
    rows = 3
    columns = 2
    fig, axs = plt.subplots(rows,
                            columns,
                            figsize=(15, 10),
                            constrained_layout=True)
    for i in range(rows):
        for j in range(columns):
            env = envs[i * columns + j]
            gym_env = gym.make(env)

            # collect expert results
            expert_model_path = 'assets/expert_models/{}_ppo_0.p'.format(env)
            expert, _, running_state, _ = pickle.load(
                open(expert_model_path, "rb"))
            expert_results = evaluate_model(gym_env,
                                            expert,
                                            running_state,
                                            num_trajs=50,
                                            verbose=False)

            imitator_model_path = "assets/imitator_models/{}_{}_traj{}_seed{}.p".format(
                env, imitator_model, num_trajs, seed)
            imitator = pickle.load(open(imitator_model_path, "rb"))[0]
            imitator_results = evaluate_model(gym_env,
                                              imitator,
                                              running_state,
                                              num_trajs=50,
                                              verbose=False)

            axs[i, j].hist(expert_results[metric],
                           density=True,
                           alpha=0.5,
                           label='expert')
            axs[i, j].hist(imitator_results[metric],
                           density=True,
                           alpha=0.5,
                           label=imitator_model)
            axs[i, j].set_title('{} {} Density Curve'.format(env, metric))
            axs[i, j].legend(loc='upper right')
    for ax in axs.flat:
        ax.set(xlabel='{}'.format(metric), ylabel='density')

    fig.savefig(
        'assets/imitator_plots/{}_{}_{}_{}_density_comparison_plot.png'.format(
            envs, imitator_model, seed, metric))
    plt.close(fig)
Ejemplo n.º 2
0
    def get_update(self, model, epoch_num, device, valid_flag=True):
        model.eval()
        if valid_flag == True:
            (list_input, item_input) = self.validArrDubles
            num_inst = self.num_valid_instances * self.valid_dim
            posItemlst = self.valid_pos_items  # parameter for evaluate_model
            matShape = (self.num_valid_instances, self.valid_dim)
        else:
            (list_input, item_input) = self.testArrDubles
            num_inst = self.num_test_instances * self.valid_dim
            posItemlst = self.test_pos_items  # parameter for evaluate_model
            matShape = (self.num_test_instances, self.valid_dim)

        batch_siz = self.valid_batch_siz * self.valid_dim

        full_pred_torch_lst = []
        list_input_ten = torch.from_numpy(list_input.astype(np.long)).to(
            device)  ## could be moved to gpu before-hand
        item_input_ten = torch.from_numpy(item_input.astype(
            np.long)).to(device)
        user_input = self.list_user_vec[list_input]
        user_input_ten = torch.from_numpy(user_input.astype(
            np.long)).to(device)
        batch = Batch(num_inst, batch_siz, shuffle=False)
        ##
        ind = 0
        while batch.has_next_batch():
            batch_indices = batch.get_next_batch_indices()
            if self.params.method == 'bpr' or self.params.loss == 'pairwise':
                user_neg_input = None
                y_pred = model(
                    user_input_ten[batch_indices], user_neg_input,
                    list_input_ten[batch_indices],
                    item_input_ten[batch_indices])  # first argument for user
            else:
                y_pred = model(user_indices=user_input_ten[batch_indices],
                               list_indices=list_input_ten[batch_indices],
                               item_indices=item_input_ten[batch_indices]
                               )  # first argument for user
            full_pred_torch_lst.append(y_pred.detach().cpu().numpy())

        full_pred_np = np.concatenate(
            full_pred_torch_lst)  #.data.cpu().numpy()

        predMatrix = np.array(full_pred_np).reshape(matShape)
        itemMatrix = np.array(item_input).reshape(matShape)
        '''
        print('predMatrix')
        print(predMatrix[0:20,0:20])
        print('itemMatrix')
        print(itemMatrix[0:20,0:20])
        '''

        (hits, ndcgs, maps) = evaluate_model(posItemlst=posItemlst,
                                             itemMatrix=itemMatrix,
                                             predMatrix=predMatrix,
                                             k=self.at_k,
                                             num_thread=self.num_thread)
        return (hits, ndcgs, maps)
Ejemplo n.º 3
0
    def get_update(self, model, epoch_num, device, valid_flag=True):
        model.eval()
        if valid_flag == True:
            (list_input, item_input) = self.validArrDubles
            num_inst = self.num_valid_instances * self.valid_dim
            posItemlst = self.valid_pos_items  # parameter for evaluate_model
            matShape = (self.num_valid_instances, self.valid_dim)
        else:
            (list_input, item_input) = self.testArrDubles
            num_inst = self.num_test_instances * self.valid_dim
            posItemlst = self.test_pos_items  # parameter for evaluate_model
            matShape = (self.num_test_instances, self.valid_dim)

        batch_siz = self.valid_batch_siz * self.valid_dim

        full_pred_torch_lst = []
        list_input_ten = torch.from_numpy(list_input.astype(np.long)).to(
            device)  ## could be moved to gpu before-hand
        item_input_ten = torch.from_numpy(item_input.astype(
            np.long)).to(device)
        user_input = self.list_user_vec[list_input]
        user_input_ten = torch.from_numpy(user_input.astype(
            np.long)).to(device)
        batch = Batch(num_inst, batch_siz, shuffle=False)
        while batch.has_next_batch():
            batch_indices = batch.get_next_batch_indices()

            if valid_flag == True:
                item_seq = torch.from_numpy(self.params.train_matrix_item_seq[
                    list_input[batch_indices]].astype(np.long)).to(
                        device)  ## ##for_test
            else:
                item_seq = torch.from_numpy(
                    self.params.train_matrix_item_seq_for_test[
                        list_input[batch_indices]].astype(np.long)).to(
                            device)  ## ##for_test
            y_pred = model(user_indices=user_input_ten[batch_indices],
                           list_indices=list_input_ten[batch_indices],
                           item_seq=item_seq,
                           test_item_indices=item_input_ten[batch_indices],
                           train=False,
                           network='seq')  # ##
            full_pred_torch_lst.append(y_pred.detach().cpu().numpy())

        full_pred_np = np.concatenate(
            full_pred_torch_lst)  #.data.cpu().numpy()
        # ==============================

        predMatrix = np.array(full_pred_np).reshape(matShape)
        itemMatrix = np.array(item_input).reshape(matShape)

        (hits, ndcgs, maps) = evaluate_model(posItemlst=posItemlst,
                                             itemMatrix=itemMatrix,
                                             predMatrix=predMatrix,
                                             k=self.at_k,
                                             num_thread=self.num_thread)
        return (hits, ndcgs, maps)
Ejemplo n.º 4
0
def evaluate_cifar_models():
    """
    Evaluates the baseline and its ExpandNets and its
    :return:
    """

    print('***  Baseline ***')
    evaluate_model(net=Cifar_Tiny(num_classes=100),
                   path='results/models/tiny_cifar100_3_' + seed + '.model',
                   result_path='results/evals/tiny_cifar100_3_' + seed +
                   '.pickle',
                   dataset_name='cifar100',
                   dataset_loader=cifar100_loader)

    print('***  ExpandNet-FC ***')
    evaluate_model(
        net=Cifar_Tiny_ExpandNet_fc(num_classes=100),
        path='results/models/tiny_cifar100_3_enet_fc_' + seed + '.model',
        result_path='results/evals/tiny_cifar100_3_enet_fc_' + seed +
        '.pickle',
        dataset_name='cifar100',
        dataset_loader=cifar100_loader)

    print('***  ExpandNet-CL ***')
    evaluate_model(
        net=Cifar_Tiny_ExpandNet_cl(num_classes=100),
        path='results/models/tiny_cifar100_3_enet_cl_' + seed + '.model',
        result_path='results/evals/tiny_cifar100_3_enet_cl_' + seed +
        '.pickle',
        dataset_name='cifar100',
        dataset_loader=cifar100_loader)

    print('***  ExpandNet-CL+FC ***')
    evaluate_model(
        net=Cifar_Tiny_ExpandNet_cl_fc(num_classes=100),
        path='results/models/tiny_cifar100_3_enet_cl_fc_' + seed + '.model',
        result_path='results/evals/tiny_cifar100_3_enet_cl_fc_' + seed +
        '.pickle',
        dataset_name='cifar100',
        dataset_loader=cifar100_loader)
Ejemplo n.º 5
0
    log_list = {"bc_loss": [],
                "uncertainty_cost":[],
                "avg_reward": [],
                "std_reward": []}

    total_timesteps = 0

    for i_iter in range(args.max_iter_num):
        batch, log = agent.collect_samples(args.min_batch_size)
        # train DRIL
        t0 = time.time()
        loss = imitator.train(batch)
        t1 = time.time()

        imitator.policy.actor.to('cpu')
        episode_rewards = evaluate_model(env, imitator.policy.actor, running_state=running_state, verbose=False)[
            'episodes_rewards']
        imitator.policy.actor.to(args.device)

        if i_iter % args.log_interval == 0:
            print('{}\tT_update: {:.4f}\t training loss: {:.2f}\t uncertainty cost: {:.4f}'
                  '\t R_avg: {:.2f}\t R_std: {:.2f}'.format(
                i_iter, t1 - t0, loss['bc_loss'], loss['uncertainty_cost'],
                episode_rewards.mean(), episode_rewards.std()))

        if args.save_model_interval > 0 and (i_iter + 1) % args.save_model_interval == 0:
            to_device(torch.device('cpu'), imitator.policy)
            pickle.dump((imitator.policy, imitator.config),
                        open(os.path.join(assets_dir(), 'imitator_models/{}_dril_traj{}_seed{}.p'.format(args.env_name,
                                                                                                       args.num_trajs,
                                                                                                       args.seed)),
                             'wb'))
def evaluate_cifar_models():
    """
    Evaluates the baseline and its ExpandNets and its
    :return:
    """

    print('***  Baseline ***')
    evaluate_model(net=Cifar_Tiny(num_classes=10),
                   path='results/models/tiny_cifar10_7_' + seed + '.model',
                   result_path='results/evals/tiny_cifar10_7_' + seed +
                   '.pickle')

    print('***  ExpandNet-FC ***')
    evaluate_model(
        net=Cifar_Tiny_ExpandNet_fc(num_classes=10),
        path='results/models/tiny_cifar10_7_enet_fc_' + seed + '.model',
        result_path='results/evals/tiny_cifar10_7_enet_fc_' + seed + '.pickle')

    print('***  ExpandNet-CL ***')
    evaluate_model(
        net=Cifar_Tiny_ExpandNet_cl(num_classes=10),
        path='results/models/tiny_cifar10_7_enet_cl_' + seed + '.model',
        result_path='results/evals/tiny_cifar10_7_enet_cl_' + seed + '.pickle')

    print('***  ExpandNet-CL+FC ***')
    evaluate_model(
        net=Cifar_Tiny_ExpandNet_cl_fc(num_classes=10),
        path='results/models/tiny_cifar10_7_enet_cl_fc_' + seed + '.model',
        result_path='results/evals/tiny_cifar10_7_enet_cl_fc_' + seed +
        '.pickle')

    print('***  ExpandNet-CK ***')
    evaluate_model(
        net=Cifar_Tiny_ExpandNet_ck(num_classes=10),
        path='results/models/tiny_cifar10_7_enet_ck_' + seed + '.model',
        result_path='results/evals/tiny_cifar10_7_enet_ck_' + seed + '.pickle')

    print('***  ExpandNet-CK+FC ***')
    evaluate_model(
        net=Cifar_Tiny_ExpandNet_ck_fc(num_classes=10),
        path='results/models/tiny_cifar10_7_enet_ck_fc_' + seed + '.model',
        result_path='results/evals/tiny_cifar10_7_enet_ck_fc_' + seed +
        '.pickle')
Ejemplo n.º 7
0
def run_model(dataset, conf):
    # ## 1) Build Table graph
    # ### Tables tokenization
    tokenized_tables, vocabulary, cell_dict, reversed_dictionary = corpus_tuple = create_corpus(
        dataset, include_attr=conf["add_attr"])
    if conf["shuffle_vocab"] == True:
        shuffled_vocab = shuffle_vocabulary(vocabulary)
    else:
        shuffled_vocab = None

    nodes = build_node_features(vocabulary)
    row_edges_index, row_edges_weights = build_graph_edges(
        tokenized_tables,
        s_vocab=shuffled_vocab,
        sample_frac=conf["row_edges_sample"],
        columns=False)
    col_edges_index, col_edges_weights = build_graph_edges(
        tokenized_tables,
        s_vocab=shuffled_vocab,
        sample_frac=conf["column_edges_sample"],
        columns=True)

    edges = torch.cat((row_edges_index, col_edges_index), dim=1)
    weights = torch.cat((row_edges_weights, col_edges_weights), dim=0)
    graph_data = Data(x=nodes, edge_index=edges, edge_attr=weights)

    # ## 2 ) Run Table Auto-Encoder Model:
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    loader = DataLoader(torch.arange(graph_data.num_nodes),
                        batch_size=128,
                        shuffle=True)
    graph_data = graph_data.to(device)

    def train():
        model.train()
        total_loss = 0
        for subset in loader:
            optimizer.zero_grad()
            loss = model.loss(graph_data.edge_index, subset.to(device))
            loss.backward()
            optimizer.step()
            total_loss += loss.item()
        return total_loss / len(loader)

    model = Node2Vec(graph_data.num_nodes,
                     embedding_dim=conf["vector_size"],
                     walk_length=conf["n2v_walk_length"],
                     context_size=conf["n2v_context_size"],
                     walks_per_node=conf["n2v_walks_per_node"])
    model = model.to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
    losses = []
    for epoch in range(conf["epoch_num"]):
        loss = train()
        print('Epoch: {:02d}, Loss: {:.4f}'.format(epoch, loss))
        losses.append(float(loss))
    # ### 3) Extract the latent cell vectors, generate table vectors:
    model.eval()
    with torch.no_grad():
        z = model(torch.arange(graph_data.num_nodes, device=device))
        cell_vectors = z.cpu().numpy()
    vec_list = generate_table_vectors(cell_vectors,
                                      tokenized_tables,
                                      s_vocab=shuffled_vocab)

    # ## 3) Evaluate the model
    result_score = evaluate_model(dataset, vec_list, k=5)
    return cell_vectors, vec_list, losses, result_score
Ejemplo n.º 8
0
                        default=False,
                        help='print verbose')
    args = parser.parse_args()

    args.model_path = "assets/imitator_models/{}_{}_traj5_seed{}.p".format(
        args.env_name, args.model, args.seed)
    args.expert_path = "assets/expert_models/{}_ppo_0.p".format(args.env_name)

    # load expert and state normalization
    expert, _, running_state, _ = pickle.load(open(args.expert_path, "rb"))
    running_state.fix = True

    # load imitator
    imitator = pickle.load(open(args.model_path, "rb"))[0]
    """environment"""
    env = gym.make(args.env_name)
    state_dim = env.observation_space.shape[0]
    is_disc_action = len(env.action_space.shape) == 0
    action_dim = env.action_space.n if is_disc_action else env.action_space.shape[
        0]

    print("=======================================")
    print("Expert Settings: " + args.expert_path)
    print("Imitator Settings: " + args.model_path)
    print("---------------------------------------")
    print("Evaluating Models for {} Trajectories".format(args.num_trajs))
    # TODO
    print("Expert:")
    evaluate_model(env, expert, running_state)
    print("Imitator:")
    evaluate_model(env, imitator, running_state)
Ejemplo n.º 9
0
    data_loader_test = torch.utils.data.DataLoader(dataset_test,
                                                   batch_size=args.batch_size,
                                                   shuffle=False,
                                                   num_workers=4,
                                                   collate_fn=mycollate)

    #Setup model
    model = get_classifier_model(args.model_name,
                                 num_classes=10,
                                 use_pretrained=False,
                                 train_last_layer_only=True)

    model.load_state_dict(torch.load(args.testmodel_path))
    model.to(args.device)

    score, predictions, ground_truth, class_names = evaluate_model(
        model, data_loader_test, args.device)
    print("Evaluation score {}".format(score))

    #Create and save confusion matrix
    conf_mat = confusion_matrix(ground_truth, predictions, normalize='true')
    df_cm = pd.DataFrame(conf_mat)
    fig, ax = plt.subplots(figsize=(10, 7))
    conf_mat_fig = sn.heatmap(df_cm,
                              annot=True,
                              xticklabels=class_names,
                              yticklabels=class_names,
                              cmap="Blues",
                              cbar=False)
    plt.xlabel('Prediction', fontsize=18)
    ax.xaxis.set_label_position('top')
    plt.ylabel('Actual', fontsize=18)
Ejemplo n.º 10
0
def run_model(dataset, conf):
    # ## 1) Build Table graph
    # ### Tables tokenization
    tokenized_tables, vocabulary, cell_dict, reversed_dictionary = corpus_tuple = create_corpus(
        dataset, include_attr=conf["add_attr"])
    if conf["shuffle_vocab"] == True:
        shuffled_vocab = shuffle_vocabulary(vocabulary)
    else:
        shuffled_vocab = None

    nodes = build_node_features(vocabulary)
    row_edges_index, row_edges_weights = build_graph_edges(
        tokenized_tables,
        s_vocab=shuffled_vocab,
        sample_frac=conf["row_edges_sample"],
        columns=False)
    col_edges_index, col_edges_weights = build_graph_edges(
        tokenized_tables,
        s_vocab=shuffled_vocab,
        sample_frac=conf["column_edges_sample"],
        columns=True)

    all_row_edges_index, all_row_edges_weights = build_graph_edges(
        tokenized_tables,
        s_vocab=shuffled_vocab,
        sample_frac=1.0,
        columns=False)
    all_col_edges_index, all_col_edges_weights = build_graph_edges(
        tokenized_tables,
        s_vocab=shuffled_vocab,
        sample_frac=1.0,
        columns=True)
    all_possible_edges = torch.cat((all_row_edges_index, all_col_edges_index),
                                   dim=1)

    edges = torch.cat((row_edges_index, col_edges_index), dim=1)
    weights = torch.cat((row_edges_weights, col_edges_weights), dim=0)
    graph_data = Data(x=nodes, edge_index=edges, edge_attr=weights)

    # ## 2 ) Run Table Auto-Encoder Model:
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    loader = DataLoader(torch.arange(graph_data.num_nodes),
                        batch_size=128,
                        shuffle=True)
    graph_data = graph_data.to(device)

    x, train_pos_edge_index = nodes, edges

    EPS = 1e-15
    MAX_LOGVAR = 10

    class TVGAE(GAE):
        r"""The Variational Graph Auto-Encoder model from the
        `"Variational Graph Auto-Encoders" <https://arxiv.org/abs/1611.07308>`_
        paper.

        Args:
            encoder (Module): The encoder module to compute :math:`\mu` and
                :math:`\log\sigma^2`.
            decoder (Module, optional): The decoder module. If set to :obj:`None`,
                will default to the
                :class:`torch_geometric.nn.models.InnerProductDecoder`.
                (default: :obj:`None`)
        """
        def __init__(self, encoder, decoder=None):
            super(TVGAE, self).__init__(encoder, decoder)

        def reparametrize(self, mu, logvar):
            if self.training:
                return mu + torch.randn_like(logvar) * torch.exp(logvar)
            else:
                return mu

        def encode(self, *args, **kwargs):
            """"""
            self.__rmu__, self.__rlogvar__, self.__cmu__, self.__clogvar__ = self.encoder(
                *args, **kwargs)
            self.__rlogvar__ = self.__rlogvar__.clamp(max=MAX_LOGVAR)
            self.__clogvar__ = self.__clogvar__.clamp(max=MAX_LOGVAR)
            zr = self.reparametrize(self.__rmu__, self.__rlogvar__)
            zc = self.reparametrize(self.__cmu__, self.__clogvar__)
            z = torch.cat((zr, zc), 0)
            return z

        def kl_loss(self):

            rmu = self.__rmu__
            rlogvar = self.__rlogvar__

            cmu = self.__cmu__
            clogvar = self.__clogvar__

            rkl = -0.5 * torch.mean(
                torch.sum(1 + rlogvar - rmu**2 - rlogvar.exp(), dim=1))
            ckl = -0.5 * torch.mean(
                torch.sum(1 + clogvar - rmu**2 - clogvar.exp(), dim=1))
            return (rkl, ckl)

        def recon_loss(self, z, pos_edge_index, all_possible_edges):
            EPS = 1e-15
            MAX_LOGVAR = 10

            pos_loss = -torch.log(
                model.decoder(z, pos_edge_index, sigmoid=True) + EPS).mean()

            # Do not include self-loops in negative samples
            pos_edge_index, _ = remove_self_loops(pos_edge_index)
            pos_edge_index, _ = add_self_loops(pos_edge_index)

            neg_edge_index = negative_sampling(all_possible_edges, z.size(0))
            neg_loss = -torch.log(1 - model.decoder(
                z, neg_edge_index, sigmoid=True) + EPS).mean()

            return pos_loss + neg_loss

    class Encoder(torch.nn.Module):
        def __init__(self, in_channels, out_channels):
            super(Encoder, self).__init__()
            self.conv_rows = GCNConv(in_channels,
                                     2 * out_channels,
                                     cached=True)
            self.conv_cols = GCNConv(in_channels,
                                     2 * out_channels,
                                     cached=True)

            self.conv_rmu = GCNConv(2 * out_channels,
                                    out_channels,
                                    cached=True)
            self.conv_rlogvar = GCNConv(2 * out_channels,
                                        out_channels,
                                        cached=True)

            self.conv_cmu = GCNConv(2 * out_channels,
                                    out_channels,
                                    cached=True)
            self.conv_clogvar = GCNConv(2 * out_channels,
                                        out_channels,
                                        cached=True)

        def forward(self, x, row_edge_index, col_edge_index):
            xr = F.relu(self.conv_rows(x, row_edge_index))
            xc = F.relu(self.conv_cols(x, col_edge_index))
            return self.conv_rmu(xr, row_edge_index),\
                self.conv_rlogvar(xr, row_edge_index),\
                self.conv_cmu(xc, col_edge_index),\
                self.conv_clogvar(xc, col_edge_index)

    channels = conf["vector_size"]

    enc = Encoder(graph_data.num_features, channels)
    model = TVGAE(enc)
    model = model.to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=0.01)

    def train(model, optimizer, x, row_edges, col_edges):
        model.train()
        optimizer.zero_grad()
        z = model.encode(x, row_edges, col_edges)
        mid = int(len(z) / 2)
        zr = z[:mid]
        zc = z[mid:]

        #recon loss:
        rrl = model.recon_loss(zr, row_edges, all_possible_edges)
        crl = model.recon_loss(zc, col_edges, all_possible_edges)
        #loss = rrl+crl

        rkl, ckl = model.kl_loss()
        #loss = rkl+ckl

        loss = rrl + crl + rkl + ckl

        loss.backward()
        optimizer.step()
        #return loss,rrl,crl
        return loss, rrl, crl, rkl, ckl

    def get_cell_vectors(model, x, row_edges_index, col_edges_index):
        model.eval()
        with torch.no_grad():
            z = model.encode(x, row_edges_index, col_edges_index)
            cell_vectors = z.numpy()
        return z, cell_vectors

    losses = []
    results = []
    for epoch in range(conf["epoch_num"]):
        #loss,row_loss,col_loss = train(model,optimizer,x,row_edges_index,col_edges_index)
        loss = train(model, optimizer, x, row_edges_index, col_edges_index)
        losses.append(loss)
        print(epoch, loss)
        z, cell_vectors = get_cell_vectors(model, x, row_edges_index,
                                           col_edges_index)
        vec_list = generate_table_vectors(cell_vectors,
                                          tokenized_tables,
                                          s_vocab=shuffled_vocab)
        result_score = evaluate_model(dataset, vec_list, k=5)
        print(result_score)
        results.append(result_score)

    # ### 3) Extract the latent cell vectors, generate table vectors:

    #z,cell_vectors = get_cell_vectors(model,x,train_pos_edge_index)

    #vec_list=generate_table_vectors(cell_vectors,tokenized_tables,s_vocab=shuffled_vocab)

    # ## 3) Evaluate the model
    #result_score=evaluate_model(dataset,vec_list,k=5)
    return cell_vectors, vec_list, losses, results
Ejemplo n.º 11
0
def run_model(dataset, conf):
    # ## 1) Build Table graph
    # ### Tables tokenization
    tokenized_tables, vocabulary, cell_dict, reversed_dictionary = corpus_tuple = create_corpus(
        dataset, include_attr=conf["add_attr"])
    if conf["shuffle_vocab"] == True:
        shuffled_vocab = shuffle_vocabulary(vocabulary)
    else:
        shuffled_vocab = None

    nodes = build_node_features(vocabulary)
    row_edges_index, row_edges_weights = build_graph_edges(
        tokenized_tables,
        s_vocab=shuffled_vocab,
        sample_frac=conf["row_edges_sample"],
        columns=False)
    col_edges_index, col_edges_weights = build_graph_edges(
        tokenized_tables,
        s_vocab=shuffled_vocab,
        sample_frac=conf["column_edges_sample"],
        columns=True)

    edges = torch.cat((row_edges_index, col_edges_index), dim=1)
    weights = torch.cat((row_edges_weights, col_edges_weights), dim=0)
    graph_data = Data(x=nodes, edge_index=edges, edge_attr=weights)

    # ## 2 ) Run Table Auto-Encoder Model:
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    loader = DataLoader(torch.arange(graph_data.num_nodes),
                        batch_size=128,
                        shuffle=True)
    graph_data = graph_data.to(device)

    x, train_pos_edge_index = nodes, edges

    class Encoder(torch.nn.Module):
        def __init__(self, in_channels, out_channels):
            super(Encoder, self).__init__()
            self.conv1 = GCNConv(in_channels, 2 * out_channels, cached=True)
            self.conv_mu = GCNConv(2 * out_channels, out_channels, cached=True)
            self.conv_logvar = GCNConv(2 * out_channels,
                                       out_channels,
                                       cached=True)

        def forward(self, x, edge_index):
            x = F.relu(self.conv1(x, edge_index))
            return self.conv_mu(x, edge_index), self.conv_logvar(x, edge_index)

    channels = conf["vector_size"]
    enc = Encoder(graph_data.num_features, channels)
    model = VGAE(enc)
    model = model.to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=0.01)

    def train(model, optimizer, x, train_pos_edge_index):
        model.train()
        optimizer.zero_grad()
        z = model.encode(x, train_pos_edge_index)
        rl = model.recon_loss(z, train_pos_edge_index)
        kl = model.kl_loss()

        loss = rl + kl

        loss.backward()
        optimizer.step()
        return (rl, kl, loss)

    losses = []
    for epoch in range(conf["epoch_num"]):
        loss = train(model, optimizer, x, train_pos_edge_index)
        losses.append(loss)
        print(epoch, loss)
        losses.append(loss)
    # ### 3) Extract the latent cell vectors, generate table vectors:
    def get_cell_vectors(model, x, train_pos_edge_index):
        model.eval()
        with torch.no_grad():
            z = model.encode(x, train_pos_edge_index)
            cell_vectors = z.numpy()
        return z, cell_vectors

    z, cell_vectors = get_cell_vectors(model, x, train_pos_edge_index)

    vec_list = generate_table_vectors(cell_vectors,
                                      tokenized_tables,
                                      s_vocab=shuffled_vocab)

    # ## 3) Evaluate the model
    result_score = evaluate_model(dataset, vec_list, k=5)
    return cell_vectors, vec_list, losses, result_score