def text_to_instance(self, question_text: str, passage_text: str, passage_tokens: List[Token], passage_text_index_to_token_index: List[int], passage_wordpieces: List[List[int]], question_id: str = None, passage_id: str = None, answer_annotations: List[Dict] = None, original_answer_annotations: List[List[Dict]] = None, answer_type: str = None, instance_index: int = None) -> Optional[Instance]: # Tokenize question question_tokens = self._tokenizer.tokenize_with_offsets(question_text) question_text_index_to_token_index = index_text_to_tokens( question_text, question_tokens) question_words = self._word_tokenize(question_text) question_alignment = self._tokenizer.align_tokens_to_tokens( question_text, question_words, question_tokens) question_wordpieces = self._tokenizer.alignment_to_token_wordpieces( question_alignment) # Index tokens encoded_inputs = self._tokenizer.encode_plus( [token.text for token in question_tokens], [token.text for token in passage_tokens], add_special_tokens=True, max_length=self._max_pieces, truncation_strategy='only_second', return_token_type_ids=True, return_special_tokens_mask=True) question_passage_token_type_ids = encoded_inputs['token_type_ids'] question_passage_special_tokens_mask = encoded_inputs[ 'special_tokens_mask'] question_position = self._tokenizer.get_type_position_in_sequence( 0, question_passage_token_type_ids, question_passage_special_tokens_mask) passage_position = self._tokenizer.get_type_position_in_sequence( 1, question_passage_token_type_ids, question_passage_special_tokens_mask) question_passage_tokens, num_of_tokens_per_type = self._tokenizer.convert_to_tokens( encoded_inputs, [{ 'tokens': question_tokens, 'wordpieces': question_wordpieces, 'position': question_position }, { 'tokens': passage_tokens, 'wordpieces': passage_wordpieces, 'position': passage_position }]) # Adjust wordpieces question_passage_wordpieces = self._tokenizer.adjust_wordpieces( [{ 'wordpieces': question_wordpieces, 'position': question_position, 'num_of_tokens': num_of_tokens_per_type[0] }, { 'wordpieces': passage_wordpieces, 'position': passage_position, 'num_of_tokens': num_of_tokens_per_type[1] }], question_passage_tokens) # Adjust text index to token index question_text_index_to_token_index = [ token_index + question_position for i, token_index in enumerate(question_text_index_to_token_index) if token_index < num_of_tokens_per_type[0] ] passage_text_index_to_token_index = [ token_index + passage_position for i, token_index in enumerate(passage_text_index_to_token_index) if token_index < num_of_tokens_per_type[1] ] # Truncation-related code encoded_passage_tokens_length = num_of_tokens_per_type[1] if encoded_passage_tokens_length > 0: if encoded_passage_tokens_length < len(passage_tokens): first_truncated_passage_token = passage_tokens[ encoded_passage_tokens_length] max_passage_length = first_truncated_passage_token.idx else: max_passage_length = -1 else: max_passage_length = 0 fields: Dict[str, Field] = {} fields[ 'question_passage_tokens'] = question_passage_field = LabelsField( encoded_inputs['input_ids']) fields['question_passage_token_type_ids'] = LabelsField( question_passage_token_type_ids) fields['question_passage_special_tokens_mask'] = LabelsField( question_passage_special_tokens_mask) fields['question_passage_pad_mask'] = LabelsField( [1] * len(question_passage_tokens)) # in a word broken up into pieces, every piece except the first should be ignored when calculating the loss first_wordpiece_mask = [ i == wordpieces[0] for i, wordpieces in enumerate(question_passage_wordpieces) ] fields['first_wordpiece_mask'] = LabelsField(first_wordpiece_mask) # Compile question, passage, answer metadata metadata = { 'original_passage': passage_text, 'original_question': question_text, 'passage_tokens': passage_tokens, 'question_tokens': question_tokens, 'question_passage_tokens': question_passage_tokens, 'question_passage_wordpieces': question_passage_wordpieces, 'passage_id': passage_id, 'question_id': question_id, 'max_passage_length': max_passage_length } if instance_index is not None: metadata['instance_index'] = instance_index if answer_annotations: _, answer_texts = extract_answer_info_from_annotation( answer_annotations[0]) answer_texts = list(OrderedDict.fromkeys(answer_texts)) gold_indexes = {'question': [], 'passage': []} for original_answer_annotation in original_answer_annotations[0]: if original_answer_annotation['text'] in answer_texts: gold_index = original_answer_annotation['answer_start'] if gold_index not in gold_indexes['passage']: gold_indexes['passage'].append(gold_index) metadata['answer_annotations'] = answer_annotations kwargs = { 'seq_tokens': question_passage_tokens, 'seq_field': question_passage_field, 'seq_wordpieces': question_passage_wordpieces, 'question_text': question_text, 'question_text_index_to_token_index': question_text_index_to_token_index, 'passage_text': passage_text[:max_passage_length] if max_passage_length > -1 else passage_text, 'passage_text_index_to_token_index': passage_text_index_to_token_index, 'answer_texts': answer_texts, 'gold_indexes': gold_indexes, 'answer_type': answer_type, # TODO: Elad - Probably temporary, used to mimic the old reader's behavior 'is_training': self. _is_training, # TODO: Elad - Probably temporary, used to mimic the old reader's behavior 'old_reader_behavior': self. _old_reader_behavior # TODO: Elad - temporary, used to mimic the old reader's behavior } answer_generator_names = None if self._answer_generator_names_per_type is not None: answer_generator_names = self._answer_generator_names_per_type[ answer_type] has_answer = False for answer_generator_name, answer_field_generator in self._answer_field_generators.items( ): if answer_generator_names is None or answer_generator_name in answer_generator_names: answer_fields, generator_has_answer = answer_field_generator.get_answer_fields( **kwargs) fields.update(answer_fields) has_answer |= generator_has_answer else: fields.update( answer_field_generator.get_empty_answer_fields( **kwargs)) # throw away instances without possible answer generation if self._is_training and not has_answer: return None fields['metadata'] = MetadataField(metadata) return Instance(fields)
def _read(self, file_path: str): if not self._lazy and self._pickle['action'] == 'load': # Try to load the data, if it fails then read it from scratch and save it loaded_pkl = load_pkl(self._pickle, self._is_training) if loaded_pkl is not None: for instance in loaded_pkl: yield instance return else: self._pickle['action'] = 'save' # import pdb # pdb.set_trace() # add supoort for multi-dataset training global_index = 0 instances_count = 0 instances = [] # global_instances = {"global_index": 0, "instances_count": 0, "instances": []} for ind, single_file_path in enumerate(file_path.split(',')): single_file_path_cached = cached_path(single_file_path) print(single_file_path) with open(single_file_path_cached, encoding="utf8") as dataset_file: dataset = json.load(dataset_file) dataset = standardize_dataset(dataset, self._standardize_text_func) # self.gen_dataset_instances(dataset, global_instances) for passage_id, passage_info in dataset.items(): # for passage_id, passage_info in tqdm(dataset.items()): passage_text = passage_info['passage'] # Tokenize passage passage_tokens = self._tokenizer.tokenize_with_offsets( passage_text) passage_text_index_to_token_index = index_text_to_tokens( passage_text, passage_tokens) passage_words = self._word_tokenize(passage_text) passage_alignment = self._tokenizer.align_tokens_to_tokens( passage_text, passage_words, passage_tokens) passage_wordpieces = self._tokenizer.alignment_to_token_wordpieces( passage_alignment) # Process questions from this passage for relative_index, qa_pair in enumerate( passage_info['qa_pairs']): if 0 < self._max_instances <= instances_count: if not self._lazy and self._pickle['action'] == 'save': save_pkl(instances, self._pickle, self._is_training) return question_id = qa_pair['query_id'] question_text = qa_pair['question'] answer_annotations: List[Dict] = list() original_answer_annotations: List[List[Dict]] = list() answer_type = None if 'answer' in qa_pair and qa_pair['answer']: answer = qa_pair['answer'] original_answer = qa_pair['original_answer'] answer_type = get_answer_type(answer) if answer_type is None or answer_type not in self._answer_types_filter: continue answer_annotations.append(answer) original_answer_annotations.append(original_answer) # If the standardization deleted characters then we need to adjust answer_start deletion_indexes = \ self._standardize_text_func(passage_info['original_passage'], deletions_tracking=True)[1] for span in original_answer: answer_start = span['answer_start'] for index in deletion_indexes.keys(): if span['answer_start'] > index: answer_start -= deletion_indexes[index] span['answer_start'] = answer_start if self._is_training and answer_type is None: continue try: from termcolor import colored instance = self.text_to_instance( question_text, passage_text, passage_tokens, passage_text_index_to_token_index, passage_wordpieces, question_id, passage_id, answer_annotations, original_answer_annotations, answer_type, global_index + relative_index) except BaseException as e: print(colored(e, 'green')) print( colored("found exception in sentence: " + question_id + question_text)) # else: # continue if instance is not None: instances_count += 1 if not self._lazy: instances.append(instance) yield instance global_index += len(passage_info['qa_pairs']) if not self._lazy and self._pickle['action'] == 'save': save_pkl(instances, self._pickle, self._is_training)
def gen_dataset_instances(self, dataset, globel_info): import pdb pdb.set_trace() print("generate dataset instances1") for passage_id, passage_info in tqdm(dataset.items()): passage_text = passage_info['passage'] # Tokenize passage passage_tokens = self._tokenizer.tokenize_with_offsets( passage_text) passage_text_index_to_token_index = index_text_to_tokens( passage_text, passage_tokens) passage_words = self._word_tokenize(passage_text) passage_alignment = self._tokenizer.align_tokens_to_tokens( passage_text, passage_words, passage_tokens) passage_wordpieces = self._tokenizer.alignment_to_token_wordpieces( passage_alignment) # Process questions from this passage for relative_index, qa_pair in enumerate(passage_info['qa_pairs']): if 0 < self._max_instances <= globel_info["instances_count"]: if not self._lazy and self._pickle['action'] == 'save': save_pkl(globel_info["instance"], self._pickle, self._is_training) return question_id = qa_pair['query_id'] question_text = qa_pair['question'] answer_annotations: List[Dict] = list() original_answer_annotations: List[List[Dict]] = list() answer_type = None if 'answer' in qa_pair and qa_pair['answer']: answer = qa_pair['answer'] original_answer = qa_pair['original_answer'] answer_type = get_answer_type(answer) if answer_type is None or answer_type not in self._answer_types_filter: continue answer_annotations.append(answer) original_answer_annotations.append(original_answer) # If the standardization deleted characters then we need to adjust answer_start deletion_indexes = self._standardize_text_func( passage_info['original_passage'], deletions_tracking=True)[1] for span in original_answer: answer_start = span['answer_start'] for index in deletion_indexes.keys(): if span['answer_start'] > index: answer_start -= deletion_indexes[index] span['answer_start'] = answer_start if self._is_training and answer_type is None: continue instance = self.text_to_instance( question_text, passage_text, passage_tokens, passage_text_index_to_token_index, passage_wordpieces, question_id, passage_id, answer_annotations, original_answer_annotations, answer_type, globel_info["global_index"] + relative_index) if instance is not None: globel_info["instances_count"] += 1 if not self._lazy: globel_info["instances"].append(instance) yield instance globel_info["globel_index"] += len(passage_info['qa_pairs']) print("generate dataset instances")
def text_to_instance(self, question_text: str, passage_text: str, passage_tokens: List[Token], passage_text_index_to_token_index: List[int], passage_wordpieces: List[List[int]], number_occurrences_in_passage: List[Dict[str, Any]], question_id: str = None, passage_id: str = None, answer_annotations: List[Dict] = None, answer_type: str = None, instance_index: int = None) -> Optional[Instance]: """ process a question/answer pair (related to a passage) and prepares an `Instance` Parameters ---------- question_text: str, passage_text: str, passage_tokens: List[Token], passage_text_index_to_token_index: List[int], passage_wordpieces: List[List[int]], number_occurrences_in_passage: List[Dict[str, Any]], question_id: str = None, passage_id: str = None, answer_annotations: List[Dict] = None, answer_type: str = None, instance_index: int = None Returns ------- instance: Optional[Instance] an `Instance` containing all the data that the `Model` will takes as the input """ # We alter it, so use a copy to keep it usable for the next questions with the same paragrpah number_occurrences_in_passage = [number_occurrence.copy() for number_occurrence in number_occurrences_in_passage] """ Tokenize question - question_tokens: [ĠHow, Ġmany, Ġpoints, Ġdid, Ġthe, Ġbu, cc, aneers, Ġneed, ... - question_text_index_to_token_index: [0, 0, 0, 0, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, ... - question_words: [How, many, points, did, the, buccaneers, need, to, tie, in, the, first, ?] - question_alignment: [[0], [1], [2], [3], [4], [5, 6, 7], [8], [9], [10], [11], [12], [13], [14]] - question_wordpieces: [[0], [1], [2], [3], [4], [5, 6, 7], [5, 6, 7], [5, 6, 7], [8], [9], [10], ... """ question_tokens = self._tokenizer.tokenize_with_offsets(question_text) question_text_index_to_token_index = index_text_to_tokens(question_text, question_tokens) question_words = self._word_tokenize(question_text) question_alignment = self._tokenizer.align_tokens_to_tokens(question_text, question_words, question_tokens) question_wordpieces = self._tokenizer.alignment_to_token_wordpieces(question_alignment) """ Index tokens encoded_inputs is a dictionary with: - 'special_tokens_mask', - 'input_ids', - 'token_type_ids', - 'attention_mask' """ encoded_inputs = self._tokenizer.encode_plus([token.text for token in question_tokens], [token.text for token in passage_tokens], add_special_tokens=True, max_length=self._max_pieces, truncation_strategy='only_second', return_token_type_ids=True, return_special_tokens_mask=True) question_passage_token_type_ids = encoded_inputs['token_type_ids'] question_passage_special_tokens_mask = encoded_inputs['special_tokens_mask'] question_position = self._tokenizer.get_type_position_in_sequence(0, question_passage_token_type_ids, question_passage_special_tokens_mask) passage_position = self._tokenizer.get_type_position_in_sequence(1, question_passage_token_type_ids, question_passage_special_tokens_mask) question_passage_tokens, num_of_tokens_per_type = self._tokenizer.convert_to_tokens(encoded_inputs, [ {'tokens': question_tokens, 'wordpieces': question_wordpieces, 'position': question_position}, {'tokens': passage_tokens, 'wordpieces': passage_wordpieces, 'position': passage_position} ]) # Adjust wordpieces question_passage_wordpieces = self._tokenizer.adjust_wordpieces([ {'wordpieces': question_wordpieces, 'position': question_position, 'num_of_tokens': num_of_tokens_per_type[0]}, {'wordpieces': passage_wordpieces, 'position': passage_position, 'num_of_tokens': num_of_tokens_per_type[1]} ], question_passage_tokens) # Adjust text index to token index question_text_index_to_token_index = [token_index + question_position for i, token_index in enumerate(question_text_index_to_token_index) if token_index < num_of_tokens_per_type[0]] passage_text_index_to_token_index = [token_index + passage_position for i, token_index in enumerate(passage_text_index_to_token_index) if token_index < num_of_tokens_per_type[1]] # Truncation-related code encoded_passage_tokens_length = num_of_tokens_per_type[1] if encoded_passage_tokens_length > 0: if encoded_passage_tokens_length < len(passage_tokens): first_truncated_passage_token = passage_tokens[encoded_passage_tokens_length] max_passage_length = first_truncated_passage_token.idx number_occurrences_in_passage = \ clipped_passage_num(number_occurrences_in_passage, encoded_passage_tokens_length) else: max_passage_length = -1 else: max_passage_length = 0 # update the indices of the numbers with respect to the question. for number_occurrence in number_occurrences_in_passage: number_occurrence['indices'] = [index + passage_position for index in number_occurrence['indices']] # hack to guarantee minimal length of padded number # according to dataset_readers.reading_comprehension.drop from allennlp) number_occurrences_in_passage.append({ 'value': 0, 'indices': [-1] }) # create fields dictionary for the `Instance` fields: Dict[str, Field] = {} fields['question_passage_tokens'] = question_passage_field = LabelsField(encoded_inputs['input_ids']) fields['question_passage_token_type_ids'] = LabelsField(question_passage_token_type_ids) fields['question_passage_special_tokens_mask'] = LabelsField(question_passage_special_tokens_mask) fields['question_passage_pad_mask'] = LabelsField([1] * len(question_passage_tokens)) # in a word broken up into pieces, every piece except the first should be ignored when calculating the loss first_wordpiece_mask = [i == wordpieces[0] for i, wordpieces in enumerate(question_passage_wordpieces)] fields['first_wordpiece_mask'] = LabelsField(first_wordpiece_mask) fields['number_indices'] = get_number_indices_field(number_occurrences_in_passage) # Compile question, passage, answer metadata metadata = {'original_passage': passage_text, 'original_question': question_text, 'original_numbers': [number_occurrence['value'] for number_occurrence in number_occurrences_in_passage], 'passage_tokens': passage_tokens, 'question_tokens': question_tokens, 'question_passage_tokens': question_passage_tokens, 'question_passage_wordpieces': question_passage_wordpieces, 'passage_id': passage_id, 'question_id': question_id, 'max_passage_length': max_passage_length} if instance_index is not None: metadata['instance_index'] = instance_index if answer_annotations: # Get answer type, answer text, tokenize # For multi-span, remove repeating answers. Although possible, in the dataset it is mostly mistakes. answer_accessor, answer_texts = extract_answer_info_from_annotation(answer_annotations[0]) if answer_accessor == AnswerAccessor.SPAN.value: answer_texts = list(OrderedDict.fromkeys(answer_texts)) metadata['answer_annotations'] = answer_annotations kwargs = { 'seq_tokens': question_passage_tokens, 'seq_field': question_passage_field, 'seq_wordpieces': question_passage_wordpieces, 'question_text': question_text, 'question_text_index_to_token_index': question_text_index_to_token_index, 'passage_text': passage_text[:max_passage_length] if max_passage_length > -1 else passage_text, 'passage_text_index_to_token_index': passage_text_index_to_token_index, 'answer_texts': answer_texts, 'number_occurrences_in_passage': number_occurrences_in_passage, 'answer_type': answer_type, # TODO: Elad - Probably temporary, used to mimic the old reader's behavior 'is_training': self._is_training, # TODO: Elad - Probably temporary, used to mimic the old reader's behavior 'old_reader_behavior': self._old_reader_behavior # TODO: Elad - temporary, used to mimic the old reader's behavior } answer_generator_names = None if self._answer_generator_names_per_type is not None: answer_generator_names = self._answer_generator_names_per_type[answer_type] has_answer = False for answer_generator_name, answer_field_generator in self._answer_field_generators.items(): if answer_generator_names is None or answer_generator_name in answer_generator_names: answer_fields, generator_has_answer = answer_field_generator.get_answer_fields(**kwargs) fields.update(answer_fields) has_answer |= generator_has_answer else: fields.update(answer_field_generator.get_empty_answer_fields(**kwargs)) # throw away instances without possible answer generation if self._is_training and not has_answer: return None fields['metadata'] = MetadataField(metadata) return Instance(fields)
def _read(self, file_path: str): """ This method reads the file containing the dataset and prepares the data for the model. It processes the passage text and calls `text_to_instance` method to process the question/answer text Parameters ---------- file_path: str the file location of the dataset Returns ------- instances: Iterable an Iterable (a generator) that contains pre-processed data as a list of `Instance` objects; """ instances_count = 0 if not self._lazy and self._pickle['action'] == 'load': print("entra nell'if del load") # Try to load the data, if it fails then read it from scratch and save it loaded_pkl = load_pkl(self._pickle, self._is_training) print(loaded_pkl) if loaded_pkl is not None: for instance in loaded_pkl: if 0 < self._max_instances <= instances_count: return instances_count += 1 yield instance return else: self._pickle['action'] = 'save' file_path = cached_path(file_path) with open(file_path, encoding="utf8") as dataset_file: dataset = json.load(dataset_file) dataset = standardize_dataset(dataset, self._standardize_text_func) global_index = 0 instances = [] for passage_id, passage_info in tqdm(dataset.items()): passage_text = passage_info['passage'] """ Tokenize passage - passage_tokens: [ĠTo, Ġstart, Ġthe, Ġseason, ... - passage_text_index_to_token_index: [0, 0, 0, 1, 1, 1, 1, 1, 1, ... - passage_words: [To, start, the, season, ,, the, ... - passage_alignment_ [[0], [1], [2], [3], [4], [5], ... - passage_wordpieces: [[0], [1], [2], [3], ... """ passage_tokens = self._tokenizer.tokenize_with_offsets(passage_text) passage_text_index_to_token_index = index_text_to_tokens(passage_text, passage_tokens) passage_words = number_extraction_tokens = self._word_tokenize(passage_text) passage_alignment = self._tokenizer.align_tokens_to_tokens(passage_text, passage_words, passage_tokens) passage_wordpieces = self._tokenizer.alignment_to_token_wordpieces(passage_alignment) number_occurrences_in_passage = \ extract_number_occurrences(number_extraction_tokens, passage_alignment) # Process questions from this passage for relative_index, qa_pair in enumerate(passage_info['qa_pairs']): if 0 < self._max_instances <= instances_count: if not self._lazy and self._pickle['action'] == 'save': save_pkl(instances, self._pickle, self._is_training) return # extract question id and question text from the pair question_id = qa_pair['query_id'] question_text = qa_pair['question'] # this list will contains the answer and the validated answers answer_annotations: List[Dict] = list() answer_type = None if 'answer' in qa_pair and qa_pair['answer']: # extract the answer answer = qa_pair['answer'] # deduces the answer type checking the answer's fields answer_type = get_answer_type(answer) # skips answer with a non-valid type if answer_type is None or answer_type not in self._answer_types_filter: continue answer_annotations.append(answer) if 'validated_answers' in qa_pair and qa_pair['validated_answers']: answer_annotations += qa_pair['validated_answers'] if self._is_training and answer_type is None: continue instance = self.text_to_instance(question_text, passage_text, passage_tokens, passage_text_index_to_token_index, passage_wordpieces, number_occurrences_in_passage, question_id, passage_id, answer_annotations, answer_type, global_index + relative_index) if instance is not None: instances_count += 1 if not self._lazy: instances.append(instance) yield instance global_index += len(passage_info['qa_pairs']) if not self._lazy and self._pickle['action'] == 'save': save_pkl(instances, self._pickle, self._is_training)
def _read(self, file_path: str): instances_count = 0 if not self._lazy and self._pickle['action'] == 'load': # Try to load the data, if it fails then read it from scratch and save it loaded_pkl = load_pkl(self._pickle, self._is_training) if loaded_pkl is not None: for instance in loaded_pkl: if 0 < self._max_instances <= instances_count: return instances_count += 1 yield instance return else: self._pickle['action'] = 'save' file_path = cached_path(file_path) with open(file_path, encoding='utf8') as dataset_file: dataset = json.load(dataset_file) dataset = standardize_dataset(dataset, self._standardize_text_func) global_index = 0 instances = [] for passage_id, passage_info in tqdm(dataset.items()): passage_text = passage_info['passage'] # Tokenize passage passage_tokens = self._tokenizer.tokenize_with_offsets( passage_text) passage_text_index_to_token_index = index_text_to_tokens( passage_text, passage_tokens) passage_words = number_extraction_tokens = self._word_tokenize( passage_text) passage_alignment = self._tokenizer.align_tokens_to_tokens( passage_text, passage_words, passage_tokens) passage_wordpieces = self._tokenizer.alignment_to_token_wordpieces( passage_alignment) number_occurrences_in_passage =\ extract_number_occurrences(number_extraction_tokens, passage_alignment) # Process questions from this passage for relative_index, qa_pair in enumerate(passage_info['qa_pairs']): if 0 < self._max_instances <= instances_count: if not self._lazy and self._pickle['action'] == 'save': save_pkl(instances, self._pickle, self._is_training) return question_id = qa_pair['query_id'] question_text = qa_pair['question'] answer_annotations: List[Dict] = list() answer_type = None if 'answer' in qa_pair and qa_pair['answer']: answer = qa_pair['answer'] answer_type = get_answer_type(answer) if answer_type is None or answer_type not in self._answer_types_filter: continue answer_annotations.append(answer) if 'validated_answers' in qa_pair and qa_pair[ 'validated_answers']: answer_annotations += qa_pair['validated_answers'] if self._is_training and answer_type is None: continue instance = self.text_to_instance( question_text, passage_text, passage_tokens, passage_text_index_to_token_index, passage_wordpieces, number_occurrences_in_passage, question_id, passage_id, answer_annotations, answer_type, global_index + relative_index) if instance is not None: instances_count += 1 if not self._lazy: instances.append(instance) yield instance global_index += len(passage_info['qa_pairs']) if not self._lazy and self._pickle['action'] == 'save': save_pkl(instances, self._pickle, self._is_training)
def get_analysis(dataset_pkl_path, predictions_path, heads): with open(os.path.join(dataset_pkl_path), 'rb') as dataset_pkl: dataset = pickle.load(dataset_pkl) predictions = {} with open(os.path.join(predictions_path), 'rb') as predictions_file: while True: line = predictions_file.readline() if not line: break prediction = json.loads(line) predictions[prediction['query_id']] = prediction tokenizer = HuggingfaceTransformersTokenizer( pretrained_model='roberta-large') word_tokenizer = custom_word_tokenizer() word_tokenize =\ lambda text: [token for token in split_tokens_by_hyphen(word_tokenizer.tokenize(text))] num_gold_tokens_to_stats = defaultdict(lambda: defaultdict(int)) num_occurrences_stats = defaultdict(lambda: defaultdict(int)) for i, instance in tqdm(enumerate(dataset)): full_prediction = predictions[instance['metadata']['question_id']] if full_prediction['predicted_ability'] not in heads: continue question_text = instance['metadata']['original_question'] question_tokens = tokenizer.tokenize_with_offsets(question_text) question_text_index_to_token_index = index_text_to_tokens( question_text, question_tokens) question_words = word_tokenize(question_text) question_alignment = tokenizer.align_tokens_to_tokens( question_text, question_words, question_tokens) question_wordpieces = tokenizer.alignment_to_token_wordpieces( question_alignment) passage_text = instance['metadata']['original_passage'] passage_tokens = tokenizer.tokenize_with_offsets(passage_text) passage_text_index_to_token_index = index_text_to_tokens( passage_text, passage_tokens) passage_words = word_tokenize(passage_text) passage_alignment = tokenizer.align_tokens_to_tokens( passage_text, passage_words, passage_tokens) passage_wordpieces = tokenizer.alignment_to_token_wordpieces( passage_alignment) passage_text = passage_text[:instance['metadata'] ['max_passage_length']] question_passage_tokens = instance['metadata'][ 'question_passage_tokens'] question_passage_wordpieces = instance['metadata'][ 'question_passage_wordpieces'] # Index tokens encoded_inputs = tokenizer.encode_plus( [token.text for token in question_tokens], [token.text for token in passage_tokens], add_special_tokens=True, max_length=MAX_PIECES, truncation_strategy='only_second', return_token_type_ids=True, return_special_tokens_mask=True) question_passage_token_type_ids = encoded_inputs['token_type_ids'] question_passage_special_tokens_mask = encoded_inputs[ 'special_tokens_mask'] question_position = tokenizer.get_type_position_in_sequence( 0, question_passage_token_type_ids, question_passage_special_tokens_mask) passage_position = tokenizer.get_type_position_in_sequence( 1, question_passage_token_type_ids, question_passage_special_tokens_mask) question_passage_tokens, num_of_tokens_per_type = tokenizer.convert_to_tokens( encoded_inputs, [{ 'tokens': question_tokens, 'wordpieces': question_wordpieces, 'position': question_position }, { 'tokens': passage_tokens, 'wordpieces': passage_wordpieces, 'position': passage_position }]) # Adjust text index to token index question_text_index_to_token_index = [ token_index + question_position for i, token_index in enumerate(question_text_index_to_token_index) if token_index < num_of_tokens_per_type[0] ] passage_text_index_to_token_index = [ token_index + passage_position for i, token_index in enumerate(passage_text_index_to_token_index) if token_index < num_of_tokens_per_type[1] ] prediction = full_prediction['answer']['value'] maximizing_ground_truth = full_prediction['maximizing_ground_truth'] answer_accessor, answer_texts = extract_answer_info_from_annotation( maximizing_ground_truth) gold_indexes = {'question': None, 'passage': None} alignment = align_predicted_and_maximizing_gold( prediction, answer_texts) for gold_index, predicted_index in enumerate(alignment): num_of_gold_tokens = len( tokenizer.tokenize(answer_texts[gold_index])) num_of_predicted_tokens = len( tokenizer.tokenize(prediction[predicted_index])) num_gold_tokens_to_stats[num_of_gold_tokens]['count'] += 1 num_gold_tokens_to_stats[num_of_gold_tokens][ 'num_predicted_tokens'] += num_of_predicted_tokens num_gold_tokens_to_stats[num_of_gold_tokens][ 'em'] += full_prediction['em'] * 100 num_gold_tokens_to_stats[num_of_gold_tokens][ 'f1'] += full_prediction['f1'] * 100 answer_spans = [] answer_spans += find_valid_spans( question_text, [answer_texts[gold_index]], question_text_index_to_token_index, question_passage_tokens, question_passage_wordpieces, gold_indexes['question']) answer_spans += find_valid_spans( passage_text, [answer_texts[gold_index]], passage_text_index_to_token_index, question_passage_tokens, question_passage_wordpieces, gold_indexes['passage']) occurrences = len(answer_spans) num_occurrences_stats[occurrences]['count'] += 1 num_occurrences_stats[occurrences][ 'em'] += full_prediction['em'] * 100 num_occurrences_stats[occurrences][ 'f1'] += full_prediction['f1'] * 100 return num_gold_tokens_to_stats, num_occurrences_stats