def testRaggedTile(self, descr, rt_input, multiples, expected, ragged_rank=None): rt = ragged_factory_ops.constant(rt_input, ragged_rank) expected_shape = [ None if dim is None else dim * multiple for (dim, multiple) in zip(rt.shape.as_list(), multiples) ] # Test with both const & non-const multiples: ragged_tile has a few code # paths that optimize the case where multiples[d] is known to be 1. const_multiples = constant_op.constant(multiples, dtypes.int64) non_const_multiples = array_ops.placeholder_with_default( const_multiples, shape=[len(multiples)]) for multiples_tensor in (const_multiples, non_const_multiples): tiled = ragged_array_ops.tile(rt, multiples_tensor) self.assertEqual(tiled.ragged_rank, rt.ragged_rank) self.assertEqual(tiled.shape.ndims, rt.shape.ndims) if multiples_tensor is const_multiples: self.assertEqual(tiled.shape.as_list(), expected_shape) with self.test_session(): self.assertEqual(tiled.eval().tolist(), expected)
def testRaggedTileWithTensorInput(self): # When the input is a `Tensor`, ragged_tile just delegates to tf.tile. dt = constant_op.constant([[1, 2], [3, 4]]) tiled = ragged_array_ops.tile(dt, [3, 2]) expected = [[1, 2, 1, 2], [3, 4, 3, 4], [1, 2, 1, 2], [3, 4, 3, 4], [1, 2, 1, 2], [3, 4, 3, 4]] # pyformat: disable self.assertAllEqual(tiled, expected)
def testRaggedTile(self, descr, rt_input, multiples, expected, ragged_rank=None): rt = ragged_factory_ops.constant(rt_input, ragged_rank) expected_shape = [ None if dim is None else dim * multiple for (dim, multiple) in zip(rt.shape.as_list(), multiples) ] # Test with both const & non-const multiples: ragged_tile has a few code # paths that optimize the case where multiples[d] is known to be 1. const_multiples = constant_op.constant(multiples, dtypes.int64) non_const_multiples = array_ops.placeholder_with_default( const_multiples, shape=[len(multiples)]) for multiples_tensor in (const_multiples, non_const_multiples): tiled = ragged_array_ops.tile(rt, multiples_tensor) self.assertEqual(tiled.ragged_rank, rt.ragged_rank) self.assertEqual(tiled.shape.ndims, rt.shape.ndims) if multiples_tensor is const_multiples: self.assertEqual(tiled.shape.as_list(), expected_shape) self.assertAllEqual(tiled, expected)
def testRaggedTileWithTensorInput(self): # When the input is a `Tensor`, ragged_tile just delegates to tf.tile. dt = constant_op.constant([[1, 2], [3, 4]]) tiled = ragged_array_ops.tile(dt, [3, 2]) expected = [[1, 2, 1, 2], [3, 4, 3, 4], [1, 2, 1, 2], [3, 4, 3, 4], [1, 2, 1, 2], [3, 4, 3, 4]] # pyformat: disable self.assertRaggedEqual(tiled, expected)
def testRaggedTileWithTensorInput(self): # When the input is a `Tensor`, ragged_tile just delegates to tf.tile. dt = constant_op.constant([[1, 2], [3, 4]]) tiled = ragged_array_ops.tile(dt, [3, 2]) expected = [[1, 2, 1, 2], [3, 4, 3, 4], [1, 2, 1, 2], [3, 4, 3, 4], [1, 2, 1, 2], [3, 4, 3, 4]] # pyformat: disable with self.test_session(): self.assertEqual(tiled.eval().tolist(), expected)
def _broadcast_to_ragged_shape(rt_input, dst_shape, broadcast_inner_dimensions): """Broadcasts rt_input to the ragged shape `dst_shape`.""" # Check that rt_input and dst_shape have the same row_splits dtype. if (isinstance(rt_input, ragged_tensor.RaggedTensor) and rt_input.row_splits.dtype != dst_shape.dim_size_dtype): if not ragged_config.auto_cast_partition_dtype(): raise ValueError('rt_input and dst_shape have different row_split ' 'dtypes; use RaggedTensor.with_row_splits_dtype() or ' 'RaggedTensorDynamicShape.with_dim_size_dtype() to ' 'convert to a compatible dtype.') rt_input = rt_input.with_row_splits_dtype(dtypes.int64) dst_shape = dst_shape.with_dim_size_dtype(dtypes.int64) # dst_shape's rank and ragged_rank must be greater than or equal to rt_input's if rt_input.shape.ndims is None or dst_shape.rank is None: raise ValueError('Unable to broadcast: unknown rank') if rt_input.shape.ndims > dst_shape.rank: raise ValueError('Incompatible with shape: rank mismatch') if (isinstance(rt_input, ragged_tensor.RaggedTensor) and rt_input.ragged_rank >= dst_shape.num_partitioned_dimensions): raise ValueError('Incompatible with shape: ragged rank mismatch') src_shape = RaggedTensorDynamicShape.from_tensor(rt_input) src_shape = src_shape.broadcast_to_rank(dst_shape.rank) # Add dimensions to rt_input so its rank and ragged_rank matches dst_shape. if dst_shape.rank > rt_input.shape.ndims: if rt_input.shape.ndims < dst_shape.num_inner_dimensions + 1: rt_input = array_ops.reshape( rt_input, array_ops.concat([[-1], dst_shape.inner_dim_sizes], axis=0)) for _ in range(dst_shape.rank - rt_input.shape.ndims): if ragged_tensor.is_ragged(rt_input): nrows = rt_input.nrows() else: nrows = array_ops.shape(rt_input, out_type=dst_shape.dim_size_dtype)[0] rt_input = ragged_tensor.RaggedTensor.from_row_lengths(rt_input, [nrows], validate=False) # Add ragged dimensions to match dst_shape. if ragged_tensor.is_ragged(rt_input): inner_rank_diff = ( rt_input.flat_values.shape.ndims - 1 - dst_shape.num_inner_dimensions) if inner_rank_diff > 0: rt_input = rt_input.with_flat_values( ragged_tensor.RaggedTensor.from_tensor( rt_input.flat_values, ragged_rank=inner_rank_diff, row_splits_dtype=dst_shape.dim_size_dtype)) else: rt_input = ragged_tensor.RaggedTensor.from_tensor( rt_input, ragged_rank=dst_shape.num_partitioned_dimensions - 1, row_splits_dtype=dst_shape.dim_size_dtype) # Do broadcasting for any dimensions that will remain uniform. We can do # these all at once, since they're independent of one another. multiples = [1] * dst_shape.rank for axis in range(dst_shape.num_partitioned_dimensions): if not src_shape.is_ragged(axis) and not dst_shape.is_ragged(axis): src_size = src_shape.dimension_size(axis) dst_size = dst_shape.dimension_size(axis) if ((tensor_util.constant_value(src_size) in (1, None)) and (tensor_util.constant_value(dst_size) != 1)): multiples[axis] = array_ops.where( math_ops.equal(src_size, 1), dst_size, 1) if not all(isinstance(v, int) and v == 1 for v in multiples): multiples = array_ops.stack(multiples, axis=0) rt_input = ragged_array_ops.tile(rt_input, multiples) if broadcast_inner_dimensions: rt_input = rt_input.with_flat_values( array_ops.reshape( rt_input.flat_values, array_ops.concat([[-1], dst_shape.inner_dim_sizes], axis=0))) # Do broadcasting for dimensions that become ragged. We must do these from # outermost to innermost. for axis in range(dst_shape.num_partitioned_dimensions): if not src_shape.is_ragged(axis) and dst_shape.is_ragged(axis): dst_size = dst_shape.dimension_size(axis) rt_input = _ragged_tile_axis(rt_input, axis, dst_size, dst_shape.dim_size_dtype) return rt_input
def gather_nd(params, indices, batch_dims=0, name=None): """Gather slices from `params` using `n`-dimensional indices. This operation is similar to `gather`, but it uses the innermost dimension of `indices` to define a slice into `params`. In particular, if: * `indices` has shape `[A1...AN, I]` * `params` has shape `[B1...BM]` Then: * `result` has shape `[A1...AN, B_{I+1}...BM]`. * `result[a1...aN] = params[indices[a1...aN, :]]` Args: params: A potentially ragged tensor with shape `[A1...AN, I]`. indices: A potentially ragged tensor with shape `[B1...BM]`. batch_dims: Must be zero. name: A name for the operation (optional). Returns: A potentially ragged tensor with shape `[A1...AN, B_{I+1}...BM]`. #### Examples: ```python >>> params = tf.compat.v1.ragged.constant_value( ... [ [ ['000', '001'], ['010' ] ], ... [ ['100' ], ['110', '111', '112'], ['120'] ], ... [ [ ], ['210' ] ] ]) >>> # Gather 2D slices from a 3D tensor >>> ragged.gather_nd(params, [[2], [0]]) [ [ [ ], ['210'] ] [ ['000', '001'], ['010'] ] ] >>> # Gather 1D slices from a 3D tensor >>> ragged.gather_nd(params, [[2, 1], [0, 0]]) [['210'], ['000', '001']] >>> # Gather scalars from a 3D tensor >>> ragged.gather_nd(params, [[0, 0, 1], [1, 1, 2]]) ['001', '112'] ``` """ if not isinstance(batch_dims, int) or batch_dims != 0: raise ValueError('batch_dims != 0 is not supported for ragged gather yet.') if not (ragged_tensor.is_ragged(params) or ragged_tensor.is_ragged(indices)): return array_ops.gather_nd(params, indices, name) with ops.name_scope(name, 'RaggedGatherNd', [params, indices]): params = ragged_tensor.convert_to_tensor_or_ragged_tensor( params, name='params') indices = ragged_tensor.convert_to_tensor_or_ragged_tensor( indices, name='indices') params, indices = ragged_tensor.match_row_splits_dtypes(params, indices) indices_shape = indices.shape indices_ndims = indices_shape.ndims if indices_ndims is None: raise ValueError('indices.rank be statically known.') if indices_ndims == 0: raise ValueError('indices.rank must be at least 1.') if (ragged_tensor.is_ragged(indices) and indices_ndims == indices.ragged_rank + 1): raise ValueError('The innermost dimension of indices may not be ragged') # `index_size` is the "n" in "gather_nd" -- i.e., the number of dimensions # that each index slices into. index_size = tensor_shape.dimension_value(indices_shape[-1]) if index_size is None: raise ValueError('indices.shape[-1] must be statically known.') # If `indices` has more than 2 dimensions, then recurse. If `indices` is # dense, then we convert it to ragged before recursing, and then convert # the result back to `dense` if appropriate. if indices_ndims > 2: indices_is_dense = not ragged_tensor.is_ragged(indices) if indices_is_dense: indices = ragged_tensor.RaggedTensor.from_tensor( indices, ragged_rank=indices_ndims - 2, row_splits_dtype=params.row_splits.dtype) result = indices.with_flat_values(gather_nd(params, indices.flat_values)) if (indices_is_dense and ragged_tensor.is_ragged(result) and result.ragged_rank == indices_ndims - 2): result = ragged_tensor.RaggedTensor.to_tensor(result) return result # indices_ndims <= 2, and the innermost dimension of indices may not be # ragged, so `indices` must not be ragged. assert not ragged_tensor.is_ragged(indices) assert ragged_tensor.is_ragged(params) # Handle corner case: An empty index tuple selects the entire `params` # value. So if `index_size` is zero, then tile `params`. if index_size == 0: params_ndims = params.ragged_rank + array_ops.rank(params.flat_values) for dim in range(indices_ndims - 1): params = ragged_array_ops.expand_dims(params, axis=0) multiples = array_ops.concat([ array_ops.shape(indices)[:-1], array_ops.ones([params_ndims], dtypes.int32) ], axis=0) return ragged_array_ops.tile(params, multiples) # When index_size=1, we can just flatten the index tuples and use gather. elif index_size == 1: flattened_index_tuples = array_ops.reshape(indices, [-1]) return gather(params, flattened_index_tuples) # Otherwise, params is a RaggedTensor, and indices is a 1D or 2D Tensor. # Flatten both the index tuples and the params, such that the flattened # index tuples point to the correct values in the flattened params; and # then use ragged.gather on the flattened index tuples & params. else: indices = math_ops.cast(indices, params.row_splits.dtype) # Flatten the outermost 2 dimensions of the index tuples & params. flattened_index_tuples = array_ops.gather(params.row_splits, indices[..., 0]) flattened_index_tuples += indices[..., 1] flattened_params = params.values # Flatten any remaining dimensions. for dim in range(2, index_size): if not ragged_tensor.is_ragged(flattened_params): flattened_index_tuples = array_ops.expand_dims( flattened_index_tuples, axis=1) flattened_index_tuples = array_ops.concat( [flattened_index_tuples, indices[..., dim:]], axis=1) return array_ops.gather_nd(flattened_params, flattened_index_tuples) flattened_index_tuples = array_ops.gather( flattened_params.row_starts(), flattened_index_tuples) flattened_index_tuples += indices[..., dim] flattened_params = flattened_params.values # Gather using the flattened index tuples and params. return gather(flattened_params, flattened_index_tuples)
def _broadcast_to_ragged_shape(rt_input, dst_shape, broadcast_inner_dimensions): """Broadcasts rt_input to the ragged shape `dst_shape`.""" # Check that rt_input and dst_shape have the same row_splits dtype. if (isinstance(rt_input, ragged_tensor.RaggedTensor) and rt_input.row_splits.dtype != dst_shape.dim_size_dtype): if not ragged_config.auto_cast_partition_dtype(): raise ValueError( 'rt_input and dst_shape have different row_split ' 'dtypes; use RaggedTensor.with_row_splits_dtype() or ' 'RaggedTensorDynamicShape.with_dim_size_dtype() to ' 'convert to a compatible dtype.') rt_input = rt_input.with_row_splits_dtype(dtypes.int64) dst_shape = dst_shape.with_dim_size_dtype(dtypes.int64) # dst_shape's rank and ragged_rank must be greater than or equal to rt_input's if rt_input.shape.ndims is None or dst_shape.rank is None: raise ValueError('Unable to broadcast: unknown rank') if rt_input.shape.ndims > dst_shape.rank: raise ValueError('Incompatible with shape: rank mismatch') if (isinstance(rt_input, ragged_tensor.RaggedTensor) and rt_input.ragged_rank >= dst_shape.num_partitioned_dimensions): raise ValueError('Incompatible with shape: ragged rank mismatch') src_shape = RaggedTensorDynamicShape.from_tensor(rt_input) src_shape = src_shape.broadcast_to_rank(dst_shape.rank) # Add dimensions to rt_input so its rank and ragged_rank matches dst_shape. if dst_shape.rank > rt_input.shape.ndims: if rt_input.shape.ndims < dst_shape.num_inner_dimensions + 1: rt_input = array_ops.reshape( rt_input, array_ops.concat([[-1], dst_shape.inner_dim_sizes], axis=0)) for _ in range(dst_shape.rank - rt_input.shape.ndims): if ragged_tensor.is_ragged(rt_input): nrows = rt_input.nrows() else: nrows = array_ops.shape(rt_input, out_type=dst_shape.dim_size_dtype)[0] rt_input = ragged_tensor.RaggedTensor.from_row_lengths( rt_input, [nrows], validate=False) # Add ragged dimensions to match dst_shape. if ragged_tensor.is_ragged(rt_input): inner_rank_diff = (rt_input.flat_values.shape.ndims - 1 - dst_shape.num_inner_dimensions) if inner_rank_diff > 0: rt_input = rt_input.with_flat_values( ragged_tensor.RaggedTensor.from_tensor( rt_input.flat_values, ragged_rank=inner_rank_diff, row_splits_dtype=dst_shape.dim_size_dtype)) else: rt_input = ragged_tensor.RaggedTensor.from_tensor( rt_input, ragged_rank=dst_shape.num_partitioned_dimensions - 1, row_splits_dtype=dst_shape.dim_size_dtype) # Do broadcasting for any dimensions that will remain uniform. We can do # these all at once, since they're independent of one another. multiples = [1] * dst_shape.rank for axis in range(dst_shape.num_partitioned_dimensions): if not src_shape.is_ragged(axis) and not dst_shape.is_ragged(axis): src_size = src_shape.dimension_size(axis) dst_size = dst_shape.dimension_size(axis) if ((tensor_util.constant_value(src_size) in (1, None)) and (tensor_util.constant_value(dst_size) != 1)): multiples[axis] = array_ops.where(math_ops.equal(src_size, 1), dst_size, 1) if not all(isinstance(v, int) and v == 1 for v in multiples): multiples = array_ops.stack(multiples, axis=0) rt_input = ragged_array_ops.tile(rt_input, multiples) if broadcast_inner_dimensions: rt_input = rt_input.with_flat_values( array_ops.reshape( rt_input.flat_values, array_ops.concat([[-1], dst_shape.inner_dim_sizes], axis=0))) # Do broadcasting for dimensions that become ragged. We must do these from # outermost to innermost. for axis in range(dst_shape.num_partitioned_dimensions): if not src_shape.is_ragged(axis) and dst_shape.is_ragged(axis): dst_size = dst_shape.dimension_size(axis) rt_input = _ragged_tile_axis(rt_input, axis, dst_size, dst_shape.dim_size_dtype) return rt_input
def gather_nd(params, indices, batch_dims=0, name=None): """Gather slices from `params` using `n`-dimensional indices. This operation is similar to `gather`, but it uses the innermost dimension of `indices` to define a slice into `params`. In particular, if: * `indices` has shape `[A1...AN, I]` * `params` has shape `[B1...BM]` Then: * `result` has shape `[A1...AN, B_{I+1}...BM]`. * `result[a1...aN] = params[indices[a1...aN, :]]` Args: params: A potentially ragged tensor with shape `[A1...AN, I]`. indices: A potentially ragged tensor with shape `[B1...BM]`. batch_dims: Must be zero. name: A name for the operation (optional). Returns: A potentially ragged tensor with shape `[A1...AN, B_{I+1}...BM]`. #### Examples: >>> params = tf.ragged.constant( ... [ [ ['000', '001'], ['010' ] ], ... [ ['100' ], ['110', '111', '112'], ['120'] ], ... [ [ ], ['210' ] ] ]) >>> # Gather 2D slices from a 3D tensor >>> tf.gather_nd(params, [[2], [0]]) <tf.RaggedTensor [[[], [b'210']], [[b'000', b'001'], [b'010']]]> >>> # Gather 1D slices from a 3D tensor >>> tf.gather_nd(params, [[2, 1], [0, 0]]) <tf.RaggedTensor [[b'210'], [b'000', b'001']]> >>> # Gather scalars from a 3D tensor >>> tf.gather_nd(params, [[0, 0, 1], [1, 1, 2]]).numpy() array([b'001', b'112'], dtype=object) """ if not isinstance(batch_dims, int) or batch_dims != 0: raise ValueError( 'batch_dims != 0 is not supported for ragged gather yet.') if not (ragged_tensor.is_ragged(params) or ragged_tensor.is_ragged(indices)): return array_ops.gather_nd(params, indices, name) with ops.name_scope(name, 'RaggedGatherNd', [params, indices]): params = ragged_tensor.convert_to_tensor_or_ragged_tensor( params, name='params') indices = ragged_tensor.convert_to_tensor_or_ragged_tensor( indices, name='indices') params, indices = ragged_tensor.match_row_splits_dtypes( params, indices) indices_shape = indices.shape indices_ndims = indices_shape.ndims if indices_ndims is None: raise ValueError('indices.rank be statically known.') if indices_ndims == 0: raise ValueError('indices.rank must be at least 1.') if (ragged_tensor.is_ragged(indices) and indices_ndims == indices.ragged_rank + 1): raise ValueError( 'The innermost dimension of indices may not be ragged') # `index_size` is the "n" in "gather_nd" -- i.e., the number of dimensions # that each index slices into. index_size = tensor_shape.dimension_value(indices_shape[-1]) if index_size is None: raise ValueError('indices.shape[-1] must be statically known.') # If `indices` has more than 2 dimensions, then recurse. If `indices` is # dense, then we convert it to ragged before recursing, and then convert # the result back to `dense` if appropriate. if indices_ndims > 2: indices_is_dense = not ragged_tensor.is_ragged(indices) if indices_is_dense: indices = ragged_tensor.RaggedTensor.from_tensor( indices, ragged_rank=indices_ndims - 2, row_splits_dtype=params.row_splits.dtype) result = indices.with_flat_values( gather_nd(params, indices.flat_values)) if (indices_is_dense and ragged_tensor.is_ragged(result) and result.ragged_rank == indices_ndims - 2): result = ragged_tensor.RaggedTensor.to_tensor(result) return result # indices_ndims <= 2, and the innermost dimension of indices may not be # ragged, so `indices` must not be ragged. assert not ragged_tensor.is_ragged(indices) assert ragged_tensor.is_ragged(params) # Handle corner case: An empty index tuple selects the entire `params` # value. So if `index_size` is zero, then tile `params`. if index_size == 0: params_ndims = params.ragged_rank + array_ops.rank( params.flat_values) for dim in range(indices_ndims - 1): params = ragged_array_ops.expand_dims(params, axis=0) multiples = array_ops.concat([ array_ops.shape(indices)[:-1], array_ops.ones([params_ndims], dtypes.int32) ], axis=0) return ragged_array_ops.tile(params, multiples) # When index_size=1, we can just flatten the index tuples and use gather. elif index_size == 1: flattened_index_tuples = array_ops.reshape(indices, [-1]) return gather(params, flattened_index_tuples) # Otherwise, params is a RaggedTensor, and indices is a 1D or 2D Tensor. # Flatten both the index tuples and the params, such that the flattened # index tuples point to the correct values in the flattened params; and # then use ragged.gather on the flattened index tuples & params. else: indices = math_ops.cast(indices, params.row_splits.dtype) # Flatten the outermost 2 dimensions of the index tuples & params. flattened_index_tuples = array_ops.gather(params.row_splits, indices[..., 0]) flattened_index_tuples += indices[..., 1] flattened_params = params.values # Flatten any remaining dimensions. for dim in range(2, index_size): if not ragged_tensor.is_ragged(flattened_params): flattened_index_tuples = array_ops.expand_dims( flattened_index_tuples, axis=1) flattened_index_tuples = array_ops.concat( [flattened_index_tuples, indices[..., dim:]], axis=1) return array_ops.gather_nd(flattened_params, flattened_index_tuples) flattened_index_tuples = array_ops.gather( flattened_params.row_starts(), flattened_index_tuples) flattened_index_tuples += indices[..., dim] flattened_params = flattened_params.values # Gather using the flattened index tuples and params. return gather(flattened_params, flattened_index_tuples)