def unrag_tensor(x: tf.RaggedTensor, max_size: int, axis: int) -> tf.Tensor: """Converts a ragged tensor to a full tensor by padding to a maximum size. This function is useful for converting ragged tensors to a fixed size when one or more of the dimensions are of variable length. Args: x: Ragged tensor to convert. max_size: Maximum size of the axis to pad. axis: Axis of `x` to pad to `max_size`. This must specify ragged dimensions. If more than one axis is specified, `max_size` must be of the same length as `axis`. Returns: A padded version of `x`. Padding will use the equivalent of NaNs in the tensor's native dtype. This will replace the shape of the specified `axis` with `max_size`, leaving the remaining dimensions set to the bounding shape of the ragged tensor. """ bounding_shape = x.bounding_shape() axis = tf.cast(axis, tf.int64) axis = axis % len(x.shape) # Handle negative indices. axis = tf.reshape(axis, [-1, 1]) # Ensure (n, 1) shape for indexing. max_size = tf.cast(max_size, bounding_shape.dtype) max_size = tf.reshape(max_size, [-1]) # Ensure (n,) shape for indexing. shape = tf.tensor_scatter_nd_update(bounding_shape, axis, max_size) return x.to_tensor(default_value=tf.cast(np.NaN, x.dtype), shape=shape)
def _truncate_row_lengths(ragged_tensor: tf.RaggedTensor, new_lengths: tf.Tensor) -> tf.RaggedTensor: """Truncates the rows of `ragged_tensor` to the given row lengths.""" new_lengths = tf.broadcast_to(new_lengths, ragged_tensor.bounding_shape()[0:1]) def fn(x): row, new_length = x return row[0:new_length] fn_dtype = tf.RaggedTensorSpec(dtype=ragged_tensor.dtype, ragged_rank=ragged_tensor.ragged_rank - 1) result = tf.map_fn(fn, (ragged_tensor, new_lengths), dtype=fn_dtype) # Work around broken shape propagation: without this, result has unknown rank. flat_values_shape = [None] * ragged_tensor.flat_values.shape.rank result = result.with_flat_values( tf.ensure_shape(result.flat_values, flat_values_shape)) return result