def _register_comm_hook(self, state: object, hook: callable): r""" Register a communication hook which is an enhancement that provides a flexible hook to users where they can specify how DDP aggregates gradients across multiple workers. This hook would be very useful for researchers to try out new ideas. For example, this hook can be used to implement several algorithms like GossipGrad and gradient compression which involve different communication strategies for parameter syncs while running Distributed DataParallel training. Arguments: state (object): state is passed to the hook and can be used to maintain and update any state information that users would like to maintain as part of the training process. Examples: error feedback in gradient compression, peers to communicate with next in GossipGrad etc. hook (callable): is defined as: hook(state: object, bucket: dist.GradBucket) -> torch.futures.Future: This function is called once the bucket is ready. The hook can perform whatever processing is needed and return a Future indicating completion of any async work (ex: allreduce). If the hook doesn't perform any communication, it can also just return a completed Future. The Future should hold the new value of grad bucket's tensors. Once a bucket is ready, c10d reducer would call this hook and use the tensors returned by the Future and copy grads to individual parameters. .. warning :: DDP communication hook can only be registered once and should be registered before calling backward. .. warning :: The torch.futures.Future object that hook returns should contain a result that has the same shape with the tensors inside GradBucket bucket. .. warning :: DDP communication hook is experimental and subject to change. Example:: Below is an example of a noop hook that returns back the same tensors: >>> ddp._register_comm_hook(state = None, hook = noop) >>> def noop(state: object, bucket: dist.GradBucket): -> torch.futures.Future >>> fut = torch.futures.Future() >>> fut.set_result(bucket.get_tensors()) >>> return fut """ self._check_comm_hook(hook) dist._register_comm_hook(self.reducer, state, hook)
def _register_comm_hook(self, state: object, hook: callable): r""" Register a communication hook which is an enhancement that provides a flexible hook to users where they can specify how DDP aggregates gradients across multiple workers. This hook would be very useful for researchers to try out new ideas. For example, this hook can be used to implement several algorithms like GossipGrad and gradient compression which involve different communication strategies for parameter syncs while running Distributed DataParallel training. Arguments: state (object): state is passed to the hook and can be used to maintain and update any state information that users would like to maintain as part of the training process. Examples: error feedback in gradient compression, peers to communicate with next in GossipGrad etc. hook (callable): is defined as: hook(state: object, bucket: dist._GradBucket) -> torch.futures.Future: This function is called once the bucket is ready. The hook can perform whatever processing is needed and return a Future indicating completion of any async work (ex: allreduce). If the hook doesn't perform any communication, it can also just return a completed Future. The Future should hold the new value of grad bucket's tensors. Once a bucket is ready, c10d reducer would call this hook and use the tensors returned by the Future and copy grads to individual parameters. We also provide an API called ``get_future`` to retrieve a Future associated with the completion of ``c10d.ProcessGroup.work``. .. warning :: Grad bucket's tensors will not be predivided by world_size. User is responsible to divide by the world_size in case of operations like allreduce. .. warning :: DDP communication hook can only be registered once and should be registered before calling backward. .. warning :: The Future object that hook returns should contain a result that has the same shape with the tensors inside grad bucket. .. warning :: DDP communication hook does not support single-process multiple-device mode. Gradbucket tensors should consist of only a single tensor. .. warning :: ``get_future`` API supports only NCCL backend and will return a ``torch._C.Future`` which is an internal type and should be used with caution. It can still be used by ``_register_comm_hook`` API, but it is subject to some subtle differences compared to ``torch.futures.Future``. .. warning :: DDP communication hook is experimental and subject to change. Example:: Below is an example of a noop hook that returns back the same tensors: >>> def noop(state: object, bucket: dist._GradBucket): -> torch.futures.Future >>> fut = torch.futures.Future() >>> fut.set_result(bucket.get_tensors()) >>> return fut >>> ddp._register_comm_hook(state = None, hook = noop) Example:: Below is an example of a Parallel SGD algorithm where gradients are encoded before allreduce, and then decoded after allreduce. >>> def encode_and_decode(state: object, bucket: dist._GradBucket): -> torch.futures.Future >>> tensors = [t / process_group.world_size for t in bucket.get_tensors()] >>> encoded_tensors = encode(tensors) # encode gradients >>> fut = process_group.allreduce(encoded_tensors).get_future() >>> # Define the then callback to decode. >>> def decode(fut): >>> decoded_tensors = decode(fut.value()) # decode gradients >>> return decoded_tensors >>> return fut.then(decode) >>> ddp._register_comm_hook(state = None, hook = encode_and_decode) """ self._check_comm_hook(hook) dist._register_comm_hook(self.reducer, state, hook)