def assert_state_is_compatible(expected_state, state): """Asserts that states are compatible. Args: expected_state: The reference state. state: The state that must be compatible with :obj:`expected_state`. Raises: ValueError: if the states are incompatible. """ # Check structure compatibility. compat.nest.assert_same_structure(expected_state, state) # Check shape compatibility. expected_state_flat = compat.nest.flatten(expected_state) state_flat = compat.nest.flatten(state) for x, y in zip(expected_state_flat, state_flat): if compat.is_tensor(x): expected_depth = x.get_shape().as_list()[-1] depth = y.get_shape().as_list()[-1] if depth != expected_depth: raise ValueError( "Tensor in state has shape %s which is incompatible " "with the target shape %s" % (y.shape, x.shape))
def detokenize(self, tokens, sequence_length=None): """Detokenizes tokens. The Tensor version supports batches of tokens. Args: tokens: The tokens as a 1-D or 2-D ``tf.Tensor`` or list of Python strings. sequence_length: The length of each sequence. Required if :obj:`tokens` is a ``tf.Tensor``. Returns: A 0-D or 1-D string ``tf.Tensor`` if :obj:`tokens` is a ``tf.Tensor`` or a Python unicode strings otherwise. Raises: ValueError: if the rank of :obj:`tokens` is greater than 2. ValueError: if :obj:`tokens` is a 2-D ``tf.Tensor`` and :obj:`sequence_length` is not set. """ if compat.is_tensor(tokens): rank = len(tokens.get_shape().as_list()) if rank == 1: return self._detokenize_tensor(tokens) elif rank == 2: if sequence_length is None: raise ValueError( "sequence_length is required for Tensor detokenization" ) return self._detokenize_batch_tensor(tokens, sequence_length) else: raise ValueError( "Unsupported tensor rank for detokenization: {}".format( rank)) else: tokens = [tf.compat.as_text(token) for token in tokens] return self._detokenize_string(tokens)
def tokenize(self, text): """Tokenizes text. Args: text: The text to tokenize as a ``tf.Tensor`` or Python string. Returns: A 1-D string ``tf.Tensor`` if :obj:`text` is a ``tf.Tensor`` or a list of Python unicode strings otherwise. Raises: ValueError: if the rank of :obj:`text` is greater than 0. """ if compat.is_tensor(text): rank = len(text.get_shape().as_list()) if rank == 0: return self._tokenize_tensor(text) else: raise ValueError( "Unsupported tensor rank for tokenization: {}".format( rank)) else: text = tf.compat.as_text(text) return self._tokenize_string(text)