def test_dm_checkpoint_save_and_load(tmpdir): class CustomBoringModel(BoringModel): def validation_step(self, batch, batch_idx): out = super().validation_step(batch, batch_idx) self.log("early_stop_on", out["x"]) return out class CustomBoringDataModule(BoringDataModule): def state_dict(self) -> Dict[str, Any]: return {"my": "state_dict"} def load_state_dict(self, state_dict: Dict[str, Any]) -> None: self.my_state_dict = state_dict def on_save_checkpoint(self, checkpoint: Dict[str, Any]) -> None: checkpoint[self.__class__.__qualname__].update( {"on_save": "update"}) def on_load_checkpoint(self, checkpoint: Dict[str, Any]) -> None: self.checkpoint_state = checkpoint.get( self.__class__.__qualname__).copy() checkpoint[self.__class__.__qualname__].pop("on_save") reset_seed() dm = CustomBoringDataModule() model = CustomBoringModel() trainer = Trainer( default_root_dir=tmpdir, max_epochs=1, limit_train_batches=2, limit_val_batches=1, enable_model_summary=False, callbacks=[ModelCheckpoint(dirpath=tmpdir, monitor="early_stop_on")], ) # fit model with pytest.deprecated_call( match="`LightningDataModule.on_save_checkpoint` was deprecated in" " v1.6 and will be removed in v1.8. Use `state_dict` instead."): trainer.fit(model, datamodule=dm) assert trainer.state.finished, f"Training failed with {trainer.state}" checkpoint_path = list(trainer.checkpoint_callback.best_k_models.keys())[0] checkpoint = torch.load(checkpoint_path) assert dm.__class__.__qualname__ in checkpoint assert checkpoint[dm.__class__.__qualname__] == { "my": "state_dict", "on_save": "update" } for trainer_fn in TrainerFn: trainer.state.fn = trainer_fn trainer._restore_modules_and_callbacks(checkpoint_path) assert dm.checkpoint_state == {"my": "state_dict", "on_save": "update"} assert dm.my_state_dict == {"my": "state_dict"}
def test_dm_checkpoint_save_and_load(tmpdir): class CustomBoringModel(BoringModel): def validation_step(self, batch, batch_idx): out = super().validation_step(batch, batch_idx) self.log("early_stop_on", out["x"]) return out class CustomBoringDataModule(BoringDataModule): def on_save_checkpoint(self, checkpoint: Dict[str, Any]) -> None: checkpoint[self.__class__.__name__] = self.__class__.__name__ def on_load_checkpoint(self, checkpoint: Dict[str, Any]) -> None: self.checkpoint_state = checkpoint.get(self.__class__.__name__) reset_seed() dm = CustomBoringDataModule() model = CustomBoringModel() trainer = Trainer( default_root_dir=tmpdir, max_epochs=1, limit_train_batches=2, limit_val_batches=1, enable_model_summary=False, callbacks=[ModelCheckpoint(dirpath=tmpdir, monitor="early_stop_on")], ) # fit model trainer.fit(model, datamodule=dm) assert trainer.state.finished, f"Training failed with {trainer.state}" checkpoint_path = list(trainer.checkpoint_callback.best_k_models.keys())[0] checkpoint = torch.load(checkpoint_path) assert dm.__class__.__name__ in checkpoint assert checkpoint[dm.__class__.__name__] == dm.__class__.__name__ for trainer_fn in TrainerFn: trainer.state.fn = trainer_fn with mock.patch.object(dm, "on_load_checkpoint") as dm_mock: trainer._restore_modules_and_callbacks(checkpoint_path) dm_mock.assert_called_once()