def _init_from_list(self, subwords): """Initializes the encoder from a list of subwords.""" subwords = [tf.compat.as_text(s) for s in subwords if s] self._subwords = subwords # Note that internally everything is 0-indexed. Padding is dealt with at the # end of encode and the beginning of decode. self._subword_to_id = {s: i for i, s in enumerate(subwords)} # We remember the maximum length of any subword to avoid having to # check arbitrarily long strings. self._max_subword_len = max( len(_UNDERSCORE_REPLACEMENT), max([len(s) for s in subwords] or [1])) # Initialize the cache self._cache_size = 2**20 self._token_to_ids_cache = [(None, None)] * self._cache_size # Setup tokenizer # Reserved tokens are all tokens that are mixed alphanum and non-alphanum. reserved_tokens = set([_UNDERSCORE_REPLACEMENT]) for t in self._subwords: if text_encoder.is_mixed_alphanum(t): reserved_tokens.add(t) self._tokenizer = text_encoder.Tokenizer( alphanum_only=False, reserved_tokens=reserved_tokens)
def test_is_mixed_alphanum(self): self.assertFalse(text_encoder.is_mixed_alphanum('hi')) self.assertFalse(text_encoder.is_mixed_alphanum(ZH_HELLO[:-1])) self.assertTrue(text_encoder.is_mixed_alphanum('hi.')) self.assertTrue(text_encoder.is_mixed_alphanum('hi.bye')) self.assertTrue(text_encoder.is_mixed_alphanum('hi ')) self.assertTrue(text_encoder.is_mixed_alphanum(ZH_HELLO))
def _validate_build_arguments(max_subword_length, reserved_tokens, target_vocab_size): """Validate arguments for SubwordTextEncoder.build_from_corpus.""" if max_subword_length <= 0: raise ValueError( "max_subword_length must be > 0. Note that memory and compute for " "building the vocabulary scale quadratically in the length of the " "longest token.") for t in reserved_tokens: if t.endswith("_") or not text_encoder.is_mixed_alphanum(t): raise ValueError( "Reserved tokens must not end with _ and they must contain a mix " "of alphanumeric and non-alphanumeric characters. For example, " "'<EOS>'.") # Minimum vocab size = bytes + pad + 1 minimum_vocab_size = text_encoder.NUM_BYTES + 1 + 1 if target_vocab_size < minimum_vocab_size: raise ValueError("target_vocab_size must be >= %d. Got %d" % (minimum_vocab_size, target_vocab_size))