Exemple #1
0
def test_progress_increment_sequence():
    """Test sequence for incrementing."""
    batch = Progress()

    batch.increment_ready()
    assert batch.total == ProcessedTracker(ready=1)
    assert batch.current == ProcessedTracker(ready=1)

    batch.increment_started()
    assert batch.total == ProcessedTracker(ready=1, started=1)
    assert batch.current == ProcessedTracker(ready=1, started=1)

    batch.increment_processed()
    assert batch.total == ProcessedTracker(ready=1, started=1, processed=1)
    assert batch.current == ProcessedTracker(ready=1, started=1, processed=1)

    batch.increment_completed()
    assert batch.total == ProcessedTracker(ready=1,
                                           started=1,
                                           processed=1,
                                           completed=1)
    assert batch.current == ProcessedTracker(ready=1,
                                             started=1,
                                             processed=1,
                                             completed=1)
Exemple #2
0
def test_progress_raises():
    with pytest.raises(ValueError,
                       match="instances should be of the same class"):
        Progress(ReadyCompletedTracker(), ProcessedTracker())

    p = Progress(ReadyCompletedTracker(), ReadyCompletedTracker())
    with pytest.raises(
            TypeError,
            match="ReadyCompletedTracker` doesn't have a `started` attribute"):
        p.increment_started()
    with pytest.raises(
            TypeError,
            match="ReadyCompletedTracker` doesn't have a `processed` attribute"
    ):
        p.increment_processed()
def test_epoch_loop_progress_increment_sequence():
    """Test sequences for incrementing batches reads and epochs."""
    batch = Progress()

    batch.increment_ready()
    assert batch.total == Tracker(ready=1)
    assert batch.current == Tracker(ready=1)

    batch.increment_started()
    assert batch.total == Tracker(ready=1, started=1)
    assert batch.current == Tracker(ready=1, started=1)

    batch.increment_processed()
    assert batch.total == Tracker(ready=1, started=1, processed=1)
    assert batch.current == Tracker(ready=1, started=1, processed=1)

    batch.increment_completed()
    assert batch.total == Tracker(ready=1, started=1, processed=1, completed=1)
    assert batch.current == Tracker(ready=1,
                                    started=1,
                                    processed=1,
                                    completed=1)
Exemple #4
0
class PredictionEpochLoop(Loop):
    """Loop performing prediction on arbitrary sequentially used dataloaders."""
    def __init__(self) -> None:
        super().__init__()
        self.return_predictions = False
        self.predictions: List[Any] = []
        self.current_batch_indices: List[int] = []
        self.batch_progress = Progress()

        self._dl_max_batches = 0
        self._num_dataloaders = 0
        self._warning_cache = WarningCache()
        self._seen_batch_indices: List[List[int]] = []

    @property
    def done(self) -> bool:
        """Ends prediction when the iteration count exceeds the total number of available batches."""
        return self.batch_progress.current.completed >= self._dl_max_batches

    @property
    def should_store_predictions(self) -> bool:
        """Whether the predictions should be stored for later usage (e.g. aggregation or returning)"""
        any_pred = any(cb.interval.on_epoch
                       for cb in self.trainer.prediction_writer_callbacks)
        return self.return_predictions or any_pred

    def connect(self, **kwargs: "Loop") -> None:
        raise NotImplementedError(
            f"{self.__class__.__name__} does not connect any child loops.")

    def reset(self) -> None:
        """Resets the loops internal state."""
        self._seen_batch_indices = []
        self.predictions = []
        self.batch_progress.reset_on_run()

    def on_run_start(  # type: ignore[override]
        self,
        dataloader_iter: Iterator,
        dataloader_idx: int,
        dl_max_batches: int,
        num_dataloaders: int,
    ) -> None:
        """Prepares the loops internal state.

        Args:
            dataloader_iter: the iterator over the current dataloader
            dataloader_idx: the index of the current dataloader
            dl_max_batches: the maximum number of batches the current loader can produce
            num_dataloaders: the total number of dataloaders
        """
        void(dataloader_iter, dataloader_idx)
        self._dl_max_batches = dl_max_batches
        self._num_dataloaders = num_dataloaders
        # this call requires that `self.return_predictions` is set
        self._seen_batch_indices = self._get_batch_indices(
            dataloader_idx) if self.should_store_predictions else []

    def advance(  # type: ignore[override]
        self,
        dataloader_iter: Iterator,
        dataloader_idx: int,
        dl_max_batches: int,
        num_dataloaders: int,
    ) -> None:
        """Runs one prediction step.

        Args:
            dataloader_iter: the iterator over the current dataloader
            dataloader_idx: the index of the current dataloader
            dl_max_batches: the maximum number of batches the current loader can produce
            num_dataloaders: the total number of dataloaders
        """
        action_name = f"[{self.__class__.__name__}].predict_dataloader_idx_{dataloader_idx}_next"
        with self.trainer.profiler.profile(action_name):
            batch_idx, batch = next(dataloader_iter)
        self._seen_batch_indices = self._get_batch_indices(
            dataloader_idx) if self.should_store_predictions else []
        # we need to truncate the list of batch indices due to prefetching in the dataloader and Lightning
        self._seen_batch_indices = self._seen_batch_indices[:(
            self.batch_progress.current.completed + 1)]

        if batch is None:
            raise StopIteration

        batch = self.trainer._call_strategy_hook("batch_to_device",
                                                 batch,
                                                 dataloader_idx=dataloader_idx)

        self.batch_progress.increment_ready()

        self._predict_step(batch, batch_idx, dataloader_idx)

    def on_run_end(self) -> Tuple[List[Any], List[List[int]]]:
        """Returns the predictions and the corresponding batch indices."""
        predictions, all_batch_indices = self.predictions, self._seen_batch_indices
        self.predictions, self._seen_batch_indices = [], []  # free memory
        return predictions, all_batch_indices

    def _predict_step(self, batch: Any, batch_idx: int,
                      dataloader_idx: int) -> None:
        """Runs the actual predict step together with all the necessary bookkeeping and the hooks tied to the
        predict step.

        Args:
            batch: the current batch to run the prediction on
            batch_idx: the index of the current batch
            dataloader_idx: the index of the dataloader producing the current batch
        """
        # configure step_kwargs
        step_kwargs = self._build_kwargs(batch, batch_idx, dataloader_idx)

        # extract batch_indices and store them
        batch_indices = self._get_batch_indices(dataloader_idx)
        self.current_batch_indices = batch_indices[
            batch_idx] if batch_indices else []

        self.trainer._call_callback_hooks("on_predict_batch_start", batch,
                                          batch_idx, dataloader_idx)
        self.trainer._call_lightning_module_hook("on_predict_batch_start",
                                                 batch, batch_idx,
                                                 dataloader_idx)

        self.batch_progress.increment_started()

        predictions = self.trainer._call_strategy_hook("predict_step",
                                                       *step_kwargs.values())

        self.batch_progress.increment_processed()

        if predictions is None:
            self._warning_cache.warn(
                "predict returned None if it was on purpose, ignore this warning..."
            )

        self.trainer._call_callback_hooks("on_predict_batch_end", predictions,
                                          batch, batch_idx, dataloader_idx)
        self.trainer._call_lightning_module_hook("on_predict_batch_end",
                                                 predictions, batch, batch_idx,
                                                 dataloader_idx)

        self.batch_progress.increment_completed()

        if self.should_store_predictions:
            self.predictions.append(
                move_data_to_device(predictions, torch.device("cpu")))

    def _build_kwargs(self, batch: Any, batch_idx: int,
                      dataloader_idx: int) -> Dict[str, Any]:
        """Assembles the keyword arguments for the ``predict_step``

        Args:
            batch: the current batch to run the prediction on
            batch_idx: the index of the current batch
            dataloader_idx: the index of the dataloader producing the current batch

        Returns:
            the dictionary containing all the keyboard arguments for the predict step
        """
        step_kwargs = OrderedDict([("batch", batch), ("batch_idx", batch_idx)])
        if self._num_dataloaders > 1:
            step_kwargs["dataloader_idx"] = dataloader_idx
        return step_kwargs

    def _get_batch_indices(self, dataloader_idx: int) -> List[List[int]]:
        """Returns a reference to the seen batch indices if the dataloader has a batch sampler wrapped by our
        :class:`~pytorch_lightning.overrides.distributed.IndexBatchSamplerWrapper`."""
        # the batch_sampler is not be defined in case of CombinedDataLoaders
        batch_sampler = getattr(
            self.trainer.
            predict_dataloaders[dataloader_idx],  # type: ignore[has-type]
            "batch_sampler",
            None,
        )
        if isinstance(batch_sampler, IndexBatchSamplerWrapper):
            return batch_sampler.seen_batch_indices

        warning_cache.warn(
            "Lightning couldn't infer the indices fetched for your dataloader."
        )
        return []
class EvaluationEpochLoop(Loop):
    """
    This is the loop performing the evaluation. It mainly loops over the given dataloader and runs the validation
    or test step (depending on the trainer's current state).
    """
    def __init__(self) -> None:
        super().__init__()
        self.predictions: Optional[PredictionCollection] = None
        self.dataloader: Optional[Iterator] = None
        self._dl_max_batches: Optional[int] = None
        self._num_dataloaders: Optional[int] = None
        self.outputs: List[STEP_OUTPUT] = []
        self.batch_progress = Progress()

    @property
    def done(self) -> bool:
        """Returns ``True`` if the current iteration count reaches the number of dataloader batches."""
        return self.batch_progress.current.completed >= self._dl_max_batches

    def connect(self, **kwargs: "Loop") -> None:
        raise NotImplementedError(
            f"{self.__class__.__name__} does not connect any child loops.")

    def reset(self) -> None:
        """Resets the loop's internal state."""
        self.predictions = PredictionCollection(self.trainer.global_rank,
                                                self.trainer.world_size)
        self._dl_max_batches = None
        self._num_dataloaders = None
        self.outputs = []

        if not self.restarting:
            self.batch_progress.current.reset()

    def on_run_start(
        self,
        dataloader_iter: Iterator,
        dataloader_idx: int,
        dl_max_batches: int,
        num_dataloaders: int,
    ) -> None:
        """Adds the passed arguments to the loop's state if necessary

        Args:
            dataloader_iter: iterator over the dataloader
            dataloader_idx: index of the current dataloader
            dl_max_batches: maximum number of batches the dataloader can produce
            num_dataloaders: the total number of dataloaders
        """
        void(dataloader_iter, dataloader_idx)
        self._dl_max_batches = dl_max_batches
        self._num_dataloaders = num_dataloaders

    def advance(
        self,
        dataloader_iter: Iterator,
        dataloader_idx: int,
        dl_max_batches: int,
        num_dataloaders: int,
    ) -> None:
        """Calls the evaluation step with the corresponding hooks and updates the logger connector.

        Args:
            dataloader_iter: iterator over the dataloader
            dataloader_idx: index of the current dataloader
            dl_max_batches: maximum number of batches the dataloader can produce
            num_dataloaders: the total number of dataloaders

        Raises:
            StopIteration: If the current batch is None
        """
        void(dl_max_batches, num_dataloaders)

        batch_idx, batch = next(dataloader_iter)

        if batch is None:
            raise StopIteration

        with self.trainer.profiler.profile("evaluation_batch_to_device"):
            batch = self.trainer.accelerator.batch_to_device(
                batch, dataloader_idx=dataloader_idx)

        self.batch_progress.increment_ready()

        # hook
        self.on_evaluation_batch_start(batch, batch_idx, dataloader_idx)

        self.batch_progress.increment_started()

        # lightning module methods
        with self.trainer.profiler.profile("evaluation_step_and_end"):
            output = self.evaluation_step(batch, batch_idx, dataloader_idx)
            output = self.evaluation_step_end(output)

        self.batch_progress.increment_processed()

        # hook + store predictions
        self.on_evaluation_batch_end(output, batch, batch_idx, dataloader_idx)

        self.batch_progress.increment_completed()

        # log batch metrics
        self.trainer.logger_connector.update_eval_step_metrics()

        # track epoch level outputs
        self.outputs = self._track_output_for_epoch_end(self.outputs, output)

    def on_run_end(self) -> List[STEP_OUTPUT]:
        """Returns the outputs of the whole run"""
        outputs = self.outputs
        # free memory
        self.outputs = []
        return outputs

    def evaluation_step(self, batch: Any, batch_idx: int,
                        dataloader_idx: int) -> Optional[STEP_OUTPUT]:
        """The evaluation step (validation_step or test_step depending on the trainer's state).

        Args:
            batch: The current batch to run through the step.
            batch_idx: The index of the current batch
            dataloader_idx: the index of the dataloader producing the current batch

        Returns:
            the outputs of the step
        """
        # configure step_kwargs
        step_kwargs = self._build_kwargs(batch, batch_idx, dataloader_idx)

        if self.trainer.testing:
            self.trainer.lightning_module._current_fx_name = "test_step"
            with self.trainer.profiler.profile("test_step"):
                output = self.trainer.accelerator.test_step(step_kwargs)
        else:
            self.trainer.lightning_module._current_fx_name = "validation_step"
            with self.trainer.profiler.profile("validation_step"):
                output = self.trainer.accelerator.validation_step(step_kwargs)

        return output

    def evaluation_step_end(self, *args: Any,
                            **kwargs: Any) -> Optional[STEP_OUTPUT]:
        """Calls the `{validation/test}_step_end` hook"""
        hook_name = "test_step_end" if self.trainer.testing else "validation_step_end"
        output = self.trainer.call_hook(hook_name, *args, **kwargs)
        return output

    def on_evaluation_batch_start(self, batch: Any, batch_idx: int,
                                  dataloader_idx: int) -> None:
        """Calls the ``on_{validation/test}_batch_start`` hook.

        Args:
            batch: The current batch to run through the step
            batch_idx: The index of the current batch
            dataloader_idx: The index of the dataloader producing the current batch

        Raises:
            AssertionError: If the number of dataloaders is None (has not yet been set).
        """
        self.trainer.logger_connector.on_batch_start()

        assert self._num_dataloaders is not None
        self.trainer.logger_connector.on_evaluation_batch_start(
            batch, batch_idx, dataloader_idx, self._num_dataloaders)

        if self.trainer.testing:
            self.trainer.call_hook("on_test_batch_start", batch, batch_idx,
                                   dataloader_idx)
        else:
            self.trainer.call_hook("on_validation_batch_start", batch,
                                   batch_idx, dataloader_idx)

    def on_evaluation_batch_end(
        self,
        output: Optional[STEP_OUTPUT],
        batch: Any,
        batch_idx: int,
        dataloader_idx: int,
    ) -> None:
        """The ``on_{validation/test}_batch_end`` hook.

        Args:
            output: The output of the performed step
            batch: The input batch for the step
            batch_idx: The index of the current batch
            dataloader_idx: Index of the dataloader producing the current batch
        """
        hook_name = "on_test_batch_end" if self.trainer.testing else "on_validation_batch_end"
        self.trainer.call_hook(hook_name, output, batch, batch_idx,
                               dataloader_idx)

        self.trainer.logger_connector.on_batch_end()

        # store predicitons if do_write_predictions and track eval loss history
        self.store_predictions(output, batch_idx, dataloader_idx)

    def store_predictions(self, output: Optional[STEP_OUTPUT], batch_idx: int,
                          dataloader_idx: int) -> None:
        """Stores the predictions in the prediction collection (only if running in test mode)

        Args:
            output: the outputs of the current step
            batch_idx: the index of the current batch
            dataloader_idx: the index of the dataloader producing the current batch
        """
        # Add step predictions to prediction collection to write later
        if output is not None and self.predictions is not None:
            if isinstance(output, ResultCollection) and self.trainer.testing:
                self.predictions.add(output.pop("predictions", None))

        # track debug metrics
        self.trainer.dev_debugger.track_eval_loss_history(
            batch_idx, dataloader_idx, output)

    def _build_kwargs(self, batch: Any, batch_idx: int,
                      dataloader_idx: int) -> Dict[str, Union[Any, int]]:
        """Helper function to build the arguments for the current step

        Args:
            batch: The current batch to run through the step
            batch_idx: the index of the current batch
            dataloader_idx: the index of the dataloader producing the current batch

        Returns:
            the keyword arguments to pass to the step function
        """
        # make dataloader_idx arg in validation_step optional
        step_kwargs = OrderedDict([("batch", batch), ("batch_idx", batch_idx)])

        multiple_val_loaders = not self.trainer.testing and self._num_dataloaders > 1
        multiple_test_loaders = self.trainer.testing and self._num_dataloaders > 1

        if multiple_test_loaders or multiple_val_loaders:
            step_kwargs["dataloader_idx"] = dataloader_idx

        return step_kwargs

    def _track_output_for_epoch_end(
        self,
        outputs: List[Union[ResultCollection, Dict, Tensor]],
        output: Optional[Union[ResultCollection, Dict, Tensor]],
    ) -> List[Union[ResultCollection, Dict, Tensor]]:
        if output is not None:
            if isinstance(output, ResultCollection):
                output = output.detach()
                if self.trainer.move_metrics_to_cpu:
                    output = output.cpu()
            elif isinstance(output, dict):
                output = recursive_detach(
                    output, to_cpu=self.trainer.move_metrics_to_cpu)
            elif isinstance(
                    output, Tensor
            ) and output.is_cuda and self.trainer.move_metrics_to_cpu:
                output = output.cpu()
            outputs.append(output)
        return outputs
class FitLoop(Loop[None]):
    """This Loop iterates over the epochs to run the training.

    Args:
        min_epochs: The minimum number of epochs
        max_epochs: The maximum number of epochs, can be set -1 to turn this limit off
    """
    def __init__(
        self,
        min_epochs: int = 0,
        max_epochs: int = 1000,
    ) -> None:
        super().__init__()
        if max_epochs < -1:
            # Allow max_epochs to be zero, since this will be handled by fit_loop.done
            raise MisconfigurationException(
                f"`max_epochs` must be a non-negative integer or -1. You passed in {max_epochs}."
            )

        self.max_epochs = max_epochs
        self.min_epochs = min_epochs
        self.epoch_loop = TrainingEpochLoop()
        self.epoch_progress = Progress()

        self._is_fresh_start_epoch: bool = True
        self._outputs: _EPOCH_OUTPUTS_TYPE = []
        self._data_fetcher: Optional[AbstractDataFetcher] = None

    @property
    def total_batch_idx(self) -> int:
        """Returns the current batch index (across epochs)"""
        return self.epoch_loop.total_batch_idx

    @property
    def batch_idx(self) -> int:
        """Returns the current batch index (within this epoch)"""
        return self.epoch_loop.batch_idx

    @property
    def split_idx(self) -> int:
        """Returns the index of the current batch split (within the current batch) for bptt."""
        return self.epoch_loop.batch_loop.split_idx

    @property
    def min_steps(self) -> Optional[int]:
        # TODO(@justusschock): Why aren't we using the attribute in this class?
        """Returns the minimum number of steps to run."""
        return self.epoch_loop.min_steps

    @min_steps.setter
    def min_steps(self, value: Optional[int]) -> None:
        """Sets the minimum number of steps (forwards to epoch_loop)"""
        # TODO(@awaelchli): This setter is required by debugging connector (fast dev run), should be avoided
        self.epoch_loop.min_steps = value

    @property
    def max_steps(self) -> int:
        """Returns the maximum number of steps to run."""
        return self.epoch_loop.max_steps

    @max_steps.setter
    def max_steps(self, value: int) -> None:
        """Sets the maximum number of steps (forwards to epoch_loop)"""
        # TODO(@awaelchli): This setter is required by debugging connector (fast dev run), should be avoided
        if value is None:
            rank_zero_deprecation(
                "Setting `max_steps = None` is deprecated in v1.5 and will no longer be supported in v1.7."
                " Use `max_steps = -1` instead.")
            value = -1
        elif value < -1:
            raise MisconfigurationException(
                f"`max_steps` must be a non-negative integer or -1 (infinite steps). You passed in {value}."
            )
        self.epoch_loop.max_steps = value

    @property
    def running_loss(self) -> TensorRunningAccum:
        """Returns the running loss."""
        return self.epoch_loop.batch_loop.running_loss

    @Loop.restarting.setter
    def restarting(self, restarting: bool) -> None:
        # if the last epoch completely finished, we are not actually restarting, we can check this to see if all
        # current values are equal
        values = (
            self.epoch_progress.current.ready,
            self.epoch_progress.current.started,
            self.epoch_progress.current.processed,
        )
        finished_before_on_train_end = any(
            v != self.epoch_progress.current.completed for v in values)
        if finished_before_on_train_end:
            self.epoch_progress.current.completed = self.epoch_progress.current.processed
        restarting &= finished_before_on_train_end
        Loop.restarting.fset(self, restarting)  # call the parent setter

    @property
    def prefetch_batches(self) -> int:
        is_unsized = self.trainer.num_training_batches == float("inf")
        inter_batch_parallelism = os.getenv("PL_INTER_BATCH_PARALLELISM",
                                            "0") == "1"
        return 1 if is_unsized or inter_batch_parallelism else 0

    @property
    def _skip_backward(self) -> bool:
        """Determines whether the loop will skip backward during automatic optimization."""
        return self.epoch_loop.batch_loop.optimizer_loop._skip_backward

    @_skip_backward.setter
    def _skip_backward(self, value: bool) -> None:
        """Determines whether the loop will skip backward during automatic optimization."""
        self.epoch_loop.batch_loop.optimizer_loop._skip_backward = value

    @property
    def _results(self) -> _ResultCollection:
        if self.trainer.training:
            return self.epoch_loop._results
        if self.trainer.validating:
            return self.epoch_loop.val_loop._results
        raise RuntimeError(
            "`FitLoop._results` property isn't defined. Accessed outside of scope"
        )

    @property
    def done(self) -> bool:
        """Evaluates when to leave the loop."""
        # TODO(@awaelchli): Move track steps inside training loop and move part of these condition inside training loop
        stop_steps = _is_max_limit_reached(self.epoch_loop.global_step,
                                           self.max_steps)
        # `processed` is increased before `on_train_epoch_end`, the hook where checkpoints are typically saved.
        # we use it here because the checkpoint data won't have `completed` increased yet
        stop_epochs = _is_max_limit_reached(
            self.epoch_progress.current.processed, self.max_epochs)

        should_stop = False
        if self.trainer.should_stop:
            # early stopping
            met_min_epochs = self.epoch_progress.current.processed >= self.min_epochs if self.min_epochs else True
            met_min_steps = self.epoch_loop.global_step >= self.min_steps if self.min_steps else True
            if met_min_epochs and met_min_steps:
                should_stop = True
            else:
                log.info(
                    "Trainer was signaled to stop but required minimum epochs"
                    f" ({self.min_epochs}) or minimum steps ({self.min_steps}) has"
                    " not been met. Training will continue...")
        self.trainer.should_stop = should_stop

        return stop_steps or should_stop or stop_epochs or self.trainer.num_training_batches == 0

    @property
    def skip(self) -> bool:
        """Whether we should skip the training and immediately return from the call to :meth:`run`."""
        # since `trainer.num_training_batches` depends on the `train_dataloader` but that won't be called
        # until `on_run_start`, we use `limit_train_batches` instead
        return self.done or self.trainer.limit_train_batches == 0

    def connect(
            self,
            epoch_loop: TrainingEpochLoop) -> None:  # type: ignore[override]
        """Connects a training epoch loop to this fit loop."""
        self.epoch_loop = epoch_loop

    def reset(self) -> None:
        """Resets the internal state of this loop."""
        if self.restarting:
            self.epoch_progress.reset_on_restart()

    def on_run_start(self) -> None:  # type: ignore[override]
        """Calls the ``on_train_start`` hook."""
        # reset train dataloader and val dataloader
        self.trainer.reset_train_val_dataloaders(self.trainer.lightning_module)

        data_fetcher_cls = _select_data_fetcher(self.trainer)
        self._data_fetcher = data_fetcher_cls(
            prefetch_batches=self.prefetch_batches)

        self._is_fresh_start_epoch = True
        self._results.to(device=self.trainer.lightning_module.device)

        self.trainer._call_callback_hooks("on_train_start")
        self.trainer._call_lightning_module_hook("on_train_start")
        self.trainer._call_strategy_hook("on_train_start")

    def on_advance_start(self) -> None:  # type: ignore[override]
        """Prepares the dataloader for training and calls the hooks ``on_epoch_start`` and
        ``on_train_epoch_start``"""
        model = self.trainer.lightning_module

        # reset train dataloader
        if not self._is_fresh_start_epoch and self.trainer._data_connector._should_reload_train_dl:
            log.detail(
                f"{self.__class__.__name__}: resetting train dataloader")
            self.trainer.reset_train_dataloader(model)
        self._is_fresh_start_epoch = False

        # reset outputs here instead of in `reset` as they are not accumulated between epochs
        self._outputs = []

        if self.trainer.train_dataloader is not None and callable(
                getattr(self.trainer.train_dataloader.sampler, "set_epoch",
                        None)):
            # set seed for distributed sampler (enables shuffling for each epoch)
            self.trainer.train_dataloader.sampler.set_epoch(
                self.epoch_progress.current.processed)

        # changing gradient according accumulation_scheduler
        self.trainer.accumulation_scheduler.on_train_epoch_start(
            self.trainer, self.trainer.lightning_module)

        # stores accumulated grad fractions per batch
        self.epoch_loop.batch_loop.accumulated_loss.reset(
            window_length=self.trainer.accumulate_grad_batches)

        self.epoch_progress.increment_ready()

        self.trainer._logger_connector.on_epoch_start()

        self.trainer._call_callback_hooks("on_epoch_start")
        self.trainer._call_lightning_module_hook("on_epoch_start")

        self.trainer._call_callback_hooks("on_train_epoch_start")
        self.trainer._call_lightning_module_hook("on_train_epoch_start")

        self.epoch_progress.increment_started()

    def advance(self) -> None:  # type: ignore[override]
        """Runs one whole epoch."""
        log.detail(f"{self.__class__.__name__}: advancing loop")
        assert self.trainer.train_dataloader is not None
        dataloader = self.trainer.train_dataloader
        assert self._data_fetcher is not None
        self._data_fetcher.setup(dataloader,
                                 batch_to_device=partial(
                                     self.trainer._call_strategy_hook,
                                     "batch_to_device",
                                     dataloader_idx=0))
        with self.trainer.profiler.profile("run_training_epoch"):
            self._outputs = self.epoch_loop.run(self._data_fetcher)

    def on_advance_end(self) -> None:
        # inform logger the batch loop has finished
        self.trainer._logger_connector.epoch_end_reached()

        # get the model and call model.training_epoch_end
        model = self.trainer.lightning_module
        if is_overridden("training_epoch_end", model) and self._outputs:
            epoch_end_outputs = self.epoch_loop._prepare_outputs_training_epoch_end(
                self._outputs,
                lightning_module=model,
                num_optimizers=len(self.trainer.optimizers),
            )
            # run lightning module hook training_epoch_end
            # refresh the result for custom logging at the epoch level
            epoch_end_outputs = self.trainer._call_lightning_module_hook(
                "training_epoch_end", epoch_end_outputs)
            if epoch_end_outputs is not None:
                raise MisconfigurationException(
                    "`training_epoch_end` expects a return of None. "
                    "HINT: remove the return statement in `training_epoch_end`."
                )
        # free memory
        self._outputs = []

        self.epoch_progress.increment_processed()

        # call train epoch end hooks
        self.trainer._call_callback_hooks("on_train_epoch_end")
        self.trainer._call_lightning_module_hook("on_train_epoch_end")

        self.trainer._call_callback_hooks("on_epoch_end")
        self.trainer._call_lightning_module_hook("on_epoch_end")

        self.trainer._logger_connector.on_epoch_end()

        if self.epoch_loop._num_ready_batches_reached():
            self.epoch_loop.update_lr_schedulers(
                "epoch", update_plateau_schedulers=True)

        self.epoch_progress.increment_completed()

        # we manually decrease here because loggers expect that the same step is used when logging epoch-end metrics
        # even when the batch loop has finished
        self.epoch_loop._batches_that_stepped -= 1
        # log epoch metrics
        self.trainer._logger_connector.update_train_epoch_metrics()
        self.epoch_loop._batches_that_stepped += 1

        # if fault tolerant is enabled and process has been notified, exit.
        self.trainer._exit_gracefully_on_signal()

    def on_run_end(self) -> None:
        """Calls the ``on_train_end`` hook."""
        log.detail(f"{self.__class__.__name__}: train run ended")

        # hook
        self.trainer._call_callback_hooks("on_train_end")
        self.trainer._call_lightning_module_hook("on_train_end")
        self.trainer._call_strategy_hook("on_train_end")

    def teardown(self) -> None:
        if self._data_fetcher is not None:
            self._data_fetcher.teardown()
            self._data_fetcher = None
        self.epoch_loop.teardown()

    def _should_accumulate(self) -> bool:
        """Whether the gradients should be accumulated."""
        return self.epoch_loop._should_accumulate()
class PredictionEpochLoop(Loop):
    """Loop performing prediction on arbitrary sequentially used dataloaders."""
    def __init__(self) -> None:
        super().__init__()
        self.return_predictions: bool = False
        self.predictions: List[Any] = []
        self.current_batch_indices: List[int] = []
        self.batch_progress = Progress()

        self._dl_max_batches: Optional[int] = None
        self._num_dataloaders: Optional[int] = None
        self._warning_cache = WarningCache()
        self._all_batch_indices: List[int] = []

    @property
    def done(self) -> bool:
        """Ends prediction when the iteration count exceeds the total number of available batches."""
        return self.batch_progress.current.completed >= self._dl_max_batches

    @property
    def should_store_predictions(self) -> bool:
        """Whether the predictions should be stored for later usage (e.g. aggregation or returning)"""
        any_pred = any(cb.interval.on_epoch
                       for cb in self.trainer.prediction_writer_callbacks)
        return self.return_predictions or any_pred

    def connect(self, **kwargs: "Loop") -> None:
        raise NotImplementedError(
            f"{self.__class__.__name__} does not connect any child loops.")

    def reset(self) -> None:
        """Resets the loops internal state."""
        self._all_batch_indices: List[int] = []
        self.predictions: List[Any] = []
        self.batch_progress.reset_on_run()

    def on_run_start(
        self,
        dataloader_iter: Iterator,
        dataloader_idx: int,
        dl_max_batches: int,
        num_dataloaders: int,
        return_predictions: bool = False,
    ) -> None:
        """Prepares the loops internal state.

        Args:
            dataloader_iter: the iterator over the current dataloader
            dataloader_idx: the index of the current dataloader
            dl_max_batches: the maximum number of batches the current loader can produce
            num_dataloaders: the total number of dataloaders
            return_predictions: whether to return the obtained predictions
        """
        void(dataloader_iter, dataloader_idx)
        self._dl_max_batches = dl_max_batches
        self._num_dataloaders = num_dataloaders
        self.return_predictions = return_predictions

    def advance(
        self,
        dataloader_iter: Iterator,
        dataloader_idx: int,
        dl_max_batches: int,
        num_dataloaders: int,
        return_predictions: bool = False,
    ) -> None:
        """Runs one prediction step.

        Args:
            dataloader_iter: the iterator over the current dataloader
            dataloader_idx: the index of the current dataloader
            dl_max_batches: the maximum number of batches the current loader can produce
            num_dataloaders: the total number of dataloaders
            return_predictions: whether to return the obtained predictions
        """
        batch_idx, batch = next(dataloader_iter)
        if batch is None:
            raise StopIteration

        with self.trainer.profiler.profile("predict_batch_to_device"):
            batch = self.trainer.accelerator.batch_to_device(
                batch, dataloader_idx=dataloader_idx)

        self.batch_progress.increment_ready()

        with self.trainer.profiler.profile("predict_step"):
            self._predict_step(batch, batch_idx, dataloader_idx)

    def on_run_end(self) -> Tuple[List[Any], List[int]]:
        """Returns the predictions and the corresponding batch indices."""
        predictions = self.predictions
        all_batch_indices = self._all_batch_indices
        # free memory
        self.predictions = []
        self._all_batch_indices = []
        return predictions, all_batch_indices

    def _predict_step(self, batch: Any, batch_idx: int,
                      dataloader_idx: int) -> None:
        """Runs the actual predict step together with all the necessary bookkeeping and the hooks tied to the
        predict step.

        Args:
            batch: the current batch to run the prediction on
            batch_idx: the index of the current batch
            dataloader_idx: the index of the dataloader producing the current batch
        """
        # configure step_kwargs
        step_kwargs = self._build_kwargs(batch, batch_idx, dataloader_idx)

        # extract batch_indices and store them
        self._store_batch_indices(dataloader_idx)

        model_ref = self.trainer.lightning_module

        self.trainer.call_hook("on_predict_batch_start", batch, batch_idx,
                               dataloader_idx)

        self.batch_progress.increment_started()

        model_ref._current_fx_name = "predict_step"
        predictions = self.trainer.accelerator.predict_step(step_kwargs)

        self.batch_progress.increment_processed()

        if predictions is None:
            self._warning_cache.warn(
                "predict returned None if it was on purpose, ignore this warning..."
            )

        self.trainer.call_hook("on_predict_batch_end", predictions, batch,
                               batch_idx, dataloader_idx)

        self.batch_progress.increment_completed()

        if self.should_store_predictions:
            self.predictions.append(
                move_data_to_device(predictions, torch.device("cpu")))

    def _build_kwargs(self, batch: Any, batch_idx: int,
                      dataloader_idx: int) -> Dict[str, Any]:
        """Assembles the keyword arguments for the ``predict_step``

        Args:
            batch: the current batch to run the prediction on
            batch_idx: the index of the current batch
            dataloader_idx: the index of the dataloader producing the current batch

        Returns:
            the dictionary containing all the keyboard arguments for the predict step
        """
        step_kwargs = OrderedDict([("batch", batch), ("batch_idx", batch_idx)])
        if self._num_dataloaders > 1:
            step_kwargs["dataloader_idx"] = dataloader_idx
        return step_kwargs

    def _store_batch_indices(self, dataloader_idx: int) -> None:
        """Stores the batch indices if the predictions should be stored."""
        batch_sampler = self.trainer.predict_dataloaders[
            dataloader_idx].batch_sampler
        if isinstance(batch_sampler, IndexBatchSamplerWrapper):
            self.current_batch_indices = batch_sampler.batch_indices
            if self.should_store_predictions:
                self._all_batch_indices.append(batch_sampler.batch_indices)
        else:
            warning_cache.warn(
                "Lightning couldn't infer the indices fetched for your dataloader."
            )
Exemple #8
0
class ActiveLearningLoop(Loop):
    max_epochs: int
    inference_model: InferenceMCDropoutTask

    @requires(["baal", (_PL_GREATER_EQUAL_1_4_0, "pytorch-lightning>=1.4.0")])
    def __init__(self,
                 label_epoch_frequency: int,
                 inference_iteration: int = 2,
                 should_reset_weights: bool = True):
        """The `ActiveLearning Loop` describes the following training procedure. This loop is connected with the
        `ActiveLearningTrainer`

        Example::

            while unlabelled data or budget criteria not reached:

                if labelled data
                    trainer.fit(model, labelled data)

                if unlabelled data:
                    predictions = trainer.predict(model, unlabelled data)
                    uncertainties = heuristic(predictions)
                    request labellelisation for the sample with highest uncertainties under a given budget

        Args:
            label_epoch_frequency: Number of epoch to train on before requesting labellisation.
            inference_iteration: Number of inference to perform to compute uncertainty.
        """
        super().__init__()
        self.label_epoch_frequency = label_epoch_frequency
        self.inference_iteration = inference_iteration
        self.should_reset_weights = should_reset_weights
        self.fit_loop: Optional[FitLoop] = None
        self.progress = Progress()
        self._model_state_dict: Optional[Dict[str, torch.Tensor]] = None
        self._datamodule_state_dict: Optional[Dict[str, Any]] = None
        self._lightning_module: Optional[flash.Task] = None

    @property
    def done(self) -> bool:
        return self.progress.current.completed >= self.max_epochs

    def connect(self, fit_loop: FitLoop):
        self.fit_loop = fit_loop
        self.max_epochs = self.fit_loop.max_epochs
        self.fit_loop.max_epochs = self.label_epoch_frequency

    def on_run_start(self, *args: Any, **kwargs: Any) -> None:
        assert isinstance(self.trainer.datamodule, ActiveLearningDataModule)
        if self._datamodule_state_dict is not None:
            self.trainer.datamodule.load_state_dict(
                self._datamodule_state_dict)
        self.trainer.predict_loop._return_predictions = True
        self._lightning_module = self.trainer.lightning_module
        self._model_state_dict = deepcopy(self._lightning_module.state_dict())
        self.inference_model = InferenceMCDropoutTask(self._lightning_module,
                                                      self.inference_iteration)

    def reset(self) -> None:
        pass

    def on_advance_start(self, *args: Any, **kwargs: Any) -> None:
        if self.trainer.datamodule.has_labelled_data:
            self._reset_dataloader_for_stage(RunningStage.TRAINING)
            self._reset_dataloader_for_stage(RunningStage.VALIDATING)
            if self.trainer.datamodule.has_test:
                self._reset_dataloader_for_stage(RunningStage.TESTING)
        if self.trainer.datamodule.has_unlabelled_data:
            self._reset_dataloader_for_stage(RunningStage.PREDICTING)
        self.progress.increment_ready()

    def advance(self, *args: Any, **kwargs: Any) -> None:

        self.progress.increment_started()

        if self.trainer.datamodule.has_labelled_data:
            self.fit_loop.run()

        if self.trainer.datamodule.has_test:
            self._reset_testing()
            metrics = self.trainer.test_loop.run()
            if metrics:
                self.trainer.logger.log_metrics(metrics[0],
                                                step=self.trainer.global_step)

        if self.trainer.datamodule.has_unlabelled_data:
            self._reset_predicting()
            probabilities = self.trainer.predict_loop.run()
            self.trainer.datamodule.label(probabilities=probabilities)
        else:
            raise StopIteration

        self._reset_fitting()
        self.progress.increment_processed()

    def on_advance_end(self) -> None:
        if self.trainer.datamodule.has_unlabelled_data and self.should_reset_weights:
            # reload the weights to retrain from scratch with the new labelled data.
            self._lightning_module.load_state_dict(self._model_state_dict)
        self.progress.increment_completed()
        return super().on_advance_end()

    def on_run_end(self):
        self._datamodule_state_dict = self.trainer.datamodule.state_dict()
        self._reset_fitting()
        self._teardown()
        return super().on_run_end()

    def on_save_checkpoint(self) -> Dict:
        return {"datamodule_state_dict": self._datamodule_state_dict}

    def on_load_checkpoint(self, state_dict) -> None:
        self._datamodule_state_dict = state_dict.pop("datamodule_state_dict",
                                                     None)

    def __getattr__(self, key):
        if key not in self.__dict__:
            return getattr(self.fit_loop, key)
        return self.__dict__[key]

    def _connect(self, model: LightningModule):
        if _PL_GREATER_EQUAL_1_5_0:
            self.trainer.training_type_plugin.connect(model)
        else:
            self.trainer.accelerator.connect(model)

    def _reset_fitting(self):
        self.trainer.state.fn = TrainerFn.FITTING
        self.trainer.training = True
        self.trainer.lightning_module.on_train_dataloader()
        self._connect(self._lightning_module)
        self.fit_loop.epoch_progress = Progress()

    def _reset_predicting(self):
        self.trainer.state.fn = TrainerFn.PREDICTING
        self.trainer.predicting = True
        self.trainer.lightning_module.on_predict_dataloader()
        self._connect(self.inference_model)

    def _reset_testing(self):
        self.trainer.state.fn = TrainerFn.TESTING
        self.trainer.state.status = TrainerStatus.RUNNING
        self.trainer.testing = True
        self.trainer.lightning_module.on_test_dataloader()
        self._connect(self._lightning_module)

    def _reset_dataloader_for_stage(self, running_state: RunningStage):
        dataloader_name = f"{_STAGES_PREFIX[running_state]}_dataloader"
        # If the dataloader exists, we reset it.
        dataloader = (getattr(
            self.trainer.datamodule, dataloader_name) if is_overridden(
                dataloader_name, self.trainer.datamodule) else None)

        if dataloader:
            if _PL_GREATER_EQUAL_1_5_0:
                setattr(
                    self.trainer._data_connector,
                    f"_{dataloader_name}_source",
                    _DataLoaderSource(self.trainer.datamodule,
                                      dataloader_name),
                )
            else:
                setattr(
                    self.trainer.lightning_module,
                    dataloader_name,
                    _PatchDataLoader(dataloader(), running_state),
                )
            setattr(self.trainer, dataloader_name, None)
            # TODO: Resolve this within PyTorch Lightning.
            try:
                getattr(self.trainer, f"reset_{dataloader_name}")(
                    self.trainer.lightning_module)
            except MisconfigurationException:
                pass

    def _teardown(self) -> None:
        self.trainer.train_dataloader = None
        self.trainer.val_dataloaders = None
        self.trainer.test_dataloaders = None
        self.trainer.predict_dataloaders = None
        # Hack
        self.trainer.lightning_module.train_dataloader = None
        self.trainer.lightning_module.val_dataloader = None
        self.trainer.lightning_module.test_dataloader = None
        self.trainer.lightning_module.predict_dataloader = None
class EvaluationEpochLoop(Loop):
    """This is the loop performing the evaluation.

    It mainly loops over the given dataloader and runs the validation or test step (depending on the trainer's current
    state).
    """
    def __init__(self) -> None:
        super().__init__()
        self.dataloader: Optional[Iterator] = None
        self._dl_max_batches: Optional[int] = None
        self._num_dataloaders: Optional[int] = None
        self.outputs: EPOCH_OUTPUT = []
        self.batch_progress = Progress()
        self.dataloader_iter: Optional[Iterator] = None

    @property
    def done(self) -> bool:
        """Returns ``True`` if the current iteration count reaches the number of dataloader batches."""
        return self.batch_progress.current.completed >= self._dl_max_batches

    def connect(self, **kwargs: "Loop") -> None:
        raise NotImplementedError(
            f"{self.__class__.__name__} does not connect any child loops.")

    def reset(self) -> None:
        """Resets the loop's internal state."""
        self._dl_max_batches = None
        self._num_dataloaders = None
        self.outputs = []

        if not self.restarting:
            self.batch_progress.current.reset()

    def on_run_start(self, data_fetcher: AbstractDataFetcher,
                     dataloader_idx: int, dl_max_batches: int,
                     num_dataloaders: int) -> None:
        """Adds the passed arguments to the loop's state if necessary.

        Args:
            data_fetcher: the current data_fetcher wrapping the dataloader
            dataloader_idx: index of the current dataloader
            dl_max_batches: maximum number of batches the dataloader can produce
            num_dataloaders: the total number of dataloaders
        """
        void(dataloader_idx)
        self._dl_max_batches = dl_max_batches
        self._num_dataloaders = num_dataloaders

        self.dataloader_iter = _prepare_dataloader_iter(
            data_fetcher, self.batch_progress.current.ready)

    def advance(self, data_fetcher: AbstractDataFetcher, dataloader_idx: int,
                dl_max_batches: int, num_dataloaders: int) -> None:
        """Calls the evaluation step with the corresponding hooks and updates the logger connector.

        Args:
            dataloader_iter: iterator over the dataloader
            dataloader_idx: index of the current dataloader
            dl_max_batches: maximum number of batches the dataloader can produce
            num_dataloaders: the total number of dataloaders

        Raises:
            StopIteration: If the current batch is None
        """
        void(data_fetcher, dl_max_batches, num_dataloaders)

        batch_idx, (batch, _) = next(self.dataloader_iter)

        if batch is None:
            raise StopIteration

        if not self.trainer.data_connector.evaluation_data_fetcher.store_on_device:
            with self.trainer.profiler.profile("evaluation_batch_to_device"):
                batch = self.trainer.accelerator.batch_to_device(
                    batch, dataloader_idx=dataloader_idx)

        self.batch_progress.increment_ready()

        # hook
        self.on_evaluation_batch_start(batch, batch_idx, dataloader_idx)

        self.batch_progress.increment_started()

        # lightning module methods
        with self.trainer.profiler.profile("evaluation_step_and_end"):
            output = self.evaluation_step(batch, batch_idx, dataloader_idx)
            output = self.evaluation_step_end(output)

        self.batch_progress.increment_processed()

        # track loss history
        self.on_evaluation_batch_end(output, batch, batch_idx, dataloader_idx)

        self.batch_progress.increment_completed()

        # log batch metrics
        self.trainer.logger_connector.update_eval_step_metrics()

        # track epoch level outputs
        if self._should_track_batch_outputs_for_epoch_end():
            output = recursive_detach(output,
                                      to_cpu=self.trainer.move_metrics_to_cpu)
            if output is not None:
                self.outputs.append(output)

    def on_run_end(self) -> EPOCH_OUTPUT:
        """Returns the outputs of the whole run."""
        outputs = self.outputs
        # free memory
        self.outputs = []
        return outputs

    def evaluation_step(self, batch: Any, batch_idx: int,
                        dataloader_idx: int) -> Optional[STEP_OUTPUT]:
        """The evaluation step (validation_step or test_step depending on the trainer's state).

        Args:
            batch: The current batch to run through the step.
            batch_idx: The index of the current batch
            dataloader_idx: the index of the dataloader producing the current batch

        Returns:
            the outputs of the step
        """
        # configure step_kwargs
        step_kwargs = self._build_kwargs(batch, batch_idx, dataloader_idx)

        if self.trainer.testing:
            self.trainer.lightning_module._current_fx_name = "test_step"
            with self.trainer.profiler.profile("test_step"):
                output = self.trainer.accelerator.test_step(step_kwargs)
        else:
            self.trainer.lightning_module._current_fx_name = "validation_step"
            with self.trainer.profiler.profile("validation_step"):
                output = self.trainer.accelerator.validation_step(step_kwargs)

        return output

    def evaluation_step_end(self, *args: Any,
                            **kwargs: Any) -> Optional[STEP_OUTPUT]:
        """Calls the `{validation/test}_step_end` hook."""
        hook_name = "test_step_end" if self.trainer.testing else "validation_step_end"
        output = self.trainer.call_hook(hook_name, *args, **kwargs)
        return output

    def on_evaluation_batch_start(self, batch: Any, batch_idx: int,
                                  dataloader_idx: int) -> None:
        """Calls the ``on_{validation/test}_batch_start`` hook.

        Args:
            batch: The current batch to run through the step
            batch_idx: The index of the current batch
            dataloader_idx: The index of the dataloader producing the current batch

        Raises:
            AssertionError: If the number of dataloaders is None (has not yet been set).
        """
        self.trainer.logger_connector.on_batch_start()

        assert self._num_dataloaders is not None
        self.trainer.logger_connector.on_evaluation_batch_start(
            batch, batch_idx, dataloader_idx, self._num_dataloaders)

        if self.trainer.testing:
            self.trainer.call_hook("on_test_batch_start", batch, batch_idx,
                                   dataloader_idx)
        else:
            self.trainer.call_hook("on_validation_batch_start", batch,
                                   batch_idx, dataloader_idx)

    def on_evaluation_batch_end(self, output: Optional[STEP_OUTPUT],
                                batch: Any, batch_idx: int,
                                dataloader_idx: int) -> None:
        """The ``on_{validation/test}_batch_end`` hook.

        Args:
            output: The output of the performed step
            batch: The input batch for the step
            batch_idx: The index of the current batch
            dataloader_idx: Index of the dataloader producing the current batch
        """
        hook_name = "on_test_batch_end" if self.trainer.testing else "on_validation_batch_end"
        self.trainer.call_hook(hook_name, output, batch, batch_idx,
                               dataloader_idx)

        self.trainer.logger_connector.on_batch_end()

    def _build_kwargs(self, batch: Any, batch_idx: int,
                      dataloader_idx: int) -> Dict[str, Union[Any, int]]:
        """Helper function to build the arguments for the current step.

        Args:
            batch: The current batch to run through the step
            batch_idx: the index of the current batch
            dataloader_idx: the index of the dataloader producing the current batch

        Returns:
            the keyword arguments to pass to the step function
        """
        # make dataloader_idx arg in validation_step optional
        step_kwargs = OrderedDict([("batch", batch), ("batch_idx", batch_idx)])

        multiple_val_loaders = not self.trainer.testing and self._num_dataloaders > 1
        multiple_test_loaders = self.trainer.testing and self._num_dataloaders > 1

        if multiple_test_loaders or multiple_val_loaders:
            step_kwargs["dataloader_idx"] = dataloader_idx

        return step_kwargs

    @lru_cache(1)
    def _should_track_batch_outputs_for_epoch_end(self) -> bool:
        """Whether the batch outputs should be stored for later usage."""
        model = self.trainer.lightning_module
        if self.trainer.testing:
            return is_overridden("test_epoch_end", model)
        return is_overridden("validation_epoch_end", model)

    def teardown(self) -> None:
        # in case the model changes
        self._should_track_batch_outputs_for_epoch_end.cache_clear()
Exemple #10
0
class FitLoop(Loop[None]):
    """This Loop iterates over the epochs to run the training.

    Args:
        min_epochs: The minimum number of epochs
        max_epochs: The maximum number of epochs, can be set -1 to turn this limit off
    """
    def __init__(
        self,
        min_epochs: Optional[int] = 1,
        max_epochs: int = 1000,
    ) -> None:
        super().__init__()
        if max_epochs < -1:
            # Allow max_epochs to be zero, since this will be handled by fit_loop.done
            raise MisconfigurationException(
                f"`max_epochs` must be a non-negative integer or -1. You passed in {max_epochs}."
            )

        self.max_epochs = max_epochs
        self.min_epochs = min_epochs
        self.epoch_loop = TrainingEpochLoop()
        self.epoch_progress = Progress()

        self._is_fresh_start_epoch: bool = True
        self._outputs: _EPOCH_OUTPUTS_TYPE = []

    @property
    def current_epoch(self) -> int:
        """Return the current epoch."""
        return self.epoch_progress.current.completed

    @current_epoch.setter
    def current_epoch(self, value: int) -> None:
        """Setter for the current epoch."""
        self.epoch_progress.current.completed = value

    @property
    def global_step(self) -> int:
        """Returns the global step."""
        return self.epoch_loop.global_step

    @global_step.setter
    def global_step(self, value: int) -> None:
        """Sets the global step (forwards to epoch_loop)"""
        self.epoch_loop.global_step = value

    @property
    def total_batch_idx(self) -> int:
        """Returns the current batch index (across epochs)"""
        return self.epoch_loop.total_batch_idx

    @property
    def batch_idx(self) -> int:
        """Returns the current batch index (within this epoch)"""
        return self.epoch_loop.batch_idx

    @property
    def split_idx(self) -> int:
        """Returns the index of the current batch split (within the current batch) for bptt."""
        return self.epoch_loop.batch_loop.split_idx

    @property
    def min_steps(self) -> Optional[int]:
        # TODO(@justusschock): Why aren't we using the attribute in this class?
        """Returns the minimum numnber of steps to run."""
        return self.epoch_loop.min_steps

    @min_steps.setter
    def min_steps(self, value: Optional[int]) -> None:
        """Sets the minimum number of steps (forwards to epoch_loop)"""
        # TODO(@awaelchli): This setter is required by debugging connector (fast dev run), should be avoided
        self.epoch_loop.min_steps = value

    @property
    def max_steps(self) -> int:
        """Returns the maximum number of steps to run."""
        return self.epoch_loop.max_steps

    @max_steps.setter
    def max_steps(self, value: int) -> None:
        """Sets the maximum number of steps (forwards to epoch_loop)"""
        # TODO(@awaelchli): This setter is required by debugging connector (fast dev run), should be avoided
        if value is None:
            rank_zero_deprecation(
                "Setting `max_steps = None` is deprecated in v1.5 and will no longer be supported in v1.7."
                " Use `max_steps = -1` instead.")
            value = -1
        elif value < -1:
            raise MisconfigurationException(
                f"`max_steps` must be a non-negative integer or -1 (infinite steps). You passed in {value}."
            )
        self.epoch_loop.max_steps = value

    @property
    def running_loss(self) -> TensorRunningAccum:
        """Returns the running loss."""
        return self.epoch_loop.batch_loop.running_loss

    @property
    def _skip_backward(self) -> bool:
        """Determines whether the loop will skip backward during automatic optimization."""
        return self.epoch_loop.batch_loop.optimizer_loop._skip_backward

    @_skip_backward.setter
    def _skip_backward(self, value: bool) -> None:
        """Determines whether the loop will skip backward during automatic optimization."""
        self.epoch_loop.batch_loop.optimizer_loop._skip_backward = value

    @property
    def _results(self) -> _ResultCollection:
        if self.trainer.training:
            return self.epoch_loop._results
        if self.trainer.validating:
            return self.epoch_loop.val_loop._results
        raise RuntimeError(
            "`FitLoop._results` property isn't defined. Accessed outside of scope"
        )

    @property
    def done(self) -> bool:
        """Evaluates when to leave the loop.

        Returns True if trainer.should_stop was set (e.g. by early stopping) or if the maximum number of steps or epochs
        is reached.
        """
        # TODO(@awaelchli): Move track steps inside training loop and move part of these condition inside training loop
        stop_steps = _is_max_limit_reached(self.global_step, self.max_steps)
        stop_epochs = _is_max_limit_reached(self.current_epoch,
                                            self.max_epochs)

        should_stop = False
        if self.trainer.should_stop:
            # early stopping
            met_min_epochs = self.current_epoch >= self.min_epochs if self.min_epochs else True
            met_min_steps = self.global_step >= self.min_steps if self.min_steps else True
            if met_min_epochs and met_min_steps:
                should_stop = True
            else:
                log.info(
                    "Trainer was signaled to stop but required minimum epochs"
                    f" ({self.min_epochs}) or minimum steps ({self.min_steps}) has"
                    " not been met. Training will continue...")
        self.trainer.should_stop = should_stop

        return stop_steps or should_stop or stop_epochs or self.trainer.num_training_batches == 0

    @property
    def skip(self) -> bool:
        """Whether we should skip the training and immediately return from the call to :meth:`run`."""
        # since `trainer.num_training_batches` depends on the `train_dataloader` but that won't be called
        # until `on_run_start`, we use `limit_train_batches` instead
        return self.done or self.trainer.limit_train_batches == 0

    def connect(
            self,
            epoch_loop: TrainingEpochLoop) -> None:  # type: ignore[override]
        """Connects a training epoch loop to this fit loop."""
        self.epoch_loop = epoch_loop

    def reset(self) -> None:
        """Resets the internal state of this loop."""
        if self.restarting:
            self.epoch_progress.reset_on_restart()

    def on_run_start(self) -> None:  # type: ignore[override]
        """Calls the ``on_train_start`` hook."""
        # reset train dataloader and val dataloader
        self.trainer.reset_train_val_dataloaders(self.trainer.lightning_module)
        self._is_fresh_start_epoch = True
        self._results.to(device=self.trainer.lightning_module.device)
        self.trainer._call_callback_hooks("on_train_start")
        self.trainer._call_lightning_module_hook("on_train_start")
        self.trainer._call_strategy_hook("on_train_start")

    def on_advance_start(self) -> None:  # type: ignore[override]
        """Prepares the dataloader for training and calls the hooks ``on_epoch_start`` and
        ``on_train_epoch_start``"""
        model = self.trainer.lightning_module

        # reset train dataloader
        if not self._is_fresh_start_epoch and self.trainer._data_connector._should_reload_train_dl:
            log.detail(
                f"{self.__class__.__name__}: resetting train dataloader")
            self.trainer.reset_train_dataloader(model)
        self._is_fresh_start_epoch = False

        # reset outputs here instead of in `reset` as they are not accumulated between epochs
        self._outputs = []

        if self.trainer.train_dataloader is not None and callable(
                getattr(self.trainer.train_dataloader.sampler, "set_epoch",
                        None)):
            # set seed for distributed sampler (enables shuffling for each epoch)
            self.trainer.train_dataloader.sampler.set_epoch(self.current_epoch)

        # changing gradient according accumulation_scheduler
        self.trainer.accumulation_scheduler.on_train_epoch_start(
            self.trainer, self.trainer.lightning_module)

        # stores accumulated grad fractions per batch
        self.epoch_loop.batch_loop.accumulated_loss.reset(
            window_length=self.trainer.accumulate_grad_batches)

        self.epoch_progress.increment_ready()

        self.trainer.logger_connector.on_epoch_start()

        self.trainer._call_callback_hooks("on_epoch_start")
        self.trainer._call_lightning_module_hook("on_epoch_start")

        self.trainer._call_callback_hooks("on_train_epoch_start")
        self.trainer._call_lightning_module_hook("on_train_epoch_start")

        self.epoch_progress.increment_started()

    def advance(self) -> None:  # type: ignore[override]
        """Runs one whole epoch."""
        log.detail(f"{self.__class__.__name__}: advancing loop")
        assert self.trainer.train_dataloader is not None
        dataloader = self.trainer.strategy.process_dataloader(
            self.trainer.train_dataloader)
        data_fetcher = self.trainer._data_connector.get_profiled_dataloader(
            dataloader, 0)

        with self.trainer.profiler.profile("run_training_epoch"):
            self._outputs = self.epoch_loop.run(data_fetcher)

    def on_advance_end(self) -> None:
        # inform logger the batch loop has finished
        self.trainer.logger_connector.epoch_end_reached()

        # get the model and call model.training_epoch_end
        model = self.trainer.lightning_module
        if is_overridden("training_epoch_end", model) and self._outputs:
            epoch_end_outputs = self.epoch_loop._prepare_outputs_training_epoch_end(
                self._outputs,
                automatic=model.automatic_optimization,
                num_optimizers=len(self.trainer.optimizers),
            )
            # run lightning module hook training_epoch_end
            # refresh the result for custom logging at the epoch level
            epoch_end_outputs = self.trainer._call_lightning_module_hook(
                "training_epoch_end", epoch_end_outputs)
            if epoch_end_outputs is not None:
                raise MisconfigurationException(
                    "`training_epoch_end` expects a return of None. "
                    "HINT: remove the return statement in `training_epoch_end`."
                )
        # free memory
        self._outputs = []

        self.epoch_progress.increment_processed()

        # call train epoch end hooks
        self.trainer._call_callback_hooks("on_train_epoch_end")
        self.trainer._call_lightning_module_hook("on_train_epoch_end")

        self.trainer._call_callback_hooks("on_epoch_end")
        self.trainer._call_lightning_module_hook("on_epoch_end")

        self.trainer.logger_connector.on_epoch_end()

        if self.epoch_loop._num_ready_batches_reached():
            self.epoch_loop.update_lr_schedulers(
                "epoch", update_plateau_schedulers=True)

        self.epoch_progress.increment_completed()

        # the global step is manually decreased here due to backwards compatibility with existing loggers
        # as they expect that the same step is used when logging epoch end metrics even when the batch loop has
        # finished. this means the attribute does not exactly track the number of optimizer steps applied.
        # TODO(@carmocca): deprecate and rename so users don't get confused
        self.global_step -= 1
        # log epoch metrics
        self.trainer.logger_connector.update_train_epoch_metrics()
        self.global_step += 1

        # if fault tolerant is enabled and process has been notified, exit.
        self.trainer._exit_gracefully_on_signal()

    def on_run_end(self) -> None:
        """Calls the ``on_train_end`` hook."""
        log.detail(f"{self.__class__.__name__}: train run ended")
        # NOTE: the current_epoch is already incremented
        # Lightning today does not increment the current epoch at the last epoch run in Trainer.fit
        # To simulate that current behavior, we decrement here.
        # TODO: must be fixed by https://github.com/PyTorchLightning/pytorch-lightning/issues/5007
        self.current_epoch = max(self.current_epoch - 1, 0)

        # hook
        self.trainer._call_callback_hooks("on_train_end")
        self.trainer._call_lightning_module_hook("on_train_end")
        self.trainer._call_strategy_hook("on_train_end")

        # give accelerators a chance to finish
        self.trainer.strategy.on_train_end()

    def teardown(self) -> None:
        self.epoch_loop.teardown()

    def _should_accumulate(self) -> bool:
        """Whether the gradients should be accumulated."""
        return self.epoch_loop._should_accumulate()