Ejemplo n.º 1
0
def test_v1_5_0_old_callback_on_save_checkpoint(tmpdir):
    class OldSignature(Callback):
        def on_save_checkpoint(self, trainer, pl_module):  # noqa
            ...

    model = BoringModel()
    trainer_kwargs = {
        "default_root_dir": tmpdir,
        "checkpoint_callback": False,
        "max_epochs": 1,
    }
    filepath = tmpdir / "test.ckpt"

    trainer = Trainer(**trainer_kwargs, callbacks=[OldSignature()])
    trainer.fit(model)

    with pytest.deprecated_call(match="old signature will be removed in v1.5"):
        trainer.save_checkpoint(filepath)

    class NewSignature(Callback):
        def on_save_checkpoint(self, trainer, pl_module, checkpoint):
            ...

    class ValidSignature1(Callback):
        def on_save_checkpoint(self, trainer, *args):
            ...

    class ValidSignature2(Callback):
        def on_save_checkpoint(self, *args):
            ...

    trainer.callbacks = [NewSignature(), ValidSignature1(), ValidSignature2()]
    with no_warning_call(DeprecationWarning):
        trainer.save_checkpoint(filepath)
Ejemplo n.º 2
0
def test_v1_8_0_callback_on_save_checkpoint_hook(tmpdir):
    class TestCallbackSaveHookReturn(Callback):
        def on_save_checkpoint(self, trainer, pl_module, checkpoint):
            return {"returning": "on_save_checkpoint"}

    class TestCallbackSaveHookOverride(Callback):
        def on_save_checkpoint(self, trainer, pl_module, checkpoint):
            print("overriding without returning")

    model = BoringModel()
    trainer = Trainer(
        callbacks=[TestCallbackSaveHookReturn()],
        max_epochs=1,
        fast_dev_run=True,
        enable_progress_bar=False,
        logger=False,
        default_root_dir=tmpdir,
    )
    trainer.fit(model)
    with pytest.deprecated_call(
        match="Returning a value from `TestCallbackSaveHookReturn.on_save_checkpoint` is deprecated in v1.6"
        " and will be removed in v1.8. Please override `Callback.state_dict`"
        " to return state to be saved."
    ):
        trainer.save_checkpoint(tmpdir + "/path.ckpt")

    trainer.callbacks = [TestCallbackSaveHookOverride()]
    trainer.save_checkpoint(tmpdir + "/pathok.ckpt")
Ejemplo n.º 3
0
def test_v1_5_0_old_on_validation_epoch_end(tmpdir):
    callback_warning_cache.clear()

    class OldSignature(Callback):
        def on_validation_epoch_end(self, trainer, pl_module):  # noqa
            ...

    model = BoringModel()
    trainer = Trainer(default_root_dir=tmpdir,
                      max_epochs=1,
                      callbacks=OldSignature())

    with pytest.deprecated_call(match="old signature will be removed in v1.5"):
        trainer.fit(model)

    class OldSignatureModel(BoringModel):
        def on_validation_epoch_end(self):  # noqa
            ...

    model = OldSignatureModel()

    with pytest.deprecated_call(match="old signature will be removed in v1.5"):
        trainer.fit(model)

    callback_warning_cache.clear()

    class NewSignature(Callback):
        def on_validation_epoch_end(self, trainer, pl_module, outputs):
            ...

    trainer.callbacks = [NewSignature()]
    with no_deprecated_call(
            match=
            "`Callback.on_validation_epoch_end` signature has changed in v1.3."
    ):
        trainer.fit(model)

    class NewSignatureModel(BoringModel):
        def on_validation_epoch_end(self, outputs):
            ...

    model = NewSignatureModel()
    with no_deprecated_call(
            match=
            "`ModelHooks.on_validation_epoch_end` signature has changed in v1.3."
    ):
        trainer.fit(model)
Ejemplo n.º 4
0
    def __init__(self,
                 pl_trainer: pl.Trainer,
                 model: pl.LightningModule,
                 population_tasks: mp.Queue,
                 tune_hparams: Dict,
                 process_position: int,
                 global_epoch: mp.Value,
                 max_epoch: int,
                 full_parallel: bool,
                 pbt_period: int = 4,
                 pbt_monitor: str = 'val_loss',
                 logger_info=None,
                 dataloaders: Optional[Dict] = None):
        """

        Args:
            pl_trainer:
            model:
            population_tasks:
            tune_hparams:
            process_position:
            global_epoch:
            max_epoch:
            full_parallel:
            pbt_period:
            **dataloaders:
        """
        super().__init__()
        # Set monitor and monitor_precision
        monitor_precision = 32
        # Set checkpoint dirpath
        #checkpoint_dirpath = pl_trainer.checkpoint_callback.dirpath
        #period = pl_trainer.checkpoint_callback.period
        # Formatting checkpoints
        checkpoint_format = '{task:03d}-{' + f'{pbt_monitor}:.{monitor_precision}f' + '}'
        checkpoint_filepath = os.path.join(pl_trainer.logger.log_dir,
                                           checkpoint_format)

        # For TaskSaving
        print(logger_info)

        checkpoint_dirpath = pl_trainer.logger.log_dir

        pl_trainer.checkpoint_callback = TaskSaving(
            filepath=checkpoint_filepath,
            monitor=pbt_monitor,
            population_tasks=population_tasks,
            period=1,
            full_parallel=full_parallel,
        )

        # For EarlyStopping
        pl_trainer.early_stop_callback = EarlyStopping(
            global_epoch=global_epoch, max_global_epoch=max_epoch)

        # For TaskLoading
        pl_trainer.callbacks = [
            TaskLoading(population_tasks=population_tasks,
                        global_epoch=global_epoch,
                        filepath=checkpoint_filepath,
                        monitor=pbt_monitor,
                        tune_hparams=tune_hparams,
                        pbt_period=pbt_period)
        ]

        # Alter logger to spec.
        #if isinstance(pl_trainer.logger, pl.loggers.TensorBoardLogger):
        pl_trainer.logger = loggers.TensorBoardLogger(
            save_dir=logger_info['save_dir'],
            name=logger_info['name'],
            version=logger_info['version'],
            task=process_position,
        )

        # Set process_position
        pl_trainer.process_position = process_position
        # pl_trainer.logger._version = f'worker_{process_position}'
        # Define and set = to
        self.trainer = pl_trainer
        self.model = model
        self.global_epoch = global_epoch
        self.population_tasks = population_tasks
        self.max_epoch = max_epoch
        self.dataloaders = dataloaders or {}
        print(dataloaders)