def test_accelerator_on_reset_dataloader_hooks(tmpdir): """ Ensure data-loader hooks are called using an Accelerator. """ class CustomAccelerator(CPUAccelerator): train_count: int = 0 val_count: int = 0 test_count: int = 0 predict_count: int = 0 def on_reset_train_dataloader(self, dataloader): self.train_count += 1 assert self.lightning_module.trainer.training return super().on_reset_train_dataloader(dataloader) def on_reset_val_dataloader(self, dataloader): self.val_count += 1 assert self.lightning_module.trainer.training or self.lightning_module.trainer.validating return super().on_reset_val_dataloader(dataloader) def on_reset_test_dataloader(self, dataloader): self.test_count += 1 assert self.lightning_module.trainer.testing return super().on_reset_test_dataloader(dataloader) def on_reset_predict_dataloader(self, dataloader): self.predict_count += 1 assert self.lightning_module.trainer.predicting return super().on_reset_predict_dataloader(dataloader) model = BoringModel() accelerator = CustomAccelerator( PrecisionPlugin(), SingleDevicePlugin(device=torch.device("cpu"))) trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True, accelerator=accelerator) trainer.fit(model) trainer.validate(model) trainer.test(model) trainer.predict(model, dataloaders=model.test_dataloader()) # assert that all loader hooks were called assert accelerator.train_count == 1 assert accelerator.val_count == 1 # only called once during the entire session assert accelerator.test_count == 1 assert accelerator.predict_count == 1 accelerator = CustomAccelerator( PrecisionPlugin(), SingleDevicePlugin(device=torch.device("cpu"))) trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True, accelerator=accelerator) trainer.validate(model) trainer.test(model) trainer.predict(model) # assert val/test/predict loader hooks were called assert accelerator.val_count == 1 assert accelerator.test_count == 1 assert accelerator.predict_count == 1
def test_restore_checkpoint_after_pre_dispatch_default(): """Assert default for restore_checkpoint_after_pre_dispatch is False.""" plugin = SingleDevicePlugin(torch.device("cpu")) accelerator = CPUAccelerator(training_type_plugin=plugin, precision_plugin=PrecisionPlugin()) assert not accelerator.training_type_plugin.restore_checkpoint_after_pre_dispatch assert not plugin.restore_checkpoint_after_pre_dispatch
def test_get_nvidia_gpu_stats(tmpdir): """Test GPU get_device_stats with Pytorch < 1.8.0.""" current_device = torch.device(f"cuda:{torch.cuda.current_device()}") GPUAccel = GPUAccelerator(training_type_plugin=DataParallelPlugin( parallel_devices=[current_device]), precision_plugin=PrecisionPlugin()) gpu_stats = GPUAccel.get_device_stats(current_device) fields = [ "utilization.gpu", "memory.used", "memory.free", "utilization.memory" ] for f in fields: assert any(f in h for h in gpu_stats.keys())
def test_get_torch_gpu_stats(tmpdir): """Test GPU get_device_stats with Pytorch >= 1.8.0.""" current_device = torch.device(f"cuda:{torch.cuda.current_device()}") GPUAccel = GPUAccelerator(training_type_plugin=DataParallelPlugin( parallel_devices=[current_device]), precision_plugin=PrecisionPlugin()) gpu_stats = GPUAccel.get_device_stats(current_device) fields = [ "allocated_bytes.all.freed", "inactive_split.all.peak", "reserved_bytes.large_pool.peak" ] for f in fields: assert any(f in h for h in gpu_stats.keys())
def test_restore_checkpoint_after_pre_dispatch(tmpdir, restore_after_pre_dispatch): """ Test to ensure that if restore_checkpoint_after_pre_dispatch is True, then we only load the state after pre-dispatch is called. """ class TestPlugin(SingleDevicePlugin): predispatched_called = False def pre_dispatch(self) -> None: super().pre_dispatch() self.predispatched_called = True @property def restore_checkpoint_after_pre_dispatch(self) -> bool: return restore_after_pre_dispatch def load_checkpoint_file( self, checkpoint_path: Union[str, Path]) -> Dict[str, Any]: assert self.predispatched_called == restore_after_pre_dispatch return super().load_checkpoint_file(checkpoint_path) model = BoringModel() trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True) trainer.fit(model) checkpoint_path = os.path.join(tmpdir, "model.pt") trainer.save_checkpoint(checkpoint_path) plugin = TestPlugin(torch.device("cpu"), checkpoint_io=TorchCheckpointIO()) accelerator = CPUAccelerator(training_type_plugin=plugin, precision_plugin=PrecisionPlugin()) assert accelerator.restore_checkpoint_after_pre_dispatch == restore_after_pre_dispatch assert plugin.restore_checkpoint_after_pre_dispatch == restore_after_pre_dispatch trainer = Trainer(default_root_dir=tmpdir, accelerator=accelerator, fast_dev_run=True, resume_from_checkpoint=checkpoint_path) trainer.fit(model) for func in (trainer.test, trainer.validate, trainer.predict): accelerator.training_type_plugin.predispatched_called = False func(model, ckpt_path=checkpoint_path)
def test_restore_checkpoint_after_pre_setup(tmpdir, restore_after_pre_setup): """Test to ensure that if restore_checkpoint_after_setup is True, then we only load the state after pre- dispatch is called.""" class TestPlugin(SingleDeviceStrategy): setup_called = False def setup(self, trainer: "pl.Trainer") -> None: super().setup(trainer) self.setup_called = True @property def restore_checkpoint_after_setup(self) -> bool: return restore_after_pre_setup def load_checkpoint( self, checkpoint_path: Union[str, Path]) -> Dict[str, Any]: assert self.setup_called == restore_after_pre_setup return super().load_checkpoint(checkpoint_path) model = BoringModel() trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True) trainer.fit(model) checkpoint_path = os.path.join(tmpdir, "model.pt") trainer.save_checkpoint(checkpoint_path) plugin = TestPlugin( accelerator=CPUAccelerator(), precision_plugin=PrecisionPlugin(), device=torch.device("cpu"), checkpoint_io=TorchCheckpointIO(), ) assert plugin.restore_checkpoint_after_setup == restore_after_pre_setup trainer = Trainer(default_root_dir=tmpdir, strategy=plugin, fast_dev_run=True) trainer.fit(model, ckpt_path=checkpoint_path) for func in (trainer.test, trainer.validate, trainer.predict): plugin.setup_called = False func(model, ckpt_path=checkpoint_path)
def test_restore_checkpoint_after_pre_setup_default(): """Assert default for restore_checkpoint_after_setup is False.""" plugin = SingleDeviceStrategy(accelerator=CPUAccelerator(), device=torch.device("cpu"), precision_plugin=PrecisionPlugin()) assert not plugin.restore_checkpoint_after_setup