Beispiel #1
0
 def testTensorParamsAndTensorIndices(self):
     params = ['a', 'b', 'c', 'd', 'e']
     indices = [2, 0, 2, 1]
     self.assertRaggedEqual(ragged_array_ops.gather(params, indices),
                            [b'c', b'a', b'c', b'b'])
     self.assertIsInstance(ragged_array_ops.gather(params, indices),
                           ops.Tensor)
Beispiel #2
0
 def testDocStringExamples(self):
     params = constant_op.constant(['a', 'b', 'c', 'd', 'e'])
     indices = constant_op.constant([3, 1, 2, 1, 0])
     ragged_params = ragged_factory_ops.constant([['a', 'b', 'c'], ['d'],
                                                  [], ['e']])
     ragged_indices = ragged_factory_ops.constant([[3, 1, 2], [1], [], [0]])
     self.assertRaggedEqual(ragged_array_ops.gather(params, ragged_indices),
                            [[b'd', b'b', b'c'], [b'b'], [], [b'a']])
     self.assertRaggedEqual(
         ragged_array_ops.gather(ragged_params, indices),
         [[b'e'], [b'd'], [], [b'd'], [b'a', b'b', b'c']])
     self.assertRaggedEqual(
         ragged_array_ops.gather(ragged_params, ragged_indices),
         [[[b'e'], [b'd'], []], [[b'd']], [], [[b'a', b'b', b'c']]])
Beispiel #3
0
 def testRaggedParamsAndTensorIndices(self):
     params = ragged_factory_ops.constant([['a', 'b'], ['c', 'd', 'e'],
                                           ['f'], [], ['g']])
     indices = [2, 0, 2, 1]
     self.assertRaggedEqual(
         ragged_array_ops.gather(params, indices),
         [[b'f'], [b'a', b'b'], [b'f'], [b'c', b'd', b'e']])
Beispiel #4
0
 def testOutOfBoundsError(self):
     tensor_params = ['a', 'b', 'c']
     tensor_indices = [0, 1, 2]
     ragged_params = ragged_factory_ops.constant([['a', 'b'], ['c']])
     ragged_indices = ragged_factory_ops.constant([[0, 3]])
     with self.assertRaisesRegexp(errors.InvalidArgumentError,
                                  r'indices\[1\] = 3 is not in \[0, 3\)'):
         self.evaluate(
             ragged_array_ops.gather(tensor_params, ragged_indices))
     with self.assertRaisesRegexp(errors.InvalidArgumentError,
                                  r'indices\[2\] = 2 is not in \[0, 2\)'):
         self.evaluate(
             ragged_array_ops.gather(ragged_params, tensor_indices))
     with self.assertRaisesRegexp(errors.InvalidArgumentError,
                                  r'indices\[1\] = 3 is not in \[0, 2\)'):
         self.evaluate(
             ragged_array_ops.gather(ragged_params, ragged_indices))
Beispiel #5
0
def _ragged_gather_v1(params, indices, validate_indices=None, name=None,
                      axis=0):
  return ragged_array_ops.gather(
      params=params,
      indices=indices,
      validate_indices=validate_indices,
      axis=axis,
      name=name)
Beispiel #6
0
 def testTensorParamsAnd4DRaggedIndices(self):
     indices = ragged_factory_ops.constant(
         [[[[3, 4], [0, 6]], []], [[[2, 1], [1, 0]], [[2, 5]], [[2, 3]]],
          [[[1, 0]]]],  # pyformat: disable
         ragged_rank=2,
         inner_shape=(2, ))
     params = ['a', 'b', 'c', 'd', 'e', 'f', 'g']
     self.assertRaggedEqual(
         ragged_array_ops.gather(params, indices),
         [[[[b'd', b'e'], [b'a', b'g']], []],
          [[[b'c', b'b'], [b'b', b'a']], [[b'c', b'f']], [[b'c', b'd']]],
          [[[b'b', b'a']]]])  # pyformat: disable
Beispiel #7
0
 def testRaggedParamsAndRaggedIndices(self):
     params = ragged_factory_ops.constant([['a', 'b'], ['c', 'd', 'e'],
                                           ['f'], [], ['g']])
     indices = ragged_factory_ops.constant([[2, 1], [1, 2, 0], [3]])
     self.assertRaggedEqual(
         ragged_array_ops.gather(params, indices),
         [
             [[b'f'], [b'c', b'd', b'e']],  # [[p[2], p[1]      ],
             [[b'c', b'd', b'e'], [b'f'], [b'a', b'b']
              ],  #  [p[1], p[2], p[0]],
             [[]]
         ]  #  [p[3]            ]]
     )  # pyformat: disable
Beispiel #8
0
 def test3DRaggedParamsAnd2DTensorIndices(self):
     params = ragged_factory_ops.constant([[['a', 'b'], []],
                                           [['c', 'd'], ['e'], ['f']],
                                           [['g']]])
     indices = [[1, 2], [0, 1], [2, 2]]
     self.assertRaggedEqual(
         ragged_array_ops.gather(params, indices),
         [
             [[[b'c', b'd'], [b'e'], [b'f']], [[b'g']]],  # [[p1, p2],
             [[[b'a', b'b'], []], [[b'c', b'd'], [b'e'], [b'f']]
              ],  #  [p0, p1],
             [[[b'g']], [[b'g']]]
         ]  #  [p2, p2]]
     )  # pyformat: disable
Beispiel #9
0
def _build_ragged_tensor_from_value_ranges(starts, limits, step, values):
    """Returns a `RaggedTensor` containing the specified sequences of values.

  Returns a RaggedTensor `output` where:

  ```python
  output.shape[0] = starts.shape[0]
  output[i] = values[starts[i]:limits[i]:step]
  ```

  Requires that `starts.shape == limits.shape` and
  `0 <= starts[i] <= limits[i] <= values.shape[0]`.

  Args:
    starts: 1D integer Tensor specifying the start indices for the sequences of
      values to include.
    limits: 1D integer Tensor specifying the limit indices for the sequences of
      values to include.
    step: Integer value specifying the step size for strided slices.
    values: The set of values to select from.

  Returns:
    A `RaggedTensor`.

  Raises:
    ValueError: Until the prerequisite ops are checked in.
  """
    # Use `ragged_range` to get the index of each value we should include.
    if step is None:
        step = 1
    step = ops.convert_to_tensor(step, name="step")
    if step.dtype.is_integer:
        step = math_ops.cast(step, dtypes.int64)
    else:
        raise TypeError("slice strides must be integers or None")
    value_indices = ragged_math_ops.range(starts, limits, step)

    # Use `ragged_gather` or `array_ops.gather` to collect the values.
    if isinstance(values, ragged_tensor.RaggedTensor):
        gathered_values = ragged_array_ops.gather(params=values,
                                                  indices=value_indices.values)
    else:
        gathered_values = array_ops.gather(params=values,
                                           indices=value_indices.values)

    # Assemble the RaggedTensor from splits & values.
    return value_indices.with_values(gathered_values)
def _build_ragged_tensor_from_value_ranges(starts, limits, step, values):
  """Returns a `RaggedTensor` containing the specified sequences of values.

  Returns a RaggedTensor `output` where:

  ```python
  output.shape[0] = starts.shape[0]
  output[i] = values[starts[i]:limits[i]:step]
  ```

  Requires that `starts.shape == limits.shape` and
  `0 <= starts[i] <= limits[i] <= values.shape[0]`.

  Args:
    starts: 1D integer Tensor specifying the start indices for the sequences of
      values to include.
    limits: 1D integer Tensor specifying the limit indices for the sequences of
      values to include.
    step: Integer value specifying the step size for strided slices.
    values: The set of values to select from.

  Returns:
    A `RaggedTensor`.

  Raises:
    ValueError: Until the prerequisite ops are checked in.
  """
  # Use `ragged_range` to get the index of each value we should include.
  if step is None:
    step = 1
  step = ops.convert_to_tensor(step, name="step")
  if step.dtype.is_integer:
    step = math_ops.cast(step, dtypes.int64)
  else:
    raise TypeError("slice strides must be integers or None")
  value_indices = ragged_math_ops.range(starts, limits, step)

  # Use `ragged_gather` or `array_ops.gather` to collect the values.
  if isinstance(values, ragged_tensor.RaggedTensor):
    gathered_values = ragged_array_ops.gather(
        params=values, indices=value_indices.values)
  else:
    gathered_values = array_ops.gather(
        params=values, indices=value_indices.values)

  # Assemble the RaggedTensor from splits & values.
  return value_indices.with_values(gathered_values)
Beispiel #11
0
 def testTensorParamsAndRaggedIndices(self):
     params = ['a', 'b', 'c', 'd', 'e']
     indices = ragged_factory_ops.constant([[2, 1], [1, 2, 0], [3]])
     self.assertRaggedEqual(ragged_array_ops.gather(params, indices),
                            [[b'c', b'b'], [b'b', b'c', b'a'], [b'd']])