예제 #1
0
    def test_freq_filtering(self, need_pad, use_unk, special_tokens):
        base_vocab = Vocabulary(use_pad=need_pad,
                                use_unk=use_unk,
                                special_tokens=special_tokens)

        for p in dataset_path_iterator(self.data_path, ".txt"):
            with open(p) as f:
                for line in f:
                    for w in line.strip().split():
                        base_vocab.add_element(w)

        vocab_filter = FrequencyVocabFilter(base_vocab,
                                            min_frequency=2,
                                            max_frequency=4)

        filtered = base_vocab.filter(vocab_filter)

        for e, eid in base_vocab.vocab_items():
            if base_vocab.is_special_token(eid):
                # Check that the filtered vocab have all special elements.
                self.assertTrue(filtered.has_element(e))
            else:
                base_count = base_vocab.get_count(e)
                if 2 <= base_count <= 4:
                    self.assertTrue(filtered.has_element(e))
                    self.assertEqual(base_count, filtered.get_count(e))
                else:
                    self.assertFalse(filtered.has_element(e))

        self.assertEqual(len(base_vocab._element2id),
                         len(base_vocab._id2element))
예제 #2
0
    def initialize(self, config: Union[Dict, Config]):
        self.config = Config(config, self.default_configs())

        if self.config.vocab_method != "custom":
            self._vocab = Vocabulary(
                method=self.config.vocab_method,
                use_pad=self.config.need_pad,
                use_unk=self.config.vocab_use_unk,
                pad_value=self.config.pad_value,
                unk_value=self.config.unk_value,
            )
        else:
            self._vocab = None
        self._vocab_method = self.config.vocab_method
예제 #3
0
    def __init__(self, config: Union[Dict, Config]):
        self.config = Config(config, self.default_configs())

        if self.config.entry_type is None:
            raise AttributeError("entry_type needs to be specified in "
                                 "the configuration of an extractor.")

        if self.config.vocab_method != "raw":
            self.vocab: Optional[Vocabulary] = \
                Vocabulary(method=self.config.vocab_method,
                           need_pad=self.config.need_pad,
                           use_unk=self.config.vocab_use_unk)
        else:
            self.vocab = None
예제 #4
0
    def initialize(self, config: Union[Dict, Config]):
        # pylint: disable=attribute-defined-outside-init
        self.config = Config(config, self.default_configs())
        if self.config.entry_type is None:
            raise AttributeError("`entry_type` needs to be specified in "
                                 "the configuration of an extractor.")
        self._entry_type = get_class(self.config.entry_type)

        if self.config.vocab_method != "custom":
            self._vocab = Vocabulary(
                method=self.config.vocab_method,
                use_pad=self.config.need_pad,
                use_unk=self.config.vocab_use_unk,
                pad_value=self.config.pad_value,
                unk_value=self.config.unk_value,
            )
        else:
            self._vocab = None
        self._vocab_method = self.config.vocab_method
예제 #5
0
    def test_indexing_vocab(self):
        methods = ["indexing", "one-hot"]
        flags = [True, False]
        for method, need_pad, use_unk in product(methods, flags, flags):
            vocab = Vocabulary(method=method,
                               need_pad=need_pad,
                               use_unk=use_unk)

            # Check vocabulary add_element, element2repr and id2element
            elements = [
                "EU", "rejects", "German", "call", "to", "boycott", "British",
                "lamb", "."
            ]
            for ele in elements:
                vocab.add_element(ele)
            save_len = len(vocab)
            for ele in elements:
                vocab.add_element(ele)
            self.assertEqual(save_len, len(vocab))

            representation = [vocab.element2repr(ele) for ele in elements]

            self.assertTrue(len(representation) > 0)

            if method == "indexing":
                self.assertTrue(isinstance(representation[0], int))
            else:
                self.assertTrue(isinstance(representation[0], list))

            recovered_elements = []
            for rep in representation:
                if method == "indexing":
                    idx = rep
                else:
                    idx = self.argmax(rep)
                recovered_elements.append(vocab.id2element(idx))

            self.assertListEqual(elements, recovered_elements)

            # Check __len__, items.
            self.assertEqual(
                len(set(elements)) + int(use_unk) + int(need_pad), len(vocab))
            saved_len = len(vocab)

            # Check has_element
            for ele in elements:
                self.assertTrue(vocab.has_element(ele))
            for ele in range(10):
                self.assertFalse(vocab.has_element(ele))

            # check PAD_ELEMENT
            if need_pad:
                if method == "indexing":
                    expected_pad_repr = 0
                else:
                    expected_pad_repr = [0] * (len(vocab) - 1)
                self.assertEqual(expected_pad_repr,
                                 vocab.element2repr(Vocabulary.PAD_ELEMENT))

            # Check UNK_ELEMENT
            if use_unk:
                if method == "indexing":
                    expected_unk_repr = 0 + int(need_pad)
                else:
                    expected_unk_repr = [0] * (len(vocab) - int(need_pad))
                    expected_unk_repr[0] = 1
                self.assertEqual(expected_unk_repr,
                                 vocab.element2repr(Vocabulary.UNK_ELEMENT))
                self.assertEqual(expected_unk_repr,
                                 vocab.element2repr("random_element"))
                self.assertEqual(saved_len, len(vocab))
예제 #6
0
    def test_vocabulary(self):
        methods = ["indexing", "one-hot"]
        flags = [True, False]
        for method, need_pad, use_unk in product(methods, flags, flags):
            # As stated here: https://github.com/python/typing/issues/511
            # If we use the generic type here we cannot pickle the class
            # in python 3.6 or earlier (the issue is fixed in 3.7).
            # So here we do not use the type annotation for testing.
            vocab = Vocabulary(method=method,
                               use_pad=need_pad,
                               use_unk=use_unk)

            # Check vocabulary add_element, element2repr and id2element
            elements = [
                "EU",
                "rejects",
                "German",
                "call",
                "to",
                "boycott",
                "British",
                "lamb",
                ".",
            ]
            for ele in elements:
                vocab.add_element(ele)
            save_len = len(vocab)
            for ele in elements:
                vocab.add_element(ele)
            self.assertEqual(save_len, len(vocab))

            representation = [vocab.element2repr(ele) for ele in elements]

            self.assertTrue(len(representation) > 0)

            if method == "indexing":
                self.assertTrue(isinstance(representation[0], int))
            else:
                self.assertTrue(isinstance(representation[0], list))

            recovered_elements = []
            for rep in representation:
                if method == "indexing":
                    idx = rep
                else:
                    idx = self.argmax(rep)
                recovered_elements.append(vocab.id2element(idx))

            self.assertListEqual(elements, recovered_elements)

            # Check __len__, items.
            self.assertEqual(
                len(set(elements)) + int(use_unk) + int(need_pad), len(vocab))
            saved_len = len(vocab)

            # Check has_element
            for ele in elements:
                self.assertTrue(vocab.has_element(ele))
            for ele in range(10):
                self.assertFalse(vocab.has_element(ele))

            # check PAD_ELEMENT
            if need_pad:
                if method == "indexing":
                    expected_pad_repr = 0
                else:
                    expected_pad_repr = [0] * (len(vocab) - 1)
                self.assertEqual(expected_pad_repr,
                                 vocab.element2repr(SpecialTokens.PAD))

            # Check UNK_ELEMENT
            if use_unk:
                if method == "indexing":
                    expected_unk_repr = 0 + int(need_pad)
                else:
                    expected_unk_repr = [0] * (len(vocab) - int(need_pad))
                    expected_unk_repr[0] = 1
                self.assertEqual(expected_unk_repr,
                                 vocab.element2repr(SpecialTokens.UNK))
                self.assertEqual(expected_unk_repr,
                                 vocab.element2repr("random_element"))
                self.assertEqual(saved_len, len(vocab))

            # Check state
            new_vocab = pkl.loads(pkl.dumps(vocab))
            self.assertEqual(vocab.method, new_vocab.method)
            self.assertEqual(vocab.use_pad, new_vocab.use_pad)
            self.assertEqual(vocab.use_unk, new_vocab.use_unk)
            self.assertEqual(vocab._element2id, new_vocab._element2id)
            self.assertEqual(vocab._id2element, new_vocab._id2element)
            self.assertEqual(vocab.next_id, new_vocab.next_id)
예제 #7
0
    def test_custom_vocab(self, method, expected_pad_value,
                          expected_unk_value):
        vocab = Vocabulary(method=method, use_pad=False, use_unk=False)
        predefined = {
            "[PAD]": -1,
            "[CLS]": -1,
            "[UNK]": -1,
            "a": 2,
            "b": 3,
            "c": 4,
        }
        for e, count in predefined.items():
            if count == -1:
                vocab.add_special_element(e)
            else:
                vocab.add_element(e, count=count)

        # Set the first element [PAD] to be the padding value.
        vocab.mark_special_element(0, "PAD")
        # Set the third element [UNK] to be the unknown value.
        vocab.mark_special_element(2, "UNK")

        # Check that padding values are the same as the expected representation.
        self.assertEqual(vocab.get_pad_value(), expected_pad_value)
        self.assertEqual(vocab.element2repr("[PAD]"), expected_pad_value)

        # Check that unknown words are mapped to expected representation.
        self.assertEqual(vocab.element2repr("something else"),
                         expected_unk_value)

        for i in [0, 1, 2]:
            self.assertTrue(vocab.is_special_token(i))
            with self.assertRaises(InvalidOperationException):
                vocab.get_count(i)