def configure_ddp(self): if self.main_rpc_process: self.pre_configure_ddp() self._model = DistributedDataParallel( LightningDistributedModule(self.model), device_ids=self.determine_ddp_device_ids(), process_group=mpu.get_data_parallel_group(), **self._ddp_kwargs, ) # Plugin handle backwards across processes. Currently not supported for DDP + pipe parallel self._model.require_backward_grad_sync = False
def configure_ddp(self, model: LightningModule, device_ids: List[int]) -> DistributedDataParallel: """ Pass through all customizations from constructor to :class:`~torch.nn.parallel.DistributedDataParallel`. Override to define a custom DDP implementation. .. note:: This requires that your DDP implementation subclasses :class:`~torch.nn.parallel.DistributedDataParallel` and that the original LightningModule gets wrapped by :class:`~pytorch_lightning.overrides.data_parallel.LightningDistributedModule`. The default implementation is:: def configure_ddp(self, model, device_ids): model = DistributedDataParallel( LightningDistributedModule(model), device_ids=device_ids, **self._ddp_kwargs, ) return model Args: model: the LightningModule device_ids: the list of devices available Returns: the model wrapped in :class:`~torch.nn.parallel.DistributedDataParallel` """ # if unset, default `find_unused_parameters` `True` self._ddp_kwargs["find_unused_parameters"] = self._ddp_kwargs.get( "find_unused_parameters", True) model = DistributedDataParallel( module=LightningDistributedModule(model), device_ids=device_ids, **self._ddp_kwargs, ) return model
def __init__(self, module: LightningModule, *args, **kwargs): warnings.warn( "The usage of `LightningDistributedDataParallel` is deprecated since v1.2 and will be removed in v1.4." " From now on we recommend to directly subclass `torch.nn.parallel.DistributedDataParallel`.", DeprecationWarning) super().__init__(LightningDistributedModule(module), *args, **kwargs)