Ejemplo n.º 1
0
    def flatten(self, tensor):
        """Flattens and caches the tensor's batch_dims."""
        with tf.name_scope('batch_flatten'):
            if self._batch_dims == 1:
                return tensor

            self._original_tensor_shape = composite.shape(tensor)

            if tensor.shape[self._batch_dims:].is_fully_defined():
                return composite.reshape(
                    tensor, [-1] + tensor.shape[self._batch_dims:].as_list())

            reshaped = composite.reshape(
                tensor,
                tf.concat(
                    [[-1], composite.shape(tensor)[self._batch_dims:]],
                    axis=0),
            )
            # If the batch dimensions are all defined but the rest are undefined,
            # `reshaped` will have None as the first squashed dim since we are calling
            # tf.shape above. Since we know how many batch_dims we have, we can check
            # if all the elements we want to squash are defined, allowing us to
            # call ensure_shape to set the shape of the squashed dim. Note that this
            # is only implemented for tf.Tensor and not SparseTensors.
            if (isinstance(tensor, tf.Tensor)
                    and tensor.shape[:self._batch_dims].is_fully_defined()):
                return tf.ensure_shape(
                    reshaped,
                    [np.prod(tensor.shape[:self._batch_dims], dtype=np.int64)
                     ] + tensor.shape[self._batch_dims:])
            return reshaped
Ejemplo n.º 2
0
def flatten_multi_batched_nested_tensors(tensors, specs):
    """Reshape tensors to contain only one batch dimension.

  For each tensor, it checks the number of extra dimensions beyond those in
  the spec, and reshapes tensor to have only one batch dimension.
  NOTE: Each tensor's batch dimensions must be the same.

  Args:
    tensors: Nested list/tuple or dict of batched Tensors or SparseTensors.
    specs: Nested list/tuple or dict of TensorSpecs, describing the shape of the
      non-batched Tensors.

  Returns:
    A nested version of each tensor with a single batch dimension.
    A list of the batch dimensions which were flattened.
  Raises:
    ValueError: if the tensors and specs have incompatible dimensions or shapes.
  """
    assert_same_structure(
        tensors,
        specs,
        message='Tensors and specs do not have matching structures')
    flat_tensors = tf.nest.flatten(tensors)
    flat_spec_shapes = [spec_shape(s) for s in tf.nest.flatten(specs)]
    out_tensors = []
    batch_dims = []
    for i, (tensor, sp_shape) in enumerate(zip(flat_tensors,
                                               flat_spec_shapes)):
        if i == 0:  # Set batch_dims based on first tensor.
            batch_dims = tensor.shape[:tensor.shape.rank - sp_shape.rank]
            if batch_dims.is_fully_defined():
                batch_dims = batch_dims.as_list()
                batch_prod = np.prod(batch_dims)
                batch_dims = tf.constant(batch_dims, dtype=tf.int64)
            else:
                batch_dims = tf.shape(tensor)[:tensor.shape.rank -
                                              sp_shape.rank]
                batch_prod = tf.reduce_prod(batch_dims)
        if not sp_shape.is_fully_defined():
            # When shape of spec is not fully defined, we do not rely on it to
            # reshape the tensor but retain the original non-batch dims of tensors.
            non_batch_dims = tf.shape(tensor)[tensor.shape.rank -
                                              sp_shape.rank:]
            reshaped_dims = tf.concat([[batch_prod], non_batch_dims], 0)
        else:
            reshaped_dims = [batch_prod] + sp_shape.as_list()
        out_tensors.append(composite.reshape(tensor, reshaped_dims))
    return tf.nest.pack_sequence_as(tensors, out_tensors), batch_dims
Ejemplo n.º 3
0
  def unflatten(self, tensor):
    """Unflattens the tensor's batch_dims using the cached shape."""
    with tf.name_scope('batch_unflatten'):
      if self._batch_dims == 1:
        return tensor

      if self._original_tensor_shape is None:
        raise ValueError('Please call flatten before unflatten.')

      # pyformat: disable
      return composite.reshape(
          tensor,
          tf.concat([
              self._original_tensor_shape[:self._batch_dims],
              composite.shape(tensor)[1:]], axis=0)
      )
Ejemplo n.º 4
0
  def flatten(self, tensor):
    """Flattens and caches the tensor's batch_dims."""
    with tf.name_scope('batch_flatten'):
      if self._batch_dims == 1:
        return tensor

      self._original_tensor_shape = composite.shape(tensor)

      if tensor.shape[self._batch_dims:].is_fully_defined():
        return composite.reshape(
            tensor, [-1] + tensor.shape[self._batch_dims:].as_list())

      return tf.reshape(
          tensor,
          tf.concat([[-1], composite.shape(tensor)[self._batch_dims:]], axis=0),
      )