def _select_span_with_token( text: str, tensorizer: Tensorizer, token_str: str = "[START_ENT]" ) -> T: id = tensorizer.get_token_id(token_str) query_tensor = tensorizer.text_to_tensor(text) if id not in query_tensor: query_tensor_full = tensorizer.text_to_tensor(text, apply_max_len=False) token_indexes = (query_tensor_full == id).nonzero() if token_indexes.size(0) > 0: start_pos = token_indexes[0, 0].item() # add some randomization to avoid overfitting to a specific token position left_shit = int(tensorizer.max_length / 2) rnd_shift = int((rnd.random() - 0.5) * left_shit / 2) left_shit += rnd_shift query_tensor = query_tensor_full[start_pos - left_shit :] cls_id = tensorizer.tokenizer.cls_token_id if query_tensor[0] != cls_id: query_tensor = torch.cat([torch.tensor([cls_id]), query_tensor], dim=0) from dpr.models.reader import _pad_to_len query_tensor = _pad_to_len( query_tensor, tensorizer.get_pad_id(), tensorizer.max_length ) query_tensor[-1] = tensorizer.tokenizer.sep_token_id # logger.info('aligned query_tensor %s', query_tensor) assert id in query_tensor, "query_tensor={}".format(query_tensor) return query_tensor else: raise RuntimeError( "[START_ENT] toke not found for Entity Linking sample query={}".format( text ) ) else: return query_tensor
def _create_question_passages_tensors( wiki_data: TokenizedWikipediaPassages, question_token_ids: np.ndarray, tensorizer: Tensorizer, positives: List[ReaderPassage], negatives: List[ReaderPassage], total_size: int, empty_ids: T, max_n_answers: int, is_train: bool, is_random: bool = True ): max_len = empty_ids.size(0) pad_token_id = tensorizer.get_pad_id() if is_train: # select just one positive positive_idx = _get_positive_idx(positives, max_len, is_random) if positive_idx is None: return None positive = positives[positive_idx] if getattr(positive, "sequence_ids", None) is None: # Load in passage tokens and title tokens positive.load_tokens( question_token_ids=question_token_ids, **wiki_data.get_tokenized_data(int(positive.id)) ) sequence_ids, passage_offset = tensorizer.concatenate_inputs({ "question": positive.question_token_ids, "passage_title": positive.title_token_ids, "passage": positive.passage_token_ids, }, get_passage_offset=True) positive.sequence_ids = sequence_ids positive.passage_offset = passage_offset positive.answers_spans = [ (start + passage_offset, end + passage_offset) for start, end in positive.answers_spans ] positive_a_spans = _get_answer_spans(positive_idx, positives, max_len)[0: max_n_answers] answer_starts = [span[0] for span in positive_a_spans] answer_ends = [span[1] for span in positive_a_spans] assert all(s < max_len for s in answer_starts) assert all(e < max_len for e in answer_ends) positive_input_ids = tensorizer.to_max_length(positive.sequence_ids.numpy(), apply_max_len=True) positive_input_ids = torch.from_numpy(positive_input_ids) answer_starts_tensor = torch.zeros((total_size, max_n_answers)).long() answer_starts_tensor[0, 0:len(answer_starts)] = torch.tensor(answer_starts) # only first passage contains the answer answer_ends_tensor = torch.zeros((total_size, max_n_answers)).long() answer_ends_tensor[0, 0:len(answer_ends)] = torch.tensor(answer_ends) # only first passage contains the answer answer_mask = torch.zeros((total_size, max_n_answers), dtype=torch.long) answer_mask[0, 0:len(answer_starts)] = torch.tensor([1 for _ in range(len(answer_starts))]) positives_IDs: List[int] = [positive.id] positives_selected = [positive_input_ids] else: positives_IDs: List[int] = [] positives_selected = [] answer_starts_tensor = None answer_ends_tensor = None answer_mask = None positives_num = len(positives_selected) negative_idxs = np.random.permutation(range(len(negatives))) if is_random else range( len(negatives) - positives_num) negative_idxs = negative_idxs[:total_size - positives_num] negatives_IDs: List[int] = [] negatives_selected = [] for negative_idx in negative_idxs: negative = negatives[negative_idx] if getattr(negative, "sequence_ids", None) is None: # Load in passage tokens and title tokens negative.load_tokens( question_token_ids=question_token_ids, **wiki_data.get_tokenized_data(int(negative.id)) ) # Concatenate input tokens sequence_ids, passage_offset = tensorizer.concatenate_inputs({ "question": negative.question_token_ids, "passage_title": negative.title_token_ids, "passage": negative.passage_token_ids, }, get_passage_offset=True) negative.sequence_ids = sequence_ids negative.passage_offset = passage_offset negatives_IDs.append(negative.id) negative_input_ids = tensorizer.to_max_length(negative.sequence_ids.numpy(), apply_max_len=True) negatives_selected.append(torch.from_numpy(negative_input_ids)) while len(negatives_selected) < total_size - positives_num: negatives_IDs.append(-1) negatives_selected.append(empty_ids.clone()) context_IDs = torch.tensor(positives_IDs + negatives_IDs, dtype=torch.int64) input_ids = torch.stack([t for t in positives_selected + negatives_selected], dim=0) assert len(context_IDs) == len(input_ids) return context_IDs, input_ids, answer_starts_tensor, answer_ends_tensor, answer_mask
def create_reader_input( wiki_data: TokenizedWikipediaPassages, tensorizer: Tensorizer, samples: List[ReaderSample], passages_per_question: int, max_length: int, max_n_answers: int, is_train: bool, shuffle: bool, ) -> ReaderBatch: """ Creates a reader batch instance out of a list of ReaderSample-s. This is compatible with `GeneralDataset`. :param wiki_data: all tokenized wikipedia passages :param tensorizer: initialized tensorizer (which contains the tokenizer) :param samples: list of samples to create the batch for :param passages_per_question: amount of passages for every question in a batch :param max_length: max model input sequence length :param max_n_answers: max num of answers per single question :param is_train: if the samples are for a train set :param shuffle: should passages selection be randomized :return: ReaderBatch instance """ context_IDs = [] input_ids = [] start_positions = [] end_positions = [] answers_masks = [] empty_sequence = torch.Tensor().new_full((max_length,), tensorizer.get_pad_id(), dtype=torch.long) for sample in samples: if is_train: positive_ctxs = sample.positive_passages negative_ctxs = sample.negative_passages else: positive_ctxs = [] negative_ctxs = sample.positive_passages + sample.negative_passages # Need to re-sort samples based on their scores negative_ctxs = sorted(negative_ctxs, key=lambda x: x.score, reverse=True) question_token_ids = sample.question_token_ids sample_tensors = _create_question_passages_tensors( wiki_data, question_token_ids, tensorizer, positive_ctxs, negative_ctxs, passages_per_question, empty_sequence, max_n_answers, is_train, is_random=shuffle ) if not sample_tensors: logger.warning('No valid passages combination for question=%s ', sample.question) continue context_ID, sample_input_ids, starts_tensor, ends_tensor, answer_mask = sample_tensors context_IDs.append(context_ID) input_ids.append(sample_input_ids) if is_train: start_positions.append(starts_tensor) end_positions.append(ends_tensor) answers_masks.append(answer_mask) context_IDs = torch.cat([IDs.unsqueeze(0) for IDs in context_IDs], dim=0) # (N, M) input_ids = torch.cat([ids.unsqueeze(0) for ids in input_ids], dim=0) # (N, M) if is_train: start_positions = torch.stack(start_positions, dim=0) end_positions = torch.stack(end_positions, dim=0) answers_masks = torch.stack(answers_masks, dim=0) return ReaderBatch(context_IDs, input_ids, start_positions, end_positions, answers_masks)