Esempio n. 1
0
    def configure_ddp(self):
        """ Override LightningModule ddp if using model parallel.
            Sets find_unused_parameters to True.
        """

        app_state = AppState()

        if app_state.model_parallel_size is not None:
            logging.info(f"Configuring DDP for model parallelism.")

            # With model parallelism, multiple GPUs form a large "logical GPU"
            # this means that data parallel groups span multiple GPUs
            # and are non-trivial
            device_ids = self.determine_ddp_device_ids()
            self._model = DistributedDataParallel(
                LightningDistributedModule(self.model),
                device_ids=device_ids,
                output_device=device_ids[0],
                process_group=app_state.data_parallel_group,
                find_unused_parameters=True,
                **self._ddp_kwargs,
            )

        else:
            super().configure_ddp()
Esempio n. 2
0
 def configure_ddp(self):
     self.pre_configure_ddp()
     self._model = DistributedDataParallel(
         LightningDistributedModule(self.model),
         device_ids=self.determine_ddp_device_ids(),
         **self._ddp_kwargs,
     )
Esempio n. 3
0
    def configure_ddp(self):
        """ Override LightningModule ddp if using model parallel.
            Sets find_unused_parameters to False to use activation-checkpoint-recomputation.
        """

        app_state = AppState()

        if app_state.model_parallel_size is not None:
            logging.info(f"Configuring DDP for model parallelism.")

            # With model parallelism, multiple GPUs form a large "logical GPU"
            # this means that data parallel groups span multiple GPUs
            # and are non-trivial
            # TODO: for megatron-lm self.model is a list
            self.pre_configure_ddp()
            # device_ids = self.determine_ddp_device_ids()
            self._model = DistributedDataParallel(
                LightningDistributedModule(self.model),
                process_group=parallel_state.get_data_parallel_group(),
                **self._ddp_kwargs,
            )

            if self.no_ddp_communication_hook:
                # When using custom gradient accumulation and allreduce, disable
                # DDP communication hook that works on the gradient bucket.
                # Instead, use the custom gradient function and communication hook,
                # which is defined in the master optimizer wrapper.
                self._model.require_backward_grad_sync = False
                self._model.register_comm_hook(None, noop_hook)

        else:
            super().configure_ddp()
Esempio n. 4
0
 def configure_ddp(self):
     self.pre_configure_ddp()
     self._model = DistributedDataParallel(
         LightningDistributedModule(self.model),
         **self._ddp_kwargs,
     )
     self._register_ddp_hooks()
Esempio n. 5
0
    def configure_ddp(self) -> None:
        self.pre_configure_ddp()
        self.model = self._setup_model(LightningDistributedModule(self.model))
        self._register_ddp_hooks()

        # set up optimizers after the wrapped module has been moved to the device
        self.setup_optimizers(self.lightning_module.trainer)
        optimizers_to_device(self.optimizers, self.root_device)
Esempio n. 6
0
 def configure_ddp(self) -> None:
     # DDP does not accept static graph as param with torch < 1.11
     if _TORCH_LESSER_EQUAL_1_10_2:
         log.detail(
             f"{self.__class__.__name__}: configuring DistributedDataParallel"
         )
         self.pre_configure_ddp()
         self.model = self._setup_model(
             LightningDistributedModule(self.model))  # type: ignore
         if self.root_device.type == "hpu" and self._static_graph:
             self._model._set_static_graph()  # type: ignore
         self._register_ddp_hooks()
     else:
         super().configure_ddp()
Esempio n. 7
0
 def connect(self, model: "pl.LightningModule") -> None:
     TPUSpawnStrategy._validate_patched_dataloaders(model)
     self.wrapped_model = xmp.MpModelWrapper(
         LightningDistributedModule(model))
     return super().connect(model)
Esempio n. 8
0
 def configure_ddp(self) -> None:
     log.detail(
         f"{self.__class__.__name__}: configuring DistributedDataParallel")
     self.pre_configure_ddp()
     self.model = self._setup_model(LightningDistributedModule(self.model))
     self._register_ddp_hooks()
Esempio n. 9
0
 def configure_ddp(self) -> None:
     self.pre_configure_ddp()
     self.model = self._setup_model(LightningDistributedModule(self.model))
     self._register_ddp_hooks()
Esempio n. 10
0
 def configure_ddp(self):
     self.pre_configure_ddp()
     self._model = self._setup_model(LightningDistributedModule(self.model))
     self._register_ddp_hooks()
     self._model._set_static_graph()
Esempio n. 11
0
 def configure_ddp(self):
     self.pre_configure_ddp()
     self._model = DistributedDataParallel(
         LightningDistributedModule(self.model),
         delay_allreduce=True,
     )