def test_trainer_save_checkpoint_storage_options(tmpdir): """This test validates that storage_options argument is properly passed to ``CheckpointIO``""" model = BoringModel() trainer = Trainer( default_root_dir=tmpdir, fast_dev_run=True, enable_checkpointing=False, ) trainer.fit(model) instance_path = tmpdir + "/path.ckpt" instance_storage_options = "my instance storage options" with mock.patch( "pytorch_lightning.plugins.io.torch_plugin.TorchCheckpointIO.save_checkpoint" ) as io_mock: trainer.save_checkpoint(instance_path, storage_options=instance_storage_options) io_mock.assert_called_with(ANY, instance_path, storage_options=instance_storage_options) trainer.save_checkpoint(instance_path) io_mock.assert_called_with(ANY, instance_path, storage_options=None) with mock.patch( "pytorch_lightning.trainer.connectors.checkpoint_connector.CheckpointConnector.save_checkpoint" ) as cc_mock: trainer.save_checkpoint(instance_path, True) cc_mock.assert_called_with(instance_path, weights_only=True, storage_options=None) trainer.save_checkpoint(instance_path, False, instance_storage_options) cc_mock.assert_called_with(instance_path, weights_only=False, storage_options=instance_storage_options) torch_checkpoint_io = TorchCheckpointIO() with pytest.raises( TypeError, match= r"`Trainer.save_checkpoint\(..., storage_options=...\)` with `storage_options` arg" f" is not supported for `{torch_checkpoint_io.__class__.__name__}`. Please implement your custom `CheckpointIO`" " to define how you'd like to use `storage_options`.", ): torch_checkpoint_io.save_checkpoint( {}, instance_path, storage_options=instance_storage_options) xla_checkpoint_io = XLACheckpointIO() with pytest.raises( TypeError, match= r"`Trainer.save_checkpoint\(..., storage_options=...\)` with `storage_options` arg" f" is not supported for `{xla_checkpoint_io.__class__.__name__}`. Please implement your custom `CheckpointIO`" " to define how you'd like to use `storage_options`.", ): xla_checkpoint_io.save_checkpoint( {}, instance_path, storage_options=instance_storage_options)
def test_restore_checkpoint_after_pre_dispatch(tmpdir, restore_after_pre_dispatch): """ Test to ensure that if restore_checkpoint_after_pre_dispatch is True, then we only load the state after pre-dispatch is called. """ class TestPlugin(SingleDevicePlugin): predispatched_called = False def pre_dispatch(self) -> None: super().pre_dispatch() self.predispatched_called = True @property def restore_checkpoint_after_pre_dispatch(self) -> bool: return restore_after_pre_dispatch def load_checkpoint_file( self, checkpoint_path: Union[str, Path]) -> Dict[str, Any]: assert self.predispatched_called == restore_after_pre_dispatch return super().load_checkpoint_file(checkpoint_path) model = BoringModel() trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True) trainer.fit(model) checkpoint_path = os.path.join(tmpdir, "model.pt") trainer.save_checkpoint(checkpoint_path) plugin = TestPlugin(torch.device("cpu"), checkpoint_io=TorchCheckpointIO()) accelerator = CPUAccelerator(training_type_plugin=plugin, precision_plugin=PrecisionPlugin()) assert accelerator.restore_checkpoint_after_pre_dispatch == restore_after_pre_dispatch assert plugin.restore_checkpoint_after_pre_dispatch == restore_after_pre_dispatch trainer = Trainer(default_root_dir=tmpdir, accelerator=accelerator, fast_dev_run=True, resume_from_checkpoint=checkpoint_path) trainer.fit(model) for func in (trainer.test, trainer.validate, trainer.predict): accelerator.training_type_plugin.predispatched_called = False func(model, ckpt_path=checkpoint_path)
def test_restore_checkpoint_after_pre_setup(tmpdir, restore_after_pre_setup): """Test to ensure that if restore_checkpoint_after_setup is True, then we only load the state after pre- dispatch is called.""" class TestPlugin(SingleDeviceStrategy): setup_called = False def setup(self, trainer: "pl.Trainer") -> None: super().setup(trainer) self.setup_called = True @property def restore_checkpoint_after_setup(self) -> bool: return restore_after_pre_setup def load_checkpoint( self, checkpoint_path: Union[str, Path]) -> Dict[str, Any]: assert self.setup_called == restore_after_pre_setup return super().load_checkpoint(checkpoint_path) model = BoringModel() trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True) trainer.fit(model) checkpoint_path = os.path.join(tmpdir, "model.pt") trainer.save_checkpoint(checkpoint_path) plugin = TestPlugin( accelerator=CPUAccelerator(), precision_plugin=PrecisionPlugin(), device=torch.device("cpu"), checkpoint_io=TorchCheckpointIO(), ) assert plugin.restore_checkpoint_after_setup == restore_after_pre_setup trainer = Trainer(default_root_dir=tmpdir, strategy=plugin, fast_dev_run=True) trainer.fit(model, ckpt_path=checkpoint_path) for func in (trainer.test, trainer.validate, trainer.predict): plugin.setup_called = False func(model, ckpt_path=checkpoint_path)