示例#1
0
    def test_length_limit_works(self):
        max_query_length = 10
        stride = 20

        reader = TransformerSquadReader(
            length_limit=100,
            max_query_length=max_query_length,
            stride=stride,
            skip_invalid_examples=False,
        )
        instances = ensure_list(
            reader.read(FIXTURES_ROOT / "rc" / "squad.json"))

        assert len(instances) == 12
        # The sequence is "<s> question </s> </s> context".
        assert instances[0].fields["context_span"].span_start == len(
            reader._tokenizer.sequence_pair_start_tokens
        ) + max_query_length + len(reader._tokenizer.sequence_pair_mid_tokens)

        instance_0_text = [
            t.text for t in instances[0].fields["question_with_context"].tokens
        ]
        instance_1_text = [
            t.text for t in instances[1].fields["question_with_context"].tokens
        ]
        assert instance_0_text[:max_query_length +
                               2] == instance_1_text[:max_query_length + 2]
        assert instance_0_text[max_query_length +
                               3] != instance_1_text[max_query_length + 3]
        assert instance_0_text[-1] == "[SEP]"
        assert instance_0_text[-2] == "##rot"
        assert (
            instance_1_text[instances[1].fields["context_span"].span_start +
                            stride - 1] == "##rot")
示例#2
0
    def test_read_from_file(self):
        reader = TransformerSquadReader()
        instances = ensure_list(
            reader.read(FIXTURES_ROOT / "rc" / "squad.json"))
        assert len(instances) == 5

        token_text = [
            t.text for t in instances[0].fields["question_with_context"].tokens
        ]
        token_type_ids = [
            t.type_id
            for t in instances[0].fields["question_with_context"].tokens
        ]

        assert token_text[:3] == ["[CLS]", "To", "whom"]
        assert token_type_ids[:3] == [0, 0, 0]

        assert token_text[-3:] == ["Mary", ".", "[SEP]"]
        assert token_type_ids[-3:] == [1, 1, 1]

        assert token_text[
            instances[0].fields["context_span"].span_start] == "Architectural"
        assert token_type_ids[
            instances[0].fields["context_span"].span_start] == 1

        assert token_text[instances[0].fields["context_span"].span_end +
                          1] == "[SEP]"
        assert token_type_ids[instances[0].fields["context_span"].span_end +
                              1] == 1
        assert token_text[instances[0].fields["context_span"].span_end] == "."
        assert token_type_ids[
            instances[0].fields["context_span"].span_end] == 1

        assert token_text[instances[0].fields["answer_span"].span_start:
                          instances[0].fields["answer_span"].span_end + 1] == [
                              "Saint", "Bern", "##ade", "##tte", "So", "##ubi",
                              "##rous"
                          ]

        for instance in instances:
            token_type_ids = [
                t.type_id
                for t in instance.fields["question_with_context"].tokens
            ]
            context_start = instance.fields["context_span"].span_start
            context_end = instance.fields["context_span"].span_end + 1
            assert all(id == 0 for id in token_type_ids[:context_start])
            assert all(id == 1
                       for id in token_type_ids[context_start:context_end])
示例#3
0
    def test_read_from_file_squad2(self, include_cls_index: bool):
        reader = TransformerSquadReader()

        # This should be `False` to begin with since the `[CLS]` token is the first
        # token with BERT.
        assert reader._include_cls_index is False
        reader._include_cls_index = include_cls_index

        instances = ensure_list(
            reader.read(FIXTURES_ROOT / "rc" / "squad2.json"))
        assert len(instances) == 6

        token_text = [
            t.text for t in instances[0].fields["question_with_context"].tokens
        ]
        token_type_ids = [
            t.type_id
            for t in instances[0].fields["question_with_context"].tokens
        ]

        assert token_text[:3] == ["[CLS]", "This", "is"]
        assert token_type_ids[:3] == [0, 0, 0]

        assert token_text[-3:] == ["Mary", ".", "[SEP]"]
        assert token_type_ids[-3:] == [1, 1, 1]

        for instance in instances:
            tokens = instance.fields["question_with_context"].tokens
            token_type_ids = [t.type_id for t in tokens]
            context_start = instance.fields["context_span"].span_start
            context_end = instance.fields["context_span"].span_end + 1
            assert all(id == 0 for id in token_type_ids[:context_start])
            assert all(id == 1
                       for id in token_type_ids[context_start:context_end])
            if include_cls_index:
                assert tokens[instance.fields["cls_index"].
                              sequence_index].text == "[CLS]"
示例#4
0
    def test_roberta_bug(self):
        """This reader tokenizes first by spaces, and then re-tokenizes using the wordpiece tokenizer that comes
        with the transformer model. For RoBERTa, this produces a bug, since RoBERTa tokens are different depending
        on whether they are preceded by a space, and the first round of tokenization cuts off the spaces. The
        reader has a workaround for this case. This tests that workaround."""
        reader = TransformerSquadReader(transformer_model_name="roberta-base")
        instances = ensure_list(
            reader.read(FIXTURES_ROOT / "rc" / "squad.json"))
        assert instances
        assert len(instances) == 5
        token_text = [
            t.text for t in instances[1].fields["question_with_context"].tokens
        ]
        token_ids = [
            t.text_id
            for t in instances[1].fields["question_with_context"].tokens
        ]

        assert token_text[:3] == ["<s>", "What", "Ġsits"]
        assert token_ids[:3] == [
            0,
            2264,
            6476,
        ]
 def setup_method(self):
     super().setup_method()
     self.reader = TransformerSquadReader(length_limit=50, stride=10)
     self.vocab = Vocabulary()
     self.model = TransformerQA(self.vocab)
     self.predictor = TransformerQAPredictor(self.model, self.reader)