def __init__(self, partitioned_dim_sizes, inner_dim_sizes, dim_size_dtype=None): """Creates a RaggedTensorDynamicShape. Args: partitioned_dim_sizes: A `list` of 0-D or 1-D integer `Tensor`, one for each partitioned dimension. If dimension `d` is uniform, then `partitioned_dim_sizes[d]` must be an integer scalar, specifying the size of all slices across dimension `d`. If dimension `d` is ragged, then `partitioned_dim_sizes[d]` must be an integer vector, specifying the size of each slice across dimension `d`. inner_dim_sizes: A 1-D integer `Tensor`, whose length is equal to the number of inner dimensions. `inner_dim_sizes[n]` is the size of all slices across the `n`th inner dimension (which is the `(len(partitioned_dim_sizes)+n)`th dimension in the overall tensor. dim_size_dtype: dtype for dimension sizes. If not specified, then it is chosen based on the dtypes of `partitioned_dim_sizes` and `inner_dim_sizes`. """ assert isinstance(partitioned_dim_sizes, (list, tuple)) with ops.name_scope(None, 'RaggedTensorDynamicShape', (partitioned_dim_sizes, inner_dim_sizes)): partitioned_dim_sizes = tuple( ops.convert_to_tensor(size, name='partitioned_dimension_size_%d' % i) for (i, size) in enumerate(partitioned_dim_sizes)) inner_dim_sizes = ops.convert_to_tensor( inner_dim_sizes, name='inner_dim_sizes') # Validate shapes. if partitioned_dim_sizes: for axis, dimension_size in enumerate(partitioned_dim_sizes): if dimension_size.shape.ndims is None: raise ValueError( 'rank of partitioned_dim_sizes[%d] is unknown' % axis) dimension_size.shape.with_rank_at_most(1) if partitioned_dim_sizes[0].shape.ndims == 1: raise ValueError('outermost partitioned dimension must be uniform') if partitioned_dim_sizes[-1].shape.ndims == 0: raise ValueError('innermost partitioned dimension must be ragged') inner_dim_sizes.shape.assert_has_rank(1) # Convert dimension size tensors to a single dtype. if dim_size_dtype is None: dim_size_dtypes = set([p.dtype for p in partitioned_dim_sizes if p.shape.ndims == 1]) if not dim_size_dtypes: dim_size_dtype = dtypes.int64 elif len(dim_size_dtypes) == 1: dim_size_dtype = dim_size_dtypes.pop() else: if not ragged_config.auto_cast_partition_dtype(): raise ValueError('partitioned_dim_sizes must have matching dtypes') dim_size_dtype = dtypes.int64 partitioned_dim_sizes = tuple(math_ops.cast(p, dim_size_dtype) for p in partitioned_dim_sizes) inner_dim_sizes = math_ops.cast(inner_dim_sizes, dim_size_dtype) self._partitioned_dim_sizes = partitioned_dim_sizes self._inner_dim_sizes = inner_dim_sizes
def _convert_declared_ragged(current, declared): """Converts an output with RaggedTensorType into a _RaggedTensorComponents.""" # Check that the ragged ranks match up. # + 1 to account for the rank of the outermost dimension. current_ragged_rank = getattr(current, "ragged_rank", 0) if declared.ragged_rank != current_ragged_rank + 1: raise ValueError( "The declared ragged rank (%d) mismatches the result (%d)" % (declared.ragged_rank, current_ragged_rank + 1)) # Check that dtypes match up. if declared.dtype != current.dtype: raise ValueError("The declared dtype (%s) mismatches the result (%s)" % (declared.dtype, current.dtype)) if (isinstance(current, ragged_tensor.RaggedTensor) and declared.row_splits_dtype != current.row_splits.dtype): if not ragged_config.auto_cast_partition_dtype(): raise ValueError( "The declared row_splits dtype (%s) mismatches the result (%s)." " Use RaggedTensor.with_row_splits_dtype to convert it." % (declared.row_splits_dtype, current.row_splits.dtype)) current = current.with_row_splits_dtype(declared.row_splits_dtype) if isinstance(current, ragged_tensor.RaggedTensor): return current else: nrows = array_ops.shape(current, out_type=declared.row_splits_dtype)[0] row_length = array_ops.expand_dims(nrows, axis=0) return _RaggedTensorComponents(flat_values=current, nested_row_lengths=(), outer_row_length=row_length)
def _convert_declared_ragged(current, declared): """Converts an output with RaggedTensorType into a _RaggedTensorComponents.""" # Check that the ragged ranks match up. # + 1 to account for the rank of the outermost dimension. current_ragged_rank = getattr(current, "ragged_rank", 0) if declared.ragged_rank != current_ragged_rank + 1: raise ValueError( "The declared ragged rank (%d) mismatches the result (%d)" % (declared.ragged_rank, current_ragged_rank + 1)) # Check that dtypes match up. if declared.dtype != current.dtype: raise ValueError( "The declared dtype (%s) mismatches the result (%s)" % (declared.dtype, current.dtype)) if (isinstance(current, ragged_tensor.RaggedTensor) and declared.row_splits_dtype != current.row_splits.dtype): if not ragged_config.auto_cast_partition_dtype(): raise ValueError( "The declared row_splits dtype (%s) mismatches the result (%s)." " Use RaggedTensor.with_row_splits_dtype to convert it." % (declared.row_splits_dtype, current.row_splits.dtype)) current = current.with_row_splits_dtype(declared.row_splits_dtype) if isinstance(current, ragged_tensor.RaggedTensor): return current else: nrows = array_ops.shape(current, out_type=declared.row_splits_dtype)[0] row_length = array_ops.expand_dims(nrows, axis=0) return _RaggedTensorComponents( flat_values=current, nested_row_lengths=(), outer_row_length=row_length)
def _broadcast_to_ragged_shape(rt_input, dst_shape, broadcast_inner_dimensions): """Broadcasts rt_input to the ragged shape `dst_shape`.""" # Check that rt_input and dst_shape have the same row_splits dtype. if (isinstance(rt_input, ragged_tensor.RaggedTensor) and rt_input.row_splits.dtype != dst_shape.dim_size_dtype): if not ragged_config.auto_cast_partition_dtype(): raise ValueError('rt_input and dst_shape have different row_split ' 'dtypes; use RaggedTensor.with_row_splits_dtype() or ' 'RaggedTensorDynamicShape.with_dim_size_dtype() to ' 'convert to a compatible dtype.') rt_input = rt_input.with_row_splits_dtype(dtypes.int64) dst_shape = dst_shape.with_dim_size_dtype(dtypes.int64) # dst_shape's rank and ragged_rank must be greater than or equal to rt_input's if rt_input.shape.ndims is None or dst_shape.rank is None: raise ValueError('Unable to broadcast: unknown rank') if rt_input.shape.ndims > dst_shape.rank: raise ValueError('Incompatible with shape: rank mismatch') if (isinstance(rt_input, ragged_tensor.RaggedTensor) and rt_input.ragged_rank >= dst_shape.num_partitioned_dimensions): raise ValueError('Incompatible with shape: ragged rank mismatch') src_shape = RaggedTensorDynamicShape.from_tensor(rt_input) src_shape = src_shape.broadcast_to_rank(dst_shape.rank) # Add dimensions to rt_input so its rank and ragged_rank matches dst_shape. if dst_shape.rank > rt_input.shape.ndims: if rt_input.shape.ndims < dst_shape.num_inner_dimensions + 1: rt_input = array_ops.reshape( rt_input, array_ops.concat([[-1], dst_shape.inner_dim_sizes], axis=0)) for _ in range(dst_shape.rank - rt_input.shape.ndims): if ragged_tensor.is_ragged(rt_input): nrows = rt_input.nrows() else: nrows = array_ops.shape(rt_input, out_type=dst_shape.dim_size_dtype)[0] rt_input = ragged_tensor.RaggedTensor.from_row_lengths(rt_input, [nrows], validate=False) # Add ragged dimensions to match dst_shape. if ragged_tensor.is_ragged(rt_input): inner_rank_diff = ( rt_input.flat_values.shape.ndims - 1 - dst_shape.num_inner_dimensions) if inner_rank_diff > 0: rt_input = rt_input.with_flat_values( ragged_tensor.RaggedTensor.from_tensor( rt_input.flat_values, ragged_rank=inner_rank_diff, row_splits_dtype=dst_shape.dim_size_dtype)) else: rt_input = ragged_tensor.RaggedTensor.from_tensor( rt_input, ragged_rank=dst_shape.num_partitioned_dimensions - 1, row_splits_dtype=dst_shape.dim_size_dtype) # Do broadcasting for any dimensions that will remain uniform. We can do # these all at once, since they're independent of one another. multiples = [1] * dst_shape.rank for axis in range(dst_shape.num_partitioned_dimensions): if not src_shape.is_ragged(axis) and not dst_shape.is_ragged(axis): src_size = src_shape.dimension_size(axis) dst_size = dst_shape.dimension_size(axis) if ((tensor_util.constant_value(src_size) in (1, None)) and (tensor_util.constant_value(dst_size) != 1)): multiples[axis] = array_ops.where( math_ops.equal(src_size, 1), dst_size, 1) if not all(isinstance(v, int) and v == 1 for v in multiples): multiples = array_ops.stack(multiples, axis=0) rt_input = ragged_array_ops.tile(rt_input, multiples) if broadcast_inner_dimensions: rt_input = rt_input.with_flat_values( array_ops.reshape( rt_input.flat_values, array_ops.concat([[-1], dst_shape.inner_dim_sizes], axis=0))) # Do broadcasting for dimensions that become ragged. We must do these from # outermost to innermost. for axis in range(dst_shape.num_partitioned_dimensions): if not src_shape.is_ragged(axis) and dst_shape.is_ragged(axis): dst_size = dst_shape.dimension_size(axis) rt_input = _ragged_tile_axis(rt_input, axis, dst_size, dst_shape.dim_size_dtype) return rt_input
def _broadcast_to_ragged_shape(rt_input, dst_shape, broadcast_inner_dimensions): """Broadcasts rt_input to the ragged shape `dst_shape`.""" # Check that rt_input and dst_shape have the same row_splits dtype. if (isinstance(rt_input, ragged_tensor.RaggedTensor) and rt_input.row_splits.dtype != dst_shape.dim_size_dtype): if not ragged_config.auto_cast_partition_dtype(): raise ValueError( 'rt_input and dst_shape have different row_split ' 'dtypes; use RaggedTensor.with_row_splits_dtype() or ' 'RaggedTensorDynamicShape.with_dim_size_dtype() to ' 'convert to a compatible dtype.') rt_input = rt_input.with_row_splits_dtype(dtypes.int64) dst_shape = dst_shape.with_dim_size_dtype(dtypes.int64) # dst_shape's rank and ragged_rank must be greater than or equal to rt_input's if rt_input.shape.ndims is None or dst_shape.rank is None: raise ValueError('Unable to broadcast: unknown rank') if rt_input.shape.ndims > dst_shape.rank: raise ValueError('Incompatible with shape: rank mismatch') if (isinstance(rt_input, ragged_tensor.RaggedTensor) and rt_input.ragged_rank >= dst_shape.num_partitioned_dimensions): raise ValueError('Incompatible with shape: ragged rank mismatch') src_shape = RaggedTensorDynamicShape.from_tensor(rt_input) src_shape = src_shape.broadcast_to_rank(dst_shape.rank) # Add dimensions to rt_input so its rank and ragged_rank matches dst_shape. if dst_shape.rank > rt_input.shape.ndims: if rt_input.shape.ndims < dst_shape.num_inner_dimensions + 1: rt_input = array_ops.reshape( rt_input, array_ops.concat([[-1], dst_shape.inner_dim_sizes], axis=0)) for _ in range(dst_shape.rank - rt_input.shape.ndims): if ragged_tensor.is_ragged(rt_input): nrows = rt_input.nrows() else: nrows = array_ops.shape(rt_input, out_type=dst_shape.dim_size_dtype)[0] rt_input = ragged_tensor.RaggedTensor.from_row_lengths( rt_input, [nrows], validate=False) # Add ragged dimensions to match dst_shape. if ragged_tensor.is_ragged(rt_input): inner_rank_diff = (rt_input.flat_values.shape.ndims - 1 - dst_shape.num_inner_dimensions) if inner_rank_diff > 0: rt_input = rt_input.with_flat_values( ragged_tensor.RaggedTensor.from_tensor( rt_input.flat_values, ragged_rank=inner_rank_diff, row_splits_dtype=dst_shape.dim_size_dtype)) else: rt_input = ragged_tensor.RaggedTensor.from_tensor( rt_input, ragged_rank=dst_shape.num_partitioned_dimensions - 1, row_splits_dtype=dst_shape.dim_size_dtype) # Do broadcasting for any dimensions that will remain uniform. We can do # these all at once, since they're independent of one another. multiples = [1] * dst_shape.rank for axis in range(dst_shape.num_partitioned_dimensions): if not src_shape.is_ragged(axis) and not dst_shape.is_ragged(axis): src_size = src_shape.dimension_size(axis) dst_size = dst_shape.dimension_size(axis) if ((tensor_util.constant_value(src_size) in (1, None)) and (tensor_util.constant_value(dst_size) != 1)): multiples[axis] = array_ops.where(math_ops.equal(src_size, 1), dst_size, 1) if not all(isinstance(v, int) and v == 1 for v in multiples): multiples = array_ops.stack(multiples, axis=0) rt_input = ragged_array_ops.tile(rt_input, multiples) if broadcast_inner_dimensions: rt_input = rt_input.with_flat_values( array_ops.reshape( rt_input.flat_values, array_ops.concat([[-1], dst_shape.inner_dim_sizes], axis=0))) # Do broadcasting for dimensions that become ragged. We must do these from # outermost to innermost. for axis in range(dst_shape.num_partitioned_dimensions): if not src_shape.is_ragged(axis) and dst_shape.is_ragged(axis): dst_size = dst_shape.dimension_size(axis) rt_input = _ragged_tile_axis(rt_input, axis, dst_size, dst_shape.dim_size_dtype) return rt_input
def map_flat_values(op, *args, **kwargs): """Applies `op` to the values of one or more RaggedTensors. Replaces any `RaggedTensor` in `args` or `kwargs` with its `flat_values` tensor, and then calls `op`. Returns a `RaggedTensor` that is constructed from the input `RaggedTensor`s' `nested_row_splits` and the value returned by the `op`. If the input arguments contain multiple `RaggedTensor`s, then they must have identical `nested_row_splits`. Examples: ```python >>> rt = ragged.constant([[1, 2, 3], [], [4, 5], [6]]) >>> ragged.map_flat_values(tf.ones_like, rt).eval().tolist() [[1, 1, 1], [], [1, 1], [1]] >>> ragged.map_flat_values(tf.multiply, rt, rt).eval().tolist() [[1, 4, 9], [], [16, 25], [36]] >>> ragged.map_flat_values(tf.add, rt, 5).eval().tolist() [[6, 7, 8], [], [9, 10], [11]] ``` Args: op: The operation that should be applied to the RaggedTensor `flat_values`. `op` is typically an element-wise operation (such as math_ops.add), but any operation that preserves the size of the outermost dimension can be used. I.e., `shape[0]` of the value returned by `op` must match `shape[0]` of the `RaggedTensor`s' `flat_values` tensors. *args: Arguments for `op`. **kwargs: Keyword arguments for `op`. Returns: A `RaggedTensor` whose `ragged_rank` matches the `ragged_rank` of all input `RaggedTensor`s. Raises: ValueError: If args contains no `RaggedTensors`, or if the `nested_splits` of the input `RaggedTensor`s are not identical. """ # Replace RaggedTensors with their values; and collect the splits tensors # from each RaggedTensor. nested_splits_lists = [] inner_args = _replace_ragged_with_flat_values(args, nested_splits_lists) inner_kwargs = _replace_ragged_with_flat_values(kwargs, nested_splits_lists) if not nested_splits_lists: return op(*args, **kwargs) split_dtypes = set(splits[0].dtype for splits in nested_splits_lists) if len(split_dtypes) > 1: if not ragged_config.auto_cast_partition_dtype(): raise ValueError( "Input RaggedTensors have mismatched row_splits dtypes; " "use RaggedTensor.with_row_splits_dtype() to convert " "them to compatible dtypes.") nested_splits_lists = [ [math_ops.cast(s, dtypes.int64) for s in nested_splits] # pylint: disable=g-complex-comprehension for nested_splits in nested_splits_lists ] with ops.control_dependencies( ragged_util.assert_splits_match(nested_splits_lists)): # Delegate to op, and then compose the result from the transformed values # and the splits. return ragged_tensor.RaggedTensor.from_nested_row_splits( op(*inner_args, **inner_kwargs), nested_splits_lists[0], validate=False)
def map_flat_values(op, *args, **kwargs): """Applies `op` to the `flat_values` of one or more RaggedTensors. Replaces any `RaggedTensor` in `args` or `kwargs` with its `flat_values` tensor (which collapses all ragged dimensions), and then calls `op`. Returns a `RaggedTensor` that is constructed from the input `RaggedTensor`s' `nested_row_splits` and the value returned by the `op`. If the input arguments contain multiple `RaggedTensor`s, then they must have identical `nested_row_splits`. This operation is generally used to apply elementwise operations to each value in a `RaggedTensor`. Warning: `tf.ragged.map_flat_values` does *not* apply `op` to each row of a ragged tensor. This difference is important for non-elementwise operations, such as `tf.reduce_sum`. If you wish to apply a non-elementwise operation to each row of a ragged tensor, use `tf.map_fn` instead. (You may need to specify an `output_signature` when using `tf.map_fn` with ragged tensors.) Examples: >>> rt = tf.ragged.constant([[1, 2, 3], [], [4, 5], [6]]) >>> tf.ragged.map_flat_values(tf.ones_like, rt) <tf.RaggedTensor [[1, 1, 1], [], [1, 1], [1]]> >>> tf.ragged.map_flat_values(tf.multiply, rt, rt) <tf.RaggedTensor [[1, 4, 9], [], [16, 25], [36]]> >>> tf.ragged.map_flat_values(tf.add, rt, 5) <tf.RaggedTensor [[6, 7, 8], [], [9, 10], [11]]> Example with a non-elementwise operation (note that `map_flat_values` and `map_fn` return different results): >>> rt = tf.ragged.constant([[1.0, 3.0], [], [3.0, 6.0, 3.0]]) >>> def normalized(x): ... return x / tf.reduce_sum(x) >>> tf.ragged.map_flat_values(normalized, rt) <tf.RaggedTensor [[0.0625, 0.1875], [], [0.1875, 0.375, 0.1875]]> >>> tf.map_fn(normalized, rt) <tf.RaggedTensor [[0.25, 0.75], [], [0.25, 0.5, 0.25]]> Args: op: The operation that should be applied to the RaggedTensor `flat_values`. `op` is typically an element-wise operation (such as math_ops.add), but any operation that preserves the size of the outermost dimension can be used. I.e., `shape[0]` of the value returned by `op` must match `shape[0]` of the `RaggedTensor`s' `flat_values` tensors. *args: Arguments for `op`. **kwargs: Keyword arguments for `op`. Returns: A `RaggedTensor` whose `ragged_rank` matches the `ragged_rank` of all input `RaggedTensor`s. Raises: ValueError: If args contains no `RaggedTensors`, or if the `nested_splits` of the input `RaggedTensor`s are not identical. """ # Replace RaggedTensors with their values; and collect the partitions tensors # from each RaggedTensor. partition_lists = [] flat_values_nrows = [] inner_args = _replace_ragged_with_flat_values(args, partition_lists, flat_values_nrows) inner_kwargs = _replace_ragged_with_flat_values(kwargs, partition_lists, flat_values_nrows) if not partition_lists: return op(*args, **kwargs) # If we can statically determine that the inputs are incompatible, then raise # an error. (We can't guarantee full compatibility statically, so we need to # perform some runtime checks too; but this allows us to fail sooner in some # cases.) if flat_values_nrows: flat_values_nrows = set(flat_values_nrows) if len(flat_values_nrows) != 1: raise ValueError( "Input RaggedTensors' flat_values must all have the " "same outer-dimension size. Got sizes: %s" % flat_values_nrows) flat_values_nrows = flat_values_nrows.pop() # Get the single element else: flat_values_nrows = None partition_dtypes = set(p[0].dtype for p in partition_lists) if len(partition_dtypes) > 1: if not ragged_config.auto_cast_partition_dtype(): raise ValueError( "Input RaggedTensors have mismatched row partition " "dtypes; use RaggedTensor.with_row_splits_dtype() to " "convert them to compatible dtypes.") partition_lists = [ [p.with_dtype(dtypes.int64) for p in partition_list] # pylint: disable=g-complex-comprehension for partition_list in partition_lists ] # Delegate to `op` op_output = op(*inner_args, **inner_kwargs) # Check that the result has the expected shape (if known). if flat_values_nrows is not None: if not op_output.shape[:1].is_compatible_with([flat_values_nrows]): raise ValueError( "tf.ragged.map_flat_values requires that the output of `op` have " "the same outer-dimension size as flat_values of any ragged " "inputs. (output shape: %s; expected outer dimension size: %s)" % (op_output.shape, flat_values_nrows)) # Compose the result from the transformed values and the partitions. return ragged_tensor.RaggedTensor._from_nested_row_partitions( # pylint: disable=protected-access op_output, _merge_partition_lists(partition_lists), validate=False)
def map_flat_values(op, *args, **kwargs): """Applies `op` to the values of one or more RaggedTensors. Replaces any `RaggedTensor` in `args` or `kwargs` with its `flat_values` tensor, and then calls `op`. Returns a `RaggedTensor` that is constructed from the input `RaggedTensor`s' `nested_row_splits` and the value returned by the `op`. If the input arguments contain multiple `RaggedTensor`s, then they must have identical `nested_row_splits`. Examples: ```python >>> rt = ragged.constant([[1, 2, 3], [], [4, 5], [6]]) >>> ragged.map_flat_values(tf.ones_like, rt).eval().tolist() [[1, 1, 1], [], [1, 1], [1]] >>> ragged.map_flat_values(tf.multiply, rt, rt).eval().tolist() [[1, 4, 9], [], [16, 25], [36]] >>> ragged.map_flat_values(tf.add, rt, 5).eval().tolist() [[6, 7, 8], [], [9, 10], [11]] ``` Args: op: The operation that should be applied to the RaggedTensor `flat_values`. `op` is typically an element-wise operation (such as math_ops.add), but any operation that preserves the size of the outermost dimension can be used. I.e., `shape[0]` of the value returned by `op` must match `shape[0]` of the `RaggedTensor`s' `flat_values` tensors. *args: Arguments for `op`. **kwargs: Keyword arguments for `op`. Returns: A `RaggedTensor` whose `ragged_rank` matches the `ragged_rank` of all input `RaggedTensor`s. Raises: ValueError: If args contains no `RaggedTensors`, or if the `nested_splits` of the input `RaggedTensor`s are not identical. """ # Replace RaggedTensors with their values; and collect the splits tensors # from each RaggedTensor. nested_splits_lists = [] inner_args = _replace_ragged_with_flat_values(args, nested_splits_lists) inner_kwargs = _replace_ragged_with_flat_values(kwargs, nested_splits_lists) if not nested_splits_lists: return op(*args, **kwargs) split_dtypes = set(splits[0].dtype for splits in nested_splits_lists) if len(split_dtypes) > 1: if not ragged_config.auto_cast_partition_dtype(): raise ValueError("Input RaggedTensors have mismatched row_splits dtypes; " "use RaggedTensor.with_row_splits_dtype() to convert " "them to compatible dtypes.") nested_splits_lists = [ [math_ops.cast(s, dtypes.int64) for s in nested_splits] # pylint: disable=g-complex-comprehension for nested_splits in nested_splits_lists] with ops.control_dependencies( ragged_util.assert_splits_match(nested_splits_lists)): # Delegate to op, and then compose the result from the transformed values # and the splits. return ragged_tensor.RaggedTensor.from_nested_row_splits( op(*inner_args, **inner_kwargs), nested_splits_lists[0], validate=False)
def map_flat_values(op, *args, **kwargs): """Applies `op` to the values of one or more RaggedTensors. Replaces any `RaggedTensor` in `args` or `kwargs` with its `flat_values` tensor, and then calls `op`. Returns a `RaggedTensor` that is constructed from the input `RaggedTensor`s' `nested_row_splits` and the value returned by the `op`. If the input arguments contain multiple `RaggedTensor`s, then they must have identical `nested_row_splits`. Examples: >>> rt = tf.ragged.constant([[1, 2, 3], [], [4, 5], [6]]) >>> map_flat_values(tf.ones_like, rt).to_list() [[1, 1, 1], [], [1, 1], [1]] >>> map_flat_values(tf.multiply, rt, rt).to_list() [[1, 4, 9], [], [16, 25], [36]] >>> map_flat_values(tf.add, rt, 5).to_list() [[6, 7, 8], [], [9, 10], [11]] Args: op: The operation that should be applied to the RaggedTensor `flat_values`. `op` is typically an element-wise operation (such as math_ops.add), but any operation that preserves the size of the outermost dimension can be used. I.e., `shape[0]` of the value returned by `op` must match `shape[0]` of the `RaggedTensor`s' `flat_values` tensors. *args: Arguments for `op`. **kwargs: Keyword arguments for `op`. Returns: A `RaggedTensor` whose `ragged_rank` matches the `ragged_rank` of all input `RaggedTensor`s. Raises: ValueError: If args contains no `RaggedTensors`, or if the `nested_splits` of the input `RaggedTensor`s are not identical. """ # Replace RaggedTensors with their values; and collect the splits tensors # from each RaggedTensor. nested_splits_lists = [] flat_values_nrows = [] inner_args = _replace_ragged_with_flat_values(args, nested_splits_lists, flat_values_nrows) inner_kwargs = _replace_ragged_with_flat_values(kwargs, nested_splits_lists, flat_values_nrows) if not nested_splits_lists: return op(*args, **kwargs) if flat_values_nrows: flat_values_nrows = set(flat_values_nrows) if len(flat_values_nrows) != 1: raise ValueError("Input RaggedTensors' flat_values must all have the " "same outer-dimension size. Got sizes: %s" % flat_values_nrows) flat_values_nrows = flat_values_nrows.pop() # Get the single element else: flat_values_nrows = None split_dtypes = set(splits[0].dtype for splits in nested_splits_lists) if len(split_dtypes) > 1: if not ragged_config.auto_cast_partition_dtype(): raise ValueError("Input RaggedTensors have mismatched row_splits dtypes; " "use RaggedTensor.with_row_splits_dtype() to convert " "them to compatible dtypes.") nested_splits_lists = [ [math_ops.cast(s, dtypes.int64) for s in nested_splits] # pylint: disable=g-complex-comprehension for nested_splits in nested_splits_lists] with ops.control_dependencies( ragged_util.assert_splits_match(nested_splits_lists)): # Delegate to `op` op_output = op(*inner_args, **inner_kwargs) # Check that the result has the expected shape (if known). if flat_values_nrows is not None: if not op_output.shape[:1].is_compatible_with([flat_values_nrows]): raise ValueError( "tf.ragged.map_flat_values requires that the output of `op` have " "the same outer-dimension size as flat_values of any ragged " "inputs. (output shape: %s; expected outer dimension size: %s)" % (op_output.shape, flat_values_nrows)) # Compose the result from the transformed values and the splits. return ragged_tensor.RaggedTensor.from_nested_row_splits( op_output, nested_splits_lists[0], validate=False)