def create_biencoder_input_from_reader_input( tensorizer: Tensorizer, reader_batch: ReaderBatch, ) -> BiEncoderBatch: input_ids = reader_batch.input_ids # (N, M, L) question_ids: List[T] = [] # len N context_ids: List[T] = [] # len N * M for input_id_i in input_ids: for j, input_id in enumerate(input_id_i): ids = tensorizer.unconcatenate_inputs( input_id, components={"question", "passage_title", "passage"}) if ids is None: # full padding context_ids.append(input_id) continue # Question question_id = tensorizer.concatenate_inputs( ids={"question": ids["question"].tolist()}, get_passage_offset=False, to_max_length=True, ) if j == 0: question_ids.append(question_id) else: assert (question_id == question_ids[-1]).all() # Passage passage_title = ids["passage_title"] passage = ids["passage"] context_ids.append( tensorizer.concatenate_inputs( ids={ "passage_title": passage_title.tolist(), "passage": passage.tolist() }, get_passage_offset=False, to_max_length=True, )) question_ids = torch.stack(question_ids) context_ids = torch.stack(context_ids) question_segments = torch.zeros_like(question_ids) context_segments = torch.zeros_like(context_ids) biencoder_batch = BiEncoderBatch( question_ids=question_ids, question_segments=question_segments, context_IDs=None, # not used context_ids=context_ids, ctx_segments=context_segments, is_positive=None, # not used hard_negatives=None, # not used encoder_type=None, # not used ) return biencoder_batch
def create_biencoder_input_tokenized( cls, samples: List[BiEncoderSampleTokenized], tensorizer: Tensorizer, insert_title: bool, num_hard_negatives: int = 0, num_bm25_negatives: int = 0, shuffle: bool = True, shuffle_positives: bool = False, hard_neg_fallback: bool = True, query_token: str = None, ) -> BiEncoderBatch: """ Creates a batch of the biencoder training tuple using tokenized data. :param samples: list of BiEncoderSampleTokenized-s to create the batch for :param tensorizer: components to create model input tensors from a text sequence :param insert_title: enables title insertion at the beginning of the context sequences :param num_hard_negatives: amount of hard negatives (densely retrieved) per question :param num_bm25_negatives: amount of BM25 negatives (sparsely retrieved) per question :param shuffle: shuffles negative passages pools :param shuffle_positives: shuffles positive passages pools. This is only effective for samples whose gold passage is available. In that case, the positive chosen is not necessarily the gold passage. Otherwise, the positive passages will be shuffled regardless of this parameter. :return: BiEncoderBatch tuple """ question_tensors: List[T] = [] ctx_ids: List[int] = [] # passage IDs ctx_tensors: List[T] = [] positive_ctx_indices = [] hard_neg_ctx_indices = [] # Strict settings assert insert_title is True # for now only allow `insert_title` to be True assert query_token is None for sample in samples: # Skip samples without positive passges (either gold or distant positives) if len(sample.positive_passages) == 0: continue # ctx+ & [ctx-] composition # as of now, take the first(gold) ctx+ only if (shuffle and shuffle_positives) or ( not sample.positive_passages[0].is_gold): positive_ctxs = sample.positive_passages positive_ctx = positive_ctxs[np.random.choice( len(positive_ctxs))] else: positive_ctx = sample.positive_passages[0] bm25_neg_ctxs = sample.bm25_negative_passages hard_neg_ctxs = sample.hard_negative_passages question_ids = sample.query_ids if shuffle: random.shuffle(bm25_neg_ctxs) random.shuffle(hard_neg_ctxs) if hard_neg_fallback and len(hard_neg_ctxs) == 0: hard_neg_ctxs = bm25_neg_ctxs[0:num_hard_negatives] bm25_neg_ctxs = bm25_neg_ctxs[0:num_bm25_negatives] hard_neg_ctxs = hard_neg_ctxs[0:num_hard_negatives] all_ctxs = [positive_ctx] + bm25_neg_ctxs + hard_neg_ctxs hard_negatives_start_idx = 1 + len(bm25_neg_ctxs) hard_negatives_end_idx = len(all_ctxs) current_ctxs_len = len(ctx_tensors) # Context IDs ctx_id = [ctx.id for ctx in all_ctxs] ctx_ids.extend(ctx_id) # Context tensors sample_ctxs_tensors = [ tensorizer.concatenate_inputs( ids={ "passage_title": list(ctx.title_ids), "passage": list(ctx.text_ids) }, get_passage_offset=False, to_max_length=True, ) for ctx in all_ctxs ] ctx_tensors.extend(sample_ctxs_tensors) positive_ctx_indices.append(current_ctxs_len) hard_neg_ctx_indices.append([ i for i in range( current_ctxs_len + hard_negatives_start_idx, current_ctxs_len + hard_negatives_end_idx, ) ]) question_tensors.append( tensorizer.concatenate_inputs(ids={"question": question_ids}, get_passage_offset=False, to_max_length=True)) ctx_ids = torch.tensor(ctx_ids, dtype=torch.int64) ctxs_tensor = torch.cat([ctx.view(1, -1) for ctx in ctx_tensors], dim=0) questions_tensor = torch.cat([q.view(1, -1) for q in question_tensors], dim=0) ctx_segments = torch.zeros_like(ctxs_tensor) question_segments = torch.zeros_like(questions_tensor) return BiEncoderBatch( questions_tensor, question_segments, ctx_ids, ctxs_tensor, ctx_segments, positive_ctx_indices, hard_neg_ctx_indices, "question", )
def _create_question_passages_tensors( wiki_data: TokenizedWikipediaPassages, question_token_ids: np.ndarray, tensorizer: Tensorizer, positives: List[ReaderPassage], negatives: List[ReaderPassage], total_size: int, empty_ids: T, max_n_answers: int, is_train: bool, is_random: bool = True ): max_len = empty_ids.size(0) pad_token_id = tensorizer.get_pad_id() if is_train: # select just one positive positive_idx = _get_positive_idx(positives, max_len, is_random) if positive_idx is None: return None positive = positives[positive_idx] if getattr(positive, "sequence_ids", None) is None: # Load in passage tokens and title tokens positive.load_tokens( question_token_ids=question_token_ids, **wiki_data.get_tokenized_data(int(positive.id)) ) sequence_ids, passage_offset = tensorizer.concatenate_inputs({ "question": positive.question_token_ids, "passage_title": positive.title_token_ids, "passage": positive.passage_token_ids, }, get_passage_offset=True) positive.sequence_ids = sequence_ids positive.passage_offset = passage_offset positive.answers_spans = [ (start + passage_offset, end + passage_offset) for start, end in positive.answers_spans ] positive_a_spans = _get_answer_spans(positive_idx, positives, max_len)[0: max_n_answers] answer_starts = [span[0] for span in positive_a_spans] answer_ends = [span[1] for span in positive_a_spans] assert all(s < max_len for s in answer_starts) assert all(e < max_len for e in answer_ends) positive_input_ids = tensorizer.to_max_length(positive.sequence_ids.numpy(), apply_max_len=True) positive_input_ids = torch.from_numpy(positive_input_ids) answer_starts_tensor = torch.zeros((total_size, max_n_answers)).long() answer_starts_tensor[0, 0:len(answer_starts)] = torch.tensor(answer_starts) # only first passage contains the answer answer_ends_tensor = torch.zeros((total_size, max_n_answers)).long() answer_ends_tensor[0, 0:len(answer_ends)] = torch.tensor(answer_ends) # only first passage contains the answer answer_mask = torch.zeros((total_size, max_n_answers), dtype=torch.long) answer_mask[0, 0:len(answer_starts)] = torch.tensor([1 for _ in range(len(answer_starts))]) positives_IDs: List[int] = [positive.id] positives_selected = [positive_input_ids] else: positives_IDs: List[int] = [] positives_selected = [] answer_starts_tensor = None answer_ends_tensor = None answer_mask = None positives_num = len(positives_selected) negative_idxs = np.random.permutation(range(len(negatives))) if is_random else range( len(negatives) - positives_num) negative_idxs = negative_idxs[:total_size - positives_num] negatives_IDs: List[int] = [] negatives_selected = [] for negative_idx in negative_idxs: negative = negatives[negative_idx] if getattr(negative, "sequence_ids", None) is None: # Load in passage tokens and title tokens negative.load_tokens( question_token_ids=question_token_ids, **wiki_data.get_tokenized_data(int(negative.id)) ) # Concatenate input tokens sequence_ids, passage_offset = tensorizer.concatenate_inputs({ "question": negative.question_token_ids, "passage_title": negative.title_token_ids, "passage": negative.passage_token_ids, }, get_passage_offset=True) negative.sequence_ids = sequence_ids negative.passage_offset = passage_offset negatives_IDs.append(negative.id) negative_input_ids = tensorizer.to_max_length(negative.sequence_ids.numpy(), apply_max_len=True) negatives_selected.append(torch.from_numpy(negative_input_ids)) while len(negatives_selected) < total_size - positives_num: negatives_IDs.append(-1) negatives_selected.append(empty_ids.clone()) context_IDs = torch.tensor(positives_IDs + negatives_IDs, dtype=torch.int64) input_ids = torch.stack([t for t in positives_selected + negatives_selected], dim=0) assert len(context_IDs) == len(input_ids) return context_IDs, input_ids, answer_starts_tensor, answer_ends_tensor, answer_mask