def test_reserved_tokens(self): reserved_tokens = ['<EOS>', 'FOO', 'zoo'] vocab_list = [u'hi', 'bye', ZH_HELLO] # Check that reserved tokens are a prefix of vocab with self.assertRaisesWithPredicateMatch(ValueError, 'must start with'): text_encoder.TokenTextEncoder(vocab_list=vocab_list, reserved_tokens=reserved_tokens) vocab_list = reserved_tokens + vocab_list encoder = text_encoder.TokenTextEncoder( vocab_list=vocab_list, reserved_tokens=reserved_tokens) # No reserved tokens text = 'hi<<>><<>foo!^* bar && bye (%s hi)' % ZH_HELLO text_ids = [i + 1 for i in [3, 6, 6, 4, 5, 3]] self.assertEqual(text_ids, encoder.encode(text)) # With reserved tokens text = 'hi<<>><<>foo!<EOS>^* barFOO && bye (%szoo hi)' % ZH_HELLO reserved_text_ids = list(text_ids) reserved_text_ids.insert(2, 1) # <EOS> reserved_text_ids.insert(4, 2) # FOO reserved_text_ids.insert(7, 3) # zoo self.assertEqual(reserved_text_ids, encoder.encode(text))
def test_file_backed(self): with test_utils.tmp_dir(self.get_temp_dir()) as tmp_dir: vocab_fname = os.path.join(tmp_dir, 'vocab.tokens') encoder = text_encoder.TokenTextEncoder( vocab_list=[u'hi', 'bye', ZH_HELLO]) encoder.store_to_file(vocab_fname) file_backed_encoder = text_encoder.TokenTextEncoder( vocab_file=vocab_fname) self.assertEqual(encoder.tokens, file_backed_encoder.tokens)
def test_translation_multiple_encoders(self): # Unicode integer-encoded by byte self.assertFeature( feature=features.Translation(languages=["en", "zh"], encoder=[ text_encoder.TokenTextEncoder( ["hello", " "]), text_encoder.ByteTextEncoder() ]), shape={ "en": (None, ), "zh": (None, ) }, dtype={ "en": tf.int64, "zh": tf.int64 }, tests=[ testing.FeatureExpectationItem( value={ "en": EN_HELLO, "zh": ZH_HELLO }, expected={ "en": [1], "zh": [i + 1 for i in [228, 189, 160, 229, 165, 189, 32]] }, ), ], )
def test_file_backed_with_args(self): with test_utils.tmp_dir(self.get_temp_dir()) as tmp_dir: # Set all the args to non-default values, including Tokenizer tokenizer = text_encoder.Tokenizer( reserved_tokens=['<FOOBAR>'], alphanum_only=False) encoder = text_encoder.TokenTextEncoder( vocab_list=['hi', 'bye', ZH_HELLO], lowercase=True, oov_buckets=2, oov_token='ZOO', tokenizer=tokenizer) vocab_fname = os.path.join(tmp_dir, 'vocab') encoder.save_to_file(vocab_fname) file_backed_encoder = text_encoder.TokenTextEncoder.load_from_file( vocab_fname) self.assertEqual(encoder.tokens, file_backed_encoder.tokens) self.assertEqual(encoder.vocab_size, file_backed_encoder.vocab_size) self.assertEqual(encoder.lowercase, file_backed_encoder.lowercase) self.assertEqual(encoder.oov_token, file_backed_encoder.oov_token) self.assertEqual(encoder.tokenizer.alphanum_only, file_backed_encoder.tokenizer.alphanum_only) self.assertEqual(encoder.tokenizer.reserved_tokens, file_backed_encoder.tokenizer.reserved_tokens)
def test_tokenization(self): encoder = text_encoder.TokenTextEncoder(vocab_list=['hi', 'bye', ZH_HELLO]) text = 'hi<<>><<>foo!^* bar && bye (%s hi)' % ZH_HELLO self.assertEqual(['hi', 'foo', 'bar', 'bye', ZH_HELLO.strip(), 'hi'], text_encoder.Tokenizer().tokenize(text)) self.assertEqual([i + 1 for i in [0, 3, 3, 1, 2, 0]], encoder.encode(text))
def test_oov(self): encoder = text_encoder.TokenTextEncoder( vocab_list=['hi', 'bye', ZH_HELLO], oov_buckets=1, oov_token='UNK') ids = [i + 1 for i in [0, 3, 3, 1]] self.assertEqual(ids, encoder.encode('hi boo foo bye')) self.assertEqual('hi UNK UNK bye', encoder.decode(ids)) self.assertEqual(5, encoder.vocab_size)
def test_lowercase(self): mixed_tokens = ['<EOS>', 'zoo!'] vocab_list = mixed_tokens + ['hi', 'bye', ZH_HELLO] encoder = text_encoder.TokenTextEncoder( vocab_list=vocab_list, lowercase=True) # No mixed tokens self.assertEqual([3, 4, 3], encoder.encode('hi byE HI!')) # With mixed tokens self.assertEqual([3, 1, 4, 2, 3], encoder.encode('hi<EOS>byE Zoo! HI!'))
def test_file_backed(self): with testing.tmp_dir(self.get_temp_dir()) as tmp_dir: vocab_fname = os.path.join(tmp_dir, 'vocab') encoder = text_encoder.TokenTextEncoder( vocab_list=['hi', 'bye', ZH_HELLO]) encoder.save_to_file(vocab_fname) file_backed_encoder = text_encoder.TokenTextEncoder.load_from_file( vocab_fname) self.assertEqual(encoder.tokens, file_backed_encoder.tokens)
def test_with_tokenizer(self): class DummyTokenizer(object): def tokenize(self, s): del s return ['hi', 'bye'] tokenizer = DummyTokenizer() base_vocab = [u'hi', 'bye', ZH_HELLO] reserved_tokens = ['<EOS>', 'FOO', 'foo', 'zoo'] vocab_list = reserved_tokens + base_vocab with self.assertRaisesWithPredicateMatch( ValueError, 'reserved_tokens must be None'): text_encoder.TokenTextEncoder(vocab_list=vocab_list, reserved_tokens=reserved_tokens, tokenizer=tokenizer) encoder = text_encoder.TokenTextEncoder(vocab_list=base_vocab, tokenizer=tokenizer) # Ensure it uses the passed tokenizer and not the default self.assertEqual([1, 2], encoder.encode('zoo foo'))
def test_multiple_oov(self): encoder = text_encoder.TokenTextEncoder( vocab_list=['hi', 'bye', ZH_HELLO], oov_buckets=2, oov_token='UNK') encoded = encoder.encode('hi boo zoo too foo bye') self.assertEqual(1, encoded[0]) self.assertEqual(2, encoded[-1]) self.assertIn(4, encoded) self.assertIn(5, encoded) self.assertEqual(6, encoder.vocab_size) self.assertEqual('hi UNK UNK bye', encoder.decode([1, 4, 5, 2]))
def test_lowercase(self): reserved_tokens = ['<EOS>', 'FOO', 'foo', 'zoo'] vocab_list = reserved_tokens + [u'hi', 'bye', ZH_HELLO] encoder = text_encoder.TokenTextEncoder( vocab_list=vocab_list, reserved_tokens=reserved_tokens, lowercase=True) # No reserved tokens self.assertEqual([5, 6, 5], encoder.encode('hi byE HI')) # With reserved tokens self.assertEqual([5, 1, 6, 2, 5], encoder.encode('hi<EOS>byE FOO HI'))
def test_with_tokenizer(self): class DummyTokenizer(object): def tokenize(self, s): del s return ['hi', 'bye'] tokenizer = DummyTokenizer() vocab_list = ['hi', 'bye', ZH_HELLO] encoder = text_encoder.TokenTextEncoder(vocab_list=vocab_list, tokenizer=tokenizer) # Ensure it uses the passed tokenizer and not the default self.assertEqual([1, 2], encoder.encode('zoo foo'))
def test_mixedalphanum_tokens(self): mixed_tokens = ['<EOS>', 'zoo!', '!foo'] vocab_list = mixed_tokens + ['hi', 'bye', ZH_HELLO] encoder = text_encoder.TokenTextEncoder(vocab_list=vocab_list) # No mixed tokens text = 'hi<<>><<>foo!^* bar && bye (%s hi)' % ZH_HELLO # hi=3, foo=OOV, bar=OOV, bye=4, ZH_HELLO=5, hi=3 text_ids = [i + 1 for i in [3, 6, 6, 4, 5, 3]] self.assertEqual(text_ids, encoder.encode(text)) # With mixed tokens text = 'hi<<>><<>foo!<EOS>^* barzoo! FOO && bye (%s hi)' % ZH_HELLO # hi=3, foo=OOV, <EOS>=0, bar=OOV, zoo!=1, FOO=OOV, bye=4, ZH_HELLO=5, hi=3 text_ids = [i + 1 for i in [3, 6, 0, 6, 1, 6, 4, 5, 3]] self.assertEqual(text_ids, encoder.encode(text))
def test_encode_decode(self): encoder = text_encoder.TokenTextEncoder( vocab_list=['hi', 'bye', ZH_HELLO]) ids = [i + 1 for i in [0, 1, 2, 0]] self.assertEqual(ids, encoder.encode('hi bye %s hi' % ZH_HELLO)) self.assertEqual('hi bye %shi' % ZH_HELLO, encoder.decode(ids))