def test_translation_multiple_encoders(self): # Unicode integer-encoded by byte self.assertFeature( feature=features.Translation(languages=["en", "zh"], encoder=[ text_encoder.TokenTextEncoder( ["hello", " "]), text_encoder.ByteTextEncoder() ]), shape={ "en": (None, ), "zh": (None, ) }, dtype={ "en": tf.int64, "zh": tf.int64 }, tests=[ testing.FeatureExpectationItem( value={ "en": EN_HELLO, "zh": ZH_HELLO }, expected={ "en": [1], "zh": [i + 1 for i in [228, 189, 160, 229, 165, 189, 32]] }, ), ], skip_feature_tests=True)
def test_tokenization(self): encoder = text_encoder.TokenTextEncoder(vocab_list=['hi', 'bye', ZH_HELLO]) text = 'hi<<>><<>foo!^* bar && bye (%s hi)' % ZH_HELLO self.assertEqual(['hi', 'foo', 'bar', 'bye', ZH_HELLO.strip(), 'hi'], text_encoder.Tokenizer().tokenize(text)) self.assertEqual([i + 1 for i in [0, 3, 3, 1, 2, 0]], encoder.encode(text))
def test_file_backed_with_args(self): with testing.tmp_dir(self.get_temp_dir()) as tmp_dir: # Set all the args to non-default values, including Tokenizer tokenizer = text_encoder.Tokenizer( reserved_tokens=['<FOOBAR>'], alphanum_only=False) encoder = text_encoder.TokenTextEncoder( vocab_list=['hi', 'bye', ZH_HELLO], lowercase=True, oov_buckets=2, oov_token='ZOO', tokenizer=tokenizer) vocab_fname = os.path.join(tmp_dir, 'vocab') encoder.save_to_file(vocab_fname) file_backed_encoder = text_encoder.TokenTextEncoder.load_from_file( vocab_fname) self.assertEqual(encoder.tokens, file_backed_encoder.tokens) self.assertEqual(encoder.vocab_size, file_backed_encoder.vocab_size) self.assertEqual(encoder.lowercase, file_backed_encoder.lowercase) self.assertEqual(encoder.oov_token, file_backed_encoder.oov_token) self.assertEqual(encoder.tokenizer.alphanum_only, file_backed_encoder.tokenizer.alphanum_only) self.assertEqual(encoder.tokenizer.reserved_tokens, file_backed_encoder.tokenizer.reserved_tokens)
def test_oov(self): encoder = text_encoder.TokenTextEncoder( vocab_list=['hi', 'bye', ZH_HELLO], oov_buckets=1, oov_token='UNK') ids = [i + 1 for i in [0, 3, 3, 1]] self.assertEqual(ids, encoder.encode('hi boo foo bye')) self.assertEqual('hi UNK UNK bye', encoder.decode(ids)) self.assertEqual(5, encoder.vocab_size)
def test_lowercase(self): mixed_tokens = ['<EOS>', 'zoo!'] vocab_list = mixed_tokens + ['hi', 'bye', ZH_HELLO] encoder = text_encoder.TokenTextEncoder( vocab_list=vocab_list, lowercase=True) # No mixed tokens self.assertEqual([3, 4, 3], encoder.encode('hi byE HI!')) # With mixed tokens self.assertEqual([3, 1, 4, 2, 3], encoder.encode('hi<EOS>byE Zoo! HI!'))
def test_file_backed(self): with testing.tmp_dir(self.get_temp_dir()) as tmp_dir: vocab_fname = os.path.join(tmp_dir, 'vocab') encoder = text_encoder.TokenTextEncoder( vocab_list=['hi', 'bye', ZH_HELLO]) encoder.save_to_file(vocab_fname) file_backed_encoder = text_encoder.TokenTextEncoder.load_from_file( vocab_fname) self.assertEqual(encoder.tokens, file_backed_encoder.tokens)
def test_multiple_oov(self): encoder = text_encoder.TokenTextEncoder( vocab_list=['hi', 'bye', ZH_HELLO], oov_buckets=2, oov_token='UNK') encoded = encoder.encode('hi boo zoo too foo bye') self.assertEqual(1, encoded[0]) self.assertEqual(2, encoded[-1]) self.assertIn(4, encoded) self.assertIn(5, encoded) self.assertEqual(6, encoder.vocab_size) self.assertEqual('hi UNK UNK bye', encoder.decode([1, 4, 5, 2]))
def test_with_tokenizer(self): class DummyTokenizer(object): def tokenize(self, s): del s return ['hi', 'bye'] tokenizer = DummyTokenizer() vocab_list = ['hi', 'bye', ZH_HELLO] encoder = text_encoder.TokenTextEncoder(vocab_list=vocab_list, tokenizer=tokenizer) # Ensure it uses the passed tokenizer and not the default self.assertEqual([1, 2], encoder.encode('zoo foo'))
def test_mixedalphanum_tokens(self): mixed_tokens = ['<EOS>', 'zoo!', '!foo'] vocab_list = mixed_tokens + ['hi', 'bye', ZH_HELLO] encoder = text_encoder.TokenTextEncoder(vocab_list=vocab_list) # No mixed tokens text = 'hi<<>><<>foo!^* bar && bye (%s hi)' % ZH_HELLO # hi=3, foo=OOV, bar=OOV, bye=4, ZH_HELLO=5, hi=3 text_ids = [i + 1 for i in [3, 6, 6, 4, 5, 3]] self.assertEqual(text_ids, encoder.encode(text)) # With mixed tokens text = 'hi<<>><<>foo!<EOS>^* barzoo! FOO && bye (%s hi)' % ZH_HELLO # hi=3, foo=OOV, <EOS>=0, bar=OOV, zoo!=1, FOO=OOV, bye=4, ZH_HELLO=5, hi=3 text_ids = [i + 1 for i in [3, 6, 0, 6, 1, 6, 4, 5, 3]] self.assertEqual(text_ids, encoder.encode(text))
def test_encode_decode(self): encoder = text_encoder.TokenTextEncoder(vocab_list=['hi', 'bye', ZH_HELLO]) ids = [i + 1 for i in [0, 1, 2, 0]] self.assertEqual(ids, encoder.encode('hi bye %s hi' % ZH_HELLO)) self.assertEqual('hi bye %shi' % ZH_HELLO, encoder.decode(ids))