Пример #1
0
    def fit(self):
        """
        Trains the transformer-based neural ranker.
        """
        logging.info("Total batches per epoch : {}".format(len(self.train_loader)))
        logging.info("Validating every {} epoch.".format(self.validate_epochs))        
        val_ndcg=0
        for epoch in range(self.num_epochs):
            for inputs in tqdm(self.train_loader, desc="Epoch {}".format(epoch), total=len(self.train_loader)):
                self.model.train()

                for k, v in inputs.items():
                    inputs[k] = v.to(self.device)                

                outputs = self.model(**inputs)
                last_hidden_states = outputs[0]
                loss_metric,err_pos,sparsity = self.metric_loss(last_hidden_states,inputs['labels'])

                if self.num_gpu > 1:
                    loss_metric = loss_metric.mean()

                loss_metric.backward()

                nn.utils.clip_grad_norm_(self.model.parameters(),
                                         self.max_grad_norm)
                self.optimizer.step()
                self.optimizer.zero_grad()

            if self.validate_epochs > 0 and epoch % self.validate_epochs == 0:
                logits, labels = self.predict(loader = self.val_loader)
                res = results_analyses_tools.evaluate_and_aggregate(logits, labels, ['ndcg_cut_10'])
                val_ndcg = res['ndcg_cut_10']
                if val_ndcg>self.best_ndcg:
                    self.best_ndcg = val_ndcg
                if self.sacred_ex != None:
                    self.sacred_ex.log_scalar('eval_ndcg_10', val_ndcg, epoch+1)                    

            logging.info('Epoch {} val nDCG@10 {:.3f}'.format(epoch + 1, val_ndcg))
    def fit(self):
        """
        Trains the transformer-based neural ranker.
        """
        logging.info("Total batches per epoch : {}".format(len(self.train_loader)))
        logging.info("Validating every {} epoch.".format(self.validate_epochs))        
        val_ndcg=0
        for epoch in range(self.num_epochs):
            for i,inputs in tqdm(enumerate(self.train_loader), desc="Epoch {}".format(epoch), total=len(self.train_loader)):
                self.model.train()
                for k, v in inputs.items():
                    if k != 'query':
                        inputs[k] = v.to(self.device)

                # outputs = self.model(**inputs)
                outputs = self.model(attention_mask=inputs['attention_mask'],input_ids=inputs['input_ids'],token_type_ids=inputs['token_type_ids'],labels=inputs['labels'])
                loss = outputs[0] 

                if self.num_gpu > 1:
                    loss = loss.mean() 

                loss.backward()

                nn.utils.clip_grad_norm_(self.model.parameters(),
                                         self.max_grad_norm)
                self.optimizer.step()
                self.optimizer.zero_grad()

            if self.validate_epochs > 0 and epoch % self.validate_epochs == 0:
                logits, labels, all_ids,all_queries,all_logits_without_acc = self.predict(loader = self.val_loader)
                res = results_analyses_tools.evaluate_and_aggregate(logits, labels, ['ndcg_cut_10'])
                val_ndcg = res['ndcg_cut_10']
                if val_ndcg>self.best_ndcg:
                    self.best_ndcg = val_ndcg
                if self.sacred_ex != None:
                    self.sacred_ex.log_scalar('eval_ndcg_10', val_ndcg, epoch+1)                    

            logging.info('Epoch {} val nDCG@10 {:.3f}'.format(epoch + 1, val_ndcg))
Пример #3
0
def run_experiment(args):
    args.run_id = str(ex.current_run._id)

    tokenizer = BertTokenizer.from_pretrained(args.transformer_model)
    #Load datasets
    ## Conversation Response Ranking
    if args.task in ["mantis", "msdialog", "ubuntu_dstc8"]: 
        add_turn_separator = (args.task != "ubuntu_dstc8") # Ubuntu data has several utterances from same user in the context.
        train = preprocess_crr.read_crr_tsv_as_df(args.data_folder+args.task+"/train.tsv", args.sample_data, add_turn_separator)
        valid = preprocess_crr.read_crr_tsv_as_df(args.data_folder+args.task+"/valid.tsv", args.sample_data, add_turn_separator)
        special_tokens_dict = {'additional_special_tokens': ['[UTTERANCE_SEP]', '[TURN_SEP]'] }
        tokenizer.add_special_tokens(special_tokens_dict)
    ## Similar Question Retrieval and Passage Retrieval
    elif args.task in ["qqp", "linkso", "trec2020pr"]:
        if args.sample_data == -1: args.sample_data=None            
        train = pd.read_csv(args.data_folder+args.task+"/train.tsv", sep="\t", nrows=args.sample_data)
        valid = pd.read_csv(args.data_folder+args.task+"/valid.tsv", sep="\t", nrows=args.sample_data)
    elif args.task=="scisumm":
        train, valid = preprocess_scisumm.transform_to_dfs("../data/Training-Set-2019/Task1/From-Training-Set-2018/")

    #Choose the negative candidate sampler
    document_col = train.columns[1]
    if args.train_negative_sampler == 'random':
        ns_train = negative_sampling.RandomNegativeSampler(list(train[document_col].values), args.num_ns_train)
    elif args.train_negative_sampler == 'bm25':
        ns_train = negative_sampling.BM25NegativeSamplerPyserini(list(train[document_col].values), args.num_ns_train, 
                    args.data_folder+"/"+args.task+"/anserini_train/", args.sample_data, args.anserini_folder)
    elif args.train_negative_sampler == 'sentenceBERT':
        ns_train = negative_sampling.SentenceBERTNegativeSampler(list(train[document_col].values), args.num_ns_train, 
                    args.data_folder+"/"+args.task+"/train_sentenceBERTembeds", args.sample_data, args.bert_sentence_model)

    if args.test_negative_sampler == 'random':
        ns_val = negative_sampling.RandomNegativeSampler(list(valid[document_col].values) + list(train[document_col].values), args.num_ns_eval)
    elif args.test_negative_sampler == 'bm25':
        ns_val = negative_sampling.BM25NegativeSamplerPyserini(list(valid[document_col].values) + list(train[document_col].values),
                    args.num_ns_eval, args.data_folder+"/"+args.task+"/anserini_valid/", args.sample_data, args.anserini_folder)
    elif args.test_negative_sampler == 'sentenceBERT':
        ns_val = negative_sampling.SentenceBERTNegativeSampler(list(valid[document_col].values) + list(train[document_col].values),
                    args.num_ns_eval, args.data_folder+"/"+args.task+"/valid_sentenceBERTembeds", args.sample_data, args.bert_sentence_model)

    #Create the loaders for the datasets, with the respective negative samplers
    dataloader = dataset.QueryDocumentDataLoader(train, valid, valid,
                                tokenizer, ns_train, ns_val,
                                'classification', args.train_batch_size, 
                                args.val_batch_size, args.max_seq_len, 
                                args.sample_data, args.data_folder + args.task)

    train_loader, val_loader, test_loader = dataloader.get_pytorch_dataloaders()


    #Instantiate transformer model to be used
    model = BertForSequenceClassification.from_pretrained(args.transformer_model)
    model.resize_token_embeddings(len(dataloader.tokenizer))

    #Instantiate trainer that handles fitting.
    trainer = transformer_trainer.TransformerTrainer(model, train_loader, val_loader, test_loader, 
                                 args.num_ns_eval, "classification", tokenizer,
                                 args.validate_every_epochs, args.num_validation_instances,
                                 args.num_epochs, args.lr, args.sacred_ex)

    #Train
    model_name = model.__class__.__name__
    logging.info("Fitting {} for {}{}".format(model_name, args.data_folder, args.task))
    trainer.fit()

    #Predict for test
    logging.info("Predicting")
    preds, labels = trainer.test()
    res = results_analyses_tools.evaluate_and_aggregate(preds, labels, ['R_10@1','R_10@1',
                    'R_10@2',
                    'R_10@5',
                    'R_2@1'])
    for metric, v in res.items():
        logging.info("Test {} : {:4f}".format(metric, v))

    #Saving predictions and labels to a file
    max_preds_column = max([len(l) for l in preds])
    preds_df = pd.DataFrame(preds, columns=["prediction_"+str(i) for i in range(max_preds_column)])
    preds_df.to_csv(args.output_dir+"/"+args.run_id+"/predictions.csv", index=False)

    labels_df = pd.DataFrame(labels, columns=["label_"+str(i) for i in range(max_preds_column)])
    labels_df.to_csv(args.output_dir+"/"+args.run_id+"/labels.csv", index=False)

    #Saving model to a file
    if args.save_model:
        torch.save(model.state_dict(), args.output_dir+"/"+args.run_id+"/model")

    #In case we want to get uncertainty estimations at prediction time
    if args.predict_with_uncertainty_estimation:  
        logging.info("Predicting with dropout.")      
        preds, uncertainties, labels, foward_passes_preds = trainer.test_with_dropout(args.num_foward_prediction_passes)
        res = results_analyses_tools.evaluate_and_aggregate(preds, labels, ['R_10@1'])
        for metric, v in res.items():
            logging.info("Test (w. dropout and {} foward passes) {} : {:4f}".format(args.num_foward_prediction_passes, metric, v))
        
        max_preds_column = max([len(l) for l in preds])
        preds_df = pd.DataFrame(preds, columns=["prediction_"+str(i) for i in range(max_preds_column)])
        preds_df.to_csv(args.output_dir+"/"+args.run_id+"/predictions_with_dropout.csv", index=False)

        for i, f_pass_preds in enumerate(foward_passes_preds):
            preds_df = pd.DataFrame(f_pass_preds, columns=["prediction_"+str(i) for i in range(max_preds_column)])
            preds_df.to_csv(args.output_dir+"/"+args.run_id+"/predictions_with_dropout_f_pass_{}.csv".format(i), index=False)

        labels_df = pd.DataFrame(labels, columns=["label_"+str(i) for i in range(max_preds_column)])
        labels_df.to_csv(args.output_dir+"/"+args.run_id+"/labels.csv", index=False)
        
        uncertainties_df = pd.DataFrame(uncertainties, columns=["uncertainty_"+str(i) for i in range(max_preds_column)])
        uncertainties_df.to_csv(args.output_dir+"/"+args.run_id+"/uncertainties.csv", index=False)

    return trainer.best_ndcg
    def fit(self):
        """
        Trains the transformer-based neural ranker.
        """
        logging.info("Total batches per epoch : {}".format(
            len(self.train_loader)))
        if self.validate_every_epochs > 0:
            logging.info("Validating every {} epoch.".format(
                self.validate_every_epochs))
        if self.validate_every_steps > 0:
            logging.info("Validating every {} step.".format(
                self.validate_every_steps))
        if self._has_wandb:
            wandb.watch(self.model)

        total_steps = 0
        total_loss = 0
        total_instances = 0

        if self.num_training_instances == -1:
            actual_epochs = self.num_epochs
        else:
            instances_in_one_epoch = len(
                self.train_loader) * self.train_loader.batch_size
            actual_epochs = -(-self.num_training_instances //
                              instances_in_one_epoch)  # rounding up
            logging.info(
                "Actual epochs (rounded up): {}".format(actual_epochs))

        for epoch in range(actual_epochs):
            epoch_batches_tqdm = tqdm(self.train_loader,
                                      desc="Epoch {}, steps".format(epoch),
                                      total=len(self.train_loader))
            for batch_inputs in epoch_batches_tqdm:
                self.model.train()

                for k, v in batch_inputs.items():
                    batch_inputs[k] = v.to(self.device)

                outputs = self.model(**batch_inputs)
                loss = outputs[0]

                if self.num_gpu > 1:
                    loss = loss.mean()

                loss.backward()
                total_loss += loss.item()

                nn.utils.clip_grad_norm_(self.model.parameters(),
                                         self.max_grad_norm)
                self.optimizer.step()
                self.optimizer.zero_grad()
                total_steps += 1
                total_instances += batch_inputs[k].shape[0]

                if self.num_training_instances != -1 and total_instances >= self.num_training_instances:
                    logging.info(
                        "Reached num_training_instances of {} ({} batches). Early stopping."
                        .format(self.num_training_instances, total_steps))
                    break

                #logging for steps
                is_validation_step = (self.validate_every_steps > 0
                                      and total_steps %
                                      self.validate_every_steps == 0)
                if is_validation_step:
                    logits, labels, _ = self.predict(loader=self.val_loader)
                    res = results_analyses_tools.evaluate_and_aggregate(
                        logits, labels, [self.validation_metric])
                    val_metric_res = res[self.validation_metric]
                    if val_metric_res > self.best_eval_metric:
                        self.best_eval_metric = val_metric_res
                    if self.sacred_ex != None:
                        self.sacred_ex.log_scalar(
                            self.validation_metric + "_by_step",
                            val_metric_res, total_steps)
                        self.sacred_ex.log_scalar("avg_loss_by_step",
                                                  total_loss / total_steps,
                                                  total_steps)
                    if self._has_wandb:
                        wandb.log({
                            'step':
                            total_steps,
                            self.validation_metric + "_by_step":
                            val_metric_res
                        })
                        wandb.log({
                            'step': total_steps,
                            "avg_loss_by_step": total_loss / total_steps
                        })

                    epoch_batches_tqdm.set_description(
                        "Epoch {} ({}: {:3f}), steps".format(
                            epoch, self.validation_metric, val_metric_res))

            #logging for epochs
            is_validation_epoch = (self.validate_every_epochs > 0
                                   and epoch % self.validate_every_epochs == 0)
            if is_validation_epoch:
                logits, labels, _ = self.predict(loader=self.val_loader)
                res = results_analyses_tools.evaluate_and_aggregate(
                    logits, labels, [self.validation_metric])
                val_metric_res = res[self.validation_metric]
                if val_metric_res > self.best_eval_metric:
                    self.best_eval_metric = val_metric_res
                if self.sacred_ex != None:
                    self.sacred_ex.log_scalar(
                        self.validation_metric + "_by_epoch", val_metric_res,
                        epoch + 1)
                    self.sacred_ex.log_scalar("avg_loss_by_epoch",
                                              total_loss / total_steps,
                                              epoch + 1)
                if self._has_wandb:
                    wandb.log({
                        'epoch':
                        epoch + 1,
                        self.validation_metric + "_by_epoch":
                        val_metric_res
                    })
                    wandb.log({
                        'epoch': epoch + 1,
                        "avg_loss_by_epoch": total_loss / total_steps
                    })
                epoch_batches_tqdm.set_description(
                    "Epoch {} ({}: {:3f}), steps".format(
                        epoch, self.validation_metric, val_metric_res))
Пример #5
0
def run_experiment(args):
    args.run_id = str(ex.current_run._id)

    torch.manual_seed(args.seed)
    torch.cuda.manual_seed(args.seed)

    tokenizer = BertTokenizer.from_pretrained(args.transformer_model)
    # Conversation Response Ranking datasets needs special tokens
    if args.task in ["mantis", "msdialog", "ubuntu_dstc8"]:
        special_tokens_dict = {
            'additional_special_tokens': ['[UTTERANCE_SEP]', '[TURN_SEP]']
        }
        tokenizer.add_special_tokens(special_tokens_dict)

    #Load datasets
    train = pd.read_csv(
        args.data_folder + args.task + "/train.tsv",
        sep="\t",
        nrows=args.sample_data if args.sample_data != -1 else None)
    valid = pd.read_csv(
        args.data_folder + args.task + "/valid.tsv",
        sep="\t",
        nrows=args.sample_data if args.sample_data != -1 else None)

    #Choose the negative candidate sampler
    document_col = train.columns[1]
    if args.train_negative_sampler == 'random':
        ns_train = negative_sampling.RandomNegativeSampler(
            list(train[document_col].values), args.num_ns_train)
    elif args.train_negative_sampler == 'bm25':
        ns_train = negative_sampling.BM25NegativeSamplerPyserini(
            list(train[document_col].values), args.num_ns_train,
            args.data_folder + args.task + "/anserini_train/",
            args.sample_data, args.anserini_folder)
    elif args.train_negative_sampler == 'sentenceBERT':
        ns_train = negative_sampling.SentenceBERTNegativeSampler(
            list(train[document_col].values), args.num_ns_train,
            args.data_folder + args.task + "/train_sentenceBERTembeds",
            args.sample_data, args.bert_sentence_model)

    if args.test_negative_sampler == 'random':
        ns_val = negative_sampling.RandomNegativeSampler(
            list(valid[document_col].values) +
            list(train[document_col].values), args.num_ns_eval)
    elif args.test_negative_sampler == 'bm25':
        ns_val = negative_sampling.BM25NegativeSamplerPyserini(
            list(valid[document_col].values) +
            list(train[document_col].values), args.num_ns_eval,
            args.data_folder + args.task + "/anserini_valid/",
            args.sample_data, args.anserini_folder)
    elif args.test_negative_sampler == 'sentenceBERT':
        ns_val = negative_sampling.SentenceBERTNegativeSampler(
            list(valid[document_col].values) +
            list(train[document_col].values), args.num_ns_eval,
            args.data_folder + args.task + "/valid_sentenceBERTembeds",
            args.sample_data, args.bert_sentence_model)

    #Create the loaders for the datasets, with the respective negative samplers
    dataloader = dataset.QueryDocumentDataLoader(
        train, valid, valid, tokenizer, ns_train, ns_val, 'classification',
        args.train_batch_size, args.val_batch_size, args.max_seq_len,
        args.sample_data, args.data_folder + args.task)

    train_loader, val_loader, test_loader = dataloader.get_pytorch_dataloaders(
    )

    #Instantiate transformer model to be used
    model = pointwise_bert.BertForPointwiseLearning.from_pretrained(
        args.transformer_model,
        loss_function=args.loss_function,
        smoothing=args.smoothing)

    model.resize_token_embeddings(len(dataloader.tokenizer))

    #Instantiate trainer that handles fitting.
    trainer = transformer_trainer.TransformerTrainer(
        model,
        train_loader,
        val_loader,
        test_loader,
        args.num_ns_eval,
        "classification",
        tokenizer,
        args.validate_every_epochs,
        args.num_validation_batches,
        args.num_epochs,
        args.lr,
        args.sacred_ex,
        args.validate_every_steps,
        validation_metric='R_10@1',
        num_training_instances=args.num_training_instances)

    #Train
    model_name = model.__class__.__name__
    logging.info("Fitting {} for {}{}".format(model_name, args.data_folder,
                                              args.task))
    trainer.fit()

    #Predict for test
    logging.info("Predicting for the validation set.")
    preds, labels, softmax_logits = trainer.test()
    res = results_analyses_tools.evaluate_and_aggregate(
        preds, labels, ['R_10@1'])
    for metric, v in res.items():
        logging.info("Test {} : {:3f}".format(metric, v))
        wandb.log({'step': 0, "dev_" + metric: v})

    #Saving predictions and labels to a file
    max_preds_column = max([len(l) for l in preds])
    preds_df = pd.DataFrame(
        preds,
        columns=["prediction_" + str(i) for i in range(max_preds_column)])
    preds_df.to_csv(args.output_dir + "/" + args.run_id + "/predictions.csv",
                    index=False)

    softmax_df = pd.DataFrame(
        softmax_logits,
        columns=["prediction_" + str(i) for i in range(max_preds_column)])
    softmax_df.to_csv(args.output_dir + "/" + args.run_id +
                      "/predictions_softmax.csv",
                      index=False)

    labels_df = pd.DataFrame(
        labels, columns=["label_" + str(i) for i in range(max_preds_column)])
    labels_df.to_csv(args.output_dir + "/" + args.run_id + "/labels.csv",
                     index=False)

    #Saving model to a file
    if args.save_model:
        torch.save(model.state_dict(),
                   args.output_dir + "/" + args.run_id + "/model")

    #In case we want to get uncertainty estimations at prediction time
    if args.predict_with_uncertainty_estimation:
        logging.info("Predicting with MC dropout for the validation set.")
        preds, labels, softmax_logits, foward_passes_preds, uncertainties = trainer.test_with_dropout(
            args.num_foward_prediction_passes)
        res = results_analyses_tools.evaluate_and_aggregate(
            preds, labels, ['R_10@1'])
        for metric, v in res.items():
            logging.info(
                "Test (w. dropout and {} foward passes) {} : {:3f}".format(
                    args.num_foward_prediction_passes, metric, v))

        max_preds_column = max([len(l) for l in preds])
        preds_df = pd.DataFrame(
            preds,
            columns=["prediction_" + str(i) for i in range(max_preds_column)])
        preds_df.to_csv(args.output_dir + "/" + args.run_id +
                        "/predictions_with_dropout.csv",
                        index=False)

        softmax_df = pd.DataFrame(
            softmax_logits,
            columns=["prediction_" + str(i) for i in range(max_preds_column)])
        softmax_df.to_csv(args.output_dir + "/" + args.run_id +
                          "/predictions_with_dropout_softmax.csv",
                          index=False)

        for i, f_pass_preds in enumerate(foward_passes_preds):
            preds_df = pd.DataFrame(f_pass_preds,
                                    columns=[
                                        "prediction_" + str(i)
                                        for i in range(max_preds_column)
                                    ])
            preds_df.to_csv(
                args.output_dir + "/" + args.run_id +
                "/predictions_with_dropout_f_pass_{}.csv".format(i),
                index=False)

        labels_df = pd.DataFrame(
            labels,
            columns=["label_" + str(i) for i in range(max_preds_column)])
        labels_df.to_csv(args.output_dir + "/" + args.run_id + "/labels.csv",
                         index=False)

        uncertainties_df = pd.DataFrame(
            uncertainties,
            columns=["uncertainty_" + str(i) for i in range(max_preds_column)])
        uncertainties_df.to_csv(args.output_dir + "/" + args.run_id +
                                "/uncertainties.csv",
                                index=False)

    return trainer.best_eval_metric
def run_experiment(args):
    args.run_id = str(ex.current_run._id)

    tokenizer = BertTokenizer.from_pretrained(args.transformer_model)
    # Load datasets
    ## Conversation Response Ranking
    if args.task in ["mantis", "msdialog", "ubuntu_dstc8"]:
        add_turn_separator = (
            args.task != "ubuntu_dstc8"
        )  # Ubuntu data has several utterances from same user in the context.
        train = preprocess_crr.read_crr_tsv_as_df(
            args.data_folder + args.task + "/train.tsv", args.sample_data,
            add_turn_separator)
        valid = preprocess_crr.read_crr_tsv_as_df(
            args.data_folder + args.task + "/valid.tsv", args.sample_data,
            add_turn_separator)
        special_tokens_dict = {
            'additional_special_tokens': ['[UTTERANCE_SEP]', '[TURN_SEP]']
        }
        tokenizer.add_special_tokens(special_tokens_dict)
    ## Similar Question Retrieval and Passage Retrieval
    elif args.task in ["qqp", "linkso", "trec2020pr"]:
        if args.sample_data == -1: args.sample_data = None
        train = pd.read_csv(args.data_folder + args.task + "/train.tsv",
                            sep="\t",
                            nrows=args.sample_data)
        valid = pd.read_csv(args.data_folder + args.task + "/valid.tsv",
                            sep="\t",
                            nrows=args.sample_data)
    elif args.task == "scisumm":
        train, valid = preprocess_scisumm.transform_to_dfs(
            "../data/Training-Set-2019/Task1/From-Training-Set-2018/")
    elif args.task == "scisumm_ranked":
        train, valid, test = preprocess_scisumm_ranked.transform_to_dfs(
            args.path_to_ranked_file, args.path_to_ranked_test,
            args.path_to_ranked_dev)

    # Choose the negative candidate sampler
    document_col = train.columns[1]
    ns_train = None
    ns_val = None
    if args.train_negative_sampler == 'random':
        ns_train = negative_sampling.RandomNegativeSampler(
            list(train[document_col].values), args.num_ns_train)
    elif args.train_negative_sampler == 'bm25':
        ns_train = negative_sampling.BM25NegativeSamplerPyserini(
            list(train[document_col].values), args.num_ns_train,
            args.data_folder + "/" + args.task + "/anserini_train/",
            args.sample_data, args.anserini_folder)
    elif args.train_negative_sampler == 'sentenceBERT':
        ns_train = negative_sampling.SentenceBERTNegativeSampler(
            list(train[document_col].values), args.num_ns_train,
            args.data_folder + "/" + args.task + "/train_sentenceBERTembeds",
            args.sample_data, args.bert_sentence_model)
    if args.test_negative_sampler == 'random':
        ns_val = negative_sampling.RandomNegativeSampler(
            list(valid[document_col].values) +
            list(train[document_col].values), args.num_ns_eval)
    elif args.test_negative_sampler == 'bm25':
        ns_val = negative_sampling.BM25NegativeSamplerPyserini(
            list(valid[document_col].values) +
            list(train[document_col].values), args.num_ns_eval,
            args.data_folder + "/" + args.task + "/anserini_valid/",
            args.sample_data, args.anserini_folder)
    elif args.test_negative_sampler == 'sentenceBERT':
        ns_val = negative_sampling.SentenceBERTNegativeSampler(
            list(valid[document_col].values) +
            list(train[document_col].values), args.num_ns_eval,
            args.data_folder + "/" + args.task + "/valid_sentenceBERTembeds",
            args.sample_data, args.bert_sentence_model)

    # Create the loaders for the datasets, with the respective negative samplers
    dataloader = dataset.QueryDocumentDataLoader(
        train, valid, test, tokenizer, ns_train, ns_val, 'classification',
        args.train_batch_size, args.val_batch_size, args.max_seq_len,
        args.sample_data, args.data_folder + "/" + args.task)
    if args.task == "scisumm_ranked":
        with_ranked_list = True
    else:
        with_ranked_list = False
    train_loader, val_loader, test_loader = dataloader.get_pytorch_dataloaders(
        with_ranked_list)

    # Instantiate transformer model to be used
    model = BertForSequenceClassification.from_pretrained(
        args.transformer_model)
    model.resize_token_embeddings(len(dataloader.tokenizer))

    # Instantiate trainer that handles fitting.
    trainer = transformer_trainer.TransformerTrainer(
        model, train_loader, val_loader, test_loader, args.num_ns_eval,
        "classification", tokenizer, args.validate_every_epochs,
        args.num_validation_instances, args.num_epochs, args.lr,
        args.sacred_ex)

    # Train
    model_name = model.__class__.__name__
    logging.info("Fitting {} for {}{}".format(model_name, args.data_folder,
                                              args.task))
    trainer.fit()

    # Predict for test
    logging.info("Predicting")
    preds, labels, doc_ids, all_queries, preds_without_acc = trainer.validate()
    res = results_analyses_tools.evaluate_and_aggregate(
        preds, labels, [
            'R_10@1', 'R_10@2', 'R_10@5', 'R_2@1', 'accuracy_0.3',
            'accuracy_0.3_upto_1', 'precision_0.3', 'recall_0.3',
            'f_score_0.3', 'accuracy_0.4', 'accuracy_0.4_upto_1',
            'precision_0.4', 'recall_0.4', 'f_score_0.4', 'accuracy_0.5',
            'accuracy_0.5_upto_1', 'precision_0.5', 'recall_0.5', 'f_score_0.5'
        ])
    for metric, v in res.items():
        logging.info("Test {} : {:4f}".format(metric, v))

    # Saving predictions and labels to a file
    max_preds_column = max([len(l) for l in preds])
    preds_df = pd.DataFrame(
        preds,
        columns=["prediction_" + str(i) for i in range(max_preds_column)])
    preds_df.to_csv(args.output_dir + "/" + args.run_id + "/predictions.csv",
                    index=False)

    labels_df = pd.DataFrame(
        labels, columns=["label_" + str(i) for i in range(max_preds_column)])
    labels_df.to_csv(args.output_dir + "/" + args.run_id + "/labels.csv",
                     index=False)

    new_preds = list((np.array(preds_without_acc) > 0.3).astype(int))
    d = {
        'query': all_queries,
        'doc_id': doc_ids,
        'label': new_preds,
        'similiarity': preds_without_acc
    }

    df_doc_ids = pd.DataFrame(d)
    df_doc_ids_ones = df_doc_ids[df_doc_ids['label'] == 1]
    df_doc_ids_ones = df_doc_ids_ones.groupby('query').agg(list).reset_index()
    df_doc_ids_non_ones = df_doc_ids.groupby('query').agg(list).reset_index()
    new_df = []
    for i, row in df_doc_ids_non_ones.iterrows():
        if all([v == 0 for v in row['label']]):
            highest_value = [
                x for _, x in sorted(zip(row['similiarity'], row['doc_id']),
                                     key=lambda pair: pair[0])
            ]
            highest_value_sim = [x for x in sorted(row['similiarity'])]

            row['label'] = [1]
            row['doc_id'] = [highest_value[0]]
            row['similiarity'] = [highest_value_sim[0]]

            new_df.append(row)

    result = pd.concat([df_doc_ids_ones, pd.DataFrame(new_df)])
    result.to_csv(args.output_dir + "/" + args.run_id + "/doc_ids_dev.csv",
                  index=False,
                  sep='\t')

    # predict on the test set
    preds, labels, doc_ids, all_queries, preds_without_acc = trainer.test()

    new_preds = list((np.array(preds_without_acc) > 0.3).astype(int))
    d = {
        'query': all_queries,
        'doc_id': doc_ids,
        'label': new_preds,
        'similiarity': preds_without_acc
    }

    df_doc_ids = pd.DataFrame(d)
    df_doc_ids_ones = df_doc_ids[df_doc_ids['label'] == 1]
    df_doc_ids_ones = df_doc_ids_ones.groupby('query').agg(list).reset_index()
    df_doc_ids_non_ones = df_doc_ids.groupby('query').agg(list).reset_index()
    new_df = []
    for i, row in df_doc_ids_non_ones.iterrows():
        if all([v == 0 for v in row['label']]):
            highest_value = [
                x for _, x in sorted(zip(row['similiarity'], row['doc_id']),
                                     key=lambda pair: pair[0])
            ]
            highest_value_sim = [x for x in sorted(row['similiarity'])]

            row['label'] = [1]
            row['doc_id'] = [highest_value[0]]
            row['similiarity'] = [highest_value_sim[0]]

            new_df.append(row)

    result = pd.concat([df_doc_ids_ones, pd.DataFrame(new_df)])
    result.to_csv(args.output_dir + "/" + args.run_id + "/doc_ids_test.csv",
                  index=False,
                  sep='\t')

    # Saving model to a file
    if args.save_model:
        torch.save(model.state_dict(),
                   args.output_dir + "/" + args.run_id + "/model")

    # In case we want to get uncertainty estimations at prediction time
    if args.predict_with_uncertainty_estimation:
        logging.info("Predicting with dropout.")
        preds, uncertainties, labels, foward_passes_preds = trainer.test_with_dropout(
            args.num_foward_prediction_passes)
        res = results_analyses_tools.evaluate_and_aggregate(
            preds, labels, ['R_10@1'])
        for metric, v in res.items():
            logging.info(
                "Test (w. dropout and {} foward passes) {} : {:4f}".format(
                    args.num_foward_prediction_passes, metric, v))

        max_preds_column = max([len(l) for l in preds])
        preds_df = pd.DataFrame(
            preds,
            columns=["prediction_" + str(i) for i in range(max_preds_column)])
        preds_df.to_csv(args.output_dir + "/" + args.run_id +
                        "/predictions_with_dropout.csv",
                        index=False)

        for i, f_pass_preds in enumerate(foward_passes_preds):
            preds_df = pd.DataFrame(f_pass_preds,
                                    columns=[
                                        "prediction_" + str(i)
                                        for i in range(max_preds_column)
                                    ])
            preds_df.to_csv(
                args.output_dir + "/" + args.run_id +
                "/predictions_with_dropout_f_pass_{}.csv".format(i),
                index=False)

        labels_df = pd.DataFrame(
            labels,
            columns=["label_" + str(i) for i in range(max_preds_column)])
        labels_df.to_csv(args.output_dir + "/" + args.run_id + "/labels.csv",
                         index=False)

        uncertainties_df = pd.DataFrame(
            uncertainties,
            columns=["uncertainty_" + str(i) for i in range(max_preds_column)])
        uncertainties_df.to_csv(args.output_dir + "/" + args.run_id +
                                "/uncertainties.csv",
                                index=False)

    return trainer.best_ndcg