def get_answer_fields( self, **kwargs: Dict[str, Any]) -> Tuple[Dict[str, Field], bool]: answer_texts: List[str] = kwargs['answer_texts'] fields: Dict[str, Field] = {} target_numbers = get_target_numbers(answer_texts) numbers_for_count = list(range(self._max_count + 1)) valid_counts: List[int] = DropReader.find_valid_counts( numbers_for_count, target_numbers) if len(valid_counts) > 0: has_answer = True counts_field: List[Field] = [ LabelField(count_label, skip_indexing=True) for count_label in valid_counts ] fields['answer_as_counts'] = ListField(counts_field) else: has_answer = False fields.update(self.get_empty_answer_fields(**kwargs)) return fields, has_answer
def load_drop_data(): DROP_DEV_FILE = '/home/tony/answer-generation/raw_data/drop/drop_dataset_dev.json' data = {} with open(DROP_DEV_FILE) as dataset_file: dataset = json.load(dataset_file) for passage_id, passage_info in dataset.items(): passage = clean_string(passage_info["passage"]) for question_answer in passage_info["qa_pairs"]: question_id = question_answer["query_id"] question = clean_string(question_answer["question"]) answer_annotations = [] if "answer" in question_answer: answer_annotations.append(question_answer["answer"]) if "validated_answers" in question_answer: answer_annotations += question_answer["validated_answers"] # Extract out the label per annotation answer_annotations = [ ' '.join(DropReader.extract_answer_info_from_annotation(a)[1]) for a in answer_annotations ] # Get the most common answer as the gold answer reference = clean_string(str(most_frequent(answer_annotations))) data[question_id] = { 'context': passage, 'question': question, 'reference': reference, 'candidates': set() } return data
def clean(self, passage, question, answer, passage_tagging, question_tagging): passage_tokens = [Token(w) for w in passage_tagging['words']] spans = DropReader.find_valid_spans(passage_tokens, answer['spans']) new_answer_texts = [] cleaned = False for answer_text in answer['spans']: valid = True for span in spans: span_text = ' '.join(passage_tagging['words'][span[0]:span[1] + 1]).lower() if answer_text.lower() != span_text: continue if any(tag != 'O' for tag in passage_tagging['tags'][span[0]:span[1] + 1]): valid = False cleaned = True break if valid: new_answer_texts.append(answer_text) if not cleaned: return None new_answer = answer.copy() new_answer['spans'] = new_answer_texts return {'answer': new_answer}
def text_to_instance( self, question_text: str, passage_text: str, passage_tokens: List[Token], passage_spans: List[Tuple[int, int]], numbers_in_passage: List[Any], number_words: List[str], number_indices: List[int], number_len: List[int], question_id: str = None, passage_id: str = None, answer_annotations: List[Dict] = None) -> Union[Instance, None]: # Tokenize question and passage question_tokens = self.tokenizer.tokenize(question_text) qlen = len(question_tokens) plen = len(passage_tokens) question_passage_tokens = [Token('[CLS]')] + question_tokens + [ Token('[SEP]') ] + passage_tokens if len(question_passage_tokens) > self.max_pieces - 1: question_passage_tokens = question_passage_tokens[:self. max_pieces - 1] passage_tokens = passage_tokens[:self.max_pieces - qlen - 3] plen = len(passage_tokens) number_indices, number_len, numbers_in_passage = \ clipped_passage_num(number_indices, number_len, numbers_in_passage, plen) question_passage_tokens += [Token('[SEP]')] number_indices = [index + qlen + 2 for index in number_indices] + [-1] # Not done in-place so they won't change the numbers saved for the passage number_len = number_len + [1] numbers_in_passage = numbers_in_passage + [0] number_tokens = [Token(str(number)) for number in numbers_in_passage] extra_number_tokens = [Token(str(num)) for num in self.extra_numbers] mask_indices = [0, qlen + 1, len(question_passage_tokens) - 1] if self.extract_spans: # adapt indexes to question_passage_tokens sequence passage_spans = [(span[0] + qlen + 2, span[1] + qlen + 2) for span in passage_spans] # remove spans of truncated part of passage passage_spans = [ span for span in passage_spans if span[1] <= len(question_passage_tokens) ] # make span indexes inclusive passage_spans = [(span[0], span[1] - 1) for span in passage_spans] fields: Dict[str, Field] = {} # Add feature fields question_passage_field = TextField(question_passage_tokens, self.token_indexers) fields["question_passage"] = question_passage_field if self.extract_spans: passage_span_fields = [ SpanField(span[0], span[1], question_passage_field) for span in passage_spans ] fields["passage_spans"] = ListField(passage_span_fields) number_token_indices = \ [ArrayField(np.arange(start_ind, start_ind + number_len[i]), padding_value=-1) for i, start_ind in enumerate(number_indices)] fields["number_indices"] = ListField(number_token_indices) numbers_in_passage_field = TextField(number_tokens, self.token_indexers) extra_numbers_field = TextField(extra_number_tokens, self.token_indexers) all_numbers_field = TextField(extra_number_tokens + number_tokens, self.token_indexers) mask_index_fields: List[Field] = [ IndexField(index, question_passage_field) for index in mask_indices ] fields["mask_indices"] = ListField(mask_index_fields) # Compile question, passage, answer metadata metadata = { "original_passage": passage_text, "original_question": question_text, "original_numbers": numbers_in_passage, "original_number_words": number_words, "extra_numbers": self.extra_numbers, "passage_tokens": passage_tokens, "question_tokens": question_tokens, "question_passage_tokens": question_passage_tokens, "passage_id": passage_id, "question_id": question_id } if answer_annotations: for annotation in answer_annotations: tokenized_spans = [[ token.text for token in self.tokenizer.tokenize(answer) ] for answer in annotation['spans']] annotation['spans'] = [ tokenlist_to_passage(token_list) for token_list in tokenized_spans ] # Get answer type, answer text, tokenize answer_type, answer_texts = DropReader.extract_answer_info_from_annotation( answer_annotations[0]) tokenized_answer_texts = [] num_spans = min(len(answer_texts), self.max_spans) for answer_text in answer_texts: answer_tokens = self.tokenizer.tokenize(answer_text) tokenized_answer_texts.append(' '.join( token.text for token in answer_tokens)) metadata["answer_annotations"] = answer_annotations metadata["answer_texts"] = answer_texts metadata["answer_tokens"] = tokenized_answer_texts # Find answer text in question and passage valid_question_spans = DropReader.find_valid_spans( question_tokens, tokenized_answer_texts) for span_ind, span in enumerate(valid_question_spans): valid_question_spans[span_ind] = (span[0] + 1, span[1] + 1) valid_passage_spans = DropReader.find_valid_spans( passage_tokens, tokenized_answer_texts) for span_ind, span in enumerate(valid_passage_spans): valid_passage_spans[span_ind] = (span[0] + qlen + 2, span[1] + qlen + 2) # Get target numbers target_numbers = [] for answer_text in answer_texts: number = self.word_to_num(answer_text) if number is not None: target_numbers.append(number) # Get possible ways to arrive at target numbers with add/sub valid_expressions: List[List[int]] = [] exp_strings = None if answer_type in ["number", "date"]: if self.exp_search == 'full': expressions = get_full_exp( list(enumerate(self.extra_numbers + numbers_in_passage)), target_numbers, self.operations, self.op_dict, self.max_depth) zipped = list(zip(*expressions)) if zipped: valid_expressions = list(zipped[0]) exp_strings = list(zipped[1]) elif self.exp_search == 'add_sub': valid_expressions = \ DropReader.find_valid_add_sub_expressions(self.extra_numbers + numbers_in_passage, target_numbers, self.max_numbers_expression) elif self.exp_search == 'template': valid_expressions, exp_strings = \ get_template_exp(self.extra_numbers + numbers_in_passage, target_numbers, self.templates, self.template_strings) exp_strings = sum(exp_strings, []) # Get possible ways to arrive at target numbers with counting valid_counts: List[int] = [] if answer_type in ["number"]: numbers_for_count = list(range(self.max_count + 1)) valid_counts = DropReader.find_valid_counts( numbers_for_count, target_numbers) # Update metadata with answer info answer_info = { "answer_passage_spans": valid_passage_spans, "answer_question_spans": valid_question_spans, "num_spans": num_spans, "expressions": valid_expressions, "counts": valid_counts } if self.exp_search in ['template', 'full']: answer_info['expr_text'] = exp_strings metadata["answer_info"] = answer_info # Add answer fields passage_span_fields: List[Field] = [ SpanField(span[0], span[1], question_passage_field) for span in valid_passage_spans ] if not passage_span_fields: passage_span_fields.append( SpanField(-1, -1, question_passage_field)) fields["answer_as_passage_spans"] = ListField(passage_span_fields) question_span_fields: List[Field] = [ SpanField(span[0], span[1], question_passage_field) for span in valid_question_spans ] if not question_span_fields: question_span_fields.append( SpanField(-1, -1, question_passage_field)) fields["answer_as_question_spans"] = ListField( question_span_fields) if self.exp_search == 'add_sub': add_sub_signs_field: List[Field] = [] extra_signs_field: List[Field] = [] for signs_for_one_add_sub_expressions in valid_expressions: extra_signs = signs_for_one_add_sub_expressions[:len( self.extra_numbers)] normal_signs = signs_for_one_add_sub_expressions[ len(self.extra_numbers):] add_sub_signs_field.append( SequenceLabelField(normal_signs, numbers_in_passage_field)) extra_signs_field.append( SequenceLabelField(extra_signs, extra_numbers_field)) if not add_sub_signs_field: add_sub_signs_field.append( SequenceLabelField([0] * len(number_tokens), numbers_in_passage_field)) if not extra_signs_field: extra_signs_field.append( SequenceLabelField([0] * len(self.extra_numbers), extra_numbers_field)) fields["answer_as_expressions"] = ListField( add_sub_signs_field) if self.extra_numbers: fields["answer_as_expressions_extra"] = ListField( extra_signs_field) elif self.exp_search in ['template', 'full']: expression_indices = [] for expression in valid_expressions: if not expression: expression.append(3 * [-1]) expression_indices.append( ArrayField(np.array(expression), padding_value=-1)) if not expression_indices: expression_indices = \ [ArrayField(np.array([3 * [-1]]), padding_value=-1) for _ in range(len(self.templates))] fields["answer_as_expressions"] = ListField(expression_indices) count_fields: List[Field] = [ LabelField(count_label, skip_indexing=True) for count_label in valid_counts ] if not count_fields: count_fields.append(LabelField(-1, skip_indexing=True)) fields["answer_as_counts"] = ListField(count_fields) fields["num_spans"] = LabelField(num_spans, skip_indexing=True) fields["metadata"] = MetadataField(metadata) return Instance(fields)
def text_to_instance(self, question_text: str, passage_text: str, passage_tokens: List[Token], numbers_in_passage: List[Any], number_words: List[str], number_indices: List[int], number_len: List[int], question_id: str = None, passage_id: str = None, answer_annotations: List[Dict] = None, specific_answer_type: str = None) -> Optional[Instance]: # Tokenize question and passage ''' ### all_number_in_qp_tokens = [qp_tokens[idx] for idx in number_indices] unit_tokens = self.tokenizer.tokenize(answer_annotations[0]['unit']) valid_unit_spans = DropReader.find_valid_spans(question_tokens, [answer_annotations[0]['unit']]) assert len(valid_unit_spans) == 1 ### index + 1 since there is an CLS token at the front valid_unit_spans = [(valid_unit_spans[0][0]+1, valid_unit_spans[0][1]+1)] ''' question_tokens = self.tokenizer.tokenize(question_text) question_tokens = fill_token_indices(question_tokens, question_text, self._uncased, self.basic_tokenizer) qlen = len(question_tokens) qp_tokens = [Token('[CLS]')] + question_tokens + [Token('[SEP]')] + passage_tokens # if qp has more than max_pieces tokens (including CLS and SEP), clip the passage max_passage_length = -1 if len(qp_tokens) > self.max_pieces - 1: qp_tokens = qp_tokens[:self.max_pieces - 1] passage_tokens = passage_tokens[:self.max_pieces - qlen - 3] plen = len(passage_tokens) number_indices, number_len, numbers_in_passage = \ clipped_passage_num(number_indices, number_len, numbers_in_passage, plen) max_passage_length = token_to_span(passage_tokens[-1])[1] if plen > 0 else 0 qp_tokens += [Token('[SEP]')] # update the indices of the numbers with respect to the question. # Not done in-place so they won't change the numbers saved for the passage number_indices = [index + qlen + 2 for index in number_indices] + [-1] number_len = number_len + [1] numbers_in_passage = numbers_in_passage + [0] number_tokens = [Token(str(number)) for number in numbers_in_passage] extra_number_tokens = [Token(str(num)) for num in self.extra_numbers] mask_indices = [0, qlen + 1, len(qp_tokens) - 1] fields: Dict[str, Field] = {} # Add feature fields qp_field = TextField(qp_tokens, self.token_indexers) fields["question_passage"] = qp_field number_token_indices = \ [ArrayField(np.arange(start_ind, start_ind + number_len[i]), padding_value=-1) for i, start_ind in enumerate(number_indices)] fields["number_indices"] = ListField(number_token_indices) numbers_in_passage_field = TextField(number_tokens, self.token_indexers) extra_numbers_field = TextField(extra_number_tokens, self.token_indexers) mask_index_fields: List[Field] = [IndexField(index, qp_field) for index in mask_indices] fields["mask_indices"] = ListField(mask_index_fields) # Compile question, passage, answer metadata metadata = {"original_passage": passage_text, "original_question": question_text, "original_numbers": numbers_in_passage, "original_number_words": number_words, "extra_numbers": self.extra_numbers, "passage_tokens": passage_tokens, "question_tokens": question_tokens, "question_passage_tokens": qp_tokens, "passage_id": passage_id, "question_id": question_id, "max_passage_length": max_passage_length} # in a word broken up into pieces, every piece except the first should be ignored when calculating the loss wordpiece_mask = [not token.text.startswith('##') for token in qp_tokens] wordpiece_mask = np.array(wordpiece_mask) fields['bio_wordpiece_mask'] = ArrayField(wordpiece_mask, dtype=np.int64) if answer_annotations: # Get answer type, answer text, tokenize # For multi-span, remove repeating answers. Although possible, in the dataset it is mostly mistakes. if answer_annotations[0]['yesno']: answer_type = YESNO_ANSER_TYPE answer_texts = 'true' if answer_annotations[0]['yesno'] == '1' else 'false' else: answer_type, answer_texts = DropReader.extract_answer_info_from_annotation(answer_annotations[0]) if answer_type == SPAN_ANSWER_TYPE: answer_texts = list(OrderedDict.fromkeys(answer_texts)) tokenized_answer_texts = [] for answer_text in answer_texts: answer_tokens = self.tokenizer.tokenize(answer_text) tokenized_answer_text = ' '.join(token.text for token in answer_tokens) if tokenized_answer_text not in tokenized_answer_texts and tokenized_answer_text != '': tokenized_answer_texts.append(tokenized_answer_text) metadata["answer_annotations"] = answer_annotations metadata["answer_texts"] = answer_texts metadata["answer_tokens"] = tokenized_answer_texts # Find unit text in question # import pdb; pdb.set_trace() if answer_annotations[0]['unit'] != '': # print('answer_annotations[0][unit] = '+str(answer_annotations[0]['unit'])) valid_unit_spans = DropReader.find_valid_spans(question_tokens, [answer_annotations[0]['unit']]) ## assert len(valid_unit_spans) <= 1 ### index + 1 since there is an CLS token at the front valid_unit_spans = [(unit_span[0]+1, unit_span[1]+1) for unit_span in valid_unit_spans] else: valid_unit_spans = [] # Find answer text in question and passage # if len(tokenized_answer_texts)==1 and tokenized_answer_texts[0] == '': # import pdb; pdb.set_trace() valid_question_spans = DropReader.find_valid_spans(question_tokens, tokenized_answer_texts) for span_ind, span in enumerate(valid_question_spans): valid_question_spans[span_ind] = (span[0] + 1, span[1] + 1) valid_passage_spans = DropReader.find_valid_spans(passage_tokens, tokenized_answer_texts) for span_ind, span in enumerate(valid_passage_spans): valid_passage_spans[span_ind] = (span[0] + qlen + 2, span[1] + qlen + 2) # throw away an instance in training if a span appearing in the answer is missing from the question and passage if self._is_training: if specific_answer_type in SPAN_ANSWER_TYPES: for tokenized_answer_text in tokenized_answer_texts: temp_spans = DropReader.find_valid_spans(qp_field, [tokenized_answer_text]) if len(temp_spans) == 0: return None # Get target numbers target_numbers = [] if specific_answer_type != MULTIPLE_SPAN or self.multispan_allow_all_heads_to_answer: for answer_text in answer_texts: number = self.word_to_num(answer_text, self.improve_number_extraction) if number is not None: target_numbers.append(number) # Get possible ways to arrive at target numbers with add/sub valid_expressions: List[List[int]] = [] exp_strings = None if answer_type in ["number", "date"]: if self.target_number_rounding: valid_expressions = \ find_valid_add_sub_expressions_with_rounding( self.extra_numbers + numbers_in_passage, target_numbers, self.max_numbers_expression) else: valid_expressions = \ DropReader.find_valid_add_sub_expressions(self.extra_numbers + numbers_in_passage, target_numbers, self.max_numbers_expression) if len(target_numbers) == 0: import pdb; pdb.set_trace() if self.discard_impossible_number_questions: # The train set was verified to have all of its target_numbers lists of length 1. if (answer_type == "number" and len(valid_expressions) == 0 and self._is_training and self.max_count < target_numbers[0]): # The number to predict can't be derived from any head, so we shouldn't train on it. # arithmetic - no expressions that yield the number to predict. # counting - the maximal count is smaller than the number to predict. # However, although the answer is marked in the dataset as a number type answer, # maybe it cannot be found due to a bug in DROP's text parsing. # So in addition, we try to find the answer as a span in the text. # If the answer is indeed a span in the text, we don't discard that question. if len(valid_question_spans) == 0 and len(valid_passage_spans) == 0: return None if not self.keep_impossible_number_questions_which_exist_as_spans: return None # Get possible ways to arrive at target numbers with counting valid_counts: List[int] = [] if answer_type in ["number"]: numbers_for_count = list(range(self.max_count + 1)) valid_counts = DropReader.find_valid_counts(numbers_for_count, target_numbers) valid_yesno: int = -1 if answer_type in ["yesno"]: valid_yesno = 1 if answer_texts == 'true' else 0 # Update metadata with answer info answer_info = {"answer_passage_spans": valid_passage_spans, "answer_question_spans": valid_question_spans, "expressions": valid_expressions, "counts": valid_counts, "unit": valid_unit_spans, "yesno": valid_yesno} metadata["answer_info"] = answer_info # Add answer fields passage_span_fields: List[Field] = [] if specific_answer_type != MULTIPLE_SPAN or self.multispan_allow_all_heads_to_answer: passage_span_fields: List[Field] = [SpanField(span[0], span[1], qp_field) for span in valid_passage_spans] if not passage_span_fields: passage_span_fields.append(SpanField(-1, -1, qp_field)) fields["answer_as_passage_spans"] = ListField(passage_span_fields) question_span_fields: List[Field] = [] if specific_answer_type != MULTIPLE_SPAN or self.multispan_allow_all_heads_to_answer: question_span_fields: List[Field] = [SpanField(span[0], span[1], qp_field) for span in valid_question_spans] if not question_span_fields: question_span_fields.append(SpanField(-1, -1, qp_field)) fields["answer_as_question_spans"] = ListField(question_span_fields) add_sub_signs_field: List[Field] = [] extra_signs_field: List[Field] = [] for signs_for_one_add_sub_expressions in valid_expressions: extra_signs = signs_for_one_add_sub_expressions[:len(self.extra_numbers)] normal_signs = signs_for_one_add_sub_expressions[len(self.extra_numbers):] add_sub_signs_field.append(SequenceLabelField(normal_signs, numbers_in_passage_field)) extra_signs_field.append(SequenceLabelField(extra_signs, extra_numbers_field)) if not add_sub_signs_field: add_sub_signs_field.append(SequenceLabelField([0] * len(number_tokens), numbers_in_passage_field)) if not extra_signs_field: extra_signs_field.append(SequenceLabelField([0] * len(self.extra_numbers), extra_numbers_field)) fields["answer_as_expressions"] = ListField(add_sub_signs_field) if self.extra_numbers: fields["answer_as_expressions_extra"] = ListField(extra_signs_field) ''' Add unit_field ''' unit_span_fields: List[Field] = [] unit_span_fields: List[Field] = [SpanField(span[0], span[1], qp_field) for span in valid_unit_spans] if not unit_span_fields: unit_span_fields.append(SpanField(-1, -1, qp_field)) fields["answer_as_unit_spans"] = ListField(unit_span_fields) count_fields: List[Field] = [LabelField(count_label, skip_indexing=True) for count_label in valid_counts] if not count_fields: count_fields.append(LabelField(-1, skip_indexing=True)) fields["answer_as_counts"] = ListField(count_fields) yesno_field: List[Field] = [LabelField(valid_yesno, skip_indexing=True)] fields["answer_as_yesno"] = ListField(yesno_field) no_answer_bios = SequenceLabelField([0] * len(qp_tokens), sequence_field=qp_field) if (specific_answer_type in self.bio_types) and (len(valid_passage_spans) > 0 or len(valid_question_spans) > 0): # Used for flexible BIO loss # START spans_dict = {} text_to_disjoint_bios: List[ListField] = [] flexibility_count = 1 for tokenized_answer_text in tokenized_answer_texts: spans = DropReader.find_valid_spans(qp_tokens, [tokenized_answer_text]) if len(spans) == 0: # possible if the passage was clipped, but not for all of the answers continue spans_dict[tokenized_answer_text] = spans disjoint_bios: List[SequenceLabelField] = [] for span_ind, span in enumerate(spans): bios = create_bio_labels([span], len(qp_field)) disjoint_bios.append(SequenceLabelField(bios, sequence_field=qp_field)) text_to_disjoint_bios.append(ListField(disjoint_bios)) flexibility_count *= ((2**len(spans)) - 1) fields["answer_as_text_to_disjoint_bios"] = ListField(text_to_disjoint_bios) if (flexibility_count < self.flexibility_threshold): # generate all non-empty span combinations per each text spans_combinations_dict = {} for key, spans in spans_dict.items(): spans_combinations_dict[key] = all_combinations = [] for i in range(1, len(spans) + 1): all_combinations += list(itertools.combinations(spans, i)) # calculate product between all the combinations per each text packed_gold_spans_list = itertools.product(*list(spans_combinations_dict.values())) bios_list: List[SequenceLabelField] = [] for packed_gold_spans in packed_gold_spans_list: gold_spans = [s for sublist in packed_gold_spans for s in sublist] bios = create_bio_labels(gold_spans, len(qp_field)) bios_list.append(SequenceLabelField(bios, sequence_field=qp_field)) fields["answer_as_list_of_bios"] = ListField(bios_list) fields["answer_as_text_to_disjoint_bios"] = ListField([ListField([no_answer_bios])]) else: fields["answer_as_list_of_bios"] = ListField([no_answer_bios]) # END # Used for both "require-all" BIO loss and flexible loss bio_labels = create_bio_labels(valid_question_spans + valid_passage_spans, len(qp_field)) fields['span_bio_labels'] = SequenceLabelField(bio_labels, sequence_field=qp_field) fields["is_bio_mask"] = LabelField(1, skip_indexing=True) else: fields["answer_as_text_to_disjoint_bios"] = ListField([ListField([no_answer_bios])]) fields["answer_as_list_of_bios"] = ListField([no_answer_bios]) # create all 'O' BIO labels for non-span questions fields['span_bio_labels'] = no_answer_bios fields["is_bio_mask"] = LabelField(0, skip_indexing=True) fields["metadata"] = MetadataField(metadata) return Instance(fields)
def augment(self, passage_id, question_id, passage_text, answer_texts): # Valid for augmentation if self._augmentation_ratio < 1 and random.uniform( 0, 1) > self._augmentation_ratio: return [] # Only multi span questions are augmented if answer_texts is None or len(answer_texts) <= 1: return [] if passage_id == self._cached_passage_id: passage_tags = self._cached_passage_tags reconstructed_passage_text = self._reconstructed_passage_text else: passage_tags = self._tagger.predict_json( {"sentence": passage_text}) passage_tags['idxs'] = [] temp_passage = passage_text reconstructed_passage_text = '' absolute_index = 0 for i, word in enumerate(passage_tags['words']): tag = passage_tags['tags'][i] relative_index = temp_passage.index(word) first_part = temp_passage[:relative_index] second_part = word reconstructed_passage_text += first_part + second_part start_idx = absolute_index + relative_index end_idx = start_idx + len(word) # exclusive passage_tags['idxs'].append((start_idx, end_idx)) absolute_index = end_idx temp_passage = temp_passage[relative_index + len(word):] self._cached_passage_id = passage_id self._cached_passage_tags = passage_tags self._reconstructed_passage_text = reconstructed_passage_text if reconstructed_passage_text != passage_text: return [] # Don't try augmentation if there are shared words between the answers or repeating words in an answer words_set = set() words_count = 0 for answer_text in answer_texts: words = answer_text.split() words_set.update(words) words_count += len(words) if len(words_set) != words_count: return [] # Validations spans = DropReader.find_valid_spans( [Token(word) for word in passage_tags['words']], answer_texts) # Validate each answer has a span (tokenizing here is by the tagger so we can't assume our fixes to the data are enough) for answer_text in answer_texts: answer_has_tag = False for span in spans: span_text = ' '.join(passage_tags['words'][span[0]:span[1] + 1]).lower() if answer_text.lower() == span_text: answer_has_tag = True break if not answer_has_tag: return [] # Validate all spans have PER tags and that there is no span with PER immediately before or after to avoid replacement of partial names for span in spans: answer_tags = passage_tags['tags'][span[0]:span[1] + 1] span_length = len(answer_tags) # if not all(tag.endswith('PER') for tag in answer_tags): # # Should we enforce a proper BILOU tagging sequence? if not self.is_valid_BILOU(answer_tags): return [] # Check if we need it '''if span[0] > 0 and passage_tags['tags'][span[0] - 1].endswith('PER'): return [] if span[1] < len(passage_tags) - 1 and passage_tags['tags'][span[1] + 1].endswith('PER'): return []''' # Heavier stuff, after fast validations new_passage_text = passage_text PLACEHOLDER_SYMBOL = "#" pending_swaps = [] swaps_mapping = [(i + 1) % len(answer_texts) for i in range(len(answer_texts))] subs_per_answer_text = [ get_all_subsequences(answer_text.split()) for answer_text in answer_texts ] for i, answer_text in enumerate(answer_texts): replacer_answer_index = swaps_mapping[i] for sub in subs_per_answer_text[i]: spans_per_sub = DropReader.find_valid_spans( [Token(word) for word in passage_tags['words']], [sub]) relative_start_idx = answer_text.index(sub) relative_end_idx = relative_start_idx + len(sub) # exclusive if (len(spans_per_sub) > 0) and answer_text.index(sub) != 0 and ( relative_end_idx != len(answer_text)): # We have a span that is from the middle of the answer. Too ambiguous, ignore question return [] partition = 1 if relative_start_idx == 0 and relative_end_idx == len( answer_text): partition = -1 elif relative_start_idx == 0: partition = 0 for span in spans_per_sub: answer_tags = passage_tags['tags'][span[0]:span[1] + 1] # if not all(tag.endswith('PER') for tag in answer_tags): # Should we enforce a proper BILOU tagging sequence? if not self.is_valid_BILOU(answer_tags): continue first_pard_end = passage_tags['idxs'][span[0]][0] last_part_start = passage_tags['idxs'][span[1]][1] pending_swaps.append({ 'replacer_answer_index': replacer_answer_index, 'partition': partition, 'first_pard_end': first_pard_end, 'last_part_start': last_part_start }) first_part = new_passage_text[:passage_tags['idxs'][ span[0]][0]] last_part = new_passage_text[ passage_tags['idxs'][span[1]][1]:] new_passage_text = first_part + (PLACEHOLDER_SYMBOL) * ( last_part_start - first_pard_end) + last_part for swap in sorted(pending_swaps, key=lambda x: x['first_pard_end'], reverse=True): replacer_answer_index = swap['replacer_answer_index'] partition = swap['partition'] first_pard_end = swap['first_pard_end'] last_part_start = swap['last_part_start'] replacer_answer = answer_texts[replacer_answer_index] if partition == -1: replacer = replacer_answer else: answer_parts = replacer_answer.split() if len(answer_parts) > 1: replacer = answer_parts[0] if partition == 0 else ' '.join( answer_parts[1:]) else: replacer = replacer_answer first_part = new_passage_text[:first_pard_end] last_part = new_passage_text[last_part_start:] new_passage_text = first_part + replacer + last_part return [new_passage_text]
def clean(self, passage, question, answer, passage_tagging, question_tagging): passage_tokens = [Token(w) for w in passage_tagging['words']] spans = DropReader.find_valid_spans(passage_tokens, answer['spans']) if not spans: return None new_answer_texts = [] cleaned = False for answer_text in answer['spans']: if len(answer_text.split()) <= 1: continue new_answer_text = answer_text for span in spans: span_text = ' '.join(passage_tagging['words'][span[0]:span[1] + 1]).lower() if answer_text.lower() != span_text: continue span_tags = passage_tagging['tags'][span[0]:span[1] + 1] count_o = sum(tag == 'O' for tag in span_tags) other_than_o = len(span_tags) - count_o if count_o == 0 or other_than_o == 0: break tags_to_trim = ['O'] if count_o <= other_than_o else [ 'ORG', 'LOC', 'PER', 'MISC' ] # Remove words only from the start and from the end to keep a valid span span_words = passage_tagging['words'][span[0]:span[1] + 1] words_count = len(span_words) remove_from_start = True remove_from_end = True for i in range(words_count): if remove_from_start and all(not span_tags[i].endswith(tag) for tag in tags_to_trim): remove_from_start = False if remove_from_end and all( not span_tags[words_count - i - 1].endswith(tag) for tag in tags_to_trim): remove_from_end = False if remove_from_start: del span_words[0] if remove_from_end: del span_words[-1] if not remove_from_end and not remove_from_start: break new_answer_text = ' '.join(span_words) cleaned = True break new_answer_texts.append(new_answer_text) if not cleaned: return None new_answer = answer.copy() new_answer['spans'] = new_answer_texts return {'answer': new_answer}