Beispiel #1
0
def test_progress_increment_sequence():
    """Test sequence for incrementing."""
    batch = Progress()

    batch.increment_ready()
    assert batch.total == ProcessedTracker(ready=1)
    assert batch.current == ProcessedTracker(ready=1)

    batch.increment_started()
    assert batch.total == ProcessedTracker(ready=1, started=1)
    assert batch.current == ProcessedTracker(ready=1, started=1)

    batch.increment_processed()
    assert batch.total == ProcessedTracker(ready=1, started=1, processed=1)
    assert batch.current == ProcessedTracker(ready=1, started=1, processed=1)

    batch.increment_completed()
    assert batch.total == ProcessedTracker(ready=1,
                                           started=1,
                                           processed=1,
                                           completed=1)
    assert batch.current == ProcessedTracker(ready=1,
                                             started=1,
                                             processed=1,
                                             completed=1)
def test_epoch_loop_progress_increment_sequence():
    """Test sequences for incrementing batches reads and epochs."""
    batch = Progress()

    batch.increment_ready()
    assert batch.total == Tracker(ready=1)
    assert batch.current == Tracker(ready=1)

    batch.increment_started()
    assert batch.total == Tracker(ready=1, started=1)
    assert batch.current == Tracker(ready=1, started=1)

    batch.increment_processed()
    assert batch.total == Tracker(ready=1, started=1, processed=1)
    assert batch.current == Tracker(ready=1, started=1, processed=1)

    batch.increment_completed()
    assert batch.total == Tracker(ready=1, started=1, processed=1, completed=1)
    assert batch.current == Tracker(ready=1,
                                    started=1,
                                    processed=1,
                                    completed=1)
Beispiel #3
0
class PredictionEpochLoop(Loop):
    """Loop performing prediction on arbitrary sequentially used dataloaders."""
    def __init__(self) -> None:
        super().__init__()
        self.return_predictions = False
        self.predictions: List[Any] = []
        self.current_batch_indices: List[int] = []
        self.batch_progress = Progress()

        self._dl_max_batches = 0
        self._num_dataloaders = 0
        self._warning_cache = WarningCache()
        self._seen_batch_indices: List[List[int]] = []

    @property
    def done(self) -> bool:
        """Ends prediction when the iteration count exceeds the total number of available batches."""
        return self.batch_progress.current.completed >= self._dl_max_batches

    @property
    def should_store_predictions(self) -> bool:
        """Whether the predictions should be stored for later usage (e.g. aggregation or returning)"""
        any_pred = any(cb.interval.on_epoch
                       for cb in self.trainer.prediction_writer_callbacks)
        return self.return_predictions or any_pred

    def connect(self, **kwargs: "Loop") -> None:
        raise NotImplementedError(
            f"{self.__class__.__name__} does not connect any child loops.")

    def reset(self) -> None:
        """Resets the loops internal state."""
        self._seen_batch_indices = []
        self.predictions = []
        self.batch_progress.reset_on_run()

    def on_run_start(  # type: ignore[override]
        self,
        dataloader_iter: Iterator,
        dataloader_idx: int,
        dl_max_batches: int,
        num_dataloaders: int,
    ) -> None:
        """Prepares the loops internal state.

        Args:
            dataloader_iter: the iterator over the current dataloader
            dataloader_idx: the index of the current dataloader
            dl_max_batches: the maximum number of batches the current loader can produce
            num_dataloaders: the total number of dataloaders
        """
        void(dataloader_iter, dataloader_idx)
        self._dl_max_batches = dl_max_batches
        self._num_dataloaders = num_dataloaders
        # this call requires that `self.return_predictions` is set
        self._seen_batch_indices = self._get_batch_indices(
            dataloader_idx) if self.should_store_predictions else []

    def advance(  # type: ignore[override]
        self,
        dataloader_iter: Iterator,
        dataloader_idx: int,
        dl_max_batches: int,
        num_dataloaders: int,
    ) -> None:
        """Runs one prediction step.

        Args:
            dataloader_iter: the iterator over the current dataloader
            dataloader_idx: the index of the current dataloader
            dl_max_batches: the maximum number of batches the current loader can produce
            num_dataloaders: the total number of dataloaders
        """
        action_name = f"[{self.__class__.__name__}].predict_dataloader_idx_{dataloader_idx}_next"
        with self.trainer.profiler.profile(action_name):
            batch_idx, batch = next(dataloader_iter)
        self._seen_batch_indices = self._get_batch_indices(
            dataloader_idx) if self.should_store_predictions else []
        # we need to truncate the list of batch indices due to prefetching in the dataloader and Lightning
        self._seen_batch_indices = self._seen_batch_indices[:(
            self.batch_progress.current.completed + 1)]

        if batch is None:
            raise StopIteration

        batch = self.trainer._call_strategy_hook("batch_to_device",
                                                 batch,
                                                 dataloader_idx=dataloader_idx)

        self.batch_progress.increment_ready()

        self._predict_step(batch, batch_idx, dataloader_idx)

    def on_run_end(self) -> Tuple[List[Any], List[List[int]]]:
        """Returns the predictions and the corresponding batch indices."""
        predictions, all_batch_indices = self.predictions, self._seen_batch_indices
        self.predictions, self._seen_batch_indices = [], []  # free memory
        return predictions, all_batch_indices

    def _predict_step(self, batch: Any, batch_idx: int,
                      dataloader_idx: int) -> None:
        """Runs the actual predict step together with all the necessary bookkeeping and the hooks tied to the
        predict step.

        Args:
            batch: the current batch to run the prediction on
            batch_idx: the index of the current batch
            dataloader_idx: the index of the dataloader producing the current batch
        """
        # configure step_kwargs
        step_kwargs = self._build_kwargs(batch, batch_idx, dataloader_idx)

        # extract batch_indices and store them
        batch_indices = self._get_batch_indices(dataloader_idx)
        self.current_batch_indices = batch_indices[
            batch_idx] if batch_indices else []

        self.trainer._call_callback_hooks("on_predict_batch_start", batch,
                                          batch_idx, dataloader_idx)
        self.trainer._call_lightning_module_hook("on_predict_batch_start",
                                                 batch, batch_idx,
                                                 dataloader_idx)

        self.batch_progress.increment_started()

        predictions = self.trainer._call_strategy_hook("predict_step",
                                                       *step_kwargs.values())

        self.batch_progress.increment_processed()

        if predictions is None:
            self._warning_cache.warn(
                "predict returned None if it was on purpose, ignore this warning..."
            )

        self.trainer._call_callback_hooks("on_predict_batch_end", predictions,
                                          batch, batch_idx, dataloader_idx)
        self.trainer._call_lightning_module_hook("on_predict_batch_end",
                                                 predictions, batch, batch_idx,
                                                 dataloader_idx)

        self.batch_progress.increment_completed()

        if self.should_store_predictions:
            self.predictions.append(
                move_data_to_device(predictions, torch.device("cpu")))

    def _build_kwargs(self, batch: Any, batch_idx: int,
                      dataloader_idx: int) -> Dict[str, Any]:
        """Assembles the keyword arguments for the ``predict_step``

        Args:
            batch: the current batch to run the prediction on
            batch_idx: the index of the current batch
            dataloader_idx: the index of the dataloader producing the current batch

        Returns:
            the dictionary containing all the keyboard arguments for the predict step
        """
        step_kwargs = OrderedDict([("batch", batch), ("batch_idx", batch_idx)])
        if self._num_dataloaders > 1:
            step_kwargs["dataloader_idx"] = dataloader_idx
        return step_kwargs

    def _get_batch_indices(self, dataloader_idx: int) -> List[List[int]]:
        """Returns a reference to the seen batch indices if the dataloader has a batch sampler wrapped by our
        :class:`~pytorch_lightning.overrides.distributed.IndexBatchSamplerWrapper`."""
        # the batch_sampler is not be defined in case of CombinedDataLoaders
        batch_sampler = getattr(
            self.trainer.
            predict_dataloaders[dataloader_idx],  # type: ignore[has-type]
            "batch_sampler",
            None,
        )
        if isinstance(batch_sampler, IndexBatchSamplerWrapper):
            return batch_sampler.seen_batch_indices

        warning_cache.warn(
            "Lightning couldn't infer the indices fetched for your dataloader."
        )
        return []
class TrainingEpochLoop(loops.Loop):
    """
    Runs over all batches in a dataloader (one epoch).

    Args:
        min_steps: The minimum number of steps (batches) to process
        max_steps: The maximum number of steps (batches) to process
    """
    def __init__(self, min_steps: int, max_steps: int):
        super().__init__()
        self.min_steps: int = min_steps
        self.max_steps: int = max_steps

        self.global_step: int = 0
        # manually tracking which is the last batch is necessary for iterable dataset support
        self.is_last_batch: Optional[bool] = None
        self.batch_progress = Progress()
        self.scheduler_progress = SchedulerProgress()

        self.batch_loop: Optional[TrainingBatchLoop] = None
        self.val_loop: Optional["loops.EvaluationLoop"] = None

        self._results = ResultCollection(training=True)
        self._epoch_output: Optional[List[List[STEP_OUTPUT]]] = None

    @property
    def total_batch_idx(self) -> int:
        """Returns the current batch index (across epochs)"""
        # use `ready` instead of `completed` in case this is accessed after `completed` has been increased
        # but before the next `ready` increase
        return self.batch_progress.total.ready - 1

    @property
    def batch_idx(self) -> int:
        """Returns the current batch index (within this epoch)"""
        # use `ready` instead of `completed` in case this is accessed after `completed` has been increased
        # but before the next `ready` increase
        return self.batch_progress.current.ready - 1

    @property
    def done(self) -> bool:
        """Returns whether the training should be stopped.
        The criteria are that the number of steps reached the max steps,
        the last batch is reached or the trainer signals to stop (e.g. by early stopping).
        """
        max_steps_reached = self.max_steps is not None and self.global_step >= self.max_steps
        return max_steps_reached or self.trainer.should_stop or self._num_training_batches_reached(
            self.is_last_batch)

    def connect(
        self,
        batch_loop: TrainingBatchLoop = None,
        val_loop: Optional["loops.EvaluationLoop"] = None,
    ) -> None:
        """Optionally connect a custom batch or validation loop to this training epoch loop."""
        if batch_loop is not None:
            self.batch_loop = batch_loop
        if val_loop is not None:
            self.val_loop = val_loop

    def reset(self) -> None:
        """Resets the internal state of the loop for a new run"""
        self.is_last_batch = False

        # track epoch output
        self._epoch_output = [[] for _ in range(
            self.batch_loop.num_active_optimizers(self.total_batch_idx))]

        if not self.restarting:
            self.batch_progress.current.reset()
            self.scheduler_progress.current.reset()
            self.batch_loop.optim_progress.reset_on_epoch()

    def on_run_start(self, dataloader_iter: Iterator, **kwargs: Any) -> None:
        # hook
        self.trainer.logger_connector.on_epoch_start()
        self.trainer.call_hook("on_epoch_start")
        self.trainer.call_hook("on_train_epoch_start")
        self.trainer.fit_loop.epoch_progress.increment_started()

        self.dataloader_iter = _prepare_dataloader_iter(
            dataloader_iter, self.batch_idx + 1)

    def advance(self, *args: Any, **kwargs: Any) -> None:
        """Runs a single training batch.

        Args:
            dataloader_iter: the iterator over the dataloader producing the new batch

        Raises:
            StopIteration: When the epoch is canceled by the user returning -1
        """
        batch_idx, (batch, is_last) = next(self.dataloader_iter)

        if not self.trainer.data_connector.train_data_fetcher.store_on_device:
            with self.trainer.profiler.profile("training_batch_to_device"):
                batch = self.trainer.accelerator.batch_to_device(batch)

        self.batch_progress.increment_ready()

        with self.trainer.profiler.profile("run_training_batch"):
            batch_output = self.batch_loop.run(batch, batch_idx)

        self.batch_progress.increment_processed()

        self.is_last_batch = is_last

        # when returning -1 from train_step, we end epoch early
        if batch_output.signal == -1:
            raise StopIteration

        # update non-plateau LR schedulers
        # update epoch-interval ones only when we are at the end of training epoch
        self.update_lr_schedulers("step", update_plateau_schedulers=False)
        if self._num_training_batches_reached(is_last):
            self.update_lr_schedulers("epoch", update_plateau_schedulers=False)

        batch_end_outputs = [
            opt_idx_out for opt_idx_out in batch_output.training_step_output
            if len(opt_idx_out)
        ]
        processed_batch_end_outputs = self._prepare_outputs(batch_end_outputs,
                                                            batch_mode=True)

        # hook
        self.trainer.call_hook("on_train_batch_end",
                               processed_batch_end_outputs, batch,
                               self.batch_idx, 0)
        self.trainer.call_hook("on_batch_end")
        self.trainer.logger_connector.on_batch_end()

        self.batch_progress.increment_completed()

        # figure out what to track for epoch end
        self._track_epoch_end_reduce_metrics(self._epoch_output,
                                             batch_end_outputs)

        # -----------------------------------------
        # SAVE METRICS TO LOGGERS AND PROGRESS_BAR
        # -----------------------------------------
        self.trainer.logger_connector.update_train_step_metrics()

    def on_advance_end(self):
        """Runs validation and Checkpointing if necessary.

        Raises:
            StopIteration: if :attr:`done` evaluates to ``True`` to finish this epoch
        """
        # -----------------------------------------
        # VALIDATE IF NEEDED + CHECKPOINT CALLBACK
        # -----------------------------------------
        should_check_val = self._should_check_val_fx(self.batch_idx,
                                                     self.is_last_batch)
        if should_check_val:
            self.trainer.validating = True
            self._run_validation()
            self.trainer.training = True

        # -----------------------------------------
        # SAVE LOGGERS (ie: Tensorboard, etc...)
        # -----------------------------------------
        self._save_loggers_on_train_batch_end()

        # update plateau LR scheduler after metrics are logged
        self.update_lr_schedulers("step", update_plateau_schedulers=True)

        # progress global step according to grads progress
        self._increment_accumulated_grad_global_step()

    def on_run_end(self) -> List[List[STEP_OUTPUT]]:
        """Calls the on_epoch_end hook.

        Returns:
            The output of each training step for each optimizer

        Raises:
            MisconfigurationException: ``train_epoch_end`` does not return ``None``
        """
        if self.batch_progress.current.ready == 0:
            # dataloader/iterator did not produce a batch
            return

        # inform logger the batch loop has finished
        self.trainer.logger_connector.epoch_end_reached()

        # prepare epoch output
        processed_outputs = self._prepare_outputs(self._epoch_output,
                                                  batch_mode=False)

        # get the model and call model.training_epoch_end
        model = self.trainer.lightning_module

        if is_overridden("training_epoch_end", model):
            # run training_epoch_end
            # refresh the result for custom logging at the epoch level
            model._current_fx_name = "training_epoch_end"

            # lightningmodule hook
            training_epoch_end_output = model.training_epoch_end(
                processed_outputs)

            if training_epoch_end_output is not None:
                raise MisconfigurationException(
                    "training_epoch_end expects a return of None. "
                    "HINT: remove the return statement in training_epoch_end")

        self.trainer.fit_loop.epoch_progress.increment_processed()

        # call train epoch end hooks
        self.trainer.call_hook("on_train_epoch_end")
        self.trainer.call_hook("on_epoch_end")
        self.trainer.logger_connector.on_epoch_end()

        if self._num_training_batches_reached(self.is_last_batch):
            self.update_lr_schedulers("epoch", update_plateau_schedulers=True)

        epoch_output = self._epoch_output
        # free memory
        self._epoch_output = None
        return epoch_output

    def teardown(self) -> None:
        self._results.cpu()
        self.batch_loop.teardown()
        self.val_loop.teardown()

    def _run_validation(self):
        # reload dataloaders
        self.val_loop.reload_evaluation_dataloaders()

        with torch.no_grad():
            self.val_loop.run()

    def _accumulated_batches_reached(self) -> bool:
        """Determine if accumulation will be finished by the end of the current batch."""
        return self.batch_progress.current.ready % self.trainer.accumulate_grad_batches == 0

    def _num_training_batches_reached(self,
                                      is_last_batch: bool = False) -> bool:
        """Checks if we are in the last batch or if there are more batches to follow.

        Args:
            is_last_batch: Whether the current batch is the last one
        """
        return self.batch_progress.current.ready == self.trainer.num_training_batches or is_last_batch

    def _should_accumulate(self) -> bool:
        """Checks if the optimizer step should be performed or gradients should be accumulated for the current step."""
        accumulation_done = self._accumulated_batches_reached()
        is_final_batch = self._num_training_batches_reached()
        return not (accumulation_done or is_final_batch)

    def _track_epoch_end_reduce_metrics(
            self, epoch_output: List[List[STEP_OUTPUT]],
            batch_end_outputs: STEP_OUTPUT) -> None:
        """Adds the batch outputs to the epoch outputs and prepares reduction"""
        hook_overridden = is_overridden("training_epoch_end",
                                        self.trainer.lightning_module)
        if not hook_overridden:
            return

        # track the outputs to reduce at the end of the epoch
        for opt_idx, opt_outputs in enumerate(batch_end_outputs):
            # with 1 step (no tbptt) don't use a sequence at epoch end
            if isinstance(opt_outputs, list) and len(opt_outputs) == 1:
                opt_outputs = opt_outputs[0]

            epoch_output[opt_idx].append(opt_outputs)

    @staticmethod
    def _prepare_outputs(
        outputs: List[List[List["ResultCollection"]]], batch_mode: bool
    ) -> Union[List[List[List[Dict]]], List[List[Dict]], List[Dict], Dict]:
        """
        Extract required information from batch or epoch end results.

        Args:
            outputs: A 3-dimensional list of ``ResultCollection`` objects with dimensions:
                ``[optimizer outs][batch outs][tbptt steps]``.

            batch_mode: If True, ignore the batch output dimension.

        Returns:
            The cleaned outputs with ``ResultCollection`` objects converted to dictionaries.
            All list dimensions of size one will be collapsed.
        """
        processed_outputs = []
        for opt_outputs in outputs:
            # handle an edge case where an optimizer output is the empty list
            if len(opt_outputs) == 0:
                continue

            processed_batch_outputs = []

            if batch_mode:
                opt_outputs = [opt_outputs]

            for batch_outputs in opt_outputs:
                processed_tbptt_outputs = []

                if isinstance(batch_outputs, ResultCollection):
                    batch_outputs = [batch_outputs]

                for tbptt_output in batch_outputs:
                    out = {}
                    if tbptt_output.minimize is not None:
                        out["loss"] = tbptt_output.minimize.detach()
                    out.update(tbptt_output.extra)
                    processed_tbptt_outputs.append(out)

                # if there was only one tbptt step then we can collapse that dimension
                if len(processed_tbptt_outputs) == 1:
                    processed_tbptt_outputs = processed_tbptt_outputs[0]
                processed_batch_outputs.append(processed_tbptt_outputs)

            # batch_outputs should be just one dict (or a list of dicts if using tbptt) per optimizer
            if batch_mode:
                processed_batch_outputs = processed_batch_outputs[0]
            processed_outputs.append(processed_batch_outputs)

        # if there is only one optimiser then we collapse that dimension
        if len(processed_outputs) == 1:
            processed_outputs = processed_outputs[0]
        return processed_outputs

    def update_lr_schedulers(self, interval: str,
                             update_plateau_schedulers: bool) -> None:
        """updates the lr schedulers based on the given interval"""
        if interval == "step" and self._should_accumulate():
            return
        self.trainer.optimizer_connector.update_learning_rates(
            interval=interval,
            update_plateau_schedulers=update_plateau_schedulers,
            opt_indices=[
                opt_idx for opt_idx, _ in
                self.batch_loop.get_active_optimizers(self.total_batch_idx)
            ],
        )

    def _increment_accumulated_grad_global_step(self) -> None:
        """Increments global step according to grads progress"""
        if not self._should_accumulate():
            self.global_step = self.trainer.accelerator.update_global_step(
                self.batch_progress.current.ready, self.trainer.global_step)

    def _should_check_val_fx(self, batch_idx: int,
                             is_last_batch: bool) -> bool:
        """Decide if we should run validation."""
        if not self.trainer.enable_validation:
            return False

        is_val_check_epoch = (self.trainer.current_epoch +
                              1) % self.trainer.check_val_every_n_epoch == 0
        if not is_val_check_epoch:
            return False

        # val_check_batch is inf for iterable datasets with no length defined
        is_infinite_dataset = self.trainer.val_check_batch == float("inf")
        if is_last_batch and is_infinite_dataset:
            return True

        if self.trainer.should_stop:
            return True

        # TODO(@awaelchli): let training/eval loop handle logic around limit_*_batches and val_check_batch
        is_val_check_batch = is_last_batch
        if isinstance(self.trainer.limit_train_batches,
                      int) and is_infinite_dataset:
            is_val_check_batch = (batch_idx +
                                  1) % self.trainer.limit_train_batches == 0
        elif self.trainer.val_check_batch != float("inf"):
            is_val_check_batch = (batch_idx +
                                  1) % self.trainer.val_check_batch == 0
        return is_val_check_batch

    def _save_loggers_on_train_batch_end(self) -> None:
        """Flushes loggers to disk"""
        # when loggers should save to disk
        should_flush_logs = self.trainer.logger_connector.should_flush_logs
        if should_flush_logs and self.trainer.is_global_zero and self.trainer.logger is not None:
            self.trainer.logger.save()
class EvaluationEpochLoop(Loop):
    """
    This is the loop performing the evaluation. It mainly loops over the given dataloader and runs the validation
    or test step (depending on the trainer's current state).
    """
    def __init__(self) -> None:
        super().__init__()
        self.predictions: Optional[PredictionCollection] = None
        self.dataloader: Optional[Iterator] = None
        self._dl_max_batches: Optional[int] = None
        self._num_dataloaders: Optional[int] = None
        self.outputs: List[STEP_OUTPUT] = []
        self.batch_progress = Progress()

    @property
    def done(self) -> bool:
        """Returns ``True`` if the current iteration count reaches the number of dataloader batches."""
        return self.batch_progress.current.completed >= self._dl_max_batches

    def connect(self, **kwargs: "Loop") -> None:
        raise NotImplementedError(
            f"{self.__class__.__name__} does not connect any child loops.")

    def reset(self) -> None:
        """Resets the loop's internal state."""
        self.predictions = PredictionCollection(self.trainer.global_rank,
                                                self.trainer.world_size)
        self._dl_max_batches = None
        self._num_dataloaders = None
        self.outputs = []

        if not self.restarting:
            self.batch_progress.current.reset()

    def on_run_start(
        self,
        dataloader_iter: Iterator,
        dataloader_idx: int,
        dl_max_batches: int,
        num_dataloaders: int,
    ) -> None:
        """Adds the passed arguments to the loop's state if necessary

        Args:
            dataloader_iter: iterator over the dataloader
            dataloader_idx: index of the current dataloader
            dl_max_batches: maximum number of batches the dataloader can produce
            num_dataloaders: the total number of dataloaders
        """
        void(dataloader_iter, dataloader_idx)
        self._dl_max_batches = dl_max_batches
        self._num_dataloaders = num_dataloaders

    def advance(
        self,
        dataloader_iter: Iterator,
        dataloader_idx: int,
        dl_max_batches: int,
        num_dataloaders: int,
    ) -> None:
        """Calls the evaluation step with the corresponding hooks and updates the logger connector.

        Args:
            dataloader_iter: iterator over the dataloader
            dataloader_idx: index of the current dataloader
            dl_max_batches: maximum number of batches the dataloader can produce
            num_dataloaders: the total number of dataloaders

        Raises:
            StopIteration: If the current batch is None
        """
        void(dl_max_batches, num_dataloaders)

        batch_idx, batch = next(dataloader_iter)

        if batch is None:
            raise StopIteration

        with self.trainer.profiler.profile("evaluation_batch_to_device"):
            batch = self.trainer.accelerator.batch_to_device(
                batch, dataloader_idx=dataloader_idx)

        self.batch_progress.increment_ready()

        # hook
        self.on_evaluation_batch_start(batch, batch_idx, dataloader_idx)

        self.batch_progress.increment_started()

        # lightning module methods
        with self.trainer.profiler.profile("evaluation_step_and_end"):
            output = self.evaluation_step(batch, batch_idx, dataloader_idx)
            output = self.evaluation_step_end(output)

        self.batch_progress.increment_processed()

        # hook + store predictions
        self.on_evaluation_batch_end(output, batch, batch_idx, dataloader_idx)

        self.batch_progress.increment_completed()

        # log batch metrics
        self.trainer.logger_connector.update_eval_step_metrics()

        # track epoch level outputs
        self.outputs = self._track_output_for_epoch_end(self.outputs, output)

    def on_run_end(self) -> List[STEP_OUTPUT]:
        """Returns the outputs of the whole run"""
        outputs = self.outputs
        # free memory
        self.outputs = []
        return outputs

    def evaluation_step(self, batch: Any, batch_idx: int,
                        dataloader_idx: int) -> Optional[STEP_OUTPUT]:
        """The evaluation step (validation_step or test_step depending on the trainer's state).

        Args:
            batch: The current batch to run through the step.
            batch_idx: The index of the current batch
            dataloader_idx: the index of the dataloader producing the current batch

        Returns:
            the outputs of the step
        """
        # configure step_kwargs
        step_kwargs = self._build_kwargs(batch, batch_idx, dataloader_idx)

        if self.trainer.testing:
            self.trainer.lightning_module._current_fx_name = "test_step"
            with self.trainer.profiler.profile("test_step"):
                output = self.trainer.accelerator.test_step(step_kwargs)
        else:
            self.trainer.lightning_module._current_fx_name = "validation_step"
            with self.trainer.profiler.profile("validation_step"):
                output = self.trainer.accelerator.validation_step(step_kwargs)

        return output

    def evaluation_step_end(self, *args: Any,
                            **kwargs: Any) -> Optional[STEP_OUTPUT]:
        """Calls the `{validation/test}_step_end` hook"""
        hook_name = "test_step_end" if self.trainer.testing else "validation_step_end"
        output = self.trainer.call_hook(hook_name, *args, **kwargs)
        return output

    def on_evaluation_batch_start(self, batch: Any, batch_idx: int,
                                  dataloader_idx: int) -> None:
        """Calls the ``on_{validation/test}_batch_start`` hook.

        Args:
            batch: The current batch to run through the step
            batch_idx: The index of the current batch
            dataloader_idx: The index of the dataloader producing the current batch

        Raises:
            AssertionError: If the number of dataloaders is None (has not yet been set).
        """
        self.trainer.logger_connector.on_batch_start()

        assert self._num_dataloaders is not None
        self.trainer.logger_connector.on_evaluation_batch_start(
            batch, batch_idx, dataloader_idx, self._num_dataloaders)

        if self.trainer.testing:
            self.trainer.call_hook("on_test_batch_start", batch, batch_idx,
                                   dataloader_idx)
        else:
            self.trainer.call_hook("on_validation_batch_start", batch,
                                   batch_idx, dataloader_idx)

    def on_evaluation_batch_end(
        self,
        output: Optional[STEP_OUTPUT],
        batch: Any,
        batch_idx: int,
        dataloader_idx: int,
    ) -> None:
        """The ``on_{validation/test}_batch_end`` hook.

        Args:
            output: The output of the performed step
            batch: The input batch for the step
            batch_idx: The index of the current batch
            dataloader_idx: Index of the dataloader producing the current batch
        """
        hook_name = "on_test_batch_end" if self.trainer.testing else "on_validation_batch_end"
        self.trainer.call_hook(hook_name, output, batch, batch_idx,
                               dataloader_idx)

        self.trainer.logger_connector.on_batch_end()

        # store predicitons if do_write_predictions and track eval loss history
        self.store_predictions(output, batch_idx, dataloader_idx)

    def store_predictions(self, output: Optional[STEP_OUTPUT], batch_idx: int,
                          dataloader_idx: int) -> None:
        """Stores the predictions in the prediction collection (only if running in test mode)

        Args:
            output: the outputs of the current step
            batch_idx: the index of the current batch
            dataloader_idx: the index of the dataloader producing the current batch
        """
        # Add step predictions to prediction collection to write later
        if output is not None and self.predictions is not None:
            if isinstance(output, ResultCollection) and self.trainer.testing:
                self.predictions.add(output.pop("predictions", None))

        # track debug metrics
        self.trainer.dev_debugger.track_eval_loss_history(
            batch_idx, dataloader_idx, output)

    def _build_kwargs(self, batch: Any, batch_idx: int,
                      dataloader_idx: int) -> Dict[str, Union[Any, int]]:
        """Helper function to build the arguments for the current step

        Args:
            batch: The current batch to run through the step
            batch_idx: the index of the current batch
            dataloader_idx: the index of the dataloader producing the current batch

        Returns:
            the keyword arguments to pass to the step function
        """
        # make dataloader_idx arg in validation_step optional
        step_kwargs = OrderedDict([("batch", batch), ("batch_idx", batch_idx)])

        multiple_val_loaders = not self.trainer.testing and self._num_dataloaders > 1
        multiple_test_loaders = self.trainer.testing and self._num_dataloaders > 1

        if multiple_test_loaders or multiple_val_loaders:
            step_kwargs["dataloader_idx"] = dataloader_idx

        return step_kwargs

    def _track_output_for_epoch_end(
        self,
        outputs: List[Union[ResultCollection, Dict, Tensor]],
        output: Optional[Union[ResultCollection, Dict, Tensor]],
    ) -> List[Union[ResultCollection, Dict, Tensor]]:
        if output is not None:
            if isinstance(output, ResultCollection):
                output = output.detach()
                if self.trainer.move_metrics_to_cpu:
                    output = output.cpu()
            elif isinstance(output, dict):
                output = recursive_detach(
                    output, to_cpu=self.trainer.move_metrics_to_cpu)
            elif isinstance(
                    output, Tensor
            ) and output.is_cuda and self.trainer.move_metrics_to_cpu:
                output = output.cpu()
            outputs.append(output)
        return outputs
class FitLoop(Loop[None]):
    """This Loop iterates over the epochs to run the training.

    Args:
        min_epochs: The minimum number of epochs
        max_epochs: The maximum number of epochs, can be set -1 to turn this limit off
    """
    def __init__(
        self,
        min_epochs: int = 0,
        max_epochs: int = 1000,
    ) -> None:
        super().__init__()
        if max_epochs < -1:
            # Allow max_epochs to be zero, since this will be handled by fit_loop.done
            raise MisconfigurationException(
                f"`max_epochs` must be a non-negative integer or -1. You passed in {max_epochs}."
            )

        self.max_epochs = max_epochs
        self.min_epochs = min_epochs
        self.epoch_loop = TrainingEpochLoop()
        self.epoch_progress = Progress()

        self._is_fresh_start_epoch: bool = True
        self._outputs: _EPOCH_OUTPUTS_TYPE = []
        self._data_fetcher: Optional[AbstractDataFetcher] = None

    @property
    def total_batch_idx(self) -> int:
        """Returns the current batch index (across epochs)"""
        return self.epoch_loop.total_batch_idx

    @property
    def batch_idx(self) -> int:
        """Returns the current batch index (within this epoch)"""
        return self.epoch_loop.batch_idx

    @property
    def split_idx(self) -> int:
        """Returns the index of the current batch split (within the current batch) for bptt."""
        return self.epoch_loop.batch_loop.split_idx

    @property
    def min_steps(self) -> Optional[int]:
        # TODO(@justusschock): Why aren't we using the attribute in this class?
        """Returns the minimum number of steps to run."""
        return self.epoch_loop.min_steps

    @min_steps.setter
    def min_steps(self, value: Optional[int]) -> None:
        """Sets the minimum number of steps (forwards to epoch_loop)"""
        # TODO(@awaelchli): This setter is required by debugging connector (fast dev run), should be avoided
        self.epoch_loop.min_steps = value

    @property
    def max_steps(self) -> int:
        """Returns the maximum number of steps to run."""
        return self.epoch_loop.max_steps

    @max_steps.setter
    def max_steps(self, value: int) -> None:
        """Sets the maximum number of steps (forwards to epoch_loop)"""
        # TODO(@awaelchli): This setter is required by debugging connector (fast dev run), should be avoided
        if value is None:
            rank_zero_deprecation(
                "Setting `max_steps = None` is deprecated in v1.5 and will no longer be supported in v1.7."
                " Use `max_steps = -1` instead.")
            value = -1
        elif value < -1:
            raise MisconfigurationException(
                f"`max_steps` must be a non-negative integer or -1 (infinite steps). You passed in {value}."
            )
        self.epoch_loop.max_steps = value

    @property
    def running_loss(self) -> TensorRunningAccum:
        """Returns the running loss."""
        return self.epoch_loop.batch_loop.running_loss

    @Loop.restarting.setter
    def restarting(self, restarting: bool) -> None:
        # if the last epoch completely finished, we are not actually restarting, we can check this to see if all
        # current values are equal
        values = (
            self.epoch_progress.current.ready,
            self.epoch_progress.current.started,
            self.epoch_progress.current.processed,
        )
        finished_before_on_train_end = any(
            v != self.epoch_progress.current.completed for v in values)
        if finished_before_on_train_end:
            self.epoch_progress.current.completed = self.epoch_progress.current.processed
        restarting &= finished_before_on_train_end
        Loop.restarting.fset(self, restarting)  # call the parent setter

    @property
    def prefetch_batches(self) -> int:
        is_unsized = self.trainer.num_training_batches == float("inf")
        inter_batch_parallelism = os.getenv("PL_INTER_BATCH_PARALLELISM",
                                            "0") == "1"
        return 1 if is_unsized or inter_batch_parallelism else 0

    @property
    def _skip_backward(self) -> bool:
        """Determines whether the loop will skip backward during automatic optimization."""
        return self.epoch_loop.batch_loop.optimizer_loop._skip_backward

    @_skip_backward.setter
    def _skip_backward(self, value: bool) -> None:
        """Determines whether the loop will skip backward during automatic optimization."""
        self.epoch_loop.batch_loop.optimizer_loop._skip_backward = value

    @property
    def _results(self) -> _ResultCollection:
        if self.trainer.training:
            return self.epoch_loop._results
        if self.trainer.validating:
            return self.epoch_loop.val_loop._results
        raise RuntimeError(
            "`FitLoop._results` property isn't defined. Accessed outside of scope"
        )

    @property
    def done(self) -> bool:
        """Evaluates when to leave the loop."""
        # TODO(@awaelchli): Move track steps inside training loop and move part of these condition inside training loop
        stop_steps = _is_max_limit_reached(self.epoch_loop.global_step,
                                           self.max_steps)
        # `processed` is increased before `on_train_epoch_end`, the hook where checkpoints are typically saved.
        # we use it here because the checkpoint data won't have `completed` increased yet
        stop_epochs = _is_max_limit_reached(
            self.epoch_progress.current.processed, self.max_epochs)

        should_stop = False
        if self.trainer.should_stop:
            # early stopping
            met_min_epochs = self.epoch_progress.current.processed >= self.min_epochs if self.min_epochs else True
            met_min_steps = self.epoch_loop.global_step >= self.min_steps if self.min_steps else True
            if met_min_epochs and met_min_steps:
                should_stop = True
            else:
                log.info(
                    "Trainer was signaled to stop but required minimum epochs"
                    f" ({self.min_epochs}) or minimum steps ({self.min_steps}) has"
                    " not been met. Training will continue...")
        self.trainer.should_stop = should_stop

        return stop_steps or should_stop or stop_epochs or self.trainer.num_training_batches == 0

    @property
    def skip(self) -> bool:
        """Whether we should skip the training and immediately return from the call to :meth:`run`."""
        # since `trainer.num_training_batches` depends on the `train_dataloader` but that won't be called
        # until `on_run_start`, we use `limit_train_batches` instead
        return self.done or self.trainer.limit_train_batches == 0

    def connect(
            self,
            epoch_loop: TrainingEpochLoop) -> None:  # type: ignore[override]
        """Connects a training epoch loop to this fit loop."""
        self.epoch_loop = epoch_loop

    def reset(self) -> None:
        """Resets the internal state of this loop."""
        if self.restarting:
            self.epoch_progress.reset_on_restart()

    def on_run_start(self) -> None:  # type: ignore[override]
        """Calls the ``on_train_start`` hook."""
        # reset train dataloader and val dataloader
        self.trainer.reset_train_val_dataloaders(self.trainer.lightning_module)

        data_fetcher_cls = _select_data_fetcher(self.trainer)
        self._data_fetcher = data_fetcher_cls(
            prefetch_batches=self.prefetch_batches)

        self._is_fresh_start_epoch = True
        self._results.to(device=self.trainer.lightning_module.device)

        self.trainer._call_callback_hooks("on_train_start")
        self.trainer._call_lightning_module_hook("on_train_start")
        self.trainer._call_strategy_hook("on_train_start")

    def on_advance_start(self) -> None:  # type: ignore[override]
        """Prepares the dataloader for training and calls the hooks ``on_epoch_start`` and
        ``on_train_epoch_start``"""
        model = self.trainer.lightning_module

        # reset train dataloader
        if not self._is_fresh_start_epoch and self.trainer._data_connector._should_reload_train_dl:
            log.detail(
                f"{self.__class__.__name__}: resetting train dataloader")
            self.trainer.reset_train_dataloader(model)
        self._is_fresh_start_epoch = False

        # reset outputs here instead of in `reset` as they are not accumulated between epochs
        self._outputs = []

        if self.trainer.train_dataloader is not None and callable(
                getattr(self.trainer.train_dataloader.sampler, "set_epoch",
                        None)):
            # set seed for distributed sampler (enables shuffling for each epoch)
            self.trainer.train_dataloader.sampler.set_epoch(
                self.epoch_progress.current.processed)

        # changing gradient according accumulation_scheduler
        self.trainer.accumulation_scheduler.on_train_epoch_start(
            self.trainer, self.trainer.lightning_module)

        # stores accumulated grad fractions per batch
        self.epoch_loop.batch_loop.accumulated_loss.reset(
            window_length=self.trainer.accumulate_grad_batches)

        self.epoch_progress.increment_ready()

        self.trainer._logger_connector.on_epoch_start()

        self.trainer._call_callback_hooks("on_epoch_start")
        self.trainer._call_lightning_module_hook("on_epoch_start")

        self.trainer._call_callback_hooks("on_train_epoch_start")
        self.trainer._call_lightning_module_hook("on_train_epoch_start")

        self.epoch_progress.increment_started()

    def advance(self) -> None:  # type: ignore[override]
        """Runs one whole epoch."""
        log.detail(f"{self.__class__.__name__}: advancing loop")
        assert self.trainer.train_dataloader is not None
        dataloader = self.trainer.train_dataloader
        assert self._data_fetcher is not None
        self._data_fetcher.setup(dataloader,
                                 batch_to_device=partial(
                                     self.trainer._call_strategy_hook,
                                     "batch_to_device",
                                     dataloader_idx=0))
        with self.trainer.profiler.profile("run_training_epoch"):
            self._outputs = self.epoch_loop.run(self._data_fetcher)

    def on_advance_end(self) -> None:
        # inform logger the batch loop has finished
        self.trainer._logger_connector.epoch_end_reached()

        # get the model and call model.training_epoch_end
        model = self.trainer.lightning_module
        if is_overridden("training_epoch_end", model) and self._outputs:
            epoch_end_outputs = self.epoch_loop._prepare_outputs_training_epoch_end(
                self._outputs,
                lightning_module=model,
                num_optimizers=len(self.trainer.optimizers),
            )
            # run lightning module hook training_epoch_end
            # refresh the result for custom logging at the epoch level
            epoch_end_outputs = self.trainer._call_lightning_module_hook(
                "training_epoch_end", epoch_end_outputs)
            if epoch_end_outputs is not None:
                raise MisconfigurationException(
                    "`training_epoch_end` expects a return of None. "
                    "HINT: remove the return statement in `training_epoch_end`."
                )
        # free memory
        self._outputs = []

        self.epoch_progress.increment_processed()

        # call train epoch end hooks
        self.trainer._call_callback_hooks("on_train_epoch_end")
        self.trainer._call_lightning_module_hook("on_train_epoch_end")

        self.trainer._call_callback_hooks("on_epoch_end")
        self.trainer._call_lightning_module_hook("on_epoch_end")

        self.trainer._logger_connector.on_epoch_end()

        if self.epoch_loop._num_ready_batches_reached():
            self.epoch_loop.update_lr_schedulers(
                "epoch", update_plateau_schedulers=True)

        self.epoch_progress.increment_completed()

        # we manually decrease here because loggers expect that the same step is used when logging epoch-end metrics
        # even when the batch loop has finished
        self.epoch_loop._batches_that_stepped -= 1
        # log epoch metrics
        self.trainer._logger_connector.update_train_epoch_metrics()
        self.epoch_loop._batches_that_stepped += 1

        # if fault tolerant is enabled and process has been notified, exit.
        self.trainer._exit_gracefully_on_signal()

    def on_run_end(self) -> None:
        """Calls the ``on_train_end`` hook."""
        log.detail(f"{self.__class__.__name__}: train run ended")

        # hook
        self.trainer._call_callback_hooks("on_train_end")
        self.trainer._call_lightning_module_hook("on_train_end")
        self.trainer._call_strategy_hook("on_train_end")

    def teardown(self) -> None:
        if self._data_fetcher is not None:
            self._data_fetcher.teardown()
            self._data_fetcher = None
        self.epoch_loop.teardown()

    def _should_accumulate(self) -> bool:
        """Whether the gradients should be accumulated."""
        return self.epoch_loop._should_accumulate()
class PredictionEpochLoop(Loop):
    """Loop performing prediction on arbitrary sequentially used dataloaders."""
    def __init__(self) -> None:
        super().__init__()
        self.return_predictions: bool = False
        self.predictions: List[Any] = []
        self.current_batch_indices: List[int] = []
        self.batch_progress = Progress()

        self._dl_max_batches: Optional[int] = None
        self._num_dataloaders: Optional[int] = None
        self._warning_cache = WarningCache()
        self._all_batch_indices: List[int] = []

    @property
    def done(self) -> bool:
        """Ends prediction when the iteration count exceeds the total number of available batches."""
        return self.batch_progress.current.completed >= self._dl_max_batches

    @property
    def should_store_predictions(self) -> bool:
        """Whether the predictions should be stored for later usage (e.g. aggregation or returning)"""
        any_pred = any(cb.interval.on_epoch
                       for cb in self.trainer.prediction_writer_callbacks)
        return self.return_predictions or any_pred

    def connect(self, **kwargs: "Loop") -> None:
        raise NotImplementedError(
            f"{self.__class__.__name__} does not connect any child loops.")

    def reset(self) -> None:
        """Resets the loops internal state."""
        self._all_batch_indices: List[int] = []
        self.predictions: List[Any] = []
        self.batch_progress.reset_on_run()

    def on_run_start(
        self,
        dataloader_iter: Iterator,
        dataloader_idx: int,
        dl_max_batches: int,
        num_dataloaders: int,
        return_predictions: bool = False,
    ) -> None:
        """Prepares the loops internal state.

        Args:
            dataloader_iter: the iterator over the current dataloader
            dataloader_idx: the index of the current dataloader
            dl_max_batches: the maximum number of batches the current loader can produce
            num_dataloaders: the total number of dataloaders
            return_predictions: whether to return the obtained predictions
        """
        void(dataloader_iter, dataloader_idx)
        self._dl_max_batches = dl_max_batches
        self._num_dataloaders = num_dataloaders
        self.return_predictions = return_predictions

    def advance(
        self,
        dataloader_iter: Iterator,
        dataloader_idx: int,
        dl_max_batches: int,
        num_dataloaders: int,
        return_predictions: bool = False,
    ) -> None:
        """Runs one prediction step.

        Args:
            dataloader_iter: the iterator over the current dataloader
            dataloader_idx: the index of the current dataloader
            dl_max_batches: the maximum number of batches the current loader can produce
            num_dataloaders: the total number of dataloaders
            return_predictions: whether to return the obtained predictions
        """
        batch_idx, batch = next(dataloader_iter)
        if batch is None:
            raise StopIteration

        with self.trainer.profiler.profile("predict_batch_to_device"):
            batch = self.trainer.accelerator.batch_to_device(
                batch, dataloader_idx=dataloader_idx)

        self.batch_progress.increment_ready()

        with self.trainer.profiler.profile("predict_step"):
            self._predict_step(batch, batch_idx, dataloader_idx)

    def on_run_end(self) -> Tuple[List[Any], List[int]]:
        """Returns the predictions and the corresponding batch indices."""
        predictions = self.predictions
        all_batch_indices = self._all_batch_indices
        # free memory
        self.predictions = []
        self._all_batch_indices = []
        return predictions, all_batch_indices

    def _predict_step(self, batch: Any, batch_idx: int,
                      dataloader_idx: int) -> None:
        """Runs the actual predict step together with all the necessary bookkeeping and the hooks tied to the
        predict step.

        Args:
            batch: the current batch to run the prediction on
            batch_idx: the index of the current batch
            dataloader_idx: the index of the dataloader producing the current batch
        """
        # configure step_kwargs
        step_kwargs = self._build_kwargs(batch, batch_idx, dataloader_idx)

        # extract batch_indices and store them
        self._store_batch_indices(dataloader_idx)

        model_ref = self.trainer.lightning_module

        self.trainer.call_hook("on_predict_batch_start", batch, batch_idx,
                               dataloader_idx)

        self.batch_progress.increment_started()

        model_ref._current_fx_name = "predict_step"
        predictions = self.trainer.accelerator.predict_step(step_kwargs)

        self.batch_progress.increment_processed()

        if predictions is None:
            self._warning_cache.warn(
                "predict returned None if it was on purpose, ignore this warning..."
            )

        self.trainer.call_hook("on_predict_batch_end", predictions, batch,
                               batch_idx, dataloader_idx)

        self.batch_progress.increment_completed()

        if self.should_store_predictions:
            self.predictions.append(
                move_data_to_device(predictions, torch.device("cpu")))

    def _build_kwargs(self, batch: Any, batch_idx: int,
                      dataloader_idx: int) -> Dict[str, Any]:
        """Assembles the keyword arguments for the ``predict_step``

        Args:
            batch: the current batch to run the prediction on
            batch_idx: the index of the current batch
            dataloader_idx: the index of the dataloader producing the current batch

        Returns:
            the dictionary containing all the keyboard arguments for the predict step
        """
        step_kwargs = OrderedDict([("batch", batch), ("batch_idx", batch_idx)])
        if self._num_dataloaders > 1:
            step_kwargs["dataloader_idx"] = dataloader_idx
        return step_kwargs

    def _store_batch_indices(self, dataloader_idx: int) -> None:
        """Stores the batch indices if the predictions should be stored."""
        batch_sampler = self.trainer.predict_dataloaders[
            dataloader_idx].batch_sampler
        if isinstance(batch_sampler, IndexBatchSamplerWrapper):
            self.current_batch_indices = batch_sampler.batch_indices
            if self.should_store_predictions:
                self._all_batch_indices.append(batch_sampler.batch_indices)
        else:
            warning_cache.warn(
                "Lightning couldn't infer the indices fetched for your dataloader."
            )
class FitLoop(Loop):
    """This Loop iterates over the epochs to run the training.

    Args:
        min_epochs: The minimum number of epochs
        max_epochs: The maximum number of epochs, can be set -1 to turn this limit off
    """

    def __init__(
        self,
        min_epochs: Optional[int] = 1,
        max_epochs: int = 1000,
    ) -> None:
        super().__init__()
        if max_epochs < -1:
            # Allow max_epochs to be zero, since this will be handled by fit_loop.done
            raise MisconfigurationException(
                f"`max_epochs` must be a non-negative integer or -1. You passed in {max_epochs}."
            )

        self.max_epochs = max_epochs
        self.min_epochs = min_epochs
        self.epoch_loop: Optional[TrainingEpochLoop] = None
        self.epoch_progress = Progress()
        self._is_fresh_start_epoch: bool = True

    @property
    def current_epoch(self) -> int:
        """Return the current epoch."""
        return self.epoch_progress.current.completed

    @current_epoch.setter
    def current_epoch(self, value: int) -> None:
        """Setter for the current epoch."""
        self.epoch_progress.current.completed = value

    @property
    def global_step(self) -> int:
        """Returns the global step."""
        return self.epoch_loop.global_step

    @global_step.setter
    def global_step(self, value: int) -> None:
        """Sets the global step (forwards to epoch_loop)"""
        self.epoch_loop.global_step = value

    @property
    def total_batch_idx(self) -> int:
        """Returns the current batch index (across epochs)"""
        return self.epoch_loop.total_batch_idx

    @property
    def batch_idx(self) -> int:
        """Returns the current batch index (within this epoch)"""
        return self.epoch_loop.batch_idx

    @property
    def split_idx(self) -> int:
        """Returns the index of the current batch split (within the current batch) for bptt."""
        return self.epoch_loop.batch_loop.split_idx

    @property
    def min_steps(self) -> int:
        # TODO(@justusschock): Why aren't we using the attribute in this class?
        """Returns the minimum numnber of steps to run."""
        return self.epoch_loop.min_steps

    @min_steps.setter
    def min_steps(self, value: int) -> None:
        """Sets the minimum number of steps (forwards to epoch_loop)"""
        # TODO(@awaelchli): This setter is required by debugging connector (fast dev run), should be avoided
        self.epoch_loop.min_steps = value

    @property
    def max_steps(self) -> int:
        """Returns the maximum number of steps to run."""
        return self.epoch_loop.max_steps

    @max_steps.setter
    def max_steps(self, value: int) -> None:
        """Sets the maximum number of steps (forwards to epoch_loop)"""
        # TODO(@awaelchli): This setter is required by debugging connector (fast dev run), should be avoided
        if value is None:
            rank_zero_deprecation(
                "Setting `max_steps = None` is deprecated in v1.5 and will no longer be supported in v1.7."
                " Use `max_steps = -1` instead."
            )
            value = -1
        elif value < -1:
            raise MisconfigurationException(
                f"`max_steps` must be a non-negative integer or -1 (infinite steps). You passed in {value}."
            )
        self.epoch_loop.max_steps = value

    @property
    def running_loss(self) -> TensorRunningAccum:
        """Returns the running loss."""
        return self.epoch_loop.batch_loop.running_loss

    @property
    def _skip_backward(self) -> bool:
        """Determines whether the loop will skip backward during automatic optimization."""
        assert self.epoch_loop.batch_loop is not None
        assert self.epoch_loop.batch_loop.optimizer_loop is not None
        return self.epoch_loop.batch_loop.optimizer_loop._skip_backward

    @_skip_backward.setter
    def _skip_backward(self, value: bool) -> None:
        """Determines whether the loop will skip backward during automatic optimization."""
        assert self.epoch_loop.batch_loop is not None
        assert self.epoch_loop.batch_loop.optimizer_loop is not None
        self.epoch_loop.batch_loop.optimizer_loop._skip_backward = value

    @property
    def _results(self) -> ResultCollection:
        if self.trainer.training:
            return self.epoch_loop._results
        if self.trainer.validating:
            return self.epoch_loop.val_loop._results
        raise RuntimeError("`FitLoop._results` property isn't defined. Accessed outside of scope")

    @property
    def done(self) -> bool:
        """Evaluates when to leave the loop.

        Returns True if trainer.should_stop was set (e.g. by early stopping) or if the maximum number of steps or epochs
        is reached.
        """
        # TODO(@awaelchli): Move track steps inside training loop and move part of these condition inside training loop
        stop_steps = _is_max_limit_reached(self.global_step, self.max_steps)
        stop_epochs = _is_max_limit_reached(self.current_epoch, self.max_epochs)

        should_stop = False
        if self.trainer.should_stop:
            # early stopping
            met_min_epochs = self.current_epoch >= self.min_epochs if self.min_epochs else True
            met_min_steps = self.global_step >= self.min_steps if self.min_steps else True
            if met_min_epochs and met_min_steps:
                should_stop = True
            else:
                log.info(
                    "Trainer was signaled to stop but required minimum epochs"
                    f" ({self.min_epochs}) or minimum steps ({self.min_steps}) has"
                    " not been met. Training will continue..."
                )
        self.trainer.should_stop = should_stop

        return stop_steps or should_stop or stop_epochs or self.trainer.num_training_batches == 0

    @property
    def skip(self) -> bool:
        """Whether we should skip the training and immediately return from the call to :meth:`run`."""
        # since `trainer.num_training_batches` depends on the `train_dataloader` but that won't be called
        # until `on_run_start`, we use `limit_train_batches` instead
        return self.done or self.trainer.limit_train_batches == 0

    def connect(self, epoch_loop: TrainingEpochLoop):
        """Connects a training epoch loop to this fit loop."""
        self.epoch_loop = epoch_loop

    def reset(self) -> None:
        """Resets the internal state of this loop."""
        if self.restarting:
            self.epoch_progress.reset_on_restart()

    def on_run_start(self) -> None:
        """Calls the ``on_train_start`` hook."""
        # reset train dataloader and val dataloader
        self.trainer.reset_train_val_dataloaders(self.trainer.lightning_module)
        self._is_fresh_start_epoch = True
        self._results.to(device=self.trainer.lightning_module.device)
        self.trainer.call_hook("on_train_start")

    def on_advance_start(self) -> None:
        """Prepares the dataloader for training and calls the hooks ``on_epoch_start`` and
        ``on_train_epoch_start``"""
        model = self.trainer.lightning_module

        # reset train dataloader
        if not self._is_fresh_start_epoch and self.trainer._should_reload_dl_epoch:
            self.trainer.reset_train_dataloader(model)
        self._is_fresh_start_epoch = False

        if self.trainer.train_dataloader is not None and callable(
            getattr(self.trainer.train_dataloader.sampler, "set_epoch", None)
        ):
            # set seed for distributed sampler (enables shuffling for each epoch)
            self.trainer.train_dataloader.sampler.set_epoch(self.current_epoch)

        # changing gradient according accumulation_scheduler
        self.trainer.accumulation_scheduler.on_train_epoch_start(self.trainer, self.trainer.lightning_module)

        # stores accumulated grad fractions per batch
        self.epoch_loop.batch_loop.accumulated_loss = TensorRunningAccum(
            window_length=self.trainer.accumulate_grad_batches
        )

        self.epoch_progress.increment_ready()

    def advance(self) -> None:
        """Runs one whole epoch."""
        dataloader = self.trainer.training_type_plugin.process_dataloader(self.trainer.train_dataloader)
        data_fetcher = self.trainer._data_connector.get_profiled_dataloader(dataloader)

        with self.trainer.profiler.profile("run_training_epoch"):
            self.epoch_loop.run(data_fetcher)

            # the global step is manually decreased here due to backwards compatibility with existing loggers
            # as they expect that the same step is used when logging epoch end metrics even when the batch loop has
            # finished. this means the attribute does not exactly track the number of optimizer steps applied.
            # TODO(@carmocca): deprecate and rename so users don't get confused
            self.global_step -= 1
            # log epoch metrics
            self.trainer.logger_connector.update_train_epoch_metrics()
            self.global_step += 1

    def on_advance_end(self) -> None:
        self.epoch_progress.increment_completed()

    def on_run_end(self) -> None:
        """Calls the ``on_train_end`` hook."""
        # NOTE: the current_epoch is already incremented
        # Lightning today does not increment the current epoch at the last epoch run in Trainer.fit
        # To simulate that current behavior, we decrement here.
        # TODO: must be fixed by https://github.com/PyTorchLightning/pytorch-lightning/issues/5007
        self.current_epoch = max(self.current_epoch - 1, 0)

        # hook
        self.trainer.call_hook("on_train_end")

        # give accelerators a chance to finish
        self.trainer.training_type_plugin.on_train_end()

    def teardown(self) -> None:
        self.epoch_loop.teardown()

    def _should_accumulate(self) -> bool:
        """Whether the gradients should be accumulated."""
        return self.epoch_loop._should_accumulate()
Beispiel #9
0
class FitLoop(Loop):
    """
    This Loop iterates over the epochs to run the training.

    Args:
        min_epochs: The minimum number of epochs
        max_epochs: The maximum number of epochs
    """
    def __init__(self,
                 min_epochs: Optional[int] = None,
                 max_epochs: Optional[int] = None):
        super().__init__()
        self.max_epochs = max_epochs
        self.min_epochs = min_epochs
        self.epoch_loop: Optional[TrainingEpochLoop] = None
        self.epoch_progress = Progress()

    @property
    def current_epoch(self) -> int:
        """Return the current epoch"""
        return self.epoch_progress.current.completed

    @current_epoch.setter
    def current_epoch(self, value: int) -> None:
        """Setter for the current epoch"""
        self.epoch_progress.current.completed = value

    @property
    def global_step(self) -> int:
        """Returns the global step"""
        return self.epoch_loop.global_step

    @global_step.setter
    def global_step(self, value: int) -> None:
        """Sets the global step (forwards to epoch_loop)"""
        self.epoch_loop.global_step = value

    @property
    def total_batch_idx(self) -> int:
        """Returns the total number of batches already run (across all epochs)"""
        return self.epoch_loop.total_batch_idx

    @property
    def batch_idx(self) -> int:
        """Returns the number of batches already run within this epoch"""
        return self.epoch_loop.batch_progress.current.ready - 1

    @property
    def split_idx(self) -> int:
        """Returns the index of the current batch split (within the current batch) for bptt"""
        return self.epoch_loop.split_idx

    @property
    def min_steps(self) -> int:
        # TODO(@justusschock): Why aren't we using the attribute in this class?
        """Returns the minimum numnber of steps to run"""
        return self.epoch_loop.min_steps

    @min_steps.setter
    def min_steps(self, value: int) -> None:
        """Sets the minimum number of steps (forwards to epoch_loop)"""
        # TODO(@awaelchli): This setter is required by debugging connector (fast dev run), should be avoided
        self.epoch_loop.min_steps = value

    @property
    def max_steps(self) -> int:
        """Returns the maximum number of steps to run"""
        return self.epoch_loop.max_steps

    @max_steps.setter
    def max_steps(self, value: int) -> None:
        """Sets the maximum number of steps (forwards to epoch_loop)"""
        # TODO(@awaelchli): This setter is required by debugging connector (fast dev run), should be avoided
        self.epoch_loop.max_steps = value

    @property
    def running_loss(self) -> TensorRunningAccum:
        """Returns the running loss"""
        return self.epoch_loop.batch_loop.running_loss

    @property
    def _skip_backward(self) -> bool:
        """ Determines whether the loop will skip backward during automatic optimization. """
        return self.epoch_loop.batch_loop._skip_backward

    @_skip_backward.setter
    def _skip_backward(self, value: bool) -> None:
        """ Determines whether the loop will skip backward during automatic optimization. """
        self.epoch_loop.batch_loop._skip_backward = value

    @property
    def _results(self) -> ResultCollection:
        if self.trainer.training:
            return self.epoch_loop._results
        if self.trainer.validating:
            return self.epoch_loop.val_loop._results
        raise RuntimeError(
            "`FitLoop._results` property isn't defined. Accessed outside of scope"
        )

    @property
    def done(self) -> bool:
        """Evaluates when to leave the loop.

        Returns True if trainer.should_stop was set (e.g. by early stopping)
        or if the maximum number of steps or epochs is reached.
        """
        # TODO(@awaelchli): Move track steps inside training loop and move part of these condition inside training loop
        stop_steps = self.max_steps is not None and self.global_step >= self.max_steps
        stop_epochs = self.max_epochs is not None and self.current_epoch >= self.max_epochs

        should_stop = False
        if self.trainer.should_stop:
            # early stopping
            met_min_epochs = self.current_epoch >= self.min_epochs if self.min_epochs else True
            met_min_steps = self.global_step >= self.min_steps if self.min_steps else True
            if met_min_epochs and met_min_steps:
                should_stop = True
            else:
                log.info(
                    'Trainer was signaled to stop but required minimum epochs'
                    f' ({self.min_epochs}) or minimum steps ({self.min_steps}) has'
                    ' not been met. Training will continue...')
        self.trainer.should_stop = should_stop

        return stop_steps or should_stop or stop_epochs

    @property
    def skip(self) -> bool:
        """Whether we should skip the training and immediately return from the call to :meth:`run`."""
        return self.done or self.trainer.num_training_batches == 0

    def connect(self, epoch_loop: TrainingEpochLoop):
        """Connects a training epoch loop to this fit loop."""
        self.epoch_loop = epoch_loop

    def reset(self) -> None:
        """Resets the internal state of this loop"""

    def on_run_start(self) -> None:
        """Calls the ``on_train_start`` hook."""
        self._results.to(device=self.trainer.lightning_module.device)
        self.trainer.call_hook("on_train_start")

    def on_advance_start(self) -> None:
        """Prepares the dataloader for training and calls the hooks ``on_epoch_start`` and ``on_train_epoch_start``"""
        model = self.trainer.lightning_module

        # reset train dataloader
        if self.current_epoch != 0 and self.trainer._should_reload_dl_epoch:
            self.trainer.reset_train_dataloader(model)

        # TODO: specify the possible exception
        with suppress(Exception):
            # set seed for distributed sampler (enables shuffling for each epoch)
            self.trainer.train_dataloader.sampler.set_epoch(self.current_epoch)

        # changing gradient according accumulation_scheduler
        self.trainer.accumulation_scheduler.on_train_epoch_start(
            self.trainer, self.trainer.lightning_module)

        # stores accumulated grad fractions per batch
        self.epoch_loop.batch_loop.accumulated_loss = TensorRunningAccum(
            window_length=self.trainer.accumulate_grad_batches)

        self.epoch_progress.increment_ready()

    def advance(self) -> None:
        """Runs one whole epoch."""
        train_dataloader = self.trainer.accelerator.process_dataloader(
            self.trainer.train_dataloader)
        train_dataloader = self.trainer.data_connector.get_profiled_train_dataloader(
            train_dataloader)

        with self.trainer.profiler.profile("run_training_epoch"):
            # run train epoch
            epoch_output = self.epoch_loop.run(train_dataloader)

            if epoch_output is None:
                return

            # the global step is manually decreased here due to backwards compatibility with existing loggers
            # as they expect that the same step is used when logging epoch end metrics even when the batch loop has
            # finished. this means the attribute does not exactly track the number of optimizer steps applied.
            # TODO(@carmocca): deprecate and rename so users don't get confused
            self.global_step -= 1
            # log epoch metrics
            self.trainer.logger_connector.update_train_epoch_metrics()
            self.global_step += 1

    def on_advance_end(self) -> None:
        self.epoch_progress.increment_completed()

    def on_run_end(self) -> None:
        """Calls the ``on_train_end`` hook"""
        # NOTE: the iteration_count/current_epoch is already incremented
        # Lightning today does not increment the current epoch at the last epoch run in Trainer.fit
        # To simulate that current behavior, we decrement here.
        # TODO: must be fixed by https://github.com/PyTorchLightning/pytorch-lightning/issues/5007
        self.current_epoch -= 1

        # hook
        self.trainer.call_hook("on_train_end")

        # todo: TPU 8 cores hangs in flush with TensorBoard. Might do for all loggers.
        # It might be related to xla tensors blocked when moving the cpu
        # kill loggers
        if self.trainer.logger is not None:
            self.trainer.logger.finalize("success")

        # summarize profile results
        self.trainer.profiler.describe()

        # give accelerators a chance to finish
        self.trainer.accelerator.on_train_end()

    def should_accumulate(self) -> bool:
        """Whether the gradients should be accumulated"""
        return self.epoch_loop.batch_loop.should_accumulate()

    def teardown(self) -> None:
        self.epoch_loop.teardown()
Beispiel #10
0
class ActiveLearningLoop(Loop):
    max_epochs: int
    inference_model: InferenceMCDropoutTask

    @requires(["baal", (_PL_GREATER_EQUAL_1_4_0, "pytorch-lightning>=1.4.0")])
    def __init__(self,
                 label_epoch_frequency: int,
                 inference_iteration: int = 2,
                 should_reset_weights: bool = True):
        """The `ActiveLearning Loop` describes the following training procedure. This loop is connected with the
        `ActiveLearningTrainer`

        Example::

            while unlabelled data or budget criteria not reached:

                if labelled data
                    trainer.fit(model, labelled data)

                if unlabelled data:
                    predictions = trainer.predict(model, unlabelled data)
                    uncertainties = heuristic(predictions)
                    request labellelisation for the sample with highest uncertainties under a given budget

        Args:
            label_epoch_frequency: Number of epoch to train on before requesting labellisation.
            inference_iteration: Number of inference to perform to compute uncertainty.
        """
        super().__init__()
        self.label_epoch_frequency = label_epoch_frequency
        self.inference_iteration = inference_iteration
        self.should_reset_weights = should_reset_weights
        self.fit_loop: Optional[FitLoop] = None
        self.progress = Progress()
        self._model_state_dict: Optional[Dict[str, torch.Tensor]] = None
        self._datamodule_state_dict: Optional[Dict[str, Any]] = None
        self._lightning_module: Optional[flash.Task] = None

    @property
    def done(self) -> bool:
        return self.progress.current.completed >= self.max_epochs

    def connect(self, fit_loop: FitLoop):
        self.fit_loop = fit_loop
        self.max_epochs = self.fit_loop.max_epochs
        self.fit_loop.max_epochs = self.label_epoch_frequency

    def on_run_start(self, *args: Any, **kwargs: Any) -> None:
        assert isinstance(self.trainer.datamodule, ActiveLearningDataModule)
        if self._datamodule_state_dict is not None:
            self.trainer.datamodule.load_state_dict(
                self._datamodule_state_dict)
        self.trainer.predict_loop._return_predictions = True
        self._lightning_module = self.trainer.lightning_module
        self._model_state_dict = deepcopy(self._lightning_module.state_dict())
        self.inference_model = InferenceMCDropoutTask(self._lightning_module,
                                                      self.inference_iteration)

    def reset(self) -> None:
        pass

    def on_advance_start(self, *args: Any, **kwargs: Any) -> None:
        if self.trainer.datamodule.has_labelled_data:
            self._reset_dataloader_for_stage(RunningStage.TRAINING)
            self._reset_dataloader_for_stage(RunningStage.VALIDATING)
            if self.trainer.datamodule.has_test:
                self._reset_dataloader_for_stage(RunningStage.TESTING)
        if self.trainer.datamodule.has_unlabelled_data:
            self._reset_dataloader_for_stage(RunningStage.PREDICTING)
        self.progress.increment_ready()

    def advance(self, *args: Any, **kwargs: Any) -> None:

        self.progress.increment_started()

        if self.trainer.datamodule.has_labelled_data:
            self.fit_loop.run()

        if self.trainer.datamodule.has_test:
            self._reset_testing()
            metrics = self.trainer.test_loop.run()
            if metrics:
                self.trainer.logger.log_metrics(metrics[0],
                                                step=self.trainer.global_step)

        if self.trainer.datamodule.has_unlabelled_data:
            self._reset_predicting()
            probabilities = self.trainer.predict_loop.run()
            self.trainer.datamodule.label(probabilities=probabilities)
        else:
            raise StopIteration

        self._reset_fitting()
        self.progress.increment_processed()

    def on_advance_end(self) -> None:
        if self.trainer.datamodule.has_unlabelled_data and self.should_reset_weights:
            # reload the weights to retrain from scratch with the new labelled data.
            self._lightning_module.load_state_dict(self._model_state_dict)
        self.progress.increment_completed()
        return super().on_advance_end()

    def on_run_end(self):
        self._datamodule_state_dict = self.trainer.datamodule.state_dict()
        self._reset_fitting()
        self._teardown()
        return super().on_run_end()

    def on_save_checkpoint(self) -> Dict:
        return {"datamodule_state_dict": self._datamodule_state_dict}

    def on_load_checkpoint(self, state_dict) -> None:
        self._datamodule_state_dict = state_dict.pop("datamodule_state_dict",
                                                     None)

    def __getattr__(self, key):
        if key not in self.__dict__:
            return getattr(self.fit_loop, key)
        return self.__dict__[key]

    def _connect(self, model: LightningModule):
        if _PL_GREATER_EQUAL_1_5_0:
            self.trainer.training_type_plugin.connect(model)
        else:
            self.trainer.accelerator.connect(model)

    def _reset_fitting(self):
        self.trainer.state.fn = TrainerFn.FITTING
        self.trainer.training = True
        self.trainer.lightning_module.on_train_dataloader()
        self._connect(self._lightning_module)
        self.fit_loop.epoch_progress = Progress()

    def _reset_predicting(self):
        self.trainer.state.fn = TrainerFn.PREDICTING
        self.trainer.predicting = True
        self.trainer.lightning_module.on_predict_dataloader()
        self._connect(self.inference_model)

    def _reset_testing(self):
        self.trainer.state.fn = TrainerFn.TESTING
        self.trainer.state.status = TrainerStatus.RUNNING
        self.trainer.testing = True
        self.trainer.lightning_module.on_test_dataloader()
        self._connect(self._lightning_module)

    def _reset_dataloader_for_stage(self, running_state: RunningStage):
        dataloader_name = f"{_STAGES_PREFIX[running_state]}_dataloader"
        # If the dataloader exists, we reset it.
        dataloader = (getattr(
            self.trainer.datamodule, dataloader_name) if is_overridden(
                dataloader_name, self.trainer.datamodule) else None)

        if dataloader:
            if _PL_GREATER_EQUAL_1_5_0:
                setattr(
                    self.trainer._data_connector,
                    f"_{dataloader_name}_source",
                    _DataLoaderSource(self.trainer.datamodule,
                                      dataloader_name),
                )
            else:
                setattr(
                    self.trainer.lightning_module,
                    dataloader_name,
                    _PatchDataLoader(dataloader(), running_state),
                )
            setattr(self.trainer, dataloader_name, None)
            # TODO: Resolve this within PyTorch Lightning.
            try:
                getattr(self.trainer, f"reset_{dataloader_name}")(
                    self.trainer.lightning_module)
            except MisconfigurationException:
                pass

    def _teardown(self) -> None:
        self.trainer.train_dataloader = None
        self.trainer.val_dataloaders = None
        self.trainer.test_dataloaders = None
        self.trainer.predict_dataloaders = None
        # Hack
        self.trainer.lightning_module.train_dataloader = None
        self.trainer.lightning_module.val_dataloader = None
        self.trainer.lightning_module.test_dataloader = None
        self.trainer.lightning_module.predict_dataloader = None
class TrainingEpochLoop(loops.Loop):
    """
    Runs over all batches in a dataloader (one epoch).

    Args:
        min_steps: The minimum number of steps (batches) to process
        max_steps: The maximum number of steps (batches) to process
    """
    def __init__(self, min_steps: int, max_steps: int):
        super().__init__()
        self.min_steps: int = min_steps
        self.max_steps: int = max_steps
        self.global_step: int = 0
        # the total batch index across all epochs
        self.total_batch_idx: int = 0
        # the current split index when the batch gets split into chunks in truncated backprop through time
        self.split_idx: Optional[int] = None
        # the number of batches seen this run, updates immediately after batch_loop.run()
        # TODO: replace by progress tracking
        self.batches_seen: int = 0
        self.is_last_batch: Optional[bool] = None
        self.batch_progress = Progress()
        self.scheduler_progress = SchedulerProgress()

        self.batch_loop: Optional[TrainingBatchLoop] = None
        self.val_loop: Optional["loops.EvaluationLoop"] = None

        self._results = ResultCollection(training=True)
        self._dataloader_idx: Optional[int] = None
        self._warning_cache: WarningCache = WarningCache()
        self._epoch_output: Optional[List[List[STEP_OUTPUT]]] = None

    @property
    def batch_idx(self) -> int:
        """Returns the current batch index (within this epoch)"""
        return self.iteration_count

    @property
    def done(self) -> bool:
        """Returns whether the training should be stopped.
        The criteria are that the number of steps reached the max steps,
        the last batch is reached or the trainer signals to stop (e.g. by early stopping).
        """
        max_steps_reached = self.max_steps is not None and self.global_step >= self.max_steps
        return max_steps_reached or self.trainer.should_stop or self._num_training_batches_reached(
            self.is_last_batch)

    def connect(self,
                batch_loop: Optional[TrainingBatchLoop] = None,
                val_loop: Optional["loops.EvaluationLoop"] = None) -> None:
        """Optionally connect a custom batch or validation loop to this training epoch loop."""
        if batch_loop is not None:
            self.batch_loop = batch_loop
        if val_loop is not None:
            self.val_loop = val_loop

    def reset(self) -> None:
        """Resets the internal state of the loop for a new run"""
        self.iteration_count = 0
        self.batches_seen = 0
        self.is_last_batch = False
        self._dataloader_idx = 0

        # track epoch output
        self._epoch_output = [[] for _ in range(
            self.batch_loop.num_active_optimizers(self.total_batch_idx))]

        if self.restarting:
            self.iteration_count = self.batches_seen = self.batch_progress.current.completed
        else:
            self.batch_progress.current.reset()
            self.scheduler_progress.current.reset()
            self.batch_loop.optim_progress.reset_on_epoch()

    def on_run_start(self, *args: Any, **kwargs: Any) -> None:
        # hook
        self.trainer.logger_connector.on_epoch_start()
        self.trainer.call_hook("on_epoch_start")
        self.trainer.call_hook("on_train_epoch_start")
        self.trainer.fit_loop.epoch_progress.increment_started()

    def advance(self, dataloader_iter: Iterator, **kwargs: Any) -> None:
        """Runs a single training batch.

        Args:
            dataloader_iter: the iterator over the dataloader producing the new batch

        Raises:
            StopIteration: When the epoch is canceled by the user returning -1
        """
        _, (batch, is_last) = next(dataloader_iter)
        self.is_last_batch = is_last

        # ------------------------------------
        # TRAINING_STEP + TRAINING_STEP_END
        # ------------------------------------
        with self.trainer.profiler.profile("training_batch_to_device"):
            batch = self.trainer.accelerator.batch_to_device(
                batch, dataloader_idx=self._dataloader_idx)

        self.batch_progress.increment_ready()

        with self.trainer.profiler.profile("run_training_batch"):
            batch_output = self.batch_loop.run(batch, self.iteration_count,
                                               self._dataloader_idx)
            self.batches_seen += 1

        self.batch_progress.increment_processed()

        # when returning -1 from train_step, we end epoch early
        if batch_output.signal == -1:
            raise StopIteration

        # update non-plateau LR schedulers
        # update epoch-interval ones only when we are at the end of training epoch
        self.update_lr_schedulers("step", update_plateau_schedulers=False)
        if self._num_training_batches_reached(is_last):
            self.update_lr_schedulers("epoch", update_plateau_schedulers=False)

        batch_end_outputs = [
            opt_idx_out for opt_idx_out in batch_output.training_step_output
            if len(opt_idx_out)
        ]
        processed_batch_end_outputs = self._prepare_outputs(batch_end_outputs,
                                                            batch_mode=True)

        # hook
        self.trainer.call_hook("on_train_batch_end",
                               processed_batch_end_outputs, batch,
                               self.iteration_count, self._dataloader_idx)
        self.trainer.call_hook("on_batch_end")
        self.trainer.logger_connector.on_batch_end()

        self.batch_progress.increment_completed()

        # figure out what to track for epoch end
        self._track_epoch_end_reduce_metrics(self._epoch_output,
                                             batch_end_outputs)

        # -----------------------------------------
        # SAVE METRICS TO LOGGERS AND PROGRESS_BAR
        # -----------------------------------------
        self.trainer.logger_connector.update_train_step_metrics()

    def on_advance_end(self):
        """Runs validation and Checkpointing if necessary.

        Raises:
            StopIteration: if :attr:`done` evaluates to ``True`` to finish this epoch
        """
        # -----------------------------------------
        # VALIDATE IF NEEDED + CHECKPOINT CALLBACK
        # -----------------------------------------
        should_check_val = self._should_check_val_fx(self.iteration_count,
                                                     self.is_last_batch)
        if should_check_val:
            self.trainer.validating = True
            self._run_validation()
            self.trainer.training = True

        # -----------------------------------------
        # SAVE LOGGERS (ie: Tensorboard, etc...)
        # -----------------------------------------
        self._save_loggers_on_train_batch_end()

        # update plateau LR scheduler after metrics are logged
        self.update_lr_schedulers("step", update_plateau_schedulers=True)

        self.total_batch_idx += 1

        # progress global step according to grads progress
        self._increment_accumulated_grad_global_step()

        if self.done:
            raise StopIteration

    def on_run_end(self) -> List[List[STEP_OUTPUT]]:
        """Calls the on_epoch_end hook.

        Returns:
            The output of each training step for each optimizer

        Raises:
            MisconfigurationException: ``train_epoch_end`` does not return ``None``
        """
        if self.batches_seen == 0:
            # dataloader/iterator did not produce a batch
            return

        # inform logger the batch loop has finished
        self.trainer.logger_connector.epoch_end_reached()

        # prepare epoch output
        processed_outputs = self._prepare_outputs(self._epoch_output,
                                                  batch_mode=False)

        # get the model and call model.training_epoch_end
        model = self.trainer.lightning_module

        if is_overridden("training_epoch_end", model):
            # run training_epoch_end
            # refresh the result for custom logging at the epoch level
            model._current_fx_name = "training_epoch_end"

            # lightningmodule hook
            training_epoch_end_output = model.training_epoch_end(
                processed_outputs)

            if training_epoch_end_output is not None:
                raise MisconfigurationException(
                    "training_epoch_end expects a return of None. "
                    "HINT: remove the return statement in training_epoch_end")

        self.trainer.fit_loop.epoch_progress.increment_processed()

        # call train epoch end hooks
        self._on_train_epoch_end_hook(processed_outputs)
        self.trainer.call_hook("on_epoch_end")
        self.trainer.logger_connector.on_epoch_end()

        self.update_lr_schedulers("epoch", update_plateau_schedulers=True)

        epoch_output = self._epoch_output
        # free memory
        self._epoch_output = None
        return epoch_output

    def teardown(self) -> None:
        self._results.cpu()
        self.batch_loop.teardown()
        self.val_loop.teardown()

    def _run_validation(self):
        # reload dataloaders
        self.val_loop.reload_evaluation_dataloaders()

        with torch.no_grad():
            self.val_loop.run()

    def _on_train_epoch_end_hook(
            self, processed_epoch_output: List[List[STEP_OUTPUT]]) -> None:
        """Runs ``on_train_epoch_end hook``."""
        # We cannot rely on Trainer.call_hook because the signatures might be different across
        # lightning module and callback
        # As a result, we need to inspect if the module accepts `outputs` in `on_train_epoch_end`

        # This implementation is copied from Trainer.call_hook
        hook_name = "on_train_epoch_end"
        prev_fx_name = self.trainer.lightning_module._current_fx_name
        self.trainer.lightning_module._current_fx_name = hook_name

        # always profile hooks
        with self.trainer.profiler.profile(hook_name):

            # first call trainer hook
            if hasattr(self.trainer, hook_name):
                trainer_hook = getattr(self.trainer, hook_name)
                trainer_hook(processed_epoch_output)

            # next call hook in lightningModule
            model_ref = self.trainer.lightning_module
            if is_overridden(hook_name, model_ref):
                hook_fx = getattr(model_ref, hook_name)
                if is_param_in_hook_signature(hook_fx, "outputs"):
                    self._warning_cache.deprecation(
                        "The signature of `ModelHooks.on_train_epoch_end` has changed in v1.3."
                        " `outputs` parameter has been deprecated."
                        " Support for the old signature will be removed in v1.5"
                    )
                    model_ref.on_train_epoch_end(processed_epoch_output)
                else:
                    model_ref.on_train_epoch_end()

            # call the accelerator hook
            if hasattr(self.trainer.accelerator, hook_name):
                accelerator_hook = getattr(self.trainer.accelerator, hook_name)
                accelerator_hook()

        # restore current_fx when nested context
        self.trainer.lightning_module._current_fx_name = prev_fx_name

    def _num_training_batches_reached(self,
                                      is_last_batch: bool = False) -> bool:
        """Checks if we are in the last batch or if there are more batches to follow."""

        # TODO: Can we combine this with training_batch_loop's arg that does a similar check?
        return self.batches_seen == self.trainer.num_training_batches or is_last_batch

    def _track_epoch_end_reduce_metrics(
            self, epoch_output: List[List[STEP_OUTPUT]],
            batch_end_outputs: STEP_OUTPUT) -> None:
        """Adds the batch outputs to the epoch outputs and prepares reduction"""
        hook_overridden = self._should_add_batch_output_to_epoch_output()
        if not hook_overridden:
            return

        # track the outputs to reduce at the end of the epoch
        for opt_idx, opt_outputs in enumerate(batch_end_outputs):
            # with 1 step (no tbptt) don't use a sequence at epoch end
            if (isinstance(opt_outputs, list) and len(opt_outputs) == 1
                    and not isinstance(opt_outputs[0], ResultCollection)):
                opt_outputs = opt_outputs[0]

            epoch_output[opt_idx].append(opt_outputs)

    def _should_add_batch_output_to_epoch_output(self) -> bool:
        """
        We add to the epoch outputs if
        1. The model defines training_epoch_end OR
        2. The model overrides on_train_epoch_end which has `outputs` in the signature
        """
        # TODO: in v1.5 this only needs to check if training_epoch_end is overridden
        lightning_module = self.trainer.lightning_module
        if is_overridden("training_epoch_end", lightning_module):
            return True

        if is_overridden("on_train_epoch_end", lightning_module):
            model_hook_fx = getattr(lightning_module, "on_train_epoch_end")
            if is_param_in_hook_signature(model_hook_fx, "outputs"):
                return True

        return False

    @staticmethod
    def _prepare_outputs(
        outputs: List[List[List["ResultCollection"]]], batch_mode: bool
    ) -> Union[List[List[List[Dict]]], List[List[Dict]], List[Dict], Dict]:
        """
        Extract required information from batch or epoch end results.

        Args:
            outputs: A 3-dimensional list of ``ResultCollection`` objects with dimensions:
                ``[optimizer outs][batch outs][tbptt steps]``.

            batch_mode: If True, ignore the batch output dimension.

        Returns:
            The cleaned outputs with ``ResultCollection`` objects converted to dictionaries.
            All list dimensions of size one will be collapsed.
        """
        processed_outputs = []
        for opt_outputs in outputs:
            # handle an edge case where an optimizer output is the empty list
            if len(opt_outputs) == 0:
                continue

            processed_batch_outputs = []

            if batch_mode:
                opt_outputs = [opt_outputs]

            for batch_outputs in opt_outputs:
                processed_tbptt_outputs = []

                if isinstance(batch_outputs, ResultCollection):
                    batch_outputs = [batch_outputs]

                for tbptt_output in batch_outputs:
                    out = tbptt_output.extra
                    if tbptt_output.minimize is not None:
                        out["loss"] = tbptt_output.minimize.detach()
                    processed_tbptt_outputs.append(out)

                # if there was only one tbptt step then we can collapse that dimension
                if len(processed_tbptt_outputs) == 1:
                    processed_tbptt_outputs = processed_tbptt_outputs[0]
                processed_batch_outputs.append(processed_tbptt_outputs)

            # batch_outputs should be just one dict (or a list of dicts if using tbptt) per optimizer
            if batch_mode:
                processed_batch_outputs = processed_batch_outputs[0]
            processed_outputs.append(processed_batch_outputs)

        # if there is only one optimiser then we collapse that dimension
        if len(processed_outputs) == 1:
            processed_outputs = processed_outputs[0]
        return processed_outputs

    def update_lr_schedulers(self, interval: str,
                             update_plateau_schedulers: bool) -> None:
        """updates the lr schedulers based on the given interval"""
        if interval == "step" and self.batch_loop.should_accumulate():
            return
        self.trainer.optimizer_connector.update_learning_rates(
            interval=interval,
            update_plateau_schedulers=update_plateau_schedulers,
            opt_indices=[
                opt_idx for opt_idx, _ in
                self.batch_loop.get_active_optimizers(self.total_batch_idx)
            ],
        )

    def _increment_accumulated_grad_global_step(self) -> None:
        """increments global step"""
        num_accumulated_batches_reached = self.batch_loop._accumulated_batches_reached(
        )
        num_training_batches_reached = self._num_training_batches_reached()

        # progress global step according to grads progress
        if num_accumulated_batches_reached or num_training_batches_reached:
            self.global_step = self.trainer.accelerator.update_global_step(
                self.total_batch_idx, self.trainer.global_step)

    def _should_check_val_fx(self, batch_idx: int,
                             is_last_batch: bool) -> bool:
        """Decide if we should run validation."""
        if not self.trainer.enable_validation:
            return False

        is_val_check_epoch = (self.trainer.current_epoch +
                              1) % self.trainer.check_val_every_n_epoch == 0
        if not is_val_check_epoch:
            return False

        # val_check_batch is inf for iterable datasets with no length defined
        is_infinite_dataset = self.trainer.val_check_batch == float("inf")
        if is_last_batch and is_infinite_dataset:
            return True

        if self.trainer.should_stop:
            return True

        # TODO(@awaelchli): let training/eval loop handle logic around limit_*_batches and val_check_batch
        is_val_check_batch = is_last_batch
        if isinstance(self.trainer.limit_train_batches,
                      int) and is_infinite_dataset:
            is_val_check_batch = (batch_idx +
                                  1) % self.trainer.limit_train_batches == 0
        elif self.trainer.val_check_batch != float("inf"):
            is_val_check_batch = (batch_idx +
                                  1) % self.trainer.val_check_batch == 0
        return is_val_check_batch

    def _save_loggers_on_train_batch_end(self) -> None:
        """Flushes loggers to disk"""
        # when loggers should save to disk
        should_flush_logs = self.trainer.logger_connector.should_flush_logs
        if should_flush_logs and self.trainer.is_global_zero and self.trainer.logger is not None:
            self.trainer.logger.save()
class EvaluationEpochLoop(Loop):
    """This is the loop performing the evaluation.

    It mainly loops over the given dataloader and runs the validation or test step (depending on the trainer's current
    state).
    """
    def __init__(self) -> None:
        super().__init__()
        self.dataloader: Optional[Iterator] = None
        self._dl_max_batches: Optional[int] = None
        self._num_dataloaders: Optional[int] = None
        self.outputs: EPOCH_OUTPUT = []
        self.batch_progress = Progress()
        self.dataloader_iter: Optional[Iterator] = None

    @property
    def done(self) -> bool:
        """Returns ``True`` if the current iteration count reaches the number of dataloader batches."""
        return self.batch_progress.current.completed >= self._dl_max_batches

    def connect(self, **kwargs: "Loop") -> None:
        raise NotImplementedError(
            f"{self.__class__.__name__} does not connect any child loops.")

    def reset(self) -> None:
        """Resets the loop's internal state."""
        self._dl_max_batches = None
        self._num_dataloaders = None
        self.outputs = []

        if not self.restarting:
            self.batch_progress.current.reset()

    def on_run_start(self, data_fetcher: AbstractDataFetcher,
                     dataloader_idx: int, dl_max_batches: int,
                     num_dataloaders: int) -> None:
        """Adds the passed arguments to the loop's state if necessary.

        Args:
            data_fetcher: the current data_fetcher wrapping the dataloader
            dataloader_idx: index of the current dataloader
            dl_max_batches: maximum number of batches the dataloader can produce
            num_dataloaders: the total number of dataloaders
        """
        void(dataloader_idx)
        self._dl_max_batches = dl_max_batches
        self._num_dataloaders = num_dataloaders

        self.dataloader_iter = _prepare_dataloader_iter(
            data_fetcher, self.batch_progress.current.ready)

    def advance(self, data_fetcher: AbstractDataFetcher, dataloader_idx: int,
                dl_max_batches: int, num_dataloaders: int) -> None:
        """Calls the evaluation step with the corresponding hooks and updates the logger connector.

        Args:
            dataloader_iter: iterator over the dataloader
            dataloader_idx: index of the current dataloader
            dl_max_batches: maximum number of batches the dataloader can produce
            num_dataloaders: the total number of dataloaders

        Raises:
            StopIteration: If the current batch is None
        """
        void(data_fetcher, dl_max_batches, num_dataloaders)

        batch_idx, (batch, _) = next(self.dataloader_iter)

        if batch is None:
            raise StopIteration

        if not self.trainer.data_connector.evaluation_data_fetcher.store_on_device:
            with self.trainer.profiler.profile("evaluation_batch_to_device"):
                batch = self.trainer.accelerator.batch_to_device(
                    batch, dataloader_idx=dataloader_idx)

        self.batch_progress.increment_ready()

        # hook
        self.on_evaluation_batch_start(batch, batch_idx, dataloader_idx)

        self.batch_progress.increment_started()

        # lightning module methods
        with self.trainer.profiler.profile("evaluation_step_and_end"):
            output = self.evaluation_step(batch, batch_idx, dataloader_idx)
            output = self.evaluation_step_end(output)

        self.batch_progress.increment_processed()

        # track loss history
        self.on_evaluation_batch_end(output, batch, batch_idx, dataloader_idx)

        self.batch_progress.increment_completed()

        # log batch metrics
        self.trainer.logger_connector.update_eval_step_metrics()

        # track epoch level outputs
        if self._should_track_batch_outputs_for_epoch_end():
            output = recursive_detach(output,
                                      to_cpu=self.trainer.move_metrics_to_cpu)
            if output is not None:
                self.outputs.append(output)

    def on_run_end(self) -> EPOCH_OUTPUT:
        """Returns the outputs of the whole run."""
        outputs = self.outputs
        # free memory
        self.outputs = []
        return outputs

    def evaluation_step(self, batch: Any, batch_idx: int,
                        dataloader_idx: int) -> Optional[STEP_OUTPUT]:
        """The evaluation step (validation_step or test_step depending on the trainer's state).

        Args:
            batch: The current batch to run through the step.
            batch_idx: The index of the current batch
            dataloader_idx: the index of the dataloader producing the current batch

        Returns:
            the outputs of the step
        """
        # configure step_kwargs
        step_kwargs = self._build_kwargs(batch, batch_idx, dataloader_idx)

        if self.trainer.testing:
            self.trainer.lightning_module._current_fx_name = "test_step"
            with self.trainer.profiler.profile("test_step"):
                output = self.trainer.accelerator.test_step(step_kwargs)
        else:
            self.trainer.lightning_module._current_fx_name = "validation_step"
            with self.trainer.profiler.profile("validation_step"):
                output = self.trainer.accelerator.validation_step(step_kwargs)

        return output

    def evaluation_step_end(self, *args: Any,
                            **kwargs: Any) -> Optional[STEP_OUTPUT]:
        """Calls the `{validation/test}_step_end` hook."""
        hook_name = "test_step_end" if self.trainer.testing else "validation_step_end"
        output = self.trainer.call_hook(hook_name, *args, **kwargs)
        return output

    def on_evaluation_batch_start(self, batch: Any, batch_idx: int,
                                  dataloader_idx: int) -> None:
        """Calls the ``on_{validation/test}_batch_start`` hook.

        Args:
            batch: The current batch to run through the step
            batch_idx: The index of the current batch
            dataloader_idx: The index of the dataloader producing the current batch

        Raises:
            AssertionError: If the number of dataloaders is None (has not yet been set).
        """
        self.trainer.logger_connector.on_batch_start()

        assert self._num_dataloaders is not None
        self.trainer.logger_connector.on_evaluation_batch_start(
            batch, batch_idx, dataloader_idx, self._num_dataloaders)

        if self.trainer.testing:
            self.trainer.call_hook("on_test_batch_start", batch, batch_idx,
                                   dataloader_idx)
        else:
            self.trainer.call_hook("on_validation_batch_start", batch,
                                   batch_idx, dataloader_idx)

    def on_evaluation_batch_end(self, output: Optional[STEP_OUTPUT],
                                batch: Any, batch_idx: int,
                                dataloader_idx: int) -> None:
        """The ``on_{validation/test}_batch_end`` hook.

        Args:
            output: The output of the performed step
            batch: The input batch for the step
            batch_idx: The index of the current batch
            dataloader_idx: Index of the dataloader producing the current batch
        """
        hook_name = "on_test_batch_end" if self.trainer.testing else "on_validation_batch_end"
        self.trainer.call_hook(hook_name, output, batch, batch_idx,
                               dataloader_idx)

        self.trainer.logger_connector.on_batch_end()

    def _build_kwargs(self, batch: Any, batch_idx: int,
                      dataloader_idx: int) -> Dict[str, Union[Any, int]]:
        """Helper function to build the arguments for the current step.

        Args:
            batch: The current batch to run through the step
            batch_idx: the index of the current batch
            dataloader_idx: the index of the dataloader producing the current batch

        Returns:
            the keyword arguments to pass to the step function
        """
        # make dataloader_idx arg in validation_step optional
        step_kwargs = OrderedDict([("batch", batch), ("batch_idx", batch_idx)])

        multiple_val_loaders = not self.trainer.testing and self._num_dataloaders > 1
        multiple_test_loaders = self.trainer.testing and self._num_dataloaders > 1

        if multiple_test_loaders or multiple_val_loaders:
            step_kwargs["dataloader_idx"] = dataloader_idx

        return step_kwargs

    @lru_cache(1)
    def _should_track_batch_outputs_for_epoch_end(self) -> bool:
        """Whether the batch outputs should be stored for later usage."""
        model = self.trainer.lightning_module
        if self.trainer.testing:
            return is_overridden("test_epoch_end", model)
        return is_overridden("validation_epoch_end", model)

    def teardown(self) -> None:
        # in case the model changes
        self._should_track_batch_outputs_for_epoch_end.cache_clear()
Beispiel #13
0
class FitLoop(Loop[None]):
    """This Loop iterates over the epochs to run the training.

    Args:
        min_epochs: The minimum number of epochs
        max_epochs: The maximum number of epochs, can be set -1 to turn this limit off
    """
    def __init__(
        self,
        min_epochs: Optional[int] = 1,
        max_epochs: int = 1000,
    ) -> None:
        super().__init__()
        if max_epochs < -1:
            # Allow max_epochs to be zero, since this will be handled by fit_loop.done
            raise MisconfigurationException(
                f"`max_epochs` must be a non-negative integer or -1. You passed in {max_epochs}."
            )

        self.max_epochs = max_epochs
        self.min_epochs = min_epochs
        self.epoch_loop = TrainingEpochLoop()
        self.epoch_progress = Progress()

        self._is_fresh_start_epoch: bool = True
        self._outputs: _EPOCH_OUTPUTS_TYPE = []

    @property
    def global_step(self) -> int:
        """Returns the global step."""
        return self.epoch_loop.global_step

    @global_step.setter
    def global_step(self, value: int) -> None:
        """Sets the global step (forwards to epoch_loop)"""
        self.epoch_loop.global_step = value

    @property
    def total_batch_idx(self) -> int:
        """Returns the current batch index (across epochs)"""
        return self.epoch_loop.total_batch_idx

    @property
    def batch_idx(self) -> int:
        """Returns the current batch index (within this epoch)"""
        return self.epoch_loop.batch_idx

    @property
    def split_idx(self) -> int:
        """Returns the index of the current batch split (within the current batch) for bptt."""
        return self.epoch_loop.batch_loop.split_idx

    @property
    def min_steps(self) -> Optional[int]:
        # TODO(@justusschock): Why aren't we using the attribute in this class?
        """Returns the minimum numnber of steps to run."""
        return self.epoch_loop.min_steps

    @min_steps.setter
    def min_steps(self, value: Optional[int]) -> None:
        """Sets the minimum number of steps (forwards to epoch_loop)"""
        # TODO(@awaelchli): This setter is required by debugging connector (fast dev run), should be avoided
        self.epoch_loop.min_steps = value

    @property
    def max_steps(self) -> int:
        """Returns the maximum number of steps to run."""
        return self.epoch_loop.max_steps

    @max_steps.setter
    def max_steps(self, value: int) -> None:
        """Sets the maximum number of steps (forwards to epoch_loop)"""
        # TODO(@awaelchli): This setter is required by debugging connector (fast dev run), should be avoided
        if value is None:
            rank_zero_deprecation(
                "Setting `max_steps = None` is deprecated in v1.5 and will no longer be supported in v1.7."
                " Use `max_steps = -1` instead.")
            value = -1
        elif value < -1:
            raise MisconfigurationException(
                f"`max_steps` must be a non-negative integer or -1 (infinite steps). You passed in {value}."
            )
        self.epoch_loop.max_steps = value

    @property
    def running_loss(self) -> TensorRunningAccum:
        """Returns the running loss."""
        return self.epoch_loop.batch_loop.running_loss

    @property
    def _skip_backward(self) -> bool:
        """Determines whether the loop will skip backward during automatic optimization."""
        return self.epoch_loop.batch_loop.optimizer_loop._skip_backward

    @_skip_backward.setter
    def _skip_backward(self, value: bool) -> None:
        """Determines whether the loop will skip backward during automatic optimization."""
        self.epoch_loop.batch_loop.optimizer_loop._skip_backward = value

    @property
    def _results(self) -> _ResultCollection:
        if self.trainer.training:
            return self.epoch_loop._results
        if self.trainer.validating:
            return self.epoch_loop.val_loop._results
        raise RuntimeError(
            "`FitLoop._results` property isn't defined. Accessed outside of scope"
        )

    @property
    def done(self) -> bool:
        """Evaluates when to leave the loop."""
        # TODO(@awaelchli): Move track steps inside training loop and move part of these condition inside training loop
        stop_steps = _is_max_limit_reached(self.global_step, self.max_steps)
        stop_epochs = _is_max_limit_reached(
            self.epoch_progress.current.completed, self.max_epochs)

        should_stop = False
        if self.trainer.should_stop:
            # early stopping
            met_min_epochs = self.epoch_progress.current.completed >= self.min_epochs if self.min_epochs else True
            met_min_steps = self.global_step >= self.min_steps if self.min_steps else True
            if met_min_epochs and met_min_steps:
                should_stop = True
            else:
                log.info(
                    "Trainer was signaled to stop but required minimum epochs"
                    f" ({self.min_epochs}) or minimum steps ({self.min_steps}) has"
                    " not been met. Training will continue...")
        self.trainer.should_stop = should_stop

        return stop_steps or should_stop or stop_epochs or self.trainer.num_training_batches == 0

    @property
    def skip(self) -> bool:
        """Whether we should skip the training and immediately return from the call to :meth:`run`."""
        # since `trainer.num_training_batches` depends on the `train_dataloader` but that won't be called
        # until `on_run_start`, we use `limit_train_batches` instead
        return self.done or self.trainer.limit_train_batches == 0

    def connect(
            self,
            epoch_loop: TrainingEpochLoop) -> None:  # type: ignore[override]
        """Connects a training epoch loop to this fit loop."""
        self.epoch_loop = epoch_loop

    def reset(self) -> None:
        """Resets the internal state of this loop."""
        if self.restarting:
            self.epoch_progress.reset_on_restart()

    def on_run_start(self) -> None:  # type: ignore[override]
        """Calls the ``on_train_start`` hook."""
        # reset train dataloader and val dataloader
        self.trainer.reset_train_val_dataloaders(self.trainer.lightning_module)

        ft_enabled = _FaultTolerantMode.detect_current_mode().is_enabled
        if not ft_enabled and self.restarting and self.trainer.num_training_batches not in (
                0, float("inf")):
            self.trainer.accumulate_grad_batches = self.trainer.accumulation_scheduler.get_accumulate_grad_batches(
                self.trainer.current_epoch)
            expected_steps = math.ceil(self.trainer.num_training_batches /
                                       self.trainer.accumulate_grad_batches)

            # global_step is incremented during checkpointing (#11555)
            if (self.trainer.global_step - 1) % expected_steps != 0:
                rank_zero_warn(
                    "You're resuming from a checkpoint that ended mid-epoch."
                    " Training will start from the beginning of the next epoch."
                    " This can cause unreliable results if further training is done,"
                    " consider using an end of epoch checkpoint or use fault-tolerant training"
                    " to restart as if training did not stop.")

        self._is_fresh_start_epoch = True
        self._results.to(device=self.trainer.lightning_module.device)
        self.trainer._call_callback_hooks("on_train_start")
        self.trainer._call_lightning_module_hook("on_train_start")
        self.trainer._call_strategy_hook("on_train_start")

    def on_advance_start(self) -> None:  # type: ignore[override]
        """Prepares the dataloader for training and calls the hooks ``on_epoch_start`` and
        ``on_train_epoch_start``"""
        model = self.trainer.lightning_module

        # reset train dataloader
        if not self._is_fresh_start_epoch and self.trainer._data_connector._should_reload_train_dl:
            log.detail(
                f"{self.__class__.__name__}: resetting train dataloader")
            self.trainer.reset_train_dataloader(model)
        self._is_fresh_start_epoch = False

        # reset outputs here instead of in `reset` as they are not accumulated between epochs
        self._outputs = []

        if self.trainer.train_dataloader is not None and callable(
                getattr(self.trainer.train_dataloader.sampler, "set_epoch",
                        None)):
            # set seed for distributed sampler (enables shuffling for each epoch)
            self.trainer.train_dataloader.sampler.set_epoch(
                self.epoch_progress.current.completed)

        # changing gradient according accumulation_scheduler
        self.trainer.accumulation_scheduler.on_train_epoch_start(
            self.trainer, self.trainer.lightning_module)

        # stores accumulated grad fractions per batch
        self.epoch_loop.batch_loop.accumulated_loss.reset(
            window_length=self.trainer.accumulate_grad_batches)

        self.epoch_progress.increment_ready()

        self.trainer.logger_connector.on_epoch_start()

        self.trainer._call_callback_hooks("on_epoch_start")
        self.trainer._call_lightning_module_hook("on_epoch_start")

        self.trainer._call_callback_hooks("on_train_epoch_start")
        self.trainer._call_lightning_module_hook("on_train_epoch_start")

        self.epoch_progress.increment_started()

    def advance(self) -> None:  # type: ignore[override]
        """Runs one whole epoch."""
        log.detail(f"{self.__class__.__name__}: advancing loop")
        assert self.trainer.train_dataloader is not None
        dataloader = self.trainer.strategy.process_dataloader(
            self.trainer.train_dataloader)
        data_fetcher = self.trainer._data_connector.get_profiled_dataloader(
            dataloader, 0)

        with self.trainer.profiler.profile("run_training_epoch"):
            self._outputs = self.epoch_loop.run(data_fetcher)

    def on_advance_end(self) -> None:
        # inform logger the batch loop has finished
        self.trainer.logger_connector.epoch_end_reached()

        # get the model and call model.training_epoch_end
        model = self.trainer.lightning_module
        if is_overridden("training_epoch_end", model) and self._outputs:
            epoch_end_outputs = self.epoch_loop._prepare_outputs_training_epoch_end(
                self._outputs,
                automatic=model.automatic_optimization,
                num_optimizers=len(self.trainer.optimizers),
            )
            # run lightning module hook training_epoch_end
            # refresh the result for custom logging at the epoch level
            epoch_end_outputs = self.trainer._call_lightning_module_hook(
                "training_epoch_end", epoch_end_outputs)
            if epoch_end_outputs is not None:
                raise MisconfigurationException(
                    "`training_epoch_end` expects a return of None. "
                    "HINT: remove the return statement in `training_epoch_end`."
                )
        # free memory
        self._outputs = []

        self.epoch_progress.increment_processed()

        # call train epoch end hooks
        self.trainer._call_callback_hooks("on_train_epoch_end")
        self.trainer._call_lightning_module_hook("on_train_epoch_end")

        self.trainer._call_callback_hooks("on_epoch_end")
        self.trainer._call_lightning_module_hook("on_epoch_end")

        self.trainer.logger_connector.on_epoch_end()

        if self.epoch_loop._num_ready_batches_reached():
            self.epoch_loop.update_lr_schedulers(
                "epoch", update_plateau_schedulers=True)

        self.epoch_progress.increment_completed()

        # the global step is manually decreased here due to backwards compatibility with existing loggers
        # as they expect that the same step is used when logging epoch end metrics even when the batch loop has
        # finished. this means the attribute does not exactly track the number of optimizer steps applied.
        # TODO(@carmocca): deprecate and rename so users don't get confused
        self.global_step -= 1
        # log epoch metrics
        self.trainer.logger_connector.update_train_epoch_metrics()
        self.global_step += 1

        # if fault tolerant is enabled and process has been notified, exit.
        self.trainer._exit_gracefully_on_signal()

    def on_run_end(self) -> None:
        """Calls the ``on_train_end`` hook."""
        log.detail(f"{self.__class__.__name__}: train run ended")
        # NOTE: the current_epoch is already incremented
        # Lightning today does not increment the current epoch at the last epoch run in Trainer.fit
        # To simulate that current behavior, we decrement here.
        # TODO: must be fixed by https://github.com/PyTorchLightning/pytorch-lightning/issues/5007
        self.epoch_progress.current.completed = max(
            self.epoch_progress.current.completed - 1, 0)

        # hook
        self.trainer._call_callback_hooks("on_train_end")
        self.trainer._call_lightning_module_hook("on_train_end")
        self.trainer._call_strategy_hook("on_train_end")

        # give accelerators a chance to finish
        self.trainer.strategy.on_train_end()

    def teardown(self) -> None:
        self.epoch_loop.teardown()

    def _should_accumulate(self) -> bool:
        """Whether the gradients should be accumulated."""
        return self.epoch_loop._should_accumulate()