def testMismatchRaggedRank(self): elems = ragged_factory_ops.constant([[[1, 2, 3]], [[4, 5], [6, 7]]]) fn = lambda x: ragged_math_ops.reduce_sum(x, axis=0) with self.assertRaisesRegex( ValueError, r'(?s)Expected `fn` to return.*But it returned.*'): _ = ragged_map_ops.map_fn(fn, elems, dtype=ragged_tensor.RaggedTensorType( dtype=dtypes.int64, ragged_rank=23))
def testMismatchRaggedRank(self): elems = ragged_factory_ops.constant([[[1, 2, 3]], [[4, 5], [6, 7]]]) fn = lambda x: ragged_math_ops.reduce_sum(x, axis=0) with self.assertRaisesWithLiteralMatch( ValueError, r'The declared ragged rank (23) mismatches the result (1)'): _ = ragged_map_ops.map_fn(fn, elems, dtype=ragged_tensor.RaggedTensorType( dtype=dtypes.int64, ragged_rank=23))
def testMismatchRaggedRank(self): elems = ragged_factory_ops.constant([[[1, 2, 3]], [[4, 5], [6, 7]]]) fn = lambda x: ragged_math_ops.reduce_sum(x, axis=0) with self.assertRaisesWithLiteralMatch( ValueError, r'The declared ragged rank (23) mismatches the result (1)'): _ = ragged_map_ops.map_fn( fn, elems, dtype=ragged_tensor.RaggedTensorType( dtype=dtypes.int64, ragged_rank=23))
def boolean_mask(data, mask, name=None): """Applies a boolean mask to `data` without flattening the mask dimensions. Returns a potentially ragged tensor that is formed by retaining the elements in `data` where the corresponding value in `mask` is `True`. * `output[a1...aA, i, b1...bB] = data[a1...aA, j, b1...bB]` Where `j` is the `i`th `True` entry of `mask[a1...aA]`. Note that `output` preserves the mask dimensions `a1...aA`; this differs from `tf.boolean_mask`, which flattens those dimensions. Args: data: A potentially ragged tensor. mask: A potentially ragged boolean tensor. `mask`'s shape must be a prefix of `data`'s shape. `rank(mask)` must be known statically. name: A name prefix for the returned tensor (optional). Returns: A potentially ragged tensor that is formed by retaining the elements in `data` where the corresponding value in `mask` is `True`. * `rank(output) = rank(data)`. * `output.ragged_rank = max(data.ragged_rank, rank(mask) - 1)`. Raises: ValueError: if `rank(mask)` is not known statically; or if `mask.shape` is not a prefix of `data.shape`. #### Examples: >>> # Aliases for True & False so data and mask line up. >>> T, F = (True, False) >>> tf.ragged.boolean_mask( # Mask a 2D Tensor. ... data=[[1, 2, 3], [4, 5, 6], [7, 8, 9]], ... mask=[[T, F, T], [F, F, F], [T, F, F]]).to_list() [[1, 3], [], [7]] >>> tf.ragged.boolean_mask( # Mask a 2D RaggedTensor. ... tf.ragged.constant([[1, 2, 3], [4], [5, 6]]), ... tf.ragged.constant([[F, F, T], [F], [T, T]])).to_list() [[3], [], [5, 6]] >>> tf.ragged.boolean_mask( # Mask rows of a 2D RaggedTensor. ... tf.ragged.constant([[1, 2, 3], [4], [5, 6]]), ... tf.ragged.constant([True, False, True])).to_list() [[1, 2, 3], [5, 6]] """ with ops.name_scope(name, 'RaggedMask', [data, mask]): # Convert inputs to tensors. data = ragged_tensor.convert_to_tensor_or_ragged_tensor(data, name='data') mask = ragged_tensor.convert_to_tensor_or_ragged_tensor( mask, dtypes.bool, name='mask') row_splits_dtype, (data, mask) = ragged_tensor.match_row_splits_dtypes( data, mask, return_dtype=True) # Get static rank of mask. if mask.shape.ndims is None: raise ValueError('mask.shape.ndims must be known statically.') elif mask.shape.ndims == 0: raise ValueError('mask cannot be scalar.') # If mask is ragged, then recurse with a non-ragged mask. if ragged_tensor.is_ragged(mask): if not ragged_tensor.is_ragged(data): data = ragged_tensor.RaggedTensor.from_tensor( data, ragged_rank=mask.ragged_rank, row_splits_dtype=mask.row_splits.dtype) # Check that mask.nested_row_splits is a prefix of # data.nested_row_splits. splits_list = [ mask.nested_row_splits, data.nested_row_splits[:mask.ragged_rank] ] with ops.control_dependencies( ragged_util.assert_splits_match(splits_list)): # Strip off ragged `splits` until `mask` is non-ragged. Keep the splits # that we strip off in `splits`, so we can add them back on after # we recursively mask the non-ragged data. splits = [] while ragged_tensor.is_ragged(mask): if mask.shape.ndims > 2: splits.append(mask.row_splits) else: # Count the number of True mask values in each row to find the # lengths of the filtered rows; then convert to splits. int_mask = ragged_functional_ops.map_flat_values( math_ops.cast, mask, dtype=row_splits_dtype) masked_row_lengths = ragged_math_ops.reduce_sum(int_mask, axis=1) splits.append(ragged_util.lengths_to_splits(masked_row_lengths)) mask = mask.values data = data.values # Recursively apply the nested non-ragged mask to the nested data. masked_values = boolean_mask(data, mask) # Add the ragged `splits` back to the result. masked_values = ragged_tensor.RaggedTensor.from_nested_row_splits( masked_values, splits, validate=False) return masked_values # If mask is non-ragged and has rank 1, and data is ragged, then build a # ragged tensor with the indicated rows. elif ragged_tensor.is_ragged(data) and mask.shape.ndims == 1: # Get the masked splits: first get the length of each row, then filter # out the rows that we are deleting, and convert that filtered set of # masks back to a splits tensor. lengths = data.row_lengths() masked_lengths = array_ops.boolean_mask(lengths, mask) masked_splits = ragged_util.lengths_to_splits(masked_lengths) # Get the masked values: first get row ids corresponding to each # value, then use tf.gather to build a boolean mask that's false for # values that come from rows that we are deleting, and use that mask to # construct the masked values tensor. segment_ids = segment_id_ops.row_splits_to_segment_ids(data.row_splits) segment_mask = array_ops.gather(mask, segment_ids) masked_values = boolean_mask(data.values, segment_mask) return ragged_tensor.RaggedTensor.from_row_splits( masked_values, masked_splits, validate=False) # If mask is non-ragged and has rank>1, then convert it to be ragged, # with a ragged rank matching data. if ragged_tensor.is_ragged(data): mask = ragged_tensor.RaggedTensor.from_tensor( mask, ragged_rank=min(data.ragged_rank, mask.shape.ndims - 1), row_splits_dtype=data.row_splits.dtype) return boolean_mask(data, mask) # Otherwise, data and mask are both `Tensor`s. else: # Apply `boolean_mask` to get the masked values. masked_values = array_ops.boolean_mask(data, mask) if mask.shape.ndims >= 2: # Add the innermost ragged dimension. For each innermost cell, get the # number of values it contains. Then flatten that to get a list of # cell lengths, and convert it to splits. Finally, combine the splits # and values to get the innermost ragged tensor. masked_lengths = math_ops.count_nonzero( mask, axis=-1, dtype=row_splits_dtype) flattened_masked_lengths = array_ops.reshape(masked_lengths, [-1]) masked_values = ragged_tensor.RaggedTensor.from_row_lengths( masked_values, flattened_masked_lengths, validate=False) # Wrap remaining ragged dimensions. if mask.shape.ndims > 2: mask_shape = array_ops.shape(mask, out_type=row_splits_dtype) split_size = math_ops.cumprod(mask_shape) + 1 for dim in range(mask.shape.ndims - 3, -1, -1): elt_size = mask_shape[dim + 1] masked_splits = math_ops.range(split_size[dim]) * elt_size masked_values = ragged_tensor.RaggedTensor.from_row_splits( masked_values, masked_splits, validate=False) return masked_values
def testReduceKeepsInnerDimensionShape(self): # Test for bug [b/139823356]. rt = ragged_factory_ops.constant([[[[1, 1]]]], ragged_rank=2) self.assertEqual(rt.shape.as_list(), [1, None, None, 2]) reduced = ragged_math_ops.reduce_sum(rt, axis=2) self.assertEqual(reduced.shape.as_list(), [1, None, 2])
class RaggedMapOpTest(test_util.TensorFlowTestCase, parameterized.TestCase): @parameterized.parameters([ # The following test sets map over a RaggedTensor and apply a # transformation that returns with shape: # [d1, (d2)] -> [d1] dict( fn=mo.reduce_mean, elems=[[1, 2, 3], [4, 5], [6, 7]], elems_dtype=dtypes.int32, expected_output=[2, 4, 6], result_dtype=dtypes.int32, ), dict( fn=string_ops.reduce_join, elems=[['foo', 'bar', 'baz'], ['a'], ['b', 'c']], expected_output=[b'foobarbaz', b'a', b'bc'], elems_dtype=dtypes.string, result_dtype=dtypes.string, ), # [d1, (d2)] -> [d1, 2] dict( fn=lambda x: array_ops.stack([mo.reduce_mean(x), mo.reduce_sum(x)]), # fn=self.stack_mean_and_sum, elems=[[1, 2, 3], [4, 5], [6, 7]], expected_output=[[2, 6], [4.5, 9], [6.5, 13]], elems_dtype=dtypes.float32, result_dtype=dtypes.float32, expected_ragged_rank=0, ), # [d1, (d2)] -> [d1, (d2)] dict( fn=lambda x: x + np.int64(1), elems=[[1, 2, 3], [4, 5], [6, 7]], expected_output=[[2, 3, 4], [5, 6], [7, 8]], elems_dtype=dtypes.int64, result_dtype=ragged_tensor.RaggedTensorType(dtype=dtypes.int64, ragged_rank=1), ), # [d1, (d2), d3] -> [d1, (d2), d3] dict( fn=lambda x: x + np.int64(1), elems=[[[1, 2], [3, 4]], [], [[5, 6], [7, 8], [9, 0]]], elems_ragged_rank=1, expected_ragged_rank=1, result_dtype=ragged_tensor.RaggedTensorType(dtype=dtypes.int64, ragged_rank=1), expected_output=[[[2, 3], [4, 5]], [], [[6, 7], [8, 9], [10, 1]]], ), # [d1, (d2)] -> [d1, (d2), (d3)] dict( fn=lambda x: ragged_tensor.RaggedTensor.from_row_starts(x, [0]), elems=[[1, 2, 3], [4, 5], [6, 7]], expected_output=[[[1, 2, 3]], [[4, 5]], [[6, 7]]], result_dtype=ragged_tensor.RaggedTensorType(dtype=dtypes.int64, ragged_rank=2), ), # [d1, (d2), (d3)] -> [d1, (d2), (d3)] dict( fn=lambda x: ragged_functional_ops.map_flat_values(mo.add, x, 1), elems=[[[1, 2, 3]], [[4, 5], [6, 7]]], expected_output=[[[2, 3, 4]], [[5, 6], [7, 8]]], result_dtype=ragged_tensor.RaggedTensorType(dtype=dtypes.int64, ragged_rank=2), ), # [d1, (d2), (d3)] -> [d1, (d2)] dict( fn=lambda x: ragged_math_ops.reduce_sum(x, axis=1), elems=[[[1, 2, 3]], [[4, 5], [6, 7]]], expected_output=[[6], [9, 13]], result_dtype=ragged_tensor.RaggedTensorType(dtype=dtypes.int64, ragged_rank=1), ), # [d1, (d2), (d3)] -> [d1, (d3)] dict( fn=lambda x: ragged_math_ops.reduce_sum(x, axis=0), elems=[[[1, 2, 3]], [[4, 5], [6, 7]]], expected_output=[[1, 2, 3], [10, 12]], result_dtype=ragged_tensor.RaggedTensorType(dtype=dtypes.int64, ragged_rank=1), ), # [d1, (d2), (d3)] -> [d1] dict( fn=ragged_math_ops.reduce_sum, elems=[[[1, 2, 3]], [[4, 5], [6, 7]]], expected_output=[6, 22], result_dtype=dtypes.int64, ), # [d1] -> [d1, (d2)] dict( fn=mo.range, elems=[4, 0, 2], expected_output=[[0, 1, 2, 3], [], [0, 1]], result_dtype=ragged_tensor.RaggedTensorType(dtype=dtypes.int64, ragged_rank=1), ), # [d1] -> [d1, (d2), (d3)] dict( fn=lambda x: ragged_math_ops.range(mo.range(x)), elems=[5, 0, 3], expected_output=[[[], [0], [0, 1], [0, 1, 2], [0, 1, 2, 3]], [], [[], [0], [0, 1]]], result_dtype=ragged_tensor.RaggedTensorType(dtype=dtypes.int64, ragged_rank=2), ), # [d1, (d2), (d3), (d4a), (d5)] -> [d1, (d2), (d3), (d4b), (d5)] dict( fn=lambda x: x + np.int64(1), elems=[[[[[1, 2, 3]], [[4], [5]]]], [[[[6, 7]]], [[[8], []]]]], expected_output=[[[[[2, 3, 4]], [[5], [6]]]], [[[[7, 8]]], [[[9], []]]]], result_dtype=ragged_tensor.RaggedTensorType(dtype=dtypes.int64, ragged_rank=4), ), ]) def testRaggedMap( self, fn, elems, expected_output, expected_ragged_rank=None, result_ragged_rank=None, elems_ragged_rank=None, elems_dtype=dtypes.int64, result_dtype=None, infer_shape=True, ): elems = ragged_factory_ops.constant(elems, elems_dtype, elems_ragged_rank) output = ragged_map_ops.map_fn(fn=fn, elems=elems, dtype=result_dtype, infer_shape=infer_shape) expected_rt = ragged_factory_ops.constant( expected_output, ragged_rank=expected_ragged_rank) self.assertAllEqual(expected_rt, output) def testRaggedMapOnStructure(self): batman = ragged_factory_ops.constant([[1, 2, 3], [4], [5, 6, 7]]) # [[10, 20, 30], [40], [50, 60, 70]] robin = ragged_functional_ops.map_flat_values(mo.multiply, batman, 10) features = {'batman': batman, 'robin': robin} def _reduce_sum_from_all(f): return mo.reduce_sum(f['batman']) + mo.reduce_sum(f['robin']) output = ragged_map_ops.map_fn( fn=_reduce_sum_from_all, elems=features, dtype=dtypes.int32, ) self.assertAllEqual(output, [66, 44, 198]) # Test mapping over a dict of RTs can produce a dict of RTs. def testRaggedMapOnStructure_RaggedOutputs(self): batman = ragged_factory_ops.constant([[1, 2, 3], [4], [5, 6, 7]]) # [[10, 20, 30], [40], [50, 60, 70]] robin = ragged_functional_ops.map_flat_values(mo.multiply, batman, 10) features = {'batman': batman, 'robin': robin} def _increment(f): return { 'batman': f['batman'] + 1, 'robin': f['robin'] + 1, } output = ragged_map_ops.map_fn( fn=_increment, elems=features, infer_shape=False, dtype={ 'batman': ragged_tensor.RaggedTensorType(dtype=dtypes.int32, ragged_rank=1), 'robin': ragged_tensor.RaggedTensorType(dtype=dtypes.int32, ragged_rank=1) }, ) self.assertAllEqual(output['batman'], [[2, 3, 4], [5], [6, 7, 8]]) self.assertAllEqual(output['robin'], [[11, 21, 31], [41], [51, 61, 71]]) def testZip(self): x = ragged_factory_ops.constant( [[10, 20], [30, 40], [50, 60], [70], [80, 90, 100]], dtypes.int64) y = array_ops.expand_dims(mo.range(x.nrows(out_type=dtypes.int64)), axis=1) def _zip(foo): y_val, x_val = foo bar = array_ops.tile(y_val, array_ops.shape(x_val)) return array_ops.stack([bar, x_val], axis=1) output = ragged_map_ops.map_fn(_zip, (y, x), dtype=ragged_tensor.RaggedTensorType( dtype=dtypes.int64, ragged_rank=1), infer_shape=False) self.assertAllEqual( output, [[[0, 10], [0, 20]], [[1, 30], [1, 40]], [[2, 50], [2, 60]], [[3, 70]], [[4, 80], [4, 90], [4, 100]]]) def testBatchGather(self): tokens = ragged_factory_ops.constant([['hello', '.', 'there'], ['merhaba'], ['bonjour', '.', 'ca va', '?']]) indices = ragged_factory_ops.constant([[0, 2], [0], [0, 2]]) def gather(x): tokens_val, indices_val = x return array_ops.gather(tokens_val, indices_val) data = tokens, indices out = ragged_map_ops.map_fn(gather, data, dtype=ragged_tensor.RaggedTensorType( dtype=dtypes.string, ragged_rank=1), infer_shape=False) self.assertAllEqual( out, [[b'hello', b'there'], [b'merhaba'], [b'bonjour', b'ca va']]) def testMismatchRaggedRank(self): elems = ragged_factory_ops.constant([[[1, 2, 3]], [[4, 5], [6, 7]]]) fn = lambda x: ragged_math_ops.reduce_sum(x, axis=0) with self.assertRaisesRegex( ValueError, r'(?s)Expected `fn` to return.*But it returned.*'): _ = ragged_map_ops.map_fn(fn, elems, dtype=ragged_tensor.RaggedTensorType( dtype=dtypes.int64, ragged_rank=23)) def testMismatchRaggedRank2(self): elems = ragged_factory_ops.constant([[1, 2, 3], [4, 5], [6, 7]]) fn = lambda x: ragged_tensor.RaggedTensor.from_row_starts(x, [0]) with self.assertRaisesRegex( ValueError, r'(?s)Expected `fn` to return.*But it returned.*'): _ = ragged_map_ops.map_fn(fn, elems, dtype=ragged_tensor.RaggedTensorType( dtype=dtypes.int64, ragged_rank=10)) def testMapOnSparseTensor(self): s = sparse_tensor.SparseTensor( indices=[[0, 0], [0, 1], [1, 0], [1, 1]], values=[0, 5, 0, 4], dense_shape=[2, 2], ) t2 = ragged_tensor.RaggedTensor.from_sparse(s) id_t2 = ragged_map_ops.map_fn( lambda x: x, t2, ) self.assertAllEqual(id_t2, [[0, 5], [0, 4]])
def boolean_mask(data, mask, keepdims=False, name=None): """Applies a boolean mask to `data`. Returns a potentially ragged tensor that is formed by retaining the elements in `data` where the corresponding value in `mask` is `True`. If `keepdims` is true then outer dimensions (corresponding to the `mask` dimensions) are preserved, and: * `output[a1...aA, i, b1...bB] = data[a1...aA, j, b1...bB]` Where `j` is the `i`th `True` entry of `mask[a1...aA]`. If `keepdims` is false, then the outer dimensions are collapsed (similar to the behavior of `tf.boolean_mask`), and: * `output[i, b1...bB] = data[a1...aA, b1...bB]` Where `(a1...aA)` is the `i`th `True` entry of `mask` (in row-major order). Args: data: A potentially ragged tensor. mask: A potentially ragged boolean tensor. `mask`'s shape must be a prefix of `data`'s shape. `rank(mask)` must be known statically. keepdims: Whether to preserve the outer dimensions (`keepdims=True`) or flatten them (`keepdims=False`). name: A name prefix for the returned tensor (optional). Returns: A potentially ragged tensor that is formed by retaining the elements in `data` where the corresponding value in `mask` is `True`. If `keepdims` is false: * `rank(output) = rank(data) - rank(mask) + 1`. * `output.ragged_rank = max(data.ragged_rank - rank(mask) + 1, 0)`. If `keepdims` is true: * `rank(output) = rank(data)`. * `output.ragged_rank = max(data.ragged_rank, rank(mask) - 1)`. Raises: ValueError: if `rank(mask)` is not known statically; or if `mask.shape` is not a prefix of `data.shape`. #### Examples: ```python >>> # Aliases for True & False so data and mask line up. >>> T, F = (True, False) >>> tf.ragged.boolean_mask( # Mask a 2D Tensor. Flatten outer dims. ... data=[[1, 2, 3], [4, 5, 6], [7, 8, 9]], ... mask=[[T, F, T], [F, F, F], [T, F, F]], ... keepdims=False).tolist() [1, 3, 7] >>> tf.ragged.boolean_mask( # Mask a 2D Tensor. Preserve outer dims. ... data=[[1, 2, 3], [4, 5, 6], [7, 8, 9]], ... mask=[[T, F, T], [F, F, F], [T, F, F]], ... keepdims=True).tolist() [[1, 3], [], [7]] >>> tf.ragged.boolean_mask( # Mask a 2D RaggedTensor. Flatten outer dims. ... tf.ragged.constant([[1, 2, 3], [4], [5, 6]]), ... tf.ragged.constant([[F, F, T], [F], [T, T]]), ... keepdims=False).tolist() [3, 5, 6] >>> tf.ragged.boolean_mask( # Mask a 2D RaggedTensor. Preserve outer dims. ... tf.ragged.constant([[1, 2, 3], [4], [5, 6]]), ... tf.ragged.constant([[F, F, T], [F], [T, T]]), ... keepdims=True).tolist() [[3], [], [5, 6]] >>> tf.ragged.boolean_mask( # Mask rows of a 2D RaggedTensor. ... tf.ragged.constant([[1, 2, 3], [4], [5, 6]]), ... tf.ragged.constant([True, False, True]), ... keepdims=True).tolist() [[1, 2, 3], [5, 6]] ``` """ with ops.name_scope(name, 'RaggedMask', [data, mask]): # Convert inputs to tensors. data = ragged_tensor.convert_to_tensor_or_ragged_tensor(data, name='data') mask = ragged_tensor.convert_to_tensor_or_ragged_tensor( mask, dtypes.bool, name='mask') row_splits_dtype, (data, mask) = ragged_tensor.match_row_splits_dtypes( data, mask, return_dtype=True) # Get static rank of mask. if mask.shape.ndims is None: raise ValueError('mask.shape.ndims must be known statically.') elif mask.shape.ndims == 0: raise ValueError('mask cannot be scalar.') # If mask is ragged, then recurse with a non-ragged mask. if ragged_tensor.is_ragged(mask): if not ragged_tensor.is_ragged(data): data = ragged_tensor.RaggedTensor.from_tensor( data, ragged_rank=mask.ragged_rank, row_splits_dtype=mask.row_splits.dtype) # Check that mask.nested_row_splits is a prefix of # data.nested_row_splits. splits_list = [ mask.nested_row_splits, data.nested_row_splits[:mask.ragged_rank] ] with ops.control_dependencies( ragged_util.assert_splits_match(splits_list)): # Strip off ragged `splits` until `mask` is non-ragged. Keep the splits # that we strip off in `splits`, so we can add them back on after # we recursively mask the non-ragged data. splits = [] while ragged_tensor.is_ragged(mask): if mask.shape.ndims > 2: splits.append(mask.row_splits) else: # Count the number of True mask values in each row to find the # lengths of the filtered rows; then convert to splits. int_mask = ragged_functional_ops.map_flat_values( math_ops.cast, mask, dtype=row_splits_dtype) masked_row_lengths = ragged_math_ops.reduce_sum(int_mask, axis=1) splits.append(ragged_util.lengths_to_splits(masked_row_lengths)) mask = mask.values data = data.values # Recursively apply the nested non-ragged mask to the nested data. masked_values = boolean_mask(data, mask, keepdims) # Add the ragged `splits` back to the result. if keepdims: masked_values = ragged_tensor.RaggedTensor.from_nested_row_splits( masked_values, splits, validate=False) return masked_values # If mask is non-ragged and has rank 1, and data is ragged, then build a # ragged tensor with the indicated rows. elif ragged_tensor.is_ragged(data) and mask.shape.ndims == 1: # Get the masked splits: first get the length of each row, then filter # out the rows that we are deleting, and convert that filtered set of # masks back to a splits tensor. lengths = data.row_lengths() masked_lengths = array_ops.boolean_mask(lengths, mask) masked_splits = ragged_util.lengths_to_splits(masked_lengths) # Get the masked values: first get row ids corresponding to each # value, then use tf.gather to build a boolean mask that's false for # values that come from rows that we are deleting, and use that mask to # construct the masked values tensor. segment_ids = segment_id_ops.row_splits_to_segment_ids(data.row_splits) segment_mask = array_ops.gather(mask, segment_ids) masked_values = boolean_mask(data.values, segment_mask, keepdims=False) return ragged_tensor.RaggedTensor.from_row_splits(masked_values, masked_splits, validate=False) # If mask is non-ragged and has rank>1, then convert it to be ragged, # with a ragged rank matching data. if ragged_tensor.is_ragged(data): mask = ragged_tensor.RaggedTensor.from_tensor( mask, ragged_rank=min(data.ragged_rank, mask.shape.ndims - 1), row_splits_dtype=data.row_splits.dtype) return boolean_mask(data, mask, keepdims) # Otherwise, data and mask are both `Tensor`s. else: # Apply `boolean_mask` to get the masked values. masked_values = array_ops.boolean_mask(data, mask) if mask.shape.ndims >= 2 and keepdims: # Add the innermost ragged dimension. For each innermost cell, get the # number of values it contains. Then flatten that to get a list of # cell lengths, and convert it to splits. Finally, combine the splits # and values to get the innermost ragged tensor. masked_lengths = math_ops.count_nonzero(mask, axis=-1, dtype=row_splits_dtype) flattened_masked_lengths = array_ops.reshape(masked_lengths, [-1]) masked_values = ragged_tensor.RaggedTensor.from_row_lengths( masked_values, flattened_masked_lengths, validate=False) # Wrap remaining ragged dimensions. if mask.shape.ndims > 2 and keepdims: mask_shape = array_ops.shape(mask, out_type=row_splits_dtype) split_size = math_ops.cumprod(mask_shape) + 1 for dim in range(mask.shape.ndims - 3, -1, -1): elt_size = mask_shape[dim + 1] masked_splits = math_ops.range(split_size[dim]) * elt_size masked_values = ragged_tensor.RaggedTensor.from_row_splits( masked_values, masked_splits, validate=False) return masked_values