Exemplo n.º 1
0
    def configure_ddp(self, model: LightningModule,
                      device_ids: List[int]) -> DistributedDataParallel:
        """ Override LightningModule ddp if using model parallel.

        Args:
            model (LightningModule): the LightningModule currently being optimized
            device_ids (List[int]): the list of GPU ids.

        Returns:
            DistributedDataParallel: DDP wrapped model
        """

        app_state = AppState()

        if app_state.model_parallel_size is not None:
            logging.info("Configuring DDP for model parallelism.")
            logging.info(
                f"data_parallel_group: {app_state.data_parallel_group}")
            # with model parallelism, multiple GPUs form a large "logical GPU"
            # this means that data parallel groups span multiple GPUs
            # and are non-trivial

            model = LightningDistributedDataParallel(
                model,
                device_ids,
                output_device=device_ids[0],
                process_group=app_state.data_parallel_group)
            return model

        else:
            logging.info(
                "Did not detect model parallel using LightningModule.configure_ddp"
            )
            return LightningModule.configure_ddp(self, model, device_ids)
 def configure_ddp(
     self, model: "LightningModule", device_ids: List[int]
 ) -> DistributedDataParallel:
     model = LightningDistributedDataParallel(
         model, device_ids=device_ids, find_unused_parameters=True
     )
     return model
Exemplo n.º 3
0
    def configure_ddp(self, model, device_ids):
        """Override to init DDP in a different way or use your own wrapper.

        :param model:
        :param device_ids:
        :return: DDP wrapped model

        Overwrite to define your own DDP implementation init.
        The only requirement is that:
        1. On a validation batch the call goes to model.validation_step.
        2. On a training batch the call goes to model.training_step.
        3. On a testing batch, the call goes to model.test_step

        .. code-block:: python

            def configure_ddp(self, model, device_ids):
                # Lightning DDP simply routes to test_step, val_step, etc...
                model = LightningDistributedDataParallel(
                    model,
                    device_ids=device_ids,
                    find_unused_parameters=True
                )
                return model


        """
        model = LightningDistributedDataParallel(model,
                                                 device_ids=device_ids,
                                                 find_unused_parameters=True)
        return model
Exemplo n.º 4
0
    def configure_ddp(
            self, model: LightningModule,
            device_ids: List[int]) -> LightningDistributedDataParallel:
        """
        Pass through all customizations from constructor to `LightningDistributedDataParallel`.
        Override to define a custom DDP implementation.

        .. note:: Only requirement is that your DDP implementation subclasses LightningDistributedDataParallel


        The default implementation is::

            def configure_ddp(self, model, device_ids):
                model = LightningDistributedDataParallel(
                    model, device_ids=device_ids, find_unused_parameters=False
                )
                return model

        Args:
            model: the lightningModule
            device_ids: the list of devices available

        Returns:
            the model wrapped in LightningDistributedDataParallel

        """
        # if unset, default `find_unused_parameters` `False`
        self._ddp_kwargs["find_unused_parameters"] = self._ddp_kwargs.get(
            "find_unused_parameters", False)
        model = LightningDistributedDataParallel(
            model,
            device_ids=device_ids,
            **self._ddp_kwargs,
        )
        return model
Exemplo n.º 5
0
 def block_backward_sync(self, model: LightningDistributedDataParallel):
     """
     Blocks ddp sync gradients behaviour on backwards pass.
     This is useful for skipping sync when accumulating gradients, reducing communication overhead
     Returns: context manager with sync behaviour off
     """
     yield model.no_sync()
Exemplo n.º 6
0
 def configure_ddp(self, model, device_ids):
     model = LightningDistributedDataParallel(
         model,
         device_ids=device_ids,
         find_unused_parameters=True
     )
     return model
Exemplo n.º 7
0
 def _distributed_model(
         self,
         model: pl.LightningModule) -> LightningDistributedDataParallel:
     model = LightningDistributedDataParallel(model,
                                              device_ids=[0],
                                              find_unused_parameters=True)
     return model
Exemplo n.º 8
0
    def configure_ddp(
            self, model: LightningModule,
            device_ids: List[int]) -> LightningDistributedDataParallel:
        """
        Override to define a custom DDP implementation.

        .. note:: Only requirement is that your DDP implementation subclasses LightningDistributedDataParallel


        The default implementation is::

            def configure_ddp(self, model, device_ids):
                model = LightningDistributedDataParallel(
                    model, device_ids=device_ids, find_unused_parameters=True
                )
                return model

        Args:
            model: the lightningModule
            device_ids: the list of devices available

        Returns:
            the model wrapped in LightningDistributedDataParallel

        """
        model = LightningDistributedDataParallel(model,
                                                 device_ids=device_ids,
                                                 find_unused_parameters=True)
        return model
Exemplo n.º 9
0
 def _configure_ddp(self, models, device_ids, ddp_args=None):
     assert len(models) == 1
     model = models[0]
     assert isinstance(model, ptl.LightningModule)
     model = LightningDistributedDataParallel(model,
                                              device_ids=device_ids,
                                              find_unused_parameters=True)
     return [model]
Exemplo n.º 10
0
 def configure_ddp(self, model, device_ids):
     self._ddp_kwargs["find_unused_parameters"] = self._ddp_kwargs.get(
         "find_unused_parameters", False)
     model = LightningDistributedDataParallel(
         model,
         device_ids=device_ids,
         **self._ddp_kwargs["find_unused_parameters"])
     return model
Exemplo n.º 11
0
 def configure_ddp(self, model: LightningModule,
                   device_ids: List[int]) -> DistributedDataParallel:
     logging.info(
         f'overriding ddp to set find_unused_parameters to {self._cfg.find_unused_parameters}'
     )
     model = LightningDistributedDataParallel(
         model,
         device_ids=device_ids,
         find_unused_parameters=self._cfg.find_unused_parameters)
     return model
Exemplo n.º 12
0
 def configure_ddp(self):
     # old, deprecated implementation
     with pytest.deprecated_call(
         match='`LightningDistributedDataParallel` is deprecated since v1.2 and will be removed in v1.4.'
     ):
         self._model = LightningDistributedDataParallel(
             module=self.lightning_module,
             device_ids=self.determine_ddp_device_ids(),
             **self._ddp_kwargs,
         )
         assert isinstance(self.model, torch.nn.parallel.DistributedDataParallel)
         assert isinstance(self.model.module, LightningDistributedModule)
Exemplo n.º 13
0
    def configure_ddp(self, model, device_ids):
        r"""

        Override to init DDP in your own way or with your own wrapper.
        The only requirements are that:

        1. On a validation batch the call goes to model.validation_step.
        2. On a training batch the call goes to model.training_step.
        3. On a testing batch, the call goes to model.test_step

        Args:
            model (LightningModule): the LightningModule currently being optimized
            device_ids (list): the list of GPU ids

        Return:
            DDP wrapped model

        Example
        -------
        .. code-block:: python

            # default implementation used in Trainer
            def configure_ddp(self, model, device_ids):
                # Lightning DDP simply routes to test_step, val_step, etc...
                model = LightningDistributedDataParallel(
                    model,
                    device_ids=device_ids,
                    find_unused_parameters=True
                )
                return model


        """
        model = LightningDistributedDataParallel(
            model,
            device_ids=device_ids,
            find_unused_parameters=True
        )
        return model
Exemplo n.º 14
0
 def on_after_manual_backward(self, model: LightningDistributedDataParallel):
     model.reducer_reset_hooks()
Exemplo n.º 15
0
 def on_before_manual_backward(self, model: LightningDistributedDataParallel, output: Any):
     model.reducer_prepare_for_backwards(output)