def generate_question_vectors(
    question_encoder: torch.nn.Module,
    tensorizer: Tensorizer,
    questions: List[str],
    bsz: int,
    query_token: str = None,
    selector: RepTokenSelector = None,
) -> T:
    n = len(questions)
    query_vectors = []

    with torch.no_grad():
        for j, batch_start in enumerate(range(0, n, bsz)):
            batch_questions = questions[batch_start:batch_start + bsz]

            if query_token:
                if query_token == "[START_ENT]":
                    batch_token_tensors = [
                        _select_span_with_token(q,
                                                tensorizer,
                                                token_str=query_token)
                        for q in batch_questions
                    ]
                else:
                    batch_token_tensors = [
                        tensorizer.text_to_tensor(" ".join([query_token, q]))
                        for q in batch_questions
                    ]
            else:
                batch_token_tensors = [
                    tensorizer.text_to_tensor(q) for q in batch_questions
                ]

            q_ids_batch = torch.stack(batch_token_tensors, dim=0).cuda()
            q_seg_batch = torch.zeros_like(q_ids_batch).cuda()
            q_attn_mask = tensorizer.get_attn_mask(q_ids_batch)

            if selector:
                rep_positions = selector.get_positions(q_ids_batch, tensorizer)

                _, out, _ = BiEncoder.get_representation(
                    question_encoder,
                    q_ids_batch,
                    q_seg_batch,
                    q_attn_mask,
                    representation_token_pos=rep_positions,
                )
            else:
                _, out, _ = question_encoder(q_ids_batch, q_seg_batch,
                                             q_attn_mask)

            query_vectors.extend(out.cpu().split(1, dim=0))

            if len(query_vectors) % 100 == 0:
                logger.info("Encoded queries %d", len(query_vectors))

    query_tensor = torch.cat(query_vectors, dim=0)
    logger.info("Total encoded queries tensor %s", query_tensor.size())
    assert query_tensor.size(0) == len(questions)
    return query_tensor
def gen_ctx_vectors(ctx_rows: List[Tuple[object, str, str]], model: nn.Module, tensorizer: Tensorizer,
                    insert_title: bool = True) -> List[Tuple[object, np.array]]:
    n = len(ctx_rows)
    bsz = args.batch_size
    total = 0
    results = []
    for j, batch_start in enumerate(range(0, n, bsz)):

        batch_token_tensors = [tensorizer.text_to_tensor(ctx[1], title=ctx[2] if insert_title else None) for ctx in
                               ctx_rows[batch_start:batch_start + bsz]]

        ctx_ids_batch = torch.stack(batch_token_tensors, dim=0)
        ctx_seg_batch = torch.zeros_like(ctx_ids_batch)
        ctx_attn_mask = tensorizer.get_attn_mask(ctx_ids_batch)
        with torch.no_grad():
            _, out, _ = model(ctx_ids_batch, ctx_seg_batch, ctx_attn_mask)
        out = out.cpu()

        ctx_ids = [r[0] for r in ctx_rows[batch_start:batch_start + bsz]]

        assert len(ctx_ids) == out.size(0)

        total += len(ctx_ids)

        results.extend([
            (ctx_ids[i], out[i].view(-1).numpy())
            for i in range(out.size(0))
        ])

        if total % 10 == 0:
            logger.info('Encoded passages %d', total)

    return results
Example #3
0
def _do_biencoder_fwd_pass(model: nn.Module, input: BiEncoderBatch,
                           tensorizer: Tensorizer,
                           args) -> (torch.Tensor, int):
    input = BiEncoderBatch(**move_to_device(input._asdict(), args.device))

    q_attn_mask = tensorizer.get_attn_mask(input.question_ids)
    ctx_attn_mask = tensorizer.get_attn_mask(input.context_ids)

    if model.training:
        model_out = model(input.question_ids, input.question_segments,
                          q_attn_mask, input.context_ids, input.ctx_segments,
                          ctx_attn_mask)
    else:
        with torch.no_grad():
            model_out = model(input.question_ids, input.question_segments,
                              q_attn_mask, input.context_ids,
                              input.ctx_segments, ctx_attn_mask)

    local_q_vector, local_ctx_vectors = model_out

    loss_function = BiEncoderNllLoss()

    loss, is_correct = _calc_loss(args, loss_function, local_q_vector,
                                  local_ctx_vectors, input.is_positive,
                                  input.hard_negatives)

    is_correct = is_correct.sum().item()

    if args.n_gpu > 1:
        loss = loss.mean()
    if args.gradient_accumulation_steps > 1:
        loss = loss / args.gradient_accumulation_steps

    return loss, is_correct
def _do_biencoder_fwd_pass(
    model: nn.Module,
    input: BiEncoderBatch,
    tensorizer: Tensorizer,
    cfg,
    encoder_type: str,
    rep_positions=0,
    loss_scale: float = None,
) -> Tuple[torch.Tensor, int]:

    input = BiEncoderBatch(**move_to_device(input._asdict(), cfg.device))

    q_attn_mask = tensorizer.get_attn_mask(input.question_ids)
    ctx_attn_mask = tensorizer.get_attn_mask(input.context_ids)

    if model.training:
        model_out = model(
            input.question_ids,
            input.question_segments,
            q_attn_mask,
            input.context_ids,
            input.ctx_segments,
            ctx_attn_mask,
            encoder_type=encoder_type,
            representation_token_pos=rep_positions,
        )
    else:
        with torch.no_grad():
            model_out = model(
                input.question_ids,
                input.question_segments,
                q_attn_mask,
                input.context_ids,
                input.ctx_segments,
                ctx_attn_mask,
                encoder_type=encoder_type,
                representation_token_pos=rep_positions,
            )

    local_q_vector, local_ctx_vectors = model_out

    loss_function = BiEncoderNllLoss()

    loss, is_correct = _calc_loss(
        cfg,
        loss_function,
        local_q_vector,
        local_ctx_vectors,
        input.is_positive,
        input.hard_negatives,
        loss_scale=loss_scale,
    )

    is_correct = is_correct.sum().item()

    if cfg.n_gpu > 1:
        loss = loss.mean()
    if cfg.train.gradient_accumulation_steps > 1:
        loss = loss / cfg.gradient_accumulation_steps
    return loss, is_correct
Example #5
0
def create_biencoder_input_from_reader_input(
    tensorizer: Tensorizer,
    reader_batch: ReaderBatch,
) -> BiEncoderBatch:

    input_ids = reader_batch.input_ids  # (N, M, L)
    question_ids: List[T] = []  # len N
    context_ids: List[T] = []  # len N * M

    for input_id_i in input_ids:
        for j, input_id in enumerate(input_id_i):
            ids = tensorizer.unconcatenate_inputs(
                input_id, components={"question", "passage_title", "passage"})

            if ids is None:  # full padding
                context_ids.append(input_id)
                continue

            # Question
            question_id = tensorizer.concatenate_inputs(
                ids={"question": ids["question"].tolist()},
                get_passage_offset=False,
                to_max_length=True,
            )
            if j == 0:
                question_ids.append(question_id)
            else:
                assert (question_id == question_ids[-1]).all()

            # Passage
            passage_title = ids["passage_title"]
            passage = ids["passage"]
            context_ids.append(
                tensorizer.concatenate_inputs(
                    ids={
                        "passage_title": passage_title.tolist(),
                        "passage": passage.tolist()
                    },
                    get_passage_offset=False,
                    to_max_length=True,
                ))

    question_ids = torch.stack(question_ids)
    context_ids = torch.stack(context_ids)

    question_segments = torch.zeros_like(question_ids)
    context_segments = torch.zeros_like(context_ids)

    biencoder_batch = BiEncoderBatch(
        question_ids=question_ids,
        question_segments=question_segments,
        context_IDs=None,  # not used
        context_ids=context_ids,
        ctx_segments=context_segments,
        is_positive=None,  # not used
        hard_negatives=None,  # not used
        encoder_type=None,  # not used
    )
    return biencoder_batch
Example #6
0
 def _run_preprocessing(tensorizer: Tensorizer):
     # temporarily disable auto-padding to save disk space usage of serialized files
     tensorizer.set_pad_to_max(False)
     serialized_files = convert_retriever_results(is_train, data_files[0], out_file_prefix,
                                                  gold_passages_src,
                                                  self.tensorizer,
                                                  num_workers=self.args.num_workers)
     tensorizer.set_pad_to_max(True)
     return serialized_files
Example #7
0
def _find_answer_spans(
    tensorizer: Tensorizer,
    ctx: DataPassage,
    question: str,
    answers: List[str],
    answers_token_ids: List[List[int]],
    warn_if_no_answer: bool = False,
    raise_if_no_answer: bool = False,
    warn_if_has_answer: bool = False,
    raise_if_has_answer: bool = False,
    recheck_negatives: bool = False,
) -> DataPassage:
    if (not recheck_negatives) and (not ctx.has_answer):
        return ctx

    answer_spans = [
        _find_answer_positions(ctx.passage_token_ids, answers_token_ids[i])
        for i in range(len(answers))
    ]

    # flatten spans list
    answer_spans = [item for sublist in answer_spans for item in sublist]
    answers_spans = list(filter(None, answer_spans))
    ctx.answers_spans = answers_spans

    if len(answers_spans) == 0 and (warn_if_no_answer or raise_if_no_answer):
        passage_text = tensorizer.tensor_to_text(
            torch.from_numpy(ctx.passage_token_ids))
        passage_title = tensorizer.tensor_to_text(
            torch.from_numpy(ctx.title_token_ids))
        message = (
            f"No answer found in passage id={ctx.id} text={passage_text}, title={passage_title}, "
            f"answers={answers}, question={question}")

        if raise_if_no_answer:
            raise ValueError(message)
        else:
            logger.warning(message)

    if len(answers_spans) > 0 and (warn_if_has_answer or raise_if_has_answer):
        passage_text = tensorizer.tensor_to_text(
            torch.from_numpy(ctx.passage_token_ids))
        passage_title = tensorizer.tensor_to_text(
            torch.from_numpy(ctx.title_token_ids))
        message = (
            f"Answer FOUND in passage id={ctx.id} text={passage_text}, title={passage_title}, "
            f"answers={answers}, question={question}")

        if raise_if_has_answer:
            raise ValueError(message)
        else:
            logger.warning(message)

    ctx.has_answer = bool(answers_spans)

    return ctx
Example #8
0
def _extend_span_to_full_words(tensorizer: Tensorizer, tokens: List[int], span: Tuple[int, int]) -> Tuple[int, int]:
    start_index, end_index = span
    max_len = len(tokens)
    while start_index > 0 and tensorizer.is_sub_word_id(tokens[start_index]):
        start_index -= 1

    while end_index < max_len - 1 and tensorizer.is_sub_word_id(tokens[end_index + 1]):
        end_index += 1

    return start_index, end_index
def gen_ctx_vectors(
    cfg: DictConfig,
    ctx_rows: List[Tuple[object, BiEncoderPassage]],
    model: nn.Module,
    tensorizer: Tensorizer,
    insert_title: bool = True,
) -> List[Tuple[object, np.array]]:
    n = len(ctx_rows)
    bsz = cfg.batch_size
    total = 0
    results = []
    for j, batch_start in enumerate(range(0, n, bsz)):
        batch = ctx_rows[batch_start : batch_start + bsz]
        batch_token_tensors = [
            tensorizer.text_to_tensor(
                ctx[1].text, title=ctx[1].title if insert_title else None
            )
            for ctx in batch
        ]

        ctx_ids_batch = move_to_device(
            torch.stack(batch_token_tensors, dim=0), cfg.device
        )
        ctx_seg_batch = move_to_device(torch.zeros_like(ctx_ids_batch), cfg.device)
        ctx_attn_mask = move_to_device(
            tensorizer.get_attn_mask(ctx_ids_batch), cfg.device
        )
        with torch.no_grad():
            _, out, _ = model(ctx_ids_batch, ctx_seg_batch, ctx_attn_mask)
        out = out.cpu()

        ctx_ids = [r[0] for r in batch]
        extra_info = []
        if len(batch[0]) > 3:
            extra_info = [r[3:] for r in batch]

        assert len(ctx_ids) == out.size(0)
        total += len(ctx_ids)

        # TODO: refactor to avoid 'if'
        if extra_info:
            results.extend(
                [
                    (ctx_ids[i], out[i].view(-1).numpy(), *extra_info[i])
                    for i in range(out.size(0))
                ]
            )
        else:
            results.extend(
                [(ctx_ids[i], out[i].view(-1).numpy()) for i in range(out.size(0))]
            )

        if total % 10 == 0:
            logger.info("Encoded passages %d", total)
    return results
Example #10
0
def gen_ctx_vectors(
        ctx_rows: List[Tuple[object, str, str]],
        model: nn.Module,
        tensorizer: Tensorizer,
        insert_title: bool = True) -> List[Tuple[object, np.array]]:
    n = len(ctx_rows)
    bsz = args.batch_size
    total = 0
    results = []
    for j, batch_start in enumerate(range(0, n, bsz)):

        all_txt = []
        for ctx in ctx_rows[batch_start:batch_start + bsz]:
            if ctx[2]:
                txt = ['title:', ctx[2], 'context:', ctx[1]]
            else:
                txt = ['context:', ctx[1]]
            txt = ' '.join(txt)
            all_txt.append(txt)
        batch_token_tensors = [
            tensorizer.text_to_tensor(txt, max_length=250) for txt in all_txt
        ]
        #batch_token_tensors = [tensorizer.text_to_tensor(ctx[1], title=ctx[2] if insert_title else None) for ctx in #original
        #                       ctx_rows[batch_start:batch_start + bsz]]                                             #original

        ctx_ids_batch = move_to_device(torch.stack(batch_token_tensors, dim=0),
                                       args.device)
        ctx_seg_batch = move_to_device(torch.zeros_like(ctx_ids_batch),
                                       args.device)
        ctx_attn_mask = move_to_device(tensorizer.get_attn_mask(ctx_ids_batch),
                                       args.device)
        with torch.no_grad():
            _, out, _ = model(ctx_ids_batch, ctx_seg_batch, ctx_attn_mask)
        out = out.cpu()

        ctx_ids = [r[0] for r in ctx_rows[batch_start:batch_start + bsz]]

        assert len(ctx_ids) == out.size(0)

        total += len(ctx_ids)

        #results.extend([
        #    (ctx_ids[i], out[i].view(-1).numpy())
        #    for i in range(out.size(0))
        #])

        results.extend([(ctx_ids[i], out[i].numpy())
                        for i in range(out.size(0))])

        if total % 10 == 0:
            logger.info('Encoded passages %d', total)

    return results
Example #11
0
    def get_positions(self,
                      input_ids: T,
                      tenzorizer: Tensorizer,
                      model: torch.nn.Module = None):
        if not self.token_id:
            self.token_id = tenzorizer.get_token_id(self.token)
        token_indexes = (input_ids == self.token_id).nonzero()
        # check if all samples in input_ids has index presence and out a default value otherwise
        bsz = input_ids.size(0)
        if bsz == token_indexes.size(0):
            return token_indexes

        token_indexes_result = []
        found_idx_cnt = 0
        for i in range(bsz):
            if (found_idx_cnt < token_indexes.size(0)
                    and token_indexes[found_idx_cnt][0] == i):
                # this samples has the special token
                token_indexes_result.append(token_indexes[found_idx_cnt])
                found_idx_cnt += 1
            else:
                logger.warning("missing special token %s", input_ids[i])

                token_indexes_result.append(
                    torch.tensor([i, 0]).to(input_ids.device)
                )  # setting 0-th token, i.e. CLS for BERT as the special one
        token_indexes_result = torch.stack(token_indexes_result, dim=0)
        return token_indexes_result
Example #12
0
def _load_tokens_into_ctx(
    ctx: DataPassage,
    question_token_ids: np.ndarray,
    wiki_data: TokenizedWikipediaPassages,
    tensorizer: Tensorizer,
    check_pre_tokenized_data: bool = True,
) -> DataPassage:
    tokens = wiki_data.get_tokenized_data(int(ctx.id))

    # Double check if needed
    if ctx.passage_text is not None:
        orig_passage_ids = tensorizer.text_to_tensor(
            ctx.passage_text,
            add_special_tokens=False,
        ).numpy()
        if check_pre_tokenized_data and (len(orig_passage_ids) != len(tokens["passage_token_ids"]) or \
                not (orig_passage_ids == tokens["passage_token_ids"]).all()):
            raise ValueError(
                f"Passage token mismatch: id: {ctx.id}, orig: {orig_passage_ids}, "
                f"pre-processed: {tokens['passage_token_ids']}. If the sequence lengths are different,"
                f" this might be because the maximum length of the tokenizer is set differently during "
                f"pre-processing and training.")

        orig_title_ids = tensorizer.text_to_tensor(
            ctx.title,
            add_special_tokens=False,
        ).numpy()
        if check_pre_tokenized_data and (len(orig_title_ids) != len(tokens["title_token_ids"]) or \
                not (orig_title_ids == tokens["title_token_ids"]).all()):
            raise ValueError(
                f"Passage title token mismatch: id: {ctx.id}, orig: {orig_title_ids}, "
                f"pre-processed: {tokens['title_token_ids']}. If the sequence lengths are different,"
                f" this might be because the maximum length of the tokenizer is set differently during "
                f"pre-processing and training.")

    ctx.load_tokens(
        question_token_ids=question_token_ids,
        **tokens)  # load question, passage and passage title tokens

    # Remove redundant data
    ctx.on_serialize(remove_tokens=False)

    return ctx
Example #13
0
def get_best_spans(
    tensorizer: Tensorizer,
    start_logits: List,
    end_logits: List,
    ctx_ids: List,
    max_answer_length: int,
    passage_idx: int,
    relevance_score: float,
    top_spans: int = 1,
) -> List[SpanPrediction]:
    """
    Finds the best answer span for the extractive Q&A model
    """
    scores = []
    for (i, s) in enumerate(start_logits):
        for (j, e) in enumerate(end_logits[i : i + max_answer_length]):
            scores.append(((i, i + j), s + e))

    scores = sorted(scores, key=lambda x: x[1], reverse=True)
    

    chosen_span_intervals = []
    best_spans = []

    for (start_index, end_index), score in scores:
        assert start_index <= end_index
        length = end_index - start_index + 1
        assert length <= max_answer_length

        if any(
            [
                start_index <= prev_start_index <= prev_end_index <= end_index
                or prev_start_index <= start_index <= end_index <= prev_end_index
                for (prev_start_index, prev_end_index) in chosen_span_intervals
            ]
        ):
            continue

        # extend bpe subtokens to full tokens
        start_index, end_index = _extend_span_to_full_words(
            tensorizer, ctx_ids, (start_index, end_index)
        )

        predicted_answer = tensorizer.to_string(ctx_ids[start_index : end_index + 1]).upper()
        best_spans.append(
            SpanPrediction(
                predicted_answer, score, relevance_score, passage_idx, ctx_ids
            )
        )
        chosen_span_intervals.append((start_index, end_index))

        if len(chosen_span_intervals) == top_spans:
            break
    return best_spans
Example #14
0
def _select_span_with_token(
    text: str, tensorizer: Tensorizer, token_str: str = "[START_ENT]"
) -> T:
    id = tensorizer.get_token_id(token_str)
    query_tensor = tensorizer.text_to_tensor(text)

    if id not in query_tensor:
        query_tensor_full = tensorizer.text_to_tensor(text, apply_max_len=False)
        token_indexes = (query_tensor_full == id).nonzero()
        if token_indexes.size(0) > 0:
            start_pos = token_indexes[0, 0].item()
            # add some randomization to avoid overfitting to a specific token position

            left_shit = int(tensorizer.max_length / 2)
            rnd_shift = int((rnd.random() - 0.5) * left_shit / 2)
            left_shit += rnd_shift

            query_tensor = query_tensor_full[start_pos - left_shit :]
            cls_id = tensorizer.tokenizer.cls_token_id
            if query_tensor[0] != cls_id:
                query_tensor = torch.cat([torch.tensor([cls_id]), query_tensor], dim=0)

            from dpr.models.reader import _pad_to_len

            query_tensor = _pad_to_len(
                query_tensor, tensorizer.get_pad_id(), tensorizer.max_length
            )
            query_tensor[-1] = tensorizer.tokenizer.sep_token_id
            # logger.info('aligned query_tensor %s', query_tensor)

            assert id in query_tensor, "query_tensor={}".format(query_tensor)
            return query_tensor
        else:
            raise RuntimeError(
                "[START_ENT] toke not found for Entity Linking sample query={}".format(
                    text
                )
            )
    else:
        return query_tensor
Example #15
0
    def get_positions(self, input_ids: T, tenzorizer: Tensorizer,
                      model: torch.nn.Module):
        attention_masks = tenzorizer.get_attn_mask(input_ids)
        rep_positions = []

        for attention_mask in attention_masks:
            if model.training:
                input_length = (attention_mask != 0).sum()
                rep_position = random.randint(0, input_length - 1)
                rep_positions.append(rep_position)
            else:
                # Fall back to default
                rep_positions.append(self.static_position)
        rep_positions = torch.tensor(rep_positions,
                                     dtype=torch.int8).unsqueeze(-1).repeat(
                                         1, 2)
        return rep_positions
Example #16
0
def _select_reader_passages(
    sample: Dict,
    question: str,
    tensorizer: Tensorizer,
    gold_passage_map: Optional[Dict[str, ReaderPassage]],
    gold_page_only_positives: bool,
    max_positives: int,
    max1_negatives: int,
    max2_negatives: int,
    max_retriever_passages: int,
    include_gold_passage: bool,
    is_train_set: bool,
) -> Tuple[List[ReaderPassage], List[ReaderPassage]]:
    answers = sample["answers"]

    ctxs = [ReaderPassage(**ctx) for ctx in sample["ctxs"]][0:max_retriever_passages]
    answers_token_ids = [tensorizer.text_to_tensor(a, add_special_tokens=False) for a in answers]

    if is_train_set:
        positive_samples = list(filter(lambda ctx: ctx.has_answer, ctxs))
        negative_samples = list(filter(lambda ctx: not ctx.has_answer, ctxs))
    else:
        positive_samples = []
        negative_samples = ctxs

    positive_ctxs_from_gold_page = (
        list(
            filter(
                lambda ctx: _is_from_gold_wiki_page(gold_passage_map, ctx.title, question),
                positive_samples,
            )
        )
        if gold_page_only_positives and gold_passage_map
        else []
    )

    def find_answer_spans(ctx: ReaderPassage):
        if ctx.has_answer:
            if ctx.passage_token_ids is None:
                ctx.passage_token_ids = tensorizer.text_to_tensor(ctx.passage_text, add_special_tokens=False)

            answer_spans = [
                _find_answer_positions(ctx.passage_token_ids, answers_token_ids[i]) for i in range(len(answers))
            ]

            # flatten spans list
            answer_spans = [item for sublist in answer_spans for item in sublist]
            answers_spans = list(filter(None, answer_spans))
            ctx.answers_spans = answers_spans

            if not answers_spans:
                logger.warning(
                    "No answer found in passage id=%s text=%s, answers=%s, question=%s",
                    ctx.id,
                    "",  # ctx.passage_text
                    answers,
                    question,
                )
            ctx.has_answer = bool(answers_spans)
        return ctx

    # check if any of the selected ctx+ has answer spans
    selected_positive_ctxs = list(
        filter(
            lambda ctx: ctx.has_answer,
            [find_answer_spans(ctx) for ctx in positive_ctxs_from_gold_page],
        )
    )

    if not selected_positive_ctxs:  # fallback to positive ctx not from gold pages
        selected_positive_ctxs = list(
            filter(
                lambda ctx: ctx.has_answer,
                [find_answer_spans(ctx) for ctx in positive_samples],
            )
        )[0:max_positives]

    # optionally include gold passage itself if it is still not in the positives list
    if include_gold_passage and question in gold_passage_map:
        gold_passage = gold_passage_map[question]
        included_gold_passage = next(
            iter(ctx for ctx in selected_positive_ctxs if ctx.passage_text == gold_passage.passage_text),
            None,
        )
        if not included_gold_passage:
            gold_passage.has_answer = True
            gold_passage = find_answer_spans(gold_passage)
            if not gold_passage.has_answer:
                logger.warning("No answer found in gold passage: %s", gold_passage)
            else:
                selected_positive_ctxs.append(gold_passage)

    max_negatives = (
        min(max(10 * len(selected_positive_ctxs), max1_negatives), max2_negatives)
        if is_train_set
        else DEFAULT_EVAL_PASSAGES
    )
    negative_samples = negative_samples[0:max_negatives]
    return selected_positive_ctxs, negative_samples
Example #17
0
def preprocess_retriever_data(
    samples: List[Dict],
    gold_info_file: Optional[str],
    tensorizer: Tensorizer,
    cfg: ReaderPreprocessingCfg = DEFAULT_PREPROCESSING_CFG_TRAIN,
    is_train_set: bool = True,
) -> Iterable[ReaderSample]:
    """
    Converts retriever results into reader training data.
    :param samples: samples from the retriever's json file results
    :param gold_info_file: optional path for the 'gold passages & questions' file. Required to get best results for NQ
    :param tensorizer: Tensorizer object for text to model input tensors conversions
    :param cfg: ReaderPreprocessingCfg object with positive and negative passage selection parameters
    :param is_train_set: if the data should be processed as a train set
    :return: iterable of ReaderSample objects which can be consumed by the reader model
    """
    sep_tensor = tensorizer.get_pair_separator_ids()  # separator can be a multi token
    gold_passage_map, canonical_questions = _get_gold_ctx_dict(gold_info_file) if gold_info_file else ({}, {})

    no_positive_passages = 0
    positives_from_gold = 0

    def create_reader_sample_ids(sample: ReaderPassage, question: str):
        question_and_title = tensorizer.text_to_tensor(sample.title, title=question, add_special_tokens=True)
        if sample.passage_token_ids is None:
            sample.passage_token_ids = tensorizer.text_to_tensor(sample.passage_text, add_special_tokens=False)

        all_concatenated, shift = _concat_pair(
            question_and_title,
            sample.passage_token_ids,
            tailing_sep=sep_tensor if cfg.use_tailing_sep else None,
        )

        sample.sequence_ids = all_concatenated
        sample.passage_offset = shift
        assert shift > 1
        if sample.has_answer and is_train_set:
            sample.answers_spans = [(span[0] + shift, span[1] + shift) for span in sample.answers_spans]
        return sample

    for sample in samples:
        question = sample["question"]
        question_txt = sample["query_text"] if "query_text" in sample else question

        if canonical_questions and question_txt in canonical_questions:
            question_txt = canonical_questions[question_txt]

        positive_passages, negative_passages = _select_reader_passages(
            sample,
            question_txt,
            tensorizer,
            gold_passage_map,
            cfg.gold_page_only_positives,
            cfg.max_positives,
            cfg.max_negatives,
            cfg.min_negatives,
            cfg.max_retriever_passages,
            cfg.include_gold_passage,
            is_train_set,
        )
        # create concatenated sequence ids for each passage and adjust answer spans
        positive_passages = [create_reader_sample_ids(s, question) for s in positive_passages]
        negative_passages = [create_reader_sample_ids(s, question) for s in negative_passages]

        if is_train_set and len(positive_passages) == 0:
            no_positive_passages += 1
            if cfg.skip_no_positves:
                continue

        if next(iter(ctx for ctx in positive_passages if ctx.score == -1), None):
            positives_from_gold += 1

        if is_train_set:
            yield ReaderSample(
                question,
                sample["answers"],
                positive_passages=positive_passages,
                negative_passages=negative_passages,
            )
        else:
            yield ReaderSample(question, sample["answers"], passages=negative_passages)

    logger.info("no positive passages samples: %d", no_positive_passages)
    logger.info("positive passages from gold samples: %d", positives_from_gold)
Example #18
0
def _create_question_passages_tensors(
    wiki_data: TokenizedWikipediaPassages,
    question_token_ids: np.ndarray,
    tensorizer: Tensorizer,
    positives: List[ReaderPassage],
    negatives: List[ReaderPassage],
    total_size: int,
    empty_ids: T,
    max_n_answers: int,
    is_train: bool,
    is_random: bool = True
):
    max_len = empty_ids.size(0)
    pad_token_id = tensorizer.get_pad_id()

    if is_train:
        # select just one positive
        positive_idx = _get_positive_idx(positives, max_len, is_random)
        if positive_idx is None:
            return None

        positive = positives[positive_idx]

        if getattr(positive, "sequence_ids", None) is None:
            # Load in passage tokens and title tokens
            positive.load_tokens(
                question_token_ids=question_token_ids,
                **wiki_data.get_tokenized_data(int(positive.id))
            )
            sequence_ids, passage_offset = tensorizer.concatenate_inputs({
                "question": positive.question_token_ids,
                "passage_title": positive.title_token_ids,
                "passage": positive.passage_token_ids,
            }, get_passage_offset=True)

            positive.sequence_ids = sequence_ids
            positive.passage_offset = passage_offset
            positive.answers_spans = [
                (start + passage_offset, end + passage_offset)
                for start, end in positive.answers_spans
            ]

        positive_a_spans = _get_answer_spans(positive_idx, positives, max_len)[0: max_n_answers]

        answer_starts = [span[0] for span in positive_a_spans]
        answer_ends = [span[1] for span in positive_a_spans]

        assert all(s < max_len for s in answer_starts)
        assert all(e < max_len for e in answer_ends)

        positive_input_ids = tensorizer.to_max_length(positive.sequence_ids.numpy(), apply_max_len=True)
        positive_input_ids = torch.from_numpy(positive_input_ids)

        answer_starts_tensor = torch.zeros((total_size, max_n_answers)).long()
        answer_starts_tensor[0, 0:len(answer_starts)] = torch.tensor(answer_starts)  # only first passage contains the answer

        answer_ends_tensor = torch.zeros((total_size, max_n_answers)).long()
        answer_ends_tensor[0, 0:len(answer_ends)] = torch.tensor(answer_ends)  # only first passage contains the answer

        answer_mask = torch.zeros((total_size, max_n_answers), dtype=torch.long)
        answer_mask[0, 0:len(answer_starts)] = torch.tensor([1 for _ in range(len(answer_starts))])

        positives_IDs: List[int] = [positive.id]
        positives_selected = [positive_input_ids]

    else:
        positives_IDs: List[int] = []
        positives_selected = []
        answer_starts_tensor = None
        answer_ends_tensor = None
        answer_mask = None

    positives_num = len(positives_selected)
    negative_idxs = np.random.permutation(range(len(negatives))) if is_random else range(
        len(negatives) - positives_num)

    negative_idxs = negative_idxs[:total_size - positives_num]

    negatives_IDs: List[int] = []
    negatives_selected = []

    for negative_idx in negative_idxs:
        negative = negatives[negative_idx]

        if getattr(negative, "sequence_ids", None) is None:
            # Load in passage tokens and title tokens
            negative.load_tokens(
                question_token_ids=question_token_ids,
                **wiki_data.get_tokenized_data(int(negative.id))
            )
            # Concatenate input tokens
            sequence_ids, passage_offset = tensorizer.concatenate_inputs({
                "question": negative.question_token_ids,
                "passage_title": negative.title_token_ids,
                "passage": negative.passage_token_ids,
            }, get_passage_offset=True)
            negative.sequence_ids = sequence_ids
            negative.passage_offset = passage_offset

        negatives_IDs.append(negative.id)

        negative_input_ids = tensorizer.to_max_length(negative.sequence_ids.numpy(), apply_max_len=True)
        negatives_selected.append(torch.from_numpy(negative_input_ids))

    while len(negatives_selected) < total_size - positives_num:
        negatives_IDs.append(-1)
        negatives_selected.append(empty_ids.clone())

    context_IDs = torch.tensor(positives_IDs + negatives_IDs, dtype=torch.int64)
    input_ids = torch.stack([t for t in positives_selected + negatives_selected], dim=0)
    assert len(context_IDs) == len(input_ids)

    return context_IDs, input_ids, answer_starts_tensor, answer_ends_tensor, answer_mask
Example #19
0
    def create_biencoder_input2(
        cls,
        samples: List[BiEncoderSample],
        tensorizer: Tensorizer,
        insert_title: bool,
        num_hard_negatives: int = 0,
        num_other_negatives: int = 0,
        shuffle: bool = True,
        shuffle_positives: bool = False,
        hard_neg_fallback: bool = True,
        query_token: str = None,
    ) -> BiEncoderBatch:
        """
        Creates a batch of the biencoder training tuple.
        :param samples: list of BiEncoderSample-s to create the batch for
        :param tensorizer: components to create model input tensors from a text sequence
        :param insert_title: enables title insertion at the beginning of the context sequences
        :param num_hard_negatives: amount of hard negatives per question (taken from samples' pools)
        :param num_other_negatives: amount of other negatives per question (taken from samples' pools)
        :param shuffle: shuffles negative passages pools
        :param shuffle_positives: shuffles positive passages pools
        :return: BiEncoderBatch tuple
        """
        question_tensors = []
        ctx_tensors = []
        positive_ctx_indices = []
        hard_neg_ctx_indices = []

        for sample in samples:
            # ctx+ & [ctx-] composition
            # as of now, take the first(gold) ctx+ only

            if shuffle and shuffle_positives:
                positive_ctxs = sample.positive_passages
                positive_ctx = positive_ctxs[np.random.choice(
                    len(positive_ctxs))]
            else:
                positive_ctx = sample.positive_passages[0]

            neg_ctxs = sample.negative_passages
            hard_neg_ctxs = sample.hard_negative_passages
            question = sample.query
            # question = normalize_question(sample.query)

            if shuffle:
                random.shuffle(neg_ctxs)
                random.shuffle(hard_neg_ctxs)

            if hard_neg_fallback and len(hard_neg_ctxs) == 0:
                hard_neg_ctxs = neg_ctxs[0:num_hard_negatives]

            neg_ctxs = neg_ctxs[0:num_other_negatives]
            hard_neg_ctxs = hard_neg_ctxs[0:num_hard_negatives]

            all_ctxs = [positive_ctx] + neg_ctxs + hard_neg_ctxs
            hard_negatives_start_idx = 1
            hard_negatives_end_idx = 1 + len(hard_neg_ctxs)

            current_ctxs_len = len(ctx_tensors)

            sample_ctxs_tensors = [
                tensorizer.text_to_tensor(
                    ctx.text,
                    title=ctx.title if (insert_title and ctx.title) else None)
                for ctx in all_ctxs
            ]

            ctx_tensors.extend(sample_ctxs_tensors)
            positive_ctx_indices.append(current_ctxs_len)
            hard_neg_ctx_indices.append([
                i for i in range(
                    current_ctxs_len + hard_negatives_start_idx,
                    current_ctxs_len + hard_negatives_end_idx,
                )
            ])

            if query_token:
                # TODO: tmp workaround for EL, remove or revise
                if query_token == "[START_ENT]":
                    query_span = _select_span_with_token(question,
                                                         tensorizer,
                                                         token_str=query_token)
                    question_tensors.append(query_span)
                else:
                    question_tensors.append(
                        tensorizer.text_to_tensor(" ".join(
                            [query_token, question])))
            else:
                question_tensors.append(tensorizer.text_to_tensor(question))

        ctxs_tensor = torch.cat([ctx.view(1, -1) for ctx in ctx_tensors],
                                dim=0)
        questions_tensor = torch.cat([q.view(1, -1) for q in question_tensors],
                                     dim=0)

        ctx_segments = torch.zeros_like(ctxs_tensor)
        question_segments = torch.zeros_like(questions_tensor)

        return BiEncoderBatch(
            questions_tensor,
            question_segments,
            ctxs_tensor,
            ctx_segments,
            positive_ctx_indices,
            hard_neg_ctx_indices,
            "question",
        )
Example #20
0
def _select_passages(
    wiki_data: TokenizedWikipediaPassages,
    sample: Dict,
    bm25_sample: Tuple[Tuple[int, float]],
    question: str,
    processed_question: str,
    question_token_ids: np.ndarray,
    answers: List[str],
    expanded_answers: List[List[str]],
    all_answers: List[str],
    tensorizer: Tensorizer,
    gold_passage_map: Dict[str, DataPassage],
    processed_gold_passage_map: Dict[str, DataPassage],
    cfg: PreprocessingCfg,
    is_train_set: bool,
    check_pre_tokenized_data: bool,
) -> Tuple[List[DataPassage], List[DataPassage], List[DataPassage],
           List[DataPassage], List[DataPassage], List[DataPassage],
           List[DataPassage]]:
    """
    Select and process valid passages for training/evaluation.
    """
    # Tokenize answers
    answers_token_ids: List[np.ndarray] = [
        tensorizer.text_to_tensor(a, add_special_tokens=False).numpy()
        for a in all_answers
    ]

    # Gold context; we want to cover more gold passages, that's why we are matching both
    # `processed_question` and `question` (canonical question).
    if question in processed_gold_passage_map or processed_question in processed_gold_passage_map:
        if question in processed_gold_passage_map:
            gold_ctx = processed_gold_passage_map[question]
        else:
            gold_ctx = processed_gold_passage_map[processed_question]

        gold_ctx = _load_tokens_into_ctx(
            gold_ctx,
            question_token_ids,
            wiki_data,
            tensorizer,
            check_pre_tokenized_data,
        )  # load question, passage title and passage tokens into the context object
        gold_ctx = _find_answer_spans(
            tensorizer,
            gold_ctx,
            question,
            all_answers,
            answers_token_ids,
            warn_if_no_answer=True,
            raise_if_no_answer=False,
            warn_if_has_answer=False,
            raise_if_has_answer=False,
            recheck_negatives=False,
        )  # find answer spans for all passages
        if gold_ctx.has_answer:
            gold_ctxs = [gold_ctx]
        else:
            gold_ctxs = []
    else:
        gold_ctxs = []

    # Densely retrieved contexts
    ctxs = [DataPassage(is_from_bm25=False, **ctx) for ctx in sample["ctxs"]]
    ctxs = [
        _load_tokens_into_ctx(ctx, question_token_ids, wiki_data, tensorizer,
                              check_pre_tokenized_data) for ctx in ctxs
    ]  # load question, passage title and passage tokens into the context object
    # Find answer spans for all passages
    ctxs: List[DataPassage] = [
        _find_answer_spans(
            tensorizer,
            ctx,
            question,
            all_answers,
            answers_token_ids,
            warn_if_no_answer=ctx.
            has_answer,  # warn if originally it contains answer string
            warn_if_has_answer=(
                not ctx.has_answer
            ),  # warn if originally it does NOT contain answer string
            recheck_negatives=cfg.recheck_negatives,
        ) for ctx in ctxs
    ]

    # Sparsely retrieved contexts (BM25)
    bm25_ctxs = [
        DataPassage(id=passage_id, score=score, is_from_bm25=True)
        for passage_id, score in bm25_sample
    ]
    bm25_ctxs = [
        _load_tokens_into_ctx(ctx, question_token_ids, wiki_data, tensorizer,
                              check_pre_tokenized_data) for ctx in bm25_ctxs
    ]  # load question, passage title and passage tokens into the context object
    # Find answer spans for all passages
    bm25_ctxs: List[DataPassage] = [
        _find_answer_spans(
            tensorizer,
            ctx,
            question,
            all_answers,
            answers_token_ids,
            warn_if_no_answer=False,
            warn_if_has_answer=False,
            recheck_negatives=True,  # `has_answer` of any BM25 passage is None
        ) for ctx in bm25_ctxs
    ]

    # Filter positives and negatives using distant supervision
    positive_samples = list(filter(lambda ctx: ctx.has_answer, ctxs))
    distantly_positive_samples: List[DataPassage] = []
    negative_samples = list(filter(lambda ctx: not ctx.has_answer, ctxs))
    bm25_positive_samples = list(filter(lambda ctx: ctx.has_answer, bm25_ctxs))
    bm25_distantly_positive_samples: List[DataPassage] = []
    bm25_negative_samples = list(
        filter(lambda ctx: not ctx.has_answer, bm25_ctxs))

    # Filter unwanted positive passages if training
    if is_train_set:

        # Get positives that are from gold positive passages
        if cfg.gold_page_only_positives:
            selected_positive_ctxs: List[DataPassage] = []
            selected_negative_ctxs: List[DataPassage] = negative_samples
            selected_bm25_positive_ctxs: List[DataPassage] = []
            selected_bm25_negative_ctxs: List[
                DataPassage] = bm25_negative_samples

            for positives, selected_positives, selected_negatives, distantly_positives in [
                (positive_samples, selected_positive_ctxs,
                 selected_negative_ctxs, distantly_positive_samples),
                (bm25_positive_samples, selected_bm25_positive_ctxs,
                 selected_bm25_negative_ctxs, bm25_distantly_positive_samples)
            ]:

                for ctx in positives:
                    is_from_gold = _is_from_gold_wiki_page(
                        gold_passage_map, ctx,
                        tensorizer.tensor_to_text(
                            torch.from_numpy(ctx.title_token_ids)), question)
                    if is_from_gold:
                        selected_positives.append(ctx)
                    else:  # if it has answer but does not come from gold passage
                        if cfg.should_negatives_contain_answer:
                            selected_negatives.append(ctx)
                        else:
                            distantly_positives.append(ctx)
        else:
            selected_positive_ctxs = positive_samples
            selected_negative_ctxs = negative_samples
            selected_bm25_positive_ctxs = bm25_positive_samples
            selected_bm25_negative_ctxs = bm25_negative_samples

        # Fallback to positive ctx not from gold passages
        if len(selected_positive_ctxs) == 0:
            selected_positive_ctxs = positive_samples
        if len(selected_bm25_positive_ctxs) == 0:
            selected_bm25_positive_ctxs = bm25_positive_samples

        # Optionally include gold passage itself if it is still not in the positives list
        if cfg.include_gold_passage:
            if question in gold_passage_map:
                gold_passage = gold_passage_map[question]
                gold_passage.is_gold = True
                gold_passage.has_answer = True  # assuming it has answer

                gold_passage = _find_answer_spans(
                    tensorizer,
                    gold_passage,
                    question,
                    all_answers,
                    answers_token_ids,
                    warn_if_no_answer=False,
                    warn_if_has_answer=False,
                    recheck_negatives=True,
                )  # warn below

                if not gold_passage.has_answer:
                    logger.warning(
                        "No answer found in GOLD passage: passage='%s', question='%s', answers=%s, expanded_answers=%s",
                        gold_passage.passage_text,
                        question,
                        answers,
                        expanded_answers,
                    )
                selected_positive_ctxs.append(
                    gold_passage
                )  # append anyway, since we need this for retriever (not reader)
            else:
                logger.warning(f"Question '{question}' has no gold positive")

    else:
        # NOTE: See `create_reader_input` function in `reader.py` to see how
        # positive and negative samples are merged (keeping their original order)
        selected_positive_ctxs = positive_samples
        selected_negative_ctxs = negative_samples
        selected_bm25_positive_ctxs = bm25_positive_samples
        selected_bm25_negative_ctxs = bm25_negative_samples

    # Restrict number of BM25 passages
    selected_bm25_positive_ctxs = selected_bm25_positive_ctxs[:cfg.
                                                              max_bm25_positives]
    selected_bm25_negative_ctxs = selected_bm25_negative_ctxs[:cfg.
                                                              max_bm25_negatives]

    return (
        gold_ctxs,
        selected_positive_ctxs,
        selected_negative_ctxs,
        distantly_positive_samples,
        selected_bm25_positive_ctxs,
        selected_bm25_negative_ctxs,
        bm25_distantly_positive_samples,
    )
Example #21
0
    def create_biencoder_input(
        cls,
        samples: List,
        tensorizer: Tensorizer,
        insert_title: bool,
        num_hard_negatives: int = 0,
        num_other_negatives: int = 0,
        shuffle: bool = True,
        shuffle_positives: bool = False,
        max_retrys: int = 100,
    ) -> BiEncoderBatch:
        """
        Creates a batch of the biencoder training tuple.
        :param samples: list of data items (from json) to create the batch for
        :param tensorizer: components to create model input tensors from a text sequence
        :param insert_title: enables title insertion at the beginning of the context sequences
        :param num_hard_negatives: amount of hard negatives per question (taken from samples' pools)
        :param num_other_negatives: amount of other negatives per question (taken from samples' pools)
        :param shuffle: shuffles negative passages pools
        :param shuffle_positives: shuffles positive passages pools
        :param max_retrys: max retry count to find unique positive context
        :return: BiEncoderBatch tuple
        """
        question_tensors = []
        ctx_tensors = []
        positive_ctx_indices = []
        hard_neg_ctx_indices = []

        used_ctxs = set()
        for sample in samples:
            # ctx+ & [ctx-] composition
            # as of now, take the first(gold) ctx+ only
            if shuffle and shuffle_positives:
                positive_ctxs = sample["positive_ctxs"]
                positive_ctx = positive_ctxs[np.random.choice(
                    len(positive_ctxs))]
                retry_counter = 0
                while positive_ctx[
                        'text'] in used_ctxs and retry_counter < max_retrys:
                    positive_ctx = positive_ctxs[np.random.choice(
                        len(positive_ctxs))]
                    retry_counter += 1
                used_ctxs.add(positive_ctx['text'])
            else:
                positive_ctx = sample["positive_ctxs"][0]

            #TODO: probably add negative_ctxs validation

            neg_ctxs = sample["negative_ctxs"]
            hard_neg_ctxs = sample["hard_negative_ctxs"]
            question = normalize_question(sample["question"])

            if shuffle:
                random.shuffle(neg_ctxs)
                random.shuffle(hard_neg_ctxs)

            neg_ctxs = neg_ctxs[0:num_other_negatives]
            hard_neg_ctxs = hard_neg_ctxs[0:num_hard_negatives]

            all_ctxs = [positive_ctx] + neg_ctxs + hard_neg_ctxs
            hard_negatives_start_idx = len(
                neg_ctxs
            ) + 1  # originally that was 1 which I don't think is right
            hard_negatives_end_idx = len(neg_ctxs) + 1 + len(hard_neg_ctxs)

            current_ctxs_len = len(ctx_tensors)

            sample_ctxs_tensors = [
                tensorizer.text_to_tensor(
                    ctx["text"], title=ctx["title"] if insert_title else None)
                for ctx in all_ctxs
            ]

            ctx_tensors.extend(sample_ctxs_tensors)
            positive_ctx_indices.append(current_ctxs_len)
            hard_neg_ctx_indices.append([
                i for i in range(
                    current_ctxs_len + hard_negatives_start_idx,
                    current_ctxs_len + hard_negatives_end_idx,
                )
            ])

            question_tensors.append(tensorizer.text_to_tensor(question))

        ctxs_tensor = torch.cat([ctx.view(1, -1) for ctx in ctx_tensors],
                                dim=0)
        questions_tensor = torch.cat([q.view(1, -1) for q in question_tensors],
                                     dim=0)

        ctx_segments = torch.zeros_like(ctxs_tensor)
        question_segments = torch.zeros_like(questions_tensor)

        return BiEncoderBatch(
            questions_tensor,
            question_segments,
            ctxs_tensor,
            ctx_segments,
            positive_ctx_indices,
            hard_neg_ctx_indices,
        )
Example #22
0
    def create_graded_biencoder_input2(
        cls,
        samples: List[GradedBiEncoderSample],
        tensorizer: Tensorizer,
        insert_title: bool,
        num_hard_negatives: int = 0,
        num_other_negatives: int = 0,
        num_related: int = 0,
        num_highly_related: int = 0,
        shuffle: bool = True,
        shuffle_positives: bool = False,
        hard_neg_fallback: bool = True,
        query_token: str = None,
        relation_grades: list = [1.0, 1.0, 1.0, 0.0, 0.0],
    ) -> GradedBiEncoderBatch:
        """
        Creates a batch of the biencoder training tuple.
        :param samples: list of GradedBiEncoderSample-s to create the batch for
        :param tensorizer: components to create model input tensors from a text sequence
        :param insert_title: enables title insertion at the beginning of the context sequences
        :param num_hard_negatives: amount of hard negatives per question (taken from samples' pools)
        :param num_other_negatives: amount of other negatives per question (taken from samples' pools)
        :param shuffle: shuffles negative passages pools
        :param shuffle_positives: shuffles positive passages pools
        :return: BiEncoderBatch tuple
        """
        question_tensors = []
        ctx_tensors = []
        positive_ctx_indices = []
        hard_neg_ctx_indices = []
        negatives_ctx_indices = []
        related_ctx_indices = []
        highly_related_ctx_indices = []
        relations = []

        for sample in samples:
            # ctx+ & [ctx-] composition
            # as of now, take the first(gold) ctx+ only

            if shuffle and shuffle_positives:
                positive_ctxs = sample.positive_passages
                positive_ctx = positive_ctxs[np.random.choice(
                    len(positive_ctxs))]
            else:
                positive_ctx = sample.positive_passages[0]

            neg_ctxs = sample.negative_passages
            hard_neg_ctxs = sample.hard_negative_passages
            related_ctxs = sample.related_passage
            highly_related_ctxs = sample.highly_related_passage
            question = sample.query
            # question = normalize_question(sample.query)

            if shuffle:
                random.shuffle(neg_ctxs)
                random.shuffle(hard_neg_ctxs)
                random.shuffle(related_ctxs)
                random.shuffle(highly_related_ctxs)

            if hard_neg_fallback and len(hard_neg_ctxs) == 0:
                hard_neg_ctxs = neg_ctxs[0:num_hard_negatives]

            neg_ctxs = neg_ctxs[0:num_other_negatives]
            hard_neg_ctxs = hard_neg_ctxs[0:num_hard_negatives]
            related_ctxs = related_ctxs[0:num_related]
            highly_related_ctxs = highly_related_ctxs[0:num_highly_related]

            all_ctxs = [
                positive_ctx
            ] + neg_ctxs + hard_neg_ctxs + related_ctxs + highly_related_ctxs

            # relations
            rel_positive, rel_highly_related, rel_related, rel_negative, rel_hard_negative = relation_grades
            question_relations = []
            if relations != []:  # pre-padding with negatives
                question_relations = [rel_negative] * len(relations[-1])
            question_relations.extend([rel_positive])
            question_relations.extend([rel_negative] * len(neg_ctxs))
            question_relations.extend([rel_hard_negative] * len(hard_neg_ctxs))
            question_relations.extend([rel_related] * len(related_ctxs))
            question_relations.extend([rel_highly_related] *
                                      len(highly_related_ctxs))
            relations.append(question_relations)

            # post-padding with negatives
            for relation in relations:
                if len(relation) < len(relations[-1]):
                    num_negatives_to_post_pad = len(
                        relations[-1]) - len(relation)
                    relation.extend([rel_negative] * num_negatives_to_post_pad)

            # calculate all positions
            current_ctxs_len = len(ctx_tensors)
            positive_ctx_indices.append(current_ctxs_len)

            negatives_start_idx = 1 + current_ctxs_len
            negatives_end_idx = 1 + len(neg_ctxs) + current_ctxs_len
            negatives_idx_range = list(
                range(negatives_start_idx, negatives_end_idx))
            negatives_ctx_indices.append(negatives_idx_range)

            hard_negatives_start_idx = negatives_end_idx + current_ctxs_len
            hard_negatives_end_idx = negatives_end_idx + len(
                hard_neg_ctxs) + current_ctxs_len
            hard_negatives_idx_range = list(
                range(hard_negatives_start_idx, hard_negatives_end_idx))
            hard_neg_ctx_indices.append(hard_negatives_idx_range)

            related_start_idx = hard_negatives_end_idx + current_ctxs_len
            related_end_idx = hard_negatives_end_idx + len(
                related_ctxs) + current_ctxs_len
            related_idx_range = list(range(related_start_idx, related_end_idx))
            related_ctx_indices.append(related_idx_range)

            highly_related_start_idx = related_end_idx + current_ctxs_len
            highly_related_end_idx = related_end_idx + len(
                highly_related_ctxs) + current_ctxs_len
            highly_related_idx_range = list(
                range(highly_related_start_idx, highly_related_end_idx))
            highly_related_ctx_indices.append(highly_related_idx_range)

            # add all ctxs to ctx_tensors
            sample_ctxs_tensors = [
                tensorizer.text_to_tensor(
                    ctx.text,
                    title=ctx.title if (insert_title and ctx.title) else None)
                for ctx in all_ctxs
            ]

            ctx_tensors.extend(sample_ctxs_tensors)

            if query_token:
                # TODO: tmp workaround for EL, remove or revise
                if query_token == "[START_ENT]":
                    query_span = _select_span_with_token(question,
                                                         tensorizer,
                                                         token_str=query_token)
                    question_tensors.append(query_span)
                else:
                    question_tensors.append(
                        tensorizer.text_to_tensor(" ".join(
                            [query_token, question])))
            else:
                question_tensors.append(tensorizer.text_to_tensor(question))

        ctxs_tensor = torch.cat([ctx.view(1, -1) for ctx in ctx_tensors],
                                dim=0)
        questions_tensor = torch.cat([q.view(1, -1) for q in question_tensors],
                                     dim=0)

        ctx_segments = torch.zeros_like(ctxs_tensor)
        question_segments = torch.zeros_like(questions_tensor)

        return GradedBiEncoderBatch(
            questions_tensor,
            question_segments,
            ctxs_tensor,
            ctx_segments,
            positive_ctx_indices,
            hard_neg_ctx_indices,
            negatives_ctx_indices,
            related_ctx_indices,
            highly_related_ctx_indices,
            relations,
            "question",
        )
Example #23
0
    def create_biencoder_input_tokenized(
        cls,
        samples: List[BiEncoderSampleTokenized],
        tensorizer: Tensorizer,
        insert_title: bool,
        num_hard_negatives: int = 0,
        num_bm25_negatives: int = 0,
        shuffle: bool = True,
        shuffle_positives: bool = False,
        hard_neg_fallback: bool = True,
        query_token: str = None,
    ) -> BiEncoderBatch:
        """
        Creates a batch of the biencoder training tuple using tokenized data.
        :param samples: list of BiEncoderSampleTokenized-s to create the batch for
        :param tensorizer: components to create model input tensors from a text sequence
        :param insert_title: enables title insertion at the beginning of the context sequences
        :param num_hard_negatives: amount of hard negatives (densely retrieved) per question
        :param num_bm25_negatives: amount of BM25 negatives (sparsely retrieved) per question
        :param shuffle: shuffles negative passages pools
        :param shuffle_positives: shuffles positive passages pools. This is only effective for samples whose gold
            passage is available. In that case, the positive chosen is not necessarily the gold passage. Otherwise,
            the positive passages will be shuffled regardless of this parameter.
        :return: BiEncoderBatch tuple
        """
        question_tensors: List[T] = []
        ctx_ids: List[int] = []  # passage IDs
        ctx_tensors: List[T] = []
        positive_ctx_indices = []
        hard_neg_ctx_indices = []

        # Strict settings
        assert insert_title is True  # for now only allow `insert_title` to be True
        assert query_token is None

        for sample in samples:
            # Skip samples without positive passges (either gold or distant positives)
            if len(sample.positive_passages) == 0:
                continue
            # ctx+ & [ctx-] composition
            # as of now, take the first(gold) ctx+ only

            if (shuffle and shuffle_positives) or (
                    not sample.positive_passages[0].is_gold):
                positive_ctxs = sample.positive_passages
                positive_ctx = positive_ctxs[np.random.choice(
                    len(positive_ctxs))]
            else:
                positive_ctx = sample.positive_passages[0]

            bm25_neg_ctxs = sample.bm25_negative_passages
            hard_neg_ctxs = sample.hard_negative_passages
            question_ids = sample.query_ids

            if shuffle:
                random.shuffle(bm25_neg_ctxs)
                random.shuffle(hard_neg_ctxs)

            if hard_neg_fallback and len(hard_neg_ctxs) == 0:
                hard_neg_ctxs = bm25_neg_ctxs[0:num_hard_negatives]

            bm25_neg_ctxs = bm25_neg_ctxs[0:num_bm25_negatives]
            hard_neg_ctxs = hard_neg_ctxs[0:num_hard_negatives]

            all_ctxs = [positive_ctx] + bm25_neg_ctxs + hard_neg_ctxs
            hard_negatives_start_idx = 1 + len(bm25_neg_ctxs)
            hard_negatives_end_idx = len(all_ctxs)

            current_ctxs_len = len(ctx_tensors)

            # Context IDs
            ctx_id = [ctx.id for ctx in all_ctxs]
            ctx_ids.extend(ctx_id)

            # Context tensors
            sample_ctxs_tensors = [
                tensorizer.concatenate_inputs(
                    ids={
                        "passage_title": list(ctx.title_ids),
                        "passage": list(ctx.text_ids)
                    },
                    get_passage_offset=False,
                    to_max_length=True,
                ) for ctx in all_ctxs
            ]
            ctx_tensors.extend(sample_ctxs_tensors)

            positive_ctx_indices.append(current_ctxs_len)
            hard_neg_ctx_indices.append([
                i for i in range(
                    current_ctxs_len + hard_negatives_start_idx,
                    current_ctxs_len + hard_negatives_end_idx,
                )
            ])

            question_tensors.append(
                tensorizer.concatenate_inputs(ids={"question": question_ids},
                                              get_passage_offset=False,
                                              to_max_length=True))

        ctx_ids = torch.tensor(ctx_ids, dtype=torch.int64)
        ctxs_tensor = torch.cat([ctx.view(1, -1) for ctx in ctx_tensors],
                                dim=0)
        questions_tensor = torch.cat([q.view(1, -1) for q in question_tensors],
                                     dim=0)

        ctx_segments = torch.zeros_like(ctxs_tensor)
        question_segments = torch.zeros_like(questions_tensor)

        return BiEncoderBatch(
            questions_tensor,
            question_segments,
            ctx_ids,
            ctxs_tensor,
            ctx_segments,
            positive_ctx_indices,
            hard_neg_ctx_indices,
            "question",
        )
Example #24
0
def _do_biencoder_fwd_pass(
    model: nn.Module,
    input: BiEncoderBatch,
    tensorizer: Tensorizer,
    loss_function,
    cfg,
    encoder_type: str,
    rep_positions_q=0,
    rep_positions_c=0,
    loss_scale: float = None,
    clustering: bool = False,
) -> Tuple[torch.Tensor, int]:

    input = BiEncoderBatch(**move_to_device(input._asdict(), cfg.device))

    q_attn_mask = tensorizer.get_attn_mask(input.question_ids)
    ctx_attn_mask = tensorizer.get_attn_mask(input.context_ids)

    if model.training:
        model_out = model(
            input.question_ids,
            input.question_segments,
            q_attn_mask,
            input.context_ids,
            input.ctx_segments,
            ctx_attn_mask,
            encoder_type=encoder_type,
            representation_token_pos_q=rep_positions_q,
            representation_token_pos_c=rep_positions_c,
        )
    else:
        with torch.no_grad():
            model_out = model(
                input.question_ids,
                input.question_segments,
                q_attn_mask,
                input.context_ids,
                input.ctx_segments,
                ctx_attn_mask,
                encoder_type=encoder_type,
                representation_token_pos_q=rep_positions_q,
                representation_token_pos_c=rep_positions_c,
            )

    local_q_vector, local_ctx_vectors = model_out

    if cfg.others.is_matching:  # MatchBiEncoder model
        loss, ml_is_correct, matching_is_correct = _calc_loss_matching(
            cfg,
            model,
            loss_function,
            local_q_vector,
            local_ctx_vectors,
            input.is_positive,
            input.hard_negatives,
            loss_scale=loss_scale,
        )
        ml_is_correct = ml_is_correct.sum().item()
        matching_is_correct = matching_is_correct.sum().item()
    else:
        loss, is_correct = calc_loss(
            cfg,
            loss_function,
            local_q_vector,
            local_ctx_vectors,
            input.is_positive,
            input.hard_negatives,
            loss_scale=loss_scale,
        )
        is_correct = is_correct.sum().item()

    if cfg.n_gpu > 1:
        loss = loss.mean()
    if cfg.train.gradient_accumulation_steps > 1:
        loss = loss / cfg.gradient_accumulation_steps

    if clustering:
        assert not cfg.others.is_matching
        return loss, is_correct, model_out
    elif cfg.others.is_matching:
        return loss, ml_is_correct, matching_is_correct
    else:
        return loss, is_correct
Example #25
0
def gen_ctx_vectors(
    cfg: DictConfig,
    ctx_rows: List[Tuple[object, BiEncoderPassage]],
    q_rows: List[object],
    model: nn.Module,
    tensorizer: Tensorizer,
    insert_title: bool = True,
) -> List[Tuple[object, np.array]]:
    n = len(ctx_rows)
    bsz = cfg.batch_size
    total = 0
    results = []
    for j, batch_start in enumerate(range(0, n, bsz)):
        # Passage preprocess # TODO; max seq length check
        batch = ctx_rows[batch_start:batch_start + bsz]
        batch_token_tensors = [
            tensorizer.text_to_tensor(
                ctx[1].text, title=ctx[1].title if insert_title else None)
            for ctx in batch
        ]

        ctx_ids_batch = move_to_device(torch.stack(batch_token_tensors, dim=0),
                                       cfg.device)
        ctx_seg_batch = move_to_device(torch.zeros_like(ctx_ids_batch),
                                       cfg.device)
        ctx_attn_mask = move_to_device(tensorizer.get_attn_mask(ctx_ids_batch),
                                       cfg.device)

        # Question preprocess
        q_batch = q_rows[batch_start:batch_start + bsz]
        q_batch_token_tensors = [
            tensorizer.text_to_tensor(qq) for qq in q_batch
        ]

        q_ids_batch = move_to_device(torch.stack(q_batch_token_tensors, dim=0),
                                     cfg.device)
        q_seg_batch = move_to_device(torch.zeros_like(q_ids_batch), cfg.device)
        q_attn_mask = move_to_device(tensorizer.get_attn_mask(q_ids_batch),
                                     cfg.device)

        # Selector
        from dpr.data.biencoder_data import DEFAULT_SELECTOR
        selector = DEFAULT_SELECTOR
        rep_positions = selector.get_positions(q_ids_batch, tensorizer)

        with torch.no_grad():
            q_dense, ctx_dense = model(
                q_ids_batch,
                q_seg_batch,
                q_attn_mask,
                ctx_ids_batch,
                ctx_seg_batch,
                ctx_attn_mask,
                representation_token_pos=rep_positions,
            )
        q_dense = q_dense.cpu()
        ctx_dense = ctx_dense.cpu()
        ctx_ids = [r[0] for r in batch]

        assert len(ctx_ids) == q_dense.size(0) == ctx_dense.size(0)
        total += len(ctx_ids)

        results.extend([(ctx_ids[i], q_dense[i].numpy(), ctx_dense[i].numpy(),
                         q_dense[i].numpy().dot(ctx_dense[i].numpy()))
                        for i in range(q_dense.size(0))])

        if total % 10 == 0:
            logger.info("Encoded questions / passages %d", total)
            # break
    return results
Example #26
0
def create_reader_input(
    wiki_data: TokenizedWikipediaPassages,
    tensorizer: Tensorizer,
    samples: List[ReaderSample],
    passages_per_question: int,
    max_length: int,
    max_n_answers: int,
    is_train: bool,
    shuffle: bool,
) -> ReaderBatch:
    """
    Creates a reader batch instance out of a list of ReaderSample-s. This is compatible with `GeneralDataset`.
    :param wiki_data: all tokenized wikipedia passages
    :param tensorizer: initialized tensorizer (which contains the tokenizer)
    :param samples: list of samples to create the batch for
    :param passages_per_question: amount of passages for every question in a batch
    :param max_length: max model input sequence length
    :param max_n_answers: max num of answers per single question
    :param is_train: if the samples are for a train set
    :param shuffle: should passages selection be randomized
    :return: ReaderBatch instance
    """
    context_IDs = []
    input_ids = []
    start_positions = []
    end_positions = []
    answers_masks = []

    empty_sequence = torch.Tensor().new_full((max_length,), tensorizer.get_pad_id(), dtype=torch.long)

    for sample in samples:
        if is_train:
            positive_ctxs = sample.positive_passages
            negative_ctxs = sample.negative_passages
        else:
            positive_ctxs = []
            negative_ctxs = sample.positive_passages + sample.negative_passages
            # Need to re-sort samples based on their scores
            negative_ctxs = sorted(negative_ctxs, key=lambda x: x.score, reverse=True)
        question_token_ids = sample.question_token_ids

        sample_tensors = _create_question_passages_tensors(
            wiki_data,
            question_token_ids,
            tensorizer,
            positive_ctxs,
            negative_ctxs,
            passages_per_question,
            empty_sequence,
            max_n_answers,
            is_train,
            is_random=shuffle
        )

        if not sample_tensors:
            logger.warning('No valid passages combination for question=%s ', sample.question)
            continue
        context_ID, sample_input_ids, starts_tensor, ends_tensor, answer_mask = sample_tensors

        context_IDs.append(context_ID)
        input_ids.append(sample_input_ids)

        if is_train:
            start_positions.append(starts_tensor)
            end_positions.append(ends_tensor)
            answers_masks.append(answer_mask)

    context_IDs = torch.cat([IDs.unsqueeze(0) for IDs in context_IDs], dim=0)  # (N, M)
    input_ids = torch.cat([ids.unsqueeze(0) for ids in input_ids], dim=0)  # (N, M)

    if is_train:
        start_positions = torch.stack(start_positions, dim=0)
        end_positions = torch.stack(end_positions, dim=0)
        answers_masks = torch.stack(answers_masks, dim=0)

    return ReaderBatch(context_IDs, input_ids, start_positions, end_positions, answers_masks)
Example #27
0
def generate_question_vectors(
    question_encoder: torch.nn.Module,
    tensorizer: Tensorizer,
    questions: List[str],
    bsz: int,
    query_token: str = None,
    selector: RepTokenSelector = None,
) -> T:
    n = len(questions)
    query_vectors = []

    with torch.no_grad():
        for j, batch_start in enumerate(range(0, n, bsz)):
            batch_questions = questions[batch_start:batch_start + bsz]

            if query_token:
                # TODO: tmp workaround for EL, remove or revise
                if query_token == "[START_ENT]":
                    batch_tensors = [
                        _select_span_with_token(q,
                                                tensorizer,
                                                token_str=query_token)
                        for q in batch_questions
                    ]
                else:
                    batch_tensors = [
                        tensorizer.text_to_tensor(" ".join([query_token, q]))
                        for q in batch_questions
                    ]
            elif isinstance(batch_questions[0], T):
                batch_tensors = [q for q in batch_questions]
            else:
                batch_tensors = [
                    tensorizer.text_to_tensor(q) for q in batch_questions
                ]

            # TODO: this only works for Wav2vec pipeline but will crash the regular text pipeline
            max_vector_len = max(q_t.size(1) for q_t in batch_tensors)
            min_vector_len = min(q_t.size(1) for q_t in batch_tensors)

            if max_vector_len != min_vector_len:
                # TODO: _pad_to_len move to utils
                from dpr.models.reader import _pad_to_len
                batch_tensors = [
                    _pad_to_len(q.squeeze(0), 0, max_vector_len)
                    for q in batch_tensors
                ]

            q_ids_batch = torch.stack(batch_tensors, dim=0).cuda()
            q_seg_batch = torch.zeros_like(q_ids_batch).cuda()
            q_attn_mask = tensorizer.get_attn_mask(q_ids_batch)

            if selector:
                rep_positions = selector.get_positions(q_ids_batch, tensorizer)

                _, out, _ = BiEncoder.get_representation(
                    question_encoder,
                    q_ids_batch,
                    q_seg_batch,
                    q_attn_mask,
                    representation_token_pos=rep_positions,
                )
            else:
                _, out, _ = question_encoder(q_ids_batch, q_seg_batch,
                                             q_attn_mask)

            query_vectors.extend(out.cpu().split(1, dim=0))

            if len(query_vectors) % 100 == 0:
                logger.info("Encoded queries %d", len(query_vectors))

    query_tensor = torch.cat(query_vectors, dim=0)
    logger.info("Total encoded queries tensor %s", query_tensor.size())
    assert query_tensor.size(0) == len(questions)
    return query_tensor
Example #28
0
    def create_biencoder_input(
        cls,
        samples: List,
        tensorizer: Tensorizer,
        insert_title: bool,
        num_hard_negatives: int = 0,
        num_other_negatives: int = 0,
        shuffle: bool = True,
        shuffle_positives: bool = False,
        hard_neg_fallback: bool = True,
    ) -> BiEncoderBatch:
        """
        Creates a batch of the biencoder training tuple.
        :param samples: list of data items (from json) to create the batch for
        :param tensorizer: components to create model input tensors from a text sequence
        :param insert_title: enables title insertion at the beginning of the context sequences
        :param num_hard_negatives: amount of hard negatives per question (taken from samples' pools)
        :param num_other_negatives: amount of other negatives per question (taken from samples' pools)
        :param shuffle: shuffles negative passages pools
        :param shuffle_positives: shuffles positive passages pools
        :return: BiEncoderBatch tuple
        """
        question_tensors = []
        ctx_tensors = []
        positive_ctx_indices = []
        hard_neg_ctx_indices = []

        for sample in samples:
            # ctx+ & [ctx-] composition
            # as of now, take the first(gold) ctx+ only
            if shuffle and shuffle_positives:
                positive_ctxs = sample["positive_ctxs"]
                positive_ctx = positive_ctxs[np.random.choice(
                    len(positive_ctxs))]
            else:
                positive_ctx = sample["positive_ctxs"][0]

            neg_ctxs = sample["negative_ctxs"]
            hard_neg_ctxs = sample["hard_negative_ctxs"]

            if shuffle:
                random.shuffle(neg_ctxs)
                random.shuffle(hard_neg_ctxs)

            if hard_neg_fallback and len(hard_neg_ctxs) == 0:
                hard_neg_ctxs = neg_ctxs[0:num_hard_negatives]

            neg_ctxs = neg_ctxs[0:num_other_negatives]
            hard_neg_ctxs = hard_neg_ctxs[0:num_hard_negatives]

            all_ctxs = [positive_ctx] + neg_ctxs + hard_neg_ctxs
            hard_negatives_start_idx = 1
            hard_negatives_end_idx = 1 + len(hard_neg_ctxs)

            current_ctxs_len = len(ctx_tensors)

            sample_ctxs_tensors = [
                tensorizer.text_to_tensor(
                    ctx["text"],
                    title=ctx["title"] if
                    (insert_title and "title" in ctx) else None,
                ) for ctx in all_ctxs
            ]

            ctx_tensors.extend(sample_ctxs_tensors)
            positive_ctx_indices.append(current_ctxs_len)
            hard_neg_ctx_indices.append([
                i for i in range(
                    current_ctxs_len + hard_negatives_start_idx,
                    current_ctxs_len + hard_negatives_end_idx,
                )
            ])

            question_tensors.append(tensorizer.text_to_tensor(question))

        ctxs_tensor = torch.cat([ctx.view(1, -1) for ctx in ctx_tensors],
                                dim=0)
        questions_tensor = torch.cat([q.view(1, -1) for q in question_tensors],
                                     dim=0)

        ctx_segments = torch.zeros_like(ctxs_tensor)
        question_segments = torch.zeros_like(questions_tensor)

        return BiEncoderBatch(
            questions_tensor,
            question_segments,
            ctxs_tensor,
            ctx_segments,
            positive_ctx_indices,
            hard_neg_ctx_indices,
            "question",
        )
Example #29
0
def _preprocess_retriever_data(
    samples: List[Dict],
    bm25_samples: List[Tuple[Tuple[int, float]]],
    wiki_data: TokenizedWikipediaPassages,
    gold_info_file: Optional[str],
    gold_info_processed_file: str,
    tensorizer: Tensorizer,
    cfg: PreprocessingCfg = DEFAULT_PREPROCESSING_CFG_TRAIN,
    is_train_set: bool = True,
    check_pre_tokenized_data: bool = True,
) -> Iterable[DataSample]:
    """
    Converts retriever results into general retriever/reader training data.
    :param samples: samples from the retriever's json file results
    :param bm25_samples: bm25 retrieval results; list of tuples of tuples of (passage_id, score), where passages of each
        sample are already sorted by their scores
    :param gold_info_file: optional path for the 'gold passages & questions' file. Required to get best results for NQ
    :param gold_info_processed_file: path to the preprocessed gold passages pickle file. Unlike `gold_passages_file` which
        contains original gold passages, this file should contain processed, matched, 100-word split passages that match
        with the original gold passages.
    :param tensorizer: Tensorizer object for text to model input tensors conversions
    :param cfg: PreprocessingCfg object with positive and negative passage selection parameters
    :param is_train_set: if the data should be processed as a train set
    :return: iterable of DataSample objects which can be consumed by the reader model
    """
    gold_passage_map, canonical_questions = (_get_gold_ctx_dict(gold_info_file)
                                             if gold_info_file is not None else
                                             ({}, {}))
    processed_gold_passage_map = (
        _get_processed_gold_ctx_dict(gold_info_processed_file)
        if gold_info_processed_file else {})

    number_no_positive_samples = 0
    number_samples_from_gold = 0
    number_samples_with_gold = 0
    assert len(samples) == len(bm25_samples)

    for sample, bm25_sample in zip(samples, bm25_samples):
        # Refer to `_get_gold_ctx_dict` for why we need to distinguish between two types of questions
        # Here `processed_question` refer to tokenized questions, where `question` refer to
        # canonical questions.
        processed_question = sample["question"]
        if processed_question in canonical_questions:
            question = canonical_questions[processed_question]
        else:
            question = processed_question

        question_token_ids: np.ndarray = tensorizer.text_to_tensor(
            normalize_question(question)
            if cfg.normalize_questions else question,
            add_special_tokens=False,
        ).numpy()

        orig_answers = sample["answers"]
        if cfg.expand_answers:
            expanded_answers = [
                get_expanded_answer(answer) for answer in orig_answers
            ]
        else:
            expanded_answers = []
        all_answers = orig_answers + sum(expanded_answers, [])

        passages = _select_passages(
            wiki_data,
            sample,
            bm25_sample,
            question,
            processed_question,
            question_token_ids,
            orig_answers,
            expanded_answers,
            all_answers,
            tensorizer,
            gold_passage_map,
            processed_gold_passage_map,
            cfg,
            is_train_set,
            check_pre_tokenized_data,
        )
        gold_passages = passages[0]
        positive_passages, negative_passages, distantly_positive_passages = passages[
            1:4]
        bm25_positive_passages, bm25_negative_passages, bm25_distantly_positive_passages = passages[
            4:]

        if is_train_set and len(positive_passages) == 0:
            number_no_positive_samples += 1
            if cfg.skip_no_positives:
                continue

        if any(ctx for ctx in positive_passages if ctx.is_from_gold):
            number_samples_from_gold += 1

        if len(gold_passages) > 0:
            number_samples_with_gold += 1

        yield DataSample(
            question,
            question_token_ids=question_token_ids,
            answers=all_answers,
            orig_answers=orig_answers,
            expanded_answers=expanded_answers,
            # Gold
            gold_passages=gold_passages,
            # Dense
            positive_passages=positive_passages,
            distantly_positive_passages=distantly_positive_passages,
            negative_passages=negative_passages,
            # Sparse
            bm25_positive_passages=bm25_positive_passages,
            bm25_distantly_positive_passages=bm25_distantly_positive_passages,
            bm25_negative_passages=bm25_negative_passages,
        )

    logger.info(
        f"Number of samples whose at least one positive passage is "
        f"from the same article as the gold passage: {number_samples_from_gold}"
    )
    logger.info(
        f"Number of samples whose gold passage is available: {number_samples_with_gold}"
    )
    logger.info(
        f"Number of samples with no positive passages: {number_no_positive_samples}"
    )