def testSharedVariable(self): shared_variable_store = {} num_devices = 3 creator_fns = [] for i in range(num_devices): creator_fn = shared_variable_creator.make_fn(shared_variable_store, i) creator_fns.append(creator_fn) with variable_scope.variable_creator_scope(creator_fns[0]): v0 = variable_scope.variable(1.0, name="foo") with variable_scope.variable_creator_scope(creator_fns[1]): v1 = variable_scope.variable(1.0, name="foo") with variable_scope.variable_creator_scope(creator_fns[2]): v2 = variable_scope.variable(1.0, name="foo") # v1 and v2 should be same as v0 self.assertIs(v1, v0) self.assertIs(v2, v0)
def testSharedVariable(self): shared_variable_store = {} num_devices = 3 creator_fns = [] for i in range(num_devices): creator_fn = shared_variable_creator.make_fn( shared_variable_store, i) creator_fns.append(creator_fn) with variable_scope.variable_creator_scope(creator_fns[0]): v0 = variable_scope.variable(1.0, name="foo") with variable_scope.variable_creator_scope(creator_fns[1]): v1 = variable_scope.variable(1.0, name="foo") with variable_scope.variable_creator_scope(creator_fns[2]): v2 = variable_scope.variable(1.0, name="foo") # v1 and v2 should be same as v0 self.assertIs(v1, v0) self.assertIs(v2, v0)
def _call_for_each_replica(distribution, fn, args, kwargs): """Run `fn` in separate threads, once per replica/worker device. Args: distribution: the DistributionStrategy object. fn: function to run (will be run once per device, each in its own thread). args: positional arguments for `fn` kwargs: keyword arguments for `fn`. Returns: Merged return value of `fn` across all replicas. Raises: RuntimeError: If fn() calls get_replica_context().merge_call() a different number of times from the available devices. """ # TODO(josh11b): Add this option once we add synchronization to variable # creation. Until then, this is pretty unsafe to use. run_concurrently = False if not context.executing_eagerly(): # Needed for per-thread device, etc. contexts in graph mode. ops.get_default_graph().switch_to_thread_local() coord = coordinator.Coordinator( clean_stop_exception_types=(_RequestedStop, )) shared_variable_store = {} # TODO(isaprykin): Create these threads once instead of during every run() # call. threads = [] for index, d in enumerate(distribution.worker_devices): variable_creator_fn = shared_variable_creator.make_fn( shared_variable_store, index) t = MirroredStrategy._MirroredReplicaThread( # pylint: disable=protected-access distribution, coord, d, variable_creator_fn, fn, *values.select_device(d, args), **values.select_device(d, kwargs)) threads.append(t) for t in threads: t.start() # When `fn` starts `should_run` event is set on _MirroredReplicaThread # (`MRT`) threads. The execution waits until # `MRT.has_paused` is set, which indicates that either `fn` is # complete or a `get_replica_context().merge_call()` is called. If `fn` is # complete, then `MRT.done` is set to True. Otherwise, arguments # of `get_replica_context().merge_call` from all paused threads are grouped # and the `merge_fn` is performed. Results of the # `get_replica_context().merge_call` are then set to `MRT.merge_result`. # Each such `get_replica_context().merge_call` call returns the # `MRT.merge_result` for that thread when `MRT.should_run` event # is reset again. Execution of `fn` resumes. try: with coord.stop_on_exception(): all_done = False while not all_done and not coord.should_stop(): done = [] if run_concurrently: for t in threads: t.should_run.set() for t in threads: t.has_paused.wait() t.has_paused.clear() if coord.should_stop(): return None done.append(t.done) else: for t in threads: t.should_run.set() t.has_paused.wait() t.has_paused.clear() if coord.should_stop(): return None done.append(t.done) if coord.should_stop(): return None all_done = all(done) if not all_done: if any(done): raise RuntimeError( "Some replicas made a different number of " "replica_context().merge_call() calls.") # get_replica_context().merge_call() case merge_args = values.regroup( {t.device: t.merge_args for t in threads}) merge_kwargs = values.regroup( {t.device: t.merge_kwargs for t in threads}) # We capture the name_scope of the MRT when we call merge_fn # to ensure that if we have opened a name scope in the MRT, # it will be respected when executing the merge function. We only # capture the name_scope from the first MRT and assume it is # the same for all other MRTs. mtt_captured_name_scope = threads[0].captured_name_scope with ops.name_scope(mtt_captured_name_scope): merge_result = threads[0].merge_fn( distribution, *merge_args, **merge_kwargs) for t in threads: t.merge_result = values.select_device( t.device, merge_result) finally: for t in threads: t.should_run.set() coord.join(threads) return values.regroup({t.device: t.main_result for t in threads})
def _call_for_each_tower(distribution, fn, *args, **kwargs): """Run `fn` in separate threads, once per tower/worker device. Args: distribution: the DistributionStrategy object. fn: function to run (will be run once per device, each in its own thread). *args: positional arguments for `fn` **kwargs: keyword arguments for `fn`. `"run_concurrently"`: Boolean indicating whether executions of `fn` can be run concurrently (under eager execution only), defaults to `True`. Returns: Merged return value of `fn` across all towers. Raises: RuntimeError: If fn() calls get_tower_context().merge_call() a different number of times from the available devices. """ run_concurrently = kwargs.pop("run_concurrently", True) if not context.executing_eagerly(): # Lots of TF library code isn't thread-safe in graph mode, and # there is little to be gained by turning on multithreading when # constructing a graph. run_concurrently = False # Needed for per-thread device, etc. contexts in graph mode. ops.get_default_graph().switch_to_thread_local() elif run_concurrently is None: run_concurrently = True coord = coordinator.Coordinator(clean_stop_exception_types=(_RequestedStop,)) shared_variable_store = {} # TODO(isaprykin): Create these threads once instead of during every run() # call. threads = [] for index, d in enumerate(distribution.worker_devices): variable_creator_fn = shared_variable_creator.make_fn( shared_variable_store, index) t = MirroredStrategy._MirroredTowerThread( # pylint: disable=protected-access distribution, coord, d, variable_creator_fn, fn, *values.select_device(d, args), **values.select_device(d, kwargs)) threads.append(t) for t in threads: t.start() # When `fn` starts `should_run` event is set on _MirroredTowerThread # (`MTT`) threads. The execution waits until # `MTT.has_paused` is set, which indicates that either `fn` is # complete or a `get_tower_context().merge_call()` is called. If `fn` is # complete, then `MTT.done` is set to True. Otherwise, arguments # of `get_tower_context().merge_call` from all paused threads are grouped # and the `merge_fn` is performed. Results of the # `get_tower_context().merge_call` are then set to `MTT.merge_result`. # Each such `get_tower_context().merge_call` call returns the # `MTT.merge_result` for that thread when `MTT.should_run` event # is reset again. Execution of `fn` resumes. try: with coord.stop_on_exception(): all_done = False while not all_done and not coord.should_stop(): done = [] if run_concurrently: for t in threads: t.should_run.set() for t in threads: t.has_paused.wait() t.has_paused.clear() if coord.should_stop(): return None done.append(t.done) else: for t in threads: t.should_run.set() t.has_paused.wait() t.has_paused.clear() if coord.should_stop(): return None done.append(t.done) if coord.should_stop(): return None all_done = all(done) if not all_done: if any(done): raise RuntimeError("Some towers made a different number of " "tower_context().merge_call() calls.") # get_tower_context().merge_call() case merge_args = values.regroup({t.device: t.merge_args for t in threads}) merge_kwargs = values.regroup( {t.device: t.merge_kwargs for t in threads}) # We capture the name_scope of the MTT when we call merge_fn # to ensure that if we have opened a name scope in the MTT, # it will be respected when executing the merge function. We only # capture the name_scope from the first MTT and assume it is # the same for all other MTTs. mtt_captured_name_scope = threads[0].captured_name_scope with ops.name_scope(mtt_captured_name_scope): merge_result = threads[0].merge_fn(distribution, *merge_args, **merge_kwargs) for t in threads: t.merge_result = values.select_device(t.device, merge_result) finally: for t in threads: t.should_run.set() coord.join(threads) return values.regroup({t.device: t.main_result for t in threads})
def _call_for_each_tower(self, fn, *args, **kwargs): """Run `fn` in separate threads, once per tower/worker device. Args: fn: function to run (will be run once per device, each in its own thread). *args: positional arguments for `fn` **kwargs: keyword arguments for `fn`. `"run_concurrently"`: Boolean indicating whether executions of `fn` can be run concurrently (under eager execution only), defaults to `True`. Returns: Merged return value of `fn` across all towers. Raises: RuntimeError: If fn() calls get_tower_context().merge_call() a different number of times for when called for different devices. """ run_concurrently = kwargs.pop("run_concurrently", True) if not context.executing_eagerly(): # Lots of TF library code isn't thread-safe in graph mode, and # there is little to be gained by turning on multithreading when # constructing a graph. run_concurrently = False # Needed for per-thread device, etc. contexts in graph mode. ops.get_default_graph().switch_to_thread_local() elif run_concurrently is None: run_concurrently = True coord = coordinator.Coordinator( clean_stop_exception_types=(_RequestedStop, )) shared_variable_store = {} # TODO(isaprykin): Create these threads once instead of during every run() # call. threads = [] for index, d in enumerate(self._devices): variable_creator_fn = shared_variable_creator.make_fn( shared_variable_store, index) t = MirroredStrategy._MirroredTowerThread( self, coord, d, variable_creator_fn, fn, *values.select_device(d, args), **values.select_device(d, kwargs)) threads.append(t) for t in threads: t.start() # When `fn` starts `should_run` event is set on _MirroredTowerThread # (`MTT`) threads. The execution waits until # `MTT.has_paused` is set, which indicates that either `fn` is # complete or a `get_tower_context().merge_call()` is called. If `fn` is # complete, then `MTT.done` is set to True. Otherwise, arguments # of `get_tower_context().merge_call` from all paused threads are grouped # and the `merge_fn` is performed. Results of the # `get_tower_context().merge_call` are then set to `MTT.merge_result`. # Each such `get_tower_context().merge_call` call returns the # `MTT.merge_result` for that thread when `MTT.should_run` event # is reset again. Execution of `fn` resumes. try: with coord.stop_on_exception(): all_done = False while not all_done and not coord.should_stop(): done = [] if run_concurrently: for t in threads: t.should_run.set() for t in threads: t.has_paused.wait() t.has_paused.clear() if coord.should_stop(): return None done.append(t.done) else: for t in threads: t.should_run.set() t.has_paused.wait() t.has_paused.clear() if coord.should_stop(): return None done.append(t.done) if coord.should_stop(): return None all_done = all(done) if not all_done: if any(done): raise RuntimeError( "Some towers made a different number of " "tower_context().merge_call() calls.") # get_tower_context().merge_call() case merge_args = values.regroup( {t.device: t.merge_args for t in threads}) merge_kwargs = values.regroup( {t.device: t.merge_kwargs for t in threads}) merge_result = threads[0].merge_fn( self, *merge_args, **merge_kwargs) for t in threads: t.merge_result = values.select_device( t.device, merge_result) finally: for t in threads: t.should_run.set() coord.join(threads) return values.regroup({t.device: t.main_result for t in threads})