def _broadcast_elementwise_args(elementwise_args): """Broadcasts the values of `elementwise_args` to have compatible shapes. Args: elementwise_args: A dictionary whose keys are potentially ragged tensors. Returns: A tuple `(broadcast_args, broadcast_splits, checks)` where: * `broadcast_args` is a dictionary with the same keys as `elementwise_args`, mapping to broadcasted tensors. * `broadcast_splits` is the broadcasted nested row splits. * `checks` is a possibly empty tuple of assertion operations that should be added as control dependencies. Raises: ValueError: If broadcasting fails. """ # No elementwise arguments were used: nothing to do! if not elementwise_args: return elementwise_args, (), () # A single elementwise argument was used: no broadcasting necessary. if len(elementwise_args) == 1: arg = list(elementwise_args.values())[0] if ragged_tensor.is_ragged(arg): return elementwise_args, arg.nested_row_splits, () else: return elementwise_args, (), () # Multiple elementwise arguments. else: is_ragged = [ragged_tensor.is_ragged(t) for t in elementwise_args.values()] if not any(is_ragged): return elementwise_args, (), () # Support limited broadcasting (namely, scalar + ragged). Full # broadcasting support will be added later. if all((ragged_tensor.is_ragged(t) or t.shape.ndims == 0) for t in elementwise_args.values()): nested_splits_lists = [ t.nested_row_splits for t in elementwise_args.values() if ragged_tensor.is_ragged(t) ] if len(nested_splits_lists) == 1: checks = () else: if any(t.shape.ndims is None for t in elementwise_args.values()): raise ValueError('Ragged elementwise ops require that rank (number ' 'of dimensions) be statically known.') if len(set(t.shape.ndims for t in elementwise_args.values())) != 1: raise ValueError('Ragged elementwise ops do not support ' 'broadcasting yet') checks = ragged_util.assert_splits_match(nested_splits_lists) return (elementwise_args, nested_splits_lists[0], checks) else: raise ValueError('Ragged elementwise ops do not support broadcasting yet')
def map_flat_values(op, *args, **kwargs): """Applies `op` to the values of one or more RaggedTensors. Replaces any `RaggedTensor` in `args` or `kwargs` with its `flat_values` tensor, and then calls `op`. Returns a `RaggedTensor` that is constructed from the input `RaggedTensor`s' `nested_row_splits` and the value returned by the `op`. If the input arguments contain multiple `RaggedTensor`s, then they must have identical `nested_row_splits`. Examples: ```python >>> rt = ragged.constant([[1, 2, 3], [], [4, 5], [6]]) >>> ragged.map_flat_values(tf.ones_like, rt).eval().tolist() [[1, 1, 1], [], [1, 1], [1]] >>> ragged.map_flat_values(tf.multiply, rt, rt).eval().tolist() [[1, 4, 9], [], [16, 25], [36]] >>> ragged.map_flat_values(tf.add, rt, 5).eval().tolist() [[6, 7, 8], [], [9, 10], [11]] ``` Args: op: The operation that should be applied to the RaggedTensor `flat_values`. `op` is typically an element-wise operation (such as math_ops.add), but any operation that preserves the size of the outermost dimension can be used. I.e., `shape[0]` of the value returned by `op` must match `shape[0]` of the `RaggedTensor`s' `flat_values` tensors. *args: Arguments for `op`. **kwargs: Keyword arguments for `op`. Returns: A `RaggedTensor` whose `ragged_rank` matches the `ragged_rank` of all input `RaggedTensor`s. Raises: ValueError: If args contains no `RaggedTensors`, or if the `nested_splits` of the input `RaggedTensor`s are not identical. """ # Replace RaggedTensors with their values; and collect the splits tensors # from each RaggedTensor. nested_splits_lists = [] inner_args = _replace_ragged_with_flat_values(args, nested_splits_lists) inner_kwargs = _replace_ragged_with_flat_values(kwargs, nested_splits_lists) if not nested_splits_lists: return op(*args, **kwargs) with ops.control_dependencies( ragged_util.assert_splits_match(nested_splits_lists)): # Delegate to op, and then compose the result from the transformed values # and the splits. return ragged_tensor.RaggedTensor.from_nested_row_splits( op(*inner_args, **inner_kwargs), nested_splits_lists[0])
def handle(self, args, kwargs): if args: x, args = args[0], args[1:] else: kwargs = kwargs.copy() x = kwargs.pop(self._x, None) if x is None: return self.NOT_SUPPORTED if self._arg_is_list: found_ragged = False for elt in x: if ragged_tensor.is_ragged(elt): found_ragged = True elif not _is_convertible_to_tensor(elt): return self.NOT_SUPPORTED if found_ragged: x = [ ragged_tensor.convert_to_tensor_or_ragged_tensor(elt) if ragged_tensor.is_ragged(elt) else elt for elt in x ] x = ragged_tensor.match_row_splits_dtypes(*x) ragged_elts = [ elt for elt in x if ragged_tensor.is_ragged(elt) ] nested_splits_lists = [ elt.nested_row_splits for elt in ragged_elts ] flat_values = [ elt.flat_values if ragged_tensor.is_ragged(elt) else elt for elt in x ] with ops.control_dependencies( ragged_util.assert_splits_match(nested_splits_lists)): return ragged_elts[0].with_flat_values( self._original_op(flat_values, *args, **kwargs)) else: return self.NOT_SUPPORTED else: found_ragged = ragged_tensor.is_ragged(x) if found_ragged: x = ragged_tensor.convert_to_tensor_or_ragged_tensor( x, name=self._x) mapped_values = self._original_op(x.flat_values, *args, **kwargs) return x.with_flat_values(mapped_values) else: return self.NOT_SUPPORTED
def handle(self, args, kwargs): if args: x, args = args[0], args[1:] else: kwargs = kwargs.copy() x = kwargs.pop(self._x, None) if x is None: return self.NOT_SUPPORTED if self._arg_is_list: found_ragged = False for elt in x: if ragged_tensor.is_ragged(elt): found_ragged = True elif not _is_convertible_to_tensor(elt): return self.NOT_SUPPORTED if found_ragged: nested_splits_lists = [ elt.nested_row_splits for elt in x if ragged_tensor.is_ragged(elt) ] inner_values = [ elt.inner_values if ragged_tensor.is_ragged(elt) else elt for elt in x ] with ops.control_dependencies( ragged_util.assert_splits_match(nested_splits_lists)): return ragged_factory_ops.from_nested_row_splits( self._original_op(inner_values, *args, **kwargs), nested_splits_lists[0]) else: return self.NOT_SUPPORTED else: found_ragged = ragged_tensor.is_ragged(x) if found_ragged: mapped_values = self._original_op(x.inner_values, *args, **kwargs) return x.with_inner_values(mapped_values) else: return self.NOT_SUPPORTED
def handle(self, args, kwargs): if args: x, args = args[0], args[1:] else: kwargs = kwargs.copy() x = kwargs.pop(self._x, None) if x is None: return self.NOT_SUPPORTED if self._arg_is_list: found_ragged = False for elt in x: if ragged_tensor.is_ragged(elt): found_ragged = True elif not _is_convertible_to_tensor(elt): return self.NOT_SUPPORTED if found_ragged: x = ragged_tensor.match_row_splits_dtypes(*x) nested_splits_lists = [ elt.nested_row_splits for elt in x if ragged_tensor.is_ragged(elt) ] flat_values = [ elt.flat_values if ragged_tensor.is_ragged(elt) else elt for elt in x ] with ops.control_dependencies( ragged_util.assert_splits_match(nested_splits_lists)): return ragged_tensor.RaggedTensor.from_nested_row_splits( self._original_op(flat_values, *args, **kwargs), nested_splits_lists[0], validate=False) else: return self.NOT_SUPPORTED else: found_ragged = ragged_tensor.is_ragged(x) if found_ragged: mapped_values = self._original_op(x.flat_values, *args, **kwargs) return x.with_flat_values(mapped_values) else: return self.NOT_SUPPORTED
def boolean_mask(data, mask, name=None): """Applies a boolean mask to `data` without flattening the mask dimensions. Returns a potentially ragged tensor that is formed by retaining the elements in `data` where the corresponding value in `mask` is `True`. * `output[a1...aA, i, b1...bB] = data[a1...aA, j, b1...bB]` Where `j` is the `i`th `True` entry of `mask[a1...aA]`. Note that `output` preserves the mask dimensions `a1...aA`; this differs from `tf.boolean_mask`, which flattens those dimensions. Args: data: A potentially ragged tensor. mask: A potentially ragged boolean tensor. `mask`'s shape must be a prefix of `data`'s shape. `rank(mask)` must be known statically. name: A name prefix for the returned tensor (optional). Returns: A potentially ragged tensor that is formed by retaining the elements in `data` where the corresponding value in `mask` is `True`. * `rank(output) = rank(data)`. * `output.ragged_rank = max(data.ragged_rank, rank(mask) - 1)`. Raises: ValueError: if `rank(mask)` is not known statically; or if `mask.shape` is not a prefix of `data.shape`. #### Examples: >>> # Aliases for True & False so data and mask line up. >>> T, F = (True, False) >>> tf.ragged.boolean_mask( # Mask a 2D Tensor. ... data=[[1, 2, 3], [4, 5, 6], [7, 8, 9]], ... mask=[[T, F, T], [F, F, F], [T, F, F]]).to_list() [[1, 3], [], [7]] >>> tf.ragged.boolean_mask( # Mask a 2D RaggedTensor. ... tf.ragged.constant([[1, 2, 3], [4], [5, 6]]), ... tf.ragged.constant([[F, F, T], [F], [T, T]])).to_list() [[3], [], [5, 6]] >>> tf.ragged.boolean_mask( # Mask rows of a 2D RaggedTensor. ... tf.ragged.constant([[1, 2, 3], [4], [5, 6]]), ... tf.ragged.constant([True, False, True])).to_list() [[1, 2, 3], [5, 6]] """ with ops.name_scope(name, 'RaggedMask', [data, mask]): # Convert inputs to tensors. data = ragged_tensor.convert_to_tensor_or_ragged_tensor(data, name='data') mask = ragged_tensor.convert_to_tensor_or_ragged_tensor( mask, dtypes.bool, name='mask') row_splits_dtype, (data, mask) = ragged_tensor.match_row_splits_dtypes( data, mask, return_dtype=True) # Get static rank of mask. if mask.shape.ndims is None: raise ValueError('mask.shape.ndims must be known statically.') elif mask.shape.ndims == 0: raise ValueError('mask cannot be scalar.') # If mask is ragged, then recurse with a non-ragged mask. if ragged_tensor.is_ragged(mask): if not ragged_tensor.is_ragged(data): data = ragged_tensor.RaggedTensor.from_tensor( data, ragged_rank=mask.ragged_rank, row_splits_dtype=mask.row_splits.dtype) # Check that mask.nested_row_splits is a prefix of # data.nested_row_splits. splits_list = [ mask.nested_row_splits, data.nested_row_splits[:mask.ragged_rank] ] with ops.control_dependencies( ragged_util.assert_splits_match(splits_list)): # Strip off ragged `splits` until `mask` is non-ragged. Keep the splits # that we strip off in `splits`, so we can add them back on after # we recursively mask the non-ragged data. splits = [] while ragged_tensor.is_ragged(mask): if mask.shape.ndims > 2: splits.append(mask.row_splits) else: # Count the number of True mask values in each row to find the # lengths of the filtered rows; then convert to splits. int_mask = ragged_functional_ops.map_flat_values( math_ops.cast, mask, dtype=row_splits_dtype) masked_row_lengths = ragged_math_ops.reduce_sum(int_mask, axis=1) splits.append(ragged_util.lengths_to_splits(masked_row_lengths)) mask = mask.values data = data.values # Recursively apply the nested non-ragged mask to the nested data. masked_values = boolean_mask(data, mask) # Add the ragged `splits` back to the result. masked_values = ragged_tensor.RaggedTensor.from_nested_row_splits( masked_values, splits, validate=False) return masked_values # If mask is non-ragged and has rank 1, and data is ragged, then build a # ragged tensor with the indicated rows. elif ragged_tensor.is_ragged(data) and mask.shape.ndims == 1: # Get the masked splits: first get the length of each row, then filter # out the rows that we are deleting, and convert that filtered set of # masks back to a splits tensor. lengths = data.row_lengths() masked_lengths = array_ops.boolean_mask(lengths, mask) masked_splits = ragged_util.lengths_to_splits(masked_lengths) # Get the masked values: first get row ids corresponding to each # value, then use tf.gather to build a boolean mask that's false for # values that come from rows that we are deleting, and use that mask to # construct the masked values tensor. segment_ids = segment_id_ops.row_splits_to_segment_ids(data.row_splits) segment_mask = array_ops.gather(mask, segment_ids) masked_values = boolean_mask(data.values, segment_mask) return ragged_tensor.RaggedTensor.from_row_splits( masked_values, masked_splits, validate=False) # If mask is non-ragged and has rank>1, then convert it to be ragged, # with a ragged rank matching data. if ragged_tensor.is_ragged(data): mask = ragged_tensor.RaggedTensor.from_tensor( mask, ragged_rank=min(data.ragged_rank, mask.shape.ndims - 1), row_splits_dtype=data.row_splits.dtype) return boolean_mask(data, mask) # Otherwise, data and mask are both `Tensor`s. else: # Apply `boolean_mask` to get the masked values. masked_values = array_ops.boolean_mask(data, mask) if mask.shape.ndims >= 2: # Add the innermost ragged dimension. For each innermost cell, get the # number of values it contains. Then flatten that to get a list of # cell lengths, and convert it to splits. Finally, combine the splits # and values to get the innermost ragged tensor. masked_lengths = math_ops.count_nonzero( mask, axis=-1, dtype=row_splits_dtype) flattened_masked_lengths = array_ops.reshape(masked_lengths, [-1]) masked_values = ragged_tensor.RaggedTensor.from_row_lengths( masked_values, flattened_masked_lengths, validate=False) # Wrap remaining ragged dimensions. if mask.shape.ndims > 2: mask_shape = array_ops.shape(mask, out_type=row_splits_dtype) split_size = math_ops.cumprod(mask_shape) + 1 for dim in range(mask.shape.ndims - 3, -1, -1): elt_size = mask_shape[dim + 1] masked_splits = math_ops.range(split_size[dim]) * elt_size masked_values = ragged_tensor.RaggedTensor.from_row_splits( masked_values, masked_splits, validate=False) return masked_values
def _ragged_stack_concat_helper(rt_inputs, axis, stack_values): """Helper function to concatenate or stack ragged tensors. Args: rt_inputs: A list of RaggedTensors or Tensors to combine. axis: The axis along which to concatenate or stack. stack_values: A boolean -- if true, then stack values; otherwise, concatenate them. Returns: A RaggedTensor. Raises: ValueError: If rt_inputs is empty, or if axis is out of range. """ # Validate parameters. if not rt_inputs: raise ValueError('rt_inputs may not be empty.') # Convert input tensors. rt_inputs = [ ragged_tensor.convert_to_tensor_or_ragged_tensor(rt_input, name='rt_input') for rt_input in rt_inputs ] row_splits_dtype, rt_inputs = ragged_tensor.match_row_splits_dtypes( *rt_inputs, return_dtype=True) rt_inputs = list(rt_inputs) # Special case: if there's only one input, then return it as-is. if len(rt_inputs) == 1: if stack_values: return ragged_array_ops.expand_dims(rt_inputs[0], axis=axis) else: return rt_inputs[0] # Check the rank (number of dimensions) of the input tensors. ndims = None for rt in rt_inputs: if ndims is None: ndims = rt.shape.ndims else: rt.shape.assert_has_rank(ndims) out_ndims = ndims if (ndims is None or not stack_values) else ndims + 1 axis = ragged_util.get_positive_axis(axis, out_ndims) # If all the inputs are Tensors, and we're combining the final dimension, # then we can delegate to the tf.stack/tf.concat operation, and return a # Tensor. if all(not ragged_tensor.is_ragged(rt) for rt in rt_inputs): if ndims is not None and (axis == out_ndims - 1 or axis == ndims - 1): if stack_values: return array_ops.stack(rt_inputs, axis) else: return array_ops.concat(rt_inputs, axis) # Convert any Tensor inputs to RaggedTensors. This makes it # possible to concatenate Tensors and RaggedTensors together. for i in range(len(rt_inputs)): if not ragged_tensor.is_ragged(rt_inputs[i]): rt_inputs[i] = ragged_tensor.RaggedTensor.from_tensor( rt_inputs[i], ragged_rank=1, row_splits_dtype=row_splits_dtype) # Convert the input tensors to all have the same ragged_rank. ragged_rank = max(max(rt.ragged_rank for rt in rt_inputs), 1) rt_inputs = [ _increase_ragged_rank_to(rt, ragged_rank, row_splits_dtype) for rt in rt_inputs ] if axis == 0: return _ragged_stack_concat_axis_0(rt_inputs, stack_values) elif axis == 1: return _ragged_stack_concat_axis_1(rt_inputs, stack_values) else: # axis > 1: recurse. values = [rt.values for rt in rt_inputs] splits = [[rt_input.row_splits] for rt_input in rt_inputs] with ops.control_dependencies(ragged_util.assert_splits_match(splits)): return ragged_tensor.RaggedTensor.from_row_splits( _ragged_stack_concat_helper(values, axis - 1, stack_values), splits[0][0], validate=False)
def ragged_assert_compatible_and_get_flat_values(values, mask=None): """If ragged, it checks the compatibility and then returns the flat_values. Note: If two tensors are dense, it does not check their compatibility. Note: Although two ragged tensors with different ragged ranks could have identical overall rank and dimension sizes and hence be compatible, we do not support those cases. Args: values: A list of potentially ragged tensor of the same ragged_rank. mask: A potentially ragged tensor of the same ragged_rank as elements in Values. Returns: A tuple in which the first element is the list of tensors and the second is the mask tensor. ([Values], mask). Mask and the element in Values are equal to the flat_values of the input arguments (if they were ragged). """ if isinstance(values, list): is_all_ragged = \ all(isinstance(rt, ragged_tensor.RaggedTensor) for rt in values) is_any_ragged = \ any(isinstance(rt, ragged_tensor.RaggedTensor) for rt in values) else: is_all_ragged = isinstance(values, ragged_tensor.RaggedTensor) is_any_ragged = is_all_ragged if (is_all_ragged and ((mask is None) or isinstance(mask, ragged_tensor.RaggedTensor))): to_be_stripped = False if not isinstance(values, list): values = [values] to_be_stripped = True # NOTE: we leave the flat_values compatiblity to # tf.TensorShape `assert_is_compatible_with` # check if both dynamic dimensions are equal and then use the flat_values. nested_row_split_list = [rt.nested_row_splits for rt in values] assertion_list = ragged_util.assert_splits_match(nested_row_split_list) # if both are ragged sample_weights also should be ragged with same dims. if isinstance(mask, ragged_tensor.RaggedTensor): assertion_list_for_mask = ragged_util.assert_splits_match( [nested_row_split_list[0], mask.nested_row_splits]) tmp = control_flow_ops.with_dependencies(assertion_list_for_mask, mask.flat_values) mask = array_ops.expand_dims(tmp, -1) # values has at least 1 element. flat_values = [] for value in values: tmp = control_flow_ops.with_dependencies(assertion_list, value.flat_values) flat_values.append(array_ops.expand_dims(tmp, -1)) values = flat_values[0] if to_be_stripped else flat_values elif is_any_ragged: raise TypeError('One of the inputs does not have acceptable types.') # values are empty or value are not ragged and mask is ragged. elif isinstance(mask, ragged_tensor.RaggedTensor): raise TypeError('Ragged mask is not allowed with non-ragged inputs.') return values, mask
def __init__(self, shape, fields): """Creates a `StructuredTensor` from a dictionary of fields. Args: shape: A `TensorShape`: static information about the shape of the `StructuredTensor`. Must have a known `rank`. fields: A dictionary mapping from string to `Tensor`, `RaggedTensor`, or `StructuredTensor`, providing the values for individual fields in each structure. If `ndims > 0`, then every tensor in `fields` must have the same shape in the first `shape.rank` dimensions; and that shape must be compatible with `shape`. Returns: A `StructuredTensor`. """ shape = tensor_shape.as_shape(shape) rank = shape.ndims if rank is None: raise ValueError("StructuredTensor's shape must have known rank.") if not isinstance(fields, dict): raise TypeError('fields must be a dictionary, got %s' % type(fields).__name__) self._fields = {} with ops.name_scope(None, 'StructuredTensor', fields.values()): for (key, value) in fields.items(): if not isinstance(key, str): raise TypeError('Unexpected type for key in `fields`: %r' % key) if not _FIELD_NAME_RE.match(key): raise ValueError('Field name %r is not currently allowed.' % key) if not isinstance( value, (ops.Tensor, ragged_tensor.RaggedTensor, StructuredTensor)): if ragged_tensor.is_ragged(value): value = ragged_tensor.convert_to_tensor_or_ragged_tensor(value) else: try: value = ops.convert_to_tensor(value) except (ValueError, TypeError): raise TypeError('Unexpected type for value in `fields`: %r' % value) self._fields[key] = value # Check the static TensorShape for this StructuredTensor. self._static_shape = shape if rank > 0: for value in self._fields.values(): self._static_shape = self._static_shape.merge_with(value.shape[:rank]) self._nested_row_splits = [] if rank > 1: # If any fields are ragged, then check that all row-splits match. shared_row_splits = [] for field in self._fields.values(): # TODO(edloper): A field shouldn't count as ragged if it has # uniform_row_length defined for all the dimensions in question. if isinstance(field, ragged_tensor.RaggedTensor): shared_row_splits.append(field.nested_row_splits[:rank - 1]) elif isinstance(field, StructuredTensor) and field.ragged_rank > 0: shared_row_splits.append(field.nested_row_splits[:rank - 1]) if shared_row_splits: if len(shared_row_splits) != len(self._fields): raise ValueError('Ragged StructuredTensor contains non-ragged fields') # Check if the splits are identical. This should be the common case. identical_splits = True for splits in shared_row_splits[1:]: if len(splits) != len(shared_row_splits[0]): raise ValueError('Fields have inconsistent ragged_rank') for (s1, s2) in zip(splits, shared_row_splits[0]): if s1 is not s2: identical_splits = False if identical_splits: self._nested_row_splits = shared_row_splits[0] else: # If splits aren't identical, then add assertions to check that they # match. with ops.control_dependencies( ragged_util.assert_splits_match(shared_row_splits)): self._nested_row_splits = [array_ops.identity(splits) for splits in shared_row_splits[0]]
def _ragged_stack_concat_helper(rt_inputs, axis, stack_values): """Helper function to concatenate or stack ragged tensors. Args: rt_inputs: A list of RaggedTensors or Tensors to combine. axis: The axis along which to concatenate or stack. stack_values: A boolean -- if true, then stack values; otherwise, concatenate them. Returns: A RaggedTensor. Raises: ValueError: If rt_inputs is empty, or if axis is out of range. """ # Validate parameters. if not rt_inputs: raise ValueError('rt_inputs may not be empty.') # Convert input tensors. rt_inputs = [ ragged_tensor.convert_to_tensor_or_ragged_tensor( rt_input, name='rt_input') for rt_input in rt_inputs ] row_splits_dtype, rt_inputs = ragged_tensor.match_row_splits_dtypes( *rt_inputs, return_dtype=True) rt_inputs = list(rt_inputs) # Special case: if there's only one input, then return it as-is. if len(rt_inputs) == 1: if stack_values: return ragged_array_ops.expand_dims(rt_inputs[0], axis=axis) else: return rt_inputs[0] # Check the rank (number of dimensions) of the input tensors. ndims = None for rt in rt_inputs: if ndims is None: ndims = rt.shape.ndims else: rt.shape.assert_has_rank(ndims) out_ndims = ndims if (ndims is None or not stack_values) else ndims + 1 axis = ragged_util.get_positive_axis(axis, out_ndims) # If all the inputs are Tensors, and we're combining the final dimension, # then we can delegate to the tf.stack/tf.concat operation, and return a # Tensor. if all(not ragged_tensor.is_ragged(rt) for rt in rt_inputs): if ndims is not None and (axis == out_ndims - 1 or axis == ndims - 1): if stack_values: return array_ops.stack(rt_inputs, axis) else: return array_ops.concat(rt_inputs, axis) # Convert any Tensor inputs to RaggedTensors. This makes it # possible to concatenate Tensors and RaggedTensors together. for i in range(len(rt_inputs)): if not ragged_tensor.is_ragged(rt_inputs[i]): rt_inputs[i] = ragged_tensor.RaggedTensor.from_tensor( rt_inputs[i], ragged_rank=1, row_splits_dtype=row_splits_dtype) # Convert the input tensors to all have the same ragged_rank. ragged_rank = max(max(rt.ragged_rank for rt in rt_inputs), 1) rt_inputs = [_increase_ragged_rank_to(rt, ragged_rank, row_splits_dtype) for rt in rt_inputs] if axis == 0: return _ragged_stack_concat_axis_0(rt_inputs, stack_values) elif axis == 1: return _ragged_stack_concat_axis_1(rt_inputs, stack_values) else: # axis > 1: recurse. values = [rt.values for rt in rt_inputs] splits = [[rt_input.row_splits] for rt_input in rt_inputs] with ops.control_dependencies(ragged_util.assert_splits_match(splits)): return ragged_tensor.RaggedTensor.from_row_splits( _ragged_stack_concat_helper(values, axis - 1, stack_values), splits[0][0], validate=False)
def map_flat_values(op, *args, **kwargs): """Applies `op` to the values of one or more RaggedTensors. Replaces any `RaggedTensor` in `args` or `kwargs` with its `flat_values` tensor, and then calls `op`. Returns a `RaggedTensor` that is constructed from the input `RaggedTensor`s' `nested_row_splits` and the value returned by the `op`. If the input arguments contain multiple `RaggedTensor`s, then they must have identical `nested_row_splits`. Examples: >>> rt = tf.ragged.constant([[1, 2, 3], [], [4, 5], [6]]) >>> map_flat_values(tf.ones_like, rt).to_list() [[1, 1, 1], [], [1, 1], [1]] >>> map_flat_values(tf.multiply, rt, rt).to_list() [[1, 4, 9], [], [16, 25], [36]] >>> map_flat_values(tf.add, rt, 5).to_list() [[6, 7, 8], [], [9, 10], [11]] Args: op: The operation that should be applied to the RaggedTensor `flat_values`. `op` is typically an element-wise operation (such as math_ops.add), but any operation that preserves the size of the outermost dimension can be used. I.e., `shape[0]` of the value returned by `op` must match `shape[0]` of the `RaggedTensor`s' `flat_values` tensors. *args: Arguments for `op`. **kwargs: Keyword arguments for `op`. Returns: A `RaggedTensor` whose `ragged_rank` matches the `ragged_rank` of all input `RaggedTensor`s. Raises: ValueError: If args contains no `RaggedTensors`, or if the `nested_splits` of the input `RaggedTensor`s are not identical. """ # Replace RaggedTensors with their values; and collect the splits tensors # from each RaggedTensor. nested_splits_lists = [] inner_args = _replace_ragged_with_flat_values(args, nested_splits_lists) inner_kwargs = _replace_ragged_with_flat_values(kwargs, nested_splits_lists) if not nested_splits_lists: return op(*args, **kwargs) split_dtypes = set(splits[0].dtype for splits in nested_splits_lists) if len(split_dtypes) > 1: if not ragged_config.auto_cast_partition_dtype(): raise ValueError( "Input RaggedTensors have mismatched row_splits dtypes; " "use RaggedTensor.with_row_splits_dtype() to convert " "them to compatible dtypes.") nested_splits_lists = [ [math_ops.cast(s, dtypes.int64) for s in nested_splits] # pylint: disable=g-complex-comprehension for nested_splits in nested_splits_lists ] with ops.control_dependencies( ragged_util.assert_splits_match(nested_splits_lists)): # Delegate to op, and then compose the result from the transformed values # and the splits. return ragged_tensor.RaggedTensor.from_nested_row_splits( op(*inner_args, **inner_kwargs), nested_splits_lists[0], validate=False)
def boolean_mask(data, mask, keepdims=False, name=None): """Applies a boolean mask to `data`. Returns a potentially ragged tensor that is formed by retaining the elements in `data` where the corresponding value in `mask` is `True`. If `keepdims` is true then outer dimensions (corresponding to the `mask` dimensions) are preserved, and: * `output[a1...aA, i, b1...bB] = data[a1...aA, j, b1...bB]` Where `j` is the `i`th `True` entry of `mask[a1...aA]`. If `keepdims` is false, then the outer dimensions are collapsed (similar to the behavior of `tf.boolean_mask`), and: * `output[i, b1...bB] = data[a1...aA, b1...bB]` Where `(a1...aA)` is the `i`th `True` entry of `mask` (in row-major order). Args: data: A potentially ragged tensor. mask: A potentially ragged boolean tensor. `mask`'s shape must be a prefix of `data`'s shape. `rank(mask)` must be known statically. keepdims: Whether to preserve the outer dimensions (`keepdims=True`) or flatten them (`keepdims=False`). name: A name prefix for the returned tensor (optional). Returns: A potentially ragged tensor that is formed by retaining the elements in `data` where the corresponding value in `mask` is `True`. If `keepdims` is false: * `rank(output) = rank(data) - rank(mask) + 1`. * `output.ragged_rank = max(data.ragged_rank - rank(mask) + 1, 0)`. If `keepdims` is true: * `rank(output) = rank(data)`. * `output.ragged_rank = max(data.ragged_rank, rank(mask) - 1)`. Raises: ValueError: if `rank(mask)` is not known statically; or if `mask.shape` is not a prefix of `data.shape`. #### Examples: ```python >>> # Aliases for True & False so data and mask line up. >>> T, F = (True, False) >>> tf.ragged.boolean_mask( # Mask a 2D Tensor. Flatten outer dims. ... data=[[1, 2, 3], [4, 5, 6], [7, 8, 9]], ... mask=[[T, F, T], [F, F, F], [T, F, F]], ... keepdims=False).tolist() [1, 3, 7] >>> tf.ragged.boolean_mask( # Mask a 2D Tensor. Preserve outer dims. ... data=[[1, 2, 3], [4, 5, 6], [7, 8, 9]], ... mask=[[T, F, T], [F, F, F], [T, F, F]], ... keepdims=True).tolist() [[1, 3], [], [7]] >>> tf.ragged.boolean_mask( # Mask a 2D RaggedTensor. Flatten outer dims. ... tf.ragged.constant([[1, 2, 3], [4], [5, 6]]), ... tf.ragged.constant([[F, F, T], [F], [T, T]]), ... keepdims=False).tolist() [3, 5, 6] >>> tf.ragged.boolean_mask( # Mask a 2D RaggedTensor. Preserve outer dims. ... tf.ragged.constant([[1, 2, 3], [4], [5, 6]]), ... tf.ragged.constant([[F, F, T], [F], [T, T]]), ... keepdims=True).tolist() [[3], [], [5, 6]] >>> tf.ragged.boolean_mask( # Mask rows of a 2D RaggedTensor. ... tf.ragged.constant([[1, 2, 3], [4], [5, 6]]), ... tf.ragged.constant([True, False, True]), ... keepdims=True).tolist() [[1, 2, 3], [5, 6]] ``` """ with ops.name_scope(name, 'RaggedMask', [data, mask]): # Convert inputs to tensors. data = ragged_tensor.convert_to_tensor_or_ragged_tensor(data, name='data') mask = ragged_tensor.convert_to_tensor_or_ragged_tensor( mask, dtypes.bool, name='mask') row_splits_dtype, (data, mask) = ragged_tensor.match_row_splits_dtypes( data, mask, return_dtype=True) # Get static rank of mask. if mask.shape.ndims is None: raise ValueError('mask.shape.ndims must be known statically.') elif mask.shape.ndims == 0: raise ValueError('mask cannot be scalar.') # If mask is ragged, then recurse with a non-ragged mask. if ragged_tensor.is_ragged(mask): if not ragged_tensor.is_ragged(data): data = ragged_tensor.RaggedTensor.from_tensor( data, ragged_rank=mask.ragged_rank, row_splits_dtype=mask.row_splits.dtype) # Check that mask.nested_row_splits is a prefix of # data.nested_row_splits. splits_list = [ mask.nested_row_splits, data.nested_row_splits[:mask.ragged_rank] ] with ops.control_dependencies( ragged_util.assert_splits_match(splits_list)): # Strip off ragged `splits` until `mask` is non-ragged. Keep the splits # that we strip off in `splits`, so we can add them back on after # we recursively mask the non-ragged data. splits = [] while ragged_tensor.is_ragged(mask): if mask.shape.ndims > 2: splits.append(mask.row_splits) else: # Count the number of True mask values in each row to find the # lengths of the filtered rows; then convert to splits. int_mask = ragged_functional_ops.map_flat_values( math_ops.cast, mask, dtype=row_splits_dtype) masked_row_lengths = ragged_math_ops.reduce_sum(int_mask, axis=1) splits.append(ragged_util.lengths_to_splits(masked_row_lengths)) mask = mask.values data = data.values # Recursively apply the nested non-ragged mask to the nested data. masked_values = boolean_mask(data, mask, keepdims) # Add the ragged `splits` back to the result. if keepdims: masked_values = ragged_tensor.RaggedTensor.from_nested_row_splits( masked_values, splits, validate=False) return masked_values # If mask is non-ragged and has rank 1, and data is ragged, then build a # ragged tensor with the indicated rows. elif ragged_tensor.is_ragged(data) and mask.shape.ndims == 1: # Get the masked splits: first get the length of each row, then filter # out the rows that we are deleting, and convert that filtered set of # masks back to a splits tensor. lengths = data.row_lengths() masked_lengths = array_ops.boolean_mask(lengths, mask) masked_splits = ragged_util.lengths_to_splits(masked_lengths) # Get the masked values: first get row ids corresponding to each # value, then use tf.gather to build a boolean mask that's false for # values that come from rows that we are deleting, and use that mask to # construct the masked values tensor. segment_ids = segment_id_ops.row_splits_to_segment_ids(data.row_splits) segment_mask = array_ops.gather(mask, segment_ids) masked_values = boolean_mask(data.values, segment_mask, keepdims=False) return ragged_tensor.RaggedTensor.from_row_splits(masked_values, masked_splits, validate=False) # If mask is non-ragged and has rank>1, then convert it to be ragged, # with a ragged rank matching data. if ragged_tensor.is_ragged(data): mask = ragged_tensor.RaggedTensor.from_tensor( mask, ragged_rank=min(data.ragged_rank, mask.shape.ndims - 1), row_splits_dtype=data.row_splits.dtype) return boolean_mask(data, mask, keepdims) # Otherwise, data and mask are both `Tensor`s. else: # Apply `boolean_mask` to get the masked values. masked_values = array_ops.boolean_mask(data, mask) if mask.shape.ndims >= 2 and keepdims: # Add the innermost ragged dimension. For each innermost cell, get the # number of values it contains. Then flatten that to get a list of # cell lengths, and convert it to splits. Finally, combine the splits # and values to get the innermost ragged tensor. masked_lengths = math_ops.count_nonzero(mask, axis=-1, dtype=row_splits_dtype) flattened_masked_lengths = array_ops.reshape(masked_lengths, [-1]) masked_values = ragged_tensor.RaggedTensor.from_row_lengths( masked_values, flattened_masked_lengths, validate=False) # Wrap remaining ragged dimensions. if mask.shape.ndims > 2 and keepdims: mask_shape = array_ops.shape(mask, out_type=row_splits_dtype) split_size = math_ops.cumprod(mask_shape) + 1 for dim in range(mask.shape.ndims - 3, -1, -1): elt_size = mask_shape[dim + 1] masked_splits = math_ops.range(split_size[dim]) * elt_size masked_values = ragged_tensor.RaggedTensor.from_row_splits( masked_values, masked_splits, validate=False) return masked_values
def _broadcast_elementwise_args(elementwise_args): """Broadcasts the values of `elementwise_args` to have compatible shapes. Args: elementwise_args: A dictionary whose keys are potentially ragged tensors. Returns: A tuple `(broadcast_args, broadcast_splits, checks)` where: * `broadcast_args` is a dictionary with the same keys as `elementwise_args`, mapping to broadcasted tensors. * `broadcast_splits` is the broadcasted nested row splits. * `checks` is a possibly empty tuple of assertion operations that should be added as control dependencies. Raises: ValueError: If broadcasting fails. """ # No elementwise arguments were used: nothing to do! if not elementwise_args: return elementwise_args, (), () # A single elementwise argument was used: no broadcasting necessary. if len(elementwise_args) == 1: arg = list(elementwise_args.values())[0] if ragged_tensor.is_ragged(arg): return elementwise_args, arg.nested_row_splits, () else: return elementwise_args, (), () # Multiple elementwise arguments. else: is_ragged = [ ragged_tensor.is_ragged(t) for t in elementwise_args.values() ] if not any(is_ragged): return elementwise_args, (), () # Support limited broadcasting (namely, scalar + ragged). Full # broadcasting support will be added later. if all((ragged_tensor.is_ragged(t) or t.shape.ndims == 0) for t in elementwise_args.values()): nested_splits_lists = [ t.nested_row_splits for t in elementwise_args.values() if ragged_tensor.is_ragged(t) ] if len(nested_splits_lists) == 1: checks = () else: if any(t.shape.ndims is None for t in elementwise_args.values()): raise ValueError( 'Ragged elementwise ops require that rank (number ' 'of dimensions) be statically known.') if len(set(t.shape.ndims for t in elementwise_args.values())) != 1: raise ValueError('Ragged elementwise ops do not support ' 'broadcasting yet') checks = ragged_util.assert_splits_match(nested_splits_lists) return (elementwise_args, nested_splits_lists[0], checks) else: raise ValueError( 'Ragged elementwise ops do not support broadcasting yet')
def map_flat_values(op, *args, **kwargs): """Applies `op` to the `flat_values` of one or more RaggedTensors. Replaces any `RaggedTensor` in `args` or `kwargs` with its `flat_values` tensor (which collapses all ragged dimensions), and then calls `op`. Returns a `RaggedTensor` that is constructed from the input `RaggedTensor`s' `nested_row_splits` and the value returned by the `op`. If the input arguments contain multiple `RaggedTensor`s, then they must have identical `nested_row_splits`. This operation is generally used to apply elementwise operations to each value in a `RaggedTensor`. Warning: `tf.ragged.map_flat_values` does *not* apply `op` to each row of a ragged tensor. This difference is important for non-elementwise operations, such as `tf.reduce_sum`. If you wish to apply a non-elementwise operation to each row of a ragged tensor, use `tf.map_fn` instead. (You may need to specify an `output_signature` when using `tf.map_fn` with ragged tensors.) Examples: >>> rt = tf.ragged.constant([[1, 2, 3], [], [4, 5], [6]]) >>> tf.ragged.map_flat_values(tf.ones_like, rt) <tf.RaggedTensor [[1, 1, 1], [], [1, 1], [1]]> >>> tf.ragged.map_flat_values(tf.multiply, rt, rt) <tf.RaggedTensor [[1, 4, 9], [], [16, 25], [36]]> >>> tf.ragged.map_flat_values(tf.add, rt, 5) <tf.RaggedTensor [[6, 7, 8], [], [9, 10], [11]]> Example with a non-elementwise operation (note that `map_flat_values` and `map_fn` return different results): >>> rt = tf.ragged.constant([[1.0, 3.0], [], [3.0, 6.0, 3.0]]) >>> def normalized(x): ... return x / tf.reduce_sum(x) >>> tf.ragged.map_flat_values(normalized, rt) <tf.RaggedTensor [[0.0625, 0.1875], [], [0.1875, 0.375, 0.1875]]> >>> tf.map_fn(normalized, rt) <tf.RaggedTensor [[0.25, 0.75], [], [0.25, 0.5, 0.25]]> Args: op: The operation that should be applied to the RaggedTensor `flat_values`. `op` is typically an element-wise operation (such as math_ops.add), but any operation that preserves the size of the outermost dimension can be used. I.e., `shape[0]` of the value returned by `op` must match `shape[0]` of the `RaggedTensor`s' `flat_values` tensors. *args: Arguments for `op`. **kwargs: Keyword arguments for `op`. Returns: A `RaggedTensor` whose `ragged_rank` matches the `ragged_rank` of all input `RaggedTensor`s. Raises: ValueError: If args contains no `RaggedTensors`, or if the `nested_splits` of the input `RaggedTensor`s are not identical. """ # Replace RaggedTensors with their values; and collect the splits tensors # from each RaggedTensor. nested_splits_lists = [] flat_values_nrows = [] inner_args = _replace_ragged_with_flat_values(args, nested_splits_lists, flat_values_nrows) inner_kwargs = _replace_ragged_with_flat_values(kwargs, nested_splits_lists, flat_values_nrows) if not nested_splits_lists: return op(*args, **kwargs) if flat_values_nrows: flat_values_nrows = set(flat_values_nrows) if len(flat_values_nrows) != 1: raise ValueError( "Input RaggedTensors' flat_values must all have the " "same outer-dimension size. Got sizes: %s" % flat_values_nrows) flat_values_nrows = flat_values_nrows.pop() # Get the single element else: flat_values_nrows = None split_dtypes = set(splits[0].dtype for splits in nested_splits_lists) if len(split_dtypes) > 1: if not ragged_config.auto_cast_partition_dtype(): raise ValueError( "Input RaggedTensors have mismatched row_splits dtypes; " "use RaggedTensor.with_row_splits_dtype() to convert " "them to compatible dtypes.") nested_splits_lists = [ [math_ops.cast(s, dtypes.int64) for s in nested_splits] # pylint: disable=g-complex-comprehension for nested_splits in nested_splits_lists ] with ops.control_dependencies( ragged_util.assert_splits_match(nested_splits_lists)): # Delegate to `op` op_output = op(*inner_args, **inner_kwargs) # Check that the result has the expected shape (if known). if flat_values_nrows is not None: if not op_output.shape[:1].is_compatible_with([flat_values_nrows]): raise ValueError( "tf.ragged.map_flat_values requires that the output of `op` have " "the same outer-dimension size as flat_values of any ragged " "inputs. (output shape: %s; expected outer dimension size: %s)" % (op_output.shape, flat_values_nrows)) # Compose the result from the transformed values and the splits. return ragged_tensor.RaggedTensor.from_nested_row_splits( op_output, nested_splits_lists[0], validate=False)