def select_precision_plugin(self) -> PrecisionPlugin: # set precision type self.amp_type = AMPType.from_str(self.amp_type) if self.use_ipu: return IPUPrecisionPlugin(self.precision) if self._distrib_type == DistributedType.DEEPSPEED or isinstance( self._training_type_plugin, DeepSpeedPlugin): return DeepSpeedPrecisionPlugin(self.precision) if self.precision == 32: return PrecisionPlugin() if self.precision == 64: return DoublePrecisionPlugin() if self.precision in (16, "bf16"): if self.use_tpu: return TPUHalfPrecisionPlugin() if self.amp_type == AMPType.NATIVE: if not _NATIVE_AMP_AVAILABLE: msg = ( "You have asked for native AMP but your PyTorch version does not support it." " Consider upgrading with `pip install torch>=1.6`.") if _APEX_AVAILABLE: self.amp_type = AMPType.APEX msg += " We will attempt to use NVIDIA Apex for this session." rank_zero_warn(msg) else: raise MisconfigurationException(msg) else: log.info( f"Using native {self.precision} bit Automatic Mixed Precision" ) if self._is_sharded_training_type: return ShardedNativeMixedPrecisionPlugin( self.precision, use_cpu=self.use_cpu) if self._is_fully_sharded_training_type: return FullyShardedNativeMixedPrecisionPlugin( self.precision, use_cpu=self.use_cpu) return NativeMixedPrecisionPlugin(self.precision, use_cpu=self.use_cpu) if self.amp_type == AMPType.APEX: if not _APEX_AVAILABLE: raise MisconfigurationException( "You have asked for Apex AMP but you have not installed it yet." " Install apex first using this guide: https://github.com/NVIDIA/apex#linux" ) if self._is_sharded_training_type or self._is_fully_sharded_training_type: raise MisconfigurationException( "Sharded Plugin is not supported with Apex AMP, please using native AMP for 16-bit precision." ) log.info("Using APEX 16bit precision.") return ApexMixedPrecisionPlugin(self.amp_level) raise MisconfigurationException( f"Precision {self.precision} is invalid. Allowed precision values: {PrecisionType.supported_types()}" )
def select_precision_plugin(self) -> PrecisionPlugin: # set precision type self.amp_type = AMPType.from_str(self.amp_type) if self.use_ipu: return IPUPrecisionPlugin(self.precision) if self._distrib_type == DistributedType.DEEPSPEED or isinstance( self._training_type_plugin, DeepSpeedPlugin): return DeepSpeedPrecisionPlugin(self.precision) if self.precision == 32: return PrecisionPlugin() if self.precision == 64: return DoublePrecisionPlugin() if self.precision == 16: if self.use_tpu: return TPUHalfPrecisionPlugin() if self.amp_type == AMPType.NATIVE: if self.use_cpu: raise MisconfigurationException( "You have asked for native AMP on CPU, but AMP is only available on GPU." ) if not _NATIVE_AMP_AVAILABLE: msg = ( "You have asked for native AMP but your PyTorch version does not support it." " Consider upgrading with `pip install torch>=1.6`.") if _APEX_AVAILABLE: self.amp_type = AMPType.APEX msg += " We will attempt to use NVIDIA Apex for this session." rank_zero_warn(msg) else: raise MisconfigurationException(msg) else: log.info("Using native 16bit precision.") if self._is_sharded_training_type: return ShardedNativeMixedPrecisionPlugin() if self._is_fully_sharded_training_type: return FullyShardedNativeMixedPrecisionPlugin() return NativeMixedPrecisionPlugin() if self.amp_type == AMPType.APEX: if not _APEX_AVAILABLE: raise MisconfigurationException( "You have asked for Apex AMP but you have not installed it yet." " Install apex first using this guide: https://github.com/NVIDIA/apex#linux" ) if self._is_sharded_training_type or self._is_fully_sharded_training_type: raise MisconfigurationException( "Sharded Plugin is not supported with Apex AMP," " please using native AMP for 16-bit precision.") log.info("Using APEX 16bit precision.") return ApexMixedPrecisionPlugin(self.amp_level) raise NotImplementedError("We only support precisions 64, 32 and 16!")
def _check_config_and_set_final_flags( self, strategy: Optional[Union[str, Strategy]], accelerator: Optional[Union[str, Accelerator]], precision: Union[int, str], plugins: Optional[Union[PLUGIN_INPUT, List[PLUGIN_INPUT]]], amp_type: str, amp_level: Optional[str], sync_batchnorm: bool, ) -> None: """This method checks: 1. strategy: strategy, accelerator and plugin can all be set to strategies 2. accelerator: if the value of the accelerator argument is a type of accelerator (instance or string), set self._accelerator_flag accordingly. If the value is strategy related (instance or string), it gets handled by 1. 3. precision: The final value of the precision flag may be determined either by the precision argument or by a plugin instance. 4. plugins: a plugin could occur as a value of the strategy argument (handled by 1), or the precision argument (handled by 3). We also extract the CheckpointIO and ClusterEnvironment plugins. """ if plugins is not None: plugins = [plugins] if not isinstance(plugins, list) else plugins if strategy is not None: self._strategy_flag = strategy if strategy == "ddp_cpu": raise MisconfigurationException( "`Trainer(strategy='ddp_cpu')` is not a valid strategy," " you can use `Trainer(strategy='ddp'|'ddp_spawn', accelerator='cpu')` instead." ) if strategy == "tpu_spawn": raise MisconfigurationException( "`Trainer(strategy='tpu_spawn')` is not a valid strategy," " you can use `Trainer(strategy='ddp_spawn', accelerator='tpu')` instead." ) # handle duplications and conflict if isinstance(accelerator, Strategy) and strategy != accelerator: raise MisconfigurationException( f"Incompatible values set in `strategy` and `accelerator` arguments." f"Received both strategy={strategy} and accelerator={accelerator}" ) if isinstance( accelerator, str ) and accelerator in self._registered_strategies and strategy != accelerator: raise MisconfigurationException( f"strategy {strategy} already set through `strategy` flag," f" but have also passed {accelerator} in through the accelerator flag." ) if plugins: for plugin in plugins: if isinstance(plugin, Strategy): raise MisconfigurationException( f"You have passed `Trainer(strategy={strategy})`" f" and you can only specify one strategy, but you have passed {plugin} as a plugin." ) if isinstance( plugin, str) and plugin in self._registered_strategies: raise MisconfigurationException( f"You have passed `Trainer(strategy={strategy})`" f" and you can only specify one strategy, but you have passed {plugin} as a plugin." ) if accelerator is not None: if accelerator in self._accelerator_types or accelerator == "auto" or isinstance( accelerator, Accelerator): self._accelerator_flag = accelerator elif accelerator in self._registered_strategies or isinstance( accelerator, Strategy): rank_zero_deprecation( f"Passing `Trainer(accelerator={accelerator!r})` has been deprecated" f" in v1.5 and will be removed in v1.7. Use `Trainer(strategy={accelerator!r})` instead." ) self._strategy_flag = accelerator elif accelerator == "ddp_cpu" and not self._strategy_flag: self._strategy_flag = accelerator if precision is not None: if str(precision) not in self._precision_types: raise MisconfigurationException( f"Precision {repr(precision)} is invalid. Allowed precision values: {self._precision_types}" ) self._precision_flag = precision if plugins: plugins_flags_types: Dict[str, int] = Counter() for plugin in plugins: if isinstance(plugin, Strategy) or isinstance( plugin, str) and plugin in self._registered_strategies: self._strategy_flag = plugin rank_zero_deprecation( f"Passing {plugin} `strategy` to the `plugins` flag in Trainer has been deprecated" f" in v1.5 and will be removed in v1.7. Use `Trainer(strategy={plugin})` instead." ) plugins_flags_types[Strategy.__name__] += 1 elif isinstance(plugin, PrecisionPlugin): self._precision_plugin_flag = plugin plugins_flags_types[PrecisionPlugin.__name__] += 1 elif isinstance(plugin, CheckpointIO): self.checkpoint_io = plugin plugins_flags_types[CheckpointIO.__name__] += 1 elif isinstance(plugin, ClusterEnvironment): self._cluster_environment_flag = plugin plugins_flags_types[ClusterEnvironment.__name__] += 1 elif isinstance(plugin, LayerSync): if sync_batchnorm and not isinstance( plugin, NativeSyncBatchNorm): raise MisconfigurationException( f"You set `Trainer(sync_batchnorm=True)` and provided a `{plugin.__class__.__name__}`" " plugin, but this is not allowed. Choose one or the other." ) self._layer_sync = plugin plugins_flags_types[NativeSyncBatchNorm.__name__] += 1 else: raise MisconfigurationException( f"Found invalid type for plugin {plugin}. Expected one of: PrecisionPlugin, " "CheckpointIO, ClusterEnviroment, LayerSync, or Strategy." ) duplicated_plugin_key = [ k for k, v in plugins_flags_types.items() if v > 1 ] if duplicated_plugin_key: raise MisconfigurationException( f"Received multiple values for {', '.join(duplicated_plugin_key)} flags in `plugins`." " Expected one value for each type at most.") # handle the case when the user passes in a strategy instance which has an accelerator, precision, # checkpoint io or cluster env set up # TODO: @awaelchli improve the error messages below if self._strategy_flag and isinstance(self._strategy_flag, Strategy): if self._strategy_flag._accelerator: if self._accelerator_flag: raise MisconfigurationException( "accelerator set through both strategy class and accelerator flag, choose one" ) else: self._accelerator_flag = self._strategy_flag._accelerator if self._strategy_flag._precision_plugin: # [RFC] handle precision plugin set up conflict? if self._precision_plugin_flag: raise MisconfigurationException( "precision set through both strategy class and plugins, choose one" ) else: self._precision_plugin_flag = self._strategy_flag._precision_plugin if self._strategy_flag._checkpoint_io: if self.checkpoint_io: raise MisconfigurationException( "checkpoint_io set through both strategy class and plugins, choose one" ) else: self.checkpoint_io = self._strategy_flag._checkpoint_io if getattr(self._strategy_flag, "cluster_environment", None): if self._cluster_environment_flag: raise MisconfigurationException( "cluster_environment set through both strategy class and plugins, choose one" ) else: self._cluster_environment_flag = getattr( self._strategy_flag, "cluster_environment") if hasattr(self._strategy_flag, "parallel_devices"): if self._strategy_flag.parallel_devices: if self._strategy_flag.parallel_devices[0].type == "cpu": if self._accelerator_flag and self._accelerator_flag not in ( "auto", "cpu"): raise MisconfigurationException( f"CPU parallel_devices set through {self._strategy_flag.__class__.__name__} class," f" but accelerator set to {self._accelerator_flag}, please choose one device type" ) self._accelerator_flag = "cpu" if self._strategy_flag.parallel_devices[0].type == "cuda": if self._accelerator_flag and self._accelerator_flag not in ( "auto", "gpu"): raise MisconfigurationException( f"GPU parallel_devices set through {self._strategy_flag.__class__.__name__} class," f" but accelerator set to {self._accelerator_flag}, please choose one device type" ) self._accelerator_flag = "gpu" self._parallel_devices = self._strategy_flag.parallel_devices amp_type = amp_type if isinstance(amp_type, str) else None self._amp_type_flag = AMPType.from_str(amp_type) if amp_level is not None and self._amp_type_flag != AMPType.APEX: raise MisconfigurationException( f"You have asked for `amp_level={amp_level!r}` but it's only supported with `amp_backend='apex'`." )
def select_precision_plugin(self) -> PrecisionPlugin: # set precision type self.amp_type = AMPType.from_str(self.amp_type) # validation for all plugins if self.amp_level is not None and self.amp_type != AMPType.APEX: raise MisconfigurationException( f"You have asked for `amp_level={self.amp_level!r}` but it's only supported with `amp_backend='apex'`." ) if self.use_ipu: if self.precision not in (16, 32): raise MisconfigurationException( f"`Trainer(accelerator='ipu', precision={self.precision!r})` is not supported." ) return IPUPrecisionPlugin(self.precision) if self.use_tpu: if self.precision == 32: return TPUPrecisionPlugin() elif self.precision == 64: raise MisconfigurationException( "`Trainer(accelerator='tpu', precision=64)` is not implemented." " Please, open an issue in `https://github.com/PyTorchLightning/pytorch-lightning/issues`" " requesting this feature.") elif self.precision in (16, "bf16"): if self.precision == 16: # this is not deprecated to ease transition between accelerator environments rank_zero_warn( f"You passed `Trainer(accelerator='tpu', precision=16)` but {self.amp_type.value} AMP" f" is not supported with TPUs. Using `precision='bf16'` instead." ) return TPUBf16PrecisionPlugin() if self._distrib_type == _StrategyType.DEEPSPEED or isinstance( self._training_type_plugin, DeepSpeedStrategy): return DeepSpeedPrecisionPlugin(self.precision, self.amp_type, self.amp_level) if self.precision == 32: return PrecisionPlugin() if self.precision == 64: return DoublePrecisionPlugin() # maybe convert the precision value if self.precision == 16 and self.use_cpu: if self.amp_type == AMPType.APEX: # apex was explicitly passed, not a good idea to silently switch to native AMP raise MisconfigurationException( "You passed `Trainer(accelerator='cpu', precision=16, amp_type='apex')`" " but apex AMP not supported on CPU.") # this automatic switch is to ease transition between accelerator environments rank_zero_warn( "You passed `Trainer(accelerator='cpu', precision=16)` but native AMP is not supported on CPU." " Using `precision='bf16'` instead.") self.precision = "bf16" if self.precision in (16, "bf16"): if self.precision == "bf16" and self.amp_type != AMPType.NATIVE: raise MisconfigurationException( f"You passed `Trainer(amp_type={self.amp_type.value!r}, precision='bf16')` but it's not supported." " Try using `amp_type='native'` instead.") rank_zero_info( f"Using 16bit {self.amp_type.value} Automatic Mixed Precision (AMP)" if self.precision == 16 else "Using bfloat16 Automatic Mixed Precision (AMP)") if self.amp_type == AMPType.NATIVE: device = "cpu" if self.use_cpu else "cuda" if self._is_sharded_training_type: return ShardedNativeMixedPrecisionPlugin( self.precision, device) if self._is_fully_sharded_training_type: return FullyShardedNativeMixedPrecisionPlugin( self.precision, device) return NativeMixedPrecisionPlugin(self.precision, device) if self.amp_type == AMPType.APEX: if self._is_sharded_training_type or self._is_fully_sharded_training_type: raise MisconfigurationException( "Sharded plugins are not supported with apex, please switch to `amp_backend='native'`." ) self.amp_level = self.amp_level or "O2" return ApexMixedPrecisionPlugin(self.amp_level) raise RuntimeError("No precision set")
def _check_config_and_set_final_flags( self, strategy: Optional[Union[str, Strategy]], accelerator: Optional[Union[str, Accelerator]], precision: Union[int, str], plugins: Optional[Union[PLUGIN_INPUT, List[PLUGIN_INPUT]]], amp_type: str, amp_level: Optional[str], sync_batchnorm: bool, ) -> None: """This method checks: 1. strategy: whether the strategy name is valid, and sets the internal flags if it is. 2. accelerator: if the value of the accelerator argument is a type of accelerator (instance or string), set self._accelerator_flag accordingly. 3. precision: The final value of the precision flag may be determined either by the precision argument or by a plugin instance. 4. plugins: The list of plugins may contain a Precision plugin, CheckpointIO, ClusterEnvironment and others. Additionally, other flags such as `precision` or `sync_batchnorm` can populate the list with the corresponding plugin instances. """ if plugins is not None: plugins = [plugins] if not isinstance(plugins, list) else plugins if isinstance(strategy, str): strategy = strategy.lower() if strategy is not None: self._strategy_flag = strategy if strategy == "ddp_cpu": raise MisconfigurationException( "`Trainer(strategy='ddp_cpu')` is not a valid strategy," " you can use `Trainer(strategy='ddp'|'ddp_spawn', accelerator='cpu')` instead." ) if strategy == "tpu_spawn": raise MisconfigurationException( "`Trainer(strategy='tpu_spawn')` is not a valid strategy," " you can use `Trainer(strategy='ddp_spawn', accelerator='tpu')` instead." ) if accelerator is not None: if accelerator in self._accelerator_types or accelerator == "auto" or isinstance(accelerator, Accelerator): self._accelerator_flag = accelerator if precision is not None: if str(precision) not in self._precision_types: raise MisconfigurationException( f"Precision {repr(precision)} is invalid. Allowed precision values: {self._precision_types}" ) self._precision_flag = precision if plugins: plugins_flags_types: Dict[str, int] = Counter() for plugin in plugins: if isinstance(plugin, PrecisionPlugin): self._precision_plugin_flag = plugin plugins_flags_types[PrecisionPlugin.__name__] += 1 elif isinstance(plugin, CheckpointIO): self.checkpoint_io = plugin plugins_flags_types[CheckpointIO.__name__] += 1 elif isinstance(plugin, ClusterEnvironment): self._cluster_environment_flag = plugin plugins_flags_types[ClusterEnvironment.__name__] += 1 elif isinstance(plugin, LayerSync): if sync_batchnorm and not isinstance(plugin, NativeSyncBatchNorm): raise MisconfigurationException( f"You set `Trainer(sync_batchnorm=True)` and provided a `{plugin.__class__.__name__}`" " plugin, but this is not allowed. Choose one or the other." ) self._layer_sync = plugin plugins_flags_types[NativeSyncBatchNorm.__name__] += 1 else: raise MisconfigurationException( f"Found invalid type for plugin {plugin}. Expected one of: PrecisionPlugin, " "CheckpointIO, ClusterEnviroment, or LayerSync." ) duplicated_plugin_key = [k for k, v in plugins_flags_types.items() if v > 1] if duplicated_plugin_key: raise MisconfigurationException( f"Received multiple values for {', '.join(duplicated_plugin_key)} flags in `plugins`." " Expected one value for each type at most." ) # handle the case when the user passes in a strategy instance which has an accelerator, precision, # checkpoint io or cluster env set up # TODO: @awaelchli improve the error messages below if self._strategy_flag and isinstance(self._strategy_flag, Strategy): if self._strategy_flag._accelerator: if self._accelerator_flag: raise MisconfigurationException( "accelerator set through both strategy class and accelerator flag, choose one" ) else: self._accelerator_flag = self._strategy_flag._accelerator if self._strategy_flag._precision_plugin: # [RFC] handle precision plugin set up conflict? if self._precision_plugin_flag: raise MisconfigurationException("precision set through both strategy class and plugins, choose one") else: self._precision_plugin_flag = self._strategy_flag._precision_plugin if self._strategy_flag._checkpoint_io: if self.checkpoint_io: raise MisconfigurationException( "checkpoint_io set through both strategy class and plugins, choose one" ) else: self.checkpoint_io = self._strategy_flag._checkpoint_io if getattr(self._strategy_flag, "cluster_environment", None): if self._cluster_environment_flag: raise MisconfigurationException( "cluster_environment set through both strategy class and plugins, choose one" ) else: self._cluster_environment_flag = getattr(self._strategy_flag, "cluster_environment") if hasattr(self._strategy_flag, "parallel_devices"): if self._strategy_flag.parallel_devices: if self._strategy_flag.parallel_devices[0].type == "cpu": if self._accelerator_flag and self._accelerator_flag not in ("auto", "cpu"): raise MisconfigurationException( f"CPU parallel_devices set through {self._strategy_flag.__class__.__name__} class," f" but accelerator set to {self._accelerator_flag}, please choose one device type" ) self._accelerator_flag = "cpu" if self._strategy_flag.parallel_devices[0].type == "cuda": if self._accelerator_flag and self._accelerator_flag not in ("auto", "gpu"): raise MisconfigurationException( f"GPU parallel_devices set through {self._strategy_flag.__class__.__name__} class," f" but accelerator set to {self._accelerator_flag}, please choose one device type" ) self._accelerator_flag = "gpu" self._parallel_devices = self._strategy_flag.parallel_devices amp_type = amp_type if isinstance(amp_type, str) else None self._amp_type_flag = AMPType.from_str(amp_type) if amp_level is not None and self._amp_type_flag != AMPType.APEX: raise MisconfigurationException( f"You have asked for `amp_level={amp_level!r}` but it's only supported with `amp_backend='apex'`." )