Ejemplo n.º 1
0
def _elementwise_where_v2(condition, x, y):
    """Ragged version of tf.where_v2(condition, x, y)."""
    # Broadcast x, y, and condition to have the same shape.
    if not (condition.shape.is_fully_defined() and x.shape.is_fully_defined()
            and y.shape.is_fully_defined() and x.shape == y.shape
            and condition.shape == x.shape):
        shape_c = ragged_tensor_shape.RaggedTensorDynamicShape.from_tensor(
            condition)
        shape_x = ragged_tensor_shape.RaggedTensorDynamicShape.from_tensor(x)
        shape_y = ragged_tensor_shape.RaggedTensorDynamicShape.from_tensor(y)
        shape = ragged_tensor_shape.broadcast_dynamic_shape(
            shape_c,
            ragged_tensor_shape.broadcast_dynamic_shape(shape_x, shape_y))
        condition = ragged_tensor_shape.broadcast_to(condition, shape)
        x = ragged_tensor_shape.broadcast_to(x, shape)
        y = ragged_tensor_shape.broadcast_to(y, shape)

    condition_is_ragged = isinstance(condition, ragged_tensor.RaggedTensor)
    x_is_ragged = isinstance(x, ragged_tensor.RaggedTensor)
    y_is_ragged = isinstance(y, ragged_tensor.RaggedTensor)
    if not (condition_is_ragged or x_is_ragged or y_is_ragged):
        return array_ops.where_v2(condition, x, y)

    return ragged_functional_ops.map_flat_values(array_ops.where_v2, condition,
                                                 x, y)
Ejemplo n.º 2
0
    def handle(self, args, kwargs):
        # Extract the binary args.
        if len(args) > 1:
            x = args[0]
            y = args[1]
            args = args[2:]
        elif args:
            kwargs = kwargs.copy()
            x = args[0]
            y = kwargs.pop(self._y, None)
            args = args[1:]
        else:
            kwargs = kwargs.copy()
            x = kwargs.pop(self._x, None)
            y = kwargs.pop(self._y, None)

        # Bail if we don't have at least one ragged argument.
        x_is_ragged = ragged_tensor.is_ragged(x)
        y_is_ragged = ragged_tensor.is_ragged(y)
        if not (x_is_ragged or y_is_ragged):
            return self.NOT_SUPPORTED

        # Convert args to tensors.  Bail if conversion fails.
        try:
            x = ragged_tensor.convert_to_tensor_or_ragged_tensor(
                x,
                name=self._x,
                preferred_dtype=(y.dtype if y_is_ragged else None))
            y = ragged_tensor.convert_to_tensor_or_ragged_tensor(
                y,
                name=self._y,
                preferred_dtype=(x.dtype if x_is_ragged else None))
        except (TypeError, ValueError):
            return self.NOT_SUPPORTED

        if x_is_ragged and y_is_ragged:
            x, y = ragged_tensor.match_row_splits_dtypes(x, y)

        if ((x_is_ragged and y_is_ragged)
                or (x_is_ragged and x.flat_values.shape.ndims <= y.shape.ndims)
                or
            (y_is_ragged and y.flat_values.shape.ndims <= x.shape.ndims)):
            bcast_shape = ragged_tensor_shape.broadcast_dynamic_shape(
                ragged_tensor_shape.RaggedTensorDynamicShape.from_tensor(x),
                ragged_tensor_shape.RaggedTensorDynamicShape.from_tensor(y))
            x = ragged_tensor_shape.broadcast_to(
                x, bcast_shape, broadcast_inner_dimensions=False)
            y = ragged_tensor_shape.broadcast_to(
                y, bcast_shape, broadcast_inner_dimensions=False)

        x_values = x.flat_values if ragged_tensor.is_ragged(x) else x
        y_values = y.flat_values if ragged_tensor.is_ragged(y) else y
        mapped_values = self._original_op(x_values, y_values, *args, **kwargs)
        if ragged_tensor.is_ragged(x):
            return x.with_flat_values(mapped_values)
        else:
            return y.with_flat_values(mapped_values)
Ejemplo n.º 3
0
  def handle(self, args, kwargs):
    # Extract the binary args.
    if len(args) > 1:
      x = args[0]
      y = args[1]
      args = args[2:]
    elif args:
      kwargs = kwargs.copy()
      x = args[0]
      y = kwargs.pop(self._y, None)
      args = args[1:]
    else:
      kwargs = kwargs.copy()
      x = kwargs.pop(self._x, None)
      y = kwargs.pop(self._y, None)

    # Bail if we don't have at least one ragged argument.
    x_is_ragged = ragged_tensor.is_ragged(x)
    y_is_ragged = ragged_tensor.is_ragged(y)
    if not (x_is_ragged or y_is_ragged):
      return self.NOT_SUPPORTED

    # Convert args to tensors.  Bail if conversion fails.
    try:
      if not x_is_ragged:
        x = ops.convert_to_tensor(x, name=self._x, preferred_dtype=y.dtype)
      if not y_is_ragged:
        y = ops.convert_to_tensor(y, name=self._y, preferred_dtype=x.dtype)
    except (TypeError, ValueError):
      return self.NOT_SUPPORTED

    if x_is_ragged and y_is_ragged:
      x, y = ragged_tensor.match_row_splits_dtypes(x, y)

    if ((x_is_ragged and y_is_ragged) or
        (x_is_ragged and x.flat_values.shape.ndims <= y.shape.ndims) or
        (y_is_ragged and y.flat_values.shape.ndims <= x.shape.ndims)):
      bcast_shape = ragged_tensor_shape.broadcast_dynamic_shape(
          ragged_tensor_shape.RaggedTensorDynamicShape.from_tensor(x),
          ragged_tensor_shape.RaggedTensorDynamicShape.from_tensor(y))
      x = ragged_tensor_shape.broadcast_to(
          x, bcast_shape, broadcast_inner_dimensions=False)
      y = ragged_tensor_shape.broadcast_to(
          y, bcast_shape, broadcast_inner_dimensions=False)

    x_values = x.flat_values if ragged_tensor.is_ragged(x) else x
    y_values = y.flat_values if ragged_tensor.is_ragged(y) else y
    mapped_values = self._original_op(x_values, y_values, *args, **kwargs)
    if ragged_tensor.is_ragged(x):
      return x.with_flat_values(mapped_values)
    else:
      return y.with_flat_values(mapped_values)
Ejemplo n.º 4
0
def ragged_binary_elementwise_op(op, x, y):
    """Binary elementwise api handler for RaggedTensors."""
    x_is_ragged = ragged_tensor.is_ragged(x)
    y_is_ragged = ragged_tensor.is_ragged(y)

    # Convert args to tensors.
    x = ragged_tensor.convert_to_tensor_or_ragged_tensor(
        x, preferred_dtype=(y.dtype if y_is_ragged else None))
    y = ragged_tensor.convert_to_tensor_or_ragged_tensor(
        y, preferred_dtype=x.dtype)

    if x_is_ragged and y_is_ragged:
        x, y = ragged_tensor.match_row_splits_dtypes(x, y)

    # Perform broadcasting, when appropraite
    if ((x_is_ragged and y_is_ragged)
            or (x_is_ragged and x.flat_values.shape.ndims <= y.shape.ndims)
            or (y_is_ragged and y.flat_values.shape.ndims <= x.shape.ndims)):
        # If both x and y are ragged, they must have the same row_splits_dtype now.
        if x_is_ragged:
            dim_size_dtype = x.row_splits.dtype
        else:
            dim_size_dtype = y.row_splits.dtype

        shape_x = ragged_tensor_shape.RaggedTensorDynamicShape.from_tensor(
            x, dim_size_dtype=dim_size_dtype)
        shape_y = ragged_tensor_shape.RaggedTensorDynamicShape.from_tensor(
            y, dim_size_dtype=dim_size_dtype)
        bcast_shape = ragged_tensor_shape.broadcast_dynamic_shape(
            shape_x, shape_y)
        x = ragged_tensor_shape.broadcast_to(x,
                                             bcast_shape,
                                             broadcast_inner_dimensions=False)
        y = ragged_tensor_shape.broadcast_to(y,
                                             bcast_shape,
                                             broadcast_inner_dimensions=False)

    x_values = x.flat_values if ragged_tensor.is_ragged(x) else x
    y_values = y.flat_values if ragged_tensor.is_ragged(y) else y
    mapped_values = op(x_values, y_values)
    if isinstance(mapped_values, bool):
        return mapped_values  # Special case for tensor_equals.
    if ragged_tensor.is_ragged(x):
        return x.with_flat_values(mapped_values)
    else:
        return y.with_flat_values(mapped_values)
Ejemplo n.º 5
0
def batch_gather_with_default(params, indices, default_value='', name=None):
    """Same as `batch_gather` but inserts `default_value` for invalid indices.

  This operation is similar to `batch_gather` except that it will substitute
  the value for invalid indices with `default_value` as the contents.
  See `batch_gather` for more details.


  Args:
    params: A potentially ragged tensor with shape `[B1...BN, P1...PM]` (`N>=0`,
      `M>0`).
    indices: A potentially ragged tensor with shape `[B1...BN, I]` (`N>=0`).
    default_value: A value to be inserted in places where `indices` are out of
      bounds. Must be the same dtype as params and either a scalar or rank 1.
    name: A name for the operation (optional).

  Returns:
    A potentially ragged tensor with shape `[B1...BN, I, P2...PM]`.
    `result.ragged_rank = max(indices.ragged_rank, params.ragged_rank)`.

  #### Example:
    ```python
    >>> params = tf.ragged.constant([
          ['a', 'b', 'c'],
          ['d'],
          [],
          ['e']])
    >>> indices = tf.ragged.constant([[1, 2, -1], [], [], [0, 10]])
    >>> batch_gather_with_default(params, indices, 'FOO')
    [['b', 'c', 'FOO'], [], [], ['e', 'FOO']]
  ```
  """
    with ops.name_scope(name, 'RaggedBatchGatherWithDefault'):
        params = ragged_tensor.convert_to_tensor_or_ragged_tensor(
            params,
            name='params',
        )
        indices = ragged_tensor.convert_to_tensor_or_ragged_tensor(
            indices,
            name='indices',
        )
        default_value = ragged_tensor.convert_to_tensor_or_ragged_tensor(
            default_value,
            name='default_value',
        )
        # TODO(hterry): lift this restriction and support default_values of
        #               of rank > 1
        if (default_value.shape.ndims is not 0
                and default_value.shape.ndims is not 1):
            raise ValueError('"default_value" must be a scalar or vector')
        upper_bounds = None
        if indices.shape.ndims is None:
            raise ValueError('Indices must have a known rank.')
        if params.shape.ndims is None:
            raise ValueError('Params must have a known rank.')

        num_batch_dimensions = indices.shape.ndims - 1
        pad = None
        # The logic for this works as follows:
        # - create a padded params, where:
        #    padded_params[b1...bn, 0] = default_value
        #    padded_params[b1...bn, i] = params[b1...bn, i-1] (i>0)
        # - create an `upper_bounds` Tensor that contains the number of elements
        #   in each innermost rank. Broadcast `upper_bounds` to be the same shape
        #   as `indices`.
        # - check to see which index in `indices` are out of bounds and substitute
        #   it with the index containing `default_value` (the first).
        # - call batch_gather with the indices adjusted.
        with ops.control_dependencies([
                check_ops.assert_greater_equal(array_ops.rank(params),
                                               array_ops.rank(indices))
        ]):
            if ragged_tensor.is_ragged(params):
                row_lengths = ragged_array_ops.expand_dims(
                    params.row_lengths(axis=num_batch_dimensions), axis=-1)
                upper_bounds = math_ops.cast(row_lengths, indices.dtype)

                pad_shape = _get_pad_shape(params, indices)

                pad = ragged_tensor_shape.broadcast_to(default_value,
                                                       pad_shape)
            else:
                params_shape = array_ops.shape(params)
                pad_shape = array_ops.concat([
                    params_shape[:num_batch_dimensions], [1],
                    params_shape[num_batch_dimensions + 1:params.shape.ndims]
                ], 0)
                upper_bounds = params_shape[num_batch_dimensions]
                pad = array_ops.broadcast_to(default_value, pad_shape)

            # Add `default_value` as the first value in the innermost (ragged) rank.
            pad = math_ops.cast(pad, params.dtype)
            padded_params = array_ops.concat([pad, params],
                                             axis=num_batch_dimensions)

            # Adjust the indices by substituting out-of-bound indices to the
            # default-value index (which is the first element)
            shifted_indices = indices + 1
            is_out_of_bounds = (indices < 0) | (indices > upper_bounds)
            adjusted_indices = ragged_where_op.where(
                is_out_of_bounds,
                x=array_ops.zeros_like(indices),
                y=shifted_indices,
            )
            return array_ops.batch_gather(params=padded_params,
                                          indices=adjusted_indices,
                                          name=name)
def batch_gather_with_default(params,
                              indices,
                              default_value='',
                              name=None):
  """Same as `batch_gather` but inserts `default_value` for invalid indices.

  This operation is similar to `batch_gather` except that it will substitute
  the value for invalid indices with `default_value` as the contents.
  See `batch_gather` for more details.


  Args:
    params: A potentially ragged tensor with shape `[B1...BN, P1...PM]` (`N>=0`,
      `M>0`).
    indices: A potentially ragged tensor with shape `[B1...BN, I]` (`N>=0`).
    default_value: A value to be inserted in places where `indices` are out of
      bounds. Must be the same dtype as params and either a scalar or rank 1.
    name: A name for the operation (optional).

  Returns:
    A potentially ragged tensor with shape `[B1...BN, I, P2...PM]`.
    `result.ragged_rank = max(indices.ragged_rank, params.ragged_rank)`.

  #### Example:
    ```python
    >>> params = tf.ragged.constant([
          ['a', 'b', 'c'],
          ['d'],
          [],
          ['e']])
    >>> indices = tf.ragged.constant([[1, 2, -1], [], [], [0, 10]])
    >>> batch_gather_with_default(params, indices, 'FOO')
    [['b', 'c', 'FOO'], [], [], ['e', 'FOO']]
  ```
  """
  with ops.name_scope(name, 'RaggedBatchGatherWithDefault'):
    params = ragged_tensor.convert_to_tensor_or_ragged_tensor(
        params, name='params',
    )
    indices = ragged_tensor.convert_to_tensor_or_ragged_tensor(
        indices, name='indices',
    )
    default_value = ragged_tensor.convert_to_tensor_or_ragged_tensor(
        default_value, name='default_value',
    )
    # TODO(hterry): lift this restriction and support default_values of
    #               of rank > 1
    if (default_value.shape.ndims is not 0
        and default_value.shape.ndims is not 1):
      raise ValueError('"default_value" must be a scalar or vector')
    upper_bounds = None
    if indices.shape.ndims is None:
      raise ValueError('Indices must have a known rank.')
    if params.shape.ndims is None:
      raise ValueError('Params must have a known rank.')

    num_batch_dimensions = indices.shape.ndims - 1
    pad = None
    # The logic for this works as follows:
    # - create a padded params, where:
    #    padded_params[b1...bn, 0] = default_value
    #    padded_params[b1...bn, i] = params[b1...bn, i-1] (i>0)
    # - create an `upper_bounds` Tensor that contains the number of elements
    #   in each innermost rank. Broadcast `upper_bounds` to be the same shape
    #   as `indices`.
    # - check to see which index in `indices` are out of bounds and substitute
    #   it with the index containing `default_value` (the first).
    # - call batch_gather with the indices adjusted.
    with ops.control_dependencies([
        check_ops.assert_greater_equal(array_ops.rank(params),
                                       array_ops.rank(indices))]):
      if ragged_tensor.is_ragged(params):
        row_lengths = ragged_array_ops.expand_dims(
            params.row_lengths(axis=num_batch_dimensions),
            axis=-1)
        upper_bounds = math_ops.cast(row_lengths, indices.dtype)

        pad_shape = _get_pad_shape(params, indices)

        pad = ragged_tensor_shape.broadcast_to(
            default_value, pad_shape)
      else:
        params_shape = array_ops.shape(params)
        pad_shape = array_ops.concat([
            params_shape[:num_batch_dimensions],
            [1],
            params_shape[num_batch_dimensions + 1:params.shape.ndims]
        ], 0)
        upper_bounds = params_shape[num_batch_dimensions]
        pad = array_ops.broadcast_to(default_value, pad_shape)

      # Add `default_value` as the first value in the innermost (ragged) rank.
      pad = math_ops.cast(pad, params.dtype)
      padded_params = array_ops.concat(
          [pad, params], axis=num_batch_dimensions)

      # Adjust the indices by substituting out-of-bound indices to the
      # default-value index (which is the first element)
      shifted_indices = indices + 1
      is_out_of_bounds = (indices < 0) | (indices > upper_bounds)
      adjusted_indices = ragged_where_op.where(
          is_out_of_bounds,
          x=array_ops.zeros_like(indices), y=shifted_indices,
      )
      return array_ops.batch_gather(
          params=padded_params, indices=adjusted_indices, name=name)
Ejemplo n.º 7
0
 def testRaggedBroadcastTo(self, x, dim_sizes, expected):
     shape = RaggedTensorDynamicShape.from_dim_sizes(dim_sizes)
     result = ragged_tensor_shape.broadcast_to(x, shape)
     self.assertEqual(getattr(result, 'ragged_rank', 0),
                      getattr(expected, 'ragged_rank', 0))
     self.assertAllEqual(result, expected)
Ejemplo n.º 8
0
 def testRaggedBroadcastTo(self, x, dim_sizes, expected):
   shape = RaggedTensorDynamicShape.from_dim_sizes(dim_sizes)
   result = ragged_tensor_shape.broadcast_to(x, shape)
   self.assertEqual(
       getattr(result, 'ragged_rank', 0), getattr(expected, 'ragged_rank', 0))
   self.assertRaggedEqual(result, expected)
Ejemplo n.º 9
0
def _broadcast_elementwise_args(elementwise_args):
  """Broadcasts the values of `elementwise_args` to have compatible shapes.

  Args:
    elementwise_args: A dictionary whose keys are potentially ragged tensors.

  Returns:
    A tuple `(broadcast_args, broadcast_splits, checks)` where:

    * `broadcast_args` is a dictionary with the same keys as
      `elementwise_args`, mapping to broadcasted tensors.
    * `broadcast_splits` is the broadcasted nested row splits.
    * `checks` is a possibly empty tuple of assertion operations that should
      be added as control dependencies.

  Raises:
    ValueError: If broadcasting fails.
  """
  # No elementwise arguments were used: nothing to do!
  if not elementwise_args:
    return elementwise_args, (), ()

  # A single elementwise argument was used: no broadcasting necessary.
  if len(elementwise_args) == 1:
    arg = list(elementwise_args.values())[0]
    if ragged_tensor.is_ragged(arg):
      return elementwise_args, arg.nested_row_splits, ()
    else:
      return elementwise_args, (), ()

  # Multiple elementwise arguments.
  else:
    is_ragged = [ragged_tensor.is_ragged(t) for t in elementwise_args.values()]
    if not any(is_ragged):
      return elementwise_args, (), ()

    # If we have a single ragged tensor plus a set of scalars, then we can
    # rely on the underlying elementwise op to do broadcasting.
    if (sum(is_ragged) == 1 and
        all((ragged_tensor.is_ragged(t) or t.shape.ndims == 0)
            for t in elementwise_args.values())):
      nested_splits_lists = [
          t.nested_row_splits
          for t in elementwise_args.values()
          if ragged_tensor.is_ragged(t)][0]
      return elementwise_args, nested_splits_lists, ()

    else:
      # Get the shapes of all the elementwise arguments.
      shapes = [ragged_tensor_shape.RaggedTensorDynamicShape.from_tensor(t)
                for t in elementwise_args.values()]

      # Broadcast the shapes to all have the same rank (the max rank).
      ranks = [t.shape.ndims for t in elementwise_args.values()]
      if any(rank is None for rank in ranks):
        raise ValueError('Unable to broadcast: unknown rank')
      broadcast_rank = max(ranks)
      shapes = [shape.broadcast_to_rank(broadcast_rank) for shape in shapes]

      # For each dimension, broadcast the shapes to be compatible.
      for axis in range(broadcast_rank):
        # For each i, broadcast shape[i+1] to be compatible with shape[i]; and
        # then finally broadcast shape[0] to be compatible with shape[-1].
        for i in range(len(shapes)):
          j = (i + 1) % len(shapes)
          dim_size = shapes[i].dimension_size(axis)
          shapes[j] = shapes[j].broadcast_dimension(axis, dim_size)
      broadcast_shape = shapes[0]

      # Broadcast every elementwise arg to the shape that we calculated.
      elementwise_args = dict([
          (key, ragged_tensor_shape.broadcast_to(t, broadcast_shape, False))
          for (key, t) in elementwise_args.items()])
      nested_splits_lists = list(elementwise_args.values())[0].nested_row_splits
      return elementwise_args, nested_splits_lists, ()
def _broadcast_elementwise_args(elementwise_args):
    """Broadcasts the values of `elementwise_args` to have compatible shapes.

  Args:
    elementwise_args: A dictionary whose keys are potentially ragged tensors.

  Returns:
    A tuple `(broadcast_args, broadcast_splits, checks)` where:

    * `broadcast_args` is a dictionary with the same keys as
      `elementwise_args`, mapping to broadcasted tensors.
    * `broadcast_splits` is the broadcasted nested row splits.
    * `checks` is a possibly empty tuple of assertion operations that should
      be added as control dependencies.

  Raises:
    ValueError: If broadcasting fails.
  """
    # No elementwise arguments were used: nothing to do!
    if not elementwise_args:
        return elementwise_args, (), ()

    # A single elementwise argument was used: no broadcasting necessary.
    if len(elementwise_args) == 1:
        arg = list(elementwise_args.values())[0]
        if ragged_tensor.is_ragged(arg):
            return elementwise_args, arg.nested_row_splits, ()
        else:
            return elementwise_args, (), ()

    # Multiple elementwise arguments.
    else:
        is_ragged = [
            ragged_tensor.is_ragged(t) for t in elementwise_args.values()
        ]
        if not any(is_ragged):
            return elementwise_args, (), ()

        # If we have a single ragged tensor plus a set of scalars, then we can
        # rely on the underlying elementwise op to do broadcasting.
        if (sum(is_ragged) == 1 and all(
            (ragged_tensor.is_ragged(t) or t.shape.ndims == 0)
                for t in elementwise_args.values())):
            nested_splits_lists = [
                t.nested_row_splits for t in elementwise_args.values()
                if ragged_tensor.is_ragged(t)
            ][0]
            return elementwise_args, nested_splits_lists, ()

        else:
            # Get the shapes of all the elementwise arguments.
            shapes = [
                ragged_tensor_shape.RaggedTensorDynamicShape.from_tensor(t)
                for t in elementwise_args.values()
            ]

            # Broadcast the shapes to all have the same rank (the max rank).
            ranks = [t.shape.ndims for t in elementwise_args.values()]
            if any(rank is None for rank in ranks):
                raise ValueError('Unable to broadcast: unknown rank')
            broadcast_rank = max(ranks)
            shapes = [
                shape.broadcast_to_rank(broadcast_rank) for shape in shapes
            ]

            # For each dimension, broadcast the shapes to be compatible.
            for axis in range(broadcast_rank):
                # For each i, broadcast shape[i+1] to be compatible with shape[i]; and
                # then finally broadcast shape[0] to be compatible with shape[-1].
                for i in range(len(shapes)):
                    j = (i + 1) % len(shapes)
                    dim_size = shapes[i].dimension_size(axis)
                    shapes[j] = shapes[j].broadcast_dimension(axis, dim_size)
            broadcast_shape = shapes[0]

            # Broadcast every elementwise arg to the shape that we calculated.
            elementwise_args = dict([
                (key,
                 ragged_tensor_shape.broadcast_to(t, broadcast_shape, False))
                for (key, t) in elementwise_args.items()
            ])
            nested_splits_lists = list(
                elementwise_args.values())[0].nested_row_splits
            return elementwise_args, nested_splits_lists, ()