def _evaluate(self, model: LightningModule, dataloaders: List[DataLoader], max_batches: Union[int, List[int]], test_mode: bool = False): """Run evaluation code. Args: model: The model to evaluate. dataloaders: A list of PyTorch dataloaders. max_batches: An integer or list of integers with length of the number of dataloaders. Each entry is the number of batches to process in the corresponding dataloader. test_mode: """ # enable eval mode model.zero_grad() model.eval() # copy properties for forward overrides self.copy_trainer_model_properties(model) # disable gradients to save memory torch.set_grad_enabled(False) # bookkeeping outputs = [] predictions = PredictionCollection(self.global_rank, self.world_size) # convert max_batches to list if isinstance(max_batches, int): max_batches = [max_batches] * len(dataloaders) # -------------------------- # ON_EVAL_EPOCH_START hook # -------------------------- self.__call_eval_loop_hook_start(test_mode) # run validation for dataloader_idx, dataloader in enumerate(dataloaders): dl_outputs = [] # on TPU we have to wrap it under the ParallelLoader if self.use_tpu: device = xm.xla_device(self.tpu_id) dataloader = xla_pl.ParallelLoader(dataloader, [device]) dataloader = dataloader.per_device_loader(device) # each dataloader has a max num batches dl_max_batches = max_batches[dataloader_idx] for batch_idx, batch in enumerate(dataloader): if batch is None: continue # stop short when running on limited batches if batch_idx >= dl_max_batches: break # callbacks if test_mode: self.on_test_batch_start(batch, batch_idx, dataloader_idx) if self.is_overridden('on_test_batch_start'): model_ref = self.get_model() with self.profiler.profile('on_test_batch_start'): model_ref.on_test_batch_start(output) else: self.on_validation_batch_start(batch, batch_idx, dataloader_idx) if self.is_overridden('on_validation_batch_start'): model_ref = self.get_model() with self.profiler.profile( 'on_validation_batch_start'): model_ref.on_validation_batch_start(output) # ----------------- # RUN EVALUATION STEP # ----------------- if self.amp_backend == AMPType.NATIVE and not self.use_tpu: with torch.cuda.amp.autocast(): output = self.evaluation_forward( model, batch, batch_idx, dataloader_idx, test_mode) else: output = self.evaluation_forward(model, batch, batch_idx, dataloader_idx, test_mode) is_result_obj = isinstance(output, Result) # track batch size for weighted average if is_result_obj: output.track_batch_size(len(batch)) # allow only EvalResult when using structured results (from val_step) if is_result_obj and not isinstance(output, EvalResult): m = 'only EvalResults or dicts are allowed from validation_step' raise MisconfigurationException(m) # ------------------ # EVAL STEP END # ------------------ # on dp / ddp2 might still want to do something with the batch parts eval_step_end_hook_name = 'test_step_end' if test_mode else 'validation_step_end' if self.is_overridden(eval_step_end_hook_name): model_ref = self.get_model() with self.profiler.profile(eval_step_end_hook_name): eval_step_end = getattr(model_ref, eval_step_end_hook_name) output = eval_step_end(output) elif is_result_obj and (self.use_dp or self.use_ddp2): # result auto reduce output.dp_reduce() # callbacks (on __batch_end) if test_mode: self.on_test_batch_end(batch, batch_idx, dataloader_idx) if self.is_overridden('on_test_batch_end'): model_ref = self.get_model() with self.profiler.profile('on_test_batch_end'): model_ref.on_test_batch_end(output) else: self.on_validation_batch_end(batch, batch_idx, dataloader_idx) if self.is_overridden('on_validation_batch_end'): model_ref = self.get_model() with self.profiler.profile('on_validation_batch_end'): model_ref.on_validation_batch_end(output) # track outputs for collation if output is not None: # Add step predictions to prediction collection to write later do_write_predictions = is_result_obj and test_mode if do_write_predictions: predictions.add(output.pop('predictions', None)) dl_outputs.append(output) self.__eval_add_step_metrics(output, batch_idx) # track debug metrics self.dev_debugger.track_eval_loss_history( test_mode, batch_idx, dataloader_idx, output) outputs.append(dl_outputs) # --------------------- # EVAL_EPOCH_END # --------------------- using_eval_result = len(outputs) > 0 and len( outputs[0]) > 0 and isinstance(outputs[0][0], EvalResult) eval_results = self.__run_eval_epoch_end(test_mode, outputs, dataloaders, using_eval_result) # log callback metrics self.__update_callback_metrics(eval_results, using_eval_result) # Write predictions to disk if they're available. predictions.to_disk() # enable train mode again model.train() # enable gradients to save memory torch.set_grad_enabled(True) # -------------------------- # ON_EVAL_EPOCH_END hook # -------------------------- self.__call_eval_loop_hook_end(test_mode) return eval_results
class EvaluationLoop(object): def __init__(self, trainer): self.trainer = trainer self.testing = False self.outputs = [] self.predictions = None self.max_batches = None def is_using_eval_results(self): outputs = self.outputs using_eval_result = len(outputs) > 0 and len(outputs[0]) > 0 and isinstance(outputs[0][0], EvalResult) return using_eval_result def setup(self, model, max_batches, dataloaders): # copy properties for forward overrides self.trainer.copy_trainer_model_properties(model) # bookkeeping self.outputs = [] self.predictions = PredictionCollection(self.trainer.global_rank, self.trainer.world_size) # convert max_batches to list if isinstance(max_batches, int): max_batches = [max_batches] * len(dataloaders) self.max_batches = max_batches def on_evaluation_epoch_start(self, *args, **kwargs): if self.testing: self.trainer.call_hook('on_test_epoch_start', *args, **kwargs) else: self.trainer.call_hook('on_validation_epoch_start', *args, **kwargs) def build_args(self, test_mode, batch, batch_idx, dataloader_idx): # make dataloader_idx arg in validation_step optional args = [batch, batch_idx] multiple_val_loaders = (not test_mode and len(self.trainer.val_dataloaders) > 1) multiple_test_loaders = (test_mode and len(self.trainer.test_dataloaders) > 1) if multiple_test_loaders or multiple_val_loaders: args.append(dataloader_idx) return args def evaluation_step(self, test_mode, batch, batch_idx, dataloader_idx): # configure args args = self.build_args(test_mode, batch, batch_idx, dataloader_idx) # run actual test step if self.testing: output = self.trainer.accelerator_backend.test_step(args) else: output = self.trainer.accelerator_backend.validation_step(args) # track batch size for weighted average is_result_obj = isinstance(output, Result) if is_result_obj: output.track_batch_size(len(batch)) # allow only EvalResult when using structured results (from val_step) if is_result_obj and not isinstance(output, EvalResult): m = 'only EvalResults or dicts are allowed from validation_step' raise MisconfigurationException(m) return output def evaluation_step_end(self, *args, **kwargs): if self.testing: output = self.trainer.call_hook('test_step_end', *args, **kwargs) else: output = self.trainer.call_hook('validation_step_end', *args, **kwargs) return output def evaluation_epoch_end(self, num_dataloaders): using_eval_result = self.is_using_eval_results() # call the model epoch end eval_results = self.__run_eval_epoch_end(num_dataloaders, using_eval_result) return eval_results def log_epoch_metrics(self, eval_results): using_eval_result = self.is_using_eval_results() if using_eval_result: if isinstance(eval_results, list): for eval_result in eval_results: self.trainer.callback_metrics = eval_result.callback_metrics else: self.trainer.callback_metrics = eval_results.callback_metrics else: if isinstance(eval_results, list): for eval_result in eval_results: # with a scalar return, auto set it to "val_loss" for callbacks if isinstance(eval_result, torch.Tensor): flat = {'val_loss': eval_result} else: flat = flatten_dict(eval_result) self.trainer.callback_metrics.update(flat) else: # with a scalar return, auto set it to "val_loss" for callbacks if isinstance(eval_results, torch.Tensor): flat = {'val_loss': eval_results} else: flat = flatten_dict(eval_results) self.trainer.callback_metrics.update(flat) def __run_eval_epoch_end(self, num_dataloaders, using_eval_result): model = self.trainer.get_model() # with a single dataloader don't pass an array outputs = self.outputs eval_results = outputs if num_dataloaders == 1: eval_results = outputs[0] user_reduced = False if self.testing: if self.trainer.is_overridden('test_epoch_end', model=model): if using_eval_result: eval_results = self.__gather_epoch_end_eval_results(outputs) eval_results = model.test_epoch_end(eval_results) user_reduced = True else: if self.trainer.is_overridden('validation_epoch_end', model=model): if using_eval_result: eval_results = self.__gather_epoch_end_eval_results(outputs) eval_results = model.validation_epoch_end(eval_results) user_reduced = True if using_eval_result and not user_reduced: eval_results = self.__auto_reduce_result_objs(outputs) if not isinstance(eval_results, list): eval_results = [eval_results] return eval_results def __gather_epoch_end_eval_results(self, outputs): eval_results = [] for epoch_output in outputs: result = epoch_output[0].__class__.gather(epoch_output) if 'checkpoint_on' in result: result.checkpoint_on = result.checkpoint_on.mean() if 'early_stop_on' in result: result.early_stop_on = result.early_stop_on.mean() eval_results.append(result) # with 1 dataloader don't pass in a list if len(eval_results) == 1: eval_results = eval_results[0] return eval_results def __auto_reduce_result_objs(self, outputs): # outputs has a list of results per dataloader eval_results = [] for dl_output in outputs: result = dl_output[0] result = result.__class__.reduce_on_epoch_end(dl_output) if 'checkpoint_on' in result: result.checkpoint_on = result.checkpoint_on.mean() if 'early_stop_on' in result: result.early_stop_on = result.early_stop_on.mean() eval_results.append(result) return eval_results def on_evaluation_batch_start(self, *args, **kwargs): if self.testing: self.trainer.call_hook('on_test_batch_start', *args, **kwargs) else: self.trainer.call_hook('on_validation_batch_start', *args, **kwargs) def on_evaluation_batch_end(self, *args, **kwargs): if self.testing: self.trainer.call_hook('on_test_batch_end', *args, **kwargs) else: self.trainer.call_hook('on_validation_batch_end', *args, **kwargs) def evaluation_batch_end_cleanup(self, output, batch_idx, dataloader_idx): # Add step predictions to prediction collection to write later if output is not None: do_write_predictions = isinstance(output, Result) and self.testing if do_write_predictions: self.predictions.add(output.pop('predictions', None)) # track debug metrics self.trainer.dev_debugger.track_eval_loss_history(self.testing, batch_idx, dataloader_idx, output) def on_evaluation_epoch_end(self, eval_results, *args, **kwargs): # log epoch level metrics self.log_epoch_metrics(eval_results) # Write predictions to disk if they're available self.predictions.to_disk() # call the callback hook if self.testing: self.trainer.call_hook('on_test_epoch_end', *args, **kwargs) else: self.trainer.call_hook('on_validation_epoch_end', *args, **kwargs) def log_step_metrics(self, output, batch_idx): if self.trainer.running_sanity_check: return if isinstance(output, EvalResult): step_log_metrics = output.batch_log_metrics step_pbar_metrics = output.batch_pbar_metrics if len(step_log_metrics) > 0: # make the metrics appear as a different line in the same graph metrics_by_epoch = {} for k, v in step_log_metrics.items(): metrics_by_epoch[f'{k}/epoch_{self.trainer.current_epoch}'] = v self.trainer.log_metrics(metrics_by_epoch, {}, step=batch_idx) if len(step_pbar_metrics) > 0: self.trainer.add_progress_bar_metrics(step_pbar_metrics)