def call_hook(self, hook_name, *args, **kwargs): # set hook_name to model + reset Result obj skip = self._reset_result_and_set_hook_fx_name(hook_name) # always profile hooks with self.profiler.profile(hook_name): # first call trainer hook if hasattr(self, hook_name): trainer_hook = getattr(self, hook_name) trainer_hook(*args, **kwargs) # next call hook in lightningModule output = None model_ref = self.lightning_module if is_overridden(hook_name, model_ref): hook_fx = getattr(model_ref, hook_name) output = hook_fx(*args, **kwargs) # if the PL module doesn't have the hook then call the accelerator # used to auto-reduce things for the user with Results obj elif hasattr(self.accelerator, hook_name): accelerator_hook = getattr(self.accelerator, hook_name) output = accelerator_hook(*args, **kwargs) if not skip: self._cache_logged_metrics() return output
def on_train_epoch_end(self, epoch_output: List[List[List[Result]]]) -> None: # inform logger the batch loop has finished self.trainer.logger_connector.on_train_epoch_end() # prepare epoch output processed_epoch_output = TrainLoop._prepare_outputs(epoch_output, batch_mode=False) # get the model and call model.training_epoch_end model = self.trainer.lightning_module if is_overridden('training_epoch_end', model=model): # run training_epoch_end # refresh the result for custom logging at the epoch level model._current_fx_name = 'training_epoch_end' training_epoch_end_output = model.training_epoch_end( processed_epoch_output) if training_epoch_end_output is not None: raise MisconfigurationException( 'training_epoch_end expects a return of None. ' 'HINT: remove the return statement in training_epoch_end') # capture logging self.trainer.logger_connector.cache_logged_metrics() # call train epoch end hooks self._on_train_epoch_end_hook(processed_epoch_output) self.trainer.call_hook('on_epoch_end')
def run_sanity_check(self, ref_model): using_val_step = ref_model.val_dataloader is not None and is_overridden('validation_step', ref_model) should_sanity_check = using_val_step and self.num_sanity_val_steps > 0 and self.limit_val_batches > 0 # run tiny validation (if validation defined) # to make sure program won't crash during val if should_sanity_check: stage = self._running_stage self.sanity_checking = True # hook and callback self.on_sanity_check_start() # run eval step _, eval_results = self.run_evaluation() # allow no returns from eval if eval_results is not None and len(eval_results) > 0: # when we get a list back, used only the last item if isinstance(eval_results, list): eval_results = eval_results[-1] _, _, _, callback_metrics, _ = self.process_dict_result(eval_results) self.logger_connector.callback_metrics = callback_metrics self.on_sanity_check_end() self._running_stage = stage
def attach_datamodule( self, model: "pl.LightningModule", datamodule: Optional["pl.LightningDataModule"] = None) -> None: # If we have a datamodule, attach necessary hooks + dataloaders if datamodule is None: return self._train_dataloader_source = _DataLoaderSource( datamodule, "train_dataloader") self._val_dataloader_source = _DataLoaderSource( datamodule, "val_dataloader") self._test_dataloader_source = _DataLoaderSource( datamodule, "test_dataloader") self._predict_dataloader_source = _DataLoaderSource( datamodule, "predict_dataloader") # Override data transfer hooks if dataset-specific to_device logic has been defined in datamodule batch_transfer_hooks = ("on_before_batch_transfer", "transfer_batch_to_device", "on_after_batch_transfer") for hook in batch_transfer_hooks: if is_overridden(hook, datamodule): setattr(model, hook, getattr(datamodule, hook)) self.trainer.datamodule = datamodule datamodule.trainer = self.trainer # experimental feature for Flash if hasattr(datamodule, "data_pipeline"): model.data_pipeline = datamodule.data_pipeline
def _collect_rank_zero_results(self, trainer: "pl.Trainer", results: Any) -> Optional["_SpawnOutput"]: rank_zero_debug("Finalizing the TPU spawn environment.") checkpoint_callback = trainer.checkpoint_callback best_model_path = checkpoint_callback.best_model_path if checkpoint_callback else None # requires to compute the state_dict on all processes in case Metrics are present state_dict = self.lightning_module.state_dict() # save the last weights weights_path = None if trainer.state.fn == TrainerFn.FITTING: weights_path = os.path.join(trainer.default_root_dir, ".temp.ckpt") self.checkpoint_io.save_checkpoint(state_dict, weights_path) # We use `local_rank` here as separate filesystems are used for each VM for TPU Pod Training if self.local_rank != 0: return # adds the `callback_metrics` to the queue extra = _FakeQueue() if is_overridden("add_to_queue", self.lightning_module): # TODO: Remove the if in v1.7 self.lightning_module.add_to_queue(extra) self.add_to_queue(trainer, extra) return _SpawnOutput(best_model_path, weights_path, trainer.state, results, extra)
def init_deepspeed(self): # deepspeed handles gradient clipping internally if is_overridden("configure_gradient_clipping", self.lightning_module, pl.LightningModule): rank_zero_warn( "Since DeepSpeed handles gradient clipping internally, the default" " `LightningModule.configure_gradient_clipping` implementation will not actually clip gradients." " The hook will still be called. Consider setting" " `Trainer(gradient_clip_val=..., gradient_clip_algorithm='norm')`" " which will use the internal mechanism.") if self.lightning_module.trainer.gradient_clip_algorithm == GradClipAlgorithmType.VALUE: raise MisconfigurationException( "DeepSpeed does not support clipping gradients by value.") accumulation_scheduler = self.lightning_module.trainer.accumulation_scheduler if accumulation_scheduler.epochs != [0]: raise MisconfigurationException( "DeepSpeed currently does not support different `accumulate_grad_batches` at different epochs." ) model = LightningDeepSpeedModule( pl_module=self.model, precision=self.precision_plugin.precision) if self.lightning_module.trainer and self.lightning_module.trainer.training: self._initialize_deepspeed_train(model) else: self._initialize_deepspeed_inference(model)
def test_dm_apply_batch_transfer_handler(get_module_mock): expected_device = torch.device('cuda', 0) class CustomBatch: def __init__(self, data): self.samples = data[0] self.targets = data[1] class CurrentTestDM(LightningDataModule): rank = 0 transfer_batch_to_device_hook_rank = None on_before_batch_transfer_hook_rank = None on_after_batch_transfer_hook_rank = None def on_before_batch_transfer(self, batch, dataloader_idx): self.on_before_batch_transfer_hook_rank = self.rank self.rank += 1 batch.samples += 1 return batch def on_after_batch_transfer(self, batch, dataloader_idx): assert batch.samples.device == batch.targets.device == expected_device self.on_after_batch_transfer_hook_rank = self.rank self.rank += 1 batch.targets *= 2 return batch def transfer_batch_to_device(self, batch, device): self.transfer_batch_to_device_hook_rank = self.rank self.rank += 1 batch.samples = batch.samples.to(device) batch.targets = batch.targets.to(device) return batch dm = CurrentTestDM() model = BoringModel() batch = CustomBatch((torch.zeros(5, 32), torch.ones(5, 1, dtype=torch.long))) trainer = Trainer(gpus=1) # running .fit() would require us to implement custom data loaders, we mock the model reference instead get_module_mock.return_value = model if is_overridden('transfer_batch_to_device', dm): model.transfer_batch_to_device = dm.transfer_batch_to_device model.on_before_batch_transfer = dm.on_before_batch_transfer model.transfer_batch_to_device = dm.transfer_batch_to_device model.on_after_batch_transfer = dm.on_after_batch_transfer batch_gpu = trainer.accelerator.batch_to_device(batch, expected_device) assert dm.on_before_batch_transfer_hook_rank == 0 assert dm.transfer_batch_to_device_hook_rank == 1 assert dm.on_after_batch_transfer_hook_rank == 2 assert batch_gpu.samples.device == batch_gpu.targets.device == expected_device assert torch.allclose(batch_gpu.samples.cpu(), torch.ones(5, 32)) assert torch.allclose(batch_gpu.targets.cpu(), torch.ones(5, 1, dtype=torch.long) * 2)
def _add_configure_optimizers_method_to_model(self, subcommand: Optional[str]) -> None: """Overrides the model's :meth:`~pytorch_lightning.core.lightning.LightningModule.configure_optimizers` method if a single optimizer and optionally a scheduler argument groups are added to the parser as 'AUTOMATIC'.""" parser = self._parser(subcommand) def get_automatic( class_type: Union[Type, Tuple[Type, ...]], register: Dict[str, Tuple[Union[Type, Tuple[Type, ...]], str]] ) -> List[str]: automatic = [] for key, (base_class, link_to) in register.items(): if not isinstance(base_class, tuple): base_class = (base_class,) if link_to == "AUTOMATIC" and any(issubclass(c, class_type) for c in base_class): automatic.append(key) return automatic optimizers = get_automatic(Optimizer, parser._optimizers) lr_schedulers = get_automatic(LRSchedulerTypeTuple, parser._lr_schedulers) if len(optimizers) == 0: return if len(optimizers) > 1 or len(lr_schedulers) > 1: raise MisconfigurationException( f"`{self.__class__.__name__}.add_configure_optimizers_method_to_model` expects at most one optimizer " f"and one lr_scheduler to be 'AUTOMATIC', but found {optimizers+lr_schedulers}. In this case the user " "is expected to link the argument groups and implement `configure_optimizers`, see " "https://pytorch-lightning.readthedocs.io/en/stable/common/lightning_cli.html" "#optimizers-and-learning-rate-schedulers" ) optimizer_class = parser._optimizers[optimizers[0]][0] optimizer_init = self._get(self.config_init, optimizers[0]) if not isinstance(optimizer_class, tuple): optimizer_init = _global_add_class_path(optimizer_class, optimizer_init) if not optimizer_init: # optimizers were registered automatically but not passed by the user return lr_scheduler_init = None if lr_schedulers: lr_scheduler_class = parser._lr_schedulers[lr_schedulers[0]][0] lr_scheduler_init = self._get(self.config_init, lr_schedulers[0]) if not isinstance(lr_scheduler_class, tuple): lr_scheduler_init = _global_add_class_path(lr_scheduler_class, lr_scheduler_init) if is_overridden("configure_optimizers", self.model): _warn( f"`{self.model.__class__.__name__}.configure_optimizers` will be overridden by " f"`{self.__class__.__name__}.configure_optimizers`." ) optimizer = instantiate_class(self.model.parameters(), optimizer_init) lr_scheduler = instantiate_class(optimizer, lr_scheduler_init) if lr_scheduler_init else None fn = partial(self.configure_optimizers, optimizer=optimizer, lr_scheduler=lr_scheduler) update_wrapper(fn, self.configure_optimizers) # necessary for `is_overridden` # override the existing method self.model.configure_optimizers = MethodType(fn, self.model)
def add_configure_optimizers_method_to_model(self) -> None: """ Adds to the model an automatically generated configure_optimizers method If a single optimizer and optionally a scheduler argument groups are added to the parser as 'AUTOMATIC', then a `configure_optimizers` method is automatically implemented in the model class. """ def get_automatic(class_type: Union[Type, Tuple[Type, ...]]) -> List[str]: automatic = [] for key, (base_class, link_to) in self.parser.optimizers_and_lr_schedulers.items(): if not isinstance(base_class, tuple): base_class = (base_class, ) if link_to == 'AUTOMATIC' and any(issubclass(c, class_type) for c in base_class): automatic.append(key) return automatic optimizers = get_automatic(Optimizer) lr_schedulers = get_automatic(LRSchedulerTypeTuple) if len(optimizers) == 0: return if len(optimizers) > 1 or len(lr_schedulers) > 1: raise MisconfigurationException( f"`{self.__class__.__name__}.add_configure_optimizers_method_to_model` expects at most one optimizer " f"and one lr_scheduler to be 'AUTOMATIC', but found {optimizers+lr_schedulers}. In this case the user " "is expected to link the argument groups and implement `configure_optimizers`, see " "https://pytorch-lightning.readthedocs.io/en/stable/common/lightning_cli.html" "#optimizers-and-learning-rate-schedulers" ) if is_overridden('configure_optimizers', self.model): warnings.warn( f"`{self.model.__class__.__name__}.configure_optimizers` will be overridden by " f"`{self.__class__.__name__}.add_configure_optimizers_method_to_model`." ) optimizer_class = self.parser.optimizers_and_lr_schedulers[optimizers[0]][0] optimizer_init = self.config_init.get(optimizers[0], {}) if not isinstance(optimizer_class, tuple): optimizer_init = _global_add_class_path(optimizer_class, optimizer_init) lr_scheduler_init = None if lr_schedulers: lr_scheduler_class = self.parser.optimizers_and_lr_schedulers[lr_schedulers[0]][0] lr_scheduler_init = self.config_init.get(lr_schedulers[0], {}) if not isinstance(lr_scheduler_class, tuple): lr_scheduler_init = _global_add_class_path(lr_scheduler_class, lr_scheduler_init) def configure_optimizers( self: LightningModule ) -> Union[Optimizer, Tuple[List[Optimizer], List[LRSchedulerType]]]: optimizer = instantiate_class(self.parameters(), optimizer_init) if not lr_scheduler_init: return optimizer lr_scheduler = instantiate_class(optimizer, lr_scheduler_init) return [optimizer], [lr_scheduler] self.model.configure_optimizers = MethodType(configure_optimizers, self.model)
def _should_add_batch_output_to_epoch_output(self) -> bool: """ We add to the epoch outputs if 1. The model defines training_epoch_end OR 2. The model overrides on_train_epoch_end which has `outputs` in the signature """ # TODO: in v1.5 this only needs to check if training_epoch_end is overridden lightning_module = self.trainer.lightning_module if is_overridden("training_epoch_end", lightning_module): return True if is_overridden("on_train_epoch_end", lightning_module): model_hook_fx = getattr(lightning_module, "on_train_epoch_end") if is_param_in_hook_signature(model_hook_fx, "outputs"): return True return False
def _disable_zero_grad(self) -> None: lightning_module = self.lightning_module if is_overridden("optimizer_zero_grad", lightning_module): assert lightning_module is not None # `is_overridden` returns False otherwise rank_zero_warn( "You have overridden the `LightningModule.optimizer_zero_grad` hook but it will be ignored since" " IPUs handle the zeroing of gradients internally.") lightning_module.optimizer_zero_grad = None # type: ignore[assignment]
def _check_on_pretrain_routine(model: "pl.LightningModule") -> None: hooks = (("on_pretrain_routine_start", "on_fit_start"), ("on_pretrain_routine_end", "on_fit_start")) for hook, alternative_hook in hooks: if is_overridden(hook, model): rank_zero_deprecation( f"The `LightningModule.{hook}` hook was deprecated in v1.6 and" f" will be removed in v1.8. Please use `LightningModule.{alternative_hook}` instead." )
def __verify_eval_loop_configuration(self, model): stage = "val" if self.trainer.validating else "test" loader_name = f'{stage}_dataloader' step_name = f'{stage}_step' has_loader = is_overridden(loader_name, model) has_step = is_overridden(step_name, model) if has_loader and not has_step: rank_zero_warn( f'you passed in a {loader_name} but have no {step_name}. Skipping {stage} loop' ) if has_step and not has_loader: rank_zero_warn( f'you defined a {step_name} but have no {loader_name}. Skipping {stage} loop' )
def backward(self, model: "pl.LightningModule", closure_loss: Tensor, *args: Any, **kwargs: Any) -> None: if is_overridden("backward", model): warning_cache.warn( "You have overridden the `LightningModule.backward` hook but it will be ignored since DeepSpeed handles" " the backward logic internally.") deepspeed_engine: DeepSpeedEngine = model.trainer.model deepspeed_engine.backward(closure_loss, *args, **kwargs)
def _check_on_keyboard_interrupt(trainer: "pl.Trainer") -> None: """Checks if on_keyboard_interrupt is overriden and sends a deprecation warning.""" for callback in trainer.callbacks: if is_overridden(method_name="on_keyboard_interrupt", instance=callback): rank_zero_deprecation( "The `on_keyboard_interrupt` callback hook was deprecated in v1.5 and will be removed in v1.7." " Please use the `on_exception` callback hook instead.")
def _check_setup_method(trainer: "pl.Trainer") -> None: for obj in [trainer.lightning_module, trainer.datamodule ] + trainer.callbacks: if is_overridden("setup", obj) and not is_param_in_hook_signature( obj.setup, "stage"): raise MisconfigurationException( f"`{obj.__class__.__name__}.setup` does not have a `stage` argument." )
def can_prepare_data(self): should_call_dm_prepare_data = True if self.trainer.datamodule is not None and is_overridden('prepare_data', self.trainer.datamodule): should_call_dm_prepare_data = not self.trainer.datamodule.has_prepared_data if self.trainer.prepare_data_per_node: return self.trainer.local_rank == 0 and should_call_dm_prepare_data return self.trainer.node_rank == 0 and self.trainer.local_rank == 0 and should_call_dm_prepare_data
def make_petastorm_reader(model, data_path, dataloader_attr, reader_worker_count, reader_pool_type, should_read=True): from petastorm import TransformSpec, make_reader, make_batch_reader import horovod.torch as hvd is_loader_overridden = False if LooseVersion(pl.__version__) >= LooseVersion('1.0.0'): from pytorch_lightning.utilities.model_helpers import is_overridden is_loader_overridden = is_overridden(dataloader_attr, model) if not should_read or is_loader_overridden: yield return transform_spec = TransformSpec( transformation) if transformation else None # In general, make_batch_reader is faster than make_reader for reading the dataset. # However, we found out that make_reader performs data transformations much faster than # make_batch_reader with parallel worker processes. Therefore, the default reader # we choose is make_batch_reader unless there are data transformations. reader_factory_kwargs = dict() if transform_spec: reader_factory = make_reader reader_factory_kwargs['pyarrow_serialize'] = True else: reader_factory = make_batch_reader # Petastorm: read data from the store with the correct shard for this rank # setting num_epochs=None will cause an infinite iterator # and enables ranks to perform training and validation with # unequal number of samples with reader_factory(data_path, num_epochs=1, cur_shard=hvd.rank(), shard_count=hvd.size(), reader_pool_type=reader_pool_type, workers_count=reader_worker_count, hdfs_driver=PETASTORM_HDFS_DRIVER, schema_fields=schema_fields, transform_spec=transform_spec, **reader_factory_kwargs) as reader: def dataloader_fn(): return dataloader_cls( reader, batch_size=batch_size, shuffling_queue_capacity=calculate_shuffle_buffer_size()) try: setattr(model, dataloader_attr, dataloader_fn) yield finally: setattr(model, dataloader_attr, None)
def __verify_eval_loop_configuration(trainer: "pl.Trainer", model: "pl.LightningModule", stage: str) -> None: loader_name = f"{stage}_dataloader" step_name = "validation_step" if stage == "val" else f"{stage}_step" trainer_method = "validate" if stage == "val" else stage on_eval_hook = f"on_{loader_name}" has_loader = getattr(trainer._data_connector, f"_{stage}_dataloader_source").is_defined() has_step = is_overridden(step_name, model) has_on_eval_dataloader = is_overridden(on_eval_hook, model) # ---------------------------------------------- # verify model does not have on_eval_dataloader # ---------------------------------------------- if has_on_eval_dataloader: rank_zero_deprecation( f"Method `{on_eval_hook}` is deprecated in v1.5.0 and will" f" be removed in v1.7.0. Please use `{loader_name}()` directly.") # ----------------------------------- # verify model has an eval_dataloader # ----------------------------------- if not has_loader: raise MisconfigurationException( f"No `{loader_name}()` method defined to run `Trainer.{trainer_method}`." ) # predict_step is not required to be overridden if stage == "predict": if model.predict_step is None: raise MisconfigurationException( "`predict_step` cannot be None to run `Trainer.predict`") elif not has_step and not is_overridden("forward", model): raise MisconfigurationException( "`Trainer.predict` requires `forward` method to run.") else: # ----------------------------------- # verify model has an eval_step # ----------------------------------- if not has_step: raise MisconfigurationException( f"No `{step_name}()` method defined to run `Trainer.{trainer_method}`." )
def __verify_train_loop_configuration(self, model: "pl.LightningModule") -> None: # ----------------------------------- # verify model has a training step # ----------------------------------- has_training_step = is_overridden("training_step", model) if not has_training_step: raise MisconfigurationException( "No `training_step()` method defined. Lightning `Trainer` expects as minimum a" " `training_step()`, `train_dataloader()` and `configure_optimizers()` to be defined." ) # ----------------------------------- # verify model has a train dataloader # ----------------------------------- has_train_dataloader = is_overridden("train_dataloader", model) if not has_train_dataloader: raise MisconfigurationException( "No `train_dataloader()` method defined. Lightning `Trainer` expects as minimum a" " `training_step()`, `train_dataloader()` and `configure_optimizers()` to be defined." ) # ----------------------------------- # verify model has optimizer # ----------------------------------- has_optimizers = is_overridden("configure_optimizers", model) if not has_optimizers: raise MisconfigurationException( "No `configure_optimizers()` method defined. Lightning `Trainer` expects as minimum a" " `training_step()`, `train_dataloader()` and `configure_optimizers()` to be defined." ) trainer = self.trainer trainer.overriden_optimizer_step = is_overridden("optimizer_step", model) trainer.overriden_optimizer_zero_grad = is_overridden("optimizer_zero_grad", model) automatic_optimization = model.automatic_optimization going_to_accumulate_grad_batches = trainer.accumulation_scheduler.going_to_accumulate_grad_batches() has_overriden_optimization_functions = trainer.overriden_optimizer_step or trainer.overriden_optimizer_zero_grad if has_overriden_optimization_functions and going_to_accumulate_grad_batches and automatic_optimization: rank_zero_warn( "When using `Trainer(accumulate_grad_batches != 1)` and overriding" "`LightningModule.optimizer_{step,zero_grad}`, the hooks will not be called on every batch" "(rather, they are called on every optimization step)." )
def attach_datamodule(self, model, datamodule: Optional[LightningDataModule]) -> None: # We use datamodule if it's been provided on .fit or .test, otherwise we check model for it datamodule = datamodule or getattr(model, 'datamodule', None) # If we have a datamodule, attach necessary hooks + dataloaders if datamodule: # Override loader hooks if is_overridden('train_dataloader', datamodule): model.train_dataloader = datamodule.train_dataloader if is_overridden('val_dataloader', datamodule): model.val_dataloader = datamodule.val_dataloader if is_overridden('test_dataloader', datamodule): model.test_dataloader = datamodule.test_dataloader if is_overridden('predict_dataloader', datamodule): model.predict_dataloader = datamodule.predict_dataloader # Override data transfer hooks if dataset-specific to_device logic has been defined in datamodule if is_overridden('on_before_batch_transfer', datamodule): model.on_before_batch_transfer = datamodule.on_before_batch_transfer if is_overridden('transfer_batch_to_device', datamodule): model.transfer_batch_to_device = datamodule.transfer_batch_to_device if is_overridden('on_after_batch_transfer', datamodule): model.on_after_batch_transfer = datamodule.on_after_batch_transfer self.trainer.datamodule = datamodule datamodule.trainer = self.trainer
def __verify_train_loop_configuration(self, model: 'pl.LightningModule') -> None: # ----------------------------------- # verify model has a training step # ----------------------------------- has_training_step = is_overridden('training_step', model) if not has_training_step: raise MisconfigurationException( 'No `training_step()` method defined. Lightning `Trainer` expects as minimum a' ' `training_step()`, `train_dataloader()` and `configure_optimizers()` to be defined.' ) # ----------------------------------- # verify model has a train dataloader # ----------------------------------- has_train_dataloader = is_overridden('train_dataloader', model) if not has_train_dataloader: raise MisconfigurationException( 'No `train_dataloader()` method defined. Lightning `Trainer` expects as minimum a' ' `training_step()`, `train_dataloader()` and `configure_optimizers()` to be defined.' ) # ----------------------------------- # verify model has optimizer # ----------------------------------- has_optimizers = is_overridden('configure_optimizers', model) if not has_optimizers: raise MisconfigurationException( 'No `configure_optimizers()` method defined. Lightning `Trainer` expects as minimum a' ' `training_step()`, `train_dataloader()` and `configure_optimizers()` to be defined.' ) trainer = self.trainer trainer.overriden_optimizer_step = is_overridden('optimizer_step', model) trainer.overriden_optimizer_zero_grad = is_overridden('optimizer_zero_grad', model) automatic_optimization = model.automatic_optimization going_to_accumulate_grad_batches = trainer.accumulation_scheduler.going_to_accumulate_grad_batches() has_overriden_optimization_functions = trainer.overriden_optimizer_step or trainer.overriden_optimizer_zero_grad if has_overriden_optimization_functions and going_to_accumulate_grad_batches and automatic_optimization: raise MisconfigurationException( 'When overriding `LightningModule` optimizer_step or optimizer_zero_grad,' ' `accumulate_grad_batches` in `Trainer` should be 1.' ' It ensures optimizer_step or optimizer_zero_grad are called on every batch.' )
def reset_predict_dataloader(self, model) -> None: """Resets the predict dataloader and determines the number of batches. Args: model: The current `LightningModule` """ has_loader = is_overridden('predict_dataloader', model) if has_loader: self.num_predict_batches, self.predict_dataloaders = self._reset_eval_dataloader(model, 'predict')
def __run_eval_epoch_end(self, num_dataloaders): model = self.trainer.lightning_module # with a single dataloader don't pass an array outputs = self.outputs # free memory self.outputs = [] eval_results = outputs if num_dataloaders == 1: eval_results = outputs[0] user_reduced = False if self.trainer.testing: if is_overridden('test_epoch_end', model=model): model._current_fx_name = 'test_epoch_end' eval_results = model.test_epoch_end(eval_results) user_reduced = True else: if is_overridden('validation_epoch_end', model=model): model._current_fx_name = 'validation_epoch_end' eval_results = model.validation_epoch_end(eval_results) user_reduced = True # capture logging self.trainer.logger_connector.cache_logged_metrics() # depre warning if eval_results is not None and user_reduced: step = 'testing_epoch_end' if self.trainer.testing else 'validation_epoch_end' self.warning_cache.warn( f'The {step} should not return anything as of 9.1.' ' To log, use self.log(...) or self.write(...) directly in the LightningModule' ) if not isinstance(eval_results, list): eval_results = [eval_results] # track depreceated metrics self.trainer.logger_connector.track_metrics_deprecated(eval_results) return eval_results
def evaluation_epoch_end(self, outputs: EPOCH_OUTPUT) -> None: # unset dataloder_idx in model self.trainer.logger_connector.evaluation_epoch_end() # call the model epoch end model = self.trainer.lightning_module if self.trainer.testing: if is_overridden('test_epoch_end', model=model): model._current_fx_name = 'test_epoch_end' model.test_epoch_end(outputs) else: if is_overridden('validation_epoch_end', model=model): model._current_fx_name = 'validation_epoch_end' model.validation_epoch_end(outputs) # capture logging self.trainer.logger_connector.cache_logged_metrics()
def __verify_eval_loop_configuration(self, model, eval_loop_name): step_name = f'{eval_loop_name}_step' # map the dataloader name loader_name = f'{eval_loop_name}_dataloader' if eval_loop_name == 'validation': loader_name = 'val_dataloader' has_loader = is_overridden(loader_name, model) has_step = is_overridden(step_name, model) if has_loader and not has_step: rank_zero_warn( f'you passed in a {loader_name} but have no {step_name}. Skipping {eval_loop_name} loop' ) if has_step and not has_loader: rank_zero_warn( f'you defined a {step_name} but have no {loader_name}. Skipping {eval_loop_name} loop' )
def reset_predict_dataloader(self, model: Optional["pl.LightningModule"] = None) -> None: """Resets the predict dataloader and determines the number of batches. Args: model: The `LightningModule` if called outside of the trainer scope. """ pl_module = self.lightning_module or model has_loader = is_overridden("predict_dataloader", pl_module) if has_loader: self.num_predict_batches, self.predict_dataloaders = self._reset_eval_dataloader("predict", model=pl_module)
def evaluation_epoch_end(self, outputs: EPOCH_OUTPUT) -> None: # inform logger the batch loop has finished self.trainer.logger_connector.epoch_end_reached() # call the model epoch end model = self.trainer.lightning_module # unset dataloader_idx in model model._current_dataloader_idx = None if self.trainer.testing: if is_overridden('test_epoch_end', model): model._current_fx_name = 'test_epoch_end' model.test_epoch_end(outputs) else: if is_overridden('validation_epoch_end', model): model._current_fx_name = 'validation_epoch_end' model.validation_epoch_end(outputs)
def _disable_zero_grad(self) -> None: lightning_module = self.lightning_module if is_overridden("optimizer_zero_grad", lightning_module): assert lightning_module is not None # `is_overridden` returns False otherwise rank_zero_warn( "You have overridden `optimizer_zero_grad` which will be disabled." " When `HivemindStrategy(reuse_grad_buffers=True)`, the optimizer cannot call zero grad," " as this would delete the gradients before they are averaged." ) assert lightning_module is not None lightning_module.optimizer_zero_grad = None # type: ignore[assignment]
def _validate_data_hooks(self, model): # Raise Misconfiguration exception since these hooks are not supported in DP mode # TODO: Remove this blocker once batch transfer to device is integrated in Lightning for DP mode. batch_transfer_hooks = ('on_before_batch_transfer', 'transfer_batch_to_device', 'on_after_batch_transfer') for hook in batch_transfer_hooks: if self.trainer.accelerator_connector.use_dp and is_overridden( hook, model): raise MisconfigurationException( f'Overriding `{hook}` is not supported in DP mode.')