def get_mask_values(self, masked_lm_ids): """Get the values used for masking, random injection or no-op. Args: masked_lm_ids: a `RaggedTensor` of n dimensions and dtype int32 or int64 whose values are the ids of items that have been selected for masking. Returns: a `RaggedTensor` of the same dtype and shape with `masked_lm_ids` whose values contain either the mask token, randomly injected token or original value. """ validate_rates = control_flow_ops.Assert( self._mask_token_rate + self._random_token_rate <= 1, ["mask_token_rate + random_token_rate must be <= 1"]) with ops.control_dependencies([validate_rates]): # Generate a random number for all mask-able items. Items that should be # treated atomically (e.g. all wordpieces in a token, span, etc) will have # the same random number. random_uniform = _get_random(masked_lm_ids) # Merge down to rank 2. random_uniform = (random_uniform if random_uniform.ragged_rank == 1 else random_uniform.merge_dims(1, -1)) mask_values = masked_lm_ids all_mask_flat = array_ops.tile([self._mask_token], array_ops.shape( mask_values.flat_values)) # Maybe add mask token `mask_token_rate`% of the time should_mask_flat = random_uniform.flat_values < math_ops.cast( self._mask_token_rate, dtypes.float32) mask_values = mask_values.with_flat_values( ragged_where_op.where(should_mask_flat, x=math_ops.cast( all_mask_flat, mask_values.flat_values.dtype), y=mask_values.flat_values)) # Maybe inject random token `random_token_rate`% of the time. all_random_flat = random_ops.random_uniform( array_ops.shape(mask_values.flat_values), maxval=math_ops.cast(self._vocab_size, dtypes.float32)) should_inject_random_flat = math_ops.logical_and( random_uniform.flat_values > self._mask_token_rate, random_uniform.flat_values < self._random_token_rate) mask_values = mask_values.with_flat_values( ragged_where_op.where(should_inject_random_flat, x=math_ops.cast( all_random_flat, mask_values.flat_values.dtype), y=mask_values.flat_values)) return mask_values
def _multivalent_span_alignment(overlaps): """Returns the multivalent span alignment for a given overlaps tensor. Args: overlaps: `<int64>[D1...DB, source_size, target_size]`: `overlaps[b1...bB, i, j]` is true if source span `i` overlaps target span `j` (in batch `b1...bB`). Returns: `<int64>[D1...DB, source_size, (num_aligned_spans)]`: `result[b1...bB, i, n]=j` if target span `j` is the `n`'th target span that aligns with source span `i` (in batch `b1...bB`). """ overlaps_ndims = overlaps.shape.ndims assert overlaps_ndims is not None # guaranteed/checked by span_overlaps() assert overlaps_ndims >= 2 # If there are multiple batch dimensions, then flatten them and recurse. if overlaps_ndims > 3: if not isinstance(overlaps, ragged_tensor.RaggedTensor): overlaps = ragged_tensor.RaggedTensor.from_tensor( overlaps, ragged_rank=overlaps.shape.ndims - 3) return overlaps.with_values( _multivalent_span_alignment(overlaps.values)) elif overlaps_ndims == 2: # no batch dimension assert not isinstance(overlaps, ragged_tensor.RaggedTensor) overlap_positions = array_ops.where(overlaps) return ragged_tensor.RaggedTensor.from_value_rowids( values=overlap_positions[:, 1], value_rowids=overlap_positions[:, 0], nrows=array_ops.shape(overlaps, out_type=dtypes.int64)[0]) else: # batch dimension if not isinstance(overlaps, ragged_tensor.RaggedTensor): overlaps = ragged_tensor.RaggedTensor.from_tensor(overlaps, ragged_rank=1) overlap_positions = ragged_where_op.where(overlaps.values) if isinstance(overlaps.values, ragged_tensor.RaggedTensor): overlaps_values_nrows = overlaps.values.nrows() else: overlaps_values_nrows = array_ops.shape(overlaps.values, out_type=dtypes.int64)[0] return overlaps.with_values( ragged_tensor.RaggedTensor.from_value_rowids( values=overlap_positions[:, 1], value_rowids=overlap_positions[:, 0], nrows=overlaps_values_nrows))
def get_segments(self, sentences): """Extracts the next sentence label from sentences. Args: sentences: A `RaggedTensor` of strings w/ shape [batch, (num_sentences)]. Returns: A tuple of (segment_a, segment_b, is_next_sentence) where: segment_a: A `Tensor` of strings w/ shape [total_num_sentences] that contains all the original sentences. segment_b: A `Tensor` with shape [num_sentences] that contains either the subsequent sentence of `segment_a` or a randomly injected sentence. is_next_sentence: A `Tensor` of bool w/ shape [num_sentences] that contains whether or not `segment_b` is truly a subsequent sentence or not. """ next_sentence = ragged_map_ops.map_fn( functools.partial(manip_ops.roll, axis=0, shift=-1), sentences, dtype=ragged_tensor.RaggedTensorType(dtypes.string, 1), infer_shape=False) random_sentence = sentences.with_flat_values( self._shuffle_fn(sentences.flat_values)) is_next_sentence_labels = (self._random_fn(sentences.flat_values.shape) > self._random_next_sentence_threshold) is_next_sentence = sentences.with_flat_values(is_next_sentence_labels) # Randomly decide if we should use next sentence or throw in a random # sentence. segment_two = ragged_where_op.where(is_next_sentence, x=next_sentence, y=random_sentence) # Get rid of the docs dimensions sentences = sentences.merge_dims(-2, -1) segment_two = segment_two.merge_dims(-2, -1) is_next_sentence = is_next_sentence.merge_dims(-2, -1) is_next_sentence = math_ops.cast(is_next_sentence, dtypes.int64) return sentences, segment_two, is_next_sentence
def batch_gather_with_default(params, indices, default_value='', name=None): """Same as `batch_gather` but inserts `default_value` for invalid indices. This operation is similar to `batch_gather` except that it will substitute the value for invalid indices with `default_value` as the contents. See `batch_gather` for more details. Args: params: A potentially ragged tensor with shape `[B1...BN, P1...PM]` (`N>=0`, `M>0`). indices: A potentially ragged tensor with shape `[B1...BN, I]` (`N>=0`). default_value: A value to be inserted in places where `indices` are out of bounds. Must be the same dtype as params and either a scalar or rank 1. name: A name for the operation (optional). Returns: A potentially ragged tensor with shape `[B1...BN, I, P2...PM]`. `result.ragged_rank = max(indices.ragged_rank, params.ragged_rank)`. #### Example: ```python >>> params = tf.ragged.constant([ ['a', 'b', 'c'], ['d'], [], ['e']]) >>> indices = tf.ragged.constant([[1, 2, -1], [], [], [0, 10]]) >>> batch_gather_with_default(params, indices, 'FOO') [['b', 'c', 'FOO'], [], [], ['e', 'FOO']] ``` """ with ops.name_scope(name, 'RaggedBatchGatherWithDefault'): params = ragged_tensor.convert_to_tensor_or_ragged_tensor( params, name='params', ) indices = ragged_tensor.convert_to_tensor_or_ragged_tensor( indices, name='indices', ) default_value = ragged_tensor.convert_to_tensor_or_ragged_tensor( default_value, name='default_value', ) # TODO(hterry): lift this restriction and support default_values of # of rank > 1 if (default_value.shape.ndims is not 0 and default_value.shape.ndims is not 1): raise ValueError('"default_value" must be a scalar or vector') upper_bounds = None if indices.shape.ndims is None: raise ValueError('Indices must have a known rank.') if params.shape.ndims is None: raise ValueError('Params must have a known rank.') num_batch_dimensions = indices.shape.ndims - 1 pad = None # The logic for this works as follows: # - create a padded params, where: # padded_params[b1...bn, 0] = default_value # padded_params[b1...bn, i] = params[b1...bn, i-1] (i>0) # - create an `upper_bounds` Tensor that contains the number of elements # in each innermost rank. Broadcast `upper_bounds` to be the same shape # as `indices`. # - check to see which index in `indices` are out of bounds and substitute # it with the index containing `default_value` (the first). # - call batch_gather with the indices adjusted. with ops.control_dependencies([ check_ops.assert_greater_equal(array_ops.rank(params), array_ops.rank(indices)) ]): if ragged_tensor.is_ragged(params): row_lengths = ragged_array_ops.expand_dims( params.row_lengths(axis=num_batch_dimensions), axis=-1) upper_bounds = math_ops.cast(row_lengths, indices.dtype) pad_shape = _get_pad_shape(params, indices) pad = ragged_tensor_shape.broadcast_to(default_value, pad_shape) else: params_shape = array_ops.shape(params) pad_shape = array_ops.concat([ params_shape[:num_batch_dimensions], [1], params_shape[num_batch_dimensions + 1:params.shape.ndims] ], 0) upper_bounds = params_shape[num_batch_dimensions] pad = array_ops.broadcast_to(default_value, pad_shape) # Add `default_value` as the first value in the innermost (ragged) rank. pad = math_ops.cast(pad, params.dtype) padded_params = array_ops.concat([pad, params], axis=num_batch_dimensions) # Adjust the indices by substituting out-of-bound indices to the # default-value index (which is the first element) shifted_indices = indices + 1 is_out_of_bounds = (indices < 0) | (indices > upper_bounds) adjusted_indices = ragged_where_op.where( is_out_of_bounds, x=array_ops.zeros_like(indices), y=shifted_indices, ) return array_ops.batch_gather(params=padded_params, indices=adjusted_indices, name=name)
def testRaggedWhereErrors(self, condition, error, message, x=None, y=None): with self.assertRaisesRegex(error, message): ragged_where_op.where(condition, x, y)
def testRaggedWhere(self, condition, expected, x=None, y=None): result = ragged_where_op.where(condition, x, y) self.assertAllEqual(result, expected)
def batch_gather_with_default(params, indices, default_value='', name=None): """Same as `batch_gather` but inserts `default_value` for invalid indices. This operation is similar to `batch_gather` except that it will substitute the value for invalid indices with `default_value` as the contents. See `batch_gather` for more details. Args: params: A potentially ragged tensor with shape `[B1...BN, P1...PM]` (`N>=0`, `M>0`). indices: A potentially ragged tensor with shape `[B1...BN, I]` (`N>=0`). default_value: A value to be inserted in places where `indices` are out of bounds. Must be the same dtype as params and either a scalar or rank 1. name: A name for the operation (optional). Returns: A potentially ragged tensor with shape `[B1...BN, I, P2...PM]`. `result.ragged_rank = max(indices.ragged_rank, params.ragged_rank)`. #### Example: ```python >>> params = tf.ragged.constant([ ['a', 'b', 'c'], ['d'], [], ['e']]) >>> indices = tf.ragged.constant([[1, 2, -1], [], [], [0, 10]]) >>> batch_gather_with_default(params, indices, 'FOO') [['b', 'c', 'FOO'], [], [], ['e', 'FOO']] ``` """ with ops.name_scope(name, 'RaggedBatchGatherWithDefault'): params = ragged_tensor.convert_to_tensor_or_ragged_tensor( params, name='params', ) indices = ragged_tensor.convert_to_tensor_or_ragged_tensor( indices, name='indices', ) default_value = ragged_tensor.convert_to_tensor_or_ragged_tensor( default_value, name='default_value', ) # TODO(hterry): lift this restriction and support default_values of # of rank > 1 if (default_value.shape.ndims is not 0 and default_value.shape.ndims is not 1): raise ValueError('"default_value" must be a scalar or vector') upper_bounds = None if indices.shape.ndims is None: raise ValueError('Indices must have a known rank.') if params.shape.ndims is None: raise ValueError('Params must have a known rank.') num_batch_dimensions = indices.shape.ndims - 1 pad = None # The logic for this works as follows: # - create a padded params, where: # padded_params[b1...bn, 0] = default_value # padded_params[b1...bn, i] = params[b1...bn, i-1] (i>0) # - create an `upper_bounds` Tensor that contains the number of elements # in each innermost rank. Broadcast `upper_bounds` to be the same shape # as `indices`. # - check to see which index in `indices` are out of bounds and substitute # it with the index containing `default_value` (the first). # - call batch_gather with the indices adjusted. with ops.control_dependencies([ check_ops.assert_greater_equal(array_ops.rank(params), array_ops.rank(indices))]): if ragged_tensor.is_ragged(params): row_lengths = ragged_array_ops.expand_dims( params.row_lengths(axis=num_batch_dimensions), axis=-1) upper_bounds = math_ops.cast(row_lengths, indices.dtype) pad_shape = _get_pad_shape(params, indices) pad = ragged_tensor_shape.broadcast_to( default_value, pad_shape) else: params_shape = array_ops.shape(params) pad_shape = array_ops.concat([ params_shape[:num_batch_dimensions], [1], params_shape[num_batch_dimensions + 1:params.shape.ndims] ], 0) upper_bounds = params_shape[num_batch_dimensions] pad = array_ops.broadcast_to(default_value, pad_shape) # Add `default_value` as the first value in the innermost (ragged) rank. pad = math_ops.cast(pad, params.dtype) padded_params = array_ops.concat( [pad, params], axis=num_batch_dimensions) # Adjust the indices by substituting out-of-bound indices to the # default-value index (which is the first element) shifted_indices = indices + 1 is_out_of_bounds = (indices < 0) | (indices > upper_bounds) adjusted_indices = ragged_where_op.where( is_out_of_bounds, x=array_ops.zeros_like(indices), y=shifted_indices, ) return array_ops.batch_gather( params=padded_params, indices=adjusted_indices, name=name)
def testRaggedWhereErrors(self, condition, error, message, x=None, y=None): with self.assertRaisesRegexp(error, message): ragged_where_op.where(condition, x, y)
def testRaggedWhere(self, condition, expected, x=None, y=None): result = ragged_where_op.where(condition, x, y) self.assertRaggedEqual(result, expected)