Exemple #1
0
def shutdown(graceful=True):
    r"""
    Perform a shutdown of the RPC agent, and then destroy the RPC agent. This
    stops the local agent from accepting outstanding requests, and shuts
    down the RPC framework by terminating all RPC threads. If ``graceful=True``,
    this will block until all local and remote RPC processes reach this method
    and wait for all outstanding work to complete. Otherwise, if
    ``graceful=False``, this is a local shutdown, and it does not wait for other
    RPC processes to reach this method.

    .. warning::
        For :class:`~torch.futures.Future` objects returned by
        :meth:`~torch.distributed.rpc.rpc_async`, ``future.wait()`` should not
        be called after ``shutdown()``.

    Args:
        graceful (bool): Whether to do a graceful shutdown or not. If True,
                         this will 1) wait until there is no pending system
                         messages for ``UserRRefs`` and delete them; 2) block
                         until all local and remote RPC processes have reached
                         this method and wait for all outstanding work to
                         complete.

    Example::
        Make sure that ``MASTER_ADDR`` and ``MASTER_PORT`` are set properly
        on both workers. Refer to :meth:`~torch.distributed.init_process_group`
        API for more details. For example,

        >>> export MASTER_ADDR=localhost
        >>> export MASTER_PORT=5678

        Then run the following code in two different processes:

        >>> # On worker 0:
        >>> import torch
        >>> import torch.distributed.rpc as rpc
        >>> rpc.init_rpc("worker0", rank=0, world_size=2)
        >>> # do some work
        >>> result = rpc.rpc_sync("worker1", torch.add, args=(torch.ones(1), 1))
        >>> # ready to shutdown
        >>> rpc.shutdown()

        >>> # On worker 1:
        >>> import torch.distributed.rpc as rpc
        >>> rpc.init_rpc("worker1", rank=1, world_size=2)
        >>> # wait for worker 0 to finish work, and then shutdown.
        >>> rpc.shutdown()
    """
    if graceful:
        try:
            _wait_all_workers()
            _delete_all_user_and_unforked_owner_rrefs()
            _get_current_rpc_agent().join(shutdown=True)
        finally:
            # In case of errors, continue to complete the local shutdown.
            _finalize_shutdown()
    else:
        _finalize_shutdown()
Exemple #2
0
def get_worker_info(worker_name=None):
    r"""
    Get :class:`~torch.distributed.rpc.WorkerInfo` of a given worker name.
    Use this :class:`~torch.distributed.rpc.WorkerInfo` to avoid passing an
    expensive string on every invocation.

    Args:
        worker_name (str): the string name of a worker. If ``None``, return the
                           the id of the current worker. (default ``None``)

    Returns:
        :class:`~torch.distributed.rpc.WorkerInfo` instance for the given
        ``worker_name`` or :class:`~torch.distributed.rpc.WorkerInfo` of the
        current worker if ``worker_name`` is ``None``.
    """
    if worker_name is not None:
        return _get_current_rpc_agent().get_worker_info(worker_name)
    else:
        return _get_current_rpc_agent().get_worker_info()
Exemple #3
0
def _finalize_shutdown():
    try:
        # This raises a `TORCH_CHECK()` exception on RRef leak detected.
        _destroy_rref_context(_ignore_rref_leak)
    finally:
        _get_current_rpc_agent().shutdown()
        # clean up python rpc handler in shutdown(), see comments in
        # PythonRpcHandler::cleanup(), call it in python API because the
        # cleanup() function has python dependency, it assumes python
        # interpreter exists.
        # No matter if RRef leak exception is raised, this clean-up code
        # must run to avoid destruction segfault in Python 3.5.
        #
        # future.wait() should not be called after shutdown().
        # pythonRpcHandler is cleaned up in shutdown(), after
        # shutdown(), python objects returned from rpc python call can not be
        # resolved.
        _cleanup_python_rpc_handler()
        _reset_current_rpc_agent()
Exemple #4
0
def shutdown(graceful=True):
    r"""
    Perform a shutdown of the RPC agent, and then destroy the RPC agent. This
    stops the local agent from accepting outstanding requests, and shuts
    down the RPC framework by terminating all RPC threads. If ``graceful=True``,
    this will block until all local and remote RPC processes reach this method
    and wait for all outstanding work to complete. Otherwise, if
    ``graceful=False``, this is a local shutdown, and it does not wait for other
    RPC processes to reach this method.

    .. warning::
        For :class:`~torch.futures.Future` objects returned by
        :meth:`~torch.distributed.rpc.rpc_async`, ``future.wait()`` should not
        be called after ``shutdown()``.

    Args:
        graceful (bool): Whether to do a graceful shutdown or not. If True,
                         this will 1) wait until there is no pending system
                         messages for ``UserRRefs`` and delete them; 2) block
                         until all local and remote RPC processes have reached
                         this method and wait for all outstanding work to
                         complete.

    Example::
        Make sure that ``MASTER_ADDR`` and ``MASTER_PORT`` are set properly
        on both workers. Refer to :meth:`~torch.distributed.init_process_group`
        API for more details. For example,

        >>> export MASTER_ADDR=localhost
        >>> export MASTER_PORT=5678

        Then run the following code in two different processes:

        >>> # On worker 0:
        >>> import torch
        >>> import torch.distributed.rpc as rpc
        >>> rpc.init_rpc("worker0", rank=0, world_size=2)
        >>> # do some work
        >>> result = rpc.rpc_sync("worker1", torch.add, args=(torch.ones(1), 1))
        >>> # ready to shutdown
        >>> rpc.shutdown()

        >>> # On worker 1:
        >>> import torch.distributed.rpc as rpc
        >>> rpc.init_rpc("worker1", rank=1, world_size=2)
        >>> # wait for worker 0 to finish work, and then shutdown.
        >>> rpc.shutdown()
    """
    if graceful:
        _wait_all_workers()
        _delete_all_user_and_unforked_owner_rrefs()
        _get_current_rpc_agent().join(shutdown=True)
    try:
        # This raises a `TORCH_CHECK()` exception on RRef leak detected.
        _destroy_rref_context(_ignore_rref_leak)
    finally:
        _get_current_rpc_agent().shutdown()
        # clean up python rpc handler in shutdown(), see comments in
        # PythonRpcHandler::cleanup(), call it in python API because the
        # cleanup() function has python dependency, it assumes python
        # interpreter exists.
        # No matter if RRef leak exception is raised, this clean-up code
        # must run to avoid destruction segfault in Python 3.5.
        #
        # future.wait() should not be called after shutdown().
        # pythonRpcHandler is cleaned up in shutdown(), after
        # shutdown(), python objects returned from rpc python call can not be
        # resolved.
        _cleanup_python_rpc_handler()
        _reset_current_rpc_agent()
Exemple #5
0
def _all_gather(obj, timeout=UNSET_RPC_TIMEOUT):
    r"""
    This is similar to torch.distributed.all_gather(), but is using RPC. It
    picks the worker with the smallest name (alphabetic order) as the leader.
    Then all followers send their data ``obj`` to the leader. After the leader
    has received all, it will broadcast the results back to all followers. This
    function blocks until all workers have received the gathered results.
    """
    assert (_ALL_WORKER_NAMES is not None
            ), "`_ALL_WORKER_NAMES` is not initialized for `def _all_gather`."
    leader_name = sorted(_ALL_WORKER_NAMES)[0]

    self_name = _get_current_rpc_agent().get_worker_info().name

    global _all_gather_sequence_id
    with _all_gather_dict_lock:
        sequence_id = _all_gather_sequence_id
        _all_gather_sequence_id += 1

    is_leader = leader_name == self_name
    if timeout == UNSET_RPC_TIMEOUT:
        timeout = get_rpc_timeout()

    # Phase 1: Followers send it's object to the leader
    if is_leader:
        _gather_to_leader(sequence_id, self_name, obj)
    else:
        rpc_sync(
            leader_name,
            _gather_to_leader,
            args=(sequence_id, self_name, obj),
            timeout=timeout,
        )

    with _all_gather_dict_lock:
        states = _all_gather_sequence_id_to_states[sequence_id]
    states.proceed_signal.wait()

    # Phase 2: Leader broadcast gathered results to all followers
    # Leader's signal is the first to be unblocked, after receiving all
    # followers' data objects.
    if is_leader:
        worker_name_to_response_future_dict = dict()
        for follower_name in _ALL_WORKER_NAMES - {leader_name}:
            fut = rpc_async(follower_name,
                            _broadcast_to_followers,
                            args=(sequence_id, states.gathered_objects),
                            timeout=timeout)
            worker_name_to_response_future_dict[follower_name] = fut

        errors = []
        for follower_name, fut in worker_name_to_response_future_dict.items():
            try:
                fut.wait()
            except RuntimeError as ex:
                errors.append((follower_name, ex))

        if errors:
            raise RuntimeError(
                f"Followers {[e[0] for e in errors]} timed out in _all_gather "
                f"after {timeout:.2f} seconds. The first exception is {errors[0][1]}"
            )

    return states.gathered_objects
Exemple #6
0
def shutdown(graceful=True, timeout=DEFAULT_SHUTDOWN_TIMEOUT):
    r"""
    Perform a shutdown of the RPC agent, and then destroy the RPC agent. This
    stops the local agent from accepting outstanding requests, and shuts
    down the RPC framework by terminating all RPC threads. If ``graceful=True``,
    this will block until all local and remote RPC processes reach this method
    and wait for all outstanding work to complete. Otherwise, if
    ``graceful=False``, this is a local shutdown, and it does not wait for other
    RPC processes to reach this method.

    .. warning::
        For :class:`~torch.futures.Future` objects returned by
        :meth:`~torch.distributed.rpc.rpc_async`, ``future.wait()`` should not
        be called after ``shutdown()``.

    Args:
        graceful (bool): Whether to do a graceful shutdown or not. If True,
                         this will 1) wait until there is no pending system
                         messages for ``UserRRefs`` and delete them; 2) block
                         until all local and remote RPC processes have reached
                         this method and wait for all outstanding work to
                         complete.

    Example::
        Make sure that ``MASTER_ADDR`` and ``MASTER_PORT`` are set properly
        on both workers. Refer to :meth:`~torch.distributed.init_process_group`
        API for more details. For example,

        >>> export MASTER_ADDR=localhost
        >>> export MASTER_PORT=5678

        Then run the following code in two different processes:

        >>> # On worker 0:
        >>> import torch
        >>> import torch.distributed.rpc as rpc
        >>> rpc.init_rpc("worker0", rank=0, world_size=2)
        >>> # do some work
        >>> result = rpc.rpc_sync("worker1", torch.add, args=(torch.ones(1), 1))
        >>> # ready to shutdown
        >>> rpc.shutdown()

        >>> # On worker 1:
        >>> import torch.distributed.rpc as rpc
        >>> rpc.init_rpc("worker1", rank=1, world_size=2)
        >>> # wait for worker 0 to finish work, and then shutdown.
        >>> rpc.shutdown()
    """
    if graceful:
        try:
            agent = _get_current_rpc_agent()
            if not isinstance(agent, TensorPipeAgent) or agent.is_static_group:
                _wait_all_workers(timeout)
                _delete_all_user_and_unforked_owner_rrefs()
                agent.join(shutdown=True, timeout=timeout)
            else:
                # This is a dynamic group so we need to grab the token for the operation
                my_worker_info = agent.get_worker_info()
                my_name = my_worker_info.name
                with _group_membership_management(agent.store, my_name, False):
                    all_worker_infos = agent.get_worker_infos()
                    for worker in all_worker_infos:
                        if worker.name != my_name:
                            rpc_sync(worker.name, _update_group_membership, args=(my_worker_info, [], {}, False))
                    agent.join(shutdown=True, timeout=timeout)
        finally:
            # In case of errors, continue to complete the local shutdown.
            _finalize_shutdown()
    else:
        _finalize_shutdown()
Exemple #7
0
def _all_gather(obj, worker_names=None, timeout=UNSET_RPC_TIMEOUT):
    r"""
    This is similar to torch.distributed.all_gather(), but is using RPC. It
    picks the worker with the smallest name (alphabetic order) as the leader.
    Then all followers send their data ``obj`` to the leader. After the leader
    has received all, it will broadcast the results back to all followers. This
    function blocks until all workers have received the gathered results.
    """
    if not worker_names:
        assert (
            _ALL_WORKER_NAMES is not None
        ), "`_ALL_WORKER_NAMES` is not initialized for `def _all_gather`."
        worker_names = _ALL_WORKER_NAMES
    leader_name = sorted(worker_names)[0]

    self_name = _get_current_rpc_agent().get_worker_info().name

    with _all_gather_dict_lock:
        concat_names = "".join(sorted(worker_names))
        sequence_num = _all_gather_sequence_id.get(concat_names, 0)
        _all_gather_sequence_id[concat_names] = sequence_num + 1
        sequence_id = concat_names + str(sequence_num)

    is_leader = leader_name == self_name

    if timeout == UNSET_RPC_TIMEOUT:
        # Timeout is specified by agent for RPC calls
        rpc_timeout = get_rpc_timeout()
        # No timeout for signal
        signal_timeout = None
    elif timeout == DEFAULT_SHUTDOWN_TIMEOUT:
        # No timeout for RPC
        rpc_timeout = timeout
        # No timeout for signal
        signal_timeout = None
    else:
        # Signal and RPC timeout use the same timeout
        signal_timeout = rpc_timeout = timeout

    # Phase 1: Followers send it's object to the leader
    if is_leader:
        _gather_to_leader(sequence_id, self_name, obj, worker_names)
    else:
        rpc_sync(
            leader_name,
            _gather_to_leader,
            args=(sequence_id, self_name, obj, worker_names),
            timeout=rpc_timeout,
        )

    with _all_gather_dict_lock:
        states = _all_gather_sequence_id_to_states[sequence_id]

    # Timeout is either set by function parameter or None (which is indefinite)
    states.proceed_signal.wait(timeout=signal_timeout)

    # Phase 2: Leader broadcast gathered results to all followers
    # Leader's signal is the first to be unblocked, after receiving all
    # followers' data objects.
    if is_leader:
        worker_name_to_response_future_dict = dict()
        for follower_name in worker_names - {leader_name}:
            fut = rpc_async(
                follower_name,
                _broadcast_to_followers,
                args=(sequence_id, states.gathered_objects),
                timeout=rpc_timeout
            )
            worker_name_to_response_future_dict[follower_name] = fut

        errors = []
        for follower_name, fut in worker_name_to_response_future_dict.items():
            try:
                fut.wait()
            except RuntimeError as ex:
                errors.append((follower_name, ex))

        if errors:
            raise RuntimeError(
                f"Followers {[e[0] for e in errors]} timed out in _all_gather "
                f"after {rpc_timeout:.2f} seconds. The first exception is {errors[0][1]}"
            )

    # Clean up for the states using the sequence_id
    with _all_gather_dict_lock:
        states = _all_gather_sequence_id_to_states.pop(sequence_id)
    return states.gathered_objects