예제 #1
0
def test_trainer_save_checkpoint_storage_options(tmpdir):
    """This test validates that storage_options argument is properly passed to ``CheckpointIO``"""
    model = BoringModel()
    trainer = Trainer(
        default_root_dir=tmpdir,
        fast_dev_run=True,
        enable_checkpointing=False,
    )
    trainer.fit(model)
    instance_path = tmpdir + "/path.ckpt"
    instance_storage_options = "my instance storage options"

    with mock.patch(
            "pytorch_lightning.plugins.io.torch_plugin.TorchCheckpointIO.save_checkpoint"
    ) as io_mock:
        trainer.save_checkpoint(instance_path,
                                storage_options=instance_storage_options)
        io_mock.assert_called_with(ANY,
                                   instance_path,
                                   storage_options=instance_storage_options)
        trainer.save_checkpoint(instance_path)
        io_mock.assert_called_with(ANY, instance_path, storage_options=None)

    with mock.patch(
            "pytorch_lightning.trainer.connectors.checkpoint_connector.CheckpointConnector.save_checkpoint"
    ) as cc_mock:
        trainer.save_checkpoint(instance_path, True)
        cc_mock.assert_called_with(instance_path,
                                   weights_only=True,
                                   storage_options=None)
        trainer.save_checkpoint(instance_path, False, instance_storage_options)
        cc_mock.assert_called_with(instance_path,
                                   weights_only=False,
                                   storage_options=instance_storage_options)

    torch_checkpoint_io = TorchCheckpointIO()
    with pytest.raises(
            TypeError,
            match=
            r"`Trainer.save_checkpoint\(..., storage_options=...\)` with `storage_options` arg"
            f" is not supported for `{torch_checkpoint_io.__class__.__name__}`. Please implement your custom `CheckpointIO`"
            " to define how you'd like to use `storage_options`.",
    ):
        torch_checkpoint_io.save_checkpoint(
            {}, instance_path, storage_options=instance_storage_options)
    xla_checkpoint_io = XLACheckpointIO()
    with pytest.raises(
            TypeError,
            match=
            r"`Trainer.save_checkpoint\(..., storage_options=...\)` with `storage_options` arg"
            f" is not supported for `{xla_checkpoint_io.__class__.__name__}`. Please implement your custom `CheckpointIO`"
            " to define how you'd like to use `storage_options`.",
    ):
        xla_checkpoint_io.save_checkpoint(
            {}, instance_path, storage_options=instance_storage_options)
예제 #2
0
def test_restore_checkpoint_after_pre_dispatch(tmpdir,
                                               restore_after_pre_dispatch):
    """
    Test to ensure that if restore_checkpoint_after_pre_dispatch is True, then we only load the state after
    pre-dispatch is called.
    """
    class TestPlugin(SingleDevicePlugin):
        predispatched_called = False

        def pre_dispatch(self) -> None:
            super().pre_dispatch()
            self.predispatched_called = True

        @property
        def restore_checkpoint_after_pre_dispatch(self) -> bool:
            return restore_after_pre_dispatch

        def load_checkpoint_file(
                self, checkpoint_path: Union[str, Path]) -> Dict[str, Any]:
            assert self.predispatched_called == restore_after_pre_dispatch
            return super().load_checkpoint_file(checkpoint_path)

    model = BoringModel()
    trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True)
    trainer.fit(model)

    checkpoint_path = os.path.join(tmpdir, "model.pt")
    trainer.save_checkpoint(checkpoint_path)

    plugin = TestPlugin(torch.device("cpu"), checkpoint_io=TorchCheckpointIO())
    accelerator = CPUAccelerator(training_type_plugin=plugin,
                                 precision_plugin=PrecisionPlugin())

    assert accelerator.restore_checkpoint_after_pre_dispatch == restore_after_pre_dispatch
    assert plugin.restore_checkpoint_after_pre_dispatch == restore_after_pre_dispatch

    trainer = Trainer(default_root_dir=tmpdir,
                      accelerator=accelerator,
                      fast_dev_run=True,
                      resume_from_checkpoint=checkpoint_path)
    trainer.fit(model)
    for func in (trainer.test, trainer.validate, trainer.predict):
        accelerator.training_type_plugin.predispatched_called = False
        func(model, ckpt_path=checkpoint_path)
예제 #3
0
def test_restore_checkpoint_after_pre_setup(tmpdir, restore_after_pre_setup):
    """Test to ensure that if restore_checkpoint_after_setup is True, then we only load the state after pre-
    dispatch is called."""
    class TestPlugin(SingleDeviceStrategy):
        setup_called = False

        def setup(self, trainer: "pl.Trainer") -> None:
            super().setup(trainer)
            self.setup_called = True

        @property
        def restore_checkpoint_after_setup(self) -> bool:
            return restore_after_pre_setup

        def load_checkpoint(
                self, checkpoint_path: Union[str, Path]) -> Dict[str, Any]:
            assert self.setup_called == restore_after_pre_setup
            return super().load_checkpoint(checkpoint_path)

    model = BoringModel()
    trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True)
    trainer.fit(model)

    checkpoint_path = os.path.join(tmpdir, "model.pt")
    trainer.save_checkpoint(checkpoint_path)

    plugin = TestPlugin(
        accelerator=CPUAccelerator(),
        precision_plugin=PrecisionPlugin(),
        device=torch.device("cpu"),
        checkpoint_io=TorchCheckpointIO(),
    )
    assert plugin.restore_checkpoint_after_setup == restore_after_pre_setup

    trainer = Trainer(default_root_dir=tmpdir,
                      strategy=plugin,
                      fast_dev_run=True)
    trainer.fit(model, ckpt_path=checkpoint_path)
    for func in (trainer.test, trainer.validate, trainer.predict):
        plugin.setup_called = False
        func(model, ckpt_path=checkpoint_path)