Exemple #1
0
def unrag_tensor(x: tf.RaggedTensor, max_size: int, axis: int) -> tf.Tensor:
    """Converts a ragged tensor to a full tensor by padding to a maximum size.

    This function is useful for converting ragged tensors to a fixed size when one or
    more of the dimensions are of variable length.

    Args:
        x: Ragged tensor to convert.
        max_size: Maximum size of the axis to pad.
        axis: Axis of `x` to pad to `max_size`. This must specify ragged dimensions.
            If more than one axis is specified, `max_size` must be of the same length as
            `axis`.

    Returns:
        A padded version of `x`. Padding will use the equivalent of NaNs in the tensor's
        native dtype.

        This will replace the shape of the specified `axis` with `max_size`, leaving the
        remaining dimensions set to the bounding shape of the ragged tensor.
    """
    bounding_shape = x.bounding_shape()
    axis = tf.cast(axis, tf.int64)
    axis = axis % len(x.shape)  # Handle negative indices.
    axis = tf.reshape(axis, [-1, 1])  # Ensure (n, 1) shape for indexing.
    max_size = tf.cast(max_size, bounding_shape.dtype)
    max_size = tf.reshape(max_size, [-1])  # Ensure (n,) shape for indexing.
    shape = tf.tensor_scatter_nd_update(bounding_shape, axis, max_size)
    return x.to_tensor(default_value=tf.cast(np.NaN, x.dtype), shape=shape)
def _truncate_row_lengths(ragged_tensor: tf.RaggedTensor,
                          new_lengths: tf.Tensor) -> tf.RaggedTensor:
  """Truncates the rows of `ragged_tensor` to the given row lengths."""
  new_lengths = tf.broadcast_to(new_lengths,
                                ragged_tensor.bounding_shape()[0:1])
  def fn(x):
    row, new_length = x
    return row[0:new_length]
  fn_dtype = tf.RaggedTensorSpec(dtype=ragged_tensor.dtype,
                                 ragged_rank=ragged_tensor.ragged_rank - 1)
  result = tf.map_fn(fn, (ragged_tensor, new_lengths), dtype=fn_dtype)
  # Work around broken shape propagation: without this, result has unknown rank.
  flat_values_shape = [None] * ragged_tensor.flat_values.shape.rank
  result = result.with_flat_values(
      tf.ensure_shape(result.flat_values, flat_values_shape))

  return result