Example #1
0
def embed_passages(opt, passages, model, tokenizer):
    batch_size = opt.per_gpu_batch_size * opt.world_size
    collator = src.data.TextCollator(tokenizer, model.config.passage_maxlength)
    dataset = src.data.TextDataset(passages,
                                   title_prefix='title:',
                                   passage_prefix='context:')
    dataloader = DataLoader(dataset,
                            batch_size=batch_size,
                            drop_last=False,
                            num_workers=10,
                            collate_fn=collator)
    total = 0
    allids, allembeddings = [], []
    with torch.no_grad():
        for k, (ids, text_ids, text_mask) in enumerate(dataloader):
            embeddings = model.embed_text(text_ids=text_ids.cuda(),
                                          text_mask=text_mask.cuda(),
                                          apply_mask=model.apply_passage_mask)
            embeddings = embeddings.cpu()
            total += len(ids)

            allids.append(ids)
            allembeddings.append(embeddings)
            if k % 10 == 0:
                logger.info('Encoded passages %d', total)

    allembeddings = torch.cat(allembeddings, dim=0).numpy()
    allids = [x for idlist in allids for x in idlist]
    return allids, allembeddings
Example #2
0
def embed_questions(opt, data, model, tokenizer):
    batch_size = opt.per_gpu_batch_size * opt.world_size
    dataset = src.data.Dataset(data)
    collator = src.data.Collator(opt.question_maxlength, tokenizer)
    dataloader = DataLoader(dataset,
                            batch_size=batch_size,
                            drop_last=False,
                            num_workers=10,
                            collate_fn=collator)
    model.eval()
    embedding = []
    with torch.no_grad():
        for k, batch in enumerate(dataloader):
            (idx, _, _, question_ids, question_mask) = batch
            output = model.embed_text(
                text_ids=question_ids.to(opt.device).view(
                    -1, question_ids.size(-1)),
                text_mask=question_mask.to(opt.device).view(
                    -1, question_ids.size(-1)),
                apply_mask=model.apply_question_mask,
                extract_cls=model.extract_cls,
            )
            embedding.append(output)

    embedding = torch.cat(embedding, dim=0)
    logger.info(f'Questions embeddings shape: {embedding.size()}')

    return embedding.cpu().numpy()