def test_translation_encoded(self): # Unicode integer-encoded by byte self.assertFeature( feature=features.Translation( languages=["en", "zh"], encoder=text_encoder.ByteTextEncoder()), shape={ "en": (None, ), "zh": (None, ) }, dtype={ "en": tf.int64, "zh": tf.int64 }, tests=[ testing.FeatureExpectationItem( value={ "en": EN_HELLO, "zh": ZH_HELLO }, expected={ # Incremented for pad "en": [i + 1 for i in [104, 101, 108, 108, 111, 32]], "zh": [i + 1 for i in [228, 189, 160, 229, 165, 189, 32]] }, ), ], )
def test_translation_multiple_encoders(self): # Unicode integer-encoded by byte self.assertFeature( feature=features.Translation(languages=["en", "zh"], encoder=[ text_encoder.TokenTextEncoder( ["hello", " "]), text_encoder.ByteTextEncoder() ]), shape={ "en": (None, ), "zh": (None, ) }, dtype={ "en": tf.int64, "zh": tf.int64 }, tests=[ testing.FeatureExpectationItem( value={ "en": EN_HELLO, "zh": ZH_HELLO }, expected={ "en": [1], "zh": [i + 1 for i in [228, 189, 160, 229, 165, 189, 32]] }, ), ], )
def test_encode_decode(self): encoder = text_encoder.ByteTextEncoder() self.assertEqual(self.ZH_HELLO_IDS, encoder.encode(ZH_HELLO)) self.assertEqual(self.EN_HELLO_IDS, encoder.encode(EN_HELLO)) self.assertEqual(self.EN_HELLO_IDS, encoder.encode('hello ')) self.assertEqual(EN_HELLO, encoder.decode(self.EN_HELLO_IDS)) self.assertEqual(ZH_HELLO, encoder.decode(self.ZH_HELLO_IDS)) self.assertEqual(text_encoder.NUM_BYTES + 1, encoder.vocab_size)
def test_file_backed(self, additional_tokens): encoder = text_encoder.ByteTextEncoder(additional_tokens=additional_tokens) with test_utils.tmp_dir(self.get_temp_dir()) as tmp_dir: vocab_fname = os.path.join(tmp_dir, 'vocab') encoder.save_to_file(vocab_fname) file_backed_encoder = text_encoder.ByteTextEncoder.load_from_file( vocab_fname) self.assertEqual(encoder.vocab_size, file_backed_encoder.vocab_size) self.assertEqual(encoder.additional_tokens, file_backed_encoder.additional_tokens)
def test_reserved_tokens(self): # One with non-alphanumeric chars, one uppercase, one lowercase reserved_tokens = ['<EOS>', 'FOO', 'bar'] encoder = text_encoder.ByteTextEncoder(reserved_tokens=reserved_tokens) # Without reserved tokens hello_ids = [i + len(reserved_tokens) for i in self.ZH_HELLO_IDS] self.assertEqual(hello_ids, encoder.encode(ZH_HELLO)) # With reserved tokens text_reserved = '%s %s%s%s' % (reserved_tokens[0], ZH_HELLO, reserved_tokens[1], reserved_tokens[2]) self.assertEqual([1, 32 + 1 + len(reserved_tokens)] + hello_ids + [2, 3], encoder.encode(text_reserved)) self.assertEqual(2**8 + 1 + len(reserved_tokens), encoder.vocab_size)
def test_save_load_metadata(self): text_f = features.Text(encoder=text_encoder.ByteTextEncoder( additional_tokens=['HI'])) text = u'HI 你好' ids = text_f.str2ints(text) self.assertEqual(1, ids[0]) with testing.tmp_dir(self.get_temp_dir()) as data_dir: feature_name = 'dummy' text_f.save_metadata(data_dir, feature_name) new_f = features.Text() new_f.load_metadata(data_dir, feature_name) self.assertEqual(ids, text_f.str2ints(text))
def expectations(self): nonunicode_text = 'hello world' unicode_text = u'你好' return [ test_utils.FeatureExpectation( name='text', feature=features.Text(), shape=(), dtype=tf.string, tests=[ # Non-unicode test_utils.FeatureExpectationItem( value=nonunicode_text, expected=tf.compat.as_bytes(nonunicode_text), ), # Unicode test_utils.FeatureExpectationItem( value=unicode_text, expected=tf.compat.as_bytes(unicode_text), ), # Empty string test_utils.FeatureExpectationItem( value='', expected=tf.compat.as_bytes(''), ), ], ), # Unicode integer-encoded by byte test_utils.FeatureExpectation( name='text_unicode_encoded', feature=features.Text(encoder=text_encoder.ByteTextEncoder()), shape=(None, ), dtype=tf.int64, tests=[ test_utils.FeatureExpectationItem( value=unicode_text, expected=[ i + 1 for i in [228, 189, 160, 229, 165, 189] ], ), # Empty string test_utils.FeatureExpectationItem( value='', expected=[], ), ], ), ]
def test_additional_tokens(self): # One with non-alphanumeric chars, one uppercase, one lowercase additional_tokens = ['<EOS>', 'FOO', 'bar'] encoder = text_encoder.ByteTextEncoder(additional_tokens=additional_tokens) # Without additional tokens hello_ids = [i + len(additional_tokens) for i in self.ZH_HELLO_IDS] self.assertEqual(hello_ids, encoder.encode(ZH_HELLO)) # With additional tokens text_additional = '%s %s%s%s' % (additional_tokens[0], ZH_HELLO, additional_tokens[1], additional_tokens[2]) expected_ids = [1, 32 + 1 + len(additional_tokens)] + hello_ids + [2, 3] self.assertEqual(expected_ids, encoder.encode(text_additional)) self.assertEqual(text_additional, encoder.decode(expected_ids)) self.assertEqual(text_encoder.NUM_BYTES + 1 + len(additional_tokens), encoder.vocab_size)
def test_text_encoded(self): unicode_text = u'你好' # Unicode integer-encoded by byte self.assertFeature( feature=features.Text(encoder=text_encoder.ByteTextEncoder()), shape=(None, ), dtype=tf.int64, tests=[ testing.FeatureExpectationItem( value=unicode_text, expected=[i + 1 for i in [228, 189, 160, 229, 165, 189]], ), # Empty string testing.FeatureExpectationItem( value='', expected=[], ), ], )
def test_text_conversion(self): text_f = features.Text(encoder=text_encoder.ByteTextEncoder()) text = u'你好' self.assertEqual(text, text_f.ints2str(text_f.str2ints(text)))