def test_accelerator_on_reset_dataloader_hooks(tmpdir): """ Ensure data-loader hooks are called using an Accelerator. """ class CustomAccelerator(CPUAccelerator): train_count: int = 0 val_count: int = 0 test_count: int = 0 predict_count: int = 0 def on_reset_train_dataloader(self, dataloader): self.train_count += 1 assert self.lightning_module.trainer.training return super().on_reset_train_dataloader(dataloader) def on_reset_val_dataloader(self, dataloader): self.val_count += 1 assert self.lightning_module.trainer.training or self.lightning_module.trainer.validating return super().on_reset_val_dataloader(dataloader) def on_reset_test_dataloader(self, dataloader): self.test_count += 1 assert self.lightning_module.trainer.testing return super().on_reset_test_dataloader(dataloader) def on_reset_predict_dataloader(self, dataloader): self.predict_count += 1 assert self.lightning_module.trainer.predicting return super().on_reset_predict_dataloader(dataloader) model = BoringModel() accelerator = CustomAccelerator( PrecisionPlugin(), SingleDevicePlugin(device=torch.device("cpu"))) trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True, accelerator=accelerator) trainer.fit(model) trainer.validate(model) trainer.test(model) trainer.predict(model, dataloaders=model.test_dataloader()) # assert that all loader hooks were called assert accelerator.train_count == 1 assert accelerator.val_count == 1 # only called once during the entire session assert accelerator.test_count == 1 assert accelerator.predict_count == 1 accelerator = CustomAccelerator( PrecisionPlugin(), SingleDevicePlugin(device=torch.device("cpu"))) trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True, accelerator=accelerator) trainer.validate(model) trainer.test(model) trainer.predict(model) # assert val/test/predict loader hooks were called assert accelerator.val_count == 1 assert accelerator.test_count == 1 assert accelerator.predict_count == 1
def test_restore_checkpoint_after_pre_dispatch_default(): """Assert default for restore_checkpoint_after_pre_dispatch is False.""" plugin = SingleDevicePlugin(torch.device("cpu")) accelerator = CPUAccelerator(training_type_plugin=plugin, precision_plugin=PrecisionPlugin()) assert not accelerator.training_type_plugin.restore_checkpoint_after_pre_dispatch assert not plugin.restore_checkpoint_after_pre_dispatch
def select_training_type_plugin(self) -> TrainingTypePlugin: if self.use_ddp2: plugin = DDP2Plugin(parallel_devices=self.parallel_devices, cluster_environment=self.cluster_environment) elif self.use_ddp and self.use_deepspeed: plugin = DeepSpeedPlugin( num_nodes=self.num_nodes, cluster_environment=self.select_cluster_environment(), parallel_devices=self.parallel_devices) elif self.use_ddp: use_slurm_ddp = self.use_ddp and self.is_slurm_managing_tasks use_torchelastic_ddp = self.use_ddp and self.is_using_torchelastic use_ddp_spawn = self._distrib_type == DistributedType.DDP_SPAWN use_ddp_cpu_spawn = self.use_ddp and self.on_cpu use_ddp_cpu_torch_elastic = use_ddp_cpu_spawn and self.is_using_torchelastic use_ddp_cpu_slurm = use_ddp_cpu_spawn and self.is_slurm_managing_tasks use_ddp_sharded = self._distrib_type == DistributedType.DDP_SHARDED use_ddp_sharded_spawn = self._distrib_type == DistributedType.DDP_SHARDED_SPAWN # TODO: decouple from TE # ddp script mode uses the same flags as TE if os.environ.get("PL_IN_DDP_SUBPROCESS", False): use_torchelastic_ddp = False if self.on_tpu: ddp_plugin_cls = TPUSpawnPlugin elif use_ddp_sharded: ddp_plugin_cls = DDPShardedPlugin elif use_ddp_sharded_spawn: ddp_plugin_cls = DDPSpawnShardedPlugin elif use_ddp_cpu_slurm or use_slurm_ddp or use_ddp_cpu_torch_elastic or use_torchelastic_ddp: ddp_plugin_cls = DDPPlugin elif use_ddp_spawn or use_ddp_cpu_spawn: ddp_plugin_cls = DDPSpawnPlugin else: ddp_plugin_cls = DDPPlugin plugin = ddp_plugin_cls( parallel_devices=self.parallel_devices, num_nodes=self.num_nodes, cluster_environment=self.cluster_environment, sync_batchnorm=self.sync_batchnorm, ) elif self.use_dp: plugin = DataParallelPlugin(parallel_devices=self.parallel_devices) elif self.use_horovod: plugin = HorovodPlugin(parallel_devices=self.parallel_devices) elif self.on_tpu: if isinstance(self.tpu_cores, list): plugin = SingleTPUPlugin(self.tpu_id) else: plugin = TPUSpawnPlugin( parallel_devices=list(range(self.tpu_cores))) else: single_gpu_ordinal = device_parser.determine_root_gpu_device( self.parallel_device_ids) plugin = SingleDevicePlugin(device=torch.device( f"cuda:{single_gpu_ordinal}" if self.on_gpu else "cpu")) return plugin
def test_unsupported_precision_plugins(): """Test error messages are raised for unsupported precision plugins with CPU.""" trainer = Mock() accelerator = CPUAccelerator(training_type_plugin=SingleDevicePlugin( torch.device("cpu")), precision_plugin=MixedPrecisionPlugin()) with pytest.raises(MisconfigurationException, match=r"AMP \+ CPU is not supported"): accelerator.setup(trainer=trainer)
def test_checkpoint_plugin_called(tmpdir): """Ensure that the custom checkpoint IO plugin and torch checkpoint IO plugin is called when saving/loading.""" checkpoint_plugin = CustomCheckpointIO() checkpoint_plugin = MagicMock(wraps=checkpoint_plugin, spec=CustomCheckpointIO) ck = ModelCheckpoint(dirpath=tmpdir, save_last=True) model = BoringModel() device = torch.device("cpu") trainer = Trainer( default_root_dir=tmpdir, strategy=SingleDevicePlugin(device, checkpoint_io=checkpoint_plugin), callbacks=ck, max_epochs=2, ) trainer.fit(model) assert checkpoint_plugin.save_checkpoint.call_count == 5 assert checkpoint_plugin.remove_checkpoint.call_count == 1 trainer.test(model, ckpt_path=ck.last_model_path) checkpoint_plugin.load_checkpoint.assert_called_with(tmpdir / "last.ckpt") checkpoint_plugin.reset_mock() ck = ModelCheckpoint(dirpath=tmpdir, save_last=True) model = BoringModel() device = torch.device("cpu") trainer = Trainer( default_root_dir=tmpdir, strategy=SingleDevicePlugin(device), plugins=[checkpoint_plugin], callbacks=ck, max_epochs=2, ) trainer.fit(model) assert checkpoint_plugin.save_checkpoint.call_count == 5 assert checkpoint_plugin.remove_checkpoint.call_count == 1 trainer.test(model, ckpt_path=ck.last_model_path) checkpoint_plugin.load_checkpoint.assert_called_once() checkpoint_plugin.load_checkpoint.assert_called_with(tmpdir / "last.ckpt")
def select_training_type_plugin(self): cluster_environment = self.select_cluster_environment() if self.use_ddp2: plugin = DDP2Plugin(parallel_devices=self.parallel_devices, cluster_environment=cluster_environment) elif self.use_ddp: use_slurm_ddp = self.use_ddp and self.is_slurm_managing_tasks use_torchelastic_ddp = self.use_ddp and self.is_using_torchelastic use_ddp_spawn = self._distrib_type == DistributedType.DDP_SPAWN use_ddp_cpu_spawn = self.use_ddp and self.on_cpu use_ddp_cpu_torch_elastic = use_ddp_cpu_spawn and self.is_using_torchelastic use_ddp_cpu_slurm = use_ddp_cpu_spawn and self.is_slurm_managing_tasks # use_ddp_sharded = self.distributed_backend == "ddp_sharded" # use_ddp_sharded_spawn = self.distributed_backend == "ddp_sharded_spawn" if self.on_tpu: ddp_plugin_cls = TPUSpawnPlugin # ddp script mode uses the same flags as TE # TODO: decouple from TE if os.environ.get("PL_IN_DDP_SUBPROCESS", False): use_torchelastic_ddp = False # fixme # if use_ddp_sharded: # ddp_plugin_cls = DDPShardedPlugin # elif use_ddp_sharded_spawn: # ddp_plugin_cls = DDPSpawnShardedPlugin if use_ddp_cpu_slurm or use_slurm_ddp or use_ddp_cpu_torch_elastic or use_torchelastic_ddp: ddp_plugin_cls = DDPPlugin elif use_ddp_spawn or use_ddp_cpu_spawn: ddp_plugin_cls = DDPSpawnPlugin else: ddp_plugin_cls = DDPPlugin plugin = ddp_plugin_cls( parallel_devices=self.parallel_devices, num_nodes=self.num_nodes, cluster_environment=cluster_environment, sync_batchnorm=self.sync_batchnorm, ) elif self.use_dp: plugin = DataParallelPlugin(parallel_devices=self.parallel_devices) elif self.use_horovod: plugin = HorovodPlugin(parallel_devices=self.parallel_devices) elif self.on_tpu: plugin = SingleTPUPlugin(self.tpu_id) else: plugin = SingleDevicePlugin(device=torch.device( f"cuda:{self.root_gpu}" if self.on_gpu else "cpu")) return plugin
def __init__( self, precision_plugin: PrecisionPlugin = PrecisionPlugin(), training_type_plugin: TrainingTypePlugin = SingleDevicePlugin( torch.device(ipex.DEVICE)), enable_bf16=False, ) -> None: """ Args: precision_plugin: the plugin to handle precision-specific parts training_type_plugin: the plugin to handle different training routines """ if enable_bf16: # Automatically mix precision ipex.enable_auto_mixed_precision(mixed_dtype=torch.bfloat16) self.device = ipex.DEVICE super().__init__(precision_plugin=precision_plugin, training_type_plugin=training_type_plugin)
def select_training_type_plugin(self) -> TrainingTypePlugin: if isinstance( self.distributed_backend, Accelerator ) and self.distributed_backend.training_type_plugin is not None: plugin = self.distributed_backend.training_type_plugin elif self.use_ddp2: plugin = DDP2Plugin( parallel_devices=self.parallel_devices, cluster_environment=self.cluster_environment, ) elif self.use_ddp and self.use_deepspeed: plugin = DeepSpeedPlugin( cluster_environment=self.select_cluster_environment(), parallel_devices=self.parallel_devices ) elif self.use_ddp: use_slurm_ddp = self.use_ddp and self.is_slurm_managing_tasks use_torchelastic_ddp = self.use_ddp and TorchElasticEnvironment.is_using_torchelastic() use_kubeflow_ddp = self.use_ddp and KubeflowEnvironment.is_using_kubeflow() use_ddp_spawn = self._distrib_type == DistributedType.DDP_SPAWN use_ddp_cpu_spawn = self.use_ddp and self.on_cpu use_tpu_spawn = self.on_tpu and self._distrib_type == DistributedType.TPU_SPAWN use_ddp_cpu_torch_elastic = use_ddp_cpu_spawn and TorchElasticEnvironment.is_using_torchelastic() use_ddp_cpu_kubeflow = use_ddp_cpu_spawn and KubeflowEnvironment.is_using_kubeflow() use_ddp_cpu_slurm = use_ddp_cpu_spawn and self.is_slurm_managing_tasks use_ddp_sharded = self._distrib_type == DistributedType.DDP_SHARDED use_ddp_sharded_spawn = self._distrib_type == DistributedType.DDP_SHARDED_SPAWN use_ddp_fully_sharded = self._distrib_type == DistributedType.DDP_FULLY_SHARDED # TODO: decouple from TE # ddp script mode uses the same flags as TE if os.environ.get("PL_IN_DDP_SUBPROCESS", False): use_torchelastic_ddp = False if use_tpu_spawn: ddp_plugin_cls = TPUSpawnPlugin elif use_ddp_sharded: ddp_plugin_cls = DDPShardedPlugin elif use_ddp_sharded_spawn: ddp_plugin_cls = DDPSpawnShardedPlugin elif ( use_ddp_cpu_slurm or use_slurm_ddp or use_ddp_cpu_torch_elastic or use_torchelastic_ddp or use_kubeflow_ddp or use_ddp_cpu_kubeflow ): ddp_plugin_cls = DDPPlugin elif use_ddp_spawn or use_ddp_cpu_spawn: ddp_plugin_cls = DDPSpawnPlugin elif use_ddp_fully_sharded: ddp_plugin_cls = DDPFullyShardedPlugin else: ddp_plugin_cls = DDPPlugin plugin = ddp_plugin_cls( parallel_devices=self.parallel_devices, cluster_environment=self.cluster_environment, ) elif self.use_dp: plugin = DataParallelPlugin(parallel_devices=self.parallel_devices) elif self.use_horovod: plugin = HorovodPlugin(parallel_devices=self.parallel_devices) elif self.on_tpu and isinstance(self.tpu_cores, list): plugin = SingleTPUPlugin(self.tpu_id) elif self.on_ipu: plugin = IPUPlugin(parallel_devices=self.parallel_devices) else: single_gpu_ordinal = device_parser.determine_root_gpu_device(self.parallel_device_ids) plugin = SingleDevicePlugin(device=torch.device(f"cuda:{single_gpu_ordinal}" if self.on_gpu else "cpu")) return plugin