Beispiel #1
0
  def _init_from_list(self, subwords):
    """Initializes the encoder from a list of subwords."""
    subwords = [tf.compat.as_text(s) for s in subwords if s]
    self._subwords = subwords
    # Note that internally everything is 0-indexed. Padding is dealt with at the
    # end of encode and the beginning of decode.
    self._subword_to_id = {s: i for i, s in enumerate(subwords)}

    # We remember the maximum length of any subword to avoid having to
    # check arbitrarily long strings.
    self._max_subword_len = max(
        len(_UNDERSCORE_REPLACEMENT), max([len(s) for s in subwords] or [1]))

    # Initialize the cache
    self._cache_size = 2**20
    self._token_to_ids_cache = [(None, None)] * self._cache_size

    # Setup tokenizer
    # Reserved tokens are all tokens that are mixed alphanum and non-alphanum.
    reserved_tokens = set([_UNDERSCORE_REPLACEMENT])
    for t in self._subwords:
      if text_encoder.is_mixed_alphanum(t):
        reserved_tokens.add(t)
    self._tokenizer = text_encoder.Tokenizer(
        alphanum_only=False, reserved_tokens=reserved_tokens)
 def test_is_mixed_alphanum(self):
   self.assertFalse(text_encoder.is_mixed_alphanum('hi'))
   self.assertFalse(text_encoder.is_mixed_alphanum(ZH_HELLO[:-1]))
   self.assertTrue(text_encoder.is_mixed_alphanum('hi.'))
   self.assertTrue(text_encoder.is_mixed_alphanum('hi.bye'))
   self.assertTrue(text_encoder.is_mixed_alphanum('hi '))
   self.assertTrue(text_encoder.is_mixed_alphanum(ZH_HELLO))
Beispiel #3
0
def _validate_build_arguments(max_subword_length, reserved_tokens,
                              target_vocab_size):
  """Validate arguments for SubwordTextEncoder.build_from_corpus."""
  if max_subword_length <= 0:
    raise ValueError(
        "max_subword_length must be > 0. Note that memory and compute for "
        "building the vocabulary scale quadratically in the length of the "
        "longest token.")
  for t in reserved_tokens:
    if t.endswith("_") or not text_encoder.is_mixed_alphanum(t):
      raise ValueError(
          "Reserved tokens must not end with _ and they must contain a mix "
          "of alphanumeric and non-alphanumeric characters. For example, "
          "'<EOS>'.")
  # Minimum vocab size = bytes + pad + 1
  minimum_vocab_size = text_encoder.NUM_BYTES + 1 + 1
  if target_vocab_size < minimum_vocab_size:
    raise ValueError("target_vocab_size must be >= %d. Got %d" %
                     (minimum_vocab_size, target_vocab_size))