Esempio n. 1
0
def assert_state_is_compatible(expected_state, state):
    """Asserts that states are compatible.

  Args:
    expected_state: The reference state.
    state: The state that must be compatible with :obj:`expected_state`.

  Raises:
    ValueError: if the states are incompatible.
  """
    # Check structure compatibility.
    compat.nest.assert_same_structure(expected_state, state)

    # Check shape compatibility.
    expected_state_flat = compat.nest.flatten(expected_state)
    state_flat = compat.nest.flatten(state)

    for x, y in zip(expected_state_flat, state_flat):
        if compat.is_tensor(x):
            expected_depth = x.get_shape().as_list()[-1]
            depth = y.get_shape().as_list()[-1]
            if depth != expected_depth:
                raise ValueError(
                    "Tensor in state has shape %s which is incompatible "
                    "with the target shape %s" % (y.shape, x.shape))
Esempio n. 2
0
    def detokenize(self, tokens, sequence_length=None):
        """Detokenizes tokens.

    The Tensor version supports batches of tokens.

    Args:
      tokens: The tokens as a 1-D or 2-D ``tf.Tensor`` or list of Python
        strings.
      sequence_length: The length of each sequence. Required if :obj:`tokens`
        is a ``tf.Tensor``.

    Returns:
      A 0-D or 1-D string ``tf.Tensor`` if :obj:`tokens` is a ``tf.Tensor`` or a
      Python unicode strings otherwise.

    Raises:
      ValueError: if the rank of :obj:`tokens` is greater than 2.
      ValueError: if :obj:`tokens` is a 2-D ``tf.Tensor`` and
        :obj:`sequence_length` is not set.
    """
        if compat.is_tensor(tokens):
            rank = len(tokens.get_shape().as_list())
            if rank == 1:
                return self._detokenize_tensor(tokens)
            elif rank == 2:
                if sequence_length is None:
                    raise ValueError(
                        "sequence_length is required for Tensor detokenization"
                    )
                return self._detokenize_batch_tensor(tokens, sequence_length)
            else:
                raise ValueError(
                    "Unsupported tensor rank for detokenization: {}".format(
                        rank))
        else:
            tokens = [tf.compat.as_text(token) for token in tokens]
            return self._detokenize_string(tokens)
Esempio n. 3
0
    def tokenize(self, text):
        """Tokenizes text.

    Args:
      text: The text to tokenize as a ``tf.Tensor`` or Python string.

    Returns:
      A 1-D string ``tf.Tensor`` if :obj:`text` is a ``tf.Tensor`` or a list of
      Python unicode strings otherwise.

    Raises:
      ValueError: if the rank of :obj:`text` is greater than 0.
    """
        if compat.is_tensor(text):
            rank = len(text.get_shape().as_list())
            if rank == 0:
                return self._tokenize_tensor(text)
            else:
                raise ValueError(
                    "Unsupported tensor rank for tokenization: {}".format(
                        rank))
        else:
            text = tf.compat.as_text(text)
            return self._tokenize_string(text)