示例#1
0
    def _init_from_list(self, subwords):
        """Initializes the encoder from a list of subwords."""
        subwords = [tf.compat.as_text(s) for s in subwords if s]
        self._subwords = subwords
        # Note that internally everything is 0-indexed. Padding is dealt with at the
        # end of encode and the beginning of decode.
        self._subword_to_id = {s: i for i, s in enumerate(subwords)}

        # We remember the maximum length of any subword to avoid having to
        # check arbitrarily long strings.
        self._max_subword_len = max(len(_UNDERSCORE_REPLACEMENT),
                                    max([len(s) for s in subwords] or [1]))

        # Initialize the cache
        self._cache_size = 2**20
        self._token_to_ids_cache = [(None, None)] * self._cache_size

        # Setup tokenizer
        # Reserved tokens are all tokens that are mixed alphanum and non-alphanum.
        reserved_tokens = set([_UNDERSCORE_REPLACEMENT])
        for t in self._subwords:
            if text_encoder.is_mixed_alphanum(t):
                reserved_tokens.add(t)
        self._tokenizer = text_encoder.Tokenizer(
            alphanum_only=False, reserved_tokens=reserved_tokens)
示例#2
0
 def test_is_mixed_alphanum(self):
   self.assertFalse(text_encoder.is_mixed_alphanum('hi'))
   self.assertFalse(text_encoder.is_mixed_alphanum(ZH_HELLO[:-1]))
   self.assertTrue(text_encoder.is_mixed_alphanum('hi.'))
   self.assertTrue(text_encoder.is_mixed_alphanum('hi.bye'))
   self.assertTrue(text_encoder.is_mixed_alphanum('hi '))
   self.assertTrue(text_encoder.is_mixed_alphanum(ZH_HELLO))
示例#3
0
def _validate_build_arguments(max_subword_length, reserved_tokens,
                              target_vocab_size):
    """Validate arguments for SubwordTextEncoder.build_from_corpus."""
    if max_subword_length <= 0:
        raise ValueError(
            "max_subword_length must be > 0. Note that memory and compute for "
            "building the vocabulary scale quadratically in the length of the "
            "longest token.")
    for t in reserved_tokens:
        if t.endswith("_") or not text_encoder.is_mixed_alphanum(t):
            raise ValueError(
                "Reserved tokens must not end with _ and they must contain a mix "
                "of alphanumeric and non-alphanumeric characters. For example, "
                "'<EOS>'.")
    # Minimum vocab size = bytes + pad + 1
    minimum_vocab_size = text_encoder.NUM_BYTES + 1 + 1
    if target_vocab_size < minimum_vocab_size:
        raise ValueError("target_vocab_size must be >= %d. Got %d" %
                         (minimum_vocab_size, target_vocab_size))