Esempio n. 1
0
    def advance(self, kwargs: OrderedDict) -> None:  # type: ignore[override]
        """Runs the train step together with optimization (if necessary) on the current batch split.

        Args:
            kwargs: the kwargs passed down to the hooks.
        """
        # replace the batch with the split batch
        self.split_idx, kwargs["batch"] = self._remaining_splits.pop(0)

        self.trainer._logger_connector.on_train_split_start(self.split_idx)

        outputs: Optional[Union[_OPTIMIZER_LOOP_OUTPUTS_TYPE,
                                _MANUAL_LOOP_OUTPUTS_TYPE]] = None  # for mypy
        # choose which loop will run the optimization
        if self.trainer.lightning_module.automatic_optimization:
            optimizers = _get_active_optimizers(
                self.trainer.optimizers, self.trainer.optimizer_frequencies,
                kwargs.get("batch_idx", 0))
            outputs = self.optimizer_loop.run(optimizers, kwargs)
        else:
            outputs = self.manual_loop.run(kwargs)
        if outputs:
            # automatic: can be empty if all optimizers skip their batches
            # manual: #9052 added support for raising `StopIteration` in the `training_step`. If that happens,
            # then `advance` doesn't finish and an empty dict is returned
            self._outputs.append(outputs)
    def advance(self, batch: Any,
                batch_idx: int) -> None:  # type: ignore[override]
        """Runs the train step together with optimization (if necessary) on the current batch split.

        Args:
            batch: the current batch to run the training on (this is not the split!)
            batch_idx: the index of the current batch
        """
        void(batch)
        self.split_idx, split_batch = self._remaining_splits.pop(0)

        self.trainer._logger_connector.on_train_split_start(self.split_idx)

        outputs: Optional[Union[_OPTIMIZER_LOOP_OUTPUTS_TYPE,
                                _MANUAL_LOOP_OUTPUTS_TYPE]] = None  # for mypy
        # choose which loop will run the optimization
        if self.trainer.lightning_module.automatic_optimization:
            optimizers = _get_active_optimizers(
                self.trainer.optimizers, self.trainer.optimizer_frequencies,
                batch_idx)
            outputs = self.optimizer_loop.run(split_batch, optimizers,
                                              batch_idx)
        else:
            outputs = self.manual_loop.run(split_batch, batch_idx)
        if outputs:
            # automatic: can be empty if all optimizers skip their batches
            # manual: #9052 added support for raising `StopIteration` in the `training_step`. If that happens,
            # then `advance` doesn't finish and an empty dict is returned
            self._outputs.append(outputs)
Esempio n. 3
0
    def on_train_epoch_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
        """Called when the epoch begins."""
        # import is here to avoid circular imports
        from pytorch_lightning.loops.utilities import _get_active_optimizers

        for opt_idx, optimizer in _get_active_optimizers(trainer.optimizers, trainer.optimizer_frequencies):
            num_param_groups = len(optimizer.param_groups)
            self.finetune_function(pl_module, trainer.current_epoch, optimizer, opt_idx)
            current_param_groups = optimizer.param_groups
            self._store(pl_module, opt_idx, num_param_groups, current_param_groups)
Esempio n. 4
0
 def update_lr_schedulers(self, interval: str, update_plateau_schedulers: bool) -> None:
     """updates the lr schedulers based on the given interval."""
     if interval == "step" and self._should_accumulate():
         return
     active_optimizers = _get_active_optimizers(
         self.trainer.optimizers, self.trainer.optimizer_frequencies, self.total_batch_idx
     )
     self._update_learning_rates(
         interval=interval,
         update_plateau_schedulers=update_plateau_schedulers,
         opt_indices=[opt_idx for opt_idx, _ in active_optimizers],
     )