def test_unsupported_precision_plugins(): """Test error messages are raised for unsupported precision plugins with CPU.""" trainer = Mock() accelerator = CPUAccelerator(training_type_plugin=SingleDevicePlugin( torch.device("cpu")), precision_plugin=MixedPrecisionPlugin()) with pytest.raises(MisconfigurationException, match=r"AMP \+ CPU is not supported"): accelerator.setup(trainer=trainer)
def test_restore_checkpoint_after_pre_dispatch_default(): """Assert default for restore_checkpoint_after_pre_dispatch is False.""" plugin = SingleDevicePlugin(torch.device("cpu")) accelerator = CPUAccelerator(training_type_plugin=plugin, precision_plugin=PrecisionPlugin()) assert not accelerator.training_type_plugin.restore_checkpoint_after_pre_dispatch assert not plugin.restore_checkpoint_after_pre_dispatch
def test_restore_checkpoint_after_pre_dispatch(tmpdir, restore_after_pre_dispatch): """ Test to ensure that if restore_checkpoint_after_pre_dispatch is True, then we only load the state after pre-dispatch is called. """ class TestPlugin(SingleDevicePlugin): predispatched_called = False def pre_dispatch(self) -> None: super().pre_dispatch() self.predispatched_called = True @property def restore_checkpoint_after_pre_dispatch(self) -> bool: return restore_after_pre_dispatch def load_checkpoint_file( self, checkpoint_path: Union[str, Path]) -> Dict[str, Any]: assert self.predispatched_called == restore_after_pre_dispatch return super().load_checkpoint_file(checkpoint_path) model = BoringModel() trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True) trainer.fit(model) checkpoint_path = os.path.join(tmpdir, "model.pt") trainer.save_checkpoint(checkpoint_path) plugin = TestPlugin(torch.device("cpu"), checkpoint_io=TorchCheckpointIO()) accelerator = CPUAccelerator(training_type_plugin=plugin, precision_plugin=PrecisionPlugin()) assert accelerator.restore_checkpoint_after_pre_dispatch == restore_after_pre_dispatch assert plugin.restore_checkpoint_after_pre_dispatch == restore_after_pre_dispatch trainer = Trainer(default_root_dir=tmpdir, accelerator=accelerator, fast_dev_run=True, resume_from_checkpoint=checkpoint_path) trainer.fit(model) for func in (trainer.test, trainer.validate, trainer.predict): accelerator.training_type_plugin.predispatched_called = False func(model, ckpt_path=checkpoint_path)
def test_restore_checkpoint_after_pre_setup(tmpdir, restore_after_pre_setup): """Test to ensure that if restore_checkpoint_after_setup is True, then we only load the state after pre- dispatch is called.""" class TestPlugin(SingleDeviceStrategy): setup_called = False def setup(self, trainer: "pl.Trainer") -> None: super().setup(trainer) self.setup_called = True @property def restore_checkpoint_after_setup(self) -> bool: return restore_after_pre_setup def load_checkpoint( self, checkpoint_path: Union[str, Path]) -> Dict[str, Any]: assert self.setup_called == restore_after_pre_setup return super().load_checkpoint(checkpoint_path) model = BoringModel() trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True) trainer.fit(model) checkpoint_path = os.path.join(tmpdir, "model.pt") trainer.save_checkpoint(checkpoint_path) plugin = TestPlugin( accelerator=CPUAccelerator(), precision_plugin=PrecisionPlugin(), device=torch.device("cpu"), checkpoint_io=TorchCheckpointIO(), ) assert plugin.restore_checkpoint_after_setup == restore_after_pre_setup trainer = Trainer(default_root_dir=tmpdir, strategy=plugin, fast_dev_run=True) trainer.fit(model, ckpt_path=checkpoint_path) for func in (trainer.test, trainer.validate, trainer.predict): plugin.setup_called = False func(model, ckpt_path=checkpoint_path)
def test_auto_device_count(_): assert CPUAccelerator.auto_device_count() == 1 assert CUDAAccelerator.auto_device_count() == 2 assert TPUAccelerator.auto_device_count() == 8 assert IPUAccelerator.auto_device_count() == 4
def test_auto_device_count(device_count_mock): assert CPUAccelerator.auto_device_count() == 1 assert GPUAccelerator.auto_device_count() == 2 assert TPUAccelerator.auto_device_count() == 8 assert IPUAccelerator.auto_device_count() == 4
def test_availability(): assert CPUAccelerator.is_available()
def test_restore_checkpoint_after_pre_setup_default(): """Assert default for restore_checkpoint_after_setup is False.""" plugin = SingleDeviceStrategy(accelerator=CPUAccelerator(), device=torch.device("cpu"), precision_plugin=PrecisionPlugin()) assert not plugin.restore_checkpoint_after_setup