def where_v2(condition, x=None, y=None, name=None): """Return the elements where `condition` is `True`. : If both `x` and `y` are None: Retrieve indices of true elements. Returns the coordinates of true elements of `condition`. The coordinates are returned in a 2-D tensor with shape `[num_true_values, dim_size(condition)]`, where `result[i]` is the coordinates of the `i`th true value (in row-major order). : If both `x` and `y` are non-`None`: Multiplex between `x` and `y`. Choose an output shape from the shapes of `condition`, `x`, and `y` that all three shapes are broadcastable to; and then use the broadcasted `condition` tensor as a mask that chooses whether the corredsponding element in the output should be taken from `x` (if `condition` is true) or `y` (if `condition` is false). >>> # Example: retrieve indices of true elements >>> tf.where(tf.ragged.constant([[True, False], [True]])) <tf.Tensor: shape=(2, 2), dtype=int64, numpy= array([[0, 0], [1, 0]])> >>> # Example: multiplex between `x` and `y` >>> tf.where(tf.ragged.constant([[True, False], [True, False, True]]), ... tf.ragged.constant([['A', 'B'], ['C', 'D', 'E']]), ... tf.ragged.constant([['a', 'b'], ['c', 'd', 'e']])) <tf.RaggedTensor [[b'A', b'b'], [b'C', b'd', b'E']]> Args: condition: A potentially ragged tensor of type `bool` x: A potentially ragged tensor (optional). y: A potentially ragged tensor (optional). Must be specified if `x` is specified. Must have the same rank and type as `x`. name: A name of the operation (optional). Returns: : If both `x` and `y` are `None`: A `Tensor` with shape `(num_true, rank(condition))`. : Otherwise: A potentially ragged tensor with the same type as `x` and `y`, and whose shape is broadcast-compatible with `x`, `y`, and `condition`. Raises: ValueError: When exactly one of `x` or `y` is non-`None`; or when `condition`, `x`, and `y` have incompatible shapes. """ if (x is None) != (y is None): raise ValueError('x and y must be either both None or both non-None') with ops.name_scope('RaggedWhere', name, [condition, x, y]): condition = ragged_tensor.convert_to_tensor_or_ragged_tensor( condition, name='condition') if x is None: return _coordinate_where(condition) else: x = ragged_tensor.convert_to_tensor_or_ragged_tensor(x, name='x') y = ragged_tensor.convert_to_tensor_or_ragged_tensor(y, name='y') condition, x, y = ragged_tensor.match_row_splits_dtypes( condition, x, y) return _elementwise_where_v2(condition, x, y)
def handle(self, args, kwargs): # Extract the binary args. if len(args) > 1: x = args[0] y = args[1] args = args[2:] elif args: kwargs = kwargs.copy() x = args[0] y = kwargs.pop(self._y, None) args = args[1:] else: kwargs = kwargs.copy() x = kwargs.pop(self._x, None) y = kwargs.pop(self._y, None) # Bail if we don't have at least one ragged argument. x_is_ragged = ragged_tensor.is_ragged(x) y_is_ragged = ragged_tensor.is_ragged(y) if not (x_is_ragged or y_is_ragged): return self.NOT_SUPPORTED # Convert args to tensors. Bail if conversion fails. try: x = ragged_tensor.convert_to_tensor_or_ragged_tensor( x, name=self._x, preferred_dtype=(y.dtype if y_is_ragged else None)) y = ragged_tensor.convert_to_tensor_or_ragged_tensor( y, name=self._y, preferred_dtype=(x.dtype if x_is_ragged else None)) except (TypeError, ValueError): return self.NOT_SUPPORTED if x_is_ragged and y_is_ragged: x, y = ragged_tensor.match_row_splits_dtypes(x, y) if ((x_is_ragged and y_is_ragged) or (x_is_ragged and x.flat_values.shape.ndims <= y.shape.ndims) or (y_is_ragged and y.flat_values.shape.ndims <= x.shape.ndims)): bcast_shape = ragged_tensor_shape.broadcast_dynamic_shape( ragged_tensor_shape.RaggedTensorDynamicShape.from_tensor(x), ragged_tensor_shape.RaggedTensorDynamicShape.from_tensor(y)) x = ragged_tensor_shape.broadcast_to( x, bcast_shape, broadcast_inner_dimensions=False) y = ragged_tensor_shape.broadcast_to( y, bcast_shape, broadcast_inner_dimensions=False) x_values = x.flat_values if ragged_tensor.is_ragged(x) else x y_values = y.flat_values if ragged_tensor.is_ragged(y) else y mapped_values = self._original_op(x_values, y_values, *args, **kwargs) if ragged_tensor.is_ragged(x): return x.with_flat_values(mapped_values) else: return y.with_flat_values(mapped_values)
def handle(self, args, kwargs): # Extract the binary args. if len(args) > 1: x = args[0] y = args[1] args = args[2:] elif args: kwargs = kwargs.copy() x = args[0] y = kwargs.pop(self._y, None) args = args[1:] else: kwargs = kwargs.copy() x = kwargs.pop(self._x, None) y = kwargs.pop(self._y, None) # Bail if we don't have at least one ragged argument. x_is_ragged = ragged_tensor.is_ragged(x) y_is_ragged = ragged_tensor.is_ragged(y) if not (x_is_ragged or y_is_ragged): return self.NOT_SUPPORTED # Convert args to tensors. Bail if conversion fails. try: if not x_is_ragged: x = ops.convert_to_tensor(x, name=self._x, preferred_dtype=y.dtype) if not y_is_ragged: y = ops.convert_to_tensor(y, name=self._y, preferred_dtype=x.dtype) except (TypeError, ValueError): return self.NOT_SUPPORTED if x_is_ragged and y_is_ragged: x, y = ragged_tensor.match_row_splits_dtypes(x, y) if ((x_is_ragged and y_is_ragged) or (x_is_ragged and x.flat_values.shape.ndims <= y.shape.ndims) or (y_is_ragged and y.flat_values.shape.ndims <= x.shape.ndims)): bcast_shape = ragged_tensor_shape.broadcast_dynamic_shape( ragged_tensor_shape.RaggedTensorDynamicShape.from_tensor(x), ragged_tensor_shape.RaggedTensorDynamicShape.from_tensor(y)) x = ragged_tensor_shape.broadcast_to( x, bcast_shape, broadcast_inner_dimensions=False) y = ragged_tensor_shape.broadcast_to( y, bcast_shape, broadcast_inner_dimensions=False) x_values = x.flat_values if ragged_tensor.is_ragged(x) else x y_values = y.flat_values if ragged_tensor.is_ragged(y) else y mapped_values = self._original_op(x_values, y_values, *args, **kwargs) if ragged_tensor.is_ragged(x): return x.with_flat_values(mapped_values) else: return y.with_flat_values(mapped_values)
def handle(self, args, kwargs): if args: x, args = args[0], args[1:] else: kwargs = kwargs.copy() x = kwargs.pop(self._x, None) if x is None: return self.NOT_SUPPORTED if self._arg_is_list: found_ragged = False for elt in x: if ragged_tensor.is_ragged(elt): found_ragged = True elif not _is_convertible_to_tensor(elt): return self.NOT_SUPPORTED if found_ragged: x = [ ragged_tensor.convert_to_tensor_or_ragged_tensor(elt) if ragged_tensor.is_ragged(elt) else elt for elt in x ] x = ragged_tensor.match_row_splits_dtypes(*x) ragged_elts = [ elt for elt in x if ragged_tensor.is_ragged(elt) ] nested_splits_lists = [ elt.nested_row_splits for elt in ragged_elts ] flat_values = [ elt.flat_values if ragged_tensor.is_ragged(elt) else elt for elt in x ] with ops.control_dependencies( ragged_util.assert_splits_match(nested_splits_lists)): return ragged_elts[0].with_flat_values( self._original_op(flat_values, *args, **kwargs)) else: return self.NOT_SUPPORTED else: found_ragged = ragged_tensor.is_ragged(x) if found_ragged: x = ragged_tensor.convert_to_tensor_or_ragged_tensor( x, name=self._x) mapped_values = self._original_op(x.flat_values, *args, **kwargs) return x.with_flat_values(mapped_values) else: return self.NOT_SUPPORTED
def ragged_binary_elementwise_op(op, x, y): """Binary elementwise api handler for RaggedTensors.""" x_is_ragged = ragged_tensor.is_ragged(x) y_is_ragged = ragged_tensor.is_ragged(y) # Convert args to tensors. x = ragged_tensor.convert_to_tensor_or_ragged_tensor( x, preferred_dtype=(y.dtype if y_is_ragged else None)) y = ragged_tensor.convert_to_tensor_or_ragged_tensor( y, preferred_dtype=x.dtype) if x_is_ragged and y_is_ragged: x, y = ragged_tensor.match_row_splits_dtypes(x, y) # Perform broadcasting, when appropraite if ((x_is_ragged and y_is_ragged) or (x_is_ragged and x.flat_values.shape.ndims <= y.shape.ndims) or (y_is_ragged and y.flat_values.shape.ndims <= x.shape.ndims)): # If both x and y are ragged, they must have the same row_splits_dtype now. if x_is_ragged: dim_size_dtype = x.row_splits.dtype else: dim_size_dtype = y.row_splits.dtype shape_x = ragged_tensor_shape.RaggedTensorDynamicShape.from_tensor( x, dim_size_dtype=dim_size_dtype) shape_y = ragged_tensor_shape.RaggedTensorDynamicShape.from_tensor( y, dim_size_dtype=dim_size_dtype) bcast_shape = ragged_tensor_shape.broadcast_dynamic_shape( shape_x, shape_y) x = ragged_tensor_shape.broadcast_to(x, bcast_shape, broadcast_inner_dimensions=False) y = ragged_tensor_shape.broadcast_to(y, bcast_shape, broadcast_inner_dimensions=False) x_values = x.flat_values if ragged_tensor.is_ragged(x) else x y_values = y.flat_values if ragged_tensor.is_ragged(y) else y mapped_values = op(x_values, y_values) if isinstance(mapped_values, bool): return mapped_values # Special case for tensor_equals. if ragged_tensor.is_ragged(x): return x.with_flat_values(mapped_values) else: return y.with_flat_values(mapped_values)
def handle(self, args, kwargs): if args: x, args = args[0], args[1:] else: kwargs = kwargs.copy() x = kwargs.pop(self._x, None) if x is None: return self.NOT_SUPPORTED if self._arg_is_list: found_ragged = False for elt in x: if ragged_tensor.is_ragged(elt): found_ragged = True elif not _is_convertible_to_tensor(elt): return self.NOT_SUPPORTED if found_ragged: x = ragged_tensor.match_row_splits_dtypes(*x) nested_splits_lists = [ elt.nested_row_splits for elt in x if ragged_tensor.is_ragged(elt) ] flat_values = [ elt.flat_values if ragged_tensor.is_ragged(elt) else elt for elt in x ] with ops.control_dependencies( ragged_util.assert_splits_match(nested_splits_lists)): return ragged_tensor.RaggedTensor.from_nested_row_splits( self._original_op(flat_values, *args, **kwargs), nested_splits_lists[0], validate=False) else: return self.NOT_SUPPORTED else: found_ragged = ragged_tensor.is_ragged(x) if found_ragged: mapped_values = self._original_op(x.flat_values, *args, **kwargs) return x.with_flat_values(mapped_values) else: return self.NOT_SUPPORTED
def _ragged_segment_aggregate(unsorted_segment_op, data, segment_ids, num_segments, separator=None, name=None): """Aggregates along segments of a RaggedTensor using `unsorted_segment_op`. Returns a RaggedTensor `output` with `num_segments` rows, where the row `output[i]` is formed by combining all rows of `data` whose corresponding `segment_id` is `i`. The values in each row are combined using `unsorted_segment_op`. The length of the row `output[i]` will be the maximum of the lengths of all rows of `data` whose corresponding `segment_id` is `i`. If no `data` rows correspond to a given segment ID, then the output row for that segment ID will be empty. Args: unsorted_segment_op: The tensorflow `op` that should be used to combine values in each row. Must have the same signature and basic behavior as `unsorted_segment_sum`, `unsorted_segment_max`, etc. data: A `RaggedTensor` containing the values to be combined. segment_ids: A `Tensor` or `RaggedTensor`. Must have type `int64` or `int32`. `segment_ids.shape` must be a prefix of `data.shape`. `segment_ids` is not required to be sorted. num_segments: An `int32` or `int64` scalar. separator: An optional string. Defaults to None. The separator to use when joining. Only used for string types. name: A name prefix for the returned tensor (optional). Returns: A `RaggedTensor` containing the aggregated values. The returned tensor has the same dtype as `data`, and its shape is `[num_segments] + data.shape[segment_ids.rank:]`. Raises: ValueError: If segment_ids.shape is not a prefix of data.shape. """ if not (ragged_tensor.is_ragged(data) or ragged_tensor.is_ragged(segment_ids)): if separator is not None: # It uses unsorted_segment_join. return unsorted_segment_op(data, segment_ids, num_segments, separator, name) else: return unsorted_segment_op(data, segment_ids, num_segments, name) with ops.name_scope(name, 'RaggedSegment', [data, segment_ids, num_segments]) as name: data = ragged_tensor.convert_to_tensor_or_ragged_tensor(data, name='data') segment_ids = ragged_tensor.convert_to_tensor_or_ragged_tensor( segment_ids, name='segment_ids') data, segment_ids = ragged_tensor.match_row_splits_dtypes(data, segment_ids) if segment_ids.dtype not in (dtypes.int32, dtypes.int64): raise ValueError('segment_ids must have dtype int32 or int64.') if ragged_tensor.is_ragged(segment_ids): if not ragged_tensor.is_ragged(data): raise ValueError('segment_ids.shape must be a prefix of data.shape, ' 'but segment_ids is ragged and data is not.') check_splits = check_ops.assert_equal( segment_ids.row_splits, data.row_splits, message='segment_ids.shape must be a prefix of data.shape') with ops.control_dependencies([check_splits]): return _ragged_segment_aggregate(unsorted_segment_op, data.values, segment_ids.values, num_segments, separator) # Find the length of each row in data. (shape=[data_nrows]) data_row_lengths = data.row_splits[1:] - data.row_splits[:-1] # Find the length that each output row will have. The length of the row # corresponding to segment `id` is `max(data_row_lengths[i])` where # `segment_ids[i]=id`. (shape=[output_nrows]) output_row_lengths = math_ops.maximum( math_ops.unsorted_segment_max(data_row_lengths, segment_ids, num_segments), 0) # Build the splits tensor for the output RaggedTensor. output_splits = array_ops.concat([ array_ops.zeros([1], output_row_lengths.dtype), math_ops.cumsum(output_row_lengths) ], axis=0) # For each row in `data`, find the start & limit position where that row's # values will be aggregated in output.values. data_row_to_out_row_start = array_ops.gather(output_splits, segment_ids) data_row_to_out_row_limit = data_row_to_out_row_start + data_row_lengths # For each value in `data.values`, find the position where it will # aggregated in `output.values`. # Get the target output values index for each data values index. data_val_to_out_val_index = range(data_row_to_out_row_start, data_row_to_out_row_limit).values # Recursively aggregate the values. output_values = _ragged_segment_aggregate(unsorted_segment_op, data.values, data_val_to_out_val_index, output_splits[-1], separator) return ragged_tensor.RaggedTensor.from_row_splits( output_values, output_splits, validate=False)
def batch_gather(params, indices, name=None): """Gathers slices from `params` according to `indices` with batch dims. This operation is similar to `gather`, but it assumes that the leading `N` dimensions of `indices` and `params` are batch dimensions, and performs a gather within each batch. In particular, when using this operation with `N` batch dimensions `B1...BN`: * `indices` has shape `[B1...BN, I]` * `params` has shape `[B1...BN, P1...PM]`. * `result` has shape `[B1...BN, I, P2...PM]`. * `result[b1...bN, i, p2...pM] = params[b1...bN, indices[b1...bN, i], p2...pM]` Args: params: A potentially ragged tensor with shape `[B1...BN, P1...PM]` (`N>=0`, `M>0`). indices: A potentially ragged tensor with shape `[B1...BN, I]` (`N>=0`). name: A name for the operation (optional). Returns: A potentially ragged tensor with shape `[B1...BN, I, P2...PM]`. `result.ragged_rank = max(indices.ragged_rank, params.ragged_rank)`. #### Example: ```python >>> params = tf.ragged.constant([['a', 'b', 'c'], ['d'], [], ['e']]) >>> indices = tf.ragged.constant([[1, 2, 0], [], [], [0, 0]]) >>> tf.compat.v1.batch_gather(params, indices) [['b', 'c', 'a'], [], [], ['e', 'e']] ``` """ if not (ragged_tensor.is_ragged(params) or ragged_tensor.is_ragged(indices)): return array_ops.batch_gather(params, indices, name) with ops.name_scope(name, 'RaggedBatchGather', [params, indices]): params = ragged_tensor.convert_to_tensor_or_ragged_tensor( params, name='params') indices = ragged_tensor.convert_to_tensor_or_ragged_tensor( indices, name='indices') params, indices = ragged_tensor.match_row_splits_dtypes(params, indices) indices_ndims = indices.shape.ndims if indices_ndims is None: raise ValueError( 'batch_gather does not allow indices with unknown shape.') if indices_ndims == 0: raise ValueError('indices.rank must be at least 1.') if ragged_tensor.is_ragged(indices): # If the outermost ragged dimension is a batch dimension, recurse. if indices_ndims > 2: if not ragged_tensor.is_ragged(params): raise ValueError('batch shape from indices does ' 'not match params shape') checks = [check_ops.assert_equal(params.row_splits, indices.row_splits)] with ops.control_dependencies(checks): return ragged_tensor.RaggedTensor.from_row_splits( batch_gather(params.values, indices.values), indices.row_splits, validate=False) # Otherwise, indices is a 2D ragged tensor with 1 ragged dimension. else: # Ensure that `params` is ragged and has at least 2 dimensions. if not ragged_tensor.is_ragged(params): if params.shape.ndims is not None and params.shape.ndims < 2: raise ValueError('batch shape from indices does ' 'not match params shape') params = ragged_tensor.RaggedTensor.from_tensor( params, ragged_rank=1, row_splits_dtype=indices.row_splits.dtype) # Adjust indices from within-batch to global (in params.values), and # then use ragged.gather to gather them. num_indices = indices.row_lengths() params_starts = params.row_starts() adjustments = ragged_util.repeat(params_starts, num_indices, axis=0) adjusted_index_values = ( math_ops.cast(indices.values, adjustments.dtype) + adjustments) return ragged_tensor.RaggedTensor.from_row_splits( ragged_gather_ops.gather(params.values, adjusted_index_values), indices.row_splits, validate=False) else: # params is a RaggedTensor and indices is a Tensor. if indices_ndims == 1: return ragged_gather_ops.gather(params, indices) elif indices_ndims == 2: # Adjust indices from batch-local to global (in params.values) adjustments = array_ops.expand_dims(params.row_starts(), 1) adjusted_indices = ( math_ops.cast(indices, adjustments.dtype) + adjustments) return ragged_gather_ops.gather(params.values, adjusted_indices) else: raise ValueError('batch shape from indices does not match params shape')
def gather(params, indices, validate_indices=None, axis=0, batch_dims=0, name=None): """Gathers ragged slices from `params` axis `0` according to `indices`. Returns `RaggedTensor` output, such that: ```python output.shape = indices.shape + params.shape[1:] output.ragged_rank = indices.shape.ndims + params.ragged_rank output[i...j, d0...dn] = params[indices[i...j], d0...dn] ``` `params` may be ragged. `indices` may be ragged. `indices` must have dtype `int32` or `int64`. If any index is out of bounds, then an error is returned. Examples: >>> params = tf.constant(['a', 'b', 'c', 'd', 'e']) >>> indices = tf.constant([3, 1, 2, 1, 0]) >>> ragged_params = tf.ragged.constant([['a', 'b', 'c'], ['d'], [], ['e']]) >>> ragged_indices = tf.ragged.constant([[3, 1, 2], [1], [], [0]]) >>> tf.gather(params, ragged_indices) <tf.RaggedTensor [[b'd', b'b', b'c'], [b'b'], [], [b'a']]> >>> tf.gather(ragged_params, indices) <tf.RaggedTensor [[b'e'], [b'd'], [], [b'd'], [b'a', b'b', b'c']]> >>> tf.gather(ragged_params, ragged_indices) <tf.RaggedTensor [[[b'e'], [b'd'], []], [[b'd']], [], [[b'a', b'b', b'c']]]> Args: params: The potentially ragged tensor from which to gather values. Must be at least rank 1. indices: The potentially ragged tensor indicating which values to gather. Must have dtype `int32` or `int64`. Values must be in the range `[0, params.shape[0]]`. validate_indices: Ignored. axis: Must be zero. batch_dims: Must be zero. name: A name for the operation (optional). Returns: A `RaggedTensor`, where `output.dtype=params.dtype` and `output.shape=indices.shape + params.shape[1:]` and `output.ragged_rank=indices.shape.ndims + params.ragged_rank`. Raises: ValueError: If indices.shape.ndims is not known statically. """ del validate_indices if not isinstance(axis, int) or axis != 0: raise ValueError('axis != 0 is not supported for ragged gather yet.') if not isinstance(batch_dims, int) or batch_dims != 0: raise ValueError( 'batch_dims != 0 is not supported for ragged gather yet.') with ops.name_scope(name, 'RaggedGather', [params, indices]): params = ragged_tensor.convert_to_tensor_or_ragged_tensor( params, name='params') indices = ragged_tensor.convert_to_tensor_or_ragged_tensor( indices, name='indices') params, indices = ragged_tensor.match_row_splits_dtypes( params, indices) if ragged_tensor.is_ragged(indices): return indices.with_values(gather(params, indices.values)) if not ragged_tensor.is_ragged(params): return array_ops.gather(params, indices) indices = ops.convert_to_tensor(indices) if indices.shape.ndims is None: raise ValueError('indices.shape.ndims must be known statically') result = gen_ragged_array_ops.ragged_gather( indices=indices, params_dense_values=params.flat_values, params_nested_splits=params.nested_row_splits, OUTPUT_RAGGED_RANK=indices.shape.ndims + len(params.nested_row_splits) - 1) # Compose the RaggedTensor from splits & values. return ragged_tensor.RaggedTensor.from_nested_row_splits( result.output_dense_values, result.output_nested_splits, validate=False)
def gather_nd(params, indices, batch_dims=0, name=None): """Gather slices from `params` using `n`-dimensional indices. This operation is similar to `gather`, but it uses the innermost dimension of `indices` to define a slice into `params`. In particular, if: * `indices` has shape `[A1...AN, I]` * `params` has shape `[B1...BM]` Then: * `result` has shape `[A1...AN, B_{I+1}...BM]`. * `result[a1...aN] = params[indices[a1...aN, :]]` Args: params: A potentially ragged tensor with shape `[A1...AN, I]`. indices: A potentially ragged tensor with shape `[B1...BM]`. batch_dims: Must be zero. name: A name for the operation (optional). Returns: A potentially ragged tensor with shape `[A1...AN, B_{I+1}...BM]`. #### Examples: >>> params = tf.ragged.constant( ... [ [ ['000', '001'], ['010' ] ], ... [ ['100' ], ['110', '111', '112'], ['120'] ], ... [ [ ], ['210' ] ] ]) >>> # Gather 2D slices from a 3D tensor >>> tf.gather_nd(params, [[2], [0]]) <tf.RaggedTensor [[[], [b'210']], [[b'000', b'001'], [b'010']]]> >>> # Gather 1D slices from a 3D tensor >>> tf.gather_nd(params, [[2, 1], [0, 0]]) <tf.RaggedTensor [[b'210'], [b'000', b'001']]> >>> # Gather scalars from a 3D tensor >>> tf.gather_nd(params, [[0, 0, 1], [1, 1, 2]]).numpy() array([b'001', b'112'], dtype=object) """ if not isinstance(batch_dims, int) or batch_dims != 0: raise ValueError( 'batch_dims != 0 is not supported for ragged gather yet.') if not (ragged_tensor.is_ragged(params) or ragged_tensor.is_ragged(indices)): return array_ops.gather_nd(params, indices, name) with ops.name_scope(name, 'RaggedGatherNd', [params, indices]): params = ragged_tensor.convert_to_tensor_or_ragged_tensor( params, name='params') indices = ragged_tensor.convert_to_tensor_or_ragged_tensor( indices, name='indices') params, indices = ragged_tensor.match_row_splits_dtypes( params, indices) indices_shape = indices.shape indices_ndims = indices_shape.ndims if indices_ndims is None: raise ValueError('indices.rank be statically known.') if indices_ndims == 0: raise ValueError('indices.rank must be at least 1.') if (ragged_tensor.is_ragged(indices) and indices_ndims == indices.ragged_rank + 1): raise ValueError( 'The innermost dimension of indices may not be ragged') # `index_size` is the "n" in "gather_nd" -- i.e., the number of dimensions # that each index slices into. index_size = tensor_shape.dimension_value(indices_shape[-1]) if index_size is None: raise ValueError('indices.shape[-1] must be statically known.') # If `indices` has more than 2 dimensions, then recurse. If `indices` is # dense, then we convert it to ragged before recursing, and then convert # the result back to `dense` if appropriate. if indices_ndims > 2: indices_is_dense = not ragged_tensor.is_ragged(indices) if indices_is_dense: indices = ragged_tensor.RaggedTensor.from_tensor( indices, ragged_rank=indices_ndims - 2, row_splits_dtype=params.row_splits.dtype) result = indices.with_flat_values( gather_nd(params, indices.flat_values)) if (indices_is_dense and ragged_tensor.is_ragged(result) and result.ragged_rank == indices_ndims - 2): result = ragged_tensor.RaggedTensor.to_tensor(result) return result # indices_ndims <= 2, and the innermost dimension of indices may not be # ragged, so `indices` must not be ragged. assert not ragged_tensor.is_ragged(indices) assert ragged_tensor.is_ragged(params) # Handle corner case: An empty index tuple selects the entire `params` # value. So if `index_size` is zero, then tile `params`. if index_size == 0: params_ndims = params.ragged_rank + array_ops.rank( params.flat_values) for dim in range(indices_ndims - 1): params = ragged_array_ops.expand_dims(params, axis=0) multiples = array_ops.concat([ array_ops.shape(indices)[:-1], array_ops.ones([params_ndims], dtypes.int32) ], axis=0) return ragged_array_ops.tile(params, multiples) # When index_size=1, we can just flatten the index tuples and use gather. elif index_size == 1: flattened_index_tuples = array_ops.reshape(indices, [-1]) return gather(params, flattened_index_tuples) # Otherwise, params is a RaggedTensor, and indices is a 1D or 2D Tensor. # Flatten both the index tuples and the params, such that the flattened # index tuples point to the correct values in the flattened params; and # then use ragged.gather on the flattened index tuples & params. else: indices = math_ops.cast(indices, params.row_splits.dtype) # Flatten the outermost 2 dimensions of the index tuples & params. flattened_index_tuples = array_ops.gather(params.row_splits, indices[..., 0]) flattened_index_tuples += indices[..., 1] flattened_params = params.values # Flatten any remaining dimensions. for dim in range(2, index_size): if not ragged_tensor.is_ragged(flattened_params): flattened_index_tuples = array_ops.expand_dims( flattened_index_tuples, axis=1) flattened_index_tuples = array_ops.concat( [flattened_index_tuples, indices[..., dim:]], axis=1) return array_ops.gather_nd(flattened_params, flattened_index_tuples) flattened_index_tuples = array_ops.gather( flattened_params.row_starts(), flattened_index_tuples) flattened_index_tuples += indices[..., dim] flattened_params = flattened_params.values # Gather using the flattened index tuples and params. return gather(flattened_params, flattened_index_tuples)
def gather(params, indices, validate_indices=None, axis=None, batch_dims=0, name=None): """Gathers ragged slices from `params` axis `0` according to `indices`. See `tf.gather` for full documentation. (This version has the same API as `tf.gather`, but supports ragged `params` and `indices`.) Examples: >>> params = tf.constant(['a', 'b', 'c', 'd', 'e']) >>> indices = tf.constant([3, 1, 2, 1, 0]) >>> ragged_params = tf.ragged.constant([['a', 'b', 'c'], ['d'], [], ['e']]) >>> ragged_indices = tf.ragged.constant([[3, 1, 2], [1], [], [0]]) >>> tf.gather(params, ragged_indices) <tf.RaggedTensor [[b'd', b'b', b'c'], [b'b'], [], [b'a']]> >>> tf.gather(ragged_params, indices) <tf.RaggedTensor [[b'e'], [b'd'], [], [b'd'], [b'a', b'b', b'c']]> >>> tf.gather(ragged_params, ragged_indices) <tf.RaggedTensor [[[b'e'], [b'd'], []], [[b'd']], [], [[b'a', b'b', b'c']]]> Args: params: The potentially ragged tensor from which to gather values. Must be at least rank 1. indices: The potentially ragged tensor indicating which values to gather. Must have dtype `int32` or `int64`. Values must be in the range `[0, params.shape[0]]`. validate_indices: Ignored. axis: The axis in `params` to gather `indices` from. batch_dims: The number of batch dimensions. name: A name for the operation (optional). Returns: A `RaggedTensor`, where `output.dtype=params.dtype` and `output.shape=indices.shape + params.shape[1:]` and `output.ragged_rank=indices.shape.ndims + params.ragged_rank`. Raises: ValueError: If indices.shape.ndims is not known statically. """ del validate_indices with ops.name_scope(name, 'RaggedGather', [params, indices]): params = ragged_tensor.convert_to_tensor_or_ragged_tensor( params, name='params') indices = ragged_tensor.convert_to_tensor_or_ragged_tensor( indices, name='indices') params, indices = ragged_tensor.match_row_splits_dtypes( params, indices) if batch_dims != indices.shape.rank: batch_dims = array_ops.get_positive_axis( batch_dims, indices.shape.rank, axis_name='batch_dims', ndims_name='rank(indices)') if params.shape.rank is not None and batch_dims >= params.shape.rank: raise ValueError('batch_dims must be less than rank(params)') if axis is None: axis = batch_dims axis = array_ops.get_positive_axis(axis, params.shape.rank, ndims_name='rank(params)') if axis < batch_dims: raise ValueError( 'axis must be greater than or equal to batch_dims') if indices.shape.rank is not None: if not 0 <= batch_dims <= indices.shape.rank: raise ValueError( 'batch_dims=%s must be between 0 and rank(indices)=%s' % (batch_dims, indices.shape.rank)) return _gather(params, indices, axis, batch_dims)
def batch_gather_with_default(params, indices, default_value='', name=None): """Same as `batch_gather` but inserts `default_value` for invalid indices. This operation is similar to `batch_gather` except that it will substitute the value for invalid indices with `default_value` as the contents. See `batch_gather` for more details. Args: params: A potentially ragged tensor with shape `[B1...BN, P1...PM]` (`N>=0`, `M>0`). indices: A potentially ragged tensor with shape `[B1...BN, I]` (`N>=0`). default_value: A value to be inserted in places where `indices` are out of bounds. Must be the same dtype as params and either a scalar or rank 1. name: A name for the operation (optional). Returns: A potentially ragged tensor with shape `[B1...BN, I, P2...PM]`. `result.ragged_rank = max(indices.ragged_rank, params.ragged_rank)`. #### Example: ```python >>> params = tf.ragged.constant([ ['a', 'b', 'c'], ['d'], [], ['e']]) >>> indices = tf.ragged.constant([[1, 2, -1], [], [], [0, 10]]) >>> batch_gather_with_default(params, indices, 'FOO') [['b', 'c', 'FOO'], [], [], ['e', 'FOO']] ``` """ with ops.name_scope(name, 'RaggedBatchGatherWithDefault'): params = ragged_tensor.convert_to_tensor_or_ragged_tensor( params, name='params', ) indices = ragged_tensor.convert_to_tensor_or_ragged_tensor( indices, name='indices', ) default_value = ragged_tensor.convert_to_tensor_or_ragged_tensor( default_value, name='default_value', ) row_splits_dtype, (params, indices, default_value) = ( ragged_tensor.match_row_splits_dtypes(params, indices, default_value, return_dtype=True)) # TODO(hterry): lift this restriction and support default_values of # of rank > 1 if (default_value.shape.ndims is not 0 and default_value.shape.ndims is not 1): raise ValueError('"default_value" must be a scalar or vector') upper_bounds = None if indices.shape.ndims is None: raise ValueError('Indices must have a known rank.') if params.shape.ndims is None: raise ValueError('Params must have a known rank.') num_batch_dimensions = indices.shape.ndims - 1 pad = None # The logic for this works as follows: # - create a padded params, where: # padded_params[b1...bn, 0] = default_value # padded_params[b1...bn, i] = params[b1...bn, i-1] (i>0) # - create an `upper_bounds` Tensor that contains the number of elements # in each innermost rank. Broadcast `upper_bounds` to be the same shape # as `indices`. # - check to see which index in `indices` are out of bounds and substitute # it with the index containing `default_value` (the first). # - call batch_gather with the indices adjusted. with ops.control_dependencies([ check_ops.assert_greater_equal(array_ops.rank(params), array_ops.rank(indices))]): if ragged_tensor.is_ragged(params): row_lengths = ragged_array_ops.expand_dims( params.row_lengths(axis=num_batch_dimensions), axis=-1) upper_bounds = math_ops.cast(row_lengths, indices.dtype) pad_shape = _get_pad_shape(params, indices, row_splits_dtype) pad = ragged_tensor_shape.broadcast_to( default_value, pad_shape) else: params_shape = array_ops.shape(params) pad_shape = array_ops.concat([ params_shape[:num_batch_dimensions], [1], params_shape[num_batch_dimensions + 1:params.shape.ndims] ], 0) upper_bounds = params_shape[num_batch_dimensions] pad = array_ops.broadcast_to(default_value, pad_shape) # Add `default_value` as the first value in the innermost (ragged) rank. pad = math_ops.cast(pad, params.dtype) padded_params = array_ops.concat( [pad, params], axis=num_batch_dimensions) # Adjust the indices by substituting out-of-bound indices to the # default-value index (which is the first element) shifted_indices = indices + 1 is_out_of_bounds = (indices < 0) | (indices > upper_bounds) adjusted_indices = ragged_where_op.where( is_out_of_bounds, x=array_ops.zeros_like(indices), y=shifted_indices, ) return array_ops.batch_gather( params=padded_params, indices=adjusted_indices, name=name)
def _ragged_stack_concat_helper(rt_inputs, axis, stack_values): """Helper function to concatenate or stack ragged tensors. Args: rt_inputs: A list of RaggedTensors or Tensors to combine. axis: The axis along which to concatenate or stack. stack_values: A boolean -- if true, then stack values; otherwise, concatenate them. Returns: A RaggedTensor. Raises: ValueError: If rt_inputs is empty, or if axis is out of range. """ # Validate parameters. if not rt_inputs: raise ValueError('rt_inputs may not be empty.') # Convert input tensors. rt_inputs = [ ragged_tensor.convert_to_tensor_or_ragged_tensor(rt_input, name='rt_input') for rt_input in rt_inputs ] row_splits_dtype, rt_inputs = ragged_tensor.match_row_splits_dtypes( *rt_inputs, return_dtype=True) rt_inputs = list(rt_inputs) # Special case: if there's only one input, then return it as-is. if len(rt_inputs) == 1: if stack_values: return ragged_array_ops.expand_dims(rt_inputs[0], axis=axis) else: return rt_inputs[0] # Check the rank (number of dimensions) of the input tensors. ndims = None for rt in rt_inputs: if ndims is None: ndims = rt.shape.ndims else: rt.shape.assert_has_rank(ndims) out_ndims = ndims if (ndims is None or not stack_values) else ndims + 1 axis = ragged_util.get_positive_axis(axis, out_ndims) # If all the inputs are Tensors, and we're combining the final dimension, # then we can delegate to the tf.stack/tf.concat operation, and return a # Tensor. if all(not ragged_tensor.is_ragged(rt) for rt in rt_inputs): if ndims is not None and (axis == out_ndims - 1 or axis == ndims - 1): if stack_values: return array_ops.stack(rt_inputs, axis) else: return array_ops.concat(rt_inputs, axis) # Convert any Tensor inputs to RaggedTensors. This makes it # possible to concatenate Tensors and RaggedTensors together. for i in range(len(rt_inputs)): if not ragged_tensor.is_ragged(rt_inputs[i]): rt_inputs[i] = ragged_tensor.RaggedTensor.from_tensor( rt_inputs[i], ragged_rank=1, row_splits_dtype=row_splits_dtype) # Convert the input tensors to all have the same ragged_rank. ragged_rank = max(max(rt.ragged_rank for rt in rt_inputs), 1) rt_inputs = [ _increase_ragged_rank_to(rt, ragged_rank, row_splits_dtype) for rt in rt_inputs ] if axis == 0: return _ragged_stack_concat_axis_0(rt_inputs, stack_values) elif axis == 1: return _ragged_stack_concat_axis_1(rt_inputs, stack_values) else: # axis > 1: recurse. values = [rt.values for rt in rt_inputs] splits = [[rt_input.row_splits] for rt_input in rt_inputs] with ops.control_dependencies(ragged_util.assert_splits_match(splits)): return ragged_tensor.RaggedTensor.from_row_splits( _ragged_stack_concat_helper(values, axis - 1, stack_values), splits[0][0], validate=False)
def gather_nd(params, indices, batch_dims=0, name=None): """Gather slices from `params` using `n`-dimensional indices. This operation is similar to `gather`, but it uses the innermost dimension of `indices` to define a slice into `params`. In particular, if: * `indices` has shape `[A1...AN, I]` * `params` has shape `[B1...BM]` Then: * `result` has shape `[A1...AN, B_{I+1}...BM]`. * `result[a1...aN] = params[indices[a1...aN, :]]` Args: params: A potentially ragged tensor with shape `[A1...AN, I]`. indices: A potentially ragged tensor with shape `[B1...BM]`. batch_dims: Must be zero. name: A name for the operation (optional). Returns: A potentially ragged tensor with shape `[A1...AN, B_{I+1}...BM]`. #### Examples: ```python >>> params = tf.compat.v1.ragged.constant_value( ... [ [ ['000', '001'], ['010' ] ], ... [ ['100' ], ['110', '111', '112'], ['120'] ], ... [ [ ], ['210' ] ] ]) >>> # Gather 2D slices from a 3D tensor >>> ragged.gather_nd(params, [[2], [0]]) [ [ [ ], ['210'] ] [ ['000', '001'], ['010'] ] ] >>> # Gather 1D slices from a 3D tensor >>> ragged.gather_nd(params, [[2, 1], [0, 0]]) [['210'], ['000', '001']] >>> # Gather scalars from a 3D tensor >>> ragged.gather_nd(params, [[0, 0, 1], [1, 1, 2]]) ['001', '112'] ``` """ if not isinstance(batch_dims, int) or batch_dims != 0: raise ValueError('batch_dims != 0 is not supported for ragged gather yet.') if not (ragged_tensor.is_ragged(params) or ragged_tensor.is_ragged(indices)): return array_ops.gather_nd(params, indices, name) with ops.name_scope(name, 'RaggedGatherNd', [params, indices]): params = ragged_tensor.convert_to_tensor_or_ragged_tensor( params, name='params') indices = ragged_tensor.convert_to_tensor_or_ragged_tensor( indices, name='indices') params, indices = ragged_tensor.match_row_splits_dtypes(params, indices) indices_shape = indices.shape indices_ndims = indices_shape.ndims if indices_ndims is None: raise ValueError('indices.rank be statically known.') if indices_ndims == 0: raise ValueError('indices.rank must be at least 1.') if (ragged_tensor.is_ragged(indices) and indices_ndims == indices.ragged_rank + 1): raise ValueError('The innermost dimension of indices may not be ragged') # `index_size` is the "n" in "gather_nd" -- i.e., the number of dimensions # that each index slices into. index_size = tensor_shape.dimension_value(indices_shape[-1]) if index_size is None: raise ValueError('indices.shape[-1] must be statically known.') # If `indices` has more than 2 dimensions, then recurse. If `indices` is # dense, then we convert it to ragged before recursing, and then convert # the result back to `dense` if appropriate. if indices_ndims > 2: indices_is_dense = not ragged_tensor.is_ragged(indices) if indices_is_dense: indices = ragged_tensor.RaggedTensor.from_tensor( indices, ragged_rank=indices_ndims - 2, row_splits_dtype=params.row_splits.dtype) result = indices.with_flat_values(gather_nd(params, indices.flat_values)) if (indices_is_dense and ragged_tensor.is_ragged(result) and result.ragged_rank == indices_ndims - 2): result = ragged_tensor.RaggedTensor.to_tensor(result) return result # indices_ndims <= 2, and the innermost dimension of indices may not be # ragged, so `indices` must not be ragged. assert not ragged_tensor.is_ragged(indices) assert ragged_tensor.is_ragged(params) # Handle corner case: An empty index tuple selects the entire `params` # value. So if `index_size` is zero, then tile `params`. if index_size == 0: params_ndims = params.ragged_rank + array_ops.rank(params.flat_values) for dim in range(indices_ndims - 1): params = ragged_array_ops.expand_dims(params, axis=0) multiples = array_ops.concat([ array_ops.shape(indices)[:-1], array_ops.ones([params_ndims], dtypes.int32) ], axis=0) return ragged_array_ops.tile(params, multiples) # When index_size=1, we can just flatten the index tuples and use gather. elif index_size == 1: flattened_index_tuples = array_ops.reshape(indices, [-1]) return gather(params, flattened_index_tuples) # Otherwise, params is a RaggedTensor, and indices is a 1D or 2D Tensor. # Flatten both the index tuples and the params, such that the flattened # index tuples point to the correct values in the flattened params; and # then use ragged.gather on the flattened index tuples & params. else: indices = math_ops.cast(indices, params.row_splits.dtype) # Flatten the outermost 2 dimensions of the index tuples & params. flattened_index_tuples = array_ops.gather(params.row_splits, indices[..., 0]) flattened_index_tuples += indices[..., 1] flattened_params = params.values # Flatten any remaining dimensions. for dim in range(2, index_size): if not ragged_tensor.is_ragged(flattened_params): flattened_index_tuples = array_ops.expand_dims( flattened_index_tuples, axis=1) flattened_index_tuples = array_ops.concat( [flattened_index_tuples, indices[..., dim:]], axis=1) return array_ops.gather_nd(flattened_params, flattened_index_tuples) flattened_index_tuples = array_ops.gather( flattened_params.row_starts(), flattened_index_tuples) flattened_index_tuples += indices[..., dim] flattened_params = flattened_params.values # Gather using the flattened index tuples and params. return gather(flattened_params, flattened_index_tuples)
def batch_gather(params, indices, name=None): """Gathers slices from `params` according to `indices` with batch dims. This operation is similar to `gather`, but it assumes that the leading `N` dimensions of `indices` and `params` are batch dimensions, and performs a gather within each batch. In particular, when using this operation with `N` batch dimensions `B1...BN`: * `indices` has shape `[B1...BN, I]` * `params` has shape `[B1...BN, P1...PM]`. * `result` has shape `[B1...BN, I, P2...PM]`. * `result[b1...bN, i, p2...pM] = params[b1...bN, indices[b1...bN, i], p2...pM]` Args: params: A potentially ragged tensor with shape `[B1...BN, P1...PM]` (`N>=0`, `M>0`). indices: A potentially ragged tensor with shape `[B1...BN, I]` (`N>=0`). name: A name for the operation (optional). Returns: A potentially ragged tensor with shape `[B1...BN, I, P2...PM]`. `result.ragged_rank = max(indices.ragged_rank, params.ragged_rank)`. #### Example: ```python >>> params = tf.ragged.constant([['a', 'b', 'c'], ['d'], [], ['e']]) >>> indices = tf.ragged.constant([[1, 2, 0], [], [], [0, 0]]) >>> tf.compat.v1.batch_gather(params, indices) [['b', 'c', 'a'], [], [], ['e', 'e']] ``` """ if not (ragged_tensor.is_ragged(params) or ragged_tensor.is_ragged(indices)): return array_ops.batch_gather(params, indices, name) with ops.name_scope(name, 'RaggedBatchGather', [params, indices]): params = ragged_tensor.convert_to_tensor_or_ragged_tensor( params, name='params') indices = ragged_tensor.convert_to_tensor_or_ragged_tensor( indices, name='indices') params, indices = ragged_tensor.match_row_splits_dtypes( params, indices) indices_ndims = indices.shape.ndims if indices_ndims is None: raise ValueError( 'batch_gather does not allow indices with unknown shape.') if indices_ndims == 0: raise ValueError('indices.rank must be at least 1.') if ragged_tensor.is_ragged(indices): # If the outermost ragged dimension is a batch dimension, recurse. if indices_ndims > 2: if not ragged_tensor.is_ragged(params): raise ValueError('batch shape from indices does ' 'not match params shape') checks = [ check_ops.assert_equal(params.row_splits, indices.row_splits) ] with ops.control_dependencies(checks): return ragged_tensor.RaggedTensor.from_row_splits( batch_gather(params.values, indices.values), indices.row_splits, validate=False) # Otherwise, indices is a 2D ragged tensor with 1 ragged dimension. else: # Ensure that `params` is ragged and has at least 2 dimensions. if not ragged_tensor.is_ragged(params): if params.shape.ndims is not None and params.shape.ndims < 2: raise ValueError('batch shape from indices does ' 'not match params shape') params = ragged_tensor.RaggedTensor.from_tensor( params, ragged_rank=1, row_splits_dtype=indices.row_splits.dtype) # Adjust indices from within-batch to global (in params.values), and # then use ragged.gather to gather them. num_indices = indices.row_lengths() params_starts = params.row_starts() adjustments = ragged_util.repeat(params_starts, num_indices, axis=0) adjusted_index_values = ( math_ops.cast(indices.values, adjustments.dtype) + adjustments) return ragged_tensor.RaggedTensor.from_row_splits( ragged_gather_ops.gather(params.values, adjusted_index_values), indices.row_splits, validate=False) else: # params is a RaggedTensor and indices is a Tensor. if indices_ndims == 1: return ragged_gather_ops.gather(params, indices) elif indices_ndims == 2: # Adjust indices from batch-local to global (in params.values) adjustments = array_ops.expand_dims(params.row_starts(), 1) adjusted_indices = (math_ops.cast(indices, adjustments.dtype) + adjustments) return ragged_gather_ops.gather(params.values, adjusted_indices) else: raise ValueError( 'batch shape from indices does not match params shape')
def _ragged_stack_concat_helper(rt_inputs, axis, stack_values): """Helper function to concatenate or stack ragged tensors. Args: rt_inputs: A list of RaggedTensors or Tensors to combine. axis: The axis along which to concatenate or stack. stack_values: A boolean -- if true, then stack values; otherwise, concatenate them. Returns: A RaggedTensor. Raises: ValueError: If rt_inputs is empty, or if axis is out of range. """ # Validate parameters. if not rt_inputs: raise ValueError('rt_inputs may not be empty.') # Convert input tensors. rt_inputs = [ ragged_tensor.convert_to_tensor_or_ragged_tensor( rt_input, name='rt_input') for rt_input in rt_inputs ] row_splits_dtype, rt_inputs = ragged_tensor.match_row_splits_dtypes( *rt_inputs, return_dtype=True) rt_inputs = list(rt_inputs) # Special case: if there's only one input, then return it as-is. if len(rt_inputs) == 1: if stack_values: return ragged_array_ops.expand_dims(rt_inputs[0], axis=axis) else: return rt_inputs[0] # Check the rank (number of dimensions) of the input tensors. ndims = None for rt in rt_inputs: if ndims is None: ndims = rt.shape.ndims else: rt.shape.assert_has_rank(ndims) out_ndims = ndims if (ndims is None or not stack_values) else ndims + 1 axis = ragged_util.get_positive_axis(axis, out_ndims) # If all the inputs are Tensors, and we're combining the final dimension, # then we can delegate to the tf.stack/tf.concat operation, and return a # Tensor. if all(not ragged_tensor.is_ragged(rt) for rt in rt_inputs): if ndims is not None and (axis == out_ndims - 1 or axis == ndims - 1): if stack_values: return array_ops.stack(rt_inputs, axis) else: return array_ops.concat(rt_inputs, axis) # Convert any Tensor inputs to RaggedTensors. This makes it # possible to concatenate Tensors and RaggedTensors together. for i in range(len(rt_inputs)): if not ragged_tensor.is_ragged(rt_inputs[i]): rt_inputs[i] = ragged_tensor.RaggedTensor.from_tensor( rt_inputs[i], ragged_rank=1, row_splits_dtype=row_splits_dtype) # Convert the input tensors to all have the same ragged_rank. ragged_rank = max(max(rt.ragged_rank for rt in rt_inputs), 1) rt_inputs = [_increase_ragged_rank_to(rt, ragged_rank, row_splits_dtype) for rt in rt_inputs] if axis == 0: return _ragged_stack_concat_axis_0(rt_inputs, stack_values) elif axis == 1: return _ragged_stack_concat_axis_1(rt_inputs, stack_values) else: # axis > 1: recurse. values = [rt.values for rt in rt_inputs] splits = [[rt_input.row_splits] for rt_input in rt_inputs] with ops.control_dependencies(ragged_util.assert_splits_match(splits)): return ragged_tensor.RaggedTensor.from_row_splits( _ragged_stack_concat_helper(values, axis - 1, stack_values), splits[0][0], validate=False)
def batch_gather_with_default(params, indices, default_value='', name=None): """Same as `batch_gather` but inserts `default_value` for invalid indices. This operation is similar to `batch_gather` except that it will substitute the value for invalid indices with `default_value` as the contents. See `batch_gather` for more details. Args: params: A potentially ragged tensor with shape `[B1...BN, P1...PM]` (`N>=0`, `M>0`). indices: A potentially ragged tensor with shape `[B1...BN, I]` (`N>=0`). default_value: A value to be inserted in places where `indices` are out of bounds. Must be the same dtype as params and either a scalar or rank 1. name: A name for the operation (optional). Returns: A potentially ragged tensor with shape `[B1...BN, I, P2...PM]`. `result.ragged_rank = max(indices.ragged_rank, params.ragged_rank)`. #### Example: ```python >>> params = tf.ragged.constant([ ['a', 'b', 'c'], ['d'], [], ['e']]) >>> indices = tf.ragged.constant([[1, 2, -1], [], [], [0, 10]]) >>> batch_gather_with_default(params, indices, 'FOO') [['b', 'c', 'FOO'], [], [], ['e', 'FOO']] ``` """ with ops.name_scope(name, 'RaggedBatchGatherWithDefault'): params = ragged_tensor.convert_to_tensor_or_ragged_tensor( params, name='params', ) indices = ragged_tensor.convert_to_tensor_or_ragged_tensor( indices, name='indices', ) default_value = ragged_tensor.convert_to_tensor_or_ragged_tensor( default_value, name='default_value', ) row_splits_dtype, (params, indices, default_value) = ( ragged_tensor.match_row_splits_dtypes(params, indices, default_value, return_dtype=True)) # TODO(hterry): lift this restriction and support default_values of # of rank > 1 if default_value.shape.ndims not in (0, 1): raise ValueError('"default_value" must be a scalar or vector') upper_bounds = None if indices.shape.ndims is None: raise ValueError('Indices must have a known rank.') if params.shape.ndims is None: raise ValueError('Params must have a known rank.') num_batch_dimensions = indices.shape.ndims - 1 pad = None # The logic for this works as follows: # - create a padded params, where: # padded_params[b1...bn, 0] = default_value # padded_params[b1...bn, i] = params[b1...bn, i-1] (i>0) # - create an `upper_bounds` Tensor that contains the number of elements # in each innermost rank. Broadcast `upper_bounds` to be the same shape # as `indices`. # - check to see which index in `indices` are out of bounds and substitute # it with the index containing `default_value` (the first). # - call batch_gather with the indices adjusted. with ops.control_dependencies([ check_ops.assert_greater_equal(array_ops.rank(params), array_ops.rank(indices))]): if ragged_tensor.is_ragged(params): row_lengths = ragged_array_ops.expand_dims( params.row_lengths(axis=num_batch_dimensions), axis=-1) upper_bounds = math_ops.cast(row_lengths, indices.dtype) pad_shape = _get_pad_shape(params, indices, row_splits_dtype) pad = ragged_tensor_shape.broadcast_to( default_value, pad_shape) else: params_shape = array_ops.shape(params) pad_shape = array_ops.concat([ params_shape[:num_batch_dimensions], [1], params_shape[num_batch_dimensions + 1:params.shape.ndims] ], 0) upper_bounds = params_shape[num_batch_dimensions] pad = array_ops.broadcast_to(default_value, pad_shape) # Add `default_value` as the first value in the innermost (ragged) rank. pad = math_ops.cast(pad, params.dtype) padded_params = array_ops.concat( [pad, params], axis=num_batch_dimensions) # Adjust the indices by substituting out-of-bound indices to the # default-value index (which is the first element) shifted_indices = indices + 1 is_out_of_bounds = (indices < 0) | (indices > upper_bounds) adjusted_indices = ragged_where_op.where( is_out_of_bounds, x=array_ops.zeros_like(indices), y=shifted_indices, ) return array_ops.batch_gather( params=padded_params, indices=adjusted_indices, name=name)
def map_fn(fn, elems, dtype=None, parallel_iterations=None, back_prop=True, swap_memory=False, infer_shape=True, name=None): """map on the list of tensors unpacked from `elems` on dimension 0. The simplest version of `map_fn` repeatedly applies the callable `fn` to a sequence of elements from first to last. The elements are made of the tensors unpacked from `elems`. `dtype` is the data type of the return value of `fn`. Users must provide `dtype` if it is different from the data type of `elems`. Suppose that `elems` is unpacked into `values`, a list of tensors. The shape of the result tensor is `[values.shape[0]] + fn(values[0]).shape`. This method also allows multi-arity `elems` and output of `fn`. If `elems` is a (possibly nested) list or tuple of tensors, then each of these tensors must have a matching first (unpack) dimension. The signature of `fn` may match the structure of `elems`. That is, if `elems` is `(t1, [t2, t3, [t4, t5]])`, then an appropriate signature for `fn` is: `fn = lambda (t1, [t2, t3, [t4, t5]]):`. Furthermore, `fn` may emit a different structure than its input. For example, `fn` may look like: `fn = lambda t1: return (t1 + 1, t1 - 1)`. In this case, the `dtype` parameter is not optional: `dtype` must be a type or (possibly nested) tuple of types matching the output of `fn`. To apply a functional operation to the nonzero elements of a SparseTensor one of the following methods is recommended. First, if the function is expressible as TensorFlow ops, use ```python result = SparseTensor(input.indices, fn(input.values), input.dense_shape) ``` If, however, the function is not expressible as a TensorFlow op, then use ```python result = SparseTensor( input.indices, map_fn(fn, input.values), input.dense_shape) ``` instead. When executing eagerly, map_fn does not execute in parallel even if `parallel_iterations` is set to a value > 1. You can still get the performance benefits of running a function in parallel by using the `tf.contrib.eager.defun` decorator, ```python # Assume the function being used in map_fn is fn. # To ensure map_fn calls fn in parallel, use the defun decorator. @tf.contrib.eager.defun def func(tensor): return tf.map_fn(fn, tensor) ``` Note that if you use the defun decorator, any non-TensorFlow Python code that you may have written in your function won't get executed. See `tf.contrib.eager.defun` for more details. The recommendation would be to debug without defun but switch to defun to get performance benefits of running map_fn in parallel. Args: fn: The callable to be performed. It accepts one argument, which will have the same (possibly nested) structure as `elems`. Its output must have the same structure as `dtype` if one is provided, otherwise it must have the same structure as `elems`. elems: A tensor or (possibly nested) sequence of tensors, each of which will be unpacked along their first dimension. The nested sequence of the resulting slices will be applied to `fn`. dtype: (optional) The output type(s) of `fn`. If `fn` returns a structure of Tensors differing from the structure of `elems`, then `dtype` is not optional and must have the same structure as the output of `fn`. Use `RaggedTensorType` to declare an output of type `RaggedTensor`. parallel_iterations: (optional) The number of iterations allowed to run in parallel. When graph building, the default value is 10. While executing eagerly, the default value is set to 1. back_prop: (optional) True enables support for back propagation. swap_memory: (optional) True enables GPU-CPU memory swapping. infer_shape: (optional) False disables tests for consistent output shapes. name: (optional) Name prefix for the returned tensors. Returns: A possibly nested sequence of potentially ragged tensors. Each tensor packs the results of applying `fn` to tensors unpacked from `elems` along the first dimension, from first to last. Raises: TypeError: if `fn` is not callable or the structure of the output of `fn` and `dtype` do not match, or if elems is a SparseTensor. ValueError: if the lengths of the output of `fn` and `dtype` do not match. #### Examples: ```python elems = np.array([1, 2, 3, 4, 5, 6]) squares = map_fn(lambda x: x * x, elems) # squares == [1, 4, 9, 16, 25, 36] ``` ```python elems = (np.array([1, 2, 3]), np.array([-1, 1, -1])) alternate = map_fn(lambda x: x[0] * x[1], elems, dtype=tf.int64) # alternate == [-1, 2, -3] ``` ```python elems = np.array([1, 2, 3]) alternates = map_fn(lambda x: (x, -x), elems, dtype=(tf.int64, tf.int64)) # alternates[0] == [1, 2, 3] # alternates[1] == [-1, -2, -3] ``` ```python elems=ragged.constant([[1, 2, 3], [4, 5], [6, 7]]) mean = map_fn(tf.reduce_mean, elems) # mean == [2, 4, 6] ``` ```python elems=ragged.constant([[1, 2, 3], [4, 5], [6, 7]], dtype=tf.int64) out = map_fn(fn=lambda x: x+1, elems, dtype=ragged.RaggedTensorType(type=tf.int64, ragged_rank=0)) # out = ragged.constant([[2, 3, 4], [5, 6], [7, 8]]) ``` """ if not callable(fn): raise TypeError("fn must be callable.") if isinstance(elems, sparse_tensor.SparseTensor): raise TypeError( "To perform a map on the values of a sparse tensor use either " " SparseTensor(input.indices, fn(input.values), input.dense_shape) or " " SparseTensor(input.indices, map_fn(fn, input.values), " "input.dense_shape)") in_graph_mode = not context.executing_eagerly() # Set the default number of parallel_iterations depending on graph/eager mode. if in_graph_mode and not parallel_iterations: parallel_iterations = 10 elif not in_graph_mode and not parallel_iterations: parallel_iterations = 1 if not in_graph_mode and parallel_iterations > 1: logging.log_first_n(logging.WARN, "Setting parallel_iterations > 1 has no " "effect when executing eagerly. Consider calling map_fn" " with tf.contrib.eager.defun to execute fn in " "parallel.", 1) parallel_iterations = 1 input_is_sequence = nest.is_sequence(elems) input_flatten = lambda x: nest.flatten(x) if input_is_sequence else [x] def input_pack(x): return nest.pack_sequence_as(elems, x) if input_is_sequence else x[0] elems_flat = input_flatten(elems) elems_flat = ragged_tensor.match_row_splits_dtypes(*elems_flat) with ops.name_scope(name, "map", elems_flat): # TODO(akshayka): Remove the in_graph_mode check once caching devices are # supported in Eager if in_graph_mode: # Any get_variable calls in fn will cache the first call locally # and not issue repeated network I/O requests for each iteration. varscope = vs.get_variable_scope() varscope_caching_device_was_none = False if varscope.caching_device is None: # TODO(ebrevdo): Change to using colocate_with here and in other # methods. varscope.set_caching_device(lambda op: op.device) varscope_caching_device_was_none = True elems_flat = [ ragged_tensor.convert_to_tensor_or_ragged_tensor(elem, name="elem") for elem in elems_flat ] # We can either infer the output, or we can assume that it will be the same # as the input structure. dtype = dtype or input_pack([elem.dtype for elem in elems_flat]) # Find the number of iterations, n may be known statically. if isinstance(elems_flat[0], ragged_tensor.RaggedTensor): n = elems_flat[0].nrows(out_type=dtypes.int32) else: static_shape = elems_flat[0].shape if static_shape.ndims is not None and static_shape.ndims < 1: if len(elems_flat) == 1: raise ValueError( "elems must be a 1+ dimensional Tensor, not a scalar") else: raise ValueError( "elements in elems must be 1+ dimensional Tensors, not scalars") n = (tensor_shape.dimension_value(static_shape[0]) or array_ops.shape(elems_flat[0])[0]) n = math_ops.cast(n, dtype=dtypes.int32) # Create a flat list of TAs. # Flatten the dtype structure to a list. dtype_flat = nest.flatten(dtype) # decompose to components dtype_components = [_maybe_decompose_dtype(d) for d in dtype_flat] dtype_components_flat = nest.flatten(dtype_components) # Create TensorArrays. accs_ta = [ tensor_array_ops.TensorArray( dtype=t, dynamic_size=False, infer_shape=infer_shape, size=n) for t in dtype_components_flat ] i = constant_op.constant(0, dtype=dtypes.int32) def compute(i, tas): """The loop body of map_fn. Args: i: the loop counter tas: the flat TensorArray accumulator list Returns: (i + 1, tas): the updated counter + updated TensorArrays Raises: TypeError: if dtype and packed_fn_values structure do not match ValueType: if dtype and packed_fn_values lengths do not match """ # Get Tensors or RaggedTensors sliced at i, then pack it back to the # original structure. packed_values = input_pack([elem_flat[i] for elem_flat in elems_flat]) packed_fn_values = fn(packed_values) # Check that the structure of the output matches what was declared or # inferred. # nest.assert_same_structure(dtype or elems, packed_fn_values) # Flatten and decompose to a list of Tensors flat_fn_values = nest.flatten(packed_fn_values) # If we declared that we are expecting a RaggedTensor output, but we get a # Tensor output. We should try to convert it to a RaggedTensor. flat_fn_composite_tensors = list( _convert_declared(flat_fn_values, dtype_flat)) flat_fn_components = [ _maybe_decompose_tensor(t) for t in flat_fn_composite_tensors ] flat_fn_tensors = nest.flatten(flat_fn_components) # Write to TAs. tas = [ta.write(i, value) for (ta, value) in zip(tas, flat_fn_tensors)] return (i + 1, tas) _, r_a = control_flow_ops.while_loop( lambda i, _: i < n, compute, (i, accs_ta), parallel_iterations=parallel_iterations, back_prop=back_prop, swap_memory=swap_memory, maximum_iterations=n) # TODO(akshayka): Remove the in_graph_mode check once caching devices are # supported in Eager if in_graph_mode and varscope_caching_device_was_none: varscope.set_caching_device(None) # Pack back into a list of components results_as_components = nest.pack_sequence_as(dtype_components, r_a) # Stack TensorArrays for Tensor outputs, and concat RaggedTensor outputs. def _stack_or_concat(e): if isinstance(e, _RaggedTensorComponents): return _concat_ragged_tensor_components(e) else: result = e.stack() return result results_flat_components = [ _stack_or_concat(e) for e in results_as_components ] results_packed = [ _maybe_recompose_tensor(c) for c in results_flat_components ] results_packed = nest.pack_sequence_as(dtype, results_packed) return results_packed
def boolean_mask(data, mask, keepdims=False, name=None): """Applies a boolean mask to `data`. Returns a potentially ragged tensor that is formed by retaining the elements in `data` where the corresponding value in `mask` is `True`. If `keepdims` is true then outer dimensions (corresponding to the `mask` dimensions) are preserved, and: * `output[a1...aA, i, b1...bB] = data[a1...aA, j, b1...bB]` Where `j` is the `i`th `True` entry of `mask[a1...aA]`. If `keepdims` is false, then the outer dimensions are collapsed (similar to the behavior of `tf.boolean_mask`), and: * `output[i, b1...bB] = data[a1...aA, b1...bB]` Where `(a1...aA)` is the `i`th `True` entry of `mask` (in row-major order). Args: data: A potentially ragged tensor. mask: A potentially ragged boolean tensor. `mask`'s shape must be a prefix of `data`'s shape. `rank(mask)` must be known statically. keepdims: Whether to preserve the outer dimensions (`keepdims=True`) or flatten them (`keepdims=False`). name: A name prefix for the returned tensor (optional). Returns: A potentially ragged tensor that is formed by retaining the elements in `data` where the corresponding value in `mask` is `True`. If `keepdims` is false: * `rank(output) = rank(data) - rank(mask) + 1`. * `output.ragged_rank = max(data.ragged_rank - rank(mask) + 1, 0)`. If `keepdims` is true: * `rank(output) = rank(data)`. * `output.ragged_rank = max(data.ragged_rank, rank(mask) - 1)`. Raises: ValueError: if `rank(mask)` is not known statically; or if `mask.shape` is not a prefix of `data.shape`. #### Examples: ```python >>> # Aliases for True & False so data and mask line up. >>> T, F = (True, False) >>> tf.ragged.boolean_mask( # Mask a 2D Tensor. Flatten outer dims. ... data=[[1, 2, 3], [4, 5, 6], [7, 8, 9]], ... mask=[[T, F, T], [F, F, F], [T, F, F]], ... keepdims=False).tolist() [1, 3, 7] >>> tf.ragged.boolean_mask( # Mask a 2D Tensor. Preserve outer dims. ... data=[[1, 2, 3], [4, 5, 6], [7, 8, 9]], ... mask=[[T, F, T], [F, F, F], [T, F, F]], ... keepdims=True).tolist() [[1, 3], [], [7]] >>> tf.ragged.boolean_mask( # Mask a 2D RaggedTensor. Flatten outer dims. ... tf.ragged.constant([[1, 2, 3], [4], [5, 6]]), ... tf.ragged.constant([[F, F, T], [F], [T, T]]), ... keepdims=False).tolist() [3, 5, 6] >>> tf.ragged.boolean_mask( # Mask a 2D RaggedTensor. Preserve outer dims. ... tf.ragged.constant([[1, 2, 3], [4], [5, 6]]), ... tf.ragged.constant([[F, F, T], [F], [T, T]]), ... keepdims=True).tolist() [[3], [], [5, 6]] >>> tf.ragged.boolean_mask( # Mask rows of a 2D RaggedTensor. ... tf.ragged.constant([[1, 2, 3], [4], [5, 6]]), ... tf.ragged.constant([True, False, True]), ... keepdims=True).tolist() [[1, 2, 3], [5, 6]] ``` """ with ops.name_scope(name, 'RaggedMask', [data, mask]): # Convert inputs to tensors. data = ragged_tensor.convert_to_tensor_or_ragged_tensor(data, name='data') mask = ragged_tensor.convert_to_tensor_or_ragged_tensor( mask, dtypes.bool, name='mask') row_splits_dtype, (data, mask) = ragged_tensor.match_row_splits_dtypes( data, mask, return_dtype=True) # Get static rank of mask. if mask.shape.ndims is None: raise ValueError('mask.shape.ndims must be known statically.') elif mask.shape.ndims == 0: raise ValueError('mask cannot be scalar.') # If mask is ragged, then recurse with a non-ragged mask. if ragged_tensor.is_ragged(mask): if not ragged_tensor.is_ragged(data): data = ragged_tensor.RaggedTensor.from_tensor( data, ragged_rank=mask.ragged_rank, row_splits_dtype=mask.row_splits.dtype) # Check that mask.nested_row_splits is a prefix of # data.nested_row_splits. splits_list = [ mask.nested_row_splits, data.nested_row_splits[:mask.ragged_rank] ] with ops.control_dependencies( ragged_util.assert_splits_match(splits_list)): # Strip off ragged `splits` until `mask` is non-ragged. Keep the splits # that we strip off in `splits`, so we can add them back on after # we recursively mask the non-ragged data. splits = [] while ragged_tensor.is_ragged(mask): if mask.shape.ndims > 2: splits.append(mask.row_splits) else: # Count the number of True mask values in each row to find the # lengths of the filtered rows; then convert to splits. int_mask = ragged_functional_ops.map_flat_values( math_ops.cast, mask, dtype=row_splits_dtype) masked_row_lengths = ragged_math_ops.reduce_sum(int_mask, axis=1) splits.append(ragged_util.lengths_to_splits(masked_row_lengths)) mask = mask.values data = data.values # Recursively apply the nested non-ragged mask to the nested data. masked_values = boolean_mask(data, mask, keepdims) # Add the ragged `splits` back to the result. if keepdims: masked_values = ragged_tensor.RaggedTensor.from_nested_row_splits( masked_values, splits, validate=False) return masked_values # If mask is non-ragged and has rank 1, and data is ragged, then build a # ragged tensor with the indicated rows. elif ragged_tensor.is_ragged(data) and mask.shape.ndims == 1: # Get the masked splits: first get the length of each row, then filter # out the rows that we are deleting, and convert that filtered set of # masks back to a splits tensor. lengths = data.row_lengths() masked_lengths = array_ops.boolean_mask(lengths, mask) masked_splits = ragged_util.lengths_to_splits(masked_lengths) # Get the masked values: first get row ids corresponding to each # value, then use tf.gather to build a boolean mask that's false for # values that come from rows that we are deleting, and use that mask to # construct the masked values tensor. segment_ids = segment_id_ops.row_splits_to_segment_ids(data.row_splits) segment_mask = array_ops.gather(mask, segment_ids) masked_values = boolean_mask(data.values, segment_mask, keepdims=False) return ragged_tensor.RaggedTensor.from_row_splits(masked_values, masked_splits, validate=False) # If mask is non-ragged and has rank>1, then convert it to be ragged, # with a ragged rank matching data. if ragged_tensor.is_ragged(data): mask = ragged_tensor.RaggedTensor.from_tensor( mask, ragged_rank=min(data.ragged_rank, mask.shape.ndims - 1), row_splits_dtype=data.row_splits.dtype) return boolean_mask(data, mask, keepdims) # Otherwise, data and mask are both `Tensor`s. else: # Apply `boolean_mask` to get the masked values. masked_values = array_ops.boolean_mask(data, mask) if mask.shape.ndims >= 2 and keepdims: # Add the innermost ragged dimension. For each innermost cell, get the # number of values it contains. Then flatten that to get a list of # cell lengths, and convert it to splits. Finally, combine the splits # and values to get the innermost ragged tensor. masked_lengths = math_ops.count_nonzero(mask, axis=-1, dtype=row_splits_dtype) flattened_masked_lengths = array_ops.reshape(masked_lengths, [-1]) masked_values = ragged_tensor.RaggedTensor.from_row_lengths( masked_values, flattened_masked_lengths, validate=False) # Wrap remaining ragged dimensions. if mask.shape.ndims > 2 and keepdims: mask_shape = array_ops.shape(mask, out_type=row_splits_dtype) split_size = math_ops.cumprod(mask_shape) + 1 for dim in range(mask.shape.ndims - 3, -1, -1): elt_size = mask_shape[dim + 1] masked_splits = math_ops.range(split_size[dim]) * elt_size masked_values = ragged_tensor.RaggedTensor.from_row_splits( masked_values, masked_splits, validate=False) return masked_values
def where(condition, x=None, y=None, name=None): """Return the elements, either from `x` or `y`, depending on the `condition`. : If both `x` and `y` are `None`: Returns the coordinates of true elements of `condition`. The coordinates are returned in a 2-D tensor with shape `[num_true_values, dim_size(condition)]`, where `result[i]` is the coordinates of the `i`th true value (in row-major order). : If both `x` and `y` are non-`None`: Returns a tensor formed by selecting values from `x` where condition is true, and from `y` when condition is false. In particular: : If `condition`, `x`, and `y` all have the same shape: * `result[i1...iN] = x[i1...iN]` if `condition[i1...iN]` is true. * `result[i1...iN] = y[i1...iN]` if `condition[i1...iN]` is false. : Otherwise: * `condition` must be a vector. * `x` and `y` must have the same number of dimensions. * The outermost dimensions of `condition`, `x`, and `y` must all have the same size. * `result[i] = x[i]` if `condition[i]` is true. * `result[i] = y[i]` if `condition[i]` is false. Args: condition: A potentially ragged tensor of type `bool` x: A potentially ragged tensor (optional). y: A potentially ragged tensor (optional). Must be specified if `x` is specified. Must have the same rank and type as `x`. name: A name of the operation (optional) Returns: : If both `x` and `y` are `None`: A `Tensor` with shape `(num_true, dim_size(condition))`. : Otherwise: A potentially ragged tensor with the same type, rank, and outermost dimension size as `x` and `y`. `result.ragged_rank = max(x.ragged_rank, y.ragged_rank)`. Raises: ValueError: When exactly one of `x` or `y` is non-`None`; or when `condition`, `x`, and `y` have incompatible shapes. #### Examples: >>> # Coordinates where condition is true. >>> condition = tf.ragged.constant([[True, False, True], [False, True]]) >>> print(where(condition)) tf.Tensor( [[0 0] [0 2] [1 1]], shape=(3, 2), dtype=int64) >>> # Elementwise selection between x and y, based on condition. >>> condition = tf.ragged.constant([[True, False, True], [False, True]]) >>> x = tf.ragged.constant([['A', 'B', 'C'], ['D', 'E']]) >>> y = tf.ragged.constant([['a', 'b', 'c'], ['d', 'e']]) >>> print(where(condition, x, y)) <tf.RaggedTensor [[b'A', b'b', b'C'], [b'd', b'E']]> >>> # Row selection between x and y, based on condition. >>> condition = [True, False] >>> x = tf.ragged.constant([['A', 'B', 'C'], ['D', 'E']]) >>> y = tf.ragged.constant([['a', 'b', 'c'], ['d', 'e']]) >>> print(where(condition, x, y)) <tf.RaggedTensor [[b'A', b'B', b'C'], [b'd', b'e']]> """ if (x is None) != (y is None): raise ValueError('x and y must be either both None or both non-None') with ops.name_scope('RaggedWhere', name, [condition, x, y]): condition = ragged_tensor.convert_to_tensor_or_ragged_tensor( condition, name='condition') if x is None: return _coordinate_where(condition) else: x = ragged_tensor.convert_to_tensor_or_ragged_tensor(x, name='x') y = ragged_tensor.convert_to_tensor_or_ragged_tensor(y, name='y') condition, x, y = ragged_tensor.match_row_splits_dtypes(condition, x, y) return _elementwise_where(condition, x, y)
def where(condition, x=None, y=None, name=None): """Return the elements, either from `x` or `y`, depending on the `condition`. : If both `x` and `y` are `None`: Returns the coordinates of true elements of `condition`. The coordinates are returned in a 2-D tensor with shape `[num_true_values, dim_size(condition)]`, where `result[i]` is the coordinates of the `i`th true value (in row-major order). : If both `x` and `y` are non-`None`: Returns a tensor formed by selecting values from `x` where condition is true, and from `y` when condition is false. In particular: : If `condition`, `x`, and `y` all have the same shape: * `result[i1...iN] = x[i1...iN]` if `condition[i1...iN]` is true. * `result[i1...iN] = y[i1...iN]` if `condition[i1...iN]` is false. : Otherwise: * `condition` must be a vector. * `x` and `y` must have the same number of dimensions. * The outermost dimensions of `condition`, `x`, and `y` must all have the same size. * `result[i] = x[i]` if `condition[i]` is true. * `result[i] = y[i]` if `condition[i]` is false. Args: condition: A potentially ragged tensor of type `bool` x: A potentially ragged tensor (optional). y: A potentially ragged tensor (optional). Must be specified if `x` is specified. Must have the same rank and type as `x`. name: A name of the operation (optional) Returns: : If both `x` and `y` are `None`: A `Tensor` with shape `(num_true, dim_size(condition))`. : Otherwise: A potentially ragged tensor with the same type, rank, and outermost dimension size as `x` and `y`. `result.ragged_rank = max(x.ragged_rank, y.ragged_rank)`. Raises: ValueError: When exactly one of `x` or `y` is non-`None`; or when `condition`, `x`, and `y` have incompatible shapes. #### Examples: ```python >>> # Coordinates where condition is true. >>> condition = tf.compat.v1.ragged.constant_value( ... [[True, False, True], [False, True]]) >>> ragged.where(condition) [[0, 0], [0, 2], [1, 1]] >>> # Elementwise selection between x and y, based on condition. >>> condition = tf.compat.v1.ragged.constant_value( ... [[True, False, True], [False, True]]) >>> x = tf.compat.v1.ragged.constant_value([['A', 'B', 'C'], ['D', 'E']]) >>> y = tf.compat.v1.ragged.constant_value([['a', 'b', 'c'], ['d', 'e']]) >>> ragged.where(condition, x, y) [['A', 'b', 'C'], ['d', 'E']] >>> # Row selection between x and y, based on condition. >>> condition = [True, False] >>> x = tf.compat.v1.ragged.constant_value([['A', 'B', 'C'], ['D', 'E']]) >>> y = tf.compat.v1.ragged.constant_value([['a', 'b', 'c'], ['d', 'e']]) >>> ragged.where(condition, x, y) [['A', 'B', 'C'], ['d', 'e']] ``` """ if (x is None) != (y is None): raise ValueError('x and y must be either both None or both non-None') with ops.name_scope('RaggedWhere', name, [condition, x, y]): condition = ragged_tensor.convert_to_tensor_or_ragged_tensor( condition, name='condition') if x is None: return _coordinate_where(condition) else: x = ragged_tensor.convert_to_tensor_or_ragged_tensor(x, name='x') y = ragged_tensor.convert_to_tensor_or_ragged_tensor(y, name='y') condition, x, y = ragged_tensor.match_row_splits_dtypes(condition, x, y) return _elementwise_where(condition, x, y)
def gather(params, indices, validate_indices=None, axis=0, batch_dims=0, name=None): """Gathers ragged slices from `params` axis `0` according to `indices`. Returns `RaggedTensor` output, such that: ```python output.shape = indices.shape + params.shape[1:] output.ragged_rank = indices.shape.ndims + params.ragged_rank output[i...j, d0...dn] = params[indices[i...j], d0...dn] ``` `params` may be ragged. `indices` may be ragged. `indices` must have dtype `int32` or `int64`. If any index is out of bounds, then an error is returned. Examples: ```python >>> params = tf.constant(['a', 'b', 'c', 'd', 'e']) >>> indices = tf.constant([3, 1, 2, 1, 0]) >>> ragged_params = tf.ragged.constant([['a', 'b', 'c'], ['d'], [], ['e']]) >>> ragged_indices = tf.ragged.constant([[3, 1, 2], [1], [], [0]]) >>> print ragged.gather(params, ragged_indices) [['d', 'b', 'c'], ['b'], [], ['a']] >>> print ragged.gather(ragged_params, indices) [['e'], ['d'], [], ['d'], ['a', 'b', 'c']] >>> print ragged.gather(ragged_params, ragged_indices) [[['e'], ['d'], []], [['d']], [], [['a', 'b', 'c']]] ``` Args: params: The potentially ragged tensor from which to gather values. Must be at least rank 1. indices: The potentially ragged tensor indicating which values to gather. Must have dtype `int32` or `int64`. Values must be in the range `[0, params.shape[0]]`. validate_indices: Ignored. axis: Must be zero. batch_dims: Must be zero. name: A name for the operation (optional). Returns: A `RaggedTensor`, where `output.dtype=params.dtype` and `output.shape=indices.shape + params.shape[1:]` and `output.ragged_rank=indices.shape.ndims + params.ragged_rank`. Raises: ValueError: If indices.shape.ndims is not known statically. """ del validate_indices if not isinstance(axis, int) or axis != 0: raise ValueError('axis != 0 is not supported for ragged gather yet.') if not isinstance(batch_dims, int) or batch_dims != 0: raise ValueError('batch_dims != 0 is not supported for ragged gather yet.') with ops.name_scope(name, 'RaggedGather', [params, indices]): params = ragged_tensor.convert_to_tensor_or_ragged_tensor( params, name='params') indices = ragged_tensor.convert_to_tensor_or_ragged_tensor( indices, name='indices') params, indices = ragged_tensor.match_row_splits_dtypes(params, indices) if ragged_tensor.is_ragged(indices): return indices.with_values(gather(params, indices.values)) if not ragged_tensor.is_ragged(params): return array_ops.gather(params, indices) indices = ops.convert_to_tensor(indices) if indices.shape.ndims is None: raise ValueError('indices.shape.ndims must be known statically') result = gen_ragged_array_ops.ragged_gather( indices=indices, params_dense_values=params.flat_values, params_nested_splits=params.nested_row_splits, OUTPUT_RAGGED_RANK=indices.shape.ndims + len(params.nested_row_splits) - 1) # Compose the RaggedTensor from splits & values. return ragged_tensor.RaggedTensor.from_nested_row_splits( result.output_dense_values, result.output_nested_splits, validate=False)
def map_fn(fn, elems, dtype=None, parallel_iterations=None, back_prop=True, swap_memory=False, infer_shape=True, name=None): """map on the list of tensors unpacked from `elems` on dimension 0. The simplest version of `map_fn` repeatedly applies the callable `fn` to a sequence of elements from first to last. The elements are made of the tensors unpacked from `elems`. `dtype` is the data type of the return value of `fn`. Users must provide `dtype` if it is different from the data type of `elems`. Suppose that `elems` is unpacked into `values`, a list of tensors. The shape of the result tensor is `[values.shape[0]] + fn(values[0]).shape`. This method also allows multi-arity `elems` and output of `fn`. If `elems` is a (possibly nested) list or tuple of tensors, then each of these tensors must have a matching first (unpack) dimension. The signature of `fn` may match the structure of `elems`. That is, if `elems` is `(t1, [t2, t3, [t4, t5]])`, then an appropriate signature for `fn` is: `fn = lambda (t1, [t2, t3, [t4, t5]]):`. Furthermore, `fn` may emit a different structure than its input. For example, `fn` may look like: `fn = lambda t1: return (t1 + 1, t1 - 1)`. In this case, the `dtype` parameter is not optional: `dtype` must be a type or (possibly nested) tuple of types matching the output of `fn`. To apply a functional operation to the nonzero elements of a SparseTensor one of the following methods is recommended. First, if the function is expressible as TensorFlow ops, use ```python result = SparseTensor(input.indices, fn(input.values), input.dense_shape) ``` If, however, the function is not expressible as a TensorFlow op, then use ```python result = SparseTensor( input.indices, map_fn(fn, input.values), input.dense_shape) ``` instead. When executing eagerly, map_fn does not execute in parallel even if `parallel_iterations` is set to a value > 1. You can still get the performance benefits of running a function in parallel by using the `tf.contrib.eager.defun` decorator, ```python # Assume the function being used in map_fn is fn. # To ensure map_fn calls fn in parallel, use the defun decorator. @tf.contrib.eager.defun def func(tensor): return tf.map_fn(fn, tensor) ``` Note that if you use the defun decorator, any non-TensorFlow Python code that you may have written in your function won't get executed. See `tf.contrib.eager.defun` for more details. The recommendation would be to debug without defun but switch to defun to get performance benefits of running map_fn in parallel. Args: fn: The callable to be performed. It accepts one argument, which will have the same (possibly nested) structure as `elems`. Its output must have the same structure as `dtype` if one is provided, otherwise it must have the same structure as `elems`. elems: A tensor or (possibly nested) sequence of tensors, each of which will be unpacked along their first dimension. The nested sequence of the resulting slices will be applied to `fn`. dtype: (optional) The output type(s) of `fn`. If `fn` returns a structure of Tensors differing from the structure of `elems`, then `dtype` is not optional and must have the same structure as the output of `fn`. Use `RaggedTensorType` to declare an output of type `RaggedTensor`. parallel_iterations: (optional) The number of iterations allowed to run in parallel. When graph building, the default value is 10. While executing eagerly, the default value is set to 1. back_prop: (optional) True enables support for back propagation. swap_memory: (optional) True enables GPU-CPU memory swapping. infer_shape: (optional) False disables tests for consistent output shapes. name: (optional) Name prefix for the returned tensors. Returns: A possibly nested sequence of potentially ragged tensors. Each tensor packs the results of applying `fn` to tensors unpacked from `elems` along the first dimension, from first to last. Raises: TypeError: if `fn` is not callable or the structure of the output of `fn` and `dtype` do not match, or if elems is a SparseTensor. ValueError: if the lengths of the output of `fn` and `dtype` do not match. #### Examples: ```python elems = np.array([1, 2, 3, 4, 5, 6]) squares = map_fn(lambda x: x * x, elems) # squares == [1, 4, 9, 16, 25, 36] ``` ```python elems = (np.array([1, 2, 3]), np.array([-1, 1, -1])) alternate = map_fn(lambda x: x[0] * x[1], elems, dtype=tf.int64) # alternate == [-1, 2, -3] ``` ```python elems = np.array([1, 2, 3]) alternates = map_fn(lambda x: (x, -x), elems, dtype=(tf.int64, tf.int64)) # alternates[0] == [1, 2, 3] # alternates[1] == [-1, -2, -3] ``` ```python elems=ragged.constant([[1, 2, 3], [4, 5], [6, 7]]) mean = map_fn(tf.reduce_mean, elems) # mean == [2, 4, 6] ``` ```python elems=ragged.constant([[1, 2, 3], [4, 5], [6, 7]], dtype=tf.int64) out = map_fn(fn=lambda x: x+1, elems, dtype=ragged.RaggedTensorType(type=tf.int64, ragged_rank=0)) # out = tf.ragged.constant([[2, 3, 4], [5, 6], [7, 8]]) ``` """ if not callable(fn): raise TypeError("fn must be callable.") if isinstance(elems, sparse_tensor.SparseTensor): raise TypeError( "To perform a map on the values of a sparse tensor use either " " SparseTensor(input.indices, fn(input.values), input.dense_shape) or " " SparseTensor(input.indices, map_fn(fn, input.values), " "input.dense_shape)") in_graph_mode = not context.executing_eagerly() # Set the default number of parallel_iterations depending on graph/eager mode. if in_graph_mode and not parallel_iterations: parallel_iterations = 10 elif not in_graph_mode and not parallel_iterations: parallel_iterations = 1 if not in_graph_mode and parallel_iterations > 1: logging.log_first_n( logging.WARN, "Setting parallel_iterations > 1 has no " "effect when executing eagerly. Consider calling map_fn" " with tf.contrib.eager.defun to execute fn in " "parallel.", 1) parallel_iterations = 1 input_is_sequence = nest.is_sequence(elems) input_flatten = lambda x: nest.flatten(x) if input_is_sequence else [x] def input_pack(x): return nest.pack_sequence_as(elems, x) if input_is_sequence else x[0] elems_flat = input_flatten(elems) elems_flat = ragged_tensor.match_row_splits_dtypes(*elems_flat) with ops.name_scope(name, "map", elems_flat): # TODO(akshayka): Remove the in_graph_mode check once caching devices are # supported in Eager if in_graph_mode: # Any get_variable calls in fn will cache the first call locally # and not issue repeated network I/O requests for each iteration. varscope = vs.get_variable_scope() varscope_caching_device_was_none = False if varscope.caching_device is None: # TODO(ebrevdo): Change to using colocate_with here and in other # methods. varscope.set_caching_device(lambda op: op.device) varscope_caching_device_was_none = True elems_flat = [ ragged_tensor.convert_to_tensor_or_ragged_tensor(elem, name="elem") for elem in elems_flat ] # We can either infer the output, or we can assume that it will be the same # as the input structure. dtype = dtype or input_pack([elem.dtype for elem in elems_flat]) # Find the number of iterations, n may be known statically. if isinstance(elems_flat[0], ragged_tensor.RaggedTensor): n = elems_flat[0].nrows(out_type=dtypes.int32) else: static_shape = elems_flat[0].shape if static_shape.ndims is not None and static_shape.ndims < 1: if len(elems_flat) == 1: raise ValueError( "elems must be a 1+ dimensional Tensor, not a scalar") else: raise ValueError( "elements in elems must be 1+ dimensional Tensors, not scalars" ) n = (tensor_shape.dimension_value(static_shape[0]) or array_ops.shape(elems_flat[0])[0]) n = math_ops.cast(n, dtype=dtypes.int32) # Create a flat list of TAs. # Flatten the dtype structure to a list. dtype_flat = nest.flatten(dtype) # decompose to components dtype_components = [_maybe_decompose_dtype(d) for d in dtype_flat] dtype_components_flat = nest.flatten(dtype_components) # Create TensorArrays. accs_ta = [ tensor_array_ops.TensorArray(dtype=t, dynamic_size=False, infer_shape=infer_shape, size=n) for t in dtype_components_flat ] i = constant_op.constant(0, dtype=dtypes.int32) def compute(i, tas): """The loop body of map_fn. Args: i: the loop counter tas: the flat TensorArray accumulator list Returns: (i + 1, tas): the updated counter + updated TensorArrays Raises: TypeError: if dtype and packed_fn_values structure do not match ValueType: if dtype and packed_fn_values lengths do not match """ # Get Tensors or RaggedTensors sliced at i, then pack it back to the # original structure. packed_values = input_pack( [elem_flat[i] for elem_flat in elems_flat]) packed_fn_values = fn(packed_values) # Check that the structure of the output matches what was declared or # inferred. # nest.assert_same_structure(dtype or elems, packed_fn_values) # Flatten and decompose to a list of Tensors flat_fn_values = nest.flatten(packed_fn_values) # If we declared that we are expecting a RaggedTensor output, but we get a # Tensor output. We should try to convert it to a RaggedTensor. flat_fn_composite_tensors = list( _convert_declared(flat_fn_values, dtype_flat)) flat_fn_components = [ _maybe_decompose_tensor(t) for t in flat_fn_composite_tensors ] flat_fn_tensors = nest.flatten(flat_fn_components) # Write to TAs. tas = [ ta.write(i, value) for (ta, value) in zip(tas, flat_fn_tensors) ] return (i + 1, tas) _, r_a = control_flow_ops.while_loop( lambda i, _: i < n, compute, (i, accs_ta), parallel_iterations=parallel_iterations, back_prop=back_prop, swap_memory=swap_memory, maximum_iterations=n) # TODO(akshayka): Remove the in_graph_mode check once caching devices are # supported in Eager if in_graph_mode and varscope_caching_device_was_none: varscope.set_caching_device(None) # Pack back into a list of components results_as_components = nest.pack_sequence_as(dtype_components, r_a) # Stack TensorArrays for Tensor outputs, and concat RaggedTensor outputs. def _stack_or_concat(e): if isinstance(e, _RaggedTensorComponents): return _concat_ragged_tensor_components(e) else: result = e.stack() return result results_flat_components = [ _stack_or_concat(e) for e in results_as_components ] results_packed = [ _maybe_recompose_tensor(c) for c in results_flat_components ] results_packed = nest.pack_sequence_as(dtype, results_packed) return results_packed
def boolean_mask(data, mask, name=None): """Applies a boolean mask to `data` without flattening the mask dimensions. Returns a potentially ragged tensor that is formed by retaining the elements in `data` where the corresponding value in `mask` is `True`. * `output[a1...aA, i, b1...bB] = data[a1...aA, j, b1...bB]` Where `j` is the `i`th `True` entry of `mask[a1...aA]`. Note that `output` preserves the mask dimensions `a1...aA`; this differs from `tf.boolean_mask`, which flattens those dimensions. Args: data: A potentially ragged tensor. mask: A potentially ragged boolean tensor. `mask`'s shape must be a prefix of `data`'s shape. `rank(mask)` must be known statically. name: A name prefix for the returned tensor (optional). Returns: A potentially ragged tensor that is formed by retaining the elements in `data` where the corresponding value in `mask` is `True`. * `rank(output) = rank(data)`. * `output.ragged_rank = max(data.ragged_rank, rank(mask) - 1)`. Raises: ValueError: if `rank(mask)` is not known statically; or if `mask.shape` is not a prefix of `data.shape`. #### Examples: >>> # Aliases for True & False so data and mask line up. >>> T, F = (True, False) >>> tf.ragged.boolean_mask( # Mask a 2D Tensor. ... data=[[1, 2, 3], [4, 5, 6], [7, 8, 9]], ... mask=[[T, F, T], [F, F, F], [T, F, F]]).to_list() [[1, 3], [], [7]] >>> tf.ragged.boolean_mask( # Mask a 2D RaggedTensor. ... tf.ragged.constant([[1, 2, 3], [4], [5, 6]]), ... tf.ragged.constant([[F, F, T], [F], [T, T]])).to_list() [[3], [], [5, 6]] >>> tf.ragged.boolean_mask( # Mask rows of a 2D RaggedTensor. ... tf.ragged.constant([[1, 2, 3], [4], [5, 6]]), ... tf.ragged.constant([True, False, True])).to_list() [[1, 2, 3], [5, 6]] """ with ops.name_scope(name, 'RaggedMask', [data, mask]): # Convert inputs to tensors. data = ragged_tensor.convert_to_tensor_or_ragged_tensor(data, name='data') mask = ragged_tensor.convert_to_tensor_or_ragged_tensor( mask, dtypes.bool, name='mask') row_splits_dtype, (data, mask) = ragged_tensor.match_row_splits_dtypes( data, mask, return_dtype=True) # Get static rank of mask. if mask.shape.ndims is None: raise ValueError('mask.shape.ndims must be known statically.') elif mask.shape.ndims == 0: raise ValueError('mask cannot be scalar.') # If mask is ragged, then recurse with a non-ragged mask. if ragged_tensor.is_ragged(mask): if not ragged_tensor.is_ragged(data): data = ragged_tensor.RaggedTensor.from_tensor( data, ragged_rank=mask.ragged_rank, row_splits_dtype=mask.row_splits.dtype) # Check that mask.nested_row_splits is a prefix of # data.nested_row_splits. splits_list = [ mask.nested_row_splits, data.nested_row_splits[:mask.ragged_rank] ] with ops.control_dependencies( ragged_util.assert_splits_match(splits_list)): # Strip off ragged `splits` until `mask` is non-ragged. Keep the splits # that we strip off in `splits`, so we can add them back on after # we recursively mask the non-ragged data. splits = [] while ragged_tensor.is_ragged(mask): if mask.shape.ndims > 2: splits.append(mask.row_splits) else: # Count the number of True mask values in each row to find the # lengths of the filtered rows; then convert to splits. int_mask = ragged_functional_ops.map_flat_values( math_ops.cast, mask, dtype=row_splits_dtype) masked_row_lengths = ragged_math_ops.reduce_sum(int_mask, axis=1) splits.append(ragged_util.lengths_to_splits(masked_row_lengths)) mask = mask.values data = data.values # Recursively apply the nested non-ragged mask to the nested data. masked_values = boolean_mask(data, mask) # Add the ragged `splits` back to the result. masked_values = ragged_tensor.RaggedTensor.from_nested_row_splits( masked_values, splits, validate=False) return masked_values # If mask is non-ragged and has rank 1, and data is ragged, then build a # ragged tensor with the indicated rows. elif ragged_tensor.is_ragged(data) and mask.shape.ndims == 1: # Get the masked splits: first get the length of each row, then filter # out the rows that we are deleting, and convert that filtered set of # masks back to a splits tensor. lengths = data.row_lengths() masked_lengths = array_ops.boolean_mask(lengths, mask) masked_splits = ragged_util.lengths_to_splits(masked_lengths) # Get the masked values: first get row ids corresponding to each # value, then use tf.gather to build a boolean mask that's false for # values that come from rows that we are deleting, and use that mask to # construct the masked values tensor. segment_ids = segment_id_ops.row_splits_to_segment_ids(data.row_splits) segment_mask = array_ops.gather(mask, segment_ids) masked_values = boolean_mask(data.values, segment_mask) return ragged_tensor.RaggedTensor.from_row_splits( masked_values, masked_splits, validate=False) # If mask is non-ragged and has rank>1, then convert it to be ragged, # with a ragged rank matching data. if ragged_tensor.is_ragged(data): mask = ragged_tensor.RaggedTensor.from_tensor( mask, ragged_rank=min(data.ragged_rank, mask.shape.ndims - 1), row_splits_dtype=data.row_splits.dtype) return boolean_mask(data, mask) # Otherwise, data and mask are both `Tensor`s. else: # Apply `boolean_mask` to get the masked values. masked_values = array_ops.boolean_mask(data, mask) if mask.shape.ndims >= 2: # Add the innermost ragged dimension. For each innermost cell, get the # number of values it contains. Then flatten that to get a list of # cell lengths, and convert it to splits. Finally, combine the splits # and values to get the innermost ragged tensor. masked_lengths = math_ops.count_nonzero( mask, axis=-1, dtype=row_splits_dtype) flattened_masked_lengths = array_ops.reshape(masked_lengths, [-1]) masked_values = ragged_tensor.RaggedTensor.from_row_lengths( masked_values, flattened_masked_lengths, validate=False) # Wrap remaining ragged dimensions. if mask.shape.ndims > 2: mask_shape = array_ops.shape(mask, out_type=row_splits_dtype) split_size = math_ops.cumprod(mask_shape) + 1 for dim in range(mask.shape.ndims - 3, -1, -1): elt_size = mask_shape[dim + 1] masked_splits = math_ops.range(split_size[dim]) * elt_size masked_values = ragged_tensor.RaggedTensor.from_row_splits( masked_values, masked_splits, validate=False) return masked_values
def _ragged_segment_aggregate(unsorted_segment_op, data, segment_ids, num_segments, name=None): """Aggregates along segments of a RaggedTensor using `unsorted_segment_op`. Returns a RaggedTensor `output` with `num_segments` rows, where the row `output[i]` is formed by combining all rows of `data` whose corresponding `segment_id` is `i`. The values in each row are combined using `unsorted_segment_op`. The length of the row `output[i]` will be the maximum of the lengths of all rows of `data` whose corresponding `segment_id` is `i`. If no `data` rows correspond to a given segment ID, then the output row for that segment ID will be empty. Args: unsorted_segment_op: The tensorflow `op` that should be used to combine values in each row. Must have the same signature and basic behavior as `unsorted_segment_sum`, `unsorted_segment_max`, etc. data: A `RaggedTensor` containing the values to be combined. segment_ids: A `Tensor` or `RaggedTensor`. Must have type `int64` or `int32`. `segment_ids.shape` must be a prefix of `data.shape`. `segment_ids` is not required to be sorted. num_segments: An `int32` or `int64` scalar. name: A name prefix for the returned tensor (optional). Returns: A `RaggedTensor` containing the aggregated values. The returned tensor has the same dtype as `data`, and its shape is `[num_segments] + data.shape[segment_ids.rank:]`. Raises: ValueError: If segment_ids.shape is not a prefix of data.shape. """ if not (ragged_tensor.is_ragged(data) or ragged_tensor.is_ragged(segment_ids)): return unsorted_segment_op(data, segment_ids, num_segments, name) with ops.name_scope(name, 'RaggedSegment', [data, segment_ids, num_segments]) as name: data = ragged_tensor.convert_to_tensor_or_ragged_tensor(data, name='data') segment_ids = ragged_tensor.convert_to_tensor_or_ragged_tensor( segment_ids, name='segment_ids') data, segment_ids = ragged_tensor.match_row_splits_dtypes(data, segment_ids) if segment_ids.dtype not in (dtypes.int32, dtypes.int64): raise ValueError('segment_ids must have dtype int32 or int64.') if ragged_tensor.is_ragged(segment_ids): if not ragged_tensor.is_ragged(data): raise ValueError('segment_ids.shape must be a prefix of data.shape, ' 'but segment_ids is ragged and data is not.') check_splits = check_ops.assert_equal( segment_ids.row_splits, data.row_splits, message='segment_ids.shape must be a prefix of data.shape') with ops.control_dependencies([check_splits]): return _ragged_segment_aggregate(unsorted_segment_op, data.values, segment_ids.values, num_segments, name) # Find the length of each row in data. (shape=[data_nrows]) data_row_lengths = data.row_splits[1:] - data.row_splits[:-1] # Find the length that each output row will have. The length of the row # corresponding to segment `id` is `max(data_row_lengths[i])` where # `segment_ids[i]=id`. (shape=[output_nrows]) output_row_lengths = math_ops.maximum( math_ops.unsorted_segment_max(data_row_lengths, segment_ids, num_segments), 0) # Build the splits tensor for the output RaggedTensor. output_splits = array_ops.concat([ array_ops.zeros([1], output_row_lengths.dtype), math_ops.cumsum(output_row_lengths) ], axis=0) # For each row in `data`, find the start & limit position where that row's # values will be aggregated in output.values. data_row_to_out_row_start = array_ops.gather(output_splits, segment_ids) data_row_to_out_row_limit = data_row_to_out_row_start + data_row_lengths # For each value in `data.values`, find the position where it will # aggregated in `output.values`. # Get the target output values index for each data values index. data_val_to_out_val_index = range(data_row_to_out_row_start, data_row_to_out_row_limit).values # Recursively aggregate the values. output_values = _ragged_segment_aggregate(unsorted_segment_op, data.values, data_val_to_out_val_index, output_splits[-1]) return ragged_tensor.RaggedTensor.from_row_splits( output_values, output_splits, validate=False)