Ejemplo n.º 1
0
def generate_question_vectors(
    question_encoder: torch.nn.Module,
    tensorizer: Tensorizer,
    questions: List[str],
    bsz: int,
    query_token: str = None,
    selector: RepTokenSelector = None,
) -> T:
    n = len(questions)
    query_vectors = []

    with torch.no_grad():
        for j, batch_start in enumerate(range(0, n, bsz)):
            batch_questions = questions[batch_start:batch_start + bsz]

            if query_token:
                if query_token == "[START_ENT]":
                    batch_token_tensors = [
                        _select_span_with_token(q,
                                                tensorizer,
                                                token_str=query_token)
                        for q in batch_questions
                    ]
                else:
                    batch_token_tensors = [
                        tensorizer.text_to_tensor(" ".join([query_token, q]))
                        for q in batch_questions
                    ]
            else:
                batch_token_tensors = [
                    tensorizer.text_to_tensor(q) for q in batch_questions
                ]

            q_ids_batch = torch.stack(batch_token_tensors, dim=0).cuda()
            q_seg_batch = torch.zeros_like(q_ids_batch).cuda()
            q_attn_mask = tensorizer.get_attn_mask(q_ids_batch)

            if selector:
                rep_positions = selector.get_positions(q_ids_batch, tensorizer)

                _, out, _ = BiEncoder.get_representation(
                    question_encoder,
                    q_ids_batch,
                    q_seg_batch,
                    q_attn_mask,
                    representation_token_pos=rep_positions,
                )
            else:
                _, out, _ = question_encoder(q_ids_batch, q_seg_batch,
                                             q_attn_mask)

            query_vectors.extend(out.cpu().split(1, dim=0))

            if len(query_vectors) % 100 == 0:
                logger.info("Encoded queries %d", len(query_vectors))

    query_tensor = torch.cat(query_vectors, dim=0)
    logger.info("Total encoded queries tensor %s", query_tensor.size())
    assert query_tensor.size(0) == len(questions)
    return query_tensor
def gen_ctx_vectors(ctx_rows: List[Tuple[object, str, str]], model: nn.Module, tensorizer: Tensorizer,
                    insert_title: bool = True) -> List[Tuple[object, np.array]]:
    n = len(ctx_rows)
    bsz = args.batch_size
    total = 0
    results = []
    for j, batch_start in enumerate(range(0, n, bsz)):

        batch_token_tensors = [tensorizer.text_to_tensor(ctx[1], title=ctx[2] if insert_title else None) for ctx in
                               ctx_rows[batch_start:batch_start + bsz]]

        ctx_ids_batch = torch.stack(batch_token_tensors, dim=0)
        ctx_seg_batch = torch.zeros_like(ctx_ids_batch)
        ctx_attn_mask = tensorizer.get_attn_mask(ctx_ids_batch)
        with torch.no_grad():
            _, out, _ = model(ctx_ids_batch, ctx_seg_batch, ctx_attn_mask)
        out = out.cpu()

        ctx_ids = [r[0] for r in ctx_rows[batch_start:batch_start + bsz]]

        assert len(ctx_ids) == out.size(0)

        total += len(ctx_ids)

        results.extend([
            (ctx_ids[i], out[i].view(-1).numpy())
            for i in range(out.size(0))
        ])

        if total % 10 == 0:
            logger.info('Encoded passages %d', total)

    return results
Ejemplo n.º 3
0
def _load_tokens_into_ctx(
    ctx: DataPassage,
    question_token_ids: np.ndarray,
    wiki_data: TokenizedWikipediaPassages,
    tensorizer: Tensorizer,
    check_pre_tokenized_data: bool = True,
) -> DataPassage:
    tokens = wiki_data.get_tokenized_data(int(ctx.id))

    # Double check if needed
    if ctx.passage_text is not None:
        orig_passage_ids = tensorizer.text_to_tensor(
            ctx.passage_text,
            add_special_tokens=False,
        ).numpy()
        if check_pre_tokenized_data and (len(orig_passage_ids) != len(tokens["passage_token_ids"]) or \
                not (orig_passage_ids == tokens["passage_token_ids"]).all()):
            raise ValueError(
                f"Passage token mismatch: id: {ctx.id}, orig: {orig_passage_ids}, "
                f"pre-processed: {tokens['passage_token_ids']}. If the sequence lengths are different,"
                f" this might be because the maximum length of the tokenizer is set differently during "
                f"pre-processing and training.")

        orig_title_ids = tensorizer.text_to_tensor(
            ctx.title,
            add_special_tokens=False,
        ).numpy()
        if check_pre_tokenized_data and (len(orig_title_ids) != len(tokens["title_token_ids"]) or \
                not (orig_title_ids == tokens["title_token_ids"]).all()):
            raise ValueError(
                f"Passage title token mismatch: id: {ctx.id}, orig: {orig_title_ids}, "
                f"pre-processed: {tokens['title_token_ids']}. If the sequence lengths are different,"
                f" this might be because the maximum length of the tokenizer is set differently during "
                f"pre-processing and training.")

    ctx.load_tokens(
        question_token_ids=question_token_ids,
        **tokens)  # load question, passage and passage title tokens

    # Remove redundant data
    ctx.on_serialize(remove_tokens=False)

    return ctx
Ejemplo n.º 4
0
def gen_ctx_vectors(
    cfg: DictConfig,
    ctx_rows: List[Tuple[object, BiEncoderPassage]],
    model: nn.Module,
    tensorizer: Tensorizer,
    insert_title: bool = True,
) -> List[Tuple[object, np.array]]:
    n = len(ctx_rows)
    bsz = cfg.batch_size
    total = 0
    results = []
    for j, batch_start in enumerate(range(0, n, bsz)):
        batch = ctx_rows[batch_start : batch_start + bsz]
        batch_token_tensors = [
            tensorizer.text_to_tensor(
                ctx[1].text, title=ctx[1].title if insert_title else None
            )
            for ctx in batch
        ]

        ctx_ids_batch = move_to_device(
            torch.stack(batch_token_tensors, dim=0), cfg.device
        )
        ctx_seg_batch = move_to_device(torch.zeros_like(ctx_ids_batch), cfg.device)
        ctx_attn_mask = move_to_device(
            tensorizer.get_attn_mask(ctx_ids_batch), cfg.device
        )
        with torch.no_grad():
            _, out, _ = model(ctx_ids_batch, ctx_seg_batch, ctx_attn_mask)
        out = out.cpu()

        ctx_ids = [r[0] for r in batch]
        extra_info = []
        if len(batch[0]) > 3:
            extra_info = [r[3:] for r in batch]

        assert len(ctx_ids) == out.size(0)
        total += len(ctx_ids)

        # TODO: refactor to avoid 'if'
        if extra_info:
            results.extend(
                [
                    (ctx_ids[i], out[i].view(-1).numpy(), *extra_info[i])
                    for i in range(out.size(0))
                ]
            )
        else:
            results.extend(
                [(ctx_ids[i], out[i].view(-1).numpy()) for i in range(out.size(0))]
            )

        if total % 10 == 0:
            logger.info("Encoded passages %d", total)
    return results
Ejemplo n.º 5
0
def gen_ctx_vectors(
        ctx_rows: List[Tuple[object, str, str]],
        model: nn.Module,
        tensorizer: Tensorizer,
        insert_title: bool = True) -> List[Tuple[object, np.array]]:
    n = len(ctx_rows)
    bsz = args.batch_size
    total = 0
    results = []
    for j, batch_start in enumerate(range(0, n, bsz)):

        all_txt = []
        for ctx in ctx_rows[batch_start:batch_start + bsz]:
            if ctx[2]:
                txt = ['title:', ctx[2], 'context:', ctx[1]]
            else:
                txt = ['context:', ctx[1]]
            txt = ' '.join(txt)
            all_txt.append(txt)
        batch_token_tensors = [
            tensorizer.text_to_tensor(txt, max_length=250) for txt in all_txt
        ]
        #batch_token_tensors = [tensorizer.text_to_tensor(ctx[1], title=ctx[2] if insert_title else None) for ctx in #original
        #                       ctx_rows[batch_start:batch_start + bsz]]                                             #original

        ctx_ids_batch = move_to_device(torch.stack(batch_token_tensors, dim=0),
                                       args.device)
        ctx_seg_batch = move_to_device(torch.zeros_like(ctx_ids_batch),
                                       args.device)
        ctx_attn_mask = move_to_device(tensorizer.get_attn_mask(ctx_ids_batch),
                                       args.device)
        with torch.no_grad():
            _, out, _ = model(ctx_ids_batch, ctx_seg_batch, ctx_attn_mask)
        out = out.cpu()

        ctx_ids = [r[0] for r in ctx_rows[batch_start:batch_start + bsz]]

        assert len(ctx_ids) == out.size(0)

        total += len(ctx_ids)

        #results.extend([
        #    (ctx_ids[i], out[i].view(-1).numpy())
        #    for i in range(out.size(0))
        #])

        results.extend([(ctx_ids[i], out[i].numpy())
                        for i in range(out.size(0))])

        if total % 10 == 0:
            logger.info('Encoded passages %d', total)

    return results
Ejemplo n.º 6
0
def _select_span_with_token(
    text: str, tensorizer: Tensorizer, token_str: str = "[START_ENT]"
) -> T:
    id = tensorizer.get_token_id(token_str)
    query_tensor = tensorizer.text_to_tensor(text)

    if id not in query_tensor:
        query_tensor_full = tensorizer.text_to_tensor(text, apply_max_len=False)
        token_indexes = (query_tensor_full == id).nonzero()
        if token_indexes.size(0) > 0:
            start_pos = token_indexes[0, 0].item()
            # add some randomization to avoid overfitting to a specific token position

            left_shit = int(tensorizer.max_length / 2)
            rnd_shift = int((rnd.random() - 0.5) * left_shit / 2)
            left_shit += rnd_shift

            query_tensor = query_tensor_full[start_pos - left_shit :]
            cls_id = tensorizer.tokenizer.cls_token_id
            if query_tensor[0] != cls_id:
                query_tensor = torch.cat([torch.tensor([cls_id]), query_tensor], dim=0)

            from dpr.models.reader import _pad_to_len

            query_tensor = _pad_to_len(
                query_tensor, tensorizer.get_pad_id(), tensorizer.max_length
            )
            query_tensor[-1] = tensorizer.tokenizer.sep_token_id
            # logger.info('aligned query_tensor %s', query_tensor)

            assert id in query_tensor, "query_tensor={}".format(query_tensor)
            return query_tensor
        else:
            raise RuntimeError(
                "[START_ENT] toke not found for Entity Linking sample query={}".format(
                    text
                )
            )
    else:
        return query_tensor
Ejemplo n.º 7
0
def _select_reader_passages(
    sample: Dict,
    question: str,
    tensorizer: Tensorizer,
    gold_passage_map: Optional[Dict[str, ReaderPassage]],
    gold_page_only_positives: bool,
    max_positives: int,
    max1_negatives: int,
    max2_negatives: int,
    max_retriever_passages: int,
    include_gold_passage: bool,
    is_train_set: bool,
) -> Tuple[List[ReaderPassage], List[ReaderPassage]]:
    answers = sample["answers"]

    ctxs = [ReaderPassage(**ctx) for ctx in sample["ctxs"]][0:max_retriever_passages]
    answers_token_ids = [tensorizer.text_to_tensor(a, add_special_tokens=False) for a in answers]

    if is_train_set:
        positive_samples = list(filter(lambda ctx: ctx.has_answer, ctxs))
        negative_samples = list(filter(lambda ctx: not ctx.has_answer, ctxs))
    else:
        positive_samples = []
        negative_samples = ctxs

    positive_ctxs_from_gold_page = (
        list(
            filter(
                lambda ctx: _is_from_gold_wiki_page(gold_passage_map, ctx.title, question),
                positive_samples,
            )
        )
        if gold_page_only_positives and gold_passage_map
        else []
    )

    def find_answer_spans(ctx: ReaderPassage):
        if ctx.has_answer:
            if ctx.passage_token_ids is None:
                ctx.passage_token_ids = tensorizer.text_to_tensor(ctx.passage_text, add_special_tokens=False)

            answer_spans = [
                _find_answer_positions(ctx.passage_token_ids, answers_token_ids[i]) for i in range(len(answers))
            ]

            # flatten spans list
            answer_spans = [item for sublist in answer_spans for item in sublist]
            answers_spans = list(filter(None, answer_spans))
            ctx.answers_spans = answers_spans

            if not answers_spans:
                logger.warning(
                    "No answer found in passage id=%s text=%s, answers=%s, question=%s",
                    ctx.id,
                    "",  # ctx.passage_text
                    answers,
                    question,
                )
            ctx.has_answer = bool(answers_spans)
        return ctx

    # check if any of the selected ctx+ has answer spans
    selected_positive_ctxs = list(
        filter(
            lambda ctx: ctx.has_answer,
            [find_answer_spans(ctx) for ctx in positive_ctxs_from_gold_page],
        )
    )

    if not selected_positive_ctxs:  # fallback to positive ctx not from gold pages
        selected_positive_ctxs = list(
            filter(
                lambda ctx: ctx.has_answer,
                [find_answer_spans(ctx) for ctx in positive_samples],
            )
        )[0:max_positives]

    # optionally include gold passage itself if it is still not in the positives list
    if include_gold_passage and question in gold_passage_map:
        gold_passage = gold_passage_map[question]
        included_gold_passage = next(
            iter(ctx for ctx in selected_positive_ctxs if ctx.passage_text == gold_passage.passage_text),
            None,
        )
        if not included_gold_passage:
            gold_passage.has_answer = True
            gold_passage = find_answer_spans(gold_passage)
            if not gold_passage.has_answer:
                logger.warning("No answer found in gold passage: %s", gold_passage)
            else:
                selected_positive_ctxs.append(gold_passage)

    max_negatives = (
        min(max(10 * len(selected_positive_ctxs), max1_negatives), max2_negatives)
        if is_train_set
        else DEFAULT_EVAL_PASSAGES
    )
    negative_samples = negative_samples[0:max_negatives]
    return selected_positive_ctxs, negative_samples
Ejemplo n.º 8
0
    def create_biencoder_input2(
        cls,
        samples: List[BiEncoderSample],
        tensorizer: Tensorizer,
        insert_title: bool,
        num_hard_negatives: int = 0,
        num_other_negatives: int = 0,
        shuffle: bool = True,
        shuffle_positives: bool = False,
        hard_neg_fallback: bool = True,
        query_token: str = None,
    ) -> BiEncoderBatch:
        """
        Creates a batch of the biencoder training tuple.
        :param samples: list of BiEncoderSample-s to create the batch for
        :param tensorizer: components to create model input tensors from a text sequence
        :param insert_title: enables title insertion at the beginning of the context sequences
        :param num_hard_negatives: amount of hard negatives per question (taken from samples' pools)
        :param num_other_negatives: amount of other negatives per question (taken from samples' pools)
        :param shuffle: shuffles negative passages pools
        :param shuffle_positives: shuffles positive passages pools
        :return: BiEncoderBatch tuple
        """
        question_tensors = []
        ctx_tensors = []
        positive_ctx_indices = []
        hard_neg_ctx_indices = []

        for sample in samples:
            # ctx+ & [ctx-] composition
            # as of now, take the first(gold) ctx+ only

            if shuffle and shuffle_positives:
                positive_ctxs = sample.positive_passages
                positive_ctx = positive_ctxs[np.random.choice(
                    len(positive_ctxs))]
            else:
                positive_ctx = sample.positive_passages[0]

            neg_ctxs = sample.negative_passages
            hard_neg_ctxs = sample.hard_negative_passages
            question = sample.query
            # question = normalize_question(sample.query)

            if shuffle:
                random.shuffle(neg_ctxs)
                random.shuffle(hard_neg_ctxs)

            if hard_neg_fallback and len(hard_neg_ctxs) == 0:
                hard_neg_ctxs = neg_ctxs[0:num_hard_negatives]

            neg_ctxs = neg_ctxs[0:num_other_negatives]
            hard_neg_ctxs = hard_neg_ctxs[0:num_hard_negatives]

            all_ctxs = [positive_ctx] + neg_ctxs + hard_neg_ctxs
            hard_negatives_start_idx = 1
            hard_negatives_end_idx = 1 + len(hard_neg_ctxs)

            current_ctxs_len = len(ctx_tensors)

            sample_ctxs_tensors = [
                tensorizer.text_to_tensor(
                    ctx.text,
                    title=ctx.title if (insert_title and ctx.title) else None)
                for ctx in all_ctxs
            ]

            ctx_tensors.extend(sample_ctxs_tensors)
            positive_ctx_indices.append(current_ctxs_len)
            hard_neg_ctx_indices.append([
                i for i in range(
                    current_ctxs_len + hard_negatives_start_idx,
                    current_ctxs_len + hard_negatives_end_idx,
                )
            ])

            if query_token:
                # TODO: tmp workaround for EL, remove or revise
                if query_token == "[START_ENT]":
                    query_span = _select_span_with_token(question,
                                                         tensorizer,
                                                         token_str=query_token)
                    question_tensors.append(query_span)
                else:
                    question_tensors.append(
                        tensorizer.text_to_tensor(" ".join(
                            [query_token, question])))
            else:
                question_tensors.append(tensorizer.text_to_tensor(question))

        ctxs_tensor = torch.cat([ctx.view(1, -1) for ctx in ctx_tensors],
                                dim=0)
        questions_tensor = torch.cat([q.view(1, -1) for q in question_tensors],
                                     dim=0)

        ctx_segments = torch.zeros_like(ctxs_tensor)
        question_segments = torch.zeros_like(questions_tensor)

        return BiEncoderBatch(
            questions_tensor,
            question_segments,
            ctxs_tensor,
            ctx_segments,
            positive_ctx_indices,
            hard_neg_ctx_indices,
            "question",
        )
Ejemplo n.º 9
0
    def create_biencoder_input(
        cls,
        samples: List,
        tensorizer: Tensorizer,
        insert_title: bool,
        num_hard_negatives: int = 0,
        num_other_negatives: int = 0,
        shuffle: bool = True,
        shuffle_positives: bool = False,
        hard_neg_fallback: bool = True,
    ) -> BiEncoderBatch:
        """
        Creates a batch of the biencoder training tuple.
        :param samples: list of data items (from json) to create the batch for
        :param tensorizer: components to create model input tensors from a text sequence
        :param insert_title: enables title insertion at the beginning of the context sequences
        :param num_hard_negatives: amount of hard negatives per question (taken from samples' pools)
        :param num_other_negatives: amount of other negatives per question (taken from samples' pools)
        :param shuffle: shuffles negative passages pools
        :param shuffle_positives: shuffles positive passages pools
        :return: BiEncoderBatch tuple
        """
        question_tensors = []
        ctx_tensors = []
        positive_ctx_indices = []
        hard_neg_ctx_indices = []

        for sample in samples:
            # ctx+ & [ctx-] composition
            # as of now, take the first(gold) ctx+ only
            if shuffle and shuffle_positives:
                positive_ctxs = sample["positive_ctxs"]
                positive_ctx = positive_ctxs[np.random.choice(
                    len(positive_ctxs))]
            else:
                positive_ctx = sample["positive_ctxs"][0]

            neg_ctxs = sample["negative_ctxs"]
            hard_neg_ctxs = sample["hard_negative_ctxs"]

            if shuffle:
                random.shuffle(neg_ctxs)
                random.shuffle(hard_neg_ctxs)

            if hard_neg_fallback and len(hard_neg_ctxs) == 0:
                hard_neg_ctxs = neg_ctxs[0:num_hard_negatives]

            neg_ctxs = neg_ctxs[0:num_other_negatives]
            hard_neg_ctxs = hard_neg_ctxs[0:num_hard_negatives]

            all_ctxs = [positive_ctx] + neg_ctxs + hard_neg_ctxs
            hard_negatives_start_idx = 1
            hard_negatives_end_idx = 1 + len(hard_neg_ctxs)

            current_ctxs_len = len(ctx_tensors)

            sample_ctxs_tensors = [
                tensorizer.text_to_tensor(
                    ctx["text"],
                    title=ctx["title"] if
                    (insert_title and "title" in ctx) else None,
                ) for ctx in all_ctxs
            ]

            ctx_tensors.extend(sample_ctxs_tensors)
            positive_ctx_indices.append(current_ctxs_len)
            hard_neg_ctx_indices.append([
                i for i in range(
                    current_ctxs_len + hard_negatives_start_idx,
                    current_ctxs_len + hard_negatives_end_idx,
                )
            ])

            question_tensors.append(tensorizer.text_to_tensor(question))

        ctxs_tensor = torch.cat([ctx.view(1, -1) for ctx in ctx_tensors],
                                dim=0)
        questions_tensor = torch.cat([q.view(1, -1) for q in question_tensors],
                                     dim=0)

        ctx_segments = torch.zeros_like(ctxs_tensor)
        question_segments = torch.zeros_like(questions_tensor)

        return BiEncoderBatch(
            questions_tensor,
            question_segments,
            ctxs_tensor,
            ctx_segments,
            positive_ctx_indices,
            hard_neg_ctx_indices,
            "question",
        )
Ejemplo n.º 10
0
    def create_graded_biencoder_input2(
        cls,
        samples: List[GradedBiEncoderSample],
        tensorizer: Tensorizer,
        insert_title: bool,
        num_hard_negatives: int = 0,
        num_other_negatives: int = 0,
        num_related: int = 0,
        num_highly_related: int = 0,
        shuffle: bool = True,
        shuffle_positives: bool = False,
        hard_neg_fallback: bool = True,
        query_token: str = None,
        relation_grades: list = [1.0, 1.0, 1.0, 0.0, 0.0],
    ) -> GradedBiEncoderBatch:
        """
        Creates a batch of the biencoder training tuple.
        :param samples: list of GradedBiEncoderSample-s to create the batch for
        :param tensorizer: components to create model input tensors from a text sequence
        :param insert_title: enables title insertion at the beginning of the context sequences
        :param num_hard_negatives: amount of hard negatives per question (taken from samples' pools)
        :param num_other_negatives: amount of other negatives per question (taken from samples' pools)
        :param shuffle: shuffles negative passages pools
        :param shuffle_positives: shuffles positive passages pools
        :return: BiEncoderBatch tuple
        """
        question_tensors = []
        ctx_tensors = []
        positive_ctx_indices = []
        hard_neg_ctx_indices = []
        negatives_ctx_indices = []
        related_ctx_indices = []
        highly_related_ctx_indices = []
        relations = []

        for sample in samples:
            # ctx+ & [ctx-] composition
            # as of now, take the first(gold) ctx+ only

            if shuffle and shuffle_positives:
                positive_ctxs = sample.positive_passages
                positive_ctx = positive_ctxs[np.random.choice(
                    len(positive_ctxs))]
            else:
                positive_ctx = sample.positive_passages[0]

            neg_ctxs = sample.negative_passages
            hard_neg_ctxs = sample.hard_negative_passages
            related_ctxs = sample.related_passage
            highly_related_ctxs = sample.highly_related_passage
            question = sample.query
            # question = normalize_question(sample.query)

            if shuffle:
                random.shuffle(neg_ctxs)
                random.shuffle(hard_neg_ctxs)
                random.shuffle(related_ctxs)
                random.shuffle(highly_related_ctxs)

            if hard_neg_fallback and len(hard_neg_ctxs) == 0:
                hard_neg_ctxs = neg_ctxs[0:num_hard_negatives]

            neg_ctxs = neg_ctxs[0:num_other_negatives]
            hard_neg_ctxs = hard_neg_ctxs[0:num_hard_negatives]
            related_ctxs = related_ctxs[0:num_related]
            highly_related_ctxs = highly_related_ctxs[0:num_highly_related]

            all_ctxs = [
                positive_ctx
            ] + neg_ctxs + hard_neg_ctxs + related_ctxs + highly_related_ctxs

            # relations
            rel_positive, rel_highly_related, rel_related, rel_negative, rel_hard_negative = relation_grades
            question_relations = []
            if relations != []:  # pre-padding with negatives
                question_relations = [rel_negative] * len(relations[-1])
            question_relations.extend([rel_positive])
            question_relations.extend([rel_negative] * len(neg_ctxs))
            question_relations.extend([rel_hard_negative] * len(hard_neg_ctxs))
            question_relations.extend([rel_related] * len(related_ctxs))
            question_relations.extend([rel_highly_related] *
                                      len(highly_related_ctxs))
            relations.append(question_relations)

            # post-padding with negatives
            for relation in relations:
                if len(relation) < len(relations[-1]):
                    num_negatives_to_post_pad = len(
                        relations[-1]) - len(relation)
                    relation.extend([rel_negative] * num_negatives_to_post_pad)

            # calculate all positions
            current_ctxs_len = len(ctx_tensors)
            positive_ctx_indices.append(current_ctxs_len)

            negatives_start_idx = 1 + current_ctxs_len
            negatives_end_idx = 1 + len(neg_ctxs) + current_ctxs_len
            negatives_idx_range = list(
                range(negatives_start_idx, negatives_end_idx))
            negatives_ctx_indices.append(negatives_idx_range)

            hard_negatives_start_idx = negatives_end_idx + current_ctxs_len
            hard_negatives_end_idx = negatives_end_idx + len(
                hard_neg_ctxs) + current_ctxs_len
            hard_negatives_idx_range = list(
                range(hard_negatives_start_idx, hard_negatives_end_idx))
            hard_neg_ctx_indices.append(hard_negatives_idx_range)

            related_start_idx = hard_negatives_end_idx + current_ctxs_len
            related_end_idx = hard_negatives_end_idx + len(
                related_ctxs) + current_ctxs_len
            related_idx_range = list(range(related_start_idx, related_end_idx))
            related_ctx_indices.append(related_idx_range)

            highly_related_start_idx = related_end_idx + current_ctxs_len
            highly_related_end_idx = related_end_idx + len(
                highly_related_ctxs) + current_ctxs_len
            highly_related_idx_range = list(
                range(highly_related_start_idx, highly_related_end_idx))
            highly_related_ctx_indices.append(highly_related_idx_range)

            # add all ctxs to ctx_tensors
            sample_ctxs_tensors = [
                tensorizer.text_to_tensor(
                    ctx.text,
                    title=ctx.title if (insert_title and ctx.title) else None)
                for ctx in all_ctxs
            ]

            ctx_tensors.extend(sample_ctxs_tensors)

            if query_token:
                # TODO: tmp workaround for EL, remove or revise
                if query_token == "[START_ENT]":
                    query_span = _select_span_with_token(question,
                                                         tensorizer,
                                                         token_str=query_token)
                    question_tensors.append(query_span)
                else:
                    question_tensors.append(
                        tensorizer.text_to_tensor(" ".join(
                            [query_token, question])))
            else:
                question_tensors.append(tensorizer.text_to_tensor(question))

        ctxs_tensor = torch.cat([ctx.view(1, -1) for ctx in ctx_tensors],
                                dim=0)
        questions_tensor = torch.cat([q.view(1, -1) for q in question_tensors],
                                     dim=0)

        ctx_segments = torch.zeros_like(ctxs_tensor)
        question_segments = torch.zeros_like(questions_tensor)

        return GradedBiEncoderBatch(
            questions_tensor,
            question_segments,
            ctxs_tensor,
            ctx_segments,
            positive_ctx_indices,
            hard_neg_ctx_indices,
            negatives_ctx_indices,
            related_ctx_indices,
            highly_related_ctx_indices,
            relations,
            "question",
        )
Ejemplo n.º 11
0
def _select_passages(
    wiki_data: TokenizedWikipediaPassages,
    sample: Dict,
    bm25_sample: Tuple[Tuple[int, float]],
    question: str,
    processed_question: str,
    question_token_ids: np.ndarray,
    answers: List[str],
    expanded_answers: List[List[str]],
    all_answers: List[str],
    tensorizer: Tensorizer,
    gold_passage_map: Dict[str, DataPassage],
    processed_gold_passage_map: Dict[str, DataPassage],
    cfg: PreprocessingCfg,
    is_train_set: bool,
    check_pre_tokenized_data: bool,
) -> Tuple[List[DataPassage], List[DataPassage], List[DataPassage],
           List[DataPassage], List[DataPassage], List[DataPassage],
           List[DataPassage]]:
    """
    Select and process valid passages for training/evaluation.
    """
    # Tokenize answers
    answers_token_ids: List[np.ndarray] = [
        tensorizer.text_to_tensor(a, add_special_tokens=False).numpy()
        for a in all_answers
    ]

    # Gold context; we want to cover more gold passages, that's why we are matching both
    # `processed_question` and `question` (canonical question).
    if question in processed_gold_passage_map or processed_question in processed_gold_passage_map:
        if question in processed_gold_passage_map:
            gold_ctx = processed_gold_passage_map[question]
        else:
            gold_ctx = processed_gold_passage_map[processed_question]

        gold_ctx = _load_tokens_into_ctx(
            gold_ctx,
            question_token_ids,
            wiki_data,
            tensorizer,
            check_pre_tokenized_data,
        )  # load question, passage title and passage tokens into the context object
        gold_ctx = _find_answer_spans(
            tensorizer,
            gold_ctx,
            question,
            all_answers,
            answers_token_ids,
            warn_if_no_answer=True,
            raise_if_no_answer=False,
            warn_if_has_answer=False,
            raise_if_has_answer=False,
            recheck_negatives=False,
        )  # find answer spans for all passages
        if gold_ctx.has_answer:
            gold_ctxs = [gold_ctx]
        else:
            gold_ctxs = []
    else:
        gold_ctxs = []

    # Densely retrieved contexts
    ctxs = [DataPassage(is_from_bm25=False, **ctx) for ctx in sample["ctxs"]]
    ctxs = [
        _load_tokens_into_ctx(ctx, question_token_ids, wiki_data, tensorizer,
                              check_pre_tokenized_data) for ctx in ctxs
    ]  # load question, passage title and passage tokens into the context object
    # Find answer spans for all passages
    ctxs: List[DataPassage] = [
        _find_answer_spans(
            tensorizer,
            ctx,
            question,
            all_answers,
            answers_token_ids,
            warn_if_no_answer=ctx.
            has_answer,  # warn if originally it contains answer string
            warn_if_has_answer=(
                not ctx.has_answer
            ),  # warn if originally it does NOT contain answer string
            recheck_negatives=cfg.recheck_negatives,
        ) for ctx in ctxs
    ]

    # Sparsely retrieved contexts (BM25)
    bm25_ctxs = [
        DataPassage(id=passage_id, score=score, is_from_bm25=True)
        for passage_id, score in bm25_sample
    ]
    bm25_ctxs = [
        _load_tokens_into_ctx(ctx, question_token_ids, wiki_data, tensorizer,
                              check_pre_tokenized_data) for ctx in bm25_ctxs
    ]  # load question, passage title and passage tokens into the context object
    # Find answer spans for all passages
    bm25_ctxs: List[DataPassage] = [
        _find_answer_spans(
            tensorizer,
            ctx,
            question,
            all_answers,
            answers_token_ids,
            warn_if_no_answer=False,
            warn_if_has_answer=False,
            recheck_negatives=True,  # `has_answer` of any BM25 passage is None
        ) for ctx in bm25_ctxs
    ]

    # Filter positives and negatives using distant supervision
    positive_samples = list(filter(lambda ctx: ctx.has_answer, ctxs))
    distantly_positive_samples: List[DataPassage] = []
    negative_samples = list(filter(lambda ctx: not ctx.has_answer, ctxs))
    bm25_positive_samples = list(filter(lambda ctx: ctx.has_answer, bm25_ctxs))
    bm25_distantly_positive_samples: List[DataPassage] = []
    bm25_negative_samples = list(
        filter(lambda ctx: not ctx.has_answer, bm25_ctxs))

    # Filter unwanted positive passages if training
    if is_train_set:

        # Get positives that are from gold positive passages
        if cfg.gold_page_only_positives:
            selected_positive_ctxs: List[DataPassage] = []
            selected_negative_ctxs: List[DataPassage] = negative_samples
            selected_bm25_positive_ctxs: List[DataPassage] = []
            selected_bm25_negative_ctxs: List[
                DataPassage] = bm25_negative_samples

            for positives, selected_positives, selected_negatives, distantly_positives in [
                (positive_samples, selected_positive_ctxs,
                 selected_negative_ctxs, distantly_positive_samples),
                (bm25_positive_samples, selected_bm25_positive_ctxs,
                 selected_bm25_negative_ctxs, bm25_distantly_positive_samples)
            ]:

                for ctx in positives:
                    is_from_gold = _is_from_gold_wiki_page(
                        gold_passage_map, ctx,
                        tensorizer.tensor_to_text(
                            torch.from_numpy(ctx.title_token_ids)), question)
                    if is_from_gold:
                        selected_positives.append(ctx)
                    else:  # if it has answer but does not come from gold passage
                        if cfg.should_negatives_contain_answer:
                            selected_negatives.append(ctx)
                        else:
                            distantly_positives.append(ctx)
        else:
            selected_positive_ctxs = positive_samples
            selected_negative_ctxs = negative_samples
            selected_bm25_positive_ctxs = bm25_positive_samples
            selected_bm25_negative_ctxs = bm25_negative_samples

        # Fallback to positive ctx not from gold passages
        if len(selected_positive_ctxs) == 0:
            selected_positive_ctxs = positive_samples
        if len(selected_bm25_positive_ctxs) == 0:
            selected_bm25_positive_ctxs = bm25_positive_samples

        # Optionally include gold passage itself if it is still not in the positives list
        if cfg.include_gold_passage:
            if question in gold_passage_map:
                gold_passage = gold_passage_map[question]
                gold_passage.is_gold = True
                gold_passage.has_answer = True  # assuming it has answer

                gold_passage = _find_answer_spans(
                    tensorizer,
                    gold_passage,
                    question,
                    all_answers,
                    answers_token_ids,
                    warn_if_no_answer=False,
                    warn_if_has_answer=False,
                    recheck_negatives=True,
                )  # warn below

                if not gold_passage.has_answer:
                    logger.warning(
                        "No answer found in GOLD passage: passage='%s', question='%s', answers=%s, expanded_answers=%s",
                        gold_passage.passage_text,
                        question,
                        answers,
                        expanded_answers,
                    )
                selected_positive_ctxs.append(
                    gold_passage
                )  # append anyway, since we need this for retriever (not reader)
            else:
                logger.warning(f"Question '{question}' has no gold positive")

    else:
        # NOTE: See `create_reader_input` function in `reader.py` to see how
        # positive and negative samples are merged (keeping their original order)
        selected_positive_ctxs = positive_samples
        selected_negative_ctxs = negative_samples
        selected_bm25_positive_ctxs = bm25_positive_samples
        selected_bm25_negative_ctxs = bm25_negative_samples

    # Restrict number of BM25 passages
    selected_bm25_positive_ctxs = selected_bm25_positive_ctxs[:cfg.
                                                              max_bm25_positives]
    selected_bm25_negative_ctxs = selected_bm25_negative_ctxs[:cfg.
                                                              max_bm25_negatives]

    return (
        gold_ctxs,
        selected_positive_ctxs,
        selected_negative_ctxs,
        distantly_positive_samples,
        selected_bm25_positive_ctxs,
        selected_bm25_negative_ctxs,
        bm25_distantly_positive_samples,
    )
Ejemplo n.º 12
0
def _preprocess_retriever_data(
    samples: List[Dict],
    bm25_samples: List[Tuple[Tuple[int, float]]],
    wiki_data: TokenizedWikipediaPassages,
    gold_info_file: Optional[str],
    gold_info_processed_file: str,
    tensorizer: Tensorizer,
    cfg: PreprocessingCfg = DEFAULT_PREPROCESSING_CFG_TRAIN,
    is_train_set: bool = True,
    check_pre_tokenized_data: bool = True,
) -> Iterable[DataSample]:
    """
    Converts retriever results into general retriever/reader training data.
    :param samples: samples from the retriever's json file results
    :param bm25_samples: bm25 retrieval results; list of tuples of tuples of (passage_id, score), where passages of each
        sample are already sorted by their scores
    :param gold_info_file: optional path for the 'gold passages & questions' file. Required to get best results for NQ
    :param gold_info_processed_file: path to the preprocessed gold passages pickle file. Unlike `gold_passages_file` which
        contains original gold passages, this file should contain processed, matched, 100-word split passages that match
        with the original gold passages.
    :param tensorizer: Tensorizer object for text to model input tensors conversions
    :param cfg: PreprocessingCfg object with positive and negative passage selection parameters
    :param is_train_set: if the data should be processed as a train set
    :return: iterable of DataSample objects which can be consumed by the reader model
    """
    gold_passage_map, canonical_questions = (_get_gold_ctx_dict(gold_info_file)
                                             if gold_info_file is not None else
                                             ({}, {}))
    processed_gold_passage_map = (
        _get_processed_gold_ctx_dict(gold_info_processed_file)
        if gold_info_processed_file else {})

    number_no_positive_samples = 0
    number_samples_from_gold = 0
    number_samples_with_gold = 0
    assert len(samples) == len(bm25_samples)

    for sample, bm25_sample in zip(samples, bm25_samples):
        # Refer to `_get_gold_ctx_dict` for why we need to distinguish between two types of questions
        # Here `processed_question` refer to tokenized questions, where `question` refer to
        # canonical questions.
        processed_question = sample["question"]
        if processed_question in canonical_questions:
            question = canonical_questions[processed_question]
        else:
            question = processed_question

        question_token_ids: np.ndarray = tensorizer.text_to_tensor(
            normalize_question(question)
            if cfg.normalize_questions else question,
            add_special_tokens=False,
        ).numpy()

        orig_answers = sample["answers"]
        if cfg.expand_answers:
            expanded_answers = [
                get_expanded_answer(answer) for answer in orig_answers
            ]
        else:
            expanded_answers = []
        all_answers = orig_answers + sum(expanded_answers, [])

        passages = _select_passages(
            wiki_data,
            sample,
            bm25_sample,
            question,
            processed_question,
            question_token_ids,
            orig_answers,
            expanded_answers,
            all_answers,
            tensorizer,
            gold_passage_map,
            processed_gold_passage_map,
            cfg,
            is_train_set,
            check_pre_tokenized_data,
        )
        gold_passages = passages[0]
        positive_passages, negative_passages, distantly_positive_passages = passages[
            1:4]
        bm25_positive_passages, bm25_negative_passages, bm25_distantly_positive_passages = passages[
            4:]

        if is_train_set and len(positive_passages) == 0:
            number_no_positive_samples += 1
            if cfg.skip_no_positives:
                continue

        if any(ctx for ctx in positive_passages if ctx.is_from_gold):
            number_samples_from_gold += 1

        if len(gold_passages) > 0:
            number_samples_with_gold += 1

        yield DataSample(
            question,
            question_token_ids=question_token_ids,
            answers=all_answers,
            orig_answers=orig_answers,
            expanded_answers=expanded_answers,
            # Gold
            gold_passages=gold_passages,
            # Dense
            positive_passages=positive_passages,
            distantly_positive_passages=distantly_positive_passages,
            negative_passages=negative_passages,
            # Sparse
            bm25_positive_passages=bm25_positive_passages,
            bm25_distantly_positive_passages=bm25_distantly_positive_passages,
            bm25_negative_passages=bm25_negative_passages,
        )

    logger.info(
        f"Number of samples whose at least one positive passage is "
        f"from the same article as the gold passage: {number_samples_from_gold}"
    )
    logger.info(
        f"Number of samples whose gold passage is available: {number_samples_with_gold}"
    )
    logger.info(
        f"Number of samples with no positive passages: {number_no_positive_samples}"
    )
Ejemplo n.º 13
0
def gen_ctx_vectors(
    cfg: DictConfig,
    ctx_rows: List[Tuple[object, BiEncoderPassage]],
    q_rows: List[object],
    model: nn.Module,
    tensorizer: Tensorizer,
    insert_title: bool = True,
) -> List[Tuple[object, np.array]]:
    n = len(ctx_rows)
    bsz = cfg.batch_size
    total = 0
    results = []
    for j, batch_start in enumerate(range(0, n, bsz)):
        # Passage preprocess # TODO; max seq length check
        batch = ctx_rows[batch_start:batch_start + bsz]
        batch_token_tensors = [
            tensorizer.text_to_tensor(
                ctx[1].text, title=ctx[1].title if insert_title else None)
            for ctx in batch
        ]

        ctx_ids_batch = move_to_device(torch.stack(batch_token_tensors, dim=0),
                                       cfg.device)
        ctx_seg_batch = move_to_device(torch.zeros_like(ctx_ids_batch),
                                       cfg.device)
        ctx_attn_mask = move_to_device(tensorizer.get_attn_mask(ctx_ids_batch),
                                       cfg.device)

        # Question preprocess
        q_batch = q_rows[batch_start:batch_start + bsz]
        q_batch_token_tensors = [
            tensorizer.text_to_tensor(qq) for qq in q_batch
        ]

        q_ids_batch = move_to_device(torch.stack(q_batch_token_tensors, dim=0),
                                     cfg.device)
        q_seg_batch = move_to_device(torch.zeros_like(q_ids_batch), cfg.device)
        q_attn_mask = move_to_device(tensorizer.get_attn_mask(q_ids_batch),
                                     cfg.device)

        # Selector
        from dpr.data.biencoder_data import DEFAULT_SELECTOR
        selector = DEFAULT_SELECTOR
        rep_positions = selector.get_positions(q_ids_batch, tensorizer)

        with torch.no_grad():
            q_dense, ctx_dense = model(
                q_ids_batch,
                q_seg_batch,
                q_attn_mask,
                ctx_ids_batch,
                ctx_seg_batch,
                ctx_attn_mask,
                representation_token_pos=rep_positions,
            )
        q_dense = q_dense.cpu()
        ctx_dense = ctx_dense.cpu()
        ctx_ids = [r[0] for r in batch]

        assert len(ctx_ids) == q_dense.size(0) == ctx_dense.size(0)
        total += len(ctx_ids)

        results.extend([(ctx_ids[i], q_dense[i].numpy(), ctx_dense[i].numpy(),
                         q_dense[i].numpy().dot(ctx_dense[i].numpy()))
                        for i in range(q_dense.size(0))])

        if total % 10 == 0:
            logger.info("Encoded questions / passages %d", total)
            # break
    return results
Ejemplo n.º 14
0
    def create_biencoder_input(
        cls,
        samples: List,
        tensorizer: Tensorizer,
        insert_title: bool,
        num_hard_negatives: int = 0,
        num_other_negatives: int = 0,
        shuffle: bool = True,
        shuffle_positives: bool = False,
        max_retrys: int = 100,
    ) -> BiEncoderBatch:
        """
        Creates a batch of the biencoder training tuple.
        :param samples: list of data items (from json) to create the batch for
        :param tensorizer: components to create model input tensors from a text sequence
        :param insert_title: enables title insertion at the beginning of the context sequences
        :param num_hard_negatives: amount of hard negatives per question (taken from samples' pools)
        :param num_other_negatives: amount of other negatives per question (taken from samples' pools)
        :param shuffle: shuffles negative passages pools
        :param shuffle_positives: shuffles positive passages pools
        :param max_retrys: max retry count to find unique positive context
        :return: BiEncoderBatch tuple
        """
        question_tensors = []
        ctx_tensors = []
        positive_ctx_indices = []
        hard_neg_ctx_indices = []

        used_ctxs = set()
        for sample in samples:
            # ctx+ & [ctx-] composition
            # as of now, take the first(gold) ctx+ only
            if shuffle and shuffle_positives:
                positive_ctxs = sample["positive_ctxs"]
                positive_ctx = positive_ctxs[np.random.choice(
                    len(positive_ctxs))]
                retry_counter = 0
                while positive_ctx[
                        'text'] in used_ctxs and retry_counter < max_retrys:
                    positive_ctx = positive_ctxs[np.random.choice(
                        len(positive_ctxs))]
                    retry_counter += 1
                used_ctxs.add(positive_ctx['text'])
            else:
                positive_ctx = sample["positive_ctxs"][0]

            #TODO: probably add negative_ctxs validation

            neg_ctxs = sample["negative_ctxs"]
            hard_neg_ctxs = sample["hard_negative_ctxs"]
            question = normalize_question(sample["question"])

            if shuffle:
                random.shuffle(neg_ctxs)
                random.shuffle(hard_neg_ctxs)

            neg_ctxs = neg_ctxs[0:num_other_negatives]
            hard_neg_ctxs = hard_neg_ctxs[0:num_hard_negatives]

            all_ctxs = [positive_ctx] + neg_ctxs + hard_neg_ctxs
            hard_negatives_start_idx = len(
                neg_ctxs
            ) + 1  # originally that was 1 which I don't think is right
            hard_negatives_end_idx = len(neg_ctxs) + 1 + len(hard_neg_ctxs)

            current_ctxs_len = len(ctx_tensors)

            sample_ctxs_tensors = [
                tensorizer.text_to_tensor(
                    ctx["text"], title=ctx["title"] if insert_title else None)
                for ctx in all_ctxs
            ]

            ctx_tensors.extend(sample_ctxs_tensors)
            positive_ctx_indices.append(current_ctxs_len)
            hard_neg_ctx_indices.append([
                i for i in range(
                    current_ctxs_len + hard_negatives_start_idx,
                    current_ctxs_len + hard_negatives_end_idx,
                )
            ])

            question_tensors.append(tensorizer.text_to_tensor(question))

        ctxs_tensor = torch.cat([ctx.view(1, -1) for ctx in ctx_tensors],
                                dim=0)
        questions_tensor = torch.cat([q.view(1, -1) for q in question_tensors],
                                     dim=0)

        ctx_segments = torch.zeros_like(ctxs_tensor)
        question_segments = torch.zeros_like(questions_tensor)

        return BiEncoderBatch(
            questions_tensor,
            question_segments,
            ctxs_tensor,
            ctx_segments,
            positive_ctx_indices,
            hard_neg_ctx_indices,
        )
Ejemplo n.º 15
0
def generate_question_vectors(
    question_encoder: torch.nn.Module,
    tensorizer: Tensorizer,
    questions: List[str],
    bsz: int,
    query_token: str = None,
    selector: RepTokenSelector = None,
) -> T:
    n = len(questions)
    query_vectors = []

    with torch.no_grad():
        for j, batch_start in enumerate(range(0, n, bsz)):
            batch_questions = questions[batch_start:batch_start + bsz]

            if query_token:
                # TODO: tmp workaround for EL, remove or revise
                if query_token == "[START_ENT]":
                    batch_tensors = [
                        _select_span_with_token(q,
                                                tensorizer,
                                                token_str=query_token)
                        for q in batch_questions
                    ]
                else:
                    batch_tensors = [
                        tensorizer.text_to_tensor(" ".join([query_token, q]))
                        for q in batch_questions
                    ]
            elif isinstance(batch_questions[0], T):
                batch_tensors = [q for q in batch_questions]
            else:
                batch_tensors = [
                    tensorizer.text_to_tensor(q) for q in batch_questions
                ]

            # TODO: this only works for Wav2vec pipeline but will crash the regular text pipeline
            max_vector_len = max(q_t.size(1) for q_t in batch_tensors)
            min_vector_len = min(q_t.size(1) for q_t in batch_tensors)

            if max_vector_len != min_vector_len:
                # TODO: _pad_to_len move to utils
                from dpr.models.reader import _pad_to_len
                batch_tensors = [
                    _pad_to_len(q.squeeze(0), 0, max_vector_len)
                    for q in batch_tensors
                ]

            q_ids_batch = torch.stack(batch_tensors, dim=0).cuda()
            q_seg_batch = torch.zeros_like(q_ids_batch).cuda()
            q_attn_mask = tensorizer.get_attn_mask(q_ids_batch)

            if selector:
                rep_positions = selector.get_positions(q_ids_batch, tensorizer)

                _, out, _ = BiEncoder.get_representation(
                    question_encoder,
                    q_ids_batch,
                    q_seg_batch,
                    q_attn_mask,
                    representation_token_pos=rep_positions,
                )
            else:
                _, out, _ = question_encoder(q_ids_batch, q_seg_batch,
                                             q_attn_mask)

            query_vectors.extend(out.cpu().split(1, dim=0))

            if len(query_vectors) % 100 == 0:
                logger.info("Encoded queries %d", len(query_vectors))

    query_tensor = torch.cat(query_vectors, dim=0)
    logger.info("Total encoded queries tensor %s", query_tensor.size())
    assert query_tensor.size(0) == len(questions)
    return query_tensor