def testTensorParamsAndTensorIndices(self): params = ['a', 'b', 'c', 'd', 'e'] indices = [2, 0, 2, 1] self.assertAllEqual(ragged_gather_ops.gather(params, indices), [b'c', b'a', b'c', b'b']) self.assertIsInstance(ragged_gather_ops.gather(params, indices), ops.Tensor)
def _broadcast_ragged_targets_for_overlap(target_start, target_limit, source_splits): """Repeats target indices for each source item in the same batch. Args: target_start: `<int>[batch_size, (target_size)]` target_limit: `<int>[batch_size, (target_size)]` source_splits: `<int64>[batch_size, (source_size+1)]` Returns: `<int>[batch_size, (source_size), (target_size)]`. A tuple of ragged tensors `(tiled_target_start, tiled_target_limit)` where: * `tiled_target_start[b, s, t] = target_start[b, t]` * `tiled_target_limit[b, s, t] = target_limit[b, t]` """ source_batch_ids = segment_id_ops.row_splits_to_segment_ids(source_splits) target_start = ragged_tensor.RaggedTensor.from_value_rowids( ragged_gather_ops.gather(target_start, source_batch_ids), source_batch_ids) target_limit = ragged_tensor.RaggedTensor.from_value_rowids( ragged_gather_ops.gather(target_limit, source_batch_ids), source_batch_ids) return (target_start, target_limit)
def testReturnNbestAndDetokenize(self): sp = SentencepieceTokenizer( self.model, nbest_size=2, out_type=dtypes.int32, return_nbest=True) sentences = ['I love carpet', 'Never tell me the odds'] result = sp.tokenize(ragged_factory_ops.constant(sentences)) detokenized = sp.detokenize(result) self.assertAllEqual( _utf8(sentences), ragged_gather_ops.gather(detokenized, [0, 2])) self.assertAllEqual( _utf8(sentences), ragged_gather_ops.gather(detokenized, [1, 3]))
def testDocStringExamples(self): params = constant_op.constant(['a', 'b', 'c', 'd', 'e']) indices = constant_op.constant([3, 1, 2, 1, 0]) ragged_params = ragged_factory_ops.constant([['a', 'b', 'c'], ['d'], [], ['e']]) ragged_indices = ragged_factory_ops.constant([[3, 1, 2], [1], [], [0]]) self.assertAllEqual(ragged_gather_ops.gather(params, ragged_indices), [[b'd', b'b', b'c'], [b'b'], [], [b'a']]) self.assertAllEqual(ragged_gather_ops.gather(ragged_params, indices), [[b'e'], [b'd'], [], [b'd'], [b'a', b'b', b'c']]) self.assertAllEqual( ragged_gather_ops.gather(ragged_params, ragged_indices), [[[b'e'], [b'd'], []], [[b'd']], [], [[b'a', b'b', b'c']]])
def testOutOfBoundsError(self): tensor_params = ['a', 'b', 'c'] tensor_indices = [0, 1, 2] ragged_params = ragged_factory_ops.constant([['a', 'b'], ['c']]) ragged_indices = ragged_factory_ops.constant([[0, 3]]) with self.assertRaisesRegexp(errors.InvalidArgumentError, r'indices\[1\] = 3 is not in \[0, 3\)'): self.evaluate(ragged_gather_ops.gather(tensor_params, ragged_indices)) with self.assertRaisesRegexp(errors.InvalidArgumentError, r'indices\[2\] = 2 is not in \[0, 2\)'): self.evaluate(ragged_gather_ops.gather(ragged_params, tensor_indices)) with self.assertRaisesRegexp(errors.InvalidArgumentError, r'indices\[1\] = 3 is not in \[0, 2\)'): self.evaluate(ragged_gather_ops.gather(ragged_params, ragged_indices))
def batch_gather(params: ragged_tensor.RaggedOrDense, indices: ragged_tensor.RaggedOrDense, name=None): """Gathers slices from `params` according to `indices` with batch dims. This operation is similar to `gather`, but it assumes that the leading `N` dimensions of `indices` and `params` are batch dimensions, and performs a gather within each batch. In particular, when using this operation with `N` batch dimensions `B1...BN`: * `indices` has shape `[B1...BN, I]` * `params` has shape `[B1...BN, P1...PM]`. * `result` has shape `[B1...BN, I, P2...PM]`. * `result[b1...bN, i, p2...pM] = params[b1...bN, indices[b1...bN, i], p2...pM]` Args: params: A potentially ragged tensor with shape `[B1...BN, P1...PM]` (`N>=0`, `M>0`). indices: A potentially ragged tensor with shape `[B1...BN, I]` (`N>=0`). name: A name for the operation (optional). Returns: A potentially ragged tensor with shape `[B1...BN, I, P2...PM]`. `result.ragged_rank = max(indices.ragged_rank, params.ragged_rank)`. #### Example: >>> params = tf.ragged.constant([['a', 'b', 'c'], ['d'], [], ['e']]) >>> indices = tf.ragged.constant([[1, 2, 0], [], [], [0, 0]]) >>> tf.compat.v1.batch_gather(params, indices) <tf.RaggedTensor [[b'b', b'c', b'a'], [], [], [b'e', b'e']]> """ return ragged_gather_ops.gather(params, indices, batch_dims=-1, name=name)
def testRaggedParamsAndTensorIndices(self): params = ragged_factory_ops.constant([['a', 'b'], ['c', 'd', 'e'], ['f'], [], ['g']]) indices = [2, 0, 2, 1] self.assertAllEqual( ragged_gather_ops.gather(params, indices), [[b'f'], [b'a', b'b'], [b'f'], [b'c', b'd', b'e']])
def testGradient(self, params, indices, expected_out, out_grad, expected_grad, params_ragged_rank=None): """Tests that ragged_gather generates the right gradient. Args: params: The `params` that should be passed to `gather`. indices: The `indices` that should be passed to `gather`. expected_out: The expected value of `gather(params, indices)`. `expected_out.shape = indices.shape + params.shape[1:]`. out_grad: The value that should be fed in as the gradient for `out` when testing the gradient of `ragged_gather`. Must have the same shape as `expected_out`. expected_grad: The expected gradient for that should be returned for `params`. Must have hte same shape as `params`. params_ragged_rank: The ragged_rank of `params`. """ if context.executing_eagerly(): return params = ragged_factory_ops.constant( params, dtype=dtypes.float32, ragged_rank=params_ragged_rank) indices = constant_op.constant(indices, dtype=dtypes.int32) out_ragged_rank = params.ragged_rank + indices.shape.ndims - 1 out_grad = ragged_factory_ops.constant( out_grad, dtype=dtypes.float32, ragged_rank=out_ragged_rank) expected_out = ragged_factory_ops.constant( expected_out, dtype=dtypes.float32, ragged_rank=out_ragged_rank) expected_grad = ragged_factory_ops.constant( expected_grad, dtype=dtypes.float32, ragged_rank=params.ragged_rank) out = ragged_gather_ops.gather(params, indices) self.assertAllClose(out, expected_out) grads = gradients_impl.gradients( out.flat_values, (params.nested_row_splits + (params.flat_values, indices,)), out_grad.flat_values) param_nested_splits_grads = grads[:-2] params_flat_values_grad = grads[-2] indices_grad = grads[-1] self.assertEqual(indices_grad, None) for splits_grad in param_nested_splits_grads: self.assertEqual(splits_grad, None) # The gradient generates an IndexedSlices; convert back to a normal Tensor. self.assertIsInstance(params_flat_values_grad, indexed_slices.IndexedSlices) params_flat_values_grad = ops.convert_to_tensor(params_flat_values_grad) params_grad = params.with_flat_values(params_flat_values_grad) self.assertAllClose(params_grad, expected_grad, atol=2e-6, rtol=2e-6)
def _ragged_gather_v1(params, indices, validate_indices=None, name=None, axis=0, batch_dims=0): return ragged_gather_ops.gather( params=params, indices=indices, validate_indices=validate_indices, axis=axis, batch_dims=batch_dims, name=name)
def _ragged_gather_v1(params, indices, validate_indices=None, name=None, axis=0, batch_dims=0): return ragged_gather_ops.gather( params=params, indices=indices, validate_indices=validate_indices, axis=axis, batch_dims=batch_dims, name=name)
def testRaggedParamsAndRaggedIndices(self): params = ragged_factory_ops.constant([['a', 'b'], ['c', 'd', 'e'], ['f'], [], ['g']]) indices = ragged_factory_ops.constant([[2, 1], [1, 2, 0], [3]]) self.assertAllEqual( ragged_gather_ops.gather(params, indices), [[[b'f'], [b'c', b'd', b'e']], # [[p[2], p[1] ], [[b'c', b'd', b'e'], [b'f'], [b'a', b'b']], # [p[1], p[2], p[0]], [[]]] # [p[3] ]] ) # pyformat: disable
def test3DRaggedParamsAnd2DTensorIndices(self): params = ragged_factory_ops.constant([[['a', 'b'], []], [['c', 'd'], ['e'], ['f']], [['g']]]) indices = [[1, 2], [0, 1], [2, 2]] self.assertAllEqual( ragged_gather_ops.gather(params, indices), [[[[b'c', b'd'], [b'e'], [b'f']], [[b'g']]], # [[p1, p2], [[[b'a', b'b'], []], [[b'c', b'd'], [b'e'], [b'f']]], # [p0, p1], [[[b'g']], [[b'g']]]] # [p2, p2]] ) # pyformat: disable
def testTensorParamsAnd4DRaggedIndices(self): indices = ragged_factory_ops.constant( [[[[3, 4], [0, 6]], []], [[[2, 1], [1, 0]], [[2, 5]], [[2, 3]]], [[[1, 0]]]], # pyformat: disable ragged_rank=2, inner_shape=(2,)) params = ['a', 'b', 'c', 'd', 'e', 'f', 'g'] self.assertAllEqual( ragged_gather_ops.gather(params, indices), [[[[b'd', b'e'], [b'a', b'g']], []], [[[b'c', b'b'], [b'b', b'a']], [[b'c', b'f']], [[b'c', b'd']]], [[[b'b', b'a']]]]) # pyformat: disable
def _ragged_stack_concat_axis_1(rt_inputs, stack_values): """Helper function to concatenate or stack ragged tensors along axis 1. Args: rt_inputs: A list of RaggedTensors, all with the same rank and ragged_rank. stack_values: Boolean. If true, then stack values; otherwise, concatenate them. Returns: A RaggedTensor. """ num_inputs = len(rt_inputs) rt_nrows = rt_inputs[0].nrows() nrows_msg = 'Input tensors have incompatible shapes.' nrows_checks = [ check_ops.assert_equal(rt.nrows(), rt_nrows, message=nrows_msg) for rt in rt_inputs[1:] ] with ops.control_dependencies(nrows_checks): # Concatentate the inputs together to put them in a single ragged tensor. concatenated_rt = _ragged_stack_concat_axis_0(rt_inputs, stack_values=False) # Use ragged.gather to permute the rows of concatenated_rt. In particular, # permuted_rt = [rt_inputs[0][0], ..., rt_inputs[N][0], # rt_inputs[0][1], ..., rt_inputs[N][1], # ..., # rt_inputs[0][M], ..., rt_input[N][M]] # where `N=num_inputs-1` and `M=rt_nrows-1`. row_indices = math_ops.range(rt_nrows * num_inputs) row_index_matrix = array_ops.reshape(row_indices, [num_inputs, -1]) transposed_row_index_matrix = array_ops.transpose(row_index_matrix) row_permutation = array_ops.reshape(transposed_row_index_matrix, [-1]) permuted_rt = ragged_gather_ops.gather(concatenated_rt, row_permutation) if stack_values: # Add a new splits tensor to group together the values. stack_splits = math_ops.range(0, rt_nrows * num_inputs + 1, num_inputs) _copy_row_shape(rt_inputs, stack_splits) return ragged_tensor.RaggedTensor.from_row_splits(permuted_rt, stack_splits, validate=False) else: # Merge together adjacent rows by dropping the row-split indices that # separate them. concat_splits = permuted_rt.row_splits[::num_inputs] _copy_row_shape(rt_inputs, concat_splits) return ragged_tensor.RaggedTensor.from_row_splits( permuted_rt.values, concat_splits, validate=False)
def _build_ragged_tensor_from_value_ranges(starts, limits, step, values): """Returns a `RaggedTensor` containing the specified sequences of values. Returns a RaggedTensor `output` where: ```python output.shape[0] = starts.shape[0] output[i] = values[starts[i]:limits[i]:step] ``` Requires that `starts.shape == limits.shape` and `0 <= starts[i] <= limits[i] <= values.shape[0]`. Args: starts: 1D integer Tensor specifying the start indices for the sequences of values to include. limits: 1D integer Tensor specifying the limit indices for the sequences of values to include. step: Integer value specifying the step size for strided slices. values: The set of values to select from. Returns: A `RaggedTensor`. Raises: ValueError: Until the prerequisite ops are checked in. """ # Use `ragged_range` to get the index of each value we should include. if step is None: step = 1 step = ops.convert_to_tensor(step, name="step") if step.dtype.is_integer: step = math_ops.cast(step, starts.dtype) else: raise TypeError("slice strides must be integers or None") value_indices = ragged_math_ops.range(starts, limits, step, row_splits_dtype=starts.dtype) # Use `ragged_gather` or `array_ops.gather` to collect the values. if isinstance(values, ragged_tensor.RaggedTensor): gathered_values = ragged_gather_ops.gather( params=values, indices=value_indices.values) else: gathered_values = array_ops.gather(params=values, indices=value_indices.values) # Assemble the RaggedTensor from splits & values. return value_indices.with_values(gathered_values)
def _ragged_stack_concat_axis_1(rt_inputs, stack_values): """Helper function to concatenate or stack ragged tensors along axis 1. Args: rt_inputs: A list of RaggedTensors, all with the same rank and ragged_rank. stack_values: Boolean. If true, then stack values; otherwise, concatenate them. Returns: A RaggedTensor. """ num_inputs = len(rt_inputs) rt_nrows = rt_inputs[0].nrows() nrows_msg = 'Input tensors have incompatible shapes.' nrows_checks = [ check_ops.assert_equal(rt.nrows(), rt_nrows, message=nrows_msg) for rt in rt_inputs[1:] ] with ops.control_dependencies(nrows_checks): # Concatentate the inputs together to put them in a single ragged tensor. concatenated_rt = _ragged_stack_concat_axis_0(rt_inputs, stack_values=False) # Use ragged.gather to permute the rows of concatenated_rt. In particular, # permuted_rt = [rt_inputs[0][0], ..., rt_inputs[N][0], # rt_inputs[0][1], ..., rt_inputs[N][1], # ..., # rt_inputs[0][M], ..., rt_input[N][M]] # where `N=num_inputs-1` and `M=rt_nrows-1`. row_indices = math_ops.range(rt_nrows * num_inputs) row_index_matrix = array_ops.reshape(row_indices, [num_inputs, -1]) transposed_row_index_matrix = array_ops.transpose(row_index_matrix) row_permutation = array_ops.reshape(transposed_row_index_matrix, [-1]) permuted_rt = ragged_gather_ops.gather(concatenated_rt, row_permutation) if stack_values: # Add a new splits tensor to group together the values. stack_splits = math_ops.range(0, rt_nrows * num_inputs + 1, num_inputs) _copy_row_shape(rt_inputs, stack_splits) return ragged_tensor.RaggedTensor.from_row_splits( permuted_rt, stack_splits, validate=False) else: # Merge together adjacent rows by dropping the row-split indices that # separate them. concat_splits = permuted_rt.row_splits[::num_inputs] _copy_row_shape(rt_inputs, concat_splits) return ragged_tensor.RaggedTensor.from_row_splits( permuted_rt.values, concat_splits, validate=False)
def _build_ragged_tensor_from_value_ranges(starts, limits, step, values): """Returns a `RaggedTensor` containing the specified sequences of values. Returns a RaggedTensor `output` where: ```python output.shape[0] = starts.shape[0] output[i] = values[starts[i]:limits[i]:step] ``` Requires that `starts.shape == limits.shape` and `0 <= starts[i] <= limits[i] <= values.shape[0]`. Args: starts: 1D integer Tensor specifying the start indices for the sequences of values to include. limits: 1D integer Tensor specifying the limit indices for the sequences of values to include. step: Integer value specifying the step size for strided slices. values: The set of values to select from. Returns: A `RaggedTensor`. Raises: ValueError: Until the prerequisite ops are checked in. """ # Use `ragged_range` to get the index of each value we should include. if step is None: step = 1 step = ops.convert_to_tensor(step, name="step") if step.dtype.is_integer: step = math_ops.cast(step, dtypes.int64) else: raise TypeError("slice strides must be integers or None") value_indices = ragged_math_ops.range(starts, limits, step) # Use `ragged_gather` or `array_ops.gather` to collect the values. if isinstance(values, ragged_tensor.RaggedTensor): gathered_values = ragged_gather_ops.gather( params=values, indices=value_indices.values) else: gathered_values = array_ops.gather( params=values, indices=value_indices.values) # Assemble the RaggedTensor from splits & values. return value_indices.with_values(gathered_values)
def testRaggedGather(self, params, indices, expected, axis=None, batch_dims=0, params_ragged_rank=None, indices_ragged_rank=None): params = ragged_factory_ops.constant(params, ragged_rank=params_ragged_rank) indices = ragged_factory_ops.constant(indices, ragged_rank=indices_ragged_rank) actual = ragged_gather_ops.gather(params, indices, axis=axis, batch_dims=batch_dims) self.assertAllEqual(actual, self._str_to_bytes(expected))
def testMatchesDenseGather(self, params_shape, indices_shape, axis=None, batch_dims=0): # Build random params & indices matrics w/ the expected shapes. if axis is None: axis = batch_dims params = np.random.randint(100, size=params_shape, dtype=np.int32) indices = np.random.randint(params_shape[axis], size=indices_shape, dtype=np.int32) # Use array_ops.gather to get the expected value. expected = array_ops.gather(params, indices, axis=axis, batch_dims=batch_dims) # Build ragged tensors with varying ragged_ranks from params & axis. params_tensors = [params] + [ ragged_tensor.RaggedTensor.from_tensor(params, ragged_rank=i) for i in range(1, len(params_shape)) ] indices_tensors = [indices] + [ ragged_tensor.RaggedTensor.from_tensor(indices, ragged_rank=i) for i in range(1, len(indices_shape)) ] # For each combination of params & axis tensors, check that # ragged_gather_ops.gather matches array_ops.gather. for params_tensor in params_tensors: for indices_tensor in indices_tensors: actual = ragged_gather_ops.gather(params_tensor, indices_tensor, axis=axis, batch_dims=batch_dims) if isinstance(actual, ragged_tensor.RaggedTensor): actual = actual.to_tensor() self.assertAllEqual( expected, actual, 'params.ragged_rank=%s, indices.ragged_rank=%s' % (getattr(params_tensor, 'ragged_rank', 0), getattr(indices_tensor, 'ragged_rank', 0)))
def _broadcast_ragged_sources_for_overlap(source_start, source_limit, target_splits): """Repeats source indices for each target item in the same batch. Args: source_start: `<int>[batch_size, (source_size)]` source_limit: `<int>[batch_size, (source_size)]` target_splits: `<int64>[batch_size, (target_size+1)]` Returns: `<int>[batch_size, (source_size), (target_size)]`. A tuple of tensors `(tiled_source_start, tiled_source_limit)` where: * `tiled_target_start[b, s, t] = source_start[b, s]` * `tiled_target_limit[b, s, t] = source_limit[b, s]` """ source_splits = source_start.row_splits target_rowlens = target_splits[1:] - target_splits[:-1] source_batch_ids = segment_id_ops.row_splits_to_segment_ids(source_splits) # <int64>[sum(source_size[b] for b in range(batch_size))] # source_repeats[i] is the number of target spans in the batch that contains # source span i. We need to add a new ragged dimension that repeats each # source span this number of times. source_repeats = ragged_gather_ops.gather(target_rowlens, source_batch_ids) # <int64>[sum(source_size[b] for b in range(batch_size)) + 1] # The row_splits tensor for the inner ragged dimension of the result tensors. inner_splits = array_ops.concat([[0], math_ops.cumsum(source_repeats)], axis=0) # <int64>[sum(source_size[b] * target_size[b] for b in range(batch_size))] # Indices for gathering source indices. source_indices = segment_id_ops.row_splits_to_segment_ids(inner_splits) source_start = ragged_tensor.RaggedTensor.from_nested_row_splits( array_ops.gather(source_start.values, source_indices), [source_splits, inner_splits]) source_limit = ragged_tensor.RaggedTensor.from_nested_row_splits( array_ops.gather(source_limit.values, source_indices), [source_splits, inner_splits]) return source_start, source_limit
def _elementwise_where(condition, x, y): """Ragged version of tf.where(condition, x, y).""" condition_is_ragged = isinstance(condition, ragged_tensor.RaggedTensor) x_is_ragged = isinstance(x, ragged_tensor.RaggedTensor) y_is_ragged = isinstance(y, ragged_tensor.RaggedTensor) if not (condition_is_ragged or x_is_ragged or y_is_ragged): return array_ops.where(condition, x, y) elif condition_is_ragged and x_is_ragged and y_is_ragged: return ragged_functional_ops.map_flat_values(array_ops.where, condition, x, y) elif not condition_is_ragged: # Concatenate x and y, and then use `gather` to assemble the selected rows. condition.shape.assert_has_rank(1) x_nrows = _nrows(x) x_and_y = ragged_concat_ops.concat([x, y], axis=0) indices = array_ops.where(condition, math_ops.range(x_nrows), x_nrows + math_ops.range(_nrows(y))) return ragged_gather_ops.gather(x_and_y, indices) else: raise ValueError('Input shapes do not match.')
def _elementwise_where(condition, x, y): """Ragged version of tf.where(condition, x, y).""" condition_is_ragged = isinstance(condition, ragged_tensor.RaggedTensor) x_is_ragged = isinstance(x, ragged_tensor.RaggedTensor) y_is_ragged = isinstance(y, ragged_tensor.RaggedTensor) if not (condition_is_ragged or x_is_ragged or y_is_ragged): return array_ops.where(condition, x, y) elif condition_is_ragged and x_is_ragged and y_is_ragged: return ragged_functional_ops.map_flat_values(array_ops.where, condition, x, y) elif not condition_is_ragged: # Concatenate x and y, and then use `gather` to assemble the selected rows. condition.shape.assert_has_rank(1) x_nrows = _nrows(x) x_and_y = ragged_concat_ops.concat([x, y], axis=0) indices = array_ops.where(condition, math_ops.range(x_nrows), x_nrows + math_ops.range(_nrows(y))) return ragged_gather_ops.gather(x_and_y, indices) else: raise ValueError('Input shapes do not match.')
def batch_gather(params, indices, name=None): """Gathers slices from `params` according to `indices` with batch dims. This operation is similar to `gather`, but it assumes that the leading `N` dimensions of `indices` and `params` are batch dimensions, and performs a gather within each batch. In particular, when using this operation with `N` batch dimensions `B1...BN`: * `indices` has shape `[B1...BN, I]` * `params` has shape `[B1...BN, P1...PM]`. * `result` has shape `[B1...BN, I, P2...PM]`. * `result[b1...bN, i, p2...pM] = params[b1...bN, indices[b1...bN, i], p2...pM]` Args: params: A potentially ragged tensor with shape `[B1...BN, P1...PM]` (`N>=0`, `M>0`). indices: A potentially ragged tensor with shape `[B1...BN, I]` (`N>=0`). name: A name for the operation (optional). Returns: A potentially ragged tensor with shape `[B1...BN, I, P2...PM]`. `result.ragged_rank = max(indices.ragged_rank, params.ragged_rank)`. #### Example: ```python >>> params = tf.ragged.constant([['a', 'b', 'c'], ['d'], [], ['e']]) >>> indices = tf.ragged.constant([[1, 2, 0], [], [], [0, 0]]) >>> tf.batch_gather(params, indices) [['b', 'c', 'a'], [], [], ['e', 'e']] ``` """ if not (ragged_tensor.is_ragged(params) or ragged_tensor.is_ragged(indices)): return array_ops.batch_gather(params, indices, name) with ops.name_scope(name, 'RaggedBatchGather', [params, indices]): params = ragged_tensor.convert_to_tensor_or_ragged_tensor( params, name='params') indices = ragged_tensor.convert_to_tensor_or_ragged_tensor( indices, name='indices') indices_ndims = indices.shape.ndims if indices_ndims is None: raise ValueError( 'batch_gather does not allow indices with unknown shape.') if indices_ndims == 0: raise ValueError('indices.rank must be at least 1.') if ragged_tensor.is_ragged(indices): # If the outermost ragged dimension is a batch dimension, recurse. if indices_ndims > 2: if not ragged_tensor.is_ragged(params): raise ValueError('batch shape from indices does ' 'not match params shape') checks = [ check_ops.assert_equal(params.row_splits, indices.row_splits) ] with ops.control_dependencies(checks): return ragged_tensor.RaggedTensor.from_row_splits( batch_gather(params.values, indices.values), indices.row_splits) # Otherwise, indices is a 2D ragged tensor with 1 ragged dimension. else: # Ensure that `params` is ragged and has at least 2 dimensions. if not ragged_tensor.is_ragged(params): if params.shape.ndims is not None and params.shape.ndims < 2: raise ValueError('batch shape from indices does ' 'not match params shape') params = ragged_conversion_ops.from_tensor(params, ragged_rank=1) # Adjust indices from within-batch to global (in params.values), and # then use ragged.gather to gather them. num_indices = indices.row_lengths() params_starts = params.row_starts() adjustments = ragged_util.repeat(params_starts, num_indices, axis=0) adjusted_index_values = math_ops.to_int64( indices.values) + adjustments return ragged_tensor.RaggedTensor.from_row_splits( ragged_gather_ops.gather(params.values, adjusted_index_values), indices.row_splits) else: # params is a RaggedTensor and indices is a Tensor. if indices_ndims == 1: return ragged_gather_ops.gather(params, indices) elif indices_ndims == 2: # Adjust indices from batch-local to global (in params.values) adjustments = array_ops.expand_dims(params.row_starts(), 1) adjusted_indices = math_ops.to_int64(indices) + adjustments return ragged_gather_ops.gather(params.values, adjusted_indices) else: raise ValueError( 'batch shape from indices does not match params shape')
def batch_gather(params, indices, name=None): """Gathers slices from `params` according to `indices` with batch dims. This operation is similar to `gather`, but it assumes that the leading `N` dimensions of `indices` and `params` are batch dimensions, and performs a gather within each batch. In particular, when using this operation with `N` batch dimensions `B1...BN`: * `indices` has shape `[B1...BN, I]` * `params` has shape `[B1...BN, P1...PM]`. * `result` has shape `[B1...BN, I, P2...PM]`. * `result[b1...bN, i, p2...pM] = params[b1...bN, indices[b1...bN, i], p2...pM]` Args: params: A potentially ragged tensor with shape `[B1...BN, P1...PM]` (`N>=0`, `M>0`). indices: A potentially ragged tensor with shape `[B1...BN, I]` (`N>=0`). name: A name for the operation (optional). Returns: A potentially ragged tensor with shape `[B1...BN, I, P2...PM]`. `result.ragged_rank = max(indices.ragged_rank, params.ragged_rank)`. #### Example: ```python >>> params = tf.ragged.constant([['a', 'b', 'c'], ['d'], [], ['e']]) >>> indices = tf.ragged.constant([[1, 2, 0], [], [], [0, 0]]) >>> tf.compat.v1.batch_gather(params, indices) [['b', 'c', 'a'], [], [], ['e', 'e']] ``` """ if not (ragged_tensor.is_ragged(params) or ragged_tensor.is_ragged(indices)): return array_ops.batch_gather(params, indices, name) with ops.name_scope(name, 'RaggedBatchGather', [params, indices]): params = ragged_tensor.convert_to_tensor_or_ragged_tensor( params, name='params') indices = ragged_tensor.convert_to_tensor_or_ragged_tensor( indices, name='indices') params, indices = ragged_tensor.match_row_splits_dtypes(params, indices) indices_ndims = indices.shape.ndims if indices_ndims is None: raise ValueError( 'batch_gather does not allow indices with unknown shape.') if indices_ndims == 0: raise ValueError('indices.rank must be at least 1.') if ragged_tensor.is_ragged(indices): # If the outermost ragged dimension is a batch dimension, recurse. if indices_ndims > 2: if not ragged_tensor.is_ragged(params): raise ValueError('batch shape from indices does ' 'not match params shape') checks = [check_ops.assert_equal(params.row_splits, indices.row_splits)] with ops.control_dependencies(checks): return ragged_tensor.RaggedTensor.from_row_splits( batch_gather(params.values, indices.values), indices.row_splits, validate=False) # Otherwise, indices is a 2D ragged tensor with 1 ragged dimension. else: # Ensure that `params` is ragged and has at least 2 dimensions. if not ragged_tensor.is_ragged(params): if params.shape.ndims is not None and params.shape.ndims < 2: raise ValueError('batch shape from indices does ' 'not match params shape') params = ragged_tensor.RaggedTensor.from_tensor( params, ragged_rank=1, row_splits_dtype=indices.row_splits.dtype) # Adjust indices from within-batch to global (in params.values), and # then use ragged.gather to gather them. num_indices = indices.row_lengths() params_starts = params.row_starts() adjustments = ragged_util.repeat(params_starts, num_indices, axis=0) adjusted_index_values = ( math_ops.cast(indices.values, adjustments.dtype) + adjustments) return ragged_tensor.RaggedTensor.from_row_splits( ragged_gather_ops.gather(params.values, adjusted_index_values), indices.row_splits, validate=False) else: # params is a RaggedTensor and indices is a Tensor. if indices_ndims == 1: return ragged_gather_ops.gather(params, indices) elif indices_ndims == 2: # Adjust indices from batch-local to global (in params.values) adjustments = array_ops.expand_dims(params.row_starts(), 1) adjusted_indices = ( math_ops.cast(indices, adjustments.dtype) + adjustments) return ragged_gather_ops.gather(params.values, adjusted_indices) else: raise ValueError('batch shape from indices does not match params shape')
def testRaggedParamsAndScalarIndices(self): params = ragged_factory_ops.constant([['a', 'b'], ['c', 'd', 'e'], ['f'], [], ['g']]) indices = 1 self.assertRaggedEqual(ragged_gather_ops.gather(params, indices), [b'c', b'd', b'e'])
def testTensorParamsAndRaggedIndices(self): params = ['a', 'b', 'c', 'd', 'e'] indices = ragged_factory_ops.constant([[2, 1], [1, 2, 0], [3]]) self.assertAllEqual( ragged_gather_ops.gather(params, indices), [[b'c', b'b'], [b'b', b'c', b'a'], [b'd']])