def _whitespace_tokenize_with_offsets_encode_decode_wrapper( self, input_tensor): """Tokenizes a tensor of UTF-8 strings with rank of 1. Args: input_tensor: The single dimensional Tensor to tokenize. Returns: Tuple of RaggedTensors of tokenized text and byte offsets, with shapes [num_strings, (num_tokens or num_offsets)]. """ # Decode the strings and get byte offsets (codepoints, byte_start_offsets) = ( ragged_string_ops.unicode_decode_with_offsets(input_tensor, "UTF-8")) byte_limit_offsets = array_ops.concat([ byte_start_offsets[:, 1:], math_ops.cast( array_ops.expand_dims(string_ops.string_length(input_tensor), 1), dtypes.int64) ], 1) # Tokenize (codepoint_tokens, codepoint_start_offsets, codepoint_limit_offsets) = ( self._whitespace_tokenize_codepoints_with_offsets(codepoints)) # Encode the codepoints and translate the codepoint offsets to byte offsets. return (ragged_string_ops.unicode_encode(codepoint_tokens, "UTF-8"), array_ops.batch_gather(byte_start_offsets, codepoint_start_offsets), array_ops.batch_gather( byte_limit_offsets, math_ops.subtract(codepoint_limit_offsets, [1])))
def testString(self): params = np.array([[b"asdf", b"zxcv"], [b"qwer", b"uiop"]]) with self.cached_session(): indices_tf = constant_op.constant([1]) self.assertAllEqual( [[b"qwer", b"uiop"]], self.evaluate(array_ops.batch_gather(params, indices_tf)))
def testUnknownIndices(self): # This test needs a placeholder which means we need to construct a graph. with ops.Graph().as_default(): params = constant_op.constant([[0, 1, 2]]) indices = array_ops.placeholder(dtypes.int32, shape=[None, None]) gather_t = array_ops.batch_gather(params, indices) self.assertEqual([1, None], gather_t.get_shape().as_list())
def testString(self): params = np.array([[b"asdf", b"zxcv"], [b"qwer", b"uiop"]]) with self.cached_session(): indices_tf = constant_op.constant([1]) self.assertAllEqual([[b"qwer", b"uiop"]], array_ops.batch_gather(params, indices_tf).eval())
def testEmptySlices(self): with self.session(use_gpu=True): for dtype in _TEST_TYPES: for itype in np.int32, np.int64: params = np.zeros((7, 0, 0), dtype=dtype.as_numpy_dtype) indices = np.array([3, 4], dtype=itype) gather = array_ops.batch_gather(params, indices) self.assertAllEqual(gather.eval(), np.zeros((2, 0, 0)))
def testEmptySlices(self): with self.session(use_gpu=True): for dtype in _TEST_TYPES: for itype in np.int32, np.int64: params = np.zeros((7, 0, 0), dtype=dtype.as_numpy_dtype) indices = np.array([3, 4], dtype=itype) gather = array_ops.batch_gather(params, indices) self.assertAllEqual(gather.eval(), np.zeros((2, 0, 0)))
def testSimpleGather(self): data = np.array([0, 1, 2, 3, 7, 5, 8, 9, 10, 11, 15, 13]) indices = [3, 4] with self.test_session(use_gpu=True): for dtype in _TEST_TYPES: params_np = self._buildParams(data, dtype) params = constant_op.constant(params_np) indices_tf = constant_op.constant(indices) gather_t = array_ops.batch_gather(params, indices_tf) expected_result = np.array([3, 7]) np_val = self._buildParams(expected_result, dtype) gather_val = gather_t.eval() self.assertAllEqual(np_val, gather_val) self.assertEqual(np_val.shape, gather_t.get_shape())
def testSimpleGather(self): data = np.array([0, 1, 2, 3, 7, 5, 8, 9, 10, 11, 15, 13]) indices = [3, 4] with self.test_session(use_gpu=True): for dtype in _TEST_TYPES: params_np = self._buildParams(data, dtype) params = constant_op.constant(params_np) indices_tf = constant_op.constant(indices) gather_t = array_ops.batch_gather(params, indices_tf) expected_result = np.array([3, 7]) np_val = self._buildParams(expected_result, dtype) gather_val = gather_t.eval() self.assertAllEqual(np_val, gather_val) self.assertEqual(np_val.shape, gather_t.get_shape())
def testHigherRank(self): data = np.array([[[0, 1, 2], [3, 7, 5]], [[8, 9, 10], [11, 15, 13]]]) indices = [[[2, 0], [1, 2]], [[2, 0], [0, 1]]] with self.session(): for dtype in _TEST_TYPES: params_np = self._buildParams(data, dtype) params = constant_op.constant(params_np) indices_tf = constant_op.constant(indices) gather_t = array_ops.batch_gather(params, indices_tf) gather_val = self.evaluate(gather_t) expected_result = np.array([[[2, 0], [7, 5]], [[10, 8], [11, 15]]]) np_val = self._buildParams(expected_result, dtype) self.assertAllEqual(np_val, gather_val) self.assertEqual(np_val.shape, gather_t.get_shape())
def test2DArray(self, indices_dtype): data = np.array([[0, 1, 2, 3, 7, 5], [8, 9, 10, 11, 15, 13]]) indices = [[3], [4]] with self.session(use_gpu=True): for dtype in _TEST_TYPES: params_np = self._buildParams(data, dtype) params = constant_op.constant(params_np) indices_tf = constant_op.constant(indices, dtype=indices_dtype) gather_t = array_ops.batch_gather(params, indices_tf) expected_result = np.array([[3], [15]]) np_val = self._buildParams(expected_result, dtype) gather_val = gather_t.eval() self.assertAllEqual(np_val, gather_val) self.assertEqual(np_val.shape, gather_t.get_shape())
def test2DArray(self, indices_dtype): data = np.array([[0, 1, 2, 3, 7, 5], [8, 9, 10, 11, 15, 13]]) indices = [[3], [4]] with self.session(use_gpu=True): for dtype in _TEST_TYPES: params_np = self._buildParams(data, dtype) params = constant_op.constant(params_np) indices_tf = constant_op.constant(indices, dtype=indices_dtype) gather_t = array_ops.batch_gather(params, indices_tf) expected_result = np.array([[3], [15]]) np_val = self._buildParams(expected_result, dtype) gather_val = self.evaluate(gather_t) self.assertAllEqual(np_val, gather_val) self.assertEqual(np_val.shape, gather_t.get_shape())
def testHigherRank(self): data = np.array([[[0, 1, 2], [3, 7, 5]], [[8, 9, 10], [11, 15, 13]]]) indices = [[[2, 0], [1, 2]], [[2, 0], [0, 1]]] with self.session(use_gpu=True): for dtype in _TEST_TYPES: params_np = self._buildParams(data, dtype) params = constant_op.constant(params_np) indices_tf = constant_op.constant(indices) gather_t = array_ops.batch_gather(params, indices_tf) gather_val = gather_t.eval() expected_result = np.array([[[2, 0], [7, 5]], [[10, 8], [11, 15]]]) np_val = self._buildParams(expected_result, dtype) self.assertAllEqual(np_val, gather_val) self.assertEqual(np_val.shape, gather_t.get_shape())
def _ragged_tensor_scatter_nd_update(params, indices, updates): """Version of tensor_scatter_nd_update() where the values are ragged.""" # Create a RT in the shape of `params` and containing the "global" positions. # Here "global" means the element position in the flat values Tensor. global_positions_flat = math_ops.range(array_ops.size(params.flat_values)) global_positions = params.with_flat_values(global_positions_flat) global_indices = array_ops.batch_gather(global_positions, indices) update_indices = global_indices.flat_values update_indices = array_ops.expand_dims(update_indices, -1) update_indices = math_ops.cast(update_indices, params.dtype) params_flat = params.flat_values update_values = math_ops.cast(updates.flat_values, params_flat.dtype) results_flat = array_ops.tensor_scatter_update(params_flat, update_indices, update_values) return params.with_flat_values(results_flat)
a_top_k = tf.nn.top_k(a, k) indices = a_top_k[1] a_reshape = a.reshape(-1, 4) indices_reshape = tf.reshape(indices,(-1, 2)) gather_top_k = [] for i in range(len(a)): gather_top_k.append(tf.gather(a[i], indices_reshape[i])) gather_top_k = tf.stack(gather_top_k, axis=0) gather_top_k = tf.reshape(gather_top_k, (2, 2)) #gather_top_k = tf.expand_dims(gather_top_k, -1) batch_gather_top_k,batch_indices = batch_gather(a, indices) with tf.Session() as sess: sess.run(tf.global_variables_initializer()) a_top_k_arr = sess.run(a_top_k) print(a_top_k_arr) gather_top_k = sess.run(gather_top_k) print('gather_top_k: ',gather_top_k,sep='\n') print('a_top_k_arr[0]:',a_top_k_arr[0],sep='\n') print('batch_gather: ',sess.run(batch_gather_top_k),sep='\n') print('batch_indices: ',sess.run(batch_indices),sep='\n') #tensor = [[1,2],[3,4],[5,6]] mask = np.array([[True,False],[True,False],[True,False]]) """
def testBadIndicesCPU(self): with ops.device_v2("cpu:0"): params = [[0, 1, 2], [3, 4, 5]] with self.assertRaisesOpError( r"indices\[0\] = 7 is not in \[0, 2\)"): self.evaluate(array_ops.batch_gather(params, [7]))
def testBadIndicesCPU(self): with self.session(use_gpu=False): params = [[0, 1, 2], [3, 4, 5]] with self.assertRaisesOpError( r"indices\[0\] = 7 is not in \[0, 2\)"): array_ops.batch_gather(params, [7]).eval()
def testUnknownIndices(self): params = constant_op.constant([[0, 1, 2]]) indices = array_ops.placeholder(dtypes.int32, shape=[None, None]) gather_t = array_ops.batch_gather(params, indices) self.assertEqual([1, None], gather_t.get_shape().as_list())
def batch_gather_with_default(params, indices, default_value='', name=None): """Same as `batch_gather` but inserts `default_value` for invalid indices. This operation is similar to `batch_gather` except that it will substitute the value for invalid indices with `default_value` as the contents. See `batch_gather` for more details. Args: params: A potentially ragged tensor with shape `[B1...BN, P1...PM]` (`N>=0`, `M>0`). indices: A potentially ragged tensor with shape `[B1...BN, I]` (`N>=0`). default_value: A value to be inserted in places where `indices` are out of bounds. Must be the same dtype as params and either a scalar or rank 1. name: A name for the operation (optional). Returns: A potentially ragged tensor with shape `[B1...BN, I, P2...PM]`. `result.ragged_rank = max(indices.ragged_rank, params.ragged_rank)`. #### Example: ```python >>> params = tf.ragged.constant([ ['a', 'b', 'c'], ['d'], [], ['e']]) >>> indices = tf.ragged.constant([[1, 2, -1], [], [], [0, 10]]) >>> batch_gather_with_default(params, indices, 'FOO') [['b', 'c', 'FOO'], [], [], ['e', 'FOO']] ``` """ with ops.name_scope(name, 'RaggedBatchGatherWithDefault'): params = ragged_tensor.convert_to_tensor_or_ragged_tensor( params, name='params', ) indices = ragged_tensor.convert_to_tensor_or_ragged_tensor( indices, name='indices', ) default_value = ragged_tensor.convert_to_tensor_or_ragged_tensor( default_value, name='default_value', ) # TODO(hterry): lift this restriction and support default_values of # of rank > 1 if (default_value.shape.ndims is not 0 and default_value.shape.ndims is not 1): raise ValueError('"default_value" must be a scalar or vector') upper_bounds = None if indices.shape.ndims is None: raise ValueError('Indices must have a known rank.') if params.shape.ndims is None: raise ValueError('Params must have a known rank.') num_batch_dimensions = indices.shape.ndims - 1 pad = None # The logic for this works as follows: # - create a padded params, where: # padded_params[b1...bn, 0] = default_value # padded_params[b1...bn, i] = params[b1...bn, i-1] (i>0) # - create an `upper_bounds` Tensor that contains the number of elements # in each innermost rank. Broadcast `upper_bounds` to be the same shape # as `indices`. # - check to see which index in `indices` are out of bounds and substitute # it with the index containing `default_value` (the first). # - call batch_gather with the indices adjusted. with ops.control_dependencies([ check_ops.assert_greater_equal(array_ops.rank(params), array_ops.rank(indices)) ]): if ragged_tensor.is_ragged(params): row_lengths = ragged_array_ops.expand_dims( params.row_lengths(axis=num_batch_dimensions), axis=-1) upper_bounds = math_ops.cast(row_lengths, indices.dtype) pad_shape = _get_pad_shape(params, indices) pad = ragged_tensor_shape.broadcast_to(default_value, pad_shape) else: params_shape = array_ops.shape(params) pad_shape = array_ops.concat([ params_shape[:num_batch_dimensions], [1], params_shape[num_batch_dimensions + 1:params.shape.ndims] ], 0) upper_bounds = params_shape[num_batch_dimensions] pad = array_ops.broadcast_to(default_value, pad_shape) # Add `default_value` as the first value in the innermost (ragged) rank. pad = math_ops.cast(pad, params.dtype) padded_params = array_ops.concat([pad, params], axis=num_batch_dimensions) # Adjust the indices by substituting out-of-bound indices to the # default-value index (which is the first element) shifted_indices = indices + 1 is_out_of_bounds = (indices < 0) | (indices > upper_bounds) adjusted_indices = ragged_where_op.where( is_out_of_bounds, x=array_ops.zeros_like(indices), y=shifted_indices, ) return array_ops.batch_gather(params=padded_params, indices=adjusted_indices, name=name)
def batch_gather_with_default(params, indices, default_value='', name=None): """Same as `batch_gather` but inserts `default_value` for invalid indices. This operation is similar to `batch_gather` except that it will substitute the value for invalid indices with `default_value` as the contents. See `batch_gather` for more details. Args: params: A potentially ragged tensor with shape `[B1...BN, P1...PM]` (`N>=0`, `M>0`). indices: A potentially ragged tensor with shape `[B1...BN, I]` (`N>=0`). default_value: A value to be inserted in places where `indices` are out of bounds. Must be the same dtype as params and either a scalar or rank 1. name: A name for the operation (optional). Returns: A potentially ragged tensor with shape `[B1...BN, I, P2...PM]`. `result.ragged_rank = max(indices.ragged_rank, params.ragged_rank)`. #### Example: ```python >>> params = tf.ragged.constant([ ['a', 'b', 'c'], ['d'], [], ['e']]) >>> indices = tf.ragged.constant([[1, 2, -1], [], [], [0, 10]]) >>> batch_gather_with_default(params, indices, 'FOO') [['b', 'c', 'FOO'], [], [], ['e', 'FOO']] ``` """ with ops.name_scope(name, 'RaggedBatchGatherWithDefault'): params = ragged_tensor.convert_to_tensor_or_ragged_tensor( params, name='params', ) indices = ragged_tensor.convert_to_tensor_or_ragged_tensor( indices, name='indices', ) default_value = ragged_tensor.convert_to_tensor_or_ragged_tensor( default_value, name='default_value', ) # TODO(hterry): lift this restriction and support default_values of # of rank > 1 if (default_value.shape.ndims is not 0 and default_value.shape.ndims is not 1): raise ValueError('"default_value" must be a scalar or vector') upper_bounds = None if indices.shape.ndims is None: raise ValueError('Indices must have a known rank.') if params.shape.ndims is None: raise ValueError('Params must have a known rank.') num_batch_dimensions = indices.shape.ndims - 1 pad = None # The logic for this works as follows: # - create a padded params, where: # padded_params[b1...bn, 0] = default_value # padded_params[b1...bn, i] = params[b1...bn, i-1] (i>0) # - create an `upper_bounds` Tensor that contains the number of elements # in each innermost rank. Broadcast `upper_bounds` to be the same shape # as `indices`. # - check to see which index in `indices` are out of bounds and substitute # it with the index containing `default_value` (the first). # - call batch_gather with the indices adjusted. with ops.control_dependencies([ check_ops.assert_greater_equal(array_ops.rank(params), array_ops.rank(indices))]): if ragged_tensor.is_ragged(params): row_lengths = ragged_array_ops.expand_dims( params.row_lengths(axis=num_batch_dimensions), axis=-1) upper_bounds = math_ops.cast(row_lengths, indices.dtype) pad_shape = _get_pad_shape(params, indices) pad = ragged_tensor_shape.broadcast_to( default_value, pad_shape) else: params_shape = array_ops.shape(params) pad_shape = array_ops.concat([ params_shape[:num_batch_dimensions], [1], params_shape[num_batch_dimensions + 1:params.shape.ndims] ], 0) upper_bounds = params_shape[num_batch_dimensions] pad = array_ops.broadcast_to(default_value, pad_shape) # Add `default_value` as the first value in the innermost (ragged) rank. pad = math_ops.cast(pad, params.dtype) padded_params = array_ops.concat( [pad, params], axis=num_batch_dimensions) # Adjust the indices by substituting out-of-bound indices to the # default-value index (which is the first element) shifted_indices = indices + 1 is_out_of_bounds = (indices < 0) | (indices > upper_bounds) adjusted_indices = ragged_where_op.where( is_out_of_bounds, x=array_ops.zeros_like(indices), y=shifted_indices, ) return array_ops.batch_gather( params=padded_params, indices=adjusted_indices, name=name)
def batch_gather(params, indices, name=None): """Gathers slices from `params` according to `indices` with batch dims. This operation is similar to `gather`, but it assumes that the leading `N` dimensions of `indices` and `params` are batch dimensions, and performs a gather within each batch. In particular, when using this operation with `N` batch dimensions `B1...BN`: * `indices` has shape `[B1...BN, I]` * `params` has shape `[B1...BN, P1...PM]`. * `result` has shape `[B1...BN, I, P2...PM]`. * `result[b1...bN, i, p2...pM] = params[b1...bN, indices[b1...bN, i], p2...pM]` Args: params: A potentially ragged tensor with shape `[B1...BN, P1...PM]` (`N>=0`, `M>0`). indices: A potentially ragged tensor with shape `[B1...BN, I]` (`N>=0`). name: A name for the operation (optional). Returns: A potentially ragged tensor with shape `[B1...BN, I, P2...PM]`. `result.ragged_rank = max(indices.ragged_rank, params.ragged_rank)`. #### Example: ```python >>> params = tf.ragged.constant([['a', 'b', 'c'], ['d'], [], ['e']]) >>> indices = tf.ragged.constant([[1, 2, 0], [], [], [0, 0]]) >>> tf.compat.v1.batch_gather(params, indices) [['b', 'c', 'a'], [], [], ['e', 'e']] ``` """ if not (ragged_tensor.is_ragged(params) or ragged_tensor.is_ragged(indices)): return array_ops.batch_gather(params, indices, name) with ops.name_scope(name, 'RaggedBatchGather', [params, indices]): params = ragged_tensor.convert_to_tensor_or_ragged_tensor( params, name='params') indices = ragged_tensor.convert_to_tensor_or_ragged_tensor( indices, name='indices') params, indices = ragged_tensor.match_row_splits_dtypes(params, indices) indices_ndims = indices.shape.ndims if indices_ndims is None: raise ValueError( 'batch_gather does not allow indices with unknown shape.') if indices_ndims == 0: raise ValueError('indices.rank must be at least 1.') if ragged_tensor.is_ragged(indices): # If the outermost ragged dimension is a batch dimension, recurse. if indices_ndims > 2: if not ragged_tensor.is_ragged(params): raise ValueError('batch shape from indices does ' 'not match params shape') checks = [check_ops.assert_equal(params.row_splits, indices.row_splits)] with ops.control_dependencies(checks): return ragged_tensor.RaggedTensor.from_row_splits( batch_gather(params.values, indices.values), indices.row_splits, validate=False) # Otherwise, indices is a 2D ragged tensor with 1 ragged dimension. else: # Ensure that `params` is ragged and has at least 2 dimensions. if not ragged_tensor.is_ragged(params): if params.shape.ndims is not None and params.shape.ndims < 2: raise ValueError('batch shape from indices does ' 'not match params shape') params = ragged_tensor.RaggedTensor.from_tensor( params, ragged_rank=1, row_splits_dtype=indices.row_splits.dtype) # Adjust indices from within-batch to global (in params.values), and # then use ragged.gather to gather them. num_indices = indices.row_lengths() params_starts = params.row_starts() adjustments = ragged_util.repeat(params_starts, num_indices, axis=0) adjusted_index_values = ( math_ops.cast(indices.values, adjustments.dtype) + adjustments) return ragged_tensor.RaggedTensor.from_row_splits( ragged_gather_ops.gather(params.values, adjusted_index_values), indices.row_splits, validate=False) else: # params is a RaggedTensor and indices is a Tensor. if indices_ndims == 1: return ragged_gather_ops.gather(params, indices) elif indices_ndims == 2: # Adjust indices from batch-local to global (in params.values) adjustments = array_ops.expand_dims(params.row_starts(), 1) adjusted_indices = ( math_ops.cast(indices, adjustments.dtype) + adjustments) return ragged_gather_ops.gather(params.values, adjusted_indices) else: raise ValueError('batch shape from indices does not match params shape')
def testUnknownIndices(self): params = constant_op.constant([[0, 1, 2]]) indices = array_ops.placeholder(dtypes.int32, shape=[None, None]) gather_t = array_ops.batch_gather(params, indices) self.assertEqual([1, None], gather_t.get_shape().as_list())
def mask_language_model(input_ids, item_selector, mask_values_chooser, axis=1): """Applies dynamic language model masking. `mask_language_model` implements the `Masked LM and Masking Procedure` described in `BERT: Pre-training of Deep Bidirectional Transformers for Language Understanding` (https://arxiv.org/pdf/1810.04805.pdf). `mask_language_model` uses an `ItemSelector` to select the items for masking, and a `MaskValuesChooser` to assign the values to the selected items. The purpose of this is to bias the representation towards the actual observed item. Masking is performed on items in an axis. A decision is taken independently at random to mask with [MASK], mask with random tokens from the full vocab, or not mask at all. Note that the masking decision is broadcasted to the sub-dimensions. For example, in a RaggedTensor of shape `[batch, (wordpieces)]` and if axis=1, each wordpiece independently gets masked (or not). With the following input: ``` [[b"Sp", b"##onge", b"bob", b"Sq", b"##uare", b"##pants" ], [b"Bar", b"##ack", b"Ob", b"##ama"], [b"Mar", b"##vel", b"A", b"##ven", b"##gers"]], ``` `mask_language_model` could end up masking individual wordpieces: ``` [[b"[MASK]", b"##onge", b"bob", b"Sq", b"[MASK]", b"##pants" ], [b"Bar", b"##ack", b"[MASK]", b"##ama"], [b"[MASK]", b"##vel", b"A", b"##ven", b"##gers"]] ``` Or with random token inserted ``` [[b"[MASK]", b"##onge", b"bob", b"Sq", b"[MASK]", b"##pants" ], [b"Bar", b"##ack", b"Sq", b"##ama"], # random token inserted for 'Ob' [b"Bar", b"##vel", b"A", b"##ven", b"##gers"]] # random token inserted for # 'Mar' ``` In a RaggedTensor of shape `[batch, (words), (wordpieces)]`, whole words get masked (or not). If a word gets masked, all its tokens are independently either replaced by `[MASK]`, by random tokens, or no substitution occurs. Note that any arbitrary spans that can be constructed by a `RaggedTensor` can be masked in the same way. For example, if we have an `RaggedTensor` with shape `[batch, (token), (wordpieces)]`: ``` [[[b"Sp", "##onge"], [b"bob"], [b"Sq", b"##uare", b"##pants"]], [[b"Bar", "##ack"], [b"Ob", b"##ama"]], [[b"Mar", "##vel"], [b"A", b"##ven", b"##gers"]]] ``` `mask_language_model` could mask whole spans (items grouped together by the same 1st dimension): ``` [[[b"[MASK]", "[MASK]"], [b"bob"], [b"Sq", b"##uare", b"##pants"]], [[b"Bar", "##ack"], [b"[MASK]", b"[MASK]"]], [[b"[MASK]", "[MASK]"], [b"A", b"##ven", b"##gers"]]] ``` or insert randoms items in spans: ``` [[[b"Mar", "##ama"], [b"bob"], [b"Sq", b"##uare", b"##pants"]], [[b"Bar", "##ack"], [b"##onge", b"##gers"]], [[b"Ob", "Sp"], [b"A", b"##ven", b"##gers"]]] ``` Args: input_ids: A `RaggedTensor` of n dimensions (where n >= 2) on which masking will be applied to items up to dimension 1. item_selector: An instance of `ItemSelector` that is used for selecting items to be masked. mask_values_chooser: An instance of `MaskValuesChooser` which determines the values assigned to the ids chosen for masking. axis: the axis where items will be treated atomically for masking. Returns: A tuple of (masked_input_ids, masked_positions, masked_ids) where: masked_input_ids: A `RaggedTensor` in the same shape and dtype as `input_ids`, but with items in `masked_positions` possibly replaced with `mask_token`, random id, or no change. masked_positions: A `RaggedTensor` of ints with shape [batch, (num_masked)] containing the positions of items selected for masking. masked_ids: A `RaggedTensor` with shape [batch, (num_masked)] and same type as `input_ids` containing the original values before masking and thus used as labels for the task. """ if not isinstance(item_selector, item_selector_ops.ItemSelector): raise ValueError( "`item_selector` must be an instance of `ItemSelector`") if not isinstance(mask_values_chooser, MaskValuesChooser): raise ValueError("`mask_values_chooser` must be an instance of " + "`MaskValuesChooser`") input_ids = ragged_tensor.convert_to_tensor_or_ragged_tensor(input_ids) # Identify the items that are maskable and obtain their positions in the # rank 2 space. masked_token_positions = _get_selected_item_positions( item_selector, input_ids, axis) # Flatten everything down to a 2D RaggedTensor masked_token_positions = (masked_token_positions if masked_token_positions.ragged_rank == 1 else masked_token_positions.merge_dims(1, -1)) input_ids = (input_ids if input_ids.ragged_rank == 1 else input_ids.merge_dims(1, -1)) # Gather all the current ids in the places selected for masking. masked_lm_ids = array_ops.batch_gather(input_ids, masked_token_positions) # Figure out what we are going to replace these values with -- either masked # token, random int id, or do nothing. mask_values = mask_values_chooser.get_mask_values(masked_lm_ids) # scatter the new mask values back to their respective positions new_input_ids = _ragged_tensor_scatter_nd_update(input_ids, masked_token_positions, mask_values) return new_input_ids, masked_token_positions, masked_lm_ids
def batch_gather(params, indices, name=None): """Gathers slices from `params` according to `indices` with batch dims. This operation is similar to `gather`, but it assumes that the leading `N` dimensions of `indices` and `params` are batch dimensions, and performs a gather within each batch. In particular, when using this operation with `N` batch dimensions `B1...BN`: * `indices` has shape `[B1...BN, I]` * `params` has shape `[B1...BN, P1...PM]`. * `result` has shape `[B1...BN, I, P2...PM]`. * `result[b1...bN, i, p2...pM] = params[b1...bN, indices[b1...bN, i], p2...pM]` Args: params: A potentially ragged tensor with shape `[B1...BN, P1...PM]` (`N>=0`, `M>0`). indices: A potentially ragged tensor with shape `[B1...BN, I]` (`N>=0`). name: A name for the operation (optional). Returns: A potentially ragged tensor with shape `[B1...BN, I, P2...PM]`. `result.ragged_rank = max(indices.ragged_rank, params.ragged_rank)`. #### Example: ```python >>> params = tf.ragged.constant([['a', 'b', 'c'], ['d'], [], ['e']]) >>> indices = tf.ragged.constant([[1, 2, 0], [], [], [0, 0]]) >>> ragged.batch_gather(params, indices) [['b', 'c', 'a'], [], [], ['e', 'e']] ``` """ if not (ragged_tensor.is_ragged(params) or ragged_tensor.is_ragged(indices)): return array_ops.batch_gather(params, indices, name) with ops.name_scope(name, 'RaggedBatchGather', [params, indices]): params = ragged_tensor.convert_to_tensor_or_ragged_tensor( params, name='params') indices = ragged_tensor.convert_to_tensor_or_ragged_tensor( indices, name='indices') indices_ndims = indices.shape.ndims if indices_ndims is None: raise ValueError( 'batch_gather does not allow indices with unknown shape.') if indices_ndims == 0: raise ValueError('indices.rank must be at least 1.') if ragged_tensor.is_ragged(indices): # If the outermost ragged dimension is a batch dimension, recurse. if indices_ndims > 2: if not ragged_tensor.is_ragged(params): raise ValueError('batch shape from indices does ' 'not match params shape') checks = [check_ops.assert_equal(params.row_splits, indices.row_splits)] with ops.control_dependencies(checks): return ragged_tensor.RaggedTensor.from_row_splits( batch_gather(params.values, indices.values), indices.row_splits) # Otherwise, indices is a 2D ragged tensor with 1 ragged dimension. else: # Ensure that `params` is ragged and has at least 2 dimensions. if not ragged_tensor.is_ragged(params): if params.shape.ndims is not None and params.shape.ndims < 2: raise ValueError('batch shape from indices does ' 'not match params shape') params = ragged_conversion_ops.from_tensor(params, ragged_rank=1) # Adjust indices from within-batch to global (in params.values), and # then use ragged.gather to gather them. num_indices = indices.row_lengths() params_starts = params.row_starts() adjustments = ragged_util.repeat(params_starts, num_indices, axis=0) adjusted_index_values = math_ops.to_int64(indices.values) + adjustments return ragged_tensor.RaggedTensor.from_row_splits( gather(params.values, adjusted_index_values), indices.row_splits) else: # params is a RaggedTensor and indices is a Tensor. if indices_ndims == 1: return gather(params, indices) elif indices_ndims == 2: # Adjust indices from batch-local to global (in params.values) adjustments = array_ops.expand_dims(params.row_starts(), 1) adjusted_indices = math_ops.to_int64(indices) + adjustments return gather(params.values, adjusted_indices) else: raise ValueError('batch shape from indices does not match params shape')
def testBadIndicesCPU(self): with self.session(use_gpu=False): params = [[0, 1, 2], [3, 4, 5]] with self.assertRaisesOpError(r"indices\[0\] = 7 is not in \[0, 2\)"): array_ops.batch_gather(params, [7]).eval()