Example #1
0
def create_biencoder_input_from_reader_input(
    tensorizer: Tensorizer,
    reader_batch: ReaderBatch,
) -> BiEncoderBatch:

    input_ids = reader_batch.input_ids  # (N, M, L)
    question_ids: List[T] = []  # len N
    context_ids: List[T] = []  # len N * M

    for input_id_i in input_ids:
        for j, input_id in enumerate(input_id_i):
            ids = tensorizer.unconcatenate_inputs(
                input_id, components={"question", "passage_title", "passage"})

            if ids is None:  # full padding
                context_ids.append(input_id)
                continue

            # Question
            question_id = tensorizer.concatenate_inputs(
                ids={"question": ids["question"].tolist()},
                get_passage_offset=False,
                to_max_length=True,
            )
            if j == 0:
                question_ids.append(question_id)
            else:
                assert (question_id == question_ids[-1]).all()

            # Passage
            passage_title = ids["passage_title"]
            passage = ids["passage"]
            context_ids.append(
                tensorizer.concatenate_inputs(
                    ids={
                        "passage_title": passage_title.tolist(),
                        "passage": passage.tolist()
                    },
                    get_passage_offset=False,
                    to_max_length=True,
                ))

    question_ids = torch.stack(question_ids)
    context_ids = torch.stack(context_ids)

    question_segments = torch.zeros_like(question_ids)
    context_segments = torch.zeros_like(context_ids)

    biencoder_batch = BiEncoderBatch(
        question_ids=question_ids,
        question_segments=question_segments,
        context_IDs=None,  # not used
        context_ids=context_ids,
        ctx_segments=context_segments,
        is_positive=None,  # not used
        hard_negatives=None,  # not used
        encoder_type=None,  # not used
    )
    return biencoder_batch
Example #2
0
    def create_biencoder_input_tokenized(
        cls,
        samples: List[BiEncoderSampleTokenized],
        tensorizer: Tensorizer,
        insert_title: bool,
        num_hard_negatives: int = 0,
        num_bm25_negatives: int = 0,
        shuffle: bool = True,
        shuffle_positives: bool = False,
        hard_neg_fallback: bool = True,
        query_token: str = None,
    ) -> BiEncoderBatch:
        """
        Creates a batch of the biencoder training tuple using tokenized data.
        :param samples: list of BiEncoderSampleTokenized-s to create the batch for
        :param tensorizer: components to create model input tensors from a text sequence
        :param insert_title: enables title insertion at the beginning of the context sequences
        :param num_hard_negatives: amount of hard negatives (densely retrieved) per question
        :param num_bm25_negatives: amount of BM25 negatives (sparsely retrieved) per question
        :param shuffle: shuffles negative passages pools
        :param shuffle_positives: shuffles positive passages pools. This is only effective for samples whose gold
            passage is available. In that case, the positive chosen is not necessarily the gold passage. Otherwise,
            the positive passages will be shuffled regardless of this parameter.
        :return: BiEncoderBatch tuple
        """
        question_tensors: List[T] = []
        ctx_ids: List[int] = []  # passage IDs
        ctx_tensors: List[T] = []
        positive_ctx_indices = []
        hard_neg_ctx_indices = []

        # Strict settings
        assert insert_title is True  # for now only allow `insert_title` to be True
        assert query_token is None

        for sample in samples:
            # Skip samples without positive passges (either gold or distant positives)
            if len(sample.positive_passages) == 0:
                continue
            # ctx+ & [ctx-] composition
            # as of now, take the first(gold) ctx+ only

            if (shuffle and shuffle_positives) or (
                    not sample.positive_passages[0].is_gold):
                positive_ctxs = sample.positive_passages
                positive_ctx = positive_ctxs[np.random.choice(
                    len(positive_ctxs))]
            else:
                positive_ctx = sample.positive_passages[0]

            bm25_neg_ctxs = sample.bm25_negative_passages
            hard_neg_ctxs = sample.hard_negative_passages
            question_ids = sample.query_ids

            if shuffle:
                random.shuffle(bm25_neg_ctxs)
                random.shuffle(hard_neg_ctxs)

            if hard_neg_fallback and len(hard_neg_ctxs) == 0:
                hard_neg_ctxs = bm25_neg_ctxs[0:num_hard_negatives]

            bm25_neg_ctxs = bm25_neg_ctxs[0:num_bm25_negatives]
            hard_neg_ctxs = hard_neg_ctxs[0:num_hard_negatives]

            all_ctxs = [positive_ctx] + bm25_neg_ctxs + hard_neg_ctxs
            hard_negatives_start_idx = 1 + len(bm25_neg_ctxs)
            hard_negatives_end_idx = len(all_ctxs)

            current_ctxs_len = len(ctx_tensors)

            # Context IDs
            ctx_id = [ctx.id for ctx in all_ctxs]
            ctx_ids.extend(ctx_id)

            # Context tensors
            sample_ctxs_tensors = [
                tensorizer.concatenate_inputs(
                    ids={
                        "passage_title": list(ctx.title_ids),
                        "passage": list(ctx.text_ids)
                    },
                    get_passage_offset=False,
                    to_max_length=True,
                ) for ctx in all_ctxs
            ]
            ctx_tensors.extend(sample_ctxs_tensors)

            positive_ctx_indices.append(current_ctxs_len)
            hard_neg_ctx_indices.append([
                i for i in range(
                    current_ctxs_len + hard_negatives_start_idx,
                    current_ctxs_len + hard_negatives_end_idx,
                )
            ])

            question_tensors.append(
                tensorizer.concatenate_inputs(ids={"question": question_ids},
                                              get_passage_offset=False,
                                              to_max_length=True))

        ctx_ids = torch.tensor(ctx_ids, dtype=torch.int64)
        ctxs_tensor = torch.cat([ctx.view(1, -1) for ctx in ctx_tensors],
                                dim=0)
        questions_tensor = torch.cat([q.view(1, -1) for q in question_tensors],
                                     dim=0)

        ctx_segments = torch.zeros_like(ctxs_tensor)
        question_segments = torch.zeros_like(questions_tensor)

        return BiEncoderBatch(
            questions_tensor,
            question_segments,
            ctx_ids,
            ctxs_tensor,
            ctx_segments,
            positive_ctx_indices,
            hard_neg_ctx_indices,
            "question",
        )
Example #3
0
def _create_question_passages_tensors(
    wiki_data: TokenizedWikipediaPassages,
    question_token_ids: np.ndarray,
    tensorizer: Tensorizer,
    positives: List[ReaderPassage],
    negatives: List[ReaderPassage],
    total_size: int,
    empty_ids: T,
    max_n_answers: int,
    is_train: bool,
    is_random: bool = True
):
    max_len = empty_ids.size(0)
    pad_token_id = tensorizer.get_pad_id()

    if is_train:
        # select just one positive
        positive_idx = _get_positive_idx(positives, max_len, is_random)
        if positive_idx is None:
            return None

        positive = positives[positive_idx]

        if getattr(positive, "sequence_ids", None) is None:
            # Load in passage tokens and title tokens
            positive.load_tokens(
                question_token_ids=question_token_ids,
                **wiki_data.get_tokenized_data(int(positive.id))
            )
            sequence_ids, passage_offset = tensorizer.concatenate_inputs({
                "question": positive.question_token_ids,
                "passage_title": positive.title_token_ids,
                "passage": positive.passage_token_ids,
            }, get_passage_offset=True)

            positive.sequence_ids = sequence_ids
            positive.passage_offset = passage_offset
            positive.answers_spans = [
                (start + passage_offset, end + passage_offset)
                for start, end in positive.answers_spans
            ]

        positive_a_spans = _get_answer_spans(positive_idx, positives, max_len)[0: max_n_answers]

        answer_starts = [span[0] for span in positive_a_spans]
        answer_ends = [span[1] for span in positive_a_spans]

        assert all(s < max_len for s in answer_starts)
        assert all(e < max_len for e in answer_ends)

        positive_input_ids = tensorizer.to_max_length(positive.sequence_ids.numpy(), apply_max_len=True)
        positive_input_ids = torch.from_numpy(positive_input_ids)

        answer_starts_tensor = torch.zeros((total_size, max_n_answers)).long()
        answer_starts_tensor[0, 0:len(answer_starts)] = torch.tensor(answer_starts)  # only first passage contains the answer

        answer_ends_tensor = torch.zeros((total_size, max_n_answers)).long()
        answer_ends_tensor[0, 0:len(answer_ends)] = torch.tensor(answer_ends)  # only first passage contains the answer

        answer_mask = torch.zeros((total_size, max_n_answers), dtype=torch.long)
        answer_mask[0, 0:len(answer_starts)] = torch.tensor([1 for _ in range(len(answer_starts))])

        positives_IDs: List[int] = [positive.id]
        positives_selected = [positive_input_ids]

    else:
        positives_IDs: List[int] = []
        positives_selected = []
        answer_starts_tensor = None
        answer_ends_tensor = None
        answer_mask = None

    positives_num = len(positives_selected)
    negative_idxs = np.random.permutation(range(len(negatives))) if is_random else range(
        len(negatives) - positives_num)

    negative_idxs = negative_idxs[:total_size - positives_num]

    negatives_IDs: List[int] = []
    negatives_selected = []

    for negative_idx in negative_idxs:
        negative = negatives[negative_idx]

        if getattr(negative, "sequence_ids", None) is None:
            # Load in passage tokens and title tokens
            negative.load_tokens(
                question_token_ids=question_token_ids,
                **wiki_data.get_tokenized_data(int(negative.id))
            )
            # Concatenate input tokens
            sequence_ids, passage_offset = tensorizer.concatenate_inputs({
                "question": negative.question_token_ids,
                "passage_title": negative.title_token_ids,
                "passage": negative.passage_token_ids,
            }, get_passage_offset=True)
            negative.sequence_ids = sequence_ids
            negative.passage_offset = passage_offset

        negatives_IDs.append(negative.id)

        negative_input_ids = tensorizer.to_max_length(negative.sequence_ids.numpy(), apply_max_len=True)
        negatives_selected.append(torch.from_numpy(negative_input_ids))

    while len(negatives_selected) < total_size - positives_num:
        negatives_IDs.append(-1)
        negatives_selected.append(empty_ids.clone())

    context_IDs = torch.tensor(positives_IDs + negatives_IDs, dtype=torch.int64)
    input_ids = torch.stack([t for t in positives_selected + negatives_selected], dim=0)
    assert len(context_IDs) == len(input_ids)

    return context_IDs, input_ids, answer_starts_tensor, answer_ends_tensor, answer_mask