def _invoke_rpc(to, func, rpc_type, args=None, kwargs=None, rpc_timeout=UNSET_RPC_TIMEOUT): if not callable(func): raise TypeError("function should be callable.") qualified_name = torch.jit._builtins._find_builtin(func) dst_worker_info = _to_worker_info(to) should_profile = torch.autograd._profiler_enabled() ctx_manager = _enable_rpc_profiler(should_profile, qualified_name, func, rpc_type, dst_worker_info) with ctx_manager as rf: args = args if args else () kwargs = kwargs if kwargs else {} is_async_exec = hasattr(func, "_wrapped_async_rpc_function") if is_async_exec: wrapped = func._wrapped_async_rpc_function if isinstance(wrapped, torch.jit.ScriptFunction): func = wrapped if qualified_name is not None: fut = _invoke_rpc_builtin( dst_worker_info, qualified_name, rpc_timeout, *args, **kwargs ) elif isinstance(func, torch.jit.ScriptFunction): fut = _invoke_rpc_torchscript( dst_worker_info.name, torch._jit_internal._qualified_name(func), args, kwargs, rpc_timeout, is_async_exec ) else: (pickled_python_udf, tensors) = _default_pickler.serialize( PythonUDF(func, args, kwargs) ) fut = _invoke_rpc_python_udf( dst_worker_info, pickled_python_udf, tensors, rpc_timeout, is_async_exec ) if should_profile: assert torch.autograd._profiler_enabled() assert rf is not None # Schedule profiling callbacks to run when the future completes. # This returns a future that is completed when the original future # completes and the profiling callbacks have been completed as well, # to guarantee that fut.wait() completes the profiling. This new # future will contain the same value as the original future. fut = rf._call_end_callbacks_on_future(fut) return fut
def _invoke_rpc(to, func, rpc_type, args=None, kwargs=None, rpc_timeout=UNSET_RPC_TIMEOUT): if not callable(func): raise TypeError("function should be callable.") qualified_name = torch.jit._builtins._find_builtin(func) dst_worker_info = _to_worker_info(to) # TODO: profiling logic does not really belong in invoke_rpc, it should be # added as part of a context manager or helper (https://github.com/pytorch/pytorch/issues/36360) should_profile = torch.autograd._profiler_enabled() ctx_manager = contextlib.suppress() if should_profile: # Create appropriate string representation based on type of func # (builtin, script, python) if qualified_name is None: func_name = (torch._jit_internal._qualified_name(func) if isinstance(func, torch.jit.ScriptFunction) else func.__qualname__) else: func_name = qualified_name # Build RPC profiling key. rpc_profiling_key = _build_rpc_profiling_key( rpc_type, func_name, get_worker_info().name, dst_worker_info.name, ) RemoteProfilerManager.set_current_profiling_key(rpc_profiling_key) # Mypy doesn't support re-def of a variable not in the same block (#1174) ctx_manager = torch.autograd.profiler.record_function( rpc_profiling_key) # type: ignore[assignment] with ctx_manager as rf: args = args if args else () kwargs = kwargs if kwargs else {} is_async_exec = hasattr(func, "_wrapped_async_rpc_function") if is_async_exec: wrapped = func._wrapped_async_rpc_function if isinstance(wrapped, torch.jit.ScriptFunction): func = wrapped if qualified_name is not None: fut = _invoke_rpc_builtin(dst_worker_info, qualified_name, rpc_timeout, *args, **kwargs) elif isinstance(func, torch.jit.ScriptFunction): fut = _invoke_rpc_torchscript( dst_worker_info.name, torch._jit_internal._qualified_name(func), args, kwargs, rpc_timeout, is_async_exec) else: (pickled_python_udf, tensors) = _default_pickler.serialize( PythonUDF(func, args, kwargs)) fut = _invoke_rpc_python_udf(dst_worker_info, pickled_python_udf, tensors, rpc_timeout, is_async_exec) if should_profile: assert torch.autograd._profiler_enabled() assert rf is not None # Schedule profiling callbacks to run when the future completes. # This returns a future that is completed when the original future # completes and the profiling callbacks have been completed as well, # to guarantee that fut.wait() completes the profiling. This new # future will contain the same value as the original future. fut = rf._call_end_callbacks_on_future(fut) return fut