Exemple #1
0
def evaluate(args, model):
    """ Train the model """
    dev_dataset = SequenceDataset(
        TextTokenIdsCache(args.preprocess_dir, f"{args.mode}-query"),
        args.max_seq_length)
    collate_fn = get_collate_function(args.max_seq_length)
    batch_size = args.pergpu_eval_batch_size
    if args.n_gpu > 1:
        batch_size *= args.n_gpu
    dev_dataloader = DataLoader(dev_dataset,
                                batch_size=batch_size,
                                collate_fn=collate_fn)

    if args.n_gpu > 1:
        model = torch.nn.DataParallel(model)
    qembedding_memmap = np.memmap(args.qmemmap_path,
                                  dtype="float32",
                                  shape=(len(dev_dataset), 768),
                                  mode="w+")
    with torch.no_grad():
        for step, (batch, qoffsets) in enumerate(tqdm(dev_dataloader)):
            batch = {k: v.to(args.model_device) for k, v in batch.items()}
            model.eval()
            embeddings = model(input_ids=batch["input_ids"],
                               attention_mask=batch["attention_mask"],
                               is_query=True)
            embeddings = embeddings.detach().cpu().numpy()
            qembedding_memmap[qoffsets] = embeddings
    return qembedding_memmap
Exemple #2
0
def doc_inference(model, args, embedding_size):
    if os.path.exists(args.doc_memmap_path):
        print(f"{args.doc_memmap_path} exists, skip inference")
        return
    doc_collator = single_get_collate_function(args.max_doc_length)
    ids_cache = TextTokenIdsCache(data_dir=args.preprocess_dir,
                                  prefix="passages")
    subset = list(range(len(ids_cache)))
    doc_dataset = SubsetSeqDataset(subset=subset,
                                   ids_cache=ids_cache,
                                   max_seq_length=args.max_doc_length)
    assert not os.path.exists(args.doc_memmap_path)
    doc_memmap = np.memmap(args.doc_memmap_path,
                           dtype=np.float32,
                           mode="w+",
                           shape=(len(doc_dataset), embedding_size))
    docid_memmap = np.memmap(args.docid_memmap_path,
                             dtype=np.int32,
                             mode="w+",
                             shape=(len(doc_dataset), ))
    try:
        prediction(model,
                   doc_collator,
                   args,
                   doc_dataset,
                   doc_memmap,
                   docid_memmap,
                   is_query=False)
    except:
        subprocess.check_call(["rm", args.doc_memmap_path])
        subprocess.check_call(["rm", args.docid_memmap_path])
        raise
Exemple #3
0
def query_inference(model, args, embedding_size):
    if os.path.exists(args.query_memmap_path):
        print(f"{args.query_memmap_path} exists, skip inference")
        return
    query_collator = single_get_collate_function(args.max_query_length)
    query_dataset = SequenceDataset(ids_cache=TextTokenIdsCache(
        data_dir=args.preprocess_dir, prefix=f"{args.mode}-query"),
                                    max_seq_length=args.max_query_length)
    query_memmap = np.memmap(args.query_memmap_path,
                             dtype=np.float32,
                             mode="w+",
                             shape=(len(query_dataset), embedding_size))
    queryids_memmap = np.memmap(args.queryids_memmap_path,
                                dtype=np.int32,
                                mode="w+",
                                shape=(len(query_dataset), ))
    try:
        prediction(model,
                   query_collator,
                   args,
                   query_dataset,
                   query_memmap,
                   queryids_memmap,
                   is_query=True)
    except:
        subprocess.check_call(["rm", args.query_memmap_path])
        subprocess.check_call(["rm", args.queryids_memmap_path])
        raise
Exemple #4
0
def query_inference(model, args, embedding_size):
    query_collator = single_get_collate_function(args.max_query_length)
    query_dataset = SequenceDataset(ids_cache=TextTokenIdsCache(
        data_dir=args.preprocess_dir, prefix=f"{args.mode}-query"),
                                    max_seq_length=args.max_query_length)
    query_memmap = np.memmap(args.query_memmap_path,
                             dtype=np.float32,
                             mode="w+",
                             shape=(len(query_dataset), embedding_size))
    queryids_memmap = np.memmap(args.queryids_memmap_path,
                                dtype=np.int32,
                                mode="w+",
                                shape=(len(query_dataset), ))

    prediction(model,
               query_collator,
               args,
               query_dataset,
               query_memmap,
               queryids_memmap,
               is_query=True)
Exemple #5
0
def main():
    parser = HfArgumentParser(
        (ModelArguments, DataTrainingArguments, MyTrainingArguments))
    model_args, data_args, training_args = parser.parse_args_into_dataclasses()

    if (os.path.exists(training_args.output_dir)
            and os.listdir(training_args.output_dir) and training_args.do_train
            and not training_args.overwrite_output_dir):
        raise ValueError(
            f"Output directory ({training_args.output_dir}) already exists and is not empty. "
            "Use --overwrite_output_dir to overcome.")

    # Setup logging
    logging.basicConfig(
        format="%(asctime)s-%(levelname)s-%(name)s- %(message)s",
        datefmt="%m/%d/%Y %H:%M:%S",
        level=logging.INFO
        if is_main_process(training_args.local_rank) else logging.WARN,
    )

    # Log on each process the small summary:
    logger.warning(
        f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}"
        +
        f"distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}"
    )
    # Set the verbosity to info of the Transformers logger (on main process only):
    if is_main_process(training_args.local_rank):
        transformers.utils.logging.set_verbosity_info()
    logger.info(f"Training/evaluation parameters {training_args}")

    # Set seed before initializing model.
    set_seed(training_args.seed)

    config = RobertaConfig.from_pretrained(
        model_args.model_name_or_path,
        finetuning_task="msmarco",
        gradient_checkpointing=model_args.gradient_checkpointing)
    tokenizer = RobertaTokenizer.from_pretrained(
        model_args.model_name_or_path,
        use_fast=False,
    )
    config.gradient_checkpointing = model_args.gradient_checkpointing

    data_args.label_path = os.path.join(data_args.data_dir, "train-qrel.tsv")
    rel_dict = load_rel(data_args.label_path)
    if training_args.hard_neg:
        train_dataset = TrainInbatchWithHardDataset(
            rel_file=data_args.label_path,
            rank_file=data_args.hardneg_path,
            queryids_cache=TextTokenIdsCache(data_dir=data_args.data_dir,
                                             prefix="train-query"),
            docids_cache=TextTokenIdsCache(data_dir=data_args.data_dir,
                                           prefix="passages"),
            max_query_length=data_args.max_query_length,
            max_doc_length=data_args.max_doc_length,
            hard_num=training_args.per_query_hard_num)
        data_collator = triple_get_collate_function(
            data_args.max_query_length,
            data_args.max_doc_length,
            rel_dict=rel_dict,
            padding=training_args.padding)
        model_class = RobertaDot_InBatch
    elif training_args.rand_neg:
        train_dataset = TrainInbatchWithRandDataset(
            rel_file=data_args.label_path,
            rand_num=training_args.per_query_hard_num,
            queryids_cache=TextTokenIdsCache(data_dir=data_args.data_dir,
                                             prefix="train-query"),
            docids_cache=TextTokenIdsCache(data_dir=data_args.data_dir,
                                           prefix="passages"),
            max_query_length=data_args.max_query_length,
            max_doc_length=data_args.max_doc_length)
        data_collator = triple_get_collate_function(
            data_args.max_query_length,
            data_args.max_doc_length,
            rel_dict=rel_dict,
            padding=training_args.padding)
        model_class = RobertaDot_Rand
    else:
        train_dataset = TrainInbatchDataset(
            rel_file=data_args.label_path,
            queryids_cache=TextTokenIdsCache(data_dir=data_args.data_dir,
                                             prefix="train-query"),
            docids_cache=TextTokenIdsCache(data_dir=data_args.data_dir,
                                           prefix="passages"),
            max_query_length=data_args.max_query_length,
            max_doc_length=data_args.max_doc_length)
        data_collator = dual_get_collate_function(
            data_args.max_query_length,
            data_args.max_doc_length,
            rel_dict=rel_dict,
            padding=training_args.padding)
        model_class = RobertaDot_InBatch

    model = model_class.from_pretrained(
        model_args.model_name_or_path,
        from_tf=bool(".ckpt" in model_args.model_name_or_path),
        config=config,
    )

    # Initialize our Trainer
    trainer = DRTrainer(
        model=model,
        args=training_args,
        train_dataset=train_dataset,
        eval_dataset=None,
        compute_metrics=None,
        tokenizer=tokenizer,
        data_collator=data_collator,
    )
    trainer.remove_callback(TensorBoardCallback)
    trainer.add_callback(
        MyTensorBoardCallback(tb_writer=SummaryWriter(
            os.path.join(training_args.output_dir, "log"))))
    trainer.add_callback(MyTrainerCallback())

    # Training
    trainer.train()
    trainer.save_model()  # Saves the tokenizer too for easy upload
Exemple #6
0
def train(args, model):
    """ Train the model """
    tb_writer = SummaryWriter(args.log_dir)
    passage_embeddings = np.memmap(args.pembed_path,
                                   dtype=np.float32,
                                   mode="r").reshape(
                                       -1, model.output_embedding_size)

    args.train_batch_size = args.per_gpu_batch_size
    train_dataset = TrainQueryDataset(
        TextTokenIdsCache(args.preprocess_dir, "train-query"),
        os.path.join(args.preprocess_dir, "train-qrel.MSMARCO.tsv"),
        args.max_seq_length)

    train_sampler = RandomSampler(train_dataset)
    collate_fn = get_collate_function(args.max_seq_length)
    train_dataloader = DataLoader(train_dataset,
                                  sampler=train_sampler,
                                  batch_size=args.train_batch_size,
                                  collate_fn=collate_fn)

    t_total = len(train_dataloader
                  ) // args.gradient_accumulation_steps * args.num_train_epochs

    no_decay = ['bias', 'LayerNorm.weight']
    optimizer_grouped_parameters = [{
        'params': [
            p for n, p in model.named_parameters()
            if not any(nd in n for nd in no_decay)
        ],
        'weight_decay':
        args.weight_decay
    }, {
        'params': [
            p for n, p in model.named_parameters()
            if any(nd in n for nd in no_decay)
        ],
        'weight_decay':
        0.0
    }]
    optimizer = AdamW(optimizer_grouped_parameters,
                      lr=args.learning_rate,
                      eps=args.adam_epsilon)
    scheduler = get_linear_schedule_with_warmup(
        optimizer,
        num_warmup_steps=args.warmup_steps,
        num_training_steps=t_total)

    index = load_index(passage_embeddings, args.index_path,
                       args.faiss_gpu_index, args.use_gpu
                       and not args.index_cpu)

    # Train!
    logger.info("***** Running training *****")
    logger.info("  Num examples = %d", len(train_dataset))
    logger.info("  Num Epochs = %d", args.num_train_epochs)
    logger.info("  Total train batch size (w. accumulation) = %d",
                args.train_batch_size * args.gradient_accumulation_steps)
    logger.info("  Gradient Accumulation steps = %d",
                args.gradient_accumulation_steps)
    logger.info("  Total optimization steps = %d", t_total)

    global_step = 0
    tr_loss, logging_loss = 0.0, 0.0
    tr_mrr, logging_mrr = 0.0, 0.0
    tr_recall, logging_recall = 0.0, 0.0
    model.zero_grad()
    train_iterator = trange(int(args.num_train_epochs), desc="Epoch")
    set_seed(
        args)  # Added here for reproductibility (even between python 2 and 3)

    for epoch_idx, _ in enumerate(train_iterator):
        epoch_iterator = tqdm(train_dataloader, desc="Iteration")
        for step, (batch, _, all_rel_poffsets) in enumerate(epoch_iterator):

            batch = {k: v.to(args.model_device) for k, v in batch.items()}
            model.train()
            query_embeddings = model(query_ids=batch["input_ids"],
                                     attention_mask_q=batch["attention_mask"],
                                     is_query=True)
            I_nearest_neighbor = index.search(
                query_embeddings.detach().cpu().numpy(), args.neg_topk)[1]

            loss = 0
            for retrieve_poffsets, cur_rel_poffsets, qembedding in zip(
                    I_nearest_neighbor, all_rel_poffsets, query_embeddings):
                target_labels = np.isin(retrieve_poffsets,
                                        cur_rel_poffsets).astype(np.int32)

                first_rel_pos = np.where(target_labels[:10])[0]
                mrr = 1 / (1 +
                           first_rel_pos[0]) if len(first_rel_pos) > 0 else 0

                tr_mrr += mrr / args.train_batch_size
                recall = 1 if mrr > 0 else 0
                tr_recall += recall / args.train_batch_size

                if np.sum(target_labels) == 0:
                    retrieve_poffsets = np.hstack(
                        [retrieve_poffsets, cur_rel_poffsets])
                    target_labels = np.hstack(
                        [target_labels, [True] * len(cur_rel_poffsets)])

                target_labels = target_labels.reshape(-1, 1)
                rel_diff = target_labels - target_labels.T
                pos_pairs = (rel_diff > 0).astype(np.float32)
                num_pos_pairs = np.sum(pos_pairs, (0, 1))

                assert num_pos_pairs > 0
                neg_pairs = (rel_diff < 0).astype(np.float32)
                num_pairs = 2 * num_pos_pairs  # num pos pairs and neg pairs are always the same

                pos_pairs = torch.FloatTensor(pos_pairs).to(args.model_device)
                neg_pairs = torch.FloatTensor(neg_pairs).to(args.model_device)

                topK_passage_embeddings = torch.FloatTensor(
                    passage_embeddings[retrieve_poffsets]).to(
                        args.model_device)
                y_pred = (qembedding.unsqueeze(0) *
                          topK_passage_embeddings).sum(-1, keepdim=True)

                C_pos = torch.log(1 + torch.exp(y_pred - y_pred.t()))
                C_neg = torch.log(1 + torch.exp(y_pred - y_pred.t()))

                C = pos_pairs * C_pos + neg_pairs * C_neg

                if args.metric is not None:
                    with torch.no_grad():
                        weights = metric_weights(y_pred, args.metric_cut)
                    C = C * weights
                cur_loss = torch.sum(C, (0, 1)) / num_pairs
                loss += cur_loss

            loss /= (args.train_batch_size * args.gradient_accumulation_steps)
            loss.backward()

            tr_loss += loss.item()
            torch.nn.utils.clip_grad_norm_(model.parameters(),
                                           args.max_grad_norm)

            if (step + 1) % args.gradient_accumulation_steps == 0:
                optimizer.step()
                scheduler.step()  # Update learning rate schedule
                model.zero_grad()
                global_step += 1

                if args.logging_steps > 0 and global_step % args.logging_steps == 0:
                    tb_writer.add_scalar('lr',
                                         scheduler.get_lr()[0], global_step)
                    cur_loss = (tr_loss - logging_loss) / args.logging_steps
                    tb_writer.add_scalar('train/all_loss', cur_loss,
                                         global_step)
                    logging_loss = tr_loss

                    cur_mrr = (tr_mrr - logging_mrr) / (
                        args.logging_steps * args.gradient_accumulation_steps)
                    tb_writer.add_scalar('train/mrr_10', cur_mrr, global_step)
                    logging_mrr = tr_mrr

                    cur_recall = (tr_recall - logging_recall) / (
                        args.logging_steps * args.gradient_accumulation_steps)
                    tb_writer.add_scalar('train/recall_10', cur_recall,
                                         global_step)
                    logging_recall = tr_recall

                if args.save_steps > 0 and global_step % args.save_steps == 0:
                    save_model(model, args.model_save_dir,
                               'ckpt-{}'.format(global_step), args)

        save_model(model, args.model_save_dir,
                   'epoch-{}'.format(epoch_idx + 1), args)