Esempio n. 1
0
  def testRaggedTile(self,
                     descr,
                     rt_input,
                     multiples,
                     expected,
                     ragged_rank=None):
    rt = ragged_factory_ops.constant(rt_input, ragged_rank)

    expected_shape = [
        None if dim is None else dim * multiple
        for (dim, multiple) in zip(rt.shape.as_list(), multiples)
    ]

    # Test with both const & non-const multiples: ragged_tile has a few code
    # paths that optimize the case where multiples[d] is known to be 1.
    const_multiples = constant_op.constant(multiples, dtypes.int64)
    non_const_multiples = array_ops.placeholder_with_default(
        const_multiples, shape=[len(multiples)])

    for multiples_tensor in (const_multiples, non_const_multiples):
      tiled = ragged_array_ops.tile(rt, multiples_tensor)
      self.assertEqual(tiled.ragged_rank, rt.ragged_rank)
      self.assertEqual(tiled.shape.ndims, rt.shape.ndims)
      if multiples_tensor is const_multiples:
        self.assertEqual(tiled.shape.as_list(), expected_shape)
      with self.test_session():
        self.assertEqual(tiled.eval().tolist(), expected)
Esempio n. 2
0
 def testRaggedTileWithTensorInput(self):
     # When the input is a `Tensor`, ragged_tile just delegates to tf.tile.
     dt = constant_op.constant([[1, 2], [3, 4]])
     tiled = ragged_array_ops.tile(dt, [3, 2])
     expected = [[1, 2, 1, 2], [3, 4, 3, 4], [1, 2, 1, 2], [3, 4, 3, 4],
                 [1, 2, 1, 2], [3, 4, 3, 4]]  # pyformat: disable
     self.assertAllEqual(tiled, expected)
Esempio n. 3
0
    def testRaggedTile(self,
                       descr,
                       rt_input,
                       multiples,
                       expected,
                       ragged_rank=None):
        rt = ragged_factory_ops.constant(rt_input, ragged_rank)

        expected_shape = [
            None if dim is None else dim * multiple
            for (dim, multiple) in zip(rt.shape.as_list(), multiples)
        ]

        # Test with both const & non-const multiples: ragged_tile has a few code
        # paths that optimize the case where multiples[d] is known to be 1.
        const_multiples = constant_op.constant(multiples, dtypes.int64)
        non_const_multiples = array_ops.placeholder_with_default(
            const_multiples, shape=[len(multiples)])

        for multiples_tensor in (const_multiples, non_const_multiples):
            tiled = ragged_array_ops.tile(rt, multiples_tensor)
            self.assertEqual(tiled.ragged_rank, rt.ragged_rank)
            self.assertEqual(tiled.shape.ndims, rt.shape.ndims)
            if multiples_tensor is const_multiples:
                self.assertEqual(tiled.shape.as_list(), expected_shape)
            self.assertAllEqual(tiled, expected)
Esempio n. 4
0
 def testRaggedTileWithTensorInput(self):
   # When the input is a `Tensor`, ragged_tile just delegates to tf.tile.
   dt = constant_op.constant([[1, 2], [3, 4]])
   tiled = ragged_array_ops.tile(dt, [3, 2])
   expected = [[1, 2, 1, 2], [3, 4, 3, 4],
               [1, 2, 1, 2], [3, 4, 3, 4],
               [1, 2, 1, 2], [3, 4, 3, 4]]  # pyformat: disable
   self.assertRaggedEqual(tiled, expected)
Esempio n. 5
0
 def testRaggedTileWithTensorInput(self):
     # When the input is a `Tensor`, ragged_tile just delegates to tf.tile.
     dt = constant_op.constant([[1, 2], [3, 4]])
     tiled = ragged_array_ops.tile(dt, [3, 2])
     expected = [[1, 2, 1, 2], [3, 4, 3, 4], [1, 2, 1, 2], [3, 4, 3, 4],
                 [1, 2, 1, 2], [3, 4, 3, 4]]  # pyformat: disable
     with self.test_session():
         self.assertEqual(tiled.eval().tolist(), expected)
def _broadcast_to_ragged_shape(rt_input, dst_shape, broadcast_inner_dimensions):
  """Broadcasts rt_input to the ragged shape `dst_shape`."""
  # Check that rt_input and dst_shape have the same row_splits dtype.
  if (isinstance(rt_input, ragged_tensor.RaggedTensor) and
      rt_input.row_splits.dtype != dst_shape.dim_size_dtype):
    if not ragged_config.auto_cast_partition_dtype():
      raise ValueError('rt_input and dst_shape have different row_split '
                       'dtypes; use RaggedTensor.with_row_splits_dtype() or '
                       'RaggedTensorDynamicShape.with_dim_size_dtype() to '
                       'convert to a compatible dtype.')
    rt_input = rt_input.with_row_splits_dtype(dtypes.int64)
    dst_shape = dst_shape.with_dim_size_dtype(dtypes.int64)

  # dst_shape's rank and ragged_rank must be greater than or equal to rt_input's
  if rt_input.shape.ndims is None or dst_shape.rank is None:
    raise ValueError('Unable to broadcast: unknown rank')
  if rt_input.shape.ndims > dst_shape.rank:
    raise ValueError('Incompatible with shape: rank mismatch')
  if (isinstance(rt_input, ragged_tensor.RaggedTensor) and
      rt_input.ragged_rank >= dst_shape.num_partitioned_dimensions):
    raise ValueError('Incompatible with shape: ragged rank mismatch')

  src_shape = RaggedTensorDynamicShape.from_tensor(rt_input)
  src_shape = src_shape.broadcast_to_rank(dst_shape.rank)

  # Add dimensions to rt_input so its rank and ragged_rank matches dst_shape.
  if dst_shape.rank > rt_input.shape.ndims:
    if rt_input.shape.ndims < dst_shape.num_inner_dimensions + 1:
      rt_input = array_ops.reshape(
          rt_input, array_ops.concat([[-1], dst_shape.inner_dim_sizes], axis=0))
    for _ in range(dst_shape.rank - rt_input.shape.ndims):
      if ragged_tensor.is_ragged(rt_input):
        nrows = rt_input.nrows()
      else:
        nrows = array_ops.shape(rt_input,
                                out_type=dst_shape.dim_size_dtype)[0]
      rt_input = ragged_tensor.RaggedTensor.from_row_lengths(rt_input, [nrows],
                                                             validate=False)

  # Add ragged dimensions to match dst_shape.
  if ragged_tensor.is_ragged(rt_input):
    inner_rank_diff = (
        rt_input.flat_values.shape.ndims - 1 - dst_shape.num_inner_dimensions)
    if inner_rank_diff > 0:
      rt_input = rt_input.with_flat_values(
          ragged_tensor.RaggedTensor.from_tensor(
              rt_input.flat_values, ragged_rank=inner_rank_diff,
              row_splits_dtype=dst_shape.dim_size_dtype))
  else:
    rt_input = ragged_tensor.RaggedTensor.from_tensor(
        rt_input, ragged_rank=dst_shape.num_partitioned_dimensions - 1,
        row_splits_dtype=dst_shape.dim_size_dtype)

  # Do broadcasting for any dimensions that will remain uniform.  We can do
  # these all at once, since they're independent of one another.
  multiples = [1] * dst_shape.rank
  for axis in range(dst_shape.num_partitioned_dimensions):
    if not src_shape.is_ragged(axis) and not dst_shape.is_ragged(axis):
      src_size = src_shape.dimension_size(axis)
      dst_size = dst_shape.dimension_size(axis)
      if ((tensor_util.constant_value(src_size) in (1, None)) and
          (tensor_util.constant_value(dst_size) != 1)):
        multiples[axis] = array_ops.where(
            math_ops.equal(src_size, 1), dst_size, 1)
  if not all(isinstance(v, int) and v == 1 for v in multiples):
    multiples = array_ops.stack(multiples, axis=0)
    rt_input = ragged_array_ops.tile(rt_input, multiples)

  if broadcast_inner_dimensions:
    rt_input = rt_input.with_flat_values(
        array_ops.reshape(
            rt_input.flat_values,
            array_ops.concat([[-1], dst_shape.inner_dim_sizes], axis=0)))

  # Do broadcasting for dimensions that become ragged.  We must do these from
  # outermost to innermost.
  for axis in range(dst_shape.num_partitioned_dimensions):
    if not src_shape.is_ragged(axis) and dst_shape.is_ragged(axis):
      dst_size = dst_shape.dimension_size(axis)
      rt_input = _ragged_tile_axis(rt_input, axis, dst_size,
                                   dst_shape.dim_size_dtype)

  return rt_input
Esempio n. 7
0
def gather_nd(params, indices, batch_dims=0, name=None):
  """Gather slices from `params` using `n`-dimensional indices.

  This operation is similar to `gather`, but it uses the innermost dimension
  of `indices` to define a slice into `params`.  In particular, if:

  * `indices` has shape `[A1...AN, I]`
  * `params` has shape `[B1...BM]`

  Then:

  * `result` has shape `[A1...AN, B_{I+1}...BM]`.
  * `result[a1...aN] = params[indices[a1...aN, :]]`

  Args:
    params: A potentially ragged tensor with shape `[A1...AN, I]`.
    indices: A potentially ragged tensor with shape `[B1...BM]`.
    batch_dims: Must be zero.
    name: A name for the operation (optional).

  Returns:
    A potentially ragged tensor with shape `[A1...AN, B_{I+1}...BM]`.

  #### Examples:
    ```python
    >>> params = tf.compat.v1.ragged.constant_value(
    ...     [ [ ['000', '001'], ['010'              ]          ],
    ...       [ ['100'       ], ['110', '111', '112'], ['120'] ],
    ...       [ [            ], ['210'              ]          ] ])

    >>> # Gather 2D slices from a 3D tensor
    >>> ragged.gather_nd(params, [[2], [0]])
    [ [ [            ], ['210'] ]
      [ ['000', '001'], ['010'] ] ]

    >>> # Gather 1D slices from a 3D tensor
    >>> ragged.gather_nd(params, [[2, 1], [0, 0]])
    [['210'], ['000', '001']]

    >>> # Gather scalars from a 3D tensor
    >>> ragged.gather_nd(params, [[0, 0, 1], [1, 1, 2]])
    ['001', '112']
    ```
  """
  if not isinstance(batch_dims, int) or batch_dims != 0:
    raise ValueError('batch_dims != 0 is not supported for ragged gather yet.')
  if not (ragged_tensor.is_ragged(params) or ragged_tensor.is_ragged(indices)):
    return array_ops.gather_nd(params, indices, name)

  with ops.name_scope(name, 'RaggedGatherNd', [params, indices]):

    params = ragged_tensor.convert_to_tensor_or_ragged_tensor(
        params, name='params')
    indices = ragged_tensor.convert_to_tensor_or_ragged_tensor(
        indices, name='indices')
    params, indices = ragged_tensor.match_row_splits_dtypes(params, indices)
    indices_shape = indices.shape
    indices_ndims = indices_shape.ndims
    if indices_ndims is None:
      raise ValueError('indices.rank be statically known.')
    if indices_ndims == 0:
      raise ValueError('indices.rank must be at least 1.')
    if (ragged_tensor.is_ragged(indices) and
        indices_ndims == indices.ragged_rank + 1):
      raise ValueError('The innermost dimension of indices may not be ragged')

    # `index_size` is the "n" in "gather_nd" -- i.e., the number of dimensions
    # that each index slices into.
    index_size = tensor_shape.dimension_value(indices_shape[-1])
    if index_size is None:
      raise ValueError('indices.shape[-1] must be statically known.')

    # If `indices` has more than 2 dimensions, then recurse.  If `indices` is
    # dense, then we convert it to ragged before recursing, and then convert
    # the result back to `dense` if appropriate.
    if indices_ndims > 2:
      indices_is_dense = not ragged_tensor.is_ragged(indices)
      if indices_is_dense:
        indices = ragged_tensor.RaggedTensor.from_tensor(
            indices, ragged_rank=indices_ndims - 2,
            row_splits_dtype=params.row_splits.dtype)
      result = indices.with_flat_values(gather_nd(params, indices.flat_values))
      if (indices_is_dense and ragged_tensor.is_ragged(result) and
          result.ragged_rank == indices_ndims - 2):
        result = ragged_tensor.RaggedTensor.to_tensor(result)
      return result

    # indices_ndims <= 2, and the innermost dimension of indices may not be
    # ragged, so `indices` must not be ragged.
    assert not ragged_tensor.is_ragged(indices)
    assert ragged_tensor.is_ragged(params)

    # Handle corner case: An empty index tuple selects the entire `params`
    # value.  So if `index_size` is zero, then tile `params`.
    if index_size == 0:
      params_ndims = params.ragged_rank + array_ops.rank(params.flat_values)
      for dim in range(indices_ndims - 1):
        params = ragged_array_ops.expand_dims(params, axis=0)
      multiples = array_ops.concat([
          array_ops.shape(indices)[:-1],
          array_ops.ones([params_ndims], dtypes.int32)
      ],
                                   axis=0)
      return ragged_array_ops.tile(params, multiples)

    # When index_size=1, we can just flatten the index tuples and use gather.
    elif index_size == 1:
      flattened_index_tuples = array_ops.reshape(indices, [-1])
      return gather(params, flattened_index_tuples)

    # Otherwise, params is a RaggedTensor, and indices is a 1D or 2D Tensor.
    # Flatten both the index tuples and the params, such that the flattened
    # index tuples point to the correct values in the flattened params; and
    # then use ragged.gather on the flattened index tuples & params.
    else:
      indices = math_ops.cast(indices, params.row_splits.dtype)

      # Flatten the outermost 2 dimensions of the index tuples & params.
      flattened_index_tuples = array_ops.gather(params.row_splits,
                                                indices[..., 0])
      flattened_index_tuples += indices[..., 1]
      flattened_params = params.values

      # Flatten any remaining dimensions.
      for dim in range(2, index_size):
        if not ragged_tensor.is_ragged(flattened_params):
          flattened_index_tuples = array_ops.expand_dims(
              flattened_index_tuples, axis=1)
          flattened_index_tuples = array_ops.concat(
              [flattened_index_tuples, indices[..., dim:]], axis=1)
          return array_ops.gather_nd(flattened_params, flattened_index_tuples)

        flattened_index_tuples = array_ops.gather(
            flattened_params.row_starts(), flattened_index_tuples)
        flattened_index_tuples += indices[..., dim]
        flattened_params = flattened_params.values

      # Gather using the flattened index tuples and params.
      return gather(flattened_params, flattened_index_tuples)
Esempio n. 8
0
def _broadcast_to_ragged_shape(rt_input, dst_shape,
                               broadcast_inner_dimensions):
    """Broadcasts rt_input to the ragged shape `dst_shape`."""
    # Check that rt_input and dst_shape have the same row_splits dtype.
    if (isinstance(rt_input, ragged_tensor.RaggedTensor)
            and rt_input.row_splits.dtype != dst_shape.dim_size_dtype):
        if not ragged_config.auto_cast_partition_dtype():
            raise ValueError(
                'rt_input and dst_shape have different row_split '
                'dtypes; use RaggedTensor.with_row_splits_dtype() or '
                'RaggedTensorDynamicShape.with_dim_size_dtype() to '
                'convert to a compatible dtype.')
        rt_input = rt_input.with_row_splits_dtype(dtypes.int64)
        dst_shape = dst_shape.with_dim_size_dtype(dtypes.int64)

    # dst_shape's rank and ragged_rank must be greater than or equal to rt_input's
    if rt_input.shape.ndims is None or dst_shape.rank is None:
        raise ValueError('Unable to broadcast: unknown rank')
    if rt_input.shape.ndims > dst_shape.rank:
        raise ValueError('Incompatible with shape: rank mismatch')
    if (isinstance(rt_input, ragged_tensor.RaggedTensor)
            and rt_input.ragged_rank >= dst_shape.num_partitioned_dimensions):
        raise ValueError('Incompatible with shape: ragged rank mismatch')

    src_shape = RaggedTensorDynamicShape.from_tensor(rt_input)
    src_shape = src_shape.broadcast_to_rank(dst_shape.rank)

    # Add dimensions to rt_input so its rank and ragged_rank matches dst_shape.
    if dst_shape.rank > rt_input.shape.ndims:
        if rt_input.shape.ndims < dst_shape.num_inner_dimensions + 1:
            rt_input = array_ops.reshape(
                rt_input,
                array_ops.concat([[-1], dst_shape.inner_dim_sizes], axis=0))
        for _ in range(dst_shape.rank - rt_input.shape.ndims):
            if ragged_tensor.is_ragged(rt_input):
                nrows = rt_input.nrows()
            else:
                nrows = array_ops.shape(rt_input,
                                        out_type=dst_shape.dim_size_dtype)[0]
            rt_input = ragged_tensor.RaggedTensor.from_row_lengths(
                rt_input, [nrows], validate=False)

    # Add ragged dimensions to match dst_shape.
    if ragged_tensor.is_ragged(rt_input):
        inner_rank_diff = (rt_input.flat_values.shape.ndims - 1 -
                           dst_shape.num_inner_dimensions)
        if inner_rank_diff > 0:
            rt_input = rt_input.with_flat_values(
                ragged_tensor.RaggedTensor.from_tensor(
                    rt_input.flat_values,
                    ragged_rank=inner_rank_diff,
                    row_splits_dtype=dst_shape.dim_size_dtype))
    else:
        rt_input = ragged_tensor.RaggedTensor.from_tensor(
            rt_input,
            ragged_rank=dst_shape.num_partitioned_dimensions - 1,
            row_splits_dtype=dst_shape.dim_size_dtype)

    # Do broadcasting for any dimensions that will remain uniform.  We can do
    # these all at once, since they're independent of one another.
    multiples = [1] * dst_shape.rank
    for axis in range(dst_shape.num_partitioned_dimensions):
        if not src_shape.is_ragged(axis) and not dst_shape.is_ragged(axis):
            src_size = src_shape.dimension_size(axis)
            dst_size = dst_shape.dimension_size(axis)
            if ((tensor_util.constant_value(src_size) in (1, None))
                    and (tensor_util.constant_value(dst_size) != 1)):
                multiples[axis] = array_ops.where(math_ops.equal(src_size, 1),
                                                  dst_size, 1)
    if not all(isinstance(v, int) and v == 1 for v in multiples):
        multiples = array_ops.stack(multiples, axis=0)
        rt_input = ragged_array_ops.tile(rt_input, multiples)

    if broadcast_inner_dimensions:
        rt_input = rt_input.with_flat_values(
            array_ops.reshape(
                rt_input.flat_values,
                array_ops.concat([[-1], dst_shape.inner_dim_sizes], axis=0)))

    # Do broadcasting for dimensions that become ragged.  We must do these from
    # outermost to innermost.
    for axis in range(dst_shape.num_partitioned_dimensions):
        if not src_shape.is_ragged(axis) and dst_shape.is_ragged(axis):
            dst_size = dst_shape.dimension_size(axis)
            rt_input = _ragged_tile_axis(rt_input, axis, dst_size,
                                         dst_shape.dim_size_dtype)

    return rt_input
def gather_nd(params, indices, batch_dims=0, name=None):
    """Gather slices from `params` using `n`-dimensional indices.

  This operation is similar to `gather`, but it uses the innermost dimension
  of `indices` to define a slice into `params`.  In particular, if:

  * `indices` has shape `[A1...AN, I]`
  * `params` has shape `[B1...BM]`

  Then:

  * `result` has shape `[A1...AN, B_{I+1}...BM]`.
  * `result[a1...aN] = params[indices[a1...aN, :]]`

  Args:
    params: A potentially ragged tensor with shape `[A1...AN, I]`.
    indices: A potentially ragged tensor with shape `[B1...BM]`.
    batch_dims: Must be zero.
    name: A name for the operation (optional).

  Returns:
    A potentially ragged tensor with shape `[A1...AN, B_{I+1}...BM]`.

  #### Examples:

  >>> params = tf.ragged.constant(
  ...     [ [ ['000', '001'], ['010'              ]          ],
  ...       [ ['100'       ], ['110', '111', '112'], ['120'] ],
  ...       [ [            ], ['210'              ]          ] ])

  >>> # Gather 2D slices from a 3D tensor
  >>> tf.gather_nd(params, [[2], [0]])
  <tf.RaggedTensor [[[], [b'210']], [[b'000', b'001'], [b'010']]]>

  >>> # Gather 1D slices from a 3D tensor
  >>> tf.gather_nd(params, [[2, 1], [0, 0]])
  <tf.RaggedTensor [[b'210'], [b'000', b'001']]>

  >>> # Gather scalars from a 3D tensor
  >>> tf.gather_nd(params, [[0, 0, 1], [1, 1, 2]]).numpy()
  array([b'001', b'112'], dtype=object)
  """
    if not isinstance(batch_dims, int) or batch_dims != 0:
        raise ValueError(
            'batch_dims != 0 is not supported for ragged gather yet.')
    if not (ragged_tensor.is_ragged(params)
            or ragged_tensor.is_ragged(indices)):
        return array_ops.gather_nd(params, indices, name)

    with ops.name_scope(name, 'RaggedGatherNd', [params, indices]):

        params = ragged_tensor.convert_to_tensor_or_ragged_tensor(
            params, name='params')
        indices = ragged_tensor.convert_to_tensor_or_ragged_tensor(
            indices, name='indices')
        params, indices = ragged_tensor.match_row_splits_dtypes(
            params, indices)
        indices_shape = indices.shape
        indices_ndims = indices_shape.ndims
        if indices_ndims is None:
            raise ValueError('indices.rank be statically known.')
        if indices_ndims == 0:
            raise ValueError('indices.rank must be at least 1.')
        if (ragged_tensor.is_ragged(indices)
                and indices_ndims == indices.ragged_rank + 1):
            raise ValueError(
                'The innermost dimension of indices may not be ragged')

        # `index_size` is the "n" in "gather_nd" -- i.e., the number of dimensions
        # that each index slices into.
        index_size = tensor_shape.dimension_value(indices_shape[-1])
        if index_size is None:
            raise ValueError('indices.shape[-1] must be statically known.')

        # If `indices` has more than 2 dimensions, then recurse.  If `indices` is
        # dense, then we convert it to ragged before recursing, and then convert
        # the result back to `dense` if appropriate.
        if indices_ndims > 2:
            indices_is_dense = not ragged_tensor.is_ragged(indices)
            if indices_is_dense:
                indices = ragged_tensor.RaggedTensor.from_tensor(
                    indices,
                    ragged_rank=indices_ndims - 2,
                    row_splits_dtype=params.row_splits.dtype)
            result = indices.with_flat_values(
                gather_nd(params, indices.flat_values))
            if (indices_is_dense and ragged_tensor.is_ragged(result)
                    and result.ragged_rank == indices_ndims - 2):
                result = ragged_tensor.RaggedTensor.to_tensor(result)
            return result

        # indices_ndims <= 2, and the innermost dimension of indices may not be
        # ragged, so `indices` must not be ragged.
        assert not ragged_tensor.is_ragged(indices)
        assert ragged_tensor.is_ragged(params)

        # Handle corner case: An empty index tuple selects the entire `params`
        # value.  So if `index_size` is zero, then tile `params`.
        if index_size == 0:
            params_ndims = params.ragged_rank + array_ops.rank(
                params.flat_values)
            for dim in range(indices_ndims - 1):
                params = ragged_array_ops.expand_dims(params, axis=0)
            multiples = array_ops.concat([
                array_ops.shape(indices)[:-1],
                array_ops.ones([params_ndims], dtypes.int32)
            ],
                                         axis=0)
            return ragged_array_ops.tile(params, multiples)

        # When index_size=1, we can just flatten the index tuples and use gather.
        elif index_size == 1:
            flattened_index_tuples = array_ops.reshape(indices, [-1])
            return gather(params, flattened_index_tuples)

        # Otherwise, params is a RaggedTensor, and indices is a 1D or 2D Tensor.
        # Flatten both the index tuples and the params, such that the flattened
        # index tuples point to the correct values in the flattened params; and
        # then use ragged.gather on the flattened index tuples & params.
        else:
            indices = math_ops.cast(indices, params.row_splits.dtype)

            # Flatten the outermost 2 dimensions of the index tuples & params.
            flattened_index_tuples = array_ops.gather(params.row_splits,
                                                      indices[..., 0])
            flattened_index_tuples += indices[..., 1]
            flattened_params = params.values

            # Flatten any remaining dimensions.
            for dim in range(2, index_size):
                if not ragged_tensor.is_ragged(flattened_params):
                    flattened_index_tuples = array_ops.expand_dims(
                        flattened_index_tuples, axis=1)
                    flattened_index_tuples = array_ops.concat(
                        [flattened_index_tuples, indices[..., dim:]], axis=1)
                    return array_ops.gather_nd(flattened_params,
                                               flattened_index_tuples)

                flattened_index_tuples = array_ops.gather(
                    flattened_params.row_starts(), flattened_index_tuples)
                flattened_index_tuples += indices[..., dim]
                flattened_params = flattened_params.values

            # Gather using the flattened index tuples and params.
            return gather(flattened_params, flattened_index_tuples)