def _broadcast_elementwise_args(elementwise_args):
  """Broadcasts the values of `elementwise_args` to have compatible shapes.

  Args:
    elementwise_args: A dictionary whose keys are potentially ragged tensors.

  Returns:
    A tuple `(broadcast_args, broadcast_splits, checks)` where:

    * `broadcast_args` is a dictionary with the same keys as
      `elementwise_args`, mapping to broadcasted tensors.
    * `broadcast_splits` is the broadcasted nested row splits.
    * `checks` is a possibly empty tuple of assertion operations that should
      be added as control dependencies.

  Raises:
    ValueError: If broadcasting fails.
  """
  # No elementwise arguments were used: nothing to do!
  if not elementwise_args:
    return elementwise_args, (), ()

  # A single elementwise argument was used: no broadcasting necessary.
  if len(elementwise_args) == 1:
    arg = list(elementwise_args.values())[0]
    if ragged_tensor.is_ragged(arg):
      return elementwise_args, arg.nested_row_splits, ()
    else:
      return elementwise_args, (), ()

  # Multiple elementwise arguments.
  else:
    is_ragged = [ragged_tensor.is_ragged(t) for t in elementwise_args.values()]
    if not any(is_ragged):
      return elementwise_args, (), ()

    # Support limited broadcasting (namely, scalar + ragged).  Full
    # broadcasting support will be added later.
    if all((ragged_tensor.is_ragged(t) or t.shape.ndims == 0)
           for t in elementwise_args.values()):
      nested_splits_lists = [
          t.nested_row_splits
          for t in elementwise_args.values()
          if ragged_tensor.is_ragged(t)
      ]
      if len(nested_splits_lists) == 1:
        checks = ()
      else:
        if any(t.shape.ndims is None for t in elementwise_args.values()):
          raise ValueError('Ragged elementwise ops require that rank (number '
                           'of dimensions) be statically known.')
        if len(set(t.shape.ndims for t in elementwise_args.values())) != 1:
          raise ValueError('Ragged elementwise ops do not support '
                           'broadcasting yet')
        checks = ragged_util.assert_splits_match(nested_splits_lists)
      return (elementwise_args, nested_splits_lists[0], checks)
    else:
      raise ValueError('Ragged elementwise ops do not support broadcasting yet')
Exemplo n.º 2
0
def map_flat_values(op, *args, **kwargs):
    """Applies `op` to the values of one or more RaggedTensors.

  Replaces any `RaggedTensor` in `args` or `kwargs` with its `flat_values`
  tensor, and then calls `op`.  Returns a `RaggedTensor` that is constructed
  from the input `RaggedTensor`s' `nested_row_splits` and the value returned by
  the `op`.

  If the input arguments contain multiple `RaggedTensor`s, then they must have
  identical `nested_row_splits`.

  Examples:

  ```python
  >>> rt = ragged.constant([[1, 2, 3], [], [4, 5], [6]])
  >>> ragged.map_flat_values(tf.ones_like, rt).eval().tolist()
  [[1, 1, 1], [], [1, 1], [1]]
  >>> ragged.map_flat_values(tf.multiply, rt, rt).eval().tolist()
  [[1, 4, 9], [], [16, 25], [36]]
  >>> ragged.map_flat_values(tf.add, rt, 5).eval().tolist()
  [[6, 7, 8], [], [9, 10], [11]]
  ```

  Args:
    op: The operation that should be applied to the RaggedTensor `flat_values`.
      `op` is typically an element-wise operation (such as math_ops.add), but
      any operation that preserves the size of the outermost dimension can be
      used.  I.e., `shape[0]` of the value returned by `op` must match
      `shape[0]` of the `RaggedTensor`s' `flat_values` tensors.
    *args: Arguments for `op`.
    **kwargs: Keyword arguments for `op`.

  Returns:
    A `RaggedTensor` whose `ragged_rank` matches the `ragged_rank` of all
    input `RaggedTensor`s.
  Raises:
    ValueError: If args contains no `RaggedTensors`, or if the `nested_splits`
      of the input `RaggedTensor`s are not identical.
  """
    # Replace RaggedTensors with their values; and collect the splits tensors
    # from each RaggedTensor.
    nested_splits_lists = []
    inner_args = _replace_ragged_with_flat_values(args, nested_splits_lists)
    inner_kwargs = _replace_ragged_with_flat_values(kwargs,
                                                    nested_splits_lists)
    if not nested_splits_lists:
        return op(*args, **kwargs)

    with ops.control_dependencies(
            ragged_util.assert_splits_match(nested_splits_lists)):
        # Delegate to op, and then compose the result from the transformed values
        # and the splits.
        return ragged_tensor.RaggedTensor.from_nested_row_splits(
            op(*inner_args, **inner_kwargs), nested_splits_lists[0])
Exemplo n.º 3
0
def map_flat_values(op, *args, **kwargs):
  """Applies `op` to the values of one or more RaggedTensors.

  Replaces any `RaggedTensor` in `args` or `kwargs` with its `flat_values`
  tensor, and then calls `op`.  Returns a `RaggedTensor` that is constructed
  from the input `RaggedTensor`s' `nested_row_splits` and the value returned by
  the `op`.

  If the input arguments contain multiple `RaggedTensor`s, then they must have
  identical `nested_row_splits`.

  Examples:

  ```python
  >>> rt = ragged.constant([[1, 2, 3], [], [4, 5], [6]])
  >>> ragged.map_flat_values(tf.ones_like, rt).eval().tolist()
  [[1, 1, 1], [], [1, 1], [1]]
  >>> ragged.map_flat_values(tf.multiply, rt, rt).eval().tolist()
  [[1, 4, 9], [], [16, 25], [36]]
  >>> ragged.map_flat_values(tf.add, rt, 5).eval().tolist()
  [[6, 7, 8], [], [9, 10], [11]]
  ```

  Args:
    op: The operation that should be applied to the RaggedTensor `flat_values`.
      `op` is typically an element-wise operation (such as math_ops.add), but
      any operation that preserves the size of the outermost dimension can be
      used.  I.e., `shape[0]` of the value returned by `op` must match
      `shape[0]` of the `RaggedTensor`s' `flat_values` tensors.
    *args: Arguments for `op`.
    **kwargs: Keyword arguments for `op`.

  Returns:
    A `RaggedTensor` whose `ragged_rank` matches the `ragged_rank` of all
    input `RaggedTensor`s.
  Raises:
    ValueError: If args contains no `RaggedTensors`, or if the `nested_splits`
      of the input `RaggedTensor`s are not identical.
  """
  # Replace RaggedTensors with their values; and collect the splits tensors
  # from each RaggedTensor.
  nested_splits_lists = []
  inner_args = _replace_ragged_with_flat_values(args, nested_splits_lists)
  inner_kwargs = _replace_ragged_with_flat_values(kwargs, nested_splits_lists)
  if not nested_splits_lists:
    return op(*args, **kwargs)

  with ops.control_dependencies(
      ragged_util.assert_splits_match(nested_splits_lists)):
    # Delegate to op, and then compose the result from the transformed values
    # and the splits.
    return ragged_tensor.RaggedTensor.from_nested_row_splits(
        op(*inner_args, **inner_kwargs), nested_splits_lists[0])
Exemplo n.º 4
0
 def handle(self, args, kwargs):
     if args:
         x, args = args[0], args[1:]
     else:
         kwargs = kwargs.copy()
         x = kwargs.pop(self._x, None)
     if x is None:
         return self.NOT_SUPPORTED
     if self._arg_is_list:
         found_ragged = False
         for elt in x:
             if ragged_tensor.is_ragged(elt):
                 found_ragged = True
             elif not _is_convertible_to_tensor(elt):
                 return self.NOT_SUPPORTED
         if found_ragged:
             x = [
                 ragged_tensor.convert_to_tensor_or_ragged_tensor(elt)
                 if ragged_tensor.is_ragged(elt) else elt for elt in x
             ]
             x = ragged_tensor.match_row_splits_dtypes(*x)
             ragged_elts = [
                 elt for elt in x if ragged_tensor.is_ragged(elt)
             ]
             nested_splits_lists = [
                 elt.nested_row_splits for elt in ragged_elts
             ]
             flat_values = [
                 elt.flat_values if ragged_tensor.is_ragged(elt) else elt
                 for elt in x
             ]
             with ops.control_dependencies(
                     ragged_util.assert_splits_match(nested_splits_lists)):
                 return ragged_elts[0].with_flat_values(
                     self._original_op(flat_values, *args, **kwargs))
         else:
             return self.NOT_SUPPORTED
     else:
         found_ragged = ragged_tensor.is_ragged(x)
         if found_ragged:
             x = ragged_tensor.convert_to_tensor_or_ragged_tensor(
                 x, name=self._x)
             mapped_values = self._original_op(x.flat_values, *args,
                                               **kwargs)
             return x.with_flat_values(mapped_values)
         else:
             return self.NOT_SUPPORTED
Exemplo n.º 5
0
 def handle(self, args, kwargs):
     if args:
         x, args = args[0], args[1:]
     else:
         kwargs = kwargs.copy()
         x = kwargs.pop(self._x, None)
     if x is None:
         return self.NOT_SUPPORTED
     if self._arg_is_list:
         found_ragged = False
         for elt in x:
             if ragged_tensor.is_ragged(elt):
                 found_ragged = True
             elif not _is_convertible_to_tensor(elt):
                 return self.NOT_SUPPORTED
         if found_ragged:
             nested_splits_lists = [
                 elt.nested_row_splits for elt in x
                 if ragged_tensor.is_ragged(elt)
             ]
             inner_values = [
                 elt.inner_values if ragged_tensor.is_ragged(elt) else elt
                 for elt in x
             ]
             with ops.control_dependencies(
                     ragged_util.assert_splits_match(nested_splits_lists)):
                 return ragged_factory_ops.from_nested_row_splits(
                     self._original_op(inner_values, *args, **kwargs),
                     nested_splits_lists[0])
         else:
             return self.NOT_SUPPORTED
     else:
         found_ragged = ragged_tensor.is_ragged(x)
         if found_ragged:
             mapped_values = self._original_op(x.inner_values, *args,
                                               **kwargs)
             return x.with_inner_values(mapped_values)
         else:
             return self.NOT_SUPPORTED
Exemplo n.º 6
0
 def handle(self, args, kwargs):
   if args:
     x, args = args[0], args[1:]
   else:
     kwargs = kwargs.copy()
     x = kwargs.pop(self._x, None)
   if x is None:
     return self.NOT_SUPPORTED
   if self._arg_is_list:
     found_ragged = False
     for elt in x:
       if ragged_tensor.is_ragged(elt):
         found_ragged = True
       elif not _is_convertible_to_tensor(elt):
         return self.NOT_SUPPORTED
     if found_ragged:
       x = ragged_tensor.match_row_splits_dtypes(*x)
       nested_splits_lists = [
           elt.nested_row_splits for elt in x if ragged_tensor.is_ragged(elt)
       ]
       flat_values = [
           elt.flat_values if ragged_tensor.is_ragged(elt) else elt
           for elt in x
       ]
       with ops.control_dependencies(
           ragged_util.assert_splits_match(nested_splits_lists)):
         return ragged_tensor.RaggedTensor.from_nested_row_splits(
             self._original_op(flat_values, *args, **kwargs),
             nested_splits_lists[0], validate=False)
     else:
       return self.NOT_SUPPORTED
   else:
     found_ragged = ragged_tensor.is_ragged(x)
     if found_ragged:
       mapped_values = self._original_op(x.flat_values, *args, **kwargs)
       return x.with_flat_values(mapped_values)
     else:
       return self.NOT_SUPPORTED
Exemplo n.º 7
0
def boolean_mask(data, mask, name=None):
  """Applies a boolean mask to `data` without flattening the mask dimensions.

  Returns a potentially ragged tensor that is formed by retaining the elements
  in `data` where the corresponding value in `mask` is `True`.

  * `output[a1...aA, i, b1...bB] = data[a1...aA, j, b1...bB]`

     Where `j` is the `i`th `True` entry of `mask[a1...aA]`.

  Note that `output` preserves the mask dimensions `a1...aA`; this differs
  from `tf.boolean_mask`, which flattens those dimensions.

  Args:
    data: A potentially ragged tensor.
    mask: A potentially ragged boolean tensor.  `mask`'s shape must be a prefix
      of `data`'s shape.  `rank(mask)` must be known statically.
    name: A name prefix for the returned tensor (optional).

  Returns:
    A potentially ragged tensor that is formed by retaining the elements in
    `data` where the corresponding value in `mask` is `True`.

    * `rank(output) = rank(data)`.
    * `output.ragged_rank = max(data.ragged_rank, rank(mask) - 1)`.

  Raises:
    ValueError: if `rank(mask)` is not known statically; or if `mask.shape` is
      not a prefix of `data.shape`.

  #### Examples:

  >>> # Aliases for True & False so data and mask line up.
  >>> T, F = (True, False)

  >>> tf.ragged.boolean_mask(  # Mask a 2D Tensor.
  ...     data=[[1, 2, 3], [4, 5, 6], [7, 8, 9]],
  ...     mask=[[T, F, T], [F, F, F], [T, F, F]]).to_list()
  [[1, 3], [], [7]]

  >>> tf.ragged.boolean_mask(  # Mask a 2D RaggedTensor.
  ...     tf.ragged.constant([[1, 2, 3], [4], [5, 6]]),
  ...     tf.ragged.constant([[F, F, T], [F], [T, T]])).to_list()
  [[3], [], [5, 6]]

  >>> tf.ragged.boolean_mask(  # Mask rows of a 2D RaggedTensor.
  ...     tf.ragged.constant([[1, 2, 3], [4], [5, 6]]),
  ...     tf.ragged.constant([True, False, True])).to_list()
  [[1, 2, 3], [5, 6]]
  """
  with ops.name_scope(name, 'RaggedMask', [data, mask]):
    # Convert inputs to tensors.
    data = ragged_tensor.convert_to_tensor_or_ragged_tensor(data, name='data')
    mask = ragged_tensor.convert_to_tensor_or_ragged_tensor(
        mask, dtypes.bool, name='mask')
    row_splits_dtype, (data, mask) = ragged_tensor.match_row_splits_dtypes(
        data, mask, return_dtype=True)

    # Get static rank of mask.
    if mask.shape.ndims is None:
      raise ValueError('mask.shape.ndims must be known statically.')
    elif mask.shape.ndims == 0:
      raise ValueError('mask cannot be scalar.')

    # If mask is ragged, then recurse with a non-ragged mask.
    if ragged_tensor.is_ragged(mask):
      if not ragged_tensor.is_ragged(data):
        data = ragged_tensor.RaggedTensor.from_tensor(
            data,
            ragged_rank=mask.ragged_rank,
            row_splits_dtype=mask.row_splits.dtype)
      # Check that mask.nested_row_splits is a prefix of
      # data.nested_row_splits.
      splits_list = [
          mask.nested_row_splits, data.nested_row_splits[:mask.ragged_rank]
      ]
      with ops.control_dependencies(
          ragged_util.assert_splits_match(splits_list)):
        # Strip off ragged `splits` until `mask` is non-ragged.  Keep the splits
        # that we strip off in `splits`, so we can add them back on after
        # we recursively mask the non-ragged data.
        splits = []
        while ragged_tensor.is_ragged(mask):
          if mask.shape.ndims > 2:
            splits.append(mask.row_splits)
          else:
            # Count the number of True mask values in each row to find the
            # lengths of the filtered rows; then convert to splits.
            int_mask = ragged_functional_ops.map_flat_values(
                math_ops.cast, mask, dtype=row_splits_dtype)
            masked_row_lengths = ragged_math_ops.reduce_sum(int_mask, axis=1)
            splits.append(ragged_util.lengths_to_splits(masked_row_lengths))
          mask = mask.values
          data = data.values

        # Recursively apply the nested non-ragged mask to the nested data.
        masked_values = boolean_mask(data, mask)

        # Add the ragged `splits` back to the result.
        masked_values = ragged_tensor.RaggedTensor.from_nested_row_splits(
            masked_values, splits, validate=False)

        return masked_values

    # If mask is non-ragged and has rank 1, and data is ragged, then build a
    # ragged tensor with the indicated rows.
    elif ragged_tensor.is_ragged(data) and mask.shape.ndims == 1:
      # Get the masked splits: first get the length of each row, then filter
      # out the rows that we are deleting, and convert that filtered set of
      # masks back to a splits tensor.
      lengths = data.row_lengths()
      masked_lengths = array_ops.boolean_mask(lengths, mask)
      masked_splits = ragged_util.lengths_to_splits(masked_lengths)

      # Get the masked values: first get row ids corresponding to each
      # value, then use tf.gather to build a boolean mask that's false for
      # values that come from rows that we are deleting, and use that mask to
      # construct the masked values tensor.
      segment_ids = segment_id_ops.row_splits_to_segment_ids(data.row_splits)
      segment_mask = array_ops.gather(mask, segment_ids)
      masked_values = boolean_mask(data.values, segment_mask)

      return ragged_tensor.RaggedTensor.from_row_splits(
          masked_values, masked_splits, validate=False)

    # If mask is non-ragged and has rank>1, then convert it to be ragged,
    # with a ragged rank matching data.
    if ragged_tensor.is_ragged(data):
      mask = ragged_tensor.RaggedTensor.from_tensor(
          mask,
          ragged_rank=min(data.ragged_rank, mask.shape.ndims - 1),
          row_splits_dtype=data.row_splits.dtype)
      return boolean_mask(data, mask)

    # Otherwise, data and mask are both `Tensor`s.
    else:
      # Apply `boolean_mask` to get the masked values.
      masked_values = array_ops.boolean_mask(data, mask)

      if mask.shape.ndims >= 2:
        # Add the innermost ragged dimension.  For each innermost cell, get the
        # number of values it contains.  Then flatten that to get a list of
        # cell lengths, and convert it to splits.  Finally, combine the splits
        # and values to get the innermost ragged tensor.
        masked_lengths = math_ops.count_nonzero(
            mask, axis=-1, dtype=row_splits_dtype)
        flattened_masked_lengths = array_ops.reshape(masked_lengths, [-1])
        masked_values = ragged_tensor.RaggedTensor.from_row_lengths(
            masked_values, flattened_masked_lengths, validate=False)

        # Wrap remaining ragged dimensions.
        if mask.shape.ndims > 2:
          mask_shape = array_ops.shape(mask, out_type=row_splits_dtype)
          split_size = math_ops.cumprod(mask_shape) + 1
          for dim in range(mask.shape.ndims - 3, -1, -1):
            elt_size = mask_shape[dim + 1]
            masked_splits = math_ops.range(split_size[dim]) * elt_size
            masked_values = ragged_tensor.RaggedTensor.from_row_splits(
                masked_values, masked_splits, validate=False)

      return masked_values
Exemplo n.º 8
0
def _ragged_stack_concat_helper(rt_inputs, axis, stack_values):
    """Helper function to concatenate or stack ragged tensors.

  Args:
    rt_inputs: A list of RaggedTensors or Tensors to combine.
    axis: The axis along which to concatenate or stack.
    stack_values: A boolean -- if true, then stack values; otherwise,
      concatenate them.

  Returns:
    A RaggedTensor.
  Raises:
    ValueError: If rt_inputs is empty, or if axis is out of range.
  """
    # Validate parameters.
    if not rt_inputs:
        raise ValueError('rt_inputs may not be empty.')

    # Convert input tensors.
    rt_inputs = [
        ragged_tensor.convert_to_tensor_or_ragged_tensor(rt_input,
                                                         name='rt_input')
        for rt_input in rt_inputs
    ]
    row_splits_dtype, rt_inputs = ragged_tensor.match_row_splits_dtypes(
        *rt_inputs, return_dtype=True)
    rt_inputs = list(rt_inputs)

    # Special case: if there's only one input, then return it as-is.
    if len(rt_inputs) == 1:
        if stack_values:
            return ragged_array_ops.expand_dims(rt_inputs[0], axis=axis)
        else:
            return rt_inputs[0]

    # Check the rank (number of dimensions) of the input tensors.
    ndims = None
    for rt in rt_inputs:
        if ndims is None:
            ndims = rt.shape.ndims
        else:
            rt.shape.assert_has_rank(ndims)

    out_ndims = ndims if (ndims is None or not stack_values) else ndims + 1
    axis = ragged_util.get_positive_axis(axis, out_ndims)

    # If all the inputs are Tensors, and we're combining the final dimension,
    # then we can delegate to the tf.stack/tf.concat operation, and return a
    # Tensor.
    if all(not ragged_tensor.is_ragged(rt) for rt in rt_inputs):
        if ndims is not None and (axis == out_ndims - 1 or axis == ndims - 1):
            if stack_values:
                return array_ops.stack(rt_inputs, axis)
            else:
                return array_ops.concat(rt_inputs, axis)

    # Convert any Tensor inputs to RaggedTensors.  This makes it
    # possible to concatenate Tensors and RaggedTensors together.
    for i in range(len(rt_inputs)):
        if not ragged_tensor.is_ragged(rt_inputs[i]):
            rt_inputs[i] = ragged_tensor.RaggedTensor.from_tensor(
                rt_inputs[i], ragged_rank=1, row_splits_dtype=row_splits_dtype)

    # Convert the input tensors to all have the same ragged_rank.
    ragged_rank = max(max(rt.ragged_rank for rt in rt_inputs), 1)
    rt_inputs = [
        _increase_ragged_rank_to(rt, ragged_rank, row_splits_dtype)
        for rt in rt_inputs
    ]

    if axis == 0:
        return _ragged_stack_concat_axis_0(rt_inputs, stack_values)
    elif axis == 1:
        return _ragged_stack_concat_axis_1(rt_inputs, stack_values)
    else:  # axis > 1: recurse.
        values = [rt.values for rt in rt_inputs]
        splits = [[rt_input.row_splits] for rt_input in rt_inputs]
        with ops.control_dependencies(ragged_util.assert_splits_match(splits)):
            return ragged_tensor.RaggedTensor.from_row_splits(
                _ragged_stack_concat_helper(values, axis - 1, stack_values),
                splits[0][0],
                validate=False)
Exemplo n.º 9
0
def ragged_assert_compatible_and_get_flat_values(values, mask=None):
  """If ragged, it checks the compatibility and then returns the flat_values.

     Note: If two tensors are dense, it does not check their compatibility.
     Note: Although two ragged tensors with different ragged ranks could have
           identical overall rank and dimension sizes and hence be compatible,
           we do not support those cases.
  Args:
     values: A list of potentially ragged tensor of the same ragged_rank.
     mask: A potentially ragged tensor of the same ragged_rank as elements in
       Values.

  Returns:
     A tuple in which the first element is the list of tensors and the second
     is the mask tensor. ([Values], mask). Mask and the element in Values
     are equal to the flat_values of the input arguments (if they were ragged).
  """
  if isinstance(values, list):
    is_all_ragged = \
        all(isinstance(rt, ragged_tensor.RaggedTensor) for rt in values)
    is_any_ragged = \
        any(isinstance(rt, ragged_tensor.RaggedTensor) for rt in values)
  else:
    is_all_ragged = isinstance(values, ragged_tensor.RaggedTensor)
    is_any_ragged = is_all_ragged
  if (is_all_ragged and
      ((mask is None) or isinstance(mask, ragged_tensor.RaggedTensor))):
    to_be_stripped = False
    if not isinstance(values, list):
      values = [values]
      to_be_stripped = True

    # NOTE: we leave the flat_values compatiblity to
    # tf.TensorShape `assert_is_compatible_with`
    # check if both dynamic dimensions are equal and then use the flat_values.
    nested_row_split_list = [rt.nested_row_splits for rt in values]
    assertion_list = ragged_util.assert_splits_match(nested_row_split_list)

    # if both are ragged sample_weights also should be ragged with same dims.
    if isinstance(mask, ragged_tensor.RaggedTensor):
      assertion_list_for_mask = ragged_util.assert_splits_match(
          [nested_row_split_list[0], mask.nested_row_splits])
      tmp = control_flow_ops.with_dependencies(assertion_list_for_mask,
                                               mask.flat_values)
      mask = array_ops.expand_dims(tmp, -1)

    # values has at least 1 element.
    flat_values = []
    for value in values:
      tmp = control_flow_ops.with_dependencies(assertion_list,
                                               value.flat_values)
      flat_values.append(array_ops.expand_dims(tmp, -1))

    values = flat_values[0] if to_be_stripped else flat_values

  elif is_any_ragged:
    raise TypeError('One of the inputs does not have acceptable types.')
  # values are empty or value are not ragged and mask is ragged.
  elif isinstance(mask, ragged_tensor.RaggedTensor):
    raise TypeError('Ragged mask is not allowed with non-ragged inputs.')

  return values, mask
  def __init__(self, shape, fields):
    """Creates a `StructuredTensor` from a dictionary of fields.

    Args:
      shape: A `TensorShape`: static information about the shape of the
        `StructuredTensor`.  Must have a known `rank`.
      fields: A dictionary mapping from string to `Tensor`, `RaggedTensor`, or
        `StructuredTensor`, providing the values for individual fields in each
        structure.  If `ndims > 0`, then every tensor in `fields` must have the
        same shape in the first `shape.rank` dimensions; and that shape must be
        compatible with `shape`.

    Returns:
      A `StructuredTensor`.
    """
    shape = tensor_shape.as_shape(shape)
    rank = shape.ndims
    if rank is None:
      raise ValueError("StructuredTensor's shape must have known rank.")
    if not isinstance(fields, dict):
      raise TypeError('fields must be a dictionary, got %s' %
                      type(fields).__name__)
    self._fields = {}
    with ops.name_scope(None, 'StructuredTensor', fields.values()):
      for (key, value) in fields.items():
        if not isinstance(key, str):
          raise TypeError('Unexpected type for key in `fields`: %r' % key)
        if not _FIELD_NAME_RE.match(key):
          raise ValueError('Field name %r is not currently allowed.' % key)
        if not isinstance(
            value, (ops.Tensor, ragged_tensor.RaggedTensor, StructuredTensor)):
          if ragged_tensor.is_ragged(value):
            value = ragged_tensor.convert_to_tensor_or_ragged_tensor(value)
          else:
            try:
              value = ops.convert_to_tensor(value)
            except (ValueError, TypeError):
              raise TypeError('Unexpected type for value in `fields`: %r' %
                              value)
        self._fields[key] = value

    # Check the static TensorShape for this StructuredTensor.
    self._static_shape = shape
    if rank > 0:
      for value in self._fields.values():
        self._static_shape = self._static_shape.merge_with(value.shape[:rank])

    self._nested_row_splits = []
    if rank > 1:
      # If any fields are ragged, then check that all row-splits match.
      shared_row_splits = []
      for field in self._fields.values():
        # TODO(edloper): A field shouldn't count as ragged if it has
        # uniform_row_length defined for all the dimensions in question.
        if isinstance(field, ragged_tensor.RaggedTensor):
          shared_row_splits.append(field.nested_row_splits[:rank - 1])
        elif isinstance(field, StructuredTensor) and field.ragged_rank > 0:
          shared_row_splits.append(field.nested_row_splits[:rank - 1])
      if shared_row_splits:
        if len(shared_row_splits) != len(self._fields):
          raise ValueError('Ragged StructuredTensor contains non-ragged fields')

        # Check if the splits are identical.  This should be the common case.
        identical_splits = True
        for splits in shared_row_splits[1:]:
          if len(splits) != len(shared_row_splits[0]):
            raise ValueError('Fields have inconsistent ragged_rank')
          for (s1, s2) in zip(splits, shared_row_splits[0]):
            if s1 is not s2:
              identical_splits = False

        if identical_splits:
          self._nested_row_splits = shared_row_splits[0]
        else:
          # If splits aren't identical, then add assertions to check that they
          # match.
          with ops.control_dependencies(
              ragged_util.assert_splits_match(shared_row_splits)):
            self._nested_row_splits = [array_ops.identity(splits)
                                       for splits in shared_row_splits[0]]
Exemplo n.º 11
0
def _ragged_stack_concat_helper(rt_inputs, axis, stack_values):
  """Helper function to concatenate or stack ragged tensors.

  Args:
    rt_inputs: A list of RaggedTensors or Tensors to combine.
    axis: The axis along which to concatenate or stack.
    stack_values: A boolean -- if true, then stack values; otherwise,
      concatenate them.

  Returns:
    A RaggedTensor.
  Raises:
    ValueError: If rt_inputs is empty, or if axis is out of range.
  """
  # Validate parameters.
  if not rt_inputs:
    raise ValueError('rt_inputs may not be empty.')

  # Convert input tensors.
  rt_inputs = [
      ragged_tensor.convert_to_tensor_or_ragged_tensor(
          rt_input, name='rt_input') for rt_input in rt_inputs
  ]
  row_splits_dtype, rt_inputs = ragged_tensor.match_row_splits_dtypes(
      *rt_inputs, return_dtype=True)
  rt_inputs = list(rt_inputs)

  # Special case: if there's only one input, then return it as-is.
  if len(rt_inputs) == 1:
    if stack_values:
      return ragged_array_ops.expand_dims(rt_inputs[0], axis=axis)
    else:
      return rt_inputs[0]

  # Check the rank (number of dimensions) of the input tensors.
  ndims = None
  for rt in rt_inputs:
    if ndims is None:
      ndims = rt.shape.ndims
    else:
      rt.shape.assert_has_rank(ndims)

  out_ndims = ndims if (ndims is None or not stack_values) else ndims + 1
  axis = ragged_util.get_positive_axis(axis, out_ndims)

  # If all the inputs are Tensors, and we're combining the final dimension,
  # then we can delegate to the tf.stack/tf.concat operation, and return a
  # Tensor.
  if all(not ragged_tensor.is_ragged(rt) for rt in rt_inputs):
    if ndims is not None and (axis == out_ndims - 1 or axis == ndims - 1):
      if stack_values:
        return array_ops.stack(rt_inputs, axis)
      else:
        return array_ops.concat(rt_inputs, axis)

  # Convert any Tensor inputs to RaggedTensors.  This makes it
  # possible to concatenate Tensors and RaggedTensors together.
  for i in range(len(rt_inputs)):
    if not ragged_tensor.is_ragged(rt_inputs[i]):
      rt_inputs[i] = ragged_tensor.RaggedTensor.from_tensor(
          rt_inputs[i], ragged_rank=1, row_splits_dtype=row_splits_dtype)

  # Convert the input tensors to all have the same ragged_rank.
  ragged_rank = max(max(rt.ragged_rank for rt in rt_inputs), 1)
  rt_inputs = [_increase_ragged_rank_to(rt, ragged_rank, row_splits_dtype)
               for rt in rt_inputs]

  if axis == 0:
    return _ragged_stack_concat_axis_0(rt_inputs, stack_values)
  elif axis == 1:
    return _ragged_stack_concat_axis_1(rt_inputs, stack_values)
  else:  # axis > 1: recurse.
    values = [rt.values for rt in rt_inputs]
    splits = [[rt_input.row_splits] for rt_input in rt_inputs]
    with ops.control_dependencies(ragged_util.assert_splits_match(splits)):
      return ragged_tensor.RaggedTensor.from_row_splits(
          _ragged_stack_concat_helper(values, axis - 1, stack_values),
          splits[0][0], validate=False)
Exemplo n.º 12
0
def map_flat_values(op, *args, **kwargs):
    """Applies `op` to the values of one or more RaggedTensors.

  Replaces any `RaggedTensor` in `args` or `kwargs` with its `flat_values`
  tensor, and then calls `op`.  Returns a `RaggedTensor` that is constructed
  from the input `RaggedTensor`s' `nested_row_splits` and the value returned by
  the `op`.

  If the input arguments contain multiple `RaggedTensor`s, then they must have
  identical `nested_row_splits`.

  Examples:

  >>> rt = tf.ragged.constant([[1, 2, 3], [], [4, 5], [6]])
  >>> map_flat_values(tf.ones_like, rt).to_list()
  [[1, 1, 1], [], [1, 1], [1]]
  >>> map_flat_values(tf.multiply, rt, rt).to_list()
  [[1, 4, 9], [], [16, 25], [36]]
  >>> map_flat_values(tf.add, rt, 5).to_list()
  [[6, 7, 8], [], [9, 10], [11]]

  Args:
    op: The operation that should be applied to the RaggedTensor `flat_values`.
      `op` is typically an element-wise operation (such as math_ops.add), but
      any operation that preserves the size of the outermost dimension can be
      used.  I.e., `shape[0]` of the value returned by `op` must match
      `shape[0]` of the `RaggedTensor`s' `flat_values` tensors.
    *args: Arguments for `op`.
    **kwargs: Keyword arguments for `op`.

  Returns:
    A `RaggedTensor` whose `ragged_rank` matches the `ragged_rank` of all
    input `RaggedTensor`s.
  Raises:
    ValueError: If args contains no `RaggedTensors`, or if the `nested_splits`
      of the input `RaggedTensor`s are not identical.
  """
    # Replace RaggedTensors with their values; and collect the splits tensors
    # from each RaggedTensor.
    nested_splits_lists = []
    inner_args = _replace_ragged_with_flat_values(args, nested_splits_lists)
    inner_kwargs = _replace_ragged_with_flat_values(kwargs,
                                                    nested_splits_lists)
    if not nested_splits_lists:
        return op(*args, **kwargs)

    split_dtypes = set(splits[0].dtype for splits in nested_splits_lists)
    if len(split_dtypes) > 1:
        if not ragged_config.auto_cast_partition_dtype():
            raise ValueError(
                "Input RaggedTensors have mismatched row_splits dtypes; "
                "use RaggedTensor.with_row_splits_dtype() to convert "
                "them to compatible dtypes.")

        nested_splits_lists = [
            [math_ops.cast(s, dtypes.int64) for s in nested_splits]  # pylint: disable=g-complex-comprehension
            for nested_splits in nested_splits_lists
        ]

    with ops.control_dependencies(
            ragged_util.assert_splits_match(nested_splits_lists)):
        # Delegate to op, and then compose the result from the transformed values
        # and the splits.
        return ragged_tensor.RaggedTensor.from_nested_row_splits(
            op(*inner_args, **inner_kwargs),
            nested_splits_lists[0],
            validate=False)
Exemplo n.º 13
0
def boolean_mask(data, mask, keepdims=False, name=None):
  """Applies a boolean mask to `data`.

  Returns a potentially ragged tensor that is formed by retaining the elements
  in `data` where the corresponding value in `mask` is `True`.

  If `keepdims` is true then outer dimensions (corresponding to the `mask`
  dimensions) are preserved, and:

  * `output[a1...aA, i, b1...bB] = data[a1...aA, j, b1...bB]`

     Where `j` is the `i`th `True` entry of `mask[a1...aA]`.

  If `keepdims` is false, then the outer dimensions are collapsed (similar to
  the behavior of `tf.boolean_mask`), and:

  * `output[i, b1...bB] = data[a1...aA, b1...bB]`

     Where `(a1...aA)` is the `i`th `True` entry of `mask`
     (in row-major order).

  Args:
    data: A potentially ragged tensor.
    mask: A potentially ragged boolean tensor.  `mask`'s shape must be a prefix
      of `data`'s shape.  `rank(mask)` must be known statically.
    keepdims: Whether to preserve the outer dimensions (`keepdims=True`) or
      flatten them (`keepdims=False`).
    name: A name prefix for the returned tensor (optional).

  Returns:
    A potentially ragged tensor that is formed by retaining the elements in
    `data` where the corresponding value in `mask` is `True`.

    If `keepdims` is false:

    * `rank(output) = rank(data) - rank(mask) + 1`.
    * `output.ragged_rank = max(data.ragged_rank - rank(mask) + 1, 0)`.

    If `keepdims` is true:

    * `rank(output) = rank(data)`.
    * `output.ragged_rank = max(data.ragged_rank, rank(mask) - 1)`.

  Raises:
    ValueError: if `rank(mask)` is not known statically; or if `mask.shape` is
      not a prefix of `data.shape`.

  #### Examples:
    ```python
    >>> # Aliases for True & False so data and mask line up.
    >>> T, F = (True, False)

    >>> tf.ragged.boolean_mask(  # Mask a 2D Tensor.  Flatten outer dims.
    ...     data=[[1, 2, 3], [4, 5, 6], [7, 8, 9]],
    ...     mask=[[T, F, T], [F, F, F], [T, F, F]],
    ...     keepdims=False).tolist()
    [1, 3, 7]

    >>> tf.ragged.boolean_mask(  # Mask a 2D Tensor.  Preserve outer dims.
    ...     data=[[1, 2, 3], [4, 5, 6], [7, 8, 9]],
    ...     mask=[[T, F, T], [F, F, F], [T, F, F]],
    ...     keepdims=True).tolist()
    [[1, 3], [], [7]]

    >>> tf.ragged.boolean_mask(  # Mask a 2D RaggedTensor.  Flatten outer dims.
    ...     tf.ragged.constant([[1, 2, 3], [4], [5, 6]]),
    ...     tf.ragged.constant([[F, F, T], [F], [T, T]]),
    ...     keepdims=False).tolist()
    [3, 5, 6]

    >>> tf.ragged.boolean_mask(  # Mask a 2D RaggedTensor.  Preserve outer dims.
    ...     tf.ragged.constant([[1, 2, 3], [4], [5, 6]]),
    ...     tf.ragged.constant([[F, F, T], [F], [T, T]]),
    ...     keepdims=True).tolist()
    [[3], [], [5, 6]]

    >>> tf.ragged.boolean_mask(  # Mask rows of a 2D RaggedTensor.
    ...     tf.ragged.constant([[1, 2, 3], [4], [5, 6]]),
    ...     tf.ragged.constant([True, False, True]),
    ...     keepdims=True).tolist()
    [[1, 2, 3], [5, 6]]
    ```
  """
  with ops.name_scope(name, 'RaggedMask', [data, mask]):
    # Convert inputs to tensors.
    data = ragged_tensor.convert_to_tensor_or_ragged_tensor(data, name='data')
    mask = ragged_tensor.convert_to_tensor_or_ragged_tensor(
        mask, dtypes.bool, name='mask')
    row_splits_dtype, (data, mask) = ragged_tensor.match_row_splits_dtypes(
        data, mask, return_dtype=True)

    # Get static rank of mask.
    if mask.shape.ndims is None:
      raise ValueError('mask.shape.ndims must be known statically.')
    elif mask.shape.ndims == 0:
      raise ValueError('mask cannot be scalar.')

    # If mask is ragged, then recurse with a non-ragged mask.
    if ragged_tensor.is_ragged(mask):
      if not ragged_tensor.is_ragged(data):
        data = ragged_tensor.RaggedTensor.from_tensor(
            data, ragged_rank=mask.ragged_rank,
            row_splits_dtype=mask.row_splits.dtype)
      # Check that mask.nested_row_splits is a prefix of
      # data.nested_row_splits.
      splits_list = [
          mask.nested_row_splits, data.nested_row_splits[:mask.ragged_rank]
      ]
      with ops.control_dependencies(
          ragged_util.assert_splits_match(splits_list)):
        # Strip off ragged `splits` until `mask` is non-ragged.  Keep the splits
        # that we strip off in `splits`, so we can add them back on after
        # we recursively mask the non-ragged data.
        splits = []
        while ragged_tensor.is_ragged(mask):
          if mask.shape.ndims > 2:
            splits.append(mask.row_splits)
          else:
            # Count the number of True mask values in each row to find the
            # lengths of the filtered rows; then convert to splits.
            int_mask = ragged_functional_ops.map_flat_values(
                math_ops.cast, mask, dtype=row_splits_dtype)
            masked_row_lengths = ragged_math_ops.reduce_sum(int_mask, axis=1)
            splits.append(ragged_util.lengths_to_splits(masked_row_lengths))
          mask = mask.values
          data = data.values

        # Recursively apply the nested non-ragged mask to the nested data.
        masked_values = boolean_mask(data, mask, keepdims)

        # Add the ragged `splits` back to the result.
        if keepdims:
          masked_values = ragged_tensor.RaggedTensor.from_nested_row_splits(
              masked_values, splits, validate=False)

        return masked_values

    # If mask is non-ragged and has rank 1, and data is ragged, then build a
    # ragged tensor with the indicated rows.
    elif ragged_tensor.is_ragged(data) and mask.shape.ndims == 1:
      # Get the masked splits: first get the length of each row, then filter
      # out the rows that we are deleting, and convert that filtered set of
      # masks back to a splits tensor.
      lengths = data.row_lengths()
      masked_lengths = array_ops.boolean_mask(lengths, mask)
      masked_splits = ragged_util.lengths_to_splits(masked_lengths)

      # Get the masked values: first get row ids corresponding to each
      # value, then use tf.gather to build a boolean mask that's false for
      # values that come from rows that we are deleting, and use that mask to
      # construct the masked values tensor.
      segment_ids = segment_id_ops.row_splits_to_segment_ids(data.row_splits)
      segment_mask = array_ops.gather(mask, segment_ids)
      masked_values = boolean_mask(data.values, segment_mask, keepdims=False)

      return ragged_tensor.RaggedTensor.from_row_splits(masked_values,
                                                        masked_splits,
                                                        validate=False)

    # If mask is non-ragged and has rank>1, then convert it to be ragged,
    # with a ragged rank matching data.
    if ragged_tensor.is_ragged(data):
      mask = ragged_tensor.RaggedTensor.from_tensor(
          mask, ragged_rank=min(data.ragged_rank, mask.shape.ndims - 1),
          row_splits_dtype=data.row_splits.dtype)
      return boolean_mask(data, mask, keepdims)

    # Otherwise, data and mask are both `Tensor`s.
    else:
      # Apply `boolean_mask` to get the masked values.
      masked_values = array_ops.boolean_mask(data, mask)

      if mask.shape.ndims >= 2 and keepdims:
        # Add the innermost ragged dimension.  For each innermost cell, get the
        # number of values it contains.  Then flatten that to get a list of
        # cell lengths, and convert it to splits.  Finally, combine the splits
        # and values to get the innermost ragged tensor.
        masked_lengths = math_ops.count_nonzero(mask, axis=-1,
                                                dtype=row_splits_dtype)
        flattened_masked_lengths = array_ops.reshape(masked_lengths, [-1])
        masked_values = ragged_tensor.RaggedTensor.from_row_lengths(
            masked_values, flattened_masked_lengths, validate=False)

        # Wrap remaining ragged dimensions.
        if mask.shape.ndims > 2 and keepdims:
          mask_shape = array_ops.shape(mask, out_type=row_splits_dtype)
          split_size = math_ops.cumprod(mask_shape) + 1
          for dim in range(mask.shape.ndims - 3, -1, -1):
            elt_size = mask_shape[dim + 1]
            masked_splits = math_ops.range(split_size[dim]) * elt_size
            masked_values = ragged_tensor.RaggedTensor.from_row_splits(
                masked_values, masked_splits, validate=False)

      return masked_values
Exemplo n.º 14
0
def _broadcast_elementwise_args(elementwise_args):
    """Broadcasts the values of `elementwise_args` to have compatible shapes.

  Args:
    elementwise_args: A dictionary whose keys are potentially ragged tensors.

  Returns:
    A tuple `(broadcast_args, broadcast_splits, checks)` where:

    * `broadcast_args` is a dictionary with the same keys as
      `elementwise_args`, mapping to broadcasted tensors.
    * `broadcast_splits` is the broadcasted nested row splits.
    * `checks` is a possibly empty tuple of assertion operations that should
      be added as control dependencies.

  Raises:
    ValueError: If broadcasting fails.
  """
    # No elementwise arguments were used: nothing to do!
    if not elementwise_args:
        return elementwise_args, (), ()

    # A single elementwise argument was used: no broadcasting necessary.
    if len(elementwise_args) == 1:
        arg = list(elementwise_args.values())[0]
        if ragged_tensor.is_ragged(arg):
            return elementwise_args, arg.nested_row_splits, ()
        else:
            return elementwise_args, (), ()

    # Multiple elementwise arguments.
    else:
        is_ragged = [
            ragged_tensor.is_ragged(t) for t in elementwise_args.values()
        ]
        if not any(is_ragged):
            return elementwise_args, (), ()

        # Support limited broadcasting (namely, scalar + ragged).  Full
        # broadcasting support will be added later.
        if all((ragged_tensor.is_ragged(t) or t.shape.ndims == 0)
               for t in elementwise_args.values()):
            nested_splits_lists = [
                t.nested_row_splits for t in elementwise_args.values()
                if ragged_tensor.is_ragged(t)
            ]
            if len(nested_splits_lists) == 1:
                checks = ()
            else:
                if any(t.shape.ndims is None
                       for t in elementwise_args.values()):
                    raise ValueError(
                        'Ragged elementwise ops require that rank (number '
                        'of dimensions) be statically known.')
                if len(set(t.shape.ndims
                           for t in elementwise_args.values())) != 1:
                    raise ValueError('Ragged elementwise ops do not support '
                                     'broadcasting yet')
                checks = ragged_util.assert_splits_match(nested_splits_lists)
            return (elementwise_args, nested_splits_lists[0], checks)
        else:
            raise ValueError(
                'Ragged elementwise ops do not support broadcasting yet')
Exemplo n.º 15
0
def map_flat_values(op, *args, **kwargs):
    """Applies `op` to the `flat_values` of one or more RaggedTensors.

  Replaces any `RaggedTensor` in `args` or `kwargs` with its `flat_values`
  tensor (which collapses all ragged dimensions), and then calls `op`.  Returns
  a `RaggedTensor` that is constructed from the input `RaggedTensor`s'
  `nested_row_splits` and the value returned by the `op`.

  If the input arguments contain multiple `RaggedTensor`s, then they must have
  identical `nested_row_splits`.

  This operation is generally used to apply elementwise operations to each value
  in a `RaggedTensor`.

  Warning: `tf.ragged.map_flat_values` does *not* apply `op` to each row of a
  ragged tensor.  This difference is important for non-elementwise operations,
  such as `tf.reduce_sum`.  If you wish to apply a non-elementwise operation to
  each row of a ragged tensor, use `tf.map_fn` instead.  (You may need to
  specify an `output_signature` when using `tf.map_fn` with ragged tensors.)

  Examples:

  >>> rt = tf.ragged.constant([[1, 2, 3], [], [4, 5], [6]])
  >>> tf.ragged.map_flat_values(tf.ones_like, rt)
  <tf.RaggedTensor [[1, 1, 1], [], [1, 1], [1]]>
  >>> tf.ragged.map_flat_values(tf.multiply, rt, rt)
  <tf.RaggedTensor [[1, 4, 9], [], [16, 25], [36]]>
  >>> tf.ragged.map_flat_values(tf.add, rt, 5)
  <tf.RaggedTensor [[6, 7, 8], [], [9, 10], [11]]>

  Example with a non-elementwise operation (note that `map_flat_values` and
  `map_fn` return different results):

  >>> rt = tf.ragged.constant([[1.0, 3.0], [], [3.0, 6.0, 3.0]])
  >>> def normalized(x):
  ...   return x / tf.reduce_sum(x)
  >>> tf.ragged.map_flat_values(normalized, rt)
  <tf.RaggedTensor [[0.0625, 0.1875], [], [0.1875, 0.375, 0.1875]]>
  >>> tf.map_fn(normalized, rt)
  <tf.RaggedTensor [[0.25, 0.75], [], [0.25, 0.5, 0.25]]>

  Args:
    op: The operation that should be applied to the RaggedTensor `flat_values`.
      `op` is typically an element-wise operation (such as math_ops.add), but
      any operation that preserves the size of the outermost dimension can be
      used.  I.e., `shape[0]` of the value returned by `op` must match
      `shape[0]` of the `RaggedTensor`s' `flat_values` tensors.
    *args: Arguments for `op`.
    **kwargs: Keyword arguments for `op`.

  Returns:
    A `RaggedTensor` whose `ragged_rank` matches the `ragged_rank` of all
    input `RaggedTensor`s.
  Raises:
    ValueError: If args contains no `RaggedTensors`, or if the `nested_splits`
      of the input `RaggedTensor`s are not identical.
  """
    # Replace RaggedTensors with their values; and collect the splits tensors
    # from each RaggedTensor.
    nested_splits_lists = []
    flat_values_nrows = []
    inner_args = _replace_ragged_with_flat_values(args, nested_splits_lists,
                                                  flat_values_nrows)
    inner_kwargs = _replace_ragged_with_flat_values(kwargs,
                                                    nested_splits_lists,
                                                    flat_values_nrows)
    if not nested_splits_lists:
        return op(*args, **kwargs)
    if flat_values_nrows:
        flat_values_nrows = set(flat_values_nrows)
        if len(flat_values_nrows) != 1:
            raise ValueError(
                "Input RaggedTensors' flat_values must all have the "
                "same outer-dimension size.  Got sizes: %s" %
                flat_values_nrows)
        flat_values_nrows = flat_values_nrows.pop()  # Get the single element
    else:
        flat_values_nrows = None

    split_dtypes = set(splits[0].dtype for splits in nested_splits_lists)
    if len(split_dtypes) > 1:
        if not ragged_config.auto_cast_partition_dtype():
            raise ValueError(
                "Input RaggedTensors have mismatched row_splits dtypes; "
                "use RaggedTensor.with_row_splits_dtype() to convert "
                "them to compatible dtypes.")

        nested_splits_lists = [
            [math_ops.cast(s, dtypes.int64) for s in nested_splits]  # pylint: disable=g-complex-comprehension
            for nested_splits in nested_splits_lists
        ]

    with ops.control_dependencies(
            ragged_util.assert_splits_match(nested_splits_lists)):
        # Delegate to `op`
        op_output = op(*inner_args, **inner_kwargs)
        # Check that the result has the expected shape (if known).
        if flat_values_nrows is not None:
            if not op_output.shape[:1].is_compatible_with([flat_values_nrows]):
                raise ValueError(
                    "tf.ragged.map_flat_values requires that the output of `op` have "
                    "the same outer-dimension size as flat_values of any ragged "
                    "inputs. (output shape: %s; expected outer dimension size: %s)"
                    % (op_output.shape, flat_values_nrows))
        # Compose the result from the transformed values and the splits.
        return ragged_tensor.RaggedTensor.from_nested_row_splits(
            op_output, nested_splits_lists[0], validate=False)