Esempio n. 1
0
def test_accelerator_on_reset_dataloader_hooks(tmpdir):
    """
    Ensure data-loader hooks are called using an Accelerator.
    """
    class CustomAccelerator(CPUAccelerator):
        train_count: int = 0
        val_count: int = 0
        test_count: int = 0
        predict_count: int = 0

        def on_reset_train_dataloader(self, dataloader):
            self.train_count += 1
            assert self.lightning_module.trainer.training
            return super().on_reset_train_dataloader(dataloader)

        def on_reset_val_dataloader(self, dataloader):
            self.val_count += 1
            assert self.lightning_module.trainer.training or self.lightning_module.trainer.validating
            return super().on_reset_val_dataloader(dataloader)

        def on_reset_test_dataloader(self, dataloader):
            self.test_count += 1
            assert self.lightning_module.trainer.testing
            return super().on_reset_test_dataloader(dataloader)

        def on_reset_predict_dataloader(self, dataloader):
            self.predict_count += 1
            assert self.lightning_module.trainer.predicting
            return super().on_reset_predict_dataloader(dataloader)

    model = BoringModel()
    accelerator = CustomAccelerator(
        PrecisionPlugin(), SingleDevicePlugin(device=torch.device("cpu")))
    trainer = Trainer(default_root_dir=tmpdir,
                      fast_dev_run=True,
                      accelerator=accelerator)
    trainer.fit(model)
    trainer.validate(model)
    trainer.test(model)
    trainer.predict(model, dataloaders=model.test_dataloader())
    # assert that all loader hooks were called
    assert accelerator.train_count == 1
    assert accelerator.val_count == 1  # only called once during the entire session
    assert accelerator.test_count == 1
    assert accelerator.predict_count == 1

    accelerator = CustomAccelerator(
        PrecisionPlugin(), SingleDevicePlugin(device=torch.device("cpu")))
    trainer = Trainer(default_root_dir=tmpdir,
                      fast_dev_run=True,
                      accelerator=accelerator)
    trainer.validate(model)
    trainer.test(model)
    trainer.predict(model)
    # assert val/test/predict loader hooks were called
    assert accelerator.val_count == 1
    assert accelerator.test_count == 1
    assert accelerator.predict_count == 1
Esempio n. 2
0
def test_restore_checkpoint_after_pre_dispatch_default():
    """Assert default for restore_checkpoint_after_pre_dispatch is False."""
    plugin = SingleDevicePlugin(torch.device("cpu"))
    accelerator = CPUAccelerator(training_type_plugin=plugin,
                                 precision_plugin=PrecisionPlugin())
    assert not accelerator.training_type_plugin.restore_checkpoint_after_pre_dispatch
    assert not plugin.restore_checkpoint_after_pre_dispatch
Esempio n. 3
0
def test_get_nvidia_gpu_stats(tmpdir):
    """Test GPU get_device_stats with Pytorch < 1.8.0."""
    current_device = torch.device(f"cuda:{torch.cuda.current_device()}")
    GPUAccel = GPUAccelerator(training_type_plugin=DataParallelPlugin(
        parallel_devices=[current_device]),
                              precision_plugin=PrecisionPlugin())
    gpu_stats = GPUAccel.get_device_stats(current_device)
    fields = [
        "utilization.gpu", "memory.used", "memory.free", "utilization.memory"
    ]

    for f in fields:
        assert any(f in h for h in gpu_stats.keys())
Esempio n. 4
0
def test_get_torch_gpu_stats(tmpdir):
    """Test GPU get_device_stats with Pytorch >= 1.8.0."""
    current_device = torch.device(f"cuda:{torch.cuda.current_device()}")
    GPUAccel = GPUAccelerator(training_type_plugin=DataParallelPlugin(
        parallel_devices=[current_device]),
                              precision_plugin=PrecisionPlugin())
    gpu_stats = GPUAccel.get_device_stats(current_device)
    fields = [
        "allocated_bytes.all.freed", "inactive_split.all.peak",
        "reserved_bytes.large_pool.peak"
    ]

    for f in fields:
        assert any(f in h for h in gpu_stats.keys())
Esempio n. 5
0
def test_restore_checkpoint_after_pre_dispatch(tmpdir,
                                               restore_after_pre_dispatch):
    """
    Test to ensure that if restore_checkpoint_after_pre_dispatch is True, then we only load the state after
    pre-dispatch is called.
    """
    class TestPlugin(SingleDevicePlugin):
        predispatched_called = False

        def pre_dispatch(self) -> None:
            super().pre_dispatch()
            self.predispatched_called = True

        @property
        def restore_checkpoint_after_pre_dispatch(self) -> bool:
            return restore_after_pre_dispatch

        def load_checkpoint_file(
                self, checkpoint_path: Union[str, Path]) -> Dict[str, Any]:
            assert self.predispatched_called == restore_after_pre_dispatch
            return super().load_checkpoint_file(checkpoint_path)

    model = BoringModel()
    trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True)
    trainer.fit(model)

    checkpoint_path = os.path.join(tmpdir, "model.pt")
    trainer.save_checkpoint(checkpoint_path)

    plugin = TestPlugin(torch.device("cpu"), checkpoint_io=TorchCheckpointIO())
    accelerator = CPUAccelerator(training_type_plugin=plugin,
                                 precision_plugin=PrecisionPlugin())

    assert accelerator.restore_checkpoint_after_pre_dispatch == restore_after_pre_dispatch
    assert plugin.restore_checkpoint_after_pre_dispatch == restore_after_pre_dispatch

    trainer = Trainer(default_root_dir=tmpdir,
                      accelerator=accelerator,
                      fast_dev_run=True,
                      resume_from_checkpoint=checkpoint_path)
    trainer.fit(model)
    for func in (trainer.test, trainer.validate, trainer.predict):
        accelerator.training_type_plugin.predispatched_called = False
        func(model, ckpt_path=checkpoint_path)
Esempio n. 6
0
def test_restore_checkpoint_after_pre_setup(tmpdir, restore_after_pre_setup):
    """Test to ensure that if restore_checkpoint_after_setup is True, then we only load the state after pre-
    dispatch is called."""
    class TestPlugin(SingleDeviceStrategy):
        setup_called = False

        def setup(self, trainer: "pl.Trainer") -> None:
            super().setup(trainer)
            self.setup_called = True

        @property
        def restore_checkpoint_after_setup(self) -> bool:
            return restore_after_pre_setup

        def load_checkpoint(
                self, checkpoint_path: Union[str, Path]) -> Dict[str, Any]:
            assert self.setup_called == restore_after_pre_setup
            return super().load_checkpoint(checkpoint_path)

    model = BoringModel()
    trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True)
    trainer.fit(model)

    checkpoint_path = os.path.join(tmpdir, "model.pt")
    trainer.save_checkpoint(checkpoint_path)

    plugin = TestPlugin(
        accelerator=CPUAccelerator(),
        precision_plugin=PrecisionPlugin(),
        device=torch.device("cpu"),
        checkpoint_io=TorchCheckpointIO(),
    )
    assert plugin.restore_checkpoint_after_setup == restore_after_pre_setup

    trainer = Trainer(default_root_dir=tmpdir,
                      strategy=plugin,
                      fast_dev_run=True)
    trainer.fit(model, ckpt_path=checkpoint_path)
    for func in (trainer.test, trainer.validate, trainer.predict):
        plugin.setup_called = False
        func(model, ckpt_path=checkpoint_path)
Esempio n. 7
0
def test_restore_checkpoint_after_pre_setup_default():
    """Assert default for restore_checkpoint_after_setup is False."""
    plugin = SingleDeviceStrategy(accelerator=CPUAccelerator(),
                                  device=torch.device("cpu"),
                                  precision_plugin=PrecisionPlugin())
    assert not plugin.restore_checkpoint_after_setup