def test_accelerator_on_reset_dataloader_hooks(tmpdir): """ Ensure data-loader hooks are called using an Accelerator. """ class CustomAccelerator(CPUAccelerator): train_count: int = 0 val_count: int = 0 test_count: int = 0 predict_count: int = 0 def on_reset_train_dataloader(self, dataloader): self.train_count += 1 assert self.lightning_module.trainer.training return super().on_reset_train_dataloader(dataloader) def on_reset_val_dataloader(self, dataloader): self.val_count += 1 assert self.lightning_module.trainer.training or self.lightning_module.trainer.validating return super().on_reset_val_dataloader(dataloader) def on_reset_test_dataloader(self, dataloader): self.test_count += 1 assert self.lightning_module.trainer.testing return super().on_reset_test_dataloader(dataloader) def on_reset_predict_dataloader(self, dataloader): self.predict_count += 1 assert self.lightning_module.trainer.predicting return super().on_reset_predict_dataloader(dataloader) model = BoringModel() accelerator = CustomAccelerator( PrecisionPlugin(), SingleDevicePlugin(device=torch.device("cpu"))) trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True, accelerator=accelerator) trainer.fit(model) trainer.validate(model) trainer.test(model) trainer.predict(model, dataloaders=model.test_dataloader()) # assert that all loader hooks were called assert accelerator.train_count == 1 assert accelerator.val_count == 1 # only called once during the entire session assert accelerator.test_count == 1 assert accelerator.predict_count == 1 accelerator = CustomAccelerator( PrecisionPlugin(), SingleDevicePlugin(device=torch.device("cpu"))) trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True, accelerator=accelerator) trainer.validate(model) trainer.test(model) trainer.predict(model) # assert val/test/predict loader hooks were called assert accelerator.val_count == 1 assert accelerator.test_count == 1 assert accelerator.predict_count == 1
def test_plugin_on_reset_dataloader_hooks(tmpdir): """ Ensure data-loader hooks are called using a Plugin. """ class CustomPlugin(SingleDevicePlugin): train_count: int = 0 val_count: int = 0 test_count: int = 0 predict_count: int = 0 def on_reset_train_dataloader(self, dataloader): self.train_count += 1 assert self.lightning_module.trainer.training return super().on_reset_train_dataloader(dataloader) def on_reset_val_dataloader(self, dataloader): self.val_count += 1 assert self.lightning_module.trainer.training or self.lightning_module.trainer.validating return super().on_reset_val_dataloader(dataloader) def on_reset_test_dataloader(self, dataloader): self.test_count += 1 assert self.lightning_module.trainer.testing return super().on_reset_test_dataloader(dataloader) def on_reset_predict_dataloader(self, dataloader): self.predict_count += 1 assert self.lightning_module.trainer.predicting return super().on_reset_predict_dataloader(dataloader) plugin = CustomPlugin(device=torch.device('cpu')) model = BoringModel() trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True, plugins=plugin) trainer.fit(model) trainer.validate(model) trainer.test(model) trainer.predict(model, dataloaders=model.test_dataloader()) # assert that all loader hooks were called assert plugin.train_count == 1 assert plugin.val_count == 1 # only called once during the entire session assert plugin.test_count == 1 assert plugin.predict_count == 1 plugin = CustomPlugin(device=torch.device('cpu')) trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True, plugins=plugin) trainer.validate(model) trainer.test(model) trainer.predict(model) # assert val/test/predict loader hooks were called assert plugin.val_count == 1 assert plugin.test_count == 1 assert plugin.predict_count == 1