def shutdown(graceful=True): r""" Perform a shutdown of the RPC agent, and then destroy the RPC agent. This stops the local agent from accepting outstanding requests, and shuts down the RPC framework by terminating all RPC threads. If ``graceful=True``, this will block until all local and remote RPC processes reach this method and wait for all outstanding work to complete. Otherwise, if ``graceful=False``, this is a local shutdown, and it does not wait for other RPC processes to reach this method. .. warning:: For :class:`~torch.futures.Future` objects returned by :meth:`~torch.distributed.rpc.rpc_async`, ``future.wait()`` should not be called after ``shutdown()``. Args: graceful (bool): Whether to do a graceful shutdown or not. If True, this will 1) wait until there is no pending system messages for ``UserRRefs`` and delete them; 2) block until all local and remote RPC processes have reached this method and wait for all outstanding work to complete. Example:: Make sure that ``MASTER_ADDR`` and ``MASTER_PORT`` are set properly on both workers. Refer to :meth:`~torch.distributed.init_process_group` API for more details. For example, >>> export MASTER_ADDR=localhost >>> export MASTER_PORT=5678 Then run the following code in two different processes: >>> # On worker 0: >>> import torch >>> import torch.distributed.rpc as rpc >>> rpc.init_rpc("worker0", rank=0, world_size=2) >>> # do some work >>> result = rpc.rpc_sync("worker1", torch.add, args=(torch.ones(1), 1)) >>> # ready to shutdown >>> rpc.shutdown() >>> # On worker 1: >>> import torch.distributed.rpc as rpc >>> rpc.init_rpc("worker1", rank=1, world_size=2) >>> # wait for worker 0 to finish work, and then shutdown. >>> rpc.shutdown() """ if graceful: try: _wait_all_workers() _delete_all_user_and_unforked_owner_rrefs() _get_current_rpc_agent().join(shutdown=True) finally: # In case of errors, continue to complete the local shutdown. _finalize_shutdown() else: _finalize_shutdown()
def get_worker_info(worker_name=None): r""" Get :class:`~torch.distributed.rpc.WorkerInfo` of a given worker name. Use this :class:`~torch.distributed.rpc.WorkerInfo` to avoid passing an expensive string on every invocation. Args: worker_name (str): the string name of a worker. If ``None``, return the the id of the current worker. (default ``None``) Returns: :class:`~torch.distributed.rpc.WorkerInfo` instance for the given ``worker_name`` or :class:`~torch.distributed.rpc.WorkerInfo` of the current worker if ``worker_name`` is ``None``. """ if worker_name is not None: return _get_current_rpc_agent().get_worker_info(worker_name) else: return _get_current_rpc_agent().get_worker_info()
def _finalize_shutdown(): try: # This raises a `TORCH_CHECK()` exception on RRef leak detected. _destroy_rref_context(_ignore_rref_leak) finally: _get_current_rpc_agent().shutdown() # clean up python rpc handler in shutdown(), see comments in # PythonRpcHandler::cleanup(), call it in python API because the # cleanup() function has python dependency, it assumes python # interpreter exists. # No matter if RRef leak exception is raised, this clean-up code # must run to avoid destruction segfault in Python 3.5. # # future.wait() should not be called after shutdown(). # pythonRpcHandler is cleaned up in shutdown(), after # shutdown(), python objects returned from rpc python call can not be # resolved. _cleanup_python_rpc_handler() _reset_current_rpc_agent()
def shutdown(graceful=True): r""" Perform a shutdown of the RPC agent, and then destroy the RPC agent. This stops the local agent from accepting outstanding requests, and shuts down the RPC framework by terminating all RPC threads. If ``graceful=True``, this will block until all local and remote RPC processes reach this method and wait for all outstanding work to complete. Otherwise, if ``graceful=False``, this is a local shutdown, and it does not wait for other RPC processes to reach this method. .. warning:: For :class:`~torch.futures.Future` objects returned by :meth:`~torch.distributed.rpc.rpc_async`, ``future.wait()`` should not be called after ``shutdown()``. Args: graceful (bool): Whether to do a graceful shutdown or not. If True, this will 1) wait until there is no pending system messages for ``UserRRefs`` and delete them; 2) block until all local and remote RPC processes have reached this method and wait for all outstanding work to complete. Example:: Make sure that ``MASTER_ADDR`` and ``MASTER_PORT`` are set properly on both workers. Refer to :meth:`~torch.distributed.init_process_group` API for more details. For example, >>> export MASTER_ADDR=localhost >>> export MASTER_PORT=5678 Then run the following code in two different processes: >>> # On worker 0: >>> import torch >>> import torch.distributed.rpc as rpc >>> rpc.init_rpc("worker0", rank=0, world_size=2) >>> # do some work >>> result = rpc.rpc_sync("worker1", torch.add, args=(torch.ones(1), 1)) >>> # ready to shutdown >>> rpc.shutdown() >>> # On worker 1: >>> import torch.distributed.rpc as rpc >>> rpc.init_rpc("worker1", rank=1, world_size=2) >>> # wait for worker 0 to finish work, and then shutdown. >>> rpc.shutdown() """ if graceful: _wait_all_workers() _delete_all_user_and_unforked_owner_rrefs() _get_current_rpc_agent().join(shutdown=True) try: # This raises a `TORCH_CHECK()` exception on RRef leak detected. _destroy_rref_context(_ignore_rref_leak) finally: _get_current_rpc_agent().shutdown() # clean up python rpc handler in shutdown(), see comments in # PythonRpcHandler::cleanup(), call it in python API because the # cleanup() function has python dependency, it assumes python # interpreter exists. # No matter if RRef leak exception is raised, this clean-up code # must run to avoid destruction segfault in Python 3.5. # # future.wait() should not be called after shutdown(). # pythonRpcHandler is cleaned up in shutdown(), after # shutdown(), python objects returned from rpc python call can not be # resolved. _cleanup_python_rpc_handler() _reset_current_rpc_agent()
def _all_gather(obj, timeout=UNSET_RPC_TIMEOUT): r""" This is similar to torch.distributed.all_gather(), but is using RPC. It picks the worker with the smallest name (alphabetic order) as the leader. Then all followers send their data ``obj`` to the leader. After the leader has received all, it will broadcast the results back to all followers. This function blocks until all workers have received the gathered results. """ assert (_ALL_WORKER_NAMES is not None ), "`_ALL_WORKER_NAMES` is not initialized for `def _all_gather`." leader_name = sorted(_ALL_WORKER_NAMES)[0] self_name = _get_current_rpc_agent().get_worker_info().name global _all_gather_sequence_id with _all_gather_dict_lock: sequence_id = _all_gather_sequence_id _all_gather_sequence_id += 1 is_leader = leader_name == self_name if timeout == UNSET_RPC_TIMEOUT: timeout = get_rpc_timeout() # Phase 1: Followers send it's object to the leader if is_leader: _gather_to_leader(sequence_id, self_name, obj) else: rpc_sync( leader_name, _gather_to_leader, args=(sequence_id, self_name, obj), timeout=timeout, ) with _all_gather_dict_lock: states = _all_gather_sequence_id_to_states[sequence_id] states.proceed_signal.wait() # Phase 2: Leader broadcast gathered results to all followers # Leader's signal is the first to be unblocked, after receiving all # followers' data objects. if is_leader: worker_name_to_response_future_dict = dict() for follower_name in _ALL_WORKER_NAMES - {leader_name}: fut = rpc_async(follower_name, _broadcast_to_followers, args=(sequence_id, states.gathered_objects), timeout=timeout) worker_name_to_response_future_dict[follower_name] = fut errors = [] for follower_name, fut in worker_name_to_response_future_dict.items(): try: fut.wait() except RuntimeError as ex: errors.append((follower_name, ex)) if errors: raise RuntimeError( f"Followers {[e[0] for e in errors]} timed out in _all_gather " f"after {timeout:.2f} seconds. The first exception is {errors[0][1]}" ) return states.gathered_objects
def shutdown(graceful=True, timeout=DEFAULT_SHUTDOWN_TIMEOUT): r""" Perform a shutdown of the RPC agent, and then destroy the RPC agent. This stops the local agent from accepting outstanding requests, and shuts down the RPC framework by terminating all RPC threads. If ``graceful=True``, this will block until all local and remote RPC processes reach this method and wait for all outstanding work to complete. Otherwise, if ``graceful=False``, this is a local shutdown, and it does not wait for other RPC processes to reach this method. .. warning:: For :class:`~torch.futures.Future` objects returned by :meth:`~torch.distributed.rpc.rpc_async`, ``future.wait()`` should not be called after ``shutdown()``. Args: graceful (bool): Whether to do a graceful shutdown or not. If True, this will 1) wait until there is no pending system messages for ``UserRRefs`` and delete them; 2) block until all local and remote RPC processes have reached this method and wait for all outstanding work to complete. Example:: Make sure that ``MASTER_ADDR`` and ``MASTER_PORT`` are set properly on both workers. Refer to :meth:`~torch.distributed.init_process_group` API for more details. For example, >>> export MASTER_ADDR=localhost >>> export MASTER_PORT=5678 Then run the following code in two different processes: >>> # On worker 0: >>> import torch >>> import torch.distributed.rpc as rpc >>> rpc.init_rpc("worker0", rank=0, world_size=2) >>> # do some work >>> result = rpc.rpc_sync("worker1", torch.add, args=(torch.ones(1), 1)) >>> # ready to shutdown >>> rpc.shutdown() >>> # On worker 1: >>> import torch.distributed.rpc as rpc >>> rpc.init_rpc("worker1", rank=1, world_size=2) >>> # wait for worker 0 to finish work, and then shutdown. >>> rpc.shutdown() """ if graceful: try: agent = _get_current_rpc_agent() if not isinstance(agent, TensorPipeAgent) or agent.is_static_group: _wait_all_workers(timeout) _delete_all_user_and_unforked_owner_rrefs() agent.join(shutdown=True, timeout=timeout) else: # This is a dynamic group so we need to grab the token for the operation my_worker_info = agent.get_worker_info() my_name = my_worker_info.name with _group_membership_management(agent.store, my_name, False): all_worker_infos = agent.get_worker_infos() for worker in all_worker_infos: if worker.name != my_name: rpc_sync(worker.name, _update_group_membership, args=(my_worker_info, [], {}, False)) agent.join(shutdown=True, timeout=timeout) finally: # In case of errors, continue to complete the local shutdown. _finalize_shutdown() else: _finalize_shutdown()
def _all_gather(obj, worker_names=None, timeout=UNSET_RPC_TIMEOUT): r""" This is similar to torch.distributed.all_gather(), but is using RPC. It picks the worker with the smallest name (alphabetic order) as the leader. Then all followers send their data ``obj`` to the leader. After the leader has received all, it will broadcast the results back to all followers. This function blocks until all workers have received the gathered results. """ if not worker_names: assert ( _ALL_WORKER_NAMES is not None ), "`_ALL_WORKER_NAMES` is not initialized for `def _all_gather`." worker_names = _ALL_WORKER_NAMES leader_name = sorted(worker_names)[0] self_name = _get_current_rpc_agent().get_worker_info().name with _all_gather_dict_lock: concat_names = "".join(sorted(worker_names)) sequence_num = _all_gather_sequence_id.get(concat_names, 0) _all_gather_sequence_id[concat_names] = sequence_num + 1 sequence_id = concat_names + str(sequence_num) is_leader = leader_name == self_name if timeout == UNSET_RPC_TIMEOUT: # Timeout is specified by agent for RPC calls rpc_timeout = get_rpc_timeout() # No timeout for signal signal_timeout = None elif timeout == DEFAULT_SHUTDOWN_TIMEOUT: # No timeout for RPC rpc_timeout = timeout # No timeout for signal signal_timeout = None else: # Signal and RPC timeout use the same timeout signal_timeout = rpc_timeout = timeout # Phase 1: Followers send it's object to the leader if is_leader: _gather_to_leader(sequence_id, self_name, obj, worker_names) else: rpc_sync( leader_name, _gather_to_leader, args=(sequence_id, self_name, obj, worker_names), timeout=rpc_timeout, ) with _all_gather_dict_lock: states = _all_gather_sequence_id_to_states[sequence_id] # Timeout is either set by function parameter or None (which is indefinite) states.proceed_signal.wait(timeout=signal_timeout) # Phase 2: Leader broadcast gathered results to all followers # Leader's signal is the first to be unblocked, after receiving all # followers' data objects. if is_leader: worker_name_to_response_future_dict = dict() for follower_name in worker_names - {leader_name}: fut = rpc_async( follower_name, _broadcast_to_followers, args=(sequence_id, states.gathered_objects), timeout=rpc_timeout ) worker_name_to_response_future_dict[follower_name] = fut errors = [] for follower_name, fut in worker_name_to_response_future_dict.items(): try: fut.wait() except RuntimeError as ex: errors.append((follower_name, ex)) if errors: raise RuntimeError( f"Followers {[e[0] for e in errors]} timed out in _all_gather " f"after {rpc_timeout:.2f} seconds. The first exception is {errors[0][1]}" ) # Clean up for the states using the sequence_id with _all_gather_dict_lock: states = _all_gather_sequence_id_to_states.pop(sequence_id) return states.gathered_objects