Пример #1
0
def test_dm_checkpoint_save_and_load(tmpdir):
    class CustomBoringModel(BoringModel):
        def validation_step(self, batch, batch_idx):
            out = super().validation_step(batch, batch_idx)
            self.log("early_stop_on", out["x"])
            return out

    class CustomBoringDataModule(BoringDataModule):
        def state_dict(self) -> Dict[str, Any]:
            return {"my": "state_dict"}

        def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
            self.my_state_dict = state_dict

        def on_save_checkpoint(self, checkpoint: Dict[str, Any]) -> None:
            checkpoint[self.__class__.__qualname__].update(
                {"on_save": "update"})

        def on_load_checkpoint(self, checkpoint: Dict[str, Any]) -> None:
            self.checkpoint_state = checkpoint.get(
                self.__class__.__qualname__).copy()
            checkpoint[self.__class__.__qualname__].pop("on_save")

    reset_seed()
    dm = CustomBoringDataModule()
    model = CustomBoringModel()

    trainer = Trainer(
        default_root_dir=tmpdir,
        max_epochs=1,
        limit_train_batches=2,
        limit_val_batches=1,
        enable_model_summary=False,
        callbacks=[ModelCheckpoint(dirpath=tmpdir, monitor="early_stop_on")],
    )

    # fit model
    with pytest.deprecated_call(
            match="`LightningDataModule.on_save_checkpoint` was deprecated in"
            " v1.6 and will be removed in v1.8. Use `state_dict` instead."):
        trainer.fit(model, datamodule=dm)
    assert trainer.state.finished, f"Training failed with {trainer.state}"
    checkpoint_path = list(trainer.checkpoint_callback.best_k_models.keys())[0]
    checkpoint = torch.load(checkpoint_path)
    assert dm.__class__.__qualname__ in checkpoint
    assert checkpoint[dm.__class__.__qualname__] == {
        "my": "state_dict",
        "on_save": "update"
    }

    for trainer_fn in TrainerFn:
        trainer.state.fn = trainer_fn
        trainer._restore_modules_and_callbacks(checkpoint_path)
        assert dm.checkpoint_state == {"my": "state_dict", "on_save": "update"}
        assert dm.my_state_dict == {"my": "state_dict"}
def test_dm_checkpoint_save_and_load(tmpdir):
    class CustomBoringModel(BoringModel):
        def validation_step(self, batch, batch_idx):
            out = super().validation_step(batch, batch_idx)
            self.log("early_stop_on", out["x"])
            return out

    class CustomBoringDataModule(BoringDataModule):
        def on_save_checkpoint(self, checkpoint: Dict[str, Any]) -> None:
            checkpoint[self.__class__.__name__] = self.__class__.__name__

        def on_load_checkpoint(self, checkpoint: Dict[str, Any]) -> None:
            self.checkpoint_state = checkpoint.get(self.__class__.__name__)

    reset_seed()
    dm = CustomBoringDataModule()
    model = CustomBoringModel()

    trainer = Trainer(
        default_root_dir=tmpdir,
        max_epochs=1,
        limit_train_batches=2,
        limit_val_batches=1,
        enable_model_summary=False,
        callbacks=[ModelCheckpoint(dirpath=tmpdir, monitor="early_stop_on")],
    )

    # fit model
    trainer.fit(model, datamodule=dm)
    assert trainer.state.finished, f"Training failed with {trainer.state}"
    checkpoint_path = list(trainer.checkpoint_callback.best_k_models.keys())[0]
    checkpoint = torch.load(checkpoint_path)
    assert dm.__class__.__name__ in checkpoint
    assert checkpoint[dm.__class__.__name__] == dm.__class__.__name__

    for trainer_fn in TrainerFn:
        trainer.state.fn = trainer_fn
        with mock.patch.object(dm, "on_load_checkpoint") as dm_mock:
            trainer._restore_modules_and_callbacks(checkpoint_path)
            dm_mock.assert_called_once()