def _broadcast_elementwise_args(elementwise_args):
  """Broadcasts the values of `elementwise_args` to have compatible shapes.

  Args:
    elementwise_args: A dictionary whose keys are potentially ragged tensors.

  Returns:
    A tuple `(broadcast_args, broadcast_splits, checks)` where:

    * `broadcast_args` is a dictionary with the same keys as
      `elementwise_args`, mapping to broadcasted tensors.
    * `broadcast_splits` is the broadcasted nested row splits.
    * `checks` is a possibly empty tuple of assertion operations that should
      be added as control dependencies.

  Raises:
    ValueError: If broadcasting fails.
  """
  # No elementwise arguments were used: nothing to do!
  if not elementwise_args:
    return elementwise_args, (), ()

  # A single elementwise argument was used: no broadcasting necessary.
  if len(elementwise_args) == 1:
    arg = list(elementwise_args.values())[0]
    if ragged_tensor.is_ragged(arg):
      return elementwise_args, arg.nested_row_splits, ()
    else:
      return elementwise_args, (), ()

  # Multiple elementwise arguments.
  else:
    is_ragged = [ragged_tensor.is_ragged(t) for t in elementwise_args.values()]
    if not any(is_ragged):
      return elementwise_args, (), ()

    # Support limited broadcasting (namely, scalar + ragged).  Full
    # broadcasting support will be added later.
    if all((ragged_tensor.is_ragged(t) or t.shape.ndims == 0)
           for t in elementwise_args.values()):
      nested_splits_lists = [
          t.nested_row_splits
          for t in elementwise_args.values()
          if ragged_tensor.is_ragged(t)
      ]
      if len(nested_splits_lists) == 1:
        checks = ()
      else:
        if any(t.shape.ndims is None for t in elementwise_args.values()):
          raise ValueError('Ragged elementwise ops require that rank (number '
                           'of dimensions) be statically known.')
        if len(set(t.shape.ndims for t in elementwise_args.values())) != 1:
          raise ValueError('Ragged elementwise ops do not support '
                           'broadcasting yet')
        checks = ragged_util.assert_splits_match(nested_splits_lists)
      return (elementwise_args, nested_splits_lists[0], checks)
    else:
      raise ValueError('Ragged elementwise ops do not support broadcasting yet')
  def assertRaggedAlmostEqual(self, a, b, places=7):
    a_list = self._GetPyList(a)
    b_list = self._GetPyList(b)
    self.assertNestedListAlmostEqual(a_list, b_list, places, context='value')

    if not (isinstance(a, (list, tuple)) or isinstance(b, (list, tuple))):
      a_ragged_rank = a.ragged_rank if ragged_tensor.is_ragged(a) else 0
      b_ragged_rank = b.ragged_rank if ragged_tensor.is_ragged(b) else 0
      self.assertEqual(a_ragged_rank, b_ragged_rank)
def _unicode_decode(input, input_encoding, errors, replacement_char,
                    replace_control_characters, with_offsets):
  """Decodes each string into a sequence of codepoints."""
  input = ragged_tensor.convert_to_tensor_or_ragged_tensor(input, name="input")
  input_ndims = input.shape.ndims
  if input_ndims is None:
    raise ValueError("Rank of `input` must be statically known.")

  if input_ndims > 1:
    # Convert to a ragged tensor with ragged_rank = input_ndims - 1.
    if not ragged_tensor.is_ragged(input):
      input = ragged_tensor.RaggedTensor.from_tensor(
          input, ragged_rank=input_ndims - 1)
    elif input.ragged_rank < input_ndims - 1:
      input = input.with_flat_values(
          ragged_tensor.RaggedTensor.from_tensor(
              input.flat_values,
              ragged_rank=input_ndims - input.ragged_rank + 1))

  # Reshape the input to a flat vector, and apply the gen_string_ops op.
  if ragged_tensor.is_ragged(input):
    flat_input = array_ops.reshape(input.flat_values, [-1])
  else:
    flat_input = array_ops.reshape(input, [-1])

  if with_offsets:
    decode_op = gen_string_ops.unicode_decode_with_offsets
  else:
    decode_op = gen_string_ops.unicode_decode
  flat_result = decode_op(
      input=flat_input,
      input_encoding=input_encoding,
      errors=errors,
      replacement_char=replacement_char,
      replace_control_characters=replace_control_characters)

  if input_ndims == 0:
    codepoints = flat_result.char_values
    if with_offsets:
      offsets = flat_result.char_to_byte_starts
  else:
    codepoints = ragged_tensor.RaggedTensor.from_row_splits(
        flat_result.char_values, flat_result.row_splits, validate=False)
    if input_ndims > 1:
      codepoints = input.with_flat_values(codepoints)
    if with_offsets:
      offsets = ragged_tensor.RaggedTensor.from_row_splits(
          flat_result.char_to_byte_starts, flat_result.row_splits,
          validate=False)
      if input_ndims > 1:
        offsets = input.with_flat_values(offsets)

  if with_offsets:
    return codepoints, offsets
  else:
    return codepoints
  def assertRaggedEqual(self, a, b):
    """Asserts that two potentially ragged tensors are equal."""
    a_list = self._GetPyList(a)
    b_list = self._GetPyList(b)
    self.assertEqual(a_list, b_list)

    if not (isinstance(a, (list, tuple)) or isinstance(b, (list, tuple))):
      a_ragged_rank = a.ragged_rank if ragged_tensor.is_ragged(a) else 0
      b_ragged_rank = b.ragged_rank if ragged_tensor.is_ragged(b) else 0
      self.assertEqual(a_ragged_rank, b_ragged_rank)
Example #5
0
  def handle(self, args, kwargs):
    # Extract the binary args.
    if len(args) > 1:
      x = args[0]
      y = args[1]
      args = args[2:]
    elif args:
      kwargs = kwargs.copy()
      x = args[0]
      y = kwargs.pop(self._y, None)
      args = args[1:]
    else:
      kwargs = kwargs.copy()
      x = kwargs.pop(self._x, None)
      y = kwargs.pop(self._y, None)

    # Bail if we don't have at least one ragged argument.
    x_is_ragged = ragged_tensor.is_ragged(x)
    y_is_ragged = ragged_tensor.is_ragged(y)
    if not (x_is_ragged or y_is_ragged):
      return self.NOT_SUPPORTED

    # Convert args to tensors.  Bail if conversion fails.
    try:
      if not x_is_ragged:
        x = ops.convert_to_tensor(x, name=self._x, preferred_dtype=y.dtype)
      if not y_is_ragged:
        y = ops.convert_to_tensor(y, name=self._y, preferred_dtype=x.dtype)
    except (TypeError, ValueError):
      return self.NOT_SUPPORTED

    if x_is_ragged and y_is_ragged:
      x, y = ragged_tensor.match_row_splits_dtypes(x, y)

    if ((x_is_ragged and y_is_ragged) or
        (x_is_ragged and x.flat_values.shape.ndims <= y.shape.ndims) or
        (y_is_ragged and y.flat_values.shape.ndims <= x.shape.ndims)):
      bcast_shape = ragged_tensor_shape.broadcast_dynamic_shape(
          ragged_tensor_shape.RaggedTensorDynamicShape.from_tensor(x),
          ragged_tensor_shape.RaggedTensorDynamicShape.from_tensor(y))
      x = ragged_tensor_shape.broadcast_to(
          x, bcast_shape, broadcast_inner_dimensions=False)
      y = ragged_tensor_shape.broadcast_to(
          y, bcast_shape, broadcast_inner_dimensions=False)

    x_values = x.flat_values if ragged_tensor.is_ragged(x) else x
    y_values = y.flat_values if ragged_tensor.is_ragged(y) else y
    mapped_values = self._original_op(x_values, y_values, *args, **kwargs)
    if ragged_tensor.is_ragged(x):
      return x.with_flat_values(mapped_values)
    else:
      return y.with_flat_values(mapped_values)
  def ragged_op(*args, **kwargs):
    """Ragged version of `op`."""
    args = list(args)

    # Collect all of the elementwise arguments, and put them in a single
    # dict whose values are the (potentially ragged) tensors that need to
    # be broadcast to a common shape.  The keys of this dict are tuples
    # (argkey, index), where argkey is an int for poitional args or a string
    # for keyword args; and index is None for non-list args and the index of the
    # tensor for list args.
    elementwise_args = {}
    for (name, position, is_list) in elementwise_arg_infos.values():
      if position < len(args):
        if is_list:
          args[position] = list(args[position])
          for (index, arg) in enumerate(args[position]):
            elementwise_args[position, index] = arg
        else:
          elementwise_args[position, None] = args[position]
      elif name in kwargs:
        if is_list:
          kwargs[name] = list(kwargs[name])
          for (i, arg) in enumerate(kwargs[name]):
            elementwise_args[name, i] = arg
        else:
          elementwise_args[name, None] = kwargs[name]

    with ops.name_scope(None, op.__name__, elementwise_args.values()):
      # Convert all inputs to tensors or ragged tensors.
      for ((key, index), tensor) in elementwise_args.items():
        argname = elementwise_arg_infos[key].name
        converted = ragged_factory_ops.convert_to_tensor_or_ragged_tensor(
            tensor, name=argname)
        elementwise_args[key, index] = converted

      # Broadcast tensors to have compatible shapes.
      broadcast_args, result_splits, broadcast_check_ops = \
          _broadcast_elementwise_args(elementwise_args)

      # Replace tensor arguments with their dense values.
      for ((key, index), tensor) in broadcast_args.items():
        if ragged_tensor.is_ragged(tensor):
          if isinstance(key, int) and index is None:
            args[key] = tensor.inner_values
          elif isinstance(key, int) and index is not None:
            args[key][index] = tensor.inner_values
          elif isinstance(key, str) and index is None:
            kwargs[key] = tensor.inner_values
          else:
            assert isinstance(key, str) and index is not None
            kwargs[key][index] = tensor.inner_values

      # Call the elementwise op on the broadcasted dense values.
      with ops.control_dependencies(broadcast_check_ops):
        result_values = op(*args, **kwargs)

      # Restore any ragged dimensions that we stripped off, and return the
      # result.
      return ragged_factory_ops.from_nested_row_splits(result_values,
                                                       result_splits)
def _ragged_tile_axis(rt_input, axis, repeats, row_splits_dtype):
  """Tile a dimension of a RaggedTensor to match a ragged shape."""
  assert axis > 0  # Outermost dimension may not be ragged.

  if not ragged_tensor.is_ragged(rt_input):
    rt_input = ragged_tensor.RaggedTensor.from_tensor(
        rt_input, ragged_rank=1, row_splits_dtype=row_splits_dtype)

  if axis > 1:
    return rt_input.with_values(
        _ragged_tile_axis(rt_input.values, axis - 1, repeats,
                          row_splits_dtype))
  else:
    src_row_splits = rt_input.nested_row_splits
    src_row_lengths = rt_input.nested_row_lengths()
    splits = src_row_splits[0]

    dst_row_lengths = [repeats]
    for i in range(1, len(src_row_lengths)):
      dst_row_lengths.append(
          ragged_util.repeat_ranges(src_row_lengths[i], splits, repeats))
      splits = array_ops.gather(src_row_splits[i], splits)
    dst_values = ragged_util.repeat_ranges(rt_input.flat_values, splits,
                                           repeats)
    return ragged_tensor.RaggedTensor.from_nested_row_lengths(
        dst_values, dst_row_lengths, validate=False)
def reduce_mean(rt_input, axis=None, name=None):
  """For docs, see: _RAGGED_REDUCE_DOCSTRING."""
  with ops.name_scope(name, 'RaggedReduceMean', [rt_input, axis]):
    total = reduce_sum(rt_input, axis)
    if ragged_tensor.is_ragged(rt_input):
      ones = ragged_factory_ops.from_nested_row_splits(
          array_ops.ones_like(rt_input.inner_values),
          rt_input.nested_row_splits)
    else:
      ones = array_ops.ones_like(rt_input)
    count = reduce_sum(ones, axis)
    if ragged_tensor.is_ragged(total):
      return ragged_factory_ops.from_nested_row_splits(
          total.inner_values / count.inner_values, total.nested_row_splits)
    else:
      return total / count
Example #9
0
  def assertDatasetsEqual(self, dataset1, dataset2):
    """Checks that datasets are equal. Supports both graph and eager mode."""
    self.assertTrue(dataset_ops.get_structure(dataset1).is_compatible_with(
        dataset_ops.get_structure(dataset2)))
    self.assertTrue(dataset_ops.get_structure(dataset2).is_compatible_with(
        dataset_ops.get_structure(dataset1)))
    flattened_types = nest.flatten(
        dataset_ops.get_legacy_output_types(dataset1))

    next1 = self.getNext(dataset1)
    next2 = self.getNext(dataset2)

    while True:
      try:
        op1 = self.evaluate(next1())
      except errors.OutOfRangeError:
        with self.assertRaises(errors.OutOfRangeError):
          self.evaluate(next2())
        break
      op2 = self.evaluate(next2())

      op1 = nest.flatten(op1)
      op2 = nest.flatten(op2)
      assert len(op1) == len(op2)
      for i in range(len(op1)):
        if sparse_tensor.is_sparse(op1[i]):
          self.assertSparseValuesEqual(op1[i], op2[i])
        elif ragged_tensor.is_ragged(op1[i]):
          self.assertRaggedEqual(op1[i], op2[i])
        elif flattened_types[i] == dtypes.string:
          self.assertAllEqual(op1[i], op2[i])
        else:
          self.assertAllClose(op1[i], op2[i])
def to_sparse(rt_input, name=None):
  """Converts a `RaggedTensor` into a sparse tensor.

  Example:

  ```python
  >>> rt = ragged.constant([[1, 2, 3], [4], [], [5, 6]])
  >>> ragged.to_sparse(rt).eval()
  SparseTensorValue(indices=[[0, 0], [0, 1], [0, 2], [1, 0], [3, 0], [3, 1]],
                    values=[1, 2, 3, 4, 5, 6],
                    dense_shape=[4, 3])
  ```

  Args:
    rt_input: The input `RaggedTensor`.
    name: A name prefix for the returned tensors (optional).

  Returns:
    A SparseTensor with the same values as `rt_input`.
  """
  if not ragged_tensor.is_ragged(rt_input):
    raise TypeError('Expected RaggedTensor, got %s' % type(rt_input).__name__)
  with ops.name_scope(name, 'RaggedToSparse', [rt_input]):
    rt_input = ragged_factory_ops.convert_to_tensor_or_ragged_tensor(
        rt_input, name='rt_input')
    result = gen_ragged_conversion_ops.ragged_tensor_to_sparse(
        rt_input.nested_row_splits, rt_input.inner_values, name=name)
    return sparse_tensor.SparseTensor(
        result.sparse_indices, result.sparse_values, result.sparse_dense_shape)
Example #11
0
def normalize_tensors(tensors):
  """Converts a nested structure of tensor-like objects to tensors.

  * `SparseTensor`-like inputs are converted to `SparseTensor`.
  * `TensorArray` inputs are passed through.
  * Everything else is converted to a dense `Tensor`.

  Args:
    tensors: A nested structure of tensor-like, list,
      `SparseTensor`, `SparseTensorValue`, or `TensorArray` objects.

  Returns:
    A nested structure of tensor, `SparseTensor`, or `TensorArray` objects.
  """
  flat_tensors = nest.flatten(tensors)
  prepared = []
  with ops.name_scope("normalize_tensors"):
    for i, t in enumerate(flat_tensors):
      if sparse_tensor_lib.is_sparse(t):
        prepared.append(sparse_tensor_lib.SparseTensor.from_value(t))
      elif ragged_tensor.is_ragged(t):
        prepared.append(
            ragged_tensor.convert_to_tensor_or_ragged_tensor(
                t, name="component_%d" % i))
      elif isinstance(t, tensor_array_ops.TensorArray):
        prepared.append(t)
      else:
        prepared.append(ops.convert_to_tensor(t, name="component_%d" % i))
  return nest.pack_sequence_as(tensors, prepared)
Example #12
0
def rank(input, name=None):  # pylint: disable=redefined-builtin
  """Returns the rank of a RaggedTensor.

  Returns a 0-D `int32` `Tensor` representing the rank of `input`.

  For example:

  ```python
  # shape of tensor 't' is [2, None, None]
  t = tf.ragged.constant([[[1], [2, 2]], [[3, 3, 3], [4, 4, 4, 4]]])
  tf.rank(t)  # 3
  ```

  Args:
    input: A `RaggedTensor`
    name: A name for the operation (optional).

  Returns:
    A `Tensor` of type `int32`.
  """
  with ops.name_scope(name, 'RaggedRank', [input]) as name:
    if not ragged_tensor.is_ragged(input):
      return array_ops.rank(input, name)

    return input.ragged_rank + array_ops.rank(input.flat_values)
Example #13
0
def reduce_mean(input_tensor, axis=None, keepdims=None, name=None):
  """For docs, see: _RAGGED_REDUCE_DOCSTRING."""
  with ops.name_scope(name, 'RaggedReduceMean', [input_tensor, axis]):
    total = reduce_sum(input_tensor, axis, keepdims)
    if ragged_tensor.is_ragged(input_tensor):
      ones = ragged_tensor.RaggedTensor.from_nested_row_splits(
          array_ops.ones_like(input_tensor.flat_values),
          input_tensor.nested_row_splits)
    else:
      ones = array_ops.ones_like(input_tensor)
    count = reduce_sum(ones, axis, keepdims)
    if ragged_tensor.is_ragged(total):
      return ragged_tensor.RaggedTensor.from_nested_row_splits(
          total.flat_values / count.flat_values, total.nested_row_splits)
    else:
      return total / count
def _replace_ragged_with_flat_values(value, nested_splits_lists):
  """Replace RaggedTensors with their flat_values, and record their splits.

  Returns a copy of `value`, with any nested `RaggedTensor`s replaced by their
  `flat_values` tensor.  Looks inside lists, tuples, and dicts.

  Appends each `RaggedTensor`'s `nested_splits` to `nested_splits_lists`.

  Args:
    value: The value that should be transformed by replacing `RaggedTensors`.
    nested_splits_lists: An output parameter used to record the `nested_splits`
      for any `RaggedTensors` that were replaced.

  Returns:
    A copy of `value` with nested `RaggedTensors` replaced by their `values`.
  """
  # Base case
  if ragged_tensor.is_ragged(value):
    value = ragged_tensor.convert_to_tensor_or_ragged_tensor(value)
    nested_splits_lists.append(value.nested_row_splits)
    return value.flat_values

  # Recursion cases
  def recurse(v):
    return _replace_ragged_with_flat_values(v, nested_splits_lists)

  if isinstance(value, list):
    return [recurse(v) for v in value]
  elif isinstance(value, tuple):
    return tuple(recurse(v) for v in value)
  elif isinstance(value, dict):
    return dict((k, recurse(v)) for (k, v) in value.items())
  else:
    return value
Example #15
0
 def _eval_tensor(self, tensor):
   if ragged_tensor.is_ragged(tensor):
     return ragged_tensor_value.RaggedTensorValue(
         self._eval_tensor(tensor.values),
         self._eval_tensor(tensor.row_splits))
   else:
     return test_util.TensorFlowTestCase._eval_tensor(self, tensor)
Example #16
0
  def testToBatchedTensorList(self, value_fn, element_0_fn):
    batched_value = value_fn()
    s = structure.Structure.from_value(batched_value)
    batched_tensor_list = s._to_batched_tensor_list(batched_value)

    # The batch dimension is 2 for all of the test cases.
    # NOTE(mrry): `tf.shape()` does not currently work for the DT_VARIANT
    # tensors in which we store sparse tensors.
    for t in batched_tensor_list:
      if t.dtype != dtypes.variant:
        self.assertEqual(2, self.evaluate(array_ops.shape(t)[0]))

    # Test that the 0th element from the unbatched tensor is equal to the
    # expected value.
    expected_element_0 = self.evaluate(element_0_fn())
    unbatched_s = s._unbatch()
    actual_element_0 = unbatched_s._from_tensor_list(
        [t[0] for t in batched_tensor_list])

    for expected, actual in zip(
        nest.flatten(expected_element_0), nest.flatten(actual_element_0)):
      if sparse_tensor.is_sparse(expected):
        self.assertSparseValuesEqual(expected, actual)
      elif ragged_tensor.is_ragged(expected):
        self.assertRaggedEqual(expected, actual)
      else:
        self.assertAllEqual(expected, actual)
Example #17
0
 def eval_to_list(self, tensor):
   value = self.evaluate(tensor)
   if ragged_tensor.is_ragged(value):
     return value.to_list()
   elif isinstance(value, np.ndarray):
     return value.tolist()
   else:
     return value
  def testFromTensorSlicesMixedRagged(self):
    components = (np.tile(np.array([[1], [2], [3]]),
                          20), np.tile(np.array([[12], [13], [14]]),
                                       22), np.array([37.0, 38.0, 39.0]),
                  sparse_tensor.SparseTensorValue(
                      indices=np.array([[0, 0], [1, 0], [2, 0]]),
                      values=np.array([0, 0, 0]),
                      dense_shape=np.array([3, 1])),
                  sparse_tensor.SparseTensorValue(
                      indices=np.array([[0, 0], [1, 1], [2, 2]]),
                      values=np.array([1, 2, 3]),
                      dense_shape=np.array([3, 3])),
                  ragged_factory_ops.constant_value([[[0]], [[1]], [[2]]]))

    dataset = dataset_ops.Dataset.from_tensor_slices(components)
    get_next = self.getNext(dataset)

    expected = [
        (sparse_tensor.SparseTensorValue(
            indices=np.array([[0]]),
            values=np.array([0]),
            dense_shape=np.array([1])),
         sparse_tensor.SparseTensorValue(
             indices=np.array([[0]]),
             values=np.array([1]),
             dense_shape=np.array([3])), ragged_factory_ops.constant_value([[0]
                                                                           ])),
        (sparse_tensor.SparseTensorValue(
            indices=np.array([[0]]),
            values=np.array([0]),
            dense_shape=np.array([1])),
         sparse_tensor.SparseTensorValue(
             indices=np.array([[1]]),
             values=np.array([2]),
             dense_shape=np.array([3])), ragged_factory_ops.constant_value([[1]
                                                                           ])),
        (sparse_tensor.SparseTensorValue(
            indices=np.array([[0]]),
            values=np.array([0]),
            dense_shape=np.array([1])),
         sparse_tensor.SparseTensorValue(
             indices=np.array([[2]]),
             values=np.array([3]),
             dense_shape=np.array([3])), ragged_factory_ops.constant_value([[2]
                                                                           ])),
    ]
    for i in range(3):
      results = self.evaluate(get_next())
      for component, result_component in zip(
          (list(zip(*components[:3]))[i] + expected[i]), results):
        if sparse_tensor.is_sparse(component):
          self.assertSparseValuesEqual(component, result_component)
        elif ragged_tensor.is_ragged(component):
          self.assertRaggedEqual(component, result_component)
        else:
          self.assertAllEqual(component, result_component)
    with self.assertRaises(errors.OutOfRangeError):
      self.evaluate(get_next())
Example #19
0
def _increase_ragged_rank_to(rt_input, ragged_rank):
  """Adds ragged dimensions to `rt_input` so it has the desired ragged rank."""
  if ragged_rank > 0:
    if not ragged_tensor.is_ragged(rt_input):
      rt_input = ragged_conversion_ops.from_tensor(rt_input)
    if rt_input.ragged_rank < ragged_rank:
      rt_input = rt_input.with_values(
          _increase_ragged_rank_to(rt_input.values, ragged_rank - 1))
  return rt_input
def _increase_ragged_rank_to(rt_input, ragged_rank, row_splits_dtype):
  """Adds ragged dimensions to `rt_input` so it has the desired ragged rank."""
  if ragged_rank > 0:
    if not ragged_tensor.is_ragged(rt_input):
      rt_input = ragged_tensor.RaggedTensor.from_tensor(
          rt_input, row_splits_dtype=row_splits_dtype)
    if rt_input.ragged_rank < ragged_rank:
      rt_input = rt_input.with_values(
          _increase_ragged_rank_to(rt_input.values, ragged_rank - 1,
                                   row_splits_dtype))
  return rt_input
Example #21
0
def segment_mean(data, segment_ids, num_segments, name=None):
  """For docs, see: _RAGGED_SEGMENT_DOCSTRING."""
  with ops.name_scope(name, 'RaggedSegmentMean',
                      [data, segment_ids, num_segments]):
    total = segment_sum(data, segment_ids, num_segments)
    ones = ragged_tensor.RaggedTensor.from_nested_row_splits(
        array_ops.ones_like(data.flat_values), data.nested_row_splits)
    count = segment_sum(ones, segment_ids, num_segments)
    if ragged_tensor.is_ragged(total):
      return total.with_flat_values(total.flat_values / count.flat_values)
    else:
      return total / count
Example #22
0
 def from_tensor(cls, rt_input):
   """Constructs a ragged shape for a potentially ragged tensor."""
   with ops.name_scope(None, 'RaggedTensorDynamicShapeFromTensor', [rt_input]):
     rt_input = ragged_tensor.convert_to_tensor_or_ragged_tensor(rt_input)
     if not ragged_tensor.is_ragged(rt_input):
       return cls([], array_ops.shape(rt_input))
     else:
       partitioned_dim_sizes = (
           (rt_input.nrows(),) + rt_input.nested_row_lengths())
       return RaggedTensorDynamicShape(
           partitioned_dim_sizes,
           array_ops.shape(rt_input.flat_values)[1:])
def segment_mean(data, segment_ids, num_segments, name=None):
  """For docs, see: _RAGGED_SEGMENT_DOCSTRING."""
  with ops.name_scope(name, 'RaggedSegmentMean',
                      [data, segment_ids, num_segments]):
    total = segment_sum(data, segment_ids, num_segments)
    ones = ragged_tensor.RaggedTensor.from_nested_row_splits(
        array_ops.ones_like(data.flat_values), data.nested_row_splits,
        validate=False)
    count = segment_sum(ones, segment_ids, num_segments)
    if ragged_tensor.is_ragged(total):
      return total.with_flat_values(total.flat_values / count.flat_values)
    else:
      return total / count
Example #24
0
 def from_tensor(cls, rt_input, dim_size_dtype=None):
   """Constructs a ragged shape for a potentially ragged tensor."""
   with ops.name_scope(None, 'RaggedTensorDynamicShapeFromTensor', [rt_input]):
     rt_input = ragged_tensor.convert_to_tensor_or_ragged_tensor(rt_input)
     if not ragged_tensor.is_ragged(rt_input):
       return cls([], array_ops.shape(rt_input))
     else:
       partitioned_dim_sizes = (
           (rt_input.nrows(),) + rt_input.nested_row_lengths())
       return RaggedTensorDynamicShape(
           partitioned_dim_sizes,
           array_ops.shape(rt_input.flat_values)[1:],
           dim_size_dtype=dim_size_dtype)
Example #25
0
 def handle(self, args, kwargs):
     if args:
         x, args = args[0], args[1:]
     else:
         kwargs = kwargs.copy()
         x = kwargs.pop(self._x, None)
     if x is None:
         return self.NOT_SUPPORTED
     if self._arg_is_list:
         found_ragged = False
         for elt in x:
             if ragged_tensor.is_ragged(elt):
                 found_ragged = True
             elif not _is_convertible_to_tensor(elt):
                 return self.NOT_SUPPORTED
         if found_ragged:
             nested_splits_lists = [
                 elt.nested_row_splits for elt in x
                 if ragged_tensor.is_ragged(elt)
             ]
             flat_values = [
                 elt.flat_values if ragged_tensor.is_ragged(elt) else elt
                 for elt in x
             ]
             with ops.control_dependencies(
                     ragged_util.assert_splits_match(nested_splits_lists)):
                 return ragged_tensor.RaggedTensor.from_nested_row_splits(
                     self._original_op(flat_values, *args, **kwargs),
                     nested_splits_lists[0])
         else:
             return self.NOT_SUPPORTED
     else:
         found_ragged = ragged_tensor.is_ragged(x)
         if found_ragged:
             mapped_values = self._original_op(x.flat_values, *args,
                                               **kwargs)
             return x.with_flat_values(mapped_values)
         else:
             return self.NOT_SUPPORTED
Example #26
0
  def is_supported(self, args, kwargs):
    found_ragged = False
    for arg_info in self._ragged_args:
      if arg_info.position < len(args):
        arg = args[arg_info.position]
      else:
        arg = kwargs.get(arg_info.name, None)

      if arg_info.is_list:
        if not isinstance(arg, (list, tuple)):
          return False
        for elt in arg:
          if ragged_tensor.is_ragged(elt):
            found_ragged = True
          elif not _is_convertible_to_tensor(elt):
            return False
      else:
        if ragged_tensor.is_ragged(arg):
          found_ragged = True
        elif not _is_convertible_to_tensor(arg):
          return False
    return found_ragged
Example #27
0
def reduce_mean(input_tensor: ragged_tensor.Ragged,
                axis=None,
                keepdims=None,
                name=None):
  """For docs, see: _RAGGED_REDUCE_DOCSTRING."""
  with ops.name_scope(name, 'RaggedReduceMean', [input_tensor, axis]):
    total = reduce_sum(input_tensor, axis, keepdims)
    if ragged_tensor.is_ragged(input_tensor):
      ones = ragged_tensor.RaggedTensor.from_nested_row_splits(
          array_ops.ones_like(input_tensor.flat_values),
          input_tensor.nested_row_splits,
          validate=False)
    else:
      ones = array_ops.ones_like(input_tensor)
    count = reduce_sum(ones, axis, keepdims)
    if ragged_tensor.is_ragged(total):
      return ragged_tensor.RaggedTensor.from_nested_row_splits(
          total.flat_values / count.flat_values,
          total.nested_row_splits,
          validate=False)
    else:
      return total / count
Example #28
0
    def testUnaryElementwiseOp(self, x, op=math_ops.abs, **extra_args):
        if test_util.IsBuiltWithROCm():
            # TODO(rocm):
            # This fails on ROCm...see JIRA ticket 236756
            self.skipTest('Fails on ROCM')

        result = op(x, **extra_args)

        # Run the wrapped op on the dense values, for comparison.
        dense_x = x.flat_values if ragged_tensor.is_ragged(x) else x
        expected_flat_values = array_ops.reshape(op(dense_x, **extra_args),
                                                 [-1])

        # Check that the result has the expected shape.
        self.assertSameShape(x, result)

        # Check that the result has the expected (flattened) values.
        if ragged_tensor.is_ragged(result):
            result_flat_values = array_ops.reshape(result.flat_values, [-1])
        else:
            result_flat_values = array_ops.reshape(result, [-1])
        self.assertAllEqual(expected_flat_values, result_flat_values)
Example #29
0
    def is_supported(self, args, kwargs):
        found_ragged = False
        for arg_info in self._ragged_args:
            if arg_info.position < len(args):
                arg = args[arg_info.position]
            else:
                arg = kwargs.get(arg_info.name, None)

            if arg_info.is_list:
                if not isinstance(arg, (list, tuple)):
                    return False
                for elt in arg:
                    if ragged_tensor.is_ragged(elt):
                        found_ragged = True
                    elif not _is_convertible_to_tensor(elt):
                        return False
            else:
                if ragged_tensor.is_ragged(arg):
                    found_ragged = True
                elif not _is_convertible_to_tensor(arg):
                    return False
        return found_ragged
Example #30
0
    def from_tensor(cls, rt_input, dim_size_dtype=None):
        """Constructs a ragged shape for a potentially ragged tensor."""
        with ops.name_scope(None, 'RaggedTensorDynamicShapeFromTensor',
                            [rt_input]):
            rt_input = ragged_tensor.convert_to_tensor_or_ragged_tensor(
                rt_input)
            if not ragged_tensor.is_ragged(rt_input):
                return cls([], array_ops.shape(rt_input))
            else:
                partitioned_dim_sizes = [rt_input.nrows()]
                rt = rt_input
                while ragged_tensor.is_ragged(rt):
                    if rt.uniform_row_length is None:
                        partitioned_dim_sizes.append(rt.row_lengths())
                    else:
                        partitioned_dim_sizes.append(rt.uniform_row_length)
                    rt = rt.values

                return RaggedTensorDynamicShape(tuple(partitioned_dim_sizes),
                                                array_ops.shape(
                                                    rt_input.flat_values)[1:],
                                                dim_size_dtype=dim_size_dtype)
Example #31
0
def _convert_to_structured_field_value(value):
    """Converts `value` to a Tensor, RaggedTensor, or StructuredTensor."""
    if isinstance(value,
                  (ops.Tensor, ragged_tensor.RaggedTensor, StructuredTensor)):
        return value
    elif ragged_tensor.is_ragged(value):
        return ragged_tensor.convert_to_tensor_or_ragged_tensor(value)
    else:
        try:
            return ops.convert_to_tensor(value)
        except (ValueError, TypeError):
            raise TypeError('Unexpected type for value in `fields`: %r' %
                            value)
Example #32
0
    def _preprocess(self, inputs):
        if self._standardize == LOWER_AND_STRIP_PUNCTUATION:
            if ragged_tensor.is_ragged(inputs):
                lowercase_inputs = ragged_functional_ops.map_flat_values(
                    gen_string_ops.string_lower, inputs)
                # Depending on configuration, we may never touch the non-data tensor
                # in the ragged inputs tensor. If that is the case, and this is the
                # only layer in the keras model, running it will throw an error.
                # To get around this, we wrap the result in an identity.
                lowercase_inputs = array_ops.identity(lowercase_inputs)
            else:
                lowercase_inputs = gen_string_ops.string_lower(inputs)
            inputs = string_ops.regex_replace(lowercase_inputs,
                                              DEFAULT_STRIP_REGEX, "")
        elif callable(self._standardize):
            inputs = self._standardize(inputs)
        elif self._standardize is not None:
            raise ValueError(
                ("%s is not a supported standardization. "
                 "TextVectorization supports the following options "
                 "for `standardize`: None, "
                 "'lower_and_strip_punctuation', or a "
                 "Callable.") % self._standardize)

        if self._split is not None:
            # If we are splitting, we validate that the 1st axis is of dimension 1 and
            # so can be squeezed out. We do this here instead of after splitting for
            # performance reasons - it's more expensive to squeeze a ragged tensor.
            inputs = array_ops.squeeze(inputs, axis=1)
            if self._split == SPLIT_ON_WHITESPACE:
                # This treats multiple whitespaces as one whitespace, and strips leading
                # and trailing whitespace.
                inputs = ragged_string_ops.string_split_v2(inputs)
            elif callable(self._split):
                inputs = self._split(inputs)
            else:
                raise ValueError(
                    ("%s is not a supported splitting."
                     "TextVectorization supports the following options "
                     "for `split`: None, 'whitespace', or a Callable.") %
                    self._split)

        # Note that 'inputs' here can be either ragged or dense depending on the
        # configuration choices for this Layer. The strings.ngrams op, however, does
        # support both ragged and dense inputs.
        if self._ngrams is not None:
            inputs = ragged_string_ops.ngrams(inputs,
                                              ngram_width=self._ngrams,
                                              separator=" ")

        return inputs
Example #33
0
 def handle(self, args, kwargs):
   if args:
     x, args = args[0], args[1:]
   else:
     kwargs = kwargs.copy()
     x = kwargs.pop(self._x, None)
   if x is None:
     return self.NOT_SUPPORTED
   if self._arg_is_list:
     found_ragged = False
     for elt in x:
       if ragged_tensor.is_ragged(elt):
         found_ragged = True
       elif not _is_convertible_to_tensor(elt):
         return self.NOT_SUPPORTED
     if found_ragged:
       x = ragged_tensor.match_row_splits_dtypes(*x)
       nested_splits_lists = [
           elt.nested_row_splits for elt in x if ragged_tensor.is_ragged(elt)
       ]
       flat_values = [
           elt.flat_values if ragged_tensor.is_ragged(elt) else elt
           for elt in x
       ]
       with ops.control_dependencies(
           ragged_util.assert_splits_match(nested_splits_lists)):
         return ragged_tensor.RaggedTensor.from_nested_row_splits(
             self._original_op(flat_values, *args, **kwargs),
             nested_splits_lists[0], validate=False)
     else:
       return self.NOT_SUPPORTED
   else:
     found_ragged = ragged_tensor.is_ragged(x)
     if found_ragged:
       mapped_values = self._original_op(x.flat_values, *args, **kwargs)
       return x.with_flat_values(mapped_values)
     else:
       return self.NOT_SUPPORTED
Example #34
0
def string_format(template, inputs, placeholder="{}", summarize=3, name=None):
  """Version of tf.strings.format that handles RaggedTensors."""
  if tensor_util.is_tf_type(inputs) or ragged_tensor.is_ragged(inputs):
    inputs = [inputs]

  split_template = template.split(placeholder)
  if len(inputs) != len(split_template) - 1:
    raise ValueError("num placeholders in template and num inputs must match"
                     ": {} vs {}".format(len(split_template) - 1, len(inputs)))

  with ops.name_scope(name, "StringFormat", [inputs]):
    output_pieces = [constant_op.constant(split_template[0])]
    for i, input in enumerate(inputs):
      if ragged_tensor.is_ragged(input):
        output_pieces.append(ragged_tensor_to_string(input, summarize))
      else:
        output_pieces.append(string_ops.string_format(
            "{}", [input], summarize=summarize))
      output_pieces.append(constant_op.constant(split_template[i + 1]))
    if len(output_pieces) == 1:
      return output_pieces[0]
    else:
      return string_ops.reduce_join(output_pieces)
Example #35
0
 def call(self, inputs):
     if isinstance(inputs, tf.SparseTensor):
         id_values = self._round_and_truncate(inputs.values)
         result = tf.SparseTensor(
             indices=inputs.indices,
             values=id_values,
             dense_shape=inputs.dense_shape,
         )
     elif ragged_tensor.is_ragged(inputs):
         result = ragged_functional_ops.map_flat_values(
             self._round_and_truncate, inputs)
     else:
         result = self._round_and_truncate(inputs)
     return tf.cast(result, tf.int64)
def _structured_tensor_like(t):
    """Create a StructuredTensor with the shape of a (composite) tensor."""
    if isinstance(t, ops.Tensor):
        return _structured_tensor_from_dense_tensor(t)
    if ragged_tensor.is_ragged(t):
        return StructuredTensor.from_fields(
            {},
            shape=t.get_shape(),
            row_partitions=_all_nested_row_partitions(t))
    # here, it is a StructuredTensor
    return StructuredTensor.from_fields({},
                                        shape=t.shape,
                                        row_partitions=t.row_partitions,
                                        nrows=t.nrows())
Example #37
0
def _RaggedSubstr(text_input, begin, end):
    text_input_flat = None
    if ragged_tensor.is_ragged(text_input):
        text_input_flat = text_input.flat_values
    else:
        text_input_flat = ops.convert_to_tensor(text_input)

    if ragged_tensor.is_ragged(begin):
        broadcasted_text = array_ops.gather_v2(text_input_flat,
                                               begin.nested_value_rowids()[-1])

        # convert boardcasted_text into a 1D tensor.
        broadcasted_text = array_ops.reshape(broadcasted_text, [-1])
        size = math_ops.sub(end.flat_values, begin.flat_values)
        new_tokens = string_ops.substr_v2(broadcasted_text, begin.flat_values,
                                          size)
        return begin.with_flat_values(new_tokens)
    else:
        assert begin.shape.ndims == 1
        assert text_input_flat.shape.ndims == 0
        size = math_ops.sub(end, begin)
        new_tokens = string_ops.substr_v2(text_input_flat, begin, size)
        return new_tokens
Example #38
0
def ragged_binary_elementwise_op(op, x, y):
    """Binary elementwise api handler for RaggedTensors."""
    x_is_ragged = ragged_tensor.is_ragged(x)
    y_is_ragged = ragged_tensor.is_ragged(y)

    # Convert args to tensors.
    x = ragged_tensor.convert_to_tensor_or_ragged_tensor(
        x, preferred_dtype=(y.dtype if y_is_ragged else None))
    y = ragged_tensor.convert_to_tensor_or_ragged_tensor(
        y, preferred_dtype=x.dtype)

    if x_is_ragged and y_is_ragged:
        x, y = ragged_tensor.match_row_splits_dtypes(x, y)

    # Perform broadcasting, when appropraite
    if ((x_is_ragged and y_is_ragged)
            or (x_is_ragged and x.flat_values.shape.ndims <= y.shape.ndims)
            or (y_is_ragged and y.flat_values.shape.ndims <= x.shape.ndims)):
        bcast_shape = ragged_tensor_shape.broadcast_dynamic_shape(
            ragged_tensor_shape.RaggedTensorDynamicShape.from_tensor(x),
            ragged_tensor_shape.RaggedTensorDynamicShape.from_tensor(y))
        x = ragged_tensor_shape.broadcast_to(x,
                                             bcast_shape,
                                             broadcast_inner_dimensions=False)
        y = ragged_tensor_shape.broadcast_to(y,
                                             bcast_shape,
                                             broadcast_inner_dimensions=False)

    x_values = x.flat_values if ragged_tensor.is_ragged(x) else x
    y_values = y.flat_values if ragged_tensor.is_ragged(y) else y
    mapped_values = op(x_values, y_values)
    if isinstance(mapped_values, bool):
        return mapped_values  # Special case for tensor_equals.
    if ragged_tensor.is_ragged(x):
        return x.with_flat_values(mapped_values)
    else:
        return y.with_flat_values(mapped_values)
Example #39
0
  def __init__(self, shape, fields):
    """Creates a `StructuredTensor` from a dictionary of fields.

    Args:
      shape: A `TensorShape`: static information about the shape of the
        `StructuredTensor`.  Must have a known `rank`.
      fields: A dictionary mapping from string to `Tensor`, `RaggedTensor`, or
        `StructuredTensor`, providing the values for individual fields in each
        structure.  If `ndims > 0`, then every tensor in `fields` must have the
        same shape in the first `shape.rank` dimensions; and that shape must be
        compatible with `shape`.

    Returns:
      A `StructuredTensor`.
    """
    shape = tensor_shape.as_shape(shape)
    if shape.rank is None:
      raise ValueError("StructuredTensor's shape must have known rank.")
    if not isinstance(fields, dict):
      raise TypeError('fields must be a dictionary, got %s' %
                      type(fields).__name__)
    self._fields = {}
    with ops.name_scope(None, 'StructuredTensor', fields.values()):
      for (key, value) in fields.items():
        if not isinstance(key, str):
          raise TypeError('Unexpected type for key in `fields`: %r' % key)
        if not _FIELD_NAME_RE.match(key):
          raise ValueError('Field name %r is not currently allowed.' % key)
        if not isinstance(
            value, (ops.Tensor, ragged_tensor.RaggedTensor, StructuredTensor)):
          if ragged_tensor.is_ragged(value):
            value = ragged_tensor.convert_to_tensor_or_ragged_tensor(value)
          else:
            try:
              value = ops.convert_to_tensor(value)
            except (ValueError, TypeError):
              raise TypeError('Unexpected type for value in `fields`: %r' %
                              value)
        self._fields[key] = value

    # Check the static TensorShape for this StructuredTensor.
    shape = tensor_shape.as_shape(shape)
    rank = shape.ndims
    if rank is None:
      raise ValueError("StructuredTensor's shape must have known rank.")
    self._static_shape = shape
    if rank > 0:
      for value in self._fields.values():
        self._static_shape = self._static_shape.merge_with(value.shape[:rank])
  def testListValuedElementwiseOp(self, inputs, op=math_ops.add_n,
                                  **extra_args):
    use_kwargs = extra_args.pop('use_kwargs', False)
    if use_kwargs:
      result = op(inputs=inputs, **extra_args)
    else:
      result = op(inputs, **extra_args)

    # Run the wrapped op on the dense values, for comparison.
    dense_inputs = [
        x.flat_values if ragged_tensor.is_ragged(x) else x for x in inputs
    ]
    expected_flat_values = array_ops.reshape(
        op(dense_inputs, **extra_args), [-1])

    # Check that the result has the expected shape.
    self.assertSameShape(inputs[0], result)

    # Check that the result has the expected (flattened) values.
    if ragged_tensor.is_ragged(result):
      result_flat_values = array_ops.reshape(result.flat_values, [-1])
    else:
      result_flat_values = array_ops.reshape(result, [-1])
    self.assertAllEqual(expected_flat_values, result_flat_values)
  def testBinaryElementwiseOp(self, x, y, op=math_ops.add, **extra_args):
    use_kwargs = extra_args.pop('use_kwargs', ())
    if 'x' in use_kwargs and 'y' in use_kwargs:
      result = op(x=x, y=y, **extra_args)
    elif 'y' in use_kwargs:
      result = op(x, y=y, **extra_args)
    else:
      result = op(x, y, **extra_args)

    # Run the wrapped op on the dense values, for comparison.
    dense_x = x.flat_values if ragged_tensor.is_ragged(x) else x
    dense_y = y.flat_values if ragged_tensor.is_ragged(y) else y
    expected_flat_values = array_ops.reshape(
        op(dense_x, dense_y, **extra_args), [-1])

    # Check that the result has the expected shape.
    self.assertSameShape(y, result)

    # Check that the result has the expected (flattened) values.
    if ragged_tensor.is_ragged(result):
      result_flat_values = array_ops.reshape(result.flat_values, [-1])
    else:
      result_flat_values = array_ops.reshape(result, [-1])
    self.assertAllEqual(expected_flat_values, result_flat_values)
    def tokenize(self, input, name=None):  # pylint: disable=redefined-builtin
        """Tokenizes a tensor of UTF-8 strings.

    Args:
      input: A `RaggedTensor` or `Tensor` of UTF-8 strings with any shape.
      name: The name argument that is passed to the op function.

    Returns:
      A `RaggedTensor` of tokenized text. The returned shape is the shape of the
      input tensor with an added ragged dimension for tokens of each string.
    """
        with ops.name_scope(name, "SentenceTokenizer", [input, self]):
            input_tensor = ragged_tensor.convert_to_tensor_or_ragged_tensor(
                input)
            if input_tensor.shape.ndims is None:
                raise ValueError(
                    "Rank of input_tensor must be statically known.")
            if ragged_tensor.is_ragged(input_tensor):
                # Recursively process the values of the ragged tensor.
                tokens = self.tokenize(input_tensor.flat_values)
                return input_tensor.with_flat_values(tokens)
            else:
                if input_tensor.shape.ndims > 1:
                    # Convert the input tensor to ragged and process it.
                    return self.tokenize(
                        ragged_conversion_ops.from_tensor(input_tensor))
                elif input_tensor.shape.ndims == 0:
                    tokens = self.tokenize(array_ops.stack([input_tensor]))
                    return tokens.values
                else:
                    # Our rank 1 tensor is the correct shape, so we can process it as
                    # normal.
                    (output_values, row_splits) = (
                        gen_sentencepiece_tokenizer.sentencepiece_tokenize_op(
                            self._model_resource.resource_handle,
                            input_tensor,
                            self.nbest_size,
                            self.alpha,
                            self.add_bos,
                            self.add_eos,
                            self.reverse,
                            self.out_type,
                            return_nbest=self.return_nbest))
                    tokens = RaggedTensor.from_nested_row_splits(
                        flat_values=output_values,
                        nested_row_splits=[row_splits],
                        validate=False)
                    return tokens
Example #43
0
def from_tensor(tensor,
                lengths=None,
                padding=None,
                ragged_rank=1,
                row_splits_dtype=dtypes.int64,
                name=None):
  if ragged_tensor.is_ragged(tensor):
    return tensor
  else:
    return ragged_tensor.RaggedTensor.from_tensor(
        tensor,
        lengths=lengths,
        padding=padding,
        ragged_rank=ragged_rank,
        row_splits_dtype=row_splits_dtype,
        name=name)
Example #44
0
  def lookup(self, inputs):
    """Perform a table lookup."""
    # Sparse tensors don't play nicely with tensor conversion, so we handle
    # them before attempting to convert lists or arrays to tensors.
    if isinstance(
        inputs, (sparse_tensor.SparseTensor, sparse_tensor.SparseTensorValue)):
      return self._sparse_lookup(inputs)

    # Try to convert lists/arrays to tensors or RaggedTensors.
    inputs = ragged_tensor.convert_to_tensor_or_ragged_tensor(inputs)

    # Run the lookup operation on the converted tensor.
    if ragged_tensor.is_ragged(inputs):
      return self._ragged_lookup(inputs)
    else:
      return self._tensor_lookup(inputs)
Example #45
0
 def call(self, inputs):
   # TODO(tanzheny): Add int support.
   str_to_hash_bucket = self._get_string_to_hash_bucket_fn()
   if ragged_tensor.is_ragged(inputs):
     return ragged_functional_ops.map_flat_values(
         str_to_hash_bucket, inputs, num_buckets=self.num_bins, name='hash')
   elif isinstance(inputs, sparse_tensor.SparseTensor):
     sparse_values = inputs.values
     sparse_hashed_values = str_to_hash_bucket(
         sparse_values, self.num_bins, name='hash')
     return sparse_tensor.SparseTensor(
         indices=inputs.indices,
         values=sparse_hashed_values,
         dense_shape=inputs.dense_shape)
   else:
     return str_to_hash_bucket(inputs, self.num_bins, name='hash')
Example #46
0
    def call(self, inputs):
        if isinstance(inputs, (list, tuple, np.ndarray)):
            inputs = ops.convert_to_tensor(inputs)
        if inputs.shape.rank == 1:
            inputs = array_ops.expand_dims(inputs, axis=-1)

        self._called = True
        inputs = self._preprocess(inputs)

        # If we're not doing any output processing, return right away.
        if self._output_mode is None:
            return inputs

        indexed_data = self._index_lookup_layer(inputs)

        if self._output_mode == INT:
            # Once we have the dense tensor, we can return it if we weren't given a
            # fixed output sequence length. If we were, though, we have to dynamically
            # choose whether to pad or trim it based on each tensor.

            # We need to convert to dense if we have a ragged tensor.
            if ragged_tensor.is_ragged(indexed_data):
                dense_data = indexed_data.to_tensor(default_value=0)
            else:
                dense_data = indexed_data

            if self._output_sequence_length is None:
                dense_data.set_shape(tensor_shape.TensorShape((None, None)))
                return dense_data
            else:
                sequence_len = K.shape(dense_data)[1]
                pad_amt = self._output_sequence_length - sequence_len
                pad_fn = lambda: array_ops.pad(dense_data, [[0, 0],
                                                            [0, pad_amt]])
                slice_fn = lambda: dense_data[:, :self._output_sequence_length]
                output_tensor = control_flow_ops.cond(
                    sequence_len < self._output_sequence_length,
                    true_fn=pad_fn,
                    false_fn=slice_fn)
                output_tensor.set_shape(
                    tensor_shape.TensorShape(
                        (None, self._output_sequence_length)))
                return output_tensor

        # If we're not returning integers here, we rely on the vectorization layer
        # to create the output.
        return self._vectorize_layer(indexed_data)
Example #47
0
    def tokenize_with_offsets(self, input_strs):
        """Tokenizes a tensor of UTF-8 strings into words with [start,end) offsets.

    Args:
      input_strs: An N-dimensional `Tensor` or `RaggedTensor` of UTF-8 strings.

    Returns:
      A tuple `(tokens, start_offsets, limit_offsets)` where:
        * `tokens` is a `RaggedTensor` of strings where `tokens[i1...iN, j]` is
          the string content of the `j-th` token in `input_strs[i1...iN]`
        * `start_offsets` is a `RaggedTensor` of int64s where
          `start_offsets[i1...iN, j]` is the byte offset for the start of the
          `j-th` token in `input_strs[i1...iN]`.
        * `limit_offsets` is a `RaggedTensor` of int64s where
          `limit_offsets[i1...iN, j]` is the byte offset immediately after the
          end of the `j-th` token in `input_strs[i...iN]`.
    """
        input_strs = ragged_tensor.convert_to_tensor_or_ragged_tensor(
            input_strs)
        rank = input_strs.shape.ndims
        if rank is None:
            raise ValueError('input must have a known rank.')

        # Currently, the hub_module accepts only rank 1 input tensors, and outputs
        # rank 2 tokens/starts/ends.  To handle input of different ranks (0, 2, 3,
        # etc), we first convert the input into a rank 1 tensor, then run the
        # module, and finally convert the output back to the expected shape.
        if rank == 0:
            # Build a rank 1 input batch with one string.
            input_batch = array_ops.stack([input_strs])
            # [1, (number codepoints)]
            tokens, starts, ends = self._predict_tokens(input_batch)
            return tokens.flat_values, starts.flat_values, ends.flat_values
        elif rank == 1:
            return self._predict_tokens(input_strs)
        else:
            if not ragged_tensor.is_ragged(input_strs):
                input_strs = ragged_tensor.RaggedTensor.from_tensor(
                    input_strs, ragged_rank=rank - 1)

            # [number strings, (number codepoints)]
            tokens, starts, limits = self._predict_tokens(
                input_strs.flat_values)
            tokens = input_strs.with_flat_values(tokens)
            starts = input_strs.with_flat_values(starts)
            limits = input_strs.with_flat_values(limits)
        return tokens, starts, limits
Example #48
0
 def call(self, inputs):
     if ragged_tensor.is_ragged(inputs):
         integer_buckets = ragged_functional_ops.map_flat_values(
             gen_math_ops.Bucketize, input=inputs, boundaries=self.bins)
         # Ragged map_flat_values doesn't touch the non-values tensors in the
         # ragged composite tensor. If this op is the only op a Keras model,
         # this can cause errors in Graph mode, so wrap the tensor in an identity.
         return array_ops.identity(integer_buckets)
     elif isinstance(inputs, sparse_tensor.SparseTensor):
         integer_buckets = gen_math_ops.Bucketize(input=inputs.values,
                                                  boundaries=self.bins)
         return sparse_tensor.SparseTensor(
             indices=array_ops.identity(inputs.indices),
             values=integer_buckets,
             dense_shape=array_ops.identity(inputs.dense_shape))
     else:
         return gen_math_ops.Bucketize(input=inputs, boundaries=self.bins)
Example #49
0
    def compute(self, values, accumulator=None):
        """Compute a step in this computation, returning a new accumulator."""
        if ragged_tensor.is_ragged(values):
            values = values.to_list()
        if isinstance(values, ops.EagerTensor):
            values = values.numpy()
        if isinstance(values, np.ndarray):
            values = values.tolist()

        if accumulator is None:
            accumulator = self._create_accumulator()

        # TODO(momernick): Benchmark improvements to this algorithm.
        for document in values:
            for token in document:
                accumulator.count_dict[token] += 1

        return accumulator
  def call(self, inputs):
    if ragged_tensor.is_ragged(inputs):
      integer_buckets = ragged_functional_ops.map_flat_values(
          math_ops._bucketize, inputs, boundaries=self.bins)  # pylint: disable=protected-access
      # Ragged map_flat_values doesn't touch the non-values tensors in the
      # ragged composite tensor. If this op is the only op a Keras model,
      # this can cause errors in Graph mode, so wrap the tensor in an identity.
      integer_buckets = array_ops.identity(integer_buckets)
    else:
      integer_buckets = math_ops._bucketize(inputs, boundaries=self.bins)  # pylint: disable=protected-access

    if self.output_mode == INTEGER:
      return integer_buckets
    else:
      # The 'bins' array is the set of boundaries between the bins. We actually
      # have 'len(bins)+1' outputs.
      # TODO(momernick): This will change when we have the ability to adapt().
      return array_ops.one_hot(integer_buckets, depth=len(self.bins) + 1)
Example #51
0
  def call(self, inputs, invert=False):
    table = self._inverse_table if invert else self._table
    # The table lookup ops don't natively support ragged tensors, so if we have
    # a RT we need to use map_flat_values to look up every element.
    if ragged_tensor.is_ragged(inputs):
      indexed_data = ragged_functional_ops.map_flat_values(table.lookup, inputs)
    elif isinstance(
        inputs, (sparse_tensor.SparseTensor, sparse_tensor.SparseTensorValue)):
      indexed_data = sparse_tensor.SparseTensor(inputs.indices,
                                                table.lookup(inputs.values),
                                                inputs.dense_shape)
    else:
      indexed_data = table.lookup(inputs)

    # Composite tensors can pass tensor values through, which will cause
    # errors if this is the only layer in the model. To fix this, pass
    # the output through an identity op.
    return array_ops.identity(indexed_data)
Example #52
0
def cont_bow(source, window, seed=None, name=None):
    """Generates `Continuous bag-of-words` target and context pairs from batched list of tokens.

    Args:
        source: `2-D` string `Tensor` or `RaggedTensor`, batched lists of tokens [sentences, tokens].
        window: `int`, size of context before and after target token, must be > 0.
        seed: `int`, used to create a random seed (optional).
            See @{tf.random.set_seed} for behavior.
        name: `string`, a name for the operation (optional).

    Returns:
        `1-D` string `Tensor`: target tokens.
        `2-D` string `RaggedTensor`: context tokens.
        `2-D` int32 `RaggedTensor`: context positions.
    """
    with tf.name_scope(name or 'cont_bow'):
        source = ragged_tensor.convert_to_tensor_or_ragged_tensor(
            source, name='source')

        if source.shape.rank != 2:
            raise ValueError('Rank of `source` must equals 2')

        if not ragged_tensor.is_ragged(source):
            source = ragged_tensor.RaggedTensor.from_tensor(source,
                                                            ragged_rank=1)

        if source.ragged_rank != 1:
            raise ValueError('Ragged rank of `source` must equals 1')

        seed1, seed2 = random_seed.get_seed(seed)

        target, context_values, context_splits, context_positions = tfmiss_ops.miss_cont_bow(
            source_values=source.values,
            source_splits=source.row_splits,
            window=window,
            seed=seed1,
            seed2=seed2)

        context = tf.RaggedTensor.from_row_splits(context_values,
                                                  context_splits)
        position = tf.RaggedTensor.from_row_splits(context_positions,
                                                   context_splits)

        return target, context, position
    def detokenize(self, input, name=None):  # pylint: disable=redefined-builtin
        """Detokenizes tokens into preprocessed text.

    Args:
      input: A `RaggedTensor` or `Tensor` of UTF-8 string tokens with a rank of
        at least 1.
      name: The name argument that is passed to the op function.

    Returns:
      A N-1 dimensional string Tensor or RaggedTensor of the detokenized text.
    """
        with ops.name_scope(name, "SentenceTokenizer", [input, self]):
            input_tensor = ragged_tensor.convert_to_tensor_or_ragged_tensor(
                input)
            if input_tensor.shape.ndims is None:
                raise ValueError(
                    "Rank of input_tensor must be statically known.")
            if input_tensor.shape.ndims == 0:
                raise ValueError("Rank of input_tensor must be at least 1.")
            if ragged_tensor.is_ragged(input_tensor):
                if input_tensor.flat_values.shape.ndims > 1:
                    # If the flat_values of our ragged tensor is multi-dimensional, we can
                    # process it separately and our output will have the same nested
                    # splits as our input.
                    tokens = self.detokenize(input_tensor.flat_values)
                    return input_tensor.with_flat_values(tokens)
                elif input_tensor.ragged_rank > 1:
                    # Recursively process the values of the ragged tensor.
                    tokens = self.detokenize(input_tensor.values)
                    return input_tensor.with_values(tokens)
                else:
                    return gen_sentencepiece_tokenizer.sentencepiece_detokenize_op(
                        self._model_resource.resource_handle,
                        input_tensor.flat_values, input_tensor.row_splits,
                        self.add_bos, self.add_eos, self.reverse)
            else:
                if input_tensor.shape.ndims > 1:
                    # Convert the input tensor to ragged and process it.
                    return self.detokenize(
                        ragged_conversion_ops.from_tensor(input_tensor))
                else:
                    tokens = self.detokenize(array_ops.stack([input_tensor]))
                    return array_ops.reshape(tokens, [])
Example #54
0
def _replace_ragged_with_flat_values(value, partition_lists,
                                     flat_values_nrows):
    """Replace RaggedTensors with their flat_values, and record their partitions.

  Returns a copy of `value`, with any nested `RaggedTensor`s replaced by their
  `flat_values` tensor.  Looks inside lists, tuples, and dicts.

  Appends each `RaggedTensor`'s `RowPartition`s to `partition_lists`.

  Args:
    value: The value that should be transformed by replacing `RaggedTensors`.
    partition_lists: An output parameter used to record the row partitions
      for any `RaggedTensors` that were replaced.
    flat_values_nrows: An output parameter used to record the outer dimension
      size for each replacement `flat_values` (when known).  Contains a list of
      int.

  Returns:
    A copy of `value` with nested `RaggedTensors` replaced by their `values`.
  """
    # Base case
    if ragged_tensor.is_ragged(value):
        value = ragged_tensor.convert_to_tensor_or_ragged_tensor(value)
        partition_lists.append(value._nested_row_partitions)  # pylint: disable=protected-access
        nrows = tensor_shape.dimension_at_index(value.flat_values.shape,
                                                0).value
        if nrows is not None:
            flat_values_nrows.append(nrows)
        return value.flat_values

    # Recursion cases
    def recurse(v):
        return _replace_ragged_with_flat_values(v, partition_lists,
                                                flat_values_nrows)

    if isinstance(value, list):
        return [recurse(v) for v in value]
    elif isinstance(value, tuple):
        return tuple(recurse(v) for v in value)
    elif isinstance(value, dict):
        return dict((k, recurse(v)) for (k, v) in value.items())
    else:
        return value
Example #55
0
def tile(input, multiples, name=None):  # pylint: disable=redefined-builtin
  """Constructs a `RaggedTensor` by tiling a given `RaggedTensor`.

  The values of `input` are replicated `multiples[i]` times along the
  `i`th dimension (for each dimension `i`).  For every dimension `axis` in
  `input`, the length of each output element in that dimension is the
  length of corresponding input element multiplied by `multiples[axis]`.

  Args:
    input: A `RaggedTensor`.
    multiples: A 1-D integer `Tensor`.  Length must be the same as the number of
      dimensions in `input`.
    name: A name for the operation (optional).

  Returns:
    A `RaggedTensor` with the same type, rank, and ragged_rank as `input`.

  #### Example:
    ```python
    >>> rt = tf.ragged.constant([[1, 2], [3]])
    >>> ragged.tile(rt, [3, 2])
    [[1, 2, 1, 2], [3, 3], [1, 2, 1, 2], [3, 3], [1, 2, 1, 2], [3, 3]]
    ```
  """
  with ops.name_scope(name, 'RaggedTile', [input, multiples]):
    input = ragged_tensor.convert_to_tensor_or_ragged_tensor(
        input, name='input')
    if not ragged_tensor.is_ragged(input):
      return array_ops.tile(input, multiples, name)
    multiples = ragged_util.convert_to_int_tensor(
        multiples, name='multiples', dtype=input.row_splits.dtype)
    multiples.shape.assert_has_rank(1)

    # If the constant value of `multiples` is available, then we can use it
    # to skip tiling dimensions where `multiples=1`.
    const_multiples = tensor_util.constant_value(multiples)

    return ragged_tensor.RaggedTensor.from_nested_row_splits(
        _tile_ragged_values(input, multiples, const_multiples),
        _tile_ragged_splits(input, multiples, const_multiples),
        validate=False)
Example #56
0
 def _compareOutputToExpected(self, result_values, expected_values,
                              assert_items_equal):
   if assert_items_equal:
     # TODO(shivaniagrawal): add support for nested elements containing sparse
     # tensors when needed.
     self.assertItemsEqual(result_values, expected_values)
     return
   for i in range(len(result_values)):
     nest.assert_same_structure(result_values[i], expected_values[i])
     for result_value, expected_value in zip(
         nest.flatten(result_values[i]), nest.flatten(expected_values[i])):
       if sparse_tensor.is_sparse(result_value):
         self.assertSparseValuesEqual(result_value, expected_value)
       elif ragged_tensor.is_ragged(result_value):
         self.assertRaggedEqual(result_value, expected_value)
       else:
         self.assertAllEqual(
             result_value,
             expected_value,
             msg=("Result value: {}.  Expected value: {}"
                  .format(result_value, expected_value)))
Example #57
0
def size(input, out_type=dtypes.int32, name=None):  # pylint: disable=redefined-builtin
  """Returns the size of a potentially ragged tensor.

  The size of a ragged tensor is the size of its inner values.

  Args:
    input: A potentially ragged `Tensor`.
    out_type: The numeric output type for the operation.
    name: A name for the operation (optional).

  Returns:
    A Tensor of type `out_type`.

  #### Example:
    ```python
    >>> tf.size(tf.ragged.constant([[1, 2], [3]]))
    3
    ```
  """
  if ragged_tensor.is_ragged(input):
    return array_ops.size(input.flat_values, out_type=out_type, name=name)
  else:
    return array_ops.size(input, out_type=out_type, name=name)
def _broadcast_to_ragged_shape(rt_input, dst_shape, broadcast_inner_dimensions):
  """Broadcasts rt_input to the ragged shape `dst_shape`."""
  # Check that rt_input and dst_shape have the same row_splits dtype.
  if (isinstance(rt_input, ragged_tensor.RaggedTensor) and
      rt_input.row_splits.dtype != dst_shape.dim_size_dtype):
    if not ragged_config.auto_cast_partition_dtype():
      raise ValueError('rt_input and dst_shape have different row_split '
                       'dtypes; use RaggedTensor.with_row_splits_dtype() or '
                       'RaggedTensorDynamicShape.with_dim_size_dtype() to '
                       'convert to a compatible dtype.')
    rt_input = rt_input.with_row_splits_dtype(dtypes.int64)
    dst_shape = dst_shape.with_dim_size_dtype(dtypes.int64)

  # dst_shape's rank and ragged_rank must be greater than or equal to rt_input's
  if rt_input.shape.ndims is None or dst_shape.rank is None:
    raise ValueError('Unable to broadcast: unknown rank')
  if rt_input.shape.ndims > dst_shape.rank:
    raise ValueError('Incompatible with shape: rank mismatch')
  if (isinstance(rt_input, ragged_tensor.RaggedTensor) and
      rt_input.ragged_rank >= dst_shape.num_partitioned_dimensions):
    raise ValueError('Incompatible with shape: ragged rank mismatch')

  src_shape = RaggedTensorDynamicShape.from_tensor(rt_input)
  src_shape = src_shape.broadcast_to_rank(dst_shape.rank)

  # Add dimensions to rt_input so its rank and ragged_rank matches dst_shape.
  if dst_shape.rank > rt_input.shape.ndims:
    if rt_input.shape.ndims < dst_shape.num_inner_dimensions + 1:
      rt_input = array_ops.reshape(
          rt_input, array_ops.concat([[-1], dst_shape.inner_dim_sizes], axis=0))
    for _ in range(dst_shape.rank - rt_input.shape.ndims):
      if ragged_tensor.is_ragged(rt_input):
        nrows = rt_input.nrows()
      else:
        nrows = array_ops.shape(rt_input,
                                out_type=dst_shape.dim_size_dtype)[0]
      rt_input = ragged_tensor.RaggedTensor.from_row_lengths(rt_input, [nrows],
                                                             validate=False)

  # Add ragged dimensions to match dst_shape.
  if ragged_tensor.is_ragged(rt_input):
    inner_rank_diff = (
        rt_input.flat_values.shape.ndims - 1 - dst_shape.num_inner_dimensions)
    if inner_rank_diff > 0:
      rt_input = rt_input.with_flat_values(
          ragged_tensor.RaggedTensor.from_tensor(
              rt_input.flat_values, ragged_rank=inner_rank_diff,
              row_splits_dtype=dst_shape.dim_size_dtype))
  else:
    rt_input = ragged_tensor.RaggedTensor.from_tensor(
        rt_input, ragged_rank=dst_shape.num_partitioned_dimensions - 1,
        row_splits_dtype=dst_shape.dim_size_dtype)

  # Do broadcasting for any dimensions that will remain uniform.  We can do
  # these all at once, since they're independent of one another.
  multiples = [1] * dst_shape.rank
  for axis in range(dst_shape.num_partitioned_dimensions):
    if not src_shape.is_ragged(axis) and not dst_shape.is_ragged(axis):
      src_size = src_shape.dimension_size(axis)
      dst_size = dst_shape.dimension_size(axis)
      if ((tensor_util.constant_value(src_size) in (1, None)) and
          (tensor_util.constant_value(dst_size) != 1)):
        multiples[axis] = array_ops.where(
            math_ops.equal(src_size, 1), dst_size, 1)
  if not all(isinstance(v, int) and v == 1 for v in multiples):
    multiples = array_ops.stack(multiples, axis=0)
    rt_input = ragged_array_ops.tile(rt_input, multiples)

  if broadcast_inner_dimensions:
    rt_input = rt_input.with_flat_values(
        array_ops.reshape(
            rt_input.flat_values,
            array_ops.concat([[-1], dst_shape.inner_dim_sizes], axis=0)))

  # Do broadcasting for dimensions that become ragged.  We must do these from
  # outermost to innermost.
  for axis in range(dst_shape.num_partitioned_dimensions):
    if not src_shape.is_ragged(axis) and dst_shape.is_ragged(axis):
      dst_size = dst_shape.dimension_size(axis)
      rt_input = _ragged_tile_axis(rt_input, axis, dst_size,
                                   dst_shape.dim_size_dtype)

  return rt_input
def gather(params, indices, validate_indices=None, axis=0, batch_dims=0,
           name=None):
  """Gathers ragged slices from `params` axis `0` according to `indices`.

  Returns `RaggedTensor` output, such that:

  ```python
  output.shape = indices.shape + params.shape[1:]
  output.ragged_rank = indices.shape.ndims + params.ragged_rank
  output[i...j, d0...dn] = params[indices[i...j], d0...dn]
  ```

  `params` may be ragged.  `indices` may be ragged.
  `indices` must have dtype `int32` or `int64`. If any index is out of bounds,
  then an error is returned.

  Examples:

  ```python
  >>> params = tf.constant(['a', 'b', 'c', 'd', 'e'])
  >>> indices = tf.constant([3, 1, 2, 1, 0])
  >>> ragged_params = tf.ragged.constant([['a', 'b', 'c'], ['d'], [], ['e']])
  >>> ragged_indices = tf.ragged.constant([[3, 1, 2], [1], [], [0]])

  >>> print ragged.gather(params, ragged_indices)
  [['d', 'b', 'c'], ['b'], [], ['a']]

  >>> print ragged.gather(ragged_params, indices)
  [['e'], ['d'], [], ['d'], ['a', 'b', 'c']]

  >>> print ragged.gather(ragged_params, ragged_indices)
  [[['e'], ['d'], []], [['d']], [], [['a', 'b', 'c']]]
  ```

  Args:
    params: The potentially ragged tensor from which to gather values. Must be
      at least rank 1.
    indices: The potentially ragged tensor indicating which values to gather.
      Must have dtype `int32` or `int64`.  Values must be in the range `[0,
      params.shape[0]]`.
    validate_indices: Ignored.
    axis: Must be zero.
    batch_dims: Must be zero.
    name: A name for the operation (optional).

  Returns:
    A `RaggedTensor`, where `output.dtype=params.dtype` and
    `output.shape=indices.shape + params.shape[1:]` and
    `output.ragged_rank=indices.shape.ndims + params.ragged_rank`.

  Raises:
    ValueError: If indices.shape.ndims is not known statically.
  """
  del validate_indices
  if not isinstance(axis, int) or axis != 0:
    raise ValueError('axis != 0 is not supported for ragged gather yet.')
  if not isinstance(batch_dims, int) or batch_dims != 0:
    raise ValueError('batch_dims != 0 is not supported for ragged gather yet.')
  with ops.name_scope(name, 'RaggedGather', [params, indices]):
    params = ragged_tensor.convert_to_tensor_or_ragged_tensor(
        params, name='params')
    indices = ragged_tensor.convert_to_tensor_or_ragged_tensor(
        indices, name='indices')
    params, indices = ragged_tensor.match_row_splits_dtypes(params, indices)

    if ragged_tensor.is_ragged(indices):
      return indices.with_values(gather(params, indices.values))

    if not ragged_tensor.is_ragged(params):
      return array_ops.gather(params, indices)

    indices = ops.convert_to_tensor(indices)
    if indices.shape.ndims is None:
      raise ValueError('indices.shape.ndims must be known statically')

    result = gen_ragged_array_ops.ragged_gather(
        indices=indices,
        params_dense_values=params.flat_values,
        params_nested_splits=params.nested_row_splits,
        OUTPUT_RAGGED_RANK=indices.shape.ndims + len(params.nested_row_splits) -
        1)

    # Compose the RaggedTensor from splits & values.
    return ragged_tensor.RaggedTensor.from_nested_row_splits(
        result.output_dense_values, result.output_nested_splits, validate=False)
def gather_nd(params, indices, batch_dims=0, name=None):
  """Gather slices from `params` using `n`-dimensional indices.

  This operation is similar to `gather`, but it uses the innermost dimension
  of `indices` to define a slice into `params`.  In particular, if:

  * `indices` has shape `[A1...AN, I]`
  * `params` has shape `[B1...BM]`

  Then:

  * `result` has shape `[A1...AN, B_{I+1}...BM]`.
  * `result[a1...aN] = params[indices[a1...aN, :]]`

  Args:
    params: A potentially ragged tensor with shape `[A1...AN, I]`.
    indices: A potentially ragged tensor with shape `[B1...BM]`.
    batch_dims: Must be zero.
    name: A name for the operation (optional).

  Returns:
    A potentially ragged tensor with shape `[A1...AN, B_{I+1}...BM]`.

  #### Examples:
    ```python
    >>> params = tf.compat.v1.ragged.constant_value(
    ...     [ [ ['000', '001'], ['010'              ]          ],
    ...       [ ['100'       ], ['110', '111', '112'], ['120'] ],
    ...       [ [            ], ['210'              ]          ] ])

    >>> # Gather 2D slices from a 3D tensor
    >>> ragged.gather_nd(params, [[2], [0]])
    [ [ [            ], ['210'] ]
      [ ['000', '001'], ['010'] ] ]

    >>> # Gather 1D slices from a 3D tensor
    >>> ragged.gather_nd(params, [[2, 1], [0, 0]])
    [['210'], ['000', '001']]

    >>> # Gather scalars from a 3D tensor
    >>> ragged.gather_nd(params, [[0, 0, 1], [1, 1, 2]])
    ['001', '112']
    ```
  """
  if not isinstance(batch_dims, int) or batch_dims != 0:
    raise ValueError('batch_dims != 0 is not supported for ragged gather yet.')
  if not (ragged_tensor.is_ragged(params) or ragged_tensor.is_ragged(indices)):
    return array_ops.gather_nd(params, indices, name)

  with ops.name_scope(name, 'RaggedGatherNd', [params, indices]):

    params = ragged_tensor.convert_to_tensor_or_ragged_tensor(
        params, name='params')
    indices = ragged_tensor.convert_to_tensor_or_ragged_tensor(
        indices, name='indices')
    params, indices = ragged_tensor.match_row_splits_dtypes(params, indices)
    indices_shape = indices.shape
    indices_ndims = indices_shape.ndims
    if indices_ndims is None:
      raise ValueError('indices.rank be statically known.')
    if indices_ndims == 0:
      raise ValueError('indices.rank must be at least 1.')
    if (ragged_tensor.is_ragged(indices) and
        indices_ndims == indices.ragged_rank + 1):
      raise ValueError('The innermost dimension of indices may not be ragged')

    # `index_size` is the "n" in "gather_nd" -- i.e., the number of dimensions
    # that each index slices into.
    index_size = tensor_shape.dimension_value(indices_shape[-1])
    if index_size is None:
      raise ValueError('indices.shape[-1] must be statically known.')

    # If `indices` has more than 2 dimensions, then recurse.  If `indices` is
    # dense, then we convert it to ragged before recursing, and then convert
    # the result back to `dense` if appropriate.
    if indices_ndims > 2:
      indices_is_dense = not ragged_tensor.is_ragged(indices)
      if indices_is_dense:
        indices = ragged_tensor.RaggedTensor.from_tensor(
            indices, ragged_rank=indices_ndims - 2,
            row_splits_dtype=params.row_splits.dtype)
      result = indices.with_flat_values(gather_nd(params, indices.flat_values))
      if (indices_is_dense and ragged_tensor.is_ragged(result) and
          result.ragged_rank == indices_ndims - 2):
        result = ragged_tensor.RaggedTensor.to_tensor(result)
      return result

    # indices_ndims <= 2, and the innermost dimension of indices may not be
    # ragged, so `indices` must not be ragged.
    assert not ragged_tensor.is_ragged(indices)
    assert ragged_tensor.is_ragged(params)

    # Handle corner case: An empty index tuple selects the entire `params`
    # value.  So if `index_size` is zero, then tile `params`.
    if index_size == 0:
      params_ndims = params.ragged_rank + array_ops.rank(params.flat_values)
      for dim in range(indices_ndims - 1):
        params = ragged_array_ops.expand_dims(params, axis=0)
      multiples = array_ops.concat([
          array_ops.shape(indices)[:-1],
          array_ops.ones([params_ndims], dtypes.int32)
      ],
                                   axis=0)
      return ragged_array_ops.tile(params, multiples)

    # When index_size=1, we can just flatten the index tuples and use gather.
    elif index_size == 1:
      flattened_index_tuples = array_ops.reshape(indices, [-1])
      return gather(params, flattened_index_tuples)

    # Otherwise, params is a RaggedTensor, and indices is a 1D or 2D Tensor.
    # Flatten both the index tuples and the params, such that the flattened
    # index tuples point to the correct values in the flattened params; and
    # then use ragged.gather on the flattened index tuples & params.
    else:
      indices = math_ops.cast(indices, params.row_splits.dtype)

      # Flatten the outermost 2 dimensions of the index tuples & params.
      flattened_index_tuples = array_ops.gather(params.row_splits,
                                                indices[..., 0])
      flattened_index_tuples += indices[..., 1]
      flattened_params = params.values

      # Flatten any remaining dimensions.
      for dim in range(2, index_size):
        if not ragged_tensor.is_ragged(flattened_params):
          flattened_index_tuples = array_ops.expand_dims(
              flattened_index_tuples, axis=1)
          flattened_index_tuples = array_ops.concat(
              [flattened_index_tuples, indices[..., dim:]], axis=1)
          return array_ops.gather_nd(flattened_params, flattened_index_tuples)

        flattened_index_tuples = array_ops.gather(
            flattened_params.row_starts(), flattened_index_tuples)
        flattened_index_tuples += indices[..., dim]
        flattened_params = flattened_params.values

      # Gather using the flattened index tuples and params.
      return gather(flattened_params, flattened_index_tuples)