def _enable_rpc_profiler(should_profile, qualified_name, func, rpc_type, dst_worker_info): ctx_manager = contextlib.suppress() if should_profile: # Create appropriate string representation based on type of func # (builtin, script, python) if qualified_name is None: func_name = (torch._jit_internal._qualified_name(func) if isinstance(func, torch.jit.ScriptFunction) else func.__qualname__) else: func_name = qualified_name # Build RPC profiling key. rpc_profiling_key = _build_rpc_profiling_key( rpc_type, func_name, get_worker_info().name, dst_worker_info.name, ) RemoteProfilerManager.set_current_profiling_key(rpc_profiling_key) # Mypy doesn't support re-def of a variable not in the same block (#1174) ctx_manager = torch.autograd.profiler.record_function( rpc_profiling_key) # type: ignore[assignment] return ctx_manager
def _invoke_rpc(to, func, rpc_type, args=None, kwargs=None, rpc_timeout=UNSET_RPC_TIMEOUT): if not callable(func): raise TypeError("function should be callable.") qualified_name = torch.jit._builtins._find_builtin(func) dst_worker_info = _to_worker_info(to) # TODO: profiling logic does not really belong in invoke_rpc, it should be # added as part of a context manager or helper (https://github.com/pytorch/pytorch/issues/36360) should_profile = torch.autograd._profiler_enabled() ctx_manager = contextlib.suppress() if should_profile: # Create appropriate string representation based on type of func # (builtin, script, python) if qualified_name is None: func_name = (torch._jit_internal._qualified_name(func) if isinstance(func, torch.jit.ScriptFunction) else func.__qualname__) else: func_name = qualified_name # Build RPC profiling key. rpc_profiling_key = _build_rpc_profiling_key( rpc_type, func_name, get_worker_info().name, dst_worker_info.name, ) RemoteProfilerManager.set_current_profiling_key(rpc_profiling_key) # Mypy doesn't support re-def of a variable not in the same block (#1174) ctx_manager = torch.autograd.profiler.record_function( rpc_profiling_key) # type: ignore[assignment] with ctx_manager as rf: args = args if args else () kwargs = kwargs if kwargs else {} is_async_exec = hasattr(func, "_wrapped_async_rpc_function") if is_async_exec: wrapped = func._wrapped_async_rpc_function if isinstance(wrapped, torch.jit.ScriptFunction): func = wrapped if qualified_name is not None: fut = _invoke_rpc_builtin(dst_worker_info, qualified_name, rpc_timeout, *args, **kwargs) elif isinstance(func, torch.jit.ScriptFunction): fut = _invoke_rpc_torchscript( dst_worker_info.name, torch._jit_internal._qualified_name(func), args, kwargs, rpc_timeout, is_async_exec) else: (pickled_python_udf, tensors) = _default_pickler.serialize( PythonUDF(func, args, kwargs)) fut = _invoke_rpc_python_udf(dst_worker_info, pickled_python_udf, tensors, rpc_timeout, is_async_exec) if should_profile: assert torch.autograd._profiler_enabled() assert rf is not None # Schedule profiling callbacks to run when the future completes. # This returns a future that is completed when the original future # completes and the profiling callbacks have been completed as well, # to guarantee that fut.wait() completes the profiling. This new # future will contain the same value as the original future. fut = rf._call_end_callbacks_on_future(fut) return fut
def remote(to, func, args=None, kwargs=None, timeout=UNSET_RPC_TIMEOUT): r""" Make a remote call to run ``func`` on worker ``to`` and return an :class:`~torch.distributed.rpc.RRef` to the result value immediately. Worker ``to`` will be the owner of the returned :class:`~torch.distributed.rpc.RRef`, and the worker calling ``remote`` is a user. The owner manages the global reference count of its :class:`~torch.distributed.rpc.RRef`, and the owner :class:`~torch.distributed.rpc.RRef` is only destructed when globally there are no living references to it. Args: to (str or WorkerInfo or int): name/rank/``WorkerInfo`` of the destination worker. func (callable): a callable function, such as Python callables, builtin operators (e.g. :meth:`~torch.add`) and annotated TorchScript functions. args (tuple): the argument tuple for the ``func`` invocation. kwargs (dict): is a dictionary of keyword arguments for the ``func`` invocation. timeout (float, optional): timeout in seconds for this remote call. If the creation of this :class:`~torch.distributed.rpc.RRef` on worker ``to`` is not successfully processed on this worker within this timeout, then the next time there is an attempt to use the RRef (such as ``to_here()``), a timeout will be raised indicating this failure. A value of 0 indicates an infinite timeout, i.e. a timeout error will never be raised. If not provided, the default value set during initialization or with ``_set_rpc_timeout`` is used. Returns: A user :class:`~torch.distributed.rpc.RRef` instance to the result value. Use the blocking API :meth:`torch.distributed.rpc.RRef.to_here` to retrieve the result value locally. .. warning :: Using GPU tensors as arguments or return values of ``func`` is not supported since we don't support sending GPU tensors over the wire. You need to explicitly copy GPU tensors to CPU before using them as arguments or return values of ``func``. .. warning :: The ``remote`` API does not copy storages of argument tensors until sending them over the wire, which could be done by a different thread depending on the RPC backend type. The caller should make sure that the contents of those tensors stay intact until the returned RRef is confirmed by the owner, which can be checked using the :meth:`torch.distributed.rpc.RRef.confirmed_by_owner` API. .. warning :: Errors such as timeouts for the ``remote`` API are handled on a best-effort basis. This means that when remote calls initiated by ``remote`` fail, such as with a timeout error, we take a best-effort approach to error handling. This means that errors are handled and set on the resulting RRef on an asynchronous basis. If the RRef has not been used by the application before this handling (such as ``to_here`` or fork call), then future uses of the ``RRef`` will appropriately raise errors. However, it is possible that the user application will use the ``RRef`` before the errors are handled. In this case, errors may not be raised as they have not yet been handled. Example:: Make sure that ``MASTER_ADDR`` and ``MASTER_PORT`` are set properly on both workers. Refer to :meth:`~torch.distributed.init_process_group` API for more details. For example, >>> export MASTER_ADDR=localhost >>> export MASTER_PORT=5678 Then run the following code in two different processes: >>> # On worker 0: >>> import torch >>> import torch.distributed.rpc as rpc >>> rpc.init_rpc("worker0", rank=0, world_size=2) >>> rref1 = rpc.remote("worker1", torch.add, args=(torch.ones(2), 3)) >>> rref2 = rpc.remote("worker1", torch.add, args=(torch.ones(2), 1)) >>> x = rref1.to_here() + rref2.to_here() >>> rpc.shutdown() >>> # On worker 1: >>> import torch.distributed.rpc as rpc >>> rpc.init_rpc("worker1", rank=1, world_size=2) >>> rpc.shutdown() Below is an example of running a TorchScript function using RPC. >>> # On both workers: >>> @torch.jit.script >>> def my_script_add(t1, t2): >>> return torch.add(t1, t2) >>> # On worker 0: >>> import torch.distributed.rpc as rpc >>> rpc.init_rpc("worker0", rank=0, world_size=2) >>> rref = rpc.remote("worker1", my_script_add, args=(torch.ones(2), 3)) >>> rref.to_here() >>> rpc.shutdown() >>> # On worker 1: >>> import torch.distributed.rpc as rpc >>> rpc.init_rpc("worker1", rank=1, world_size=2) >>> rpc.shutdown() """ qualified_name = torch.jit._builtins._find_builtin(func) dst_worker_info = _to_worker_info(to) should_profile = torch.autograd._profiler_enabled() ctx_manager = contextlib.suppress() if should_profile: # Create appropriate string representation based on type of func # (builtin, script, python) if qualified_name is None: func_name = (torch._jit_internal._qualified_name(func) if isinstance(func, torch.jit.ScriptFunction) else func.__qualname__) else: func_name = qualified_name # Build RPC profiling key. rpc_profiling_key = _build_rpc_profiling_key( RPCExecMode.REMOTE, func_name, get_worker_info().name, dst_worker_info.name, ) RemoteProfilerManager.set_current_profiling_key(rpc_profiling_key) # Mypy doesn't support re-def of a variable not in the same block (#1174) ctx_manager = torch.autograd.profiler.record_function( rpc_profiling_key) # type: ignore[assignment] with ctx_manager as rf: args = args if args else () kwargs = kwargs if kwargs else {} is_async_exec = hasattr(func, "_wrapped_async_rpc_function") if is_async_exec: wrapped = func._wrapped_async_rpc_function if isinstance(wrapped, torch.jit.ScriptFunction): func = wrapped if qualified_name is not None: rref = _invoke_remote_builtin(dst_worker_info, qualified_name, timeout, *args, **kwargs) elif isinstance(func, torch.jit.ScriptFunction): rref = _invoke_remote_torchscript( dst_worker_info.name, torch._jit_internal._qualified_name(func), timeout, is_async_exec, *args, **kwargs, ) else: (pickled_python_udf, tensors) = _default_pickler.serialize( PythonUDF(func, args, kwargs)) rref = _invoke_remote_python_udf(dst_worker_info, pickled_python_udf, tensors, timeout, is_async_exec) # attach profiling information if should_profile: assert torch.autograd._profiler_enabled() assert rf is not None fut = rf._call_end_callbacks_on_future(rref._get_future()) rref._set_profiling_future(fut) return rref