def _check_and_init_precision(self) -> PrecisionPlugin: self._validate_precision_choice() if isinstance(self._precision_plugin_flag, PrecisionPlugin): return self._precision_plugin_flag if isinstance(self.accelerator, IPUAccelerator): return IPUPrecisionPlugin(self._precision_flag) # type: ignore if isinstance(self.accelerator, HPUAccelerator): return HPUPrecisionPlugin(self._precision_flag) # type: ignore if isinstance(self.accelerator, TPUAccelerator): if self._precision_flag == 32: return TPUPrecisionPlugin() elif self._precision_flag in (16, "bf16"): if self._precision_flag == 16: rank_zero_warn( "You passed `Trainer(accelerator='tpu', precision=16)` but AMP" " is not supported with TPUs. Using `precision='bf16'` instead." ) return TPUBf16PrecisionPlugin() if isinstance(self.strategy, DeepSpeedStrategy): return DeepSpeedPrecisionPlugin( self._precision_flag, self._amp_type_flag, self._amp_level_flag # type: ignore ) if self._precision_flag == 32: return PrecisionPlugin() if self._precision_flag == 64: return DoublePrecisionPlugin() if self._precision_flag == 16 and self._accelerator_flag == "cpu": rank_zero_warn( "You passed `Trainer(accelerator='cpu', precision=16)` but native AMP is not supported on CPU." " Using `precision='bf16'` instead.") self._precision_flag = "bf16" if self._precision_flag in (16, "bf16"): rank_zero_info( f"Using 16bit {self._amp_type_flag.value} Automatic Mixed Precision (AMP)" # type: ignore if self._precision_flag == 16 else "Using bfloat16 Automatic Mixed Precision (AMP)") if self._amp_type_flag == AMPType.NATIVE: device = "cpu" if self._accelerator_flag == "cpu" else "cuda" if isinstance(self.strategy, (DDPShardedStrategy, DDPSpawnShardedStrategy)): return ShardedNativeMixedPrecisionPlugin( self._precision_flag, device) if isinstance(self.strategy, DDPFullyShardedStrategy): return FullyShardedNativeMixedPrecisionPlugin( self._precision_flag, device) return NativeMixedPrecisionPlugin(self._precision_flag, device) if self._amp_type_flag == AMPType.APEX: self._amp_level_flag = self._amp_level_flag or "O2" return ApexMixedPrecisionPlugin(self._amp_level_flag) raise RuntimeError("No precision set")
def test_tpu_invalid_raises(): accelerator = TPUAccelerator(object(), TPUSpawnPlugin()) with pytest.raises(ValueError, match="TPUAccelerator` can only be used with a `TPUPrecisionPlugin"): accelerator.setup(object()) accelerator = TPUAccelerator(TPUPrecisionPlugin(), object()) with pytest.raises(ValueError, match="TPUAccelerator` can only be used with a `SingleTPUPlugin` or `TPUSpawnPlugi"): accelerator.setup(object())
def test_tpu_invalid_raises(): strategy = TPUSpawnStrategy(accelerator=TPUAccelerator(), precision_plugin=PrecisionPlugin()) with pytest.raises(ValueError, match="TPUAccelerator` can only be used with a `TPUPrecisionPlugin"): Trainer(strategy=strategy, devices=8) strategy = DDPStrategy(accelerator=TPUAccelerator(), precision_plugin=TPUPrecisionPlugin()) with pytest.raises(ValueError, match="TPUAccelerator` can only be used with a `SingleTPUStrategy`"): Trainer(strategy=strategy, devices=8)
def test_tpu_invalid_raises(): training_type_plugin = TPUSpawnStrategy(accelerator=TPUAccelerator(), precision_plugin=Mock()) with pytest.raises( ValueError, match="TPUAccelerator` can only be used with a `TPUPrecisionPlugin" ): Trainer(strategy=training_type_plugin) training_type_plugin = DDPStrategy(accelerator=TPUAccelerator(), precision_plugin=TPUPrecisionPlugin()) with pytest.raises( ValueError, match="TPUAccelerator` can only be used with a `SingleTPUStrategy`" ): Trainer(strategy=training_type_plugin)
def test_tpu_invalid_raises_set_precision_with_strategy(): accelerator = TPUAccelerator() training_type_plugin = TPUSpawnStrategy(accelerator=accelerator, precision_plugin=object()) with pytest.raises( ValueError, match= "`TPUAccelerator` can only be used with a `TPUPrecisionPlugin`"): Trainer(strategy=training_type_plugin) accelerator = TPUAccelerator() training_type_plugin = DDPStrategy(accelerator=accelerator, precision_plugin=TPUPrecisionPlugin()) with pytest.raises( ValueError, match= "The `TPUAccelerator` can only be used with a `SingleTPUStrategy` or `TPUSpawnStrategy" ): Trainer(strategy=training_type_plugin)
def select_precision_plugin(self) -> PrecisionPlugin: # set precision type self.amp_type = AMPType.from_str(self.amp_type) # validation for all plugins if self.amp_level is not None and self.amp_type != AMPType.APEX: raise MisconfigurationException( f"You have asked for `amp_level={self.amp_level!r}` but it's only supported with `amp_backend='apex'`." ) if self.use_ipu: if self.precision not in (16, 32): raise MisconfigurationException( f"`Trainer(accelerator='ipu', precision={self.precision!r})` is not supported." ) return IPUPrecisionPlugin(self.precision) if self.use_tpu: if self.precision == 32: return TPUPrecisionPlugin() elif self.precision == 64: raise MisconfigurationException( "`Trainer(accelerator='tpu', precision=64)` is not implemented." " Please, open an issue in `https://github.com/PyTorchLightning/pytorch-lightning/issues`" " requesting this feature.") elif self.precision in (16, "bf16"): if self.precision == 16: # this is not deprecated to ease transition between accelerator environments rank_zero_warn( f"You passed `Trainer(accelerator='tpu', precision=16)` but {self.amp_type.value} AMP" f" is not supported with TPUs. Using `precision='bf16'` instead." ) return TPUBf16PrecisionPlugin() if self._distrib_type == _StrategyType.DEEPSPEED or isinstance( self._training_type_plugin, DeepSpeedStrategy): return DeepSpeedPrecisionPlugin(self.precision, self.amp_type, self.amp_level) if self.precision == 32: return PrecisionPlugin() if self.precision == 64: return DoublePrecisionPlugin() # maybe convert the precision value if self.precision == 16 and self.use_cpu: if self.amp_type == AMPType.APEX: # apex was explicitly passed, not a good idea to silently switch to native AMP raise MisconfigurationException( "You passed `Trainer(accelerator='cpu', precision=16, amp_type='apex')`" " but apex AMP not supported on CPU.") # this automatic switch is to ease transition between accelerator environments rank_zero_warn( "You passed `Trainer(accelerator='cpu', precision=16)` but native AMP is not supported on CPU." " Using `precision='bf16'` instead.") self.precision = "bf16" if self.precision in (16, "bf16"): if self.precision == "bf16" and self.amp_type != AMPType.NATIVE: raise MisconfigurationException( f"You passed `Trainer(amp_type={self.amp_type.value!r}, precision='bf16')` but it's not supported." " Try using `amp_type='native'` instead.") rank_zero_info( f"Using 16bit {self.amp_type.value} Automatic Mixed Precision (AMP)" if self.precision == 16 else "Using bfloat16 Automatic Mixed Precision (AMP)") if self.amp_type == AMPType.NATIVE: device = "cpu" if self.use_cpu else "cuda" if self._is_sharded_training_type: return ShardedNativeMixedPrecisionPlugin( self.precision, device) if self._is_fully_sharded_training_type: return FullyShardedNativeMixedPrecisionPlugin( self.precision, device) return NativeMixedPrecisionPlugin(self.precision, device) if self.amp_type == AMPType.APEX: if self._is_sharded_training_type or self._is_fully_sharded_training_type: raise MisconfigurationException( "Sharded plugins are not supported with apex, please switch to `amp_backend='native'`." ) self.amp_level = self.amp_level or "O2" return ApexMixedPrecisionPlugin(self.amp_level) raise RuntimeError("No precision set")