Ejemplo n.º 1
0
    def test_pad_when_fix_length_is_not_none(self):
        nesting_field = data.Field(tokenize=list, unk_token="<cunk>", pad_token="<cpad>",
                                   init_token="<w>", eos_token="</w>")
        CHARS = data.NestedField(
            nesting_field, init_token="<s>", eos_token="</s>", fix_length=3)
        minibatch = [
            ["john", "loves", "mary"],
            ["mary", "cries"]
        ]
        expected = [
            [
                ["<w>", "<s>", "</w>"] + ["<cpad>"] * 4,
                ["<w>"] + list("john") + ["</w>", "<cpad>"],
                ["<w>", "</s>", "</w>"] + ["<cpad>"] * 4,
            ],
            [
                ["<w>", "<s>", "</w>"] + ["<cpad>"] * 4,
                ["<w>"] + list("mary") + ["</w>", "<cpad>"],
                ["<w>", "</s>", "</w>"] + ["<cpad>"] * 4,
            ]
        ]

        assert CHARS.pad(minibatch) == expected

        # test include length
        nesting_field = data.Field(tokenize=list, unk_token="<cunk>", pad_token="<cpad>",
                                   init_token="<w>", eos_token="</w>")
        CHARS = data.NestedField(nesting_field, init_token="<s>",
                                 eos_token="</s>", include_lengths=True, fix_length=3)
        arr, seq_len, words_len = CHARS.pad(minibatch)
        assert arr == expected
        assert seq_len == [3, 3]
        assert words_len == [[3, 6, 3], [3, 6, 3]]
Ejemplo n.º 2
0
    def test_serialization(self):
        nesting_field = data.Field(batch_first=True)
        field = data.NestedField(nesting_field)
        ex1 = data.Example.fromlist(["john loves mary"], [("words", field)])
        ex2 = data.Example.fromlist(["mary cries"], [("words", field)])
        dataset = data.Dataset([ex1, ex2], [("words", field)])
        field.build_vocab(dataset)
        examples_data = [
            [
                ["<w>", "<s>", "</w>"] + ["<cpad>"] * 4,
                ["<w>"] + list("john") + ["</w>", "<cpad>"],
                ["<w>"] + list("loves") + ["</w>"],
                ["<w>"] + list("mary") + ["</w>", "<cpad>"],
                ["<w>", "</s>", "</w>"] + ["<cpad>"] * 4,
            ],
            [
                ["<w>", "<s>", "</w>"] + ["<cpad>"] * 4,
                ["<w>"] + list("mary") + ["</w>", "<cpad>"],
                ["<w>"] + list("cries") + ["</w>"],
                ["<w>", "</s>", "</w>"] + ["<cpad>"] * 4,
                ["<cpad>"] * 7,
            ]
        ]

        field_pickle_filename = "char_field.pl"
        field_pickle_path = os.path.join(self.test_dir, field_pickle_filename)
        torch.save(field, field_pickle_path)

        loaded_field = torch.load(field_pickle_path)
        assert loaded_field == field

        original_numericalization = field.numericalize(examples_data)
        pickled_numericalization = loaded_field.numericalize(examples_data)

        assert torch.all(torch.eq(original_numericalization, pickled_numericalization))
Ejemplo n.º 3
0
    def test_preprocess(self):
        nesting_field = data.Field(
            tokenize=list, preprocessing=lambda xs: [x.upper() for x in xs])
        field = data.NestedField(nesting_field, preprocessing=lambda xs: reversed(xs))
        preprocessed = field.preprocess("john loves mary")

        assert preprocessed == [list("MARY"), list("LOVES"), list("JOHN")]
Ejemplo n.º 4
0
    def test_pad_when_pad_first_is_true(self):
        nesting_field = data.Field(tokenize=list, unk_token="<cunk>", pad_token="<cpad>",
                                   init_token="<w>", eos_token="</w>")
        CHARS = data.NestedField(nesting_field, init_token="<s>", eos_token="</s>",
                                 pad_first=True)
        minibatch = [
            [list("john"), list("loves"), list("mary")],
            [list("mary"), list("cries")],
        ]
        expected = [
            [
                ["<w>", "<s>", "</w>"] + ["<cpad>"] * 4,
                ["<w>"] + list("john") + ["</w>", "<cpad>"],
                ["<w>"] + list("loves") + ["</w>"],
                ["<w>"] + list("mary") + ["</w>", "<cpad>"],
                ["<w>", "</s>", "</w>"] + ["<cpad>"] * 4,
            ],
            [
                ["<cpad>"] * 7,
                ["<w>", "<s>", "</w>"] + ["<cpad>"] * 4,
                ["<w>"] + list("mary") + ["</w>", "<cpad>"],
                ["<w>"] + list("cries") + ["</w>"],
                ["<w>", "</s>", "</w>"] + ["<cpad>"] * 4,
            ]
        ]

        assert CHARS.pad(minibatch) == expected

        # test include_length
        nesting_field = data.Field(tokenize=list, unk_token="<cunk>", pad_token="<cpad>",
                                   init_token="<w>", eos_token="</w>")
        CHARS = data.NestedField(nesting_field, init_token="<s>",
                                 eos_token="</s>", include_lengths=True,
                                 pad_first=True)
        arr, seq_len, words_len = CHARS.pad(minibatch)
        assert arr == expected
        assert seq_len == [5, 4]
        assert words_len == [[3, 6, 7, 6, 3], [0, 3, 6, 7, 3]]
Ejemplo n.º 5
0
    def test_pad_when_nesting_field_is_not_sequential(self):
        nesting_field = data.Field(sequential=False, unk_token="<cunk>",
                                   pad_token="<cpad>", init_token="<w>", eos_token="</w>")
        CHARS = data.NestedField(nesting_field, init_token="<s>", eos_token="</s>")
        minibatch = [
            ["john", "loves", "mary"],
            ["mary", "cries"]
        ]
        expected = [
            ["<s>", "john", "loves", "mary", "</s>"],
            ["<s>", "mary", "cries", "</s>", "<pad>"],
        ]

        assert CHARS.pad(minibatch) == expected
Ejemplo n.º 6
0
    def test_pad_when_nesting_field_has_fix_length(self):
        nesting_field = data.Field(tokenize=list, unk_token="<cunk>", pad_token="<cpad>",
                                   init_token="<w>", eos_token="</w>", fix_length=5)
        CHARS = data.NestedField(nesting_field, init_token="<s>", eos_token="</s>")
        minibatch = [
            ["john", "loves", "mary"],
            ["mary", "cries"]
        ]
        expected = [
            [
                ["<w>", "<s>", "</w>"] + ["<cpad>"] * 2,
                ["<w>"] + list("joh") + ["</w>"],
                ["<w>"] + list("lov") + ["</w>"],
                ["<w>"] + list("mar") + ["</w>"],
                ["<w>", "</s>", "</w>"] + ["<cpad>"] * 2,
            ],
            [
                ["<w>", "<s>", "</w>"] + ["<cpad>"] * 2,
                ["<w>"] + list("mar") + ["</w>"],
                ["<w>"] + list("cri") + ["</w>"],
                ["<w>", "</s>", "</w>"] + ["<cpad>"] * 2,
                ["<cpad>"] * 5,
            ]
        ]

        assert CHARS.pad(minibatch) == expected

        # test include length
        nesting_field = data.Field(tokenize=list, unk_token="<cunk>", pad_token="<cpad>",
                                   init_token="<w>", eos_token="</w>", fix_length=5)
        CHARS = data.NestedField(nesting_field, init_token="<s>",
                                 eos_token="</s>", include_lengths=True)
        arr, seq_len, words_len = CHARS.pad(minibatch)
        assert arr == expected
        assert seq_len == [5, 4]
        assert words_len == [[3, 5, 5, 5, 3], [3, 5, 5, 3, 0]]
Ejemplo n.º 7
0
    def test_build_vocab_from_iterable(self):
        nesting_field = data.Field(unk_token="<cunk>", pad_token="<cpad>")
        CHARS = data.NestedField(nesting_field)
        CHARS.build_vocab(
            [[list("aaa"), list("bbb"), ["c"]], [list("bbb"), list("aaa")]],
            [[list("ccc"), list("bbb")], [list("bbb")]],
        )

        expected = "a b c <cunk> <cpad>".split()
        assert len(CHARS.vocab) == len(expected)
        for c in expected:
            assert c in CHARS.vocab.stoi

        expected_freqs = Counter({"a": 6, "b": 12, "c": 4})
        assert CHARS.vocab.freqs == CHARS.nesting_field.vocab.freqs == expected_freqs
Ejemplo n.º 8
0
    def test_build_vocab_from_dataset(self):
        nesting_field = data.Field(tokenize=list, unk_token="<cunk>", pad_token="<cpad>",
                                   init_token="<w>", eos_token="</w>")
        CHARS = data.NestedField(nesting_field, init_token="<s>", eos_token="</s>")
        ex1 = data.Example.fromlist(["aaa bbb c"], [("chars", CHARS)])
        ex2 = data.Example.fromlist(["bbb aaa"], [("chars", CHARS)])
        dataset = data.Dataset([ex1, ex2], [("chars", CHARS)])

        CHARS.build_vocab(dataset, min_freq=2)

        expected = "a b <w> </w> <s> </s> <cunk> <cpad>".split()
        assert len(CHARS.vocab) == len(expected)
        for c in expected:
            assert c in CHARS.vocab.stoi

        expected_freqs = Counter({"a": 6, "b": 6, "c": 1})
        assert CHARS.vocab.freqs == CHARS.nesting_field.vocab.freqs == expected_freqs
Ejemplo n.º 9
0
    def test_init_minimal(self):
        nesting_field = data.Field()
        field = data.NestedField(nesting_field)

        assert isinstance(field, data.Field)
        assert field.nesting_field is nesting_field
        assert field.sequential
        assert field.use_vocab
        assert field.init_token is None
        assert field.eos_token is None
        assert field.unk_token == nesting_field.unk_token
        assert field.fix_length is None
        assert field.dtype is torch.long
        assert field.preprocessing is None
        assert field.postprocessing is None
        assert field.lower == nesting_field.lower
        assert field.tokenize("a b c") == "a b c".split()
        assert not field.include_lengths
        assert field.batch_first
        assert field.pad_token == nesting_field.pad_token
        assert not field.pad_first
Ejemplo n.º 10
0
    def test_pad_when_no_init_and_eos_tokens(self):
        nesting_field = data.Field(tokenize=list, unk_token="<cunk>", pad_token="<cpad>",
                                   init_token="<w>", eos_token="</w>")
        CHARS = data.NestedField(nesting_field)
        minibatch = [
            ["john", "loves", "mary"],
            ["mary", "cries"]
        ]
        expected = [
            [
                ["<w>"] + list("john") + ["</w>", "<cpad>"],
                ["<w>"] + list("loves") + ["</w>"],
                ["<w>"] + list("mary") + ["</w>", "<cpad>"],
            ],
            [
                ["<w>"] + list("mary") + ["</w>", "<cpad>"],
                ["<w>"] + list("cries") + ["</w>"],
                ["<cpad>"] * 7,
            ]
        ]

        assert CHARS.pad(minibatch) == expected
Ejemplo n.º 11
0
    def test_init_full(self):
        nesting_field = data.Field()
        field = data.NestedField(
            nesting_field,
            use_vocab=False,
            init_token="<s>",
            eos_token="</s>",
            fix_length=10,
            dtype=torch.float,
            preprocessing=lambda xs: list(reversed(xs)),
            postprocessing=lambda xs: [x.upper() for x in xs],
            tokenize=list,
            pad_first=True,
        )

        assert not field.use_vocab
        assert field.init_token == "<s>"
        assert field.eos_token == "</s>"
        assert field.fix_length == 10
        assert field.dtype is torch.float
        assert field.preprocessing("a b c".split()) == "c b a".split()
        assert field.postprocessing("a b c".split()) == "A B C".split()
        assert field.tokenize("abc") == ["a", "b", "c"]
        assert field.pad_first
Ejemplo n.º 12
0
    def __init__(self, cfg):
        self.train = None
        self.val = None
        self.test = None

        self.dataset_path = cfg.dataset.path
        self.dataset_type = cfg.dataset.type
        self.lazy_reader_type = getattr(cfg.dataset, "lazy_reader_type", None)
        self.dataset_split_strategy = cfg.dataset.split_strategy
        self.usage_vocab_max_size = cfg.dataset.usage_vocab_max_size
        self.usage_vocab_min_freq = cfg.dataset.usage_vocab_min_freq
        self.target_vocab_max_size = cfg.dataset.target_vocab_max_size
        self.target_vocab_min_freq = cfg.dataset.target_vocab_min_freq

        self.max_sequence_length = cfg.model.max_sequence_length
        self.max_num_usages = cfg.model.max_num_usages
        self.max_target_length = cfg.model.max_target_length
        self.usage_tokenizer_type = cfg.model.usage_tokenizer_type
        self.target_tokenizer_type = cfg.model.target_tokenizer_type

        if self.usage_tokenizer_type == "sub_token":
            usage_tokenizer = sub_tokenizer(length=self.max_sequence_length)
        elif self.usage_tokenizer_type == "split":
            usage_tokenizer = split
        else:
            raise ValueError(f"Usage tokenizer of type {self.usage_tokenizer_type} is not supported")

        if self.target_tokenizer_type is None:
            target_tokenizer = to_list
        elif self.target_tokenizer_type == "camel_case":
            target_tokenizer = camel_case_split
        else:
            raise ValueError(f"Target tokenizer of type {self.usage_tokenizer_type} is not supported")

        # Configuring torchtext fields for automatic padding and numericalization of batches
        self.usage_field = data.NestedField(data.Field(sequential=True,
                                                       use_vocab=True,
                                                       init_token=None,
                                                       eos_token=None,
                                                       fix_length=self.max_sequence_length,
                                                       dtype=torch.long,
                                                       tokenize=usage_tokenizer,
                                                       batch_first=True,
                                                       is_target=False),
                                            fix_length=self.max_num_usages,
                                            include_lengths=True)
        self.target_field = data.Field(sequential=True,
                                       use_vocab=True,
                                       init_token=INIT_TOKEN,
                                       eos_token=EOS_TOKEN,
                                       fix_length=self.max_target_length,
                                       dtype=torch.long,
                                       tokenize=target_tokenizer,
                                       batch_first=False,
                                       is_target=True,
                                       include_lengths=True)
        self.example_fields = {'variable': ('target', self.target_field),
                               'ngrams': ('usages', self.usage_field)}
        self.dataset_fields = {'target': self.target_field,
                               'usages': self.usage_field}
        self.setup()
        self.build_vocabs()
Ejemplo n.º 13
0
    def test_numericalize(self):
        nesting_field = data.Field(batch_first=True)
        field = data.NestedField(nesting_field)
        ex1 = data.Example.fromlist(["john loves mary"], [("words", field)])
        ex2 = data.Example.fromlist(["mary cries"], [("words", field)])
        dataset = data.Dataset([ex1, ex2], [("words", field)])
        field.build_vocab(dataset)
        examples_data = [
            [
                ["<w>", "<s>", "</w>"] + ["<cpad>"] * 4,
                ["<w>"] + list("john") + ["</w>", "<cpad>"],
                ["<w>"] + list("loves") + ["</w>"],
                ["<w>"] + list("mary") + ["</w>", "<cpad>"],
                ["<w>", "</s>", "</w>"] + ["<cpad>"] * 4,
            ],
            [
                ["<w>", "<s>", "</w>"] + ["<cpad>"] * 4,
                ["<w>"] + list("mary") + ["</w>", "<cpad>"],
                ["<w>"] + list("cries") + ["</w>"],
                ["<w>", "</s>", "</w>"] + ["<cpad>"] * 4,
                ["<cpad>"] * 7,
            ]
        ]
        numericalized = field.numericalize(examples_data)

        assert numericalized.dim() == 3
        assert numericalized.size(0) == len(examples_data)
        for example, numericalized_example in zip(examples_data, numericalized):
            verify_numericalized_example(
                field, example, numericalized_example, batch_first=True)

        # test include_lengths
        nesting_field = data.Field(batch_first=True)
        field = data.NestedField(nesting_field, include_lengths=True)
        ex1 = data.Example.fromlist(["john loves mary"], [("words", field)])
        ex2 = data.Example.fromlist(["mary cries"], [("words", field)])
        dataset = data.Dataset([ex1, ex2], [("words", field)])
        field.build_vocab(dataset)
        examples_data = [
            [
                ["<w>", "<s>", "</w>"] + ["<cpad>"] * 4,
                ["<w>"] + list("john") + ["</w>", "<cpad>"],
                ["<w>"] + list("loves") + ["</w>"],
                ["<w>"] + list("mary") + ["</w>", "<cpad>"],
                ["<w>", "</s>", "</w>"] + ["<cpad>"] * 4,
            ],
            [
                ["<w>", "<s>", "</w>"] + ["<cpad>"] * 4,
                ["<w>"] + list("mary") + ["</w>", "<cpad>"],
                ["<w>"] + list("cries") + ["</w>"],
                ["<w>", "</s>", "</w>"] + ["<cpad>"] * 4,
                ["<cpad>"] * 7,
            ]
        ]

        numericalized, seq_len, word_len = field.numericalize(
            (examples_data, [5, 4], [[3, 6, 7, 6, 3], [3, 6, 7, 3, 0]]))

        assert numericalized.dim() == 3
        assert len(seq_len) == 2
        assert len(word_len) == 2

        assert numericalized.size(0) == len(examples_data)
        for example, numericalized_example in zip(examples_data, numericalized):
            verify_numericalized_example(
                field, example, numericalized_example, batch_first=True)
Ejemplo n.º 14
0
    def test_init_with_nested_field_as_nesting_field(self):
        nesting_field = data.NestedField(data.Field())

        with pytest.raises(ValueError) as excinfo:
            data.NestedField(nesting_field)
        assert "nesting field must not be another NestedField" in str(excinfo.value)
Ejemplo n.º 15
0
    def test_init_when_nesting_field_has_include_lengths_equal_true(self):
        nesting_field = data.Field(include_lengths=True)

        with pytest.raises(ValueError) as excinfo:
            data.NestedField(nesting_field)
        assert "nesting field cannot have include_lengths=True" in str(excinfo.value)
Ejemplo n.º 16
0
    def test_init_when_nesting_field_is_not_sequential(self):
        nesting_field = data.Field(sequential=False)
        field = data.NestedField(nesting_field)

        assert field.pad_token == "<pad>"