def advance(self, kwargs: OrderedDict) -> None: # type: ignore[override] """Runs the train step together with optimization (if necessary) on the current batch split. Args: kwargs: the kwargs passed down to the hooks. """ # replace the batch with the split batch self.split_idx, kwargs["batch"] = self._remaining_splits.pop(0) self.trainer._logger_connector.on_train_split_start(self.split_idx) outputs: Optional[Union[_OPTIMIZER_LOOP_OUTPUTS_TYPE, _MANUAL_LOOP_OUTPUTS_TYPE]] = None # for mypy # choose which loop will run the optimization if self.trainer.lightning_module.automatic_optimization: optimizers = _get_active_optimizers( self.trainer.optimizers, self.trainer.optimizer_frequencies, kwargs.get("batch_idx", 0)) outputs = self.optimizer_loop.run(optimizers, kwargs) else: outputs = self.manual_loop.run(kwargs) if outputs: # automatic: can be empty if all optimizers skip their batches # manual: #9052 added support for raising `StopIteration` in the `training_step`. If that happens, # then `advance` doesn't finish and an empty dict is returned self._outputs.append(outputs)
def advance(self, batch: Any, batch_idx: int) -> None: # type: ignore[override] """Runs the train step together with optimization (if necessary) on the current batch split. Args: batch: the current batch to run the training on (this is not the split!) batch_idx: the index of the current batch """ void(batch) self.split_idx, split_batch = self._remaining_splits.pop(0) self.trainer._logger_connector.on_train_split_start(self.split_idx) outputs: Optional[Union[_OPTIMIZER_LOOP_OUTPUTS_TYPE, _MANUAL_LOOP_OUTPUTS_TYPE]] = None # for mypy # choose which loop will run the optimization if self.trainer.lightning_module.automatic_optimization: optimizers = _get_active_optimizers( self.trainer.optimizers, self.trainer.optimizer_frequencies, batch_idx) outputs = self.optimizer_loop.run(split_batch, optimizers, batch_idx) else: outputs = self.manual_loop.run(split_batch, batch_idx) if outputs: # automatic: can be empty if all optimizers skip their batches # manual: #9052 added support for raising `StopIteration` in the `training_step`. If that happens, # then `advance` doesn't finish and an empty dict is returned self._outputs.append(outputs)
def on_train_epoch_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: """Called when the epoch begins.""" # import is here to avoid circular imports from pytorch_lightning.loops.utilities import _get_active_optimizers for opt_idx, optimizer in _get_active_optimizers(trainer.optimizers, trainer.optimizer_frequencies): num_param_groups = len(optimizer.param_groups) self.finetune_function(pl_module, trainer.current_epoch, optimizer, opt_idx) current_param_groups = optimizer.param_groups self._store(pl_module, opt_idx, num_param_groups, current_param_groups)
def update_lr_schedulers(self, interval: str, update_plateau_schedulers: bool) -> None: """updates the lr schedulers based on the given interval.""" if interval == "step" and self._should_accumulate(): return active_optimizers = _get_active_optimizers( self.trainer.optimizers, self.trainer.optimizer_frequencies, self.total_batch_idx ) self._update_learning_rates( interval=interval, update_plateau_schedulers=update_plateau_schedulers, opt_indices=[opt_idx for opt_idx, _ in active_optimizers], )