Esempio n. 1
0
def test_restore_checkpoint_after_pre_dispatch_default():
    """Assert default for restore_checkpoint_after_pre_dispatch is False."""
    plugin = SingleDevicePlugin(torch.device("cpu"))
    accelerator = CPUAccelerator(training_type_plugin=plugin,
                                 precision_plugin=PrecisionPlugin())
    assert not accelerator.training_type_plugin.restore_checkpoint_after_pre_dispatch
    assert not plugin.restore_checkpoint_after_pre_dispatch
Esempio n. 2
0
def test_unsupported_precision_plugins():
    """Test error messages are raised for unsupported precision plugins with CPU."""
    trainer = Mock()
    accelerator = CPUAccelerator(training_type_plugin=SingleDevicePlugin(
        torch.device("cpu")),
                                 precision_plugin=MixedPrecisionPlugin())
    with pytest.raises(MisconfigurationException,
                       match=r"AMP \+ CPU is not supported"):
        accelerator.setup(trainer=trainer)
Esempio n. 3
0
def test_restore_checkpoint_after_pre_dispatch(tmpdir,
                                               restore_after_pre_dispatch):
    """
    Test to ensure that if restore_checkpoint_after_pre_dispatch is True, then we only load the state after
    pre-dispatch is called.
    """
    class TestPlugin(SingleDevicePlugin):
        predispatched_called = False

        def pre_dispatch(self) -> None:
            super().pre_dispatch()
            self.predispatched_called = True

        @property
        def restore_checkpoint_after_pre_dispatch(self) -> bool:
            return restore_after_pre_dispatch

        def load_checkpoint_file(
                self, checkpoint_path: Union[str, Path]) -> Dict[str, Any]:
            assert self.predispatched_called == restore_after_pre_dispatch
            return super().load_checkpoint_file(checkpoint_path)

    model = BoringModel()
    trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True)
    trainer.fit(model)

    checkpoint_path = os.path.join(tmpdir, "model.pt")
    trainer.save_checkpoint(checkpoint_path)

    plugin = TestPlugin(torch.device("cpu"), checkpoint_io=TorchCheckpointIO())
    accelerator = CPUAccelerator(training_type_plugin=plugin,
                                 precision_plugin=PrecisionPlugin())

    assert accelerator.restore_checkpoint_after_pre_dispatch == restore_after_pre_dispatch
    assert plugin.restore_checkpoint_after_pre_dispatch == restore_after_pre_dispatch

    trainer = Trainer(default_root_dir=tmpdir,
                      accelerator=accelerator,
                      fast_dev_run=True,
                      resume_from_checkpoint=checkpoint_path)
    trainer.fit(model)
    for func in (trainer.test, trainer.validate, trainer.predict):
        accelerator.training_type_plugin.predispatched_called = False
        func(model, ckpt_path=checkpoint_path)
Esempio n. 4
0
def test_restore_checkpoint_after_pre_setup(tmpdir, restore_after_pre_setup):
    """Test to ensure that if restore_checkpoint_after_setup is True, then we only load the state after pre-
    dispatch is called."""
    class TestPlugin(SingleDeviceStrategy):
        setup_called = False

        def setup(self, trainer: "pl.Trainer") -> None:
            super().setup(trainer)
            self.setup_called = True

        @property
        def restore_checkpoint_after_setup(self) -> bool:
            return restore_after_pre_setup

        def load_checkpoint(
                self, checkpoint_path: Union[str, Path]) -> Dict[str, Any]:
            assert self.setup_called == restore_after_pre_setup
            return super().load_checkpoint(checkpoint_path)

    model = BoringModel()
    trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True)
    trainer.fit(model)

    checkpoint_path = os.path.join(tmpdir, "model.pt")
    trainer.save_checkpoint(checkpoint_path)

    plugin = TestPlugin(
        accelerator=CPUAccelerator(),
        precision_plugin=PrecisionPlugin(),
        device=torch.device("cpu"),
        checkpoint_io=TorchCheckpointIO(),
    )
    assert plugin.restore_checkpoint_after_setup == restore_after_pre_setup

    trainer = Trainer(default_root_dir=tmpdir,
                      strategy=plugin,
                      fast_dev_run=True)
    trainer.fit(model, ckpt_path=checkpoint_path)
    for func in (trainer.test, trainer.validate, trainer.predict):
        plugin.setup_called = False
        func(model, ckpt_path=checkpoint_path)
Esempio n. 5
0
def test_restore_checkpoint_after_pre_setup_default():
    """Assert default for restore_checkpoint_after_setup is False."""
    plugin = SingleDeviceStrategy(accelerator=CPUAccelerator(),
                                  device=torch.device("cpu"),
                                  precision_plugin=PrecisionPlugin())
    assert not plugin.restore_checkpoint_after_setup