Пример #1
0
def _invoke_rpc(to, func, rpc_type, args=None, kwargs=None, rpc_timeout=UNSET_RPC_TIMEOUT):
    if not callable(func):
        raise TypeError("function should be callable.")

    qualified_name = torch.jit._builtins._find_builtin(func)
    dst_worker_info = _to_worker_info(to)

    should_profile = torch.autograd._profiler_enabled()

    ctx_manager = _enable_rpc_profiler(should_profile, qualified_name, func, rpc_type, dst_worker_info)

    with ctx_manager as rf:
        args = args if args else ()
        kwargs = kwargs if kwargs else {}

        is_async_exec = hasattr(func, "_wrapped_async_rpc_function")

        if is_async_exec:
            wrapped = func._wrapped_async_rpc_function
            if isinstance(wrapped, torch.jit.ScriptFunction):
                func = wrapped

        if qualified_name is not None:
            fut = _invoke_rpc_builtin(
                dst_worker_info,
                qualified_name,
                rpc_timeout,
                *args,
                **kwargs
            )
        elif isinstance(func, torch.jit.ScriptFunction):
            fut = _invoke_rpc_torchscript(
                dst_worker_info.name,
                torch._jit_internal._qualified_name(func),
                args,
                kwargs,
                rpc_timeout,
                is_async_exec
            )
        else:
            (pickled_python_udf, tensors) = _default_pickler.serialize(
                PythonUDF(func, args, kwargs)
            )
            fut = _invoke_rpc_python_udf(
                dst_worker_info,
                pickled_python_udf,
                tensors,
                rpc_timeout,
                is_async_exec
            )
        if should_profile:
            assert torch.autograd._profiler_enabled()
            assert rf is not None
            # Schedule profiling callbacks to run when the future completes.
            # This returns a future that is completed when the original future
            # completes and the profiling callbacks have been completed as well,
            # to guarantee that fut.wait() completes the profiling. This new
            # future will contain the same value as the original future.
            fut = rf._call_end_callbacks_on_future(fut)
    return fut
Пример #2
0
def _invoke_rpc(to,
                func,
                rpc_type,
                args=None,
                kwargs=None,
                rpc_timeout=UNSET_RPC_TIMEOUT):
    if not callable(func):
        raise TypeError("function should be callable.")

    qualified_name = torch.jit._builtins._find_builtin(func)
    dst_worker_info = _to_worker_info(to)

    # TODO: profiling logic does not really belong in invoke_rpc, it should be
    # added as part of a context manager or helper (https://github.com/pytorch/pytorch/issues/36360)
    should_profile = torch.autograd._profiler_enabled()

    ctx_manager = contextlib.suppress()
    if should_profile:
        # Create appropriate string representation based on type of func
        # (builtin, script, python)
        if qualified_name is None:
            func_name = (torch._jit_internal._qualified_name(func)
                         if isinstance(func, torch.jit.ScriptFunction) else
                         func.__qualname__)
        else:
            func_name = qualified_name
        # Build RPC profiling key.
        rpc_profiling_key = _build_rpc_profiling_key(
            rpc_type,
            func_name,
            get_worker_info().name,
            dst_worker_info.name,
        )
        RemoteProfilerManager.set_current_profiling_key(rpc_profiling_key)
        # Mypy doesn't support re-def of a variable not in the same block (#1174)
        ctx_manager = torch.autograd.profiler.record_function(
            rpc_profiling_key)  # type: ignore[assignment]

    with ctx_manager as rf:
        args = args if args else ()
        kwargs = kwargs if kwargs else {}

        is_async_exec = hasattr(func, "_wrapped_async_rpc_function")

        if is_async_exec:
            wrapped = func._wrapped_async_rpc_function
            if isinstance(wrapped, torch.jit.ScriptFunction):
                func = wrapped

        if qualified_name is not None:
            fut = _invoke_rpc_builtin(dst_worker_info, qualified_name,
                                      rpc_timeout, *args, **kwargs)
        elif isinstance(func, torch.jit.ScriptFunction):
            fut = _invoke_rpc_torchscript(
                dst_worker_info.name,
                torch._jit_internal._qualified_name(func), args, kwargs,
                rpc_timeout, is_async_exec)
        else:
            (pickled_python_udf, tensors) = _default_pickler.serialize(
                PythonUDF(func, args, kwargs))
            fut = _invoke_rpc_python_udf(dst_worker_info, pickled_python_udf,
                                         tensors, rpc_timeout, is_async_exec)
        if should_profile:
            assert torch.autograd._profiler_enabled()
            assert rf is not None
            # Schedule profiling callbacks to run when the future completes.
            # This returns a future that is completed when the original future
            # completes and the profiling callbacks have been completed as well,
            # to guarantee that fut.wait() completes the profiling. This new
            # future will contain the same value as the original future.
            fut = rf._call_end_callbacks_on_future(fut)
    return fut