def worker_fn(input_tensor): def replica_fn(input_tensor): return input_tensor + v run_result = strategy.run(replica_fn, args=(input_tensor, )) check_ops.assert_equal_v2(run_result, 4) return run_result
def worker_fn(input_tensor): def replica_fn(input_tensor): # Within `replica_fn`, it has to be in a replica context. self.assertFalse( distribution_strategy_context.in_cross_replica_context()) return input_tensor + v, input_tensor - v run_result = self.strategy.run(replica_fn, args=(input_tensor,)) reduced_result = self.strategy.reduce('SUM', run_result, axis=None) check_ops.assert_equal_v2(reduced_result, expected_result) return reduced_result
def testSaveNonDistributed(self, distribution, synchronization, aggregation): # This test verifies that the DistributedVariable behave like the primary # variable when saving a non-distributed version of the model (the default). # The test asserts that the function traced under SaveContext has no device # annotations and only reference the primary component of the variable. Note # that please avoid capturing other eager tensors in this test to make the # assertion easy. if isinstance( distribution.extended, parameter_server_strategy.ParameterServerStrategyExtended): self.skipTest("b/148689177: AggregatingVariable doesn't " "conform to Variable interface well") # tf.function requires the return value to be Tensors, which is not always # case for properties and methods of Variable, so we simply discard the # return values. def _discard_return(f): f() return def _test(f, v): # This verifies that the function under SaveContext: # - contains no device annotations. # - only references the primary component of the variable. g = def_function.function(lambda: _discard_return(f)) options = save_options.SaveOptions( experimental_variable_policy=save_options.VariablePolicy.NONE) with save_context.save_context(options): # The graph should contain no device. graph = g.get_concrete_function().graph for op in graph.get_operations(): self.assertEqual(op.device, "", msg=str(op)) # The function should only capture the primary variable. Note that it # may not have captures, e.g. v.aggregation. captures = list(graph.captures) self.assertLessEqual(len(captures), 1) if graph.captures: self.assertIs(captures[0][0], v._primary.handle) def _assert(cond): return control_flow_ops.Assert(cond, [cond]) with distribution.scope(): # We use four variables for convenience reasons. They have no special # meaning. # - v is used whenever possible. # - w is used for scatter and gather, which require the variable to be # non-scalar. # - y is used when the dtype needs to be integer. Note that aggregation # cannot be MEAN for integers. v = variables_lib.Variable(0., synchronization=synchronization, aggregation=aggregation, trainable=True) w = variables_lib.Variable([0., 0., 0.], synchronization=synchronization, aggregation=aggregation, trainable=True) if aggregation != variables_lib.VariableAggregation.MEAN: y = variables_lib.Variable(0, synchronization=synchronization, aggregation=aggregation) # pylint: disable=g-long-lambda # tf.Variable properties. _test(lambda: self.assertEqual(v.aggregation, aggregation), v) _test(lambda: self.assertIs(v.constraint, None), v) # TODO(crccw): should we raise an error instead? _test(lambda: self.assertEqual(v.device, v._primary.device), v) _test(lambda: self.assertEqual(v.dtype, dtypes.float32), v) if not context.executing_eagerly(): _test(lambda: self.assertIs(v.graph, v._primary.graph), v) if not context.executing_eagerly(): _test(lambda: _assert(v.initial_value == 0), v) _test(lambda: self.assertIs(v.initializer, v._primary.initializer), v) _test(lambda: self.assertEqual(v.name, "Variable:0"), v) if not context.executing_eagerly(): _test(lambda: self.assertIs(v.op, v._primary.op), v) _test(lambda: self.assertEqual(v.shape, tensor_shape.TensorShape(())), v) _test(lambda: self.assertEqual(v.synchronization, synchronization), v) _test(lambda: self.assertEqual(v.trainable, True), v) # tf.Variable methods. _test(lambda: check_ops.assert_equal_v2(v.assign(1.), 1.), v) _test(lambda: check_ops.assert_equal_v2(v.assign_add(1.), 2.), v) _test(lambda: check_ops.assert_equal_v2(v.assign_sub(1.), 1.), v) # TODO(b/148689177): Implement batch_scatter_update. # count_up_to() is skipped since it's deprecated. # eval() is skipped since it shouldn't called in a tf.function. # experimental_ref() is skipped since it's deprecated. # from_proto() is skipped since it shouldn't called in a tf.function. # TODO(b/148689177): Implement gather_nd. _test( lambda: check_ops.assert_equal_v2(v.get_shape(), tensor_shape.TensorShape(())), v) # initialized_value() is skipped since it shouldn't called in a tf.function. # load() is skipped since it shouldn't called in a tf.function. _test(lambda: check_ops.assert_equal_v2(v.read_value(), 1.), v) # ref() is skipped since it shouldn't called in a tf.function. _test( lambda: check_ops.assert_equal_v2( w.scatter_add( _make_index_slices(values=[1., 2.], indices=[0, 2])), [1., 0., 2.]), w) _test( lambda: check_ops.assert_equal_v2( w.scatter_div( _make_index_slices(values=[4., 2.], indices=[0, 2])), [0.25, 0., 1.]), w) _test( lambda: check_ops.assert_equal_v2( w.scatter_max( _make_index_slices(values=[1., 0.5], indices=[1, 2])), [0.25, 1., 1.]), w) _test( lambda: check_ops.assert_equal_v2( w.scatter_min( _make_index_slices(values=[1., 0.5], indices=[0, 1])), [0.25, 0.5, 1.]), w) _test( lambda: check_ops.assert_equal_v2( w.scatter_mul( _make_index_slices(values=[2., 0.5], indices=[0, 1])), [0.5, 0.25, 1.]), w) # TODO(b/148689177): Implement scatter_nd_* _test( lambda: check_ops.assert_equal_v2( w.scatter_sub( _make_index_slices(values=[2., 0.5], indices=[0, 1])), [-1.5, -0.25, 1.]), w) _test( lambda: check_ops.assert_equal_v2( w.scatter_update( _make_index_slices(values=[2., 0.5], indices=[0, 1])), [2., 0.5, 1.]), w) # set_shape() is skipped since ResourceVariable doesn't implement it. # to_proto() is skipped since it shouldn't called in a tf.function. _test(lambda: check_ops.assert_equal_v2(v.value(), 1.), v) # DistributedVariable should be treated as ResourceVariable, so it needs to # conform to ResourceVariable interface as well. _test(lambda: self.assertIs(v.handle, v._primary.handle), v) # Convert to tensor. _test(lambda: check_ops.assert_equal_v2(ops.convert_to_tensor(v), 1.), v) # Control dependency. def _with_control_dep(): with ops.control_dependencies([v.assign(1.)]): return array_ops.identity(1) _test(_with_control_dep, v) # Operator overloads. _test(lambda: check_ops.assert_equal_v2(v.assign(7.), 7.), v) _test(lambda: check_ops.assert_equal_v2(v + 1., 8.), v) _test(lambda: check_ops.assert_equal_v2(3 + v, 10.), v) _test(lambda: check_ops.assert_equal_v2(v + v, 14.), v) _test(lambda: check_ops.assert_equal_v2(v - 2., 5.), v) _test(lambda: check_ops.assert_equal_v2(v - v, 0.), v) _test(lambda: check_ops.assert_equal_v2(v * 2., 14.), v) _test(lambda: check_ops.assert_equal_v2(3 * v, 21.), v) _test(lambda: check_ops.assert_equal_v2(v * v, 49.), v) _test( lambda: check_ops.assert_equal_v2( math_ops.cast(v / 2., dtypes.float32), 3.5), v) _test( lambda: check_ops.assert_equal_v2( math_ops.cast(14. / v, dtypes.float32), 2.), v) _test(lambda: _assert(v < 12.), v) _test(lambda: _assert(v <= 12.), v) _test(lambda: _assert(not v > 12.), v) _test(lambda: _assert(not v >= 12.), v) _test(lambda: _assert(not 12. < v), v) _test(lambda: _assert(not 12. <= v), v) _test(lambda: _assert(12. > v), v) _test(lambda: _assert(12. >= v), v) _test(lambda: check_ops.assert_near_v2(pow(v, 3.), 343.), v) _test(lambda: check_ops.assert_near_v2(pow(2., v), 128.), v) _test(lambda: check_ops.assert_equal_v2(abs(v), 7.), v) # Operator overloads that only works for integers. if aggregation != variables_lib.VariableAggregation.MEAN: _test(lambda: check_ops.assert_equal_v2(y.assign(7), 7), y) _test(lambda: check_ops.assert_equal_v2(y // 2, 3), y) _test(lambda: check_ops.assert_equal_v2(15 // y, 2), y) _test(lambda: check_ops.assert_equal_v2(y % 2, 1), y) _test(lambda: check_ops.assert_equal_v2(16 % y, 2), y) _test(lambda: check_ops.assert_equal_v2(y & 3, 3), y) _test(lambda: check_ops.assert_equal_v2(3 & y, 3), y) _test(lambda: check_ops.assert_equal_v2(y | 8, 15), y) _test(lambda: check_ops.assert_equal_v2(16 | y, 23), y) _test(lambda: check_ops.assert_equal_v2(y ^ 3, 4), y) _test(lambda: check_ops.assert_equal_v2(11 ^ y, 12), y) _test(lambda: check_ops.assert_equal_v2(-y, -7), y) _test(lambda: check_ops.assert_equal_v2(~y, ~7), y) # Index. if isinstance(distribution.extended, tpu_strategy.TPUExtended): # TODO(b/161572567): slice assignment doesn't work for TPU. _test(lambda: check_ops.assert_equal_v2(w[0], 2.), w) else: _test( lambda: check_ops.assert_equal_v2(w[0].assign(1.), [1., 0.5, 1.]), w) _test(lambda: check_ops.assert_equal_v2(w[0], 1.), w)
def split(value: ragged_tensor.Ragged, num_or_size_splits, axis=0, num=None, name=None): """Splits a RaggedTensor `value` into a list of sub RaggedTensors. If `num_or_size_splits` is an `int`, then it splits `value` along the dimension `axis` into `num_or_size_splits` smaller RaggedTensors. This requires that `value.shape[axis]` is divisible by `num_or_size_splits`. If `num_or_size_splits` is a 1-D Tensor (or list), then `value` is split into `len(num_or_size_splits)` elements. The shape of the `i`-th element has the same size as the `value` except along dimension `axis` where the size is `num_or_size_splits[i]`. Splits along a ragged dimension is not allowed. For example: >>> rt = tf.RaggedTensor.from_row_lengths( ... np.arange(6 * 3).reshape(6, 3), row_lengths=[1, 2, 2, 1]) >>> rt.shape TensorShape([4, None, 3]) >>> >>> rt1, rt2 = tf.split(rt, 2) # uniform splits >>> rt1.shape TensorShape([2, None, 3]) >>> rt2.shape TensorShape([2, None, 3]) >>> >>> rt3, rt4, rt5 = tf.split(rt, [1, 2, 1]) # ragged splits >>> rt3.shape TensorShape([1, None, 3]) >>> rt4.shape TensorShape([2, None, 3]) >>> rt5.shape TensorShape([1, None, 3]) >>> >>> rt6, rt7 = tf.split(rt, [1, 2], axis=2) # splits along axis 2 >>> rt6.shape TensorShape([4, None, 1]) >>> rt7.shape TensorShape([4, None, 2]) Args: value: The `RaggedTensor` to split. num_or_size_splits: Either an `int` indicating the number of splits along `axis` or a 1-D integer `Tensor` or Python list containing the sizes of each output tensor along `axis`. If a Python int, then it must evenly divide `value.shape[axis]`; otherwise the sum of sizes along the split axis must match that of the `value`. axis: An `int` or scalar `int32` `Tensor`. The dimension along which to split. Must be in the range `[-rank(value), rank(value))`. Defaults to 0. num: An `int` used to specify the number of outputs when `num_or_size_splits` is a 1-D list or `Tensor` and its length is statically unknown, e.g., specifying `tf.TensorSepc(None)` with the `input_signature` argument of `tf.function` (optional). name: A name for the operation (optional). Returns: if `num_or_size_splits` is an `int` returns a list of `num_or_size_splits` `RaggedTensor` objects; if `num_or_size_splits` is a 1-D Tensor returns `num_or_size_splits.get_shape[0]` `RaggedTensor` objects resulting from splitting `value`. Raises: ValueError: If the dimension `axis` of `value` is a ragged dimension. ValueError: If `num` is unspecified and cannot be inferred. ValueError: If `num` is specified but doesn't match the length of `num_or_size_splits`. ValueError: If `num_or_size_splits` is an `int` and less than 1. TypeError: If `num_or_size_splits` is not an `int` or 1-D list or 1-D `Tensor`. InvalidArgumentError: If the `axis` of `value` cannot be exactly splitted by `num_or_size_splits`. InvalidArgumentError: If `num_or_size_splits` is contains negative integers. InvalidArgumentError: If `num_or_size_splits`'s static shape is unknown and its dynamic shape is inconsistent `num`. InvalidArgumentError: If `num_or_size_splits`'s static rank is unknown and `axis` is a negative integer. """ with ops.name_scope(name, 'RaggedSplit'): value = ragged_tensor.convert_to_tensor_or_ragged_tensor( value, name='value') if isinstance(num_or_size_splits, int) and num_or_size_splits == 1: return [value] # static assert check_ops.assert_integer_v2( num_or_size_splits, message=('`num_or_size_splits` must be an `int` or 1-D list or ' '`Tensor` of integers.')) value_shape = ragged_shape.RaggedShape.from_tensor(value) axis = array_ops.get_positive_axis(axis, value_shape.rank) try: dim_size = value_shape[axis] except ValueError: raise ValueError('Cannot split a ragged dimension. Got `value` with ' f'shape {value_shape} and `axis` {axis}.') if isinstance(num_or_size_splits, int): # Uniform split num_splits = num_or_size_splits if num_splits < 1: raise ValueError('`num_or_size_splits` must be >=1 if it is an `int`.' f'Received {num_or_size_splits}.') split_length = math_ops.floordiv(dim_size, num_splits) split_lengths = array_ops.repeat(split_length, num_splits) else: # Ragged split num_splits = None split_lengths = ops.convert_to_tensor(num_or_size_splits) if split_lengths.shape.ndims is not None: if split_lengths.shape.ndims != 1: raise TypeError('`num_or_size_splits` must be an `int` or 1-D list ' f'or `Tensor`. Received {num_or_size_splits}.') num_splits = tensor_shape.dimension_value(split_lengths.shape[0]) if num_splits is None: if num is None: raise ValueError('`num` must be specified as an `int` when the ' 'size of `num_or_size_split` is statically ' f'unknown. Received `num`: {num} and ' f'`num_or_size_split`: {num_or_size_splits}.') num_splits = num else: if num is not None and num != num_splits: raise ValueError('`num` does not match the size of ' f'`num_or_size_split`. Received `num`: {num} and ' f'size of `num_or_size_split`: {num_splits}.') splits = array_ops.concat([[0], math_ops.cumsum(split_lengths)], axis=0) checks = [] checks.append( check_ops.assert_non_negative_v2( num_or_size_splits, message='`num_or_size_splits` must be non-negative.')) checks.append( check_ops.assert_equal_v2( num_splits, array_ops.shape(split_lengths)[0], message='`num` is inconsistent with `num_or_size_split.shape[0]`.')) checks.append( check_ops.assert_equal_v2( math_ops.cast(dim_size, splits.dtype), splits[-1], message=('Cannot exactly split the `axis` dimension of `value` ' 'with the given `num_or_size_split`.'))) splits = control_flow_ops.with_dependencies(checks, splits) splited_rts = [] slices = [slice(None)] * (axis + 1) for i in range(num_splits): slices[-1] = slice(splits[i], splits[i + 1]) splited_rts.append(value[tuple(slices)]) return splited_rts