コード例 #1
0
ファイル: run_bert.py プロジェクト: LanJiang315/PR-project
def train(model, args):
    model_save_dir = os.path.join(args.result_dir, args.model_name)
    if not os.path.exists(model_save_dir):
        os.mkdir(model_save_dir)
    train_samples = load_pairwise_data(args, "train")
    train_dataloader = DataLoader(train_samples, shuffle=True, batch_size=args.train_batch_size)
    train_loss = losses.MultipleNegativesRankingLoss(model)

    # Development set: Measure correlation between cosine score and gold labels
    logging.info("Read STSbenchmark dev dataset")
    dev_samples = load_pairwise_data(args, "val")
    binary_acc_evaluator = evaluation.BinaryClassificationEvaluator.from_input_examples(dev_samples, name='val')

    # Configure the training. We skip evaluation in this example
    warmup_steps = math.ceil(len(train_dataloader) * args.num_epochs * 0.1)  # 10% of train data for warm-up
    logging.info("Warmup-steps: {}".format(warmup_steps))

    # Train the model
    model.fit(train_objectives=[(train_dataloader, train_loss)],
              evaluator=binary_acc_evaluator,
              epochs=args.num_epochs,
              evaluation_steps=1000,
              warmup_steps=warmup_steps,
              save_best_model=True,
              output_path=model_save_dir)
コード例 #2
0
def get_loss(loss_type, model):
    if loss_type == 'BatchAllTripletLoss':
        return losses.BatchAllTripletLoss(model=model)

    if loss_type == 'BatchHardSoftMarginTripletLoss':
        return losses.BatchHardSoftMarginTripletLoss(model=model)

    if loss_type == 'BatchHardTripletLoss':
        return losses.BatchHardTripletLoss(model=model)

    if loss_type == 'BatchSemiHardTripletLoss':
        return losses.BatchSemiHardTripletLoss(model=model)

    if loss_type == 'ContrastiveLoss':
        return losses.ContrastiveLoss(model=model)

    if loss_type == 'CosineSimilarityLoss':
        return losses.CosineSimilarityLoss(model=model)

    if loss_type == 'MegaBatchMarginLoss':
        return losses.MegaBatchMarginLoss(model=model)

    if loss_type == 'MultipleNegativesRankingLoss':
        return losses.MultipleNegativesRankingLoss(model=model)

    if loss_type == 'OnlineContrastiveLoss':
        return losses.OnlineContrastiveLoss(model=model)

    raise ValueError('Invalid loss type')
コード例 #3
0
def cli_main():
    # 作者在issues里提到的多语言的预训练模型 xlm-r-40langs-bert-base-nli-stsb-mean-tokens
    # 针对信息检索任务的多语言预训练模型  distilbert-multilingual-nli-stsb-quora-ranking
    model = SentenceTransformer(
        'distilbert-multilingual-nli-stsb-quora-ranking')

    num_epochs = 10
    train_batch_size = 64
    model_save_path = os.path.join(
        cur_dir, 'output/training_MultipleNegativesRankingLoss-' +
        datetime.now().strftime("%Y-%m-%d_%H-%M-%S"))
    os.makedirs(model_save_path, exist_ok=True)

    colab_dir = "/content/drive/My Drive/data/nlp"
    data_file = os.path.join(colab_dir, "LCCC-large.json")
    train_samples = get_data(data_file)

    # After reading the train_samples, we create a SentencesDataset and a DataLoader
    train_dataset = SentencesDataset(train_samples, model=model)
    train_dataloader = DataLoader(train_dataset,
                                  shuffle=True,
                                  batch_size=train_batch_size)
    train_loss = losses.MultipleNegativesRankingLoss(model)

    ###### Duplicate Questions Information Retrieval ######
    evaluators = []
    data_file = os.path.join(colab_dir, "STC.json")
    max_ir_num = 5000
    max_corpus_size = 100000
    ir_queries, ir_corpus, ir_relevant_docs = get_iq_corpus(
        data_file, max_ir_num, max_corpus_size)

    ir_evaluator = evaluation.InformationRetrievalEvaluator(
        ir_queries, ir_corpus, ir_relevant_docs)
    evaluators.append(ir_evaluator)
    seq_evaluator = evaluation.SequentialEvaluator(
        evaluators, main_score_function=lambda scores: scores[-1])

    logging.info("Evaluate model without training")
    seq_evaluator(model, epoch=0, steps=0, output_path=model_save_path)

    # Train the model
    model.fit(train_objectives=[(train_dataloader, train_loss)],
              evaluator=seq_evaluator,
              epochs=num_epochs,
              warmup_steps=1000,
              output_path=model_save_path,
              output_path_ignore_not_empty=True)
コード例 #4
0
def train():
    # We construct the SentenceTransformer bi-encoder from scratch
    word_embedding_model = models.Transformer(model_name, max_seq_length=350)
    pooling_model = models.Pooling(
        word_embedding_model.get_word_embedding_dimension())
    model = SentenceTransformer(modules=[word_embedding_model, pooling_model])

    model_save_path = 'output/training_ms-marco_bi-encoder-' + model_name.replace(
        "/", "-") + '-' + datetime.now().strftime("%Y-%m-%d_%H-%M-%S")

    # Read our training file. qidpidtriples consists of triplets (qid, positive_pid, negative_pid)
    train_filepath = os.path.join(
        data_folder, 'msmarco-qidpidtriples.rnd-shuf.train-eval.tsv')

    # Create the evaluator that is called during training
    queries = read_queries()
    corpus = read_corpus()
    dev_queries, dev_corpus, dev_rel_docs = prepare_data_for_evaluation(
        queries, corpus)
    ir_evaluator = evaluation.InformationRetrievalEvaluator(
        dev_queries, dev_corpus, dev_rel_docs, name='ms-marco-train_eval')

    # For training the SentenceTransformer model, we need a dataset, a dataloader, and a loss used for training.
    train_dataset = TripletsDataset(model=model,
                                    queries=queries,
                                    corpus=corpus,
                                    triplets_file=train_filepath)
    train_dataloader = DataLoader(train_dataset,
                                  shuffle=False,
                                  batch_size=train_batch_size)
    train_loss = losses.MultipleNegativesRankingLoss(model=model)

    # print(next(iter(train_dataloader)))
    # return

    # Train the model
    model.fit(train_objectives=[(train_dataloader, train_loss)],
              evaluator=ir_evaluator,
              epochs=1,
              warmup_steps=1000,
              output_path=model_save_path,
              evaluation_steps=5000,
              use_amp=True)
コード例 #5
0
            train_samples_MultipleNegativesRankingLoss.append(
                InputExample(texts=[row['question1'], row['question2']],
                             label=1))
            train_samples_MultipleNegativesRankingLoss.append(
                InputExample(texts=[row['question2'], row['question1']],
                             label=1)
            )  # if A is a duplicate of B, then B is a duplicate of A

# Create data loader and loss for MultipleNegativesRankingLoss
train_dataset_MultipleNegativesRankingLoss = SentencesDataset(
    train_samples_MultipleNegativesRankingLoss, model=model)
train_dataloader_MultipleNegativesRankingLoss = DataLoader(
    train_dataset_MultipleNegativesRankingLoss,
    shuffle=True,
    batch_size=train_batch_size)
train_loss_MultipleNegativesRankingLoss = losses.MultipleNegativesRankingLoss(
    model)

# Create data loader and loss for OnlineContrastiveLoss
train_dataset_ConstrativeLoss = SentencesDataset(train_samples_ConstrativeLoss,
                                                 model=model)
train_dataloader_ConstrativeLoss = DataLoader(train_dataset_ConstrativeLoss,
                                              shuffle=True,
                                              batch_size=train_batch_size)
train_loss_ConstrativeLoss = losses.OnlineContrastiveLoss(
    model=model, distance_metric=distance_metric, margin=margin)

################### Development  Evaluators ##################
# We add 3 evaluators, that evaluate the model on Duplicate Questions pair classification,
# Duplicate Questions Mining, and Duplicate Questions Information Retrieval
evaluators = []
logging.info(
    "Step 3: Train bi-encoder: {} over labeled QQP (target dataset)".format(
        model_name))

# Convert the dataset to a DataLoader ready for training
logging.info("Loading BERT labeled QQP dataset")
qqp_train_data = [
    InputExample(texts=[data[0], data[1]], label=score)
    for (data, score) in zip(silver_data, binary_silver_scores)
]

train_dataloader = DataLoader(qqp_train_data,
                              shuffle=True,
                              batch_size=batch_size)
train_loss = losses.MultipleNegativesRankingLoss(bi_encoder)

###### Classification ######
# Given (quesiton1, question2), is this a duplicate or not?
# The evaluator will compute the embeddings for both questions and then compute
# a cosine similarity. If the similarity is above a threshold, we have a duplicate.
logging.info("Read QQP dev dataset")

dev_sentences1 = []
dev_sentences2 = []
dev_labels = []

with open(os.path.join(qqp_dataset_path, "classification/dev_pairs.tsv"),
          encoding='utf8') as fIn:
    reader = csv.DictReader(fIn, delimiter='\t', quoting=csv.QUOTE_NONE)
    for row in reader:
コード例 #7
0
def my_own_train(use_wandb=False):

    model_name = 'msmarco-distilbert-base-v2'
    # train_batch_size = 16
    num_epochs = 8
    model_save_path = 'output/training_ms-marco_bi-encoder-' + model_name.replace(
        "/", "-") + '-' + datetime.now().strftime("%Y-%m-%d_%H-%M-%S")

    # Load a pre-trained sentence transformer model
    model = SentenceTransformer(model_name)
    model.max_seq_length = 64
    # 128 or larger does not work to evaluate because it runs out of memory
    model.parallel_tokenization = False
    learning_rate = 5e-5
    warmup_steps = 3000
    evaluation_steps = 2000

    code_snippets_file = '../unif/data/parallel_bodies_n1000'
    descriptions_file = '../unif/data/parallel_desc_n1000'
    dataset = CodeDescRawTripletDataset(code_snippets_file, descriptions_file)
    train_data_loader = DataLoader(dataset,
                                   shuffle=True,
                                   batch_size=train_batch_size)
    train_loss = losses.MultipleNegativesRankingLoss(model=model)

    code_ids, desc_ids, relevant_docs = prepare_dataset_for_topn_evaluation(
        dataset)
    # ir_evaluator = InformationRetrievalEvaluator(desc_ids, code_ids, relevant_docs, name='all-train_eval')
    ir_evaluator = WandbInformationRetrievalEvaluator(desc_ids,
                                                      code_ids,
                                                      relevant_docs,
                                                      name='all-train_eval',
                                                      use_wandb=True)

    if use_wandb:
        wandb.init(project='code-search', name='sbert', reinit=True)
        config = wandb.config
        config.max_seq_length = model.max_seq_length
        config.learning_rate = learning_rate
        config.warmup_steps = warmup_steps
        config.evaluation_steps = evaluation_steps
        # config.embedding_size = embedding_size
        # config.evaluate_size = evaluate_size
        # config.margin = margin
        config.num_epochs = num_epochs
        config.train_size = len(dataset)
        wandb.watch(model)

    model.fit(
        train_objectives=[(train_data_loader, train_loss)],
        evaluator=ir_evaluator,
        epochs=num_epochs,
        # scheduler = 'warmupconstant',
        warmup_steps=warmup_steps,
        output_path=model_save_path,
        optimizer_params={'lr': learning_rate},
        evaluation_steps=evaluation_steps,
        use_amp=True)

    test_evaluator = InformationRetrievalEvaluator(desc_ids,
                                                   code_ids,
                                                   relevant_docs,
                                                   name='all-test_ir')
    test_evaluator(model)
コード例 #8
0
    def train(self, train_df, eval_df):
        """

        :param train_df: dataframe with columns 'text_a', 'text_b', 'labels'
        :param eval_df: dataframe with columns 'text_a', 'text_b', 'labels'
        :return:
        """

        # format training data
        if "text_a" in train_df.columns and "text_b" in train_df.columns and "labels" in train_df.columns:
            if self.args.do_lower_case:
                train_df.loc[:, 'text_a'] = train_df['text_a'].str.lower()
                train_df.loc[:, 'text_b'] = train_df['text_b'].str.lower()

            train_examples = [
                InputExample(str(i), texts=[text_a, text_b], label=label)
                for i, (text_a, text_b, label) in enumerate(
                    zip(
                        train_df["text_a"].astype(str),
                        train_df["text_b"].astype(str),
                        train_df["labels"].astype(int),
                    ))
            ]
        else:
            raise KeyError(
                'Training data processing - Required columns not found!')

        # format evaluation data
        if "text_a" in train_df.columns and "text_b" in train_df.columns and "labels" in eval_df.columns:
            if self.args.do_lower_case:
                eval_df.loc[:, 'text_a'] = eval_df['text_a'].str.lower()
                eval_df.loc[:, 'text_b'] = eval_df['text_b'].str.lower()

            evaluator = evaluation.BinaryClassificationEvaluator(
                list(eval_df["text_a"]),
                list(eval_df["text_b"]),
                list(eval_df["labels"].astype(int)),
                batch_size=self.args.eval_batch_size)
        else:
            raise KeyError(
                'Evaluation data processing - Required columns not found!')

        # Define train dataset, the dataloader and the train loss
        train_dataloader = DataLoader(train_examples,
                                      shuffle=True,
                                      batch_size=self.args.train_batch_size)
        if self.args.loss_func is not None and self.args.loss_func == 'MultipleNegativesRankingLoss':
            train_loss = losses.MultipleNegativesRankingLoss(self.model)
        else:
            distance_metric = losses.SiameseDistanceMetric.COSINE_DISTANCE
            train_loss = losses.OnlineContrastiveLoss(
                model=self.model,
                distance_metric=distance_metric,
                margin=self.args.margin)

        # Tune the model
        self.model.fit(
            train_objectives=[(train_dataloader, train_loss)],
            epochs=self.args.num_train_epochs,
            warmup_steps=self.args.warmup_steps,
            optimizer_params={'lr': self.args.learning_rate},
            weight_decay=self.args.weight_decay,
            evaluator=evaluator,
            evaluation_steps=self.args.evaluate_during_training_steps,
            max_grad_norm=self.args.max_grad_norm,
            output_path=self.args.best_model_dir,
            show_progress_bar=self.args.show_progress_bar)

        evaluation_file = os.path.join(self.args.best_model_dir,
                                       evaluator.csv_file)
        eval_results_df = pd.read_csv(evaluation_file)
        eval_results_df.sort_values(self.score_type,
                                    inplace=True,
                                    ascending=False,
                                    ignore_index=True)
        self.threshold = eval_results_df.loc[0, self.threshold_type]
        print(
            f'Set model threshold to {self.threshold} acquiring a {self.score_type} of {eval_results_df.loc[0, self.score_type]}'
        )

        return self.threshold
コード例 #9
0
def main():
    parser = argparse.ArgumentParser()

    # Input and output configs
    parser.add_argument("--task", default=None, type=str, required=True,
                        help="the task to run bert ranker for")
    parser.add_argument("--data_folder", default=None, type=str, required=True,
                        help="the folder containing data")
    parser.add_argument("--output_dir", default=None, type=str, required=True,
                        help="the folder to output predictions")
    parser.add_argument("--negative_sampler", default="random", type=str, required=False,
                        help="negative sampling procedure to use ['random', 'bm25', 'sentence_transformer']")
    parser.add_argument("--anserini_folder", default="", type=str, required=True,
                        help="Path containing the anserini bin <anserini_folder>/target/appassembler/bin/IndexCollection")
    parser.add_argument("--sentence_bert_ns_model", default="all-MiniLM-L6-v2", type=str, required=False,
                        help="model to use for sentenceBERT negative sampling.")

    parser.add_argument('--denoise_negatives', dest='denoise_negatives', action='store_true')
    parser.add_argument('--no-denoise_negatives', dest='denoise_negatives', action='store_false')
    parser.set_defaults(denoise_negatives=False)
    parser.add_argument("--num_ns_for_denoising", default=100, type=int, required=False,
                        help="Only used for --denoise_negatives. Number of total of samples to retrieve and get the bottom 10.")

    parser.add_argument("--generative_sampling_model", default="all-MiniLM-L6-v2", type=str, required=False,
                        help="model to use for generating negative samples on the go.")

    parser.add_argument('--remove_cand_subsets', dest='remove_cand_subsets', action='store_true')
    parser.add_argument('--dont_remove_cand_subsets', dest='remove_cand_subsets', action='store_false')
    parser.set_defaults(remove_cand_subsets=True)

    #which part of the context we use to sample negatives.
    parser.add_argument('--last_utterance_only', dest='last_utterance_only', action='store_true')
    parser.add_argument('--all_utterances', dest='last_utterance_only', action='store_false')
    parser.set_defaults(last_utterance_only=False)

    # External corpus to augment negative sampling
    parser.add_argument('--external_corpus', dest='use_external_corpus', action='store_true')
    parser.add_argument('--dont_use_external_corpus', dest='use_external_corpus', action='store_false')
    parser.set_defaults(use_external_corpus=False)

    # #Training procedure
    parser.add_argument("--num_epochs", default=3, type=int, required=False,
                        help="Number of epochs for training.")
    parser.add_argument("--train_batch_size", default=8, type=int, required=False,
                        help="Training batch size.")
    # #Model hyperparameters
    parser.add_argument("--transformer_model", default="bert-base-cased", type=str, required=False,
                        help="Bert model to use (default = bert-base-cased).")
    parser.add_argument("--loss", default='MultipleNegativesRankingLoss', type=str, required=False,
                        help="Loss function to use ['MultipleNegativesRankingLoss', 'TripletLoss', 'MarginMSELoss']")

    ## Wandb project name 
    parser.add_argument("--wandb_project", default='train_sentence_transformer', type=str, required=False,
                        help="name of the wandb project")
    parser.add_argument("--seed", default=42, type=int, required=False,
                        help="Random seed.")

    args = parser.parse_args()

    torch.manual_seed(args.seed)
    torch.cuda.manual_seed(args.seed)

    max_seq_length = 300
    if args.transformer_model == 'all-mpnet-base-v2' or args.transformer_model == 'msmarco-bert-base-dot-v5':
        model = SentenceTransformer(args.transformer_model)
        model.max_seq_length = max_seq_length
    else:
        word_embedding_model = models.Transformer(args.transformer_model, max_seq_length=max_seq_length)
        tokens = ['[UTTERANCE_SEP]', '[TURN_SEP]', '[AUG]']
        word_embedding_model.tokenizer.add_tokens(tokens, special_tokens=True)
        word_embedding_model.auto_model.resize_token_embeddings(len(word_embedding_model.tokenizer))
        pooling_model = models.Pooling(word_embedding_model.get_word_embedding_dimension(),
                                pooling_mode_mean_tokens=True,
                                pooling_mode_cls_token=False,
                                pooling_mode_max_tokens=False)
        model = SentenceTransformer(modules=[word_embedding_model, pooling_model])


    eval_only = False
    if eval_only:
        logging.info("Skipping training (eval_only=True)")
    
    else:
        logging.info("Creating train CRR dataset for {} using {}.".format(args.task, args.negative_sampler))
        crr_reader = CRRBenchmarkDataReader('{}/{}'.format(args.data_folder, args.task))
        train_data = crr_reader.get_examples("train.tsv", args.negative_sampler,
                                    args.anserini_folder, args.sentence_bert_ns_model, args.loss, args.output_dir,
                                    True, False,
                                    args.denoise_negatives, args.num_ns_for_denoising,
                                    args.generative_sampling_model,
                                    args.remove_cand_subsets,
                                    args.last_utterance_only,
                                    args.use_external_corpus)
        train_dataloader = DataLoader(train_data, shuffle=True, batch_size=args.train_batch_size)
    
    if args.loss == 'MultipleNegativesRankingLoss':
        train_loss = losses.MultipleNegativesRankingLoss(model=model, similarity_fct=util.dot_score)
    elif args.loss == 'MarginMSELoss':
        train_loss = losses.MarginMSELoss(model=model)
    elif args.loss == 'TripletLoss':
        train_loss = losses.TripletLoss(model=model)
    elif args.loss == 'ContrastiveLoss':
        train_loss = losses.ContrastiveLoss(model=model)
    elif args.loss == 'OnlineContrastiveLoss':
        train_loss = losses.OnlineContrastiveLoss(model=model)


    ns_description = args.negative_sampler
    if args.negative_sampler == 'sentence_transformer':
        ns_description+="_{}".format(args.sentence_bert_ns_model)

    if args.negative_sampler == 'generative':
        ns_description+="_{}".format(args.generative_sampling_model)

    wandb.init(project=args.wandb_project)
    wandb.config.update(args)

    if not eval_only: # this is the eval data for the training, not the actual evaluation
        logging.info("Getting eval data")
        examples_dev = crr_reader.get_examples('valid.tsv', 
            args.negative_sampler, args.anserini_folder, args.sentence_bert_ns_model, args.loss, args.output_dir, eval_data=True)
        examples_dev = examples_dev[0:(11*500)]
        eval_samples = []
        docs = []
        for i, example in enumerate(examples_dev):
            if (i+1)%11==0:
                eval_samples.append({'query': example.texts[0], 
                                    'positive': [example.texts[1]],
                                    'negative': docs
                })
                docs=[]
            else:
                docs.append(example.texts[2])
        evaluator = RerankingEvaluator(eval_samples, write_csv=True, similarity_fct=util.dot_score)
        warmup_steps = math.ceil(len(train_data)*args.num_epochs/args.train_batch_size*0.1) #10% of train data for warm-up
        logging.info("Warmup-steps: {}".format(warmup_steps))

        logging.info("Fitting sentenceBERT for {}".format(args.task))

        model.fit(train_objectives=[(train_dataloader, train_loss)],
            evaluator=evaluator,
            epochs=args.num_epochs,
            evaluation_steps=100,          
            steps_per_epoch=10000,        
            warmup_steps=warmup_steps,
            output_path=args.output_dir+"{}_{}_ns_{}_loss_{}".format(args.transformer_model, args.task, ns_description, args.loss))

    logging.info("Evaluating for full retrieval of responses to dialogue.")

    train = pd.read_csv(args.data_folder+args.task+"/train.tsv", sep="\t")
    test = pd.read_csv(args.data_folder+args.task+"/test.tsv", sep="\t")

    ns_test_sentenceBERT = negative_sampling.SentenceBERTNegativeSampler(list(train["response"].values)+list(test["response"].values), 10, 
                   args.data_folder+args.task+"/test_sentenceBERTembeds", -1, 
                   args.output_dir+"{}_{}_ns_{}_loss_{}".format(args.transformer_model, args.task, ns_description, args.loss),
                   use_cache_for_embeddings=False)
    
    ns_info = [
        (ns_test_sentenceBERT, 
        ["cand_sentenceBERT_{}".format(i) for i in range(10)] + ["sentenceBERT_retrieved_relevant", "sentenceBERT_rank"], 
        'sentenceBERT')
    ]
    examples = []
    examples_cols = ["context", "relevant_response"] + \
        reduce(lambda x,y:x+y, [t[1] for t in ns_info])
    logging.info("Retrieving candidates using different negative sampling strategies for {}.".format(args.task))
    recall_df = []
    for idx, row in enumerate(tqdm(test.itertuples(index=False), total=len(test))):
        context = row[0]
        relevant_response = row[1]
        instance = [context, relevant_response]

        for ns, _ , ns_name in ns_info:
            ns_candidates, scores, had_relevant, rank_relevant, _ = ns.sample(context, [relevant_response])
            for ns in ns_candidates:
                instance.append(ns)
            instance.append(had_relevant)
            instance.append(rank_relevant)
            if had_relevant:
                r10 = 1
            else:
                r10 = 0
            if rank_relevant == 0:
                r1 = 1
            else:
                r1 =0
            recall_df.append([r10, r1])
        examples.append(instance)

    recall_df  = pd.DataFrame(recall_df, columns = ["R@10", "R@1"])
    examples_df = pd.DataFrame(examples, columns=examples_cols)
    logging.info("R@10: {}".format(examples_df[[c for c in examples_df.columns if 'retrieved_relevant' in c]].sum()/examples_df.shape[0]))
    wandb.log({'R@10': (examples_df[[c for c in examples_df.columns if 'retrieved_relevant' in c]].sum()/examples_df.shape[0]).values[0]})
    rank_col = [c for c in examples_df.columns if 'rank' in c][0]
    logging.info("R@1: {}".format(examples_df[examples_df[rank_col]==0].shape[0]/examples_df.shape[0]))
    wandb.log({'R@1': examples_df[examples_df[rank_col]==0].shape[0]/examples_df.shape[0]})
    recall_df.to_csv(args.output_dir+"/recall_df_{}_{}_ns_{}_loss_{}.csv".format(args.transformer_model.replace("/", "-"), args.task, ns_description.replace("/", "-"), args.loss), index=False, sep="\t")
コード例 #10
0
ファイル: sentence_transformer_cli.py プロジェクト: j5bd/q
def train(model_name_or_path: str,
          hf_dataset: str,
          aspect: str,
          fold: Union[int, str],
          output_path: str,
          train_epochs: int = 3,
          train_batch_size: int = 25,
          eval_batch_size: int = 32,
          evaluation_steps: int = 5000,
          train_on_test: bool = False,
          loss: str = 'multiple_negatives_ranking',
          override: bool = False):
    """

    # $MODEL_NAME $HF_DATASET $ASPECT $FOLD $OUTPUT_DIR --train_epochs=3 --train_batch_size=$TRAIN_BATCH_SIZE --eval_batch_size=$EVAL_BATCH_SIZE

    Run with:
    $ export CUDA_VISIBLE_DEVICES=1
    $ ./sentence_transformer_cli.py train scibert-scivocab-uncased paperswithcode_task_docs 1 ./output/st_scibert/1 --train_epochs=3 --train_batch_size=25 --eval_batch_size=32


    :param loss: Training loss function (choices: multiple_negatives_ranking, cosine)
    :param train_on_test: If True, joint training on train and test set (validation disabled)
    :param aspect:
    :param evaluation_steps:
    :param train_epochs:
    :param model_name_or_path:
    :param hf_dataset:
    :param fold:
    :param output_path:
    :param train_batch_size:
    :param eval_batch_size:
    :param override:
    :return:
    """

    top_ks = [5, 10, 25, 50]
    # cuda_device = -1

    # hf_dataset = 'paperswithcode_task_docs'
    # model_name_or_path = 'scibert-scivocab-uncased'
    # fold = 1
    max_token_length = 336  # ssee pwc_token_stats.ipynb
    nlp_cache_dir = './data/nlp_cache'

    # train_batch_size = 25
    # eval_batch_size = 32
    # override = False

    # output_path = './output/pwc_task_st/1/sci-bert'
    # output_path = os.path.join(output_path, str(fold), model_name_or_path)  # output/1/sci-bert

    if os.path.exists(output_path) and not override:
        logger.error(f'Stop. Output path exists already: {output_path}')
        sys.exit(1)

    # if cuda_device >= 0:
    #     os.environ["CUDA_VISIBLE_DEVICES"] = str(cuda_device)

    # device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # Model path from env
    if not os.path.exists(model_name_or_path) and os.path.exists(
            os.path.join(env['bert_dir'], model_name_or_path)):
        model_name_or_path = os.path.join(env['bert_dir'], model_name_or_path)

    word_embedding_model = Transformer(model_name_or_path,
                                       max_seq_length=max_token_length)
    pooling_model = Pooling(
        word_embedding_model.get_word_embedding_dimension())

    model = SentenceTransformer(modules=[word_embedding_model, pooling_model])
    # tokenizer = BertTokenizer.from_pretrained(model_name_or_path)

    # dataset
    docs_ds = load_dataset(get_local_hf_dataset_path(hf_dataset),
                           name='docs',
                           cache_dir=nlp_cache_dir,
                           split='docs')
    train_ds = load_dataset(get_local_hf_dataset_path(hf_dataset),
                            name='relations',
                            cache_dir=nlp_cache_dir,
                            split=get_train_split(aspect, fold))
    test_ds = load_dataset(get_local_hf_dataset_path(hf_dataset),
                           name='relations',
                           cache_dir=nlp_cache_dir,
                           split=get_test_split(aspect, fold))

    # filter for positive labels only
    train_ds = train_ds.filter(lambda row: row['label'] == 'y')

    logger.info(f'After filtering: {len(train_ds):,}')

    # joint training on train and test?
    if train_on_test:
        #
        # import pyarrow
        # from datasets.arrow_dataset import Dataset
        #
        # full_ds_table = pyarrow.concat_tables([train_ds.data, test_ds.data])
        # full_ds = Dataset(arrow_table=full_ds_table)
        raise NotImplementedError('TODO Evaluator')
    else:
        # standard training on test only
        train_sds = DocumentPairSentencesDataset(docs_ds,
                                                 train_ds,
                                                 model,
                                                 max_length=max_token_length,
                                                 forced_length=0)
        train_sds.tokenize_all_docs()

        evaluator = NearestNeighborsEvaluator(model,
                                              docs_ds,
                                              test_ds,
                                              top_ks=top_ks,
                                              batch_size=eval_batch_size,
                                              show_progress_bar=True)

    if loss == 'cosine':
        train_loss = losses.CosineSimilarityLoss(model)
    elif loss == 'multiple_negatives_ranking':
        # A nice advantage of MultipleNegativesRankingLoss is that it only requires positive pairs
        # https://github.com/UKPLab/sentence-transformers/tree/master/examples/training/quora_duplicate_questions
        train_loss = losses.MultipleNegativesRankingLoss(model)
    else:
        raise ValueError(f'Unsupported loss function: {loss}')

    train_dl = DataLoader(train_sds, shuffle=True, batch_size=train_batch_size)

    # Training
    model.fit(
        train_objectives=[(train_dl, train_loss)],
        epochs=train_epochs,  # try 1-4
        warmup_steps=100,
        evaluator=evaluator,
        evaluation_steps=
        evaluation_steps,  # increase to 5000 (full dataset => 20k steps)
        output_path=output_path,
        output_path_ignore_not_empty=True)

    logger.info('Training done')