예제 #1
0
def test_accelerator_on_reset_dataloader_hooks(tmpdir):
    """
    Ensure data-loader hooks are called using an Accelerator.
    """
    class CustomAccelerator(CPUAccelerator):
        train_count: int = 0
        val_count: int = 0
        test_count: int = 0
        predict_count: int = 0

        def on_reset_train_dataloader(self, dataloader):
            self.train_count += 1
            assert self.lightning_module.trainer.training
            return super().on_reset_train_dataloader(dataloader)

        def on_reset_val_dataloader(self, dataloader):
            self.val_count += 1
            assert self.lightning_module.trainer.training or self.lightning_module.trainer.validating
            return super().on_reset_val_dataloader(dataloader)

        def on_reset_test_dataloader(self, dataloader):
            self.test_count += 1
            assert self.lightning_module.trainer.testing
            return super().on_reset_test_dataloader(dataloader)

        def on_reset_predict_dataloader(self, dataloader):
            self.predict_count += 1
            assert self.lightning_module.trainer.predicting
            return super().on_reset_predict_dataloader(dataloader)

    model = BoringModel()
    accelerator = CustomAccelerator(
        PrecisionPlugin(), SingleDevicePlugin(device=torch.device("cpu")))
    trainer = Trainer(default_root_dir=tmpdir,
                      fast_dev_run=True,
                      accelerator=accelerator)
    trainer.fit(model)
    trainer.validate(model)
    trainer.test(model)
    trainer.predict(model, dataloaders=model.test_dataloader())
    # assert that all loader hooks were called
    assert accelerator.train_count == 1
    assert accelerator.val_count == 1  # only called once during the entire session
    assert accelerator.test_count == 1
    assert accelerator.predict_count == 1

    accelerator = CustomAccelerator(
        PrecisionPlugin(), SingleDevicePlugin(device=torch.device("cpu")))
    trainer = Trainer(default_root_dir=tmpdir,
                      fast_dev_run=True,
                      accelerator=accelerator)
    trainer.validate(model)
    trainer.test(model)
    trainer.predict(model)
    # assert val/test/predict loader hooks were called
    assert accelerator.val_count == 1
    assert accelerator.test_count == 1
    assert accelerator.predict_count == 1
예제 #2
0
def test_restore_checkpoint_after_pre_dispatch_default():
    """Assert default for restore_checkpoint_after_pre_dispatch is False."""
    plugin = SingleDevicePlugin(torch.device("cpu"))
    accelerator = CPUAccelerator(training_type_plugin=plugin,
                                 precision_plugin=PrecisionPlugin())
    assert not accelerator.training_type_plugin.restore_checkpoint_after_pre_dispatch
    assert not plugin.restore_checkpoint_after_pre_dispatch
예제 #3
0
    def select_training_type_plugin(self) -> TrainingTypePlugin:
        if self.use_ddp2:
            plugin = DDP2Plugin(parallel_devices=self.parallel_devices,
                                cluster_environment=self.cluster_environment)
        elif self.use_ddp and self.use_deepspeed:
            plugin = DeepSpeedPlugin(
                num_nodes=self.num_nodes,
                cluster_environment=self.select_cluster_environment(),
                parallel_devices=self.parallel_devices)
        elif self.use_ddp:
            use_slurm_ddp = self.use_ddp and self.is_slurm_managing_tasks
            use_torchelastic_ddp = self.use_ddp and self.is_using_torchelastic
            use_ddp_spawn = self._distrib_type == DistributedType.DDP_SPAWN
            use_ddp_cpu_spawn = self.use_ddp and self.on_cpu
            use_ddp_cpu_torch_elastic = use_ddp_cpu_spawn and self.is_using_torchelastic
            use_ddp_cpu_slurm = use_ddp_cpu_spawn and self.is_slurm_managing_tasks
            use_ddp_sharded = self._distrib_type == DistributedType.DDP_SHARDED
            use_ddp_sharded_spawn = self._distrib_type == DistributedType.DDP_SHARDED_SPAWN

            # TODO: decouple from TE
            # ddp script mode uses the same flags as TE
            if os.environ.get("PL_IN_DDP_SUBPROCESS", False):
                use_torchelastic_ddp = False

            if self.on_tpu:
                ddp_plugin_cls = TPUSpawnPlugin
            elif use_ddp_sharded:
                ddp_plugin_cls = DDPShardedPlugin
            elif use_ddp_sharded_spawn:
                ddp_plugin_cls = DDPSpawnShardedPlugin
            elif use_ddp_cpu_slurm or use_slurm_ddp or use_ddp_cpu_torch_elastic or use_torchelastic_ddp:
                ddp_plugin_cls = DDPPlugin
            elif use_ddp_spawn or use_ddp_cpu_spawn:
                ddp_plugin_cls = DDPSpawnPlugin
            else:
                ddp_plugin_cls = DDPPlugin

            plugin = ddp_plugin_cls(
                parallel_devices=self.parallel_devices,
                num_nodes=self.num_nodes,
                cluster_environment=self.cluster_environment,
                sync_batchnorm=self.sync_batchnorm,
            )
        elif self.use_dp:
            plugin = DataParallelPlugin(parallel_devices=self.parallel_devices)
        elif self.use_horovod:
            plugin = HorovodPlugin(parallel_devices=self.parallel_devices)
        elif self.on_tpu:
            if isinstance(self.tpu_cores, list):
                plugin = SingleTPUPlugin(self.tpu_id)
            else:
                plugin = TPUSpawnPlugin(
                    parallel_devices=list(range(self.tpu_cores)))
        else:
            single_gpu_ordinal = device_parser.determine_root_gpu_device(
                self.parallel_device_ids)
            plugin = SingleDevicePlugin(device=torch.device(
                f"cuda:{single_gpu_ordinal}" if self.on_gpu else "cpu"))
        return plugin
예제 #4
0
def test_unsupported_precision_plugins():
    """Test error messages are raised for unsupported precision plugins with CPU."""
    trainer = Mock()
    accelerator = CPUAccelerator(training_type_plugin=SingleDevicePlugin(
        torch.device("cpu")),
                                 precision_plugin=MixedPrecisionPlugin())
    with pytest.raises(MisconfigurationException,
                       match=r"AMP \+ CPU is not supported"):
        accelerator.setup(trainer=trainer)
def test_checkpoint_plugin_called(tmpdir):
    """Ensure that the custom checkpoint IO plugin and torch checkpoint IO plugin is called when saving/loading."""
    checkpoint_plugin = CustomCheckpointIO()
    checkpoint_plugin = MagicMock(wraps=checkpoint_plugin,
                                  spec=CustomCheckpointIO)

    ck = ModelCheckpoint(dirpath=tmpdir, save_last=True)

    model = BoringModel()
    device = torch.device("cpu")
    trainer = Trainer(
        default_root_dir=tmpdir,
        strategy=SingleDevicePlugin(device, checkpoint_io=checkpoint_plugin),
        callbacks=ck,
        max_epochs=2,
    )
    trainer.fit(model)

    assert checkpoint_plugin.save_checkpoint.call_count == 5
    assert checkpoint_plugin.remove_checkpoint.call_count == 1

    trainer.test(model, ckpt_path=ck.last_model_path)
    checkpoint_plugin.load_checkpoint.assert_called_with(tmpdir / "last.ckpt")

    checkpoint_plugin.reset_mock()
    ck = ModelCheckpoint(dirpath=tmpdir, save_last=True)

    model = BoringModel()
    device = torch.device("cpu")
    trainer = Trainer(
        default_root_dir=tmpdir,
        strategy=SingleDevicePlugin(device),
        plugins=[checkpoint_plugin],
        callbacks=ck,
        max_epochs=2,
    )
    trainer.fit(model)

    assert checkpoint_plugin.save_checkpoint.call_count == 5
    assert checkpoint_plugin.remove_checkpoint.call_count == 1

    trainer.test(model, ckpt_path=ck.last_model_path)
    checkpoint_plugin.load_checkpoint.assert_called_once()
    checkpoint_plugin.load_checkpoint.assert_called_with(tmpdir / "last.ckpt")
예제 #6
0
    def select_training_type_plugin(self):
        cluster_environment = self.select_cluster_environment()
        if self.use_ddp2:
            plugin = DDP2Plugin(parallel_devices=self.parallel_devices,
                                cluster_environment=cluster_environment)
        elif self.use_ddp:
            use_slurm_ddp = self.use_ddp and self.is_slurm_managing_tasks
            use_torchelastic_ddp = self.use_ddp and self.is_using_torchelastic
            use_ddp_spawn = self._distrib_type == DistributedType.DDP_SPAWN
            use_ddp_cpu_spawn = self.use_ddp and self.on_cpu
            use_ddp_cpu_torch_elastic = use_ddp_cpu_spawn and self.is_using_torchelastic
            use_ddp_cpu_slurm = use_ddp_cpu_spawn and self.is_slurm_managing_tasks
            # use_ddp_sharded = self.distributed_backend == "ddp_sharded"
            # use_ddp_sharded_spawn = self.distributed_backend == "ddp_sharded_spawn"

            if self.on_tpu:
                ddp_plugin_cls = TPUSpawnPlugin

            # ddp script mode uses the same flags as TE
            # TODO: decouple from TE
            if os.environ.get("PL_IN_DDP_SUBPROCESS", False):
                use_torchelastic_ddp = False

            # fixme
            # if use_ddp_sharded:
            #     ddp_plugin_cls = DDPShardedPlugin
            # elif use_ddp_sharded_spawn:
            #     ddp_plugin_cls = DDPSpawnShardedPlugin
            if use_ddp_cpu_slurm or use_slurm_ddp or use_ddp_cpu_torch_elastic or use_torchelastic_ddp:
                ddp_plugin_cls = DDPPlugin
            elif use_ddp_spawn or use_ddp_cpu_spawn:
                ddp_plugin_cls = DDPSpawnPlugin
            else:
                ddp_plugin_cls = DDPPlugin

            plugin = ddp_plugin_cls(
                parallel_devices=self.parallel_devices,
                num_nodes=self.num_nodes,
                cluster_environment=cluster_environment,
                sync_batchnorm=self.sync_batchnorm,
            )
        elif self.use_dp:
            plugin = DataParallelPlugin(parallel_devices=self.parallel_devices)
        elif self.use_horovod:
            plugin = HorovodPlugin(parallel_devices=self.parallel_devices)
        elif self.on_tpu:
            plugin = SingleTPUPlugin(self.tpu_id)
        else:
            plugin = SingleDevicePlugin(device=torch.device(
                f"cuda:{self.root_gpu}" if self.on_gpu else "cpu"))
        return plugin
예제 #7
0
    def __init__(
        self,
        precision_plugin: PrecisionPlugin = PrecisionPlugin(),
        training_type_plugin: TrainingTypePlugin = SingleDevicePlugin(
            torch.device(ipex.DEVICE)),
        enable_bf16=False,
    ) -> None:
        """

        Args:
            precision_plugin: the plugin to handle precision-specific parts
            training_type_plugin: the plugin to handle different training routines
        """
        if enable_bf16:
            # Automatically mix precision
            ipex.enable_auto_mixed_precision(mixed_dtype=torch.bfloat16)

        self.device = ipex.DEVICE

        super().__init__(precision_plugin=precision_plugin,
                         training_type_plugin=training_type_plugin)
예제 #8
0
    def select_training_type_plugin(self) -> TrainingTypePlugin:
        if isinstance(
            self.distributed_backend, Accelerator
        ) and self.distributed_backend.training_type_plugin is not None:
            plugin = self.distributed_backend.training_type_plugin
        elif self.use_ddp2:
            plugin = DDP2Plugin(
                parallel_devices=self.parallel_devices,
                cluster_environment=self.cluster_environment,
            )
        elif self.use_ddp and self.use_deepspeed:
            plugin = DeepSpeedPlugin(
                cluster_environment=self.select_cluster_environment(), parallel_devices=self.parallel_devices
            )
        elif self.use_ddp:
            use_slurm_ddp = self.use_ddp and self.is_slurm_managing_tasks
            use_torchelastic_ddp = self.use_ddp and TorchElasticEnvironment.is_using_torchelastic()
            use_kubeflow_ddp = self.use_ddp and KubeflowEnvironment.is_using_kubeflow()
            use_ddp_spawn = self._distrib_type == DistributedType.DDP_SPAWN
            use_ddp_cpu_spawn = self.use_ddp and self.on_cpu
            use_tpu_spawn = self.on_tpu and self._distrib_type == DistributedType.TPU_SPAWN
            use_ddp_cpu_torch_elastic = use_ddp_cpu_spawn and TorchElasticEnvironment.is_using_torchelastic()
            use_ddp_cpu_kubeflow = use_ddp_cpu_spawn and KubeflowEnvironment.is_using_kubeflow()
            use_ddp_cpu_slurm = use_ddp_cpu_spawn and self.is_slurm_managing_tasks
            use_ddp_sharded = self._distrib_type == DistributedType.DDP_SHARDED
            use_ddp_sharded_spawn = self._distrib_type == DistributedType.DDP_SHARDED_SPAWN
            use_ddp_fully_sharded = self._distrib_type == DistributedType.DDP_FULLY_SHARDED

            # TODO: decouple from TE
            # ddp script mode uses the same flags as TE
            if os.environ.get("PL_IN_DDP_SUBPROCESS", False):
                use_torchelastic_ddp = False

            if use_tpu_spawn:
                ddp_plugin_cls = TPUSpawnPlugin
            elif use_ddp_sharded:
                ddp_plugin_cls = DDPShardedPlugin
            elif use_ddp_sharded_spawn:
                ddp_plugin_cls = DDPSpawnShardedPlugin
            elif (
                use_ddp_cpu_slurm or use_slurm_ddp or use_ddp_cpu_torch_elastic or use_torchelastic_ddp
                or use_kubeflow_ddp or use_ddp_cpu_kubeflow
            ):
                ddp_plugin_cls = DDPPlugin
            elif use_ddp_spawn or use_ddp_cpu_spawn:
                ddp_plugin_cls = DDPSpawnPlugin
            elif use_ddp_fully_sharded:
                ddp_plugin_cls = DDPFullyShardedPlugin
            else:
                ddp_plugin_cls = DDPPlugin

            plugin = ddp_plugin_cls(
                parallel_devices=self.parallel_devices,
                cluster_environment=self.cluster_environment,
            )
        elif self.use_dp:
            plugin = DataParallelPlugin(parallel_devices=self.parallel_devices)
        elif self.use_horovod:
            plugin = HorovodPlugin(parallel_devices=self.parallel_devices)
        elif self.on_tpu and isinstance(self.tpu_cores, list):
            plugin = SingleTPUPlugin(self.tpu_id)
        elif self.on_ipu:
            plugin = IPUPlugin(parallel_devices=self.parallel_devices)
        else:
            single_gpu_ordinal = device_parser.determine_root_gpu_device(self.parallel_device_ids)
            plugin = SingleDevicePlugin(device=torch.device(f"cuda:{single_gpu_ordinal}" if self.on_gpu else "cpu"))
        return plugin