Ejemplo n.º 1
0
  def test_file_backed_with_args(self):
    with test_utils.tmp_dir(self.get_temp_dir()) as tmp_dir:
      # Set all the args to non-default values, including Tokenizer
      tokenizer = text_encoder.Tokenizer(
          reserved_tokens=['<FOOBAR>'], alphanum_only=False)
      encoder = text_encoder.TokenTextEncoder(
          vocab_list=['hi', 'bye', ZH_HELLO],
          lowercase=True,
          oov_buckets=2,
          oov_token='ZOO',
          tokenizer=tokenizer)

      vocab_fname = os.path.join(tmp_dir, 'vocab')
      encoder.save_to_file(vocab_fname)

      file_backed_encoder = text_encoder.TokenTextEncoder.load_from_file(
          vocab_fname)
      self.assertEqual(encoder.tokens, file_backed_encoder.tokens)
      self.assertEqual(encoder.vocab_size, file_backed_encoder.vocab_size)
      self.assertEqual(encoder.lowercase, file_backed_encoder.lowercase)
      self.assertEqual(encoder.oov_token, file_backed_encoder.oov_token)
      self.assertEqual(encoder.tokenizer.alphanum_only,
                       file_backed_encoder.tokenizer.alphanum_only)
      self.assertEqual(encoder.tokenizer.reserved_tokens,
                       file_backed_encoder.tokenizer.reserved_tokens)
Ejemplo n.º 2
0
  def _init_from_list(self, subwords):
    """Initializes the encoder from a list of subwords."""
    subwords = [tf.compat.as_text(s) for s in subwords if s]
    self._subwords = subwords
    # Note that internally everything is 0-indexed. Padding is dealt with at the
    # end of encode and the beginning of decode.
    self._subword_to_id = {s: i for i, s in enumerate(subwords)}

    # We remember the maximum length of any subword to avoid having to
    # check arbitrarily long strings.
    self._max_subword_len = max(
        len(_UNDERSCORE_REPLACEMENT), max([len(s) for s in subwords] or [1]))

    # Initialize the cache
    self._cache_size = 2**20
    self._token_to_ids_cache = [(None, None)] * self._cache_size

    # Setup tokenizer
    # Reserved tokens are all tokens that are mixed alphanum and non-alphanum.
    reserved_tokens = set([_UNDERSCORE_REPLACEMENT])
    for t in self._subwords:
      if text_encoder.is_mixed_alphanum(t):
        reserved_tokens.add(t)
    self._tokenizer = text_encoder.Tokenizer(
        alphanum_only=False, reserved_tokens=reserved_tokens)
Ejemplo n.º 3
0
 def test_tokenization(self):
   encoder = text_encoder.TokenTextEncoder(vocab_list=['hi', 'bye', ZH_HELLO])
   text = 'hi<<>><<>foo!^* bar && bye (%s hi)' % ZH_HELLO
   self.assertEqual(['hi', 'foo', 'bar', 'bye',
                     ZH_HELLO.strip(), 'hi'],
                    text_encoder.Tokenizer().tokenize(text))
   self.assertEqual([i + 1 for i in [0, 3, 3, 1, 2, 0]], encoder.encode(text))
Ejemplo n.º 4
0
 def test_reserved_tokens(self):
   text = 'hello worldbar bar foozoo zoo FOO<EOS>'
   tokens = ['hello', ' ', 'world', 'bar', ' ', 'bar', ' ', 'foozoo',
             ' ', 'zoo', ' ', 'FOO', '<EOS>']
   tokenizer = text_encoder.Tokenizer(alphanum_only=False,
                                      reserved_tokens=['<EOS>', 'FOO', 'bar'])
   self.assertEqual(tokens, tokenizer.tokenize(text))
   self.assertEqual(text, tokenizer.join(tokenizer.tokenize(text)))
Ejemplo n.º 5
0
 def test_with_nonalphanum(self):
     text = 'hi world<<>><<>foo!^* bar &&  bye (%s hi)' % ZH_HELLO
     tokens = [
         'hi', ' ', 'world', '<<>><<>', 'foo', '!^* ', 'bar', ' &&  ',
         'bye', ' (',
         ZH_HELLO.strip(), '  ', 'hi', ')'
     ]
     tokenizer = text_encoder.Tokenizer(alphanum_only=False)
     self.assertEqual(tokens, tokenizer.tokenize(text))
     self.assertEqual(text, tokenizer.join(tokenizer.tokenize(text)))
Ejemplo n.º 6
0
def _token_counts_from_generator(generator, max_chars, reserved_tokens):
  """Builds token counts from generator."""
  reserved_tokens = list(reserved_tokens) + [_UNDERSCORE_REPLACEMENT]
  tokenizer = text_encoder.Tokenizer(
      alphanum_only=False, reserved_tokens=reserved_tokens)
  num_chars = 0
  token_counts = collections.defaultdict(int)
  for s in generator:
    s = tf.compat.as_text(s)
    if max_chars and (num_chars + len(s)) >= max_chars:
      s = s[:(max_chars - num_chars)]
    tokens = tokenizer.tokenize(s)
    tokens = _prepare_tokens_for_encode(tokens)
    for t in tokens:
      token_counts[t] += 1
    if max_chars:
      num_chars += len(s)
      if num_chars > max_chars:
        break
  return token_counts
Ejemplo n.º 7
0
 def test_whitespace(self, s, exp):
   tokenizer = text_encoder.Tokenizer(alphanum_only=False)
   self.assertEqual(exp, tokenizer.tokenize(s))
   self.assertEqual(s, tokenizer.join(tokenizer.tokenize(s)))
Ejemplo n.º 8
0
 def test_default(self):
   text = 'hi<<>><<>foo!^* bar &&  bye (%s hi)' % ZH_HELLO
   self.assertEqual(['hi', 'foo', 'bar', 'bye', ZH_HELLO.strip(), 'hi'],
                    text_encoder.Tokenizer().tokenize(text))