def _get_selected_item_positions(item_selector, input_ids, axis=1): """Get the positions of the items that have been selected. Args: item_selector: an instance of `ItemSelector`. input_ids: a `RaggedTensor` with n dimensions, whose items will be selected on. axis: (optional) An int detailing the dimension to apply selection on. Default is the 1st dimension. Returns: A `RaggedTensor` of int64s, with rank 2, shape [batch, (num_selections)] and whose values are the positions of items that have been selected. """ original_input_ids = input_ids # select items for masking selected_for_mask = item_selector.get_selection_mask(input_ids, axis) # create a positions RT original_input_ids = (original_input_ids.merge_dims( 1, -1) if original_input_ids.ragged_rank > 1 else original_input_ids) positions = ragged_math_ops.range(original_input_ids.row_lengths()) positions = input_ids.with_flat_values(positions.flat_values) # drop out not-masked positions results = ragged_array_ops.boolean_mask(positions, selected_for_mask) results = results.merge_dims(1, -1) if results.ragged_rank > 1 else results return results
def tokenize_with_offsets(self, input): # pylint: disable=redefined-builtin """Tokenizes a tensor of UTF-8 strings to Unicode characters. Returned token tensors are of integer type. Args: input: A `RaggedTensor`or `Tensor` of UTF-8 strings with any shape. Returns: A tuple `(tokens, start_offsets, end_offsets)` where: * `tokens`: A `RaggedTensor` of codepoints (integer type). * `start_offsets`: A `RaggedTensor` of the tokens' starting byte offset. * `end_offsets`: A `RaggedTensor` of the tokens' ending byte offset. """ name = None with ops.name_scope(name, "UnicodeCharTokenize", [input]): input_tensor = ragged_tensor.convert_to_tensor_or_ragged_tensor( input) (codepoints, byte_start_offsets) = ( ragged_string_ops.unicode_decode_with_offsets( input_tensor, "UTF-8")) strlens = math_ops.cast( array_ops.expand_dims(string_ops.string_length(input_tensor), -1), dtypes.int64) # Adjust strlens to set 0-length strings to empty array (there will be no # tokens in this case). final_ends = ragged_array_ops.boolean_mask(strlens, strlens > 0) byte_end_offsets = array_ops.concat( [byte_start_offsets[..., 1:], final_ends], -1) return codepoints, byte_start_offsets, byte_end_offsets
def trim(self, segments): """Truncate the list of `segments`. Truncate the list of `segments` using the truncation strategy defined by `generate_masks`. Args: segments: A list of `RaggedTensor`s w/ shape [num_batch, (num_items)]. Returns: a list of `RaggedTensor`s with len(segments) number of items and where each item has the same shape as its counterpart in `segments` and with unwanted values dropped. The values are dropped according to the `TruncationStrategy` defined. """ with ops.name_scope("Trimmer/Trim"): segments = [ ragged_tensor.convert_to_tensor_or_ragged_tensor(s) for s in segments ] truncate_masks = self.generate_mask(segments) truncated_segments = [ ragged_array_ops.boolean_mask( seg, mask.with_row_splits_dtype(seg.row_splits.dtype)) for seg, mask in zip(segments, truncate_masks) ] return truncated_segments
def remove_subtokens(inputs, mask): """ Remove wordpiece subtokens. """ if len(inputs) != 1: raise AssertionError("'%s' cannot have multiple inputs" % constants.ENCODER_REMOVE_SUBTOKENS) inputs = get_encoder_input(inputs[0]) return boolean_mask(inputs, tf.cast(mask, tf.bool)).to_tensor()
def get_selection_mask(self, input_ids, axis): selectable = super(RandomItemSelector, self).get_selectable(input_ids, axis) # Run the selection algorithm on positions RT positions_flat = math_ops.range(array_ops.size(input_ids.flat_values)) positions = input_ids.with_flat_values(positions_flat) # Mask out positions that are not selectable positions = ragged_array_ops.boolean_mask(positions, selectable) # merge to the desired axis positions = positions.merge_dims(1, axis) if axis > 1 else positions # Figure out how many we are going to select num_to_select = math_ops.ceil( math_ops.cast(positions.row_lengths(), dtypes.float32) * self.selection_rate) num_to_select = math_ops.minimum(num_to_select, self.max_selections_per_batch) num_to_select = math_ops.cast(num_to_select, dtypes.int64) # Shuffle and trim to items that are going to be selected def _shuffle_and_trim(x): positions, top_n = x if isinstance(positions, ragged_tensor.RaggedTensor): positions_at_axis = math_ops.range(positions.nrows()) chosen_positions_at_axis = self._shuffle_fn( positions_at_axis)[:top_n] return array_ops.gather(positions, chosen_positions_at_axis) else: shuffled = self._shuffle_fn(positions) return shuffled[:top_n] selected_for_mask = map_fn.map_fn( _shuffle_and_trim, (positions, num_to_select), fn_output_signature=ragged_tensor.RaggedTensorSpec( ragged_rank=positions.ragged_rank - 1, dtype=positions.dtype)) selected_for_mask.flat_values.set_shape([None]) # Construct the result which is a boolean RT # Scatter 1's to positions that have been selected_for_mask update_values = array_ops.ones_like(selected_for_mask.flat_values) update_values = math_ops.cast(update_values, input_ids.dtype) update_indices = selected_for_mask.flat_values update_indices = array_ops.expand_dims(update_indices, -1) update_indices = math_ops.cast(update_indices, input_ids.dtype) results_flat = array_ops.zeros_like(input_ids.flat_values) results_flat = gen_array_ops.tensor_scatter_update( results_flat, update_indices, update_values) results = math_ops.cast(input_ids.with_flat_values(results_flat), dtypes.bool) if axis < results.ragged_rank: reduce_axis = list(range(results.ragged_rank, axis, -1)) results = math_ops.reduce_all(results, reduce_axis) return results
def get_selectable(self, input_ids, axis): """See `get_selectable()` in superclass.""" selectable = super(FirstNItemSelector, self).get_selectable(input_ids, axis) axis = array_ops.get_positive_axis( axis, input_ids.ragged_rank + input_ids.flat_values.shape.rank) # Create a positions RT and mask out positions that are not selectable positions_flat = math_ops.range(array_ops.size(input_ids.flat_values)) positions = input_ids.with_flat_values(positions_flat) selectable_positions = ragged_array_ops.boolean_mask( positions, selectable) # merge to the desired axis selectable_positions = selectable_positions.merge_dims( 1, axis) if axis > 1 else selectable_positions # Get a selection mask based off of how many items are desired for selection merged_axis = axis - (axis - 1) selection_mask = _get_selection_mask(selectable_positions, self._num_to_select, merged_axis) # Mask out positions that were not selected. selected_positions = ragged_array_ops.boolean_mask( selectable_positions, selection_mask) # Now that we have all the positions which were chosen, we recreate a mask # (matching the original input's shape) where the value is True if it was # selected. We do this by creating a "all false" RT and scattering true # values to the positions chosen for selection. all_true = selected_positions.with_flat_values( array_ops.ones_like(selected_positions.flat_values)) all_false = math_ops.cast( array_ops.zeros(array_ops.shape(input_ids.flat_values)), dtypes.int32) results_flat = array_ops.tensor_scatter_update( all_false, array_ops.expand_dims(selected_positions.flat_values, -1), all_true.flat_values) results = input_ids.with_flat_values(results_flat) results = math_ops.cast(results, dtypes.bool) # Reduce until input.shape[:axis] for _ in range(input_ids.shape.ndims - axis - 1): results = math_ops.reduce_all(results, -1) return results
def testNothingSelector(self, masking_inputs, unselectable_ids, expected_selected_items, num_to_select=2, axis=1, description=""): masking_inputs = ragged_factory_ops.constant(masking_inputs) item_selector = item_selector_ops.NothingSelector() selection_mask = item_selector.get_selectable(masking_inputs, axis) selected_items = ragged_array_ops.boolean_mask(masking_inputs, selection_mask) self.assertAllEqual(selected_items, expected_selected_items)
def testGetSelectable(self, masking_inputs, expected_selectable, num_to_select=2, unselectable_ids=None, axis=1, description=""): masking_inputs = ragged_factory_ops.constant(masking_inputs) item_selector = item_selector_ops.FirstNItemSelector( num_to_select=num_to_select, unselectable_ids=unselectable_ids) selectable = item_selector.get_selectable(masking_inputs, axis) actual_selection = ragged_array_ops.boolean_mask( masking_inputs, selectable) self.assertAllEqual(actual_selection, expected_selectable)
def testGetSelectionMask(self, masking_inputs, expected_selected_items, unselectable_ids=None, axis=1, shuffle_fn="", description=""): shuffle_fn = (functools.partial(array_ops.reverse, axis=[-1]) if shuffle_fn == "reverse" else array_ops.identity) masking_inputs = ragged_factory_ops.constant(masking_inputs) item_selector = item_selector_ops.RandomItemSelector( max_selections_per_batch=2, selection_rate=1, shuffle_fn=shuffle_fn, unselectable_ids=unselectable_ids, ) selection_mask = item_selector.get_selection_mask(masking_inputs, axis) selected_items = ragged_array_ops.boolean_mask(masking_inputs, selection_mask) self.assertAllEqual(selected_items, expected_selected_items)
def testBooleanMask(self, descr, data, mask, expected): actual = ragged_array_ops.boolean_mask(data, mask) self.assertAllEqual(actual, expected)
def testBooleanMask(self, descr, data, mask, keepdims, expected): actual = ragged_array_ops.boolean_mask(data, mask, keepdims=keepdims) self.assertRaggedEqual(actual, expected)