def test_vocab(self): vocab = vocabularies.ByteVocabulary() self.assertEqual(259, vocab.vocab_size) self.assertSequenceEqual(self.TEST_BYTE_IDS, vocab.encode(self.TEST_STRING)) self.assertEqual(self.TEST_STRING, vocab.decode(self.TEST_BYTE_IDS)) self.assertEqual( self.TEST_BYTE_IDS, tuple(vocab.encode_tf(self.TEST_STRING).numpy())) self.assertEqual(self.TEST_STRING, _decode_tf(vocab, self.TEST_BYTE_IDS))
def test_not_equal(self): vocab1 = vocabularies.ByteVocabulary() vocab2 = vocabularies.ByteVocabulary(10) self.assertNotEqual(vocab1, vocab2)
def test_out_of_vocab(self): vocab = vocabularies.ByteVocabulary() self.assertEqual(259, vocab.vocab_size) self.assertEqual("", vocab.decode([260]))
def test_extra_ids(self): vocab = vocabularies.ByteVocabulary(extra_ids=10) self.assertEqual(269, vocab.vocab_size) self.assertEqual("a", vocab.decode([100])) self.assertEqual("", vocab.decode([268]))