示例#1
0
    def text_to_instance(self,
                         question_text: str,
                         passage_text: str,
                         passage_tokens: List[Token],
                         passage_text_index_to_token_index: List[int],
                         passage_wordpieces: List[List[int]],
                         question_id: str = None,
                         passage_id: str = None,
                         answer_annotations: List[Dict] = None,
                         original_answer_annotations: List[List[Dict]] = None,
                         answer_type: str = None,
                         instance_index: int = None) -> Optional[Instance]:

        # Tokenize question
        question_tokens = self._tokenizer.tokenize_with_offsets(question_text)
        question_text_index_to_token_index = index_text_to_tokens(
            question_text, question_tokens)
        question_words = self._word_tokenize(question_text)
        question_alignment = self._tokenizer.align_tokens_to_tokens(
            question_text, question_words, question_tokens)
        question_wordpieces = self._tokenizer.alignment_to_token_wordpieces(
            question_alignment)

        # Index tokens
        encoded_inputs = self._tokenizer.encode_plus(
            [token.text for token in question_tokens],
            [token.text for token in passage_tokens],
            add_special_tokens=True,
            max_length=self._max_pieces,
            truncation_strategy='only_second',
            return_token_type_ids=True,
            return_special_tokens_mask=True)
        question_passage_token_type_ids = encoded_inputs['token_type_ids']
        question_passage_special_tokens_mask = encoded_inputs[
            'special_tokens_mask']
        question_position = self._tokenizer.get_type_position_in_sequence(
            0, question_passage_token_type_ids,
            question_passage_special_tokens_mask)
        passage_position = self._tokenizer.get_type_position_in_sequence(
            1, question_passage_token_type_ids,
            question_passage_special_tokens_mask)
        question_passage_tokens, num_of_tokens_per_type = self._tokenizer.convert_to_tokens(
            encoded_inputs, [{
                'tokens': question_tokens,
                'wordpieces': question_wordpieces,
                'position': question_position
            }, {
                'tokens': passage_tokens,
                'wordpieces': passage_wordpieces,
                'position': passage_position
            }])

        # Adjust wordpieces
        question_passage_wordpieces = self._tokenizer.adjust_wordpieces(
            [{
                'wordpieces': question_wordpieces,
                'position': question_position,
                'num_of_tokens': num_of_tokens_per_type[0]
            }, {
                'wordpieces': passage_wordpieces,
                'position': passage_position,
                'num_of_tokens': num_of_tokens_per_type[1]
            }], question_passage_tokens)

        # Adjust text index to token index
        question_text_index_to_token_index = [
            token_index + question_position
            for i, token_index in enumerate(question_text_index_to_token_index)
            if token_index < num_of_tokens_per_type[0]
        ]
        passage_text_index_to_token_index = [
            token_index + passage_position
            for i, token_index in enumerate(passage_text_index_to_token_index)
            if token_index < num_of_tokens_per_type[1]
        ]

        # Truncation-related code
        encoded_passage_tokens_length = num_of_tokens_per_type[1]
        if encoded_passage_tokens_length > 0:
            if encoded_passage_tokens_length < len(passage_tokens):
                first_truncated_passage_token = passage_tokens[
                    encoded_passage_tokens_length]
                max_passage_length = first_truncated_passage_token.idx
            else:
                max_passage_length = -1
        else:
            max_passage_length = 0

        fields: Dict[str, Field] = {}

        fields[
            'question_passage_tokens'] = question_passage_field = LabelsField(
                encoded_inputs['input_ids'])
        fields['question_passage_token_type_ids'] = LabelsField(
            question_passage_token_type_ids)
        fields['question_passage_special_tokens_mask'] = LabelsField(
            question_passage_special_tokens_mask)
        fields['question_passage_pad_mask'] = LabelsField(
            [1] * len(question_passage_tokens))

        # in a word broken up into pieces, every piece except the first should be ignored when calculating the loss
        first_wordpiece_mask = [
            i == wordpieces[0]
            for i, wordpieces in enumerate(question_passage_wordpieces)
        ]
        fields['first_wordpiece_mask'] = LabelsField(first_wordpiece_mask)

        # Compile question, passage, answer metadata
        metadata = {
            'original_passage': passage_text,
            'original_question': question_text,
            'passage_tokens': passage_tokens,
            'question_tokens': question_tokens,
            'question_passage_tokens': question_passage_tokens,
            'question_passage_wordpieces': question_passage_wordpieces,
            'passage_id': passage_id,
            'question_id': question_id,
            'max_passage_length': max_passage_length
        }
        if instance_index is not None:
            metadata['instance_index'] = instance_index

        if answer_annotations:
            _, answer_texts = extract_answer_info_from_annotation(
                answer_annotations[0])
            answer_texts = list(OrderedDict.fromkeys(answer_texts))

            gold_indexes = {'question': [], 'passage': []}
            for original_answer_annotation in original_answer_annotations[0]:
                if original_answer_annotation['text'] in answer_texts:
                    gold_index = original_answer_annotation['answer_start']
                    if gold_index not in gold_indexes['passage']:
                        gold_indexes['passage'].append(gold_index)

            metadata['answer_annotations'] = answer_annotations

            kwargs = {
                'seq_tokens':
                question_passage_tokens,
                'seq_field':
                question_passage_field,
                'seq_wordpieces':
                question_passage_wordpieces,
                'question_text':
                question_text,
                'question_text_index_to_token_index':
                question_text_index_to_token_index,
                'passage_text':
                passage_text[:max_passage_length]
                if max_passage_length > -1 else passage_text,
                'passage_text_index_to_token_index':
                passage_text_index_to_token_index,
                'answer_texts':
                answer_texts,
                'gold_indexes':
                gold_indexes,
                'answer_type':
                answer_type,  # TODO: Elad - Probably temporary, used to mimic the old reader's behavior
                'is_training':
                self.
                _is_training,  # TODO: Elad - Probably temporary, used to mimic the old reader's behavior
                'old_reader_behavior':
                self.
                _old_reader_behavior  # TODO: Elad - temporary, used to mimic the old reader's behavior
            }

            answer_generator_names = None
            if self._answer_generator_names_per_type is not None:
                answer_generator_names = self._answer_generator_names_per_type[
                    answer_type]

            has_answer = False
            for answer_generator_name, answer_field_generator in self._answer_field_generators.items(
            ):
                if answer_generator_names is None or answer_generator_name in answer_generator_names:
                    answer_fields, generator_has_answer = answer_field_generator.get_answer_fields(
                        **kwargs)
                    fields.update(answer_fields)
                    has_answer |= generator_has_answer
                else:
                    fields.update(
                        answer_field_generator.get_empty_answer_fields(
                            **kwargs))

            # throw away instances without possible answer generation
            if self._is_training and not has_answer:
                return None

        fields['metadata'] = MetadataField(metadata)

        return Instance(fields)
示例#2
0
    def _read(self, file_path: str):
        if not self._lazy and self._pickle['action'] == 'load':
            # Try to load the data, if it fails then read it from scratch and save it
            loaded_pkl = load_pkl(self._pickle, self._is_training)
            if loaded_pkl is not None:
                for instance in loaded_pkl:
                    yield instance
                return
            else:
                self._pickle['action'] = 'save'

        # import pdb
        # pdb.set_trace()

        # add supoort for multi-dataset training
        global_index = 0
        instances_count = 0
        instances = []
        # global_instances = {"global_index": 0, "instances_count": 0, "instances": []}
        for ind, single_file_path in enumerate(file_path.split(',')):
            single_file_path_cached = cached_path(single_file_path)
            print(single_file_path)
            with open(single_file_path_cached,
                      encoding="utf8") as dataset_file:
                dataset = json.load(dataset_file)

            dataset = standardize_dataset(dataset, self._standardize_text_func)
            # self.gen_dataset_instances(dataset, global_instances)
            for passage_id, passage_info in dataset.items():
                # for passage_id, passage_info in tqdm(dataset.items()):
                passage_text = passage_info['passage']

                # Tokenize passage
                passage_tokens = self._tokenizer.tokenize_with_offsets(
                    passage_text)
                passage_text_index_to_token_index = index_text_to_tokens(
                    passage_text, passage_tokens)
                passage_words = self._word_tokenize(passage_text)
                passage_alignment = self._tokenizer.align_tokens_to_tokens(
                    passage_text, passage_words, passage_tokens)
                passage_wordpieces = self._tokenizer.alignment_to_token_wordpieces(
                    passage_alignment)

                # Process questions from this passage
                for relative_index, qa_pair in enumerate(
                        passage_info['qa_pairs']):
                    if 0 < self._max_instances <= instances_count:
                        if not self._lazy and self._pickle['action'] == 'save':
                            save_pkl(instances, self._pickle,
                                     self._is_training)
                        return

                    question_id = qa_pair['query_id']
                    question_text = qa_pair['question']

                    answer_annotations: List[Dict] = list()
                    original_answer_annotations: List[List[Dict]] = list()
                    answer_type = None
                    if 'answer' in qa_pair and qa_pair['answer']:
                        answer = qa_pair['answer']
                        original_answer = qa_pair['original_answer']

                        answer_type = get_answer_type(answer)
                        if answer_type is None or answer_type not in self._answer_types_filter:
                            continue

                        answer_annotations.append(answer)
                        original_answer_annotations.append(original_answer)

                        # If the standardization deleted characters then we need to adjust answer_start
                        deletion_indexes = \
                        self._standardize_text_func(passage_info['original_passage'], deletions_tracking=True)[1]
                        for span in original_answer:
                            answer_start = span['answer_start']
                            for index in deletion_indexes.keys():
                                if span['answer_start'] > index:
                                    answer_start -= deletion_indexes[index]
                            span['answer_start'] = answer_start

                    if self._is_training and answer_type is None:
                        continue

                    try:
                        from termcolor import colored
                        instance = self.text_to_instance(
                            question_text, passage_text, passage_tokens,
                            passage_text_index_to_token_index,
                            passage_wordpieces, question_id, passage_id,
                            answer_annotations, original_answer_annotations,
                            answer_type, global_index + relative_index)
                    except BaseException as e:
                        print(colored(e, 'green'))
                        print(
                            colored("found exception in sentence: " +
                                    question_id + question_text))
                    # else:
                    #     continue

                    if instance is not None:
                        instances_count += 1
                        if not self._lazy:
                            instances.append(instance)
                        yield instance
                global_index += len(passage_info['qa_pairs'])
        if not self._lazy and self._pickle['action'] == 'save':
            save_pkl(instances, self._pickle, self._is_training)
示例#3
0
    def gen_dataset_instances(self, dataset, globel_info):
        import pdb
        pdb.set_trace()
        print("generate dataset instances1")

        for passage_id, passage_info in tqdm(dataset.items()):
            passage_text = passage_info['passage']

            # Tokenize passage
            passage_tokens = self._tokenizer.tokenize_with_offsets(
                passage_text)
            passage_text_index_to_token_index = index_text_to_tokens(
                passage_text, passage_tokens)
            passage_words = self._word_tokenize(passage_text)
            passage_alignment = self._tokenizer.align_tokens_to_tokens(
                passage_text, passage_words, passage_tokens)
            passage_wordpieces = self._tokenizer.alignment_to_token_wordpieces(
                passage_alignment)

            # Process questions from this passage
            for relative_index, qa_pair in enumerate(passage_info['qa_pairs']):
                if 0 < self._max_instances <= globel_info["instances_count"]:
                    if not self._lazy and self._pickle['action'] == 'save':
                        save_pkl(globel_info["instance"], self._pickle,
                                 self._is_training)
                    return

                question_id = qa_pair['query_id']
                question_text = qa_pair['question']

                answer_annotations: List[Dict] = list()
                original_answer_annotations: List[List[Dict]] = list()
                answer_type = None
                if 'answer' in qa_pair and qa_pair['answer']:
                    answer = qa_pair['answer']
                    original_answer = qa_pair['original_answer']

                    answer_type = get_answer_type(answer)
                    if answer_type is None or answer_type not in self._answer_types_filter:
                        continue

                    answer_annotations.append(answer)
                    original_answer_annotations.append(original_answer)

                    # If the standardization deleted characters then we need to adjust answer_start
                    deletion_indexes = self._standardize_text_func(
                        passage_info['original_passage'],
                        deletions_tracking=True)[1]
                    for span in original_answer:
                        answer_start = span['answer_start']
                        for index in deletion_indexes.keys():
                            if span['answer_start'] > index:
                                answer_start -= deletion_indexes[index]
                        span['answer_start'] = answer_start

                if self._is_training and answer_type is None:
                    continue

                instance = self.text_to_instance(
                    question_text, passage_text, passage_tokens,
                    passage_text_index_to_token_index, passage_wordpieces,
                    question_id, passage_id, answer_annotations,
                    original_answer_annotations, answer_type,
                    globel_info["global_index"] + relative_index)
                if instance is not None:
                    globel_info["instances_count"] += 1
                    if not self._lazy:
                        globel_info["instances"].append(instance)
                    yield instance
            globel_info["globel_index"] += len(passage_info['qa_pairs'])
        print("generate dataset instances")
示例#4
0
    def text_to_instance(self,
                         question_text: str,
                         passage_text: str,
                         passage_tokens: List[Token],
                         passage_text_index_to_token_index: List[int],
                         passage_wordpieces: List[List[int]],
                         number_occurrences_in_passage: List[Dict[str, Any]],
                         question_id: str = None,
                         passage_id: str = None,
                         answer_annotations: List[Dict] = None,
                         answer_type: str = None,
                         instance_index: int = None) -> Optional[Instance]:
        """
        process a question/answer pair (related to a passage) and prepares an `Instance`

        Parameters
        ----------
        question_text: str,
        passage_text: str,
        passage_tokens: List[Token],
        passage_text_index_to_token_index: List[int],
        passage_wordpieces: List[List[int]],
        number_occurrences_in_passage: List[Dict[str, Any]],
        question_id: str = None,
        passage_id: str = None,
        answer_annotations: List[Dict] = None,
        answer_type: str = None,
        instance_index: int = None

        Returns
        -------
            instance: Optional[Instance]
                an `Instance` containing all the data that the `Model` will takes as the input
        """

        # We alter it, so use a copy to keep it usable for the next questions with the same paragrpah
        number_occurrences_in_passage = [number_occurrence.copy() for number_occurrence in
                                         number_occurrences_in_passage]

        """ Tokenize question
        - question_tokens: [ĠHow, Ġmany, Ġpoints, Ġdid, Ġthe, Ġbu, cc, aneers, Ġneed, ...
        - question_text_index_to_token_index: [0, 0, 0, 0, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, ...
        - question_words: [How, many, points, did, the, buccaneers, need, to, tie, in, the, first, ?] 
        - question_alignment: [[0], [1], [2], [3], [4], [5, 6, 7], [8], [9], [10], [11], [12], [13], [14]]
        - question_wordpieces: [[0], [1], [2], [3], [4], [5, 6, 7], [5, 6, 7], [5, 6, 7], [8], [9], [10], ...
        """
        question_tokens = self._tokenizer.tokenize_with_offsets(question_text)
        question_text_index_to_token_index = index_text_to_tokens(question_text, question_tokens)
        question_words = self._word_tokenize(question_text)
        question_alignment = self._tokenizer.align_tokens_to_tokens(question_text, question_words, question_tokens)
        question_wordpieces = self._tokenizer.alignment_to_token_wordpieces(question_alignment)

        """ Index tokens
        encoded_inputs is a dictionary with:
        - 'special_tokens_mask',
        - 'input_ids',
        - 'token_type_ids',
        - 'attention_mask'
        """
        encoded_inputs = self._tokenizer.encode_plus([token.text for token in question_tokens],
                                                     [token.text for token in passage_tokens],
                                                     add_special_tokens=True, max_length=self._max_pieces,
                                                     truncation_strategy='only_second',
                                                     return_token_type_ids=True,
                                                     return_special_tokens_mask=True)

        question_passage_token_type_ids = encoded_inputs['token_type_ids']
        question_passage_special_tokens_mask = encoded_inputs['special_tokens_mask']

        question_position = self._tokenizer.get_type_position_in_sequence(0, question_passage_token_type_ids,
                                                                          question_passage_special_tokens_mask)
        passage_position = self._tokenizer.get_type_position_in_sequence(1, question_passage_token_type_ids,
                                                                         question_passage_special_tokens_mask)
        question_passage_tokens, num_of_tokens_per_type = self._tokenizer.convert_to_tokens(encoded_inputs, [
            {'tokens': question_tokens, 'wordpieces': question_wordpieces, 'position': question_position},
            {'tokens': passage_tokens, 'wordpieces': passage_wordpieces, 'position': passage_position}
        ])

        # Adjust wordpieces
        question_passage_wordpieces = self._tokenizer.adjust_wordpieces([
            {'wordpieces': question_wordpieces, 'position': question_position,
             'num_of_tokens': num_of_tokens_per_type[0]},
            {'wordpieces': passage_wordpieces, 'position': passage_position, 'num_of_tokens': num_of_tokens_per_type[1]}
        ], question_passage_tokens)

        # Adjust text index to token index
        question_text_index_to_token_index = [token_index + question_position for i, token_index in
                                              enumerate(question_text_index_to_token_index)
                                              if token_index < num_of_tokens_per_type[0]]
        passage_text_index_to_token_index = [token_index + passage_position for i, token_index in
                                             enumerate(passage_text_index_to_token_index)
                                             if token_index < num_of_tokens_per_type[1]]

        # Truncation-related code
        encoded_passage_tokens_length = num_of_tokens_per_type[1]
        if encoded_passage_tokens_length > 0:
            if encoded_passage_tokens_length < len(passage_tokens):
                first_truncated_passage_token = passage_tokens[encoded_passage_tokens_length]
                max_passage_length = first_truncated_passage_token.idx
                number_occurrences_in_passage = \
                    clipped_passage_num(number_occurrences_in_passage, encoded_passage_tokens_length)
            else:
                max_passage_length = -1
        else:
            max_passage_length = 0

        # update the indices of the numbers with respect to the question.
        for number_occurrence in number_occurrences_in_passage:
            number_occurrence['indices'] = [index + passage_position for index in number_occurrence['indices']]

        # hack to guarantee minimal length of padded number
        # according to dataset_readers.reading_comprehension.drop from allennlp)
        number_occurrences_in_passage.append({
            'value': 0,
            'indices': [-1]
        })

        # create fields dictionary for the `Instance`
        fields: Dict[str, Field] = {}

        fields['question_passage_tokens'] = question_passage_field = LabelsField(encoded_inputs['input_ids'])
        fields['question_passage_token_type_ids'] = LabelsField(question_passage_token_type_ids)
        fields['question_passage_special_tokens_mask'] = LabelsField(question_passage_special_tokens_mask)
        fields['question_passage_pad_mask'] = LabelsField([1] * len(question_passage_tokens))

        # in a word broken up into pieces, every piece except the first should be ignored when calculating the loss
        first_wordpiece_mask = [i == wordpieces[0] for i, wordpieces in enumerate(question_passage_wordpieces)]
        fields['first_wordpiece_mask'] = LabelsField(first_wordpiece_mask)

        fields['number_indices'] = get_number_indices_field(number_occurrences_in_passage)

        # Compile question, passage, answer metadata
        metadata = {'original_passage': passage_text,
                    'original_question': question_text,
                    'original_numbers': [number_occurrence['value'] for number_occurrence in
                                         number_occurrences_in_passage],
                    'passage_tokens': passage_tokens,
                    'question_tokens': question_tokens,
                    'question_passage_tokens': question_passage_tokens,
                    'question_passage_wordpieces': question_passage_wordpieces,
                    'passage_id': passage_id,
                    'question_id': question_id,
                    'max_passage_length': max_passage_length}
        if instance_index is not None:
            metadata['instance_index'] = instance_index

        if answer_annotations:
            # Get answer type, answer text, tokenize
            # For multi-span, remove repeating answers. Although possible, in the dataset it is mostly mistakes.
            answer_accessor, answer_texts = extract_answer_info_from_annotation(answer_annotations[0])
            if answer_accessor == AnswerAccessor.SPAN.value:
                answer_texts = list(OrderedDict.fromkeys(answer_texts))

            metadata['answer_annotations'] = answer_annotations

            kwargs = {
                'seq_tokens': question_passage_tokens,
                'seq_field': question_passage_field,
                'seq_wordpieces': question_passage_wordpieces,
                'question_text': question_text,
                'question_text_index_to_token_index': question_text_index_to_token_index,
                'passage_text': passage_text[:max_passage_length] if max_passage_length > -1 else passage_text,
                'passage_text_index_to_token_index': passage_text_index_to_token_index,
                'answer_texts': answer_texts,
                'number_occurrences_in_passage': number_occurrences_in_passage,
                'answer_type': answer_type,  # TODO: Elad - Probably temporary, used to mimic the old reader's behavior
                'is_training': self._is_training,
                # TODO: Elad - Probably temporary, used to mimic the old reader's behavior
                'old_reader_behavior': self._old_reader_behavior
                # TODO: Elad - temporary, used to mimic the old reader's behavior
            }

            answer_generator_names = None
            if self._answer_generator_names_per_type is not None:
                answer_generator_names = self._answer_generator_names_per_type[answer_type]

            has_answer = False
            for answer_generator_name, answer_field_generator in self._answer_field_generators.items():
                if answer_generator_names is None or answer_generator_name in answer_generator_names:
                    answer_fields, generator_has_answer = answer_field_generator.get_answer_fields(**kwargs)
                    fields.update(answer_fields)
                    has_answer |= generator_has_answer
                else:
                    fields.update(answer_field_generator.get_empty_answer_fields(**kwargs))

            # throw away instances without possible answer generation
            if self._is_training and not has_answer:
                return None

        fields['metadata'] = MetadataField(metadata)

        return Instance(fields)
示例#5
0
    def _read(self, file_path: str):
        """
        This method reads the file containing the dataset and prepares the data for the model. It processes the
        passage text and calls `text_to_instance` method to process the question/answer text

        Parameters
        ----------
        file_path: str
            the file location of the dataset

        Returns
        -------
        instances: Iterable
            an Iterable (a generator) that contains pre-processed data as a list of `Instance` objects;
        """

        instances_count = 0
        if not self._lazy and self._pickle['action'] == 'load':
            print("entra nell'if del load")
            # Try to load the data, if it fails then read it from scratch and save it
            loaded_pkl = load_pkl(self._pickle, self._is_training)
            print(loaded_pkl)
            if loaded_pkl is not None:
                for instance in loaded_pkl:
                    if 0 < self._max_instances <= instances_count:
                        return
                    instances_count += 1
                    yield instance
                return
            else:
                self._pickle['action'] = 'save'

        file_path = cached_path(file_path)
        with open(file_path, encoding="utf8") as dataset_file:
            dataset = json.load(dataset_file)

        dataset = standardize_dataset(dataset, self._standardize_text_func)

        global_index = 0
        instances = []
        for passage_id, passage_info in tqdm(dataset.items()):
            passage_text = passage_info['passage']

            """
            Tokenize passage
            - passage_tokens: [ĠTo, Ġstart, Ġthe, Ġseason, ...
            - passage_text_index_to_token_index: [0, 0, 0, 1, 1, 1, 1, 1, 1, ...
            - passage_words: [To, start, the, season, ,, the, ...
            - passage_alignment_ [[0], [1], [2], [3], [4], [5], ...
            - passage_wordpieces: [[0], [1], [2], [3], ...
            """
            passage_tokens = self._tokenizer.tokenize_with_offsets(passage_text)
            passage_text_index_to_token_index = index_text_to_tokens(passage_text, passage_tokens)
            passage_words = number_extraction_tokens = self._word_tokenize(passage_text)
            passage_alignment = self._tokenizer.align_tokens_to_tokens(passage_text, passage_words, passage_tokens)
            passage_wordpieces = self._tokenizer.alignment_to_token_wordpieces(passage_alignment)

            number_occurrences_in_passage = \
                extract_number_occurrences(number_extraction_tokens, passage_alignment)

            # Process questions from this passage
            for relative_index, qa_pair in enumerate(passage_info['qa_pairs']):
                if 0 < self._max_instances <= instances_count:
                    if not self._lazy and self._pickle['action'] == 'save':
                        save_pkl(instances, self._pickle, self._is_training)
                    return

                # extract question id and question text from the pair
                question_id = qa_pair['query_id']
                question_text = qa_pair['question']

                # this list will contains the answer and the validated answers
                answer_annotations: List[Dict] = list()
                answer_type = None
                if 'answer' in qa_pair and qa_pair['answer']:
                    # extract the answer
                    answer = qa_pair['answer']

                    # deduces the answer type checking the answer's fields
                    answer_type = get_answer_type(answer)
                    # skips answer with a non-valid type
                    if answer_type is None or answer_type not in self._answer_types_filter:
                        continue

                    answer_annotations.append(answer)

                if 'validated_answers' in qa_pair and qa_pair['validated_answers']:
                    answer_annotations += qa_pair['validated_answers']

                if self._is_training and answer_type is None:
                    continue

                instance = self.text_to_instance(question_text,
                                                 passage_text,
                                                 passage_tokens,
                                                 passage_text_index_to_token_index,
                                                 passage_wordpieces,
                                                 number_occurrences_in_passage,
                                                 question_id,
                                                 passage_id,
                                                 answer_annotations,
                                                 answer_type,
                                                 global_index + relative_index)
                if instance is not None:
                    instances_count += 1
                    if not self._lazy:
                        instances.append(instance)
                    yield instance
            global_index += len(passage_info['qa_pairs'])
        if not self._lazy and self._pickle['action'] == 'save':
            save_pkl(instances, self._pickle, self._is_training)
示例#6
0
    def _read(self, file_path: str):
        instances_count = 0
        if not self._lazy and self._pickle['action'] == 'load':
            # Try to load the data, if it fails then read it from scratch and save it
            loaded_pkl = load_pkl(self._pickle, self._is_training)
            if loaded_pkl is not None:
                for instance in loaded_pkl:
                    if 0 < self._max_instances <= instances_count:
                        return
                    instances_count += 1
                    yield instance
                return
            else:
                self._pickle['action'] = 'save'

        file_path = cached_path(file_path)
        with open(file_path, encoding='utf8') as dataset_file:
            dataset = json.load(dataset_file)

        dataset = standardize_dataset(dataset, self._standardize_text_func)

        global_index = 0
        instances = []
        for passage_id, passage_info in tqdm(dataset.items()):
            passage_text = passage_info['passage']

            # Tokenize passage
            passage_tokens = self._tokenizer.tokenize_with_offsets(
                passage_text)
            passage_text_index_to_token_index = index_text_to_tokens(
                passage_text, passage_tokens)
            passage_words = number_extraction_tokens = self._word_tokenize(
                passage_text)
            passage_alignment = self._tokenizer.align_tokens_to_tokens(
                passage_text, passage_words, passage_tokens)
            passage_wordpieces = self._tokenizer.alignment_to_token_wordpieces(
                passage_alignment)

            number_occurrences_in_passage =\
                extract_number_occurrences(number_extraction_tokens, passage_alignment)

            # Process questions from this passage
            for relative_index, qa_pair in enumerate(passage_info['qa_pairs']):
                if 0 < self._max_instances <= instances_count:
                    if not self._lazy and self._pickle['action'] == 'save':
                        save_pkl(instances, self._pickle, self._is_training)
                    return

                question_id = qa_pair['query_id']
                question_text = qa_pair['question']

                answer_annotations: List[Dict] = list()
                answer_type = None
                if 'answer' in qa_pair and qa_pair['answer']:
                    answer = qa_pair['answer']

                    answer_type = get_answer_type(answer)
                    if answer_type is None or answer_type not in self._answer_types_filter:
                        continue

                    answer_annotations.append(answer)

                if 'validated_answers' in qa_pair and qa_pair[
                        'validated_answers']:
                    answer_annotations += qa_pair['validated_answers']

                if self._is_training and answer_type is None:
                    continue

                instance = self.text_to_instance(
                    question_text, passage_text, passage_tokens,
                    passage_text_index_to_token_index, passage_wordpieces,
                    number_occurrences_in_passage, question_id, passage_id,
                    answer_annotations, answer_type,
                    global_index + relative_index)
                if instance is not None:
                    instances_count += 1
                    if not self._lazy:
                        instances.append(instance)
                    yield instance
            global_index += len(passage_info['qa_pairs'])
        if not self._lazy and self._pickle['action'] == 'save':
            save_pkl(instances, self._pickle, self._is_training)
def get_analysis(dataset_pkl_path, predictions_path, heads):
    with open(os.path.join(dataset_pkl_path), 'rb') as dataset_pkl:
        dataset = pickle.load(dataset_pkl)

    predictions = {}
    with open(os.path.join(predictions_path), 'rb') as predictions_file:
        while True:
            line = predictions_file.readline()
            if not line:
                break
            prediction = json.loads(line)
            predictions[prediction['query_id']] = prediction

    tokenizer = HuggingfaceTransformersTokenizer(
        pretrained_model='roberta-large')
    word_tokenizer = custom_word_tokenizer()
    word_tokenize =\
        lambda text: [token for token in split_tokens_by_hyphen(word_tokenizer.tokenize(text))]

    num_gold_tokens_to_stats = defaultdict(lambda: defaultdict(int))
    num_occurrences_stats = defaultdict(lambda: defaultdict(int))
    for i, instance in tqdm(enumerate(dataset)):
        full_prediction = predictions[instance['metadata']['question_id']]
        if full_prediction['predicted_ability'] not in heads:
            continue

        question_text = instance['metadata']['original_question']
        question_tokens = tokenizer.tokenize_with_offsets(question_text)
        question_text_index_to_token_index = index_text_to_tokens(
            question_text, question_tokens)
        question_words = word_tokenize(question_text)
        question_alignment = tokenizer.align_tokens_to_tokens(
            question_text, question_words, question_tokens)
        question_wordpieces = tokenizer.alignment_to_token_wordpieces(
            question_alignment)

        passage_text = instance['metadata']['original_passage']
        passage_tokens = tokenizer.tokenize_with_offsets(passage_text)
        passage_text_index_to_token_index = index_text_to_tokens(
            passage_text, passage_tokens)
        passage_words = word_tokenize(passage_text)
        passage_alignment = tokenizer.align_tokens_to_tokens(
            passage_text, passage_words, passage_tokens)
        passage_wordpieces = tokenizer.alignment_to_token_wordpieces(
            passage_alignment)
        passage_text = passage_text[:instance['metadata']
                                    ['max_passage_length']]

        question_passage_tokens = instance['metadata'][
            'question_passage_tokens']
        question_passage_wordpieces = instance['metadata'][
            'question_passage_wordpieces']

        # Index tokens
        encoded_inputs = tokenizer.encode_plus(
            [token.text for token in question_tokens],
            [token.text for token in passage_tokens],
            add_special_tokens=True,
            max_length=MAX_PIECES,
            truncation_strategy='only_second',
            return_token_type_ids=True,
            return_special_tokens_mask=True)
        question_passage_token_type_ids = encoded_inputs['token_type_ids']
        question_passage_special_tokens_mask = encoded_inputs[
            'special_tokens_mask']

        question_position = tokenizer.get_type_position_in_sequence(
            0, question_passage_token_type_ids,
            question_passage_special_tokens_mask)
        passage_position = tokenizer.get_type_position_in_sequence(
            1, question_passage_token_type_ids,
            question_passage_special_tokens_mask)
        question_passage_tokens, num_of_tokens_per_type = tokenizer.convert_to_tokens(
            encoded_inputs, [{
                'tokens': question_tokens,
                'wordpieces': question_wordpieces,
                'position': question_position
            }, {
                'tokens': passage_tokens,
                'wordpieces': passage_wordpieces,
                'position': passage_position
            }])

        # Adjust text index to token index
        question_text_index_to_token_index = [
            token_index + question_position
            for i, token_index in enumerate(question_text_index_to_token_index)
            if token_index < num_of_tokens_per_type[0]
        ]
        passage_text_index_to_token_index = [
            token_index + passage_position
            for i, token_index in enumerate(passage_text_index_to_token_index)
            if token_index < num_of_tokens_per_type[1]
        ]

        prediction = full_prediction['answer']['value']
        maximizing_ground_truth = full_prediction['maximizing_ground_truth']
        answer_accessor, answer_texts = extract_answer_info_from_annotation(
            maximizing_ground_truth)

        gold_indexes = {'question': None, 'passage': None}

        alignment = align_predicted_and_maximizing_gold(
            prediction, answer_texts)
        for gold_index, predicted_index in enumerate(alignment):
            num_of_gold_tokens = len(
                tokenizer.tokenize(answer_texts[gold_index]))
            num_of_predicted_tokens = len(
                tokenizer.tokenize(prediction[predicted_index]))
            num_gold_tokens_to_stats[num_of_gold_tokens]['count'] += 1
            num_gold_tokens_to_stats[num_of_gold_tokens][
                'num_predicted_tokens'] += num_of_predicted_tokens
            num_gold_tokens_to_stats[num_of_gold_tokens][
                'em'] += full_prediction['em'] * 100
            num_gold_tokens_to_stats[num_of_gold_tokens][
                'f1'] += full_prediction['f1'] * 100

            answer_spans = []
            answer_spans += find_valid_spans(
                question_text, [answer_texts[gold_index]],
                question_text_index_to_token_index, question_passage_tokens,
                question_passage_wordpieces, gold_indexes['question'])
            answer_spans += find_valid_spans(
                passage_text, [answer_texts[gold_index]],
                passage_text_index_to_token_index, question_passage_tokens,
                question_passage_wordpieces, gold_indexes['passage'])
            occurrences = len(answer_spans)
            num_occurrences_stats[occurrences]['count'] += 1
            num_occurrences_stats[occurrences][
                'em'] += full_prediction['em'] * 100
            num_occurrences_stats[occurrences][
                'f1'] += full_prediction['f1'] * 100

    return num_gold_tokens_to_stats, num_occurrences_stats