def generate_question_vectors( question_encoder: torch.nn.Module, tensorizer: Tensorizer, questions: List[str], bsz: int, query_token: str = None, selector: RepTokenSelector = None, ) -> T: n = len(questions) query_vectors = [] with torch.no_grad(): for j, batch_start in enumerate(range(0, n, bsz)): batch_questions = questions[batch_start:batch_start + bsz] if query_token: if query_token == "[START_ENT]": batch_token_tensors = [ _select_span_with_token(q, tensorizer, token_str=query_token) for q in batch_questions ] else: batch_token_tensors = [ tensorizer.text_to_tensor(" ".join([query_token, q])) for q in batch_questions ] else: batch_token_tensors = [ tensorizer.text_to_tensor(q) for q in batch_questions ] q_ids_batch = torch.stack(batch_token_tensors, dim=0).cuda() q_seg_batch = torch.zeros_like(q_ids_batch).cuda() q_attn_mask = tensorizer.get_attn_mask(q_ids_batch) if selector: rep_positions = selector.get_positions(q_ids_batch, tensorizer) _, out, _ = BiEncoder.get_representation( question_encoder, q_ids_batch, q_seg_batch, q_attn_mask, representation_token_pos=rep_positions, ) else: _, out, _ = question_encoder(q_ids_batch, q_seg_batch, q_attn_mask) query_vectors.extend(out.cpu().split(1, dim=0)) if len(query_vectors) % 100 == 0: logger.info("Encoded queries %d", len(query_vectors)) query_tensor = torch.cat(query_vectors, dim=0) logger.info("Total encoded queries tensor %s", query_tensor.size()) assert query_tensor.size(0) == len(questions) return query_tensor
def gen_ctx_vectors(ctx_rows: List[Tuple[object, str, str]], model: nn.Module, tensorizer: Tensorizer, insert_title: bool = True) -> List[Tuple[object, np.array]]: n = len(ctx_rows) bsz = args.batch_size total = 0 results = [] for j, batch_start in enumerate(range(0, n, bsz)): batch_token_tensors = [tensorizer.text_to_tensor(ctx[1], title=ctx[2] if insert_title else None) for ctx in ctx_rows[batch_start:batch_start + bsz]] ctx_ids_batch = torch.stack(batch_token_tensors, dim=0) ctx_seg_batch = torch.zeros_like(ctx_ids_batch) ctx_attn_mask = tensorizer.get_attn_mask(ctx_ids_batch) with torch.no_grad(): _, out, _ = model(ctx_ids_batch, ctx_seg_batch, ctx_attn_mask) out = out.cpu() ctx_ids = [r[0] for r in ctx_rows[batch_start:batch_start + bsz]] assert len(ctx_ids) == out.size(0) total += len(ctx_ids) results.extend([ (ctx_ids[i], out[i].view(-1).numpy()) for i in range(out.size(0)) ]) if total % 10 == 0: logger.info('Encoded passages %d', total) return results
def _do_biencoder_fwd_pass(model: nn.Module, input: BiEncoderBatch, tensorizer: Tensorizer, args) -> (torch.Tensor, int): input = BiEncoderBatch(**move_to_device(input._asdict(), args.device)) q_attn_mask = tensorizer.get_attn_mask(input.question_ids) ctx_attn_mask = tensorizer.get_attn_mask(input.context_ids) if model.training: model_out = model(input.question_ids, input.question_segments, q_attn_mask, input.context_ids, input.ctx_segments, ctx_attn_mask) else: with torch.no_grad(): model_out = model(input.question_ids, input.question_segments, q_attn_mask, input.context_ids, input.ctx_segments, ctx_attn_mask) local_q_vector, local_ctx_vectors = model_out loss_function = BiEncoderNllLoss() loss, is_correct = _calc_loss(args, loss_function, local_q_vector, local_ctx_vectors, input.is_positive, input.hard_negatives) is_correct = is_correct.sum().item() if args.n_gpu > 1: loss = loss.mean() if args.gradient_accumulation_steps > 1: loss = loss / args.gradient_accumulation_steps return loss, is_correct
def _do_biencoder_fwd_pass( model: nn.Module, input: BiEncoderBatch, tensorizer: Tensorizer, cfg, encoder_type: str, rep_positions=0, loss_scale: float = None, ) -> Tuple[torch.Tensor, int]: input = BiEncoderBatch(**move_to_device(input._asdict(), cfg.device)) q_attn_mask = tensorizer.get_attn_mask(input.question_ids) ctx_attn_mask = tensorizer.get_attn_mask(input.context_ids) if model.training: model_out = model( input.question_ids, input.question_segments, q_attn_mask, input.context_ids, input.ctx_segments, ctx_attn_mask, encoder_type=encoder_type, representation_token_pos=rep_positions, ) else: with torch.no_grad(): model_out = model( input.question_ids, input.question_segments, q_attn_mask, input.context_ids, input.ctx_segments, ctx_attn_mask, encoder_type=encoder_type, representation_token_pos=rep_positions, ) local_q_vector, local_ctx_vectors = model_out loss_function = BiEncoderNllLoss() loss, is_correct = _calc_loss( cfg, loss_function, local_q_vector, local_ctx_vectors, input.is_positive, input.hard_negatives, loss_scale=loss_scale, ) is_correct = is_correct.sum().item() if cfg.n_gpu > 1: loss = loss.mean() if cfg.train.gradient_accumulation_steps > 1: loss = loss / cfg.gradient_accumulation_steps return loss, is_correct
def create_biencoder_input_from_reader_input( tensorizer: Tensorizer, reader_batch: ReaderBatch, ) -> BiEncoderBatch: input_ids = reader_batch.input_ids # (N, M, L) question_ids: List[T] = [] # len N context_ids: List[T] = [] # len N * M for input_id_i in input_ids: for j, input_id in enumerate(input_id_i): ids = tensorizer.unconcatenate_inputs( input_id, components={"question", "passage_title", "passage"}) if ids is None: # full padding context_ids.append(input_id) continue # Question question_id = tensorizer.concatenate_inputs( ids={"question": ids["question"].tolist()}, get_passage_offset=False, to_max_length=True, ) if j == 0: question_ids.append(question_id) else: assert (question_id == question_ids[-1]).all() # Passage passage_title = ids["passage_title"] passage = ids["passage"] context_ids.append( tensorizer.concatenate_inputs( ids={ "passage_title": passage_title.tolist(), "passage": passage.tolist() }, get_passage_offset=False, to_max_length=True, )) question_ids = torch.stack(question_ids) context_ids = torch.stack(context_ids) question_segments = torch.zeros_like(question_ids) context_segments = torch.zeros_like(context_ids) biencoder_batch = BiEncoderBatch( question_ids=question_ids, question_segments=question_segments, context_IDs=None, # not used context_ids=context_ids, ctx_segments=context_segments, is_positive=None, # not used hard_negatives=None, # not used encoder_type=None, # not used ) return biencoder_batch
def _run_preprocessing(tensorizer: Tensorizer): # temporarily disable auto-padding to save disk space usage of serialized files tensorizer.set_pad_to_max(False) serialized_files = convert_retriever_results(is_train, data_files[0], out_file_prefix, gold_passages_src, self.tensorizer, num_workers=self.args.num_workers) tensorizer.set_pad_to_max(True) return serialized_files
def _find_answer_spans( tensorizer: Tensorizer, ctx: DataPassage, question: str, answers: List[str], answers_token_ids: List[List[int]], warn_if_no_answer: bool = False, raise_if_no_answer: bool = False, warn_if_has_answer: bool = False, raise_if_has_answer: bool = False, recheck_negatives: bool = False, ) -> DataPassage: if (not recheck_negatives) and (not ctx.has_answer): return ctx answer_spans = [ _find_answer_positions(ctx.passage_token_ids, answers_token_ids[i]) for i in range(len(answers)) ] # flatten spans list answer_spans = [item for sublist in answer_spans for item in sublist] answers_spans = list(filter(None, answer_spans)) ctx.answers_spans = answers_spans if len(answers_spans) == 0 and (warn_if_no_answer or raise_if_no_answer): passage_text = tensorizer.tensor_to_text( torch.from_numpy(ctx.passage_token_ids)) passage_title = tensorizer.tensor_to_text( torch.from_numpy(ctx.title_token_ids)) message = ( f"No answer found in passage id={ctx.id} text={passage_text}, title={passage_title}, " f"answers={answers}, question={question}") if raise_if_no_answer: raise ValueError(message) else: logger.warning(message) if len(answers_spans) > 0 and (warn_if_has_answer or raise_if_has_answer): passage_text = tensorizer.tensor_to_text( torch.from_numpy(ctx.passage_token_ids)) passage_title = tensorizer.tensor_to_text( torch.from_numpy(ctx.title_token_ids)) message = ( f"Answer FOUND in passage id={ctx.id} text={passage_text}, title={passage_title}, " f"answers={answers}, question={question}") if raise_if_has_answer: raise ValueError(message) else: logger.warning(message) ctx.has_answer = bool(answers_spans) return ctx
def _extend_span_to_full_words(tensorizer: Tensorizer, tokens: List[int], span: Tuple[int, int]) -> Tuple[int, int]: start_index, end_index = span max_len = len(tokens) while start_index > 0 and tensorizer.is_sub_word_id(tokens[start_index]): start_index -= 1 while end_index < max_len - 1 and tensorizer.is_sub_word_id(tokens[end_index + 1]): end_index += 1 return start_index, end_index
def gen_ctx_vectors( cfg: DictConfig, ctx_rows: List[Tuple[object, BiEncoderPassage]], model: nn.Module, tensorizer: Tensorizer, insert_title: bool = True, ) -> List[Tuple[object, np.array]]: n = len(ctx_rows) bsz = cfg.batch_size total = 0 results = [] for j, batch_start in enumerate(range(0, n, bsz)): batch = ctx_rows[batch_start : batch_start + bsz] batch_token_tensors = [ tensorizer.text_to_tensor( ctx[1].text, title=ctx[1].title if insert_title else None ) for ctx in batch ] ctx_ids_batch = move_to_device( torch.stack(batch_token_tensors, dim=0), cfg.device ) ctx_seg_batch = move_to_device(torch.zeros_like(ctx_ids_batch), cfg.device) ctx_attn_mask = move_to_device( tensorizer.get_attn_mask(ctx_ids_batch), cfg.device ) with torch.no_grad(): _, out, _ = model(ctx_ids_batch, ctx_seg_batch, ctx_attn_mask) out = out.cpu() ctx_ids = [r[0] for r in batch] extra_info = [] if len(batch[0]) > 3: extra_info = [r[3:] for r in batch] assert len(ctx_ids) == out.size(0) total += len(ctx_ids) # TODO: refactor to avoid 'if' if extra_info: results.extend( [ (ctx_ids[i], out[i].view(-1).numpy(), *extra_info[i]) for i in range(out.size(0)) ] ) else: results.extend( [(ctx_ids[i], out[i].view(-1).numpy()) for i in range(out.size(0))] ) if total % 10 == 0: logger.info("Encoded passages %d", total) return results
def gen_ctx_vectors( ctx_rows: List[Tuple[object, str, str]], model: nn.Module, tensorizer: Tensorizer, insert_title: bool = True) -> List[Tuple[object, np.array]]: n = len(ctx_rows) bsz = args.batch_size total = 0 results = [] for j, batch_start in enumerate(range(0, n, bsz)): all_txt = [] for ctx in ctx_rows[batch_start:batch_start + bsz]: if ctx[2]: txt = ['title:', ctx[2], 'context:', ctx[1]] else: txt = ['context:', ctx[1]] txt = ' '.join(txt) all_txt.append(txt) batch_token_tensors = [ tensorizer.text_to_tensor(txt, max_length=250) for txt in all_txt ] #batch_token_tensors = [tensorizer.text_to_tensor(ctx[1], title=ctx[2] if insert_title else None) for ctx in #original # ctx_rows[batch_start:batch_start + bsz]] #original ctx_ids_batch = move_to_device(torch.stack(batch_token_tensors, dim=0), args.device) ctx_seg_batch = move_to_device(torch.zeros_like(ctx_ids_batch), args.device) ctx_attn_mask = move_to_device(tensorizer.get_attn_mask(ctx_ids_batch), args.device) with torch.no_grad(): _, out, _ = model(ctx_ids_batch, ctx_seg_batch, ctx_attn_mask) out = out.cpu() ctx_ids = [r[0] for r in ctx_rows[batch_start:batch_start + bsz]] assert len(ctx_ids) == out.size(0) total += len(ctx_ids) #results.extend([ # (ctx_ids[i], out[i].view(-1).numpy()) # for i in range(out.size(0)) #]) results.extend([(ctx_ids[i], out[i].numpy()) for i in range(out.size(0))]) if total % 10 == 0: logger.info('Encoded passages %d', total) return results
def get_positions(self, input_ids: T, tenzorizer: Tensorizer, model: torch.nn.Module = None): if not self.token_id: self.token_id = tenzorizer.get_token_id(self.token) token_indexes = (input_ids == self.token_id).nonzero() # check if all samples in input_ids has index presence and out a default value otherwise bsz = input_ids.size(0) if bsz == token_indexes.size(0): return token_indexes token_indexes_result = [] found_idx_cnt = 0 for i in range(bsz): if (found_idx_cnt < token_indexes.size(0) and token_indexes[found_idx_cnt][0] == i): # this samples has the special token token_indexes_result.append(token_indexes[found_idx_cnt]) found_idx_cnt += 1 else: logger.warning("missing special token %s", input_ids[i]) token_indexes_result.append( torch.tensor([i, 0]).to(input_ids.device) ) # setting 0-th token, i.e. CLS for BERT as the special one token_indexes_result = torch.stack(token_indexes_result, dim=0) return token_indexes_result
def _load_tokens_into_ctx( ctx: DataPassage, question_token_ids: np.ndarray, wiki_data: TokenizedWikipediaPassages, tensorizer: Tensorizer, check_pre_tokenized_data: bool = True, ) -> DataPassage: tokens = wiki_data.get_tokenized_data(int(ctx.id)) # Double check if needed if ctx.passage_text is not None: orig_passage_ids = tensorizer.text_to_tensor( ctx.passage_text, add_special_tokens=False, ).numpy() if check_pre_tokenized_data and (len(orig_passage_ids) != len(tokens["passage_token_ids"]) or \ not (orig_passage_ids == tokens["passage_token_ids"]).all()): raise ValueError( f"Passage token mismatch: id: {ctx.id}, orig: {orig_passage_ids}, " f"pre-processed: {tokens['passage_token_ids']}. If the sequence lengths are different," f" this might be because the maximum length of the tokenizer is set differently during " f"pre-processing and training.") orig_title_ids = tensorizer.text_to_tensor( ctx.title, add_special_tokens=False, ).numpy() if check_pre_tokenized_data and (len(orig_title_ids) != len(tokens["title_token_ids"]) or \ not (orig_title_ids == tokens["title_token_ids"]).all()): raise ValueError( f"Passage title token mismatch: id: {ctx.id}, orig: {orig_title_ids}, " f"pre-processed: {tokens['title_token_ids']}. If the sequence lengths are different," f" this might be because the maximum length of the tokenizer is set differently during " f"pre-processing and training.") ctx.load_tokens( question_token_ids=question_token_ids, **tokens) # load question, passage and passage title tokens # Remove redundant data ctx.on_serialize(remove_tokens=False) return ctx
def get_best_spans( tensorizer: Tensorizer, start_logits: List, end_logits: List, ctx_ids: List, max_answer_length: int, passage_idx: int, relevance_score: float, top_spans: int = 1, ) -> List[SpanPrediction]: """ Finds the best answer span for the extractive Q&A model """ scores = [] for (i, s) in enumerate(start_logits): for (j, e) in enumerate(end_logits[i : i + max_answer_length]): scores.append(((i, i + j), s + e)) scores = sorted(scores, key=lambda x: x[1], reverse=True) chosen_span_intervals = [] best_spans = [] for (start_index, end_index), score in scores: assert start_index <= end_index length = end_index - start_index + 1 assert length <= max_answer_length if any( [ start_index <= prev_start_index <= prev_end_index <= end_index or prev_start_index <= start_index <= end_index <= prev_end_index for (prev_start_index, prev_end_index) in chosen_span_intervals ] ): continue # extend bpe subtokens to full tokens start_index, end_index = _extend_span_to_full_words( tensorizer, ctx_ids, (start_index, end_index) ) predicted_answer = tensorizer.to_string(ctx_ids[start_index : end_index + 1]).upper() best_spans.append( SpanPrediction( predicted_answer, score, relevance_score, passage_idx, ctx_ids ) ) chosen_span_intervals.append((start_index, end_index)) if len(chosen_span_intervals) == top_spans: break return best_spans
def _select_span_with_token( text: str, tensorizer: Tensorizer, token_str: str = "[START_ENT]" ) -> T: id = tensorizer.get_token_id(token_str) query_tensor = tensorizer.text_to_tensor(text) if id not in query_tensor: query_tensor_full = tensorizer.text_to_tensor(text, apply_max_len=False) token_indexes = (query_tensor_full == id).nonzero() if token_indexes.size(0) > 0: start_pos = token_indexes[0, 0].item() # add some randomization to avoid overfitting to a specific token position left_shit = int(tensorizer.max_length / 2) rnd_shift = int((rnd.random() - 0.5) * left_shit / 2) left_shit += rnd_shift query_tensor = query_tensor_full[start_pos - left_shit :] cls_id = tensorizer.tokenizer.cls_token_id if query_tensor[0] != cls_id: query_tensor = torch.cat([torch.tensor([cls_id]), query_tensor], dim=0) from dpr.models.reader import _pad_to_len query_tensor = _pad_to_len( query_tensor, tensorizer.get_pad_id(), tensorizer.max_length ) query_tensor[-1] = tensorizer.tokenizer.sep_token_id # logger.info('aligned query_tensor %s', query_tensor) assert id in query_tensor, "query_tensor={}".format(query_tensor) return query_tensor else: raise RuntimeError( "[START_ENT] toke not found for Entity Linking sample query={}".format( text ) ) else: return query_tensor
def get_positions(self, input_ids: T, tenzorizer: Tensorizer, model: torch.nn.Module): attention_masks = tenzorizer.get_attn_mask(input_ids) rep_positions = [] for attention_mask in attention_masks: if model.training: input_length = (attention_mask != 0).sum() rep_position = random.randint(0, input_length - 1) rep_positions.append(rep_position) else: # Fall back to default rep_positions.append(self.static_position) rep_positions = torch.tensor(rep_positions, dtype=torch.int8).unsqueeze(-1).repeat( 1, 2) return rep_positions
def _select_reader_passages( sample: Dict, question: str, tensorizer: Tensorizer, gold_passage_map: Optional[Dict[str, ReaderPassage]], gold_page_only_positives: bool, max_positives: int, max1_negatives: int, max2_negatives: int, max_retriever_passages: int, include_gold_passage: bool, is_train_set: bool, ) -> Tuple[List[ReaderPassage], List[ReaderPassage]]: answers = sample["answers"] ctxs = [ReaderPassage(**ctx) for ctx in sample["ctxs"]][0:max_retriever_passages] answers_token_ids = [tensorizer.text_to_tensor(a, add_special_tokens=False) for a in answers] if is_train_set: positive_samples = list(filter(lambda ctx: ctx.has_answer, ctxs)) negative_samples = list(filter(lambda ctx: not ctx.has_answer, ctxs)) else: positive_samples = [] negative_samples = ctxs positive_ctxs_from_gold_page = ( list( filter( lambda ctx: _is_from_gold_wiki_page(gold_passage_map, ctx.title, question), positive_samples, ) ) if gold_page_only_positives and gold_passage_map else [] ) def find_answer_spans(ctx: ReaderPassage): if ctx.has_answer: if ctx.passage_token_ids is None: ctx.passage_token_ids = tensorizer.text_to_tensor(ctx.passage_text, add_special_tokens=False) answer_spans = [ _find_answer_positions(ctx.passage_token_ids, answers_token_ids[i]) for i in range(len(answers)) ] # flatten spans list answer_spans = [item for sublist in answer_spans for item in sublist] answers_spans = list(filter(None, answer_spans)) ctx.answers_spans = answers_spans if not answers_spans: logger.warning( "No answer found in passage id=%s text=%s, answers=%s, question=%s", ctx.id, "", # ctx.passage_text answers, question, ) ctx.has_answer = bool(answers_spans) return ctx # check if any of the selected ctx+ has answer spans selected_positive_ctxs = list( filter( lambda ctx: ctx.has_answer, [find_answer_spans(ctx) for ctx in positive_ctxs_from_gold_page], ) ) if not selected_positive_ctxs: # fallback to positive ctx not from gold pages selected_positive_ctxs = list( filter( lambda ctx: ctx.has_answer, [find_answer_spans(ctx) for ctx in positive_samples], ) )[0:max_positives] # optionally include gold passage itself if it is still not in the positives list if include_gold_passage and question in gold_passage_map: gold_passage = gold_passage_map[question] included_gold_passage = next( iter(ctx for ctx in selected_positive_ctxs if ctx.passage_text == gold_passage.passage_text), None, ) if not included_gold_passage: gold_passage.has_answer = True gold_passage = find_answer_spans(gold_passage) if not gold_passage.has_answer: logger.warning("No answer found in gold passage: %s", gold_passage) else: selected_positive_ctxs.append(gold_passage) max_negatives = ( min(max(10 * len(selected_positive_ctxs), max1_negatives), max2_negatives) if is_train_set else DEFAULT_EVAL_PASSAGES ) negative_samples = negative_samples[0:max_negatives] return selected_positive_ctxs, negative_samples
def preprocess_retriever_data( samples: List[Dict], gold_info_file: Optional[str], tensorizer: Tensorizer, cfg: ReaderPreprocessingCfg = DEFAULT_PREPROCESSING_CFG_TRAIN, is_train_set: bool = True, ) -> Iterable[ReaderSample]: """ Converts retriever results into reader training data. :param samples: samples from the retriever's json file results :param gold_info_file: optional path for the 'gold passages & questions' file. Required to get best results for NQ :param tensorizer: Tensorizer object for text to model input tensors conversions :param cfg: ReaderPreprocessingCfg object with positive and negative passage selection parameters :param is_train_set: if the data should be processed as a train set :return: iterable of ReaderSample objects which can be consumed by the reader model """ sep_tensor = tensorizer.get_pair_separator_ids() # separator can be a multi token gold_passage_map, canonical_questions = _get_gold_ctx_dict(gold_info_file) if gold_info_file else ({}, {}) no_positive_passages = 0 positives_from_gold = 0 def create_reader_sample_ids(sample: ReaderPassage, question: str): question_and_title = tensorizer.text_to_tensor(sample.title, title=question, add_special_tokens=True) if sample.passage_token_ids is None: sample.passage_token_ids = tensorizer.text_to_tensor(sample.passage_text, add_special_tokens=False) all_concatenated, shift = _concat_pair( question_and_title, sample.passage_token_ids, tailing_sep=sep_tensor if cfg.use_tailing_sep else None, ) sample.sequence_ids = all_concatenated sample.passage_offset = shift assert shift > 1 if sample.has_answer and is_train_set: sample.answers_spans = [(span[0] + shift, span[1] + shift) for span in sample.answers_spans] return sample for sample in samples: question = sample["question"] question_txt = sample["query_text"] if "query_text" in sample else question if canonical_questions and question_txt in canonical_questions: question_txt = canonical_questions[question_txt] positive_passages, negative_passages = _select_reader_passages( sample, question_txt, tensorizer, gold_passage_map, cfg.gold_page_only_positives, cfg.max_positives, cfg.max_negatives, cfg.min_negatives, cfg.max_retriever_passages, cfg.include_gold_passage, is_train_set, ) # create concatenated sequence ids for each passage and adjust answer spans positive_passages = [create_reader_sample_ids(s, question) for s in positive_passages] negative_passages = [create_reader_sample_ids(s, question) for s in negative_passages] if is_train_set and len(positive_passages) == 0: no_positive_passages += 1 if cfg.skip_no_positves: continue if next(iter(ctx for ctx in positive_passages if ctx.score == -1), None): positives_from_gold += 1 if is_train_set: yield ReaderSample( question, sample["answers"], positive_passages=positive_passages, negative_passages=negative_passages, ) else: yield ReaderSample(question, sample["answers"], passages=negative_passages) logger.info("no positive passages samples: %d", no_positive_passages) logger.info("positive passages from gold samples: %d", positives_from_gold)
def _create_question_passages_tensors( wiki_data: TokenizedWikipediaPassages, question_token_ids: np.ndarray, tensorizer: Tensorizer, positives: List[ReaderPassage], negatives: List[ReaderPassage], total_size: int, empty_ids: T, max_n_answers: int, is_train: bool, is_random: bool = True ): max_len = empty_ids.size(0) pad_token_id = tensorizer.get_pad_id() if is_train: # select just one positive positive_idx = _get_positive_idx(positives, max_len, is_random) if positive_idx is None: return None positive = positives[positive_idx] if getattr(positive, "sequence_ids", None) is None: # Load in passage tokens and title tokens positive.load_tokens( question_token_ids=question_token_ids, **wiki_data.get_tokenized_data(int(positive.id)) ) sequence_ids, passage_offset = tensorizer.concatenate_inputs({ "question": positive.question_token_ids, "passage_title": positive.title_token_ids, "passage": positive.passage_token_ids, }, get_passage_offset=True) positive.sequence_ids = sequence_ids positive.passage_offset = passage_offset positive.answers_spans = [ (start + passage_offset, end + passage_offset) for start, end in positive.answers_spans ] positive_a_spans = _get_answer_spans(positive_idx, positives, max_len)[0: max_n_answers] answer_starts = [span[0] for span in positive_a_spans] answer_ends = [span[1] for span in positive_a_spans] assert all(s < max_len for s in answer_starts) assert all(e < max_len for e in answer_ends) positive_input_ids = tensorizer.to_max_length(positive.sequence_ids.numpy(), apply_max_len=True) positive_input_ids = torch.from_numpy(positive_input_ids) answer_starts_tensor = torch.zeros((total_size, max_n_answers)).long() answer_starts_tensor[0, 0:len(answer_starts)] = torch.tensor(answer_starts) # only first passage contains the answer answer_ends_tensor = torch.zeros((total_size, max_n_answers)).long() answer_ends_tensor[0, 0:len(answer_ends)] = torch.tensor(answer_ends) # only first passage contains the answer answer_mask = torch.zeros((total_size, max_n_answers), dtype=torch.long) answer_mask[0, 0:len(answer_starts)] = torch.tensor([1 for _ in range(len(answer_starts))]) positives_IDs: List[int] = [positive.id] positives_selected = [positive_input_ids] else: positives_IDs: List[int] = [] positives_selected = [] answer_starts_tensor = None answer_ends_tensor = None answer_mask = None positives_num = len(positives_selected) negative_idxs = np.random.permutation(range(len(negatives))) if is_random else range( len(negatives) - positives_num) negative_idxs = negative_idxs[:total_size - positives_num] negatives_IDs: List[int] = [] negatives_selected = [] for negative_idx in negative_idxs: negative = negatives[negative_idx] if getattr(negative, "sequence_ids", None) is None: # Load in passage tokens and title tokens negative.load_tokens( question_token_ids=question_token_ids, **wiki_data.get_tokenized_data(int(negative.id)) ) # Concatenate input tokens sequence_ids, passage_offset = tensorizer.concatenate_inputs({ "question": negative.question_token_ids, "passage_title": negative.title_token_ids, "passage": negative.passage_token_ids, }, get_passage_offset=True) negative.sequence_ids = sequence_ids negative.passage_offset = passage_offset negatives_IDs.append(negative.id) negative_input_ids = tensorizer.to_max_length(negative.sequence_ids.numpy(), apply_max_len=True) negatives_selected.append(torch.from_numpy(negative_input_ids)) while len(negatives_selected) < total_size - positives_num: negatives_IDs.append(-1) negatives_selected.append(empty_ids.clone()) context_IDs = torch.tensor(positives_IDs + negatives_IDs, dtype=torch.int64) input_ids = torch.stack([t for t in positives_selected + negatives_selected], dim=0) assert len(context_IDs) == len(input_ids) return context_IDs, input_ids, answer_starts_tensor, answer_ends_tensor, answer_mask
def create_biencoder_input2( cls, samples: List[BiEncoderSample], tensorizer: Tensorizer, insert_title: bool, num_hard_negatives: int = 0, num_other_negatives: int = 0, shuffle: bool = True, shuffle_positives: bool = False, hard_neg_fallback: bool = True, query_token: str = None, ) -> BiEncoderBatch: """ Creates a batch of the biencoder training tuple. :param samples: list of BiEncoderSample-s to create the batch for :param tensorizer: components to create model input tensors from a text sequence :param insert_title: enables title insertion at the beginning of the context sequences :param num_hard_negatives: amount of hard negatives per question (taken from samples' pools) :param num_other_negatives: amount of other negatives per question (taken from samples' pools) :param shuffle: shuffles negative passages pools :param shuffle_positives: shuffles positive passages pools :return: BiEncoderBatch tuple """ question_tensors = [] ctx_tensors = [] positive_ctx_indices = [] hard_neg_ctx_indices = [] for sample in samples: # ctx+ & [ctx-] composition # as of now, take the first(gold) ctx+ only if shuffle and shuffle_positives: positive_ctxs = sample.positive_passages positive_ctx = positive_ctxs[np.random.choice( len(positive_ctxs))] else: positive_ctx = sample.positive_passages[0] neg_ctxs = sample.negative_passages hard_neg_ctxs = sample.hard_negative_passages question = sample.query # question = normalize_question(sample.query) if shuffle: random.shuffle(neg_ctxs) random.shuffle(hard_neg_ctxs) if hard_neg_fallback and len(hard_neg_ctxs) == 0: hard_neg_ctxs = neg_ctxs[0:num_hard_negatives] neg_ctxs = neg_ctxs[0:num_other_negatives] hard_neg_ctxs = hard_neg_ctxs[0:num_hard_negatives] all_ctxs = [positive_ctx] + neg_ctxs + hard_neg_ctxs hard_negatives_start_idx = 1 hard_negatives_end_idx = 1 + len(hard_neg_ctxs) current_ctxs_len = len(ctx_tensors) sample_ctxs_tensors = [ tensorizer.text_to_tensor( ctx.text, title=ctx.title if (insert_title and ctx.title) else None) for ctx in all_ctxs ] ctx_tensors.extend(sample_ctxs_tensors) positive_ctx_indices.append(current_ctxs_len) hard_neg_ctx_indices.append([ i for i in range( current_ctxs_len + hard_negatives_start_idx, current_ctxs_len + hard_negatives_end_idx, ) ]) if query_token: # TODO: tmp workaround for EL, remove or revise if query_token == "[START_ENT]": query_span = _select_span_with_token(question, tensorizer, token_str=query_token) question_tensors.append(query_span) else: question_tensors.append( tensorizer.text_to_tensor(" ".join( [query_token, question]))) else: question_tensors.append(tensorizer.text_to_tensor(question)) ctxs_tensor = torch.cat([ctx.view(1, -1) for ctx in ctx_tensors], dim=0) questions_tensor = torch.cat([q.view(1, -1) for q in question_tensors], dim=0) ctx_segments = torch.zeros_like(ctxs_tensor) question_segments = torch.zeros_like(questions_tensor) return BiEncoderBatch( questions_tensor, question_segments, ctxs_tensor, ctx_segments, positive_ctx_indices, hard_neg_ctx_indices, "question", )
def _select_passages( wiki_data: TokenizedWikipediaPassages, sample: Dict, bm25_sample: Tuple[Tuple[int, float]], question: str, processed_question: str, question_token_ids: np.ndarray, answers: List[str], expanded_answers: List[List[str]], all_answers: List[str], tensorizer: Tensorizer, gold_passage_map: Dict[str, DataPassage], processed_gold_passage_map: Dict[str, DataPassage], cfg: PreprocessingCfg, is_train_set: bool, check_pre_tokenized_data: bool, ) -> Tuple[List[DataPassage], List[DataPassage], List[DataPassage], List[DataPassage], List[DataPassage], List[DataPassage], List[DataPassage]]: """ Select and process valid passages for training/evaluation. """ # Tokenize answers answers_token_ids: List[np.ndarray] = [ tensorizer.text_to_tensor(a, add_special_tokens=False).numpy() for a in all_answers ] # Gold context; we want to cover more gold passages, that's why we are matching both # `processed_question` and `question` (canonical question). if question in processed_gold_passage_map or processed_question in processed_gold_passage_map: if question in processed_gold_passage_map: gold_ctx = processed_gold_passage_map[question] else: gold_ctx = processed_gold_passage_map[processed_question] gold_ctx = _load_tokens_into_ctx( gold_ctx, question_token_ids, wiki_data, tensorizer, check_pre_tokenized_data, ) # load question, passage title and passage tokens into the context object gold_ctx = _find_answer_spans( tensorizer, gold_ctx, question, all_answers, answers_token_ids, warn_if_no_answer=True, raise_if_no_answer=False, warn_if_has_answer=False, raise_if_has_answer=False, recheck_negatives=False, ) # find answer spans for all passages if gold_ctx.has_answer: gold_ctxs = [gold_ctx] else: gold_ctxs = [] else: gold_ctxs = [] # Densely retrieved contexts ctxs = [DataPassage(is_from_bm25=False, **ctx) for ctx in sample["ctxs"]] ctxs = [ _load_tokens_into_ctx(ctx, question_token_ids, wiki_data, tensorizer, check_pre_tokenized_data) for ctx in ctxs ] # load question, passage title and passage tokens into the context object # Find answer spans for all passages ctxs: List[DataPassage] = [ _find_answer_spans( tensorizer, ctx, question, all_answers, answers_token_ids, warn_if_no_answer=ctx. has_answer, # warn if originally it contains answer string warn_if_has_answer=( not ctx.has_answer ), # warn if originally it does NOT contain answer string recheck_negatives=cfg.recheck_negatives, ) for ctx in ctxs ] # Sparsely retrieved contexts (BM25) bm25_ctxs = [ DataPassage(id=passage_id, score=score, is_from_bm25=True) for passage_id, score in bm25_sample ] bm25_ctxs = [ _load_tokens_into_ctx(ctx, question_token_ids, wiki_data, tensorizer, check_pre_tokenized_data) for ctx in bm25_ctxs ] # load question, passage title and passage tokens into the context object # Find answer spans for all passages bm25_ctxs: List[DataPassage] = [ _find_answer_spans( tensorizer, ctx, question, all_answers, answers_token_ids, warn_if_no_answer=False, warn_if_has_answer=False, recheck_negatives=True, # `has_answer` of any BM25 passage is None ) for ctx in bm25_ctxs ] # Filter positives and negatives using distant supervision positive_samples = list(filter(lambda ctx: ctx.has_answer, ctxs)) distantly_positive_samples: List[DataPassage] = [] negative_samples = list(filter(lambda ctx: not ctx.has_answer, ctxs)) bm25_positive_samples = list(filter(lambda ctx: ctx.has_answer, bm25_ctxs)) bm25_distantly_positive_samples: List[DataPassage] = [] bm25_negative_samples = list( filter(lambda ctx: not ctx.has_answer, bm25_ctxs)) # Filter unwanted positive passages if training if is_train_set: # Get positives that are from gold positive passages if cfg.gold_page_only_positives: selected_positive_ctxs: List[DataPassage] = [] selected_negative_ctxs: List[DataPassage] = negative_samples selected_bm25_positive_ctxs: List[DataPassage] = [] selected_bm25_negative_ctxs: List[ DataPassage] = bm25_negative_samples for positives, selected_positives, selected_negatives, distantly_positives in [ (positive_samples, selected_positive_ctxs, selected_negative_ctxs, distantly_positive_samples), (bm25_positive_samples, selected_bm25_positive_ctxs, selected_bm25_negative_ctxs, bm25_distantly_positive_samples) ]: for ctx in positives: is_from_gold = _is_from_gold_wiki_page( gold_passage_map, ctx, tensorizer.tensor_to_text( torch.from_numpy(ctx.title_token_ids)), question) if is_from_gold: selected_positives.append(ctx) else: # if it has answer but does not come from gold passage if cfg.should_negatives_contain_answer: selected_negatives.append(ctx) else: distantly_positives.append(ctx) else: selected_positive_ctxs = positive_samples selected_negative_ctxs = negative_samples selected_bm25_positive_ctxs = bm25_positive_samples selected_bm25_negative_ctxs = bm25_negative_samples # Fallback to positive ctx not from gold passages if len(selected_positive_ctxs) == 0: selected_positive_ctxs = positive_samples if len(selected_bm25_positive_ctxs) == 0: selected_bm25_positive_ctxs = bm25_positive_samples # Optionally include gold passage itself if it is still not in the positives list if cfg.include_gold_passage: if question in gold_passage_map: gold_passage = gold_passage_map[question] gold_passage.is_gold = True gold_passage.has_answer = True # assuming it has answer gold_passage = _find_answer_spans( tensorizer, gold_passage, question, all_answers, answers_token_ids, warn_if_no_answer=False, warn_if_has_answer=False, recheck_negatives=True, ) # warn below if not gold_passage.has_answer: logger.warning( "No answer found in GOLD passage: passage='%s', question='%s', answers=%s, expanded_answers=%s", gold_passage.passage_text, question, answers, expanded_answers, ) selected_positive_ctxs.append( gold_passage ) # append anyway, since we need this for retriever (not reader) else: logger.warning(f"Question '{question}' has no gold positive") else: # NOTE: See `create_reader_input` function in `reader.py` to see how # positive and negative samples are merged (keeping their original order) selected_positive_ctxs = positive_samples selected_negative_ctxs = negative_samples selected_bm25_positive_ctxs = bm25_positive_samples selected_bm25_negative_ctxs = bm25_negative_samples # Restrict number of BM25 passages selected_bm25_positive_ctxs = selected_bm25_positive_ctxs[:cfg. max_bm25_positives] selected_bm25_negative_ctxs = selected_bm25_negative_ctxs[:cfg. max_bm25_negatives] return ( gold_ctxs, selected_positive_ctxs, selected_negative_ctxs, distantly_positive_samples, selected_bm25_positive_ctxs, selected_bm25_negative_ctxs, bm25_distantly_positive_samples, )
def create_biencoder_input( cls, samples: List, tensorizer: Tensorizer, insert_title: bool, num_hard_negatives: int = 0, num_other_negatives: int = 0, shuffle: bool = True, shuffle_positives: bool = False, max_retrys: int = 100, ) -> BiEncoderBatch: """ Creates a batch of the biencoder training tuple. :param samples: list of data items (from json) to create the batch for :param tensorizer: components to create model input tensors from a text sequence :param insert_title: enables title insertion at the beginning of the context sequences :param num_hard_negatives: amount of hard negatives per question (taken from samples' pools) :param num_other_negatives: amount of other negatives per question (taken from samples' pools) :param shuffle: shuffles negative passages pools :param shuffle_positives: shuffles positive passages pools :param max_retrys: max retry count to find unique positive context :return: BiEncoderBatch tuple """ question_tensors = [] ctx_tensors = [] positive_ctx_indices = [] hard_neg_ctx_indices = [] used_ctxs = set() for sample in samples: # ctx+ & [ctx-] composition # as of now, take the first(gold) ctx+ only if shuffle and shuffle_positives: positive_ctxs = sample["positive_ctxs"] positive_ctx = positive_ctxs[np.random.choice( len(positive_ctxs))] retry_counter = 0 while positive_ctx[ 'text'] in used_ctxs and retry_counter < max_retrys: positive_ctx = positive_ctxs[np.random.choice( len(positive_ctxs))] retry_counter += 1 used_ctxs.add(positive_ctx['text']) else: positive_ctx = sample["positive_ctxs"][0] #TODO: probably add negative_ctxs validation neg_ctxs = sample["negative_ctxs"] hard_neg_ctxs = sample["hard_negative_ctxs"] question = normalize_question(sample["question"]) if shuffle: random.shuffle(neg_ctxs) random.shuffle(hard_neg_ctxs) neg_ctxs = neg_ctxs[0:num_other_negatives] hard_neg_ctxs = hard_neg_ctxs[0:num_hard_negatives] all_ctxs = [positive_ctx] + neg_ctxs + hard_neg_ctxs hard_negatives_start_idx = len( neg_ctxs ) + 1 # originally that was 1 which I don't think is right hard_negatives_end_idx = len(neg_ctxs) + 1 + len(hard_neg_ctxs) current_ctxs_len = len(ctx_tensors) sample_ctxs_tensors = [ tensorizer.text_to_tensor( ctx["text"], title=ctx["title"] if insert_title else None) for ctx in all_ctxs ] ctx_tensors.extend(sample_ctxs_tensors) positive_ctx_indices.append(current_ctxs_len) hard_neg_ctx_indices.append([ i for i in range( current_ctxs_len + hard_negatives_start_idx, current_ctxs_len + hard_negatives_end_idx, ) ]) question_tensors.append(tensorizer.text_to_tensor(question)) ctxs_tensor = torch.cat([ctx.view(1, -1) for ctx in ctx_tensors], dim=0) questions_tensor = torch.cat([q.view(1, -1) for q in question_tensors], dim=0) ctx_segments = torch.zeros_like(ctxs_tensor) question_segments = torch.zeros_like(questions_tensor) return BiEncoderBatch( questions_tensor, question_segments, ctxs_tensor, ctx_segments, positive_ctx_indices, hard_neg_ctx_indices, )
def create_graded_biencoder_input2( cls, samples: List[GradedBiEncoderSample], tensorizer: Tensorizer, insert_title: bool, num_hard_negatives: int = 0, num_other_negatives: int = 0, num_related: int = 0, num_highly_related: int = 0, shuffle: bool = True, shuffle_positives: bool = False, hard_neg_fallback: bool = True, query_token: str = None, relation_grades: list = [1.0, 1.0, 1.0, 0.0, 0.0], ) -> GradedBiEncoderBatch: """ Creates a batch of the biencoder training tuple. :param samples: list of GradedBiEncoderSample-s to create the batch for :param tensorizer: components to create model input tensors from a text sequence :param insert_title: enables title insertion at the beginning of the context sequences :param num_hard_negatives: amount of hard negatives per question (taken from samples' pools) :param num_other_negatives: amount of other negatives per question (taken from samples' pools) :param shuffle: shuffles negative passages pools :param shuffle_positives: shuffles positive passages pools :return: BiEncoderBatch tuple """ question_tensors = [] ctx_tensors = [] positive_ctx_indices = [] hard_neg_ctx_indices = [] negatives_ctx_indices = [] related_ctx_indices = [] highly_related_ctx_indices = [] relations = [] for sample in samples: # ctx+ & [ctx-] composition # as of now, take the first(gold) ctx+ only if shuffle and shuffle_positives: positive_ctxs = sample.positive_passages positive_ctx = positive_ctxs[np.random.choice( len(positive_ctxs))] else: positive_ctx = sample.positive_passages[0] neg_ctxs = sample.negative_passages hard_neg_ctxs = sample.hard_negative_passages related_ctxs = sample.related_passage highly_related_ctxs = sample.highly_related_passage question = sample.query # question = normalize_question(sample.query) if shuffle: random.shuffle(neg_ctxs) random.shuffle(hard_neg_ctxs) random.shuffle(related_ctxs) random.shuffle(highly_related_ctxs) if hard_neg_fallback and len(hard_neg_ctxs) == 0: hard_neg_ctxs = neg_ctxs[0:num_hard_negatives] neg_ctxs = neg_ctxs[0:num_other_negatives] hard_neg_ctxs = hard_neg_ctxs[0:num_hard_negatives] related_ctxs = related_ctxs[0:num_related] highly_related_ctxs = highly_related_ctxs[0:num_highly_related] all_ctxs = [ positive_ctx ] + neg_ctxs + hard_neg_ctxs + related_ctxs + highly_related_ctxs # relations rel_positive, rel_highly_related, rel_related, rel_negative, rel_hard_negative = relation_grades question_relations = [] if relations != []: # pre-padding with negatives question_relations = [rel_negative] * len(relations[-1]) question_relations.extend([rel_positive]) question_relations.extend([rel_negative] * len(neg_ctxs)) question_relations.extend([rel_hard_negative] * len(hard_neg_ctxs)) question_relations.extend([rel_related] * len(related_ctxs)) question_relations.extend([rel_highly_related] * len(highly_related_ctxs)) relations.append(question_relations) # post-padding with negatives for relation in relations: if len(relation) < len(relations[-1]): num_negatives_to_post_pad = len( relations[-1]) - len(relation) relation.extend([rel_negative] * num_negatives_to_post_pad) # calculate all positions current_ctxs_len = len(ctx_tensors) positive_ctx_indices.append(current_ctxs_len) negatives_start_idx = 1 + current_ctxs_len negatives_end_idx = 1 + len(neg_ctxs) + current_ctxs_len negatives_idx_range = list( range(negatives_start_idx, negatives_end_idx)) negatives_ctx_indices.append(negatives_idx_range) hard_negatives_start_idx = negatives_end_idx + current_ctxs_len hard_negatives_end_idx = negatives_end_idx + len( hard_neg_ctxs) + current_ctxs_len hard_negatives_idx_range = list( range(hard_negatives_start_idx, hard_negatives_end_idx)) hard_neg_ctx_indices.append(hard_negatives_idx_range) related_start_idx = hard_negatives_end_idx + current_ctxs_len related_end_idx = hard_negatives_end_idx + len( related_ctxs) + current_ctxs_len related_idx_range = list(range(related_start_idx, related_end_idx)) related_ctx_indices.append(related_idx_range) highly_related_start_idx = related_end_idx + current_ctxs_len highly_related_end_idx = related_end_idx + len( highly_related_ctxs) + current_ctxs_len highly_related_idx_range = list( range(highly_related_start_idx, highly_related_end_idx)) highly_related_ctx_indices.append(highly_related_idx_range) # add all ctxs to ctx_tensors sample_ctxs_tensors = [ tensorizer.text_to_tensor( ctx.text, title=ctx.title if (insert_title and ctx.title) else None) for ctx in all_ctxs ] ctx_tensors.extend(sample_ctxs_tensors) if query_token: # TODO: tmp workaround for EL, remove or revise if query_token == "[START_ENT]": query_span = _select_span_with_token(question, tensorizer, token_str=query_token) question_tensors.append(query_span) else: question_tensors.append( tensorizer.text_to_tensor(" ".join( [query_token, question]))) else: question_tensors.append(tensorizer.text_to_tensor(question)) ctxs_tensor = torch.cat([ctx.view(1, -1) for ctx in ctx_tensors], dim=0) questions_tensor = torch.cat([q.view(1, -1) for q in question_tensors], dim=0) ctx_segments = torch.zeros_like(ctxs_tensor) question_segments = torch.zeros_like(questions_tensor) return GradedBiEncoderBatch( questions_tensor, question_segments, ctxs_tensor, ctx_segments, positive_ctx_indices, hard_neg_ctx_indices, negatives_ctx_indices, related_ctx_indices, highly_related_ctx_indices, relations, "question", )
def create_biencoder_input_tokenized( cls, samples: List[BiEncoderSampleTokenized], tensorizer: Tensorizer, insert_title: bool, num_hard_negatives: int = 0, num_bm25_negatives: int = 0, shuffle: bool = True, shuffle_positives: bool = False, hard_neg_fallback: bool = True, query_token: str = None, ) -> BiEncoderBatch: """ Creates a batch of the biencoder training tuple using tokenized data. :param samples: list of BiEncoderSampleTokenized-s to create the batch for :param tensorizer: components to create model input tensors from a text sequence :param insert_title: enables title insertion at the beginning of the context sequences :param num_hard_negatives: amount of hard negatives (densely retrieved) per question :param num_bm25_negatives: amount of BM25 negatives (sparsely retrieved) per question :param shuffle: shuffles negative passages pools :param shuffle_positives: shuffles positive passages pools. This is only effective for samples whose gold passage is available. In that case, the positive chosen is not necessarily the gold passage. Otherwise, the positive passages will be shuffled regardless of this parameter. :return: BiEncoderBatch tuple """ question_tensors: List[T] = [] ctx_ids: List[int] = [] # passage IDs ctx_tensors: List[T] = [] positive_ctx_indices = [] hard_neg_ctx_indices = [] # Strict settings assert insert_title is True # for now only allow `insert_title` to be True assert query_token is None for sample in samples: # Skip samples without positive passges (either gold or distant positives) if len(sample.positive_passages) == 0: continue # ctx+ & [ctx-] composition # as of now, take the first(gold) ctx+ only if (shuffle and shuffle_positives) or ( not sample.positive_passages[0].is_gold): positive_ctxs = sample.positive_passages positive_ctx = positive_ctxs[np.random.choice( len(positive_ctxs))] else: positive_ctx = sample.positive_passages[0] bm25_neg_ctxs = sample.bm25_negative_passages hard_neg_ctxs = sample.hard_negative_passages question_ids = sample.query_ids if shuffle: random.shuffle(bm25_neg_ctxs) random.shuffle(hard_neg_ctxs) if hard_neg_fallback and len(hard_neg_ctxs) == 0: hard_neg_ctxs = bm25_neg_ctxs[0:num_hard_negatives] bm25_neg_ctxs = bm25_neg_ctxs[0:num_bm25_negatives] hard_neg_ctxs = hard_neg_ctxs[0:num_hard_negatives] all_ctxs = [positive_ctx] + bm25_neg_ctxs + hard_neg_ctxs hard_negatives_start_idx = 1 + len(bm25_neg_ctxs) hard_negatives_end_idx = len(all_ctxs) current_ctxs_len = len(ctx_tensors) # Context IDs ctx_id = [ctx.id for ctx in all_ctxs] ctx_ids.extend(ctx_id) # Context tensors sample_ctxs_tensors = [ tensorizer.concatenate_inputs( ids={ "passage_title": list(ctx.title_ids), "passage": list(ctx.text_ids) }, get_passage_offset=False, to_max_length=True, ) for ctx in all_ctxs ] ctx_tensors.extend(sample_ctxs_tensors) positive_ctx_indices.append(current_ctxs_len) hard_neg_ctx_indices.append([ i for i in range( current_ctxs_len + hard_negatives_start_idx, current_ctxs_len + hard_negatives_end_idx, ) ]) question_tensors.append( tensorizer.concatenate_inputs(ids={"question": question_ids}, get_passage_offset=False, to_max_length=True)) ctx_ids = torch.tensor(ctx_ids, dtype=torch.int64) ctxs_tensor = torch.cat([ctx.view(1, -1) for ctx in ctx_tensors], dim=0) questions_tensor = torch.cat([q.view(1, -1) for q in question_tensors], dim=0) ctx_segments = torch.zeros_like(ctxs_tensor) question_segments = torch.zeros_like(questions_tensor) return BiEncoderBatch( questions_tensor, question_segments, ctx_ids, ctxs_tensor, ctx_segments, positive_ctx_indices, hard_neg_ctx_indices, "question", )
def _do_biencoder_fwd_pass( model: nn.Module, input: BiEncoderBatch, tensorizer: Tensorizer, loss_function, cfg, encoder_type: str, rep_positions_q=0, rep_positions_c=0, loss_scale: float = None, clustering: bool = False, ) -> Tuple[torch.Tensor, int]: input = BiEncoderBatch(**move_to_device(input._asdict(), cfg.device)) q_attn_mask = tensorizer.get_attn_mask(input.question_ids) ctx_attn_mask = tensorizer.get_attn_mask(input.context_ids) if model.training: model_out = model( input.question_ids, input.question_segments, q_attn_mask, input.context_ids, input.ctx_segments, ctx_attn_mask, encoder_type=encoder_type, representation_token_pos_q=rep_positions_q, representation_token_pos_c=rep_positions_c, ) else: with torch.no_grad(): model_out = model( input.question_ids, input.question_segments, q_attn_mask, input.context_ids, input.ctx_segments, ctx_attn_mask, encoder_type=encoder_type, representation_token_pos_q=rep_positions_q, representation_token_pos_c=rep_positions_c, ) local_q_vector, local_ctx_vectors = model_out if cfg.others.is_matching: # MatchBiEncoder model loss, ml_is_correct, matching_is_correct = _calc_loss_matching( cfg, model, loss_function, local_q_vector, local_ctx_vectors, input.is_positive, input.hard_negatives, loss_scale=loss_scale, ) ml_is_correct = ml_is_correct.sum().item() matching_is_correct = matching_is_correct.sum().item() else: loss, is_correct = calc_loss( cfg, loss_function, local_q_vector, local_ctx_vectors, input.is_positive, input.hard_negatives, loss_scale=loss_scale, ) is_correct = is_correct.sum().item() if cfg.n_gpu > 1: loss = loss.mean() if cfg.train.gradient_accumulation_steps > 1: loss = loss / cfg.gradient_accumulation_steps if clustering: assert not cfg.others.is_matching return loss, is_correct, model_out elif cfg.others.is_matching: return loss, ml_is_correct, matching_is_correct else: return loss, is_correct
def gen_ctx_vectors( cfg: DictConfig, ctx_rows: List[Tuple[object, BiEncoderPassage]], q_rows: List[object], model: nn.Module, tensorizer: Tensorizer, insert_title: bool = True, ) -> List[Tuple[object, np.array]]: n = len(ctx_rows) bsz = cfg.batch_size total = 0 results = [] for j, batch_start in enumerate(range(0, n, bsz)): # Passage preprocess # TODO; max seq length check batch = ctx_rows[batch_start:batch_start + bsz] batch_token_tensors = [ tensorizer.text_to_tensor( ctx[1].text, title=ctx[1].title if insert_title else None) for ctx in batch ] ctx_ids_batch = move_to_device(torch.stack(batch_token_tensors, dim=0), cfg.device) ctx_seg_batch = move_to_device(torch.zeros_like(ctx_ids_batch), cfg.device) ctx_attn_mask = move_to_device(tensorizer.get_attn_mask(ctx_ids_batch), cfg.device) # Question preprocess q_batch = q_rows[batch_start:batch_start + bsz] q_batch_token_tensors = [ tensorizer.text_to_tensor(qq) for qq in q_batch ] q_ids_batch = move_to_device(torch.stack(q_batch_token_tensors, dim=0), cfg.device) q_seg_batch = move_to_device(torch.zeros_like(q_ids_batch), cfg.device) q_attn_mask = move_to_device(tensorizer.get_attn_mask(q_ids_batch), cfg.device) # Selector from dpr.data.biencoder_data import DEFAULT_SELECTOR selector = DEFAULT_SELECTOR rep_positions = selector.get_positions(q_ids_batch, tensorizer) with torch.no_grad(): q_dense, ctx_dense = model( q_ids_batch, q_seg_batch, q_attn_mask, ctx_ids_batch, ctx_seg_batch, ctx_attn_mask, representation_token_pos=rep_positions, ) q_dense = q_dense.cpu() ctx_dense = ctx_dense.cpu() ctx_ids = [r[0] for r in batch] assert len(ctx_ids) == q_dense.size(0) == ctx_dense.size(0) total += len(ctx_ids) results.extend([(ctx_ids[i], q_dense[i].numpy(), ctx_dense[i].numpy(), q_dense[i].numpy().dot(ctx_dense[i].numpy())) for i in range(q_dense.size(0))]) if total % 10 == 0: logger.info("Encoded questions / passages %d", total) # break return results
def create_reader_input( wiki_data: TokenizedWikipediaPassages, tensorizer: Tensorizer, samples: List[ReaderSample], passages_per_question: int, max_length: int, max_n_answers: int, is_train: bool, shuffle: bool, ) -> ReaderBatch: """ Creates a reader batch instance out of a list of ReaderSample-s. This is compatible with `GeneralDataset`. :param wiki_data: all tokenized wikipedia passages :param tensorizer: initialized tensorizer (which contains the tokenizer) :param samples: list of samples to create the batch for :param passages_per_question: amount of passages for every question in a batch :param max_length: max model input sequence length :param max_n_answers: max num of answers per single question :param is_train: if the samples are for a train set :param shuffle: should passages selection be randomized :return: ReaderBatch instance """ context_IDs = [] input_ids = [] start_positions = [] end_positions = [] answers_masks = [] empty_sequence = torch.Tensor().new_full((max_length,), tensorizer.get_pad_id(), dtype=torch.long) for sample in samples: if is_train: positive_ctxs = sample.positive_passages negative_ctxs = sample.negative_passages else: positive_ctxs = [] negative_ctxs = sample.positive_passages + sample.negative_passages # Need to re-sort samples based on their scores negative_ctxs = sorted(negative_ctxs, key=lambda x: x.score, reverse=True) question_token_ids = sample.question_token_ids sample_tensors = _create_question_passages_tensors( wiki_data, question_token_ids, tensorizer, positive_ctxs, negative_ctxs, passages_per_question, empty_sequence, max_n_answers, is_train, is_random=shuffle ) if not sample_tensors: logger.warning('No valid passages combination for question=%s ', sample.question) continue context_ID, sample_input_ids, starts_tensor, ends_tensor, answer_mask = sample_tensors context_IDs.append(context_ID) input_ids.append(sample_input_ids) if is_train: start_positions.append(starts_tensor) end_positions.append(ends_tensor) answers_masks.append(answer_mask) context_IDs = torch.cat([IDs.unsqueeze(0) for IDs in context_IDs], dim=0) # (N, M) input_ids = torch.cat([ids.unsqueeze(0) for ids in input_ids], dim=0) # (N, M) if is_train: start_positions = torch.stack(start_positions, dim=0) end_positions = torch.stack(end_positions, dim=0) answers_masks = torch.stack(answers_masks, dim=0) return ReaderBatch(context_IDs, input_ids, start_positions, end_positions, answers_masks)
def generate_question_vectors( question_encoder: torch.nn.Module, tensorizer: Tensorizer, questions: List[str], bsz: int, query_token: str = None, selector: RepTokenSelector = None, ) -> T: n = len(questions) query_vectors = [] with torch.no_grad(): for j, batch_start in enumerate(range(0, n, bsz)): batch_questions = questions[batch_start:batch_start + bsz] if query_token: # TODO: tmp workaround for EL, remove or revise if query_token == "[START_ENT]": batch_tensors = [ _select_span_with_token(q, tensorizer, token_str=query_token) for q in batch_questions ] else: batch_tensors = [ tensorizer.text_to_tensor(" ".join([query_token, q])) for q in batch_questions ] elif isinstance(batch_questions[0], T): batch_tensors = [q for q in batch_questions] else: batch_tensors = [ tensorizer.text_to_tensor(q) for q in batch_questions ] # TODO: this only works for Wav2vec pipeline but will crash the regular text pipeline max_vector_len = max(q_t.size(1) for q_t in batch_tensors) min_vector_len = min(q_t.size(1) for q_t in batch_tensors) if max_vector_len != min_vector_len: # TODO: _pad_to_len move to utils from dpr.models.reader import _pad_to_len batch_tensors = [ _pad_to_len(q.squeeze(0), 0, max_vector_len) for q in batch_tensors ] q_ids_batch = torch.stack(batch_tensors, dim=0).cuda() q_seg_batch = torch.zeros_like(q_ids_batch).cuda() q_attn_mask = tensorizer.get_attn_mask(q_ids_batch) if selector: rep_positions = selector.get_positions(q_ids_batch, tensorizer) _, out, _ = BiEncoder.get_representation( question_encoder, q_ids_batch, q_seg_batch, q_attn_mask, representation_token_pos=rep_positions, ) else: _, out, _ = question_encoder(q_ids_batch, q_seg_batch, q_attn_mask) query_vectors.extend(out.cpu().split(1, dim=0)) if len(query_vectors) % 100 == 0: logger.info("Encoded queries %d", len(query_vectors)) query_tensor = torch.cat(query_vectors, dim=0) logger.info("Total encoded queries tensor %s", query_tensor.size()) assert query_tensor.size(0) == len(questions) return query_tensor
def create_biencoder_input( cls, samples: List, tensorizer: Tensorizer, insert_title: bool, num_hard_negatives: int = 0, num_other_negatives: int = 0, shuffle: bool = True, shuffle_positives: bool = False, hard_neg_fallback: bool = True, ) -> BiEncoderBatch: """ Creates a batch of the biencoder training tuple. :param samples: list of data items (from json) to create the batch for :param tensorizer: components to create model input tensors from a text sequence :param insert_title: enables title insertion at the beginning of the context sequences :param num_hard_negatives: amount of hard negatives per question (taken from samples' pools) :param num_other_negatives: amount of other negatives per question (taken from samples' pools) :param shuffle: shuffles negative passages pools :param shuffle_positives: shuffles positive passages pools :return: BiEncoderBatch tuple """ question_tensors = [] ctx_tensors = [] positive_ctx_indices = [] hard_neg_ctx_indices = [] for sample in samples: # ctx+ & [ctx-] composition # as of now, take the first(gold) ctx+ only if shuffle and shuffle_positives: positive_ctxs = sample["positive_ctxs"] positive_ctx = positive_ctxs[np.random.choice( len(positive_ctxs))] else: positive_ctx = sample["positive_ctxs"][0] neg_ctxs = sample["negative_ctxs"] hard_neg_ctxs = sample["hard_negative_ctxs"] if shuffle: random.shuffle(neg_ctxs) random.shuffle(hard_neg_ctxs) if hard_neg_fallback and len(hard_neg_ctxs) == 0: hard_neg_ctxs = neg_ctxs[0:num_hard_negatives] neg_ctxs = neg_ctxs[0:num_other_negatives] hard_neg_ctxs = hard_neg_ctxs[0:num_hard_negatives] all_ctxs = [positive_ctx] + neg_ctxs + hard_neg_ctxs hard_negatives_start_idx = 1 hard_negatives_end_idx = 1 + len(hard_neg_ctxs) current_ctxs_len = len(ctx_tensors) sample_ctxs_tensors = [ tensorizer.text_to_tensor( ctx["text"], title=ctx["title"] if (insert_title and "title" in ctx) else None, ) for ctx in all_ctxs ] ctx_tensors.extend(sample_ctxs_tensors) positive_ctx_indices.append(current_ctxs_len) hard_neg_ctx_indices.append([ i for i in range( current_ctxs_len + hard_negatives_start_idx, current_ctxs_len + hard_negatives_end_idx, ) ]) question_tensors.append(tensorizer.text_to_tensor(question)) ctxs_tensor = torch.cat([ctx.view(1, -1) for ctx in ctx_tensors], dim=0) questions_tensor = torch.cat([q.view(1, -1) for q in question_tensors], dim=0) ctx_segments = torch.zeros_like(ctxs_tensor) question_segments = torch.zeros_like(questions_tensor) return BiEncoderBatch( questions_tensor, question_segments, ctxs_tensor, ctx_segments, positive_ctx_indices, hard_neg_ctx_indices, "question", )
def _preprocess_retriever_data( samples: List[Dict], bm25_samples: List[Tuple[Tuple[int, float]]], wiki_data: TokenizedWikipediaPassages, gold_info_file: Optional[str], gold_info_processed_file: str, tensorizer: Tensorizer, cfg: PreprocessingCfg = DEFAULT_PREPROCESSING_CFG_TRAIN, is_train_set: bool = True, check_pre_tokenized_data: bool = True, ) -> Iterable[DataSample]: """ Converts retriever results into general retriever/reader training data. :param samples: samples from the retriever's json file results :param bm25_samples: bm25 retrieval results; list of tuples of tuples of (passage_id, score), where passages of each sample are already sorted by their scores :param gold_info_file: optional path for the 'gold passages & questions' file. Required to get best results for NQ :param gold_info_processed_file: path to the preprocessed gold passages pickle file. Unlike `gold_passages_file` which contains original gold passages, this file should contain processed, matched, 100-word split passages that match with the original gold passages. :param tensorizer: Tensorizer object for text to model input tensors conversions :param cfg: PreprocessingCfg object with positive and negative passage selection parameters :param is_train_set: if the data should be processed as a train set :return: iterable of DataSample objects which can be consumed by the reader model """ gold_passage_map, canonical_questions = (_get_gold_ctx_dict(gold_info_file) if gold_info_file is not None else ({}, {})) processed_gold_passage_map = ( _get_processed_gold_ctx_dict(gold_info_processed_file) if gold_info_processed_file else {}) number_no_positive_samples = 0 number_samples_from_gold = 0 number_samples_with_gold = 0 assert len(samples) == len(bm25_samples) for sample, bm25_sample in zip(samples, bm25_samples): # Refer to `_get_gold_ctx_dict` for why we need to distinguish between two types of questions # Here `processed_question` refer to tokenized questions, where `question` refer to # canonical questions. processed_question = sample["question"] if processed_question in canonical_questions: question = canonical_questions[processed_question] else: question = processed_question question_token_ids: np.ndarray = tensorizer.text_to_tensor( normalize_question(question) if cfg.normalize_questions else question, add_special_tokens=False, ).numpy() orig_answers = sample["answers"] if cfg.expand_answers: expanded_answers = [ get_expanded_answer(answer) for answer in orig_answers ] else: expanded_answers = [] all_answers = orig_answers + sum(expanded_answers, []) passages = _select_passages( wiki_data, sample, bm25_sample, question, processed_question, question_token_ids, orig_answers, expanded_answers, all_answers, tensorizer, gold_passage_map, processed_gold_passage_map, cfg, is_train_set, check_pre_tokenized_data, ) gold_passages = passages[0] positive_passages, negative_passages, distantly_positive_passages = passages[ 1:4] bm25_positive_passages, bm25_negative_passages, bm25_distantly_positive_passages = passages[ 4:] if is_train_set and len(positive_passages) == 0: number_no_positive_samples += 1 if cfg.skip_no_positives: continue if any(ctx for ctx in positive_passages if ctx.is_from_gold): number_samples_from_gold += 1 if len(gold_passages) > 0: number_samples_with_gold += 1 yield DataSample( question, question_token_ids=question_token_ids, answers=all_answers, orig_answers=orig_answers, expanded_answers=expanded_answers, # Gold gold_passages=gold_passages, # Dense positive_passages=positive_passages, distantly_positive_passages=distantly_positive_passages, negative_passages=negative_passages, # Sparse bm25_positive_passages=bm25_positive_passages, bm25_distantly_positive_passages=bm25_distantly_positive_passages, bm25_negative_passages=bm25_negative_passages, ) logger.info( f"Number of samples whose at least one positive passage is " f"from the same article as the gold passage: {number_samples_from_gold}" ) logger.info( f"Number of samples whose gold passage is available: {number_samples_with_gold}" ) logger.info( f"Number of samples with no positive passages: {number_no_positive_samples}" )