def _check_and_init_precision(self) -> PrecisionPlugin: self._validate_precision_choice() if isinstance(self._precision_plugin_flag, PrecisionPlugin): return self._precision_plugin_flag if isinstance(self.accelerator, IPUAccelerator): return IPUPrecisionPlugin(self._precision_flag) # type: ignore if isinstance(self.accelerator, HPUAccelerator): return HPUPrecisionPlugin(self._precision_flag) # type: ignore if isinstance(self.accelerator, TPUAccelerator): if self._precision_flag == 32: return TPUPrecisionPlugin() elif self._precision_flag in (16, "bf16"): if self._precision_flag == 16: rank_zero_warn( "You passed `Trainer(accelerator='tpu', precision=16)` but AMP" " is not supported with TPUs. Using `precision='bf16'` instead." ) return TPUBf16PrecisionPlugin() if isinstance(self.strategy, DeepSpeedStrategy): return DeepSpeedPrecisionPlugin( self._precision_flag, self._amp_type_flag, self._amp_level_flag # type: ignore ) if self._precision_flag == 32: return PrecisionPlugin() if self._precision_flag == 64: return DoublePrecisionPlugin() if self._precision_flag == 16 and self._accelerator_flag == "cpu": rank_zero_warn( "You passed `Trainer(accelerator='cpu', precision=16)` but native AMP is not supported on CPU." " Using `precision='bf16'` instead.") self._precision_flag = "bf16" if self._precision_flag in (16, "bf16"): rank_zero_info( f"Using 16bit {self._amp_type_flag.value} Automatic Mixed Precision (AMP)" # type: ignore if self._precision_flag == 16 else "Using bfloat16 Automatic Mixed Precision (AMP)") if self._amp_type_flag == AMPType.NATIVE: device = "cpu" if self._accelerator_flag == "cpu" else "cuda" if isinstance(self.strategy, (DDPShardedStrategy, DDPSpawnShardedStrategy)): return ShardedNativeMixedPrecisionPlugin( self._precision_flag, device) if isinstance(self.strategy, DDPFullyShardedStrategy): return FullyShardedNativeMixedPrecisionPlugin( self._precision_flag, device) return NativeMixedPrecisionPlugin(self._precision_flag, device) if self._amp_type_flag == AMPType.APEX: self._amp_level_flag = self._amp_level_flag or "O2" return ApexMixedPrecisionPlugin(self._amp_level_flag) raise RuntimeError("No precision set")
def select_precision_plugin(self) -> PrecisionPlugin: # set precision type self.amp_type = AMPType.from_str(self.amp_type) if self.use_ipu: return IPUPrecisionPlugin(self.precision) if self._distrib_type == DistributedType.DEEPSPEED or isinstance( self._training_type_plugin, DeepSpeedPlugin): return DeepSpeedPrecisionPlugin(self.precision) if self.precision == 32: return PrecisionPlugin() if self.precision == 64: return DoublePrecisionPlugin() if self.precision in (16, "bf16"): if self.use_tpu: return TPUHalfPrecisionPlugin() if self.amp_type == AMPType.NATIVE: if not _NATIVE_AMP_AVAILABLE: msg = ( "You have asked for native AMP but your PyTorch version does not support it." " Consider upgrading with `pip install torch>=1.6`.") if _APEX_AVAILABLE: self.amp_type = AMPType.APEX msg += " We will attempt to use NVIDIA Apex for this session." rank_zero_warn(msg) else: raise MisconfigurationException(msg) else: log.info( f"Using native {self.precision} bit Automatic Mixed Precision" ) if self._is_sharded_training_type: return ShardedNativeMixedPrecisionPlugin( self.precision, use_cpu=self.use_cpu) if self._is_fully_sharded_training_type: return FullyShardedNativeMixedPrecisionPlugin( self.precision, use_cpu=self.use_cpu) return NativeMixedPrecisionPlugin(self.precision, use_cpu=self.use_cpu) if self.amp_type == AMPType.APEX: if not _APEX_AVAILABLE: raise MisconfigurationException( "You have asked for Apex AMP but you have not installed it yet." " Install apex first using this guide: https://github.com/NVIDIA/apex#linux" ) if self._is_sharded_training_type or self._is_fully_sharded_training_type: raise MisconfigurationException( "Sharded Plugin is not supported with Apex AMP, please using native AMP for 16-bit precision." ) log.info("Using APEX 16bit precision.") return ApexMixedPrecisionPlugin(self.amp_level) raise MisconfigurationException( f"Precision {self.precision} is invalid. Allowed precision values: {PrecisionType.supported_types()}" )
def select_precision_plugin(self) -> PrecisionPlugin: # set precision type self.amp_type = AMPType.from_str(self.amp_type) if self.use_ipu: return IPUPrecisionPlugin(self.precision) if self._distrib_type == DistributedType.DEEPSPEED or isinstance( self._training_type_plugin, DeepSpeedPlugin): return DeepSpeedPrecisionPlugin(self.precision) if self.precision == 32: return PrecisionPlugin() if self.precision == 64: return DoublePrecisionPlugin() if self.precision == 16: if self.use_tpu: return TPUHalfPrecisionPlugin() if self.amp_type == AMPType.NATIVE: if self.use_cpu: raise MisconfigurationException( "You have asked for native AMP on CPU, but AMP is only available on GPU." ) if not _NATIVE_AMP_AVAILABLE: msg = ( "You have asked for native AMP but your PyTorch version does not support it." " Consider upgrading with `pip install torch>=1.6`.") if _APEX_AVAILABLE: self.amp_type = AMPType.APEX msg += " We will attempt to use NVIDIA Apex for this session." rank_zero_warn(msg) else: raise MisconfigurationException(msg) else: log.info("Using native 16bit precision.") if self._is_sharded_training_type: return ShardedNativeMixedPrecisionPlugin() if self._is_fully_sharded_training_type: return FullyShardedNativeMixedPrecisionPlugin() return NativeMixedPrecisionPlugin() if self.amp_type == AMPType.APEX: if not _APEX_AVAILABLE: raise MisconfigurationException( "You have asked for Apex AMP but you have not installed it yet." " Install apex first using this guide: https://github.com/NVIDIA/apex#linux" ) if self._is_sharded_training_type or self._is_fully_sharded_training_type: raise MisconfigurationException( "Sharded Plugin is not supported with Apex AMP," " please using native AMP for 16-bit precision.") log.info("Using APEX 16bit precision.") return ApexMixedPrecisionPlugin(self.amp_level) raise NotImplementedError("We only support precisions 64, 32 and 16!")
def select_precision_plugin(self) -> PrecisionPlugin: # set precision type self.amp_type = AMPType.from_str(self.amp_type) # validation for all plugins if self.amp_level is not None and self.amp_type != AMPType.APEX: raise MisconfigurationException( f"You have asked for `amp_level={self.amp_level!r}` but it's only supported with `amp_backend='apex'`." ) if self.use_ipu: if self.precision not in (16, 32): raise MisconfigurationException( f"`Trainer(accelerator='ipu', precision={self.precision!r})` is not supported." ) return IPUPrecisionPlugin(self.precision) if self.use_tpu: if self.precision == 32: return TPUPrecisionPlugin() elif self.precision == 64: raise MisconfigurationException( "`Trainer(accelerator='tpu', precision=64)` is not implemented." " Please, open an issue in `https://github.com/PyTorchLightning/pytorch-lightning/issues`" " requesting this feature.") elif self.precision in (16, "bf16"): if self.precision == 16: # this is not deprecated to ease transition between accelerator environments rank_zero_warn( f"You passed `Trainer(accelerator='tpu', precision=16)` but {self.amp_type.value} AMP" f" is not supported with TPUs. Using `precision='bf16'` instead." ) return TPUBf16PrecisionPlugin() if self._distrib_type == _StrategyType.DEEPSPEED or isinstance( self._training_type_plugin, DeepSpeedStrategy): return DeepSpeedPrecisionPlugin(self.precision, self.amp_type, self.amp_level) if self.precision == 32: return PrecisionPlugin() if self.precision == 64: return DoublePrecisionPlugin() # maybe convert the precision value if self.precision == 16 and self.use_cpu: if self.amp_type == AMPType.APEX: # apex was explicitly passed, not a good idea to silently switch to native AMP raise MisconfigurationException( "You passed `Trainer(accelerator='cpu', precision=16, amp_type='apex')`" " but apex AMP not supported on CPU.") # this automatic switch is to ease transition between accelerator environments rank_zero_warn( "You passed `Trainer(accelerator='cpu', precision=16)` but native AMP is not supported on CPU." " Using `precision='bf16'` instead.") self.precision = "bf16" if self.precision in (16, "bf16"): if self.precision == "bf16" and self.amp_type != AMPType.NATIVE: raise MisconfigurationException( f"You passed `Trainer(amp_type={self.amp_type.value!r}, precision='bf16')` but it's not supported." " Try using `amp_type='native'` instead.") rank_zero_info( f"Using 16bit {self.amp_type.value} Automatic Mixed Precision (AMP)" if self.precision == 16 else "Using bfloat16 Automatic Mixed Precision (AMP)") if self.amp_type == AMPType.NATIVE: device = "cpu" if self.use_cpu else "cuda" if self._is_sharded_training_type: return ShardedNativeMixedPrecisionPlugin( self.precision, device) if self._is_fully_sharded_training_type: return FullyShardedNativeMixedPrecisionPlugin( self.precision, device) return NativeMixedPrecisionPlugin(self.precision, device) if self.amp_type == AMPType.APEX: if self._is_sharded_training_type or self._is_fully_sharded_training_type: raise MisconfigurationException( "Sharded plugins are not supported with apex, please switch to `amp_backend='native'`." ) self.amp_level = self.amp_level or "O2" return ApexMixedPrecisionPlugin(self.amp_level) raise RuntimeError("No precision set")