Beispiel #1
0
def _find_answer_spans(
    tensorizer: Tensorizer,
    ctx: DataPassage,
    question: str,
    answers: List[str],
    answers_token_ids: List[List[int]],
    warn_if_no_answer: bool = False,
    raise_if_no_answer: bool = False,
    warn_if_has_answer: bool = False,
    raise_if_has_answer: bool = False,
    recheck_negatives: bool = False,
) -> DataPassage:
    if (not recheck_negatives) and (not ctx.has_answer):
        return ctx

    answer_spans = [
        _find_answer_positions(ctx.passage_token_ids, answers_token_ids[i])
        for i in range(len(answers))
    ]

    # flatten spans list
    answer_spans = [item for sublist in answer_spans for item in sublist]
    answers_spans = list(filter(None, answer_spans))
    ctx.answers_spans = answers_spans

    if len(answers_spans) == 0 and (warn_if_no_answer or raise_if_no_answer):
        passage_text = tensorizer.tensor_to_text(
            torch.from_numpy(ctx.passage_token_ids))
        passage_title = tensorizer.tensor_to_text(
            torch.from_numpy(ctx.title_token_ids))
        message = (
            f"No answer found in passage id={ctx.id} text={passage_text}, title={passage_title}, "
            f"answers={answers}, question={question}")

        if raise_if_no_answer:
            raise ValueError(message)
        else:
            logger.warning(message)

    if len(answers_spans) > 0 and (warn_if_has_answer or raise_if_has_answer):
        passage_text = tensorizer.tensor_to_text(
            torch.from_numpy(ctx.passage_token_ids))
        passage_title = tensorizer.tensor_to_text(
            torch.from_numpy(ctx.title_token_ids))
        message = (
            f"Answer FOUND in passage id={ctx.id} text={passage_text}, title={passage_title}, "
            f"answers={answers}, question={question}")

        if raise_if_has_answer:
            raise ValueError(message)
        else:
            logger.warning(message)

    ctx.has_answer = bool(answers_spans)

    return ctx
Beispiel #2
0
def _select_passages(
    wiki_data: TokenizedWikipediaPassages,
    sample: Dict,
    bm25_sample: Tuple[Tuple[int, float]],
    question: str,
    processed_question: str,
    question_token_ids: np.ndarray,
    answers: List[str],
    expanded_answers: List[List[str]],
    all_answers: List[str],
    tensorizer: Tensorizer,
    gold_passage_map: Dict[str, DataPassage],
    processed_gold_passage_map: Dict[str, DataPassage],
    cfg: PreprocessingCfg,
    is_train_set: bool,
    check_pre_tokenized_data: bool,
) -> Tuple[List[DataPassage], List[DataPassage], List[DataPassage],
           List[DataPassage], List[DataPassage], List[DataPassage],
           List[DataPassage]]:
    """
    Select and process valid passages for training/evaluation.
    """
    # Tokenize answers
    answers_token_ids: List[np.ndarray] = [
        tensorizer.text_to_tensor(a, add_special_tokens=False).numpy()
        for a in all_answers
    ]

    # Gold context; we want to cover more gold passages, that's why we are matching both
    # `processed_question` and `question` (canonical question).
    if question in processed_gold_passage_map or processed_question in processed_gold_passage_map:
        if question in processed_gold_passage_map:
            gold_ctx = processed_gold_passage_map[question]
        else:
            gold_ctx = processed_gold_passage_map[processed_question]

        gold_ctx = _load_tokens_into_ctx(
            gold_ctx,
            question_token_ids,
            wiki_data,
            tensorizer,
            check_pre_tokenized_data,
        )  # load question, passage title and passage tokens into the context object
        gold_ctx = _find_answer_spans(
            tensorizer,
            gold_ctx,
            question,
            all_answers,
            answers_token_ids,
            warn_if_no_answer=True,
            raise_if_no_answer=False,
            warn_if_has_answer=False,
            raise_if_has_answer=False,
            recheck_negatives=False,
        )  # find answer spans for all passages
        if gold_ctx.has_answer:
            gold_ctxs = [gold_ctx]
        else:
            gold_ctxs = []
    else:
        gold_ctxs = []

    # Densely retrieved contexts
    ctxs = [DataPassage(is_from_bm25=False, **ctx) for ctx in sample["ctxs"]]
    ctxs = [
        _load_tokens_into_ctx(ctx, question_token_ids, wiki_data, tensorizer,
                              check_pre_tokenized_data) for ctx in ctxs
    ]  # load question, passage title and passage tokens into the context object
    # Find answer spans for all passages
    ctxs: List[DataPassage] = [
        _find_answer_spans(
            tensorizer,
            ctx,
            question,
            all_answers,
            answers_token_ids,
            warn_if_no_answer=ctx.
            has_answer,  # warn if originally it contains answer string
            warn_if_has_answer=(
                not ctx.has_answer
            ),  # warn if originally it does NOT contain answer string
            recheck_negatives=cfg.recheck_negatives,
        ) for ctx in ctxs
    ]

    # Sparsely retrieved contexts (BM25)
    bm25_ctxs = [
        DataPassage(id=passage_id, score=score, is_from_bm25=True)
        for passage_id, score in bm25_sample
    ]
    bm25_ctxs = [
        _load_tokens_into_ctx(ctx, question_token_ids, wiki_data, tensorizer,
                              check_pre_tokenized_data) for ctx in bm25_ctxs
    ]  # load question, passage title and passage tokens into the context object
    # Find answer spans for all passages
    bm25_ctxs: List[DataPassage] = [
        _find_answer_spans(
            tensorizer,
            ctx,
            question,
            all_answers,
            answers_token_ids,
            warn_if_no_answer=False,
            warn_if_has_answer=False,
            recheck_negatives=True,  # `has_answer` of any BM25 passage is None
        ) for ctx in bm25_ctxs
    ]

    # Filter positives and negatives using distant supervision
    positive_samples = list(filter(lambda ctx: ctx.has_answer, ctxs))
    distantly_positive_samples: List[DataPassage] = []
    negative_samples = list(filter(lambda ctx: not ctx.has_answer, ctxs))
    bm25_positive_samples = list(filter(lambda ctx: ctx.has_answer, bm25_ctxs))
    bm25_distantly_positive_samples: List[DataPassage] = []
    bm25_negative_samples = list(
        filter(lambda ctx: not ctx.has_answer, bm25_ctxs))

    # Filter unwanted positive passages if training
    if is_train_set:

        # Get positives that are from gold positive passages
        if cfg.gold_page_only_positives:
            selected_positive_ctxs: List[DataPassage] = []
            selected_negative_ctxs: List[DataPassage] = negative_samples
            selected_bm25_positive_ctxs: List[DataPassage] = []
            selected_bm25_negative_ctxs: List[
                DataPassage] = bm25_negative_samples

            for positives, selected_positives, selected_negatives, distantly_positives in [
                (positive_samples, selected_positive_ctxs,
                 selected_negative_ctxs, distantly_positive_samples),
                (bm25_positive_samples, selected_bm25_positive_ctxs,
                 selected_bm25_negative_ctxs, bm25_distantly_positive_samples)
            ]:

                for ctx in positives:
                    is_from_gold = _is_from_gold_wiki_page(
                        gold_passage_map, ctx,
                        tensorizer.tensor_to_text(
                            torch.from_numpy(ctx.title_token_ids)), question)
                    if is_from_gold:
                        selected_positives.append(ctx)
                    else:  # if it has answer but does not come from gold passage
                        if cfg.should_negatives_contain_answer:
                            selected_negatives.append(ctx)
                        else:
                            distantly_positives.append(ctx)
        else:
            selected_positive_ctxs = positive_samples
            selected_negative_ctxs = negative_samples
            selected_bm25_positive_ctxs = bm25_positive_samples
            selected_bm25_negative_ctxs = bm25_negative_samples

        # Fallback to positive ctx not from gold passages
        if len(selected_positive_ctxs) == 0:
            selected_positive_ctxs = positive_samples
        if len(selected_bm25_positive_ctxs) == 0:
            selected_bm25_positive_ctxs = bm25_positive_samples

        # Optionally include gold passage itself if it is still not in the positives list
        if cfg.include_gold_passage:
            if question in gold_passage_map:
                gold_passage = gold_passage_map[question]
                gold_passage.is_gold = True
                gold_passage.has_answer = True  # assuming it has answer

                gold_passage = _find_answer_spans(
                    tensorizer,
                    gold_passage,
                    question,
                    all_answers,
                    answers_token_ids,
                    warn_if_no_answer=False,
                    warn_if_has_answer=False,
                    recheck_negatives=True,
                )  # warn below

                if not gold_passage.has_answer:
                    logger.warning(
                        "No answer found in GOLD passage: passage='%s', question='%s', answers=%s, expanded_answers=%s",
                        gold_passage.passage_text,
                        question,
                        answers,
                        expanded_answers,
                    )
                selected_positive_ctxs.append(
                    gold_passage
                )  # append anyway, since we need this for retriever (not reader)
            else:
                logger.warning(f"Question '{question}' has no gold positive")

    else:
        # NOTE: See `create_reader_input` function in `reader.py` to see how
        # positive and negative samples are merged (keeping their original order)
        selected_positive_ctxs = positive_samples
        selected_negative_ctxs = negative_samples
        selected_bm25_positive_ctxs = bm25_positive_samples
        selected_bm25_negative_ctxs = bm25_negative_samples

    # Restrict number of BM25 passages
    selected_bm25_positive_ctxs = selected_bm25_positive_ctxs[:cfg.
                                                              max_bm25_positives]
    selected_bm25_negative_ctxs = selected_bm25_negative_ctxs[:cfg.
                                                              max_bm25_negatives]

    return (
        gold_ctxs,
        selected_positive_ctxs,
        selected_negative_ctxs,
        distantly_positive_samples,
        selected_bm25_positive_ctxs,
        selected_bm25_negative_ctxs,
        bm25_distantly_positive_samples,
    )