def testNested(self): result = values.regroup({ _device_str(0): _nested_value("1"), _device_str(1): _nested_value("2") }) self.assertIsInstance(result, tuple) self.assertEqual(3, len(result)) self._is_per_device(result[0], ["a1", "a2"]) self._is_per_device(result[2], ["h1", "h2"]) self.assertIsInstance(result[1], list) self.assertEqual(3, len(result[1])) self._is_per_device(result[1][0], ["b1", "b2"]) self._is_per_device(result[1][2], ["g1", "g2"]) self.assertIsInstance(result[1][1], dict) self.assertEqual(set(["c", "e"]), set(result[1][1].keys())) self._is_per_device(result[1][1]["c"], ["d1", "d2"]) self._is_per_device(result[1][1]["e"], ["f1", "f2"]) # Also test that we can undo the merge using select_device() self.assertEqual(_nested_value("1"), values.select_device(_device_str(0), result)) self.assertEqual(_nested_value("2"), values.select_device(_device_str(1), result)) # select_device_mirrored() should fail due to non-mirrored values with self.assertRaises(TypeError): values.select_device_mirrored(_device_str(0), result) with self.assertRaises(TypeError): values.select_device_mirrored(_device_str(1), result)
def _update(self, var, fn, *args, **kwargs): # TODO(jhseu): Consider supporting grouped==False. assert isinstance(var, values.TPUMirroredVariable) if values._enclosing_tpu_context() is not None: # pylint: disable=protected-access return fn(var, *args, **kwargs) # Otherwise, we revert to MirroredStrategy behavior and update each variable # directly. updates = {} for d, v in var._index.items(): # pylint: disable=protected-access name = "update_%d" % self._device_index.get(d) with ops.device(d), distribute_lib.UpdateContext( d), ops.name_scope(name): # If args and kwargs are not mirrored, the value is returned as is. updates[d] = fn(v, *values.select_device_mirrored(d, args), **values.select_device_mirrored(d, kwargs)) # Make a single control dependency to keep the variables mirrored. If one # assignment is fetched, then run all assignments. sorted_keys = sorted(updates.keys()) update_tuple = control_flow_ops.tuple( [updates[d] for d in sorted_keys]) for i, d in enumerate(sorted_keys): updates[d] = update_tuple[i] return values.regroup(updates, values.Mirrored)
def testWrapClass(self): # Normally a mirrored value would be the same across devices, but # for a test it is convenient to be able to tell the values apart. result = values.regroup( { _device_str(0): _nested_value("1"), _device_str(1): _nested_value("2") }, values.Mirrored) self.assertIsInstance(result, tuple) self.assertEqual(3, len(result)) self._is_per_device(result[0], ["a1", "a2"], values.Mirrored) self._is_per_device(result[2], ["h1", "h2"], values.Mirrored) self.assertIsInstance(result[1], list) self.assertEqual(3, len(result[1])) self._is_per_device(result[1][0], ["b1", "b2"], values.Mirrored) self._is_per_device(result[1][2], ["g1", "g2"], values.Mirrored) self.assertIsInstance(result[1][1], dict) self.assertEqual(set(["c", "e"]), set(result[1][1].keys())) self._is_per_device(result[1][1]["c"], ["d1", "d2"], values.Mirrored) self._is_per_device(result[1][1]["e"], ["f1", "f2"], values.Mirrored) # Also test that we can undo the merge using select_device() self.assertEqual(_nested_value("1"), values.select_device(_device_str(0), result)) self.assertEqual(_nested_value("2"), values.select_device(_device_str(1), result)) # Values are marked as mirrored, so select_device_mirrored() is allowed. self.assertEqual(_nested_value("1"), values.select_device_mirrored(_device_str(0), result)) self.assertEqual(_nested_value("2"), values.select_device_mirrored(_device_str(1), result))
def testMirroredContainer(self): if context.num_gpus() < 1 and context.executing_eagerly(): self.skipTest( "A GPU is not available for this test in eager mode.") v, devices, mirrored = _make_mirrored() result = values.regroup(dict(zip(devices, v))) self.assertIs(mirrored, result)
def testNested(self): result = values.regroup({_device_str(0): _nested_value("1"), _device_str(1): _nested_value("2")}) self.assertIsInstance(result, tuple) self.assertEqual(3, len(result)) self._is_per_device(result[0], ["a1", "a2"]) self._is_per_device(result[2], ["h1", "h2"]) self.assertIsInstance(result[1], list) self.assertEqual(3, len(result[1])) self._is_per_device(result[1][0], ["b1", "b2"]) self._is_per_device(result[1][2], ["g1", "g2"]) self.assertIsInstance(result[1][1], dict) self.assertEqual(set(["c", "e"]), set(result[1][1].keys())) self._is_per_device(result[1][1]["c"], ["d1", "d2"]) self._is_per_device(result[1][1]["e"], ["f1", "f2"]) # Also test that we can undo the merge using select_device() self.assertEqual(_nested_value("1"), values.select_device(_device_str(0), result)) self.assertEqual(_nested_value("2"), values.select_device(_device_str(1), result)) # select_device_mirrored() should fail due to non-mirrored values with self.assertRaises(TypeError): values.select_device_mirrored(_device_str(0), result) with self.assertRaises(TypeError): values.select_device_mirrored(_device_str(1), result)
def testNamedTupleEstimatorSpec(self): with context.graph_mode(), ops.Graph().as_default(): created_estimator_specs = {} to_regroup = {} for device_id in range(3): spec = model_fn_lib.EstimatorSpec( mode=model_fn_lib.ModeKeys.TRAIN, loss=constant_op.constant(device_id / 2), train_op=array_ops.identity( constant_op.constant(device_id))) created_estimator_specs[device_id] = spec to_regroup[_device_str(device_id)] = spec merged_estimator_spec = values.regroup(to_regroup) self.assertTrue( isinstance(merged_estimator_spec, model_fn_lib.EstimatorSpec)) self.assertEquals(model_fn_lib.ModeKeys.TRAIN, merged_estimator_spec.mode) for device_id in range(3): d = _device_str(device_id) self.assertEquals(created_estimator_specs[device_id].loss, merged_estimator_spec.loss.get(d)) self.assertEquals(created_estimator_specs[device_id].train_op, merged_estimator_spec.train_op.get(d)) # Scaffold is populated by `EstimatorSpec.__new__`. self.assertEquals(created_estimator_specs[device_id].scaffold, merged_estimator_spec.scaffold.get(d)) # Also test that we can undo the merge using select_device() self.assertEquals( created_estimator_specs[device_id], values.select_device(_device_str(device_id), merged_estimator_spec))
def testOneDevice(self): result = values.regroup({_device_str(0): _nested_value("1")}) # On one device regroup() and select_device() are basically identity. self.assertEqual(_nested_value("1"), result) self.assertEqual(_nested_value("1"), values.select_device(_device_str(0), result)) # The one exception has to do with MirroredVariables. d = "/device:CPU:0" with ops.device(d): v = variable_scope.get_variable( name="v", initializer=1., use_resource=True) index = {d: v} mirrored = values.MirroredVariable(index, v) result = values.regroup(index) self.assertIs(mirrored, result)
def testNamedTupleEstimatorSpec(self): with context.graph_mode(), ops.Graph().as_default(): created_estimator_specs = {} to_regroup = {} for device_id in range(3): spec = model_fn_lib.EstimatorSpec( mode=model_fn_lib.ModeKeys.TRAIN, loss=constant_op.constant(device_id / 2), train_op=array_ops.identity(constant_op.constant(device_id))) created_estimator_specs[device_id] = spec to_regroup[_device_str(device_id)] = spec merged_estimator_spec = values.regroup(to_regroup) self.assertTrue( isinstance(merged_estimator_spec, model_fn_lib.EstimatorSpec)) self.assertEquals(model_fn_lib.ModeKeys.TRAIN, merged_estimator_spec.mode) for device_id in range(3): d = _device_str(device_id) self.assertEquals(created_estimator_specs[device_id].loss, merged_estimator_spec.loss.get(d)) self.assertEquals(created_estimator_specs[device_id].train_op, merged_estimator_spec.train_op.get(d)) # Scaffold is populated by `EstimatorSpec.__new__`. self.assertEquals(created_estimator_specs[device_id].scaffold, merged_estimator_spec.scaffold.get(d)) # Also test that we can undo the merge using select_device() self.assertEquals(created_estimator_specs[device_id], values.select_device(_device_str(device_id), merged_estimator_spec))
def testWrapClass(self): # Normally a mirrored value would be the same across devices, but # for a test it is convenient to be able to tell the values apart. result = values.regroup({_device_str(0): _nested_value("1"), _device_str(1): _nested_value("2")}, values.Mirrored) self.assertIsInstance(result, tuple) self.assertEqual(3, len(result)) self._is_per_device(result[0], ["a1", "a2"], values.Mirrored) self._is_per_device(result[2], ["h1", "h2"], values.Mirrored) self.assertIsInstance(result[1], list) self.assertEqual(3, len(result[1])) self._is_per_device(result[1][0], ["b1", "b2"], values.Mirrored) self._is_per_device(result[1][2], ["g1", "g2"], values.Mirrored) self.assertIsInstance(result[1][1], dict) self.assertEqual(set(["c", "e"]), set(result[1][1].keys())) self._is_per_device(result[1][1]["c"], ["d1", "d2"], values.Mirrored) self._is_per_device(result[1][1]["e"], ["f1", "f2"], values.Mirrored) # Also test that we can undo the merge using select_device() self.assertEqual(_nested_value("1"), values.select_device(_device_str(0), result)) self.assertEqual(_nested_value("2"), values.select_device(_device_str(1), result)) # Values are marked as mirrored, so select_device_mirrored() is allowed. self.assertEqual(_nested_value("1"), values.select_device_mirrored(_device_str(0), result)) self.assertEqual(_nested_value("2"), values.select_device_mirrored(_device_str(1), result))
def _run_steps_on_dataset(self, fn, iterator, iterations, initial_loop_values=None): if initial_loop_values is None: initial_loop_values = {} initial_loop_values = nest.flatten(initial_loop_values) ctx = values.MultiStepContext() def body(i, *args): """A wrapper around `fn` to create the while loop body.""" del args fn_inputs = iterator.get_next() if not isinstance(fn_inputs, tuple): fn_inputs = (fn_inputs,) fn_result = fn(ctx, *fn_inputs) for (name, output) in ctx.last_step_outputs.items(): # Convert all outputs to tensors, potentially from `DistributedValues`. ctx.last_step_outputs[name] = self.unwrap(output) flat_last_step_outputs = nest.flatten(ctx.last_step_outputs) with ops.control_dependencies([fn_result]): return [i + 1] + flat_last_step_outputs # We capture the control_flow_context at this point, before we run `fn` # inside a while_loop. This is useful in cases where we might need to exit # these contexts and get back to the outer context to do some things, for # e.g. create an op which should be evaluated only once at the end of the # loop on the host. One such usage is in creating metrics' value op. self._outer_control_flow_context = ( ops.get_default_graph()._get_control_flow_context()) # pylint: disable=protected-access cond = lambda i, *args: i < iterations i = constant_op.constant(0) loop_result = control_flow_ops.while_loop( cond, body, [i] + initial_loop_values, name="", parallel_iterations=1, back_prop=False, swap_memory=False, return_same_structure=True) del self._outer_control_flow_context ctx.run_op = control_flow_ops.group(loop_result) # Convert the last_step_outputs from a list to the original dict structure # of last_step_outputs. last_step_tensor_outputs = loop_result[1:] last_step_tensor_outputs_dict = nest.pack_sequence_as( ctx.last_step_outputs, last_step_tensor_outputs) for (name, aggregation) in ctx._last_step_outputs_aggregations.items(): # pylint: disable=protected-access output = last_step_tensor_outputs_dict[name] # For outputs that have already been aggregated, wrap them in a Mirrored # container, else in a PerDevice container. if aggregation is variables_lib.VariableAggregation.NONE: last_step_tensor_outputs_dict[name] = values.regroup( {d: t for d, t in zip(self._devices, output)}, values.PerDevice) else: assert len(output) == 1 last_step_tensor_outputs_dict[name] = output[0] ctx._set_last_step_outputs(last_step_tensor_outputs_dict) # pylint: disable=protected-access return ctx
def _update_non_slot(self, colocate_with, fn, *args, **kwargs): assert isinstance(colocate_with, list) # TODO(josh11b): In eager mode, use one thread per device. updates = {} for d in colocate_with: name = "update_%d" % self._device_index.get(d) with ops.device(d), distribute_lib.UpdateContext(d), ops.name_scope(name): updates[d] = fn(*values.select_device_mirrored(d, args), **values.select_device_mirrored(d, kwargs)) return values.regroup(updates, values.Mirrored)
def _run_steps_on_dataset(self, fn, iterator, iterations, initial_loop_values=None): if initial_loop_values is None: initial_loop_values = {} initial_loop_values = nest.flatten(initial_loop_values) ctx = values.MultiStepContext() def body(i, *args): """A wrapper around `fn` to create the while loop body.""" del args fn_result = fn(ctx, iterator.get_next()) for (name, output) in ctx.last_step_outputs.items(): # Convert all outputs to tensors, potentially from `DistributedValues`. ctx.last_step_outputs[name] = self.unwrap(output) flat_last_step_outputs = nest.flatten(ctx.last_step_outputs) with ops.control_dependencies([fn_result]): return [i + 1] + flat_last_step_outputs cond = lambda i, *args: i < iterations i = constant_op.constant(0) loop_result = control_flow_ops.while_loop(cond, body, [i] + initial_loop_values, name="", parallel_iterations=1, back_prop=False, swap_memory=False, return_same_structure=True) ctx.run_op = control_flow_ops.group(loop_result) # Convert the last_step_outputs from a list to the original dict structure # of last_step_outputs. last_step_tensor_outputs = loop_result[1:] last_step_tensor_outputs_dict = nest.pack_sequence_as( ctx.last_step_outputs, last_step_tensor_outputs) for (name, aggregation) in ctx._last_step_outputs_aggregations.items(): # pylint: disable=protected-access output = last_step_tensor_outputs_dict[name] # For outputs that have already been aggregated, wrap them in a Mirrored # container, else in a PerDevice container. if aggregation is variables_lib.VariableAggregation.NONE: last_step_tensor_outputs_dict[name] = values.regroup( {d: t for d, t in zip(self._devices, output)}, values.PerDevice) else: assert len(output) == 1 last_step_tensor_outputs_dict[name] = output[0] ctx._set_last_step_outputs(last_step_tensor_outputs_dict) # pylint: disable=protected-access return ctx
def _update(self, var, fn, *args, **kwargs): # TODO(josh11b): In eager mode, use one thread per device. assert isinstance(var, values.DistributedVariable) updates = {} for d, v in var._index.items(): # pylint: disable=protected-access name = "update_%d" % self._device_index.get(d) with ops.device(d), distribute_lib.UpdateContext(d), ops.name_scope(name): # If args and kwargs are not mirrored, the value is returned as is. updates[d] = fn(v, *values.select_device_mirrored(d, args), **values.select_device_mirrored(d, kwargs)) return values.regroup(updates, values.Mirrored)
def _update(self, var, fn, *args, **kwargs): # TODO(josh11b): Also support TowerLocalVariables here? If so, args and # kwargs don't need to be mirrored. assert isinstance(var, values.MirroredVariable) # TODO(josh11b): In eager mode, use one thread per device. updates = {} for d, v in var._index.items(): # pylint: disable=protected-access name = "update_%d" % self._device_index.get(d) with ops.device(d), distribute_lib.UpdateContext( d), ops.name_scope(name): updates[d] = fn(v, *values.select_device_mirrored(d, args), **values.select_device_mirrored(d, kwargs)) return values.regroup(updates, values.Mirrored)
def _update(self, var, fn, *args, **kwargs): # TODO(josh11b): Also support TowerLocalVariables here? If so, args and # kwargs don't need to be mirrored. assert isinstance(var, values.MirroredVariable) # TODO(josh11b): In eager mode, use one thread per device. updates = {} for d, v in var._index.items(): # pylint: disable=protected-access name = "update_%d" % self._device_index.get(d) with ops.device(d), distribute_lib.UpdateContext(d), ops.name_scope(name): updates[d] = fn(v, *values.select_device_mirrored(d, args), **values.select_device_mirrored(d, kwargs)) return values.regroup(updates, values.Mirrored)
def _run_steps_on_dataset(self, fn, iterator, iterations, initial_loop_values=None): if initial_loop_values is None: initial_loop_values = {} initial_loop_values = nest.flatten(initial_loop_values) ctx = values.MultiStepContext() def body(i, *args): """A wrapper around `fn` to create the while loop body.""" del args fn_inputs = iterator.get_next() if not isinstance(fn_inputs, tuple): fn_inputs = (fn_inputs,) fn_result = fn(ctx, *fn_inputs) for (name, output) in ctx.last_step_outputs.items(): # Convert all outputs to tensors, potentially from `DistributedValues`. ctx.last_step_outputs[name] = self.unwrap(output) flat_last_step_outputs = nest.flatten(ctx.last_step_outputs) with ops.control_dependencies([fn_result]): return [i + 1] + flat_last_step_outputs cond = lambda i, *args: i < iterations i = constant_op.constant(0) loop_result = control_flow_ops.while_loop( cond, body, [i] + initial_loop_values, name="", parallel_iterations=1, back_prop=False, swap_memory=False, return_same_structure=True) ctx.run_op = control_flow_ops.group(loop_result) # Convert the last_step_outputs from a list to the original dict structure # of last_step_outputs. last_step_tensor_outputs = loop_result[1:] last_step_tensor_outputs_dict = nest.pack_sequence_as( ctx.last_step_outputs, last_step_tensor_outputs) for (name, aggregation) in ctx._last_step_outputs_aggregations.items(): # pylint: disable=protected-access output = last_step_tensor_outputs_dict[name] # For outputs that have already been aggregated, wrap them in a Mirrored # container, else in a PerDevice container. if aggregation is variables_lib.VariableAggregation.NONE: last_step_tensor_outputs_dict[name] = values.regroup( {d: t for d, t in zip(self._devices, output)}, values.PerDevice) else: assert len(output) == 1 last_step_tensor_outputs_dict[name] = output[0] ctx._set_last_step_outputs(last_step_tensor_outputs_dict) # pylint: disable=protected-access return ctx
def testSameId(self): foo = object() result = values.regroup({_device_str(0): ("a", foo), _device_str(1): ("b", foo)}) self.assertIsInstance(result, tuple) self.assertEqual(2, len(result)) self._is_per_device(result[0], ["a", "b"]) self.assertIs(foo, result[1]) # Test select_device(), should undo the merge done by regroup(). result_0 = values.select_device(_device_str(0), result) self.assertIsInstance(result_0, tuple) self.assertEqual(2, len(result_0)) self.assertEqual("a", result_0[0]) self.assertIs(foo, result_0[1]) result_1 = values.select_device(_device_str(1), result) self.assertIsInstance(result_1, tuple) self.assertEqual(2, len(result_1)) self.assertEqual("b", result_1[0]) self.assertIs(foo, result_1[1])
def _update(self, var, fn, *args, **kwargs): # TODO(jhseu): Consider supporting grouped==False. assert isinstance(var, values.TPUMirroredVariable) if values._enclosing_tpu_context() is not None: # pylint: disable=protected-access return fn(var, *args, **kwargs) # Otherwise, we revert to MirroredStrategy behavior and update each variable # directly. updates = {} for d, v in var._index.items(): # pylint: disable=protected-access name = "update_%d" % self._device_index.get(d) with ops.device(d), distribute_lib.UpdateContext(d), ops.name_scope(name): # If args and kwargs are not mirrored, the value is returned as is. updates[d] = fn(v, *values.select_device_mirrored(d, args), **values.select_device_mirrored(d, kwargs)) # Make a single control dependency to keep the variables mirrored. If one # assignment is fetched, then run all assignments. sorted_keys = sorted(updates.keys()) update_tuple = control_flow_ops.tuple([updates[d] for d in sorted_keys]) for i, d in enumerate(sorted_keys): updates[d] = update_tuple[i] return values.regroup(updates, values.Mirrored)
def _call_for_each_replica(distribution, fn, args, kwargs): """Run `fn` in separate threads, once per replica/worker device. Args: distribution: the DistributionStrategy object. fn: function to run (will be run once per device, each in its own thread). args: positional arguments for `fn` kwargs: keyword arguments for `fn`. Returns: Merged return value of `fn` across all replicas. Raises: RuntimeError: If fn() calls get_replica_context().merge_call() a different number of times from the available devices. """ # TODO(josh11b): Add this option once we add synchronization to variable # creation. Until then, this is pretty unsafe to use. run_concurrently = False if not context.executing_eagerly(): # Needed for per-thread device, etc. contexts in graph mode. ops.get_default_graph().switch_to_thread_local() coord = coordinator.Coordinator( clean_stop_exception_types=(_RequestedStop, )) shared_variable_store = {} # TODO(isaprykin): Create these threads once instead of during every run() # call. threads = [] for index, d in enumerate(distribution.worker_devices): variable_creator_fn = shared_variable_creator.make_fn( shared_variable_store, index) t = MirroredStrategy._MirroredReplicaThread( # pylint: disable=protected-access distribution, coord, d, variable_creator_fn, fn, *values.select_device(d, args), **values.select_device(d, kwargs)) threads.append(t) for t in threads: t.start() # When `fn` starts `should_run` event is set on _MirroredReplicaThread # (`MRT`) threads. The execution waits until # `MRT.has_paused` is set, which indicates that either `fn` is # complete or a `get_replica_context().merge_call()` is called. If `fn` is # complete, then `MRT.done` is set to True. Otherwise, arguments # of `get_replica_context().merge_call` from all paused threads are grouped # and the `merge_fn` is performed. Results of the # `get_replica_context().merge_call` are then set to `MRT.merge_result`. # Each such `get_replica_context().merge_call` call returns the # `MRT.merge_result` for that thread when `MRT.should_run` event # is reset again. Execution of `fn` resumes. try: with coord.stop_on_exception(): all_done = False while not all_done and not coord.should_stop(): done = [] if run_concurrently: for t in threads: t.should_run.set() for t in threads: t.has_paused.wait() t.has_paused.clear() if coord.should_stop(): return None done.append(t.done) else: for t in threads: t.should_run.set() t.has_paused.wait() t.has_paused.clear() if coord.should_stop(): return None done.append(t.done) if coord.should_stop(): return None all_done = all(done) if not all_done: if any(done): raise RuntimeError( "Some replicas made a different number of " "replica_context().merge_call() calls.") # get_replica_context().merge_call() case merge_args = values.regroup( {t.device: t.merge_args for t in threads}) merge_kwargs = values.regroup( {t.device: t.merge_kwargs for t in threads}) # We capture the name_scope of the MRT when we call merge_fn # to ensure that if we have opened a name scope in the MRT, # it will be respected when executing the merge function. We only # capture the name_scope from the first MRT and assume it is # the same for all other MRTs. mtt_captured_name_scope = threads[0].captured_name_scope with ops.name_scope(mtt_captured_name_scope): merge_result = threads[0].merge_fn( distribution, *merge_args, **merge_kwargs) for t in threads: t.merge_result = values.select_device( t.device, merge_result) finally: for t in threads: t.should_run.set() coord.join(threads) return values.regroup({t.device: t.main_result for t in threads})
def testMirroredContainer(self): if context.num_gpus() < 1 and context.executing_eagerly(): self.skipTest("A GPU is not available for this test in eager mode.") v, devices, mirrored = _make_mirrored() result = values.regroup(dict(zip(devices, v))) self.assertIs(mirrored, result)
def _call_for_each_tower(distribution, fn, *args, **kwargs): """Run `fn` in separate threads, once per tower/worker device. Args: distribution: the DistributionStrategy object. fn: function to run (will be run once per device, each in its own thread). *args: positional arguments for `fn` **kwargs: keyword arguments for `fn`. `"run_concurrently"`: Boolean indicating whether executions of `fn` can be run concurrently (under eager execution only), defaults to `True`. Returns: Merged return value of `fn` across all towers. Raises: RuntimeError: If fn() calls get_tower_context().merge_call() a different number of times from the available devices. """ run_concurrently = kwargs.pop("run_concurrently", True) if not context.executing_eagerly(): # Lots of TF library code isn't thread-safe in graph mode, and # there is little to be gained by turning on multithreading when # constructing a graph. run_concurrently = False # Needed for per-thread device, etc. contexts in graph mode. ops.get_default_graph().switch_to_thread_local() elif run_concurrently is None: run_concurrently = True coord = coordinator.Coordinator(clean_stop_exception_types=(_RequestedStop,)) shared_variable_store = {} # TODO(isaprykin): Create these threads once instead of during every run() # call. threads = [] for index, d in enumerate(distribution.worker_devices): variable_creator_fn = shared_variable_creator.make_fn( shared_variable_store, index) t = MirroredStrategy._MirroredTowerThread( # pylint: disable=protected-access distribution, coord, d, variable_creator_fn, fn, *values.select_device(d, args), **values.select_device(d, kwargs)) threads.append(t) for t in threads: t.start() # When `fn` starts `should_run` event is set on _MirroredTowerThread # (`MTT`) threads. The execution waits until # `MTT.has_paused` is set, which indicates that either `fn` is # complete or a `get_tower_context().merge_call()` is called. If `fn` is # complete, then `MTT.done` is set to True. Otherwise, arguments # of `get_tower_context().merge_call` from all paused threads are grouped # and the `merge_fn` is performed. Results of the # `get_tower_context().merge_call` are then set to `MTT.merge_result`. # Each such `get_tower_context().merge_call` call returns the # `MTT.merge_result` for that thread when `MTT.should_run` event # is reset again. Execution of `fn` resumes. try: with coord.stop_on_exception(): all_done = False while not all_done and not coord.should_stop(): done = [] if run_concurrently: for t in threads: t.should_run.set() for t in threads: t.has_paused.wait() t.has_paused.clear() if coord.should_stop(): return None done.append(t.done) else: for t in threads: t.should_run.set() t.has_paused.wait() t.has_paused.clear() if coord.should_stop(): return None done.append(t.done) if coord.should_stop(): return None all_done = all(done) if not all_done: if any(done): raise RuntimeError("Some towers made a different number of " "tower_context().merge_call() calls.") # get_tower_context().merge_call() case merge_args = values.regroup({t.device: t.merge_args for t in threads}) merge_kwargs = values.regroup( {t.device: t.merge_kwargs for t in threads}) # We capture the name_scope of the MTT when we call merge_fn # to ensure that if we have opened a name scope in the MTT, # it will be respected when executing the merge function. We only # capture the name_scope from the first MTT and assume it is # the same for all other MTTs. mtt_captured_name_scope = threads[0].captured_name_scope with ops.name_scope(mtt_captured_name_scope): merge_result = threads[0].merge_fn(distribution, *merge_args, **merge_kwargs) for t in threads: t.merge_result = values.select_device(t.device, merge_result) finally: for t in threads: t.should_run.set() coord.join(threads) return values.regroup({t.device: t.main_result for t in threads})
def _call_for_each_tower(self, fn, *args, **kwargs): """Run `fn` in separate threads, once per tower/worker device. Args: fn: function to run (will be run once per device, each in its own thread). *args: positional arguments for `fn` **kwargs: keyword arguments for `fn`. `"run_concurrently"`: Boolean indicating whether executions of `fn` can be run concurrently (under eager execution only), defaults to `True`. Returns: Merged return value of `fn` across all towers. Raises: RuntimeError: If fn() calls get_tower_context().merge_call() a different number of times for when called for different devices. """ run_concurrently = kwargs.pop("run_concurrently", True) if not context.executing_eagerly(): # Lots of TF library code isn't thread-safe in graph mode, and # there is little to be gained by turning on multithreading when # constructing a graph. run_concurrently = False # Needed for per-thread device, etc. contexts in graph mode. ops.get_default_graph().switch_to_thread_local() elif run_concurrently is None: run_concurrently = True coord = coordinator.Coordinator( clean_stop_exception_types=(_RequestedStop, )) shared_variable_store = {} # TODO(isaprykin): Create these threads once instead of during every run() # call. threads = [] for index, d in enumerate(self._devices): variable_creator_fn = shared_variable_creator.make_fn( shared_variable_store, index) t = MirroredStrategy._MirroredTowerThread( self, coord, d, variable_creator_fn, fn, *values.select_device(d, args), **values.select_device(d, kwargs)) threads.append(t) for t in threads: t.start() # When `fn` starts `should_run` event is set on _MirroredTowerThread # (`MTT`) threads. The execution waits until # `MTT.has_paused` is set, which indicates that either `fn` is # complete or a `get_tower_context().merge_call()` is called. If `fn` is # complete, then `MTT.done` is set to True. Otherwise, arguments # of `get_tower_context().merge_call` from all paused threads are grouped # and the `merge_fn` is performed. Results of the # `get_tower_context().merge_call` are then set to `MTT.merge_result`. # Each such `get_tower_context().merge_call` call returns the # `MTT.merge_result` for that thread when `MTT.should_run` event # is reset again. Execution of `fn` resumes. try: with coord.stop_on_exception(): all_done = False while not all_done and not coord.should_stop(): done = [] if run_concurrently: for t in threads: t.should_run.set() for t in threads: t.has_paused.wait() t.has_paused.clear() if coord.should_stop(): return None done.append(t.done) else: for t in threads: t.should_run.set() t.has_paused.wait() t.has_paused.clear() if coord.should_stop(): return None done.append(t.done) if coord.should_stop(): return None all_done = all(done) if not all_done: if any(done): raise RuntimeError( "Some towers made a different number of " "tower_context().merge_call() calls.") # get_tower_context().merge_call() case merge_args = values.regroup( {t.device: t.merge_args for t in threads}) merge_kwargs = values.regroup( {t.device: t.merge_kwargs for t in threads}) merge_result = threads[0].merge_fn( self, *merge_args, **merge_kwargs) for t in threads: t.merge_result = values.select_device( t.device, merge_result) finally: for t in threads: t.should_run.set() coord.join(threads) return values.regroup({t.device: t.main_result for t in threads})