def select_training_type_plugin(self) -> TrainingTypePlugin: if self.use_ddp2: plugin = DDP2Plugin(parallel_devices=self.parallel_devices, cluster_environment=self.cluster_environment) elif self.use_ddp and self.use_deepspeed: plugin = DeepSpeedPlugin( num_nodes=self.num_nodes, cluster_environment=self.select_cluster_environment(), parallel_devices=self.parallel_devices) elif self.use_ddp: use_slurm_ddp = self.use_ddp and self.is_slurm_managing_tasks use_torchelastic_ddp = self.use_ddp and self.is_using_torchelastic use_ddp_spawn = self._distrib_type == DistributedType.DDP_SPAWN use_ddp_cpu_spawn = self.use_ddp and self.on_cpu use_ddp_cpu_torch_elastic = use_ddp_cpu_spawn and self.is_using_torchelastic use_ddp_cpu_slurm = use_ddp_cpu_spawn and self.is_slurm_managing_tasks use_ddp_sharded = self._distrib_type == DistributedType.DDP_SHARDED use_ddp_sharded_spawn = self._distrib_type == DistributedType.DDP_SHARDED_SPAWN # TODO: decouple from TE # ddp script mode uses the same flags as TE if os.environ.get("PL_IN_DDP_SUBPROCESS", False): use_torchelastic_ddp = False if self.on_tpu: ddp_plugin_cls = TPUSpawnPlugin elif use_ddp_sharded: ddp_plugin_cls = DDPShardedPlugin elif use_ddp_sharded_spawn: ddp_plugin_cls = DDPSpawnShardedPlugin elif use_ddp_cpu_slurm or use_slurm_ddp or use_ddp_cpu_torch_elastic or use_torchelastic_ddp: ddp_plugin_cls = DDPPlugin elif use_ddp_spawn or use_ddp_cpu_spawn: ddp_plugin_cls = DDPSpawnPlugin else: ddp_plugin_cls = DDPPlugin plugin = ddp_plugin_cls( parallel_devices=self.parallel_devices, num_nodes=self.num_nodes, cluster_environment=self.cluster_environment, sync_batchnorm=self.sync_batchnorm, ) elif self.use_dp: plugin = DataParallelPlugin(parallel_devices=self.parallel_devices) elif self.use_horovod: plugin = HorovodPlugin(parallel_devices=self.parallel_devices) elif self.on_tpu: if isinstance(self.tpu_cores, list): plugin = SingleTPUPlugin(self.tpu_id) else: plugin = TPUSpawnPlugin( parallel_devices=list(range(self.tpu_cores))) else: single_gpu_ordinal = device_parser.determine_root_gpu_device( self.parallel_device_ids) plugin = SingleDevicePlugin(device=torch.device( f"cuda:{single_gpu_ordinal}" if self.on_gpu else "cpu")) return plugin
def select_training_type_plugin(self): cluster_environment = self.select_cluster_environment() if self.use_ddp2: plugin = DDP2Plugin(parallel_devices=self.parallel_devices, cluster_environment=cluster_environment) elif self.use_ddp: use_slurm_ddp = self.use_ddp and self.is_slurm_managing_tasks use_torchelastic_ddp = self.use_ddp and self.is_using_torchelastic use_ddp_spawn = self._distrib_type == DistributedType.DDP_SPAWN use_ddp_cpu_spawn = self.use_ddp and self.on_cpu use_ddp_cpu_torch_elastic = use_ddp_cpu_spawn and self.is_using_torchelastic use_ddp_cpu_slurm = use_ddp_cpu_spawn and self.is_slurm_managing_tasks # use_ddp_sharded = self.distributed_backend == "ddp_sharded" # use_ddp_sharded_spawn = self.distributed_backend == "ddp_sharded_spawn" if self.on_tpu: ddp_plugin_cls = TPUSpawnPlugin # ddp script mode uses the same flags as TE # TODO: decouple from TE if os.environ.get("PL_IN_DDP_SUBPROCESS", False): use_torchelastic_ddp = False # fixme # if use_ddp_sharded: # ddp_plugin_cls = DDPShardedPlugin # elif use_ddp_sharded_spawn: # ddp_plugin_cls = DDPSpawnShardedPlugin if use_ddp_cpu_slurm or use_slurm_ddp or use_ddp_cpu_torch_elastic or use_torchelastic_ddp: ddp_plugin_cls = DDPPlugin elif use_ddp_spawn or use_ddp_cpu_spawn: ddp_plugin_cls = DDPSpawnPlugin else: ddp_plugin_cls = DDPPlugin plugin = ddp_plugin_cls( parallel_devices=self.parallel_devices, num_nodes=self.num_nodes, cluster_environment=cluster_environment, sync_batchnorm=self.sync_batchnorm, ) elif self.use_dp: plugin = DataParallelPlugin(parallel_devices=self.parallel_devices) elif self.use_horovod: plugin = HorovodPlugin(parallel_devices=self.parallel_devices) elif self.on_tpu: plugin = SingleTPUPlugin(self.tpu_id) else: plugin = SingleDevicePlugin(device=torch.device( f"cuda:{self.root_gpu}" if self.on_gpu else "cpu")) return plugin
def select_training_type_plugin(self) -> TrainingTypePlugin: if isinstance( self.distributed_backend, Accelerator ) and self.distributed_backend.training_type_plugin is not None: plugin = self.distributed_backend.training_type_plugin elif self.use_ddp2: plugin = DDP2Plugin( parallel_devices=self.parallel_devices, cluster_environment=self.cluster_environment, ) elif self.use_ddp and self.use_deepspeed: plugin = DeepSpeedPlugin( cluster_environment=self.select_cluster_environment(), parallel_devices=self.parallel_devices ) elif self.use_ddp: use_slurm_ddp = self.use_ddp and self.is_slurm_managing_tasks use_torchelastic_ddp = self.use_ddp and TorchElasticEnvironment.is_using_torchelastic() use_kubeflow_ddp = self.use_ddp and KubeflowEnvironment.is_using_kubeflow() use_ddp_spawn = self._distrib_type == DistributedType.DDP_SPAWN use_ddp_cpu_spawn = self.use_ddp and self.on_cpu use_tpu_spawn = self.on_tpu and self._distrib_type == DistributedType.TPU_SPAWN use_ddp_cpu_torch_elastic = use_ddp_cpu_spawn and TorchElasticEnvironment.is_using_torchelastic() use_ddp_cpu_kubeflow = use_ddp_cpu_spawn and KubeflowEnvironment.is_using_kubeflow() use_ddp_cpu_slurm = use_ddp_cpu_spawn and self.is_slurm_managing_tasks use_ddp_sharded = self._distrib_type == DistributedType.DDP_SHARDED use_ddp_sharded_spawn = self._distrib_type == DistributedType.DDP_SHARDED_SPAWN use_ddp_fully_sharded = self._distrib_type == DistributedType.DDP_FULLY_SHARDED # TODO: decouple from TE # ddp script mode uses the same flags as TE if os.environ.get("PL_IN_DDP_SUBPROCESS", False): use_torchelastic_ddp = False if use_tpu_spawn: ddp_plugin_cls = TPUSpawnPlugin elif use_ddp_sharded: ddp_plugin_cls = DDPShardedPlugin elif use_ddp_sharded_spawn: ddp_plugin_cls = DDPSpawnShardedPlugin elif ( use_ddp_cpu_slurm or use_slurm_ddp or use_ddp_cpu_torch_elastic or use_torchelastic_ddp or use_kubeflow_ddp or use_ddp_cpu_kubeflow ): ddp_plugin_cls = DDPPlugin elif use_ddp_spawn or use_ddp_cpu_spawn: ddp_plugin_cls = DDPSpawnPlugin elif use_ddp_fully_sharded: ddp_plugin_cls = DDPFullyShardedPlugin else: ddp_plugin_cls = DDPPlugin plugin = ddp_plugin_cls( parallel_devices=self.parallel_devices, cluster_environment=self.cluster_environment, ) elif self.use_dp: plugin = DataParallelPlugin(parallel_devices=self.parallel_devices) elif self.use_horovod: plugin = HorovodPlugin(parallel_devices=self.parallel_devices) elif self.on_tpu and isinstance(self.tpu_cores, list): plugin = SingleTPUPlugin(self.tpu_id) elif self.on_ipu: plugin = IPUPlugin(parallel_devices=self.parallel_devices) else: single_gpu_ordinal = device_parser.determine_root_gpu_device(self.parallel_device_ids) plugin = SingleDevicePlugin(device=torch.device(f"cuda:{single_gpu_ordinal}" if self.on_gpu else "cpu")) return plugin