def __init__(self, label_epoch_frequency: int, inference_iteration: int = 2, should_reset_weights: bool = True): """The `ActiveLearning Loop` describes the following training procedure. This loop is connected with the `ActiveLearningTrainer` Example:: while unlabelled data or budget criteria not reached: if labelled data trainer.fit(model, labelled data) if unlabelled data: predictions = trainer.predict(model, unlabelled data) uncertainties = heuristic(predictions) request labellelisation for the sample with highest uncertainties under a given budget Args: label_epoch_frequency: Number of epoch to train on before requesting labellisation. inference_iteration: Number of inference to perform to compute uncertainty. """ super().__init__() self.label_epoch_frequency = label_epoch_frequency self.inference_iteration = inference_iteration self.should_reset_weights = should_reset_weights self.fit_loop: Optional[FitLoop] = None self.progress = Progress() self._model_state_dict: Optional[Dict[str, torch.Tensor]] = None self._datamodule_state_dict: Optional[Dict[str, Any]] = None self._lightning_module: Optional[flash.Task] = None
def test_progress_increment_sequence(): """Test sequence for incrementing.""" batch = Progress() batch.increment_ready() assert batch.total == ProcessedTracker(ready=1) assert batch.current == ProcessedTracker(ready=1) batch.increment_started() assert batch.total == ProcessedTracker(ready=1, started=1) assert batch.current == ProcessedTracker(ready=1, started=1) batch.increment_processed() assert batch.total == ProcessedTracker(ready=1, started=1, processed=1) assert batch.current == ProcessedTracker(ready=1, started=1, processed=1) batch.increment_completed() assert batch.total == ProcessedTracker(ready=1, started=1, processed=1, completed=1) assert batch.current == ProcessedTracker(ready=1, started=1, processed=1, completed=1)
def __init__(self) -> None: super().__init__() self.dataloader: Optional[Iterator] = None self._dl_max_batches: Optional[int] = None self._num_dataloaders: Optional[int] = None self.outputs: List[STEP_OUTPUT] = [] self.batch_progress = Progress()
def __init__(self, min_epochs: Optional[int] = None, max_epochs: Optional[int] = None): super().__init__() self.max_epochs = max_epochs self.min_epochs = min_epochs self.epoch_loop: Optional[TrainingEpochLoop] = None self.epoch_progress = Progress()
def __init__(self, min_epochs: Optional[int] = None, max_epochs: Optional[int] = None): super().__init__() self.max_epochs = max_epochs self.min_epochs = min_epochs self.epoch_loop: Optional[TrainingEpochLoop] = None self.epoch_progress = Progress() # caches the loaded dataloader state until dataloader objects are available self._dataloader_state_dict: Dict[str, Any] = {}
def __init__(self) -> None: super().__init__() self.return_predictions = False self.predictions: List[Any] = [] self.current_batch_indices: List[int] = [] self.batch_progress = Progress() self._dl_max_batches = 0 self._num_dataloaders = 0 self._warning_cache = WarningCache() self._seen_batch_indices: List[List[int]] = []
def __init__(self, min_epochs: Optional[int] = None, max_epochs: Optional[int] = None): super().__init__() # Allow max_epochs or max_steps to be zero, since this will be handled by fit_loop.done if max_epochs and max_epochs < -1: raise MisconfigurationException( f"`max_epochs` must be a positive integer or -1. You passed in {max_epochs}." ) self.max_epochs = max_epochs self.min_epochs = min_epochs self.epoch_loop: Optional[TrainingEpochLoop] = None self.epoch_progress = Progress() # caches the loaded dataloader state until dataloader objects are available self._dataloader_state_dict: Dict[str, Any] = {}
def __init__(self, min_steps: int, max_steps: int): super().__init__() self.min_steps: int = min_steps self.max_steps: int = max_steps self.global_step: int = 0 # manually tracking which is the last batch is necessary for iterable dataset support self.is_last_batch: Optional[bool] = None self.batch_progress = Progress() self.scheduler_progress = SchedulerProgress() self.batch_loop: Optional[TrainingBatchLoop] = None self.val_loop: Optional["loops.EvaluationLoop"] = None self._results = ResultCollection(training=True) self._epoch_output: Optional[List[List[STEP_OUTPUT]]] = None
def test_progress_increment(attr): p = Progress() fn = getattr(p, f"increment_{attr}") fn() expected = ProcessedTracker(**{attr: 1}) assert p.total == expected assert p.current == expected
def __init__( self, min_epochs: Optional[int] = 1, max_epochs: int = 1000, ) -> None: super().__init__() if max_epochs < -1: # Allow max_epochs to be zero, since this will be handled by fit_loop.done raise MisconfigurationException( f"`max_epochs` must be a non-negative integer or -1. You passed in {max_epochs}." ) self.max_epochs = max_epochs self.min_epochs = min_epochs self.epoch_loop: Optional[TrainingEpochLoop] = None self.epoch_progress = Progress() self._is_fresh_start_epoch: bool = True
def __init__(self) -> None: super().__init__() # since manual optimization does not track lr scheduler or optimizer frequencies, we use a simpler progress than # `OptimizationProgress` self.optim_step_progress = Progress.from_defaults(ReadyCompletedTracker) self._done: bool = False self._hiddens: Optional[Any] = None self._output: _OUTPUTS_TYPE = {}
def __init__(self, min_steps: int, max_steps: int): super().__init__() self.min_steps: int = min_steps self.max_steps: int = max_steps self.global_step: int = 0 # the total batch index across all epochs self.total_batch_idx: int = 0 self.is_last_batch: Optional[bool] = None self.batch_progress = Progress() self.scheduler_progress = SchedulerProgress() self.batch_loop: Optional[TrainingBatchLoop] = None self.val_loop: Optional["loops.EvaluationLoop"] = None self._results = ResultCollection(training=True) self._dataloader_idx: Optional[int] = None self._warning_cache: WarningCache = WarningCache() self._epoch_output: Optional[List[List[STEP_OUTPUT]]] = None
def __init__( self, min_epochs: int = 0, max_epochs: int = 1000, ) -> None: super().__init__() if max_epochs < -1: # Allow max_epochs to be zero, since this will be handled by fit_loop.done raise MisconfigurationException( f"`max_epochs` must be a non-negative integer or -1. You passed in {max_epochs}." ) self.max_epochs = max_epochs self.min_epochs = min_epochs self.epoch_loop = TrainingEpochLoop() self.epoch_progress = Progress() self._is_fresh_start_epoch: bool = True self._outputs: _EPOCH_OUTPUTS_TYPE = [] self._data_fetcher: Optional[AbstractDataFetcher] = None
def __init__(self, min_steps: int, max_steps: int): super().__init__() self.min_steps: int = min_steps if max_steps and max_steps < -1: raise MisconfigurationException( f"`max_steps` must be a positive integer or -1. You passed in {max_steps}." ) self.max_steps: int = max_steps self.global_step: int = 0 # manually tracking which is the last batch is necessary for iterable dataset support self.is_last_batch: Optional[bool] = None self.batch_progress = Progress() self.scheduler_progress = SchedulerProgress() self.batch_loop: Optional[TrainingBatchLoop] = None self.val_loop: Optional["loops.EvaluationLoop"] = None self._results = ResultCollection(training=True) self._epoch_output: Optional[List[List[STEP_OUTPUT]]] = None
def test_epoch_loop_progress_increment_sequence(): """Test sequences for incrementing batches reads and epochs.""" batch = Progress() batch.increment_ready() assert batch.total == Tracker(ready=1) assert batch.current == Tracker(ready=1) batch.increment_started() assert batch.total == Tracker(ready=1, started=1) assert batch.current == Tracker(ready=1, started=1) batch.increment_processed() assert batch.total == Tracker(ready=1, started=1, processed=1) assert batch.current == Tracker(ready=1, started=1, processed=1) batch.increment_completed() assert batch.total == Tracker(ready=1, started=1, processed=1, completed=1) assert batch.current == Tracker(ready=1, started=1, processed=1, completed=1)
def test_loop_progress_increment_sequence(): """ Test sequences for incrementing batches reads and epochs. """ p = LoopProgress(batch=Progress(total=Tracker(started=None))) p.batch.increment_ready() assert p.batch.total == Tracker(ready=1, started=None) assert p.batch.current == Tracker(ready=1) p.batch.increment_started() assert p.batch.total == Tracker(ready=1, started=None) assert p.batch.current == Tracker(ready=1) p.batch.increment_processed() assert p.batch.total == Tracker(ready=1, started=None, processed=1) assert p.batch.current == Tracker(ready=1, processed=1) p.batch.increment_completed() assert p.batch.total == Tracker(ready=1, started=None, processed=1, completed=1) assert p.batch.current == Tracker(ready=1, processed=1, completed=1) assert p.epoch.total == Tracker() assert p.epoch.current == Tracker() p.increment_epoch_completed() assert p.batch.total == Tracker(ready=1, started=None, processed=1, completed=1) assert p.batch.current == Tracker() assert p.epoch.total == Tracker(completed=1) assert p.epoch.current == Tracker() p.batch.increment_ready() assert p.batch.total == Tracker(ready=2, started=None, processed=1, completed=1) assert p.batch.current == Tracker(ready=1) assert p.epoch.total == Tracker(completed=1) assert p.epoch.current == Tracker() p.reset_on_epoch() assert p.batch.total == Tracker(ready=2, started=None, processed=1, completed=1) assert p.batch.current == Tracker() assert p.epoch.total == Tracker(completed=1) assert p.epoch.current == Tracker()
def __init__(self, min_steps: int, max_steps: int): super().__init__() self.min_steps: int = min_steps self.max_steps: int = max_steps self.global_step: int = 0 # the total batch index across all epochs self.total_batch_idx: int = 0 # the current split index when the batch gets split into chunks in truncated backprop through time self.split_idx: Optional[int] = None # the number of batches seen this run, updates immediately after batch_loop.run() # TODO: replace by progress tracking self.batches_seen: int = 0 self.is_last_batch: Optional[bool] = None self.batch_progress = Progress() self.scheduler_progress = SchedulerProgress() self.batch_loop: Optional[TrainingBatchLoop] = None self.val_loop: Optional["loops.EvaluationLoop"] = None self._results = ResultCollection(training=True) self._dataloader_idx: Optional[int] = None self._warning_cache: WarningCache = WarningCache() self._epoch_output: Optional[List[List[STEP_OUTPUT]]] = None
def test_progress_raises(): with pytest.raises(ValueError, match="instances should be of the same class"): Progress(ReadyCompletedTracker(), ProcessedTracker()) p = Progress(ReadyCompletedTracker(), ReadyCompletedTracker()) with pytest.raises( TypeError, match="ReadyCompletedTracker` doesn't have a `started` attribute"): p.increment_started() with pytest.raises( TypeError, match="ReadyCompletedTracker` doesn't have a `processed` attribute" ): p.increment_processed()
def _reset_fitting(self): self.trainer.state.fn = TrainerFn.FITTING self.trainer.training = True self.trainer.lightning_module.on_train_dataloader() self._connect(self._lightning_module) self.fit_loop.epoch_progress = Progress()
class PredictionEpochLoop(Loop): """Loop performing prediction on arbitrary sequentially used dataloaders.""" def __init__(self) -> None: super().__init__() self.return_predictions = False self.predictions: List[Any] = [] self.current_batch_indices: List[int] = [] self.batch_progress = Progress() self._dl_max_batches = 0 self._num_dataloaders = 0 self._warning_cache = WarningCache() self._seen_batch_indices: List[List[int]] = [] @property def done(self) -> bool: """Ends prediction when the iteration count exceeds the total number of available batches.""" return self.batch_progress.current.completed >= self._dl_max_batches @property def should_store_predictions(self) -> bool: """Whether the predictions should be stored for later usage (e.g. aggregation or returning)""" any_pred = any(cb.interval.on_epoch for cb in self.trainer.prediction_writer_callbacks) return self.return_predictions or any_pred def connect(self, **kwargs: "Loop") -> None: raise NotImplementedError( f"{self.__class__.__name__} does not connect any child loops.") def reset(self) -> None: """Resets the loops internal state.""" self._seen_batch_indices = [] self.predictions = [] self.batch_progress.reset_on_run() def on_run_start( # type: ignore[override] self, dataloader_iter: Iterator, dataloader_idx: int, dl_max_batches: int, num_dataloaders: int, ) -> None: """Prepares the loops internal state. Args: dataloader_iter: the iterator over the current dataloader dataloader_idx: the index of the current dataloader dl_max_batches: the maximum number of batches the current loader can produce num_dataloaders: the total number of dataloaders """ void(dataloader_iter, dataloader_idx) self._dl_max_batches = dl_max_batches self._num_dataloaders = num_dataloaders # this call requires that `self.return_predictions` is set self._seen_batch_indices = self._get_batch_indices( dataloader_idx) if self.should_store_predictions else [] def advance( # type: ignore[override] self, dataloader_iter: Iterator, dataloader_idx: int, dl_max_batches: int, num_dataloaders: int, ) -> None: """Runs one prediction step. Args: dataloader_iter: the iterator over the current dataloader dataloader_idx: the index of the current dataloader dl_max_batches: the maximum number of batches the current loader can produce num_dataloaders: the total number of dataloaders """ action_name = f"[{self.__class__.__name__}].predict_dataloader_idx_{dataloader_idx}_next" with self.trainer.profiler.profile(action_name): batch_idx, batch = next(dataloader_iter) self._seen_batch_indices = self._get_batch_indices( dataloader_idx) if self.should_store_predictions else [] # we need to truncate the list of batch indices due to prefetching in the dataloader and Lightning self._seen_batch_indices = self._seen_batch_indices[:( self.batch_progress.current.completed + 1)] if batch is None: raise StopIteration batch = self.trainer._call_strategy_hook("batch_to_device", batch, dataloader_idx=dataloader_idx) self.batch_progress.increment_ready() self._predict_step(batch, batch_idx, dataloader_idx) def on_run_end(self) -> Tuple[List[Any], List[List[int]]]: """Returns the predictions and the corresponding batch indices.""" predictions, all_batch_indices = self.predictions, self._seen_batch_indices self.predictions, self._seen_batch_indices = [], [] # free memory return predictions, all_batch_indices def _predict_step(self, batch: Any, batch_idx: int, dataloader_idx: int) -> None: """Runs the actual predict step together with all the necessary bookkeeping and the hooks tied to the predict step. Args: batch: the current batch to run the prediction on batch_idx: the index of the current batch dataloader_idx: the index of the dataloader producing the current batch """ # configure step_kwargs step_kwargs = self._build_kwargs(batch, batch_idx, dataloader_idx) # extract batch_indices and store them batch_indices = self._get_batch_indices(dataloader_idx) self.current_batch_indices = batch_indices[ batch_idx] if batch_indices else [] self.trainer._call_callback_hooks("on_predict_batch_start", batch, batch_idx, dataloader_idx) self.trainer._call_lightning_module_hook("on_predict_batch_start", batch, batch_idx, dataloader_idx) self.batch_progress.increment_started() predictions = self.trainer._call_strategy_hook("predict_step", *step_kwargs.values()) self.batch_progress.increment_processed() if predictions is None: self._warning_cache.warn( "predict returned None if it was on purpose, ignore this warning..." ) self.trainer._call_callback_hooks("on_predict_batch_end", predictions, batch, batch_idx, dataloader_idx) self.trainer._call_lightning_module_hook("on_predict_batch_end", predictions, batch, batch_idx, dataloader_idx) self.batch_progress.increment_completed() if self.should_store_predictions: self.predictions.append( move_data_to_device(predictions, torch.device("cpu"))) def _build_kwargs(self, batch: Any, batch_idx: int, dataloader_idx: int) -> Dict[str, Any]: """Assembles the keyword arguments for the ``predict_step`` Args: batch: the current batch to run the prediction on batch_idx: the index of the current batch dataloader_idx: the index of the dataloader producing the current batch Returns: the dictionary containing all the keyboard arguments for the predict step """ step_kwargs = OrderedDict([("batch", batch), ("batch_idx", batch_idx)]) if self._num_dataloaders > 1: step_kwargs["dataloader_idx"] = dataloader_idx return step_kwargs def _get_batch_indices(self, dataloader_idx: int) -> List[List[int]]: """Returns a reference to the seen batch indices if the dataloader has a batch sampler wrapped by our :class:`~pytorch_lightning.overrides.distributed.IndexBatchSamplerWrapper`.""" # the batch_sampler is not be defined in case of CombinedDataLoaders batch_sampler = getattr( self.trainer. predict_dataloaders[dataloader_idx], # type: ignore[has-type] "batch_sampler", None, ) if isinstance(batch_sampler, IndexBatchSamplerWrapper): return batch_sampler.seen_batch_indices warning_cache.warn( "Lightning couldn't infer the indices fetched for your dataloader." ) return []
def test_progress_from_defaults(): actual = Progress.from_defaults(StartedTracker, completed=5) expected = Progress(total=StartedTracker(completed=5), current=StartedTracker(completed=5)) assert actual == expected
def test_deepcopy(): _ = deepcopy(BaseProgress()) _ = deepcopy(Progress()) _ = deepcopy(ProcessedTracker())
class FitLoop(Loop): """ This Loop iterates over the epochs to run the training. Args: min_epochs: The minimum number of epochs max_epochs: The maximum number of epochs """ def __init__(self, min_epochs: Optional[int] = None, max_epochs: Optional[int] = None): super().__init__() self.max_epochs = max_epochs self.min_epochs = min_epochs self.epoch_loop: Optional[TrainingEpochLoop] = None self.epoch_progress = Progress() @property def current_epoch(self) -> int: """Return the current epoch""" return self.epoch_progress.current.completed @current_epoch.setter def current_epoch(self, value: int) -> None: """Setter for the current epoch""" self.epoch_progress.current.completed = value @property def global_step(self) -> int: """Returns the global step""" return self.epoch_loop.global_step @global_step.setter def global_step(self, value: int) -> None: """Sets the global step (forwards to epoch_loop)""" self.epoch_loop.global_step = value @property def total_batch_idx(self) -> int: """Returns the total number of batches already run (across all epochs)""" return self.epoch_loop.total_batch_idx @property def batch_idx(self) -> int: """Returns the number of batches already run within this epoch""" return self.epoch_loop.batch_progress.current.ready - 1 @property def split_idx(self) -> int: """Returns the index of the current batch split (within the current batch) for bptt""" return self.epoch_loop.split_idx @property def min_steps(self) -> int: # TODO(@justusschock): Why aren't we using the attribute in this class? """Returns the minimum numnber of steps to run""" return self.epoch_loop.min_steps @min_steps.setter def min_steps(self, value: int) -> None: """Sets the minimum number of steps (forwards to epoch_loop)""" # TODO(@awaelchli): This setter is required by debugging connector (fast dev run), should be avoided self.epoch_loop.min_steps = value @property def max_steps(self) -> int: """Returns the maximum number of steps to run""" return self.epoch_loop.max_steps @max_steps.setter def max_steps(self, value: int) -> None: """Sets the maximum number of steps (forwards to epoch_loop)""" # TODO(@awaelchli): This setter is required by debugging connector (fast dev run), should be avoided self.epoch_loop.max_steps = value @property def running_loss(self) -> TensorRunningAccum: """Returns the running loss""" return self.epoch_loop.batch_loop.running_loss @property def _skip_backward(self) -> bool: """ Determines whether the loop will skip backward during automatic optimization. """ return self.epoch_loop.batch_loop._skip_backward @_skip_backward.setter def _skip_backward(self, value: bool) -> None: """ Determines whether the loop will skip backward during automatic optimization. """ self.epoch_loop.batch_loop._skip_backward = value @property def _results(self) -> ResultCollection: if self.trainer.training: return self.epoch_loop._results if self.trainer.validating: return self.epoch_loop.val_loop._results raise RuntimeError( "`FitLoop._results` property isn't defined. Accessed outside of scope" ) @property def done(self) -> bool: """Evaluates when to leave the loop. Returns True if trainer.should_stop was set (e.g. by early stopping) or if the maximum number of steps or epochs is reached. """ # TODO(@awaelchli): Move track steps inside training loop and move part of these condition inside training loop stop_steps = self.max_steps is not None and self.global_step >= self.max_steps stop_epochs = self.max_epochs is not None and self.current_epoch >= self.max_epochs should_stop = False if self.trainer.should_stop: # early stopping met_min_epochs = self.current_epoch >= self.min_epochs if self.min_epochs else True met_min_steps = self.global_step >= self.min_steps if self.min_steps else True if met_min_epochs and met_min_steps: should_stop = True else: log.info( 'Trainer was signaled to stop but required minimum epochs' f' ({self.min_epochs}) or minimum steps ({self.min_steps}) has' ' not been met. Training will continue...') self.trainer.should_stop = should_stop return stop_steps or should_stop or stop_epochs @property def skip(self) -> bool: """Whether we should skip the training and immediately return from the call to :meth:`run`.""" return self.done or self.trainer.num_training_batches == 0 def connect(self, epoch_loop: TrainingEpochLoop): """Connects a training epoch loop to this fit loop.""" self.epoch_loop = epoch_loop def reset(self) -> None: """Resets the internal state of this loop""" def on_run_start(self) -> None: """Calls the ``on_train_start`` hook.""" self._results.to(device=self.trainer.lightning_module.device) self.trainer.call_hook("on_train_start") def on_advance_start(self) -> None: """Prepares the dataloader for training and calls the hooks ``on_epoch_start`` and ``on_train_epoch_start``""" model = self.trainer.lightning_module # reset train dataloader if self.current_epoch != 0 and self.trainer._should_reload_dl_epoch: self.trainer.reset_train_dataloader(model) # TODO: specify the possible exception with suppress(Exception): # set seed for distributed sampler (enables shuffling for each epoch) self.trainer.train_dataloader.sampler.set_epoch(self.current_epoch) # changing gradient according accumulation_scheduler self.trainer.accumulation_scheduler.on_train_epoch_start( self.trainer, self.trainer.lightning_module) # stores accumulated grad fractions per batch self.epoch_loop.batch_loop.accumulated_loss = TensorRunningAccum( window_length=self.trainer.accumulate_grad_batches) self.epoch_progress.increment_ready() def advance(self) -> None: """Runs one whole epoch.""" train_dataloader = self.trainer.accelerator.process_dataloader( self.trainer.train_dataloader) train_dataloader = self.trainer.data_connector.get_profiled_train_dataloader( train_dataloader) with self.trainer.profiler.profile("run_training_epoch"): # run train epoch epoch_output = self.epoch_loop.run(train_dataloader) if epoch_output is None: return # the global step is manually decreased here due to backwards compatibility with existing loggers # as they expect that the same step is used when logging epoch end metrics even when the batch loop has # finished. this means the attribute does not exactly track the number of optimizer steps applied. # TODO(@carmocca): deprecate and rename so users don't get confused self.global_step -= 1 # log epoch metrics self.trainer.logger_connector.update_train_epoch_metrics() self.global_step += 1 def on_advance_end(self) -> None: self.epoch_progress.increment_completed() def on_run_end(self) -> None: """Calls the ``on_train_end`` hook""" # NOTE: the iteration_count/current_epoch is already incremented # Lightning today does not increment the current epoch at the last epoch run in Trainer.fit # To simulate that current behavior, we decrement here. # TODO: must be fixed by https://github.com/PyTorchLightning/pytorch-lightning/issues/5007 self.current_epoch -= 1 # hook self.trainer.call_hook("on_train_end") # todo: TPU 8 cores hangs in flush with TensorBoard. Might do for all loggers. # It might be related to xla tensors blocked when moving the cpu # kill loggers if self.trainer.logger is not None: self.trainer.logger.finalize("success") # summarize profile results self.trainer.profiler.describe() # give accelerators a chance to finish self.trainer.accelerator.on_train_end() def should_accumulate(self) -> bool: """Whether the gradients should be accumulated""" return self.epoch_loop.batch_loop.should_accumulate() def teardown(self) -> None: self.epoch_loop.teardown()
class TrainingEpochLoop(loops.Loop): """ Runs over all batches in a dataloader (one epoch). Args: min_steps: The minimum number of steps (batches) to process max_steps: The maximum number of steps (batches) to process """ def __init__(self, min_steps: int, max_steps: int): super().__init__() self.min_steps: int = min_steps self.max_steps: int = max_steps self.global_step: int = 0 # manually tracking which is the last batch is necessary for iterable dataset support self.is_last_batch: Optional[bool] = None self.batch_progress = Progress() self.scheduler_progress = SchedulerProgress() self.batch_loop: Optional[TrainingBatchLoop] = None self.val_loop: Optional["loops.EvaluationLoop"] = None self._results = ResultCollection(training=True) self._epoch_output: Optional[List[List[STEP_OUTPUT]]] = None @property def total_batch_idx(self) -> int: """Returns the current batch index (across epochs)""" # use `ready` instead of `completed` in case this is accessed after `completed` has been increased # but before the next `ready` increase return self.batch_progress.total.ready - 1 @property def batch_idx(self) -> int: """Returns the current batch index (within this epoch)""" # use `ready` instead of `completed` in case this is accessed after `completed` has been increased # but before the next `ready` increase return self.batch_progress.current.ready - 1 @property def done(self) -> bool: """Returns whether the training should be stopped. The criteria are that the number of steps reached the max steps, the last batch is reached or the trainer signals to stop (e.g. by early stopping). """ max_steps_reached = self.max_steps is not None and self.global_step >= self.max_steps return max_steps_reached or self.trainer.should_stop or self._num_training_batches_reached( self.is_last_batch) def connect( self, batch_loop: TrainingBatchLoop = None, val_loop: Optional["loops.EvaluationLoop"] = None, ) -> None: """Optionally connect a custom batch or validation loop to this training epoch loop.""" if batch_loop is not None: self.batch_loop = batch_loop if val_loop is not None: self.val_loop = val_loop def reset(self) -> None: """Resets the internal state of the loop for a new run""" self.is_last_batch = False # track epoch output self._epoch_output = [[] for _ in range( self.batch_loop.num_active_optimizers(self.total_batch_idx))] if not self.restarting: self.batch_progress.current.reset() self.scheduler_progress.current.reset() self.batch_loop.optim_progress.reset_on_epoch() def on_run_start(self, dataloader_iter: Iterator, **kwargs: Any) -> None: # hook self.trainer.logger_connector.on_epoch_start() self.trainer.call_hook("on_epoch_start") self.trainer.call_hook("on_train_epoch_start") self.trainer.fit_loop.epoch_progress.increment_started() self.dataloader_iter = _prepare_dataloader_iter( dataloader_iter, self.batch_idx + 1) def advance(self, *args: Any, **kwargs: Any) -> None: """Runs a single training batch. Args: dataloader_iter: the iterator over the dataloader producing the new batch Raises: StopIteration: When the epoch is canceled by the user returning -1 """ batch_idx, (batch, is_last) = next(self.dataloader_iter) if not self.trainer.data_connector.train_data_fetcher.store_on_device: with self.trainer.profiler.profile("training_batch_to_device"): batch = self.trainer.accelerator.batch_to_device(batch) self.batch_progress.increment_ready() with self.trainer.profiler.profile("run_training_batch"): batch_output = self.batch_loop.run(batch, batch_idx) self.batch_progress.increment_processed() self.is_last_batch = is_last # when returning -1 from train_step, we end epoch early if batch_output.signal == -1: raise StopIteration # update non-plateau LR schedulers # update epoch-interval ones only when we are at the end of training epoch self.update_lr_schedulers("step", update_plateau_schedulers=False) if self._num_training_batches_reached(is_last): self.update_lr_schedulers("epoch", update_plateau_schedulers=False) batch_end_outputs = [ opt_idx_out for opt_idx_out in batch_output.training_step_output if len(opt_idx_out) ] processed_batch_end_outputs = self._prepare_outputs(batch_end_outputs, batch_mode=True) # hook self.trainer.call_hook("on_train_batch_end", processed_batch_end_outputs, batch, self.batch_idx, 0) self.trainer.call_hook("on_batch_end") self.trainer.logger_connector.on_batch_end() self.batch_progress.increment_completed() # figure out what to track for epoch end self._track_epoch_end_reduce_metrics(self._epoch_output, batch_end_outputs) # ----------------------------------------- # SAVE METRICS TO LOGGERS AND PROGRESS_BAR # ----------------------------------------- self.trainer.logger_connector.update_train_step_metrics() def on_advance_end(self): """Runs validation and Checkpointing if necessary. Raises: StopIteration: if :attr:`done` evaluates to ``True`` to finish this epoch """ # ----------------------------------------- # VALIDATE IF NEEDED + CHECKPOINT CALLBACK # ----------------------------------------- should_check_val = self._should_check_val_fx(self.batch_idx, self.is_last_batch) if should_check_val: self.trainer.validating = True self._run_validation() self.trainer.training = True # ----------------------------------------- # SAVE LOGGERS (ie: Tensorboard, etc...) # ----------------------------------------- self._save_loggers_on_train_batch_end() # update plateau LR scheduler after metrics are logged self.update_lr_schedulers("step", update_plateau_schedulers=True) # progress global step according to grads progress self._increment_accumulated_grad_global_step() def on_run_end(self) -> List[List[STEP_OUTPUT]]: """Calls the on_epoch_end hook. Returns: The output of each training step for each optimizer Raises: MisconfigurationException: ``train_epoch_end`` does not return ``None`` """ if self.batch_progress.current.ready == 0: # dataloader/iterator did not produce a batch return # inform logger the batch loop has finished self.trainer.logger_connector.epoch_end_reached() # prepare epoch output processed_outputs = self._prepare_outputs(self._epoch_output, batch_mode=False) # get the model and call model.training_epoch_end model = self.trainer.lightning_module if is_overridden("training_epoch_end", model): # run training_epoch_end # refresh the result for custom logging at the epoch level model._current_fx_name = "training_epoch_end" # lightningmodule hook training_epoch_end_output = model.training_epoch_end( processed_outputs) if training_epoch_end_output is not None: raise MisconfigurationException( "training_epoch_end expects a return of None. " "HINT: remove the return statement in training_epoch_end") self.trainer.fit_loop.epoch_progress.increment_processed() # call train epoch end hooks self.trainer.call_hook("on_train_epoch_end") self.trainer.call_hook("on_epoch_end") self.trainer.logger_connector.on_epoch_end() if self._num_training_batches_reached(self.is_last_batch): self.update_lr_schedulers("epoch", update_plateau_schedulers=True) epoch_output = self._epoch_output # free memory self._epoch_output = None return epoch_output def teardown(self) -> None: self._results.cpu() self.batch_loop.teardown() self.val_loop.teardown() def _run_validation(self): # reload dataloaders self.val_loop.reload_evaluation_dataloaders() with torch.no_grad(): self.val_loop.run() def _accumulated_batches_reached(self) -> bool: """Determine if accumulation will be finished by the end of the current batch.""" return self.batch_progress.current.ready % self.trainer.accumulate_grad_batches == 0 def _num_training_batches_reached(self, is_last_batch: bool = False) -> bool: """Checks if we are in the last batch or if there are more batches to follow. Args: is_last_batch: Whether the current batch is the last one """ return self.batch_progress.current.ready == self.trainer.num_training_batches or is_last_batch def _should_accumulate(self) -> bool: """Checks if the optimizer step should be performed or gradients should be accumulated for the current step.""" accumulation_done = self._accumulated_batches_reached() is_final_batch = self._num_training_batches_reached() return not (accumulation_done or is_final_batch) def _track_epoch_end_reduce_metrics( self, epoch_output: List[List[STEP_OUTPUT]], batch_end_outputs: STEP_OUTPUT) -> None: """Adds the batch outputs to the epoch outputs and prepares reduction""" hook_overridden = is_overridden("training_epoch_end", self.trainer.lightning_module) if not hook_overridden: return # track the outputs to reduce at the end of the epoch for opt_idx, opt_outputs in enumerate(batch_end_outputs): # with 1 step (no tbptt) don't use a sequence at epoch end if isinstance(opt_outputs, list) and len(opt_outputs) == 1: opt_outputs = opt_outputs[0] epoch_output[opt_idx].append(opt_outputs) @staticmethod def _prepare_outputs( outputs: List[List[List["ResultCollection"]]], batch_mode: bool ) -> Union[List[List[List[Dict]]], List[List[Dict]], List[Dict], Dict]: """ Extract required information from batch or epoch end results. Args: outputs: A 3-dimensional list of ``ResultCollection`` objects with dimensions: ``[optimizer outs][batch outs][tbptt steps]``. batch_mode: If True, ignore the batch output dimension. Returns: The cleaned outputs with ``ResultCollection`` objects converted to dictionaries. All list dimensions of size one will be collapsed. """ processed_outputs = [] for opt_outputs in outputs: # handle an edge case where an optimizer output is the empty list if len(opt_outputs) == 0: continue processed_batch_outputs = [] if batch_mode: opt_outputs = [opt_outputs] for batch_outputs in opt_outputs: processed_tbptt_outputs = [] if isinstance(batch_outputs, ResultCollection): batch_outputs = [batch_outputs] for tbptt_output in batch_outputs: out = {} if tbptt_output.minimize is not None: out["loss"] = tbptt_output.minimize.detach() out.update(tbptt_output.extra) processed_tbptt_outputs.append(out) # if there was only one tbptt step then we can collapse that dimension if len(processed_tbptt_outputs) == 1: processed_tbptt_outputs = processed_tbptt_outputs[0] processed_batch_outputs.append(processed_tbptt_outputs) # batch_outputs should be just one dict (or a list of dicts if using tbptt) per optimizer if batch_mode: processed_batch_outputs = processed_batch_outputs[0] processed_outputs.append(processed_batch_outputs) # if there is only one optimiser then we collapse that dimension if len(processed_outputs) == 1: processed_outputs = processed_outputs[0] return processed_outputs def update_lr_schedulers(self, interval: str, update_plateau_schedulers: bool) -> None: """updates the lr schedulers based on the given interval""" if interval == "step" and self._should_accumulate(): return self.trainer.optimizer_connector.update_learning_rates( interval=interval, update_plateau_schedulers=update_plateau_schedulers, opt_indices=[ opt_idx for opt_idx, _ in self.batch_loop.get_active_optimizers(self.total_batch_idx) ], ) def _increment_accumulated_grad_global_step(self) -> None: """Increments global step according to grads progress""" if not self._should_accumulate(): self.global_step = self.trainer.accelerator.update_global_step( self.batch_progress.current.ready, self.trainer.global_step) def _should_check_val_fx(self, batch_idx: int, is_last_batch: bool) -> bool: """Decide if we should run validation.""" if not self.trainer.enable_validation: return False is_val_check_epoch = (self.trainer.current_epoch + 1) % self.trainer.check_val_every_n_epoch == 0 if not is_val_check_epoch: return False # val_check_batch is inf for iterable datasets with no length defined is_infinite_dataset = self.trainer.val_check_batch == float("inf") if is_last_batch and is_infinite_dataset: return True if self.trainer.should_stop: return True # TODO(@awaelchli): let training/eval loop handle logic around limit_*_batches and val_check_batch is_val_check_batch = is_last_batch if isinstance(self.trainer.limit_train_batches, int) and is_infinite_dataset: is_val_check_batch = (batch_idx + 1) % self.trainer.limit_train_batches == 0 elif self.trainer.val_check_batch != float("inf"): is_val_check_batch = (batch_idx + 1) % self.trainer.val_check_batch == 0 return is_val_check_batch def _save_loggers_on_train_batch_end(self) -> None: """Flushes loggers to disk""" # when loggers should save to disk should_flush_logs = self.trainer.logger_connector.should_flush_logs if should_flush_logs and self.trainer.is_global_zero and self.trainer.logger is not None: self.trainer.logger.save()
class EvaluationEpochLoop(Loop): """ This is the loop performing the evaluation. It mainly loops over the given dataloader and runs the validation or test step (depending on the trainer's current state). """ def __init__(self) -> None: super().__init__() self.predictions: Optional[PredictionCollection] = None self.dataloader: Optional[Iterator] = None self._dl_max_batches: Optional[int] = None self._num_dataloaders: Optional[int] = None self.outputs: List[STEP_OUTPUT] = [] self.batch_progress = Progress() @property def done(self) -> bool: """Returns ``True`` if the current iteration count reaches the number of dataloader batches.""" return self.batch_progress.current.completed >= self._dl_max_batches def connect(self, **kwargs: "Loop") -> None: raise NotImplementedError( f"{self.__class__.__name__} does not connect any child loops.") def reset(self) -> None: """Resets the loop's internal state.""" self.predictions = PredictionCollection(self.trainer.global_rank, self.trainer.world_size) self._dl_max_batches = None self._num_dataloaders = None self.outputs = [] if not self.restarting: self.batch_progress.current.reset() def on_run_start( self, dataloader_iter: Iterator, dataloader_idx: int, dl_max_batches: int, num_dataloaders: int, ) -> None: """Adds the passed arguments to the loop's state if necessary Args: dataloader_iter: iterator over the dataloader dataloader_idx: index of the current dataloader dl_max_batches: maximum number of batches the dataloader can produce num_dataloaders: the total number of dataloaders """ void(dataloader_iter, dataloader_idx) self._dl_max_batches = dl_max_batches self._num_dataloaders = num_dataloaders def advance( self, dataloader_iter: Iterator, dataloader_idx: int, dl_max_batches: int, num_dataloaders: int, ) -> None: """Calls the evaluation step with the corresponding hooks and updates the logger connector. Args: dataloader_iter: iterator over the dataloader dataloader_idx: index of the current dataloader dl_max_batches: maximum number of batches the dataloader can produce num_dataloaders: the total number of dataloaders Raises: StopIteration: If the current batch is None """ void(dl_max_batches, num_dataloaders) batch_idx, batch = next(dataloader_iter) if batch is None: raise StopIteration with self.trainer.profiler.profile("evaluation_batch_to_device"): batch = self.trainer.accelerator.batch_to_device( batch, dataloader_idx=dataloader_idx) self.batch_progress.increment_ready() # hook self.on_evaluation_batch_start(batch, batch_idx, dataloader_idx) self.batch_progress.increment_started() # lightning module methods with self.trainer.profiler.profile("evaluation_step_and_end"): output = self.evaluation_step(batch, batch_idx, dataloader_idx) output = self.evaluation_step_end(output) self.batch_progress.increment_processed() # hook + store predictions self.on_evaluation_batch_end(output, batch, batch_idx, dataloader_idx) self.batch_progress.increment_completed() # log batch metrics self.trainer.logger_connector.update_eval_step_metrics() # track epoch level outputs self.outputs = self._track_output_for_epoch_end(self.outputs, output) def on_run_end(self) -> List[STEP_OUTPUT]: """Returns the outputs of the whole run""" outputs = self.outputs # free memory self.outputs = [] return outputs def evaluation_step(self, batch: Any, batch_idx: int, dataloader_idx: int) -> Optional[STEP_OUTPUT]: """The evaluation step (validation_step or test_step depending on the trainer's state). Args: batch: The current batch to run through the step. batch_idx: The index of the current batch dataloader_idx: the index of the dataloader producing the current batch Returns: the outputs of the step """ # configure step_kwargs step_kwargs = self._build_kwargs(batch, batch_idx, dataloader_idx) if self.trainer.testing: self.trainer.lightning_module._current_fx_name = "test_step" with self.trainer.profiler.profile("test_step"): output = self.trainer.accelerator.test_step(step_kwargs) else: self.trainer.lightning_module._current_fx_name = "validation_step" with self.trainer.profiler.profile("validation_step"): output = self.trainer.accelerator.validation_step(step_kwargs) return output def evaluation_step_end(self, *args: Any, **kwargs: Any) -> Optional[STEP_OUTPUT]: """Calls the `{validation/test}_step_end` hook""" hook_name = "test_step_end" if self.trainer.testing else "validation_step_end" output = self.trainer.call_hook(hook_name, *args, **kwargs) return output def on_evaluation_batch_start(self, batch: Any, batch_idx: int, dataloader_idx: int) -> None: """Calls the ``on_{validation/test}_batch_start`` hook. Args: batch: The current batch to run through the step batch_idx: The index of the current batch dataloader_idx: The index of the dataloader producing the current batch Raises: AssertionError: If the number of dataloaders is None (has not yet been set). """ self.trainer.logger_connector.on_batch_start() assert self._num_dataloaders is not None self.trainer.logger_connector.on_evaluation_batch_start( batch, batch_idx, dataloader_idx, self._num_dataloaders) if self.trainer.testing: self.trainer.call_hook("on_test_batch_start", batch, batch_idx, dataloader_idx) else: self.trainer.call_hook("on_validation_batch_start", batch, batch_idx, dataloader_idx) def on_evaluation_batch_end( self, output: Optional[STEP_OUTPUT], batch: Any, batch_idx: int, dataloader_idx: int, ) -> None: """The ``on_{validation/test}_batch_end`` hook. Args: output: The output of the performed step batch: The input batch for the step batch_idx: The index of the current batch dataloader_idx: Index of the dataloader producing the current batch """ hook_name = "on_test_batch_end" if self.trainer.testing else "on_validation_batch_end" self.trainer.call_hook(hook_name, output, batch, batch_idx, dataloader_idx) self.trainer.logger_connector.on_batch_end() # store predicitons if do_write_predictions and track eval loss history self.store_predictions(output, batch_idx, dataloader_idx) def store_predictions(self, output: Optional[STEP_OUTPUT], batch_idx: int, dataloader_idx: int) -> None: """Stores the predictions in the prediction collection (only if running in test mode) Args: output: the outputs of the current step batch_idx: the index of the current batch dataloader_idx: the index of the dataloader producing the current batch """ # Add step predictions to prediction collection to write later if output is not None and self.predictions is not None: if isinstance(output, ResultCollection) and self.trainer.testing: self.predictions.add(output.pop("predictions", None)) # track debug metrics self.trainer.dev_debugger.track_eval_loss_history( batch_idx, dataloader_idx, output) def _build_kwargs(self, batch: Any, batch_idx: int, dataloader_idx: int) -> Dict[str, Union[Any, int]]: """Helper function to build the arguments for the current step Args: batch: The current batch to run through the step batch_idx: the index of the current batch dataloader_idx: the index of the dataloader producing the current batch Returns: the keyword arguments to pass to the step function """ # make dataloader_idx arg in validation_step optional step_kwargs = OrderedDict([("batch", batch), ("batch_idx", batch_idx)]) multiple_val_loaders = not self.trainer.testing and self._num_dataloaders > 1 multiple_test_loaders = self.trainer.testing and self._num_dataloaders > 1 if multiple_test_loaders or multiple_val_loaders: step_kwargs["dataloader_idx"] = dataloader_idx return step_kwargs def _track_output_for_epoch_end( self, outputs: List[Union[ResultCollection, Dict, Tensor]], output: Optional[Union[ResultCollection, Dict, Tensor]], ) -> List[Union[ResultCollection, Dict, Tensor]]: if output is not None: if isinstance(output, ResultCollection): output = output.detach() if self.trainer.move_metrics_to_cpu: output = output.cpu() elif isinstance(output, dict): output = recursive_detach( output, to_cpu=self.trainer.move_metrics_to_cpu) elif isinstance( output, Tensor ) and output.is_cuda and self.trainer.move_metrics_to_cpu: output = output.cpu() outputs.append(output) return outputs
def test_base_progress_from_defaults(): actual = Progress.from_defaults(completed=5, started=None) expected = Progress(total=Tracker(started=None, completed=5), current=Tracker(started=None, completed=5)) assert actual == expected
class ActiveLearningLoop(Loop): max_epochs: int inference_model: InferenceMCDropoutTask @requires(["baal", (_PL_GREATER_EQUAL_1_4_0, "pytorch-lightning>=1.4.0")]) def __init__(self, label_epoch_frequency: int, inference_iteration: int = 2, should_reset_weights: bool = True): """The `ActiveLearning Loop` describes the following training procedure. This loop is connected with the `ActiveLearningTrainer` Example:: while unlabelled data or budget criteria not reached: if labelled data trainer.fit(model, labelled data) if unlabelled data: predictions = trainer.predict(model, unlabelled data) uncertainties = heuristic(predictions) request labellelisation for the sample with highest uncertainties under a given budget Args: label_epoch_frequency: Number of epoch to train on before requesting labellisation. inference_iteration: Number of inference to perform to compute uncertainty. """ super().__init__() self.label_epoch_frequency = label_epoch_frequency self.inference_iteration = inference_iteration self.should_reset_weights = should_reset_weights self.fit_loop: Optional[FitLoop] = None self.progress = Progress() self._model_state_dict: Optional[Dict[str, torch.Tensor]] = None self._datamodule_state_dict: Optional[Dict[str, Any]] = None self._lightning_module: Optional[flash.Task] = None @property def done(self) -> bool: return self.progress.current.completed >= self.max_epochs def connect(self, fit_loop: FitLoop): self.fit_loop = fit_loop self.max_epochs = self.fit_loop.max_epochs self.fit_loop.max_epochs = self.label_epoch_frequency def on_run_start(self, *args: Any, **kwargs: Any) -> None: assert isinstance(self.trainer.datamodule, ActiveLearningDataModule) if self._datamodule_state_dict is not None: self.trainer.datamodule.load_state_dict( self._datamodule_state_dict) self.trainer.predict_loop._return_predictions = True self._lightning_module = self.trainer.lightning_module self._model_state_dict = deepcopy(self._lightning_module.state_dict()) self.inference_model = InferenceMCDropoutTask(self._lightning_module, self.inference_iteration) def reset(self) -> None: pass def on_advance_start(self, *args: Any, **kwargs: Any) -> None: if self.trainer.datamodule.has_labelled_data: self._reset_dataloader_for_stage(RunningStage.TRAINING) self._reset_dataloader_for_stage(RunningStage.VALIDATING) if self.trainer.datamodule.has_test: self._reset_dataloader_for_stage(RunningStage.TESTING) if self.trainer.datamodule.has_unlabelled_data: self._reset_dataloader_for_stage(RunningStage.PREDICTING) self.progress.increment_ready() def advance(self, *args: Any, **kwargs: Any) -> None: self.progress.increment_started() if self.trainer.datamodule.has_labelled_data: self.fit_loop.run() if self.trainer.datamodule.has_test: self._reset_testing() metrics = self.trainer.test_loop.run() if metrics: self.trainer.logger.log_metrics(metrics[0], step=self.trainer.global_step) if self.trainer.datamodule.has_unlabelled_data: self._reset_predicting() probabilities = self.trainer.predict_loop.run() self.trainer.datamodule.label(probabilities=probabilities) else: raise StopIteration self._reset_fitting() self.progress.increment_processed() def on_advance_end(self) -> None: if self.trainer.datamodule.has_unlabelled_data and self.should_reset_weights: # reload the weights to retrain from scratch with the new labelled data. self._lightning_module.load_state_dict(self._model_state_dict) self.progress.increment_completed() return super().on_advance_end() def on_run_end(self): self._datamodule_state_dict = self.trainer.datamodule.state_dict() self._reset_fitting() self._teardown() return super().on_run_end() def on_save_checkpoint(self) -> Dict: return {"datamodule_state_dict": self._datamodule_state_dict} def on_load_checkpoint(self, state_dict) -> None: self._datamodule_state_dict = state_dict.pop("datamodule_state_dict", None) def __getattr__(self, key): if key not in self.__dict__: return getattr(self.fit_loop, key) return self.__dict__[key] def _connect(self, model: LightningModule): if _PL_GREATER_EQUAL_1_5_0: self.trainer.training_type_plugin.connect(model) else: self.trainer.accelerator.connect(model) def _reset_fitting(self): self.trainer.state.fn = TrainerFn.FITTING self.trainer.training = True self.trainer.lightning_module.on_train_dataloader() self._connect(self._lightning_module) self.fit_loop.epoch_progress = Progress() def _reset_predicting(self): self.trainer.state.fn = TrainerFn.PREDICTING self.trainer.predicting = True self.trainer.lightning_module.on_predict_dataloader() self._connect(self.inference_model) def _reset_testing(self): self.trainer.state.fn = TrainerFn.TESTING self.trainer.state.status = TrainerStatus.RUNNING self.trainer.testing = True self.trainer.lightning_module.on_test_dataloader() self._connect(self._lightning_module) def _reset_dataloader_for_stage(self, running_state: RunningStage): dataloader_name = f"{_STAGES_PREFIX[running_state]}_dataloader" # If the dataloader exists, we reset it. dataloader = (getattr( self.trainer.datamodule, dataloader_name) if is_overridden( dataloader_name, self.trainer.datamodule) else None) if dataloader: if _PL_GREATER_EQUAL_1_5_0: setattr( self.trainer._data_connector, f"_{dataloader_name}_source", _DataLoaderSource(self.trainer.datamodule, dataloader_name), ) else: setattr( self.trainer.lightning_module, dataloader_name, _PatchDataLoader(dataloader(), running_state), ) setattr(self.trainer, dataloader_name, None) # TODO: Resolve this within PyTorch Lightning. try: getattr(self.trainer, f"reset_{dataloader_name}")( self.trainer.lightning_module) except MisconfigurationException: pass def _teardown(self) -> None: self.trainer.train_dataloader = None self.trainer.val_dataloaders = None self.trainer.test_dataloaders = None self.trainer.predict_dataloaders = None # Hack self.trainer.lightning_module.train_dataloader = None self.trainer.lightning_module.val_dataloader = None self.trainer.lightning_module.test_dataloader = None self.trainer.lightning_module.predict_dataloader = None
class FitLoop(Loop[None]): """This Loop iterates over the epochs to run the training. Args: min_epochs: The minimum number of epochs max_epochs: The maximum number of epochs, can be set -1 to turn this limit off """ def __init__( self, min_epochs: int = 0, max_epochs: int = 1000, ) -> None: super().__init__() if max_epochs < -1: # Allow max_epochs to be zero, since this will be handled by fit_loop.done raise MisconfigurationException( f"`max_epochs` must be a non-negative integer or -1. You passed in {max_epochs}." ) self.max_epochs = max_epochs self.min_epochs = min_epochs self.epoch_loop = TrainingEpochLoop() self.epoch_progress = Progress() self._is_fresh_start_epoch: bool = True self._outputs: _EPOCH_OUTPUTS_TYPE = [] self._data_fetcher: Optional[AbstractDataFetcher] = None @property def total_batch_idx(self) -> int: """Returns the current batch index (across epochs)""" return self.epoch_loop.total_batch_idx @property def batch_idx(self) -> int: """Returns the current batch index (within this epoch)""" return self.epoch_loop.batch_idx @property def split_idx(self) -> int: """Returns the index of the current batch split (within the current batch) for bptt.""" return self.epoch_loop.batch_loop.split_idx @property def min_steps(self) -> Optional[int]: # TODO(@justusschock): Why aren't we using the attribute in this class? """Returns the minimum number of steps to run.""" return self.epoch_loop.min_steps @min_steps.setter def min_steps(self, value: Optional[int]) -> None: """Sets the minimum number of steps (forwards to epoch_loop)""" # TODO(@awaelchli): This setter is required by debugging connector (fast dev run), should be avoided self.epoch_loop.min_steps = value @property def max_steps(self) -> int: """Returns the maximum number of steps to run.""" return self.epoch_loop.max_steps @max_steps.setter def max_steps(self, value: int) -> None: """Sets the maximum number of steps (forwards to epoch_loop)""" # TODO(@awaelchli): This setter is required by debugging connector (fast dev run), should be avoided if value is None: rank_zero_deprecation( "Setting `max_steps = None` is deprecated in v1.5 and will no longer be supported in v1.7." " Use `max_steps = -1` instead.") value = -1 elif value < -1: raise MisconfigurationException( f"`max_steps` must be a non-negative integer or -1 (infinite steps). You passed in {value}." ) self.epoch_loop.max_steps = value @property def running_loss(self) -> TensorRunningAccum: """Returns the running loss.""" return self.epoch_loop.batch_loop.running_loss @Loop.restarting.setter def restarting(self, restarting: bool) -> None: # if the last epoch completely finished, we are not actually restarting, we can check this to see if all # current values are equal values = ( self.epoch_progress.current.ready, self.epoch_progress.current.started, self.epoch_progress.current.processed, ) finished_before_on_train_end = any( v != self.epoch_progress.current.completed for v in values) if finished_before_on_train_end: self.epoch_progress.current.completed = self.epoch_progress.current.processed restarting &= finished_before_on_train_end Loop.restarting.fset(self, restarting) # call the parent setter @property def prefetch_batches(self) -> int: is_unsized = self.trainer.num_training_batches == float("inf") inter_batch_parallelism = os.getenv("PL_INTER_BATCH_PARALLELISM", "0") == "1" return 1 if is_unsized or inter_batch_parallelism else 0 @property def _skip_backward(self) -> bool: """Determines whether the loop will skip backward during automatic optimization.""" return self.epoch_loop.batch_loop.optimizer_loop._skip_backward @_skip_backward.setter def _skip_backward(self, value: bool) -> None: """Determines whether the loop will skip backward during automatic optimization.""" self.epoch_loop.batch_loop.optimizer_loop._skip_backward = value @property def _results(self) -> _ResultCollection: if self.trainer.training: return self.epoch_loop._results if self.trainer.validating: return self.epoch_loop.val_loop._results raise RuntimeError( "`FitLoop._results` property isn't defined. Accessed outside of scope" ) @property def done(self) -> bool: """Evaluates when to leave the loop.""" # TODO(@awaelchli): Move track steps inside training loop and move part of these condition inside training loop stop_steps = _is_max_limit_reached(self.epoch_loop.global_step, self.max_steps) # `processed` is increased before `on_train_epoch_end`, the hook where checkpoints are typically saved. # we use it here because the checkpoint data won't have `completed` increased yet stop_epochs = _is_max_limit_reached( self.epoch_progress.current.processed, self.max_epochs) should_stop = False if self.trainer.should_stop: # early stopping met_min_epochs = self.epoch_progress.current.processed >= self.min_epochs if self.min_epochs else True met_min_steps = self.epoch_loop.global_step >= self.min_steps if self.min_steps else True if met_min_epochs and met_min_steps: should_stop = True else: log.info( "Trainer was signaled to stop but required minimum epochs" f" ({self.min_epochs}) or minimum steps ({self.min_steps}) has" " not been met. Training will continue...") self.trainer.should_stop = should_stop return stop_steps or should_stop or stop_epochs or self.trainer.num_training_batches == 0 @property def skip(self) -> bool: """Whether we should skip the training and immediately return from the call to :meth:`run`.""" # since `trainer.num_training_batches` depends on the `train_dataloader` but that won't be called # until `on_run_start`, we use `limit_train_batches` instead return self.done or self.trainer.limit_train_batches == 0 def connect( self, epoch_loop: TrainingEpochLoop) -> None: # type: ignore[override] """Connects a training epoch loop to this fit loop.""" self.epoch_loop = epoch_loop def reset(self) -> None: """Resets the internal state of this loop.""" if self.restarting: self.epoch_progress.reset_on_restart() def on_run_start(self) -> None: # type: ignore[override] """Calls the ``on_train_start`` hook.""" # reset train dataloader and val dataloader self.trainer.reset_train_val_dataloaders(self.trainer.lightning_module) data_fetcher_cls = _select_data_fetcher(self.trainer) self._data_fetcher = data_fetcher_cls( prefetch_batches=self.prefetch_batches) self._is_fresh_start_epoch = True self._results.to(device=self.trainer.lightning_module.device) self.trainer._call_callback_hooks("on_train_start") self.trainer._call_lightning_module_hook("on_train_start") self.trainer._call_strategy_hook("on_train_start") def on_advance_start(self) -> None: # type: ignore[override] """Prepares the dataloader for training and calls the hooks ``on_epoch_start`` and ``on_train_epoch_start``""" model = self.trainer.lightning_module # reset train dataloader if not self._is_fresh_start_epoch and self.trainer._data_connector._should_reload_train_dl: log.detail( f"{self.__class__.__name__}: resetting train dataloader") self.trainer.reset_train_dataloader(model) self._is_fresh_start_epoch = False # reset outputs here instead of in `reset` as they are not accumulated between epochs self._outputs = [] if self.trainer.train_dataloader is not None and callable( getattr(self.trainer.train_dataloader.sampler, "set_epoch", None)): # set seed for distributed sampler (enables shuffling for each epoch) self.trainer.train_dataloader.sampler.set_epoch( self.epoch_progress.current.processed) # changing gradient according accumulation_scheduler self.trainer.accumulation_scheduler.on_train_epoch_start( self.trainer, self.trainer.lightning_module) # stores accumulated grad fractions per batch self.epoch_loop.batch_loop.accumulated_loss.reset( window_length=self.trainer.accumulate_grad_batches) self.epoch_progress.increment_ready() self.trainer._logger_connector.on_epoch_start() self.trainer._call_callback_hooks("on_epoch_start") self.trainer._call_lightning_module_hook("on_epoch_start") self.trainer._call_callback_hooks("on_train_epoch_start") self.trainer._call_lightning_module_hook("on_train_epoch_start") self.epoch_progress.increment_started() def advance(self) -> None: # type: ignore[override] """Runs one whole epoch.""" log.detail(f"{self.__class__.__name__}: advancing loop") assert self.trainer.train_dataloader is not None dataloader = self.trainer.train_dataloader assert self._data_fetcher is not None self._data_fetcher.setup(dataloader, batch_to_device=partial( self.trainer._call_strategy_hook, "batch_to_device", dataloader_idx=0)) with self.trainer.profiler.profile("run_training_epoch"): self._outputs = self.epoch_loop.run(self._data_fetcher) def on_advance_end(self) -> None: # inform logger the batch loop has finished self.trainer._logger_connector.epoch_end_reached() # get the model and call model.training_epoch_end model = self.trainer.lightning_module if is_overridden("training_epoch_end", model) and self._outputs: epoch_end_outputs = self.epoch_loop._prepare_outputs_training_epoch_end( self._outputs, lightning_module=model, num_optimizers=len(self.trainer.optimizers), ) # run lightning module hook training_epoch_end # refresh the result for custom logging at the epoch level epoch_end_outputs = self.trainer._call_lightning_module_hook( "training_epoch_end", epoch_end_outputs) if epoch_end_outputs is not None: raise MisconfigurationException( "`training_epoch_end` expects a return of None. " "HINT: remove the return statement in `training_epoch_end`." ) # free memory self._outputs = [] self.epoch_progress.increment_processed() # call train epoch end hooks self.trainer._call_callback_hooks("on_train_epoch_end") self.trainer._call_lightning_module_hook("on_train_epoch_end") self.trainer._call_callback_hooks("on_epoch_end") self.trainer._call_lightning_module_hook("on_epoch_end") self.trainer._logger_connector.on_epoch_end() if self.epoch_loop._num_ready_batches_reached(): self.epoch_loop.update_lr_schedulers( "epoch", update_plateau_schedulers=True) self.epoch_progress.increment_completed() # we manually decrease here because loggers expect that the same step is used when logging epoch-end metrics # even when the batch loop has finished self.epoch_loop._batches_that_stepped -= 1 # log epoch metrics self.trainer._logger_connector.update_train_epoch_metrics() self.epoch_loop._batches_that_stepped += 1 # if fault tolerant is enabled and process has been notified, exit. self.trainer._exit_gracefully_on_signal() def on_run_end(self) -> None: """Calls the ``on_train_end`` hook.""" log.detail(f"{self.__class__.__name__}: train run ended") # hook self.trainer._call_callback_hooks("on_train_end") self.trainer._call_lightning_module_hook("on_train_end") self.trainer._call_strategy_hook("on_train_end") def teardown(self) -> None: if self._data_fetcher is not None: self._data_fetcher.teardown() self._data_fetcher = None self.epoch_loop.teardown() def _should_accumulate(self) -> bool: """Whether the gradients should be accumulated.""" return self.epoch_loop._should_accumulate()
class FitLoop(Loop): """This Loop iterates over the epochs to run the training. Args: min_epochs: The minimum number of epochs max_epochs: The maximum number of epochs, can be set -1 to turn this limit off """ def __init__( self, min_epochs: Optional[int] = 1, max_epochs: int = 1000, ) -> None: super().__init__() if max_epochs < -1: # Allow max_epochs to be zero, since this will be handled by fit_loop.done raise MisconfigurationException( f"`max_epochs` must be a non-negative integer or -1. You passed in {max_epochs}." ) self.max_epochs = max_epochs self.min_epochs = min_epochs self.epoch_loop: Optional[TrainingEpochLoop] = None self.epoch_progress = Progress() self._is_fresh_start_epoch: bool = True @property def current_epoch(self) -> int: """Return the current epoch.""" return self.epoch_progress.current.completed @current_epoch.setter def current_epoch(self, value: int) -> None: """Setter for the current epoch.""" self.epoch_progress.current.completed = value @property def global_step(self) -> int: """Returns the global step.""" return self.epoch_loop.global_step @global_step.setter def global_step(self, value: int) -> None: """Sets the global step (forwards to epoch_loop)""" self.epoch_loop.global_step = value @property def total_batch_idx(self) -> int: """Returns the current batch index (across epochs)""" return self.epoch_loop.total_batch_idx @property def batch_idx(self) -> int: """Returns the current batch index (within this epoch)""" return self.epoch_loop.batch_idx @property def split_idx(self) -> int: """Returns the index of the current batch split (within the current batch) for bptt.""" return self.epoch_loop.batch_loop.split_idx @property def min_steps(self) -> int: # TODO(@justusschock): Why aren't we using the attribute in this class? """Returns the minimum numnber of steps to run.""" return self.epoch_loop.min_steps @min_steps.setter def min_steps(self, value: int) -> None: """Sets the minimum number of steps (forwards to epoch_loop)""" # TODO(@awaelchli): This setter is required by debugging connector (fast dev run), should be avoided self.epoch_loop.min_steps = value @property def max_steps(self) -> int: """Returns the maximum number of steps to run.""" return self.epoch_loop.max_steps @max_steps.setter def max_steps(self, value: int) -> None: """Sets the maximum number of steps (forwards to epoch_loop)""" # TODO(@awaelchli): This setter is required by debugging connector (fast dev run), should be avoided if value is None: rank_zero_deprecation( "Setting `max_steps = None` is deprecated in v1.5 and will no longer be supported in v1.7." " Use `max_steps = -1` instead." ) value = -1 elif value < -1: raise MisconfigurationException( f"`max_steps` must be a non-negative integer or -1 (infinite steps). You passed in {value}." ) self.epoch_loop.max_steps = value @property def running_loss(self) -> TensorRunningAccum: """Returns the running loss.""" return self.epoch_loop.batch_loop.running_loss @property def _skip_backward(self) -> bool: """Determines whether the loop will skip backward during automatic optimization.""" assert self.epoch_loop.batch_loop is not None assert self.epoch_loop.batch_loop.optimizer_loop is not None return self.epoch_loop.batch_loop.optimizer_loop._skip_backward @_skip_backward.setter def _skip_backward(self, value: bool) -> None: """Determines whether the loop will skip backward during automatic optimization.""" assert self.epoch_loop.batch_loop is not None assert self.epoch_loop.batch_loop.optimizer_loop is not None self.epoch_loop.batch_loop.optimizer_loop._skip_backward = value @property def _results(self) -> ResultCollection: if self.trainer.training: return self.epoch_loop._results if self.trainer.validating: return self.epoch_loop.val_loop._results raise RuntimeError("`FitLoop._results` property isn't defined. Accessed outside of scope") @property def done(self) -> bool: """Evaluates when to leave the loop. Returns True if trainer.should_stop was set (e.g. by early stopping) or if the maximum number of steps or epochs is reached. """ # TODO(@awaelchli): Move track steps inside training loop and move part of these condition inside training loop stop_steps = _is_max_limit_reached(self.global_step, self.max_steps) stop_epochs = _is_max_limit_reached(self.current_epoch, self.max_epochs) should_stop = False if self.trainer.should_stop: # early stopping met_min_epochs = self.current_epoch >= self.min_epochs if self.min_epochs else True met_min_steps = self.global_step >= self.min_steps if self.min_steps else True if met_min_epochs and met_min_steps: should_stop = True else: log.info( "Trainer was signaled to stop but required minimum epochs" f" ({self.min_epochs}) or minimum steps ({self.min_steps}) has" " not been met. Training will continue..." ) self.trainer.should_stop = should_stop return stop_steps or should_stop or stop_epochs or self.trainer.num_training_batches == 0 @property def skip(self) -> bool: """Whether we should skip the training and immediately return from the call to :meth:`run`.""" # since `trainer.num_training_batches` depends on the `train_dataloader` but that won't be called # until `on_run_start`, we use `limit_train_batches` instead return self.done or self.trainer.limit_train_batches == 0 def connect(self, epoch_loop: TrainingEpochLoop): """Connects a training epoch loop to this fit loop.""" self.epoch_loop = epoch_loop def reset(self) -> None: """Resets the internal state of this loop.""" if self.restarting: self.epoch_progress.reset_on_restart() def on_run_start(self) -> None: """Calls the ``on_train_start`` hook.""" # reset train dataloader and val dataloader self.trainer.reset_train_val_dataloaders(self.trainer.lightning_module) self._is_fresh_start_epoch = True self._results.to(device=self.trainer.lightning_module.device) self.trainer.call_hook("on_train_start") def on_advance_start(self) -> None: """Prepares the dataloader for training and calls the hooks ``on_epoch_start`` and ``on_train_epoch_start``""" model = self.trainer.lightning_module # reset train dataloader if not self._is_fresh_start_epoch and self.trainer._should_reload_dl_epoch: self.trainer.reset_train_dataloader(model) self._is_fresh_start_epoch = False if self.trainer.train_dataloader is not None and callable( getattr(self.trainer.train_dataloader.sampler, "set_epoch", None) ): # set seed for distributed sampler (enables shuffling for each epoch) self.trainer.train_dataloader.sampler.set_epoch(self.current_epoch) # changing gradient according accumulation_scheduler self.trainer.accumulation_scheduler.on_train_epoch_start(self.trainer, self.trainer.lightning_module) # stores accumulated grad fractions per batch self.epoch_loop.batch_loop.accumulated_loss = TensorRunningAccum( window_length=self.trainer.accumulate_grad_batches ) self.epoch_progress.increment_ready() def advance(self) -> None: """Runs one whole epoch.""" dataloader = self.trainer.training_type_plugin.process_dataloader(self.trainer.train_dataloader) data_fetcher = self.trainer._data_connector.get_profiled_dataloader(dataloader) with self.trainer.profiler.profile("run_training_epoch"): self.epoch_loop.run(data_fetcher) # the global step is manually decreased here due to backwards compatibility with existing loggers # as they expect that the same step is used when logging epoch end metrics even when the batch loop has # finished. this means the attribute does not exactly track the number of optimizer steps applied. # TODO(@carmocca): deprecate and rename so users don't get confused self.global_step -= 1 # log epoch metrics self.trainer.logger_connector.update_train_epoch_metrics() self.global_step += 1 def on_advance_end(self) -> None: self.epoch_progress.increment_completed() def on_run_end(self) -> None: """Calls the ``on_train_end`` hook.""" # NOTE: the current_epoch is already incremented # Lightning today does not increment the current epoch at the last epoch run in Trainer.fit # To simulate that current behavior, we decrement here. # TODO: must be fixed by https://github.com/PyTorchLightning/pytorch-lightning/issues/5007 self.current_epoch = max(self.current_epoch - 1, 0) # hook self.trainer.call_hook("on_train_end") # give accelerators a chance to finish self.trainer.training_type_plugin.on_train_end() def teardown(self) -> None: self.epoch_loop.teardown() def _should_accumulate(self) -> bool: """Whether the gradients should be accumulated.""" return self.epoch_loop._should_accumulate()
class PredictionEpochLoop(Loop): """Loop performing prediction on arbitrary sequentially used dataloaders.""" def __init__(self) -> None: super().__init__() self.return_predictions: bool = False self.predictions: List[Any] = [] self.current_batch_indices: List[int] = [] self.batch_progress = Progress() self._dl_max_batches: Optional[int] = None self._num_dataloaders: Optional[int] = None self._warning_cache = WarningCache() self._all_batch_indices: List[int] = [] @property def done(self) -> bool: """Ends prediction when the iteration count exceeds the total number of available batches.""" return self.batch_progress.current.completed >= self._dl_max_batches @property def should_store_predictions(self) -> bool: """Whether the predictions should be stored for later usage (e.g. aggregation or returning)""" any_pred = any(cb.interval.on_epoch for cb in self.trainer.prediction_writer_callbacks) return self.return_predictions or any_pred def connect(self, **kwargs: "Loop") -> None: raise NotImplementedError( f"{self.__class__.__name__} does not connect any child loops.") def reset(self) -> None: """Resets the loops internal state.""" self._all_batch_indices: List[int] = [] self.predictions: List[Any] = [] self.batch_progress.reset_on_run() def on_run_start( self, dataloader_iter: Iterator, dataloader_idx: int, dl_max_batches: int, num_dataloaders: int, return_predictions: bool = False, ) -> None: """Prepares the loops internal state. Args: dataloader_iter: the iterator over the current dataloader dataloader_idx: the index of the current dataloader dl_max_batches: the maximum number of batches the current loader can produce num_dataloaders: the total number of dataloaders return_predictions: whether to return the obtained predictions """ void(dataloader_iter, dataloader_idx) self._dl_max_batches = dl_max_batches self._num_dataloaders = num_dataloaders self.return_predictions = return_predictions def advance( self, dataloader_iter: Iterator, dataloader_idx: int, dl_max_batches: int, num_dataloaders: int, return_predictions: bool = False, ) -> None: """Runs one prediction step. Args: dataloader_iter: the iterator over the current dataloader dataloader_idx: the index of the current dataloader dl_max_batches: the maximum number of batches the current loader can produce num_dataloaders: the total number of dataloaders return_predictions: whether to return the obtained predictions """ batch_idx, batch = next(dataloader_iter) if batch is None: raise StopIteration with self.trainer.profiler.profile("predict_batch_to_device"): batch = self.trainer.accelerator.batch_to_device( batch, dataloader_idx=dataloader_idx) self.batch_progress.increment_ready() with self.trainer.profiler.profile("predict_step"): self._predict_step(batch, batch_idx, dataloader_idx) def on_run_end(self) -> Tuple[List[Any], List[int]]: """Returns the predictions and the corresponding batch indices.""" predictions = self.predictions all_batch_indices = self._all_batch_indices # free memory self.predictions = [] self._all_batch_indices = [] return predictions, all_batch_indices def _predict_step(self, batch: Any, batch_idx: int, dataloader_idx: int) -> None: """Runs the actual predict step together with all the necessary bookkeeping and the hooks tied to the predict step. Args: batch: the current batch to run the prediction on batch_idx: the index of the current batch dataloader_idx: the index of the dataloader producing the current batch """ # configure step_kwargs step_kwargs = self._build_kwargs(batch, batch_idx, dataloader_idx) # extract batch_indices and store them self._store_batch_indices(dataloader_idx) model_ref = self.trainer.lightning_module self.trainer.call_hook("on_predict_batch_start", batch, batch_idx, dataloader_idx) self.batch_progress.increment_started() model_ref._current_fx_name = "predict_step" predictions = self.trainer.accelerator.predict_step(step_kwargs) self.batch_progress.increment_processed() if predictions is None: self._warning_cache.warn( "predict returned None if it was on purpose, ignore this warning..." ) self.trainer.call_hook("on_predict_batch_end", predictions, batch, batch_idx, dataloader_idx) self.batch_progress.increment_completed() if self.should_store_predictions: self.predictions.append( move_data_to_device(predictions, torch.device("cpu"))) def _build_kwargs(self, batch: Any, batch_idx: int, dataloader_idx: int) -> Dict[str, Any]: """Assembles the keyword arguments for the ``predict_step`` Args: batch: the current batch to run the prediction on batch_idx: the index of the current batch dataloader_idx: the index of the dataloader producing the current batch Returns: the dictionary containing all the keyboard arguments for the predict step """ step_kwargs = OrderedDict([("batch", batch), ("batch_idx", batch_idx)]) if self._num_dataloaders > 1: step_kwargs["dataloader_idx"] = dataloader_idx return step_kwargs def _store_batch_indices(self, dataloader_idx: int) -> None: """Stores the batch indices if the predictions should be stored.""" batch_sampler = self.trainer.predict_dataloaders[ dataloader_idx].batch_sampler if isinstance(batch_sampler, IndexBatchSamplerWrapper): self.current_batch_indices = batch_sampler.batch_indices if self.should_store_predictions: self._all_batch_indices.append(batch_sampler.batch_indices) else: warning_cache.warn( "Lightning couldn't infer the indices fetched for your dataloader." )