Пример #1
0
  def __init__(self, partitioned_dim_sizes, inner_dim_sizes,
               dim_size_dtype=None):
    """Creates a RaggedTensorDynamicShape.

    Args:
      partitioned_dim_sizes: A `list` of 0-D or 1-D integer `Tensor`, one for
        each partitioned dimension.  If dimension `d` is uniform, then
        `partitioned_dim_sizes[d]` must be an integer scalar, specifying the
        size of all slices across dimension `d`.  If dimension `d` is ragged,
        then `partitioned_dim_sizes[d]` must be an integer vector, specifying
        the size of each slice across dimension `d`.
      inner_dim_sizes: A 1-D integer `Tensor`, whose length is equal to the
        number of inner dimensions.  `inner_dim_sizes[n]` is the size of all
        slices across the `n`th inner dimension (which is the
        `(len(partitioned_dim_sizes)+n)`th dimension in the overall tensor.
      dim_size_dtype: dtype for dimension sizes.  If not specified, then it
        is chosen based on the dtypes of `partitioned_dim_sizes` and
        `inner_dim_sizes`.
    """
    assert isinstance(partitioned_dim_sizes, (list, tuple))

    with ops.name_scope(None, 'RaggedTensorDynamicShape',
                        (partitioned_dim_sizes, inner_dim_sizes)):
      partitioned_dim_sizes = tuple(
          ops.convert_to_tensor(size, name='partitioned_dimension_size_%d' % i)
          for (i, size) in enumerate(partitioned_dim_sizes))
      inner_dim_sizes = ops.convert_to_tensor(
          inner_dim_sizes, name='inner_dim_sizes')

      # Validate shapes.
      if partitioned_dim_sizes:
        for axis, dimension_size in enumerate(partitioned_dim_sizes):
          if dimension_size.shape.ndims is None:
            raise ValueError(
                'rank of partitioned_dim_sizes[%d] is unknown' % axis)
          dimension_size.shape.with_rank_at_most(1)
        if partitioned_dim_sizes[0].shape.ndims == 1:
          raise ValueError('outermost partitioned dimension must be uniform')
        if partitioned_dim_sizes[-1].shape.ndims == 0:
          raise ValueError('innermost partitioned dimension must be ragged')
      inner_dim_sizes.shape.assert_has_rank(1)

      # Convert dimension size tensors to a single dtype.
      if dim_size_dtype is None:
        dim_size_dtypes = set([p.dtype for p in partitioned_dim_sizes
                               if p.shape.ndims == 1])
        if not dim_size_dtypes:
          dim_size_dtype = dtypes.int64
        elif len(dim_size_dtypes) == 1:
          dim_size_dtype = dim_size_dtypes.pop()
        else:
          if not ragged_config.auto_cast_partition_dtype():
            raise ValueError('partitioned_dim_sizes must have matching dtypes')
          dim_size_dtype = dtypes.int64
      partitioned_dim_sizes = tuple(math_ops.cast(p, dim_size_dtype)
                                    for p in partitioned_dim_sizes)
      inner_dim_sizes = math_ops.cast(inner_dim_sizes, dim_size_dtype)

      self._partitioned_dim_sizes = partitioned_dim_sizes
      self._inner_dim_sizes = inner_dim_sizes
Пример #2
0
def _convert_declared_ragged(current, declared):
    """Converts an output with RaggedTensorType into a _RaggedTensorComponents."""
    # Check that the ragged ranks match up.
    # + 1 to account for the rank of the outermost dimension.
    current_ragged_rank = getattr(current, "ragged_rank", 0)
    if declared.ragged_rank != current_ragged_rank + 1:
        raise ValueError(
            "The declared ragged rank (%d) mismatches the result (%d)" %
            (declared.ragged_rank, current_ragged_rank + 1))

    # Check that dtypes match up.
    if declared.dtype != current.dtype:
        raise ValueError("The declared dtype (%s) mismatches the result (%s)" %
                         (declared.dtype, current.dtype))
    if (isinstance(current, ragged_tensor.RaggedTensor)
            and declared.row_splits_dtype != current.row_splits.dtype):
        if not ragged_config.auto_cast_partition_dtype():
            raise ValueError(
                "The declared row_splits dtype (%s) mismatches the result (%s)."
                "  Use RaggedTensor.with_row_splits_dtype to convert it." %
                (declared.row_splits_dtype, current.row_splits.dtype))
        current = current.with_row_splits_dtype(declared.row_splits_dtype)

    if isinstance(current, ragged_tensor.RaggedTensor):
        return current
    else:
        nrows = array_ops.shape(current, out_type=declared.row_splits_dtype)[0]
        row_length = array_ops.expand_dims(nrows, axis=0)
        return _RaggedTensorComponents(flat_values=current,
                                       nested_row_lengths=(),
                                       outer_row_length=row_length)
Пример #3
0
  def __init__(self, partitioned_dim_sizes, inner_dim_sizes,
               dim_size_dtype=None):
    """Creates a RaggedTensorDynamicShape.

    Args:
      partitioned_dim_sizes: A `list` of 0-D or 1-D integer `Tensor`, one for
        each partitioned dimension.  If dimension `d` is uniform, then
        `partitioned_dim_sizes[d]` must be an integer scalar, specifying the
        size of all slices across dimension `d`.  If dimension `d` is ragged,
        then `partitioned_dim_sizes[d]` must be an integer vector, specifying
        the size of each slice across dimension `d`.
      inner_dim_sizes: A 1-D integer `Tensor`, whose length is equal to the
        number of inner dimensions.  `inner_dim_sizes[n]` is the size of all
        slices across the `n`th inner dimension (which is the
        `(len(partitioned_dim_sizes)+n)`th dimension in the overall tensor.
      dim_size_dtype: dtype for dimension sizes.  If not specified, then it
        is chosen based on the dtypes of `partitioned_dim_sizes` and
        `inner_dim_sizes`.
    """
    assert isinstance(partitioned_dim_sizes, (list, tuple))

    with ops.name_scope(None, 'RaggedTensorDynamicShape',
                        (partitioned_dim_sizes, inner_dim_sizes)):
      partitioned_dim_sizes = tuple(
          ops.convert_to_tensor(size, name='partitioned_dimension_size_%d' % i)
          for (i, size) in enumerate(partitioned_dim_sizes))
      inner_dim_sizes = ops.convert_to_tensor(
          inner_dim_sizes, name='inner_dim_sizes')

      # Validate shapes.
      if partitioned_dim_sizes:
        for axis, dimension_size in enumerate(partitioned_dim_sizes):
          if dimension_size.shape.ndims is None:
            raise ValueError(
                'rank of partitioned_dim_sizes[%d] is unknown' % axis)
          dimension_size.shape.with_rank_at_most(1)
        if partitioned_dim_sizes[0].shape.ndims == 1:
          raise ValueError('outermost partitioned dimension must be uniform')
        if partitioned_dim_sizes[-1].shape.ndims == 0:
          raise ValueError('innermost partitioned dimension must be ragged')
      inner_dim_sizes.shape.assert_has_rank(1)

      # Convert dimension size tensors to a single dtype.
      if dim_size_dtype is None:
        dim_size_dtypes = set([p.dtype for p in partitioned_dim_sizes
                               if p.shape.ndims == 1])
        if not dim_size_dtypes:
          dim_size_dtype = dtypes.int64
        elif len(dim_size_dtypes) == 1:
          dim_size_dtype = dim_size_dtypes.pop()
        else:
          if not ragged_config.auto_cast_partition_dtype():
            raise ValueError('partitioned_dim_sizes must have matching dtypes')
          dim_size_dtype = dtypes.int64
      partitioned_dim_sizes = tuple(math_ops.cast(p, dim_size_dtype)
                                    for p in partitioned_dim_sizes)
      inner_dim_sizes = math_ops.cast(inner_dim_sizes, dim_size_dtype)

      self._partitioned_dim_sizes = partitioned_dim_sizes
      self._inner_dim_sizes = inner_dim_sizes
Пример #4
0
def _convert_declared_ragged(current, declared):
  """Converts an output with RaggedTensorType into a _RaggedTensorComponents."""
  # Check that the ragged ranks match up.
  # + 1 to account for the rank of the outermost dimension.
  current_ragged_rank = getattr(current, "ragged_rank", 0)
  if declared.ragged_rank != current_ragged_rank + 1:
    raise ValueError(
        "The declared ragged rank (%d) mismatches the result (%d)" %
        (declared.ragged_rank, current_ragged_rank + 1))

  # Check that dtypes match up.
  if declared.dtype != current.dtype:
    raise ValueError(
        "The declared dtype (%s) mismatches the result (%s)" %
        (declared.dtype, current.dtype))
  if (isinstance(current, ragged_tensor.RaggedTensor) and
      declared.row_splits_dtype != current.row_splits.dtype):
    if not ragged_config.auto_cast_partition_dtype():
      raise ValueError(
          "The declared row_splits dtype (%s) mismatches the result (%s)."
          "  Use RaggedTensor.with_row_splits_dtype to convert it."
          % (declared.row_splits_dtype, current.row_splits.dtype))
    current = current.with_row_splits_dtype(declared.row_splits_dtype)

  if isinstance(current, ragged_tensor.RaggedTensor):
    return current
  else:
    nrows = array_ops.shape(current, out_type=declared.row_splits_dtype)[0]
    row_length = array_ops.expand_dims(nrows, axis=0)
    return _RaggedTensorComponents(
        flat_values=current,
        nested_row_lengths=(),
        outer_row_length=row_length)
Пример #5
0
def _broadcast_to_ragged_shape(rt_input, dst_shape, broadcast_inner_dimensions):
  """Broadcasts rt_input to the ragged shape `dst_shape`."""
  # Check that rt_input and dst_shape have the same row_splits dtype.
  if (isinstance(rt_input, ragged_tensor.RaggedTensor) and
      rt_input.row_splits.dtype != dst_shape.dim_size_dtype):
    if not ragged_config.auto_cast_partition_dtype():
      raise ValueError('rt_input and dst_shape have different row_split '
                       'dtypes; use RaggedTensor.with_row_splits_dtype() or '
                       'RaggedTensorDynamicShape.with_dim_size_dtype() to '
                       'convert to a compatible dtype.')
    rt_input = rt_input.with_row_splits_dtype(dtypes.int64)
    dst_shape = dst_shape.with_dim_size_dtype(dtypes.int64)

  # dst_shape's rank and ragged_rank must be greater than or equal to rt_input's
  if rt_input.shape.ndims is None or dst_shape.rank is None:
    raise ValueError('Unable to broadcast: unknown rank')
  if rt_input.shape.ndims > dst_shape.rank:
    raise ValueError('Incompatible with shape: rank mismatch')
  if (isinstance(rt_input, ragged_tensor.RaggedTensor) and
      rt_input.ragged_rank >= dst_shape.num_partitioned_dimensions):
    raise ValueError('Incompatible with shape: ragged rank mismatch')

  src_shape = RaggedTensorDynamicShape.from_tensor(rt_input)
  src_shape = src_shape.broadcast_to_rank(dst_shape.rank)

  # Add dimensions to rt_input so its rank and ragged_rank matches dst_shape.
  if dst_shape.rank > rt_input.shape.ndims:
    if rt_input.shape.ndims < dst_shape.num_inner_dimensions + 1:
      rt_input = array_ops.reshape(
          rt_input, array_ops.concat([[-1], dst_shape.inner_dim_sizes], axis=0))
    for _ in range(dst_shape.rank - rt_input.shape.ndims):
      if ragged_tensor.is_ragged(rt_input):
        nrows = rt_input.nrows()
      else:
        nrows = array_ops.shape(rt_input,
                                out_type=dst_shape.dim_size_dtype)[0]
      rt_input = ragged_tensor.RaggedTensor.from_row_lengths(rt_input, [nrows],
                                                             validate=False)

  # Add ragged dimensions to match dst_shape.
  if ragged_tensor.is_ragged(rt_input):
    inner_rank_diff = (
        rt_input.flat_values.shape.ndims - 1 - dst_shape.num_inner_dimensions)
    if inner_rank_diff > 0:
      rt_input = rt_input.with_flat_values(
          ragged_tensor.RaggedTensor.from_tensor(
              rt_input.flat_values, ragged_rank=inner_rank_diff,
              row_splits_dtype=dst_shape.dim_size_dtype))
  else:
    rt_input = ragged_tensor.RaggedTensor.from_tensor(
        rt_input, ragged_rank=dst_shape.num_partitioned_dimensions - 1,
        row_splits_dtype=dst_shape.dim_size_dtype)

  # Do broadcasting for any dimensions that will remain uniform.  We can do
  # these all at once, since they're independent of one another.
  multiples = [1] * dst_shape.rank
  for axis in range(dst_shape.num_partitioned_dimensions):
    if not src_shape.is_ragged(axis) and not dst_shape.is_ragged(axis):
      src_size = src_shape.dimension_size(axis)
      dst_size = dst_shape.dimension_size(axis)
      if ((tensor_util.constant_value(src_size) in (1, None)) and
          (tensor_util.constant_value(dst_size) != 1)):
        multiples[axis] = array_ops.where(
            math_ops.equal(src_size, 1), dst_size, 1)
  if not all(isinstance(v, int) and v == 1 for v in multiples):
    multiples = array_ops.stack(multiples, axis=0)
    rt_input = ragged_array_ops.tile(rt_input, multiples)

  if broadcast_inner_dimensions:
    rt_input = rt_input.with_flat_values(
        array_ops.reshape(
            rt_input.flat_values,
            array_ops.concat([[-1], dst_shape.inner_dim_sizes], axis=0)))

  # Do broadcasting for dimensions that become ragged.  We must do these from
  # outermost to innermost.
  for axis in range(dst_shape.num_partitioned_dimensions):
    if not src_shape.is_ragged(axis) and dst_shape.is_ragged(axis):
      dst_size = dst_shape.dimension_size(axis)
      rt_input = _ragged_tile_axis(rt_input, axis, dst_size,
                                   dst_shape.dim_size_dtype)

  return rt_input
Пример #6
0
def _broadcast_to_ragged_shape(rt_input, dst_shape,
                               broadcast_inner_dimensions):
    """Broadcasts rt_input to the ragged shape `dst_shape`."""
    # Check that rt_input and dst_shape have the same row_splits dtype.
    if (isinstance(rt_input, ragged_tensor.RaggedTensor)
            and rt_input.row_splits.dtype != dst_shape.dim_size_dtype):
        if not ragged_config.auto_cast_partition_dtype():
            raise ValueError(
                'rt_input and dst_shape have different row_split '
                'dtypes; use RaggedTensor.with_row_splits_dtype() or '
                'RaggedTensorDynamicShape.with_dim_size_dtype() to '
                'convert to a compatible dtype.')
        rt_input = rt_input.with_row_splits_dtype(dtypes.int64)
        dst_shape = dst_shape.with_dim_size_dtype(dtypes.int64)

    # dst_shape's rank and ragged_rank must be greater than or equal to rt_input's
    if rt_input.shape.ndims is None or dst_shape.rank is None:
        raise ValueError('Unable to broadcast: unknown rank')
    if rt_input.shape.ndims > dst_shape.rank:
        raise ValueError('Incompatible with shape: rank mismatch')
    if (isinstance(rt_input, ragged_tensor.RaggedTensor)
            and rt_input.ragged_rank >= dst_shape.num_partitioned_dimensions):
        raise ValueError('Incompatible with shape: ragged rank mismatch')

    src_shape = RaggedTensorDynamicShape.from_tensor(rt_input)
    src_shape = src_shape.broadcast_to_rank(dst_shape.rank)

    # Add dimensions to rt_input so its rank and ragged_rank matches dst_shape.
    if dst_shape.rank > rt_input.shape.ndims:
        if rt_input.shape.ndims < dst_shape.num_inner_dimensions + 1:
            rt_input = array_ops.reshape(
                rt_input,
                array_ops.concat([[-1], dst_shape.inner_dim_sizes], axis=0))
        for _ in range(dst_shape.rank - rt_input.shape.ndims):
            if ragged_tensor.is_ragged(rt_input):
                nrows = rt_input.nrows()
            else:
                nrows = array_ops.shape(rt_input,
                                        out_type=dst_shape.dim_size_dtype)[0]
            rt_input = ragged_tensor.RaggedTensor.from_row_lengths(
                rt_input, [nrows], validate=False)

    # Add ragged dimensions to match dst_shape.
    if ragged_tensor.is_ragged(rt_input):
        inner_rank_diff = (rt_input.flat_values.shape.ndims - 1 -
                           dst_shape.num_inner_dimensions)
        if inner_rank_diff > 0:
            rt_input = rt_input.with_flat_values(
                ragged_tensor.RaggedTensor.from_tensor(
                    rt_input.flat_values,
                    ragged_rank=inner_rank_diff,
                    row_splits_dtype=dst_shape.dim_size_dtype))
    else:
        rt_input = ragged_tensor.RaggedTensor.from_tensor(
            rt_input,
            ragged_rank=dst_shape.num_partitioned_dimensions - 1,
            row_splits_dtype=dst_shape.dim_size_dtype)

    # Do broadcasting for any dimensions that will remain uniform.  We can do
    # these all at once, since they're independent of one another.
    multiples = [1] * dst_shape.rank
    for axis in range(dst_shape.num_partitioned_dimensions):
        if not src_shape.is_ragged(axis) and not dst_shape.is_ragged(axis):
            src_size = src_shape.dimension_size(axis)
            dst_size = dst_shape.dimension_size(axis)
            if ((tensor_util.constant_value(src_size) in (1, None))
                    and (tensor_util.constant_value(dst_size) != 1)):
                multiples[axis] = array_ops.where(math_ops.equal(src_size, 1),
                                                  dst_size, 1)
    if not all(isinstance(v, int) and v == 1 for v in multiples):
        multiples = array_ops.stack(multiples, axis=0)
        rt_input = ragged_array_ops.tile(rt_input, multiples)

    if broadcast_inner_dimensions:
        rt_input = rt_input.with_flat_values(
            array_ops.reshape(
                rt_input.flat_values,
                array_ops.concat([[-1], dst_shape.inner_dim_sizes], axis=0)))

    # Do broadcasting for dimensions that become ragged.  We must do these from
    # outermost to innermost.
    for axis in range(dst_shape.num_partitioned_dimensions):
        if not src_shape.is_ragged(axis) and dst_shape.is_ragged(axis):
            dst_size = dst_shape.dimension_size(axis)
            rt_input = _ragged_tile_axis(rt_input, axis, dst_size,
                                         dst_shape.dim_size_dtype)

    return rt_input
Пример #7
0
def map_flat_values(op, *args, **kwargs):
    """Applies `op` to the values of one or more RaggedTensors.

  Replaces any `RaggedTensor` in `args` or `kwargs` with its `flat_values`
  tensor, and then calls `op`.  Returns a `RaggedTensor` that is constructed
  from the input `RaggedTensor`s' `nested_row_splits` and the value returned by
  the `op`.

  If the input arguments contain multiple `RaggedTensor`s, then they must have
  identical `nested_row_splits`.

  Examples:

  ```python
  >>> rt = ragged.constant([[1, 2, 3], [], [4, 5], [6]])
  >>> ragged.map_flat_values(tf.ones_like, rt).eval().tolist()
  [[1, 1, 1], [], [1, 1], [1]]
  >>> ragged.map_flat_values(tf.multiply, rt, rt).eval().tolist()
  [[1, 4, 9], [], [16, 25], [36]]
  >>> ragged.map_flat_values(tf.add, rt, 5).eval().tolist()
  [[6, 7, 8], [], [9, 10], [11]]
  ```

  Args:
    op: The operation that should be applied to the RaggedTensor `flat_values`.
      `op` is typically an element-wise operation (such as math_ops.add), but
      any operation that preserves the size of the outermost dimension can be
      used.  I.e., `shape[0]` of the value returned by `op` must match
      `shape[0]` of the `RaggedTensor`s' `flat_values` tensors.
    *args: Arguments for `op`.
    **kwargs: Keyword arguments for `op`.

  Returns:
    A `RaggedTensor` whose `ragged_rank` matches the `ragged_rank` of all
    input `RaggedTensor`s.
  Raises:
    ValueError: If args contains no `RaggedTensors`, or if the `nested_splits`
      of the input `RaggedTensor`s are not identical.
  """
    # Replace RaggedTensors with their values; and collect the splits tensors
    # from each RaggedTensor.
    nested_splits_lists = []
    inner_args = _replace_ragged_with_flat_values(args, nested_splits_lists)
    inner_kwargs = _replace_ragged_with_flat_values(kwargs,
                                                    nested_splits_lists)
    if not nested_splits_lists:
        return op(*args, **kwargs)

    split_dtypes = set(splits[0].dtype for splits in nested_splits_lists)
    if len(split_dtypes) > 1:
        if not ragged_config.auto_cast_partition_dtype():
            raise ValueError(
                "Input RaggedTensors have mismatched row_splits dtypes; "
                "use RaggedTensor.with_row_splits_dtype() to convert "
                "them to compatible dtypes.")

        nested_splits_lists = [
            [math_ops.cast(s, dtypes.int64) for s in nested_splits]  # pylint: disable=g-complex-comprehension
            for nested_splits in nested_splits_lists
        ]

    with ops.control_dependencies(
            ragged_util.assert_splits_match(nested_splits_lists)):
        # Delegate to op, and then compose the result from the transformed values
        # and the splits.
        return ragged_tensor.RaggedTensor.from_nested_row_splits(
            op(*inner_args, **inner_kwargs),
            nested_splits_lists[0],
            validate=False)
Пример #8
0
def map_flat_values(op, *args, **kwargs):
    """Applies `op` to the `flat_values` of one or more RaggedTensors.

  Replaces any `RaggedTensor` in `args` or `kwargs` with its `flat_values`
  tensor (which collapses all ragged dimensions), and then calls `op`.  Returns
  a `RaggedTensor` that is constructed from the input `RaggedTensor`s'
  `nested_row_splits` and the value returned by the `op`.

  If the input arguments contain multiple `RaggedTensor`s, then they must have
  identical `nested_row_splits`.

  This operation is generally used to apply elementwise operations to each value
  in a `RaggedTensor`.

  Warning: `tf.ragged.map_flat_values` does *not* apply `op` to each row of a
  ragged tensor.  This difference is important for non-elementwise operations,
  such as `tf.reduce_sum`.  If you wish to apply a non-elementwise operation to
  each row of a ragged tensor, use `tf.map_fn` instead.  (You may need to
  specify an `output_signature` when using `tf.map_fn` with ragged tensors.)

  Examples:

  >>> rt = tf.ragged.constant([[1, 2, 3], [], [4, 5], [6]])
  >>> tf.ragged.map_flat_values(tf.ones_like, rt)
  <tf.RaggedTensor [[1, 1, 1], [], [1, 1], [1]]>
  >>> tf.ragged.map_flat_values(tf.multiply, rt, rt)
  <tf.RaggedTensor [[1, 4, 9], [], [16, 25], [36]]>
  >>> tf.ragged.map_flat_values(tf.add, rt, 5)
  <tf.RaggedTensor [[6, 7, 8], [], [9, 10], [11]]>

  Example with a non-elementwise operation (note that `map_flat_values` and
  `map_fn` return different results):

  >>> rt = tf.ragged.constant([[1.0, 3.0], [], [3.0, 6.0, 3.0]])
  >>> def normalized(x):
  ...   return x / tf.reduce_sum(x)
  >>> tf.ragged.map_flat_values(normalized, rt)
  <tf.RaggedTensor [[0.0625, 0.1875], [], [0.1875, 0.375, 0.1875]]>
  >>> tf.map_fn(normalized, rt)
  <tf.RaggedTensor [[0.25, 0.75], [], [0.25, 0.5, 0.25]]>

  Args:
    op: The operation that should be applied to the RaggedTensor `flat_values`.
      `op` is typically an element-wise operation (such as math_ops.add), but
      any operation that preserves the size of the outermost dimension can be
      used.  I.e., `shape[0]` of the value returned by `op` must match
      `shape[0]` of the `RaggedTensor`s' `flat_values` tensors.
    *args: Arguments for `op`.
    **kwargs: Keyword arguments for `op`.

  Returns:
    A `RaggedTensor` whose `ragged_rank` matches the `ragged_rank` of all
    input `RaggedTensor`s.
  Raises:
    ValueError: If args contains no `RaggedTensors`, or if the `nested_splits`
      of the input `RaggedTensor`s are not identical.
  """
    # Replace RaggedTensors with their values; and collect the partitions tensors
    # from each RaggedTensor.
    partition_lists = []
    flat_values_nrows = []
    inner_args = _replace_ragged_with_flat_values(args, partition_lists,
                                                  flat_values_nrows)
    inner_kwargs = _replace_ragged_with_flat_values(kwargs, partition_lists,
                                                    flat_values_nrows)
    if not partition_lists:
        return op(*args, **kwargs)

    # If we can statically determine that the inputs are incompatible, then raise
    # an error.  (We can't guarantee full compatibility statically, so we need to
    # perform some runtime checks too; but this allows us to fail sooner in some
    # cases.)
    if flat_values_nrows:
        flat_values_nrows = set(flat_values_nrows)
        if len(flat_values_nrows) != 1:
            raise ValueError(
                "Input RaggedTensors' flat_values must all have the "
                "same outer-dimension size.  Got sizes: %s" %
                flat_values_nrows)
        flat_values_nrows = flat_values_nrows.pop()  # Get the single element
    else:
        flat_values_nrows = None

    partition_dtypes = set(p[0].dtype for p in partition_lists)
    if len(partition_dtypes) > 1:
        if not ragged_config.auto_cast_partition_dtype():
            raise ValueError(
                "Input RaggedTensors have mismatched row partition "
                "dtypes; use RaggedTensor.with_row_splits_dtype() to "
                "convert them to compatible dtypes.")

        partition_lists = [
            [p.with_dtype(dtypes.int64) for p in partition_list]  # pylint: disable=g-complex-comprehension
            for partition_list in partition_lists
        ]

    # Delegate to `op`
    op_output = op(*inner_args, **inner_kwargs)
    # Check that the result has the expected shape (if known).
    if flat_values_nrows is not None:
        if not op_output.shape[:1].is_compatible_with([flat_values_nrows]):
            raise ValueError(
                "tf.ragged.map_flat_values requires that the output of `op` have "
                "the same outer-dimension size as flat_values of any ragged "
                "inputs. (output shape: %s; expected outer dimension size: %s)"
                % (op_output.shape, flat_values_nrows))
    # Compose the result from the transformed values and the partitions.
    return ragged_tensor.RaggedTensor._from_nested_row_partitions(  # pylint: disable=protected-access
        op_output,
        _merge_partition_lists(partition_lists),
        validate=False)
def map_flat_values(op, *args, **kwargs):
  """Applies `op` to the values of one or more RaggedTensors.

  Replaces any `RaggedTensor` in `args` or `kwargs` with its `flat_values`
  tensor, and then calls `op`.  Returns a `RaggedTensor` that is constructed
  from the input `RaggedTensor`s' `nested_row_splits` and the value returned by
  the `op`.

  If the input arguments contain multiple `RaggedTensor`s, then they must have
  identical `nested_row_splits`.

  Examples:

  ```python
  >>> rt = ragged.constant([[1, 2, 3], [], [4, 5], [6]])
  >>> ragged.map_flat_values(tf.ones_like, rt).eval().tolist()
  [[1, 1, 1], [], [1, 1], [1]]
  >>> ragged.map_flat_values(tf.multiply, rt, rt).eval().tolist()
  [[1, 4, 9], [], [16, 25], [36]]
  >>> ragged.map_flat_values(tf.add, rt, 5).eval().tolist()
  [[6, 7, 8], [], [9, 10], [11]]
  ```

  Args:
    op: The operation that should be applied to the RaggedTensor `flat_values`.
      `op` is typically an element-wise operation (such as math_ops.add), but
      any operation that preserves the size of the outermost dimension can be
      used.  I.e., `shape[0]` of the value returned by `op` must match
      `shape[0]` of the `RaggedTensor`s' `flat_values` tensors.
    *args: Arguments for `op`.
    **kwargs: Keyword arguments for `op`.

  Returns:
    A `RaggedTensor` whose `ragged_rank` matches the `ragged_rank` of all
    input `RaggedTensor`s.
  Raises:
    ValueError: If args contains no `RaggedTensors`, or if the `nested_splits`
      of the input `RaggedTensor`s are not identical.
  """
  # Replace RaggedTensors with their values; and collect the splits tensors
  # from each RaggedTensor.
  nested_splits_lists = []
  inner_args = _replace_ragged_with_flat_values(args, nested_splits_lists)
  inner_kwargs = _replace_ragged_with_flat_values(kwargs, nested_splits_lists)
  if not nested_splits_lists:
    return op(*args, **kwargs)

  split_dtypes = set(splits[0].dtype for splits in nested_splits_lists)
  if len(split_dtypes) > 1:
    if not ragged_config.auto_cast_partition_dtype():
      raise ValueError("Input RaggedTensors have mismatched row_splits dtypes; "
                       "use RaggedTensor.with_row_splits_dtype() to convert "
                       "them to compatible dtypes.")

    nested_splits_lists = [
        [math_ops.cast(s, dtypes.int64) for s in nested_splits]  # pylint: disable=g-complex-comprehension
        for nested_splits in nested_splits_lists]

  with ops.control_dependencies(
      ragged_util.assert_splits_match(nested_splits_lists)):
    # Delegate to op, and then compose the result from the transformed values
    # and the splits.
    return ragged_tensor.RaggedTensor.from_nested_row_splits(
        op(*inner_args, **inner_kwargs), nested_splits_lists[0], validate=False)
Пример #10
0
def map_flat_values(op, *args, **kwargs):
  """Applies `op` to the values of one or more RaggedTensors.

  Replaces any `RaggedTensor` in `args` or `kwargs` with its `flat_values`
  tensor, and then calls `op`.  Returns a `RaggedTensor` that is constructed
  from the input `RaggedTensor`s' `nested_row_splits` and the value returned by
  the `op`.

  If the input arguments contain multiple `RaggedTensor`s, then they must have
  identical `nested_row_splits`.

  Examples:

  >>> rt = tf.ragged.constant([[1, 2, 3], [], [4, 5], [6]])
  >>> map_flat_values(tf.ones_like, rt).to_list()
  [[1, 1, 1], [], [1, 1], [1]]
  >>> map_flat_values(tf.multiply, rt, rt).to_list()
  [[1, 4, 9], [], [16, 25], [36]]
  >>> map_flat_values(tf.add, rt, 5).to_list()
  [[6, 7, 8], [], [9, 10], [11]]

  Args:
    op: The operation that should be applied to the RaggedTensor `flat_values`.
      `op` is typically an element-wise operation (such as math_ops.add), but
      any operation that preserves the size of the outermost dimension can be
      used.  I.e., `shape[0]` of the value returned by `op` must match
      `shape[0]` of the `RaggedTensor`s' `flat_values` tensors.
    *args: Arguments for `op`.
    **kwargs: Keyword arguments for `op`.

  Returns:
    A `RaggedTensor` whose `ragged_rank` matches the `ragged_rank` of all
    input `RaggedTensor`s.
  Raises:
    ValueError: If args contains no `RaggedTensors`, or if the `nested_splits`
      of the input `RaggedTensor`s are not identical.
  """
  # Replace RaggedTensors with their values; and collect the splits tensors
  # from each RaggedTensor.
  nested_splits_lists = []
  flat_values_nrows = []
  inner_args = _replace_ragged_with_flat_values(args, nested_splits_lists,
                                                flat_values_nrows)
  inner_kwargs = _replace_ragged_with_flat_values(kwargs, nested_splits_lists,
                                                  flat_values_nrows)
  if not nested_splits_lists:
    return op(*args, **kwargs)
  if flat_values_nrows:
    flat_values_nrows = set(flat_values_nrows)
    if len(flat_values_nrows) != 1:
      raise ValueError("Input RaggedTensors' flat_values must all have the "
                       "same outer-dimension size.  Got sizes: %s" %
                       flat_values_nrows)
    flat_values_nrows = flat_values_nrows.pop()  # Get the single element
  else:
    flat_values_nrows = None

  split_dtypes = set(splits[0].dtype for splits in nested_splits_lists)
  if len(split_dtypes) > 1:
    if not ragged_config.auto_cast_partition_dtype():
      raise ValueError("Input RaggedTensors have mismatched row_splits dtypes; "
                       "use RaggedTensor.with_row_splits_dtype() to convert "
                       "them to compatible dtypes.")

    nested_splits_lists = [
        [math_ops.cast(s, dtypes.int64) for s in nested_splits]  # pylint: disable=g-complex-comprehension
        for nested_splits in nested_splits_lists]

  with ops.control_dependencies(
      ragged_util.assert_splits_match(nested_splits_lists)):
    # Delegate to `op`
    op_output = op(*inner_args, **inner_kwargs)
    # Check that the result has the expected shape (if known).
    if flat_values_nrows is not None:
      if not op_output.shape[:1].is_compatible_with([flat_values_nrows]):
        raise ValueError(
            "tf.ragged.map_flat_values requires that the output of `op` have "
            "the same outer-dimension size as flat_values of any ragged "
            "inputs. (output shape: %s; expected outer dimension size: %s)" %
            (op_output.shape, flat_values_nrows))
    # Compose the result from the transformed values and the splits.
    return ragged_tensor.RaggedTensor.from_nested_row_splits(
        op_output, nested_splits_lists[0], validate=False)