def _elementwise_where_v2(condition, x, y): """Ragged version of tf.where_v2(condition, x, y).""" # Broadcast x, y, and condition to have the same shape. if not (condition.shape.is_fully_defined() and x.shape.is_fully_defined() and y.shape.is_fully_defined() and x.shape == y.shape and condition.shape == x.shape): shape_c = ragged_tensor_shape.RaggedTensorDynamicShape.from_tensor( condition) shape_x = ragged_tensor_shape.RaggedTensorDynamicShape.from_tensor(x) shape_y = ragged_tensor_shape.RaggedTensorDynamicShape.from_tensor(y) shape = ragged_tensor_shape.broadcast_dynamic_shape( shape_c, ragged_tensor_shape.broadcast_dynamic_shape(shape_x, shape_y)) condition = ragged_tensor_shape.broadcast_to(condition, shape) x = ragged_tensor_shape.broadcast_to(x, shape) y = ragged_tensor_shape.broadcast_to(y, shape) condition_is_ragged = isinstance(condition, ragged_tensor.RaggedTensor) x_is_ragged = isinstance(x, ragged_tensor.RaggedTensor) y_is_ragged = isinstance(y, ragged_tensor.RaggedTensor) if not (condition_is_ragged or x_is_ragged or y_is_ragged): return array_ops.where_v2(condition, x, y) return ragged_functional_ops.map_flat_values(array_ops.where_v2, condition, x, y)
def testBroadcastDynamicShape(self, x_dims, y_dims, expected_dims): x_shape = RaggedTensorDynamicShape.from_dim_sizes(x_dims) y_shape = RaggedTensorDynamicShape.from_dim_sizes(y_dims) expected = RaggedTensorDynamicShape.from_dim_sizes(expected_dims) result1 = ragged_tensor_shape.broadcast_dynamic_shape(x_shape, y_shape) result2 = ragged_tensor_shape.broadcast_dynamic_shape(y_shape, x_shape) self.assertShapeEq(expected, result1) self.assertShapeEq(expected, result2)
def handle(self, args, kwargs): # Extract the binary args. if len(args) > 1: x = args[0] y = args[1] args = args[2:] elif args: kwargs = kwargs.copy() x = args[0] y = kwargs.pop(self._y, None) args = args[1:] else: kwargs = kwargs.copy() x = kwargs.pop(self._x, None) y = kwargs.pop(self._y, None) # Bail if we don't have at least one ragged argument. x_is_ragged = ragged_tensor.is_ragged(x) y_is_ragged = ragged_tensor.is_ragged(y) if not (x_is_ragged or y_is_ragged): return self.NOT_SUPPORTED # Convert args to tensors. Bail if conversion fails. try: x = ragged_tensor.convert_to_tensor_or_ragged_tensor( x, name=self._x, preferred_dtype=(y.dtype if y_is_ragged else None)) y = ragged_tensor.convert_to_tensor_or_ragged_tensor( y, name=self._y, preferred_dtype=(x.dtype if x_is_ragged else None)) except (TypeError, ValueError): return self.NOT_SUPPORTED if x_is_ragged and y_is_ragged: x, y = ragged_tensor.match_row_splits_dtypes(x, y) if ((x_is_ragged and y_is_ragged) or (x_is_ragged and x.flat_values.shape.ndims <= y.shape.ndims) or (y_is_ragged and y.flat_values.shape.ndims <= x.shape.ndims)): bcast_shape = ragged_tensor_shape.broadcast_dynamic_shape( ragged_tensor_shape.RaggedTensorDynamicShape.from_tensor(x), ragged_tensor_shape.RaggedTensorDynamicShape.from_tensor(y)) x = ragged_tensor_shape.broadcast_to( x, bcast_shape, broadcast_inner_dimensions=False) y = ragged_tensor_shape.broadcast_to( y, bcast_shape, broadcast_inner_dimensions=False) x_values = x.flat_values if ragged_tensor.is_ragged(x) else x y_values = y.flat_values if ragged_tensor.is_ragged(y) else y mapped_values = self._original_op(x_values, y_values, *args, **kwargs) if ragged_tensor.is_ragged(x): return x.with_flat_values(mapped_values) else: return y.with_flat_values(mapped_values)
def handle(self, args, kwargs): # Extract the binary args. if len(args) > 1: x = args[0] y = args[1] args = args[2:] elif args: kwargs = kwargs.copy() x = args[0] y = kwargs.pop(self._y, None) args = args[1:] else: kwargs = kwargs.copy() x = kwargs.pop(self._x, None) y = kwargs.pop(self._y, None) # Bail if we don't have at least one ragged argument. x_is_ragged = ragged_tensor.is_ragged(x) y_is_ragged = ragged_tensor.is_ragged(y) if not (x_is_ragged or y_is_ragged): return self.NOT_SUPPORTED # Convert args to tensors. Bail if conversion fails. try: if not x_is_ragged: x = ops.convert_to_tensor(x, name=self._x, preferred_dtype=y.dtype) if not y_is_ragged: y = ops.convert_to_tensor(y, name=self._y, preferred_dtype=x.dtype) except (TypeError, ValueError): return self.NOT_SUPPORTED if x_is_ragged and y_is_ragged: x, y = ragged_tensor.match_row_splits_dtypes(x, y) if ((x_is_ragged and y_is_ragged) or (x_is_ragged and x.flat_values.shape.ndims <= y.shape.ndims) or (y_is_ragged and y.flat_values.shape.ndims <= x.shape.ndims)): bcast_shape = ragged_tensor_shape.broadcast_dynamic_shape( ragged_tensor_shape.RaggedTensorDynamicShape.from_tensor(x), ragged_tensor_shape.RaggedTensorDynamicShape.from_tensor(y)) x = ragged_tensor_shape.broadcast_to( x, bcast_shape, broadcast_inner_dimensions=False) y = ragged_tensor_shape.broadcast_to( y, bcast_shape, broadcast_inner_dimensions=False) x_values = x.flat_values if ragged_tensor.is_ragged(x) else x y_values = y.flat_values if ragged_tensor.is_ragged(y) else y mapped_values = self._original_op(x_values, y_values, *args, **kwargs) if ragged_tensor.is_ragged(x): return x.with_flat_values(mapped_values) else: return y.with_flat_values(mapped_values)
def ragged_binary_elementwise_op(op, x, y): """Binary elementwise api handler for RaggedTensors.""" x_is_ragged = ragged_tensor.is_ragged(x) y_is_ragged = ragged_tensor.is_ragged(y) # Convert args to tensors. x = ragged_tensor.convert_to_tensor_or_ragged_tensor( x, preferred_dtype=(y.dtype if y_is_ragged else None)) y = ragged_tensor.convert_to_tensor_or_ragged_tensor( y, preferred_dtype=x.dtype) if x_is_ragged and y_is_ragged: x, y = ragged_tensor.match_row_splits_dtypes(x, y) # Perform broadcasting, when appropraite if ((x_is_ragged and y_is_ragged) or (x_is_ragged and x.flat_values.shape.ndims <= y.shape.ndims) or (y_is_ragged and y.flat_values.shape.ndims <= x.shape.ndims)): # If both x and y are ragged, they must have the same row_splits_dtype now. if x_is_ragged: dim_size_dtype = x.row_splits.dtype else: dim_size_dtype = y.row_splits.dtype shape_x = ragged_tensor_shape.RaggedTensorDynamicShape.from_tensor( x, dim_size_dtype=dim_size_dtype) shape_y = ragged_tensor_shape.RaggedTensorDynamicShape.from_tensor( y, dim_size_dtype=dim_size_dtype) bcast_shape = ragged_tensor_shape.broadcast_dynamic_shape( shape_x, shape_y) x = ragged_tensor_shape.broadcast_to(x, bcast_shape, broadcast_inner_dimensions=False) y = ragged_tensor_shape.broadcast_to(y, bcast_shape, broadcast_inner_dimensions=False) x_values = x.flat_values if ragged_tensor.is_ragged(x) else x y_values = y.flat_values if ragged_tensor.is_ragged(y) else y mapped_values = op(x_values, y_values) if isinstance(mapped_values, bool): return mapped_values # Special case for tensor_equals. if ragged_tensor.is_ragged(x): return x.with_flat_values(mapped_values) else: return y.with_flat_values(mapped_values)