예제 #1
0
    def test_splits_roberta(self):
        tokenizer = PretrainedTransformerTokenizer("roberta-base")

        sentence = "A, <mask> AllenNLP sentence."
        expected_tokens = ["<s>", "A", ",", "<mask>", "Allen", "N", "LP", "Ġsentence", ".", "</s>"]
        tokens = [t.text for t in tokenizer.tokenize(sentence)]
        assert tokens == expected_tokens

        # sentence pair
        sentence_1 = "A, <mask> AllenNLP sentence."
        sentence_2 = "A sentence."
        expected_tokens = [
            "<s>",
            "A",
            ",",
            "<mask>",
            "Allen",
            "N",
            "LP",
            "Ġsentence",
            ".",
            "</s>",
            "</s>",
            "A",
            "Ġsentence",
            ".",
            "</s>",
        ]
        tokens = [t.text for t in tokenizer.tokenize_sentence_pair(sentence_1, sentence_2)]
        assert tokens == expected_tokens
예제 #2
0
 def test_as_array_produces_token_sequence_roberta_sentence_pair(self):
     tokenizer = AutoTokenizer.from_pretrained("roberta-base")
     allennlp_tokenizer = PretrainedTransformerTokenizer("roberta-base")
     indexer = PretrainedTransformerIndexer(model_name="roberta-base")
     default_format = "<s> AllenNLP is great! </s> </s> Really it is! </s>"
     tokens = tokenizer.tokenize(default_format)
     expected_ids = tokenizer.convert_tokens_to_ids(tokens)
     allennlp_tokens = allennlp_tokenizer.tokenize_sentence_pair(
         "AllenNLP is great!", "Really it is!")
     vocab = Vocabulary()
     indexed = indexer.tokens_to_indices(allennlp_tokens, vocab, "key")
     assert indexed["key"] == expected_ids
예제 #3
0
 def test_as_array_produces_token_sequence_bert_cased_sentence_pair(self):
     tokenizer = AutoTokenizer.from_pretrained("bert-base-cased")
     allennlp_tokenizer = PretrainedTransformerTokenizer("bert-base-cased")
     indexer = PretrainedTransformerIndexer(model_name="bert-base-cased")
     default_format = "[CLS] AllenNLP is great! [SEP] Really it is! [SEP]"
     tokens = tokenizer.tokenize(default_format)
     expected_ids = tokenizer.convert_tokens_to_ids(tokens)
     allennlp_tokens = allennlp_tokenizer.tokenize_sentence_pair(
         "AllenNLP is great!", "Really it is!")
     vocab = Vocabulary()
     indexed = indexer.tokens_to_indices(allennlp_tokens, vocab)
     assert indexed["token_ids"] == expected_ids
예제 #4
0
    def test_splits_cased_bert(self):
        tokenizer = PretrainedTransformerTokenizer("bert-base-cased")

        sentence = "A, [MASK] AllenNLP sentence."
        expected_tokens = [
            "[CLS]",
            "A",
            ",",
            "[MASK]",
            "Allen",
            "##NL",
            "##P",
            "sentence",
            ".",
            "[SEP]",
        ]
        tokens = [t.text for t in tokenizer.tokenize(sentence)]
        assert tokens == expected_tokens

        # sentence pair
        sentence_1 = "A, [MASK] AllenNLP sentence."
        sentence_2 = "A sentence."
        expected_tokens = [
            "[CLS]",
            "A",
            ",",
            "[MASK]",
            "Allen",
            "##NL",
            "##P",
            "sentence",
            ".",
            "[SEP]",
            "A",
            "sentence",
            ".",
            "[SEP]",
        ]
        tokens = [
            t.text
            for t in tokenizer.tokenize_sentence_pair(sentence_1, sentence_2)
        ]
        assert tokens == expected_tokens
예제 #5
0
    def test_token_idx_sentence_pairs(self):
        first_sentence = "I went to the zoo yesterday, but they had only one animal."
        second_sentence = "It was a shitzu."
        expected_tokens = [
            "<s>",
            "I",
            "Ġwent",
            "Ġto",
            "Ġthe",
            "Ġzoo",
            "Ġyesterday",
            ",",
            "Ġbut",
            "Ġthey",
            "Ġhad",
            "Ġonly",
            "Ġone",
            "Ġanimal",
            ".",
            "</s>",
            "</s>",
            "It",
            "Ġwas",
            "Ġa",
            "Ġsh",
            "itz",
            "u",
            ".",
            "</s>",
        ]
        expected_idxs = [
            None,
            0,
            2,
            7,
            10,
            14,
            18,
            27,
            29,
            33,
            38,
            42,
            47,
            51,
            57,
            None,
            None,
            58,
            61,
            65,
            67,
            69,
            72,
            73,
            None,
        ]

        tokenizer = PretrainedTransformerTokenizer(
            "roberta-base", calculate_character_offsets=True)
        tokenized = tokenizer.tokenize_sentence_pair(first_sentence,
                                                     second_sentence)
        tokens = [t.text for t in tokenized]
        assert tokens == expected_tokens
        idxs = [t.idx for t in tokenized]
        assert idxs == expected_idxs

        # Assert that the first and the second sentence are run together with no space in between.
        first_sentence_end_index = tokens.index("</s>") - 1
        second_sentence_start_index = first_sentence_end_index + 3
        assert (idxs[first_sentence_end_index] +
                len(tokens[first_sentence_end_index]) ==
                idxs[second_sentence_start_index])
class TransformerBinaryReader(DatasetReader):
    """
    Supports reading artisets for both transformer models (e.g., pretrained_model="roberta-base" and
    combine_input_fields=True) and NLI models with separate premise and hypothesis (set combine_input_fields=False).
    Empty contexts (premises) are replaced by "N/A" in the NLI case.
    max_pieces and add_prefix only apply to transformer models
    """
    def __init__(self,
                 pretrained_model: str = None,
                 tokenizer: Optional[Tokenizer] = None,
                 token_indexers: Dict[str, TokenIndexer] = None,
                 max_pieces: int = 512,
                 add_prefix: bool = False,
                 combine_input_fields: bool = True,
                 sample: int = -1) -> None:
        super().__init__()

        if pretrained_model != None:
            self._tokenizer = PretrainedTransformerTokenizer(
                pretrained_model, max_length=max_pieces)
            token_indexer = PretrainedTransformerIndexer(pretrained_model)
            self._token_indexers = {'tokens': token_indexer}
        else:
            self._tokenizer = tokenizer or SpacyTokenizer()
            self._token_indexers = token_indexers or {
                "tokens": SingleIdTokenIndexer()
            }

        self._sample = sample
        self._add_prefix = add_prefix
        self._combine_input_fields = combine_input_fields
        self._debug_prints = -1

    @overrides
    def _read(self, file_path: str):
        self._debug_prints = 5
        cached_file_path = cached_path(file_path)

        if file_path.endswith('.gz'):
            data_file = gzip.open(cached_file_path, 'rb')
        else:
            data_file = open(cached_file_path, 'r')

        logger.info("Reading QA instances from jsonl dataset at: %s",
                    file_path)
        item_jsons = []
        for line in data_file:
            item_jsons.append(json.loads(line.strip()))

        if self._sample != -1:
            item_jsons = random.sample(item_jsons, self._sample)
            logger.info("Sampling %d examples", self._sample)

        for item_json in Tqdm.tqdm(item_jsons, total=len(item_jsons)):
            self._debug_prints -= 1
            if self._debug_prints >= 0:
                logger.info(f"====================================")
                logger.info(f"Input json: {item_json}")
            item_id = item_json["id"]

            statement_text = item_json["phrase"]
            metadata = {} if "metadata" not in item_json else item_json[
                "metadata"]
            context = item_json["context"] if "context" in item_json else None

            yield self.text_to_instance(item_id=item_id,
                                        question=statement_text,
                                        answer_id=item_json["answer"],
                                        context=context,
                                        org_metadata=metadata)

        data_file.close()

    @overrides
    def text_to_instance(
            self,  # type: ignore
            item_id: str,
            question: str,
            answer_id: int = None,
            context: str = None,
            org_metadata: dict = {}) -> Instance:
        fields: Dict[str, Field] = {}
        if self._combine_input_fields:
            qa_tokens = self.transformer_features_from_qa(question, context)
            qa_field = TextField(qa_tokens, self._token_indexers)
            fields['phrase'] = qa_field
        else:
            premise = context
            if context == "":
                premise = "N/A"
            premise_tokens = self._tokenizer.tokenize(premise)
            hypothesis_tokens = self._tokenizer.tokenize(question)
            qa_tokens = [premise_tokens, hypothesis_tokens]
            fields["premise"] = TextField(premise_tokens, self._token_indexers)
            fields["hypothesis"] = TextField(hypothesis_tokens,
                                             self._token_indexers)

        if answer_id is not None:
            fields['label'] = LabelField(answer_id, skip_indexing=True)
        new_metadata = {
            "id": item_id,
            "question_text": question,
            "context": context,
            "correct_answer_index": answer_id
        }

        # TODO Alon get rid of this in production...
        if 'skills' in org_metadata:
            new_metadata.update({'skills': org_metadata['skills']})
        if 'tags' in org_metadata:
            new_metadata.update({'tags': org_metadata['tags']})

        if self._debug_prints >= 0:
            logger.info(f"Tokens: {qa_tokens}")
            logger.info(f"Label: {answer_id}")
        fields["metadata"] = MetadataField(new_metadata)
        return Instance(fields)

    def transformer_features_from_qa(self, question: str, context: str):
        if self._add_prefix:
            question = "Q: " + question
            if context is not None and len(context) > 0:
                context = "C: " + context
        if context is not None and len(context) > 0:
            tokens = self._tokenizer.tokenize_sentence_pair(question, context)
        else:
            tokens = self._tokenizer.tokenize(question)
        return tokens
예제 #7
0
class RuleReasoningReader(DatasetReader):
    """

    Parameters
    ----------
    """
    def __init__(self,
                 pretrained_model: str,
                 max_pieces: int = 512,
                 syntax: str = "rulebase",
                 add_prefix: Dict[str, str] = None,
                 skip_id_regex: str = None,
                 scramble_context: bool = False,
                 use_context_full: bool = False,
                 sample: int = -1) -> None:
        super().__init__()
        self._tokenizer = PretrainedTransformerTokenizer(pretrained_model,
                                                         max_length=max_pieces)
        self._tokenizer_internal = self._tokenizer.tokenizer
        token_indexer = PretrainedTransformerIndexer(pretrained_model)
        self._token_indexers = {'tokens': token_indexer}

        self._max_pieces = max_pieces
        self._add_prefix = add_prefix
        self._scramble_context = scramble_context
        self._use_context_full = use_context_full
        self._sample = sample
        self._syntax = syntax
        self._skip_id_regex = skip_id_regex

    @overrides
    def _read(self, file_path: str):
        instances = self._read_internal(file_path)
        return instances

    def _read_internal(self, file_path: str):
        # if `file_path` is a URL, redirect to the cache
        file_path = cached_path(file_path)
        counter = self._sample + 1
        debug = 5
        is_done = False

        with open(file_path, 'r') as data_file:
            logger.info("Reading instances from jsonl dataset at: %s",
                        file_path)
            for line in data_file:
                if is_done:
                    break
                item_json = json.loads(line.strip())
                item_id = item_json.get("id", "NA")
                if self._skip_id_regex and re.match(self._skip_id_regex,
                                                    item_id):
                    continue

                if self._syntax == "rulebase":
                    questions = item_json['questions']
                    if self._use_context_full:
                        context = item_json.get('context_full', '')
                    else:
                        context = item_json.get('context', "")
                elif self._syntax == "propositional-meta":
                    questions = item_json['questions'].items()
                    sentences = [x['text'] for x in item_json['triples'].values()] + \
                                [x['text'] for x in item_json['rules'].values()]
                    if self._scramble_context:
                        random.shuffle(sentences)
                    context = " ".join(sentences)
                else:
                    raise ValueError(f"Unknown syntax {self._syntax}")

                for question in questions:
                    counter -= 1
                    debug -= 1
                    if counter == 0:
                        is_done = True
                        break
                    if debug > 0:
                        logger.info(item_json)
                    if self._syntax == "rulebase":
                        text = question['text']
                        q_id = question.get('id')
                        label = None
                        if 'label' in question:
                            label = 1 if question['label'] else 0
                    elif self._syntax == "propositional-meta":
                        text = question[1]['question']
                        q_id = f"{item_id}-{question[0]}"
                        label = question[1].get('propAnswer')
                        if label is not None:
                            label = ["False", "True", "Unknown"].index(label)

                    yield self.text_to_instance(item_id=q_id,
                                                question_text=text,
                                                context=context,
                                                label=label,
                                                debug=debug)

    @overrides
    def text_to_instance(
            self,  # type: ignore
            item_id: str,
            question_text: str,
            label: int = None,
            context: str = None,
            debug: int = -1) -> Instance:
        # pylint: disable=arguments-differ
        fields: Dict[str, Field] = {}

        qa_tokens, segment_ids = self.transformer_features_from_qa(
            question_text, context)
        qa_field = TextField(qa_tokens, self._token_indexers)
        fields['phrase'] = qa_field

        metadata = {
            "id": item_id,
            "question_text": question_text,
            "tokens": [x.text for x in qa_tokens],
            "context": context
        }

        if label is not None:
            # We'll assume integer labels don't need indexing
            fields['label'] = LabelField(label,
                                         skip_indexing=isinstance(label, int))
            metadata['label'] = label

        if debug > 0:
            logger.info(f"qa_tokens = {qa_tokens}")
            logger.info(f"context = {context}")
            logger.info(f"label = {label}")

        fields["metadata"] = MetadataField(metadata)

        return Instance(fields)

    def transformer_features_from_qa(self, question: str, context: str):
        if self._add_prefix is not None:
            question = self._add_prefix.get("q", "") + question
            context = self._add_prefix.get("c", "") + context
        if context is not None:
            tokens = self._tokenizer.tokenize_sentence_pair(question, context)
        else:
            tokens = self._tokenizer.tokenize(question)
        segment_ids = [0] * len(tokens)

        return tokens, segment_ids
예제 #8
0
class TransformerMCQAReader(DatasetReader):
    """

    Parameters
    ----------
    """
    def __init__(self,
                 pretrained_model: str,
                 max_pieces: int = 512,
                 num_choices: int = 4,
                 add_prefix: Dict[str, str] = None,
                 sample: int = -1) -> None:
        super().__init__()

        self._tokenizer = PretrainedTransformerTokenizer(pretrained_model)
        self._tokenizer_internal = self._tokenizer._tokenizer
        token_indexer = PretrainedTransformerIndexer(pretrained_model)
        self._token_indexers = {'tokens': token_indexer}

        self._max_pieces = max_pieces
        self._sample = sample
        self._num_choices = num_choices
        self._add_prefix = add_prefix or {}

        for model in [
                "roberta", "bert", "openai-gpt", "gpt2", "transfo-xl", "xlnet",
                "xlm"
        ]:
            if model in pretrained_model:
                self._model_type = model
                break

    @overrides
    def _read(self, file_path: str):
        cached_file_path = cached_path(file_path)

        if file_path.endswith('.gz'):
            data_file = gzip.open(cached_file_path, 'rb')
        else:
            data_file = open(cached_file_path, 'r')

        logger.info("Reading QA instances from jsonl dataset at: %s",
                    file_path)
        item_jsons = []
        for line in data_file:
            item_jsons.append(json.loads(line.strip()))

        if self._sample != -1:
            item_jsons = random.sample(item_jsons, self._sample)
            logger.info("Sampling %d examples", self._sample)

        for item_json in Tqdm.tqdm(item_jsons, total=len(item_jsons)):
            item_id = item_json["id"]

            question_text = item_json["question"]["stem"]

            choice_label_to_id = {}
            choice_text_list = []
            choice_context_list = []
            choice_label_list = []
            choice_annotations_list = []

            any_correct = False
            choice_id_correction = 0

            for choice_id, choice_item in enumerate(
                    item_json["question"]["choices"]):
                choice_label = choice_item["label"]
                choice_label_to_id[
                    choice_label] = choice_id - choice_id_correction
                choice_text = choice_item["text"]

                choice_text_list.append(choice_text)
                choice_label_list.append(choice_label)

                if item_json.get('answerKey') == choice_label:
                    if any_correct:
                        raise ValueError(
                            "More than one correct answer found for {item_json}!"
                        )
                    any_correct = True

            if not any_correct and 'answerKey' in item_json:
                raise ValueError("No correct answer found for {item_json}!")

            answer_id = choice_label_to_id.get(item_json.get("answerKey"))
            # Pad choices with empty strings if not right number
            if len(choice_text_list) != self._num_choices:
                choice_text_list = (
                    choice_text_list +
                    self._num_choices * [''])[:self._num_choices]
                choice_context_list = (
                    choice_context_list +
                    self._num_choices * [None])[:self._num_choices]
                if answer_id is not None and answer_id >= self._num_choices:
                    logging.warning(
                        f"Skipping question with more than {self._num_choices} answers: {item_json}"
                    )
                    continue

            yield self.text_to_instance(item_id=item_id,
                                        question=question_text,
                                        choice_list=choice_text_list,
                                        answer_id=answer_id)

        data_file.close()

    @overrides
    def text_to_instance(
            self,  # type: ignore
            item_id: str,
            question: str,
            choice_list: List[str],
            answer_id: int = None) -> Instance:
        fields: Dict[str, Field] = {}

        qa_fields = []
        segment_ids_fields = []
        qa_tokens_list = []
        annotation_tags_fields = []
        for idx, choice in enumerate(choice_list):
            choice_annotations = []
            qa_tokens, segment_ids = self.transformer_features_from_qa(
                question, choice)
            qa_field = TextField(qa_tokens, self._token_indexers)
            segment_ids_field = SequenceLabelField(segment_ids, qa_field)
            qa_fields.append(qa_field)
            qa_tokens_list.append(qa_tokens)
            segment_ids_fields.append(segment_ids_field)

        fields['question'] = ListField(qa_fields)
        fields['segment_ids'] = ListField(segment_ids_fields)
        if answer_id is not None:
            fields['label'] = LabelField(answer_id, skip_indexing=True)

        metadata = {
            "id": item_id,
            "question_text": question,
            "choice_text_list": choice_list,
            "correct_answer_index": answer_id,
            "question_tokens_list": qa_tokens_list,
        }

        if len(annotation_tags_fields) > 0:
            fields['annotation_tags'] = ListField(annotation_tags_fields)
            metadata['annotation_tags'] = [
                x.array for x in annotation_tags_fields
            ]

        fields["metadata"] = MetadataField(metadata)

        return Instance(fields)

    def transformer_features_from_qa(self, question: str, answer: str):
        question = self._add_prefix.get("q", "") + question
        answer = self._add_prefix.get("a", "") + answer

        # Alon changing mask type:
        if self._model_type in ['roberta', 'xlnet']:
            question = question.replace('[MASK]', '<mask>')
        elif self._model_type in ['albert']:
            question = question.replace('[MASK]', '[MASK]>')

        tokens = self._tokenizer.tokenize_sentence_pair(question, answer)

        # TODO make sure the segments IDs do not contribute
        segment_ids = [0] * len(tokens)

        return tokens, segment_ids