コード例 #1
0
    def test_custom_vocab(self, method, expected_pad_value,
                          expected_unk_value):
        vocab = Vocabulary(method=method, use_pad=False, use_unk=False)
        predefined = {
            "[PAD]": -1,
            "[CLS]": -1,
            "[UNK]": -1,
            "a": 2,
            "b": 3,
            "c": 4,
        }
        for e, count in predefined.items():
            if count == -1:
                vocab.add_special_element(e)
            else:
                vocab.add_element(e, count=count)

        # Set the first element [PAD] to be the padding value.
        vocab.mark_special_element(0, "PAD")
        # Set the third element [UNK] to be the unknown value.
        vocab.mark_special_element(2, "UNK")

        # Check that padding values are the same as the expected representation.
        self.assertEqual(vocab.get_pad_value(), expected_pad_value)
        self.assertEqual(vocab.element2repr("[PAD]"), expected_pad_value)

        # Check that unknown words are mapped to expected representation.
        self.assertEqual(vocab.element2repr("something else"),
                         expected_unk_value)

        for i in [0, 1, 2]:
            self.assertTrue(vocab.is_special_token(i))
            with self.assertRaises(InvalidOperationException):
                vocab.get_count(i)
コード例 #2
0
    def test_freq_filtering(self, need_pad, use_unk, special_tokens):
        base_vocab = Vocabulary(use_pad=need_pad,
                                use_unk=use_unk,
                                special_tokens=special_tokens)

        for p in dataset_path_iterator(self.data_path, ".txt"):
            with open(p) as f:
                for line in f:
                    for w in line.strip().split():
                        base_vocab.add_element(w)

        vocab_filter = FrequencyVocabFilter(base_vocab,
                                            min_frequency=2,
                                            max_frequency=4)

        filtered = base_vocab.filter(vocab_filter)

        for e, eid in base_vocab.vocab_items():
            if base_vocab.is_special_token(eid):
                # Check that the filtered vocab have all special elements.
                self.assertTrue(filtered.has_element(e))
            else:
                base_count = base_vocab.get_count(e)
                if 2 <= base_count <= 4:
                    self.assertTrue(filtered.has_element(e))
                    self.assertEqual(base_count, filtered.get_count(e))
                else:
                    self.assertFalse(filtered.has_element(e))

        self.assertEqual(len(base_vocab._element2id),
                         len(base_vocab._id2element))