예제 #1
0
    def _register_comm_hook(self, state: object, hook: callable):
        r"""
        Register a communication hook which is an enhancement that provides a
        flexible hook to users where they can specify how DDP aggregates gradients
        across multiple workers.

        This hook would be very useful for researchers to try out new ideas. For
        example, this hook can be used to implement several algorithms like GossipGrad
        and gradient compression which involve different communication strategies for
        parameter syncs while running Distributed DataParallel training.

        Arguments:
            state (object): state is passed to the hook and can be used to maintain
                            and update any state information that users would like to
                            maintain as part of the training process. Examples: error
                            feedback in gradient compression, peers to communicate with
                            next in GossipGrad etc.
            hook (callable): is defined as:
                             hook(state: object, bucket: dist.GradBucket) -> torch.futures.Future:

                             This function is called once the bucket is ready. The
                             hook can perform whatever processing is needed and return
                             a Future indicating completion of any async work (ex: allreduce).
                             If the hook doesn't perform any communication, it can also
                             just return a completed Future. The Future should hold the
                             new value of grad bucket's tensors. Once a bucket is ready,
                             c10d reducer would call this hook and use the tensors returned
                             by the Future and copy grads to individual parameters.

        .. warning ::
            DDP communication hook can only be registered once and should be registered
            before calling backward.

        .. warning ::
            The torch.futures.Future object that hook returns should contain a result that
            has the same shape with the tensors inside GradBucket bucket.

        .. warning ::
            DDP communication hook is experimental and subject to change.

        Example::
            Below is an example of a noop hook that returns back the same tensors:

            >>> ddp._register_comm_hook(state = None, hook = noop)

            >>> def noop(state: object, bucket: dist.GradBucket): -> torch.futures.Future
            >>>     fut = torch.futures.Future()
            >>>     fut.set_result(bucket.get_tensors())
            >>>     return fut

        """
        self._check_comm_hook(hook)
        dist._register_comm_hook(self.reducer, state, hook)
예제 #2
0
    def _register_comm_hook(self, state: object, hook: callable):
        r"""
        Register a communication hook which is an enhancement that provides a
        flexible hook to users where they can specify how DDP aggregates gradients
        across multiple workers.

        This hook would be very useful for researchers to try out new ideas. For
        example, this hook can be used to implement several algorithms like GossipGrad
        and gradient compression which involve different communication strategies for
        parameter syncs while running Distributed DataParallel training.

        Arguments:
            state (object): state is passed to the hook and can be used to maintain
                            and update any state information that users would like to
                            maintain as part of the training process. Examples: error
                            feedback in gradient compression, peers to communicate with
                            next in GossipGrad etc.
            hook (callable): is defined as:
                             hook(state: object, bucket: dist._GradBucket) -> torch.futures.Future:

                             This function is called once the bucket is ready. The
                             hook can perform whatever processing is needed and return
                             a Future indicating completion of any async work (ex: allreduce).
                             If the hook doesn't perform any communication, it can also
                             just return a completed Future. The Future should hold the
                             new value of grad bucket's tensors. Once a bucket is ready,
                             c10d reducer would call this hook and use the tensors returned
                             by the Future and copy grads to individual parameters.

                             We also provide an API called ``get_future`` to retrieve a
                             Future associated with the completion of ``c10d.ProcessGroup.work``.

        .. warning ::
            Grad bucket's tensors will not be predivided by world_size. User is responsible
            to divide by the world_size in case of operations like allreduce.

        .. warning ::
            DDP communication hook can only be registered once and should be registered
            before calling backward.

        .. warning ::
            The Future object that hook returns should contain a result that has the same
            shape with the tensors inside grad bucket.

        .. warning ::
            DDP communication hook does not support single-process multiple-device mode.
            Gradbucket tensors should consist of only a single tensor.

        .. warning ::
            ``get_future`` API supports only NCCL backend and will return a ``torch._C.Future``
            which is an internal type and should be used with caution. It can still be used by
            ``_register_comm_hook`` API, but it is subject to some subtle differences compared
            to ``torch.futures.Future``.

        .. warning ::
            DDP communication hook is experimental and subject to change.

        Example::
            Below is an example of a noop hook that returns back the same tensors:

            >>> def noop(state: object, bucket: dist._GradBucket): -> torch.futures.Future
            >>>     fut = torch.futures.Future()
            >>>     fut.set_result(bucket.get_tensors())
            >>>     return fut

            >>> ddp._register_comm_hook(state = None, hook = noop)

        Example::
            Below is an example of a Parallel SGD algorithm where gradients are encoded before
            allreduce, and then decoded after allreduce.

            >>> def encode_and_decode(state: object, bucket: dist._GradBucket): -> torch.futures.Future
            >>>     tensors = [t / process_group.world_size for t in bucket.get_tensors()]
            >>>     encoded_tensors = encode(tensors) # encode gradients
            >>>     fut = process_group.allreduce(encoded_tensors).get_future()
            >>>     # Define the then callback to decode.
            >>>     def decode(fut):
            >>>         decoded_tensors = decode(fut.value()) # decode gradients
            >>>         return decoded_tensors
            >>>     return fut.then(decode)

            >>> ddp._register_comm_hook(state = None, hook = encode_and_decode)

        """
        self._check_comm_hook(hook)
        dist._register_comm_hook(self.reducer, state, hook)