def get_answer_fields(
            self, **kwargs: Dict[str, Any]) -> Tuple[Dict[str, Field], bool]:
        answer_texts: List[str] = kwargs['answer_texts']

        fields: Dict[str, Field] = {}

        target_numbers = get_target_numbers(answer_texts)

        numbers_for_count = list(range(self._max_count + 1))
        valid_counts: List[int] = DropReader.find_valid_counts(
            numbers_for_count, target_numbers)

        if len(valid_counts) > 0:
            has_answer = True

            counts_field: List[Field] = [
                LabelField(count_label, skip_indexing=True)
                for count_label in valid_counts
            ]
            fields['answer_as_counts'] = ListField(counts_field)
        else:
            has_answer = False
            fields.update(self.get_empty_answer_fields(**kwargs))

        return fields, has_answer
예제 #2
0
def load_drop_data():
    DROP_DEV_FILE = '/home/tony/answer-generation/raw_data/drop/drop_dataset_dev.json'

    data = {}

    with open(DROP_DEV_FILE) as dataset_file:
        dataset = json.load(dataset_file)

    for passage_id, passage_info in dataset.items():
        passage = clean_string(passage_info["passage"])
        for question_answer in passage_info["qa_pairs"]:
            question_id = question_answer["query_id"]
            question = clean_string(question_answer["question"])
            answer_annotations = []
            if "answer" in question_answer:
                answer_annotations.append(question_answer["answer"])
            if "validated_answers" in question_answer:
                answer_annotations += question_answer["validated_answers"]

            # Extract out the label per annotation
            answer_annotations = [
                ' '.join(DropReader.extract_answer_info_from_annotation(a)[1])
                for a in answer_annotations
            ]

            # Get the most common answer as the gold answer
            reference = clean_string(str(most_frequent(answer_annotations)))

            data[question_id] = {
                'context': passage,
                'question': question,
                'reference': reference,
                'candidates': set()
            }

    return data
    def clean(self, passage, question, answer, passage_tagging,
              question_tagging):
        passage_tokens = [Token(w) for w in passage_tagging['words']]
        spans = DropReader.find_valid_spans(passage_tokens, answer['spans'])

        new_answer_texts = []

        cleaned = False

        for answer_text in answer['spans']:
            valid = True

            for span in spans:
                span_text = ' '.join(passage_tagging['words'][span[0]:span[1] +
                                                              1]).lower()

                if answer_text.lower() != span_text:
                    continue

                if any(tag != 'O'
                       for tag in passage_tagging['tags'][span[0]:span[1] +
                                                          1]):
                    valid = False
                    cleaned = True
                    break

            if valid:
                new_answer_texts.append(answer_text)

        if not cleaned:
            return None

        new_answer = answer.copy()
        new_answer['spans'] = new_answer_texts

        return {'answer': new_answer}
예제 #4
0
    def text_to_instance(
            self,
            question_text: str,
            passage_text: str,
            passage_tokens: List[Token],
            passage_spans: List[Tuple[int, int]],
            numbers_in_passage: List[Any],
            number_words: List[str],
            number_indices: List[int],
            number_len: List[int],
            question_id: str = None,
            passage_id: str = None,
            answer_annotations: List[Dict] = None) -> Union[Instance, None]:
        # Tokenize question and passage
        question_tokens = self.tokenizer.tokenize(question_text)
        qlen = len(question_tokens)
        plen = len(passage_tokens)

        question_passage_tokens = [Token('[CLS]')] + question_tokens + [
            Token('[SEP]')
        ] + passage_tokens
        if len(question_passage_tokens) > self.max_pieces - 1:
            question_passage_tokens = question_passage_tokens[:self.
                                                              max_pieces - 1]
            passage_tokens = passage_tokens[:self.max_pieces - qlen - 3]
            plen = len(passage_tokens)
            number_indices, number_len, numbers_in_passage = \
                clipped_passage_num(number_indices, number_len, numbers_in_passage, plen)

        question_passage_tokens += [Token('[SEP]')]
        number_indices = [index + qlen + 2 for index in number_indices] + [-1]
        # Not done in-place so they won't change the numbers saved for the passage
        number_len = number_len + [1]
        numbers_in_passage = numbers_in_passage + [0]
        number_tokens = [Token(str(number)) for number in numbers_in_passage]
        extra_number_tokens = [Token(str(num)) for num in self.extra_numbers]

        mask_indices = [0, qlen + 1, len(question_passage_tokens) - 1]

        if self.extract_spans:
            # adapt indexes to question_passage_tokens sequence
            passage_spans = [(span[0] + qlen + 2, span[1] + qlen + 2)
                             for span in passage_spans]
            # remove spans of truncated part of passage
            passage_spans = [
                span for span in passage_spans
                if span[1] <= len(question_passage_tokens)
            ]
            # make span indexes inclusive
            passage_spans = [(span[0], span[1] - 1) for span in passage_spans]

        fields: Dict[str, Field] = {}

        # Add feature fields
        question_passage_field = TextField(question_passage_tokens,
                                           self.token_indexers)
        fields["question_passage"] = question_passage_field

        if self.extract_spans:
            passage_span_fields = [
                SpanField(span[0], span[1], question_passage_field)
                for span in passage_spans
            ]
            fields["passage_spans"] = ListField(passage_span_fields)

        number_token_indices = \
            [ArrayField(np.arange(start_ind, start_ind + number_len[i]), padding_value=-1)
             for i, start_ind in enumerate(number_indices)]
        fields["number_indices"] = ListField(number_token_indices)
        numbers_in_passage_field = TextField(number_tokens,
                                             self.token_indexers)
        extra_numbers_field = TextField(extra_number_tokens,
                                        self.token_indexers)
        all_numbers_field = TextField(extra_number_tokens + number_tokens,
                                      self.token_indexers)
        mask_index_fields: List[Field] = [
            IndexField(index, question_passage_field) for index in mask_indices
        ]
        fields["mask_indices"] = ListField(mask_index_fields)

        # Compile question, passage, answer metadata
        metadata = {
            "original_passage": passage_text,
            "original_question": question_text,
            "original_numbers": numbers_in_passage,
            "original_number_words": number_words,
            "extra_numbers": self.extra_numbers,
            "passage_tokens": passage_tokens,
            "question_tokens": question_tokens,
            "question_passage_tokens": question_passage_tokens,
            "passage_id": passage_id,
            "question_id": question_id
        }

        if answer_annotations:
            for annotation in answer_annotations:
                tokenized_spans = [[
                    token.text for token in self.tokenizer.tokenize(answer)
                ] for answer in annotation['spans']]
                annotation['spans'] = [
                    tokenlist_to_passage(token_list)
                    for token_list in tokenized_spans
                ]

            # Get answer type, answer text, tokenize
            answer_type, answer_texts = DropReader.extract_answer_info_from_annotation(
                answer_annotations[0])
            tokenized_answer_texts = []
            num_spans = min(len(answer_texts), self.max_spans)
            for answer_text in answer_texts:
                answer_tokens = self.tokenizer.tokenize(answer_text)
                tokenized_answer_texts.append(' '.join(
                    token.text for token in answer_tokens))

            metadata["answer_annotations"] = answer_annotations
            metadata["answer_texts"] = answer_texts
            metadata["answer_tokens"] = tokenized_answer_texts

            # Find answer text in question and passage
            valid_question_spans = DropReader.find_valid_spans(
                question_tokens, tokenized_answer_texts)
            for span_ind, span in enumerate(valid_question_spans):
                valid_question_spans[span_ind] = (span[0] + 1, span[1] + 1)
            valid_passage_spans = DropReader.find_valid_spans(
                passage_tokens, tokenized_answer_texts)
            for span_ind, span in enumerate(valid_passage_spans):
                valid_passage_spans[span_ind] = (span[0] + qlen + 2,
                                                 span[1] + qlen + 2)

            # Get target numbers
            target_numbers = []
            for answer_text in answer_texts:
                number = self.word_to_num(answer_text)
                if number is not None:
                    target_numbers.append(number)

            # Get possible ways to arrive at target numbers with add/sub

            valid_expressions: List[List[int]] = []
            exp_strings = None
            if answer_type in ["number", "date"]:
                if self.exp_search == 'full':
                    expressions = get_full_exp(
                        list(enumerate(self.extra_numbers +
                                       numbers_in_passage)), target_numbers,
                        self.operations, self.op_dict, self.max_depth)
                    zipped = list(zip(*expressions))
                    if zipped:
                        valid_expressions = list(zipped[0])
                        exp_strings = list(zipped[1])
                elif self.exp_search == 'add_sub':
                    valid_expressions = \
                        DropReader.find_valid_add_sub_expressions(self.extra_numbers + numbers_in_passage,
                                                                  target_numbers,
                                                                  self.max_numbers_expression)
                elif self.exp_search == 'template':
                    valid_expressions, exp_strings = \
                        get_template_exp(self.extra_numbers + numbers_in_passage,
                                         target_numbers,
                                         self.templates,
                                         self.template_strings)
                    exp_strings = sum(exp_strings, [])

            # Get possible ways to arrive at target numbers with counting
            valid_counts: List[int] = []
            if answer_type in ["number"]:
                numbers_for_count = list(range(self.max_count + 1))
                valid_counts = DropReader.find_valid_counts(
                    numbers_for_count, target_numbers)

            # Update metadata with answer info
            answer_info = {
                "answer_passage_spans": valid_passage_spans,
                "answer_question_spans": valid_question_spans,
                "num_spans": num_spans,
                "expressions": valid_expressions,
                "counts": valid_counts
            }
            if self.exp_search in ['template', 'full']:
                answer_info['expr_text'] = exp_strings
            metadata["answer_info"] = answer_info

            # Add answer fields
            passage_span_fields: List[Field] = [
                SpanField(span[0], span[1], question_passage_field)
                for span in valid_passage_spans
            ]
            if not passage_span_fields:
                passage_span_fields.append(
                    SpanField(-1, -1, question_passage_field))
            fields["answer_as_passage_spans"] = ListField(passage_span_fields)

            question_span_fields: List[Field] = [
                SpanField(span[0], span[1], question_passage_field)
                for span in valid_question_spans
            ]
            if not question_span_fields:
                question_span_fields.append(
                    SpanField(-1, -1, question_passage_field))
            fields["answer_as_question_spans"] = ListField(
                question_span_fields)

            if self.exp_search == 'add_sub':
                add_sub_signs_field: List[Field] = []
                extra_signs_field: List[Field] = []
                for signs_for_one_add_sub_expressions in valid_expressions:
                    extra_signs = signs_for_one_add_sub_expressions[:len(
                        self.extra_numbers)]
                    normal_signs = signs_for_one_add_sub_expressions[
                        len(self.extra_numbers):]
                    add_sub_signs_field.append(
                        SequenceLabelField(normal_signs,
                                           numbers_in_passage_field))
                    extra_signs_field.append(
                        SequenceLabelField(extra_signs, extra_numbers_field))
                if not add_sub_signs_field:
                    add_sub_signs_field.append(
                        SequenceLabelField([0] * len(number_tokens),
                                           numbers_in_passage_field))
                if not extra_signs_field:
                    extra_signs_field.append(
                        SequenceLabelField([0] * len(self.extra_numbers),
                                           extra_numbers_field))
                fields["answer_as_expressions"] = ListField(
                    add_sub_signs_field)
                if self.extra_numbers:
                    fields["answer_as_expressions_extra"] = ListField(
                        extra_signs_field)
            elif self.exp_search in ['template', 'full']:
                expression_indices = []
                for expression in valid_expressions:
                    if not expression:
                        expression.append(3 * [-1])
                    expression_indices.append(
                        ArrayField(np.array(expression), padding_value=-1))
                if not expression_indices:
                    expression_indices = \
                        [ArrayField(np.array([3 * [-1]]), padding_value=-1) for _ in range(len(self.templates))]
                fields["answer_as_expressions"] = ListField(expression_indices)

            count_fields: List[Field] = [
                LabelField(count_label, skip_indexing=True)
                for count_label in valid_counts
            ]
            if not count_fields:
                count_fields.append(LabelField(-1, skip_indexing=True))
            fields["answer_as_counts"] = ListField(count_fields)

            fields["num_spans"] = LabelField(num_spans, skip_indexing=True)

        fields["metadata"] = MetadataField(metadata)

        return Instance(fields)
    def text_to_instance(self,
                         question_text: str,
                         passage_text: str,
                         passage_tokens: List[Token],
                         numbers_in_passage: List[Any],
                         number_words: List[str],
                         number_indices: List[int],
                         number_len: List[int],
                         question_id: str = None,
                         passage_id: str = None,
                         answer_annotations: List[Dict] = None,
                         specific_answer_type: str = None) -> Optional[Instance]:
        # Tokenize question and passage
        '''
            ### all_number_in_qp_tokens = [qp_tokens[idx] for idx in number_indices]

            unit_tokens = self.tokenizer.tokenize(answer_annotations[0]['unit'])

            valid_unit_spans = DropReader.find_valid_spans(question_tokens, [answer_annotations[0]['unit']])
            assert len(valid_unit_spans) == 1
            ### index + 1 since there is an CLS token at the front
            valid_unit_spans = [(valid_unit_spans[0][0]+1, valid_unit_spans[0][1]+1)]
        '''

        question_tokens = self.tokenizer.tokenize(question_text)
        question_tokens = fill_token_indices(question_tokens, question_text, self._uncased, self.basic_tokenizer)

        qlen = len(question_tokens)

        qp_tokens = [Token('[CLS]')] + question_tokens + [Token('[SEP]')] + passage_tokens

        # if qp has more than max_pieces tokens (including CLS and SEP), clip the passage
        max_passage_length = -1
        if len(qp_tokens) > self.max_pieces - 1:
            qp_tokens = qp_tokens[:self.max_pieces - 1]
            passage_tokens = passage_tokens[:self.max_pieces - qlen - 3]
            plen = len(passage_tokens)
            number_indices, number_len, numbers_in_passage = \
                clipped_passage_num(number_indices, number_len, numbers_in_passage, plen)
            max_passage_length = token_to_span(passage_tokens[-1])[1] if plen > 0 else 0
        
        qp_tokens += [Token('[SEP]')]
        # update the indices of the numbers with respect to the question.
        # Not done in-place so they won't change the numbers saved for the passage
        number_indices = [index + qlen + 2 for index in number_indices] + [-1]
        number_len = number_len + [1]
        numbers_in_passage = numbers_in_passage + [0]
        number_tokens = [Token(str(number)) for number in numbers_in_passage]
        extra_number_tokens = [Token(str(num)) for num in self.extra_numbers]
        
        mask_indices = [0, qlen + 1, len(qp_tokens) - 1]
        
        fields: Dict[str, Field] = {}
            
        # Add feature fields
        qp_field = TextField(qp_tokens, self.token_indexers)
        fields["question_passage"] = qp_field
       
        number_token_indices = \
            [ArrayField(np.arange(start_ind, start_ind + number_len[i]), padding_value=-1) 
             for i, start_ind in enumerate(number_indices)]
        fields["number_indices"] = ListField(number_token_indices)
        numbers_in_passage_field = TextField(number_tokens, self.token_indexers)
        extra_numbers_field = TextField(extra_number_tokens, self.token_indexers)
        mask_index_fields: List[Field] = [IndexField(index, qp_field) for index in mask_indices]
        fields["mask_indices"] = ListField(mask_index_fields)

        # Compile question, passage, answer metadata
        metadata = {"original_passage": passage_text,
                    "original_question": question_text,
                    "original_numbers": numbers_in_passage,
                    "original_number_words": number_words,
                    "extra_numbers": self.extra_numbers,
                    "passage_tokens": passage_tokens,
                    "question_tokens": question_tokens,
                    "question_passage_tokens": qp_tokens,
                    "passage_id": passage_id,
                    "question_id": question_id,
                    "max_passage_length": max_passage_length}

        # in a word broken up into pieces, every piece except the first should be ignored when calculating the loss
        wordpiece_mask = [not token.text.startswith('##') for token in qp_tokens]
        wordpiece_mask = np.array(wordpiece_mask)
        fields['bio_wordpiece_mask'] = ArrayField(wordpiece_mask, dtype=np.int64)

        if answer_annotations:            
            # Get answer type, answer text, tokenize
            # For multi-span, remove repeating answers. Although possible, in the dataset it is mostly mistakes.
            if answer_annotations[0]['yesno']:
                answer_type = YESNO_ANSER_TYPE
                answer_texts = 'true' if answer_annotations[0]['yesno'] == '1' else 'false'
            else:
                answer_type, answer_texts = DropReader.extract_answer_info_from_annotation(answer_annotations[0])

            if answer_type == SPAN_ANSWER_TYPE:
                answer_texts = list(OrderedDict.fromkeys(answer_texts))
            tokenized_answer_texts = []
            for answer_text in answer_texts:
                answer_tokens = self.tokenizer.tokenize(answer_text)
                tokenized_answer_text = ' '.join(token.text for token in answer_tokens)
                if tokenized_answer_text not in tokenized_answer_texts and tokenized_answer_text != '':
                    tokenized_answer_texts.append(tokenized_answer_text)

            metadata["answer_annotations"] = answer_annotations
            metadata["answer_texts"] = answer_texts
            metadata["answer_tokens"] = tokenized_answer_texts
        
            # Find unit text in question
            # import pdb; pdb.set_trace()
            if answer_annotations[0]['unit'] != '': 
                # print('answer_annotations[0][unit] = '+str(answer_annotations[0]['unit']))
                valid_unit_spans = DropReader.find_valid_spans(question_tokens, [answer_annotations[0]['unit']])
                ## assert len(valid_unit_spans) <= 1
                ### index + 1 since there is an CLS token at the front
                valid_unit_spans = [(unit_span[0]+1, unit_span[1]+1) for unit_span in valid_unit_spans]
            else:
                valid_unit_spans = []
            # Find answer text in question and passage
            # if len(tokenized_answer_texts)==1 and tokenized_answer_texts[0] == '':
            #     import pdb; pdb.set_trace()
            valid_question_spans = DropReader.find_valid_spans(question_tokens, tokenized_answer_texts)
            for span_ind, span in enumerate(valid_question_spans):
                valid_question_spans[span_ind] = (span[0] + 1, span[1] + 1)
            valid_passage_spans = DropReader.find_valid_spans(passage_tokens, tokenized_answer_texts)
            for span_ind, span in enumerate(valid_passage_spans):
                valid_passage_spans[span_ind] = (span[0] + qlen + 2, span[1] + qlen + 2)

            # throw away an instance in training if a span appearing in the answer is missing from the question and passage
            if self._is_training:
                if specific_answer_type in SPAN_ANSWER_TYPES:
                    for tokenized_answer_text in tokenized_answer_texts:
                        temp_spans = DropReader.find_valid_spans(qp_field, [tokenized_answer_text])
                        if len(temp_spans) == 0:
                            return None

            # Get target numbers
            target_numbers = []
            if specific_answer_type != MULTIPLE_SPAN or self.multispan_allow_all_heads_to_answer:
                for answer_text in answer_texts:
                    number = self.word_to_num(answer_text, self.improve_number_extraction)
                    if number is not None:
                        target_numbers.append(number)
            
            # Get possible ways to arrive at target numbers with add/sub
            valid_expressions: List[List[int]] = []
            exp_strings = None
            if answer_type in ["number", "date"]:
                if self.target_number_rounding:
                    valid_expressions = \
                        find_valid_add_sub_expressions_with_rounding(
                            self.extra_numbers + numbers_in_passage,
                            target_numbers,
                            self.max_numbers_expression)
                else:
                    valid_expressions = \
                        DropReader.find_valid_add_sub_expressions(self.extra_numbers + numbers_in_passage,
                                                                  target_numbers,
                                                                  self.max_numbers_expression)
                if len(target_numbers) == 0:
                    import pdb; pdb.set_trace()
                if self.discard_impossible_number_questions:
                    # The train set was verified to have all of its target_numbers lists of length 1.
                    if (answer_type == "number" and
                            len(valid_expressions) == 0 and
                            self._is_training and
                            self.max_count < target_numbers[0]):
                        # The number to predict can't be derived from any head, so we shouldn't train on it.
                        # arithmetic - no expressions that yield the number to predict.
                        # counting - the maximal count is smaller than the number to predict.

                        # However, although the answer is marked in the dataset as a number type answer,
                        # maybe it cannot be found due to a bug in DROP's text parsing.
                        # So in addition, we try to find the answer as a span in the text.
                        # If the answer is indeed a span in the text, we don't discard that question.
                        if len(valid_question_spans) == 0 and len(valid_passage_spans) == 0:
                            return None
                        if not self.keep_impossible_number_questions_which_exist_as_spans:
                            return None

            # Get possible ways to arrive at target numbers with counting
            valid_counts: List[int] = []
            if answer_type in ["number"]:
                numbers_for_count = list(range(self.max_count + 1))
                valid_counts = DropReader.find_valid_counts(numbers_for_count, target_numbers)

            valid_yesno: int = -1
            if answer_type in ["yesno"]:
                valid_yesno = 1 if answer_texts == 'true' else 0

            # Update metadata with answer info
            answer_info = {"answer_passage_spans": valid_passage_spans,
                           "answer_question_spans": valid_question_spans,
                           "expressions": valid_expressions,
                           "counts": valid_counts,
                           "unit": valid_unit_spans,
                           "yesno": valid_yesno}

            metadata["answer_info"] = answer_info
        
            # Add answer fields
            passage_span_fields: List[Field] = []
            if specific_answer_type != MULTIPLE_SPAN or self.multispan_allow_all_heads_to_answer:
                passage_span_fields: List[Field] = [SpanField(span[0], span[1], qp_field) for span in valid_passage_spans]
            if not passage_span_fields:
                passage_span_fields.append(SpanField(-1, -1, qp_field))
            fields["answer_as_passage_spans"] = ListField(passage_span_fields)

            question_span_fields: List[Field] = []
            if specific_answer_type != MULTIPLE_SPAN or self.multispan_allow_all_heads_to_answer:
                question_span_fields: List[Field] = [SpanField(span[0], span[1], qp_field) for span in valid_question_spans]
            if not question_span_fields:
                question_span_fields.append(SpanField(-1, -1, qp_field))
            fields["answer_as_question_spans"] = ListField(question_span_fields)
            
            add_sub_signs_field: List[Field] = []
            extra_signs_field: List[Field] = []
            for signs_for_one_add_sub_expressions in valid_expressions:
                extra_signs = signs_for_one_add_sub_expressions[:len(self.extra_numbers)]
                normal_signs = signs_for_one_add_sub_expressions[len(self.extra_numbers):]
                add_sub_signs_field.append(SequenceLabelField(normal_signs, numbers_in_passage_field))
                extra_signs_field.append(SequenceLabelField(extra_signs, extra_numbers_field))
            if not add_sub_signs_field:
                add_sub_signs_field.append(SequenceLabelField([0] * len(number_tokens), numbers_in_passage_field))
            if not extra_signs_field:
                extra_signs_field.append(SequenceLabelField([0] * len(self.extra_numbers), extra_numbers_field))
            fields["answer_as_expressions"] = ListField(add_sub_signs_field)
            if self.extra_numbers:
                fields["answer_as_expressions_extra"] = ListField(extra_signs_field)


            '''
                Add unit_field
            '''
            unit_span_fields: List[Field] = []
            unit_span_fields: List[Field] = [SpanField(span[0], span[1], qp_field) for span in valid_unit_spans]
            if not unit_span_fields:
                unit_span_fields.append(SpanField(-1, -1, qp_field))
                
            fields["answer_as_unit_spans"] = ListField(unit_span_fields)

            count_fields: List[Field] = [LabelField(count_label, skip_indexing=True) for count_label in valid_counts]
            if not count_fields:
                count_fields.append(LabelField(-1, skip_indexing=True))
            fields["answer_as_counts"] = ListField(count_fields)
            
            yesno_field: List[Field] = [LabelField(valid_yesno, skip_indexing=True)]
            fields["answer_as_yesno"] = ListField(yesno_field)
            

            no_answer_bios = SequenceLabelField([0] * len(qp_tokens), sequence_field=qp_field)
            if (specific_answer_type in self.bio_types) and (len(valid_passage_spans) > 0 or len(valid_question_spans) > 0):
                
                # Used for flexible BIO loss
                # START
                
                spans_dict = {}
                text_to_disjoint_bios: List[ListField] = []
                flexibility_count = 1
                for tokenized_answer_text in tokenized_answer_texts:
                    spans = DropReader.find_valid_spans(qp_tokens, [tokenized_answer_text])
                    if len(spans) == 0:
                        # possible if the passage was clipped, but not for all of the answers
                        continue
                    spans_dict[tokenized_answer_text] = spans

                    disjoint_bios: List[SequenceLabelField] = []
                    for span_ind, span in enumerate(spans):
                        bios = create_bio_labels([span], len(qp_field))
                        disjoint_bios.append(SequenceLabelField(bios, sequence_field=qp_field))

                    text_to_disjoint_bios.append(ListField(disjoint_bios))
                    flexibility_count *= ((2**len(spans)) - 1)

                fields["answer_as_text_to_disjoint_bios"] = ListField(text_to_disjoint_bios)

                if (flexibility_count < self.flexibility_threshold):
                    # generate all non-empty span combinations per each text
                    spans_combinations_dict = {}
                    for key, spans in spans_dict.items():
                        spans_combinations_dict[key] = all_combinations = []
                        for i in range(1, len(spans) + 1):
                            all_combinations += list(itertools.combinations(spans, i))

                    # calculate product between all the combinations per each text
                    packed_gold_spans_list = itertools.product(*list(spans_combinations_dict.values()))
                    bios_list: List[SequenceLabelField] = []
                    for packed_gold_spans in packed_gold_spans_list:
                        gold_spans = [s for sublist in packed_gold_spans for s in sublist]
                        bios = create_bio_labels(gold_spans, len(qp_field))
                        bios_list.append(SequenceLabelField(bios, sequence_field=qp_field))
                    
                    fields["answer_as_list_of_bios"] = ListField(bios_list)
                    fields["answer_as_text_to_disjoint_bios"] = ListField([ListField([no_answer_bios])])
                else:
                    fields["answer_as_list_of_bios"] = ListField([no_answer_bios])

                # END

                # Used for both "require-all" BIO loss and flexible loss
                bio_labels = create_bio_labels(valid_question_spans + valid_passage_spans, len(qp_field))
                fields['span_bio_labels'] = SequenceLabelField(bio_labels, sequence_field=qp_field)

                fields["is_bio_mask"] = LabelField(1, skip_indexing=True)
            else:
                fields["answer_as_text_to_disjoint_bios"] = ListField([ListField([no_answer_bios])])
                fields["answer_as_list_of_bios"] = ListField([no_answer_bios])

                # create all 'O' BIO labels for non-span questions
                fields['span_bio_labels'] = no_answer_bios
                fields["is_bio_mask"] = LabelField(0, skip_indexing=True)

        fields["metadata"] = MetadataField(metadata)
        
        return Instance(fields)
    def augment(self, passage_id, question_id, passage_text, answer_texts):

        # Valid for augmentation
        if self._augmentation_ratio < 1 and random.uniform(
                0, 1) > self._augmentation_ratio:
            return []

        # Only multi span questions are augmented
        if answer_texts is None or len(answer_texts) <= 1:
            return []

        if passage_id == self._cached_passage_id:
            passage_tags = self._cached_passage_tags
            reconstructed_passage_text = self._reconstructed_passage_text
        else:
            passage_tags = self._tagger.predict_json(
                {"sentence": passage_text})
            passage_tags['idxs'] = []

            temp_passage = passage_text
            reconstructed_passage_text = ''

            absolute_index = 0
            for i, word in enumerate(passage_tags['words']):
                tag = passage_tags['tags'][i]

                relative_index = temp_passage.index(word)
                first_part = temp_passage[:relative_index]
                second_part = word
                reconstructed_passage_text += first_part + second_part

                start_idx = absolute_index + relative_index
                end_idx = start_idx + len(word)  # exclusive
                passage_tags['idxs'].append((start_idx, end_idx))

                absolute_index = end_idx
                temp_passage = temp_passage[relative_index + len(word):]

            self._cached_passage_id = passage_id
            self._cached_passage_tags = passage_tags
            self._reconstructed_passage_text = reconstructed_passage_text

        if reconstructed_passage_text != passage_text:
            return []

        # Don't try augmentation if there are shared words between the answers or repeating words in an answer
        words_set = set()
        words_count = 0
        for answer_text in answer_texts:
            words = answer_text.split()
            words_set.update(words)
            words_count += len(words)

            if len(words_set) != words_count:
                return []

        # Validations
        spans = DropReader.find_valid_spans(
            [Token(word) for word in passage_tags['words']], answer_texts)

        # Validate each answer has a span (tokenizing here is by the tagger so we can't assume our fixes to the data are enough)
        for answer_text in answer_texts:
            answer_has_tag = False

            for span in spans:
                span_text = ' '.join(passage_tags['words'][span[0]:span[1] +
                                                           1]).lower()

                if answer_text.lower() == span_text:
                    answer_has_tag = True
                    break

            if not answer_has_tag:
                return []

        # Validate all spans have PER tags and that there is no span with PER immediately before or after to avoid replacement of partial names
        for span in spans:
            answer_tags = passage_tags['tags'][span[0]:span[1] + 1]
            span_length = len(answer_tags)

            # if not all(tag.endswith('PER') for tag in answer_tags): # # Should we enforce a proper BILOU tagging sequence?
            if not self.is_valid_BILOU(answer_tags):
                return []

            # Check if we need it
            '''if span[0] > 0 and passage_tags['tags'][span[0] - 1].endswith('PER'):
                return []

            if span[1] < len(passage_tags) - 1 and passage_tags['tags'][span[1] + 1].endswith('PER'):
                return []'''

        # Heavier stuff, after fast validations
        new_passage_text = passage_text

        PLACEHOLDER_SYMBOL = "#"
        pending_swaps = []
        swaps_mapping = [(i + 1) % len(answer_texts)
                         for i in range(len(answer_texts))]

        subs_per_answer_text = [
            get_all_subsequences(answer_text.split())
            for answer_text in answer_texts
        ]
        for i, answer_text in enumerate(answer_texts):
            replacer_answer_index = swaps_mapping[i]
            for sub in subs_per_answer_text[i]:
                spans_per_sub = DropReader.find_valid_spans(
                    [Token(word) for word in passage_tags['words']], [sub])
                relative_start_idx = answer_text.index(sub)
                relative_end_idx = relative_start_idx + len(sub)  # exclusive
                if (len(spans_per_sub) >
                        0) and answer_text.index(sub) != 0 and (
                            relative_end_idx != len(answer_text)):
                    # We have a span that is from the middle of the answer. Too ambiguous, ignore question
                    return []
                partition = 1
                if relative_start_idx == 0 and relative_end_idx == len(
                        answer_text):
                    partition = -1
                elif relative_start_idx == 0:
                    partition = 0

                for span in spans_per_sub:
                    answer_tags = passage_tags['tags'][span[0]:span[1] + 1]
                    # if not all(tag.endswith('PER') for tag in answer_tags): # Should we enforce a proper BILOU tagging sequence?
                    if not self.is_valid_BILOU(answer_tags):
                        continue

                    first_pard_end = passage_tags['idxs'][span[0]][0]
                    last_part_start = passage_tags['idxs'][span[1]][1]
                    pending_swaps.append({
                        'replacer_answer_index': replacer_answer_index,
                        'partition': partition,
                        'first_pard_end': first_pard_end,
                        'last_part_start': last_part_start
                    })
                    first_part = new_passage_text[:passage_tags['idxs'][
                        span[0]][0]]
                    last_part = new_passage_text[
                        passage_tags['idxs'][span[1]][1]:]
                    new_passage_text = first_part + (PLACEHOLDER_SYMBOL) * (
                        last_part_start - first_pard_end) + last_part

        for swap in sorted(pending_swaps,
                           key=lambda x: x['first_pard_end'],
                           reverse=True):
            replacer_answer_index = swap['replacer_answer_index']
            partition = swap['partition']
            first_pard_end = swap['first_pard_end']
            last_part_start = swap['last_part_start']

            replacer_answer = answer_texts[replacer_answer_index]

            if partition == -1:
                replacer = replacer_answer
            else:
                answer_parts = replacer_answer.split()
                if len(answer_parts) > 1:
                    replacer = answer_parts[0] if partition == 0 else ' '.join(
                        answer_parts[1:])
                else:
                    replacer = replacer_answer

            first_part = new_passage_text[:first_pard_end]
            last_part = new_passage_text[last_part_start:]

            new_passage_text = first_part + replacer + last_part

        return [new_passage_text]
예제 #7
0
    def clean(self, passage, question, answer, passage_tagging,
              question_tagging):
        passage_tokens = [Token(w) for w in passage_tagging['words']]
        spans = DropReader.find_valid_spans(passage_tokens, answer['spans'])

        if not spans:
            return None

        new_answer_texts = []

        cleaned = False

        for answer_text in answer['spans']:
            if len(answer_text.split()) <= 1:
                continue

            new_answer_text = answer_text

            for span in spans:
                span_text = ' '.join(passage_tagging['words'][span[0]:span[1] +
                                                              1]).lower()

                if answer_text.lower() != span_text:
                    continue

                span_tags = passage_tagging['tags'][span[0]:span[1] + 1]

                count_o = sum(tag == 'O' for tag in span_tags)
                other_than_o = len(span_tags) - count_o

                if count_o == 0 or other_than_o == 0:
                    break

                tags_to_trim = ['O'] if count_o <= other_than_o else [
                    'ORG', 'LOC', 'PER', 'MISC'
                ]

                # Remove words only from the start and from the end to keep a valid span
                span_words = passage_tagging['words'][span[0]:span[1] + 1]
                words_count = len(span_words)

                remove_from_start = True
                remove_from_end = True

                for i in range(words_count):
                    if remove_from_start and all(not span_tags[i].endswith(tag)
                                                 for tag in tags_to_trim):
                        remove_from_start = False

                    if remove_from_end and all(
                            not span_tags[words_count - i - 1].endswith(tag)
                            for tag in tags_to_trim):
                        remove_from_end = False

                    if remove_from_start:
                        del span_words[0]

                    if remove_from_end:
                        del span_words[-1]

                    if not remove_from_end and not remove_from_start:
                        break

                new_answer_text = ' '.join(span_words)
                cleaned = True
                break

            new_answer_texts.append(new_answer_text)

        if not cleaned:
            return None

        new_answer = answer.copy()
        new_answer['spans'] = new_answer_texts

        return {'answer': new_answer}