def _find_answer_spans( tensorizer: Tensorizer, ctx: DataPassage, question: str, answers: List[str], answers_token_ids: List[List[int]], warn_if_no_answer: bool = False, raise_if_no_answer: bool = False, warn_if_has_answer: bool = False, raise_if_has_answer: bool = False, recheck_negatives: bool = False, ) -> DataPassage: if (not recheck_negatives) and (not ctx.has_answer): return ctx answer_spans = [ _find_answer_positions(ctx.passage_token_ids, answers_token_ids[i]) for i in range(len(answers)) ] # flatten spans list answer_spans = [item for sublist in answer_spans for item in sublist] answers_spans = list(filter(None, answer_spans)) ctx.answers_spans = answers_spans if len(answers_spans) == 0 and (warn_if_no_answer or raise_if_no_answer): passage_text = tensorizer.tensor_to_text( torch.from_numpy(ctx.passage_token_ids)) passage_title = tensorizer.tensor_to_text( torch.from_numpy(ctx.title_token_ids)) message = ( f"No answer found in passage id={ctx.id} text={passage_text}, title={passage_title}, " f"answers={answers}, question={question}") if raise_if_no_answer: raise ValueError(message) else: logger.warning(message) if len(answers_spans) > 0 and (warn_if_has_answer or raise_if_has_answer): passage_text = tensorizer.tensor_to_text( torch.from_numpy(ctx.passage_token_ids)) passage_title = tensorizer.tensor_to_text( torch.from_numpy(ctx.title_token_ids)) message = ( f"Answer FOUND in passage id={ctx.id} text={passage_text}, title={passage_title}, " f"answers={answers}, question={question}") if raise_if_has_answer: raise ValueError(message) else: logger.warning(message) ctx.has_answer = bool(answers_spans) return ctx
def _select_passages( wiki_data: TokenizedWikipediaPassages, sample: Dict, bm25_sample: Tuple[Tuple[int, float]], question: str, processed_question: str, question_token_ids: np.ndarray, answers: List[str], expanded_answers: List[List[str]], all_answers: List[str], tensorizer: Tensorizer, gold_passage_map: Dict[str, DataPassage], processed_gold_passage_map: Dict[str, DataPassage], cfg: PreprocessingCfg, is_train_set: bool, check_pre_tokenized_data: bool, ) -> Tuple[List[DataPassage], List[DataPassage], List[DataPassage], List[DataPassage], List[DataPassage], List[DataPassage], List[DataPassage]]: """ Select and process valid passages for training/evaluation. """ # Tokenize answers answers_token_ids: List[np.ndarray] = [ tensorizer.text_to_tensor(a, add_special_tokens=False).numpy() for a in all_answers ] # Gold context; we want to cover more gold passages, that's why we are matching both # `processed_question` and `question` (canonical question). if question in processed_gold_passage_map or processed_question in processed_gold_passage_map: if question in processed_gold_passage_map: gold_ctx = processed_gold_passage_map[question] else: gold_ctx = processed_gold_passage_map[processed_question] gold_ctx = _load_tokens_into_ctx( gold_ctx, question_token_ids, wiki_data, tensorizer, check_pre_tokenized_data, ) # load question, passage title and passage tokens into the context object gold_ctx = _find_answer_spans( tensorizer, gold_ctx, question, all_answers, answers_token_ids, warn_if_no_answer=True, raise_if_no_answer=False, warn_if_has_answer=False, raise_if_has_answer=False, recheck_negatives=False, ) # find answer spans for all passages if gold_ctx.has_answer: gold_ctxs = [gold_ctx] else: gold_ctxs = [] else: gold_ctxs = [] # Densely retrieved contexts ctxs = [DataPassage(is_from_bm25=False, **ctx) for ctx in sample["ctxs"]] ctxs = [ _load_tokens_into_ctx(ctx, question_token_ids, wiki_data, tensorizer, check_pre_tokenized_data) for ctx in ctxs ] # load question, passage title and passage tokens into the context object # Find answer spans for all passages ctxs: List[DataPassage] = [ _find_answer_spans( tensorizer, ctx, question, all_answers, answers_token_ids, warn_if_no_answer=ctx. has_answer, # warn if originally it contains answer string warn_if_has_answer=( not ctx.has_answer ), # warn if originally it does NOT contain answer string recheck_negatives=cfg.recheck_negatives, ) for ctx in ctxs ] # Sparsely retrieved contexts (BM25) bm25_ctxs = [ DataPassage(id=passage_id, score=score, is_from_bm25=True) for passage_id, score in bm25_sample ] bm25_ctxs = [ _load_tokens_into_ctx(ctx, question_token_ids, wiki_data, tensorizer, check_pre_tokenized_data) for ctx in bm25_ctxs ] # load question, passage title and passage tokens into the context object # Find answer spans for all passages bm25_ctxs: List[DataPassage] = [ _find_answer_spans( tensorizer, ctx, question, all_answers, answers_token_ids, warn_if_no_answer=False, warn_if_has_answer=False, recheck_negatives=True, # `has_answer` of any BM25 passage is None ) for ctx in bm25_ctxs ] # Filter positives and negatives using distant supervision positive_samples = list(filter(lambda ctx: ctx.has_answer, ctxs)) distantly_positive_samples: List[DataPassage] = [] negative_samples = list(filter(lambda ctx: not ctx.has_answer, ctxs)) bm25_positive_samples = list(filter(lambda ctx: ctx.has_answer, bm25_ctxs)) bm25_distantly_positive_samples: List[DataPassage] = [] bm25_negative_samples = list( filter(lambda ctx: not ctx.has_answer, bm25_ctxs)) # Filter unwanted positive passages if training if is_train_set: # Get positives that are from gold positive passages if cfg.gold_page_only_positives: selected_positive_ctxs: List[DataPassage] = [] selected_negative_ctxs: List[DataPassage] = negative_samples selected_bm25_positive_ctxs: List[DataPassage] = [] selected_bm25_negative_ctxs: List[ DataPassage] = bm25_negative_samples for positives, selected_positives, selected_negatives, distantly_positives in [ (positive_samples, selected_positive_ctxs, selected_negative_ctxs, distantly_positive_samples), (bm25_positive_samples, selected_bm25_positive_ctxs, selected_bm25_negative_ctxs, bm25_distantly_positive_samples) ]: for ctx in positives: is_from_gold = _is_from_gold_wiki_page( gold_passage_map, ctx, tensorizer.tensor_to_text( torch.from_numpy(ctx.title_token_ids)), question) if is_from_gold: selected_positives.append(ctx) else: # if it has answer but does not come from gold passage if cfg.should_negatives_contain_answer: selected_negatives.append(ctx) else: distantly_positives.append(ctx) else: selected_positive_ctxs = positive_samples selected_negative_ctxs = negative_samples selected_bm25_positive_ctxs = bm25_positive_samples selected_bm25_negative_ctxs = bm25_negative_samples # Fallback to positive ctx not from gold passages if len(selected_positive_ctxs) == 0: selected_positive_ctxs = positive_samples if len(selected_bm25_positive_ctxs) == 0: selected_bm25_positive_ctxs = bm25_positive_samples # Optionally include gold passage itself if it is still not in the positives list if cfg.include_gold_passage: if question in gold_passage_map: gold_passage = gold_passage_map[question] gold_passage.is_gold = True gold_passage.has_answer = True # assuming it has answer gold_passage = _find_answer_spans( tensorizer, gold_passage, question, all_answers, answers_token_ids, warn_if_no_answer=False, warn_if_has_answer=False, recheck_negatives=True, ) # warn below if not gold_passage.has_answer: logger.warning( "No answer found in GOLD passage: passage='%s', question='%s', answers=%s, expanded_answers=%s", gold_passage.passage_text, question, answers, expanded_answers, ) selected_positive_ctxs.append( gold_passage ) # append anyway, since we need this for retriever (not reader) else: logger.warning(f"Question '{question}' has no gold positive") else: # NOTE: See `create_reader_input` function in `reader.py` to see how # positive and negative samples are merged (keeping their original order) selected_positive_ctxs = positive_samples selected_negative_ctxs = negative_samples selected_bm25_positive_ctxs = bm25_positive_samples selected_bm25_negative_ctxs = bm25_negative_samples # Restrict number of BM25 passages selected_bm25_positive_ctxs = selected_bm25_positive_ctxs[:cfg. max_bm25_positives] selected_bm25_negative_ctxs = selected_bm25_negative_ctxs[:cfg. max_bm25_negatives] return ( gold_ctxs, selected_positive_ctxs, selected_negative_ctxs, distantly_positive_samples, selected_bm25_positive_ctxs, selected_bm25_negative_ctxs, bm25_distantly_positive_samples, )