Example #1
0
 def query2vec(queries):
     question_dataloader, question_examples, query_features = get_question_dataloader(
         queries, tokenizer, max_query_length, batch_size=batch_size)
     question_results = get_question_results(question_examples,
                                             query_features,
                                             question_dataloader,
                                             device,
                                             query_encoder,
                                             batch_size=batch_size)
     if debug:
         logger.info(
             f"{len(query_features)} queries: {' '.join(query_features[0].tokens_)}"
         )
     outs = []
     for qr_idx, question_result in enumerate(question_results):
         out = (question_result.start_vec.tolist(),
                question_result.end_vec.tolist(),
                query_features[qr_idx].tokens_)
         outs.append(out)
     return outs
def train_query_encoder(args, mips=None):
    # Freeze one for MIPS
    device = 'cuda' if args.cuda else 'cpu'
    logger.info("Loading pretrained encoder: this one is for MIPS (fixed)")
    pretrained_encoder, tokenizer, _ = load_encoder(device, args)

    # Train a copy of it
    logger.info("Copying target encoder")
    target_encoder = copy.deepcopy(pretrained_encoder)

    # MIPS
    if mips is None:
        mips = load_phrase_index(args)

    # Optimizer setting
    def is_train_param(name):
        if name.endswith(".embeddings.word_embeddings.weight"):
            logger.info(f'freezing {name}')
            return False
        return True

    no_decay = ["bias", "LayerNorm.weight"]
    optimizer_grouped_parameters = [{
            "params": [
                p for n, p in target_encoder.named_parameters() \
                    if not any(nd in n for nd in no_decay) and is_train_param(n)
            ],
            "weight_decay": 0.01,
        }, {
            "params": [
                p for n, p in target_encoder.named_parameters() \
                    if any(nd in n for nd in no_decay) and is_train_param(n)
            ],
            "weight_decay": 0.0
        },
    ]
    optimizer = AdamW(optimizer_grouped_parameters,
                      lr=args.learning_rate,
                      eps=args.adam_epsilon)
    step_per_epoch = math.ceil(
        len(load_qa_pairs(args.train_path, args)[1]) /
        args.per_gpu_train_batch_size)
    t_total = int(step_per_epoch // args.gradient_accumulation_steps *
                  args.num_train_epochs)
    logger.info(f"Train for {t_total} iterations")
    scheduler = get_linear_schedule_with_warmup(
        optimizer,
        num_warmup_steps=args.warmup_steps,
        num_training_steps=t_total)
    eval_steps = math.ceil(
        len(load_qa_pairs(args.dev_path, args)[1]) / args.eval_batch_size)
    logger.info(f"Test takes {eval_steps} iterations")

    # Train arguments
    args.per_gpu_train_batch_size = int(args.per_gpu_train_batch_size /
                                        args.gradient_accumulation_steps)
    best_acc = -1000.0
    for ep_idx in range(int(args.num_train_epochs)):

        # Training
        total_loss = 0.0
        total_accs = []
        total_accs_k = []

        # Load training dataset
        q_ids, questions, answers, titles = load_qa_pairs(args.train_path,
                                                          args,
                                                          shuffle=True)
        pbar = tqdm(
            get_top_phrases(mips, q_ids, questions, answers, titles,
                            pretrained_encoder, tokenizer,
                            args.per_gpu_train_batch_size, args))

        for step_idx, (q_ids, questions, answers, titles,
                       outs) in enumerate(pbar):
            train_dataloader, _, _ = get_question_dataloader(
                questions,
                tokenizer,
                args.max_query_length,
                batch_size=args.per_gpu_train_batch_size)
            svs, evs, tgts, p_tgts = annotate_phrase_vecs(
                mips, q_ids, questions, answers, titles, outs, args)

            target_encoder.train()
            svs_t = torch.Tensor(svs).to(device)
            evs_t = torch.Tensor(evs).to(device)
            tgts_t = [
                torch.Tensor([tgt_ for tgt_ in tgt
                              if tgt_ is not None]).to(device) for tgt in tgts
            ]
            p_tgts_t = [
                torch.Tensor([tgt_ for tgt_ in tgt
                              if tgt_ is not None]).to(device)
                for tgt in p_tgts
            ]

            # Train query encoder
            assert len(train_dataloader) == 1
            for batch in train_dataloader:
                batch = tuple(t.to(device) for t in batch)
                loss, accs = target_encoder.train_query(
                    input_ids_=batch[0],
                    attention_mask_=batch[1],
                    token_type_ids_=batch[2],
                    start_vecs=svs_t,
                    end_vecs=evs_t,
                    targets=tgts_t,
                    p_targets=p_tgts_t,
                )

                # Optimize, get acc and report
                if loss is not None:
                    if args.gradient_accumulation_steps > 1:
                        loss = loss / args.gradient_accumulation_steps
                    if args.fp16:
                        with amp.scale_loss(loss, optimizer) as scaled_loss:
                            scaled_loss.backward()
                    else:
                        loss.backward()

                    total_loss += loss.mean().item()
                    if args.fp16:
                        torch.nn.utils.clip_grad_norm_(
                            amp.master_params(optimizer), args.max_grad_norm)
                    else:
                        torch.nn.utils.clip_grad_norm_(
                            target_encoder.parameters(), args.max_grad_norm)

                    optimizer.step()
                    scheduler.step()  # Update learning rate schedule
                    target_encoder.zero_grad()

                    pbar.set_description(
                        f"Ep {ep_idx+1} Tr loss: {loss.mean().item():.2f}, acc: {sum(accs)/len(accs):.3f}"
                    )

                if accs is not None:
                    total_accs += accs
                    total_accs_k += [len(tgt) > 0 for tgt in tgts_t]
                else:
                    total_accs += [0.0] * len(tgts_t)
                    total_accs_k += [0.0] * len(tgts_t)

        step_idx += 1
        logger.info(
            f"Avg train loss ({step_idx} iterations): {total_loss/step_idx:.2f} | train "
            +
            f"acc@1: {sum(total_accs)/len(total_accs):.3f} | acc@{args.top_k}: {sum(total_accs_k)/len(total_accs_k):.3f}"
        )

        # Evaluation
        new_args = copy.deepcopy(args)
        new_args.top_k = 10
        new_args.save_pred = False
        new_args.test_path = args.dev_path
        dev_em, dev_f1, dev_emk, dev_f1k = evaluate(new_args, mips,
                                                    target_encoder, tokenizer)
        logger.info(f"Develoment set acc@1: {dev_em:.3f}, f1@1: {dev_f1:.3f}")

        # Save best model
        if dev_em > best_acc:
            best_acc = dev_em
            save_path = args.output_dir
            if not os.path.exists(save_path):
                os.makedirs(save_path)
            target_encoder.save_pretrained(save_path)
            logger.info(
                f"Saved best model with acc {best_acc:.3f} into {save_path}")

        if (ep_idx + 1) % 1 == 0:
            logger.info('Updating pretrained encoder')
            pretrained_encoder = copy.deepcopy(target_encoder)

    print()
    logger.info(f"Best model has acc {best_acc:.3f} saved as {save_path}")