def _saveVocab(self, name, words): vocab = vocab_lib.Vocab() for word in words: vocab.add(str(word)) vocab_file = os.path.join(self.get_temp_dir(), name) vocab.serialize(vocab_file) return vocab_file
def testSimpleVocab(self): vocab = vocab_lib.Vocab() self.assertEqual(0, vocab.size) vocab.add("toto") vocab.add("toto") vocab.add("toto") vocab.add("titi") vocab.add("titi") vocab.add("tata") self.assertEqual(3, vocab.size) self.assertEqual(1, vocab.lookup("titi")) self.assertEqual("titi", vocab.lookup(1)) pruned_size = vocab.prune(max_size=2) self.assertEqual(2, pruned_size.size) self.assertEqual(None, pruned_size.lookup("tata")) pruned_frequency = vocab.prune(min_frequency=3) self.assertEqual(1, pruned_frequency.size) self.assertEqual(0, pruned_frequency.lookup("toto"))
def make_vocab_from_file(path, data_file): vocabulary = vocab.Vocab(special_tokens=[ constants.PADDING_TOKEN, constants.START_OF_SENTENCE_TOKEN, constants.END_OF_SENTENCE_TOKEN]) vocabulary.add_from_text(data_file) vocabulary.serialize(path) return path
def testVocabPadding(self): vocab = vocab_lib.Vocab() vocab.add("toto") vocab.add("titi") vocab.add("tata") self.assertEqual(vocab.size, 3) vocab.pad_to_multiple(6, num_oov_buckets=1) self.assertEqual(vocab.size, 6 - 1)
def make_vocab(path, tokens): vocabulary = vocab.Vocab(special_tokens=[ constants.PADDING_TOKEN, constants.START_OF_SENTENCE_TOKEN, constants.END_OF_SENTENCE_TOKEN ]) for token in tokens: vocabulary.add(token) vocabulary.serialize(path) return path
def testVocabSaveAndLoad(self): vocab1 = vocab_lib.Vocab(special_tokens=["foo", "bar"]) vocab1.add("toto") vocab1.add("toto") vocab1.add("toto") vocab1.add("titi") vocab1.add("titi") vocab1.add("tata") vocab_file = os.path.join(self.get_temp_dir(), "vocab.txt") vocab1.serialize(vocab_file) vocab2 = vocab_lib.Vocab.from_file(vocab_file) self.assertEqual(vocab1.size, vocab2.size) self.assertEqual(vocab1.lookup("titi"), vocab2.lookup("titi"))
def testVocabWithSpecialTokens(self): vocab = vocab_lib.Vocab(special_tokens=["foo", "bar"]) self.assertEqual(2, vocab.size) vocab.add("toto") vocab.add("toto") vocab.add("toto") vocab.add("titi") vocab.add("titi") vocab.add("tata") self.assertEqual(5, vocab.size) self.assertEqual(3, vocab.lookup("titi")) pruned_size = vocab.prune(max_size=3) self.assertEqual(3, pruned_size.size) self.assertEqual(0, pruned_size.lookup("foo")) self.assertEqual(1, pruned_size.lookup("bar"))