Example #1
0
 def on_train_end(self, trainer: 'pl.Trainer',
                  pl_module: LightningModule) -> None:
     if self._make_pruning_permanent:
         rank_zero_debug(
             "`ModelPruning.on_train_end`. Pruning is made permanent for this checkpoint"
         )
         self.make_pruning_permanent(pl_module)
Example #2
0
    def _collect_rank_zero_results(self, trainer: "pl.Trainer",
                                   results: Any) -> Optional["_SpawnOutput"]:
        rank_zero_debug("Finalizing the DDP spawn environment.")
        checkpoint_callback = trainer.checkpoint_callback
        best_model_path = checkpoint_callback.best_model_path if checkpoint_callback else None

        # requires to compute the state_dict on all processes in case Metrics are present
        state_dict = self.lightning_module.state_dict()

        if self.global_rank != 0:
            return

        # save the last weights
        weights_path = None
        if trainer.state.fn == TrainerFn.FITTING:
            weights_path = os.path.join(trainer.default_root_dir, ".temp.ckpt")
            self.checkpoint_io.save_checkpoint(state_dict, weights_path)

        # adds the `callback_metrics` to the queue
        extra = _FakeQueue()
        if is_overridden("add_to_queue", self.lightning_module):
            # TODO: Remove the if in v1.7
            self.lightning_module.add_to_queue(extra)
        self.add_to_queue(trainer, extra)

        return _SpawnOutput(best_model_path, weights_path, trainer.state,
                            results, extra)
Example #3
0
def test_v1_8_0_rank_zero_imports():

    import warnings

    from pytorch_lightning.utilities.distributed import rank_zero_debug, rank_zero_info
    from pytorch_lightning.utilities.warnings import LightningDeprecationWarning, rank_zero_deprecation, rank_zero_warn

    with pytest.deprecated_call(
        match="pytorch_lightning.utilities.distributed.rank_zero_debug has been deprecated in v1.6"
        " and will be removed in v1.8."
    ):
        rank_zero_debug("foo")
    with pytest.deprecated_call(
        match="pytorch_lightning.utilities.distributed.rank_zero_info has been deprecated in v1.6"
        " and will be removed in v1.8."
    ):
        rank_zero_info("foo")
    with pytest.deprecated_call(
        match="pytorch_lightning.utilities.warnings.rank_zero_warn has been deprecated in v1.6"
        " and will be removed in v1.8."
    ):
        rank_zero_warn("foo")
    with pytest.deprecated_call(
        match="pytorch_lightning.utilities.warnings.rank_zero_deprecation has been deprecated in v1.6"
        " and will be removed in v1.8."
    ):
        rank_zero_deprecation("foo")
    with pytest.deprecated_call(
        match="pytorch_lightning.utilities.warnings.LightningDeprecationWarning has been deprecated in v1.6"
        " and will be removed in v1.8."
    ):
        warnings.warn("foo", LightningDeprecationWarning, stacklevel=5)
 def _resolve_refresh_rate(refresh_rate: int) -> int:
     if os.getenv("COLAB_GPU") and refresh_rate == 1:
         # smaller refresh rate on colab causes crashes, choose a higher value
         rank_zero_debug(
             "Using a higher refresh rate on Colab. Setting it to `20`")
         refresh_rate = 20
     return refresh_rate
Example #5
0
 def on_save_checkpoint(
     self, trainer: "pl.Trainer", pl_module: LightningModule, checkpoint: Dict[str, Any]
 ) -> Dict[str, Any]:
     if self._make_pruning_permanent:
         rank_zero_debug("`ModelPruning.on_save_checkpoint`. Pruning is made permanent for this checkpoint")
         # manually prune the weights so training can keep going with the same buffers
         checkpoint["state_dict"] = self._make_pruning_permanent_on_state_dict(pl_module)
     return checkpoint
Example #6
0
 def on_save_checkpoint(self, trainer, pl_module: LightningModule, checkpoint: Dict[str, Any]):
     if self._make_pruning_permanent:
         rank_zero_debug("`ModelPruning.on_save_checkpoint`. Pruning is made permanent for this checkpoint.")
         prev_device = pl_module.device
         # prune a copy so training can continue with the same buffers
         copy = deepcopy(pl_module.to("cpu"))
         self.make_pruning_permanent(copy)
         checkpoint["state_dict"] = copy.state_dict()
         pl_module.to(prev_device)
Example #7
0
 def on_validation_epoch_end(self, trainer: 'pl.Trainer',
                             pl_module: 'pl.LightningModule') -> None:
     if not trainer.sanity_checking and not self._prune_on_train_epoch_end:
         rank_zero_debug(
             "`ModelPruning.on_validation_epoch_end`. Applying pruning")
         self._run_pruning(pl_module.current_epoch)
Example #8
0
 def on_train_epoch_end(self, trainer: 'pl.Trainer',
                        pl_module: LightningModule) -> None:  # type: ignore
     if self._prune_on_train_epoch_end:
         rank_zero_debug(
             "`ModelPruning.on_train_epoch_end`. Applying pruning")
         self._run_pruning(pl_module.current_epoch)