def test_custom_vocab(self, method, expected_pad_value, expected_unk_value): vocab = Vocabulary(method=method, use_pad=False, use_unk=False) predefined = { "[PAD]": -1, "[CLS]": -1, "[UNK]": -1, "a": 2, "b": 3, "c": 4, } for e, count in predefined.items(): if count == -1: vocab.add_special_element(e) else: vocab.add_element(e, count=count) # Set the first element [PAD] to be the padding value. vocab.mark_special_element(0, "PAD") # Set the third element [UNK] to be the unknown value. vocab.mark_special_element(2, "UNK") # Check that padding values are the same as the expected representation. self.assertEqual(vocab.get_pad_value(), expected_pad_value) self.assertEqual(vocab.element2repr("[PAD]"), expected_pad_value) # Check that unknown words are mapped to expected representation. self.assertEqual(vocab.element2repr("something else"), expected_unk_value) for i in [0, 1, 2]: self.assertTrue(vocab.is_special_token(i)) with self.assertRaises(InvalidOperationException): vocab.get_count(i)
def test_freq_filtering(self, need_pad, use_unk, special_tokens): base_vocab = Vocabulary(use_pad=need_pad, use_unk=use_unk, special_tokens=special_tokens) for p in dataset_path_iterator(self.data_path, ".txt"): with open(p) as f: for line in f: for w in line.strip().split(): base_vocab.add_element(w) vocab_filter = FrequencyVocabFilter(base_vocab, min_frequency=2, max_frequency=4) filtered = base_vocab.filter(vocab_filter) for e, eid in base_vocab.vocab_items(): if base_vocab.is_special_token(eid): # Check that the filtered vocab have all special elements. self.assertTrue(filtered.has_element(e)) else: base_count = base_vocab.get_count(e) if 2 <= base_count <= 4: self.assertTrue(filtered.has_element(e)) self.assertEqual(base_count, filtered.get_count(e)) else: self.assertFalse(filtered.has_element(e)) self.assertEqual(len(base_vocab._element2id), len(base_vocab._id2element))