def _elementwise_where_v2(condition, x, y): """Ragged version of tf.where_v2(condition, x, y).""" # Broadcast x, y, and condition to have the same shape. if not (condition.shape.is_fully_defined() and x.shape.is_fully_defined() and y.shape.is_fully_defined() and x.shape == y.shape and condition.shape == x.shape): shape_c = ragged_tensor_shape.RaggedTensorDynamicShape.from_tensor( condition) shape_x = ragged_tensor_shape.RaggedTensorDynamicShape.from_tensor(x) shape_y = ragged_tensor_shape.RaggedTensorDynamicShape.from_tensor(y) shape = ragged_tensor_shape.broadcast_dynamic_shape( shape_c, ragged_tensor_shape.broadcast_dynamic_shape(shape_x, shape_y)) condition = ragged_tensor_shape.broadcast_to(condition, shape) x = ragged_tensor_shape.broadcast_to(x, shape) y = ragged_tensor_shape.broadcast_to(y, shape) condition_is_ragged = isinstance(condition, ragged_tensor.RaggedTensor) x_is_ragged = isinstance(x, ragged_tensor.RaggedTensor) y_is_ragged = isinstance(y, ragged_tensor.RaggedTensor) if not (condition_is_ragged or x_is_ragged or y_is_ragged): return array_ops.where_v2(condition, x, y) return ragged_functional_ops.map_flat_values(array_ops.where_v2, condition, x, y)
def handle(self, args, kwargs): # Extract the binary args. if len(args) > 1: x = args[0] y = args[1] args = args[2:] elif args: kwargs = kwargs.copy() x = args[0] y = kwargs.pop(self._y, None) args = args[1:] else: kwargs = kwargs.copy() x = kwargs.pop(self._x, None) y = kwargs.pop(self._y, None) # Bail if we don't have at least one ragged argument. x_is_ragged = ragged_tensor.is_ragged(x) y_is_ragged = ragged_tensor.is_ragged(y) if not (x_is_ragged or y_is_ragged): return self.NOT_SUPPORTED # Convert args to tensors. Bail if conversion fails. try: x = ragged_tensor.convert_to_tensor_or_ragged_tensor( x, name=self._x, preferred_dtype=(y.dtype if y_is_ragged else None)) y = ragged_tensor.convert_to_tensor_or_ragged_tensor( y, name=self._y, preferred_dtype=(x.dtype if x_is_ragged else None)) except (TypeError, ValueError): return self.NOT_SUPPORTED if x_is_ragged and y_is_ragged: x, y = ragged_tensor.match_row_splits_dtypes(x, y) if ((x_is_ragged and y_is_ragged) or (x_is_ragged and x.flat_values.shape.ndims <= y.shape.ndims) or (y_is_ragged and y.flat_values.shape.ndims <= x.shape.ndims)): bcast_shape = ragged_tensor_shape.broadcast_dynamic_shape( ragged_tensor_shape.RaggedTensorDynamicShape.from_tensor(x), ragged_tensor_shape.RaggedTensorDynamicShape.from_tensor(y)) x = ragged_tensor_shape.broadcast_to( x, bcast_shape, broadcast_inner_dimensions=False) y = ragged_tensor_shape.broadcast_to( y, bcast_shape, broadcast_inner_dimensions=False) x_values = x.flat_values if ragged_tensor.is_ragged(x) else x y_values = y.flat_values if ragged_tensor.is_ragged(y) else y mapped_values = self._original_op(x_values, y_values, *args, **kwargs) if ragged_tensor.is_ragged(x): return x.with_flat_values(mapped_values) else: return y.with_flat_values(mapped_values)
def handle(self, args, kwargs): # Extract the binary args. if len(args) > 1: x = args[0] y = args[1] args = args[2:] elif args: kwargs = kwargs.copy() x = args[0] y = kwargs.pop(self._y, None) args = args[1:] else: kwargs = kwargs.copy() x = kwargs.pop(self._x, None) y = kwargs.pop(self._y, None) # Bail if we don't have at least one ragged argument. x_is_ragged = ragged_tensor.is_ragged(x) y_is_ragged = ragged_tensor.is_ragged(y) if not (x_is_ragged or y_is_ragged): return self.NOT_SUPPORTED # Convert args to tensors. Bail if conversion fails. try: if not x_is_ragged: x = ops.convert_to_tensor(x, name=self._x, preferred_dtype=y.dtype) if not y_is_ragged: y = ops.convert_to_tensor(y, name=self._y, preferred_dtype=x.dtype) except (TypeError, ValueError): return self.NOT_SUPPORTED if x_is_ragged and y_is_ragged: x, y = ragged_tensor.match_row_splits_dtypes(x, y) if ((x_is_ragged and y_is_ragged) or (x_is_ragged and x.flat_values.shape.ndims <= y.shape.ndims) or (y_is_ragged and y.flat_values.shape.ndims <= x.shape.ndims)): bcast_shape = ragged_tensor_shape.broadcast_dynamic_shape( ragged_tensor_shape.RaggedTensorDynamicShape.from_tensor(x), ragged_tensor_shape.RaggedTensorDynamicShape.from_tensor(y)) x = ragged_tensor_shape.broadcast_to( x, bcast_shape, broadcast_inner_dimensions=False) y = ragged_tensor_shape.broadcast_to( y, bcast_shape, broadcast_inner_dimensions=False) x_values = x.flat_values if ragged_tensor.is_ragged(x) else x y_values = y.flat_values if ragged_tensor.is_ragged(y) else y mapped_values = self._original_op(x_values, y_values, *args, **kwargs) if ragged_tensor.is_ragged(x): return x.with_flat_values(mapped_values) else: return y.with_flat_values(mapped_values)
def ragged_binary_elementwise_op(op, x, y): """Binary elementwise api handler for RaggedTensors.""" x_is_ragged = ragged_tensor.is_ragged(x) y_is_ragged = ragged_tensor.is_ragged(y) # Convert args to tensors. x = ragged_tensor.convert_to_tensor_or_ragged_tensor( x, preferred_dtype=(y.dtype if y_is_ragged else None)) y = ragged_tensor.convert_to_tensor_or_ragged_tensor( y, preferred_dtype=x.dtype) if x_is_ragged and y_is_ragged: x, y = ragged_tensor.match_row_splits_dtypes(x, y) # Perform broadcasting, when appropraite if ((x_is_ragged and y_is_ragged) or (x_is_ragged and x.flat_values.shape.ndims <= y.shape.ndims) or (y_is_ragged and y.flat_values.shape.ndims <= x.shape.ndims)): # If both x and y are ragged, they must have the same row_splits_dtype now. if x_is_ragged: dim_size_dtype = x.row_splits.dtype else: dim_size_dtype = y.row_splits.dtype shape_x = ragged_tensor_shape.RaggedTensorDynamicShape.from_tensor( x, dim_size_dtype=dim_size_dtype) shape_y = ragged_tensor_shape.RaggedTensorDynamicShape.from_tensor( y, dim_size_dtype=dim_size_dtype) bcast_shape = ragged_tensor_shape.broadcast_dynamic_shape( shape_x, shape_y) x = ragged_tensor_shape.broadcast_to(x, bcast_shape, broadcast_inner_dimensions=False) y = ragged_tensor_shape.broadcast_to(y, bcast_shape, broadcast_inner_dimensions=False) x_values = x.flat_values if ragged_tensor.is_ragged(x) else x y_values = y.flat_values if ragged_tensor.is_ragged(y) else y mapped_values = op(x_values, y_values) if isinstance(mapped_values, bool): return mapped_values # Special case for tensor_equals. if ragged_tensor.is_ragged(x): return x.with_flat_values(mapped_values) else: return y.with_flat_values(mapped_values)
def batch_gather_with_default(params, indices, default_value='', name=None): """Same as `batch_gather` but inserts `default_value` for invalid indices. This operation is similar to `batch_gather` except that it will substitute the value for invalid indices with `default_value` as the contents. See `batch_gather` for more details. Args: params: A potentially ragged tensor with shape `[B1...BN, P1...PM]` (`N>=0`, `M>0`). indices: A potentially ragged tensor with shape `[B1...BN, I]` (`N>=0`). default_value: A value to be inserted in places where `indices` are out of bounds. Must be the same dtype as params and either a scalar or rank 1. name: A name for the operation (optional). Returns: A potentially ragged tensor with shape `[B1...BN, I, P2...PM]`. `result.ragged_rank = max(indices.ragged_rank, params.ragged_rank)`. #### Example: ```python >>> params = tf.ragged.constant([ ['a', 'b', 'c'], ['d'], [], ['e']]) >>> indices = tf.ragged.constant([[1, 2, -1], [], [], [0, 10]]) >>> batch_gather_with_default(params, indices, 'FOO') [['b', 'c', 'FOO'], [], [], ['e', 'FOO']] ``` """ with ops.name_scope(name, 'RaggedBatchGatherWithDefault'): params = ragged_tensor.convert_to_tensor_or_ragged_tensor( params, name='params', ) indices = ragged_tensor.convert_to_tensor_or_ragged_tensor( indices, name='indices', ) default_value = ragged_tensor.convert_to_tensor_or_ragged_tensor( default_value, name='default_value', ) # TODO(hterry): lift this restriction and support default_values of # of rank > 1 if (default_value.shape.ndims is not 0 and default_value.shape.ndims is not 1): raise ValueError('"default_value" must be a scalar or vector') upper_bounds = None if indices.shape.ndims is None: raise ValueError('Indices must have a known rank.') if params.shape.ndims is None: raise ValueError('Params must have a known rank.') num_batch_dimensions = indices.shape.ndims - 1 pad = None # The logic for this works as follows: # - create a padded params, where: # padded_params[b1...bn, 0] = default_value # padded_params[b1...bn, i] = params[b1...bn, i-1] (i>0) # - create an `upper_bounds` Tensor that contains the number of elements # in each innermost rank. Broadcast `upper_bounds` to be the same shape # as `indices`. # - check to see which index in `indices` are out of bounds and substitute # it with the index containing `default_value` (the first). # - call batch_gather with the indices adjusted. with ops.control_dependencies([ check_ops.assert_greater_equal(array_ops.rank(params), array_ops.rank(indices)) ]): if ragged_tensor.is_ragged(params): row_lengths = ragged_array_ops.expand_dims( params.row_lengths(axis=num_batch_dimensions), axis=-1) upper_bounds = math_ops.cast(row_lengths, indices.dtype) pad_shape = _get_pad_shape(params, indices) pad = ragged_tensor_shape.broadcast_to(default_value, pad_shape) else: params_shape = array_ops.shape(params) pad_shape = array_ops.concat([ params_shape[:num_batch_dimensions], [1], params_shape[num_batch_dimensions + 1:params.shape.ndims] ], 0) upper_bounds = params_shape[num_batch_dimensions] pad = array_ops.broadcast_to(default_value, pad_shape) # Add `default_value` as the first value in the innermost (ragged) rank. pad = math_ops.cast(pad, params.dtype) padded_params = array_ops.concat([pad, params], axis=num_batch_dimensions) # Adjust the indices by substituting out-of-bound indices to the # default-value index (which is the first element) shifted_indices = indices + 1 is_out_of_bounds = (indices < 0) | (indices > upper_bounds) adjusted_indices = ragged_where_op.where( is_out_of_bounds, x=array_ops.zeros_like(indices), y=shifted_indices, ) return array_ops.batch_gather(params=padded_params, indices=adjusted_indices, name=name)
def batch_gather_with_default(params, indices, default_value='', name=None): """Same as `batch_gather` but inserts `default_value` for invalid indices. This operation is similar to `batch_gather` except that it will substitute the value for invalid indices with `default_value` as the contents. See `batch_gather` for more details. Args: params: A potentially ragged tensor with shape `[B1...BN, P1...PM]` (`N>=0`, `M>0`). indices: A potentially ragged tensor with shape `[B1...BN, I]` (`N>=0`). default_value: A value to be inserted in places where `indices` are out of bounds. Must be the same dtype as params and either a scalar or rank 1. name: A name for the operation (optional). Returns: A potentially ragged tensor with shape `[B1...BN, I, P2...PM]`. `result.ragged_rank = max(indices.ragged_rank, params.ragged_rank)`. #### Example: ```python >>> params = tf.ragged.constant([ ['a', 'b', 'c'], ['d'], [], ['e']]) >>> indices = tf.ragged.constant([[1, 2, -1], [], [], [0, 10]]) >>> batch_gather_with_default(params, indices, 'FOO') [['b', 'c', 'FOO'], [], [], ['e', 'FOO']] ``` """ with ops.name_scope(name, 'RaggedBatchGatherWithDefault'): params = ragged_tensor.convert_to_tensor_or_ragged_tensor( params, name='params', ) indices = ragged_tensor.convert_to_tensor_or_ragged_tensor( indices, name='indices', ) default_value = ragged_tensor.convert_to_tensor_or_ragged_tensor( default_value, name='default_value', ) # TODO(hterry): lift this restriction and support default_values of # of rank > 1 if (default_value.shape.ndims is not 0 and default_value.shape.ndims is not 1): raise ValueError('"default_value" must be a scalar or vector') upper_bounds = None if indices.shape.ndims is None: raise ValueError('Indices must have a known rank.') if params.shape.ndims is None: raise ValueError('Params must have a known rank.') num_batch_dimensions = indices.shape.ndims - 1 pad = None # The logic for this works as follows: # - create a padded params, where: # padded_params[b1...bn, 0] = default_value # padded_params[b1...bn, i] = params[b1...bn, i-1] (i>0) # - create an `upper_bounds` Tensor that contains the number of elements # in each innermost rank. Broadcast `upper_bounds` to be the same shape # as `indices`. # - check to see which index in `indices` are out of bounds and substitute # it with the index containing `default_value` (the first). # - call batch_gather with the indices adjusted. with ops.control_dependencies([ check_ops.assert_greater_equal(array_ops.rank(params), array_ops.rank(indices))]): if ragged_tensor.is_ragged(params): row_lengths = ragged_array_ops.expand_dims( params.row_lengths(axis=num_batch_dimensions), axis=-1) upper_bounds = math_ops.cast(row_lengths, indices.dtype) pad_shape = _get_pad_shape(params, indices) pad = ragged_tensor_shape.broadcast_to( default_value, pad_shape) else: params_shape = array_ops.shape(params) pad_shape = array_ops.concat([ params_shape[:num_batch_dimensions], [1], params_shape[num_batch_dimensions + 1:params.shape.ndims] ], 0) upper_bounds = params_shape[num_batch_dimensions] pad = array_ops.broadcast_to(default_value, pad_shape) # Add `default_value` as the first value in the innermost (ragged) rank. pad = math_ops.cast(pad, params.dtype) padded_params = array_ops.concat( [pad, params], axis=num_batch_dimensions) # Adjust the indices by substituting out-of-bound indices to the # default-value index (which is the first element) shifted_indices = indices + 1 is_out_of_bounds = (indices < 0) | (indices > upper_bounds) adjusted_indices = ragged_where_op.where( is_out_of_bounds, x=array_ops.zeros_like(indices), y=shifted_indices, ) return array_ops.batch_gather( params=padded_params, indices=adjusted_indices, name=name)
def testRaggedBroadcastTo(self, x, dim_sizes, expected): shape = RaggedTensorDynamicShape.from_dim_sizes(dim_sizes) result = ragged_tensor_shape.broadcast_to(x, shape) self.assertEqual(getattr(result, 'ragged_rank', 0), getattr(expected, 'ragged_rank', 0)) self.assertAllEqual(result, expected)
def testRaggedBroadcastTo(self, x, dim_sizes, expected): shape = RaggedTensorDynamicShape.from_dim_sizes(dim_sizes) result = ragged_tensor_shape.broadcast_to(x, shape) self.assertEqual( getattr(result, 'ragged_rank', 0), getattr(expected, 'ragged_rank', 0)) self.assertRaggedEqual(result, expected)
def _broadcast_elementwise_args(elementwise_args): """Broadcasts the values of `elementwise_args` to have compatible shapes. Args: elementwise_args: A dictionary whose keys are potentially ragged tensors. Returns: A tuple `(broadcast_args, broadcast_splits, checks)` where: * `broadcast_args` is a dictionary with the same keys as `elementwise_args`, mapping to broadcasted tensors. * `broadcast_splits` is the broadcasted nested row splits. * `checks` is a possibly empty tuple of assertion operations that should be added as control dependencies. Raises: ValueError: If broadcasting fails. """ # No elementwise arguments were used: nothing to do! if not elementwise_args: return elementwise_args, (), () # A single elementwise argument was used: no broadcasting necessary. if len(elementwise_args) == 1: arg = list(elementwise_args.values())[0] if ragged_tensor.is_ragged(arg): return elementwise_args, arg.nested_row_splits, () else: return elementwise_args, (), () # Multiple elementwise arguments. else: is_ragged = [ragged_tensor.is_ragged(t) for t in elementwise_args.values()] if not any(is_ragged): return elementwise_args, (), () # If we have a single ragged tensor plus a set of scalars, then we can # rely on the underlying elementwise op to do broadcasting. if (sum(is_ragged) == 1 and all((ragged_tensor.is_ragged(t) or t.shape.ndims == 0) for t in elementwise_args.values())): nested_splits_lists = [ t.nested_row_splits for t in elementwise_args.values() if ragged_tensor.is_ragged(t)][0] return elementwise_args, nested_splits_lists, () else: # Get the shapes of all the elementwise arguments. shapes = [ragged_tensor_shape.RaggedTensorDynamicShape.from_tensor(t) for t in elementwise_args.values()] # Broadcast the shapes to all have the same rank (the max rank). ranks = [t.shape.ndims for t in elementwise_args.values()] if any(rank is None for rank in ranks): raise ValueError('Unable to broadcast: unknown rank') broadcast_rank = max(ranks) shapes = [shape.broadcast_to_rank(broadcast_rank) for shape in shapes] # For each dimension, broadcast the shapes to be compatible. for axis in range(broadcast_rank): # For each i, broadcast shape[i+1] to be compatible with shape[i]; and # then finally broadcast shape[0] to be compatible with shape[-1]. for i in range(len(shapes)): j = (i + 1) % len(shapes) dim_size = shapes[i].dimension_size(axis) shapes[j] = shapes[j].broadcast_dimension(axis, dim_size) broadcast_shape = shapes[0] # Broadcast every elementwise arg to the shape that we calculated. elementwise_args = dict([ (key, ragged_tensor_shape.broadcast_to(t, broadcast_shape, False)) for (key, t) in elementwise_args.items()]) nested_splits_lists = list(elementwise_args.values())[0].nested_row_splits return elementwise_args, nested_splits_lists, ()
def _broadcast_elementwise_args(elementwise_args): """Broadcasts the values of `elementwise_args` to have compatible shapes. Args: elementwise_args: A dictionary whose keys are potentially ragged tensors. Returns: A tuple `(broadcast_args, broadcast_splits, checks)` where: * `broadcast_args` is a dictionary with the same keys as `elementwise_args`, mapping to broadcasted tensors. * `broadcast_splits` is the broadcasted nested row splits. * `checks` is a possibly empty tuple of assertion operations that should be added as control dependencies. Raises: ValueError: If broadcasting fails. """ # No elementwise arguments were used: nothing to do! if not elementwise_args: return elementwise_args, (), () # A single elementwise argument was used: no broadcasting necessary. if len(elementwise_args) == 1: arg = list(elementwise_args.values())[0] if ragged_tensor.is_ragged(arg): return elementwise_args, arg.nested_row_splits, () else: return elementwise_args, (), () # Multiple elementwise arguments. else: is_ragged = [ ragged_tensor.is_ragged(t) for t in elementwise_args.values() ] if not any(is_ragged): return elementwise_args, (), () # If we have a single ragged tensor plus a set of scalars, then we can # rely on the underlying elementwise op to do broadcasting. if (sum(is_ragged) == 1 and all( (ragged_tensor.is_ragged(t) or t.shape.ndims == 0) for t in elementwise_args.values())): nested_splits_lists = [ t.nested_row_splits for t in elementwise_args.values() if ragged_tensor.is_ragged(t) ][0] return elementwise_args, nested_splits_lists, () else: # Get the shapes of all the elementwise arguments. shapes = [ ragged_tensor_shape.RaggedTensorDynamicShape.from_tensor(t) for t in elementwise_args.values() ] # Broadcast the shapes to all have the same rank (the max rank). ranks = [t.shape.ndims for t in elementwise_args.values()] if any(rank is None for rank in ranks): raise ValueError('Unable to broadcast: unknown rank') broadcast_rank = max(ranks) shapes = [ shape.broadcast_to_rank(broadcast_rank) for shape in shapes ] # For each dimension, broadcast the shapes to be compatible. for axis in range(broadcast_rank): # For each i, broadcast shape[i+1] to be compatible with shape[i]; and # then finally broadcast shape[0] to be compatible with shape[-1]. for i in range(len(shapes)): j = (i + 1) % len(shapes) dim_size = shapes[i].dimension_size(axis) shapes[j] = shapes[j].broadcast_dimension(axis, dim_size) broadcast_shape = shapes[0] # Broadcast every elementwise arg to the shape that we calculated. elementwise_args = dict([ (key, ragged_tensor_shape.broadcast_to(t, broadcast_shape, False)) for (key, t) in elementwise_args.items() ]) nested_splits_lists = list( elementwise_args.values())[0].nested_row_splits return elementwise_args, nested_splits_lists, ()