def setup(self, max_batches: List[Union[int, float]], dataloaders: List[DataLoader]) -> None: # bookkeeping self.outputs = [] self.predictions = PredictionCollection(self.trainer.global_rank, self.trainer.world_size) # convert max_batches to list if isinstance(max_batches, int): max_batches = [max_batches] * len(dataloaders) self.max_batches = max_batches self.num_dataloaders = self._get_num_dataloaders(dataloaders)
def setup(self, model, max_batches, dataloaders): # bookkeeping self.outputs = [] self.predictions = PredictionCollection(self.trainer.global_rank, self.trainer.world_size) # convert max_batches to list if isinstance(max_batches, int): max_batches = [max_batches] * len(dataloaders) self.max_batches = max_batches self.num_dataloaders = self._get_num_dataloaders(dataloaders) self._predictions = [[] for _ in range(self.num_dataloaders)]
def setup(self, model, max_batches, dataloaders): # copy properties for forward overrides self.trainer.model_connector.copy_trainer_model_properties(model) # bookkeeping self.outputs = [] self.predictions = PredictionCollection(self.trainer.global_rank, self.trainer.world_size) # convert max_batches to list if isinstance(max_batches, int): max_batches = [max_batches] * len(dataloaders) self.max_batches = max_batches
def setup(self, model, max_batches, dataloaders): # enable eval mode model.zero_grad() model.eval() # copy properties for forward overrides self.trainer.copy_trainer_model_properties(model) # disable gradients to save memory torch.set_grad_enabled(False) # bookkeeping self.outputs = [] self.predictions = PredictionCollection(self.trainer.global_rank, self.trainer.world_size) # convert max_batches to list if isinstance(max_batches, int): max_batches = [max_batches] * len(dataloaders) self.max_batches = max_batches
class EvaluationLoop(object): def __init__(self, trainer): self.trainer = trainer self.outputs = [] self.step_metrics = [] self.predictions = None self.max_batches = None self.warning_cache = WarningCache() self.num_dataloaders = None def on_trainer_init(self): self.trainer.num_val_batches = [] self.trainer.num_sanity_val_batches = [] self.trainer.num_test_batches = [] self.trainer.test_dataloaders = None self.trainer.val_dataloaders = None self.trainer.running_sanity_check = False # when .test() is called, it sets this self.trainer.tested_ckpt_path = None # when true, prints test results self.trainer.verbose_test = True def get_evaluation_dataloaders(self, max_batches): # select dataloaders model = self.trainer.lightning_module # select dataloaders if self.trainer.testing: self.trainer.reset_test_dataloader(model) dataloaders = self.trainer.test_dataloaders new_max_batches = self.trainer.num_test_batches else: # val in_sanity_check = self.trainer.running_sanity_check should_reload_every_epoch = self.trainer.reload_dataloaders_every_epoch if (self.trainer.val_dataloaders is None or should_reload_every_epoch) and not in_sanity_check: self.trainer.reset_val_dataloader(model) dataloaders = self.trainer.val_dataloaders new_max_batches = self.trainer.num_val_batches if max_batches is None: max_batches = new_max_batches return dataloaders, max_batches def should_skip_evaluation(self, max_batches): return sum(max_batches) == 0 def on_evaluation_start(self, *args, **kwargs): if self.trainer.testing: self.trainer.call_hook('on_test_start', *args, **kwargs) else: self.trainer.call_hook('on_validation_start', *args, **kwargs) def on_evaluation_model_eval(self, *_, **__): model_ref = self.trainer.lightning_module if self.trainer.testing: model_ref.on_test_model_eval() else: model_ref.on_validation_model_eval() def on_evaluation_model_train(self, *_, **__): model_ref = self.trainer.lightning_module if self.trainer.testing: model_ref.on_test_model_train() else: model_ref.on_validation_model_train() def on_evaluation_end(self, *args, **kwargs): if self.trainer.testing: self.trainer.call_hook('on_test_end', *args, **kwargs) else: self.trainer.call_hook('on_validation_end', *args, **kwargs) def reload_evaluation_dataloaders(self): model = self.trainer.lightning_module if self.trainer.testing: self.trainer.reset_test_dataloader(model) else: self.trainer.reset_val_dataloader(model) def setup(self, model, max_batches, dataloaders): # bookkeeping self.outputs = [] self.predictions = PredictionCollection(self.trainer.global_rank, self.trainer.world_size) # convert max_batches to list if isinstance(max_batches, int): max_batches = [max_batches] * len(dataloaders) self.max_batches = max_batches self.num_dataloaders = self._get_num_dataloaders(dataloaders) self._predictions = [[] for _ in range(self.num_dataloaders)] def on_evaluation_epoch_start(self, *args, **kwargs): if self.trainer.testing: self.trainer.call_hook('on_test_epoch_start', *args, **kwargs) else: self.trainer.call_hook('on_validation_epoch_start', *args, **kwargs) def _build_args(self, batch, batch_idx, dataloader_idx): # make dataloader_idx arg in validation_step optional args = [batch, batch_idx] multiple_val_loaders = ( not self.trainer.testing and self._get_num_dataloaders(self.trainer.val_dataloaders) > 1 ) multiple_test_loaders = (self.trainer.testing and self._get_num_dataloaders(self.trainer.test_dataloaders) > 1) if multiple_test_loaders or multiple_val_loaders: args.append(dataloader_idx) return args def _get_num_dataloaders(self, dataloaders): # case where user does: # return dl1, dl2 length = len(dataloaders) if len(dataloaders) > 0 and isinstance(dataloaders[0], (list, tuple)): length = len(dataloaders[0]) return length def evaluation_step(self, batch, batch_idx, dataloader_idx): # configure args args = self._build_args(batch, batch_idx, dataloader_idx) model_ref = self.trainer.lightning_module model_ref._results = Result() if self.testing: model_ref._current_fx_name = "test_step" with self.trainer.profiler.profile("test_step"): output = self.trainer.accelerator.test_step(args) else: model_ref._current_fx_name = "validation_step" with self.trainer.profiler.profile("validation_step"): output = self.trainer.accelerator.validation_step(args) # capture any logged information self.trainer.logger_connector.cache_logged_metrics() # track batch size for weighted average is_result_obj = isinstance(output, Result) if is_result_obj: output.track_batch_size(batch) return output def evaluation_step_end(self, *args, **kwargs): if self.trainer.testing: output = self.trainer.call_hook('test_step_end', *args, **kwargs) else: output = self.trainer.call_hook('validation_step_end', *args, **kwargs) return output def evaluation_epoch_end(self): # unset dataloder_idx in model self.trainer.logger_connector.evaluation_epoch_end(self.trainer.testing) # call the model epoch end deprecated_results = self.__run_eval_epoch_end(self.num_dataloaders) # enable returning anything for i, r in enumerate(deprecated_results): if not isinstance(r, (dict, Result, torch.Tensor)): deprecated_results[i] = [] return deprecated_results def log_epoch_metrics_on_evaluation_end(self): # get the final loop results eval_loop_results = self.trainer.logger_connector.get_evaluate_epoch_results() return eval_loop_results def __run_eval_epoch_end(self, num_dataloaders): model = self.trainer.lightning_module # with a single dataloader don't pass an array outputs = self.outputs eval_results = outputs if num_dataloaders == 1: eval_results = outputs[0] user_reduced = False if self.trainer.testing: if is_overridden('test_epoch_end', model=model): model._current_fx_name = 'test_epoch_end' eval_results = model.test_epoch_end(eval_results) user_reduced = True else: if is_overridden('validation_epoch_end', model=model): model._current_fx_name = 'validation_epoch_end' eval_results = model.validation_epoch_end(eval_results) user_reduced = True # capture logging self.trainer.logger_connector.cache_logged_metrics() # depre warning if eval_results is not None and user_reduced: step = 'testing_epoch_end' if self.trainer.testing else 'validation_epoch_end' self.warning_cache.warn( f'The {step} should not return anything as of 9.1.' ' To log, use self.log(...) or self.write(...) directly in the LightningModule' ) if not isinstance(eval_results, list): eval_results = [eval_results] # track depreceated metrics self.trainer.logger_connector.track_metrics_deprecated(eval_results) return eval_results def __gather_epoch_end_eval_results(self, outputs): eval_results = [] for epoch_output in outputs: result = epoch_output[0].__class__.gather(epoch_output) if 'checkpoint_on' in result: result.checkpoint_on = result.checkpoint_on.mean() if 'early_stop_on' in result: result.early_stop_on = result.early_stop_on.mean() eval_results.append(result) # with 1 dataloader don't pass in a list if len(eval_results) == 1: eval_results = eval_results[0] return eval_results def __auto_reduce_result_objs(self, outputs): # outputs has a list of results per dataloader eval_results = [] for dl_output in outputs: result = dl_output[0] result = result.__class__.reduce_on_epoch_end(dl_output) if 'checkpoint_on' in result: result.checkpoint_on = result.checkpoint_on.mean() if 'early_stop_on' in result: result.early_stop_on = result.early_stop_on.mean() eval_results.append(result) return eval_results def on_predict_epoch_end(self): self.trainer._progress_bar_callback.on_test_end(self.trainer, self.trainer.lightning_module) results = self._predictions def _convert_to_numpy(v): return v.cpu().numpy() results = apply_to_collection(results, torch.Tensor, _convert_to_numpy) return results, None def on_evaluation_batch_start(self, batch, batch_idx, dataloader_idx): # set dataloader_idx to model and track batch_size self.trainer.logger_connector.on_evaluation_batch_start( self.trainer.testing, batch, dataloader_idx, self.num_dataloaders ) if self.trainer.testing: self.trainer.call_hook('on_test_batch_start', batch, batch_idx, dataloader_idx) else: self.trainer.call_hook('on_validation_batch_start', batch, batch_idx, dataloader_idx) def on_evaluation_batch_end(self, output, batch, batch_idx, dataloader_idx): if self.trainer.testing: self.trainer.call_hook('on_test_batch_end', output, batch, batch_idx, dataloader_idx) else: self.trainer.call_hook('on_validation_batch_end', output, batch, batch_idx, dataloader_idx) # store predicitons if do_write_predictions and track eval loss history self.store_predictions(output, batch_idx, dataloader_idx) def store_predictions(self, output, batch_idx, dataloader_idx): # Add step predictions to prediction collection to write later if output is not None: do_write_predictions = isinstance(output, Result) and self.trainer.testing if do_write_predictions: self.predictions.add(output.pop('predictions', None)) # track debug metrics self.trainer.dev_debugger.track_eval_loss_history(batch_idx, dataloader_idx, output) def on_evaluation_epoch_end(self, *args, **kwargs): # call the callback hook if self.trainer.testing: self.trainer.call_hook('on_test_epoch_end', *args, **kwargs) else: self.trainer.call_hook('on_validation_epoch_end', *args, **kwargs) self.trainer.call_hook('on_epoch_end') def log_evaluation_step_metrics(self, output, batch_idx): if self.trainer.running_sanity_check: return step_log_metrics = {} step_pbar_metrics = {} self.__log_result_step_metrics(step_log_metrics, step_pbar_metrics, batch_idx) def __log_result_step_metrics(self, step_log_metrics, step_pbar_metrics, batch_idx): cached_results = self.trainer.logger_connector.cached_results cached_batch_pbar_metrics, cached_batch_log_metrics = cached_results.update_logger_connector() step_log_metrics.update(cached_batch_log_metrics) step_pbar_metrics.update(cached_batch_pbar_metrics) if len(step_log_metrics) > 0: # make the metrics appear as a different line in the same graph metrics_by_epoch = {} for k, v in step_log_metrics.items(): metrics_by_epoch[f'{k}/epoch_{self.trainer.current_epoch}'] = v self.trainer.logger_connector.log_metrics(metrics_by_epoch, {}, step=batch_idx) if len(step_pbar_metrics) > 0: self.trainer.logger_connector.add_progress_bar_metrics(step_pbar_metrics)
class EvaluationLoop(object): def __init__(self, trainer): self.trainer = trainer self.testing = False self.outputs = [] self.predictions = None self.max_batches = None def on_trainer_init(self): self.trainer.num_val_batches = [] self.trainer.num_sanity_val_batches = [] self.trainer.num_test_batches = [] self.trainer.test_dataloaders = None self.trainer.val_dataloaders = None self.trainer.running_sanity_check = False self.trainer.testing = False # when .test() is called, it sets this self.trainer.tested_ckpt_path = None # when true, prints test results self.trainer.verbose_test = True def get_evaluation_dataloaders(self, max_batches): # select dataloaders model = self.trainer.get_model() # select dataloaders if self.testing: self.trainer.reset_test_dataloader(model) dataloaders = self.trainer.test_dataloaders new_max_batches = self.trainer.num_test_batches else: # val in_sanity_check = self.trainer.running_sanity_check should_reload_every_epoch = self.trainer.reload_dataloaders_every_epoch if (self.trainer.val_dataloaders is None or should_reload_every_epoch) and not in_sanity_check: self.trainer.reset_val_dataloader(model) dataloaders = self.trainer.val_dataloaders new_max_batches = self.trainer.num_val_batches if max_batches is None: max_batches = new_max_batches return dataloaders, max_batches def should_skip_evaluation(self, dataloaders, max_batches): # skip when dataloaders aren't defined if dataloaders is None: return True # enable disabling validation step with limit_val_batches = 0 should_skip = sum(max_batches) == 0 if should_skip: return True return False def on_evaluation_start(self, *args, **kwargs): if self.testing: self.trainer.call_hook('on_test_start', *args, **kwargs) else: self.trainer.call_hook('on_validation_start', *args, **kwargs) def on_evaluation_end(self, *args, **kwargs): if self.testing: self.trainer.call_hook('on_test_end', *args, **kwargs) else: self.trainer.call_hook('on_validation_end', *args, **kwargs) def reload_evaluation_dataloaders(self): model = self.trainer.get_model() if self.testing: self.trainer.reset_test_dataloader(model) else: self.trainer.reset_val_dataloader(model) def is_using_eval_results(self): outputs = self.outputs using_eval_result = len(outputs) > 0 and len( outputs[0]) > 0 and isinstance(outputs[0][0], EvalResult) return using_eval_result def setup(self, model, max_batches, dataloaders): # copy properties for forward overrides self.trainer.model_connector.copy_trainer_model_properties(model) # bookkeeping self.outputs = [] self.predictions = PredictionCollection(self.trainer.global_rank, self.trainer.world_size) # convert max_batches to list if isinstance(max_batches, int): max_batches = [max_batches] * len(dataloaders) self.max_batches = max_batches def on_evaluation_epoch_start(self, *args, **kwargs): if self.testing: self.trainer.call_hook('on_test_epoch_start', *args, **kwargs) else: self.trainer.call_hook('on_validation_epoch_start', *args, **kwargs) def build_args(self, test_mode, batch, batch_idx, dataloader_idx): # make dataloader_idx arg in validation_step optional args = [batch, batch_idx] multiple_val_loaders = (not test_mode and len(self.trainer.val_dataloaders) > 1) multiple_test_loaders = (test_mode and len(self.trainer.test_dataloaders) > 1) if multiple_test_loaders or multiple_val_loaders: args.append(dataloader_idx) return args def evaluation_step(self, test_mode, batch, batch_idx, dataloader_idx): # configure args args = self.build_args(test_mode, batch, batch_idx, dataloader_idx) # run actual test step if self.testing: output = self.trainer.accelerator_backend.test_step(args) else: output = self.trainer.accelerator_backend.validation_step(args) # track batch size for weighted average is_result_obj = isinstance(output, Result) if is_result_obj: output.track_batch_size(len(batch)) # allow only EvalResult when using structured results (from val_step) if is_result_obj and not isinstance(output, EvalResult): m = 'only EvalResults or dicts are allowed from validation_step' raise MisconfigurationException(m) return output def evaluation_step_end(self, *args, **kwargs): if self.testing: output = self.trainer.call_hook('test_step_end', *args, **kwargs) else: output = self.trainer.call_hook('validation_step_end', *args, **kwargs) return output def evaluation_epoch_end(self, num_dataloaders): using_eval_result = self.is_using_eval_results() # call the model epoch end eval_results = self.__run_eval_epoch_end(num_dataloaders, using_eval_result) return eval_results def log_epoch_metrics(self, eval_results, test_mode): using_eval_result = self.is_using_eval_results() eval_loop_results = self.trainer.logger_connector.on_evaluation_epoch_end( eval_results, using_eval_result, test_mode) return eval_loop_results def __run_eval_epoch_end(self, num_dataloaders, using_eval_result): model = self.trainer.get_model() # with a single dataloader don't pass an array outputs = self.outputs eval_results = outputs if num_dataloaders == 1: eval_results = outputs[0] user_reduced = False if self.testing: if is_overridden('test_epoch_end', model=model): if using_eval_result: eval_results = self.__gather_epoch_end_eval_results( outputs) eval_results = model.test_epoch_end(eval_results) user_reduced = True else: if is_overridden('validation_epoch_end', model=model): if using_eval_result: eval_results = self.__gather_epoch_end_eval_results( outputs) eval_results = model.validation_epoch_end(eval_results) user_reduced = True if using_eval_result and not user_reduced: eval_results = self.__auto_reduce_result_objs(outputs) if not isinstance(eval_results, list): eval_results = [eval_results] return eval_results def __gather_epoch_end_eval_results(self, outputs): eval_results = [] for epoch_output in outputs: result = epoch_output[0].__class__.gather(epoch_output) if 'checkpoint_on' in result: result.checkpoint_on = result.checkpoint_on.mean() if 'early_stop_on' in result: result.early_stop_on = result.early_stop_on.mean() eval_results.append(result) # with 1 dataloader don't pass in a list if len(eval_results) == 1: eval_results = eval_results[0] return eval_results def __auto_reduce_result_objs(self, outputs): # outputs has a list of results per dataloader eval_results = [] for dl_output in outputs: result = dl_output[0] result = result.__class__.reduce_on_epoch_end(dl_output) if 'checkpoint_on' in result: result.checkpoint_on = result.checkpoint_on.mean() if 'early_stop_on' in result: result.early_stop_on = result.early_stop_on.mean() eval_results.append(result) return eval_results def on_evaluation_batch_start(self, *args, **kwargs): # reset the result of the PL module model = self.trainer.get_model() model._results = Result() model._current_fx_name = 'evaluation_step' if self.testing: self.trainer.call_hook('on_test_batch_start', *args, **kwargs) else: self.trainer.call_hook('on_validation_batch_start', *args, **kwargs) def on_evaluation_batch_end(self, *args, **kwargs): if self.testing: self.trainer.call_hook('on_test_batch_end', *args, **kwargs) else: self.trainer.call_hook('on_validation_batch_end', *args, **kwargs) def evaluation_batch_end_cleanup(self, output, batch_idx, dataloader_idx): # Add step predictions to prediction collection to write later if output is not None: do_write_predictions = isinstance(output, Result) and self.testing if do_write_predictions: self.predictions.add(output.pop('predictions', None)) # track debug metrics self.trainer.dev_debugger.track_eval_loss_history( self.testing, batch_idx, dataloader_idx, output) def on_evaluation_epoch_end(self, *args, **kwargs): # call the callback hook if self.testing: self.trainer.call_hook('on_test_epoch_end', *args, **kwargs) else: self.trainer.call_hook('on_validation_epoch_end', *args, **kwargs) def log_evaluation_step_metrics(self, output, batch_idx): if self.trainer.running_sanity_check: return results = self.trainer.get_model()._results self.__log_result_step_metrics(results, batch_idx) # TODO: deprecate at 1.0 if isinstance(output, EvalResult): self.__log_result_step_metrics(output, batch_idx) def __log_result_step_metrics(self, output, batch_idx): step_log_metrics = output.batch_log_metrics step_pbar_metrics = output.batch_pbar_metrics if len(step_log_metrics) > 0: # make the metrics appear as a different line in the same graph metrics_by_epoch = {} for k, v in step_log_metrics.items(): metrics_by_epoch[f'{k}/epoch_{self.trainer.current_epoch}'] = v self.trainer.logger_connector.log_metrics(metrics_by_epoch, {}, step=batch_idx) if len(step_pbar_metrics) > 0: self.trainer.logger_connector.add_progress_bar_metrics( step_pbar_metrics)
class EvaluationEpochLoop(Loop): """ This is the loop performing the evaluation. It mainly loops over the given dataloader and runs the validation or test step (depending on the trainer's current state). """ def __init__(self) -> None: super().__init__() self.predictions: Optional[PredictionCollection] = None self.dataloader: Optional[Iterator] = None self._dl_max_batches: Optional[int] = None self._num_dataloaders: Optional[int] = None self.outputs: List[STEP_OUTPUT] = [] self.batch_progress = Progress() @property def done(self) -> bool: """Returns ``True`` if the current iteration count reaches the number of dataloader batches.""" return self.batch_progress.current.completed >= self._dl_max_batches def connect(self, **kwargs: "Loop") -> None: raise NotImplementedError( f"{self.__class__.__name__} does not connect any child loops.") def reset(self) -> None: """Resets the loop's internal state.""" self.predictions = PredictionCollection(self.trainer.global_rank, self.trainer.world_size) self._dl_max_batches = None self._num_dataloaders = None self.outputs = [] if not self.restarting: self.batch_progress.current.reset() def on_run_start( self, dataloader_iter: Iterator, dataloader_idx: int, dl_max_batches: int, num_dataloaders: int, ) -> None: """Adds the passed arguments to the loop's state if necessary Args: dataloader_iter: iterator over the dataloader dataloader_idx: index of the current dataloader dl_max_batches: maximum number of batches the dataloader can produce num_dataloaders: the total number of dataloaders """ void(dataloader_iter, dataloader_idx) self._dl_max_batches = dl_max_batches self._num_dataloaders = num_dataloaders def advance( self, dataloader_iter: Iterator, dataloader_idx: int, dl_max_batches: int, num_dataloaders: int, ) -> None: """Calls the evaluation step with the corresponding hooks and updates the logger connector. Args: dataloader_iter: iterator over the dataloader dataloader_idx: index of the current dataloader dl_max_batches: maximum number of batches the dataloader can produce num_dataloaders: the total number of dataloaders Raises: StopIteration: If the current batch is None """ void(dl_max_batches, num_dataloaders) batch_idx, batch = next(dataloader_iter) if batch is None: raise StopIteration with self.trainer.profiler.profile("evaluation_batch_to_device"): batch = self.trainer.accelerator.batch_to_device( batch, dataloader_idx=dataloader_idx) self.batch_progress.increment_ready() # hook self.on_evaluation_batch_start(batch, batch_idx, dataloader_idx) self.batch_progress.increment_started() # lightning module methods with self.trainer.profiler.profile("evaluation_step_and_end"): output = self.evaluation_step(batch, batch_idx, dataloader_idx) output = self.evaluation_step_end(output) self.batch_progress.increment_processed() # hook + store predictions self.on_evaluation_batch_end(output, batch, batch_idx, dataloader_idx) self.batch_progress.increment_completed() # log batch metrics self.trainer.logger_connector.update_eval_step_metrics() # track epoch level outputs self.outputs = self._track_output_for_epoch_end(self.outputs, output) def on_run_end(self) -> List[STEP_OUTPUT]: """Returns the outputs of the whole run""" outputs = self.outputs # free memory self.outputs = [] return outputs def evaluation_step(self, batch: Any, batch_idx: int, dataloader_idx: int) -> Optional[STEP_OUTPUT]: """The evaluation step (validation_step or test_step depending on the trainer's state). Args: batch: The current batch to run through the step. batch_idx: The index of the current batch dataloader_idx: the index of the dataloader producing the current batch Returns: the outputs of the step """ # configure step_kwargs step_kwargs = self._build_kwargs(batch, batch_idx, dataloader_idx) if self.trainer.testing: self.trainer.lightning_module._current_fx_name = "test_step" with self.trainer.profiler.profile("test_step"): output = self.trainer.accelerator.test_step(step_kwargs) else: self.trainer.lightning_module._current_fx_name = "validation_step" with self.trainer.profiler.profile("validation_step"): output = self.trainer.accelerator.validation_step(step_kwargs) return output def evaluation_step_end(self, *args: Any, **kwargs: Any) -> Optional[STEP_OUTPUT]: """Calls the `{validation/test}_step_end` hook""" hook_name = "test_step_end" if self.trainer.testing else "validation_step_end" output = self.trainer.call_hook(hook_name, *args, **kwargs) return output def on_evaluation_batch_start(self, batch: Any, batch_idx: int, dataloader_idx: int) -> None: """Calls the ``on_{validation/test}_batch_start`` hook. Args: batch: The current batch to run through the step batch_idx: The index of the current batch dataloader_idx: The index of the dataloader producing the current batch Raises: AssertionError: If the number of dataloaders is None (has not yet been set). """ self.trainer.logger_connector.on_batch_start() assert self._num_dataloaders is not None self.trainer.logger_connector.on_evaluation_batch_start( batch, batch_idx, dataloader_idx, self._num_dataloaders) if self.trainer.testing: self.trainer.call_hook("on_test_batch_start", batch, batch_idx, dataloader_idx) else: self.trainer.call_hook("on_validation_batch_start", batch, batch_idx, dataloader_idx) def on_evaluation_batch_end( self, output: Optional[STEP_OUTPUT], batch: Any, batch_idx: int, dataloader_idx: int, ) -> None: """The ``on_{validation/test}_batch_end`` hook. Args: output: The output of the performed step batch: The input batch for the step batch_idx: The index of the current batch dataloader_idx: Index of the dataloader producing the current batch """ hook_name = "on_test_batch_end" if self.trainer.testing else "on_validation_batch_end" self.trainer.call_hook(hook_name, output, batch, batch_idx, dataloader_idx) self.trainer.logger_connector.on_batch_end() # store predicitons if do_write_predictions and track eval loss history self.store_predictions(output, batch_idx, dataloader_idx) def store_predictions(self, output: Optional[STEP_OUTPUT], batch_idx: int, dataloader_idx: int) -> None: """Stores the predictions in the prediction collection (only if running in test mode) Args: output: the outputs of the current step batch_idx: the index of the current batch dataloader_idx: the index of the dataloader producing the current batch """ # Add step predictions to prediction collection to write later if output is not None and self.predictions is not None: if isinstance(output, ResultCollection) and self.trainer.testing: self.predictions.add(output.pop("predictions", None)) # track debug metrics self.trainer.dev_debugger.track_eval_loss_history( batch_idx, dataloader_idx, output) def _build_kwargs(self, batch: Any, batch_idx: int, dataloader_idx: int) -> Dict[str, Union[Any, int]]: """Helper function to build the arguments for the current step Args: batch: The current batch to run through the step batch_idx: the index of the current batch dataloader_idx: the index of the dataloader producing the current batch Returns: the keyword arguments to pass to the step function """ # make dataloader_idx arg in validation_step optional step_kwargs = OrderedDict([("batch", batch), ("batch_idx", batch_idx)]) multiple_val_loaders = not self.trainer.testing and self._num_dataloaders > 1 multiple_test_loaders = self.trainer.testing and self._num_dataloaders > 1 if multiple_test_loaders or multiple_val_loaders: step_kwargs["dataloader_idx"] = dataloader_idx return step_kwargs def _track_output_for_epoch_end( self, outputs: List[Union[ResultCollection, Dict, Tensor]], output: Optional[Union[ResultCollection, Dict, Tensor]], ) -> List[Union[ResultCollection, Dict, Tensor]]: if output is not None: if isinstance(output, ResultCollection): output = output.detach() if self.trainer.move_metrics_to_cpu: output = output.cpu() elif isinstance(output, dict): output = recursive_detach( output, to_cpu=self.trainer.move_metrics_to_cpu) elif isinstance( output, Tensor ) and output.is_cuda and self.trainer.move_metrics_to_cpu: output = output.cpu() outputs.append(output) return outputs
class EvaluationLoop(object): def __init__(self, trainer: 'pl.Trainer'): self.trainer: 'pl.Trainer' = trainer self.outputs: EPOCH_OUTPUT = [] self.predictions: Optional[PredictionCollection] = None self.max_batches: Optional[List[Union[int, float]]] = None self.warning_cache = WarningCache() self.num_dataloaders: Optional[int] = None def on_trainer_init(self) -> None: self.trainer.num_sanity_val_batches = [] self.trainer.num_test_batches = [] self.trainer.num_val_batches = [] self.trainer.test_dataloaders = None self.trainer.val_dataloaders = None # .validate() and .test() set this when they load a checkpoint self.trainer.validated_ckpt_path = None self.trainer.tested_ckpt_path = None # when true, print evaluation results in .validate() and .test() self.trainer.verbose_evaluate = True def get_evaluation_dataloaders( self ) -> Tuple[Optional[List[DataLoader]], List[Union[int, float]]]: model = self.trainer.lightning_module # select dataloaders if self.trainer.testing: self.trainer.reset_test_dataloader(model) dataloaders = self.trainer.test_dataloaders max_batches = self.trainer.num_test_batches else: # val if self.trainer.val_dataloaders is None or self.trainer.reload_dataloaders_every_epoch: self.trainer.reset_val_dataloader(model) if self.trainer.sanity_checking: self.trainer.num_sanity_val_batches = [ min(self.trainer.num_sanity_val_steps, val_batches) for val_batches in self.trainer.num_val_batches ] max_batches = self.trainer.num_sanity_val_batches else: max_batches = self.trainer.num_val_batches dataloaders = self.trainer.val_dataloaders return dataloaders, max_batches def should_skip_evaluation(self, max_batches: List[Union[int, float]]) -> bool: return sum(max_batches) == 0 def on_evaluation_start(self, *args: Any, **kwargs: Any) -> None: if self.trainer.testing: self.trainer.call_hook('on_test_start', *args, **kwargs) else: self.trainer.call_hook('on_validation_start', *args, **kwargs) def on_evaluation_model_eval(self) -> None: model_ref = self.trainer.lightning_module if self.trainer.testing: model_ref.on_test_model_eval() else: model_ref.on_validation_model_eval() def on_evaluation_model_train(self) -> None: model_ref = self.trainer.lightning_module if self.trainer.testing: model_ref.on_test_model_train() else: model_ref.on_validation_model_train() def on_evaluation_end(self, *args: Any, **kwargs: Any) -> None: if self.trainer.testing: self.trainer.call_hook('on_test_end', *args, **kwargs) else: self.trainer.call_hook('on_validation_end', *args, **kwargs) if self.trainer.state.fn != TrainerFn.FITTING: # summarize profile results self.trainer.profiler.describe() def reload_evaluation_dataloaders(self) -> None: model = self.trainer.lightning_module if self.trainer.testing: self.trainer.reset_test_dataloader(model) else: self.trainer.reset_val_dataloader(model) def setup(self, max_batches: List[Union[int, float]], dataloaders: List[DataLoader]) -> None: # bookkeeping self.outputs = [] self.predictions = PredictionCollection(self.trainer.global_rank, self.trainer.world_size) # convert max_batches to list if isinstance(max_batches, int): max_batches = [max_batches] * len(dataloaders) self.max_batches = max_batches self.num_dataloaders = self._get_num_dataloaders(dataloaders) def on_evaluation_epoch_start(self, *args: Any, **kwargs: Any) -> None: self.trainer.call_hook('on_epoch_start', *args, **kwargs) if self.trainer.testing: self.trainer.call_hook('on_test_epoch_start', *args, **kwargs) else: self.trainer.call_hook('on_validation_epoch_start', *args, **kwargs) def _build_args(self, batch: Any, batch_idx: int, dataloader_idx: int) -> List[Union[Any, int]]: # make dataloader_idx arg in validation_step optional args = [batch, batch_idx] multiple_val_loaders = ( not self.trainer.testing and self._get_num_dataloaders(self.trainer.val_dataloaders) > 1) multiple_test_loaders = ( self.trainer.testing and self._get_num_dataloaders(self.trainer.test_dataloaders) > 1) if multiple_test_loaders or multiple_val_loaders: args.append(dataloader_idx) return args def _get_num_dataloaders(self, dataloaders: Optional[List[DataLoader]]) -> int: # case where user does: # return dl1, dl2 if dataloaders is not None: length = len(dataloaders) if len(dataloaders) > 0 and isinstance(dataloaders[0], (list, tuple)): length = len(dataloaders[0]) return length else: return 0 def evaluation_step(self, batch: Any, batch_idx: int, dataloader_idx: int) -> Optional[STEP_OUTPUT]: # configure args args = self._build_args(batch, batch_idx, dataloader_idx) model_ref = self.trainer.lightning_module model_ref._results = Result() if self.trainer.testing: model_ref._current_fx_name = "test_step" with self.trainer.profiler.profile("test_step"): output = self.trainer.accelerator.test_step(args) else: model_ref._current_fx_name = "validation_step" with self.trainer.profiler.profile("validation_step"): output = self.trainer.accelerator.validation_step(args) # capture any logged information self.trainer.logger_connector.cache_logged_metrics() # track batch size for weighted average if isinstance(output, Result): output.track_batch_size(batch) return output def evaluation_step_end(self, *args: Any, **kwargs: Any) -> Optional[STEP_OUTPUT]: if self.trainer.testing: output = self.trainer.call_hook('test_step_end', *args, **kwargs) else: output = self.trainer.call_hook('validation_step_end', *args, **kwargs) return output def evaluation_epoch_end(self, outputs: EPOCH_OUTPUT) -> None: # unset dataloder_idx in model self.trainer.logger_connector.evaluation_epoch_end() # call the model epoch end model = self.trainer.lightning_module if self.trainer.testing: if is_overridden('test_epoch_end', model=model): model._current_fx_name = 'test_epoch_end' model.test_epoch_end(outputs) else: if is_overridden('validation_epoch_end', model=model): model._current_fx_name = 'validation_epoch_end' model.validation_epoch_end(outputs) # capture logging self.trainer.logger_connector.cache_logged_metrics() def on_evaluation_batch_start(self, batch: Any, batch_idx: int, dataloader_idx: int) -> None: # set dataloader_idx to model and track batch_size self.trainer.logger_connector.on_evaluation_batch_start( batch, dataloader_idx, self.num_dataloaders) if self.trainer.testing: self.trainer.call_hook('on_test_batch_start', batch, batch_idx, dataloader_idx) else: self.trainer.call_hook('on_validation_batch_start', batch, batch_idx, dataloader_idx) def on_evaluation_batch_end( self, output: Optional[STEP_OUTPUT], batch: Any, batch_idx: int, dataloader_idx: int, ) -> None: if self.trainer.testing: self.trainer.call_hook('on_test_batch_end', output, batch, batch_idx, dataloader_idx) else: self.trainer.call_hook('on_validation_batch_end', output, batch, batch_idx, dataloader_idx) # store predicitons if do_write_predictions and track eval loss history self.store_predictions(output, batch_idx, dataloader_idx) def store_predictions(self, output: Optional[STEP_OUTPUT], batch_idx: int, dataloader_idx: int) -> None: # Add step predictions to prediction collection to write later if output is not None and self.predictions is not None: if isinstance(output, Result) and self.trainer.testing: self.predictions.add(output.pop('predictions', None)) # track debug metrics self.trainer.dev_debugger.track_eval_loss_history( batch_idx, dataloader_idx, output) def on_evaluation_epoch_end( self, outputs: Union[List[List[Dict]], List[Dict]]) -> None: model_ref = self.trainer.lightning_module hook_name = "on_test_epoch_end" if self.trainer.testing else "on_validation_epoch_end" self.trainer._reset_result_and_set_hook_fx_name(hook_name) with self.trainer.profiler.profile(hook_name): if hasattr(self.trainer, hook_name): on_evaluation_epoch_end_hook = getattr(self.trainer, hook_name) on_evaluation_epoch_end_hook(outputs) if is_overridden(hook_name, model_ref): model_hook_fx = getattr(model_ref, hook_name) if is_param_in_hook_signature(model_hook_fx, "outputs"): model_hook_fx(outputs) else: self.warning_cache.warn( f"`ModelHooks.{hook_name}` signature has changed in v1.3. `outputs` parameter has been added." " Support for the old signature will be removed in v1.5", DeprecationWarning) model_hook_fx() self.trainer._cache_logged_metrics() self.trainer.call_hook('on_epoch_end') def log_evaluation_step_metrics(self, batch_idx: int) -> None: if self.trainer.sanity_checking: return cached_results = self.trainer.logger_connector.cached_results if cached_results is not None: cached_batch_pbar_metrics, cached_batch_log_metrics = cached_results.update_logger_connector( ) if len(cached_batch_log_metrics) > 0: # make the metrics appear as a different line in the same graph metrics_by_epoch = {} for k, v in cached_batch_log_metrics.items(): metrics_by_epoch[ f'{k}/epoch_{self.trainer.current_epoch}'] = v self.trainer.logger_connector.log_metrics(metrics_by_epoch, {}, step=batch_idx) if len(cached_batch_pbar_metrics) > 0: self.trainer.logger_connector.add_progress_bar_metrics( cached_batch_pbar_metrics)
class EvaluationLoop(object): def __init__(self, trainer): self.trainer = trainer self.testing = False self.outputs = [] self.predictions = None self.max_batches = None def is_using_eval_results(self): outputs = self.outputs using_eval_result = len(outputs) > 0 and len( outputs[0]) > 0 and isinstance(outputs[0][0], EvalResult) return using_eval_result def setup(self, model, max_batches, dataloaders): # enable eval mode model.zero_grad() model.eval() # copy properties for forward overrides self.trainer.copy_trainer_model_properties(model) # disable gradients to save memory torch.set_grad_enabled(False) # bookkeeping self.outputs = [] self.predictions = PredictionCollection(self.trainer.global_rank, self.trainer.world_size) # convert max_batches to list if isinstance(max_batches, int): max_batches = [max_batches] * len(dataloaders) self.max_batches = max_batches def on_evaluation_epoch_start(self, *args, **kwargs): if self.testing: self.trainer.call_hook('on_test_epoch_start', *args, **kwargs) else: self.trainer.call_hook('on_validation_epoch_start', *args, **kwargs) def build_args(self, test_mode, batch, batch_idx, dataloader_idx): # make dataloader_idx arg in validation_step optional args = [batch, batch_idx] multiple_val_loaders = (not test_mode and len(self.trainer.val_dataloaders) > 1) multiple_test_loaders = (test_mode and len(self.trainer.test_dataloaders) > 1) if multiple_test_loaders or multiple_val_loaders: args.append(dataloader_idx) return args def evaluation_step(self, test_mode, batch, batch_idx, dataloader_idx): # configure args args = self.build_args(test_mode, batch, batch_idx, dataloader_idx) # run actual test step if self.testing: output = self.trainer.accelerator_backend.test_step(args) else: output = self.trainer.accelerator_backend.validation_step(args) # track batch size for weighted average is_result_obj = isinstance(output, Result) if is_result_obj: output.track_batch_size(len(batch)) # allow only EvalResult when using structured results (from val_step) if is_result_obj and not isinstance(output, EvalResult): m = 'only EvalResults or dicts are allowed from validation_step' raise MisconfigurationException(m) return output def evaluation_step_end(self, *args, **kwargs): if self.testing: output = self.trainer.call_hook('test_step_end', *args, **kwargs) else: output = self.trainer.call_hook('validation_step_end', *args, **kwargs) return output def on_evaluation_batch_start(self, *args, **kwargs): if self.testing: self.trainer.call_hook('on_test_batch_start', *args, **kwargs) else: self.trainer.call_hook('on_validation_batch_start', *args, **kwargs) def on_evaluation_batch_end(self, *args, **kwargs): if self.testing: self.trainer.call_hook('on_test_batch_end', *args, **kwargs) else: self.trainer.call_hook('on_validation_batch_end', *args, **kwargs) def evaluation_batch_end_cleanup(self, output, batch_idx, dataloader_idx): # Add step predictions to prediction collection to write later if output is not None: do_write_predictions = isinstance(output, Result) and self.testing if do_write_predictions: self.predictions.add(output.pop('predictions', None)) # track debug metrics self.trainer.dev_debugger.track_eval_loss_history( self.testing, batch_idx, dataloader_idx, output) def on_evaluation_epoch_end(self, *args, **kwargs): if self.testing: self.trainer.call_hook('on_test_epoch_end', *args, **kwargs) else: self.trainer.call_hook('on_validation_epoch_end', *args, **kwargs) def log_metrics(self, output, batch_idx): if self.trainer.running_sanity_check: return if isinstance(output, EvalResult): step_log_metrics = output.batch_log_metrics step_pbar_metrics = output.batch_pbar_metrics if len(step_log_metrics) > 0: # make the metrics appear as a different line in the same graph metrics_by_epoch = {} for k, v in step_log_metrics.items(): metrics_by_epoch[ f'{k}/epoch_{self.trainer.current_epoch}'] = v self.trainer.log_metrics(metrics_by_epoch, {}, step=batch_idx) if len(step_pbar_metrics) > 0: self.trainer.add_progress_bar_metrics(step_pbar_metrics)
def _evaluate(self, model: LightningModule, dataloaders: List[DataLoader], max_batches: Union[int, List[int]], test_mode: bool = False): """Run evaluation code. Args: model: The model to evaluate. dataloaders: A list of PyTorch dataloaders. max_batches: An integer or list of integers with length of the number of dataloaders. Each entry is the number of batches to process in the corresponding dataloader. test_mode: """ # enable eval mode model.zero_grad() model.eval() # copy properties for forward overrides self.copy_trainer_model_properties(model) # disable gradients to save memory torch.set_grad_enabled(False) # bookkeeping outputs = [] predictions = PredictionCollection(self.global_rank, self.world_size) # convert max_batches to list if isinstance(max_batches, int): max_batches = [max_batches] * len(dataloaders) # -------------------------- # ON_EVAL_EPOCH_START hook # -------------------------- self.__call_eval_loop_hook_start(test_mode) # run validation for dataloader_idx, dataloader in enumerate(dataloaders): dl_outputs = [] # on TPU we have to wrap it under the ParallelLoader if self.use_tpu: device = xm.xla_device(self.tpu_id) dataloader = xla_pl.ParallelLoader(dataloader, [device]) dataloader = dataloader.per_device_loader(device) # each dataloader has a max num batches dl_max_batches = max_batches[dataloader_idx] for batch_idx, batch in enumerate(dataloader): if batch is None: continue # stop short when running on limited batches if batch_idx >= dl_max_batches: break # callbacks if test_mode: self.on_test_batch_start(batch, batch_idx, dataloader_idx) if self.is_overridden('on_test_batch_start'): model_ref = self.get_model() with self.profiler.profile('on_test_batch_start'): model_ref.on_test_batch_start(output) else: self.on_validation_batch_start(batch, batch_idx, dataloader_idx) if self.is_overridden('on_validation_batch_start'): model_ref = self.get_model() with self.profiler.profile( 'on_validation_batch_start'): model_ref.on_validation_batch_start(output) # ----------------- # RUN EVALUATION STEP # ----------------- if self.amp_backend == AMPType.NATIVE and not self.use_tpu: with torch.cuda.amp.autocast(): output = self.evaluation_forward( model, batch, batch_idx, dataloader_idx, test_mode) else: output = self.evaluation_forward(model, batch, batch_idx, dataloader_idx, test_mode) is_result_obj = isinstance(output, Result) # track batch size for weighted average if is_result_obj: output.track_batch_size(len(batch)) # allow only EvalResult when using structured results (from val_step) if is_result_obj and not isinstance(output, EvalResult): m = 'only EvalResults or dicts are allowed from validation_step' raise MisconfigurationException(m) # ------------------ # EVAL STEP END # ------------------ # on dp / ddp2 might still want to do something with the batch parts eval_step_end_hook_name = 'test_step_end' if test_mode else 'validation_step_end' if self.is_overridden(eval_step_end_hook_name): model_ref = self.get_model() with self.profiler.profile(eval_step_end_hook_name): eval_step_end = getattr(model_ref, eval_step_end_hook_name) output = eval_step_end(output) elif is_result_obj and (self.use_dp or self.use_ddp2): # result auto reduce output.dp_reduce() # callbacks (on __batch_end) if test_mode: self.on_test_batch_end(batch, batch_idx, dataloader_idx) if self.is_overridden('on_test_batch_end'): model_ref = self.get_model() with self.profiler.profile('on_test_batch_end'): model_ref.on_test_batch_end(output) else: self.on_validation_batch_end(batch, batch_idx, dataloader_idx) if self.is_overridden('on_validation_batch_end'): model_ref = self.get_model() with self.profiler.profile('on_validation_batch_end'): model_ref.on_validation_batch_end(output) # track outputs for collation if output is not None: # Add step predictions to prediction collection to write later do_write_predictions = is_result_obj and test_mode if do_write_predictions: predictions.add(output.pop('predictions', None)) dl_outputs.append(output) self.__eval_add_step_metrics(output, batch_idx) # track debug metrics self.dev_debugger.track_eval_loss_history( test_mode, batch_idx, dataloader_idx, output) outputs.append(dl_outputs) # --------------------- # EVAL_EPOCH_END # --------------------- using_eval_result = len(outputs) > 0 and len( outputs[0]) > 0 and isinstance(outputs[0][0], EvalResult) eval_results = self.__run_eval_epoch_end(test_mode, outputs, dataloaders, using_eval_result) # log callback metrics self.__update_callback_metrics(eval_results, using_eval_result) # Write predictions to disk if they're available. predictions.to_disk() # enable train mode again model.train() # enable gradients to save memory torch.set_grad_enabled(True) # -------------------------- # ON_EVAL_EPOCH_END hook # -------------------------- self.__call_eval_loop_hook_end(test_mode) return eval_results
class EvaluationLoop(object): def __init__(self, trainer: 'pl.Trainer'): self.trainer: 'pl.Trainer' = trainer self.outputs: EPOCH_OUTPUT = [] self.predictions: Optional[PredictionCollection] = None self.max_batches: Optional[List[Union[int, float]]] = None self.warning_cache = WarningCache() self.num_dataloaders: Optional[int] = None def on_trainer_init(self) -> None: self.trainer.num_sanity_val_batches = [] self.trainer.num_test_batches = [] self.trainer.num_val_batches = [] self.trainer.test_dataloaders = None self.trainer.val_dataloaders = None # .validate() and .test() set this when they load a checkpoint self.trainer.validated_ckpt_path = None self.trainer.tested_ckpt_path = None # when true, print evaluation results in .validate() and .test() self.trainer.verbose_evaluate = True def get_evaluation_dataloaders(self) -> Tuple[Optional[List[DataLoader]], List[Union[int, float]]]: model = self.trainer.lightning_module # select dataloaders if self.trainer.testing: self.trainer.reset_test_dataloader(model) dataloaders = self.trainer.test_dataloaders max_batches = self.trainer.num_test_batches else: # val if self.trainer.val_dataloaders is None or self.trainer.reload_dataloaders_every_epoch: self.trainer.reset_val_dataloader(model) if self.trainer.sanity_checking: self.trainer.num_sanity_val_batches = [ min(self.trainer.num_sanity_val_steps, val_batches) for val_batches in self.trainer.num_val_batches ] max_batches = self.trainer.num_sanity_val_batches else: max_batches = self.trainer.num_val_batches dataloaders = self.trainer.val_dataloaders return dataloaders, max_batches def should_skip_evaluation(self, max_batches: List[Union[int, float]]) -> bool: return sum(max_batches) == 0 def on_evaluation_start(self, *args: Any, **kwargs: Any) -> None: self.should_track_batch_outputs_for_epoch_end: bool = self._should_track_batch_outputs_for_epoch_end() if self.trainer.testing: self.trainer.call_hook('on_test_start', *args, **kwargs) else: self.trainer.call_hook('on_validation_start', *args, **kwargs) def on_evaluation_model_eval(self) -> None: model_ref = self.trainer.lightning_module if self.trainer.testing: model_ref.on_test_model_eval() else: model_ref.on_validation_model_eval() def on_evaluation_model_train(self) -> None: model_ref = self.trainer.lightning_module if self.trainer.testing: model_ref.on_test_model_train() else: model_ref.on_validation_model_train() def on_evaluation_end(self, *args: Any, **kwargs: Any) -> None: if self.trainer.testing: self.trainer.call_hook('on_test_end', *args, **kwargs) else: self.trainer.call_hook('on_validation_end', *args, **kwargs) if self.trainer.state.fn != TrainerFn.FITTING: # summarize profile results self.trainer.profiler.describe() def reload_evaluation_dataloaders(self) -> None: model = self.trainer.lightning_module if self.trainer.testing: self.trainer.reset_test_dataloader(model) else: self.trainer.reset_val_dataloader(model) def setup(self, max_batches: List[Union[int, float]], dataloaders: List[DataLoader]) -> None: # bookkeeping self.outputs = [] self.predictions = PredictionCollection(self.trainer.global_rank, self.trainer.world_size) # convert max_batches to list if isinstance(max_batches, int): max_batches = [max_batches] * len(dataloaders) self.max_batches = max_batches self.num_dataloaders = self._get_num_dataloaders(dataloaders) def on_evaluation_epoch_start(self, *args: Any, **kwargs: Any) -> None: self.trainer.call_hook('on_epoch_start', *args, **kwargs) if self.trainer.testing: self.trainer.call_hook('on_test_epoch_start', *args, **kwargs) else: self.trainer.call_hook('on_validation_epoch_start', *args, **kwargs) def _build_kwargs(self, batch: Any, batch_idx: int, dataloader_idx: int) -> Dict[str, Union[Any, int]]: # make dataloader_idx arg in validation_step optional step_kwargs = OrderedDict([('batch', batch), ('batch_idx', batch_idx)]) multiple_val_loaders = ( not self.trainer.testing and self._get_num_dataloaders(self.trainer.val_dataloaders) > 1 ) multiple_test_loaders = (self.trainer.testing and self._get_num_dataloaders(self.trainer.test_dataloaders) > 1) if multiple_test_loaders or multiple_val_loaders: step_kwargs['dataloader_idx'] = dataloader_idx return step_kwargs def _get_num_dataloaders(self, dataloaders: Optional[List[DataLoader]]) -> int: # case where user does: # return dl1, dl2 if dataloaders is not None: length = len(dataloaders) if len(dataloaders) > 0 and isinstance(dataloaders[0], (list, tuple)): length = len(dataloaders[0]) return length else: return 0 def evaluation_step(self, batch: Any, batch_idx: int, dataloader_idx: int) -> Optional[STEP_OUTPUT]: # configure step_kwargs step_kwargs = self._build_kwargs(batch, batch_idx, dataloader_idx) model_ref = self.trainer.lightning_module model_ref._results = Result() if self.trainer.testing: model_ref._current_fx_name = "test_step" with self.trainer.profiler.profile("test_step"): output = self.trainer.accelerator.test_step(step_kwargs) else: model_ref._current_fx_name = "validation_step" with self.trainer.profiler.profile("validation_step"): output = self.trainer.accelerator.validation_step(step_kwargs) # capture any logged information self.trainer.logger_connector.cache_logged_metrics() # track batch size for weighted average if isinstance(output, Result): output.track_batch_size(batch) return output def evaluation_step_end(self, *args: Any, **kwargs: Any) -> Optional[STEP_OUTPUT]: if self.trainer.testing: output = self.trainer.call_hook('test_step_end', *args, **kwargs) else: output = self.trainer.call_hook('validation_step_end', *args, **kwargs) return output def _should_track_batch_outputs_for_epoch_end(self) -> bool: model = self.trainer.lightning_module if self.trainer.testing: return is_overridden('test_epoch_end', model=model) else: return is_overridden('validation_epoch_end', model=model) def evaluation_epoch_end(self, outputs: EPOCH_OUTPUT) -> None: # unset dataloder_idx in model self.trainer.logger_connector.evaluation_epoch_end() # call the model epoch end model = self.trainer.lightning_module if self.trainer.testing: if is_overridden('test_epoch_end', model=model): model._current_fx_name = 'test_epoch_end' model.test_epoch_end(outputs) else: if is_overridden('validation_epoch_end', model=model): model._current_fx_name = 'validation_epoch_end' model.validation_epoch_end(outputs) # capture logging self.trainer.logger_connector.cache_logged_metrics() def on_evaluation_batch_start(self, batch: Any, batch_idx: int, dataloader_idx: int) -> None: # set dataloader_idx to model and track batch_size self.trainer.logger_connector.on_evaluation_batch_start(batch, dataloader_idx, self.num_dataloaders) if self.trainer.testing: self.trainer.call_hook('on_test_batch_start', batch, batch_idx, dataloader_idx) else: self.trainer.call_hook('on_validation_batch_start', batch, batch_idx, dataloader_idx) def on_evaluation_batch_end( self, output: Optional[STEP_OUTPUT], batch: Any, batch_idx: int, dataloader_idx: int, ) -> None: if self.trainer.testing: self.trainer.call_hook('on_test_batch_end', output, batch, batch_idx, dataloader_idx) else: self.trainer.call_hook('on_validation_batch_end', output, batch, batch_idx, dataloader_idx) # store predicitons if do_write_predictions and track eval loss history self.store_predictions(output, batch_idx, dataloader_idx) def store_predictions(self, output: Optional[STEP_OUTPUT], batch_idx: int, dataloader_idx: int) -> None: # Add step predictions to prediction collection to write later if output is not None and self.predictions is not None: if isinstance(output, Result) and self.trainer.testing: self.predictions.add(output.pop('predictions', None)) # track debug metrics self.trainer.dev_debugger.track_eval_loss_history(batch_idx, dataloader_idx, output) def on_evaluation_epoch_end(self) -> None: model_ref = self.trainer.lightning_module hook_name = "on_test_epoch_end" if self.trainer.testing else "on_validation_epoch_end" self.trainer._reset_result_and_set_hook_fx_name(hook_name) with self.trainer.profiler.profile(hook_name): if hasattr(self.trainer, hook_name): on_evaluation_epoch_end_hook = getattr(self.trainer, hook_name) on_evaluation_epoch_end_hook() if is_overridden(hook_name, model_ref): model_hook_fx = getattr(model_ref, hook_name) model_hook_fx() self.trainer._cache_logged_metrics() self.trainer.call_hook('on_epoch_end')
class EvaluationLoop(object): def __init__(self, trainer): self.trainer = trainer self.outputs = [] self.step_metrics = [] self.predictions = None self.max_batches = None self.warning_cache = WarningCache() self.num_dataloaders = None def on_trainer_init(self): self.trainer.num_sanity_val_batches = [] self.trainer.num_test_batches = [] self.trainer.num_val_batches = [] self.trainer.test_dataloaders = None self.trainer.val_dataloaders = None # .validate() and .test() set this when they load a checkpoint self.trainer.validated_ckpt_path = None self.trainer.tested_ckpt_path = None # when true, print evaluation results in .validate() and .test() self.trainer.verbose_evaluate = True def get_evaluation_dataloaders(self): model = self.trainer.lightning_module # select dataloaders if self.trainer.testing: self.trainer.reset_test_dataloader(model) dataloaders = self.trainer.test_dataloaders max_batches = self.trainer.num_test_batches else: # val if self.trainer.val_dataloaders is None or self.trainer.reload_dataloaders_every_epoch: self.trainer.reset_val_dataloader(model) if self.trainer.sanity_checking: self.trainer.num_sanity_val_batches = [ min(self.trainer.num_sanity_val_steps, val_batches) for val_batches in self.trainer.num_val_batches ] max_batches = self.trainer.num_sanity_val_batches else: max_batches = self.trainer.num_val_batches dataloaders = self.trainer.val_dataloaders return dataloaders, max_batches def should_skip_evaluation(self, max_batches): return sum(max_batches) == 0 def on_evaluation_start(self, *args, **kwargs): if self.trainer.testing: self.trainer.call_hook('on_test_start', *args, **kwargs) else: self.trainer.call_hook('on_validation_start', *args, **kwargs) def on_evaluation_model_eval(self, *_, **__): model_ref = self.trainer.lightning_module if self.trainer.testing: model_ref.on_test_model_eval() else: model_ref.on_validation_model_eval() def on_evaluation_model_train(self, *_, **__): model_ref = self.trainer.lightning_module if self.trainer.testing: model_ref.on_test_model_train() else: model_ref.on_validation_model_train() def on_evaluation_end(self, *args, **kwargs): if self.trainer.testing: self.trainer.call_hook('on_test_end', *args, **kwargs) else: self.trainer.call_hook('on_validation_end', *args, **kwargs) if self.trainer.state != TrainerState.FITTING: # summarize profile results self.trainer.profiler.describe() def reload_evaluation_dataloaders(self): model = self.trainer.lightning_module if self.trainer.testing: self.trainer.reset_test_dataloader(model) else: self.trainer.reset_val_dataloader(model) def setup(self, model, max_batches, dataloaders): # bookkeeping self.outputs = [] self.predictions = PredictionCollection(self.trainer.global_rank, self.trainer.world_size) # convert max_batches to list if isinstance(max_batches, int): max_batches = [max_batches] * len(dataloaders) self.max_batches = max_batches self.num_dataloaders = self._get_num_dataloaders(dataloaders) self._predictions = [[] for _ in range(self.num_dataloaders)] def on_evaluation_epoch_start(self, *args, **kwargs): self.trainer.call_hook('on_epoch_start', *args, **kwargs) if self.trainer.testing: self.trainer.call_hook('on_test_epoch_start', *args, **kwargs) else: self.trainer.call_hook('on_validation_epoch_start', *args, **kwargs) def _build_args(self, batch, batch_idx, dataloader_idx): # make dataloader_idx arg in validation_step optional args = [batch, batch_idx] multiple_val_loaders = ( not self.trainer.testing and self._get_num_dataloaders(self.trainer.val_dataloaders) > 1) multiple_test_loaders = ( self.trainer.testing and self._get_num_dataloaders(self.trainer.test_dataloaders) > 1) if multiple_test_loaders or multiple_val_loaders: args.append(dataloader_idx) return args def _get_num_dataloaders(self, dataloaders): # case where user does: # return dl1, dl2 length = len(dataloaders) if len(dataloaders) > 0 and isinstance(dataloaders[0], (list, tuple)): length = len(dataloaders[0]) return length def evaluation_step(self, batch, batch_idx, dataloader_idx): # configure args args = self._build_args(batch, batch_idx, dataloader_idx) model_ref = self.trainer.lightning_module model_ref._results = Result() if self.trainer.testing: model_ref._current_fx_name = "test_step" with self.trainer.profiler.profile("test_step"): output = self.trainer.accelerator.test_step(args) else: model_ref._current_fx_name = "validation_step" with self.trainer.profiler.profile("validation_step"): output = self.trainer.accelerator.validation_step(args) # capture any logged information self.trainer.logger_connector.cache_logged_metrics() # track batch size for weighted average is_result_obj = isinstance(output, Result) if is_result_obj: output.track_batch_size(batch) return output def evaluation_step_end(self, *args, **kwargs): if self.trainer.testing: output = self.trainer.call_hook('test_step_end', *args, **kwargs) else: output = self.trainer.call_hook('validation_step_end', *args, **kwargs) return output def evaluation_epoch_end(self, outputs): # unset dataloder_idx in model self.trainer.logger_connector.evaluation_epoch_end() # call the model epoch end model = self.trainer.lightning_module if self.trainer.testing: if is_overridden('test_epoch_end', model=model): model._current_fx_name = 'test_epoch_end' model.test_epoch_end(outputs) else: if is_overridden('validation_epoch_end', model=model): model._current_fx_name = 'validation_epoch_end' model.validation_epoch_end(outputs) # capture logging self.trainer.logger_connector.cache_logged_metrics() def __gather_epoch_end_eval_results(self, outputs): eval_results = [] for epoch_output in outputs: result = epoch_output[0].__class__.gather(epoch_output) eval_results.append(result) # with 1 dataloader don't pass in a list if len(eval_results) == 1: eval_results = eval_results[0] return eval_results def __auto_reduce_result_objs(self, outputs): # outputs has a list of results per dataloader eval_results = [] for dl_output in outputs: result = dl_output[0] result = result.__class__.reduce_on_epoch_end(dl_output) eval_results.append(result) return eval_results def on_predict_epoch_end(self): self.trainer._progress_bar_callback.on_test_end( self.trainer, self.trainer.lightning_module) results = self._predictions def _convert_to_numpy(v): return v.cpu().numpy() results = apply_to_collection(results, torch.Tensor, _convert_to_numpy) return results, None def on_evaluation_batch_start(self, batch, batch_idx, dataloader_idx): # set dataloader_idx to model and track batch_size self.trainer.logger_connector.on_evaluation_batch_start( batch, dataloader_idx, self.num_dataloaders) if self.trainer.testing: self.trainer.call_hook('on_test_batch_start', batch, batch_idx, dataloader_idx) else: self.trainer.call_hook('on_validation_batch_start', batch, batch_idx, dataloader_idx) def on_evaluation_batch_end(self, output, batch, batch_idx, dataloader_idx): if self.trainer.testing: self.trainer.call_hook('on_test_batch_end', output, batch, batch_idx, dataloader_idx) else: self.trainer.call_hook('on_validation_batch_end', output, batch, batch_idx, dataloader_idx) # store predicitons if do_write_predictions and track eval loss history self.store_predictions(output, batch_idx, dataloader_idx) def store_predictions(self, output, batch_idx, dataloader_idx): # Add step predictions to prediction collection to write later if output is not None: do_write_predictions = isinstance(output, Result) and self.trainer.testing if do_write_predictions: self.predictions.add(output.pop('predictions', None)) # track debug metrics self.trainer.dev_debugger.track_eval_loss_history( batch_idx, dataloader_idx, output) def on_evaluation_epoch_end( self, outputs: Union[List[List[Dict]], List[Dict]]) -> None: model_ref = self.trainer.lightning_module hook_name = "on_test_epoch_end" if self.trainer.testing else "on_validation_epoch_end" self.trainer._reset_result_and_set_hook_fx_name(hook_name) with self.trainer.profiler.profile(hook_name): if hasattr(self.trainer, hook_name): on_evaluation_epoch_end_hook = getattr(self.trainer, hook_name) on_evaluation_epoch_end_hook(outputs) if is_overridden(hook_name, model_ref): model_hook_fx = getattr(model_ref, hook_name) if is_param_in_hook_signature(model_hook_fx, "outputs"): model_hook_fx(outputs) else: self.warning_cache.warn( f"`ModelHooks.{hook_name}` signature has changed in v1.3. `outputs` parameter has been added." " Support for the old signature will be removed in v1.5", DeprecationWarning) model_hook_fx() self.trainer._cache_logged_metrics() self.trainer.call_hook('on_epoch_end') def log_evaluation_step_metrics(self, output, batch_idx): if self.trainer.sanity_checking: return step_log_metrics = {} step_pbar_metrics = {} self.__log_result_step_metrics(step_log_metrics, step_pbar_metrics, batch_idx) def __log_result_step_metrics(self, step_log_metrics, step_pbar_metrics, batch_idx): cached_results = self.trainer.logger_connector.cached_results cached_batch_pbar_metrics, cached_batch_log_metrics = cached_results.update_logger_connector( ) step_log_metrics.update(cached_batch_log_metrics) step_pbar_metrics.update(cached_batch_pbar_metrics) if len(step_log_metrics) > 0: # make the metrics appear as a different line in the same graph metrics_by_epoch = {} for k, v in step_log_metrics.items(): metrics_by_epoch[f'{k}/epoch_{self.trainer.current_epoch}'] = v self.trainer.logger_connector.log_metrics(metrics_by_epoch, {}, step=batch_idx) if len(step_pbar_metrics) > 0: self.trainer.logger_connector.add_progress_bar_metrics( step_pbar_metrics)
class EvaluationLoop(object): def __init__(self, trainer): self.trainer = trainer self.testing = False self.outputs = [] self.predictions = None self.max_batches = None def is_using_eval_results(self): outputs = self.outputs using_eval_result = len(outputs) > 0 and len(outputs[0]) > 0 and isinstance(outputs[0][0], EvalResult) return using_eval_result def setup(self, model, max_batches, dataloaders): # copy properties for forward overrides self.trainer.copy_trainer_model_properties(model) # bookkeeping self.outputs = [] self.predictions = PredictionCollection(self.trainer.global_rank, self.trainer.world_size) # convert max_batches to list if isinstance(max_batches, int): max_batches = [max_batches] * len(dataloaders) self.max_batches = max_batches def on_evaluation_epoch_start(self, *args, **kwargs): if self.testing: self.trainer.call_hook('on_test_epoch_start', *args, **kwargs) else: self.trainer.call_hook('on_validation_epoch_start', *args, **kwargs) def build_args(self, test_mode, batch, batch_idx, dataloader_idx): # make dataloader_idx arg in validation_step optional args = [batch, batch_idx] multiple_val_loaders = (not test_mode and len(self.trainer.val_dataloaders) > 1) multiple_test_loaders = (test_mode and len(self.trainer.test_dataloaders) > 1) if multiple_test_loaders or multiple_val_loaders: args.append(dataloader_idx) return args def evaluation_step(self, test_mode, batch, batch_idx, dataloader_idx): # configure args args = self.build_args(test_mode, batch, batch_idx, dataloader_idx) # run actual test step if self.testing: output = self.trainer.accelerator_backend.test_step(args) else: output = self.trainer.accelerator_backend.validation_step(args) # track batch size for weighted average is_result_obj = isinstance(output, Result) if is_result_obj: output.track_batch_size(len(batch)) # allow only EvalResult when using structured results (from val_step) if is_result_obj and not isinstance(output, EvalResult): m = 'only EvalResults or dicts are allowed from validation_step' raise MisconfigurationException(m) return output def evaluation_step_end(self, *args, **kwargs): if self.testing: output = self.trainer.call_hook('test_step_end', *args, **kwargs) else: output = self.trainer.call_hook('validation_step_end', *args, **kwargs) return output def evaluation_epoch_end(self, num_dataloaders): using_eval_result = self.is_using_eval_results() # call the model epoch end eval_results = self.__run_eval_epoch_end(num_dataloaders, using_eval_result) return eval_results def log_epoch_metrics(self, eval_results): using_eval_result = self.is_using_eval_results() if using_eval_result: if isinstance(eval_results, list): for eval_result in eval_results: self.trainer.callback_metrics = eval_result.callback_metrics else: self.trainer.callback_metrics = eval_results.callback_metrics else: if isinstance(eval_results, list): for eval_result in eval_results: # with a scalar return, auto set it to "val_loss" for callbacks if isinstance(eval_result, torch.Tensor): flat = {'val_loss': eval_result} else: flat = flatten_dict(eval_result) self.trainer.callback_metrics.update(flat) else: # with a scalar return, auto set it to "val_loss" for callbacks if isinstance(eval_results, torch.Tensor): flat = {'val_loss': eval_results} else: flat = flatten_dict(eval_results) self.trainer.callback_metrics.update(flat) def __run_eval_epoch_end(self, num_dataloaders, using_eval_result): model = self.trainer.get_model() # with a single dataloader don't pass an array outputs = self.outputs eval_results = outputs if num_dataloaders == 1: eval_results = outputs[0] user_reduced = False if self.testing: if self.trainer.is_overridden('test_epoch_end', model=model): if using_eval_result: eval_results = self.__gather_epoch_end_eval_results(outputs) eval_results = model.test_epoch_end(eval_results) user_reduced = True else: if self.trainer.is_overridden('validation_epoch_end', model=model): if using_eval_result: eval_results = self.__gather_epoch_end_eval_results(outputs) eval_results = model.validation_epoch_end(eval_results) user_reduced = True if using_eval_result and not user_reduced: eval_results = self.__auto_reduce_result_objs(outputs) if not isinstance(eval_results, list): eval_results = [eval_results] return eval_results def __gather_epoch_end_eval_results(self, outputs): eval_results = [] for epoch_output in outputs: result = epoch_output[0].__class__.gather(epoch_output) if 'checkpoint_on' in result: result.checkpoint_on = result.checkpoint_on.mean() if 'early_stop_on' in result: result.early_stop_on = result.early_stop_on.mean() eval_results.append(result) # with 1 dataloader don't pass in a list if len(eval_results) == 1: eval_results = eval_results[0] return eval_results def __auto_reduce_result_objs(self, outputs): # outputs has a list of results per dataloader eval_results = [] for dl_output in outputs: result = dl_output[0] result = result.__class__.reduce_on_epoch_end(dl_output) if 'checkpoint_on' in result: result.checkpoint_on = result.checkpoint_on.mean() if 'early_stop_on' in result: result.early_stop_on = result.early_stop_on.mean() eval_results.append(result) return eval_results def on_evaluation_batch_start(self, *args, **kwargs): if self.testing: self.trainer.call_hook('on_test_batch_start', *args, **kwargs) else: self.trainer.call_hook('on_validation_batch_start', *args, **kwargs) def on_evaluation_batch_end(self, *args, **kwargs): if self.testing: self.trainer.call_hook('on_test_batch_end', *args, **kwargs) else: self.trainer.call_hook('on_validation_batch_end', *args, **kwargs) def evaluation_batch_end_cleanup(self, output, batch_idx, dataloader_idx): # Add step predictions to prediction collection to write later if output is not None: do_write_predictions = isinstance(output, Result) and self.testing if do_write_predictions: self.predictions.add(output.pop('predictions', None)) # track debug metrics self.trainer.dev_debugger.track_eval_loss_history(self.testing, batch_idx, dataloader_idx, output) def on_evaluation_epoch_end(self, eval_results, *args, **kwargs): # log epoch level metrics self.log_epoch_metrics(eval_results) # Write predictions to disk if they're available self.predictions.to_disk() # call the callback hook if self.testing: self.trainer.call_hook('on_test_epoch_end', *args, **kwargs) else: self.trainer.call_hook('on_validation_epoch_end', *args, **kwargs) def log_step_metrics(self, output, batch_idx): if self.trainer.running_sanity_check: return if isinstance(output, EvalResult): step_log_metrics = output.batch_log_metrics step_pbar_metrics = output.batch_pbar_metrics if len(step_log_metrics) > 0: # make the metrics appear as a different line in the same graph metrics_by_epoch = {} for k, v in step_log_metrics.items(): metrics_by_epoch[f'{k}/epoch_{self.trainer.current_epoch}'] = v self.trainer.log_metrics(metrics_by_epoch, {}, step=batch_idx) if len(step_pbar_metrics) > 0: self.trainer.add_progress_bar_metrics(step_pbar_metrics)