def test_vocab(self):
     vocab = vocabularies.ByteVocabulary()
     self.assertEqual(259, vocab.vocab_size)
     self.assertSequenceEqual(_TEST_BYTE_IDS,
                              vocab.encode(_TEST_STRING.decode()))
     self.assertEqual(_TEST_STRING,
                      tf.compat.as_bytes(vocab.decode(_TEST_BYTE_IDS)))
     self.assertEqual(_TEST_BYTE_IDS,
                      tuple(vocab.encode_tf(_TEST_STRING).numpy()))
     self.assertEqual(_TEST_STRING, vocab.decode_tf(_TEST_BYTE_IDS).numpy())
示例#2
0
 def test_vocab(self):
     vocab = vocabularies.ByteVocabulary()
     self.assertEqual(259, vocab.vocab_size)
     self.assertSequenceEqual(self.TEST_BYTE_IDS,
                              vocab.encode(self.TEST_STRING))
     self.assertEqual(self.TEST_STRING, vocab.decode(self.TEST_BYTE_IDS))
     self.assertEqual(self.TEST_BYTE_IDS,
                      tuple(vocab.encode_tf(self.TEST_STRING).numpy()))
     self.assertEqual(self.TEST_STRING, _decode_tf(vocab,
                                                   self.TEST_BYTE_IDS))
 def test_not_equal(self):
     vocab1 = vocabularies.ByteVocabulary()
     vocab2 = vocabularies.ByteVocabulary(10)
     self.assertNotEqual(vocab1, vocab2)
 def test_out_of_vocab(self):
     vocab = vocabularies.ByteVocabulary()
     self.assertEqual(259, vocab.vocab_size)
     self.assertEqual("", vocab.decode([260]))
 def test_extra_ids(self):
     vocab = vocabularies.ByteVocabulary(extra_ids=10)
     self.assertEqual(269, vocab.vocab_size)
     self.assertEqual("a", vocab.decode([100]))
     self.assertEqual("", vocab.decode([268]))