def testTensorParamsAndTensorIndices(self):
     params = ['a', 'b', 'c', 'd', 'e']
     indices = [2, 0, 2, 1]
     self.assertAllEqual(ragged_gather_ops.gather(params, indices),
                         [b'c', b'a', b'c', b'b'])
     self.assertIsInstance(ragged_gather_ops.gather(params, indices),
                           ops.Tensor)
Exemple #2
0
def _broadcast_ragged_targets_for_overlap(target_start, target_limit,
                                          source_splits):
    """Repeats target indices for each source item in the same batch.

  Args:
    target_start: `<int>[batch_size, (target_size)]`
    target_limit: `<int>[batch_size, (target_size)]`
    source_splits: `<int64>[batch_size, (source_size+1)]`

  Returns:
    `<int>[batch_size, (source_size), (target_size)]`.
    A tuple of ragged tensors `(tiled_target_start, tiled_target_limit)` where:

    * `tiled_target_start[b, s, t] = target_start[b, t]`
    * `tiled_target_limit[b, s, t] = target_limit[b, t]`
  """
    source_batch_ids = segment_id_ops.row_splits_to_segment_ids(source_splits)

    target_start = ragged_tensor.RaggedTensor.from_value_rowids(
        ragged_gather_ops.gather(target_start, source_batch_ids),
        source_batch_ids)
    target_limit = ragged_tensor.RaggedTensor.from_value_rowids(
        ragged_gather_ops.gather(target_limit, source_batch_ids),
        source_batch_ids)
    return (target_start, target_limit)
Exemple #3
0
 def testReturnNbestAndDetokenize(self):
   sp = SentencepieceTokenizer(
       self.model, nbest_size=2, out_type=dtypes.int32, return_nbest=True)
   sentences = ['I love carpet', 'Never tell me the odds']
   result = sp.tokenize(ragged_factory_ops.constant(sentences))
   detokenized = sp.detokenize(result)
   self.assertAllEqual(
       _utf8(sentences), ragged_gather_ops.gather(detokenized, [0, 2]))
   self.assertAllEqual(
       _utf8(sentences), ragged_gather_ops.gather(detokenized, [1, 3]))
 def testDocStringExamples(self):
     params = constant_op.constant(['a', 'b', 'c', 'd', 'e'])
     indices = constant_op.constant([3, 1, 2, 1, 0])
     ragged_params = ragged_factory_ops.constant([['a', 'b', 'c'], ['d'],
                                                  [], ['e']])
     ragged_indices = ragged_factory_ops.constant([[3, 1, 2], [1], [], [0]])
     self.assertAllEqual(ragged_gather_ops.gather(params, ragged_indices),
                         [[b'd', b'b', b'c'], [b'b'], [], [b'a']])
     self.assertAllEqual(ragged_gather_ops.gather(ragged_params, indices),
                         [[b'e'], [b'd'], [], [b'd'], [b'a', b'b', b'c']])
     self.assertAllEqual(
         ragged_gather_ops.gather(ragged_params, ragged_indices),
         [[[b'e'], [b'd'], []], [[b'd']], [], [[b'a', b'b', b'c']]])
Exemple #5
0
 def testOutOfBoundsError(self):
   tensor_params = ['a', 'b', 'c']
   tensor_indices = [0, 1, 2]
   ragged_params = ragged_factory_ops.constant([['a', 'b'], ['c']])
   ragged_indices = ragged_factory_ops.constant([[0, 3]])
   with self.assertRaisesRegexp(errors.InvalidArgumentError,
                                r'indices\[1\] = 3 is not in \[0, 3\)'):
     self.evaluate(ragged_gather_ops.gather(tensor_params, ragged_indices))
   with self.assertRaisesRegexp(errors.InvalidArgumentError,
                                r'indices\[2\] = 2 is not in \[0, 2\)'):
     self.evaluate(ragged_gather_ops.gather(ragged_params, tensor_indices))
   with self.assertRaisesRegexp(errors.InvalidArgumentError,
                                r'indices\[1\] = 3 is not in \[0, 2\)'):
     self.evaluate(ragged_gather_ops.gather(ragged_params, ragged_indices))
def batch_gather(params: ragged_tensor.RaggedOrDense,
                 indices: ragged_tensor.RaggedOrDense,
                 name=None):
    """Gathers slices from `params` according to `indices` with batch dims.

  This operation is similar to `gather`, but it assumes that the leading `N`
  dimensions of `indices` and `params` are batch dimensions, and performs a
  gather within each batch.  In particular, when using this operation with `N`
  batch dimensions `B1...BN`:

  * `indices` has shape `[B1...BN, I]`
  * `params` has shape `[B1...BN, P1...PM]`.
  * `result` has shape `[B1...BN, I, P2...PM]`.
  * `result[b1...bN, i, p2...pM] =
    params[b1...bN, indices[b1...bN, i], p2...pM]`

  Args:
    params: A potentially ragged tensor with shape `[B1...BN, P1...PM]` (`N>=0`,
      `M>0`).
    indices: A potentially ragged tensor with shape `[B1...BN, I]` (`N>=0`).
    name: A name for the operation (optional).

  Returns:
    A potentially ragged tensor with shape `[B1...BN, I, P2...PM]`.
    `result.ragged_rank = max(indices.ragged_rank, params.ragged_rank)`.

  #### Example:

  >>> params = tf.ragged.constant([['a', 'b', 'c'], ['d'], [], ['e']])
  >>> indices = tf.ragged.constant([[1, 2, 0], [], [], [0, 0]])
  >>> tf.compat.v1.batch_gather(params, indices)
  <tf.RaggedTensor [[b'b', b'c', b'a'], [], [], [b'e', b'e']]>
  """
    return ragged_gather_ops.gather(params, indices, batch_dims=-1, name=name)
Exemple #7
0
 def testRaggedParamsAndTensorIndices(self):
   params = ragged_factory_ops.constant([['a', 'b'], ['c', 'd', 'e'], ['f'],
                                         [], ['g']])
   indices = [2, 0, 2, 1]
   self.assertAllEqual(
       ragged_gather_ops.gather(params, indices),
       [[b'f'], [b'a', b'b'], [b'f'], [b'c', b'd', b'e']])
Exemple #8
0
  def testGradient(self,
                   params,
                   indices,
                   expected_out,
                   out_grad,
                   expected_grad,
                   params_ragged_rank=None):
    """Tests that ragged_gather generates the right gradient.

    Args:
      params: The `params` that should be passed to `gather`.
      indices: The `indices` that should be passed to `gather`.
      expected_out: The expected value of `gather(params, indices)`.
        `expected_out.shape = indices.shape + params.shape[1:]`.
      out_grad: The value that should be fed in as the gradient for `out`
        when testing the gradient of `ragged_gather`.  Must have the same
        shape as `expected_out`.
      expected_grad: The expected gradient for that should be returned for
        `params`.  Must have hte same shape as `params`.
      params_ragged_rank: The ragged_rank of `params`.
    """
    if context.executing_eagerly():
      return

    params = ragged_factory_ops.constant(
        params, dtype=dtypes.float32, ragged_rank=params_ragged_rank)
    indices = constant_op.constant(indices, dtype=dtypes.int32)
    out_ragged_rank = params.ragged_rank + indices.shape.ndims - 1
    out_grad = ragged_factory_ops.constant(
        out_grad, dtype=dtypes.float32, ragged_rank=out_ragged_rank)
    expected_out = ragged_factory_ops.constant(
        expected_out, dtype=dtypes.float32, ragged_rank=out_ragged_rank)
    expected_grad = ragged_factory_ops.constant(
        expected_grad,
        dtype=dtypes.float32,
        ragged_rank=params.ragged_rank)

    out = ragged_gather_ops.gather(params, indices)
    self.assertAllClose(out, expected_out)

    grads = gradients_impl.gradients(
        out.flat_values,
        (params.nested_row_splits + (params.flat_values, indices,)),
        out_grad.flat_values)
    param_nested_splits_grads = grads[:-2]
    params_flat_values_grad = grads[-2]
    indices_grad = grads[-1]
    self.assertEqual(indices_grad, None)
    for splits_grad in param_nested_splits_grads:
      self.assertEqual(splits_grad, None)

    # The gradient generates an IndexedSlices; convert back to a normal Tensor.
    self.assertIsInstance(params_flat_values_grad, indexed_slices.IndexedSlices)
    params_flat_values_grad = ops.convert_to_tensor(params_flat_values_grad)

    params_grad = params.with_flat_values(params_flat_values_grad)
    self.assertAllClose(params_grad, expected_grad, atol=2e-6, rtol=2e-6)
Exemple #9
0
def _ragged_gather_v1(params, indices, validate_indices=None, name=None,
                      axis=0, batch_dims=0):
  return ragged_gather_ops.gather(
      params=params,
      indices=indices,
      validate_indices=validate_indices,
      axis=axis,
      batch_dims=batch_dims,
      name=name)
Exemple #10
0
def _ragged_gather_v1(params, indices, validate_indices=None, name=None,
                      axis=0, batch_dims=0):
  return ragged_gather_ops.gather(
      params=params,
      indices=indices,
      validate_indices=validate_indices,
      axis=axis,
      batch_dims=batch_dims,
      name=name)
Exemple #11
0
 def testRaggedParamsAndRaggedIndices(self):
   params = ragged_factory_ops.constant([['a', 'b'], ['c', 'd', 'e'], ['f'],
                                         [], ['g']])
   indices = ragged_factory_ops.constant([[2, 1], [1, 2, 0], [3]])
   self.assertAllEqual(
       ragged_gather_ops.gather(params, indices),
       [[[b'f'], [b'c', b'd', b'e']],                # [[p[2], p[1]      ],
        [[b'c', b'd', b'e'], [b'f'], [b'a', b'b']],  #  [p[1], p[2], p[0]],
        [[]]]                                        #  [p[3]            ]]
   )  # pyformat: disable
Exemple #12
0
 def test3DRaggedParamsAnd2DTensorIndices(self):
   params = ragged_factory_ops.constant([[['a', 'b'], []],
                                         [['c', 'd'], ['e'], ['f']], [['g']]])
   indices = [[1, 2], [0, 1], [2, 2]]
   self.assertAllEqual(
       ragged_gather_ops.gather(params, indices),
       [[[[b'c', b'd'], [b'e'], [b'f']], [[b'g']]],            # [[p1, p2],
        [[[b'a', b'b'], []], [[b'c', b'd'], [b'e'], [b'f']]],  #  [p0, p1],
        [[[b'g']], [[b'g']]]]                                  #  [p2, p2]]
   )  # pyformat: disable
Exemple #13
0
 def testTensorParamsAnd4DRaggedIndices(self):
   indices = ragged_factory_ops.constant(
       [[[[3, 4], [0, 6]], []], [[[2, 1], [1, 0]], [[2, 5]], [[2, 3]]],
        [[[1, 0]]]],  # pyformat: disable
       ragged_rank=2,
       inner_shape=(2,))
   params = ['a', 'b', 'c', 'd', 'e', 'f', 'g']
   self.assertAllEqual(
       ragged_gather_ops.gather(params, indices),
       [[[[b'd', b'e'], [b'a', b'g']], []],
        [[[b'c', b'b'], [b'b', b'a']], [[b'c', b'f']], [[b'c', b'd']]],
        [[[b'b', b'a']]]])  # pyformat: disable
Exemple #14
0
def _ragged_stack_concat_axis_1(rt_inputs, stack_values):
    """Helper function to concatenate or stack ragged tensors along axis 1.

  Args:
    rt_inputs: A list of RaggedTensors, all with the same rank and ragged_rank.
    stack_values: Boolean.  If true, then stack values; otherwise, concatenate
      them.

  Returns:
    A RaggedTensor.
  """
    num_inputs = len(rt_inputs)

    rt_nrows = rt_inputs[0].nrows()
    nrows_msg = 'Input tensors have incompatible shapes.'
    nrows_checks = [
        check_ops.assert_equal(rt.nrows(), rt_nrows, message=nrows_msg)
        for rt in rt_inputs[1:]
    ]

    with ops.control_dependencies(nrows_checks):
        # Concatentate the inputs together to put them in a single ragged tensor.
        concatenated_rt = _ragged_stack_concat_axis_0(rt_inputs,
                                                      stack_values=False)

        # Use ragged.gather to permute the rows of concatenated_rt.  In particular,
        #   permuted_rt = [rt_inputs[0][0], ..., rt_inputs[N][0],
        #                  rt_inputs[0][1], ..., rt_inputs[N][1],
        #                      ...,
        #                  rt_inputs[0][M], ..., rt_input[N][M]]
        # where `N=num_inputs-1` and `M=rt_nrows-1`.
        row_indices = math_ops.range(rt_nrows * num_inputs)
        row_index_matrix = array_ops.reshape(row_indices, [num_inputs, -1])
        transposed_row_index_matrix = array_ops.transpose(row_index_matrix)
        row_permutation = array_ops.reshape(transposed_row_index_matrix, [-1])
        permuted_rt = ragged_gather_ops.gather(concatenated_rt,
                                               row_permutation)

        if stack_values:
            # Add a new splits tensor to group together the values.
            stack_splits = math_ops.range(0, rt_nrows * num_inputs + 1,
                                          num_inputs)
            _copy_row_shape(rt_inputs, stack_splits)
            return ragged_tensor.RaggedTensor.from_row_splits(permuted_rt,
                                                              stack_splits,
                                                              validate=False)
        else:
            # Merge together adjacent rows by dropping the row-split indices that
            # separate them.
            concat_splits = permuted_rt.row_splits[::num_inputs]
            _copy_row_shape(rt_inputs, concat_splits)
            return ragged_tensor.RaggedTensor.from_row_splits(
                permuted_rt.values, concat_splits, validate=False)
Exemple #15
0
def _build_ragged_tensor_from_value_ranges(starts, limits, step, values):
    """Returns a `RaggedTensor` containing the specified sequences of values.

  Returns a RaggedTensor `output` where:

  ```python
  output.shape[0] = starts.shape[0]
  output[i] = values[starts[i]:limits[i]:step]
  ```

  Requires that `starts.shape == limits.shape` and
  `0 <= starts[i] <= limits[i] <= values.shape[0]`.

  Args:
    starts: 1D integer Tensor specifying the start indices for the sequences of
      values to include.
    limits: 1D integer Tensor specifying the limit indices for the sequences of
      values to include.
    step: Integer value specifying the step size for strided slices.
    values: The set of values to select from.

  Returns:
    A `RaggedTensor`.

  Raises:
    ValueError: Until the prerequisite ops are checked in.
  """
    # Use `ragged_range` to get the index of each value we should include.
    if step is None:
        step = 1
    step = ops.convert_to_tensor(step, name="step")
    if step.dtype.is_integer:
        step = math_ops.cast(step, starts.dtype)
    else:
        raise TypeError("slice strides must be integers or None")
    value_indices = ragged_math_ops.range(starts,
                                          limits,
                                          step,
                                          row_splits_dtype=starts.dtype)

    # Use `ragged_gather` or `array_ops.gather` to collect the values.
    if isinstance(values, ragged_tensor.RaggedTensor):
        gathered_values = ragged_gather_ops.gather(
            params=values, indices=value_indices.values)
    else:
        gathered_values = array_ops.gather(params=values,
                                           indices=value_indices.values)

    # Assemble the RaggedTensor from splits & values.
    return value_indices.with_values(gathered_values)
def _ragged_stack_concat_axis_1(rt_inputs, stack_values):
  """Helper function to concatenate or stack ragged tensors along axis 1.

  Args:
    rt_inputs: A list of RaggedTensors, all with the same rank and ragged_rank.
    stack_values: Boolean.  If true, then stack values; otherwise, concatenate
      them.

  Returns:
    A RaggedTensor.
  """
  num_inputs = len(rt_inputs)

  rt_nrows = rt_inputs[0].nrows()
  nrows_msg = 'Input tensors have incompatible shapes.'
  nrows_checks = [
      check_ops.assert_equal(rt.nrows(), rt_nrows, message=nrows_msg)
      for rt in rt_inputs[1:]
  ]

  with ops.control_dependencies(nrows_checks):
    # Concatentate the inputs together to put them in a single ragged tensor.
    concatenated_rt = _ragged_stack_concat_axis_0(rt_inputs, stack_values=False)

    # Use ragged.gather to permute the rows of concatenated_rt.  In particular,
    #   permuted_rt = [rt_inputs[0][0], ..., rt_inputs[N][0],
    #                  rt_inputs[0][1], ..., rt_inputs[N][1],
    #                      ...,
    #                  rt_inputs[0][M], ..., rt_input[N][M]]
    # where `N=num_inputs-1` and `M=rt_nrows-1`.
    row_indices = math_ops.range(rt_nrows * num_inputs)
    row_index_matrix = array_ops.reshape(row_indices, [num_inputs, -1])
    transposed_row_index_matrix = array_ops.transpose(row_index_matrix)
    row_permutation = array_ops.reshape(transposed_row_index_matrix, [-1])
    permuted_rt = ragged_gather_ops.gather(concatenated_rt, row_permutation)

    if stack_values:
      # Add a new splits tensor to group together the values.
      stack_splits = math_ops.range(0, rt_nrows * num_inputs + 1, num_inputs)
      _copy_row_shape(rt_inputs, stack_splits)
      return ragged_tensor.RaggedTensor.from_row_splits(
          permuted_rt, stack_splits, validate=False)
    else:
      # Merge together adjacent rows by dropping the row-split indices that
      # separate them.
      concat_splits = permuted_rt.row_splits[::num_inputs]
      _copy_row_shape(rt_inputs, concat_splits)
      return ragged_tensor.RaggedTensor.from_row_splits(
          permuted_rt.values, concat_splits, validate=False)
def _build_ragged_tensor_from_value_ranges(starts, limits, step, values):
  """Returns a `RaggedTensor` containing the specified sequences of values.

  Returns a RaggedTensor `output` where:

  ```python
  output.shape[0] = starts.shape[0]
  output[i] = values[starts[i]:limits[i]:step]
  ```

  Requires that `starts.shape == limits.shape` and
  `0 <= starts[i] <= limits[i] <= values.shape[0]`.

  Args:
    starts: 1D integer Tensor specifying the start indices for the sequences of
      values to include.
    limits: 1D integer Tensor specifying the limit indices for the sequences of
      values to include.
    step: Integer value specifying the step size for strided slices.
    values: The set of values to select from.

  Returns:
    A `RaggedTensor`.

  Raises:
    ValueError: Until the prerequisite ops are checked in.
  """
  # Use `ragged_range` to get the index of each value we should include.
  if step is None:
    step = 1
  step = ops.convert_to_tensor(step, name="step")
  if step.dtype.is_integer:
    step = math_ops.cast(step, dtypes.int64)
  else:
    raise TypeError("slice strides must be integers or None")
  value_indices = ragged_math_ops.range(starts, limits, step)

  # Use `ragged_gather` or `array_ops.gather` to collect the values.
  if isinstance(values, ragged_tensor.RaggedTensor):
    gathered_values = ragged_gather_ops.gather(
        params=values, indices=value_indices.values)
  else:
    gathered_values = array_ops.gather(
        params=values, indices=value_indices.values)

  # Assemble the RaggedTensor from splits & values.
  return value_indices.with_values(gathered_values)
 def testRaggedGather(self,
                      params,
                      indices,
                      expected,
                      axis=None,
                      batch_dims=0,
                      params_ragged_rank=None,
                      indices_ragged_rank=None):
     params = ragged_factory_ops.constant(params,
                                          ragged_rank=params_ragged_rank)
     indices = ragged_factory_ops.constant(indices,
                                           ragged_rank=indices_ragged_rank)
     actual = ragged_gather_ops.gather(params,
                                       indices,
                                       axis=axis,
                                       batch_dims=batch_dims)
     self.assertAllEqual(actual, self._str_to_bytes(expected))
    def testMatchesDenseGather(self,
                               params_shape,
                               indices_shape,
                               axis=None,
                               batch_dims=0):
        # Build random params & indices matrics w/ the expected shapes.
        if axis is None:
            axis = batch_dims
        params = np.random.randint(100, size=params_shape, dtype=np.int32)
        indices = np.random.randint(params_shape[axis],
                                    size=indices_shape,
                                    dtype=np.int32)

        # Use array_ops.gather to get the expected value.
        expected = array_ops.gather(params,
                                    indices,
                                    axis=axis,
                                    batch_dims=batch_dims)

        # Build ragged tensors with varying ragged_ranks from params & axis.
        params_tensors = [params] + [
            ragged_tensor.RaggedTensor.from_tensor(params, ragged_rank=i)
            for i in range(1, len(params_shape))
        ]
        indices_tensors = [indices] + [
            ragged_tensor.RaggedTensor.from_tensor(indices, ragged_rank=i)
            for i in range(1, len(indices_shape))
        ]

        # For each combination of params & axis tensors, check that
        # ragged_gather_ops.gather matches array_ops.gather.
        for params_tensor in params_tensors:
            for indices_tensor in indices_tensors:
                actual = ragged_gather_ops.gather(params_tensor,
                                                  indices_tensor,
                                                  axis=axis,
                                                  batch_dims=batch_dims)
                if isinstance(actual, ragged_tensor.RaggedTensor):
                    actual = actual.to_tensor()
                self.assertAllEqual(
                    expected, actual,
                    'params.ragged_rank=%s, indices.ragged_rank=%s' %
                    (getattr(params_tensor, 'ragged_rank',
                             0), getattr(indices_tensor, 'ragged_rank', 0)))
Exemple #20
0
def _broadcast_ragged_sources_for_overlap(source_start, source_limit,
                                          target_splits):
    """Repeats source indices for each target item in the same batch.

  Args:
    source_start: `<int>[batch_size, (source_size)]`
    source_limit: `<int>[batch_size, (source_size)]`
    target_splits: `<int64>[batch_size, (target_size+1)]`

  Returns:
    `<int>[batch_size, (source_size), (target_size)]`.
    A tuple of tensors `(tiled_source_start, tiled_source_limit)` where:

    * `tiled_target_start[b, s, t] = source_start[b, s]`
    * `tiled_target_limit[b, s, t] = source_limit[b, s]`
  """
    source_splits = source_start.row_splits
    target_rowlens = target_splits[1:] - target_splits[:-1]
    source_batch_ids = segment_id_ops.row_splits_to_segment_ids(source_splits)

    # <int64>[sum(source_size[b] for b in range(batch_size))]
    # source_repeats[i] is the number of target spans in the batch that contains
    # source span i.  We need to add a new ragged dimension that repeats each
    # source span this number of times.
    source_repeats = ragged_gather_ops.gather(target_rowlens, source_batch_ids)

    # <int64>[sum(source_size[b] for b in range(batch_size)) + 1]
    # The row_splits tensor for the inner ragged dimension of the result tensors.
    inner_splits = array_ops.concat([[0], math_ops.cumsum(source_repeats)],
                                    axis=0)

    # <int64>[sum(source_size[b] * target_size[b] for b in range(batch_size))]
    # Indices for gathering source indices.
    source_indices = segment_id_ops.row_splits_to_segment_ids(inner_splits)

    source_start = ragged_tensor.RaggedTensor.from_nested_row_splits(
        array_ops.gather(source_start.values, source_indices),
        [source_splits, inner_splits])
    source_limit = ragged_tensor.RaggedTensor.from_nested_row_splits(
        array_ops.gather(source_limit.values, source_indices),
        [source_splits, inner_splits])

    return source_start, source_limit
Exemple #21
0
def _elementwise_where(condition, x, y):
  """Ragged version of tf.where(condition, x, y)."""
  condition_is_ragged = isinstance(condition, ragged_tensor.RaggedTensor)
  x_is_ragged = isinstance(x, ragged_tensor.RaggedTensor)
  y_is_ragged = isinstance(y, ragged_tensor.RaggedTensor)

  if not (condition_is_ragged or x_is_ragged or y_is_ragged):
    return array_ops.where(condition, x, y)

  elif condition_is_ragged and x_is_ragged and y_is_ragged:
    return ragged_functional_ops.map_flat_values(array_ops.where, condition, x,
                                                 y)
  elif not condition_is_ragged:
    # Concatenate x and y, and then use `gather` to assemble the selected rows.
    condition.shape.assert_has_rank(1)
    x_nrows = _nrows(x)
    x_and_y = ragged_concat_ops.concat([x, y], axis=0)
    indices = array_ops.where(condition, math_ops.range(x_nrows),
                              x_nrows + math_ops.range(_nrows(y)))
    return ragged_gather_ops.gather(x_and_y, indices)

  else:
    raise ValueError('Input shapes do not match.')
def _elementwise_where(condition, x, y):
    """Ragged version of tf.where(condition, x, y)."""
    condition_is_ragged = isinstance(condition, ragged_tensor.RaggedTensor)
    x_is_ragged = isinstance(x, ragged_tensor.RaggedTensor)
    y_is_ragged = isinstance(y, ragged_tensor.RaggedTensor)

    if not (condition_is_ragged or x_is_ragged or y_is_ragged):
        return array_ops.where(condition, x, y)

    elif condition_is_ragged and x_is_ragged and y_is_ragged:
        return ragged_functional_ops.map_flat_values(array_ops.where,
                                                     condition, x, y)
    elif not condition_is_ragged:
        # Concatenate x and y, and then use `gather` to assemble the selected rows.
        condition.shape.assert_has_rank(1)
        x_nrows = _nrows(x)
        x_and_y = ragged_concat_ops.concat([x, y], axis=0)
        indices = array_ops.where(condition, math_ops.range(x_nrows),
                                  x_nrows + math_ops.range(_nrows(y)))
        return ragged_gather_ops.gather(x_and_y, indices)

    else:
        raise ValueError('Input shapes do not match.')
Exemple #23
0
def batch_gather(params, indices, name=None):
    """Gathers slices from `params` according to `indices` with batch dims.

  This operation is similar to `gather`, but it assumes that the leading `N`
  dimensions of `indices` and `params` are batch dimensions, and performs a
  gather within each batch.  In particular, when using this operation with `N`
  batch dimensions `B1...BN`:

  * `indices` has shape `[B1...BN, I]`
  * `params` has shape `[B1...BN, P1...PM]`.
  * `result` has shape `[B1...BN, I, P2...PM]`.
  * `result[b1...bN, i, p2...pM] =
    params[b1...bN, indices[b1...bN, i], p2...pM]`

  Args:
    params: A potentially ragged tensor with shape `[B1...BN, P1...PM]` (`N>=0`,
      `M>0`).
    indices: A potentially ragged tensor with shape `[B1...BN, I]` (`N>=0`).
    name: A name for the operation (optional).

  Returns:
    A potentially ragged tensor with shape `[B1...BN, I, P2...PM]`.
    `result.ragged_rank = max(indices.ragged_rank, params.ragged_rank)`.

  #### Example:
    ```python
    >>> params = tf.ragged.constant([['a', 'b', 'c'], ['d'], [], ['e']])
    >>> indices = tf.ragged.constant([[1, 2, 0], [], [], [0, 0]])
    >>> tf.batch_gather(params, indices)
    [['b', 'c', 'a'], [], [], ['e', 'e']]
    ```
  """
    if not (ragged_tensor.is_ragged(params)
            or ragged_tensor.is_ragged(indices)):
        return array_ops.batch_gather(params, indices, name)

    with ops.name_scope(name, 'RaggedBatchGather', [params, indices]):
        params = ragged_tensor.convert_to_tensor_or_ragged_tensor(
            params, name='params')
        indices = ragged_tensor.convert_to_tensor_or_ragged_tensor(
            indices, name='indices')
        indices_ndims = indices.shape.ndims
        if indices_ndims is None:
            raise ValueError(
                'batch_gather does not allow indices with unknown shape.')
        if indices_ndims == 0:
            raise ValueError('indices.rank must be at least 1.')

        if ragged_tensor.is_ragged(indices):
            # If the outermost ragged dimension is a batch dimension, recurse.
            if indices_ndims > 2:
                if not ragged_tensor.is_ragged(params):
                    raise ValueError('batch shape from indices does '
                                     'not match params shape')
                checks = [
                    check_ops.assert_equal(params.row_splits,
                                           indices.row_splits)
                ]
                with ops.control_dependencies(checks):
                    return ragged_tensor.RaggedTensor.from_row_splits(
                        batch_gather(params.values, indices.values),
                        indices.row_splits)

            # Otherwise, indices is a 2D ragged tensor with 1 ragged dimension.
            else:
                # Ensure that `params` is ragged and has at least 2 dimensions.
                if not ragged_tensor.is_ragged(params):
                    if params.shape.ndims is not None and params.shape.ndims < 2:
                        raise ValueError('batch shape from indices does '
                                         'not match params shape')
                    params = ragged_conversion_ops.from_tensor(params,
                                                               ragged_rank=1)

                # Adjust indices from within-batch to global (in params.values), and
                # then use ragged.gather to gather them.
                num_indices = indices.row_lengths()
                params_starts = params.row_starts()
                adjustments = ragged_util.repeat(params_starts,
                                                 num_indices,
                                                 axis=0)
                adjusted_index_values = math_ops.to_int64(
                    indices.values) + adjustments
                return ragged_tensor.RaggedTensor.from_row_splits(
                    ragged_gather_ops.gather(params.values,
                                             adjusted_index_values),
                    indices.row_splits)

        else:  # params is a RaggedTensor and indices is a Tensor.
            if indices_ndims == 1:
                return ragged_gather_ops.gather(params, indices)
            elif indices_ndims == 2:
                # Adjust indices from batch-local to global (in params.values)
                adjustments = array_ops.expand_dims(params.row_starts(), 1)
                adjusted_indices = math_ops.to_int64(indices) + adjustments
                return ragged_gather_ops.gather(params.values,
                                                adjusted_indices)
            else:
                raise ValueError(
                    'batch shape from indices does not match params shape')
def batch_gather(params, indices, name=None):
  """Gathers slices from `params` according to `indices` with batch dims.

  This operation is similar to `gather`, but it assumes that the leading `N`
  dimensions of `indices` and `params` are batch dimensions, and performs a
  gather within each batch.  In particular, when using this operation with `N`
  batch dimensions `B1...BN`:

  * `indices` has shape `[B1...BN, I]`
  * `params` has shape `[B1...BN, P1...PM]`.
  * `result` has shape `[B1...BN, I, P2...PM]`.
  * `result[b1...bN, i, p2...pM] =
    params[b1...bN, indices[b1...bN, i], p2...pM]`

  Args:
    params: A potentially ragged tensor with shape `[B1...BN, P1...PM]` (`N>=0`,
      `M>0`).
    indices: A potentially ragged tensor with shape `[B1...BN, I]` (`N>=0`).
    name: A name for the operation (optional).

  Returns:
    A potentially ragged tensor with shape `[B1...BN, I, P2...PM]`.
    `result.ragged_rank = max(indices.ragged_rank, params.ragged_rank)`.

  #### Example:
    ```python
    >>> params = tf.ragged.constant([['a', 'b', 'c'], ['d'], [], ['e']])
    >>> indices = tf.ragged.constant([[1, 2, 0], [], [], [0, 0]])
    >>> tf.compat.v1.batch_gather(params, indices)
    [['b', 'c', 'a'], [], [], ['e', 'e']]
    ```
  """
  if not (ragged_tensor.is_ragged(params) or ragged_tensor.is_ragged(indices)):
    return array_ops.batch_gather(params, indices, name)

  with ops.name_scope(name, 'RaggedBatchGather', [params, indices]):
    params = ragged_tensor.convert_to_tensor_or_ragged_tensor(
        params, name='params')
    indices = ragged_tensor.convert_to_tensor_or_ragged_tensor(
        indices, name='indices')
    params, indices = ragged_tensor.match_row_splits_dtypes(params, indices)
    indices_ndims = indices.shape.ndims
    if indices_ndims is None:
      raise ValueError(
          'batch_gather does not allow indices with unknown shape.')
    if indices_ndims == 0:
      raise ValueError('indices.rank must be at least 1.')

    if ragged_tensor.is_ragged(indices):
      # If the outermost ragged dimension is a batch dimension, recurse.
      if indices_ndims > 2:
        if not ragged_tensor.is_ragged(params):
          raise ValueError('batch shape from indices does '
                           'not match params shape')
        checks = [check_ops.assert_equal(params.row_splits, indices.row_splits)]
        with ops.control_dependencies(checks):
          return ragged_tensor.RaggedTensor.from_row_splits(
              batch_gather(params.values, indices.values), indices.row_splits,
              validate=False)

      # Otherwise, indices is a 2D ragged tensor with 1 ragged dimension.
      else:
        # Ensure that `params` is ragged and has at least 2 dimensions.
        if not ragged_tensor.is_ragged(params):
          if params.shape.ndims is not None and params.shape.ndims < 2:
            raise ValueError('batch shape from indices does '
                             'not match params shape')
          params = ragged_tensor.RaggedTensor.from_tensor(
              params, ragged_rank=1,
              row_splits_dtype=indices.row_splits.dtype)

        # Adjust indices from within-batch to global (in params.values), and
        # then use ragged.gather to gather them.
        num_indices = indices.row_lengths()
        params_starts = params.row_starts()
        adjustments = ragged_util.repeat(params_starts, num_indices, axis=0)
        adjusted_index_values = (
            math_ops.cast(indices.values, adjustments.dtype) + adjustments)
        return ragged_tensor.RaggedTensor.from_row_splits(
            ragged_gather_ops.gather(params.values, adjusted_index_values),
            indices.row_splits, validate=False)

    else:  # params is a RaggedTensor and indices is a Tensor.
      if indices_ndims == 1:
        return ragged_gather_ops.gather(params, indices)
      elif indices_ndims == 2:
        # Adjust indices from batch-local to global (in params.values)
        adjustments = array_ops.expand_dims(params.row_starts(), 1)
        adjusted_indices = (
            math_ops.cast(indices, adjustments.dtype) + adjustments)
        return ragged_gather_ops.gather(params.values, adjusted_indices)
      else:
        raise ValueError('batch shape from indices does not match params shape')
Exemple #25
0
 def testRaggedParamsAndScalarIndices(self):
     params = ragged_factory_ops.constant([['a', 'b'], ['c', 'd', 'e'],
                                           ['f'], [], ['g']])
     indices = 1
     self.assertRaggedEqual(ragged_gather_ops.gather(params, indices),
                            [b'c', b'd', b'e'])
Exemple #26
0
 def testTensorParamsAndRaggedIndices(self):
   params = ['a', 'b', 'c', 'd', 'e']
   indices = ragged_factory_ops.constant([[2, 1], [1, 2, 0], [3]])
   self.assertAllEqual(
       ragged_gather_ops.gather(params, indices),
       [[b'c', b'b'], [b'b', b'c', b'a'], [b'd']])