Exemplo n.º 1
0
 def load_latest_checkpoint(self, model, optimizer=None):
     if not self.continue_training:
         raise RuntimeError(
             f"Trying to load the latest checkpoint from logdir {self.logdir}, "
             "but did not set `continue_training=true` in configuration.")
     model, optimizer, iteration = load_latest_checkpoint(
         self.logdir, model, optimizer)
     return model, optimizer, iteration
Exemplo n.º 2
0
def train():
    model_type = FLAGS.model_type
    run_desc = FLAGS.run_desc
    run_desc_tl = FLAGS.run_desc_tl
    data_dir = Path(FLAGS.data_dir)
    checkpoints_dir = Path(FLAGS.checkpoints_dir) / model_type / run_desc
    models_dir = Path(FLAGS.models_dir) / model_type / run_desc
    results_dir = Path(FLAGS.results_dir) / model_type / run_desc
    checkpoints_dir_tl = Path(FLAGS.checkpoints_dir) / model_type / run_desc_tl
    models_dir_tl = Path(FLAGS.models_dir) / model_type / run_desc_tl
    results_dir_tl = Path(FLAGS.results_dir) / model_type / run_desc_tl
    learning_rate = FLAGS.learning_rate
    batch_size_fn = FLAGS.batch_size
    epoch_no = FLAGS.epoch
    sent_hidden_dim = FLAGS.sent_hidden_dim
    doc_hidden_dim = FLAGS.doc_hidden_dim

    if not data_dir.exists():
        raise ValueError('Data directory does not exist')

    # create other directories if they do not exist
    create_directories(checkpoints_dir_tl, models_dir_tl, results_dir_tl)

    # load the data
    print('Loading the data...')

    # get the glove and elmo embedding
    glove_dim = 0
    elmo_dim = 0
    GloVe_vectors = None
    ELMo = None
    if 'glove' in model_type:
        GloVe_vectors = GloVe()
        glove_dim = WORD_EMBED_DIM
        print('Uploaded GloVe embeddings.')
    if 'elmo' in model_type:
        ELMo = Elmo(options_file=ELMO_OPTIONS_FILE,
                    weight_file=ELMO_WEIGHT_FILE,
                    num_output_representations=1,
                    requires_grad=False,
                    dropout=0).to(DEVICE)
        elmo_dim = ELMO_EMBED_DIM
        print('Uploaded Elmo embeddings.')
    input_dim = glove_dim + elmo_dim
    # get the fnn and snli data
    keys = ['train', 'test', 'val']
    FNN_DL_small = {}
    for i in keys:
        FNN_temp = FNNDataset(data_dir / ('FNN_small_' + i + '.pkl'),
                              GloVe_vectors, ELMo)
        FNN_DL_temp = data.DataLoader(dataset=FNN_temp,
                                      batch_size=batch_size_fn,
                                      num_workers=0,
                                      shuffle=True,
                                      drop_last=True,
                                      collate_fn=PadSortBatchFNN())
        FNN_DL_small[i] = FNN_DL_temp
    print('Uploaded FNN data.')

    # initialize the model, according to the model type
    print('Initializing the model for transfer learning...', end=' ')

    model = HierarchicalAttentionNet(input_dim=input_dim,
                                     sent_hidden_dim=sent_hidden_dim,
                                     doc_hidden_dim=doc_hidden_dim,
                                     num_classes=NUM_CLASSES_FN,
                                     dropout=0).to(DEVICE)
    print('Done!')
    print_model_parameters(model)
    print()
    print('Working on: ', end='')
    print(DEVICE)

    # set the criterion and optimizer
    # we weigh the loss: class [0] is real, class [1] is fake
    #
    loss_func_fn = nn.CrossEntropyLoss()
    optimizer = optim.Adam(params=model.parameters(), lr=learning_rate)

    # load the last checkpoint (if it exists)
    results = {
        'epoch': [],
        'train_loss': [],
        'train_accuracy': [],
        'val_loss': [],
        'val_accuracy': []
    }
    if epoch_no == '0':
        model_path = models_dir / Path('HierarchicalAttentionNet_model.pt')
        _, _, _ = load_latest_checkpoint(model_path, model, optimizer)
    else:
        checkpoint_path = checkpoints_dir / Path(
            'HierarchicalAttentionNet_Adam_checkpoint_' + str(epoch_no) +
            '_.pt')
        _, _, _ = load_checkpoint(checkpoint_path, model, optimizer)
    print(f'Starting transfer learning on the model extracted from {epoch_no}')
    epoch = 0
    for i in range(epoch, MAX_EPOCHS):
        print(f'Epoch {i+1:0{len(str(MAX_EPOCHS))}}/{MAX_EPOCHS}:')
        model.train()
        # one epoch of training
        train_loss_fn, train_acc_fn = train_epoch_fn(FNN_DL_small['train'],
                                                     model, optimizer,
                                                     loss_func_fn)

        # one epoch of eval
        model.eval()
        val_loss_fn, val_acc_fn = eval_epoch_fn(FNN_DL_small['val'], model,
                                                loss_func_fn)

        results['epoch'].append(i)
        results['train_loss'].append(train_loss_fn)
        results['train_accuracy'].append(train_acc_fn)
        results['val_loss'].append(val_loss_fn)
        results['val_accuracy'].append(val_acc_fn)
        #print(results)
        best_accuracy = torch.tensor(val_acc_fn).max().item()
        create_checkpoint(checkpoints_dir_tl, i, model, optimizer, results,
                          best_accuracy)

    # save and plot the results
    save_results(results_dir_tl, results, model)
    save_model(models_dir_tl, model)
def train():
    model_type = FLAGS.model_type
    run_desc = FLAGS.run_desc
    data_dir = Path(FLAGS.data_dir)
    checkpoints_dir = Path(FLAGS.checkpoints_dir) / model_type / run_desc
    models_dir = Path(FLAGS.models_dir) / model_type / run_desc
    results_dir = Path(FLAGS.results_dir) / model_type / run_desc
    learning_rate = LEARNING_RATE
    sent_hidden_dim = FLAGS.sent_hidden_dim
    doc_hidden_dim = FLAGS.doc_hidden_dim

    if not data_dir.exists():
        raise ValueError('Data directory does not exist')

    # create other directories if they do not exist
    create_directories(checkpoints_dir, models_dir, results_dir)

    # load the data
    print('Loading the data...')

    # get the glove and elmo embedding
    glove_dim = 0
    elmo_dim = 0
    GloVe_vectors = None
    ELMo = None
    if 'glove' in model_type:
        GloVe_vectors = GloVe()
        glove_dim = WORD_EMBED_DIM
        print('Uploaded GloVe embeddings.')
    if 'elmo' in model_type:
        ELMo = Elmo(options_file=ELMO_OPTIONS_FILE,
                    weight_file=ELMO_WEIGHT_FILE,
                    num_output_representations=1,
                    requires_grad=False,
                    dropout=0).to(DEVICE)
        elmo_dim = ELMO_EMBED_DIM
        print('Uploaded Elmo embeddings.')
    input_dim = glove_dim + elmo_dim
    # get the fnn and snli data
    FNN = {}
    FNN_DL = {}

    for path in ['train', 'val', 'test']:
        FNN[path] = FNNDataset(data_dir / ('FNN_' + path + '.pkl'),
                               GloVe_vectors, ELMo)
        FNN_DL[path] = data.DataLoader(dataset=FNN[path],
                                       batch_size=BATCH_SIZE_FN,
                                       num_workers=0,
                                       shuffle=True,
                                       drop_last=True,
                                       collate_fn=PadSortBatchFNN())
    print('Uploaded FNN data.')

    fnn_train_sent_no = get_number_sentences(data_dir / 'FNN_train.pkl')
    fnn_train_len = len(FNN['train'])

    # initialize the model, according to the model type
    print('Initializing the model...', end=' ')

    model = HierarchicalAttentionNet(input_dim=input_dim,
                                     sent_hidden_dim=sent_hidden_dim,
                                     doc_hidden_dim=doc_hidden_dim,
                                     num_classes=NUM_CLASSES_FN,
                                     dropout=0).to(DEVICE)
    print('Working on: ', end='')
    print(DEVICE)
    print('Done!')
    print_model_parameters(model)
    print()

    # set the criterion and optimizer
    # we weigh the loss: class [0] is real, class [1] is fake
    #
    real_ratio, fake_ratio = get_class_balance(data_dir / 'FNN_train.pkl')
    weights = [(1.0 - real_ratio), (1.0 - fake_ratio)]
    print(weights)
    class_weights = torch.FloatTensor(weights).to(DEVICE)
    loss_func_fn = nn.CrossEntropyLoss(weight=class_weights)
    optimizer = optim.Adam(params=model.parameters(), lr=LEARNING_RATE)

    # load the last checkpoint (if it exists)
    results = {
        'epoch': [],
        'train_loss': [],
        'train_accuracy': [],
        'val_loss': [],
        'val_accuracy': []
    }
    epoch, results, best_accuracy = load_latest_checkpoint(
        checkpoints_dir, model, optimizer)
    if epoch == 0:
        print(f'Starting training at epoch {epoch + 1}...')
    else:
        print(f'Resuming training from epoch {epoch + 1}...')

    for i in range(epoch, MAX_EPOCHS):
        print(f'Epoch {i+1:0{len(str(MAX_EPOCHS))}}/{MAX_EPOCHS}:')
        model.train()
        # one epoch of training
        train_loss_fn, train_acc_fn = train_epoch_fn(FNN_DL['train'], model,
                                                     optimizer, loss_func_fn)

        # one epoch of eval
        model.eval()
        val_loss_fn, val_acc_fn = eval_epoch_fn(FNN_DL['val'], model,
                                                loss_func_fn)

        results['epoch'].append(i)
        results['train_loss'].append(train_loss_fn)
        results['train_accuracy'].append(train_acc_fn)
        results['val_loss'].append(val_loss_fn)
        results['val_accuracy'].append(val_acc_fn)
        #print(results)
        best_accuracy = torch.tensor(val_acc_fn).max().item()
        create_checkpoint(checkpoints_dir, i, model, optimizer, results,
                          best_accuracy)
        if (i + 1) % 4 == 0 and i != 0:
            learning_rate = learning_rate / 2
            optimizer = optim.Adam(params=model.parameters(), lr=learning_rate)

    # save and plot the results
    save_results(results_dir, results, model)
    save_model(models_dir, model)
def train():
    model_type = FLAGS.model_type
    run_desc = FLAGS.run_desc
    data_dir = Path(FLAGS.data_dir)
    checkpoints_dir = Path(FLAGS.checkpoints_dir) / model_type / run_desc
    models_dir = Path(FLAGS.models_dir) / model_type / run_desc
    results_dir = Path(FLAGS.results_dir) / model_type / run_desc
    #data_percentage = FLAGS.data_percentage

    if model_type == 'STL':
        only_fn = True
    else:
        only_fn = False

    GloVe_vectors = GloVe()

    ELMo = Elmo(options_file=ELMO_OPTIONS_FILE,
                weight_file=ELMO_WEIGHT_FILE,
                num_output_representations=1,
                requires_grad=False,
                dropout=0).to(DEVICE)

    FNN_test = FNNDataset(DATA_DIR_DEFAULT / ('FNN_test.pkl'), GloVe_vectors,
                          ELMo)
    FNN_DL_test = data.DataLoader(dataset=FNN_test,
                                  batch_size=BATCH_SIZE_FN,
                                  num_workers=0,
                                  shuffle=True,
                                  drop_last=True,
                                  collate_fn=PadSortBatch())

    input_dim = 300 + 1024
    NUM_CLASSES_NLI = None

    model = HierarchicalAttentionNet(input_dim=input_dim,
                                     hidden_dim=WORD_HIDDEN_DIM,
                                     num_classes_task_fn=NUM_CLASSES_FN,
                                     embedding=None,
                                     num_classes_task_nli=NUM_CLASSES_NLI,
                                     dropout=0).to(DEVICE)
    optimizer = optim.Adam(params=model.parameters(), lr=LEARNING_RATE)

    #model.load_state_dict(torch.load(CHECKPOINTS_DIR_DEFAULT / 'HierarchicalAttentionNet_model.pt'))
    _, _, _ = load_latest_checkpoint(CHECKPOINTS_DIR_DEFAULT, model, optimizer)
    model.eval()
    loss_func_fn = nn.CrossEntropyLoss()
    y_pred = []
    y_true = []
    for step, batch in enumerate(FNN_DL_test):
        articles, article_dims, labels = batch
        out = model(batch=articles, batch_dims=article_dims, task='FN')
        y_pred.append(out.argmax(dim=1).to(DEVICE).item())
        y_true.append(labels.to(DEVICE).item())
        if step % 100 == 0 and step != 0:
            print(
                sklearn.metrics.precision_recall_fscore_support(y_true,
                                                                y_pred,
                                                                average=None))
    print(
        sklearn.metrics.precision_recall_fscore_support(y_true,
                                                        y_pred,
                                                        average='micro'))
    print(
        sklearn.metrics.precision_recall_fscore_support(y_true,
                                                        y_pred,
                                                        average='macro'))
def train():
    model_type = FLAGS.model_type
    run_desc = FLAGS.run_desc
    data_dir = Path(FLAGS.data_dir)
    checkpoints_dir = Path(FLAGS.checkpoints_dir) / model_type / run_desc
    models_dir = Path(FLAGS.models_dir) / model_type / run_desc
    results_dir = Path(FLAGS.results_dir) / model_type / run_desc
    #data_percentage = FLAGS.data_percentage

    if model_type == 'STL':
        only_fn = True
    else:
        only_fn = False

    # check if data directory exists
    if not data_dir.exists():
        raise ValueError('Data directory does not exist')

    # create other directories if they do not exist
    create_directories(checkpoints_dir, models_dir, results_dir)

    # load the data
    print('Loading the data...')

    # get the glove and elmo embeddings
    GloVe_vectors = GloVe()
    print('Uploaded GloVe embeddings.')
    # ELMo = Elmo(
    #         options_file=ELMO_OPTIONS_FILE,
    #         weight_file=ELMO_WEIGHT_FILE,
    #         num_output_representations=1,
    #         requires_grad=False,
    #         dropout=0).to(DEVICE)
    # print('Uploaded Elmo embeddings.')
    # get the fnn and snli data
    FNN = {}
    FNN_DL = {}

    for path in ['train', 'val', 'test']:
        FNN[path] = FNNDataset(data_dir / ('FNN_' + path + '.pkl'),
                               GloVe_vectors)
        FNN_DL[path] = data.DataLoader(dataset=FNN[path],
                                       batch_size=BATCH_SIZE_FN,
                                       num_workers=0,
                                       shuffle=True,
                                       drop_last=True,
                                       collate_fn=PadSortBatch())
    print('Uploaded FNN data.')
    if not only_fn:
        SNLI = {}
        SNLI_DL = {}
        for path in ['train', 'val', 'test']:
            SNLI[path] = SNLIDataset(data_dir / ('SNLI_' + path + '.pkl'),
                                     GloVe_vectors)
            SNLI_DL[path] = data.DataLoader(dataset=SNLI[path],
                                            batch_size=BATCH_SIZE_NLI,
                                            num_workers=0,
                                            shuffle=True,
                                            drop_last=True,
                                            collate_fn=PadSortBatchSNLI())
        print('Uploaded SNLI data.')
        snli_train_sent_no = len(SNLI['train']) * 2
        snli_train_len = len(SNLI['train'])
    fnn_train_sent_no = get_number_sentences(data_dir / 'FNN_train.pkl')
    fnn_train_len = len(FNN['train'])

    # initialize the model, according to the model type
    print('Initializing the model...', end=' ')
    if model_type == 'MTL':
        NUM_CLASSES_NLI = 3
        print("Loading an MTL HAN model.")
    elif model_type == 'STL':
        NUM_CLASSES_NLI = None
        print("Loading an STL HAN model.")
    elif model_type == 'Transfer':
        print("Nothing for now.")
    if ELMO_EMBED_DIM is not None:
        # input_dim = WORD_EMBED_DIM + ELMO_EMBED_DIM
        input_dim = WORD_EMBED_DIM
    else:
        input_dim = WORD_EMBED_DIM
    model = HierarchicalAttentionNet(input_dim=input_dim,
                                     hidden_dim=WORD_HIDDEN_DIM,
                                     num_classes_task_fn=NUM_CLASSES_FN,
                                     embedding=None,
                                     num_classes_task_nli=NUM_CLASSES_NLI,
                                     dropout=0).to(DEVICE)
    print('Working on: ', end='')
    print(DEVICE)
    print('Done!')
    print_model_parameters(model)
    print()

    # set the criterion and optimizer
    # we weigh the loss: class [0] is real, class [1] is fake
    #
    real_ratio, fake_ratio = get_class_balance(data_dir / 'FNN_train.pkl')
    weights = [(1.0 - real_ratio), (1.0 - fake_ratio)]
    print(weights)
    class_weights = torch.FloatTensor(weights).to(DEVICE)
    loss_func_fn = nn.CrossEntropyLoss(weight=class_weights)
    if not only_fn:
        loss_func_nli = nn.CrossEntropyLoss()
        temperature = 2
    optimizer = optim.Adam(params=model.parameters(), lr=LEARNING_RATE)

    # load the last checkpoint (if it exists)
    epoch, results, best_accuracy = load_latest_checkpoint(
        checkpoints_dir, model, optimizer)
    results_fn = {
        'epoch': [],
        'train_loss': [],
        'train_accuracy': [],
        'val_loss': [],
        'val_accuracy': []
    }
    results_nli = {
        'epoch': [],
        'train_loss': [],
        'train_accuracy': [],
        'val_loss': [],
        'val_accuracy': []
    }
    results = {'fn': results_fn, 'nli': results_nli}
    if epoch == 0:
        print(f'Starting training at epoch {epoch + 1}...')
    else:
        print(f'Resuming training from epoch {epoch + 1}...')

    for i in range(epoch, MAX_EPOCHS):
        print(f'Epoch {i+1:0{len(str(MAX_EPOCHS))}}/{MAX_EPOCHS}:')
        model.train()
        # one epoch of training
        if only_fn:
            train_loss_fn, train_acc_fn = train_epoch_fn(
                FNN_DL['train'], model, optimizer, loss_func_fn)
        elif model_type == 'MTL':
            model.train()

            train_loss_fn = []
            train_acc_fn = []
            loss_fn_weight_gradnorm = 1

            train_loss_nli = []
            train_acc_nli = []
            loss_nli_weight_gradnorm = 1

            #define by sentence number
            #loss_fn_weight_dataset = 1 - fnn_train_sent_no / (fnn_train_sent_no + snli_train_sent_no)
            #loss_nli_weight_dataset = 1 - snli_train_sent_no / (fnn_train_sent_no + snli_train_sent_no)
            loss_fn_weight_dataset = 1 - fnn_train_len / (fnn_train_len +
                                                          snli_train_len)
            loss_nli_weight_dataset = 1 - snli_train_len / (fnn_train_len +
                                                            snli_train_len)

            chance_fn = 1000 * (fnn_train_len / BATCH_SIZE_FN) / (
                (fnn_train_len / BATCH_SIZE_FN) +
                (snli_train_len / BATCH_SIZE_NLI))
            iterator_fnn = enumerate(FNN_DL['train'])
            iterator_snli = enumerate(SNLI_DL['train'])
            done_fnn, done_snli = False, False
            step_fnn = 0
            step_snli = 0
            print(
                f'Train set length, FNN: {fnn_train_len}. Train set length, SNLI: {snli_train_len}.'
            )
            print(
                f'Training set to batch size ratio for Fake News Detection is {fnn_train_len / BATCH_SIZE_FN}.'
            )
            print(
                f'Training set to batch size ratio for Language Inference is {snli_train_len / BATCH_SIZE_NLI}.'
            )

            while not (done_fnn and done_snli):
                if len(train_loss_fn) > 1 and len(train_loss_nli) > 1:
                    # computes loss weights based on the loss from the previous iterations
                    loss_fn_ratio = train_loss_fn[len(train_loss_fn) -
                                                  1] / train_loss_fn[
                                                      len(train_loss_fn) - 2]
                    loss_nli_ratio = train_loss_nli[
                        len(train_acc_nli) -
                        1] / train_loss_nli[len(train_loss_nli) - 2]
                    loss_fn_exp = math.exp(loss_fn_ratio / temperature)
                    loss_nli_exp = math.exp(loss_nli_ratio / temperature)
                    loss_fn_weight_gradnorm = loss_fn_exp / (loss_fn_exp +
                                                             loss_nli_exp)
                    loss_nli_weight_gradnorm = loss_nli_exp / (loss_fn_exp +
                                                               loss_nli_exp)
                    loss_fn_weight = math.exp(
                        loss_fn_weight_dataset * loss_fn_weight_gradnorm) / (
                            math.exp(loss_fn_weight_dataset *
                                     loss_fn_weight_gradnorm) +
                            math.exp(loss_nli_weight_dataset *
                                     loss_nli_weight_gradnorm))
                    loss_nli_weight = math.exp(
                        loss_nli_weight_dataset * loss_nli_weight_gradnorm) / (
                            math.exp(loss_fn_weight_dataset *
                                     loss_fn_weight_gradnorm) +
                            math.exp(loss_nli_weight_dataset *
                                     loss_nli_weight_gradnorm))
                else:
                    loss_fn_weight = loss_fn_weight_dataset
                    loss_nli_weight = loss_nli_weight_dataset

                # define the total loss function
                #loss_func = loss_func_fn + loss_func_nli
                # is this needed?

                if np.random.randint(0, 1000) < chance_fn:
                    try:
                        step_fnn, batch_fnn = next(iterator_fnn)
                    except StopIteration:
                        done_fnn = True
                    else:
                        try:
                            batch_loss_fn, batch_acc_fn = train_batch_fn(
                                batch_fnn, model, optimizer, loss_func_fn,
                                loss_fn_weight)
                            train_loss_fn.append(batch_loss_fn)
                            train_acc_fn.append(batch_acc_fn)
                        except:
                            print('Error in batch')
                else:
                    try:
                        step_snli, batch_snli = next(iterator_snli)
                    except StopIteration:
                        done_snli = True
                    else:
                        try:
                            batch_loss_nli, batch_acc_nli = train_batch_nli(
                                batch_snli, model, optimizer, loss_func_nli,
                                loss_nli_weight)
                            train_loss_nli.append(batch_loss_nli)
                            train_acc_nli.append(batch_acc_nli)
                        except:
                            print('Error in batch')
                print(f'FNN batch {step_fnn}')
                print(f'SNLI batch {step_snli}')
                if step_fnn % 50 == 0 and step_fnn != 0:
                    print(f'Processed {step_fnn} FNN batches.')
                    print(f'Accuracy: {train_acc_fn[len(train_acc_fn)-1]}.')
                    print(
                        f'Weight for loss for NLI is {loss_nli_weight}, for loss for FN is {loss_fn_weight}.'
                    )
                if step_snli % 50 == 0 and step_snli != 0:
                    print(f'Processed {step_snli} SNLIbatches.')
                    print(f'Accuracy: {train_acc_nli[len(train_acc_nli)-1]}.')
                    print(
                        f'Weight for loss for NLI is {loss_nli_weight}, for loss for FN is {loss_fn_weight}.'
                    )
        # one epoch of eval
        model.eval()
        val_loss_fn, val_acc_fn = eval_epoch_fn(FNN_DL['val'], model,
                                                loss_func_fn)
        tasks = ['fn']
        if model_type == 'MTL':
            val_loss_nli, val_acc_nli = eval_epoch_nli(SNLI_DL['val'], model,
                                                       loss_func_nli)
            tasks.append('nli')

        for task in tasks:
            results[task]['epoch'].append(i)
            if task == 'fn':
                temp_train_loss = train_loss_fn
                temp_val_loss = val_loss_fn
                temp_train_acc = train_acc_fn
                temp_val_acc = val_acc_fn
            elif task == 'nli':
                temp_train_loss = train_loss_nli
                temp_val_loss = val_loss_nli
                temp_train_acc = train_acc_nli
                temp_val_acc = val_acc_nli

            results[task]['train_loss'].append(temp_train_loss)
            results[task]['train_accuracy'].append(temp_train_acc)
            results[task]['val_loss'].append(temp_val_loss)
            results[task]['val_accuracy'].append(temp_val_acc)
            print(results)

        best_accuracy = torch.tensor(temp_val_acc).max().item()
        create_checkpoint(checkpoints_dir, epoch, model, optimizer, results,
                          best_accuracy)

    # save and plot the results
    save_results(results_dir, results, model)
    save_model(models_dir, model)
    plot_results(results_dir, results, model)
def test():
    model_type = FLAGS.model_type
    run_desc = FLAGS.run_desc
    data_dir = Path(FLAGS.data_dir)
    checkpoints_dir = Path(FLAGS.checkpoints_dir) / model_type / run_desc
    models_dir = Path(FLAGS.models_dir) / model_type / run_desc
    results_dir = Path(FLAGS.results_dir) / model_type / run_desc
    learning_rate = LEARNING_RATE
    epoch_no = FLAGS.epoch
    sent_hidden_dim = FLAGS.sent_hidden_dim
    doc_hidden_dim = FLAGS.doc_hidden_dim

    if not data_dir.exists():
        raise ValueError('Data directory does not exist')

    # create other directories if they do not exist
    create_directories(checkpoints_dir, models_dir, results_dir)

    # load the data
    print('Loading the data...')

    # get the glove and elmo embedding
    glove_dim = 0
    elmo_dim = 0
    GloVe_vectors = None
    ELMo = None
    if 'glove' in model_type:
        GloVe_vectors = GloVe()
        glove_dim = WORD_EMBED_DIM
        print('Uploaded GloVe embeddings.')
    if 'elmo' in model_type:
        ELMo = Elmo(options_file=ELMO_OPTIONS_FILE,
                    weight_file=ELMO_WEIGHT_FILE,
                    num_output_representations=1,
                    requires_grad=False,
                    dropout=0).to(DEVICE)
        elmo_dim = ELMO_EMBED_DIM
        print('Uploaded Elmo embeddings.')
    input_dim = glove_dim + elmo_dim
    # get the fnn and snli data

    FNN_test = FNNDataset(data_dir / ('FNN_test.pkl'), GloVe_vectors, ELMo)
    FNN_DL_test = data.DataLoader(dataset=FNN_test,
                                  batch_size=BATCH_SIZE_FN,
                                  num_workers=0,
                                  shuffle=True,
                                  drop_last=True,
                                  collate_fn=PadSortBatchFNN())
    print('Uploaded FNN data.')

    print('Initializing the model...', end=' ')

    model = HierarchicalAttentionNet(input_dim=input_dim,
                                     sent_hidden_dim=sent_hidden_dim,
                                     doc_hidden_dim=doc_hidden_dim,
                                     num_classes=NUM_CLASSES_FN,
                                     dropout=0).to(DEVICE)
    print('Working on: ', end='')
    print(DEVICE)
    print('Done!')
    print_model_parameters(model)
    print()

    optimizer = optim.Adam(params=model.parameters(), lr=LEARNING_RATE)

    #model.load_state_dict(torch.load(CHECKPOINTS_DIR_DEFAULT / 'HierarchicalAttentionNet_model.pt'))
    if epoch_no == '0':
        model_path = models_dir / Path('HierarchicalAttentionNet_model.pt')
        _, _, _ = load_latest_checkpoint(model_path, model, optimizer)
    else:
        checkpoint_path = checkpoints_dir / Path(
            'HierarchicalAttentionNet_Adam_checkpoint_' + str(epoch_no) +
            '_.pt')
        _, _, _ = load_checkpoint(checkpoint_path, model, optimizer)
    model.eval()
    loss_func_fn = nn.CrossEntropyLoss()
    y_pred = []
    y_true = []
    for step, batch in enumerate(FNN_DL_test):
        articles, article_dims, labels = batch
        out = model(batch=articles, batch_dims=article_dims)
        y_pred.append(out.argmax(dim=1).to(DEVICE).item())
        y_true.append(labels.to(DEVICE).item())
        #if step % 100 == 0 and step != 0:
        #print(sklearn.metrics.precision_recall_fscore_support(y_true, y_pred, average=None))
    print(
        sklearn.metrics.precision_recall_fscore_support(y_true,
                                                        y_pred,
                                                        average='micro'))
    print(
        sklearn.metrics.precision_recall_fscore_support(y_true,
                                                        y_pred,
                                                        average='macro'))
    print(
        sklearn.metrics.precision_recall_fscore_support(y_true,
                                                        y_pred,
                                                        average=None))