Esempio n. 1
0
def where_v2(condition, x=None, y=None, name=None):
    """Return the elements where `condition` is `True`.

  : If both `x` and `y` are None: Retrieve indices of true elements.

    Returns the coordinates of true elements of `condition`. The coordinates
    are returned in a 2-D tensor with shape
    `[num_true_values, dim_size(condition)]`, where `result[i]` is the
    coordinates of the `i`th true value (in row-major order).

  : If both `x` and `y` are non-`None`: Multiplex between `x` and `y`.

    Choose an output shape  from the shapes of `condition`, `x`, and `y` that
    all three shapes are broadcastable to; and then use the broadcasted
    `condition` tensor as a mask that chooses whether the corredsponding element
    in the output should be taken from `x` (if `condition` is true) or `y` (if
    `condition` is false).

  >>> # Example: retrieve indices of true elements
  >>> tf.where(tf.ragged.constant([[True, False], [True]]))
  <tf.Tensor: shape=(2, 2), dtype=int64, numpy= array([[0, 0], [1, 0]])>

  >>> # Example: multiplex between `x` and `y`
  >>> tf.where(tf.ragged.constant([[True, False], [True, False, True]]),
  ...          tf.ragged.constant([['A', 'B'], ['C', 'D', 'E']]),
  ...          tf.ragged.constant([['a', 'b'], ['c', 'd', 'e']]))
  <tf.RaggedTensor [[b'A', b'b'], [b'C', b'd', b'E']]>

  Args:
    condition: A potentially ragged tensor of type `bool`
    x: A potentially ragged tensor (optional).
    y: A potentially ragged tensor (optional).  Must be specified if `x` is
      specified.  Must have the same rank and type as `x`.
    name: A name of the operation (optional).

  Returns:
    : If both `x` and `y` are `None`:
      A `Tensor` with shape `(num_true, rank(condition))`.
    : Otherwise:
      A potentially ragged tensor with the same type as `x` and `y`, and whose
      shape is broadcast-compatible with `x`, `y`, and `condition`.

  Raises:
    ValueError: When exactly one of `x` or `y` is non-`None`; or when
      `condition`, `x`, and `y` have incompatible shapes.
  """
    if (x is None) != (y is None):
        raise ValueError('x and y must be either both None or both non-None')

    with ops.name_scope('RaggedWhere', name, [condition, x, y]):
        condition = ragged_tensor.convert_to_tensor_or_ragged_tensor(
            condition, name='condition')
        if x is None:
            return _coordinate_where(condition)
        else:
            x = ragged_tensor.convert_to_tensor_or_ragged_tensor(x, name='x')
            y = ragged_tensor.convert_to_tensor_or_ragged_tensor(y, name='y')
            condition, x, y = ragged_tensor.match_row_splits_dtypes(
                condition, x, y)
            return _elementwise_where_v2(condition, x, y)
Esempio n. 2
0
    def handle(self, args, kwargs):
        # Extract the binary args.
        if len(args) > 1:
            x = args[0]
            y = args[1]
            args = args[2:]
        elif args:
            kwargs = kwargs.copy()
            x = args[0]
            y = kwargs.pop(self._y, None)
            args = args[1:]
        else:
            kwargs = kwargs.copy()
            x = kwargs.pop(self._x, None)
            y = kwargs.pop(self._y, None)

        # Bail if we don't have at least one ragged argument.
        x_is_ragged = ragged_tensor.is_ragged(x)
        y_is_ragged = ragged_tensor.is_ragged(y)
        if not (x_is_ragged or y_is_ragged):
            return self.NOT_SUPPORTED

        # Convert args to tensors.  Bail if conversion fails.
        try:
            x = ragged_tensor.convert_to_tensor_or_ragged_tensor(
                x,
                name=self._x,
                preferred_dtype=(y.dtype if y_is_ragged else None))
            y = ragged_tensor.convert_to_tensor_or_ragged_tensor(
                y,
                name=self._y,
                preferred_dtype=(x.dtype if x_is_ragged else None))
        except (TypeError, ValueError):
            return self.NOT_SUPPORTED

        if x_is_ragged and y_is_ragged:
            x, y = ragged_tensor.match_row_splits_dtypes(x, y)

        if ((x_is_ragged and y_is_ragged)
                or (x_is_ragged and x.flat_values.shape.ndims <= y.shape.ndims)
                or
            (y_is_ragged and y.flat_values.shape.ndims <= x.shape.ndims)):
            bcast_shape = ragged_tensor_shape.broadcast_dynamic_shape(
                ragged_tensor_shape.RaggedTensorDynamicShape.from_tensor(x),
                ragged_tensor_shape.RaggedTensorDynamicShape.from_tensor(y))
            x = ragged_tensor_shape.broadcast_to(
                x, bcast_shape, broadcast_inner_dimensions=False)
            y = ragged_tensor_shape.broadcast_to(
                y, bcast_shape, broadcast_inner_dimensions=False)

        x_values = x.flat_values if ragged_tensor.is_ragged(x) else x
        y_values = y.flat_values if ragged_tensor.is_ragged(y) else y
        mapped_values = self._original_op(x_values, y_values, *args, **kwargs)
        if ragged_tensor.is_ragged(x):
            return x.with_flat_values(mapped_values)
        else:
            return y.with_flat_values(mapped_values)
Esempio n. 3
0
  def handle(self, args, kwargs):
    # Extract the binary args.
    if len(args) > 1:
      x = args[0]
      y = args[1]
      args = args[2:]
    elif args:
      kwargs = kwargs.copy()
      x = args[0]
      y = kwargs.pop(self._y, None)
      args = args[1:]
    else:
      kwargs = kwargs.copy()
      x = kwargs.pop(self._x, None)
      y = kwargs.pop(self._y, None)

    # Bail if we don't have at least one ragged argument.
    x_is_ragged = ragged_tensor.is_ragged(x)
    y_is_ragged = ragged_tensor.is_ragged(y)
    if not (x_is_ragged or y_is_ragged):
      return self.NOT_SUPPORTED

    # Convert args to tensors.  Bail if conversion fails.
    try:
      if not x_is_ragged:
        x = ops.convert_to_tensor(x, name=self._x, preferred_dtype=y.dtype)
      if not y_is_ragged:
        y = ops.convert_to_tensor(y, name=self._y, preferred_dtype=x.dtype)
    except (TypeError, ValueError):
      return self.NOT_SUPPORTED

    if x_is_ragged and y_is_ragged:
      x, y = ragged_tensor.match_row_splits_dtypes(x, y)

    if ((x_is_ragged and y_is_ragged) or
        (x_is_ragged and x.flat_values.shape.ndims <= y.shape.ndims) or
        (y_is_ragged and y.flat_values.shape.ndims <= x.shape.ndims)):
      bcast_shape = ragged_tensor_shape.broadcast_dynamic_shape(
          ragged_tensor_shape.RaggedTensorDynamicShape.from_tensor(x),
          ragged_tensor_shape.RaggedTensorDynamicShape.from_tensor(y))
      x = ragged_tensor_shape.broadcast_to(
          x, bcast_shape, broadcast_inner_dimensions=False)
      y = ragged_tensor_shape.broadcast_to(
          y, bcast_shape, broadcast_inner_dimensions=False)

    x_values = x.flat_values if ragged_tensor.is_ragged(x) else x
    y_values = y.flat_values if ragged_tensor.is_ragged(y) else y
    mapped_values = self._original_op(x_values, y_values, *args, **kwargs)
    if ragged_tensor.is_ragged(x):
      return x.with_flat_values(mapped_values)
    else:
      return y.with_flat_values(mapped_values)
Esempio n. 4
0
 def handle(self, args, kwargs):
     if args:
         x, args = args[0], args[1:]
     else:
         kwargs = kwargs.copy()
         x = kwargs.pop(self._x, None)
     if x is None:
         return self.NOT_SUPPORTED
     if self._arg_is_list:
         found_ragged = False
         for elt in x:
             if ragged_tensor.is_ragged(elt):
                 found_ragged = True
             elif not _is_convertible_to_tensor(elt):
                 return self.NOT_SUPPORTED
         if found_ragged:
             x = [
                 ragged_tensor.convert_to_tensor_or_ragged_tensor(elt)
                 if ragged_tensor.is_ragged(elt) else elt for elt in x
             ]
             x = ragged_tensor.match_row_splits_dtypes(*x)
             ragged_elts = [
                 elt for elt in x if ragged_tensor.is_ragged(elt)
             ]
             nested_splits_lists = [
                 elt.nested_row_splits for elt in ragged_elts
             ]
             flat_values = [
                 elt.flat_values if ragged_tensor.is_ragged(elt) else elt
                 for elt in x
             ]
             with ops.control_dependencies(
                     ragged_util.assert_splits_match(nested_splits_lists)):
                 return ragged_elts[0].with_flat_values(
                     self._original_op(flat_values, *args, **kwargs))
         else:
             return self.NOT_SUPPORTED
     else:
         found_ragged = ragged_tensor.is_ragged(x)
         if found_ragged:
             x = ragged_tensor.convert_to_tensor_or_ragged_tensor(
                 x, name=self._x)
             mapped_values = self._original_op(x.flat_values, *args,
                                               **kwargs)
             return x.with_flat_values(mapped_values)
         else:
             return self.NOT_SUPPORTED
Esempio n. 5
0
def ragged_binary_elementwise_op(op, x, y):
    """Binary elementwise api handler for RaggedTensors."""
    x_is_ragged = ragged_tensor.is_ragged(x)
    y_is_ragged = ragged_tensor.is_ragged(y)

    # Convert args to tensors.
    x = ragged_tensor.convert_to_tensor_or_ragged_tensor(
        x, preferred_dtype=(y.dtype if y_is_ragged else None))
    y = ragged_tensor.convert_to_tensor_or_ragged_tensor(
        y, preferred_dtype=x.dtype)

    if x_is_ragged and y_is_ragged:
        x, y = ragged_tensor.match_row_splits_dtypes(x, y)

    # Perform broadcasting, when appropraite
    if ((x_is_ragged and y_is_ragged)
            or (x_is_ragged and x.flat_values.shape.ndims <= y.shape.ndims)
            or (y_is_ragged and y.flat_values.shape.ndims <= x.shape.ndims)):
        # If both x and y are ragged, they must have the same row_splits_dtype now.
        if x_is_ragged:
            dim_size_dtype = x.row_splits.dtype
        else:
            dim_size_dtype = y.row_splits.dtype

        shape_x = ragged_tensor_shape.RaggedTensorDynamicShape.from_tensor(
            x, dim_size_dtype=dim_size_dtype)
        shape_y = ragged_tensor_shape.RaggedTensorDynamicShape.from_tensor(
            y, dim_size_dtype=dim_size_dtype)
        bcast_shape = ragged_tensor_shape.broadcast_dynamic_shape(
            shape_x, shape_y)
        x = ragged_tensor_shape.broadcast_to(x,
                                             bcast_shape,
                                             broadcast_inner_dimensions=False)
        y = ragged_tensor_shape.broadcast_to(y,
                                             bcast_shape,
                                             broadcast_inner_dimensions=False)

    x_values = x.flat_values if ragged_tensor.is_ragged(x) else x
    y_values = y.flat_values if ragged_tensor.is_ragged(y) else y
    mapped_values = op(x_values, y_values)
    if isinstance(mapped_values, bool):
        return mapped_values  # Special case for tensor_equals.
    if ragged_tensor.is_ragged(x):
        return x.with_flat_values(mapped_values)
    else:
        return y.with_flat_values(mapped_values)
Esempio n. 6
0
 def handle(self, args, kwargs):
   if args:
     x, args = args[0], args[1:]
   else:
     kwargs = kwargs.copy()
     x = kwargs.pop(self._x, None)
   if x is None:
     return self.NOT_SUPPORTED
   if self._arg_is_list:
     found_ragged = False
     for elt in x:
       if ragged_tensor.is_ragged(elt):
         found_ragged = True
       elif not _is_convertible_to_tensor(elt):
         return self.NOT_SUPPORTED
     if found_ragged:
       x = ragged_tensor.match_row_splits_dtypes(*x)
       nested_splits_lists = [
           elt.nested_row_splits for elt in x if ragged_tensor.is_ragged(elt)
       ]
       flat_values = [
           elt.flat_values if ragged_tensor.is_ragged(elt) else elt
           for elt in x
       ]
       with ops.control_dependencies(
           ragged_util.assert_splits_match(nested_splits_lists)):
         return ragged_tensor.RaggedTensor.from_nested_row_splits(
             self._original_op(flat_values, *args, **kwargs),
             nested_splits_lists[0], validate=False)
     else:
       return self.NOT_SUPPORTED
   else:
     found_ragged = ragged_tensor.is_ragged(x)
     if found_ragged:
       mapped_values = self._original_op(x.flat_values, *args, **kwargs)
       return x.with_flat_values(mapped_values)
     else:
       return self.NOT_SUPPORTED
def _ragged_segment_aggregate(unsorted_segment_op,
                              data,
                              segment_ids,
                              num_segments,
                              separator=None,
                              name=None):
  """Aggregates along segments of a RaggedTensor using `unsorted_segment_op`.

  Returns a RaggedTensor `output` with `num_segments` rows, where the row
  `output[i]` is formed by combining all rows of `data` whose corresponding
  `segment_id` is `i`.  The values in each row are combined using
  `unsorted_segment_op`.

  The length of the row `output[i]` will be the maximum of the lengths of
  all rows of `data` whose corresponding `segment_id` is `i`.  If no `data`
  rows correspond to a given segment ID, then the output row for that segment
  ID will be empty.

  Args:
    unsorted_segment_op: The tensorflow `op` that should be used to combine
      values in each row.  Must have the same signature and basic behavior as
      `unsorted_segment_sum`, `unsorted_segment_max`, etc.
    data: A `RaggedTensor` containing the values to be combined.
    segment_ids: A `Tensor` or `RaggedTensor`.  Must have type `int64` or
      `int32`.  `segment_ids.shape` must be a prefix of `data.shape`.
      `segment_ids` is not required to be sorted.
    num_segments: An `int32` or `int64` scalar.
    separator: An optional string. Defaults to None. The separator to use when
      joining. Only used for string types.
    name: A name prefix for the returned tensor (optional).

  Returns:
    A `RaggedTensor` containing the aggregated values.  The returned tensor
    has the same dtype as `data`, and its shape is
    `[num_segments] + data.shape[segment_ids.rank:]`.
  Raises:
    ValueError: If segment_ids.shape is not a prefix of data.shape.
  """
  if not (ragged_tensor.is_ragged(data) or
          ragged_tensor.is_ragged(segment_ids)):
    if separator is not None:
      # It uses unsorted_segment_join.
      return unsorted_segment_op(data, segment_ids, num_segments, separator,
                                 name)
    else:
      return unsorted_segment_op(data, segment_ids, num_segments, name)

  with ops.name_scope(name, 'RaggedSegment',
                      [data, segment_ids, num_segments]) as name:
    data = ragged_tensor.convert_to_tensor_or_ragged_tensor(data, name='data')
    segment_ids = ragged_tensor.convert_to_tensor_or_ragged_tensor(
        segment_ids, name='segment_ids')
    data, segment_ids = ragged_tensor.match_row_splits_dtypes(data, segment_ids)
    if segment_ids.dtype not in (dtypes.int32, dtypes.int64):
      raise ValueError('segment_ids must have dtype int32 or int64.')

    if ragged_tensor.is_ragged(segment_ids):
      if not ragged_tensor.is_ragged(data):
        raise ValueError('segment_ids.shape must be a prefix of data.shape, '
                         'but segment_ids is ragged and data is not.')
      check_splits = check_ops.assert_equal(
          segment_ids.row_splits,
          data.row_splits,
          message='segment_ids.shape must be a prefix of data.shape')
      with ops.control_dependencies([check_splits]):
        return _ragged_segment_aggregate(unsorted_segment_op, data.values,
                                         segment_ids.values, num_segments,
                                         separator)

    # Find the length of each row in data.  (shape=[data_nrows])
    data_row_lengths = data.row_splits[1:] - data.row_splits[:-1]

    # Find the length that each output row will have.  The length of the row
    # corresponding to segment `id` is `max(data_row_lengths[i])` where
    # `segment_ids[i]=id`.  (shape=[output_nrows])
    output_row_lengths = math_ops.maximum(
        math_ops.unsorted_segment_max(data_row_lengths, segment_ids,
                                      num_segments), 0)

    # Build the splits tensor for the output RaggedTensor.
    output_splits = array_ops.concat([
        array_ops.zeros([1], output_row_lengths.dtype),
        math_ops.cumsum(output_row_lengths)
    ],
                                     axis=0)

    # For each row in `data`, find the start & limit position where that row's
    # values will be aggregated in output.values.
    data_row_to_out_row_start = array_ops.gather(output_splits, segment_ids)
    data_row_to_out_row_limit = data_row_to_out_row_start + data_row_lengths

    # For each value in `data.values`, find the position where it will
    # aggregated in `output.values`.
    # Get the target output values index for each data values index.
    data_val_to_out_val_index = range(data_row_to_out_row_start,
                                      data_row_to_out_row_limit).values

    # Recursively aggregate the values.
    output_values = _ragged_segment_aggregate(unsorted_segment_op, data.values,
                                              data_val_to_out_val_index,
                                              output_splits[-1], separator)
    return ragged_tensor.RaggedTensor.from_row_splits(
        output_values, output_splits, validate=False)
def batch_gather(params, indices, name=None):
  """Gathers slices from `params` according to `indices` with batch dims.

  This operation is similar to `gather`, but it assumes that the leading `N`
  dimensions of `indices` and `params` are batch dimensions, and performs a
  gather within each batch.  In particular, when using this operation with `N`
  batch dimensions `B1...BN`:

  * `indices` has shape `[B1...BN, I]`
  * `params` has shape `[B1...BN, P1...PM]`.
  * `result` has shape `[B1...BN, I, P2...PM]`.
  * `result[b1...bN, i, p2...pM] =
    params[b1...bN, indices[b1...bN, i], p2...pM]`

  Args:
    params: A potentially ragged tensor with shape `[B1...BN, P1...PM]` (`N>=0`,
      `M>0`).
    indices: A potentially ragged tensor with shape `[B1...BN, I]` (`N>=0`).
    name: A name for the operation (optional).

  Returns:
    A potentially ragged tensor with shape `[B1...BN, I, P2...PM]`.
    `result.ragged_rank = max(indices.ragged_rank, params.ragged_rank)`.

  #### Example:
    ```python
    >>> params = tf.ragged.constant([['a', 'b', 'c'], ['d'], [], ['e']])
    >>> indices = tf.ragged.constant([[1, 2, 0], [], [], [0, 0]])
    >>> tf.compat.v1.batch_gather(params, indices)
    [['b', 'c', 'a'], [], [], ['e', 'e']]
    ```
  """
  if not (ragged_tensor.is_ragged(params) or ragged_tensor.is_ragged(indices)):
    return array_ops.batch_gather(params, indices, name)

  with ops.name_scope(name, 'RaggedBatchGather', [params, indices]):
    params = ragged_tensor.convert_to_tensor_or_ragged_tensor(
        params, name='params')
    indices = ragged_tensor.convert_to_tensor_or_ragged_tensor(
        indices, name='indices')
    params, indices = ragged_tensor.match_row_splits_dtypes(params, indices)
    indices_ndims = indices.shape.ndims
    if indices_ndims is None:
      raise ValueError(
          'batch_gather does not allow indices with unknown shape.')
    if indices_ndims == 0:
      raise ValueError('indices.rank must be at least 1.')

    if ragged_tensor.is_ragged(indices):
      # If the outermost ragged dimension is a batch dimension, recurse.
      if indices_ndims > 2:
        if not ragged_tensor.is_ragged(params):
          raise ValueError('batch shape from indices does '
                           'not match params shape')
        checks = [check_ops.assert_equal(params.row_splits, indices.row_splits)]
        with ops.control_dependencies(checks):
          return ragged_tensor.RaggedTensor.from_row_splits(
              batch_gather(params.values, indices.values), indices.row_splits,
              validate=False)

      # Otherwise, indices is a 2D ragged tensor with 1 ragged dimension.
      else:
        # Ensure that `params` is ragged and has at least 2 dimensions.
        if not ragged_tensor.is_ragged(params):
          if params.shape.ndims is not None and params.shape.ndims < 2:
            raise ValueError('batch shape from indices does '
                             'not match params shape')
          params = ragged_tensor.RaggedTensor.from_tensor(
              params, ragged_rank=1,
              row_splits_dtype=indices.row_splits.dtype)

        # Adjust indices from within-batch to global (in params.values), and
        # then use ragged.gather to gather them.
        num_indices = indices.row_lengths()
        params_starts = params.row_starts()
        adjustments = ragged_util.repeat(params_starts, num_indices, axis=0)
        adjusted_index_values = (
            math_ops.cast(indices.values, adjustments.dtype) + adjustments)
        return ragged_tensor.RaggedTensor.from_row_splits(
            ragged_gather_ops.gather(params.values, adjusted_index_values),
            indices.row_splits, validate=False)

    else:  # params is a RaggedTensor and indices is a Tensor.
      if indices_ndims == 1:
        return ragged_gather_ops.gather(params, indices)
      elif indices_ndims == 2:
        # Adjust indices from batch-local to global (in params.values)
        adjustments = array_ops.expand_dims(params.row_starts(), 1)
        adjusted_indices = (
            math_ops.cast(indices, adjustments.dtype) + adjustments)
        return ragged_gather_ops.gather(params.values, adjusted_indices)
      else:
        raise ValueError('batch shape from indices does not match params shape')
def gather(params,
           indices,
           validate_indices=None,
           axis=0,
           batch_dims=0,
           name=None):
    """Gathers ragged slices from `params` axis `0` according to `indices`.

  Returns `RaggedTensor` output, such that:

  ```python
  output.shape = indices.shape + params.shape[1:]
  output.ragged_rank = indices.shape.ndims + params.ragged_rank
  output[i...j, d0...dn] = params[indices[i...j], d0...dn]
  ```

  `params` may be ragged.  `indices` may be ragged.
  `indices` must have dtype `int32` or `int64`. If any index is out of bounds,
  then an error is returned.

  Examples:

  >>> params = tf.constant(['a', 'b', 'c', 'd', 'e'])
  >>> indices = tf.constant([3, 1, 2, 1, 0])
  >>> ragged_params = tf.ragged.constant([['a', 'b', 'c'], ['d'], [], ['e']])
  >>> ragged_indices = tf.ragged.constant([[3, 1, 2], [1], [], [0]])

  >>> tf.gather(params, ragged_indices)
  <tf.RaggedTensor [[b'd', b'b', b'c'], [b'b'], [], [b'a']]>

  >>> tf.gather(ragged_params, indices)
  <tf.RaggedTensor [[b'e'], [b'd'], [], [b'd'], [b'a', b'b', b'c']]>

  >>> tf.gather(ragged_params, ragged_indices)
  <tf.RaggedTensor [[[b'e'], [b'd'], []], [[b'd']], [], [[b'a', b'b', b'c']]]>

  Args:
    params: The potentially ragged tensor from which to gather values. Must be
      at least rank 1.
    indices: The potentially ragged tensor indicating which values to gather.
      Must have dtype `int32` or `int64`.  Values must be in the range `[0,
      params.shape[0]]`.
    validate_indices: Ignored.
    axis: Must be zero.
    batch_dims: Must be zero.
    name: A name for the operation (optional).

  Returns:
    A `RaggedTensor`, where `output.dtype=params.dtype` and
    `output.shape=indices.shape + params.shape[1:]` and
    `output.ragged_rank=indices.shape.ndims + params.ragged_rank`.

  Raises:
    ValueError: If indices.shape.ndims is not known statically.
  """
    del validate_indices
    if not isinstance(axis, int) or axis != 0:
        raise ValueError('axis != 0 is not supported for ragged gather yet.')
    if not isinstance(batch_dims, int) or batch_dims != 0:
        raise ValueError(
            'batch_dims != 0 is not supported for ragged gather yet.')
    with ops.name_scope(name, 'RaggedGather', [params, indices]):
        params = ragged_tensor.convert_to_tensor_or_ragged_tensor(
            params, name='params')
        indices = ragged_tensor.convert_to_tensor_or_ragged_tensor(
            indices, name='indices')
        params, indices = ragged_tensor.match_row_splits_dtypes(
            params, indices)

        if ragged_tensor.is_ragged(indices):
            return indices.with_values(gather(params, indices.values))

        if not ragged_tensor.is_ragged(params):
            return array_ops.gather(params, indices)

        indices = ops.convert_to_tensor(indices)
        if indices.shape.ndims is None:
            raise ValueError('indices.shape.ndims must be known statically')

        result = gen_ragged_array_ops.ragged_gather(
            indices=indices,
            params_dense_values=params.flat_values,
            params_nested_splits=params.nested_row_splits,
            OUTPUT_RAGGED_RANK=indices.shape.ndims +
            len(params.nested_row_splits) - 1)

        # Compose the RaggedTensor from splits & values.
        return ragged_tensor.RaggedTensor.from_nested_row_splits(
            result.output_dense_values,
            result.output_nested_splits,
            validate=False)
def gather_nd(params, indices, batch_dims=0, name=None):
    """Gather slices from `params` using `n`-dimensional indices.

  This operation is similar to `gather`, but it uses the innermost dimension
  of `indices` to define a slice into `params`.  In particular, if:

  * `indices` has shape `[A1...AN, I]`
  * `params` has shape `[B1...BM]`

  Then:

  * `result` has shape `[A1...AN, B_{I+1}...BM]`.
  * `result[a1...aN] = params[indices[a1...aN, :]]`

  Args:
    params: A potentially ragged tensor with shape `[A1...AN, I]`.
    indices: A potentially ragged tensor with shape `[B1...BM]`.
    batch_dims: Must be zero.
    name: A name for the operation (optional).

  Returns:
    A potentially ragged tensor with shape `[A1...AN, B_{I+1}...BM]`.

  #### Examples:

  >>> params = tf.ragged.constant(
  ...     [ [ ['000', '001'], ['010'              ]          ],
  ...       [ ['100'       ], ['110', '111', '112'], ['120'] ],
  ...       [ [            ], ['210'              ]          ] ])

  >>> # Gather 2D slices from a 3D tensor
  >>> tf.gather_nd(params, [[2], [0]])
  <tf.RaggedTensor [[[], [b'210']], [[b'000', b'001'], [b'010']]]>

  >>> # Gather 1D slices from a 3D tensor
  >>> tf.gather_nd(params, [[2, 1], [0, 0]])
  <tf.RaggedTensor [[b'210'], [b'000', b'001']]>

  >>> # Gather scalars from a 3D tensor
  >>> tf.gather_nd(params, [[0, 0, 1], [1, 1, 2]]).numpy()
  array([b'001', b'112'], dtype=object)
  """
    if not isinstance(batch_dims, int) or batch_dims != 0:
        raise ValueError(
            'batch_dims != 0 is not supported for ragged gather yet.')
    if not (ragged_tensor.is_ragged(params)
            or ragged_tensor.is_ragged(indices)):
        return array_ops.gather_nd(params, indices, name)

    with ops.name_scope(name, 'RaggedGatherNd', [params, indices]):

        params = ragged_tensor.convert_to_tensor_or_ragged_tensor(
            params, name='params')
        indices = ragged_tensor.convert_to_tensor_or_ragged_tensor(
            indices, name='indices')
        params, indices = ragged_tensor.match_row_splits_dtypes(
            params, indices)
        indices_shape = indices.shape
        indices_ndims = indices_shape.ndims
        if indices_ndims is None:
            raise ValueError('indices.rank be statically known.')
        if indices_ndims == 0:
            raise ValueError('indices.rank must be at least 1.')
        if (ragged_tensor.is_ragged(indices)
                and indices_ndims == indices.ragged_rank + 1):
            raise ValueError(
                'The innermost dimension of indices may not be ragged')

        # `index_size` is the "n" in "gather_nd" -- i.e., the number of dimensions
        # that each index slices into.
        index_size = tensor_shape.dimension_value(indices_shape[-1])
        if index_size is None:
            raise ValueError('indices.shape[-1] must be statically known.')

        # If `indices` has more than 2 dimensions, then recurse.  If `indices` is
        # dense, then we convert it to ragged before recursing, and then convert
        # the result back to `dense` if appropriate.
        if indices_ndims > 2:
            indices_is_dense = not ragged_tensor.is_ragged(indices)
            if indices_is_dense:
                indices = ragged_tensor.RaggedTensor.from_tensor(
                    indices,
                    ragged_rank=indices_ndims - 2,
                    row_splits_dtype=params.row_splits.dtype)
            result = indices.with_flat_values(
                gather_nd(params, indices.flat_values))
            if (indices_is_dense and ragged_tensor.is_ragged(result)
                    and result.ragged_rank == indices_ndims - 2):
                result = ragged_tensor.RaggedTensor.to_tensor(result)
            return result

        # indices_ndims <= 2, and the innermost dimension of indices may not be
        # ragged, so `indices` must not be ragged.
        assert not ragged_tensor.is_ragged(indices)
        assert ragged_tensor.is_ragged(params)

        # Handle corner case: An empty index tuple selects the entire `params`
        # value.  So if `index_size` is zero, then tile `params`.
        if index_size == 0:
            params_ndims = params.ragged_rank + array_ops.rank(
                params.flat_values)
            for dim in range(indices_ndims - 1):
                params = ragged_array_ops.expand_dims(params, axis=0)
            multiples = array_ops.concat([
                array_ops.shape(indices)[:-1],
                array_ops.ones([params_ndims], dtypes.int32)
            ],
                                         axis=0)
            return ragged_array_ops.tile(params, multiples)

        # When index_size=1, we can just flatten the index tuples and use gather.
        elif index_size == 1:
            flattened_index_tuples = array_ops.reshape(indices, [-1])
            return gather(params, flattened_index_tuples)

        # Otherwise, params is a RaggedTensor, and indices is a 1D or 2D Tensor.
        # Flatten both the index tuples and the params, such that the flattened
        # index tuples point to the correct values in the flattened params; and
        # then use ragged.gather on the flattened index tuples & params.
        else:
            indices = math_ops.cast(indices, params.row_splits.dtype)

            # Flatten the outermost 2 dimensions of the index tuples & params.
            flattened_index_tuples = array_ops.gather(params.row_splits,
                                                      indices[..., 0])
            flattened_index_tuples += indices[..., 1]
            flattened_params = params.values

            # Flatten any remaining dimensions.
            for dim in range(2, index_size):
                if not ragged_tensor.is_ragged(flattened_params):
                    flattened_index_tuples = array_ops.expand_dims(
                        flattened_index_tuples, axis=1)
                    flattened_index_tuples = array_ops.concat(
                        [flattened_index_tuples, indices[..., dim:]], axis=1)
                    return array_ops.gather_nd(flattened_params,
                                               flattened_index_tuples)

                flattened_index_tuples = array_ops.gather(
                    flattened_params.row_starts(), flattened_index_tuples)
                flattened_index_tuples += indices[..., dim]
                flattened_params = flattened_params.values

            # Gather using the flattened index tuples and params.
            return gather(flattened_params, flattened_index_tuples)
def gather(params,
           indices,
           validate_indices=None,
           axis=None,
           batch_dims=0,
           name=None):
    """Gathers ragged slices from `params` axis `0` according to `indices`.

  See `tf.gather` for full documentation.  (This version has the same API
  as `tf.gather`, but supports ragged `params` and `indices`.)

  Examples:

  >>> params = tf.constant(['a', 'b', 'c', 'd', 'e'])
  >>> indices = tf.constant([3, 1, 2, 1, 0])
  >>> ragged_params = tf.ragged.constant([['a', 'b', 'c'], ['d'], [], ['e']])
  >>> ragged_indices = tf.ragged.constant([[3, 1, 2], [1], [], [0]])

  >>> tf.gather(params, ragged_indices)
  <tf.RaggedTensor [[b'd', b'b', b'c'], [b'b'], [], [b'a']]>

  >>> tf.gather(ragged_params, indices)
  <tf.RaggedTensor [[b'e'], [b'd'], [], [b'd'], [b'a', b'b', b'c']]>

  >>> tf.gather(ragged_params, ragged_indices)
  <tf.RaggedTensor [[[b'e'], [b'd'], []], [[b'd']], [], [[b'a', b'b', b'c']]]>

  Args:
    params: The potentially ragged tensor from which to gather values. Must be
      at least rank 1.
    indices: The potentially ragged tensor indicating which values to gather.
      Must have dtype `int32` or `int64`.  Values must be in the range `[0,
      params.shape[0]]`.
    validate_indices: Ignored.
    axis: The axis in `params` to gather `indices` from.
    batch_dims: The number of batch dimensions.
    name: A name for the operation (optional).

  Returns:
    A `RaggedTensor`, where `output.dtype=params.dtype` and
    `output.shape=indices.shape + params.shape[1:]` and
    `output.ragged_rank=indices.shape.ndims + params.ragged_rank`.

  Raises:
    ValueError: If indices.shape.ndims is not known statically.
  """
    del validate_indices

    with ops.name_scope(name, 'RaggedGather', [params, indices]):
        params = ragged_tensor.convert_to_tensor_or_ragged_tensor(
            params, name='params')
        indices = ragged_tensor.convert_to_tensor_or_ragged_tensor(
            indices, name='indices')
        params, indices = ragged_tensor.match_row_splits_dtypes(
            params, indices)

        if batch_dims != indices.shape.rank:
            batch_dims = array_ops.get_positive_axis(
                batch_dims,
                indices.shape.rank,
                axis_name='batch_dims',
                ndims_name='rank(indices)')
        if params.shape.rank is not None and batch_dims >= params.shape.rank:
            raise ValueError('batch_dims must be less than rank(params)')
        if axis is None:
            axis = batch_dims
        axis = array_ops.get_positive_axis(axis,
                                           params.shape.rank,
                                           ndims_name='rank(params)')
        if axis < batch_dims:
            raise ValueError(
                'axis must be greater than or equal to batch_dims')
        if indices.shape.rank is not None:
            if not 0 <= batch_dims <= indices.shape.rank:
                raise ValueError(
                    'batch_dims=%s must be between 0 and rank(indices)=%s' %
                    (batch_dims, indices.shape.rank))

        return _gather(params, indices, axis, batch_dims)
def batch_gather_with_default(params,
                              indices,
                              default_value='',
                              name=None):
  """Same as `batch_gather` but inserts `default_value` for invalid indices.

  This operation is similar to `batch_gather` except that it will substitute
  the value for invalid indices with `default_value` as the contents.
  See `batch_gather` for more details.


  Args:
    params: A potentially ragged tensor with shape `[B1...BN, P1...PM]` (`N>=0`,
      `M>0`).
    indices: A potentially ragged tensor with shape `[B1...BN, I]` (`N>=0`).
    default_value: A value to be inserted in places where `indices` are out of
      bounds. Must be the same dtype as params and either a scalar or rank 1.
    name: A name for the operation (optional).

  Returns:
    A potentially ragged tensor with shape `[B1...BN, I, P2...PM]`.
    `result.ragged_rank = max(indices.ragged_rank, params.ragged_rank)`.

  #### Example:
    ```python
    >>> params = tf.ragged.constant([
          ['a', 'b', 'c'],
          ['d'],
          [],
          ['e']])
    >>> indices = tf.ragged.constant([[1, 2, -1], [], [], [0, 10]])
    >>> batch_gather_with_default(params, indices, 'FOO')
    [['b', 'c', 'FOO'], [], [], ['e', 'FOO']]
  ```
  """
  with ops.name_scope(name, 'RaggedBatchGatherWithDefault'):
    params = ragged_tensor.convert_to_tensor_or_ragged_tensor(
        params, name='params',
    )
    indices = ragged_tensor.convert_to_tensor_or_ragged_tensor(
        indices, name='indices',
    )
    default_value = ragged_tensor.convert_to_tensor_or_ragged_tensor(
        default_value, name='default_value',
    )
    row_splits_dtype, (params, indices, default_value) = (
        ragged_tensor.match_row_splits_dtypes(params, indices, default_value,
                                              return_dtype=True))
    # TODO(hterry): lift this restriction and support default_values of
    #               of rank > 1
    if (default_value.shape.ndims is not 0
        and default_value.shape.ndims is not 1):
      raise ValueError('"default_value" must be a scalar or vector')
    upper_bounds = None
    if indices.shape.ndims is None:
      raise ValueError('Indices must have a known rank.')
    if params.shape.ndims is None:
      raise ValueError('Params must have a known rank.')

    num_batch_dimensions = indices.shape.ndims - 1
    pad = None
    # The logic for this works as follows:
    # - create a padded params, where:
    #    padded_params[b1...bn, 0] = default_value
    #    padded_params[b1...bn, i] = params[b1...bn, i-1] (i>0)
    # - create an `upper_bounds` Tensor that contains the number of elements
    #   in each innermost rank. Broadcast `upper_bounds` to be the same shape
    #   as `indices`.
    # - check to see which index in `indices` are out of bounds and substitute
    #   it with the index containing `default_value` (the first).
    # - call batch_gather with the indices adjusted.
    with ops.control_dependencies([
        check_ops.assert_greater_equal(array_ops.rank(params),
                                       array_ops.rank(indices))]):
      if ragged_tensor.is_ragged(params):
        row_lengths = ragged_array_ops.expand_dims(
            params.row_lengths(axis=num_batch_dimensions),
            axis=-1)
        upper_bounds = math_ops.cast(row_lengths, indices.dtype)

        pad_shape = _get_pad_shape(params, indices, row_splits_dtype)

        pad = ragged_tensor_shape.broadcast_to(
            default_value, pad_shape)
      else:
        params_shape = array_ops.shape(params)
        pad_shape = array_ops.concat([
            params_shape[:num_batch_dimensions],
            [1],
            params_shape[num_batch_dimensions + 1:params.shape.ndims]
        ], 0)
        upper_bounds = params_shape[num_batch_dimensions]
        pad = array_ops.broadcast_to(default_value, pad_shape)

      # Add `default_value` as the first value in the innermost (ragged) rank.
      pad = math_ops.cast(pad, params.dtype)
      padded_params = array_ops.concat(
          [pad, params], axis=num_batch_dimensions)

      # Adjust the indices by substituting out-of-bound indices to the
      # default-value index (which is the first element)
      shifted_indices = indices + 1
      is_out_of_bounds = (indices < 0) | (indices > upper_bounds)
      adjusted_indices = ragged_where_op.where(
          is_out_of_bounds,
          x=array_ops.zeros_like(indices), y=shifted_indices,
      )
      return array_ops.batch_gather(
          params=padded_params, indices=adjusted_indices, name=name)
Esempio n. 13
0
def _ragged_stack_concat_helper(rt_inputs, axis, stack_values):
    """Helper function to concatenate or stack ragged tensors.

  Args:
    rt_inputs: A list of RaggedTensors or Tensors to combine.
    axis: The axis along which to concatenate or stack.
    stack_values: A boolean -- if true, then stack values; otherwise,
      concatenate them.

  Returns:
    A RaggedTensor.
  Raises:
    ValueError: If rt_inputs is empty, or if axis is out of range.
  """
    # Validate parameters.
    if not rt_inputs:
        raise ValueError('rt_inputs may not be empty.')

    # Convert input tensors.
    rt_inputs = [
        ragged_tensor.convert_to_tensor_or_ragged_tensor(rt_input,
                                                         name='rt_input')
        for rt_input in rt_inputs
    ]
    row_splits_dtype, rt_inputs = ragged_tensor.match_row_splits_dtypes(
        *rt_inputs, return_dtype=True)
    rt_inputs = list(rt_inputs)

    # Special case: if there's only one input, then return it as-is.
    if len(rt_inputs) == 1:
        if stack_values:
            return ragged_array_ops.expand_dims(rt_inputs[0], axis=axis)
        else:
            return rt_inputs[0]

    # Check the rank (number of dimensions) of the input tensors.
    ndims = None
    for rt in rt_inputs:
        if ndims is None:
            ndims = rt.shape.ndims
        else:
            rt.shape.assert_has_rank(ndims)

    out_ndims = ndims if (ndims is None or not stack_values) else ndims + 1
    axis = ragged_util.get_positive_axis(axis, out_ndims)

    # If all the inputs are Tensors, and we're combining the final dimension,
    # then we can delegate to the tf.stack/tf.concat operation, and return a
    # Tensor.
    if all(not ragged_tensor.is_ragged(rt) for rt in rt_inputs):
        if ndims is not None and (axis == out_ndims - 1 or axis == ndims - 1):
            if stack_values:
                return array_ops.stack(rt_inputs, axis)
            else:
                return array_ops.concat(rt_inputs, axis)

    # Convert any Tensor inputs to RaggedTensors.  This makes it
    # possible to concatenate Tensors and RaggedTensors together.
    for i in range(len(rt_inputs)):
        if not ragged_tensor.is_ragged(rt_inputs[i]):
            rt_inputs[i] = ragged_tensor.RaggedTensor.from_tensor(
                rt_inputs[i], ragged_rank=1, row_splits_dtype=row_splits_dtype)

    # Convert the input tensors to all have the same ragged_rank.
    ragged_rank = max(max(rt.ragged_rank for rt in rt_inputs), 1)
    rt_inputs = [
        _increase_ragged_rank_to(rt, ragged_rank, row_splits_dtype)
        for rt in rt_inputs
    ]

    if axis == 0:
        return _ragged_stack_concat_axis_0(rt_inputs, stack_values)
    elif axis == 1:
        return _ragged_stack_concat_axis_1(rt_inputs, stack_values)
    else:  # axis > 1: recurse.
        values = [rt.values for rt in rt_inputs]
        splits = [[rt_input.row_splits] for rt_input in rt_inputs]
        with ops.control_dependencies(ragged_util.assert_splits_match(splits)):
            return ragged_tensor.RaggedTensor.from_row_splits(
                _ragged_stack_concat_helper(values, axis - 1, stack_values),
                splits[0][0],
                validate=False)
Esempio n. 14
0
def gather_nd(params, indices, batch_dims=0, name=None):
  """Gather slices from `params` using `n`-dimensional indices.

  This operation is similar to `gather`, but it uses the innermost dimension
  of `indices` to define a slice into `params`.  In particular, if:

  * `indices` has shape `[A1...AN, I]`
  * `params` has shape `[B1...BM]`

  Then:

  * `result` has shape `[A1...AN, B_{I+1}...BM]`.
  * `result[a1...aN] = params[indices[a1...aN, :]]`

  Args:
    params: A potentially ragged tensor with shape `[A1...AN, I]`.
    indices: A potentially ragged tensor with shape `[B1...BM]`.
    batch_dims: Must be zero.
    name: A name for the operation (optional).

  Returns:
    A potentially ragged tensor with shape `[A1...AN, B_{I+1}...BM]`.

  #### Examples:
    ```python
    >>> params = tf.compat.v1.ragged.constant_value(
    ...     [ [ ['000', '001'], ['010'              ]          ],
    ...       [ ['100'       ], ['110', '111', '112'], ['120'] ],
    ...       [ [            ], ['210'              ]          ] ])

    >>> # Gather 2D slices from a 3D tensor
    >>> ragged.gather_nd(params, [[2], [0]])
    [ [ [            ], ['210'] ]
      [ ['000', '001'], ['010'] ] ]

    >>> # Gather 1D slices from a 3D tensor
    >>> ragged.gather_nd(params, [[2, 1], [0, 0]])
    [['210'], ['000', '001']]

    >>> # Gather scalars from a 3D tensor
    >>> ragged.gather_nd(params, [[0, 0, 1], [1, 1, 2]])
    ['001', '112']
    ```
  """
  if not isinstance(batch_dims, int) or batch_dims != 0:
    raise ValueError('batch_dims != 0 is not supported for ragged gather yet.')
  if not (ragged_tensor.is_ragged(params) or ragged_tensor.is_ragged(indices)):
    return array_ops.gather_nd(params, indices, name)

  with ops.name_scope(name, 'RaggedGatherNd', [params, indices]):

    params = ragged_tensor.convert_to_tensor_or_ragged_tensor(
        params, name='params')
    indices = ragged_tensor.convert_to_tensor_or_ragged_tensor(
        indices, name='indices')
    params, indices = ragged_tensor.match_row_splits_dtypes(params, indices)
    indices_shape = indices.shape
    indices_ndims = indices_shape.ndims
    if indices_ndims is None:
      raise ValueError('indices.rank be statically known.')
    if indices_ndims == 0:
      raise ValueError('indices.rank must be at least 1.')
    if (ragged_tensor.is_ragged(indices) and
        indices_ndims == indices.ragged_rank + 1):
      raise ValueError('The innermost dimension of indices may not be ragged')

    # `index_size` is the "n" in "gather_nd" -- i.e., the number of dimensions
    # that each index slices into.
    index_size = tensor_shape.dimension_value(indices_shape[-1])
    if index_size is None:
      raise ValueError('indices.shape[-1] must be statically known.')

    # If `indices` has more than 2 dimensions, then recurse.  If `indices` is
    # dense, then we convert it to ragged before recursing, and then convert
    # the result back to `dense` if appropriate.
    if indices_ndims > 2:
      indices_is_dense = not ragged_tensor.is_ragged(indices)
      if indices_is_dense:
        indices = ragged_tensor.RaggedTensor.from_tensor(
            indices, ragged_rank=indices_ndims - 2,
            row_splits_dtype=params.row_splits.dtype)
      result = indices.with_flat_values(gather_nd(params, indices.flat_values))
      if (indices_is_dense and ragged_tensor.is_ragged(result) and
          result.ragged_rank == indices_ndims - 2):
        result = ragged_tensor.RaggedTensor.to_tensor(result)
      return result

    # indices_ndims <= 2, and the innermost dimension of indices may not be
    # ragged, so `indices` must not be ragged.
    assert not ragged_tensor.is_ragged(indices)
    assert ragged_tensor.is_ragged(params)

    # Handle corner case: An empty index tuple selects the entire `params`
    # value.  So if `index_size` is zero, then tile `params`.
    if index_size == 0:
      params_ndims = params.ragged_rank + array_ops.rank(params.flat_values)
      for dim in range(indices_ndims - 1):
        params = ragged_array_ops.expand_dims(params, axis=0)
      multiples = array_ops.concat([
          array_ops.shape(indices)[:-1],
          array_ops.ones([params_ndims], dtypes.int32)
      ],
                                   axis=0)
      return ragged_array_ops.tile(params, multiples)

    # When index_size=1, we can just flatten the index tuples and use gather.
    elif index_size == 1:
      flattened_index_tuples = array_ops.reshape(indices, [-1])
      return gather(params, flattened_index_tuples)

    # Otherwise, params is a RaggedTensor, and indices is a 1D or 2D Tensor.
    # Flatten both the index tuples and the params, such that the flattened
    # index tuples point to the correct values in the flattened params; and
    # then use ragged.gather on the flattened index tuples & params.
    else:
      indices = math_ops.cast(indices, params.row_splits.dtype)

      # Flatten the outermost 2 dimensions of the index tuples & params.
      flattened_index_tuples = array_ops.gather(params.row_splits,
                                                indices[..., 0])
      flattened_index_tuples += indices[..., 1]
      flattened_params = params.values

      # Flatten any remaining dimensions.
      for dim in range(2, index_size):
        if not ragged_tensor.is_ragged(flattened_params):
          flattened_index_tuples = array_ops.expand_dims(
              flattened_index_tuples, axis=1)
          flattened_index_tuples = array_ops.concat(
              [flattened_index_tuples, indices[..., dim:]], axis=1)
          return array_ops.gather_nd(flattened_params, flattened_index_tuples)

        flattened_index_tuples = array_ops.gather(
            flattened_params.row_starts(), flattened_index_tuples)
        flattened_index_tuples += indices[..., dim]
        flattened_params = flattened_params.values

      # Gather using the flattened index tuples and params.
      return gather(flattened_params, flattened_index_tuples)
def batch_gather(params, indices, name=None):
    """Gathers slices from `params` according to `indices` with batch dims.

  This operation is similar to `gather`, but it assumes that the leading `N`
  dimensions of `indices` and `params` are batch dimensions, and performs a
  gather within each batch.  In particular, when using this operation with `N`
  batch dimensions `B1...BN`:

  * `indices` has shape `[B1...BN, I]`
  * `params` has shape `[B1...BN, P1...PM]`.
  * `result` has shape `[B1...BN, I, P2...PM]`.
  * `result[b1...bN, i, p2...pM] =
    params[b1...bN, indices[b1...bN, i], p2...pM]`

  Args:
    params: A potentially ragged tensor with shape `[B1...BN, P1...PM]` (`N>=0`,
      `M>0`).
    indices: A potentially ragged tensor with shape `[B1...BN, I]` (`N>=0`).
    name: A name for the operation (optional).

  Returns:
    A potentially ragged tensor with shape `[B1...BN, I, P2...PM]`.
    `result.ragged_rank = max(indices.ragged_rank, params.ragged_rank)`.

  #### Example:
    ```python
    >>> params = tf.ragged.constant([['a', 'b', 'c'], ['d'], [], ['e']])
    >>> indices = tf.ragged.constant([[1, 2, 0], [], [], [0, 0]])
    >>> tf.compat.v1.batch_gather(params, indices)
    [['b', 'c', 'a'], [], [], ['e', 'e']]
    ```
  """
    if not (ragged_tensor.is_ragged(params)
            or ragged_tensor.is_ragged(indices)):
        return array_ops.batch_gather(params, indices, name)

    with ops.name_scope(name, 'RaggedBatchGather', [params, indices]):
        params = ragged_tensor.convert_to_tensor_or_ragged_tensor(
            params, name='params')
        indices = ragged_tensor.convert_to_tensor_or_ragged_tensor(
            indices, name='indices')
        params, indices = ragged_tensor.match_row_splits_dtypes(
            params, indices)
        indices_ndims = indices.shape.ndims
        if indices_ndims is None:
            raise ValueError(
                'batch_gather does not allow indices with unknown shape.')
        if indices_ndims == 0:
            raise ValueError('indices.rank must be at least 1.')

        if ragged_tensor.is_ragged(indices):
            # If the outermost ragged dimension is a batch dimension, recurse.
            if indices_ndims > 2:
                if not ragged_tensor.is_ragged(params):
                    raise ValueError('batch shape from indices does '
                                     'not match params shape')
                checks = [
                    check_ops.assert_equal(params.row_splits,
                                           indices.row_splits)
                ]
                with ops.control_dependencies(checks):
                    return ragged_tensor.RaggedTensor.from_row_splits(
                        batch_gather(params.values, indices.values),
                        indices.row_splits,
                        validate=False)

            # Otherwise, indices is a 2D ragged tensor with 1 ragged dimension.
            else:
                # Ensure that `params` is ragged and has at least 2 dimensions.
                if not ragged_tensor.is_ragged(params):
                    if params.shape.ndims is not None and params.shape.ndims < 2:
                        raise ValueError('batch shape from indices does '
                                         'not match params shape')
                    params = ragged_tensor.RaggedTensor.from_tensor(
                        params,
                        ragged_rank=1,
                        row_splits_dtype=indices.row_splits.dtype)

                # Adjust indices from within-batch to global (in params.values), and
                # then use ragged.gather to gather them.
                num_indices = indices.row_lengths()
                params_starts = params.row_starts()
                adjustments = ragged_util.repeat(params_starts,
                                                 num_indices,
                                                 axis=0)
                adjusted_index_values = (
                    math_ops.cast(indices.values, adjustments.dtype) +
                    adjustments)
                return ragged_tensor.RaggedTensor.from_row_splits(
                    ragged_gather_ops.gather(params.values,
                                             adjusted_index_values),
                    indices.row_splits,
                    validate=False)

        else:  # params is a RaggedTensor and indices is a Tensor.
            if indices_ndims == 1:
                return ragged_gather_ops.gather(params, indices)
            elif indices_ndims == 2:
                # Adjust indices from batch-local to global (in params.values)
                adjustments = array_ops.expand_dims(params.row_starts(), 1)
                adjusted_indices = (math_ops.cast(indices, adjustments.dtype) +
                                    adjustments)
                return ragged_gather_ops.gather(params.values,
                                                adjusted_indices)
            else:
                raise ValueError(
                    'batch shape from indices does not match params shape')
Esempio n. 16
0
def _ragged_stack_concat_helper(rt_inputs, axis, stack_values):
  """Helper function to concatenate or stack ragged tensors.

  Args:
    rt_inputs: A list of RaggedTensors or Tensors to combine.
    axis: The axis along which to concatenate or stack.
    stack_values: A boolean -- if true, then stack values; otherwise,
      concatenate them.

  Returns:
    A RaggedTensor.
  Raises:
    ValueError: If rt_inputs is empty, or if axis is out of range.
  """
  # Validate parameters.
  if not rt_inputs:
    raise ValueError('rt_inputs may not be empty.')

  # Convert input tensors.
  rt_inputs = [
      ragged_tensor.convert_to_tensor_or_ragged_tensor(
          rt_input, name='rt_input') for rt_input in rt_inputs
  ]
  row_splits_dtype, rt_inputs = ragged_tensor.match_row_splits_dtypes(
      *rt_inputs, return_dtype=True)
  rt_inputs = list(rt_inputs)

  # Special case: if there's only one input, then return it as-is.
  if len(rt_inputs) == 1:
    if stack_values:
      return ragged_array_ops.expand_dims(rt_inputs[0], axis=axis)
    else:
      return rt_inputs[0]

  # Check the rank (number of dimensions) of the input tensors.
  ndims = None
  for rt in rt_inputs:
    if ndims is None:
      ndims = rt.shape.ndims
    else:
      rt.shape.assert_has_rank(ndims)

  out_ndims = ndims if (ndims is None or not stack_values) else ndims + 1
  axis = ragged_util.get_positive_axis(axis, out_ndims)

  # If all the inputs are Tensors, and we're combining the final dimension,
  # then we can delegate to the tf.stack/tf.concat operation, and return a
  # Tensor.
  if all(not ragged_tensor.is_ragged(rt) for rt in rt_inputs):
    if ndims is not None and (axis == out_ndims - 1 or axis == ndims - 1):
      if stack_values:
        return array_ops.stack(rt_inputs, axis)
      else:
        return array_ops.concat(rt_inputs, axis)

  # Convert any Tensor inputs to RaggedTensors.  This makes it
  # possible to concatenate Tensors and RaggedTensors together.
  for i in range(len(rt_inputs)):
    if not ragged_tensor.is_ragged(rt_inputs[i]):
      rt_inputs[i] = ragged_tensor.RaggedTensor.from_tensor(
          rt_inputs[i], ragged_rank=1, row_splits_dtype=row_splits_dtype)

  # Convert the input tensors to all have the same ragged_rank.
  ragged_rank = max(max(rt.ragged_rank for rt in rt_inputs), 1)
  rt_inputs = [_increase_ragged_rank_to(rt, ragged_rank, row_splits_dtype)
               for rt in rt_inputs]

  if axis == 0:
    return _ragged_stack_concat_axis_0(rt_inputs, stack_values)
  elif axis == 1:
    return _ragged_stack_concat_axis_1(rt_inputs, stack_values)
  else:  # axis > 1: recurse.
    values = [rt.values for rt in rt_inputs]
    splits = [[rt_input.row_splits] for rt_input in rt_inputs]
    with ops.control_dependencies(ragged_util.assert_splits_match(splits)):
      return ragged_tensor.RaggedTensor.from_row_splits(
          _ragged_stack_concat_helper(values, axis - 1, stack_values),
          splits[0][0], validate=False)
def batch_gather_with_default(params,
                              indices,
                              default_value='',
                              name=None):
  """Same as `batch_gather` but inserts `default_value` for invalid indices.

  This operation is similar to `batch_gather` except that it will substitute
  the value for invalid indices with `default_value` as the contents.
  See `batch_gather` for more details.


  Args:
    params: A potentially ragged tensor with shape `[B1...BN, P1...PM]` (`N>=0`,
      `M>0`).
    indices: A potentially ragged tensor with shape `[B1...BN, I]` (`N>=0`).
    default_value: A value to be inserted in places where `indices` are out of
      bounds. Must be the same dtype as params and either a scalar or rank 1.
    name: A name for the operation (optional).

  Returns:
    A potentially ragged tensor with shape `[B1...BN, I, P2...PM]`.
    `result.ragged_rank = max(indices.ragged_rank, params.ragged_rank)`.

  #### Example:
    ```python
    >>> params = tf.ragged.constant([
          ['a', 'b', 'c'],
          ['d'],
          [],
          ['e']])
    >>> indices = tf.ragged.constant([[1, 2, -1], [], [], [0, 10]])
    >>> batch_gather_with_default(params, indices, 'FOO')
    [['b', 'c', 'FOO'], [], [], ['e', 'FOO']]
  ```
  """
  with ops.name_scope(name, 'RaggedBatchGatherWithDefault'):
    params = ragged_tensor.convert_to_tensor_or_ragged_tensor(
        params, name='params',
    )
    indices = ragged_tensor.convert_to_tensor_or_ragged_tensor(
        indices, name='indices',
    )
    default_value = ragged_tensor.convert_to_tensor_or_ragged_tensor(
        default_value, name='default_value',
    )
    row_splits_dtype, (params, indices, default_value) = (
        ragged_tensor.match_row_splits_dtypes(params, indices, default_value,
                                              return_dtype=True))
    # TODO(hterry): lift this restriction and support default_values of
    #               of rank > 1
    if default_value.shape.ndims not in (0, 1):
      raise ValueError('"default_value" must be a scalar or vector')
    upper_bounds = None
    if indices.shape.ndims is None:
      raise ValueError('Indices must have a known rank.')
    if params.shape.ndims is None:
      raise ValueError('Params must have a known rank.')

    num_batch_dimensions = indices.shape.ndims - 1
    pad = None
    # The logic for this works as follows:
    # - create a padded params, where:
    #    padded_params[b1...bn, 0] = default_value
    #    padded_params[b1...bn, i] = params[b1...bn, i-1] (i>0)
    # - create an `upper_bounds` Tensor that contains the number of elements
    #   in each innermost rank. Broadcast `upper_bounds` to be the same shape
    #   as `indices`.
    # - check to see which index in `indices` are out of bounds and substitute
    #   it with the index containing `default_value` (the first).
    # - call batch_gather with the indices adjusted.
    with ops.control_dependencies([
        check_ops.assert_greater_equal(array_ops.rank(params),
                                       array_ops.rank(indices))]):
      if ragged_tensor.is_ragged(params):
        row_lengths = ragged_array_ops.expand_dims(
            params.row_lengths(axis=num_batch_dimensions),
            axis=-1)
        upper_bounds = math_ops.cast(row_lengths, indices.dtype)

        pad_shape = _get_pad_shape(params, indices, row_splits_dtype)

        pad = ragged_tensor_shape.broadcast_to(
            default_value, pad_shape)
      else:
        params_shape = array_ops.shape(params)
        pad_shape = array_ops.concat([
            params_shape[:num_batch_dimensions],
            [1],
            params_shape[num_batch_dimensions + 1:params.shape.ndims]
        ], 0)
        upper_bounds = params_shape[num_batch_dimensions]
        pad = array_ops.broadcast_to(default_value, pad_shape)

      # Add `default_value` as the first value in the innermost (ragged) rank.
      pad = math_ops.cast(pad, params.dtype)
      padded_params = array_ops.concat(
          [pad, params], axis=num_batch_dimensions)

      # Adjust the indices by substituting out-of-bound indices to the
      # default-value index (which is the first element)
      shifted_indices = indices + 1
      is_out_of_bounds = (indices < 0) | (indices > upper_bounds)
      adjusted_indices = ragged_where_op.where(
          is_out_of_bounds,
          x=array_ops.zeros_like(indices), y=shifted_indices,
      )
      return array_ops.batch_gather(
          params=padded_params, indices=adjusted_indices, name=name)
Esempio n. 18
0
def map_fn(fn,
           elems,
           dtype=None,
           parallel_iterations=None,
           back_prop=True,
           swap_memory=False,
           infer_shape=True,
           name=None):
  """map on the list of tensors unpacked from `elems` on dimension 0.

  The simplest version of `map_fn` repeatedly applies the callable `fn` to a
  sequence of elements from first to last. The elements are made of the
  tensors unpacked from `elems`. `dtype` is the data type of the return
  value of `fn`. Users must provide `dtype` if it is different from
  the data type of `elems`.

  Suppose that `elems` is unpacked into `values`, a list of tensors. The shape
  of the result tensor is `[values.shape[0]] + fn(values[0]).shape`.

  This method also allows multi-arity `elems` and output of `fn`.  If `elems`
  is a (possibly nested) list or tuple of tensors, then each of these tensors
  must have a matching first (unpack) dimension.  The signature of `fn` may
  match the structure of `elems`.  That is, if `elems` is
  `(t1, [t2, t3, [t4, t5]])`, then an appropriate signature for `fn` is:
  `fn = lambda (t1, [t2, t3, [t4, t5]]):`.

  Furthermore, `fn` may emit a different structure than its input.  For example,
  `fn` may look like: `fn = lambda t1: return (t1 + 1, t1 - 1)`.  In this case,
  the `dtype` parameter is not optional: `dtype` must be a type or (possibly
  nested) tuple of types matching the output of `fn`.

  To apply a functional operation to the nonzero elements of a SparseTensor
  one of the following methods is recommended. First, if the function is
  expressible as TensorFlow ops, use

  ```python
    result = SparseTensor(input.indices, fn(input.values), input.dense_shape)
  ```

  If, however, the function is not expressible as a TensorFlow op, then use

  ```python
  result = SparseTensor(
    input.indices, map_fn(fn, input.values), input.dense_shape)
  ```

  instead.

  When executing eagerly, map_fn does not execute in parallel even if
  `parallel_iterations` is set to a value > 1. You can still get the
  performance benefits of running a function in parallel by using the
  `tf.contrib.eager.defun` decorator,

  ```python
  # Assume the function being used in map_fn is fn.
  # To ensure map_fn calls fn in parallel, use the defun decorator.
  @tf.contrib.eager.defun
  def func(tensor):
    return tf.map_fn(fn, tensor)
  ```

  Note that if you use the defun decorator, any non-TensorFlow Python code
  that you may have written in your function won't get executed. See
  `tf.contrib.eager.defun` for more details. The recommendation would be to
  debug without defun but switch to defun to get performance benefits of
  running map_fn in parallel.

  Args:
    fn: The callable to be performed.  It accepts one argument, which will have
      the same (possibly nested) structure as `elems`.  Its output must have the
      same structure as `dtype` if one is provided, otherwise it must have the
      same structure as `elems`.
    elems: A tensor or (possibly nested) sequence of tensors, each of which will
      be unpacked along their first dimension.  The nested sequence of the
      resulting slices will be applied to `fn`.
    dtype: (optional) The output type(s) of `fn`.  If `fn` returns a structure
      of Tensors differing from the structure of `elems`, then `dtype` is not
      optional and must have the same structure as the output of `fn`. Use
      `RaggedTensorType` to declare an output of type `RaggedTensor`.
    parallel_iterations: (optional) The number of iterations allowed to run in
      parallel. When graph building, the default value is 10. While executing
      eagerly, the default value is set to 1.
    back_prop: (optional) True enables support for back propagation.
    swap_memory: (optional) True enables GPU-CPU memory swapping.
    infer_shape: (optional) False disables tests for consistent output shapes.
    name: (optional) Name prefix for the returned tensors.

  Returns:
    A possibly nested sequence of potentially ragged tensors.  Each
    tensor packs the results of applying `fn` to tensors unpacked from `elems`
    along the first dimension, from first to last.

  Raises:
    TypeError: if `fn` is not callable or the structure of the output of
      `fn` and `dtype` do not match, or if elems is a SparseTensor.
    ValueError: if the lengths of the output of `fn` and `dtype` do not match.

  #### Examples:

    ```python
    elems = np.array([1, 2, 3, 4, 5, 6])
    squares = map_fn(lambda x: x * x, elems)
    # squares == [1, 4, 9, 16, 25, 36]
    ```

    ```python
    elems = (np.array([1, 2, 3]), np.array([-1, 1, -1]))
    alternate = map_fn(lambda x: x[0] * x[1], elems, dtype=tf.int64)
    # alternate == [-1, 2, -3]
    ```

    ```python
    elems = np.array([1, 2, 3])
    alternates = map_fn(lambda x: (x, -x), elems, dtype=(tf.int64, tf.int64))
    # alternates[0] == [1, 2, 3]
    # alternates[1] == [-1, -2, -3]
    ```

    ```python
    elems=ragged.constant([[1, 2, 3], [4, 5], [6, 7]])
    mean = map_fn(tf.reduce_mean, elems)
    # mean == [2, 4, 6]
    ```

    ```python
    elems=ragged.constant([[1, 2, 3], [4, 5], [6, 7]], dtype=tf.int64)
    out = map_fn(fn=lambda x: x+1, elems,
      dtype=ragged.RaggedTensorType(type=tf.int64, ragged_rank=0))
    # out = ragged.constant([[2, 3, 4], [5, 6], [7, 8]])
    ```
  """
  if not callable(fn):
    raise TypeError("fn must be callable.")

  if isinstance(elems, sparse_tensor.SparseTensor):
    raise TypeError(
        "To perform a map on the values of a sparse tensor use either "
        " SparseTensor(input.indices, fn(input.values), input.dense_shape) or "
        " SparseTensor(input.indices, map_fn(fn, input.values), "
        "input.dense_shape)")

  in_graph_mode = not context.executing_eagerly()
  # Set the default number of parallel_iterations depending on graph/eager mode.
  if in_graph_mode and not parallel_iterations:
    parallel_iterations = 10
  elif not in_graph_mode and not parallel_iterations:
    parallel_iterations = 1

  if not in_graph_mode and parallel_iterations > 1:
    logging.log_first_n(logging.WARN, "Setting parallel_iterations > 1 has no "
                        "effect when executing eagerly. Consider calling map_fn"
                        " with tf.contrib.eager.defun to execute fn in "
                        "parallel.", 1)
    parallel_iterations = 1

  input_is_sequence = nest.is_sequence(elems)
  input_flatten = lambda x: nest.flatten(x) if input_is_sequence else [x]

  def input_pack(x):
    return nest.pack_sequence_as(elems, x) if input_is_sequence else x[0]

  elems_flat = input_flatten(elems)
  elems_flat = ragged_tensor.match_row_splits_dtypes(*elems_flat)

  with ops.name_scope(name, "map", elems_flat):
    # TODO(akshayka): Remove the in_graph_mode check once caching devices are
    # supported in Eager
    if in_graph_mode:
      # Any get_variable calls in fn will cache the first call locally
      # and not issue repeated network I/O requests for each iteration.
      varscope = vs.get_variable_scope()
      varscope_caching_device_was_none = False
      if varscope.caching_device is None:
        # TODO(ebrevdo): Change to using colocate_with here and in other
        # methods.
        varscope.set_caching_device(lambda op: op.device)
        varscope_caching_device_was_none = True

    elems_flat = [
        ragged_tensor.convert_to_tensor_or_ragged_tensor(elem, name="elem")
        for elem in elems_flat
    ]

    # We can either infer the output, or we can assume that it will be the same
    # as the input structure.
    dtype = dtype or input_pack([elem.dtype for elem in elems_flat])

    # Find the number of iterations, n may be known statically.
    if isinstance(elems_flat[0], ragged_tensor.RaggedTensor):
      n = elems_flat[0].nrows(out_type=dtypes.int32)
    else:
      static_shape = elems_flat[0].shape
      if static_shape.ndims is not None and static_shape.ndims < 1:
        if len(elems_flat) == 1:
          raise ValueError(
              "elems must be a 1+ dimensional Tensor, not a scalar")
        else:
          raise ValueError(
              "elements in elems must be 1+ dimensional Tensors, not scalars")
      n = (tensor_shape.dimension_value(static_shape[0]) or
           array_ops.shape(elems_flat[0])[0])

    n = math_ops.cast(n, dtype=dtypes.int32)
    # Create a flat list of TAs.

    # Flatten the dtype structure to a list.
    dtype_flat = nest.flatten(dtype)

    # decompose to components
    dtype_components = [_maybe_decompose_dtype(d) for d in dtype_flat]
    dtype_components_flat = nest.flatten(dtype_components)

    # Create TensorArrays.
    accs_ta = [
        tensor_array_ops.TensorArray(
            dtype=t, dynamic_size=False, infer_shape=infer_shape, size=n)
        for t in dtype_components_flat
    ]

    i = constant_op.constant(0, dtype=dtypes.int32)

    def compute(i, tas):
      """The loop body of map_fn.

      Args:
        i: the loop counter
        tas: the flat TensorArray accumulator list

      Returns:
        (i + 1, tas): the updated counter + updated TensorArrays

      Raises:
        TypeError: if dtype and packed_fn_values structure do not match
        ValueType: if dtype and packed_fn_values lengths do not match
      """
      # Get Tensors or RaggedTensors sliced at i, then pack it back to the
      # original structure.
      packed_values = input_pack([elem_flat[i] for elem_flat in elems_flat])
      packed_fn_values = fn(packed_values)

      # Check that the structure of the output matches what was declared or
      # inferred.
      # nest.assert_same_structure(dtype or elems, packed_fn_values)

      # Flatten and decompose to a list of Tensors
      flat_fn_values = nest.flatten(packed_fn_values)

      # If we declared that we are expecting a RaggedTensor output, but we get a
      # Tensor output. We should try to convert it to a RaggedTensor.
      flat_fn_composite_tensors = list(
          _convert_declared(flat_fn_values, dtype_flat))

      flat_fn_components = [
          _maybe_decompose_tensor(t) for t in flat_fn_composite_tensors
      ]
      flat_fn_tensors = nest.flatten(flat_fn_components)

      # Write to TAs.
      tas = [ta.write(i, value) for (ta, value) in zip(tas, flat_fn_tensors)]

      return (i + 1, tas)

    _, r_a = control_flow_ops.while_loop(
        lambda i, _: i < n, compute, (i, accs_ta),
        parallel_iterations=parallel_iterations,
        back_prop=back_prop,
        swap_memory=swap_memory,
        maximum_iterations=n)

    # TODO(akshayka): Remove the in_graph_mode check once caching devices are
    # supported in Eager
    if in_graph_mode and varscope_caching_device_was_none:
      varscope.set_caching_device(None)

    # Pack back into a list of components
    results_as_components = nest.pack_sequence_as(dtype_components, r_a)

    # Stack TensorArrays for Tensor outputs, and concat RaggedTensor outputs.
    def _stack_or_concat(e):
      if isinstance(e, _RaggedTensorComponents):
        return _concat_ragged_tensor_components(e)
      else:
        result = e.stack()
        return result

    results_flat_components = [
        _stack_or_concat(e) for e in results_as_components
    ]

    results_packed = [
        _maybe_recompose_tensor(c) for c in results_flat_components
    ]
    results_packed = nest.pack_sequence_as(dtype, results_packed)
    return results_packed
Esempio n. 19
0
def boolean_mask(data, mask, keepdims=False, name=None):
  """Applies a boolean mask to `data`.

  Returns a potentially ragged tensor that is formed by retaining the elements
  in `data` where the corresponding value in `mask` is `True`.

  If `keepdims` is true then outer dimensions (corresponding to the `mask`
  dimensions) are preserved, and:

  * `output[a1...aA, i, b1...bB] = data[a1...aA, j, b1...bB]`

     Where `j` is the `i`th `True` entry of `mask[a1...aA]`.

  If `keepdims` is false, then the outer dimensions are collapsed (similar to
  the behavior of `tf.boolean_mask`), and:

  * `output[i, b1...bB] = data[a1...aA, b1...bB]`

     Where `(a1...aA)` is the `i`th `True` entry of `mask`
     (in row-major order).

  Args:
    data: A potentially ragged tensor.
    mask: A potentially ragged boolean tensor.  `mask`'s shape must be a prefix
      of `data`'s shape.  `rank(mask)` must be known statically.
    keepdims: Whether to preserve the outer dimensions (`keepdims=True`) or
      flatten them (`keepdims=False`).
    name: A name prefix for the returned tensor (optional).

  Returns:
    A potentially ragged tensor that is formed by retaining the elements in
    `data` where the corresponding value in `mask` is `True`.

    If `keepdims` is false:

    * `rank(output) = rank(data) - rank(mask) + 1`.
    * `output.ragged_rank = max(data.ragged_rank - rank(mask) + 1, 0)`.

    If `keepdims` is true:

    * `rank(output) = rank(data)`.
    * `output.ragged_rank = max(data.ragged_rank, rank(mask) - 1)`.

  Raises:
    ValueError: if `rank(mask)` is not known statically; or if `mask.shape` is
      not a prefix of `data.shape`.

  #### Examples:
    ```python
    >>> # Aliases for True & False so data and mask line up.
    >>> T, F = (True, False)

    >>> tf.ragged.boolean_mask(  # Mask a 2D Tensor.  Flatten outer dims.
    ...     data=[[1, 2, 3], [4, 5, 6], [7, 8, 9]],
    ...     mask=[[T, F, T], [F, F, F], [T, F, F]],
    ...     keepdims=False).tolist()
    [1, 3, 7]

    >>> tf.ragged.boolean_mask(  # Mask a 2D Tensor.  Preserve outer dims.
    ...     data=[[1, 2, 3], [4, 5, 6], [7, 8, 9]],
    ...     mask=[[T, F, T], [F, F, F], [T, F, F]],
    ...     keepdims=True).tolist()
    [[1, 3], [], [7]]

    >>> tf.ragged.boolean_mask(  # Mask a 2D RaggedTensor.  Flatten outer dims.
    ...     tf.ragged.constant([[1, 2, 3], [4], [5, 6]]),
    ...     tf.ragged.constant([[F, F, T], [F], [T, T]]),
    ...     keepdims=False).tolist()
    [3, 5, 6]

    >>> tf.ragged.boolean_mask(  # Mask a 2D RaggedTensor.  Preserve outer dims.
    ...     tf.ragged.constant([[1, 2, 3], [4], [5, 6]]),
    ...     tf.ragged.constant([[F, F, T], [F], [T, T]]),
    ...     keepdims=True).tolist()
    [[3], [], [5, 6]]

    >>> tf.ragged.boolean_mask(  # Mask rows of a 2D RaggedTensor.
    ...     tf.ragged.constant([[1, 2, 3], [4], [5, 6]]),
    ...     tf.ragged.constant([True, False, True]),
    ...     keepdims=True).tolist()
    [[1, 2, 3], [5, 6]]
    ```
  """
  with ops.name_scope(name, 'RaggedMask', [data, mask]):
    # Convert inputs to tensors.
    data = ragged_tensor.convert_to_tensor_or_ragged_tensor(data, name='data')
    mask = ragged_tensor.convert_to_tensor_or_ragged_tensor(
        mask, dtypes.bool, name='mask')
    row_splits_dtype, (data, mask) = ragged_tensor.match_row_splits_dtypes(
        data, mask, return_dtype=True)

    # Get static rank of mask.
    if mask.shape.ndims is None:
      raise ValueError('mask.shape.ndims must be known statically.')
    elif mask.shape.ndims == 0:
      raise ValueError('mask cannot be scalar.')

    # If mask is ragged, then recurse with a non-ragged mask.
    if ragged_tensor.is_ragged(mask):
      if not ragged_tensor.is_ragged(data):
        data = ragged_tensor.RaggedTensor.from_tensor(
            data, ragged_rank=mask.ragged_rank,
            row_splits_dtype=mask.row_splits.dtype)
      # Check that mask.nested_row_splits is a prefix of
      # data.nested_row_splits.
      splits_list = [
          mask.nested_row_splits, data.nested_row_splits[:mask.ragged_rank]
      ]
      with ops.control_dependencies(
          ragged_util.assert_splits_match(splits_list)):
        # Strip off ragged `splits` until `mask` is non-ragged.  Keep the splits
        # that we strip off in `splits`, so we can add them back on after
        # we recursively mask the non-ragged data.
        splits = []
        while ragged_tensor.is_ragged(mask):
          if mask.shape.ndims > 2:
            splits.append(mask.row_splits)
          else:
            # Count the number of True mask values in each row to find the
            # lengths of the filtered rows; then convert to splits.
            int_mask = ragged_functional_ops.map_flat_values(
                math_ops.cast, mask, dtype=row_splits_dtype)
            masked_row_lengths = ragged_math_ops.reduce_sum(int_mask, axis=1)
            splits.append(ragged_util.lengths_to_splits(masked_row_lengths))
          mask = mask.values
          data = data.values

        # Recursively apply the nested non-ragged mask to the nested data.
        masked_values = boolean_mask(data, mask, keepdims)

        # Add the ragged `splits` back to the result.
        if keepdims:
          masked_values = ragged_tensor.RaggedTensor.from_nested_row_splits(
              masked_values, splits, validate=False)

        return masked_values

    # If mask is non-ragged and has rank 1, and data is ragged, then build a
    # ragged tensor with the indicated rows.
    elif ragged_tensor.is_ragged(data) and mask.shape.ndims == 1:
      # Get the masked splits: first get the length of each row, then filter
      # out the rows that we are deleting, and convert that filtered set of
      # masks back to a splits tensor.
      lengths = data.row_lengths()
      masked_lengths = array_ops.boolean_mask(lengths, mask)
      masked_splits = ragged_util.lengths_to_splits(masked_lengths)

      # Get the masked values: first get row ids corresponding to each
      # value, then use tf.gather to build a boolean mask that's false for
      # values that come from rows that we are deleting, and use that mask to
      # construct the masked values tensor.
      segment_ids = segment_id_ops.row_splits_to_segment_ids(data.row_splits)
      segment_mask = array_ops.gather(mask, segment_ids)
      masked_values = boolean_mask(data.values, segment_mask, keepdims=False)

      return ragged_tensor.RaggedTensor.from_row_splits(masked_values,
                                                        masked_splits,
                                                        validate=False)

    # If mask is non-ragged and has rank>1, then convert it to be ragged,
    # with a ragged rank matching data.
    if ragged_tensor.is_ragged(data):
      mask = ragged_tensor.RaggedTensor.from_tensor(
          mask, ragged_rank=min(data.ragged_rank, mask.shape.ndims - 1),
          row_splits_dtype=data.row_splits.dtype)
      return boolean_mask(data, mask, keepdims)

    # Otherwise, data and mask are both `Tensor`s.
    else:
      # Apply `boolean_mask` to get the masked values.
      masked_values = array_ops.boolean_mask(data, mask)

      if mask.shape.ndims >= 2 and keepdims:
        # Add the innermost ragged dimension.  For each innermost cell, get the
        # number of values it contains.  Then flatten that to get a list of
        # cell lengths, and convert it to splits.  Finally, combine the splits
        # and values to get the innermost ragged tensor.
        masked_lengths = math_ops.count_nonzero(mask, axis=-1,
                                                dtype=row_splits_dtype)
        flattened_masked_lengths = array_ops.reshape(masked_lengths, [-1])
        masked_values = ragged_tensor.RaggedTensor.from_row_lengths(
            masked_values, flattened_masked_lengths, validate=False)

        # Wrap remaining ragged dimensions.
        if mask.shape.ndims > 2 and keepdims:
          mask_shape = array_ops.shape(mask, out_type=row_splits_dtype)
          split_size = math_ops.cumprod(mask_shape) + 1
          for dim in range(mask.shape.ndims - 3, -1, -1):
            elt_size = mask_shape[dim + 1]
            masked_splits = math_ops.range(split_size[dim]) * elt_size
            masked_values = ragged_tensor.RaggedTensor.from_row_splits(
                masked_values, masked_splits, validate=False)

      return masked_values
Esempio n. 20
0
def where(condition, x=None, y=None, name=None):
  """Return the elements, either from `x` or `y`, depending on the `condition`.

  : If both `x` and `y` are `None`:
    Returns the coordinates of true elements of `condition`. The coordinates
    are returned in a 2-D tensor with shape
    `[num_true_values, dim_size(condition)]`, where `result[i]` is the
    coordinates of the `i`th true value (in row-major order).

  : If both `x` and `y` are non-`None`:
    Returns a tensor formed by selecting values from `x` where condition is
    true, and from `y` when condition is false.  In particular:

    : If `condition`, `x`, and `y` all have the same shape:

      * `result[i1...iN] = x[i1...iN]` if `condition[i1...iN]` is true.
      * `result[i1...iN] = y[i1...iN]` if `condition[i1...iN]` is false.

    : Otherwise:

      * `condition` must be a vector.
      * `x` and `y` must have the same number of dimensions.
      * The outermost dimensions of `condition`, `x`, and `y` must all have the
        same size.
      * `result[i] = x[i]` if `condition[i]` is true.
      * `result[i] = y[i]` if `condition[i]` is false.

  Args:
    condition: A potentially ragged tensor of type `bool`
    x: A potentially ragged tensor (optional).
    y: A potentially ragged tensor (optional).  Must be specified if `x` is
      specified.  Must have the same rank and type as `x`.
    name: A name of the operation (optional)

  Returns:
    : If both `x` and `y` are `None`:
      A `Tensor` with shape `(num_true, dim_size(condition))`.
    : Otherwise:
      A potentially ragged tensor with the same type, rank, and outermost
      dimension size as `x` and `y`.
      `result.ragged_rank = max(x.ragged_rank, y.ragged_rank)`.

  Raises:
    ValueError: When exactly one of `x` or `y` is non-`None`; or when
      `condition`, `x`, and `y` have incompatible shapes.

  #### Examples:

  >>> # Coordinates where condition is true.
  >>> condition = tf.ragged.constant([[True, False, True], [False, True]])
  >>> print(where(condition))
  tf.Tensor( [[0 0] [0 2] [1 1]], shape=(3, 2), dtype=int64)

  >>> # Elementwise selection between x and y, based on condition.
  >>> condition = tf.ragged.constant([[True, False, True], [False, True]])
  >>> x = tf.ragged.constant([['A', 'B', 'C'], ['D', 'E']])
  >>> y = tf.ragged.constant([['a', 'b', 'c'], ['d', 'e']])
  >>> print(where(condition, x, y))
  <tf.RaggedTensor [[b'A', b'b', b'C'], [b'd', b'E']]>

  >>> # Row selection between x and y, based on condition.
  >>> condition = [True, False]
  >>> x = tf.ragged.constant([['A', 'B', 'C'], ['D', 'E']])
  >>> y = tf.ragged.constant([['a', 'b', 'c'], ['d', 'e']])
  >>> print(where(condition, x, y))
  <tf.RaggedTensor [[b'A', b'B', b'C'], [b'd', b'e']]>
  """
  if (x is None) != (y is None):
    raise ValueError('x and y must be either both None or both non-None')
  with ops.name_scope('RaggedWhere', name, [condition, x, y]):
    condition = ragged_tensor.convert_to_tensor_or_ragged_tensor(
        condition, name='condition')
    if x is None:
      return _coordinate_where(condition)
    else:
      x = ragged_tensor.convert_to_tensor_or_ragged_tensor(x, name='x')
      y = ragged_tensor.convert_to_tensor_or_ragged_tensor(y, name='y')
      condition, x, y = ragged_tensor.match_row_splits_dtypes(condition, x, y)
      return _elementwise_where(condition, x, y)
Esempio n. 21
0
def where(condition, x=None, y=None, name=None):
  """Return the elements, either from `x` or `y`, depending on the `condition`.

  : If both `x` and `y` are `None`:
    Returns the coordinates of true elements of `condition`. The coordinates
    are returned in a 2-D tensor with shape
    `[num_true_values, dim_size(condition)]`, where `result[i]` is the
    coordinates of the `i`th true value (in row-major order).

  : If both `x` and `y` are non-`None`:
    Returns a tensor formed by selecting values from `x` where condition is
    true, and from `y` when condition is false.  In particular:

    : If `condition`, `x`, and `y` all have the same shape:

      * `result[i1...iN] = x[i1...iN]` if `condition[i1...iN]` is true.
      * `result[i1...iN] = y[i1...iN]` if `condition[i1...iN]` is false.

    : Otherwise:

      * `condition` must be a vector.
      * `x` and `y` must have the same number of dimensions.
      * The outermost dimensions of `condition`, `x`, and `y` must all have the
        same size.
      * `result[i] = x[i]` if `condition[i]` is true.
      * `result[i] = y[i]` if `condition[i]` is false.

  Args:
    condition: A potentially ragged tensor of type `bool`
    x: A potentially ragged tensor (optional).
    y: A potentially ragged tensor (optional).  Must be specified if `x` is
      specified.  Must have the same rank and type as `x`.
    name: A name of the operation (optional)

  Returns:
    : If both `x` and `y` are `None`:
      A `Tensor` with shape `(num_true, dim_size(condition))`.
    : Otherwise:
      A potentially ragged tensor with the same type, rank, and outermost
      dimension size as `x` and `y`.
      `result.ragged_rank = max(x.ragged_rank, y.ragged_rank)`.

  Raises:
    ValueError: When exactly one of `x` or `y` is non-`None`; or when
      `condition`, `x`, and `y` have incompatible shapes.

  #### Examples:
    ```python
    >>> # Coordinates where condition is true.
    >>> condition = tf.compat.v1.ragged.constant_value(
    ...     [[True, False, True], [False, True]])
    >>> ragged.where(condition)
    [[0, 0], [0, 2], [1, 1]]

    >>> # Elementwise selection between x and y, based on condition.
    >>> condition = tf.compat.v1.ragged.constant_value(
    ...     [[True, False, True], [False, True]])
    >>> x = tf.compat.v1.ragged.constant_value([['A', 'B', 'C'], ['D', 'E']])
    >>> y = tf.compat.v1.ragged.constant_value([['a', 'b', 'c'], ['d', 'e']])
    >>> ragged.where(condition, x, y)
    [['A', 'b', 'C'], ['d', 'E']]

    >>> # Row selection between x and y, based on condition.
    >>> condition = [True, False]
    >>> x = tf.compat.v1.ragged.constant_value([['A', 'B', 'C'], ['D', 'E']])
    >>> y = tf.compat.v1.ragged.constant_value([['a', 'b', 'c'], ['d', 'e']])
    >>> ragged.where(condition, x, y)
    [['A', 'B', 'C'], ['d', 'e']]
    ```
  """
  if (x is None) != (y is None):
    raise ValueError('x and y must be either both None or both non-None')
  with ops.name_scope('RaggedWhere', name, [condition, x, y]):
    condition = ragged_tensor.convert_to_tensor_or_ragged_tensor(
        condition, name='condition')
    if x is None:
      return _coordinate_where(condition)
    else:
      x = ragged_tensor.convert_to_tensor_or_ragged_tensor(x, name='x')
      y = ragged_tensor.convert_to_tensor_or_ragged_tensor(y, name='y')
      condition, x, y = ragged_tensor.match_row_splits_dtypes(condition, x, y)
      return _elementwise_where(condition, x, y)
Esempio n. 22
0
def gather(params, indices, validate_indices=None, axis=0, batch_dims=0,
           name=None):
  """Gathers ragged slices from `params` axis `0` according to `indices`.

  Returns `RaggedTensor` output, such that:

  ```python
  output.shape = indices.shape + params.shape[1:]
  output.ragged_rank = indices.shape.ndims + params.ragged_rank
  output[i...j, d0...dn] = params[indices[i...j], d0...dn]
  ```

  `params` may be ragged.  `indices` may be ragged.
  `indices` must have dtype `int32` or `int64`. If any index is out of bounds,
  then an error is returned.

  Examples:

  ```python
  >>> params = tf.constant(['a', 'b', 'c', 'd', 'e'])
  >>> indices = tf.constant([3, 1, 2, 1, 0])
  >>> ragged_params = tf.ragged.constant([['a', 'b', 'c'], ['d'], [], ['e']])
  >>> ragged_indices = tf.ragged.constant([[3, 1, 2], [1], [], [0]])

  >>> print ragged.gather(params, ragged_indices)
  [['d', 'b', 'c'], ['b'], [], ['a']]

  >>> print ragged.gather(ragged_params, indices)
  [['e'], ['d'], [], ['d'], ['a', 'b', 'c']]

  >>> print ragged.gather(ragged_params, ragged_indices)
  [[['e'], ['d'], []], [['d']], [], [['a', 'b', 'c']]]
  ```

  Args:
    params: The potentially ragged tensor from which to gather values. Must be
      at least rank 1.
    indices: The potentially ragged tensor indicating which values to gather.
      Must have dtype `int32` or `int64`.  Values must be in the range `[0,
      params.shape[0]]`.
    validate_indices: Ignored.
    axis: Must be zero.
    batch_dims: Must be zero.
    name: A name for the operation (optional).

  Returns:
    A `RaggedTensor`, where `output.dtype=params.dtype` and
    `output.shape=indices.shape + params.shape[1:]` and
    `output.ragged_rank=indices.shape.ndims + params.ragged_rank`.

  Raises:
    ValueError: If indices.shape.ndims is not known statically.
  """
  del validate_indices
  if not isinstance(axis, int) or axis != 0:
    raise ValueError('axis != 0 is not supported for ragged gather yet.')
  if not isinstance(batch_dims, int) or batch_dims != 0:
    raise ValueError('batch_dims != 0 is not supported for ragged gather yet.')
  with ops.name_scope(name, 'RaggedGather', [params, indices]):
    params = ragged_tensor.convert_to_tensor_or_ragged_tensor(
        params, name='params')
    indices = ragged_tensor.convert_to_tensor_or_ragged_tensor(
        indices, name='indices')
    params, indices = ragged_tensor.match_row_splits_dtypes(params, indices)

    if ragged_tensor.is_ragged(indices):
      return indices.with_values(gather(params, indices.values))

    if not ragged_tensor.is_ragged(params):
      return array_ops.gather(params, indices)

    indices = ops.convert_to_tensor(indices)
    if indices.shape.ndims is None:
      raise ValueError('indices.shape.ndims must be known statically')

    result = gen_ragged_array_ops.ragged_gather(
        indices=indices,
        params_dense_values=params.flat_values,
        params_nested_splits=params.nested_row_splits,
        OUTPUT_RAGGED_RANK=indices.shape.ndims + len(params.nested_row_splits) -
        1)

    # Compose the RaggedTensor from splits & values.
    return ragged_tensor.RaggedTensor.from_nested_row_splits(
        result.output_dense_values, result.output_nested_splits, validate=False)
Esempio n. 23
0
def map_fn(fn,
           elems,
           dtype=None,
           parallel_iterations=None,
           back_prop=True,
           swap_memory=False,
           infer_shape=True,
           name=None):
    """map on the list of tensors unpacked from `elems` on dimension 0.

  The simplest version of `map_fn` repeatedly applies the callable `fn` to a
  sequence of elements from first to last. The elements are made of the
  tensors unpacked from `elems`. `dtype` is the data type of the return
  value of `fn`. Users must provide `dtype` if it is different from
  the data type of `elems`.

  Suppose that `elems` is unpacked into `values`, a list of tensors. The shape
  of the result tensor is `[values.shape[0]] + fn(values[0]).shape`.

  This method also allows multi-arity `elems` and output of `fn`.  If `elems`
  is a (possibly nested) list or tuple of tensors, then each of these tensors
  must have a matching first (unpack) dimension.  The signature of `fn` may
  match the structure of `elems`.  That is, if `elems` is
  `(t1, [t2, t3, [t4, t5]])`, then an appropriate signature for `fn` is:
  `fn = lambda (t1, [t2, t3, [t4, t5]]):`.

  Furthermore, `fn` may emit a different structure than its input.  For example,
  `fn` may look like: `fn = lambda t1: return (t1 + 1, t1 - 1)`.  In this case,
  the `dtype` parameter is not optional: `dtype` must be a type or (possibly
  nested) tuple of types matching the output of `fn`.

  To apply a functional operation to the nonzero elements of a SparseTensor
  one of the following methods is recommended. First, if the function is
  expressible as TensorFlow ops, use

  ```python
    result = SparseTensor(input.indices, fn(input.values), input.dense_shape)
  ```

  If, however, the function is not expressible as a TensorFlow op, then use

  ```python
  result = SparseTensor(
    input.indices, map_fn(fn, input.values), input.dense_shape)
  ```

  instead.

  When executing eagerly, map_fn does not execute in parallel even if
  `parallel_iterations` is set to a value > 1. You can still get the
  performance benefits of running a function in parallel by using the
  `tf.contrib.eager.defun` decorator,

  ```python
  # Assume the function being used in map_fn is fn.
  # To ensure map_fn calls fn in parallel, use the defun decorator.
  @tf.contrib.eager.defun
  def func(tensor):
    return tf.map_fn(fn, tensor)
  ```

  Note that if you use the defun decorator, any non-TensorFlow Python code
  that you may have written in your function won't get executed. See
  `tf.contrib.eager.defun` for more details. The recommendation would be to
  debug without defun but switch to defun to get performance benefits of
  running map_fn in parallel.

  Args:
    fn: The callable to be performed.  It accepts one argument, which will have
      the same (possibly nested) structure as `elems`.  Its output must have the
      same structure as `dtype` if one is provided, otherwise it must have the
      same structure as `elems`.
    elems: A tensor or (possibly nested) sequence of tensors, each of which will
      be unpacked along their first dimension.  The nested sequence of the
      resulting slices will be applied to `fn`.
    dtype: (optional) The output type(s) of `fn`.  If `fn` returns a structure
      of Tensors differing from the structure of `elems`, then `dtype` is not
      optional and must have the same structure as the output of `fn`. Use
      `RaggedTensorType` to declare an output of type `RaggedTensor`.
    parallel_iterations: (optional) The number of iterations allowed to run in
      parallel. When graph building, the default value is 10. While executing
      eagerly, the default value is set to 1.
    back_prop: (optional) True enables support for back propagation.
    swap_memory: (optional) True enables GPU-CPU memory swapping.
    infer_shape: (optional) False disables tests for consistent output shapes.
    name: (optional) Name prefix for the returned tensors.

  Returns:
    A possibly nested sequence of potentially ragged tensors.  Each
    tensor packs the results of applying `fn` to tensors unpacked from `elems`
    along the first dimension, from first to last.

  Raises:
    TypeError: if `fn` is not callable or the structure of the output of
      `fn` and `dtype` do not match, or if elems is a SparseTensor.
    ValueError: if the lengths of the output of `fn` and `dtype` do not match.

  #### Examples:

    ```python
    elems = np.array([1, 2, 3, 4, 5, 6])
    squares = map_fn(lambda x: x * x, elems)
    # squares == [1, 4, 9, 16, 25, 36]
    ```

    ```python
    elems = (np.array([1, 2, 3]), np.array([-1, 1, -1]))
    alternate = map_fn(lambda x: x[0] * x[1], elems, dtype=tf.int64)
    # alternate == [-1, 2, -3]
    ```

    ```python
    elems = np.array([1, 2, 3])
    alternates = map_fn(lambda x: (x, -x), elems, dtype=(tf.int64, tf.int64))
    # alternates[0] == [1, 2, 3]
    # alternates[1] == [-1, -2, -3]
    ```

    ```python
    elems=ragged.constant([[1, 2, 3], [4, 5], [6, 7]])
    mean = map_fn(tf.reduce_mean, elems)
    # mean == [2, 4, 6]
    ```

    ```python
    elems=ragged.constant([[1, 2, 3], [4, 5], [6, 7]], dtype=tf.int64)
    out = map_fn(fn=lambda x: x+1, elems,
      dtype=ragged.RaggedTensorType(type=tf.int64, ragged_rank=0))
    # out = tf.ragged.constant([[2, 3, 4], [5, 6], [7, 8]])
    ```
  """
    if not callable(fn):
        raise TypeError("fn must be callable.")

    if isinstance(elems, sparse_tensor.SparseTensor):
        raise TypeError(
            "To perform a map on the values of a sparse tensor use either "
            " SparseTensor(input.indices, fn(input.values), input.dense_shape) or "
            " SparseTensor(input.indices, map_fn(fn, input.values), "
            "input.dense_shape)")

    in_graph_mode = not context.executing_eagerly()
    # Set the default number of parallel_iterations depending on graph/eager mode.
    if in_graph_mode and not parallel_iterations:
        parallel_iterations = 10
    elif not in_graph_mode and not parallel_iterations:
        parallel_iterations = 1

    if not in_graph_mode and parallel_iterations > 1:
        logging.log_first_n(
            logging.WARN, "Setting parallel_iterations > 1 has no "
            "effect when executing eagerly. Consider calling map_fn"
            " with tf.contrib.eager.defun to execute fn in "
            "parallel.", 1)
        parallel_iterations = 1

    input_is_sequence = nest.is_sequence(elems)
    input_flatten = lambda x: nest.flatten(x) if input_is_sequence else [x]

    def input_pack(x):
        return nest.pack_sequence_as(elems, x) if input_is_sequence else x[0]

    elems_flat = input_flatten(elems)
    elems_flat = ragged_tensor.match_row_splits_dtypes(*elems_flat)

    with ops.name_scope(name, "map", elems_flat):
        # TODO(akshayka): Remove the in_graph_mode check once caching devices are
        # supported in Eager
        if in_graph_mode:
            # Any get_variable calls in fn will cache the first call locally
            # and not issue repeated network I/O requests for each iteration.
            varscope = vs.get_variable_scope()
            varscope_caching_device_was_none = False
            if varscope.caching_device is None:
                # TODO(ebrevdo): Change to using colocate_with here and in other
                # methods.
                varscope.set_caching_device(lambda op: op.device)
                varscope_caching_device_was_none = True

        elems_flat = [
            ragged_tensor.convert_to_tensor_or_ragged_tensor(elem, name="elem")
            for elem in elems_flat
        ]

        # We can either infer the output, or we can assume that it will be the same
        # as the input structure.
        dtype = dtype or input_pack([elem.dtype for elem in elems_flat])

        # Find the number of iterations, n may be known statically.
        if isinstance(elems_flat[0], ragged_tensor.RaggedTensor):
            n = elems_flat[0].nrows(out_type=dtypes.int32)
        else:
            static_shape = elems_flat[0].shape
            if static_shape.ndims is not None and static_shape.ndims < 1:
                if len(elems_flat) == 1:
                    raise ValueError(
                        "elems must be a 1+ dimensional Tensor, not a scalar")
                else:
                    raise ValueError(
                        "elements in elems must be 1+ dimensional Tensors, not scalars"
                    )
            n = (tensor_shape.dimension_value(static_shape[0])
                 or array_ops.shape(elems_flat[0])[0])

        n = math_ops.cast(n, dtype=dtypes.int32)
        # Create a flat list of TAs.

        # Flatten the dtype structure to a list.
        dtype_flat = nest.flatten(dtype)

        # decompose to components
        dtype_components = [_maybe_decompose_dtype(d) for d in dtype_flat]
        dtype_components_flat = nest.flatten(dtype_components)

        # Create TensorArrays.
        accs_ta = [
            tensor_array_ops.TensorArray(dtype=t,
                                         dynamic_size=False,
                                         infer_shape=infer_shape,
                                         size=n) for t in dtype_components_flat
        ]

        i = constant_op.constant(0, dtype=dtypes.int32)

        def compute(i, tas):
            """The loop body of map_fn.

      Args:
        i: the loop counter
        tas: the flat TensorArray accumulator list

      Returns:
        (i + 1, tas): the updated counter + updated TensorArrays

      Raises:
        TypeError: if dtype and packed_fn_values structure do not match
        ValueType: if dtype and packed_fn_values lengths do not match
      """
            # Get Tensors or RaggedTensors sliced at i, then pack it back to the
            # original structure.
            packed_values = input_pack(
                [elem_flat[i] for elem_flat in elems_flat])
            packed_fn_values = fn(packed_values)

            # Check that the structure of the output matches what was declared or
            # inferred.
            # nest.assert_same_structure(dtype or elems, packed_fn_values)

            # Flatten and decompose to a list of Tensors
            flat_fn_values = nest.flatten(packed_fn_values)

            # If we declared that we are expecting a RaggedTensor output, but we get a
            # Tensor output. We should try to convert it to a RaggedTensor.
            flat_fn_composite_tensors = list(
                _convert_declared(flat_fn_values, dtype_flat))

            flat_fn_components = [
                _maybe_decompose_tensor(t) for t in flat_fn_composite_tensors
            ]
            flat_fn_tensors = nest.flatten(flat_fn_components)

            # Write to TAs.
            tas = [
                ta.write(i, value)
                for (ta, value) in zip(tas, flat_fn_tensors)
            ]

            return (i + 1, tas)

        _, r_a = control_flow_ops.while_loop(
            lambda i, _: i < n,
            compute, (i, accs_ta),
            parallel_iterations=parallel_iterations,
            back_prop=back_prop,
            swap_memory=swap_memory,
            maximum_iterations=n)

        # TODO(akshayka): Remove the in_graph_mode check once caching devices are
        # supported in Eager
        if in_graph_mode and varscope_caching_device_was_none:
            varscope.set_caching_device(None)

        # Pack back into a list of components
        results_as_components = nest.pack_sequence_as(dtype_components, r_a)

        # Stack TensorArrays for Tensor outputs, and concat RaggedTensor outputs.
        def _stack_or_concat(e):
            if isinstance(e, _RaggedTensorComponents):
                return _concat_ragged_tensor_components(e)
            else:
                result = e.stack()
                return result

        results_flat_components = [
            _stack_or_concat(e) for e in results_as_components
        ]

        results_packed = [
            _maybe_recompose_tensor(c) for c in results_flat_components
        ]
        results_packed = nest.pack_sequence_as(dtype, results_packed)
        return results_packed
Esempio n. 24
0
def boolean_mask(data, mask, name=None):
  """Applies a boolean mask to `data` without flattening the mask dimensions.

  Returns a potentially ragged tensor that is formed by retaining the elements
  in `data` where the corresponding value in `mask` is `True`.

  * `output[a1...aA, i, b1...bB] = data[a1...aA, j, b1...bB]`

     Where `j` is the `i`th `True` entry of `mask[a1...aA]`.

  Note that `output` preserves the mask dimensions `a1...aA`; this differs
  from `tf.boolean_mask`, which flattens those dimensions.

  Args:
    data: A potentially ragged tensor.
    mask: A potentially ragged boolean tensor.  `mask`'s shape must be a prefix
      of `data`'s shape.  `rank(mask)` must be known statically.
    name: A name prefix for the returned tensor (optional).

  Returns:
    A potentially ragged tensor that is formed by retaining the elements in
    `data` where the corresponding value in `mask` is `True`.

    * `rank(output) = rank(data)`.
    * `output.ragged_rank = max(data.ragged_rank, rank(mask) - 1)`.

  Raises:
    ValueError: if `rank(mask)` is not known statically; or if `mask.shape` is
      not a prefix of `data.shape`.

  #### Examples:

  >>> # Aliases for True & False so data and mask line up.
  >>> T, F = (True, False)

  >>> tf.ragged.boolean_mask(  # Mask a 2D Tensor.
  ...     data=[[1, 2, 3], [4, 5, 6], [7, 8, 9]],
  ...     mask=[[T, F, T], [F, F, F], [T, F, F]]).to_list()
  [[1, 3], [], [7]]

  >>> tf.ragged.boolean_mask(  # Mask a 2D RaggedTensor.
  ...     tf.ragged.constant([[1, 2, 3], [4], [5, 6]]),
  ...     tf.ragged.constant([[F, F, T], [F], [T, T]])).to_list()
  [[3], [], [5, 6]]

  >>> tf.ragged.boolean_mask(  # Mask rows of a 2D RaggedTensor.
  ...     tf.ragged.constant([[1, 2, 3], [4], [5, 6]]),
  ...     tf.ragged.constant([True, False, True])).to_list()
  [[1, 2, 3], [5, 6]]
  """
  with ops.name_scope(name, 'RaggedMask', [data, mask]):
    # Convert inputs to tensors.
    data = ragged_tensor.convert_to_tensor_or_ragged_tensor(data, name='data')
    mask = ragged_tensor.convert_to_tensor_or_ragged_tensor(
        mask, dtypes.bool, name='mask')
    row_splits_dtype, (data, mask) = ragged_tensor.match_row_splits_dtypes(
        data, mask, return_dtype=True)

    # Get static rank of mask.
    if mask.shape.ndims is None:
      raise ValueError('mask.shape.ndims must be known statically.')
    elif mask.shape.ndims == 0:
      raise ValueError('mask cannot be scalar.')

    # If mask is ragged, then recurse with a non-ragged mask.
    if ragged_tensor.is_ragged(mask):
      if not ragged_tensor.is_ragged(data):
        data = ragged_tensor.RaggedTensor.from_tensor(
            data,
            ragged_rank=mask.ragged_rank,
            row_splits_dtype=mask.row_splits.dtype)
      # Check that mask.nested_row_splits is a prefix of
      # data.nested_row_splits.
      splits_list = [
          mask.nested_row_splits, data.nested_row_splits[:mask.ragged_rank]
      ]
      with ops.control_dependencies(
          ragged_util.assert_splits_match(splits_list)):
        # Strip off ragged `splits` until `mask` is non-ragged.  Keep the splits
        # that we strip off in `splits`, so we can add them back on after
        # we recursively mask the non-ragged data.
        splits = []
        while ragged_tensor.is_ragged(mask):
          if mask.shape.ndims > 2:
            splits.append(mask.row_splits)
          else:
            # Count the number of True mask values in each row to find the
            # lengths of the filtered rows; then convert to splits.
            int_mask = ragged_functional_ops.map_flat_values(
                math_ops.cast, mask, dtype=row_splits_dtype)
            masked_row_lengths = ragged_math_ops.reduce_sum(int_mask, axis=1)
            splits.append(ragged_util.lengths_to_splits(masked_row_lengths))
          mask = mask.values
          data = data.values

        # Recursively apply the nested non-ragged mask to the nested data.
        masked_values = boolean_mask(data, mask)

        # Add the ragged `splits` back to the result.
        masked_values = ragged_tensor.RaggedTensor.from_nested_row_splits(
            masked_values, splits, validate=False)

        return masked_values

    # If mask is non-ragged and has rank 1, and data is ragged, then build a
    # ragged tensor with the indicated rows.
    elif ragged_tensor.is_ragged(data) and mask.shape.ndims == 1:
      # Get the masked splits: first get the length of each row, then filter
      # out the rows that we are deleting, and convert that filtered set of
      # masks back to a splits tensor.
      lengths = data.row_lengths()
      masked_lengths = array_ops.boolean_mask(lengths, mask)
      masked_splits = ragged_util.lengths_to_splits(masked_lengths)

      # Get the masked values: first get row ids corresponding to each
      # value, then use tf.gather to build a boolean mask that's false for
      # values that come from rows that we are deleting, and use that mask to
      # construct the masked values tensor.
      segment_ids = segment_id_ops.row_splits_to_segment_ids(data.row_splits)
      segment_mask = array_ops.gather(mask, segment_ids)
      masked_values = boolean_mask(data.values, segment_mask)

      return ragged_tensor.RaggedTensor.from_row_splits(
          masked_values, masked_splits, validate=False)

    # If mask is non-ragged and has rank>1, then convert it to be ragged,
    # with a ragged rank matching data.
    if ragged_tensor.is_ragged(data):
      mask = ragged_tensor.RaggedTensor.from_tensor(
          mask,
          ragged_rank=min(data.ragged_rank, mask.shape.ndims - 1),
          row_splits_dtype=data.row_splits.dtype)
      return boolean_mask(data, mask)

    # Otherwise, data and mask are both `Tensor`s.
    else:
      # Apply `boolean_mask` to get the masked values.
      masked_values = array_ops.boolean_mask(data, mask)

      if mask.shape.ndims >= 2:
        # Add the innermost ragged dimension.  For each innermost cell, get the
        # number of values it contains.  Then flatten that to get a list of
        # cell lengths, and convert it to splits.  Finally, combine the splits
        # and values to get the innermost ragged tensor.
        masked_lengths = math_ops.count_nonzero(
            mask, axis=-1, dtype=row_splits_dtype)
        flattened_masked_lengths = array_ops.reshape(masked_lengths, [-1])
        masked_values = ragged_tensor.RaggedTensor.from_row_lengths(
            masked_values, flattened_masked_lengths, validate=False)

        # Wrap remaining ragged dimensions.
        if mask.shape.ndims > 2:
          mask_shape = array_ops.shape(mask, out_type=row_splits_dtype)
          split_size = math_ops.cumprod(mask_shape) + 1
          for dim in range(mask.shape.ndims - 3, -1, -1):
            elt_size = mask_shape[dim + 1]
            masked_splits = math_ops.range(split_size[dim]) * elt_size
            masked_values = ragged_tensor.RaggedTensor.from_row_splits(
                masked_values, masked_splits, validate=False)

      return masked_values
Esempio n. 25
0
def _ragged_segment_aggregate(unsorted_segment_op,
                              data,
                              segment_ids,
                              num_segments,
                              name=None):
  """Aggregates along segments of a RaggedTensor using `unsorted_segment_op`.

  Returns a RaggedTensor `output` with `num_segments` rows, where the row
  `output[i]` is formed by combining all rows of `data` whose corresponding
  `segment_id` is `i`.  The values in each row are combined using
  `unsorted_segment_op`.

  The length of the row `output[i]` will be the maximum of the lengths of
  all rows of `data` whose corresponding `segment_id` is `i`.  If no `data`
  rows correspond to a given segment ID, then the output row for that segment
  ID will be empty.

  Args:
    unsorted_segment_op: The tensorflow `op` that should be used to combine
      values in each row.  Must have the same signature and basic behavior as
      `unsorted_segment_sum`, `unsorted_segment_max`, etc.
    data: A `RaggedTensor` containing the values to be combined.
    segment_ids: A `Tensor` or `RaggedTensor`.  Must have type `int64` or
      `int32`.  `segment_ids.shape` must be a prefix of `data.shape`.
      `segment_ids` is not required to be sorted.
    num_segments: An `int32` or `int64` scalar.
    name: A name prefix for the returned tensor (optional).

  Returns:
    A `RaggedTensor` containing the aggregated values.  The returned tensor
    has the same dtype as `data`, and its shape is
    `[num_segments] + data.shape[segment_ids.rank:]`.
  Raises:
    ValueError: If segment_ids.shape is not a prefix of data.shape.
  """
  if not (ragged_tensor.is_ragged(data) or
          ragged_tensor.is_ragged(segment_ids)):
    return unsorted_segment_op(data, segment_ids, num_segments, name)

  with ops.name_scope(name, 'RaggedSegment',
                      [data, segment_ids, num_segments]) as name:
    data = ragged_tensor.convert_to_tensor_or_ragged_tensor(data, name='data')
    segment_ids = ragged_tensor.convert_to_tensor_or_ragged_tensor(
        segment_ids, name='segment_ids')
    data, segment_ids = ragged_tensor.match_row_splits_dtypes(data, segment_ids)
    if segment_ids.dtype not in (dtypes.int32, dtypes.int64):
      raise ValueError('segment_ids must have dtype int32 or int64.')

    if ragged_tensor.is_ragged(segment_ids):
      if not ragged_tensor.is_ragged(data):
        raise ValueError('segment_ids.shape must be a prefix of data.shape, '
                         'but segment_ids is ragged and data is not.')
      check_splits = check_ops.assert_equal(
          segment_ids.row_splits,
          data.row_splits,
          message='segment_ids.shape must be a prefix of data.shape')
      with ops.control_dependencies([check_splits]):
        return _ragged_segment_aggregate(unsorted_segment_op, data.values,
                                         segment_ids.values, num_segments, name)

    # Find the length of each row in data.  (shape=[data_nrows])
    data_row_lengths = data.row_splits[1:] - data.row_splits[:-1]

    # Find the length that each output row will have.  The length of the row
    # corresponding to segment `id` is `max(data_row_lengths[i])` where
    # `segment_ids[i]=id`.  (shape=[output_nrows])
    output_row_lengths = math_ops.maximum(
        math_ops.unsorted_segment_max(data_row_lengths, segment_ids,
                                      num_segments), 0)

    # Build the splits tensor for the output RaggedTensor.
    output_splits = array_ops.concat([
        array_ops.zeros([1], output_row_lengths.dtype),
        math_ops.cumsum(output_row_lengths)
    ],
                                     axis=0)

    # For each row in `data`, find the start & limit position where that row's
    # values will be aggregated in output.values.
    data_row_to_out_row_start = array_ops.gather(output_splits, segment_ids)
    data_row_to_out_row_limit = data_row_to_out_row_start + data_row_lengths

    # For each value in `data.values`, find the position where it will
    # aggregated in `output.values`.
    # Get the target output values index for each data values index.
    data_val_to_out_val_index = range(data_row_to_out_row_start,
                                      data_row_to_out_row_limit).values

    # Recursively aggregate the values.
    output_values = _ragged_segment_aggregate(unsorted_segment_op, data.values,
                                              data_val_to_out_val_index,
                                              output_splits[-1])
    return ragged_tensor.RaggedTensor.from_row_splits(
        output_values, output_splits, validate=False)