def test_v1_5_0_old_callback_on_save_checkpoint(tmpdir): class OldSignature(Callback): def on_save_checkpoint(self, trainer, pl_module): # noqa ... model = BoringModel() trainer_kwargs = { "default_root_dir": tmpdir, "checkpoint_callback": False, "max_epochs": 1, } filepath = tmpdir / "test.ckpt" trainer = Trainer(**trainer_kwargs, callbacks=[OldSignature()]) trainer.fit(model) with pytest.deprecated_call(match="old signature will be removed in v1.5"): trainer.save_checkpoint(filepath) class NewSignature(Callback): def on_save_checkpoint(self, trainer, pl_module, checkpoint): ... class ValidSignature1(Callback): def on_save_checkpoint(self, trainer, *args): ... class ValidSignature2(Callback): def on_save_checkpoint(self, *args): ... trainer.callbacks = [NewSignature(), ValidSignature1(), ValidSignature2()] with no_warning_call(DeprecationWarning): trainer.save_checkpoint(filepath)
def test_v1_8_0_callback_on_save_checkpoint_hook(tmpdir): class TestCallbackSaveHookReturn(Callback): def on_save_checkpoint(self, trainer, pl_module, checkpoint): return {"returning": "on_save_checkpoint"} class TestCallbackSaveHookOverride(Callback): def on_save_checkpoint(self, trainer, pl_module, checkpoint): print("overriding without returning") model = BoringModel() trainer = Trainer( callbacks=[TestCallbackSaveHookReturn()], max_epochs=1, fast_dev_run=True, enable_progress_bar=False, logger=False, default_root_dir=tmpdir, ) trainer.fit(model) with pytest.deprecated_call( match="Returning a value from `TestCallbackSaveHookReturn.on_save_checkpoint` is deprecated in v1.6" " and will be removed in v1.8. Please override `Callback.state_dict`" " to return state to be saved." ): trainer.save_checkpoint(tmpdir + "/path.ckpt") trainer.callbacks = [TestCallbackSaveHookOverride()] trainer.save_checkpoint(tmpdir + "/pathok.ckpt")
def test_v1_5_0_old_on_validation_epoch_end(tmpdir): callback_warning_cache.clear() class OldSignature(Callback): def on_validation_epoch_end(self, trainer, pl_module): # noqa ... model = BoringModel() trainer = Trainer(default_root_dir=tmpdir, max_epochs=1, callbacks=OldSignature()) with pytest.deprecated_call(match="old signature will be removed in v1.5"): trainer.fit(model) class OldSignatureModel(BoringModel): def on_validation_epoch_end(self): # noqa ... model = OldSignatureModel() with pytest.deprecated_call(match="old signature will be removed in v1.5"): trainer.fit(model) callback_warning_cache.clear() class NewSignature(Callback): def on_validation_epoch_end(self, trainer, pl_module, outputs): ... trainer.callbacks = [NewSignature()] with no_deprecated_call( match= "`Callback.on_validation_epoch_end` signature has changed in v1.3." ): trainer.fit(model) class NewSignatureModel(BoringModel): def on_validation_epoch_end(self, outputs): ... model = NewSignatureModel() with no_deprecated_call( match= "`ModelHooks.on_validation_epoch_end` signature has changed in v1.3." ): trainer.fit(model)
def __init__(self, pl_trainer: pl.Trainer, model: pl.LightningModule, population_tasks: mp.Queue, tune_hparams: Dict, process_position: int, global_epoch: mp.Value, max_epoch: int, full_parallel: bool, pbt_period: int = 4, pbt_monitor: str = 'val_loss', logger_info=None, dataloaders: Optional[Dict] = None): """ Args: pl_trainer: model: population_tasks: tune_hparams: process_position: global_epoch: max_epoch: full_parallel: pbt_period: **dataloaders: """ super().__init__() # Set monitor and monitor_precision monitor_precision = 32 # Set checkpoint dirpath #checkpoint_dirpath = pl_trainer.checkpoint_callback.dirpath #period = pl_trainer.checkpoint_callback.period # Formatting checkpoints checkpoint_format = '{task:03d}-{' + f'{pbt_monitor}:.{monitor_precision}f' + '}' checkpoint_filepath = os.path.join(pl_trainer.logger.log_dir, checkpoint_format) # For TaskSaving print(logger_info) checkpoint_dirpath = pl_trainer.logger.log_dir pl_trainer.checkpoint_callback = TaskSaving( filepath=checkpoint_filepath, monitor=pbt_monitor, population_tasks=population_tasks, period=1, full_parallel=full_parallel, ) # For EarlyStopping pl_trainer.early_stop_callback = EarlyStopping( global_epoch=global_epoch, max_global_epoch=max_epoch) # For TaskLoading pl_trainer.callbacks = [ TaskLoading(population_tasks=population_tasks, global_epoch=global_epoch, filepath=checkpoint_filepath, monitor=pbt_monitor, tune_hparams=tune_hparams, pbt_period=pbt_period) ] # Alter logger to spec. #if isinstance(pl_trainer.logger, pl.loggers.TensorBoardLogger): pl_trainer.logger = loggers.TensorBoardLogger( save_dir=logger_info['save_dir'], name=logger_info['name'], version=logger_info['version'], task=process_position, ) # Set process_position pl_trainer.process_position = process_position # pl_trainer.logger._version = f'worker_{process_position}' # Define and set = to self.trainer = pl_trainer self.model = model self.global_epoch = global_epoch self.population_tasks = population_tasks self.max_epoch = max_epoch self.dataloaders = dataloaders or {} print(dataloaders)