Beispiel #1
0
 def testMismatchRaggedRank(self):
     elems = ragged_factory_ops.constant([[[1, 2, 3]], [[4, 5], [6, 7]]])
     fn = lambda x: ragged_math_ops.reduce_sum(x, axis=0)
     with self.assertRaisesRegex(
             ValueError, r'(?s)Expected `fn` to return.*But it returned.*'):
         _ = ragged_map_ops.map_fn(fn,
                                   elems,
                                   dtype=ragged_tensor.RaggedTensorType(
                                       dtype=dtypes.int64, ragged_rank=23))
Beispiel #2
0
 def testMismatchRaggedRank(self):
     elems = ragged_factory_ops.constant([[[1, 2, 3]], [[4, 5], [6, 7]]])
     fn = lambda x: ragged_math_ops.reduce_sum(x, axis=0)
     with self.assertRaisesWithLiteralMatch(
             ValueError,
             r'The declared ragged rank (23) mismatches the result (1)'):
         _ = ragged_map_ops.map_fn(fn,
                                   elems,
                                   dtype=ragged_tensor.RaggedTensorType(
                                       dtype=dtypes.int64, ragged_rank=23))
 def testMismatchRaggedRank(self):
   elems = ragged_factory_ops.constant([[[1, 2, 3]], [[4, 5], [6, 7]]])
   fn = lambda x: ragged_math_ops.reduce_sum(x, axis=0)
   with self.assertRaisesWithLiteralMatch(
       ValueError, r'The declared ragged rank (23) mismatches the result (1)'):
     _ = ragged_map_ops.map_fn(
         fn,
         elems,
         dtype=ragged_tensor.RaggedTensorType(
             dtype=dtypes.int64, ragged_rank=23))
def boolean_mask(data, mask, name=None):
  """Applies a boolean mask to `data` without flattening the mask dimensions.

  Returns a potentially ragged tensor that is formed by retaining the elements
  in `data` where the corresponding value in `mask` is `True`.

  * `output[a1...aA, i, b1...bB] = data[a1...aA, j, b1...bB]`

     Where `j` is the `i`th `True` entry of `mask[a1...aA]`.

  Note that `output` preserves the mask dimensions `a1...aA`; this differs
  from `tf.boolean_mask`, which flattens those dimensions.

  Args:
    data: A potentially ragged tensor.
    mask: A potentially ragged boolean tensor.  `mask`'s shape must be a prefix
      of `data`'s shape.  `rank(mask)` must be known statically.
    name: A name prefix for the returned tensor (optional).

  Returns:
    A potentially ragged tensor that is formed by retaining the elements in
    `data` where the corresponding value in `mask` is `True`.

    * `rank(output) = rank(data)`.
    * `output.ragged_rank = max(data.ragged_rank, rank(mask) - 1)`.

  Raises:
    ValueError: if `rank(mask)` is not known statically; or if `mask.shape` is
      not a prefix of `data.shape`.

  #### Examples:

  >>> # Aliases for True & False so data and mask line up.
  >>> T, F = (True, False)

  >>> tf.ragged.boolean_mask(  # Mask a 2D Tensor.
  ...     data=[[1, 2, 3], [4, 5, 6], [7, 8, 9]],
  ...     mask=[[T, F, T], [F, F, F], [T, F, F]]).to_list()
  [[1, 3], [], [7]]

  >>> tf.ragged.boolean_mask(  # Mask a 2D RaggedTensor.
  ...     tf.ragged.constant([[1, 2, 3], [4], [5, 6]]),
  ...     tf.ragged.constant([[F, F, T], [F], [T, T]])).to_list()
  [[3], [], [5, 6]]

  >>> tf.ragged.boolean_mask(  # Mask rows of a 2D RaggedTensor.
  ...     tf.ragged.constant([[1, 2, 3], [4], [5, 6]]),
  ...     tf.ragged.constant([True, False, True])).to_list()
  [[1, 2, 3], [5, 6]]
  """
  with ops.name_scope(name, 'RaggedMask', [data, mask]):
    # Convert inputs to tensors.
    data = ragged_tensor.convert_to_tensor_or_ragged_tensor(data, name='data')
    mask = ragged_tensor.convert_to_tensor_or_ragged_tensor(
        mask, dtypes.bool, name='mask')
    row_splits_dtype, (data, mask) = ragged_tensor.match_row_splits_dtypes(
        data, mask, return_dtype=True)

    # Get static rank of mask.
    if mask.shape.ndims is None:
      raise ValueError('mask.shape.ndims must be known statically.')
    elif mask.shape.ndims == 0:
      raise ValueError('mask cannot be scalar.')

    # If mask is ragged, then recurse with a non-ragged mask.
    if ragged_tensor.is_ragged(mask):
      if not ragged_tensor.is_ragged(data):
        data = ragged_tensor.RaggedTensor.from_tensor(
            data,
            ragged_rank=mask.ragged_rank,
            row_splits_dtype=mask.row_splits.dtype)
      # Check that mask.nested_row_splits is a prefix of
      # data.nested_row_splits.
      splits_list = [
          mask.nested_row_splits, data.nested_row_splits[:mask.ragged_rank]
      ]
      with ops.control_dependencies(
          ragged_util.assert_splits_match(splits_list)):
        # Strip off ragged `splits` until `mask` is non-ragged.  Keep the splits
        # that we strip off in `splits`, so we can add them back on after
        # we recursively mask the non-ragged data.
        splits = []
        while ragged_tensor.is_ragged(mask):
          if mask.shape.ndims > 2:
            splits.append(mask.row_splits)
          else:
            # Count the number of True mask values in each row to find the
            # lengths of the filtered rows; then convert to splits.
            int_mask = ragged_functional_ops.map_flat_values(
                math_ops.cast, mask, dtype=row_splits_dtype)
            masked_row_lengths = ragged_math_ops.reduce_sum(int_mask, axis=1)
            splits.append(ragged_util.lengths_to_splits(masked_row_lengths))
          mask = mask.values
          data = data.values

        # Recursively apply the nested non-ragged mask to the nested data.
        masked_values = boolean_mask(data, mask)

        # Add the ragged `splits` back to the result.
        masked_values = ragged_tensor.RaggedTensor.from_nested_row_splits(
            masked_values, splits, validate=False)

        return masked_values

    # If mask is non-ragged and has rank 1, and data is ragged, then build a
    # ragged tensor with the indicated rows.
    elif ragged_tensor.is_ragged(data) and mask.shape.ndims == 1:
      # Get the masked splits: first get the length of each row, then filter
      # out the rows that we are deleting, and convert that filtered set of
      # masks back to a splits tensor.
      lengths = data.row_lengths()
      masked_lengths = array_ops.boolean_mask(lengths, mask)
      masked_splits = ragged_util.lengths_to_splits(masked_lengths)

      # Get the masked values: first get row ids corresponding to each
      # value, then use tf.gather to build a boolean mask that's false for
      # values that come from rows that we are deleting, and use that mask to
      # construct the masked values tensor.
      segment_ids = segment_id_ops.row_splits_to_segment_ids(data.row_splits)
      segment_mask = array_ops.gather(mask, segment_ids)
      masked_values = boolean_mask(data.values, segment_mask)

      return ragged_tensor.RaggedTensor.from_row_splits(
          masked_values, masked_splits, validate=False)

    # If mask is non-ragged and has rank>1, then convert it to be ragged,
    # with a ragged rank matching data.
    if ragged_tensor.is_ragged(data):
      mask = ragged_tensor.RaggedTensor.from_tensor(
          mask,
          ragged_rank=min(data.ragged_rank, mask.shape.ndims - 1),
          row_splits_dtype=data.row_splits.dtype)
      return boolean_mask(data, mask)

    # Otherwise, data and mask are both `Tensor`s.
    else:
      # Apply `boolean_mask` to get the masked values.
      masked_values = array_ops.boolean_mask(data, mask)

      if mask.shape.ndims >= 2:
        # Add the innermost ragged dimension.  For each innermost cell, get the
        # number of values it contains.  Then flatten that to get a list of
        # cell lengths, and convert it to splits.  Finally, combine the splits
        # and values to get the innermost ragged tensor.
        masked_lengths = math_ops.count_nonzero(
            mask, axis=-1, dtype=row_splits_dtype)
        flattened_masked_lengths = array_ops.reshape(masked_lengths, [-1])
        masked_values = ragged_tensor.RaggedTensor.from_row_lengths(
            masked_values, flattened_masked_lengths, validate=False)

        # Wrap remaining ragged dimensions.
        if mask.shape.ndims > 2:
          mask_shape = array_ops.shape(mask, out_type=row_splits_dtype)
          split_size = math_ops.cumprod(mask_shape) + 1
          for dim in range(mask.shape.ndims - 3, -1, -1):
            elt_size = mask_shape[dim + 1]
            masked_splits = math_ops.range(split_size[dim]) * elt_size
            masked_values = ragged_tensor.RaggedTensor.from_row_splits(
                masked_values, masked_splits, validate=False)

      return masked_values
Beispiel #5
0
 def testReduceKeepsInnerDimensionShape(self):
   # Test for bug [b/139823356].
   rt = ragged_factory_ops.constant([[[[1, 1]]]], ragged_rank=2)
   self.assertEqual(rt.shape.as_list(), [1, None, None, 2])
   reduced = ragged_math_ops.reduce_sum(rt, axis=2)
   self.assertEqual(reduced.shape.as_list(), [1, None, 2])
Beispiel #6
0
class RaggedMapOpTest(test_util.TensorFlowTestCase, parameterized.TestCase):
    @parameterized.parameters([
        # The following test sets map over a RaggedTensor and apply a
        # transformation that returns with shape:
        # [d1, (d2)] -> [d1]
        dict(
            fn=mo.reduce_mean,
            elems=[[1, 2, 3], [4, 5], [6, 7]],
            elems_dtype=dtypes.int32,
            expected_output=[2, 4, 6],
            result_dtype=dtypes.int32,
        ),
        dict(
            fn=string_ops.reduce_join,
            elems=[['foo', 'bar', 'baz'], ['a'], ['b', 'c']],
            expected_output=[b'foobarbaz', b'a', b'bc'],
            elems_dtype=dtypes.string,
            result_dtype=dtypes.string,
        ),
        # [d1, (d2)] -> [d1, 2]
        dict(
            fn=lambda x: array_ops.stack([mo.reduce_mean(x),
                                          mo.reduce_sum(x)]),
            # fn=self.stack_mean_and_sum,
            elems=[[1, 2, 3], [4, 5], [6, 7]],
            expected_output=[[2, 6], [4.5, 9], [6.5, 13]],
            elems_dtype=dtypes.float32,
            result_dtype=dtypes.float32,
            expected_ragged_rank=0,
        ),
        # [d1, (d2)] -> [d1, (d2)]
        dict(
            fn=lambda x: x + np.int64(1),
            elems=[[1, 2, 3], [4, 5], [6, 7]],
            expected_output=[[2, 3, 4], [5, 6], [7, 8]],
            elems_dtype=dtypes.int64,
            result_dtype=ragged_tensor.RaggedTensorType(dtype=dtypes.int64,
                                                        ragged_rank=1),
        ),
        # [d1, (d2), d3] -> [d1, (d2), d3]
        dict(
            fn=lambda x: x + np.int64(1),
            elems=[[[1, 2], [3, 4]], [], [[5, 6], [7, 8], [9, 0]]],
            elems_ragged_rank=1,
            expected_ragged_rank=1,
            result_dtype=ragged_tensor.RaggedTensorType(dtype=dtypes.int64,
                                                        ragged_rank=1),
            expected_output=[[[2, 3], [4, 5]], [], [[6, 7], [8, 9], [10, 1]]],
        ),
        # [d1, (d2)] -> [d1, (d2), (d3)]
        dict(
            fn=lambda x: ragged_tensor.RaggedTensor.from_row_starts(x, [0]),
            elems=[[1, 2, 3], [4, 5], [6, 7]],
            expected_output=[[[1, 2, 3]], [[4, 5]], [[6, 7]]],
            result_dtype=ragged_tensor.RaggedTensorType(dtype=dtypes.int64,
                                                        ragged_rank=2),
        ),
        # [d1, (d2), (d3)] -> [d1, (d2), (d3)]
        dict(
            fn=lambda x: ragged_functional_ops.map_flat_values(mo.add, x, 1),
            elems=[[[1, 2, 3]], [[4, 5], [6, 7]]],
            expected_output=[[[2, 3, 4]], [[5, 6], [7, 8]]],
            result_dtype=ragged_tensor.RaggedTensorType(dtype=dtypes.int64,
                                                        ragged_rank=2),
        ),
        # [d1, (d2), (d3)] -> [d1, (d2)]
        dict(
            fn=lambda x: ragged_math_ops.reduce_sum(x, axis=1),
            elems=[[[1, 2, 3]], [[4, 5], [6, 7]]],
            expected_output=[[6], [9, 13]],
            result_dtype=ragged_tensor.RaggedTensorType(dtype=dtypes.int64,
                                                        ragged_rank=1),
        ),
        # [d1, (d2), (d3)] -> [d1, (d3)]
        dict(
            fn=lambda x: ragged_math_ops.reduce_sum(x, axis=0),
            elems=[[[1, 2, 3]], [[4, 5], [6, 7]]],
            expected_output=[[1, 2, 3], [10, 12]],
            result_dtype=ragged_tensor.RaggedTensorType(dtype=dtypes.int64,
                                                        ragged_rank=1),
        ),
        # [d1, (d2), (d3)] -> [d1]
        dict(
            fn=ragged_math_ops.reduce_sum,
            elems=[[[1, 2, 3]], [[4, 5], [6, 7]]],
            expected_output=[6, 22],
            result_dtype=dtypes.int64,
        ),
        # [d1] -> [d1, (d2)]
        dict(
            fn=mo.range,
            elems=[4, 0, 2],
            expected_output=[[0, 1, 2, 3], [], [0, 1]],
            result_dtype=ragged_tensor.RaggedTensorType(dtype=dtypes.int64,
                                                        ragged_rank=1),
        ),
        # [d1] -> [d1, (d2), (d3)]
        dict(
            fn=lambda x: ragged_math_ops.range(mo.range(x)),
            elems=[5, 0, 3],
            expected_output=[[[], [0], [0, 1], [0, 1, 2], [0, 1, 2, 3]], [],
                             [[], [0], [0, 1]]],
            result_dtype=ragged_tensor.RaggedTensorType(dtype=dtypes.int64,
                                                        ragged_rank=2),
        ),
        # [d1, (d2), (d3), (d4a), (d5)] ->  [d1, (d2), (d3), (d4b), (d5)]
        dict(
            fn=lambda x: x + np.int64(1),
            elems=[[[[[1, 2, 3]], [[4], [5]]]], [[[[6, 7]]], [[[8], []]]]],
            expected_output=[[[[[2, 3, 4]], [[5], [6]]]],
                             [[[[7, 8]]], [[[9], []]]]],
            result_dtype=ragged_tensor.RaggedTensorType(dtype=dtypes.int64,
                                                        ragged_rank=4),
        ),
    ])
    def testRaggedMap(
        self,
        fn,
        elems,
        expected_output,
        expected_ragged_rank=None,
        result_ragged_rank=None,
        elems_ragged_rank=None,
        elems_dtype=dtypes.int64,
        result_dtype=None,
        infer_shape=True,
    ):
        elems = ragged_factory_ops.constant(elems, elems_dtype,
                                            elems_ragged_rank)
        output = ragged_map_ops.map_fn(fn=fn,
                                       elems=elems,
                                       dtype=result_dtype,
                                       infer_shape=infer_shape)

        expected_rt = ragged_factory_ops.constant(
            expected_output, ragged_rank=expected_ragged_rank)
        self.assertAllEqual(expected_rt, output)

    def testRaggedMapOnStructure(self):
        batman = ragged_factory_ops.constant([[1, 2, 3], [4], [5, 6, 7]])
        # [[10, 20, 30], [40], [50, 60, 70]]
        robin = ragged_functional_ops.map_flat_values(mo.multiply, batman, 10)

        features = {'batman': batman, 'robin': robin}

        def _reduce_sum_from_all(f):
            return mo.reduce_sum(f['batman']) + mo.reduce_sum(f['robin'])

        output = ragged_map_ops.map_fn(
            fn=_reduce_sum_from_all,
            elems=features,
            dtype=dtypes.int32,
        )

        self.assertAllEqual(output, [66, 44, 198])

    # Test mapping over a dict of RTs can produce a dict of RTs.
    def testRaggedMapOnStructure_RaggedOutputs(self):
        batman = ragged_factory_ops.constant([[1, 2, 3], [4], [5, 6, 7]])
        # [[10, 20, 30], [40], [50, 60, 70]]
        robin = ragged_functional_ops.map_flat_values(mo.multiply, batman, 10)

        features = {'batman': batman, 'robin': robin}

        def _increment(f):
            return {
                'batman': f['batman'] + 1,
                'robin': f['robin'] + 1,
            }

        output = ragged_map_ops.map_fn(
            fn=_increment,
            elems=features,
            infer_shape=False,
            dtype={
                'batman':
                ragged_tensor.RaggedTensorType(dtype=dtypes.int32,
                                               ragged_rank=1),
                'robin':
                ragged_tensor.RaggedTensorType(dtype=dtypes.int32,
                                               ragged_rank=1)
            },
        )

        self.assertAllEqual(output['batman'], [[2, 3, 4], [5], [6, 7, 8]])
        self.assertAllEqual(output['robin'],
                            [[11, 21, 31], [41], [51, 61, 71]])

    def testZip(self):
        x = ragged_factory_ops.constant(
            [[10, 20], [30, 40], [50, 60], [70], [80, 90, 100]], dtypes.int64)
        y = array_ops.expand_dims(mo.range(x.nrows(out_type=dtypes.int64)),
                                  axis=1)

        def _zip(foo):
            y_val, x_val = foo
            bar = array_ops.tile(y_val, array_ops.shape(x_val))
            return array_ops.stack([bar, x_val], axis=1)

        output = ragged_map_ops.map_fn(_zip, (y, x),
                                       dtype=ragged_tensor.RaggedTensorType(
                                           dtype=dtypes.int64, ragged_rank=1),
                                       infer_shape=False)

        self.assertAllEqual(
            output,
            [[[0, 10], [0, 20]], [[1, 30], [1, 40]], [[2, 50], [2, 60]],
             [[3, 70]], [[4, 80], [4, 90], [4, 100]]])

    def testBatchGather(self):
        tokens = ragged_factory_ops.constant([['hello', '.', 'there'],
                                              ['merhaba'],
                                              ['bonjour', '.', 'ca va', '?']])
        indices = ragged_factory_ops.constant([[0, 2], [0], [0, 2]])

        def gather(x):
            tokens_val, indices_val = x
            return array_ops.gather(tokens_val, indices_val)

        data = tokens, indices
        out = ragged_map_ops.map_fn(gather,
                                    data,
                                    dtype=ragged_tensor.RaggedTensorType(
                                        dtype=dtypes.string, ragged_rank=1),
                                    infer_shape=False)

        self.assertAllEqual(
            out, [[b'hello', b'there'], [b'merhaba'], [b'bonjour', b'ca va']])

    def testMismatchRaggedRank(self):
        elems = ragged_factory_ops.constant([[[1, 2, 3]], [[4, 5], [6, 7]]])
        fn = lambda x: ragged_math_ops.reduce_sum(x, axis=0)
        with self.assertRaisesRegex(
                ValueError, r'(?s)Expected `fn` to return.*But it returned.*'):
            _ = ragged_map_ops.map_fn(fn,
                                      elems,
                                      dtype=ragged_tensor.RaggedTensorType(
                                          dtype=dtypes.int64, ragged_rank=23))

    def testMismatchRaggedRank2(self):
        elems = ragged_factory_ops.constant([[1, 2, 3], [4, 5], [6, 7]])
        fn = lambda x: ragged_tensor.RaggedTensor.from_row_starts(x, [0])
        with self.assertRaisesRegex(
                ValueError, r'(?s)Expected `fn` to return.*But it returned.*'):
            _ = ragged_map_ops.map_fn(fn,
                                      elems,
                                      dtype=ragged_tensor.RaggedTensorType(
                                          dtype=dtypes.int64, ragged_rank=10))

    def testMapOnSparseTensor(self):
        s = sparse_tensor.SparseTensor(
            indices=[[0, 0], [0, 1], [1, 0], [1, 1]],
            values=[0, 5, 0, 4],
            dense_shape=[2, 2],
        )
        t2 = ragged_tensor.RaggedTensor.from_sparse(s)
        id_t2 = ragged_map_ops.map_fn(
            lambda x: x,
            t2,
        )
        self.assertAllEqual(id_t2, [[0, 5], [0, 4]])
def boolean_mask(data, mask, keepdims=False, name=None):
  """Applies a boolean mask to `data`.

  Returns a potentially ragged tensor that is formed by retaining the elements
  in `data` where the corresponding value in `mask` is `True`.

  If `keepdims` is true then outer dimensions (corresponding to the `mask`
  dimensions) are preserved, and:

  * `output[a1...aA, i, b1...bB] = data[a1...aA, j, b1...bB]`

     Where `j` is the `i`th `True` entry of `mask[a1...aA]`.

  If `keepdims` is false, then the outer dimensions are collapsed (similar to
  the behavior of `tf.boolean_mask`), and:

  * `output[i, b1...bB] = data[a1...aA, b1...bB]`

     Where `(a1...aA)` is the `i`th `True` entry of `mask`
     (in row-major order).

  Args:
    data: A potentially ragged tensor.
    mask: A potentially ragged boolean tensor.  `mask`'s shape must be a prefix
      of `data`'s shape.  `rank(mask)` must be known statically.
    keepdims: Whether to preserve the outer dimensions (`keepdims=True`) or
      flatten them (`keepdims=False`).
    name: A name prefix for the returned tensor (optional).

  Returns:
    A potentially ragged tensor that is formed by retaining the elements in
    `data` where the corresponding value in `mask` is `True`.

    If `keepdims` is false:

    * `rank(output) = rank(data) - rank(mask) + 1`.
    * `output.ragged_rank = max(data.ragged_rank - rank(mask) + 1, 0)`.

    If `keepdims` is true:

    * `rank(output) = rank(data)`.
    * `output.ragged_rank = max(data.ragged_rank, rank(mask) - 1)`.

  Raises:
    ValueError: if `rank(mask)` is not known statically; or if `mask.shape` is
      not a prefix of `data.shape`.

  #### Examples:
    ```python
    >>> # Aliases for True & False so data and mask line up.
    >>> T, F = (True, False)

    >>> tf.ragged.boolean_mask(  # Mask a 2D Tensor.  Flatten outer dims.
    ...     data=[[1, 2, 3], [4, 5, 6], [7, 8, 9]],
    ...     mask=[[T, F, T], [F, F, F], [T, F, F]],
    ...     keepdims=False).tolist()
    [1, 3, 7]

    >>> tf.ragged.boolean_mask(  # Mask a 2D Tensor.  Preserve outer dims.
    ...     data=[[1, 2, 3], [4, 5, 6], [7, 8, 9]],
    ...     mask=[[T, F, T], [F, F, F], [T, F, F]],
    ...     keepdims=True).tolist()
    [[1, 3], [], [7]]

    >>> tf.ragged.boolean_mask(  # Mask a 2D RaggedTensor.  Flatten outer dims.
    ...     tf.ragged.constant([[1, 2, 3], [4], [5, 6]]),
    ...     tf.ragged.constant([[F, F, T], [F], [T, T]]),
    ...     keepdims=False).tolist()
    [3, 5, 6]

    >>> tf.ragged.boolean_mask(  # Mask a 2D RaggedTensor.  Preserve outer dims.
    ...     tf.ragged.constant([[1, 2, 3], [4], [5, 6]]),
    ...     tf.ragged.constant([[F, F, T], [F], [T, T]]),
    ...     keepdims=True).tolist()
    [[3], [], [5, 6]]

    >>> tf.ragged.boolean_mask(  # Mask rows of a 2D RaggedTensor.
    ...     tf.ragged.constant([[1, 2, 3], [4], [5, 6]]),
    ...     tf.ragged.constant([True, False, True]),
    ...     keepdims=True).tolist()
    [[1, 2, 3], [5, 6]]
    ```
  """
  with ops.name_scope(name, 'RaggedMask', [data, mask]):
    # Convert inputs to tensors.
    data = ragged_tensor.convert_to_tensor_or_ragged_tensor(data, name='data')
    mask = ragged_tensor.convert_to_tensor_or_ragged_tensor(
        mask, dtypes.bool, name='mask')
    row_splits_dtype, (data, mask) = ragged_tensor.match_row_splits_dtypes(
        data, mask, return_dtype=True)

    # Get static rank of mask.
    if mask.shape.ndims is None:
      raise ValueError('mask.shape.ndims must be known statically.')
    elif mask.shape.ndims == 0:
      raise ValueError('mask cannot be scalar.')

    # If mask is ragged, then recurse with a non-ragged mask.
    if ragged_tensor.is_ragged(mask):
      if not ragged_tensor.is_ragged(data):
        data = ragged_tensor.RaggedTensor.from_tensor(
            data, ragged_rank=mask.ragged_rank,
            row_splits_dtype=mask.row_splits.dtype)
      # Check that mask.nested_row_splits is a prefix of
      # data.nested_row_splits.
      splits_list = [
          mask.nested_row_splits, data.nested_row_splits[:mask.ragged_rank]
      ]
      with ops.control_dependencies(
          ragged_util.assert_splits_match(splits_list)):
        # Strip off ragged `splits` until `mask` is non-ragged.  Keep the splits
        # that we strip off in `splits`, so we can add them back on after
        # we recursively mask the non-ragged data.
        splits = []
        while ragged_tensor.is_ragged(mask):
          if mask.shape.ndims > 2:
            splits.append(mask.row_splits)
          else:
            # Count the number of True mask values in each row to find the
            # lengths of the filtered rows; then convert to splits.
            int_mask = ragged_functional_ops.map_flat_values(
                math_ops.cast, mask, dtype=row_splits_dtype)
            masked_row_lengths = ragged_math_ops.reduce_sum(int_mask, axis=1)
            splits.append(ragged_util.lengths_to_splits(masked_row_lengths))
          mask = mask.values
          data = data.values

        # Recursively apply the nested non-ragged mask to the nested data.
        masked_values = boolean_mask(data, mask, keepdims)

        # Add the ragged `splits` back to the result.
        if keepdims:
          masked_values = ragged_tensor.RaggedTensor.from_nested_row_splits(
              masked_values, splits, validate=False)

        return masked_values

    # If mask is non-ragged and has rank 1, and data is ragged, then build a
    # ragged tensor with the indicated rows.
    elif ragged_tensor.is_ragged(data) and mask.shape.ndims == 1:
      # Get the masked splits: first get the length of each row, then filter
      # out the rows that we are deleting, and convert that filtered set of
      # masks back to a splits tensor.
      lengths = data.row_lengths()
      masked_lengths = array_ops.boolean_mask(lengths, mask)
      masked_splits = ragged_util.lengths_to_splits(masked_lengths)

      # Get the masked values: first get row ids corresponding to each
      # value, then use tf.gather to build a boolean mask that's false for
      # values that come from rows that we are deleting, and use that mask to
      # construct the masked values tensor.
      segment_ids = segment_id_ops.row_splits_to_segment_ids(data.row_splits)
      segment_mask = array_ops.gather(mask, segment_ids)
      masked_values = boolean_mask(data.values, segment_mask, keepdims=False)

      return ragged_tensor.RaggedTensor.from_row_splits(masked_values,
                                                        masked_splits,
                                                        validate=False)

    # If mask is non-ragged and has rank>1, then convert it to be ragged,
    # with a ragged rank matching data.
    if ragged_tensor.is_ragged(data):
      mask = ragged_tensor.RaggedTensor.from_tensor(
          mask, ragged_rank=min(data.ragged_rank, mask.shape.ndims - 1),
          row_splits_dtype=data.row_splits.dtype)
      return boolean_mask(data, mask, keepdims)

    # Otherwise, data and mask are both `Tensor`s.
    else:
      # Apply `boolean_mask` to get the masked values.
      masked_values = array_ops.boolean_mask(data, mask)

      if mask.shape.ndims >= 2 and keepdims:
        # Add the innermost ragged dimension.  For each innermost cell, get the
        # number of values it contains.  Then flatten that to get a list of
        # cell lengths, and convert it to splits.  Finally, combine the splits
        # and values to get the innermost ragged tensor.
        masked_lengths = math_ops.count_nonzero(mask, axis=-1,
                                                dtype=row_splits_dtype)
        flattened_masked_lengths = array_ops.reshape(masked_lengths, [-1])
        masked_values = ragged_tensor.RaggedTensor.from_row_lengths(
            masked_values, flattened_masked_lengths, validate=False)

        # Wrap remaining ragged dimensions.
        if mask.shape.ndims > 2 and keepdims:
          mask_shape = array_ops.shape(mask, out_type=row_splits_dtype)
          split_size = math_ops.cumprod(mask_shape) + 1
          for dim in range(mask.shape.ndims - 3, -1, -1):
            elt_size = mask_shape[dim + 1]
            masked_splits = math_ops.range(split_size[dim]) * elt_size
            masked_values = ragged_tensor.RaggedTensor.from_row_splits(
                masked_values, masked_splits, validate=False)

      return masked_values