def _test_input_fn_iterator(self, task_type, task_id, num_gpus, input_fn, expected_values): distribution, master_target, config = self._get_test_object( task_type, task_id, num_gpus) devices = distribution.extended.worker_devices with ops.Graph().as_default(), \ self.cached_session(config=config, target=master_target) as sess: iterator = distribution.make_input_fn_iterator(input_fn) sess.run(iterator.initialize()) for expected_value in expected_values: next_element = iterator.get_next() computed_value = sess.run([values.select_replica(r, next_element) for r in range(len(devices))]) self.assertEqual(expected_value, computed_value) with self.assertRaises(errors.OutOfRangeError): next_element = iterator.get_next() sess.run([values.select_replica(r, next_element) for r in range(len(devices))]) # After re-initializing the iterator, should be able to iterate again. sess.run(iterator.initialize()) for expected_value in expected_values: next_element = iterator.get_next() computed_value = sess.run([values.select_replica(r, next_element) for r in range(len(devices))]) self.assertEqual(expected_value, computed_value)
def testNested(self): device_map = values.ReplicaDeviceMap((_device_str(0), _device_str(1))) result = values.regroup(device_map, (_nested_value("1"), _nested_value("2"))) self.assertIsInstance(result, tuple) self.assertEqual(3, len(result)) self._is_per_replica(result[0], ["a1", "a2"]) self._is_per_replica(result[2], ["h1", "h2"]) self.assertIsInstance(result[1], list) self.assertEqual(3, len(result[1])) self._is_per_replica(result[1][0], ["b1", "b2"]) self._is_per_replica(result[1][2], ["g1", "g2"]) self.assertIsInstance(result[1][1], dict) self.assertEqual(set(["c", "e"]), set(result[1][1].keys())) self._is_per_replica(result[1][1]["c"], ["d1", "d2"]) self._is_per_replica(result[1][1]["e"], ["f1", "f2"]) # Also test that we can undo the merge using select_replica() self.assertEqual(_nested_value("1"), values.select_replica(0, result)) self.assertEqual(_nested_value("2"), values.select_replica(1, result)) # select_device_mirrored() should fail due to non-mirrored values with self.assertRaises(TypeError): values.select_device_mirrored(_device_str(0), result) with self.assertRaises(TypeError): values.select_device_mirrored(_device_str(1), result)
def testWrapClass(self): # Normally a mirrored value would be the same across devices, but # for a test it is convenient to be able to tell the values apart. device_map = values.ReplicaDeviceMap((_device_str(0), _device_str(1))) result = values.regroup(device_map, (_nested_value("1"), _nested_value("2")), values.Mirrored) self.assertIsInstance(result, tuple) self.assertEqual(3, len(result)) self._is_per_replica(result[0], ["a1", "a2"], values.Mirrored) self._is_per_replica(result[2], ["h1", "h2"], values.Mirrored) self.assertIsInstance(result[1], list) self.assertEqual(3, len(result[1])) self._is_per_replica(result[1][0], ["b1", "b2"], values.Mirrored) self._is_per_replica(result[1][2], ["g1", "g2"], values.Mirrored) self.assertIsInstance(result[1][1], dict) self.assertEqual(set(["c", "e"]), set(result[1][1].keys())) self._is_per_replica(result[1][1]["c"], ["d1", "d2"], values.Mirrored) self._is_per_replica(result[1][1]["e"], ["f1", "f2"], values.Mirrored) # Also test that we can undo the merge using select_replica() self.assertEqual(_nested_value("1"), values.select_replica(0, result)) self.assertEqual(_nested_value("2"), values.select_replica(1, result)) # Values are marked as mirrored, so select_device_mirrored() is allowed. self.assertEqual(_nested_value("1"), values.select_device_mirrored(_device_str(0), result)) self.assertEqual(_nested_value("2"), values.select_device_mirrored(_device_str(1), result))
def _test_input_fn_iterator(self, iterator, devices, expected_values, sess=None, test_reinitialize=True): evaluate = lambda x: sess.run(x) if sess else self.evaluate(x) evaluate(iterator.initialize()) for expected_value in expected_values: next_element = iterator.get_next() computed_value = evaluate( [values.select_replica(r, next_element) for r in range(len(devices))]) self.assertEqual(expected_value, computed_value) with self.assertRaises(errors.OutOfRangeError): next_element = iterator.get_next() evaluate( [values.select_replica(r, next_element) for r in range(len(devices))]) # After re-initializing the iterator, should be able to iterate again. if test_reinitialize: evaluate(iterator.initialize()) for expected_value in expected_values: next_element = iterator.get_next() computed_value = evaluate([ values.select_replica(r, next_element) for r in range(len(devices)) ]) self.assertEqual(expected_value, computed_value)
def _tpu_run(strategy, fn, args, kwargs): """Common implementation of TPUStrategy.experimental_run_v2.""" if context.executing_eagerly() and not ops.inside_function(): raise NotImplementedError( "Eager mode not supported in TPUStrategy outside TF functions.") if kwargs is None: kwargs = {} # Used to re-structure flattened output tensors from `tpu.replicate()` # into a structured format. result = [[]] def replicated_fn(replica_id, replica_args, replica_kwargs): """Wraps user function to provide replica ID and `Tensor` inputs.""" with _TPUReplicaContext(strategy, replica_id_in_sync_group=replica_id): result[0] = fn(*replica_args, **replica_kwargs) return result[0] replicate_inputs = [] # By replica. for i in range(strategy.num_replicas_in_sync): replicate_inputs.append( [constant_op.constant(i, dtype=dtypes.int32), values.select_replica(i, args), values.select_replica(i, kwargs)]) # Construct and pass `maximum_shapes` so that we could support dynamic # shapes using dynamic padder. if replicate_inputs: maximum_shapes = [] flattened_list = nest.flatten(replicate_inputs[0]) for input_tensor in flattened_list: maximum_shapes.append(input_tensor.get_shape()) maximum_shapes = nest.pack_sequence_as(replicate_inputs[0], maximum_shapes) else: maximum_shapes = None with strategy.scope(): replicate_outputs = tpu.replicate(replicated_fn, replicate_inputs, maximum_shapes=maximum_shapes) # Remove all no ops that may have been added during 'tpu.replicate()' if isinstance(result[0], list): result[0] = [ output for output in result[0] if tensor_util.is_tensor(output) ] # Workaround for `tpu.replicate` behaviour when single `Tensor` returned. replicate_outputs = [ nest.pack_sequence_as(result[0], nest.flatten(replica_output)) for replica_output in replicate_outputs ] device_map = strategy.extended._device_map # pylint: disable=protected-access return values.regroup(device_map, replicate_outputs)
def testNamedTupleEstimatorSpec(self): with context.graph_mode(), ops.Graph().as_default(): devices = [] created_estimator_specs = [] for device_id in range(3): spec = model_fn_lib.EstimatorSpec( mode=model_fn_lib.ModeKeys.TRAIN, loss=constant_op.constant(device_id / 2), train_op=array_ops.identity(constant_op.constant(device_id))) devices.append(_device_str(device_id)) created_estimator_specs.append(spec) device_map = values.ReplicaDeviceMap(devices) merged_estimator_spec = values.regroup( device_map, created_estimator_specs) self.assertTrue( isinstance(merged_estimator_spec, model_fn_lib.EstimatorSpec)) self.assertEqual(model_fn_lib.ModeKeys.TRAIN, merged_estimator_spec.mode) for device_id in range(3): d = _device_str(device_id) self.assertEqual(created_estimator_specs[device_id].loss, merged_estimator_spec.loss.get(d)) self.assertEqual(created_estimator_specs[device_id].train_op, merged_estimator_spec.train_op.get(d)) # Scaffold is populated by `EstimatorSpec.__new__`. self.assertEqual(created_estimator_specs[device_id].scaffold, merged_estimator_spec.scaffold.get(d)) # Also test that we can undo the merge using select_replica() self.assertEqual(created_estimator_specs[device_id], values.select_replica(device_id, merged_estimator_spec))
def _test_iterator(self, sess, iterator, devices, expected_values): next_element = iterator.get_next() for r, device in enumerate(devices): v = values.select_replica(r, next_element) # The `v` here can be a tuple. for element in nest.flatten(v): self.assertTrue(element.device in device) for expected_value in expected_values: t = [values.select_replica(r, next_element) for r in range(len(devices))] actual = sess.run(t) self.assertEqual(expected_value, actual) with self.assertRaises(errors.OutOfRangeError): sess.run([values.select_replica(r, next_element) for r in range(len(devices))])
def experimental_run_v2(self, fn, args=(), kwargs=None): """See base class.""" if context.executing_eagerly() and not ops.inside_function(): raise NotImplementedError( "Eager mode not supported in TPUStrategy outside TF functions.") if kwargs is None: kwargs = {} # Used to re-structure flattened output tensors from `tpu.replicate()` # into a structured format. result = [[]] def replicated_fn(replica_id, replica_args, replica_kwargs): """Wraps user function to provide replica ID and `Tensor` inputs.""" with _TPUReplicaContext(self, replica_id_in_sync_group=replica_id): result[0] = fn(*replica_args, **replica_kwargs) return result[0] replicate_inputs = [] # By replica. for i in range(self.num_replicas_in_sync): replicate_inputs.append( [constant_op.constant(i, dtype=dtypes.int32), values.select_replica(i, args), values.select_replica(i, kwargs)]) with self.scope(): replicate_outputs = tpu.replicate(replicated_fn, replicate_inputs) # Remove all no ops that may have been added during 'tpu.replicate()' if isinstance(result[0], list): result[0] = [ output for output in result[0] if tensor_util.is_tensor(output) ] # Workaround for `tpu.replicate` behaviour when single `Tensor` returned. replicate_outputs = [ nest.pack_sequence_as(result[0], nest.flatten(replica_output)) for replica_output in replicate_outputs ] device_map = self.extended._device_map # pylint: disable=protected-access return values.regroup(device_map, replicate_outputs)
def testSameId(self): foo = object() device_map = values.ReplicaDeviceMap((_device_str(0), _device_str(1))) result = values.regroup(device_map, (("a", foo), ("b", foo))) self.assertIsInstance(result, tuple) self.assertEqual(2, len(result)) self._is_per_replica(result[0], ["a", "b"]) self.assertIs(foo, result[1]) # Test select_replica(), should undo the merge done by regroup(). result_0 = values.select_replica(0, result) self.assertIsInstance(result_0, tuple) self.assertEqual(2, len(result_0)) self.assertEqual("a", result_0[0]) self.assertIs(foo, result_0[1]) result_1 = values.select_replica(1, result) self.assertIsInstance(result_1, tuple) self.assertEqual(2, len(result_1)) self.assertEqual("b", result_1[0]) self.assertIs(foo, result_1[1])
def _test_iterator(self, input_type, dataset_fn, worker_device_pairs, expected_values, sess=None, split_batch_by=None, enable_get_next_as_optional=False): devices = nest.flatten([ds for _, ds in worker_device_pairs]) iterator = self._create_iterator( input_type, dataset_fn, worker_device_pairs, devices, split_batch_by, enable_get_next_as_optional) evaluate = lambda x: sess.run(x) if sess else self.evaluate(x) evaluate(control_flow_ops.group(iterator.initialize())) for expected_value in expected_values: next_element = iterator.get_next() computed_value = evaluate( [values.select_replica(r, next_element) for r in range(len(devices))]) self.assertEqual(len(expected_value), len(computed_value)) for i in range(len(expected_value)): self.assertAllEqual(expected_value[i], computed_value[i]) with self.assertRaises(errors.OutOfRangeError): next_element = iterator.get_next() evaluate( [values.select_replica(r, next_element) for r in range(len(devices))]) # After re-initializing the iterator, should be able to iterate again. evaluate(control_flow_ops.group(iterator.initialize())) for expected_value in expected_values: next_element = iterator.get_next() computed_value = evaluate( [values.select_replica(r, next_element) for r in range(len(devices))]) self.assertEqual(len(expected_value), len(computed_value)) for i in range(len(expected_value)): self.assertAllEqual(expected_value[i], computed_value[i])
def _test_iterator(self, input_type, dataset_fn, worker_device_pairs, expected_values, sess=None, split_batch_by=None): devices = nest.flatten([ds for _, ds in worker_device_pairs]) device_map = values.ReplicaDeviceMap(devices) input_workers = input_lib.InputWorkers(device_map, worker_device_pairs) if input_type == "input_fn": input_contexts = [ distribute_lib.InputContext() for _ in worker_device_pairs] input_fn = lambda _: dataset_fn() iterator = input_lib.InputFunctionIterator( input_fn, input_workers, input_contexts) else: iterator = input_lib.DatasetIterator( dataset_fn(), input_workers, split_batch_by) evaluate = lambda x: sess.run(x) if sess else self.evaluate(x) evaluate(control_flow_ops.group(iterator.initialize())) for expected_value in expected_values: next_element = iterator.get_next() computed_value = evaluate( [values.select_replica(r, next_element) for r in range(len(devices))]) self.assertAllEqual(expected_value, computed_value) with self.assertRaises(errors.OutOfRangeError): next_element = iterator.get_next() evaluate([values.select_replica(r, next_element) for r in range(len(devices))]) # After re-initializing the iterator, should be able to iterate again. evaluate(control_flow_ops.group(iterator.initialize())) for expected_value in expected_values: next_element = iterator.get_next() computed_value = evaluate( [values.select_replica(r, next_element) for r in range(len(devices))]) self.assertAllEqual(expected_value, computed_value)
def testOneDevice(self): device_map = values.ReplicaDeviceMap((_device_str(0),)) result = values.regroup(device_map, (_nested_value("1"),)) # On one device regroup() and select_replica() are basically identity. self.assertEqual(_nested_value("1"), result) self.assertEqual(_nested_value("1"), values.select_replica(0, result)) # The one exception has to do with MirroredVariables. d = "/device:CPU:0" with ops.device(d): v = variable_scope.get_variable( name="v", initializer=1., use_resource=True) device_map = values.ReplicaDeviceMap((d,)) mirrored = values.MirroredVariable(None, device_map, (v,), variable_scope.VariableAggregation.SUM) result = values.regroup(device_map, (v,)) self.assertIs(mirrored, result)
def experimental_run(self, fn, input_iterator=None): """See base class.""" if context.executing_eagerly(): raise NotImplementedError("Eager mode not supported in TPUStrategy.") if self.extended._disable_training_loop_on_host: # pylint: disable=protected-access raise NotImplementedError( "`experimental_run` is not compatible with " "`_disable_training_loop_on_host=True`") if input_iterator is None: inputs = [] else: inputs = input_iterator.get_next() result = [None] def replicated_fn(replica_id, inputs): """Wraps user function to provide replica ID and `Tensor` inputs.""" with _TPUReplicaContext(self, replica_id_in_sync_group=replica_id): if input_iterator is None: result[0] = fn() else: result[0] = fn(inputs) return result[0] replicate_inputs = [] # By replica. for i in range(self.num_replicas_in_sync): replicate_inputs.append( [constant_op.constant(i, dtype=dtypes.int32), values.select_replica(i, inputs)]) with self.scope(): replicate_outputs = tpu.replicate(replicated_fn, replicate_inputs) # Workaround for `tpu.replicate` behaviour when single `Tensor` returned. replicate_outputs = [ nest.pack_sequence_as(result[0], nest.flatten(replica_outputs)) for replica_outputs in replicate_outputs] device_map = self.extended._device_map # pylint: disable=protected-access return values.regroup(device_map, replicate_outputs)
def rewrite_fn(*args): """The rewritten step fn running on TPU.""" del args per_replica_inputs = multi_worker_iterator.get_next() replicate_inputs = [] for replica_id in range(self._num_replicas_in_sync): select_replica = lambda x: values.select_replica(replica_id, x) # pylint: disable=cell-var-from-loop replicate_inputs.append((nest.map_structure( select_replica, per_replica_inputs),)) replicate_outputs = tpu.replicate(run_fn, replicate_inputs) # If run_fn has tensor outputs, tpu.replicate returns a list of list. We # will flatten it in this case. If run_fn has no tensor outputs, # tpu.replicate returns a list of no_ops, we will keep the output as it # is. if isinstance(replicate_outputs[0], list): replicate_outputs = nest.flatten(replicate_outputs) return replicate_outputs
def tpu_function(args, kwargs): """TF Function used to replicate the user computation.""" if kwargs is None: kwargs = {} # Used to re-structure flattened output tensors from `tpu.replicate()` # into a structured format. result = [[]] def replicated_fn(replica_id, replica_args, replica_kwargs): """Wraps user function to provide replica ID and `Tensor` inputs.""" with _TPUReplicaContext(strategy, replica_id_in_sync_group=replica_id): result[0] = fn(*replica_args, **replica_kwargs) return result[0] replicate_inputs = [] # By replica. for i in range(strategy.num_replicas_in_sync): replicate_inputs.append([ constant_op.constant(i, dtype=dtypes.int32), values.select_replica(i, args), values.select_replica(i, kwargs) ]) # Construct and pass `maximum_shapes` so that we could support dynamic # shapes using dynamic padder. if replicate_inputs: maximum_shapes = [] flattened_list = nest.flatten(replicate_inputs[0]) for input_tensor in flattened_list: if tensor_util.is_tensor(input_tensor): maximum_shape = input_tensor.get_shape() else: maximum_shape = tensor_shape.TensorShape( np.shape(input_tensor)) maximum_shapes.append(maximum_shape) maximum_shapes = nest.pack_sequence_as(replicate_inputs[0], maximum_shapes) else: maximum_shapes = None with strategy.scope(): replicate_outputs = tpu.replicate( replicated_fn, replicate_inputs, maximum_shapes=maximum_shapes) # Remove all no ops that may have been added during 'tpu.replicate()' if isinstance(result[0], list): result[0] = [ output for output in result[0] if tensor_util.is_tensor(output) ] # Workaround for `tpu.replicate` behaviour when single `Tensor` returned. if result[0] is None: replicate_outputs = [None] * len(replicate_outputs) else: replicate_outputs = [ nest.pack_sequence_as(result[0], nest.flatten(replica_output)) for replica_output in replicate_outputs ] device_map = self._device_map # pylint: disable=protected-access return values.regroup(device_map, replicate_outputs)
def tpu_function(args, kwargs): """TF Function used to replicate the user computation.""" if kwargs is None: kwargs = {} # Remove None at the end of args as they are not replicatable # If there are None in the middle we can't do anything about it # so let those cases fail. # For example when Keras model predict is used they pass the targets as # None. We want to handle it here so all client libraries don't have to # do this as other strategies can handle None values better. while args and args[-1] is None: args = args[:-1] # Used to re-structure flattened output tensors from `tpu.replicate()` # into a structured format. result = [[]] def replicated_fn(replica_id, replica_args, replica_kwargs): """Wraps user function to provide replica ID and `Tensor` inputs.""" with _TPUReplicaContext(strategy, replica_id_in_sync_group=replica_id): result[0] = fn(*replica_args, **replica_kwargs) return result[0] replicate_inputs = [] # By replica. for i in range(strategy.num_replicas_in_sync): replicate_inputs.append([ constant_op.constant(i, dtype=dtypes.int32), values.select_replica(i, args), values.select_replica(i, kwargs) ]) # Construct and pass `maximum_shapes` so that we could support dynamic # shapes using dynamic padder. if self.experimental_enable_dynamic_batch_size and replicate_inputs: maximum_shapes = [] flattened_list = nest.flatten(replicate_inputs[0]) for input_tensor in flattened_list: if tensor_util.is_tensor(input_tensor): maximum_shape = input_tensor.get_shape() else: maximum_shape = tensor_shape.TensorShape( np.shape(input_tensor)) maximum_shapes.append(maximum_shape) maximum_shapes = nest.pack_sequence_as(replicate_inputs[0], maximum_shapes) else: maximum_shapes = None with strategy.scope(): replicate_outputs = tpu.replicate( replicated_fn, replicate_inputs, device_assignment=self._device_assignment, maximum_shapes=maximum_shapes) # Remove all no ops that may have been added during 'tpu.replicate()' if isinstance(result[0], list): result[0] = [ output for output in result[0] if tensor_util.is_tensor(output) ] # Workaround for `tpu.replicate` behaviour when single `Tensor` returned. if result[0] is None: replicate_outputs = [None] * len(replicate_outputs) else: replicate_outputs = [ nest.pack_sequence_as(result[0], nest.flatten(replica_output)) for replica_output in replicate_outputs ] device_map = self._device_map # pylint: disable=protected-access return values.regroup(device_map, replicate_outputs)
def _test_input_iteration(self, input_type, api_type, iteration_type, dataset_fn, worker_device_pairs, expected_values, sess=None, split_batch_by=None, enable_get_next_as_optional=False): if iteration_type == "for_loop" and not context.executing_eagerly(): self.skipTest("unsupported test combination.") if api_type == "wrap_into_iterator" and iteration_type == "for_loop": self.skipTest("unsupported test combination.") if api_type == "wrap_into_dataset" and input_type == "input_fn": self.skipTest("unsupported test combination.") devices = nest.flatten([ds for _, ds in worker_device_pairs]) device_map = values.ReplicaDeviceMap(devices) input_workers = input_lib.InputWorkers(device_map, worker_device_pairs) if api_type == "wrap_into_iterator": iterator = self._wrap_iterator( input_type, dataset_fn, input_workers, devices, split_batch_by, enable_get_next_as_optional) else: # wrapping into a dataset: given_dataset = dataset_fn(distribute_lib.InputContext()) dataset = self._wrap_dataset(input_type, given_dataset, input_workers, split_batch_by, enable_get_next_as_optional) if context.executing_eagerly(): iterator = iter(dataset) else: # The dataset can be a tf.data.DatasetV1Adapter instance since we wrap # tf.data.DatasetV1 as a tf.data.DatasetV1Adapter instance when we # autoshard the dataset. if not isinstance(dataset, (dataset_ops.DatasetV1, dataset_ops.DatasetV1Adapter)): iterator = iter(dataset) else: iterator = dataset.make_one_shot_iterator() if iteration_type == "get_next": evaluate = lambda x: sess.run(x) if sess else self.evaluate(x) if isinstance(iterator, input_lib.DistributedIteratorV1): evaluate(control_flow_ops.group(iterator.initialize())) else: evaluate(control_flow_ops.group(iterator._initializer)) for expected_value in expected_values: next_element = iterator.get_next() computed_value = evaluate( [values.select_replica(r, next_element) for r in range(len(devices))]) self.assertEqual(len(expected_value), len(computed_value)) for i in range(len(expected_value)): self.assertAllEqual(expected_value[i], computed_value[i]) with self.assertRaises(errors.OutOfRangeError): next_element = iterator.get_next() evaluate( [values.select_replica(r, next_element) for r in range(len(devices))]) # After re-initializing the iterator, should be able to iterate again. if isinstance(iterator, input_lib.DistributedIteratorV1): evaluate(control_flow_ops.group(iterator.initialize())) else: evaluate(control_flow_ops.group(iterator._initializer)) for expected_value in expected_values: next_element = iterator.get_next() computed_value = evaluate( [values.select_replica(r, next_element) for r in range(len(devices))]) self.assertEqual(len(expected_value), len(computed_value)) for i in range(len(expected_value)): self.assertAllEqual(expected_value[i], computed_value[i]) if iteration_type == "for_loop" and context.executing_eagerly(): actual_values = [] for x in dataset: computed_value = self.evaluate( [values.select_replica(r, x) for r in range(len(devices))]) actual_values.append(computed_value) for i, expected_value in enumerate(expected_values): self.assertEqual(len(expected_value), len(actual_values[i])) for j in range(len(expected_value)): self.assertAllEqual(expected_value[j], actual_values[i][j])
def _call_for_each_replica(distribution, devices, fn, args, kwargs): """Run `fn` in separate threads, once per replica/worker device. Args: distribution: the DistributionStrategy object. devices: the devices to run `fn` on (logical device 0 for each replica). fn: function to run (will be run once per replica, each in its own thread). args: positional arguments for `fn` kwargs: keyword arguments for `fn`. Returns: Merged return value of `fn` across all replicas. Raises: RuntimeError: If fn() calls get_replica_context().merge_call() a different number of times from the available devices. """ # TODO(josh11b): Add this option once we add synchronization to variable # creation. Until then, this is pretty unsafe to use. run_concurrently = False if not context.executing_eagerly(): # Needed for per-thread device, etc. contexts in graph mode. ops.get_default_graph().switch_to_thread_local() coord = coordinator.Coordinator( clean_stop_exception_types=(_RequestedStop, )) shared_variable_store = {} # TODO(isaprykin): Create these threads once instead of during every call. threads = [] for index in range(len(devices)): variable_creator_fn = shared_variable_creator.make_fn( shared_variable_store, index) t = _MirroredReplicaThread(distribution, coord, index, devices, variable_creator_fn, fn, values.select_replica(index, args), values.select_replica(index, kwargs)) threads.append(t) for t in threads: t.start() # When `fn` starts `should_run` event is set on _MirroredReplicaThread # (`MRT`) threads. The execution waits until # `MRT.has_paused` is set, which indicates that either `fn` is # complete or a `get_replica_context().merge_call()` is called. If `fn` is # complete, then `MRT.done` is set to True. Otherwise, arguments # of `get_replica_context().merge_call` from all paused threads are grouped # and the `merge_fn` is performed. Results of the # `get_replica_context().merge_call` are then set to `MRT.merge_result`. # Each such `get_replica_context().merge_call` call returns the # `MRT.merge_result` for that thread when `MRT.should_run` event # is reset again. Execution of `fn` resumes. try: with coord.stop_on_exception(): all_done = False while not all_done and not coord.should_stop(): done = [] if run_concurrently: for t in threads: t.should_run.set() for t in threads: t.has_paused.wait() t.has_paused.clear() if coord.should_stop(): return None done.append(t.done) else: for t in threads: t.should_run.set() t.has_paused.wait() t.has_paused.clear() if coord.should_stop(): return None done.append(t.done) if coord.should_stop(): return None all_done = all(done) if not all_done: if any(done): raise RuntimeError( "Some replicas made a different number of " "replica_context().merge_call() calls.") # get_replica_context().merge_call() case merge_args = values.regroup( tuple(t.merge_args for t in threads)) merge_kwargs = values.regroup( tuple(t.merge_kwargs for t in threads)) # We capture the name_scope of the MRT when we call merge_fn # to ensure that if we have opened a name scope in the MRT, # it will be respected when executing the merge function. We only # capture the name_scope from the first MRT and assume it is # the same for all other MRTs. mtt_captured_name_scope = threads[0].captured_name_scope mtt_captured_var_scope = threads[0].captured_var_scope # Capture and merge the control dependencies from all the threads. mtt_captured_control_deps = set() for t in threads: mtt_captured_control_deps.update( t.captured_control_deps) with ops.name_scope(mtt_captured_name_scope),\ ops.control_dependencies(mtt_captured_control_deps), \ variable_scope.variable_scope(mtt_captured_var_scope): merge_result = threads[0].merge_fn( distribution, *merge_args, **merge_kwargs) for r, t in enumerate(threads): t.merge_result = values.select_replica(r, merge_result) finally: for t in threads: t.should_run.set() coord.join(threads) return values.regroup(tuple(t.main_result for t in threads))
def _test_input_iteration(self, input_type, api_type, iteration_type, dataset_or_input_fn, worker_device_pairs, expected_values, strategy, sess=None, split_batch_by=None, input_context=None): if iteration_type == "for_loop" and not context.executing_eagerly(): self.skipTest("unsupported test combination.") if api_type == "wrap_into_iterator" and iteration_type == "for_loop": self.skipTest("unsupported test combination.") devices = nest.flatten([ds for _, ds in worker_device_pairs]) input_workers = input_lib.InputWorkers(worker_device_pairs) if api_type == "wrap_into_iterator": iterator = self._wrap_iterator(input_type, dataset_or_input_fn, input_workers, devices, split_batch_by, strategy, input_context=input_context) else: # wrapping into a dataset: dataset = self._wrap_dataset(input_type, dataset_or_input_fn, input_workers, split_batch_by, strategy, input_context=input_context) if context.executing_eagerly(): iterator = iter(dataset) else: if isinstance(dataset, input_lib.DistributedDatasetV1): iterator = dataset.make_initializable_iterator() else: self.skipTest("unsupported test combination") if iteration_type == "get_next": evaluate = lambda x: sess.run(x) if sess else self.evaluate(x) if isinstance(iterator, input_lib.DistributedIteratorV1): evaluate(control_flow_ops.group(iterator.initializer)) for expected_value in expected_values: next_element = iterator.get_next() computed_value = evaluate([ values.select_replica(r, next_element) for r in range(len(devices)) ]) self.assertEqual(len(expected_value), len(computed_value)) for i in range(len(expected_value)): self.assertAllEqual(expected_value[i], computed_value[i]) with self.assertRaises(errors.OutOfRangeError): next_element = iterator.get_next() evaluate([ values.select_replica(r, next_element) for r in range(len(devices)) ]) # After re-initializing the iterator, should be able to iterate again. if isinstance(iterator, input_lib.DistributedIteratorV1): evaluate(control_flow_ops.group(iterator.initializer)) else: evaluate(control_flow_ops.group(iterator._initializer)) for expected_value in expected_values: next_element = iterator.get_next() computed_value = evaluate([ values.select_replica(r, next_element) for r in range(len(devices)) ]) self.assertEqual(len(expected_value), len(computed_value)) for i in range(len(expected_value)): self.assertAllEqual(expected_value[i], computed_value[i]) if iteration_type == "for_loop" and context.executing_eagerly(): actual_values = [] for x in dataset: computed_value = self.evaluate( [values.select_replica(r, x) for r in range(len(devices))]) actual_values.append(computed_value) for i, expected_value in enumerate(expected_values): self.assertEqual(len(expected_value), len(actual_values[i])) for j in range(len(expected_value)): self.assertAllEqual(expected_value[j], actual_values[i][j])
def testRaggedSparse(self, distribution, input_type, drop_remainder, defun): """Test with `RaggedTensor`s and `SparseTensor`s.""" if not tf2.enabled(): self.skipTest("Only V2 is supported.") distribution.extended.experimental_enable_get_next_as_optional = True global_batch_size = 8 def dataset_fn(ctx=None): ctx = ctx or distribute_lib.InputContext() batch_size = ctx.get_per_replica_batch_size(global_batch_size) # Use 20 which isn't divisible by 8 to test partial batch behavior. row_lengths = np.mod(np.arange(20), 4).astype(np.int64) ragged_tensor = ragged_tensor_lib.RaggedTensor.from_row_lengths( np.repeat(np.arange(20, dtype=np.float32), row_lengths), row_lengths) dataset = dataset_ops.DatasetV2.from_tensor_slices({ "dense": ragged_tensor.to_tensor(), "ragged": ragged_tensor, "sparse": ragged_tensor.to_sparse(), }) dataset = dataset.shard(ctx.num_input_pipelines, ctx.input_pipeline_id) return dataset.batch(batch_size, drop_remainder=drop_remainder) dataset_or_input_fn = self._create_dataset_or_input_fn( input_type, dataset_fn) dataset = self._wrap_dataset(input_type, dataset_or_input_fn, distribution.extended._input_workers, len(distribution.extended.worker_devices), distribution) # Assert that the tensors are rebatched and sparsity is preserved. per_replica_batch = defun(lambda x: next(iter(x)))(dataset) self.assertAllEqual( values.select_replica(0, per_replica_batch["dense"]), [[0., 0., 0.], [1., 0., 0.], [2., 2., 0.], [3., 3., 3.]]) self.assertAllEqual( values.select_replica(1, per_replica_batch["dense"]), [[0., 0., 0.], [5., 0., 0.], [6., 6., 0.], [7., 7., 7.]]) # Transitively check the ragged and sparse tensors by densification. for i in range(2): self.assertLen( values.select_replica(i, per_replica_batch["ragged"]).values, 6) self.assertAllEqual( values.select_replica(i, per_replica_batch["ragged"]).to_tensor(), values.select_replica(i, per_replica_batch["dense"])) self.assertLen( values.select_replica(i, per_replica_batch["sparse"]).indices, 6) self.assertAllEqual( sparse_ops.sparse_tensor_to_dense( values.select_replica(i, per_replica_batch["sparse"])), values.select_replica(i, per_replica_batch["dense"])) # Iterate through all the batches and sum them up. def sum_batch(per_replica_features): """Sums the `PerReplica` values in the `per_replica_features` map.""" def map_fn(per_replica_values): per_replica_sums = distribution.experimental_run_v2( (lambda x: math_ops.reduce_sum(x.values)) if all( map(sparse_tensor.is_sparse, per_replica_values.values)) else math_ops.reduce_sum, (per_replica_values, )) return distribution.reduce(reduce_util.ReduceOp.SUM, per_replica_sums, axis=None) return nest.map_structure(map_fn, per_replica_features) def _reduce(state, batch): sums = sum_batch(batch) return {name: value + sums[name] for name, value in state.items()} def sum_for_loop(dataset): sums = {"dense": 0., "ragged": 0., "sparse": 0.} for batch in dataset: sums = _reduce(sums, batch) return sums def sum_while_loop(iterator, reduce_fn): sums = {"dense": 0., "ragged": 0., "sparse": 0.} while True: try: sums = reduce_fn(sums, iterator) except (StopIteration, errors.OutOfRangeError): return sums sums = sum_while_loop( iter(dataset), defun(lambda state, iterator: _reduce(state, next(iterator)))) self.assertDictEqual(sums, defun(sum_for_loop)(dataset)) self.assertAllEqual( nest.flatten(sums), # When there's no partial batch, the sum is smaller. [200. if input_type == "dataset" and drop_remainder else 310.] * 3)
def _call_for_each_replica(distribution, device_map, fn, args, kwargs): """Run `fn` in separate threads, once per replica/worker device. Args: distribution: the DistributionStrategy object. device_map: the DeviceMap with the devices to run `fn` on. fn: function to run (will be run once per replica, each in its own thread). args: positional arguments for `fn` kwargs: keyword arguments for `fn`. Returns: Merged return value of `fn` across all replicas. Raises: RuntimeError: If fn() calls get_replica_context().merge_call() a different number of times from the available devices. """ # TODO(josh11b): Add this option once we add synchronization to variable # creation. Until then, this is pretty unsafe to use. run_concurrently = False if not context.executing_eagerly(): # Needed for per-thread device, etc. contexts in graph mode. ops.get_default_graph().switch_to_thread_local() coord = coordinator.Coordinator(clean_stop_exception_types=(_RequestedStop,)) shared_variable_store = {} # TODO(isaprykin): Create these threads once instead of during every call. threads = [] for index in range(device_map.num_replicas_in_graph): variable_creator_fn = shared_variable_creator.make_fn( shared_variable_store, index) t = _MirroredReplicaThread( distribution, coord, index, device_map, variable_creator_fn, fn, values.select_replica(index, args), values.select_replica(index, kwargs)) threads.append(t) for t in threads: t.start() # When `fn` starts `should_run` event is set on _MirroredReplicaThread # (`MRT`) threads. The execution waits until # `MRT.has_paused` is set, which indicates that either `fn` is # complete or a `get_replica_context().merge_call()` is called. If `fn` is # complete, then `MRT.done` is set to True. Otherwise, arguments # of `get_replica_context().merge_call` from all paused threads are grouped # and the `merge_fn` is performed. Results of the # `get_replica_context().merge_call` are then set to `MRT.merge_result`. # Each such `get_replica_context().merge_call` call returns the # `MRT.merge_result` for that thread when `MRT.should_run` event # is reset again. Execution of `fn` resumes. try: with coord.stop_on_exception(): all_done = False while not all_done and not coord.should_stop(): done = [] if run_concurrently: for t in threads: t.should_run.set() for t in threads: t.has_paused.wait() t.has_paused.clear() if coord.should_stop(): return None done.append(t.done) else: for t in threads: t.should_run.set() t.has_paused.wait() t.has_paused.clear() if coord.should_stop(): return None done.append(t.done) if coord.should_stop(): return None all_done = all(done) if not all_done: if any(done): raise RuntimeError("Some replicas made a different number of " "replica_context().merge_call() calls.") # get_replica_context().merge_call() case merge_args = values.regroup( device_map, tuple(t.merge_args for t in threads)) merge_kwargs = values.regroup( device_map, tuple(t.merge_kwargs for t in threads)) # We capture the name_scope of the MRT when we call merge_fn # to ensure that if we have opened a name scope in the MRT, # it will be respected when executing the merge function. We only # capture the name_scope from the first MRT and assume it is # the same for all other MRTs. mtt_captured_name_scope = threads[0].captured_name_scope # Capture and merge the control dependencies from all the threads. mtt_captured_control_deps = set() for t in threads: mtt_captured_control_deps.update(t.captured_control_deps) with ops.name_scope(mtt_captured_name_scope),\ ops.control_dependencies(mtt_captured_control_deps): merge_result = threads[0].merge_fn(distribution, *merge_args, **merge_kwargs) for r, t in enumerate(threads): t.merge_result = values.select_replica(r, merge_result) finally: for t in threads: t.should_run.set() coord.join(threads) return values.regroup(device_map, tuple(t.main_result for t in threads))