示例#1
0
def test_accelerator_on_reset_dataloader_hooks(tmpdir):
    """
    Ensure data-loader hooks are called using an Accelerator.
    """
    class CustomAccelerator(CPUAccelerator):
        train_count: int = 0
        val_count: int = 0
        test_count: int = 0
        predict_count: int = 0

        def on_reset_train_dataloader(self, dataloader):
            self.train_count += 1
            assert self.lightning_module.trainer.training
            return super().on_reset_train_dataloader(dataloader)

        def on_reset_val_dataloader(self, dataloader):
            self.val_count += 1
            assert self.lightning_module.trainer.training or self.lightning_module.trainer.validating
            return super().on_reset_val_dataloader(dataloader)

        def on_reset_test_dataloader(self, dataloader):
            self.test_count += 1
            assert self.lightning_module.trainer.testing
            return super().on_reset_test_dataloader(dataloader)

        def on_reset_predict_dataloader(self, dataloader):
            self.predict_count += 1
            assert self.lightning_module.trainer.predicting
            return super().on_reset_predict_dataloader(dataloader)

    model = BoringModel()
    accelerator = CustomAccelerator(
        PrecisionPlugin(), SingleDevicePlugin(device=torch.device("cpu")))
    trainer = Trainer(default_root_dir=tmpdir,
                      fast_dev_run=True,
                      accelerator=accelerator)
    trainer.fit(model)
    trainer.validate(model)
    trainer.test(model)
    trainer.predict(model, dataloaders=model.test_dataloader())
    # assert that all loader hooks were called
    assert accelerator.train_count == 1
    assert accelerator.val_count == 1  # only called once during the entire session
    assert accelerator.test_count == 1
    assert accelerator.predict_count == 1

    accelerator = CustomAccelerator(
        PrecisionPlugin(), SingleDevicePlugin(device=torch.device("cpu")))
    trainer = Trainer(default_root_dir=tmpdir,
                      fast_dev_run=True,
                      accelerator=accelerator)
    trainer.validate(model)
    trainer.test(model)
    trainer.predict(model)
    # assert val/test/predict loader hooks were called
    assert accelerator.val_count == 1
    assert accelerator.test_count == 1
    assert accelerator.predict_count == 1
示例#2
0
def test_plugin_on_reset_dataloader_hooks(tmpdir):
    """
    Ensure data-loader hooks are called using a Plugin.
    """
    class CustomPlugin(SingleDevicePlugin):
        train_count: int = 0
        val_count: int = 0
        test_count: int = 0
        predict_count: int = 0

        def on_reset_train_dataloader(self, dataloader):
            self.train_count += 1
            assert self.lightning_module.trainer.training
            return super().on_reset_train_dataloader(dataloader)

        def on_reset_val_dataloader(self, dataloader):
            self.val_count += 1
            assert self.lightning_module.trainer.training or self.lightning_module.trainer.validating
            return super().on_reset_val_dataloader(dataloader)

        def on_reset_test_dataloader(self, dataloader):
            self.test_count += 1
            assert self.lightning_module.trainer.testing
            return super().on_reset_test_dataloader(dataloader)

        def on_reset_predict_dataloader(self, dataloader):
            self.predict_count += 1
            assert self.lightning_module.trainer.predicting
            return super().on_reset_predict_dataloader(dataloader)

    plugin = CustomPlugin(device=torch.device('cpu'))
    model = BoringModel()
    trainer = Trainer(default_root_dir=tmpdir,
                      fast_dev_run=True,
                      plugins=plugin)
    trainer.fit(model)
    trainer.validate(model)
    trainer.test(model)
    trainer.predict(model, dataloaders=model.test_dataloader())
    # assert that all loader hooks were called
    assert plugin.train_count == 1
    assert plugin.val_count == 1  # only called once during the entire session
    assert plugin.test_count == 1
    assert plugin.predict_count == 1
    plugin = CustomPlugin(device=torch.device('cpu'))
    trainer = Trainer(default_root_dir=tmpdir,
                      fast_dev_run=True,
                      plugins=plugin)
    trainer.validate(model)
    trainer.test(model)
    trainer.predict(model)
    # assert val/test/predict loader hooks were called
    assert plugin.val_count == 1
    assert plugin.test_count == 1
    assert plugin.predict_count == 1