def _check_and_init_precision(self) -> PrecisionPlugin: self._validate_precision_choice() if isinstance(self._precision_plugin_flag, PrecisionPlugin): return self._precision_plugin_flag if isinstance(self.accelerator, IPUAccelerator): return IPUPrecisionPlugin(self._precision_flag) # type: ignore if isinstance(self.accelerator, HPUAccelerator): return HPUPrecisionPlugin(self._precision_flag) # type: ignore if isinstance(self.accelerator, TPUAccelerator): if self._precision_flag == 32: return TPUPrecisionPlugin() elif self._precision_flag in (16, "bf16"): if self._precision_flag == 16: rank_zero_warn( "You passed `Trainer(accelerator='tpu', precision=16)` but AMP" " is not supported with TPUs. Using `precision='bf16'` instead." ) return TPUBf16PrecisionPlugin() if isinstance(self.strategy, DeepSpeedStrategy): return DeepSpeedPrecisionPlugin( self._precision_flag, self._amp_type_flag, self._amp_level_flag # type: ignore ) if self._precision_flag == 32: return PrecisionPlugin() if self._precision_flag == 64: return DoublePrecisionPlugin() if self._precision_flag == 16 and self._accelerator_flag == "cpu": rank_zero_warn( "You passed `Trainer(accelerator='cpu', precision=16)` but native AMP is not supported on CPU." " Using `precision='bf16'` instead.") self._precision_flag = "bf16" if self._precision_flag in (16, "bf16"): rank_zero_info( f"Using 16bit {self._amp_type_flag.value} Automatic Mixed Precision (AMP)" # type: ignore if self._precision_flag == 16 else "Using bfloat16 Automatic Mixed Precision (AMP)") if self._amp_type_flag == AMPType.NATIVE: device = "cpu" if self._accelerator_flag == "cpu" else "cuda" if isinstance(self.strategy, (DDPShardedStrategy, DDPSpawnShardedStrategy)): return ShardedNativeMixedPrecisionPlugin( self._precision_flag, device) if isinstance(self.strategy, DDPFullyShardedStrategy): return FullyShardedNativeMixedPrecisionPlugin( self._precision_flag, device) return NativeMixedPrecisionPlugin(self._precision_flag, device) if self._amp_type_flag == AMPType.APEX: self._amp_level_flag = self._amp_level_flag or "O2" return ApexMixedPrecisionPlugin(self._amp_level_flag) raise RuntimeError("No precision set")
def select_precision_plugin(self) -> PrecisionPlugin: # set precision type self.amp_type = AMPType.from_str(self.amp_type) # validation for all plugins if self.amp_level is not None and self.amp_type != AMPType.APEX: raise MisconfigurationException( f"You have asked for `amp_level={self.amp_level!r}` but it's only supported with `amp_backend='apex'`." ) if self.use_ipu: if self.precision not in (16, 32): raise MisconfigurationException( f"`Trainer(accelerator='ipu', precision={self.precision!r})` is not supported." ) return IPUPrecisionPlugin(self.precision) if self.use_tpu: if self.precision == 32: return TPUPrecisionPlugin() elif self.precision == 64: raise MisconfigurationException( "`Trainer(accelerator='tpu', precision=64)` is not implemented." " Please, open an issue in `https://github.com/PyTorchLightning/pytorch-lightning/issues`" " requesting this feature.") elif self.precision in (16, "bf16"): if self.precision == 16: # this is not deprecated to ease transition between accelerator environments rank_zero_warn( f"You passed `Trainer(accelerator='tpu', precision=16)` but {self.amp_type.value} AMP" f" is not supported with TPUs. Using `precision='bf16'` instead." ) return TPUBf16PrecisionPlugin() if self._distrib_type == _StrategyType.DEEPSPEED or isinstance( self._training_type_plugin, DeepSpeedStrategy): return DeepSpeedPrecisionPlugin(self.precision, self.amp_type, self.amp_level) if self.precision == 32: return PrecisionPlugin() if self.precision == 64: return DoublePrecisionPlugin() # maybe convert the precision value if self.precision == 16 and self.use_cpu: if self.amp_type == AMPType.APEX: # apex was explicitly passed, not a good idea to silently switch to native AMP raise MisconfigurationException( "You passed `Trainer(accelerator='cpu', precision=16, amp_type='apex')`" " but apex AMP not supported on CPU.") # this automatic switch is to ease transition between accelerator environments rank_zero_warn( "You passed `Trainer(accelerator='cpu', precision=16)` but native AMP is not supported on CPU." " Using `precision='bf16'` instead.") self.precision = "bf16" if self.precision in (16, "bf16"): if self.precision == "bf16" and self.amp_type != AMPType.NATIVE: raise MisconfigurationException( f"You passed `Trainer(amp_type={self.amp_type.value!r}, precision='bf16')` but it's not supported." " Try using `amp_type='native'` instead.") rank_zero_info( f"Using 16bit {self.amp_type.value} Automatic Mixed Precision (AMP)" if self.precision == 16 else "Using bfloat16 Automatic Mixed Precision (AMP)") if self.amp_type == AMPType.NATIVE: device = "cpu" if self.use_cpu else "cuda" if self._is_sharded_training_type: return ShardedNativeMixedPrecisionPlugin( self.precision, device) if self._is_fully_sharded_training_type: return FullyShardedNativeMixedPrecisionPlugin( self.precision, device) return NativeMixedPrecisionPlugin(self.precision, device) if self.amp_type == AMPType.APEX: if self._is_sharded_training_type or self._is_fully_sharded_training_type: raise MisconfigurationException( "Sharded plugins are not supported with apex, please switch to `amp_backend='native'`." ) self.amp_level = self.amp_level or "O2" return ApexMixedPrecisionPlugin(self.amp_level) raise RuntimeError("No precision set")
def test_teardown(): plugin = TPUBf16PrecisionPlugin() plugin.connect(Mock(), Mock(), Mock()) assert os.environ.get("XLA_USE_BF16") == "1" plugin.teardown() assert "XLA_USE_BF16" not in os.environ