def testWrapClass(self): # Normally a mirrored value would be the same across devices, but # for a test it is convenient to be able to tell the values apart. result = distribute_utils.regroup( (_nested_value("1"), _nested_value("2")), values.Mirrored) self.assertIsInstance(result, tuple) self.assertLen(result, 3) self._is_per_replica(result[0], ["a1", "a2"], values.Mirrored) self._is_per_replica(result[2], ["h1", "h2"], values.Mirrored) self.assertIsInstance(result[1], list) self.assertLen(result[1], 3) self._is_per_replica(result[1][0], ["b1", "b2"], values.Mirrored) self._is_per_replica(result[1][2], ["g1", "g2"], values.Mirrored) self.assertIsInstance(result[1][1], dict) self.assertEqual(set(["c", "e"]), set(result[1][1].keys())) self._is_per_replica(result[1][1]["c"], ["d1", "d2"], values.Mirrored) self._is_per_replica(result[1][1]["e"], ["f1", "f2"], values.Mirrored) # Also test that we can undo the merge using select_replica() self.assertEqual(_nested_value("1"), distribute_utils.select_replica(0, result)) self.assertEqual(_nested_value("2"), distribute_utils.select_replica(1, result)) # Values are marked as mirrored, so select_device_mirrored() is allowed. self.assertEqual(_nested_value("1"), distribute_utils.select_replica_mirrored(0, result)) self.assertEqual(_nested_value("2"), distribute_utils.select_replica_mirrored(1, result))
def testNested(self): result = distribute_utils.regroup( (_nested_value("1"), _nested_value("2"))) self.assertIsInstance(result, tuple) self.assertLen(result, 3) self._is_per_replica(result[0], ["a1", "a2"]) self._is_per_replica(result[2], ["h1", "h2"]) self.assertIsInstance(result[1], list) self.assertLen(result[1], 3) self._is_per_replica(result[1][0], ["b1", "b2"]) self._is_per_replica(result[1][2], ["g1", "g2"]) self.assertIsInstance(result[1][1], dict) self.assertEqual(set(["c", "e"]), set(result[1][1].keys())) self._is_per_replica(result[1][1]["c"], ["d1", "d2"]) self._is_per_replica(result[1][1]["e"], ["f1", "f2"]) # Also test that we can undo the merge using select_replica() self.assertEqual(_nested_value("1"), distribute_utils.select_replica(0, result)) self.assertEqual(_nested_value("2"), distribute_utils.select_replica(1, result)) # select_device_mirrored() should fail due to non-mirrored values with self.assertRaises(TypeError): distribute_utils.select_replica_mirrored(0, result) with self.assertRaises(TypeError): distribute_utils.select_replica_mirrored(1, result)
def _test_input_fn_iterator(self, task_type, task_id, num_gpus, input_fn, expected_values, test_reinitialize=True, ignore_order=False): distribution, master_target, config = self._get_test_object( task_type, task_id, num_gpus) devices = distribution.extended.worker_devices with ops.Graph().as_default(), \ self.cached_session(config=config, target=master_target) as sess: iterator = distribution.make_input_fn_iterator(input_fn) sess.run(iterator.initializer) for expected_value in expected_values: next_element = iterator.get_next() computed_value = sess.run([ distribute_utils.select_replica(r, next_element) for r in range(len(devices)) ]) if ignore_order: self.assertCountEqual(list(expected_value), list(computed_value)) else: self.assertEqual(list(expected_value), list(computed_value)) # error raised by calling optional_get_value on an Optional of None with self.assertRaises(errors.InvalidArgumentError): next_element = iterator.get_next() sess.run([ distribute_utils.select_replica(r, next_element) for r in range(len(devices)) ]) # After re-initializing the iterator, should be able to iterate again. if test_reinitialize: sess.run(iterator.initializer) for expected_value in expected_values: next_element = iterator.get_next() computed_value = sess.run([ distribute_utils.select_replica(r, next_element) for r in range(len(devices)) ]) if ignore_order: self.assertCountEqual(list(expected_value), list(computed_value)) else: self.assertEqual(list(expected_value), list(computed_value))
def _update_non_slot(self, colocate_with, fn, args, kwargs, group): assert isinstance(colocate_with, tuple) # TODO(josh11b): In eager mode, use one thread per device. updates = [] for i, d in enumerate(colocate_with): name = "update_%d" % i with ops.device(d), distribute_lib.UpdateContext(i), ops.name_scope(name): updates.append( fn(*distribute_utils.select_replica(i, args), **distribute_utils.select_replica(i, kwargs))) return distribute_utils.update_regroup(self, updates, group)
def _update(self, var, fn, args, kwargs, group): # TODO(josh11b): In eager mode, use one thread per device. assert isinstance(var, values.DistributedVariable) updates = [] for i, v in enumerate(var.values): name = "update_%d" % i with ops.device(v.device), \ distribute_lib.UpdateContext(i), \ ops.name_scope(name): # If args and kwargs are not mirrored, the value is returned as is. updates.append( fn(v, *distribute_utils.select_replica(i, args), **distribute_utils.select_replica(i, kwargs))) return distribute_utils.update_regroup(self, updates, group)
def rewrite_fn(*args): """The rewritten step fn running on TPU.""" del args per_replica_inputs = multi_worker_iterator.get_next() replicate_inputs = [] for replica_id in range(self._num_replicas_in_sync): select_replica = lambda x: distribute_utils.select_replica( # pylint: disable=g-long-lambda replica_id, x) # pylint: disable=cell-var-from-loop replicate_inputs.append( (nest.map_structure(select_replica, per_replica_inputs), )) replicate_outputs = tpu.replicate( run_fn, replicate_inputs, device_assignment=self._device_assignment, xla_options=tpu.XLAOptions( use_spmd_for_xla_partitioning=False)) # If run_fn has tensor outputs, tpu.replicate returns a list of list. We # will flatten it in this case. If run_fn has no tensor outputs, # tpu.replicate returns a list of no_ops, we will keep the output as it # is. if isinstance(replicate_outputs[0], list): replicate_outputs = nest.flatten(replicate_outputs) return replicate_outputs
def _call_for_each_replica(self, fn, args, kwargs): # For now, `fn` must be an @tf.function. # TODO(josh11b): Relax this restriction? Main problem is if # (a) executing eagerly, (b) `fn` not @tf.function, and # (c) executed frequently. assert isinstance(fn, def_function.Function) if _outside_run_graph() is not None: # Nested case, should just use outer function's context for things like # the current replica index. # TODO(josh11b): Test this case! with MirroredFunctionReplicaContext(self._container_strategy()): results = fn(*nest.map_structure(_unwrap_tensors, args), **nest.map_structure(_unwrap_tensors, kwargs)) return nest.map_structure(_wrap_tensors, results) _replica_index.graph_outside_run = ops.get_default_graph() return_values = [] try: with MirroredFunctionReplicaContext(self._container_strategy()): for index, device in enumerate(self._devices): _replica_index.current = index with ops.device(device): if context.executing_eagerly(): # NOTE: These functions need to execute concurrently if they # use a collective op. This is a particular concern with eager # execution. with context.execution_mode(context.ASYNC): return_values.append( fn( *distribute_utils.select_replica( index, args), **distribute_utils.select_replica( index, kwargs))) else: return_values.append( fn( *distribute_utils.select_replica( index, args), **distribute_utils.select_replica( index, kwargs))) finally: _replica_index.graph_outside_run = None _replica_index.current = None return distribute_utils.regroup(return_values)
def testSameId(self): foo = object() result = distribute_utils.regroup((("a", foo), ("b", foo))) self.assertIsInstance(result, tuple) self.assertLen(result, 2) self._is_per_replica(result[0], ["a", "b"]) self.assertIs(foo, result[1]) # Test select_replica(), should undo the merge done by regroup(). result_0 = distribute_utils.select_replica(0, result) self.assertIsInstance(result_0, tuple) self.assertLen(result_0, 2) self.assertEqual("a", result_0[0]) self.assertIs(foo, result_0[1]) result_1 = distribute_utils.select_replica(1, result) self.assertIsInstance(result_1, tuple) self.assertLen(result_1, 2) self.assertEqual("b", result_1[0]) self.assertIs(foo, result_1[1])
def test_get_next_as_optional(iterator): for expected_value in expected_values: next_element = iterator.get_next_as_optional() computed_value = evaluate([ distribute_utils.select_replica(r, next_element.get_value()) for r in range(len(devices)) ]) self.assertEqual(len(expected_value), len(computed_value)) for i in range(len(expected_value)): self.assertAllEqual(expected_value[i], computed_value[i]) next_element = iterator.get_next_as_optional() self.assertFalse(self.evaluate(next_element.has_value())) with self.assertRaises(errors.InvalidArgumentError): evaluate([ distribute_utils.select_replica(r, next_element.get_value()) for r in range(len(devices)) ])
def _test_input_fn_iterator(self, iterator, devices, expected_values, sess=None, test_reinitialize=True, ignore_order=False): evaluate = lambda x: sess.run(x) if sess else self.evaluate(x) evaluate(iterator.initializer) for expected_value in expected_values: next_element = iterator.get_next() computed_value = evaluate([ distribute_utils.select_replica(r, next_element) for r in range(len(devices)) ]) if ignore_order: self.assertCountEqual(expected_value, computed_value) else: self.assertEqual(expected_value, computed_value) with self.assertRaises(errors.OutOfRangeError): next_element = iterator.get_next() evaluate([ distribute_utils.select_replica(r, next_element) for r in range(len(devices)) ]) # After re-initializing the iterator, should be able to iterate again. if test_reinitialize: evaluate(iterator.initializer) for expected_value in expected_values: next_element = iterator.get_next() computed_value = evaluate([ distribute_utils.select_replica(r, next_element) for r in range(len(devices)) ]) if ignore_order: self.assertCountEqual(expected_value, computed_value) else: self.assertEqual(expected_value, computed_value)
def testNamedTuple(self): # We include toy implementations of Scaffold and EstimatorSpec to # avoid a dependency on Estimator here. class Scaffold(object): pass class EstimatorSpec( collections.namedtuple( "EstimatorSpec", ["mode", "loss", "train_op", "scaffold"])): def __new__(cls, mode, loss, train_op, scaffold=None): return super(EstimatorSpec, cls).__new__(cls, mode=mode, loss=loss, train_op=train_op, scaffold=scaffold or Scaffold()) with context.graph_mode(), ops.Graph().as_default(): created_estimator_specs = [] for device_id in range(3): spec = EstimatorSpec(mode=mode_keys.EstimatorModeKeys.TRAIN, loss=constant_op.constant(device_id / 2), train_op=array_ops.identity( constant_op.constant(device_id))) created_estimator_specs.append(spec) merged_estimator_spec = distribute_utils.regroup( created_estimator_specs) self.assertIsInstance(merged_estimator_spec, EstimatorSpec) self.assertEqual(mode_keys.EstimatorModeKeys.TRAIN, merged_estimator_spec.mode) for device_id in range(3): self.assertEqual(created_estimator_specs[device_id].loss, merged_estimator_spec.loss.values[device_id]) self.assertEqual( created_estimator_specs[device_id].train_op, merged_estimator_spec.train_op.values[device_id]) # Scaffold is populated by `EstimatorSpec.__new__`. self.assertEqual( created_estimator_specs[device_id].scaffold, merged_estimator_spec.scaffold.values[device_id]) self.assertIsInstance( created_estimator_specs[device_id].scaffold, Scaffold) # Also test that we can undo the merge using select_replica() self.assertEqual( created_estimator_specs[device_id], distribute_utils.select_replica(device_id, merged_estimator_spec))
def test_get_next(iterator): for expected_value in expected_values: next_element = iterator.get_next() computed_value = evaluate([ distribute_utils.select_replica(r, next_element) for r in range(len(devices)) ]) self.assertEqual(len(expected_value), len(computed_value)) for i in range(len(expected_value)): self.assertAllEqual(expected_value[i], computed_value[i]) with self.assertRaises(errors.OutOfRangeError): next_element = iterator.get_next() evaluate([ distribute_utils.select_replica(r, next_element) for r in range(len(devices)) ]) # After re-initializing the iterator, should be able to iterate again. if not ops.executing_eagerly_outside_functions(): evaluate(control_flow_ops.group(iterator.initializer)) else: if api_type == "wrap_into_iterator": self.skipTest("unsupported test combination") else: iterator = iter(dataset) for expected_value in expected_values: next_element = iterator.get_next() computed_value = evaluate([ distribute_utils.select_replica(r, next_element) for r in range(len(devices)) ]) self.assertEqual(len(expected_value), len(computed_value)) for i in range(len(expected_value)): self.assertAllEqual(expected_value[i], computed_value[i])
def tpu_function(args, kwargs): """TF Function used to replicate the user computation.""" if kwargs is None: kwargs = {} # Remove None at the end of args as they are not replicatable # If there are None in the middle we can't do anything about it # so let those cases fail. # For example when Keras model predict is used they pass the targets as # None. We want to handle it here so all client libraries don't have to # do this as other strategies can handle None values better. while args and args[-1] is None: args = args[:-1] # Used to re-structure flattened output tensors from `tpu.replicate()` # into a structured format. result = [[]] def replicated_fn(replica_id, replica_args, replica_kwargs): """Wraps user function to provide replica ID and `Tensor` inputs.""" with _TPUReplicaContext(strategy, replica_id_in_sync_group=replica_id): result[0] = fn(*replica_args, **replica_kwargs) return result[0] replicate_inputs = [] # By replica. for i in range(strategy.num_replicas_in_sync): replicate_inputs.append([ constant_op.constant(i, dtype=dtypes.int32), distribute_utils.select_replica(i, args), distribute_utils.select_replica(i, kwargs) ]) # Construct and pass `maximum_shapes` so that we could support dynamic # shapes using dynamic padder. if options.experimental_enable_dynamic_batch_size and replicate_inputs: maximum_shapes = [] flattened_list = nest.flatten(replicate_inputs[0]) for input_tensor in flattened_list: if tensor_util.is_tensor(input_tensor): rank = input_tensor.shape.rank else: rank = np.ndim(input_tensor) maximum_shape = tensor_shape.TensorShape([None] * rank) maximum_shapes.append(maximum_shape) maximum_shapes = nest.pack_sequence_as(replicate_inputs[0], maximum_shapes) else: maximum_shapes = None if options.experimental_bucketizing_dynamic_shape: padding_spec = tpu.PaddingSpec.POWER_OF_TWO else: padding_spec = None with strategy.scope(): replicate_outputs = tpu.replicate( replicated_fn, replicate_inputs, device_assignment=self._device_assignment, maximum_shapes=maximum_shapes, padding_spec=padding_spec, xla_options=tpu.XLAOptions( use_spmd_for_xla_partitioning=False)) # Remove all no ops that may have been added during 'tpu.replicate()' if isinstance(result[0], list): result[0] = [ output for output in result[0] if not isinstance(output, ops.Operation) ] # Workaround for `tpu.replicate` behaviour when single `Tensor` returned. if result[0] is None or isinstance(result[0], ops.Operation): replicate_outputs = [None] * len(replicate_outputs) else: replicate_outputs = [ nest.pack_sequence_as(result[0], nest.flatten(replica_output)) for replica_output in replicate_outputs ] return distribute_utils.regroup(replicate_outputs)
def _test_input_iteration(self, input_type, api_type, iteration_type, dataset_or_input_fn, worker_device_pairs, expected_values, strategy, sess=None, split_batch_by=None, input_context=None): if iteration_type == "for_loop" and not context.executing_eagerly(): self.skipTest("unsupported test combination.") if api_type == "wrap_into_iterator" and iteration_type == "for_loop": self.skipTest("unsupported test combination.") if api_type == "wrap_into_iterator" and input_type == "input_fn": self.skipTest("unsupported test combination.") devices = nest.flatten([ds for _, ds in worker_device_pairs]) input_workers = input_lib.InputWorkers(worker_device_pairs) if api_type == "wrap_into_iterator": iterator = self._wrap_iterator( input_type, dataset_or_input_fn, input_workers, devices, split_batch_by, strategy, input_context=input_context) else: # wrapping into a dataset: dataset = self._wrap_dataset( input_type, dataset_or_input_fn, input_workers, split_batch_by, strategy, input_context=input_context) if ops.executing_eagerly_outside_functions(): iterator = iter(dataset) else: if isinstance(dataset, input_lib.DistributedDatasetV1): iterator = dataset.make_initializable_iterator() else: self.skipTest("unsupported test combination") if isinstance(iterator, composite_tensor.CompositeTensor): nest.assert_same_structure(iterator, iterator._type_spec, expand_composites=True) if iteration_type == "get_next": evaluate = lambda x: sess.run(x) if sess else self.evaluate(x) if not ops.executing_eagerly_outside_functions(): evaluate(control_flow_ops.group(iterator.initializer)) def test_get_next(iterator): for expected_value in expected_values: next_element = iterator.get_next() computed_value = evaluate([ distribute_utils.select_replica(r, next_element) for r in range(len(devices)) ]) self.assertEqual(len(expected_value), len(computed_value)) for i in range(len(expected_value)): self.assertAllEqual(expected_value[i], computed_value[i]) with self.assertRaises(errors.OutOfRangeError): next_element = iterator.get_next() evaluate([ distribute_utils.select_replica(r, next_element) for r in range(len(devices)) ]) # After re-initializing the iterator, should be able to iterate again. if not ops.executing_eagerly_outside_functions(): evaluate(control_flow_ops.group(iterator.initializer)) else: if api_type == "wrap_into_iterator": self.skipTest("unsupported test combination") else: iterator = iter(dataset) for expected_value in expected_values: next_element = iterator.get_next() computed_value = evaluate([ distribute_utils.select_replica(r, next_element) for r in range(len(devices)) ]) self.assertEqual(len(expected_value), len(computed_value)) for i in range(len(expected_value)): self.assertAllEqual(expected_value[i], computed_value[i]) def test_get_next_as_optional(iterator): for expected_value in expected_values: next_element = iterator.get_next_as_optional() computed_value = evaluate([ distribute_utils.select_replica(r, next_element.get_value()) for r in range(len(devices)) ]) self.assertEqual(len(expected_value), len(computed_value)) for i in range(len(expected_value)): self.assertAllEqual(expected_value[i], computed_value[i]) next_element = iterator.get_next_as_optional() self.assertFalse(self.evaluate(next_element.has_value())) with self.assertRaises(errors.InvalidArgumentError): evaluate([ distribute_utils.select_replica(r, next_element.get_value()) for r in range(len(devices)) ]) test_get_next(iterator) # re-initializing the iterator if not tf2.enabled(): self.skipTest("Not testing get_next_as_optional in TF1") else: if api_type == "wrap_into_iterator": self.skipTest("unsupported test combination") else: iterator = iter(dataset) test_get_next_as_optional(iterator) if iteration_type == "for_loop" and context.executing_eagerly(): actual_values = [] for x in dataset: computed_value = self.evaluate( [distribute_utils.select_replica(r, x) for r in range(len(devices))]) actual_values.append(computed_value) for i, expected_value in enumerate(expected_values): self.assertEqual(len(expected_value), len(actual_values[i])) for j in range(len(expected_value)): self.assertAllEqual(expected_value[j], actual_values[i][j])
def testOneDevice(self): result = distribute_utils.regroup((_nested_value("1"), )) # On one device regroup() and select_replica() are basically identity. self.assertEqual(_nested_value("1"), result) self.assertEqual(_nested_value("1"), distribute_utils.select_replica(0, result))
def _call_for_each_replica(distribution, fn, args, kwargs): """Run `fn` in separate threads, once per replica/worker device. Args: distribution: the DistributionStrategy object. fn: function to run (will be run once per replica, each in its own thread). args: positional arguments for `fn` kwargs: keyword arguments for `fn`. Returns: Merged return value of `fn` across all replicas. Raises: RuntimeError: If fn() calls get_replica_context().merge_call() a different number of times from the available devices. """ # TODO(josh11b): Add this option once we add synchronization to variable # creation. Until then, this is pretty unsafe to use. run_concurrently = False if not context.executing_eagerly(): # Needed for per-thread device, etc. contexts in graph mode. ops.get_default_graph().switch_to_thread_local() coord = coordinator.Coordinator( clean_stop_exception_types=(_RequestedStop, )) shared_variable_store = {} devices = distribution.extended.worker_devices thread_local_callables = _get_thread_local_configuration_callable() # TODO(isaprykin): Create these threads once instead of during every call. threads = [] for index in range(len(devices)): variable_creator_fn = shared_variable_creator.make_fn( shared_variable_store, index) t = _MirroredReplicaThread( distribution, coord, index, devices, variable_creator_fn, fn, distribute_utils.caching_scope_local, distribute_utils.select_replica(index, args), distribute_utils.select_replica(index, kwargs), thread_local_callables) threads.append(t) for t in threads: t.start() # When `fn` starts `should_run` event is set on _MirroredReplicaThread # (`MRT`) threads. The execution waits until # `MRT.has_paused` is set, which indicates that either `fn` is # complete or a `get_replica_context().merge_call()` is called. If `fn` is # complete, then `MRT.done` is set to True. Otherwise, arguments # of `get_replica_context().merge_call` from all paused threads are grouped # and the `merge_fn` is performed. Results of the # `get_replica_context().merge_call` are then set to `MRT.merge_result`. # Each such `get_replica_context().merge_call` call returns the # `MRT.merge_result` for that thread when `MRT.should_run` event # is reset again. Execution of `fn` resumes. try: with coord.stop_on_exception(): all_done = False while not all_done and not coord.should_stop(): done = [] if run_concurrently: for t in threads: t.should_run.set() for t in threads: t.has_paused.wait() t.has_paused.clear() if coord.should_stop(): return None done.append(t.done) else: for t in threads: t.should_run.set() t.has_paused.wait() t.has_paused.clear() if coord.should_stop(): return None done.append(t.done) if coord.should_stop(): return None all_done = all(done) if not all_done: if any(done): raise RuntimeError( "Some replicas made a different number of " "replica_context().merge_call() calls.") # get_replica_context().merge_call() case merge_args = distribute_utils.regroup( tuple(t.merge_args for t in threads)) merge_kwargs = distribute_utils.regroup( tuple(t.merge_kwargs for t in threads)) # We capture the name_scope of the MRT when we call merge_fn # to ensure that if we have opened a name scope in the MRT, # it will be respected when executing the merge function. We only # capture the name_scope from the first MRT and assume it is # the same for all other MRTs. mtt_captured_name_scope = threads[0].captured_name_scope mtt_captured_var_scope = threads[0].captured_var_scope # Capture and merge the control dependencies from all the threads. mtt_captured_control_deps = set() for t in threads: mtt_captured_control_deps.update( t.captured_control_deps) with ops.name_scope(mtt_captured_name_scope),\ ops.control_dependencies(mtt_captured_control_deps), \ variable_scope.variable_scope(mtt_captured_var_scope): merge_result = threads[0].merge_fn( distribution, *merge_args, **merge_kwargs) for r, t in enumerate(threads): t.merge_result = distribute_utils.select_replica( r, merge_result) finally: for t in threads: t.should_run.set() coord.join(threads) return distribute_utils.regroup(tuple(t.main_result for t in threads))
def testSelectReplica(self, distribution, synchronization, aggregation): with distribution.scope(): v = variables_lib.Variable(1., synchronization=synchronization, aggregation=aggregation) self.assertIs(v, distribute_utils.select_replica(0, v))
def testRaggedSparse(self, distribution, input_type, drop_remainder, defun_type): """Test with `RaggedTensor`s and `SparseTensor`s.""" if not tf2.enabled(): self.skipTest("Only V2 is supported.") defun = {"lambda": lambda f: f, "tf_function": def_function.function}[defun_type] distribution.extended.experimental_enable_get_next_as_optional = True global_batch_size = 8 def dataset_fn(ctx=None): ctx = ctx or distribute_lib.InputContext() batch_size = ctx.get_per_replica_batch_size(global_batch_size) # Use 20 which isn't divisible by 8 to test partial batch behavior. row_lengths = np.mod(np.arange(20), 4).astype(np.int64) ragged_tensor = ragged_tensor_lib.RaggedTensor.from_row_lengths( np.repeat(np.arange(20, dtype=np.float32), row_lengths), row_lengths) dataset = dataset_ops.DatasetV2.from_tensor_slices({ "dense": ragged_tensor.to_tensor(), "ragged": ragged_tensor, "sparse": ragged_tensor.to_sparse(), }) dataset = dataset.shard(ctx.num_input_pipelines, ctx.input_pipeline_id) return dataset.batch(batch_size, drop_remainder=drop_remainder) dataset_or_input_fn = self._create_dataset_or_input_fn( input_type, dataset_fn) dataset = self._wrap_dataset(input_type, dataset_or_input_fn, distribution.extended._input_workers, len(distribution.extended.worker_devices), distribution) # Assert that the tensors are rebatched and sparsity is preserved. per_replica_batch = defun(lambda x: next(iter(x)))(dataset) self.assertAllEqual( distribute_utils.select_replica(0, per_replica_batch["dense"]), [[0., 0., 0.], [1., 0., 0.], [2., 2., 0.], [3., 3., 3.]]) self.assertAllEqual( distribute_utils.select_replica(1, per_replica_batch["dense"]), [[0., 0., 0.], [5., 0., 0.], [6., 6., 0.], [7., 7., 7.]]) # Transitively check the ragged and sparse tensors by densification. for i in range(2): self.assertLen( distribute_utils.select_replica(i, per_replica_batch["ragged"]).values, 6) self.assertAllEqual( distribute_utils.select_replica( i, per_replica_batch["ragged"]).to_tensor(), distribute_utils.select_replica(i, per_replica_batch["dense"])) self.assertLen( distribute_utils.select_replica(i, per_replica_batch["sparse"]).indices, 6) self.assertAllEqual( sparse_ops.sparse_tensor_to_dense( distribute_utils.select_replica(i, per_replica_batch["sparse"])), distribute_utils.select_replica(i, per_replica_batch["dense"])) # Iterate through all the batches and sum them up. def sum_batch(per_replica_features): """Sums the `PerReplica` values in the `per_replica_features` map.""" def map_fn(per_replica_values): per_replica_sums = distribution.run( (lambda x: math_ops.reduce_sum(x.values)) if all( map(sparse_tensor.is_sparse, per_replica_values.values)) else math_ops.reduce_sum, (per_replica_values,)) return distribution.reduce( reduce_util.ReduceOp.SUM, per_replica_sums, axis=None) return nest.map_structure(map_fn, per_replica_features) def _reduce(state, batch): sums = sum_batch(batch) return {name: value + sums[name] for name, value in state.items()} def sum_for_loop(dataset): sums = {"dense": 0., "ragged": 0., "sparse": 0.} for batch in dataset: sums = _reduce(sums, batch) return sums def sum_while_loop(iterator, reduce_fn): sums = {"dense": 0., "ragged": 0., "sparse": 0.} while True: try: sums = reduce_fn(sums, iterator) except (StopIteration, errors.OutOfRangeError): return sums while_sums = sum_while_loop( iter(dataset), defun(lambda state, iterator: _reduce(state, next(iterator)))) self.assertAllEqual( nest.flatten(while_sums), # When there's no partial batch, the sum is smaller. [200. if drop_remainder else 310.] * 3) for_sums = defun(sum_for_loop)(dataset) # For loops always call get next as optional inside tf functions, so we # expect 310 here when using an input function (as there are 5 batches of # size 4 round robined over 2 replicas. expected_for_sum = 200. if (not drop_remainder or ( defun_type == "tf_function" and input_type == "input_fn")): expected_for_sum = 310. self.assertAllEqual(nest.flatten(for_sums), [expected_for_sum] * 3)