def test_on_demand_op_with_dynamic_output(self): with ops.device("/device:TPU:0"): where_output = array_ops.where([True, False, True]) self.assertAllEqual(where_output, [[0], [2]]) with ops.device("/device:TPU:0"): repeat_output = array_ops.repeat(math_ops.range(2), [1, 4]) self.assertAllEqual(repeat_output, [0, 1, 1, 1, 1])
def test_on_demand_op_with_dynamic_output(self): if FLAGS.tpu_use_tfrt: self.skipTest("Support dynamic output in TFRT, see b/192576400") with ops.device("/device:TPU:0"): where_output = array_ops.where([True, False, True]) self.assertAllEqual(where_output, [[0], [2]]) with ops.device("/device:TPU:0"): repeat_output = array_ops.repeat(math_ops.range(2), [1, 4]) self.assertAllEqual(repeat_output, [0, 1, 1, 1, 1])
def accum(self, keys, old_values, new_values, exists, name=None): """ Insert `keys` with `values` if not exist, or accumulate a delta value `new_values - old_values` to 'keys'. This API will help relieve stale gradient problem in asynchronous training. Args: keys: Keys to insert. Can be a tensor of any shape. Must match the table's key type. old_values: old values to be associated with keys. Must be a tensor of arrays with same shape as `keys` and match the table's value type. new_values: new values to be associated with keys. Must be a tensor of arrays with same shape as `keys` and match the table's value type. exists: A bool type tensor indicates if keys existed or not. Must be a tensor of the same shape as `keys`. name: A name for the operation (optional). Returns: The created Operation. Raises: TypeError: when `keys` or `values` doesn't match the table data types. """ exists = ops.convert_to_tensor(exists, dtypes.bool, name="original_exists") exists = array_ops.reshape(exists, shape=[-1, 1]) exists_expanded = array_ops.repeat(exists, axis=-1, repeats=self.dim) exists_expanded = array_ops.reshape(exists_expanded, shape=array_ops.shape(old_values)) values_or_deltas = array_ops.where(exists_expanded, new_values - old_values, new_values, name="values_or_deltas") partition_index = self.partition_fn(keys, self.shard_num) keys_partitions, _ = make_partition(keys, partition_index, self.shard_num) values_or_deltas_partitions, _ = make_partition( values_or_deltas, partition_index, self.shard_num) exists_partitions, _ = make_partition(exists, partition_index, self.shard_num) ops_ = [] for idx in range(len(self.devices)): with ops.device(self.devices[idx]): ops_.append(self._tables[idx].accum( keys_partitions[idx], values_or_deltas_partitions[idx], exists_partitions[idx], name=name)) return control_flow_ops.group(ops_)
def _random_flip(image, flip_index, seed, scope_name, flip_3D_together=False): """Randomly (50% chance) flip an image along axis `flip_index`. Args: image: 4-D Tensor of shape `[batch, height, width, channels]` or 3-D Tensor of shape `[height, width, channels]`. flip_index: Dimension along which to flip the image. Vertical: 0, Horizontal: 1 seed: A Python integer. Used to create a random seed. See `tf.compat.v1.set_random_seed` for behavior. scope_name: Name of the scope in which the ops are added. Returns: A tensor of the same type and shape as `image`. Raises: ValueError: if the shape of `image` not supported. """ with ops.name_scope(None, scope_name, [image]) as scope: image = ops.convert_to_tensor(image, name='image') shape = image.get_shape() if shape.ndims == 3 or shape.ndims is None: uniform_random = random_ops.random_uniform([], 0, 1.0, seed=seed) mirror_cond = math_ops.less(uniform_random, .5) result = control_flow_ops.cond( mirror_cond, lambda: array_ops.reverse(image, [flip_index]), lambda: image, name=scope) return fix_image_flip_shape(image, result) elif shape.ndims == 4: batch_size = array_ops.shape(image)[0] if flip_3D_together: uniform_random = array_ops.repeat( random_ops.random_uniform([1], 0, 1.0, seed=seed), batch_size) else: uniform_random = random_ops.random_uniform([batch_size], 0, 1.0, seed=seed) flips = math_ops.round( array_ops.reshape(uniform_random, [batch_size, 1, 1, 1])) flips = math_ops.cast(flips, image.dtype) flipped_input = array_ops.reverse(image, [flip_index + 1]) return flips * flipped_input + (1 - flips) * image else: raise ValueError('\'image\' must have either 3 or 4 dimensions.')
def _matmul_3d_with_batch_dim_folding(a, b, **kwargs): """Multiply batches of 2D matrices where only `a.shape[1]` is ragged. Args: a: A RaggedTensor with `shape=[B, (I), J]`. (ragged_rank must be 1.) b: A Tensor with `shape=[B, J, K]` **kwargs: Additional arguments for `tf.matmul` (e.g. transpose_a). transpose_a and adjoint_a must not be true. Returns: A RaggedTensor with `shape=[B, (I), K]. """ # reshaped_a.shape = [sum(i_1, i_2, ..., i_B), 1, J] reshaped_a = array_ops.expand_dims(a.values, 1) # reshaped_b.shape = [sum(i_1, i_2, ..., i_B), J, K] reshaped_b = array_ops.repeat(b, a.row_lengths(), axis=0) # flat_result.shape = [sum(i_1, i_2, ..., i_B), 1, K] flat_result = math_ops.matmul(reshaped_a, reshaped_b, **kwargs) # result.shape = [B, (I), K] return a.with_values(array_ops.squeeze(flat_result, axis=1))
def split(value: ragged_tensor.Ragged, num_or_size_splits, axis=0, num=None, name=None): """Splits a RaggedTensor `value` into a list of sub RaggedTensors. If `num_or_size_splits` is an `int`, then it splits `value` along the dimension `axis` into `num_or_size_splits` smaller RaggedTensors. This requires that `value.shape[axis]` is divisible by `num_or_size_splits`. If `num_or_size_splits` is a 1-D Tensor (or list), then `value` is split into `len(num_or_size_splits)` elements. The shape of the `i`-th element has the same size as the `value` except along dimension `axis` where the size is `num_or_size_splits[i]`. Splits along a ragged dimension is not allowed. For example: >>> rt = tf.RaggedTensor.from_row_lengths( ... np.arange(6 * 3).reshape(6, 3), row_lengths=[1, 2, 2, 1]) >>> rt.shape TensorShape([4, None, 3]) >>> >>> rt1, rt2 = tf.split(rt, 2) # uniform splits >>> rt1.shape TensorShape([2, None, 3]) >>> rt2.shape TensorShape([2, None, 3]) >>> >>> rt3, rt4, rt5 = tf.split(rt, [1, 2, 1]) # ragged splits >>> rt3.shape TensorShape([1, None, 3]) >>> rt4.shape TensorShape([2, None, 3]) >>> rt5.shape TensorShape([1, None, 3]) >>> >>> rt6, rt7 = tf.split(rt, [1, 2], axis=2) # splits along axis 2 >>> rt6.shape TensorShape([4, None, 1]) >>> rt7.shape TensorShape([4, None, 2]) Args: value: The `RaggedTensor` to split. num_or_size_splits: Either an `int` indicating the number of splits along `axis` or a 1-D integer `Tensor` or Python list containing the sizes of each output tensor along `axis`. If a Python int, then it must evenly divide `value.shape[axis]`; otherwise the sum of sizes along the split axis must match that of the `value`. axis: An `int` or scalar `int32` `Tensor`. The dimension along which to split. Must be in the range `[-rank(value), rank(value))`. Defaults to 0. num: An `int` used to specify the number of outputs when `num_or_size_splits` is a 1-D list or `Tensor` and its length is statically unknown, e.g., specifying `tf.TensorSepc(None)` with the `input_signature` argument of `tf.function` (optional). name: A name for the operation (optional). Returns: if `num_or_size_splits` is an `int` returns a list of `num_or_size_splits` `RaggedTensor` objects; if `num_or_size_splits` is a 1-D Tensor returns `num_or_size_splits.get_shape[0]` `RaggedTensor` objects resulting from splitting `value`. Raises: ValueError: If the dimension `axis` of `value` is a ragged dimension. ValueError: If `num` is unspecified and cannot be inferred. ValueError: If `num` is specified but doesn't match the length of `num_or_size_splits`. ValueError: If `num_or_size_splits` is an `int` and less than 1. TypeError: If `num_or_size_splits` is not an `int` or 1-D list or 1-D `Tensor`. InvalidArgumentError: If the `axis` of `value` cannot be exactly splitted by `num_or_size_splits`. InvalidArgumentError: If `num_or_size_splits` is contains negative integers. InvalidArgumentError: If `num_or_size_splits`'s static shape is unknown and its dynamic shape is inconsistent `num`. InvalidArgumentError: If `num_or_size_splits`'s static rank is unknown and `axis` is a negative integer. """ with ops.name_scope(name, 'RaggedSplit'): value = ragged_tensor.convert_to_tensor_or_ragged_tensor( value, name='value') if isinstance(num_or_size_splits, int) and num_or_size_splits == 1: return [value] # static assert check_ops.assert_integer_v2( num_or_size_splits, message=('`num_or_size_splits` must be an `int` or 1-D list or ' '`Tensor` of integers.')) value_shape = ragged_shape.RaggedShape.from_tensor(value) axis = array_ops.get_positive_axis(axis, value_shape.rank) try: dim_size = value_shape[axis] except ValueError: raise ValueError('Cannot split a ragged dimension. Got `value` with ' f'shape {value_shape} and `axis` {axis}.') if isinstance(num_or_size_splits, int): # Uniform split num_splits = num_or_size_splits if num_splits < 1: raise ValueError('`num_or_size_splits` must be >=1 if it is an `int`.' f'Received {num_or_size_splits}.') split_length = math_ops.floordiv(dim_size, num_splits) split_lengths = array_ops.repeat(split_length, num_splits) else: # Ragged split num_splits = None split_lengths = ops.convert_to_tensor(num_or_size_splits) if split_lengths.shape.ndims is not None: if split_lengths.shape.ndims != 1: raise TypeError('`num_or_size_splits` must be an `int` or 1-D list ' f'or `Tensor`. Received {num_or_size_splits}.') num_splits = tensor_shape.dimension_value(split_lengths.shape[0]) if num_splits is None: if num is None: raise ValueError('`num` must be specified as an `int` when the ' 'size of `num_or_size_split` is statically ' f'unknown. Received `num`: {num} and ' f'`num_or_size_split`: {num_or_size_splits}.') num_splits = num else: if num is not None and num != num_splits: raise ValueError('`num` does not match the size of ' f'`num_or_size_split`. Received `num`: {num} and ' f'size of `num_or_size_split`: {num_splits}.') splits = array_ops.concat([[0], math_ops.cumsum(split_lengths)], axis=0) checks = [] checks.append( check_ops.assert_non_negative_v2( num_or_size_splits, message='`num_or_size_splits` must be non-negative.')) checks.append( check_ops.assert_equal_v2( num_splits, array_ops.shape(split_lengths)[0], message='`num` is inconsistent with `num_or_size_split.shape[0]`.')) checks.append( check_ops.assert_equal_v2( math_ops.cast(dim_size, splits.dtype), splits[-1], message=('Cannot exactly split the `axis` dimension of `value` ' 'with the given `num_or_size_split`.'))) splits = control_flow_ops.with_dependencies(checks, splits) splited_rts = [] slices = [slice(None)] * (axis + 1) for i in range(num_splits): slices[-1] = slice(splits[i], splits[i + 1]) splited_rts.append(value[tuple(slices)]) return splited_rts
def repeat(values, repeats, axis): return array_ops.repeat(values, repeats, axis)
def _batch_gather(params, indices, axis, batch_dims): """Helper that implements the body for ragged gather() when batch_dims>0. Args: params: The tensor from which to gather values. indices: The indices of values to gather. axis: The axis in `params` to gather `indices` from. batch_dims: The number of batch dimensions. Returns: A potentially ragged tensor. """ # Perform static checks that `params` and `indices` have compatible batch # dimensions. Note: we do not perform *runtime* checks that `params` and # `indices` actually have the same row-splits (because we wish to avoid the # runtime cost of those checks). If `params` and `indices` are # incompatible, the resulting `RaggedTensor` may be nonsensical. if not params.shape[:batch_dims].is_compatible_with( indices.shape[:batch_dims]): raise ValueError('batch shape from indices %s does not match params ' 'shape %s' % (indices.shape[:batch_dims], params.shape)) if batch_dims > 1: # Convert params & indices to ragged tensors. if not isinstance(params, ragged_tensor.RaggedTensor): if indices.uniform_row_length is None: raise ValueError( 'batch shape from indices does not match params shape: ragged ' 'indices dimension corresponds to uniform params dimension' ) params = ragged_tensor.RaggedTensor.from_tensor( params, ragged_rank=1, row_splits_dtype=indices.row_splits.dtype) if not isinstance(indices, ragged_tensor.RaggedTensor): if params.uniform_row_length is None: raise ValueError( 'batch shape from indices does not match params shape: ragged ' 'params dimension corresponds to uniform indices dimension' ) indices = ragged_tensor.RaggedTensor.from_tensor( indices, ragged_rank=1, row_splits_dtype=params.row_splits.dtype) # Flatten the two outer batch dimensions into a single batch dimension, # and recurse. return params.with_values( _gather(params.values, indices.values, axis - 1, batch_dims - 1)) if axis > 1: # Convert an axis dimension into a batch dimension, by adding a dimension # to `indices`, and tiling it to match `params`. E.g., if `params` # had shape `[B, P1, P2]`, and `indices` had shape `[B, I1, I2]`, then we # tile `indices` to have shape `[B, P1, I1, I2]`. That way, we can treat # the `P1` dimension as a batch dimension. if not isinstance(indices, ragged_tensor.RaggedTensor): adjusted_indices = params.with_values( array_ops.repeat(indices, params.row_lengths(), 0)) else: if not isinstance(params, ragged_tensor.RaggedTensor): params = ragged_tensor.RaggedTensor.from_tensor( params, ragged_rank=1, row_splits_dtype=indices.row_splits.dtype) adjusted_indices = _gather( indices, params.with_values( array_ops.repeat(math_ops.range(params.nrows()), params.row_lengths())), 0, 0) return _batch_gather(params, adjusted_indices, axis, batch_dims + 1) if indices.shape.rank is None: raise ValueError('rank(indices) must be known statically') assert batch_dims == 1 # If params.shape=[B, P1...PN] and indices.shape=[B, I1...IM], then: # # output[b, i1...im, p2...pn] = # params[b, indices[b, i1...im], p2...pn] # # We construct `output` by flattening `params`, adjusting the `indices` to # point into that flattened list, and recursively calling `gather`. flat_params = _flatten_dims_0_and_1(params) adjustments = _row_starts(params, indices.dtype) # offset for each batch # increase adjustments's rank so it broadcasts w/ the outer dim of indices adjustments = _increase_rank_to(adjustments, indices.shape.ndims) adjusted_indices = indices + adjustments return _gather(flat_params, adjusted_indices, axis - 1, 0)
def _add_batched_ragged_partition(rt, partition, tensor_dict, feature_key, validate, outer_splits=None): """Adds a batched ragged partition tensor to a batched ragged tensor. Args: rt: A RaggedTensor with shape [batch_size, ...]. partition: The partition configuration object. Specifies the key that should be used to look up the partition tensor (unless partition is a RaggedFeature.UniformRowLength, in which case there is no partition tensor). The specified tensor must have shape [batch_size, ...]. tensor_dict: The dictionary mapping keys to tensors. feature_key: The name of the feature being parsed (for error messages). validate: Whether to validate that the values form a valid RaggedTensor. outer_splits: If not None, then we have two batch dimensions, and this is the row-splits for the collapsed batch dimension. Every partition tensor must have an outer row_splits that matches this value. Returns: A new RaggedTensor where each batch item `rt[i]` has been partitioned using the `partition_t[i]`. """ if isinstance(partition, RaggedFeature.UniformRowLength): if rt.ragged_rank > 1: length = ops.convert_to_tensor(partition.length, rt.row_splits.dtype) return ragged_tensor.RaggedTensor.from_row_splits( ragged_tensor.RaggedTensor.from_uniform_row_length( rt.values, length, validate=validate), rt.row_splits // length, validate=validate) else: reshaped_vals = array_ops.reshape( rt.values, array_ops.concat([[-1, partition.length], array_ops.shape(rt.values)[1:]], axis=0)) return ragged_tensor.RaggedTensor.from_row_splits( reshaped_vals, rt.row_splits // partition.length, validate=validate) partition_t = tensor_dict[partition.key] if partition_t.values.dtype != rt.row_splits.dtype: partition_t = math_ops.cast(partition_t, rt.row_splits.dtype) checks = [] if outer_splits is not None: if validate: checks.append( check_ops.assert_equal( outer_splits, partition_t.row_splits, message="Feature %s: values and partitions are not aligned" % feature_key)) partition_t = partition_t.values with ops.control_dependencies(checks): if isinstance(partition, (RaggedFeature.RowSplits, RaggedFeature.RowLimits)): if isinstance(partition, RaggedFeature.RowSplits): partition_t = partition_t[:, 1:] adjusted_limits = partition_t.values + array_ops.repeat( rt.row_starts(), partition_t.row_lengths()) return partition_t.with_values( ragged_tensor.RaggedTensor.from_row_limits(rt.values, adjusted_limits, validate=validate)) elif isinstance(partition, RaggedFeature.RowStarts): adjusted_starts = partition_t.values + array_ops.repeat( rt.row_starts(), partition_t.row_lengths()) return partition_t.with_values( ragged_tensor.RaggedTensor.from_row_starts(rt.values, adjusted_starts, validate=validate)) elif isinstance(partition, RaggedFeature.RowLengths): return partition_t.with_values( ragged_tensor.RaggedTensor.from_row_lengths(rt.values, partition_t.values, validate=validate)) elif isinstance(partition, RaggedFeature.ValueRowIds): nrows = math_ops.maximum( # number of rows in each batch item ragged_math_ops.reduce_max(partition_t + 1, axis=1), 0) adjusted_rowids = partition_t.values + array_ops.repeat( math_ops.cumsum(nrows, exclusive=True), partition_t.row_lengths()) return ragged_tensor.RaggedTensor.from_row_lengths( ragged_tensor.RaggedTensor.from_value_rowids( rt.values, adjusted_rowids, validate=validate), nrows, validate=validate) raise ValueError(f"Unhandled partition type {partition!r}")
def repeat(a, repeats, axis=None): a = asarray(a).data repeats = asarray(repeats).data return np_utils.tensor_to_ndarray(array_ops.repeat(a, repeats, axis))
def _add_batched_ragged_partition(rt, partition, tensor_dict, validate): """Adds a batched ragged partition tensor to a batched ragged tensor. Args: rt: A RaggedTensor with shape [batch_size, ...]. partition: The partition configuration object. Specifies the key that should be used to look up the partition tensor (unless partition is a RaggedFeature.UniformRowLength, in which case there is no partition tensor). The specified tensor must have shape [batch_size, ...]. tensor_dict: The dictionary mapping keys to tensors. validate: Whether to validate that the values form a valid RaggedTensor. Returns: A new RaggedTensor where each batch item `rt[i]` has been partitioned using the `partition_t[i]`. """ if isinstance(partition, RaggedFeature.UniformRowLength): if rt.ragged_rank > 1: length = ops.convert_to_tensor(partition.length, rt.row_splits.dtype) return ragged_tensor.RaggedTensor.from_row_splits( ragged_tensor.RaggedTensor.from_uniform_row_length( rt.values, length, validate=validate), rt.row_splits // length, validate=validate) else: reshaped_vals = array_ops.reshape( rt.values, array_ops.concat([[-1, partition.length], array_ops.shape(rt.values)[1:]], axis=0)) return ragged_tensor.RaggedTensor.from_row_splits( reshaped_vals, rt.row_splits // partition.length, validate=validate) partition_t = tensor_dict[partition.key] if partition_t.values.dtype != rt.row_splits.dtype: partition_t = math_ops.cast(partition_t, rt.row_splits.dtype) if isinstance(partition, (RaggedFeature.RowSplits, RaggedFeature.RowLimits)): if isinstance(partition, RaggedFeature.RowSplits): partition_t = partition_t[:, 1:] adjusted_limits = partition_t.values + array_ops.repeat( rt.row_starts(), partition_t.row_lengths()) return partition_t.with_values( ragged_tensor.RaggedTensor.from_row_limits(rt.values, adjusted_limits, validate=validate)) elif isinstance(partition, RaggedFeature.RowStarts): adjusted_starts = partition_t.values + array_ops.repeat( rt.row_starts(), partition_t.row_lengths()) return partition_t.with_values( ragged_tensor.RaggedTensor.from_row_starts(rt.values, adjusted_starts, validate=validate)) elif isinstance(partition, RaggedFeature.RowLengths): return partition_t.with_values( ragged_tensor.RaggedTensor.from_row_lengths(rt.values, partition_t.values, validate=validate)) elif isinstance(partition, RaggedFeature.ValueRowIds): nrows = math_ops.maximum( # number of rows in each batch item ragged_math_ops.reduce_max(partition_t + 1, axis=1), 0) adjusted_rowids = partition_t.values + array_ops.repeat( math_ops.cumsum(nrows, exclusive=True), partition_t.row_lengths()) return ragged_tensor.RaggedTensor.from_row_lengths( ragged_tensor.RaggedTensor.from_value_rowids(rt.values, adjusted_rowids, validate=validate), nrows, validate=validate) raise ValueError("Unhandled partition type %r" % partition)