def test_file_backed_with_args(self): with test_utils.tmp_dir(self.get_temp_dir()) as tmp_dir: # Set all the args to non-default values, including Tokenizer tokenizer = text_encoder.Tokenizer( reserved_tokens=['<FOOBAR>'], alphanum_only=False) encoder = text_encoder.TokenTextEncoder( vocab_list=['hi', 'bye', ZH_HELLO], lowercase=True, oov_buckets=2, oov_token='ZOO', tokenizer=tokenizer) vocab_fname = os.path.join(tmp_dir, 'vocab') encoder.save_to_file(vocab_fname) file_backed_encoder = text_encoder.TokenTextEncoder.load_from_file( vocab_fname) self.assertEqual(encoder.tokens, file_backed_encoder.tokens) self.assertEqual(encoder.vocab_size, file_backed_encoder.vocab_size) self.assertEqual(encoder.lowercase, file_backed_encoder.lowercase) self.assertEqual(encoder.oov_token, file_backed_encoder.oov_token) self.assertEqual(encoder.tokenizer.alphanum_only, file_backed_encoder.tokenizer.alphanum_only) self.assertEqual(encoder.tokenizer.reserved_tokens, file_backed_encoder.tokenizer.reserved_tokens)
def _init_from_list(self, subwords): """Initializes the encoder from a list of subwords.""" subwords = [tf.compat.as_text(s) for s in subwords if s] self._subwords = subwords # Note that internally everything is 0-indexed. Padding is dealt with at the # end of encode and the beginning of decode. self._subword_to_id = {s: i for i, s in enumerate(subwords)} # We remember the maximum length of any subword to avoid having to # check arbitrarily long strings. self._max_subword_len = max( len(_UNDERSCORE_REPLACEMENT), max([len(s) for s in subwords] or [1])) # Initialize the cache self._cache_size = 2**20 self._token_to_ids_cache = [(None, None)] * self._cache_size # Setup tokenizer # Reserved tokens are all tokens that are mixed alphanum and non-alphanum. reserved_tokens = set([_UNDERSCORE_REPLACEMENT]) for t in self._subwords: if text_encoder.is_mixed_alphanum(t): reserved_tokens.add(t) self._tokenizer = text_encoder.Tokenizer( alphanum_only=False, reserved_tokens=reserved_tokens)
def test_tokenization(self): encoder = text_encoder.TokenTextEncoder(vocab_list=['hi', 'bye', ZH_HELLO]) text = 'hi<<>><<>foo!^* bar && bye (%s hi)' % ZH_HELLO self.assertEqual(['hi', 'foo', 'bar', 'bye', ZH_HELLO.strip(), 'hi'], text_encoder.Tokenizer().tokenize(text)) self.assertEqual([i + 1 for i in [0, 3, 3, 1, 2, 0]], encoder.encode(text))
def test_reserved_tokens(self): text = 'hello worldbar bar foozoo zoo FOO<EOS>' tokens = ['hello', ' ', 'world', 'bar', ' ', 'bar', ' ', 'foozoo', ' ', 'zoo', ' ', 'FOO', '<EOS>'] tokenizer = text_encoder.Tokenizer(alphanum_only=False, reserved_tokens=['<EOS>', 'FOO', 'bar']) self.assertEqual(tokens, tokenizer.tokenize(text)) self.assertEqual(text, tokenizer.join(tokenizer.tokenize(text)))
def test_with_nonalphanum(self): text = 'hi world<<>><<>foo!^* bar && bye (%s hi)' % ZH_HELLO tokens = [ 'hi', ' ', 'world', '<<>><<>', 'foo', '!^* ', 'bar', ' && ', 'bye', ' (', ZH_HELLO.strip(), ' ', 'hi', ')' ] tokenizer = text_encoder.Tokenizer(alphanum_only=False) self.assertEqual(tokens, tokenizer.tokenize(text)) self.assertEqual(text, tokenizer.join(tokenizer.tokenize(text)))
def _token_counts_from_generator(generator, max_chars, reserved_tokens): """Builds token counts from generator.""" reserved_tokens = list(reserved_tokens) + [_UNDERSCORE_REPLACEMENT] tokenizer = text_encoder.Tokenizer( alphanum_only=False, reserved_tokens=reserved_tokens) num_chars = 0 token_counts = collections.defaultdict(int) for s in generator: s = tf.compat.as_text(s) if max_chars and (num_chars + len(s)) >= max_chars: s = s[:(max_chars - num_chars)] tokens = tokenizer.tokenize(s) tokens = _prepare_tokens_for_encode(tokens) for t in tokens: token_counts[t] += 1 if max_chars: num_chars += len(s) if num_chars > max_chars: break return token_counts
def test_whitespace(self, s, exp): tokenizer = text_encoder.Tokenizer(alphanum_only=False) self.assertEqual(exp, tokenizer.tokenize(s)) self.assertEqual(s, tokenizer.join(tokenizer.tokenize(s)))
def test_default(self): text = 'hi<<>><<>foo!^* bar && bye (%s hi)' % ZH_HELLO self.assertEqual(['hi', 'foo', 'bar', 'bye', ZH_HELLO.strip(), 'hi'], text_encoder.Tokenizer().tokenize(text))