示例#1
0
    def _read(self, file_path: str):
        # if `file_path` is a URL, redirect to the cache
        file_path = cached_path(file_path)
        logger.info("Reading file at %s", file_path)
        with open(file_path) as dataset_file:
            dataset = json.load(dataset_file)
        logger.info("Reading the dataset")
        kept_count, skip_count = 0, 0
        for passage_id, passage_info in dataset.items():
            passage_text = passage_info["passage"]
            passage_tokens = self._tokenizer.tokenize(passage_text)
            passage_tokens = split_tokens_by_hyphen(passage_tokens)
            for question_answer in passage_info["qa_pairs"]:
                question_id = question_answer["query_id"]
                question_text = question_answer["question"].strip()
                answer_annotations = []
                if "answer" in question_answer:
                    answer_annotations.append(question_answer["answer"])
                if "validated_answers" in question_answer:
                    answer_annotations += question_answer["validated_answers"]

                instance = self.text_to_instance(
                    question_text,
                    passage_text,
                    question_id,
                    passage_id,
                    answer_annotations,
                    passage_tokens,
                )
                if instance is not None:
                    kept_count += 1
                    yield instance
                else:
                    skip_count += 1
        logger.info(f"Skipped {skip_count} questions, kept {kept_count} questions.")
示例#2
0
    def text_to_instance(
        self,  # type: ignore
        question_text: str,
        passage_text: str,
        question_id: str = None,
        passage_id: str = None,
        answer_annotations: List[Dict] = None,
        passage_tokens: List[Token] = None,
    ) -> Union[Instance, None]:

        if not passage_tokens:
            passage_tokens = self._tokenizer.tokenize(passage_text)
            passage_tokens = split_tokens_by_hyphen(passage_tokens)
        question_tokens = self._tokenizer.tokenize(question_text)
        question_tokens = split_tokens_by_hyphen(question_tokens)
        if self.passage_length_limit is not None:
            passage_tokens = passage_tokens[:self.passage_length_limit]
        if self.question_length_limit is not None:
            question_tokens = question_tokens[:self.question_length_limit]

        answer_type: str = None
        answer_texts: List[str] = []
        if answer_annotations:
            # Currently we only use the first annotated answer here, but actually this doesn't affect
            # the training, because we only have one annotation for the train set.
            answer_type, answer_texts = self.extract_answer_info_from_annotation(
                answer_annotations[0])

        # Tokenize the answer text in order to find the matched span based on token
        tokenized_answer_texts = []
        for answer_text in answer_texts:
            answer_tokens = self._tokenizer.tokenize(answer_text)
            answer_tokens = split_tokens_by_hyphen(answer_tokens)
            tokenized_answer_texts.append(" ".join(token.text
                                                   for token in answer_tokens))

        if self.instance_format == "squad":
            valid_passage_spans = (self.find_valid_spans(
                passage_tokens, tokenized_answer_texts)
                                   if tokenized_answer_texts else [])
            if not valid_passage_spans:
                if "passage_span" in self.skip_when_all_empty:
                    return None
                else:
                    valid_passage_spans.append(
                        (len(passage_tokens) - 1, len(passage_tokens) - 1))
            return make_reading_comprehension_instance(
                question_tokens,
                passage_tokens,
                self._token_indexers,
                passage_text,
                valid_passage_spans,
                # this `answer_texts` will not be used for evaluation
                answer_texts,
                additional_metadata={
                    "original_passage": passage_text,
                    "original_question": question_text,
                    "passage_id": passage_id,
                    "question_id": question_id,
                    "valid_passage_spans": valid_passage_spans,
                    "answer_annotations": answer_annotations,
                },
            )
        elif self.instance_format == "bert":
            question_concat_passage_tokens = question_tokens + [
                Token("[SEP]")
            ] + passage_tokens
            valid_passage_spans = []
            for span in self.find_valid_spans(passage_tokens,
                                              tokenized_answer_texts):
                # This span is for `question + [SEP] + passage`.
                valid_passage_spans.append(
                    (span[0] + len(question_tokens) + 1,
                     span[1] + len(question_tokens) + 1))
            if not valid_passage_spans:
                if "passage_span" in self.skip_when_all_empty:
                    return None
                else:
                    valid_passage_spans.append((
                        len(question_concat_passage_tokens) - 1,
                        len(question_concat_passage_tokens) - 1,
                    ))
            answer_info = {
                "answer_texts":
                answer_texts,  # this `answer_texts` will not be used for evaluation
                "answer_passage_spans": valid_passage_spans,
            }
            return self.make_bert_drop_instance(
                question_tokens,
                passage_tokens,
                question_concat_passage_tokens,
                self._token_indexers,
                passage_text,
                answer_info,
                additional_metadata={
                    "original_passage": passage_text,
                    "original_question": question_text,
                    "passage_id": passage_id,
                    "question_id": question_id,
                    "answer_annotations": answer_annotations,
                },
            )
        elif self.instance_format == "drop":
            numbers_in_passage = []
            number_indices = []
            for token_index, token in enumerate(passage_tokens):
                number = self.convert_word_to_number(token.text)
                if number is not None:
                    numbers_in_passage.append(number)
                    number_indices.append(token_index)
            # hack to guarantee minimal length of padded number
            numbers_in_passage.append(0)
            number_indices.append(-1)
            numbers_as_tokens = [
                Token(str(number)) for number in numbers_in_passage
            ]

            valid_passage_spans = (self.find_valid_spans(
                passage_tokens, tokenized_answer_texts)
                                   if tokenized_answer_texts else [])
            valid_question_spans = (self.find_valid_spans(
                question_tokens, tokenized_answer_texts)
                                    if tokenized_answer_texts else [])

            target_numbers = []
            # `answer_texts` is a list of valid answers.
            for answer_text in answer_texts:
                number = self.convert_word_to_number(answer_text)
                if number is not None:
                    target_numbers.append(number)
            valid_signs_for_add_sub_expressions: List[List[int]] = []
            valid_counts: List[int] = []
            if answer_type in ["number", "date"]:
                valid_signs_for_add_sub_expressions = self.find_valid_add_sub_expressions(
                    numbers_in_passage, target_numbers)
            if answer_type in ["number"]:
                # Currently we only support count number 0 ~ 9
                numbers_for_count = list(range(10))
                valid_counts = self.find_valid_counts(numbers_for_count,
                                                      target_numbers)

            type_to_answer_map = {
                "passage_span": valid_passage_spans,
                "question_span": valid_question_spans,
                "addition_subtraction": valid_signs_for_add_sub_expressions,
                "counting": valid_counts,
            }

            if self.skip_when_all_empty and not any(
                    type_to_answer_map[skip_type]
                    for skip_type in self.skip_when_all_empty):
                return None

            answer_info = {
                "answer_texts":
                answer_texts,  # this `answer_texts` will not be used for evaluation
                "answer_passage_spans": valid_passage_spans,
                "answer_question_spans": valid_question_spans,
                "signs_for_add_sub_expressions":
                valid_signs_for_add_sub_expressions,
                "counts": valid_counts,
            }

            return self.make_marginal_drop_instance(
                question_tokens,
                passage_tokens,
                numbers_as_tokens,
                number_indices,
                self._token_indexers,
                passage_text,
                answer_info,
                additional_metadata={
                    "original_passage": passage_text,
                    "original_question": question_text,
                    "original_numbers": numbers_in_passage,
                    "passage_id": passage_id,
                    "question_id": question_id,
                    "answer_info": answer_info,
                    "answer_annotations": answer_annotations,
                },
            )
        else:
            raise ValueError(
                f'Expect the instance format to be "drop", "squad" or "bert", '
                f"but got {self.instance_format}")