Esempio n. 1
0
def _gather(params, indices, axis, batch_dims):
  """Helper that implements the body for ragged gather().

  Assumes that `params` and `indices` have been converted to tensors or
  ragged tensors, and that `axis` and `batch_dims` have been normalized to
  be positive.  (So these conversions & normalizations can be skipped in
  recursive calls to _gather).

  Args:
    params: The tensor from which to gather values.
    indices: The indices of values to gather.
    axis: The axis in `params` to gather `indices` from.
    batch_dims: The number of batch dimensions.

  Returns:
    A potentially ragged tensor.
  """
  params_is_ragged = ragged_tensor.is_ragged(params)
  indices_is_ragged = ragged_tensor.is_ragged(indices)

  if not (params_is_ragged or indices_is_ragged):
    return array_ops.gather(params, indices, axis=axis, batch_dims=batch_dims)

  if batch_dims > 0:
    return _batch_gather(params, indices, axis, batch_dims)

  if axis > 0:
    return _axis_gather(params, indices, axis)

  if indices_is_ragged:
    return indices.with_values(_gather(params, indices.values, 0, 0))

  if indices.shape.ndims is None:
    raise ValueError('rank(indices) must be known statically')

  out_ragged_rank = indices.shape.ndims + len(params.nested_row_splits) - 1
  result = gen_ragged_array_ops.ragged_gather(
      indices=indices,
      params_dense_values=params.flat_values,
      params_nested_splits=params.nested_row_splits,
      OUTPUT_RAGGED_RANK=out_ragged_rank)

  result = ragged_tensor.RaggedTensor.from_nested_row_splits(
      result.output_dense_values, result.output_nested_splits, validate=False)

  # Inject uniform_row_lengths into the result RaggedTensors for dimensions
  # corresponding to dense outer dimensions of `indices`.
  # TODO(edloper): Change this to construct the result using RowPartition
  # objects instead, so we don't need to modify private variables.
  if indices.shape.ndims > 1:
    target = result
    indices_shape = array_ops.shape(indices, out_type=params.row_splits.dtype)
    shape_cumprod = math_ops.cumprod(indices_shape)
    for dim in range(indices.shape.ndims - 1):
      # pylint: disable=protected-access
      target._cached_nrows = shape_cumprod[dim]
      target._uniform_row_length = indices_shape[dim + 1]
      target = target.values

  return result
Esempio n. 2
0
def gather(params, indices, validate_indices=None, axis=0, batch_dims=0,
           name=None):
  """Gathers ragged slices from `params` axis `0` according to `indices`.

  Returns `RaggedTensor` output, such that:

  ```python
  output.shape = indices.shape + params.shape[1:]
  output.ragged_rank = indices.shape.ndims + params.ragged_rank
  output[i...j, d0...dn] = params[indices[i...j], d0...dn]
  ```

  `params` may be ragged.  `indices` may be ragged.
  `indices` must have dtype `int32` or `int64`. If any index is out of bounds,
  then an error is returned.

  Examples:

  ```python
  >>> params = tf.constant(['a', 'b', 'c', 'd', 'e'])
  >>> indices = tf.constant([3, 1, 2, 1, 0])
  >>> ragged_params = tf.ragged.constant([['a', 'b', 'c'], ['d'], [], ['e']])
  >>> ragged_indices = tf.ragged.constant([[3, 1, 2], [1], [], [0]])

  >>> print ragged.gather(params, ragged_indices)
  [['d', 'b', 'c'], ['b'], [], ['a']]

  >>> print ragged.gather(ragged_params, indices)
  [['e'], ['d'], [], ['d'], ['a', 'b', 'c']]

  >>> print ragged.gather(ragged_params, ragged_indices)
  [[['e'], ['d'], []], [['d']], [], [['a', 'b', 'c']]]
  ```

  Args:
    params: The potentially ragged tensor from which to gather values. Must be
      at least rank 1.
    indices: The potentially ragged tensor indicating which values to gather.
      Must have dtype `int32` or `int64`.  Values must be in the range `[0,
      params.shape[0]]`.
    validate_indices: Ignored.
    axis: Must be zero.
    batch_dims: Must be zero.
    name: A name for the operation (optional).

  Returns:
    A `RaggedTensor`, where `output.dtype=params.dtype` and
    `output.shape=indices.shape + params.shape[1:]` and
    `output.ragged_rank=indices.shape.ndims + params.ragged_rank`.

  Raises:
    ValueError: If indices.shape.ndims is not known statically.
  """
  del validate_indices
  if not isinstance(axis, int) or axis != 0:
    raise ValueError('axis != 0 is not supported for ragged gather yet.')
  if not isinstance(batch_dims, int) or batch_dims != 0:
    raise ValueError('batch_dims != 0 is not supported for ragged gather yet.')
  with ops.name_scope(name, 'RaggedGather', [params, indices]):
    params = ragged_tensor.convert_to_tensor_or_ragged_tensor(
        params, name='params')
    indices = ragged_tensor.convert_to_tensor_or_ragged_tensor(
        indices, name='indices')
    params, indices = ragged_tensor.match_row_splits_dtypes(params, indices)

    if ragged_tensor.is_ragged(indices):
      return indices.with_values(gather(params, indices.values))

    if not ragged_tensor.is_ragged(params):
      return array_ops.gather(params, indices)

    indices = ops.convert_to_tensor(indices)
    if indices.shape.ndims is None:
      raise ValueError('indices.shape.ndims must be known statically')

    result = gen_ragged_array_ops.ragged_gather(
        indices=indices,
        params_dense_values=params.flat_values,
        params_nested_splits=params.nested_row_splits,
        OUTPUT_RAGGED_RANK=indices.shape.ndims + len(params.nested_row_splits) -
        1)

    # Compose the RaggedTensor from splits & values.
    return ragged_tensor.RaggedTensor.from_nested_row_splits(
        result.output_dense_values, result.output_nested_splits, validate=False)
Esempio n. 3
0
def gather(params, indices, validate_indices=None, axis=0, name=None):
  """Gathers ragged slices from `params` axis `0` according to `indices`.

  Returns `RaggedTensor` output, such that:

  ```python
  output.shape = indices.shape + params.shape[1:]
  output.ragged_rank = indices.shape.ndims + params.ragged_rank
  output[i...j, d0...dn] = params[indices[i...j], d0...dn]
  ```

  `params` may be ragged.  `indices` may be ragged.
  `indices` must have dtype `int32` or `int64`. If any index is out of bounds,
  then an error is returned.

  Examples:

  ```python
  >>> params = tf.constant(['a', 'b', 'c', 'd', 'e'])
  >>> indices = tf.constant([3, 1, 2, 1, 0])
  >>> ragged_params = tf.ragged.constant([['a', 'b', 'c'], ['d'], [], ['e']])
  >>> ragged_indices = tf.ragged.constant([[3, 1, 2], [1], [], [0]])

  >>> print ragged.gather(params, ragged_indices)
  [['d', 'b', 'c'], ['b'], [], ['a']]

  >>> print ragged.gather(ragged_params, indices)
  [['e'], ['d'], [], ['d'], ['a', 'b', 'c']]

  >>> print ragged.gather(ragged_params, ragged_indices)
  [[['e'], ['d'], []], [['d']], [], [['a', 'b', 'c']]]
  ```

  Args:
    params: The potentially ragged tensor from which to gather values. Must be
      at least rank 1.
    indices: The potentially ragged tensor indicating which values to gather.
      Must have dtype `int32` or `int64`.  Values must be in the range `[0,
      params.shape[0]]`.
    validate_indices: Ignored.
    axis: Must be zero.
    name: A name for the operation (optional).

  Returns:
    A `RaggedTensor`, where `output.dtype=params.dtype` and
    `output.shape=indices.shape + params.shape[1:]` and
    `output.ragged_rank=indices.shape.ndims + params.ragged_rank`.

  Raises:
    ValueError: If indices.shape.ndims is not known statically.
  """
  del validate_indices
  if not isinstance(axis, int) or axis != 0:
    raise ValueError('axis>0 is not supported for ragged gather yet.')
  with ops.name_scope(name, 'RaggedGather', [params, indices]):
    params = ragged_tensor.convert_to_tensor_or_ragged_tensor(
        params, name='params')
    indices = ragged_tensor.convert_to_tensor_or_ragged_tensor(
        indices, name='indices')

    if ragged_tensor.is_ragged(indices):
      return indices.with_values(gather(params, indices.values))

    if not ragged_tensor.is_ragged(params):
      return array_ops.gather(params, indices)

    indices = ops.convert_to_tensor(indices)
    if indices.shape.ndims is None:
      raise ValueError('indices.shape.ndims must be known statically')

    result = gen_ragged_array_ops.ragged_gather(
        indices=indices,
        params_dense_values=params.flat_values,
        params_nested_splits=params.nested_row_splits,
        OUTPUT_RAGGED_RANK=indices.shape.ndims + len(params.nested_row_splits) -
        1)

    # Compose the RaggedTensor from splits & values.
    return ragged_tensor.RaggedTensor.from_nested_row_splits(
        result.output_dense_values, result.output_nested_splits)