示例#1
0
def get_loss(loss_type, model):
    if loss_type == 'BatchAllTripletLoss':
        return losses.BatchAllTripletLoss(model=model)

    if loss_type == 'BatchHardSoftMarginTripletLoss':
        return losses.BatchHardSoftMarginTripletLoss(model=model)

    if loss_type == 'BatchHardTripletLoss':
        return losses.BatchHardTripletLoss(model=model)

    if loss_type == 'BatchSemiHardTripletLoss':
        return losses.BatchSemiHardTripletLoss(model=model)

    if loss_type == 'ContrastiveLoss':
        return losses.ContrastiveLoss(model=model)

    if loss_type == 'CosineSimilarityLoss':
        return losses.CosineSimilarityLoss(model=model)

    if loss_type == 'MegaBatchMarginLoss':
        return losses.MegaBatchMarginLoss(model=model)

    if loss_type == 'MultipleNegativesRankingLoss':
        return losses.MultipleNegativesRankingLoss(model=model)

    if loss_type == 'OnlineContrastiveLoss':
        return losses.OnlineContrastiveLoss(model=model)

    raise ValueError('Invalid loss type')
def finetune_sbert(model, df, rep_sents, finetune_cfg):
    """Finetune the Sentence-BERT."""
    # setup
    train_size = finetune_cfg.get("train_size", 200000)
    sample_per_pair = finetune_cfg.get("sample_per_pair", 5)
    train_batch_size = finetune_cfg.get("train_batch_size", 32)
    epochs = finetune_cfg.get("epochs", 1)
    train = []
    n_sampled = 0
    cnts = [0, 0]  # [neg, pos]
    max_label_size = train_size // 2
    genres = df.genres.apply(set)

    with tqdm(total=train_size, position=0) as pbar:
        # sample sentence pairs
        while n_sampled < train_size:
            id1, id2 = np.random.randint(0, len(df), 2)
            label = int(bool(set.intersection(genres[id1], genres[id2])))

            if cnts[label] > max_label_size:
                continue

            sent_pairs = np.stack(np.meshgrid(rep_sents[id1],
                                              rep_sents[id2])).T.reshape(
                                                  -1, 2)
            if len(sent_pairs) <= sample_per_pair:
                samples = sent_pairs
            else:
                samples_idx = np.random.choice(sent_pairs.shape[0],
                                               sample_per_pair,
                                               replace=False)
                samples = sent_pairs[samples_idx]

            inexp = lambda pair: InputExample(texts=list(pair), label=label)
            samples = list(map(inexp, samples))
            train.extend(samples)

            n_sampled += len(samples)
            cnts[label] += len(samples)
            pbar.update(len(samples))

        # run finetune
        train_ds = SentencesDataset(train, model)
        train_obj = (
            DataLoader(train_ds, shuffle=True, batch_size=train_batch_size),
            losses.ContrastiveLoss(model=model),
        )
        model.fit(train_objectives=[train_obj],
                  epochs=epochs,
                  warmup_steps=100)
        os.makedirs("model/clustering/sbert", exist_ok=True)
        model.save("model/clustering/sbert")
def create_linked_posts(fl, data_dir, model, validate=None, is_test=False):
    train_linked_posts = []
    disbn = []

    with open(os.path.join(data_dir, fl)) as f:
        data = json.load(f)
        for obj in data:
            if obj['class'] == 'relevant':
                label = 1
            else:
                label = 0
            disbn.append(label)

            train_linked_posts.append(
                InputExample(texts=[obj['text_1'], obj['text_2']],
                             label=label))
    random.shuffle(train_linked_posts)

    if is_test:
        return train_linked_posts

    if max_size:
        train_linked_posts = train_linked_posts[:max_size]

    evaluator = None
    if linked_posts_str == validate:
        train_linked_posts, dev_linked_posts = train_test_split(
            train_linked_posts, stratify=disbn, test_size=0.1)
        evaluator = BinaryClassificationEvaluator.from_input_examples(
            dev_linked_posts, name='linked-posts')

    warmup_steps = math.ceil(
        len(train_linked_posts) * num_epochs / batch_size *
        0.1)  # 10% of train data for warm-up

    train_data_linked_posts = SentencesDataset(train_linked_posts, model=model)
    train_dataloader_linked_posts = DataLoader(train_data_linked_posts,
                                               shuffle=True,
                                               batch_size=batch_size)
    train_loss_linked_posts = losses.ContrastiveLoss(model=model)

    print('L: Number of training examples: ', len(train_linked_posts))

    global evaluation_steps
    evaluation_steps = math.ceil(len(train_linked_posts) / 0.1)

    return train_dataloader_linked_posts, train_loss_linked_posts, evaluator, warmup_steps
示例#4
0
        in_features=pooling_model.get_sentence_embedding_dimension(),
        out_features=sentence_embedding_dim,
        activation_function=nn.Tanh())

    model = SentenceTransformer(
        modules=[word_embedding_model, pooling_model, dense_model])

    train_examples = ld.load_dataset(dataset_name=dataset,
                                     dataset_type='train')

    train_dataset = SentencesDataset(train_examples, model)
    train_dataloader = DataLoader(train_dataset,
                                  shuffle=True,
                                  batch_size=batch_size)

    train_loss = losses.ContrastiveLoss(model=model)

    if task_type == "classification":
        evaluator = evaluation.EmbeddingSimilarityEvaluator.from_input_examples(
            train_examples)
    else:
        evaluator = evaluation.BinaryClassificationEvaluator.from_input_examples(
            train_examples)

    # Tune the model
    model.fit(train_objectives=[(train_dataloader, train_loss)],
              epochs=epochs,
              warmup_steps=100,
              output_path=model_save_path,
              evaluator=evaluator)
def create_search(collection,
                  query_file,
                  train,
                  data_dir,
                  model,
                  validate=None,
                  is_test=False):
    corpus = {}
    with open(os.path.join(data_dir, collection), 'r', encoding='utf8') as fIn:
        for line in fIn:
            pid, passage = line.strip().split("\t")
            corpus[pid] = passage

    queries = {}
    with open(os.path.join(data_dir, query_file), 'r', encoding='utf8') as fIn:
        for line in fIn:
            qid, query = line.strip().split("\t")
            queries[qid] = query

    train_search = []
    disbn = []
    with open(os.path.join(data_dir, train), 'r', encoding='utf8') \
            as f:
        added_q = set()
        for line in f.readlines():
            qid, pos_id, neg_id = line.strip().split()
            query = queries[qid]
            passage = corpus[pos_id]
            neg_passage = corpus[neg_id]
            if qid not in added_q:
                train_search.append(
                    InputExample(texts=[query, passage], label=1))
                disbn.append(1)
                added_q.add(qid)
            train_search.append(
                InputExample(texts=[query, neg_passage], label=0))
            disbn.append(0)
    random.shuffle(train_search)

    if is_test:
        return train_search

    if max_size:
        train_search = train_search[:max_size]
    evaluator = None

    if search_str == validate:
        train_search, dev_search = train_test_split(train_search,
                                                    stratify=disbn,
                                                    test_size=0.1)
        evaluator = BinaryClassificationEvaluator.from_input_examples(
            dev_search, name='search')

    warmup_steps = math.ceil(
        len(train_search) * num_epochs / batch_size *
        0.1)  # 10% of train data for warm-up

    # We create a DataLoader to load our train samples
    train_dataloader_search = DataLoader(train_search,
                                         shuffle=True,
                                         batch_size=batch_size)
    train_loss_search = losses.ContrastiveLoss(model=model)

    print('S: Number of training examples: ', len(train_search))

    global evaluation_steps
    evaluation_steps = math.ceil(len(train_search) / 0.1)

    return train_dataloader_search, train_loss_search, evaluator, warmup_steps
def main():
    parser = argparse.ArgumentParser()

    # Input and output configs
    parser.add_argument("--task", default=None, type=str, required=True,
                        help="the task to run bert ranker for")
    parser.add_argument("--data_folder", default=None, type=str, required=True,
                        help="the folder containing data")
    parser.add_argument("--output_dir", default=None, type=str, required=True,
                        help="the folder to output predictions")
    parser.add_argument("--negative_sampler", default="random", type=str, required=False,
                        help="negative sampling procedure to use ['random', 'bm25', 'sentence_transformer']")
    parser.add_argument("--anserini_folder", default="", type=str, required=True,
                        help="Path containing the anserini bin <anserini_folder>/target/appassembler/bin/IndexCollection")
    parser.add_argument("--sentence_bert_ns_model", default="all-MiniLM-L6-v2", type=str, required=False,
                        help="model to use for sentenceBERT negative sampling.")

    parser.add_argument('--denoise_negatives', dest='denoise_negatives', action='store_true')
    parser.add_argument('--no-denoise_negatives', dest='denoise_negatives', action='store_false')
    parser.set_defaults(denoise_negatives=False)
    parser.add_argument("--num_ns_for_denoising", default=100, type=int, required=False,
                        help="Only used for --denoise_negatives. Number of total of samples to retrieve and get the bottom 10.")

    parser.add_argument("--generative_sampling_model", default="all-MiniLM-L6-v2", type=str, required=False,
                        help="model to use for generating negative samples on the go.")

    parser.add_argument('--remove_cand_subsets', dest='remove_cand_subsets', action='store_true')
    parser.add_argument('--dont_remove_cand_subsets', dest='remove_cand_subsets', action='store_false')
    parser.set_defaults(remove_cand_subsets=True)

    #which part of the context we use to sample negatives.
    parser.add_argument('--last_utterance_only', dest='last_utterance_only', action='store_true')
    parser.add_argument('--all_utterances', dest='last_utterance_only', action='store_false')
    parser.set_defaults(last_utterance_only=False)

    # External corpus to augment negative sampling
    parser.add_argument('--external_corpus', dest='use_external_corpus', action='store_true')
    parser.add_argument('--dont_use_external_corpus', dest='use_external_corpus', action='store_false')
    parser.set_defaults(use_external_corpus=False)

    # #Training procedure
    parser.add_argument("--num_epochs", default=3, type=int, required=False,
                        help="Number of epochs for training.")
    parser.add_argument("--train_batch_size", default=8, type=int, required=False,
                        help="Training batch size.")
    # #Model hyperparameters
    parser.add_argument("--transformer_model", default="bert-base-cased", type=str, required=False,
                        help="Bert model to use (default = bert-base-cased).")
    parser.add_argument("--loss", default='MultipleNegativesRankingLoss', type=str, required=False,
                        help="Loss function to use ['MultipleNegativesRankingLoss', 'TripletLoss', 'MarginMSELoss']")

    ## Wandb project name 
    parser.add_argument("--wandb_project", default='train_sentence_transformer', type=str, required=False,
                        help="name of the wandb project")
    parser.add_argument("--seed", default=42, type=int, required=False,
                        help="Random seed.")

    args = parser.parse_args()

    torch.manual_seed(args.seed)
    torch.cuda.manual_seed(args.seed)

    max_seq_length = 300
    if args.transformer_model == 'all-mpnet-base-v2' or args.transformer_model == 'msmarco-bert-base-dot-v5':
        model = SentenceTransformer(args.transformer_model)
        model.max_seq_length = max_seq_length
    else:
        word_embedding_model = models.Transformer(args.transformer_model, max_seq_length=max_seq_length)
        tokens = ['[UTTERANCE_SEP]', '[TURN_SEP]', '[AUG]']
        word_embedding_model.tokenizer.add_tokens(tokens, special_tokens=True)
        word_embedding_model.auto_model.resize_token_embeddings(len(word_embedding_model.tokenizer))
        pooling_model = models.Pooling(word_embedding_model.get_word_embedding_dimension(),
                                pooling_mode_mean_tokens=True,
                                pooling_mode_cls_token=False,
                                pooling_mode_max_tokens=False)
        model = SentenceTransformer(modules=[word_embedding_model, pooling_model])


    eval_only = False
    if eval_only:
        logging.info("Skipping training (eval_only=True)")
    
    else:
        logging.info("Creating train CRR dataset for {} using {}.".format(args.task, args.negative_sampler))
        crr_reader = CRRBenchmarkDataReader('{}/{}'.format(args.data_folder, args.task))
        train_data = crr_reader.get_examples("train.tsv", args.negative_sampler,
                                    args.anserini_folder, args.sentence_bert_ns_model, args.loss, args.output_dir,
                                    True, False,
                                    args.denoise_negatives, args.num_ns_for_denoising,
                                    args.generative_sampling_model,
                                    args.remove_cand_subsets,
                                    args.last_utterance_only,
                                    args.use_external_corpus)
        train_dataloader = DataLoader(train_data, shuffle=True, batch_size=args.train_batch_size)
    
    if args.loss == 'MultipleNegativesRankingLoss':
        train_loss = losses.MultipleNegativesRankingLoss(model=model, similarity_fct=util.dot_score)
    elif args.loss == 'MarginMSELoss':
        train_loss = losses.MarginMSELoss(model=model)
    elif args.loss == 'TripletLoss':
        train_loss = losses.TripletLoss(model=model)
    elif args.loss == 'ContrastiveLoss':
        train_loss = losses.ContrastiveLoss(model=model)
    elif args.loss == 'OnlineContrastiveLoss':
        train_loss = losses.OnlineContrastiveLoss(model=model)


    ns_description = args.negative_sampler
    if args.negative_sampler == 'sentence_transformer':
        ns_description+="_{}".format(args.sentence_bert_ns_model)

    if args.negative_sampler == 'generative':
        ns_description+="_{}".format(args.generative_sampling_model)

    wandb.init(project=args.wandb_project)
    wandb.config.update(args)

    if not eval_only: # this is the eval data for the training, not the actual evaluation
        logging.info("Getting eval data")
        examples_dev = crr_reader.get_examples('valid.tsv', 
            args.negative_sampler, args.anserini_folder, args.sentence_bert_ns_model, args.loss, args.output_dir, eval_data=True)
        examples_dev = examples_dev[0:(11*500)]
        eval_samples = []
        docs = []
        for i, example in enumerate(examples_dev):
            if (i+1)%11==0:
                eval_samples.append({'query': example.texts[0], 
                                    'positive': [example.texts[1]],
                                    'negative': docs
                })
                docs=[]
            else:
                docs.append(example.texts[2])
        evaluator = RerankingEvaluator(eval_samples, write_csv=True, similarity_fct=util.dot_score)
        warmup_steps = math.ceil(len(train_data)*args.num_epochs/args.train_batch_size*0.1) #10% of train data for warm-up
        logging.info("Warmup-steps: {}".format(warmup_steps))

        logging.info("Fitting sentenceBERT for {}".format(args.task))

        model.fit(train_objectives=[(train_dataloader, train_loss)],
            evaluator=evaluator,
            epochs=args.num_epochs,
            evaluation_steps=100,          
            steps_per_epoch=10000,        
            warmup_steps=warmup_steps,
            output_path=args.output_dir+"{}_{}_ns_{}_loss_{}".format(args.transformer_model, args.task, ns_description, args.loss))

    logging.info("Evaluating for full retrieval of responses to dialogue.")

    train = pd.read_csv(args.data_folder+args.task+"/train.tsv", sep="\t")
    test = pd.read_csv(args.data_folder+args.task+"/test.tsv", sep="\t")

    ns_test_sentenceBERT = negative_sampling.SentenceBERTNegativeSampler(list(train["response"].values)+list(test["response"].values), 10, 
                   args.data_folder+args.task+"/test_sentenceBERTembeds", -1, 
                   args.output_dir+"{}_{}_ns_{}_loss_{}".format(args.transformer_model, args.task, ns_description, args.loss),
                   use_cache_for_embeddings=False)
    
    ns_info = [
        (ns_test_sentenceBERT, 
        ["cand_sentenceBERT_{}".format(i) for i in range(10)] + ["sentenceBERT_retrieved_relevant", "sentenceBERT_rank"], 
        'sentenceBERT')
    ]
    examples = []
    examples_cols = ["context", "relevant_response"] + \
        reduce(lambda x,y:x+y, [t[1] for t in ns_info])
    logging.info("Retrieving candidates using different negative sampling strategies for {}.".format(args.task))
    recall_df = []
    for idx, row in enumerate(tqdm(test.itertuples(index=False), total=len(test))):
        context = row[0]
        relevant_response = row[1]
        instance = [context, relevant_response]

        for ns, _ , ns_name in ns_info:
            ns_candidates, scores, had_relevant, rank_relevant, _ = ns.sample(context, [relevant_response])
            for ns in ns_candidates:
                instance.append(ns)
            instance.append(had_relevant)
            instance.append(rank_relevant)
            if had_relevant:
                r10 = 1
            else:
                r10 = 0
            if rank_relevant == 0:
                r1 = 1
            else:
                r1 =0
            recall_df.append([r10, r1])
        examples.append(instance)

    recall_df  = pd.DataFrame(recall_df, columns = ["R@10", "R@1"])
    examples_df = pd.DataFrame(examples, columns=examples_cols)
    logging.info("R@10: {}".format(examples_df[[c for c in examples_df.columns if 'retrieved_relevant' in c]].sum()/examples_df.shape[0]))
    wandb.log({'R@10': (examples_df[[c for c in examples_df.columns if 'retrieved_relevant' in c]].sum()/examples_df.shape[0]).values[0]})
    rank_col = [c for c in examples_df.columns if 'rank' in c][0]
    logging.info("R@1: {}".format(examples_df[examples_df[rank_col]==0].shape[0]/examples_df.shape[0]))
    wandb.log({'R@1': examples_df[examples_df[rank_col]==0].shape[0]/examples_df.shape[0]})
    recall_df.to_csv(args.output_dir+"/recall_df_{}_{}_ns_{}_loss_{}.csv".format(args.transformer_model.replace("/", "-"), args.task, ns_description.replace("/", "-"), args.loss), index=False, sep="\t")