示例#1
0
    def text_to_instance(
        self,  # type: ignore
        question: str,
        answer: Optional[bool] = None,
        decomposition: Optional[List[Dict[str, Any]]] = None,
        generated_decomposition: Optional[List[Dict[str, Any]]] = None,
        facts: Optional[List[str]] = None,
    ) -> Instance:
        tokenizer_wrapper = self._tokenizer_wrapper
        fields = {}
        pad_token_id = tokenizer_wrapper.tokenizer.pad_token_id

        (context, paragraphs) = self.generate_context_from_paragraphs(
            question=question,
            decomposition=decomposition,
            generated_decomposition=generated_decomposition,
            facts=facts,
        )

        if self._paragraphs_source is not None:
            if context is None:
                if self._is_training and self._skip_if_context_missing:
                    return None
                context = " "
            encoded_input = tokenizer_wrapper.encode(context, question)
        else:
            encoded_input = tokenizer_wrapper.encode(question)

        excluded_keys = ["offset_mapping", "special_tokens_mask"]
        encoded_input_fields = {
            key: LabelsField(value, padding_value=pad_token_id)
            if key == "input_ids" else LabelsField(value)
            for key, value in encoded_input.items() if key not in excluded_keys
        }
        fields["tokens"] = DictionaryField(encoded_input_fields,
                                           length=len(
                                               encoded_input["input_ids"]))

        if answer is not None:
            fields["label"] = LabelField(int(answer), skip_indexing=True)

        # make the metadata

        metadata = {
            "question": question,
            "answer": answer,
            "decomposition": decomposition,
            "generated_decomposition": generated_decomposition,
            "paragraphs": paragraphs["unified"],
            "queries":
            paragraphs["queries"] if "queries" in paragraphs else [],
        }
        fields["metadata"] = MetadataField(metadata)

        return Instance(fields)
    def get_empty_answer_fields(self, **kwargs: Dict[str,
                                                     Any]) -> Dict[str, Field]:
        number_occurrences_in_passage: List[Dict[
            str, Any]] = kwargs['number_occurrences_in_passage']

        fields: Dict[str, Field] = {}

        fields['answer_as_expressions'] = ListField(
            [LabelsField([0] * len(number_occurrences_in_passage))])
        if self._special_numbers:
            fields['answer_as_expressions_extra'] = ListField(
                [LabelsField([0] * len(self._special_numbers))])

        return fields
示例#3
0
def get_number_indices_field(number_occurrences_in_passage: List[Dict[str,
                                                                      Any]]):
    number_token_indices = \
        [LabelsField(number_occurrence['indices'], padding_value=-1) for number_occurrence in
         number_occurrences_in_passage]

    return ListField(number_token_indices)
示例#4
0
    def _get_wordpiece_indices_field(wordpieces: List[List[int]]):
        wordpiece_token_indices = []
        ingested_indices = []
        i = 0
        while i < len(wordpieces):
            current_wordpieces = wordpieces[i]
            if len(current_wordpieces) > 1:
                wordpiece_token_indices.append(LabelsField(current_wordpieces, padding_value=-1))
                i = current_wordpieces[-1] + 1
            else:
                i += 1

        # hack to guarantee minimal length of padded number
        # according to dataset_readers.reading_comprehension.drop from allennlp)
        wordpiece_token_indices.append(LabelsField([-1], padding_value=-1))

        return ListField(wordpiece_token_indices)
示例#5
0
    def text_to_instance(
        self,  # type: ignore
        question: str,
        context: Optional[str] = None,
        answer: Optional[bool] = None,
    ) -> Instance:
        tokenizer_wrapper = self._tokenizer_wrapper
        fields = {}
        pad_token_id = tokenizer_wrapper.tokenizer.pad_token_id

        if "?" not in question:
            question += "?"

        if self._with_context:
            if context is None:
                context = ""
            encoded_input = tokenizer_wrapper.encode(context, question)
        else:
            encoded_input = tokenizer_wrapper.encode(question)

        encoded_input_fields = {
            key: LabelsField(value, padding_value=pad_token_id)
            if key == "input_ids"
            else LabelsField(value)
            for key, value in encoded_input.items()
        }
        fields["tokens"] = DictionaryField(
            encoded_input_fields, length=len(encoded_input["input_ids"])
        )

        if answer is not None:
            fields["label"] = LabelField(int(answer), skip_indexing=True)

        # make the metadata
        metadata = {
            "question": question,
            "context": context,
            "answer": answer,
            "tokenized_input": tokenizer_wrapper.convert_ids_to_tokens(encoded_input["input_ids"]),
        }
        fields["metadata"] = MetadataField(metadata)

        return Instance(fields)
    def get_answer_fields(
            self, **kwargs: Dict[str, Any]) -> Tuple[Dict[str, Field], bool]:
        number_occurrences_in_passage: List[Dict[
            str, Any]] = kwargs['number_occurrences_in_passage']
        answer_texts: List[str] = kwargs['answer_texts']

        fields: Dict[str, Field] = {}

        target_numbers = get_target_numbers(answer_texts)

        # Get possible ways to arrive at target numbers with add/sub
        valid_expressions: List[List[int]] = \
            self._find_valid_add_sub_expressions_with_rounding(
                self._special_numbers + [number_occurrence['value'] for number_occurrence in number_occurrences_in_passage],
                target_numbers,
                self._max_numbers_expression)

        if len(valid_expressions) > 0:
            has_answer = True

            add_sub_signs_field: List[Field] = []
            special_signs_field: List[Field] = []

            for signs_for_one_add_sub_expressions in valid_expressions:
                special_signs = signs_for_one_add_sub_expressions[:len(
                    self._special_numbers)]
                normal_signs = signs_for_one_add_sub_expressions[
                    len(self._special_numbers):]
                add_sub_signs_field.append(LabelsField(normal_signs))
                special_signs_field.append(LabelsField(special_signs))

            fields['answer_as_expressions'] = ListField(add_sub_signs_field)
            if self._special_numbers:
                fields['answer_as_expressions_extra'] = ListField(
                    special_signs_field)
        else:
            has_answer = False
            fields.update(self.get_empty_answer_fields(**kwargs))

        return fields, has_answer
示例#7
0
    def text_to_instance(self,
                         question_text: str,
                         passage_text: str,
                         passage_tokens: List[Token],
                         passage_text_index_to_token_index: List[int],
                         passage_wordpieces: List[List[int]],
                         question_id: str = None,
                         passage_id: str = None,
                         answer_annotations: List[Dict] = None,
                         original_answer_annotations: List[List[Dict]] = None,
                         answer_type: str = None,
                         instance_index: int = None) -> Optional[Instance]:

        # Tokenize question
        question_tokens = self._tokenizer.tokenize_with_offsets(question_text)
        question_text_index_to_token_index = index_text_to_tokens(
            question_text, question_tokens)
        question_words = self._word_tokenize(question_text)
        question_alignment = self._tokenizer.align_tokens_to_tokens(
            question_text, question_words, question_tokens)
        question_wordpieces = self._tokenizer.alignment_to_token_wordpieces(
            question_alignment)

        # Index tokens
        encoded_inputs = self._tokenizer.encode_plus(
            [token.text for token in question_tokens],
            [token.text for token in passage_tokens],
            add_special_tokens=True,
            max_length=self._max_pieces,
            truncation_strategy='only_second',
            return_token_type_ids=True,
            return_special_tokens_mask=True)
        question_passage_token_type_ids = encoded_inputs['token_type_ids']
        question_passage_special_tokens_mask = encoded_inputs[
            'special_tokens_mask']
        question_position = self._tokenizer.get_type_position_in_sequence(
            0, question_passage_token_type_ids,
            question_passage_special_tokens_mask)
        passage_position = self._tokenizer.get_type_position_in_sequence(
            1, question_passage_token_type_ids,
            question_passage_special_tokens_mask)
        question_passage_tokens, num_of_tokens_per_type = self._tokenizer.convert_to_tokens(
            encoded_inputs, [{
                'tokens': question_tokens,
                'wordpieces': question_wordpieces,
                'position': question_position
            }, {
                'tokens': passage_tokens,
                'wordpieces': passage_wordpieces,
                'position': passage_position
            }])

        # Adjust wordpieces
        question_passage_wordpieces = self._tokenizer.adjust_wordpieces(
            [{
                'wordpieces': question_wordpieces,
                'position': question_position,
                'num_of_tokens': num_of_tokens_per_type[0]
            }, {
                'wordpieces': passage_wordpieces,
                'position': passage_position,
                'num_of_tokens': num_of_tokens_per_type[1]
            }], question_passage_tokens)

        # Adjust text index to token index
        question_text_index_to_token_index = [
            token_index + question_position
            for i, token_index in enumerate(question_text_index_to_token_index)
            if token_index < num_of_tokens_per_type[0]
        ]
        passage_text_index_to_token_index = [
            token_index + passage_position
            for i, token_index in enumerate(passage_text_index_to_token_index)
            if token_index < num_of_tokens_per_type[1]
        ]

        # Truncation-related code
        encoded_passage_tokens_length = num_of_tokens_per_type[1]
        if encoded_passage_tokens_length > 0:
            if encoded_passage_tokens_length < len(passage_tokens):
                first_truncated_passage_token = passage_tokens[
                    encoded_passage_tokens_length]
                max_passage_length = first_truncated_passage_token.idx
            else:
                max_passage_length = -1
        else:
            max_passage_length = 0

        fields: Dict[str, Field] = {}

        fields[
            'question_passage_tokens'] = question_passage_field = LabelsField(
                encoded_inputs['input_ids'])
        fields['question_passage_token_type_ids'] = LabelsField(
            question_passage_token_type_ids)
        fields['question_passage_special_tokens_mask'] = LabelsField(
            question_passage_special_tokens_mask)
        fields['question_passage_pad_mask'] = LabelsField(
            [1] * len(question_passage_tokens))

        # in a word broken up into pieces, every piece except the first should be ignored when calculating the loss
        first_wordpiece_mask = [
            i == wordpieces[0]
            for i, wordpieces in enumerate(question_passage_wordpieces)
        ]
        fields['first_wordpiece_mask'] = LabelsField(first_wordpiece_mask)

        # Compile question, passage, answer metadata
        metadata = {
            'original_passage': passage_text,
            'original_question': question_text,
            'passage_tokens': passage_tokens,
            'question_tokens': question_tokens,
            'question_passage_tokens': question_passage_tokens,
            'question_passage_wordpieces': question_passage_wordpieces,
            'passage_id': passage_id,
            'question_id': question_id,
            'max_passage_length': max_passage_length
        }
        if instance_index is not None:
            metadata['instance_index'] = instance_index

        if answer_annotations:
            _, answer_texts = extract_answer_info_from_annotation(
                answer_annotations[0])
            answer_texts = list(OrderedDict.fromkeys(answer_texts))

            gold_indexes = {'question': [], 'passage': []}
            for original_answer_annotation in original_answer_annotations[0]:
                if original_answer_annotation['text'] in answer_texts:
                    gold_index = original_answer_annotation['answer_start']
                    if gold_index not in gold_indexes['passage']:
                        gold_indexes['passage'].append(gold_index)

            metadata['answer_annotations'] = answer_annotations

            kwargs = {
                'seq_tokens':
                question_passage_tokens,
                'seq_field':
                question_passage_field,
                'seq_wordpieces':
                question_passage_wordpieces,
                'question_text':
                question_text,
                'question_text_index_to_token_index':
                question_text_index_to_token_index,
                'passage_text':
                passage_text[:max_passage_length]
                if max_passage_length > -1 else passage_text,
                'passage_text_index_to_token_index':
                passage_text_index_to_token_index,
                'answer_texts':
                answer_texts,
                'gold_indexes':
                gold_indexes,
                'answer_type':
                answer_type,  # TODO: Elad - Probably temporary, used to mimic the old reader's behavior
                'is_training':
                self.
                _is_training,  # TODO: Elad - Probably temporary, used to mimic the old reader's behavior
                'old_reader_behavior':
                self.
                _old_reader_behavior  # TODO: Elad - temporary, used to mimic the old reader's behavior
            }

            answer_generator_names = None
            if self._answer_generator_names_per_type is not None:
                answer_generator_names = self._answer_generator_names_per_type[
                    answer_type]

            has_answer = False
            for answer_generator_name, answer_field_generator in self._answer_field_generators.items(
            ):
                if answer_generator_names is None or answer_generator_name in answer_generator_names:
                    answer_fields, generator_has_answer = answer_field_generator.get_answer_fields(
                        **kwargs)
                    fields.update(answer_fields)
                    has_answer |= generator_has_answer
                else:
                    fields.update(
                        answer_field_generator.get_empty_answer_fields(
                            **kwargs))

            # throw away instances without possible answer generation
            if self._is_training and not has_answer:
                return None

        fields['metadata'] = MetadataField(metadata)

        return Instance(fields)
示例#8
0
    def text_to_instance(self,
                         question_text: str,
                         passage_text: str,
                         passage_tokens: List[Token],
                         passage_text_index_to_token_index: List[int],
                         passage_wordpieces: List[List[int]],
                         number_occurrences_in_passage: List[Dict[str, Any]],
                         question_id: str = None,
                         passage_id: str = None,
                         answer_annotations: List[Dict] = None,
                         answer_type: str = None,
                         instance_index: int = None) -> Optional[Instance]:
        """
        process a question/answer pair (related to a passage) and prepares an `Instance`

        Parameters
        ----------
        question_text: str,
        passage_text: str,
        passage_tokens: List[Token],
        passage_text_index_to_token_index: List[int],
        passage_wordpieces: List[List[int]],
        number_occurrences_in_passage: List[Dict[str, Any]],
        question_id: str = None,
        passage_id: str = None,
        answer_annotations: List[Dict] = None,
        answer_type: str = None,
        instance_index: int = None

        Returns
        -------
            instance: Optional[Instance]
                an `Instance` containing all the data that the `Model` will takes as the input
        """

        # We alter it, so use a copy to keep it usable for the next questions with the same paragrpah
        number_occurrences_in_passage = [number_occurrence.copy() for number_occurrence in
                                         number_occurrences_in_passage]

        """ Tokenize question
        - question_tokens: [ĠHow, Ġmany, Ġpoints, Ġdid, Ġthe, Ġbu, cc, aneers, Ġneed, ...
        - question_text_index_to_token_index: [0, 0, 0, 0, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, ...
        - question_words: [How, many, points, did, the, buccaneers, need, to, tie, in, the, first, ?] 
        - question_alignment: [[0], [1], [2], [3], [4], [5, 6, 7], [8], [9], [10], [11], [12], [13], [14]]
        - question_wordpieces: [[0], [1], [2], [3], [4], [5, 6, 7], [5, 6, 7], [5, 6, 7], [8], [9], [10], ...
        """
        question_tokens = self._tokenizer.tokenize_with_offsets(question_text)
        question_text_index_to_token_index = index_text_to_tokens(question_text, question_tokens)
        question_words = self._word_tokenize(question_text)
        question_alignment = self._tokenizer.align_tokens_to_tokens(question_text, question_words, question_tokens)
        question_wordpieces = self._tokenizer.alignment_to_token_wordpieces(question_alignment)

        """ Index tokens
        encoded_inputs is a dictionary with:
        - 'special_tokens_mask',
        - 'input_ids',
        - 'token_type_ids',
        - 'attention_mask'
        """
        encoded_inputs = self._tokenizer.encode_plus([token.text for token in question_tokens],
                                                     [token.text for token in passage_tokens],
                                                     add_special_tokens=True, max_length=self._max_pieces,
                                                     truncation_strategy='only_second',
                                                     return_token_type_ids=True,
                                                     return_special_tokens_mask=True)

        question_passage_token_type_ids = encoded_inputs['token_type_ids']
        question_passage_special_tokens_mask = encoded_inputs['special_tokens_mask']

        question_position = self._tokenizer.get_type_position_in_sequence(0, question_passage_token_type_ids,
                                                                          question_passage_special_tokens_mask)
        passage_position = self._tokenizer.get_type_position_in_sequence(1, question_passage_token_type_ids,
                                                                         question_passage_special_tokens_mask)
        question_passage_tokens, num_of_tokens_per_type = self._tokenizer.convert_to_tokens(encoded_inputs, [
            {'tokens': question_tokens, 'wordpieces': question_wordpieces, 'position': question_position},
            {'tokens': passage_tokens, 'wordpieces': passage_wordpieces, 'position': passage_position}
        ])

        # Adjust wordpieces
        question_passage_wordpieces = self._tokenizer.adjust_wordpieces([
            {'wordpieces': question_wordpieces, 'position': question_position,
             'num_of_tokens': num_of_tokens_per_type[0]},
            {'wordpieces': passage_wordpieces, 'position': passage_position, 'num_of_tokens': num_of_tokens_per_type[1]}
        ], question_passage_tokens)

        # Adjust text index to token index
        question_text_index_to_token_index = [token_index + question_position for i, token_index in
                                              enumerate(question_text_index_to_token_index)
                                              if token_index < num_of_tokens_per_type[0]]
        passage_text_index_to_token_index = [token_index + passage_position for i, token_index in
                                             enumerate(passage_text_index_to_token_index)
                                             if token_index < num_of_tokens_per_type[1]]

        # Truncation-related code
        encoded_passage_tokens_length = num_of_tokens_per_type[1]
        if encoded_passage_tokens_length > 0:
            if encoded_passage_tokens_length < len(passage_tokens):
                first_truncated_passage_token = passage_tokens[encoded_passage_tokens_length]
                max_passage_length = first_truncated_passage_token.idx
                number_occurrences_in_passage = \
                    clipped_passage_num(number_occurrences_in_passage, encoded_passage_tokens_length)
            else:
                max_passage_length = -1
        else:
            max_passage_length = 0

        # update the indices of the numbers with respect to the question.
        for number_occurrence in number_occurrences_in_passage:
            number_occurrence['indices'] = [index + passage_position for index in number_occurrence['indices']]

        # hack to guarantee minimal length of padded number
        # according to dataset_readers.reading_comprehension.drop from allennlp)
        number_occurrences_in_passage.append({
            'value': 0,
            'indices': [-1]
        })

        # create fields dictionary for the `Instance`
        fields: Dict[str, Field] = {}

        fields['question_passage_tokens'] = question_passage_field = LabelsField(encoded_inputs['input_ids'])
        fields['question_passage_token_type_ids'] = LabelsField(question_passage_token_type_ids)
        fields['question_passage_special_tokens_mask'] = LabelsField(question_passage_special_tokens_mask)
        fields['question_passage_pad_mask'] = LabelsField([1] * len(question_passage_tokens))

        # in a word broken up into pieces, every piece except the first should be ignored when calculating the loss
        first_wordpiece_mask = [i == wordpieces[0] for i, wordpieces in enumerate(question_passage_wordpieces)]
        fields['first_wordpiece_mask'] = LabelsField(first_wordpiece_mask)

        fields['number_indices'] = get_number_indices_field(number_occurrences_in_passage)

        # Compile question, passage, answer metadata
        metadata = {'original_passage': passage_text,
                    'original_question': question_text,
                    'original_numbers': [number_occurrence['value'] for number_occurrence in
                                         number_occurrences_in_passage],
                    'passage_tokens': passage_tokens,
                    'question_tokens': question_tokens,
                    'question_passage_tokens': question_passage_tokens,
                    'question_passage_wordpieces': question_passage_wordpieces,
                    'passage_id': passage_id,
                    'question_id': question_id,
                    'max_passage_length': max_passage_length}
        if instance_index is not None:
            metadata['instance_index'] = instance_index

        if answer_annotations:
            # Get answer type, answer text, tokenize
            # For multi-span, remove repeating answers. Although possible, in the dataset it is mostly mistakes.
            answer_accessor, answer_texts = extract_answer_info_from_annotation(answer_annotations[0])
            if answer_accessor == AnswerAccessor.SPAN.value:
                answer_texts = list(OrderedDict.fromkeys(answer_texts))

            metadata['answer_annotations'] = answer_annotations

            kwargs = {
                'seq_tokens': question_passage_tokens,
                'seq_field': question_passage_field,
                'seq_wordpieces': question_passage_wordpieces,
                'question_text': question_text,
                'question_text_index_to_token_index': question_text_index_to_token_index,
                'passage_text': passage_text[:max_passage_length] if max_passage_length > -1 else passage_text,
                'passage_text_index_to_token_index': passage_text_index_to_token_index,
                'answer_texts': answer_texts,
                'number_occurrences_in_passage': number_occurrences_in_passage,
                'answer_type': answer_type,  # TODO: Elad - Probably temporary, used to mimic the old reader's behavior
                'is_training': self._is_training,
                # TODO: Elad - Probably temporary, used to mimic the old reader's behavior
                'old_reader_behavior': self._old_reader_behavior
                # TODO: Elad - temporary, used to mimic the old reader's behavior
            }

            answer_generator_names = None
            if self._answer_generator_names_per_type is not None:
                answer_generator_names = self._answer_generator_names_per_type[answer_type]

            has_answer = False
            for answer_generator_name, answer_field_generator in self._answer_field_generators.items():
                if answer_generator_names is None or answer_generator_name in answer_generator_names:
                    answer_fields, generator_has_answer = answer_field_generator.get_answer_fields(**kwargs)
                    fields.update(answer_fields)
                    has_answer |= generator_has_answer
                else:
                    fields.update(answer_field_generator.get_empty_answer_fields(**kwargs))

            # throw away instances without possible answer generation
            if self._is_training and not has_answer:
                return None

        fields['metadata'] = MetadataField(metadata)

        return Instance(fields)
示例#9
0
    def get_answer_fields(self,
                **kwargs: Dict[str, Any]) -> Tuple[Dict[str, Field], bool]:
        seq_tokens: List[Token] = kwargs['seq_tokens']
        seq_wordpieces: int = kwargs['seq_wordpieces']
        question_text_index_to_token_index: List[int] = kwargs['question_text_index_to_token_index']
        question_text: str = kwargs['question_text']
        passage_text_index_to_token_index: List[int] = kwargs['passage_text_index_to_token_index']
        passage_text: str = kwargs['passage_text']
        answer_texts: List[str] = kwargs['answer_texts']
        gold_indexes: Dict[List[int]] = (kwargs['gold_indexes'] if 'gold_indexes' in kwargs 
                                          else {'question': None, 'passage': None})

        fields: Dict[str, Field] = {}

        spans_dict = {}
        all_spans = []
        is_missing_answer = False
        for i, answer_text in enumerate(answer_texts):
            answer_spans = []
            if not self._ignore_question:
                answer_spans += find_valid_spans(question_text, [answer_text], 
                                                question_text_index_to_token_index, 
                                                seq_tokens, seq_wordpieces, 
                                                gold_indexes['question'])
            answer_spans += find_valid_spans(passage_text, [answer_text], 
                                             passage_text_index_to_token_index, 
                                             seq_tokens, seq_wordpieces, 
                                             gold_indexes['passage'])
            if len(answer_spans) == 0:
                is_missing_answer = True
                continue
            spans_dict[answer_text] = answer_spans
            all_spans.extend(answer_spans)

        old_reader_behavior = kwargs['old_reader_behavior']
        if old_reader_behavior:
            answer_type = kwargs['answer_type']
            is_training = kwargs['is_training']
            if is_training:
                if answer_type in SPAN_ANSWER_TYPES:
                    if is_missing_answer:
                        all_spans = []

        if len(all_spans) > 0:
            has_answer = True

            fields['wordpiece_indices'] = self._get_wordpiece_indices_field(seq_wordpieces)

            no_answer_bios = self._get_empty_answer(seq_tokens)

            text_to_disjoint_bios: List[ListField] = []
            flexibility_count = 1
            for answer_text in answer_texts:
                spans = spans_dict[answer_text] if answer_text in spans_dict else []
                if len(spans) == 0:
                    continue

                disjoint_bios: List[LabelsField] = []
                for span_ind, span in enumerate(spans):
                    bios = self._create_sequence_labels([span], len(seq_tokens))
                    disjoint_bios.append(LabelsField(bios))

                text_to_disjoint_bios.append(ListField(disjoint_bios))
                flexibility_count *= ((2**len(spans)) - 1)

            fields['answer_as_text_to_disjoint_bios'] = ListField(text_to_disjoint_bios)

            if (flexibility_count < self._flexibility_threshold):
                # generate all non-empty span combinations per each text
                spans_combinations_dict = {}
                for key, spans in spans_dict.items():
                    spans_combinations_dict[key] = all_combinations = []
                    for i in range(1, len(spans) + 1):
                        all_combinations += list(itertools.combinations(spans, i))

                # calculate product between all the combinations per each text
                packed_gold_spans_list = itertools.product(*list(spans_combinations_dict.values()))
                bios_list: List[LabelsField] = []
                for packed_gold_spans in packed_gold_spans_list:
                    gold_spans = [s for sublist in packed_gold_spans for s in sublist]
                    bios = self._create_sequence_labels(gold_spans, len(seq_tokens))
                    bios_list.append(LabelsField(bios))

                fields['answer_as_list_of_bios'] = ListField(bios_list)
                fields['answer_as_text_to_disjoint_bios'] = ListField([ListField([no_answer_bios])])
            else:
                fields['answer_as_list_of_bios'] = ListField([no_answer_bios])

            bio_labels = self._create_sequence_labels(all_spans, len(seq_tokens))
            fields['span_bio_labels'] = LabelsField(bio_labels)

            fields['is_bio_mask'] = LabelField(1, skip_indexing=True)
        else:
            has_answer = False
            fields.update(self.get_empty_answer_fields(**kwargs))

        return fields, has_answer
示例#10
0
 def _get_empty_answer(seq_tokens: List[Token]):
     return LabelsField([0] * len(seq_tokens))
示例#11
0
    def text_to_instance(
        self,  # type: ignore
        question: str,
        modified_question: str,
        context: str,
        encoded_input: Dict[str, Any],
        answers: List[str] = None,
        first_answer_start_offset: Optional[int] = None,
        is_impossible: Optional[bool] = None,
        qid: str = None,
        window_index: int = None,
        is_boolq: bool = False,
    ) -> Instance:
        tokenizer_wrapper = self._tokenizer_wrapper
        offset_mapping = encoded_input["offset_mapping"]
        special_tokens_mask = encoded_input["special_tokens_mask"]

        token_answer_span = None
        if first_answer_start_offset is not None and answers:
            answer = answers[0]
            relevant_sequence_index = 0 if is_boolq else 1
            tokens_groups = group_tokens_by_whole_words(
                [modified_question, context], offset_mapping,
                special_tokens_mask)
            valid_spans = find_valid_spans(
                modified_question if is_boolq else context,
                answer,
                offset_mapping,
                special_tokens_mask,
                functools.partial(
                    get_token_answer_span,
                    sequence_index=relevant_sequence_index,
                ),
                tokens_groups,
                first_answer_start_offset,
            )
            token_answer_span = valid_spans[0] if len(
                valid_spans) > 0 else None

        if self._is_training and (token_answer_span is None and
                                  (is_impossible is None
                                   or is_impossible is False)):
            return None

        seq_boundaries = get_sequence_boundaries(special_tokens_mask)
        (
            first_context_token_index,
            last_context_token_index,
        ) = (seq_boundaries[1] if len(seq_boundaries) > 1 else (-1, -1))

        fields = {}
        pad_token_id = tokenizer_wrapper.tokenizer.pad_token_id

        excluded_keys = [
            "offset_mapping",
            "special_tokens_mask",
            "overflow_to_sample_mapping",
        ]

        encoded_input_fields = {
            key: LabelsField(value, padding_value=pad_token_id)
            if key == "input_ids" else LabelsField(value)
            for key, value in encoded_input.items() if key not in excluded_keys
        }
        fields["question_with_context"] = DictionaryField(
            encoded_input_fields, length=len(encoded_input["input_ids"]))

        # make the answer span
        seq_field = encoded_input_fields["input_ids"]
        if token_answer_span is not None:
            assert all(i >= 0 for i in token_answer_span)
            assert token_answer_span.start <= token_answer_span.end

            fields["answer_span"] = SpanField(
                token_answer_span.start,
                token_answer_span.end,
                seq_field,
            )
        else:
            # We have to put in something even when we don't have an answer, so that this instance can be batched
            # together with other instances that have answers.
            if is_impossible is True:
                fields["answer_span"] = SpanField(0, 0, seq_field)
            else:
                fields["answer_span"] = SpanField(-1, -1, seq_field)

        # make the context span, i.e., the span of text from which possible answers should be drawn
        fields["context_span"] = SpanField(
            first_context_token_index,
            last_context_token_index,
            seq_field,
        )
        fields["yes_no_span"] = SpanField(
            seq_boundaries[0][0],
            seq_boundaries[0][0] + 1,
            seq_field,
        )

        if token_answer_span is None:
            token_answer_span = Span(-1, -1)

        # make the metadata
        metadata = {
            "question": question,
            "modified_question": modified_question,
            "context": context,
            "offset_mapping": offset_mapping,
            "special_tokens_mask": special_tokens_mask,
            "answers": answers,
            "first_answer_start_offset": first_answer_start_offset,
            "id": qid,
            "window_index": window_index,
            "token_answer_span": token_answer_span,
            "is_impossible": is_impossible,
            "is_boolq": is_boolq,
        }
        fields["metadata"] = MetadataField(metadata)

        return Instance(fields)