Esempio n. 1
0
 def on_train_end(self, trainer: "pl.Trainer",
                  pl_module: LightningModule) -> None:
     if self._make_pruning_permanent:
         rank_zero_debug(
             "`ModelPruning.on_train_end`. Pruning is made permanent for this checkpoint"
         )
         self.make_pruning_permanent(pl_module)
Esempio n. 2
0
    def _collect_rank_zero_results(self, trainer: "pl.Trainer",
                                   results: Any) -> Optional["_SpawnOutput"]:
        rank_zero_debug("Finalizing the DDP spawn environment.")
        checkpoint_callback = trainer.checkpoint_callback
        best_model_path = checkpoint_callback.best_model_path if checkpoint_callback else None

        # requires to compute the state_dict on all processes in case Metrics are present
        state_dict = trainer.lightning_module.state_dict()

        if self._strategy.global_rank != 0:
            return None

        # save the last weights
        weights_path = None
        if trainer.state.fn == TrainerFn.FITTING:
            weights_path = os.path.join(trainer.default_root_dir, ".temp.ckpt")
            self._strategy.checkpoint_io.save_checkpoint(
                state_dict, weights_path)

        # adds the `callback_metrics` to the queue
        extra = _FakeQueue()
        if is_overridden("add_to_queue", trainer.lightning_module):
            # TODO: Remove the if in v1.7
            trainer.lightning_module.add_to_queue(extra)
        self.add_to_queue(trainer, extra)

        return _SpawnOutput(best_model_path, weights_path, trainer.state,
                            results, extra)
 def _resolve_refresh_rate(refresh_rate: int) -> int:
     if os.getenv("COLAB_GPU") and refresh_rate == 1:
         # smaller refresh rate on colab causes crashes, choose a higher value
         rank_zero_debug(
             "Using a higher refresh rate on Colab. Setting it to `20`")
         refresh_rate = 20
     return refresh_rate
Esempio n. 4
0
    def _collect_rank_zero_results(self, trainer: "pl.Trainer",
                                   results: Any) -> Optional["_WorkerOutput"]:
        rank_zero_debug("Collecting results from rank 0 process.")
        checkpoint_callback = trainer.checkpoint_callback
        best_model_path = (
            checkpoint_callback.best_model_path if checkpoint_callback
            and hasattr(checkpoint_callback, "best_model_path") else None)

        # requires to compute the state_dict on all processes in case Metrics are present
        state_dict = trainer.lightning_module.state_dict()

        # save the last weights
        weights_path = None
        if trainer.state.fn == TrainerFn.FITTING:
            weights_path = os.path.join(trainer.default_root_dir, ".temp.ckpt")
            self._strategy.checkpoint_io.save_checkpoint(
                state_dict, weights_path)

        # We use `local_rank` here as separate filesystems are used for each VM for TPU Pod Training
        if self._strategy.local_rank != 0:
            return None

        # adds the `callback_metrics` to the queue
        extra = _FakeQueue()
        self.add_to_queue(trainer, extra)

        return _WorkerOutput(best_model_path, weights_path, trainer.state,
                             results, extra)
Esempio n. 5
0
 def auto_device_count() -> int:
     """Returns the number of HPU devices when the devices is set to auto."""
     try:
         return torch_hpu.device_count()
     except (AttributeError, NameError):
         rank_zero_debug(
             "HPU `auto_device_count` failed, returning default count of 8."
         )
         return 8
Esempio n. 6
0
 def on_save_checkpoint(self, trainer: "pl.Trainer",
                        pl_module: "pl.LightningModule",
                        checkpoint: Dict[str, Any]) -> Optional[dict]:
     if self._make_pruning_permanent:
         rank_zero_debug(
             "`ModelPruning.on_save_checkpoint`. Pruning is made permanent for this checkpoint"
         )
         # manually prune the weights so training can keep going with the same buffers
         checkpoint[
             "state_dict"] = self._make_pruning_permanent_on_state_dict(
                 pl_module)
Esempio n. 7
0
    def done(self) -> bool:
        """Evaluates when to leave the loop."""
        if self.trainer.num_training_batches == 0:
            rank_zero_info("`Trainer.fit` stopped: No training batches.")
            return True

        # TODO(@awaelchli): Move track steps inside training loop and move part of these condition inside training loop
        stop_steps = _is_max_limit_reached(self.epoch_loop.global_step, self.max_steps)
        if stop_steps:
            rank_zero_info(f"`Trainer.fit` stopped: `max_steps={self.max_steps!r}` reached.")
            return True

        # `processed` is increased before `on_train_epoch_end`, the hook where checkpoints are typically saved.
        # we use it here because the checkpoint data won't have `completed` increased yet
        stop_epochs = _is_max_limit_reached(self.epoch_progress.current.processed, self.max_epochs)
        if stop_epochs:
            # in case they are not equal, override so `trainer.current_epoch` has the expected value
            self.epoch_progress.current.completed = self.epoch_progress.current.processed
            rank_zero_info(f"`Trainer.fit` stopped: `max_epochs={self.max_epochs!r}` reached.")
            return True

        if self.trainer.should_stop:
            # early stopping
            met_min_epochs = self.epoch_progress.current.processed >= self.min_epochs if self.min_epochs else True
            met_min_steps = self.epoch_loop.global_step >= self.min_steps if self.min_steps else True
            if met_min_epochs and met_min_steps:
                self.trainer.should_stop = True
                rank_zero_debug("`Trainer.fit` stopped: `trainer.should_stop` was set.")
                return True
            else:
                rank_zero_info(
                    f"Trainer was signaled to stop but the required `min_epochs={self.min_epochs!r}` or"
                    f" `min_steps={self.min_steps!r}` has not been met. Training will continue..."
                )
        self.trainer.should_stop = False
        return False
Esempio n. 8
0
 def get_device_stats(self, device: Union[str,
                                          torch.device]) -> Dict[str, Any]:
     """HPU device stats aren't supported yet."""
     rank_zero_debug("HPU device stats aren't supported yet.")
     return {}
Esempio n. 9
0
 def on_validation_epoch_end(self, trainer: "pl.Trainer",
                             pl_module: "pl.LightningModule") -> None:
     if not trainer.sanity_checking and not self._prune_on_train_epoch_end:
         rank_zero_debug(
             "`ModelPruning.on_validation_epoch_end`. Applying pruning")
         self._run_pruning(pl_module.current_epoch)
Esempio n. 10
0
 def on_train_epoch_end(self, trainer: "pl.Trainer",
                        pl_module: LightningModule) -> None:
     if self._prune_on_train_epoch_end:
         rank_zero_debug(
             "`ModelPruning.on_train_epoch_end`. Applying pruning")
         self._run_pruning(pl_module.current_epoch)