Beispiel #1
0
def _elementwise_where_v2(condition, x, y):
    """Ragged version of tf.where_v2(condition, x, y)."""
    # Broadcast x, y, and condition to have the same shape.
    if not (condition.shape.is_fully_defined() and x.shape.is_fully_defined()
            and y.shape.is_fully_defined() and x.shape == y.shape
            and condition.shape == x.shape):
        shape_c = ragged_tensor_shape.RaggedTensorDynamicShape.from_tensor(
            condition)
        shape_x = ragged_tensor_shape.RaggedTensorDynamicShape.from_tensor(x)
        shape_y = ragged_tensor_shape.RaggedTensorDynamicShape.from_tensor(y)
        shape = ragged_tensor_shape.broadcast_dynamic_shape(
            shape_c,
            ragged_tensor_shape.broadcast_dynamic_shape(shape_x, shape_y))
        condition = ragged_tensor_shape.broadcast_to(condition, shape)
        x = ragged_tensor_shape.broadcast_to(x, shape)
        y = ragged_tensor_shape.broadcast_to(y, shape)

    condition_is_ragged = isinstance(condition, ragged_tensor.RaggedTensor)
    x_is_ragged = isinstance(x, ragged_tensor.RaggedTensor)
    y_is_ragged = isinstance(y, ragged_tensor.RaggedTensor)
    if not (condition_is_ragged or x_is_ragged or y_is_ragged):
        return array_ops.where_v2(condition, x, y)

    return ragged_functional_ops.map_flat_values(array_ops.where_v2, condition,
                                                 x, y)
 def testBroadcastDynamicShape(self, x_dims, y_dims, expected_dims):
     x_shape = RaggedTensorDynamicShape.from_dim_sizes(x_dims)
     y_shape = RaggedTensorDynamicShape.from_dim_sizes(y_dims)
     expected = RaggedTensorDynamicShape.from_dim_sizes(expected_dims)
     result1 = ragged_tensor_shape.broadcast_dynamic_shape(x_shape, y_shape)
     result2 = ragged_tensor_shape.broadcast_dynamic_shape(y_shape, x_shape)
     self.assertShapeEq(expected, result1)
     self.assertShapeEq(expected, result2)
 def testBroadcastDynamicShape(self, x_dims, y_dims, expected_dims):
   x_shape = RaggedTensorDynamicShape.from_dim_sizes(x_dims)
   y_shape = RaggedTensorDynamicShape.from_dim_sizes(y_dims)
   expected = RaggedTensorDynamicShape.from_dim_sizes(expected_dims)
   result1 = ragged_tensor_shape.broadcast_dynamic_shape(x_shape, y_shape)
   result2 = ragged_tensor_shape.broadcast_dynamic_shape(y_shape, x_shape)
   self.assertShapeEq(expected, result1)
   self.assertShapeEq(expected, result2)
Beispiel #4
0
    def handle(self, args, kwargs):
        # Extract the binary args.
        if len(args) > 1:
            x = args[0]
            y = args[1]
            args = args[2:]
        elif args:
            kwargs = kwargs.copy()
            x = args[0]
            y = kwargs.pop(self._y, None)
            args = args[1:]
        else:
            kwargs = kwargs.copy()
            x = kwargs.pop(self._x, None)
            y = kwargs.pop(self._y, None)

        # Bail if we don't have at least one ragged argument.
        x_is_ragged = ragged_tensor.is_ragged(x)
        y_is_ragged = ragged_tensor.is_ragged(y)
        if not (x_is_ragged or y_is_ragged):
            return self.NOT_SUPPORTED

        # Convert args to tensors.  Bail if conversion fails.
        try:
            x = ragged_tensor.convert_to_tensor_or_ragged_tensor(
                x,
                name=self._x,
                preferred_dtype=(y.dtype if y_is_ragged else None))
            y = ragged_tensor.convert_to_tensor_or_ragged_tensor(
                y,
                name=self._y,
                preferred_dtype=(x.dtype if x_is_ragged else None))
        except (TypeError, ValueError):
            return self.NOT_SUPPORTED

        if x_is_ragged and y_is_ragged:
            x, y = ragged_tensor.match_row_splits_dtypes(x, y)

        if ((x_is_ragged and y_is_ragged)
                or (x_is_ragged and x.flat_values.shape.ndims <= y.shape.ndims)
                or
            (y_is_ragged and y.flat_values.shape.ndims <= x.shape.ndims)):
            bcast_shape = ragged_tensor_shape.broadcast_dynamic_shape(
                ragged_tensor_shape.RaggedTensorDynamicShape.from_tensor(x),
                ragged_tensor_shape.RaggedTensorDynamicShape.from_tensor(y))
            x = ragged_tensor_shape.broadcast_to(
                x, bcast_shape, broadcast_inner_dimensions=False)
            y = ragged_tensor_shape.broadcast_to(
                y, bcast_shape, broadcast_inner_dimensions=False)

        x_values = x.flat_values if ragged_tensor.is_ragged(x) else x
        y_values = y.flat_values if ragged_tensor.is_ragged(y) else y
        mapped_values = self._original_op(x_values, y_values, *args, **kwargs)
        if ragged_tensor.is_ragged(x):
            return x.with_flat_values(mapped_values)
        else:
            return y.with_flat_values(mapped_values)
  def handle(self, args, kwargs):
    # Extract the binary args.
    if len(args) > 1:
      x = args[0]
      y = args[1]
      args = args[2:]
    elif args:
      kwargs = kwargs.copy()
      x = args[0]
      y = kwargs.pop(self._y, None)
      args = args[1:]
    else:
      kwargs = kwargs.copy()
      x = kwargs.pop(self._x, None)
      y = kwargs.pop(self._y, None)

    # Bail if we don't have at least one ragged argument.
    x_is_ragged = ragged_tensor.is_ragged(x)
    y_is_ragged = ragged_tensor.is_ragged(y)
    if not (x_is_ragged or y_is_ragged):
      return self.NOT_SUPPORTED

    # Convert args to tensors.  Bail if conversion fails.
    try:
      if not x_is_ragged:
        x = ops.convert_to_tensor(x, name=self._x, preferred_dtype=y.dtype)
      if not y_is_ragged:
        y = ops.convert_to_tensor(y, name=self._y, preferred_dtype=x.dtype)
    except (TypeError, ValueError):
      return self.NOT_SUPPORTED

    if x_is_ragged and y_is_ragged:
      x, y = ragged_tensor.match_row_splits_dtypes(x, y)

    if ((x_is_ragged and y_is_ragged) or
        (x_is_ragged and x.flat_values.shape.ndims <= y.shape.ndims) or
        (y_is_ragged and y.flat_values.shape.ndims <= x.shape.ndims)):
      bcast_shape = ragged_tensor_shape.broadcast_dynamic_shape(
          ragged_tensor_shape.RaggedTensorDynamicShape.from_tensor(x),
          ragged_tensor_shape.RaggedTensorDynamicShape.from_tensor(y))
      x = ragged_tensor_shape.broadcast_to(
          x, bcast_shape, broadcast_inner_dimensions=False)
      y = ragged_tensor_shape.broadcast_to(
          y, bcast_shape, broadcast_inner_dimensions=False)

    x_values = x.flat_values if ragged_tensor.is_ragged(x) else x
    y_values = y.flat_values if ragged_tensor.is_ragged(y) else y
    mapped_values = self._original_op(x_values, y_values, *args, **kwargs)
    if ragged_tensor.is_ragged(x):
      return x.with_flat_values(mapped_values)
    else:
      return y.with_flat_values(mapped_values)
Beispiel #6
0
def ragged_binary_elementwise_op(op, x, y):
    """Binary elementwise api handler for RaggedTensors."""
    x_is_ragged = ragged_tensor.is_ragged(x)
    y_is_ragged = ragged_tensor.is_ragged(y)

    # Convert args to tensors.
    x = ragged_tensor.convert_to_tensor_or_ragged_tensor(
        x, preferred_dtype=(y.dtype if y_is_ragged else None))
    y = ragged_tensor.convert_to_tensor_or_ragged_tensor(
        y, preferred_dtype=x.dtype)

    if x_is_ragged and y_is_ragged:
        x, y = ragged_tensor.match_row_splits_dtypes(x, y)

    # Perform broadcasting, when appropraite
    if ((x_is_ragged and y_is_ragged)
            or (x_is_ragged and x.flat_values.shape.ndims <= y.shape.ndims)
            or (y_is_ragged and y.flat_values.shape.ndims <= x.shape.ndims)):
        # If both x and y are ragged, they must have the same row_splits_dtype now.
        if x_is_ragged:
            dim_size_dtype = x.row_splits.dtype
        else:
            dim_size_dtype = y.row_splits.dtype

        shape_x = ragged_tensor_shape.RaggedTensorDynamicShape.from_tensor(
            x, dim_size_dtype=dim_size_dtype)
        shape_y = ragged_tensor_shape.RaggedTensorDynamicShape.from_tensor(
            y, dim_size_dtype=dim_size_dtype)
        bcast_shape = ragged_tensor_shape.broadcast_dynamic_shape(
            shape_x, shape_y)
        x = ragged_tensor_shape.broadcast_to(x,
                                             bcast_shape,
                                             broadcast_inner_dimensions=False)
        y = ragged_tensor_shape.broadcast_to(y,
                                             bcast_shape,
                                             broadcast_inner_dimensions=False)

    x_values = x.flat_values if ragged_tensor.is_ragged(x) else x
    y_values = y.flat_values if ragged_tensor.is_ragged(y) else y
    mapped_values = op(x_values, y_values)
    if isinstance(mapped_values, bool):
        return mapped_values  # Special case for tensor_equals.
    if ragged_tensor.is_ragged(x):
        return x.with_flat_values(mapped_values)
    else:
        return y.with_flat_values(mapped_values)