Ejemplo n.º 1
0
 def _register_ddp_hooks(self) -> None:
     log.detail(f"{self.__class__.__name__}: registering ddp hooks")
     if self.root_device.type == "cuda" and self._is_single_process_single_device:
         register_ddp_comm_hook(
             model=self.model,
             ddp_comm_state=self._ddp_comm_state,
             ddp_comm_hook=self._ddp_comm_hook,
             ddp_comm_wrapper=self._ddp_comm_wrapper,
         )
Ejemplo n.º 2
0
 def _register_ddp_hooks(self) -> None:
     # currently, DDP communication hooks only work with NCCL backend and SPSD (single process single device) mode
     # https://github.com/pytorch/pytorch/blob/v1.8.0/torch/nn/parallel/distributed.py#L1080-L1084
     if _TORCH_GREATER_EQUAL_1_8 and self.on_gpu and self._is_single_process_single_device:
         register_ddp_comm_hook(
             model=self._model,
             ddp_comm_state=self._ddp_comm_state,
             ddp_comm_hook=self._ddp_comm_hook,
             ddp_comm_wrapper=self._ddp_comm_wrapper,
         )
Ejemplo n.º 3
0
 def _register_ddp_hooks(self) -> None:
     # In 1.8, DDP communication hooks only work with NCCL backend and SPSD (single process single device) mode
     # Since 1.9, DDP communication hooks can work on all backends.
     if _TORCH_GREATER_EQUAL_1_9 or (_TORCH_GREATER_EQUAL_1_8
                                     and self.on_gpu and
                                     self._is_single_process_single_device):
         register_ddp_comm_hook(
             model=self._model,
             ddp_comm_state=self._ddp_comm_state,
             ddp_comm_hook=self._ddp_comm_hook,
             ddp_comm_wrapper=self._ddp_comm_wrapper,
         )
Ejemplo n.º 4
0
    def _register_ddp_hooks(self) -> None:
        # In 1.8, DDP communication hooks only work with NCCL backend and SPSD (single process single device) mode
        # Since 1.9, DDP communication hooks can work on all backends.
        if _TORCH_GREATER_EQUAL_1_9 or (_TORCH_GREATER_EQUAL_1_8
                                        and self.on_gpu and
                                        self._is_single_process_single_device):
            register_ddp_comm_hook(
                model=self._model,
                ddp_comm_state=self._ddp_comm_state,
                ddp_comm_hook=self._ddp_comm_hook,
                ddp_comm_wrapper=self._ddp_comm_wrapper,
            )

            if (_TORCH_GREATER_EQUAL_1_10 and isinstance(
                    self._ddp_comm_state, post_localSGD.PostLocalSGDState)
                    and self.lightning_module.trainer.state.fn
                    == TrainerFn.FITTING):
                self._reinit_optimizers_with_post_localSGD(
                    self._ddp_comm_state.start_localSGD_iter)
Ejemplo n.º 5
0
    def _register_ddp_hooks(self) -> None:
        log.detail(f"{self.__class__.__name__}: registering ddp hooks")
        # In 1.8, DDP communication hooks only work with NCCL backend and SPSD (single process single device) mode
        # Since 1.9, DDP communication hooks can work on all backends.
        if _TORCH_GREATER_EQUAL_1_9 or (self.root_device.type == "cuda" and
                                        self._is_single_process_single_device):
            register_ddp_comm_hook(
                model=self.model,
                ddp_comm_state=self._ddp_comm_state,
                ddp_comm_hook=self._ddp_comm_hook,
                ddp_comm_wrapper=self._ddp_comm_wrapper,
            )

            if _TORCH_GREATER_EQUAL_1_10 and self.lightning_module.trainer.state.fn == TrainerFn.FITTING:
                import torch.distributed.algorithms.ddp_comm_hooks.post_localSGD_hook as post_localSGD

                if isinstance(self._ddp_comm_state,
                              post_localSGD.PostLocalSGDState):
                    self._enable_model_averaging()