コード例 #1
0
    def configure_ddp(self):
        if self.main_rpc_process:
            self.pre_configure_ddp()

            self._model = DistributedDataParallel(
                LightningDistributedModule(self.model),
                device_ids=self.determine_ddp_device_ids(),
                process_group=mpu.get_data_parallel_group(),
                **self._ddp_kwargs,
            )
            # Plugin handle backwards across processes. Currently not supported for DDP + pipe parallel
            self._model.require_backward_grad_sync = False
コード例 #2
0
    def configure_ddp(self, model: LightningModule,
                      device_ids: List[int]) -> DistributedDataParallel:
        """
        Pass through all customizations from constructor to :class:`~torch.nn.parallel.DistributedDataParallel`.
        Override to define a custom DDP implementation.

        .. note:: This requires that your DDP implementation subclasses
            :class:`~torch.nn.parallel.DistributedDataParallel` and that
            the original LightningModule gets wrapped by
            :class:`~pytorch_lightning.overrides.data_parallel.LightningDistributedModule`.

        The default implementation is::

            def configure_ddp(self, model, device_ids):
                model = DistributedDataParallel(
                    LightningDistributedModule(model),
                    device_ids=device_ids,
                    **self._ddp_kwargs,
                )
                return model

        Args:
            model: the LightningModule
            device_ids: the list of devices available

        Returns:
            the model wrapped in :class:`~torch.nn.parallel.DistributedDataParallel`

        """
        # if unset, default `find_unused_parameters` `True`
        self._ddp_kwargs["find_unused_parameters"] = self._ddp_kwargs.get(
            "find_unused_parameters", True)
        model = DistributedDataParallel(
            module=LightningDistributedModule(model),
            device_ids=device_ids,
            **self._ddp_kwargs,
        )
        return model
コード例 #3
0
 def __init__(self, module: LightningModule, *args, **kwargs):
     warnings.warn(
         "The usage of `LightningDistributedDataParallel` is deprecated since v1.2 and will be removed in v1.4."
         " From now on we recommend to directly subclass `torch.nn.parallel.DistributedDataParallel`.",
         DeprecationWarning)
     super().__init__(LightningDistributedModule(module), *args, **kwargs)