def _check_and_init_precision(self) -> PrecisionPlugin: self._validate_precision_choice() if isinstance(self._precision_plugin_flag, PrecisionPlugin): return self._precision_plugin_flag if isinstance(self.accelerator, IPUAccelerator): return IPUPrecisionPlugin(self._precision_flag) # type: ignore if isinstance(self.accelerator, HPUAccelerator): return HPUPrecisionPlugin(self._precision_flag) # type: ignore if isinstance(self.accelerator, TPUAccelerator): if self._precision_flag == 32: return TPUPrecisionPlugin() elif self._precision_flag in (16, "bf16"): if self._precision_flag == 16: rank_zero_warn( "You passed `Trainer(accelerator='tpu', precision=16)` but AMP" " is not supported with TPUs. Using `precision='bf16'` instead." ) return TPUBf16PrecisionPlugin() if isinstance(self.strategy, DeepSpeedStrategy): return DeepSpeedPrecisionPlugin( self._precision_flag, self._amp_type_flag, self._amp_level_flag # type: ignore ) if self._precision_flag == 32: return PrecisionPlugin() if self._precision_flag == 64: return DoublePrecisionPlugin() if self._precision_flag == 16 and self._accelerator_flag == "cpu": rank_zero_warn( "You passed `Trainer(accelerator='cpu', precision=16)` but native AMP is not supported on CPU." " Using `precision='bf16'` instead.") self._precision_flag = "bf16" if self._precision_flag in (16, "bf16"): rank_zero_info( f"Using 16bit {self._amp_type_flag.value} Automatic Mixed Precision (AMP)" # type: ignore if self._precision_flag == 16 else "Using bfloat16 Automatic Mixed Precision (AMP)") if self._amp_type_flag == AMPType.NATIVE: device = "cpu" if self._accelerator_flag == "cpu" else "cuda" if isinstance(self.strategy, (DDPShardedStrategy, DDPSpawnShardedStrategy)): return ShardedNativeMixedPrecisionPlugin( self._precision_flag, device) if isinstance(self.strategy, DDPFullyShardedStrategy): return FullyShardedNativeMixedPrecisionPlugin( self._precision_flag, device) return NativeMixedPrecisionPlugin(self._precision_flag, device) if self._amp_type_flag == AMPType.APEX: self._amp_level_flag = self._amp_level_flag or "O2" return ApexMixedPrecisionPlugin(self._amp_level_flag) raise RuntimeError("No precision set")
def select_precision_plugin(self) -> PrecisionPlugin: # set precision type self.amp_type = AMPType.from_str(self.amp_type) if self.use_ipu: return IPUPrecisionPlugin(self.precision) if self._distrib_type == DistributedType.DEEPSPEED or isinstance( self._training_type_plugin, DeepSpeedPlugin): return DeepSpeedPrecisionPlugin(self.precision) if self.precision == 32: return PrecisionPlugin() if self.precision == 64: return DoublePrecisionPlugin() if self.precision in (16, "bf16"): if self.use_tpu: return TPUHalfPrecisionPlugin() if self.amp_type == AMPType.NATIVE: if not _NATIVE_AMP_AVAILABLE: msg = ( "You have asked for native AMP but your PyTorch version does not support it." " Consider upgrading with `pip install torch>=1.6`.") if _APEX_AVAILABLE: self.amp_type = AMPType.APEX msg += " We will attempt to use NVIDIA Apex for this session." rank_zero_warn(msg) else: raise MisconfigurationException(msg) else: log.info( f"Using native {self.precision} bit Automatic Mixed Precision" ) if self._is_sharded_training_type: return ShardedNativeMixedPrecisionPlugin( self.precision, use_cpu=self.use_cpu) if self._is_fully_sharded_training_type: return FullyShardedNativeMixedPrecisionPlugin( self.precision, use_cpu=self.use_cpu) return NativeMixedPrecisionPlugin(self.precision, use_cpu=self.use_cpu) if self.amp_type == AMPType.APEX: if not _APEX_AVAILABLE: raise MisconfigurationException( "You have asked for Apex AMP but you have not installed it yet." " Install apex first using this guide: https://github.com/NVIDIA/apex#linux" ) if self._is_sharded_training_type or self._is_fully_sharded_training_type: raise MisconfigurationException( "Sharded Plugin is not supported with Apex AMP, please using native AMP for 16-bit precision." ) log.info("Using APEX 16bit precision.") return ApexMixedPrecisionPlugin(self.amp_level) raise MisconfigurationException( f"Precision {self.precision} is invalid. Allowed precision values: {PrecisionType.supported_types()}" )
def test_cpu_amp_precision_context_manager(tmpdir): """Test to ensure that the context manager correctly is set to CPU + bfloat16.""" plugin = NativeMixedPrecisionPlugin("bf16", "cpu") assert plugin.device == "cpu" assert plugin.scaler is None context_manager = plugin.autocast_context_manager() assert isinstance(context_manager, torch.autocast) # check with str due to a bug upstream: https://github.com/pytorch/pytorch/issues/65786 assert str(context_manager.fast_dtype) == str(torch.bfloat16)
def test_cpu_amp_precision_context_manager(tmpdir): """Test to ensure that the context manager correctly is set to CPU + bfloat16, and a scaler isn't set.""" plugin = NativeMixedPrecisionPlugin(precision="bf16", use_cpu=True) assert plugin.use_cpu assert not hasattr(plugin, "scaler") context_manager = plugin.autocast_context_manager() assert isinstance(context_manager, torch.cpu.amp.autocast) assert context_manager.dtype == torch.bfloat16
def select_precision_plugin(self) -> PrecisionPlugin: # set precision type self.amp_type = AMPType.from_str(self.amp_type) if self.use_ipu: return IPUPrecisionPlugin(self.precision) if self._distrib_type == DistributedType.DEEPSPEED or isinstance( self._training_type_plugin, DeepSpeedPlugin): return DeepSpeedPrecisionPlugin(self.precision) if self.precision == 32: return PrecisionPlugin() if self.precision == 64: return DoublePrecisionPlugin() if self.precision == 16: if self.use_tpu: return TPUHalfPrecisionPlugin() if self.amp_type == AMPType.NATIVE: if self.use_cpu: raise MisconfigurationException( "You have asked for native AMP on CPU, but AMP is only available on GPU." ) if not _NATIVE_AMP_AVAILABLE: msg = ( "You have asked for native AMP but your PyTorch version does not support it." " Consider upgrading with `pip install torch>=1.6`.") if _APEX_AVAILABLE: self.amp_type = AMPType.APEX msg += " We will attempt to use NVIDIA Apex for this session." rank_zero_warn(msg) else: raise MisconfigurationException(msg) else: log.info("Using native 16bit precision.") if self._is_sharded_training_type: return ShardedNativeMixedPrecisionPlugin() if self._is_fully_sharded_training_type: return FullyShardedNativeMixedPrecisionPlugin() return NativeMixedPrecisionPlugin() if self.amp_type == AMPType.APEX: if not _APEX_AVAILABLE: raise MisconfigurationException( "You have asked for Apex AMP but you have not installed it yet." " Install apex first using this guide: https://github.com/NVIDIA/apex#linux" ) if self._is_sharded_training_type or self._is_fully_sharded_training_type: raise MisconfigurationException( "Sharded Plugin is not supported with Apex AMP," " please using native AMP for 16-bit precision.") log.info("Using APEX 16bit precision.") return ApexMixedPrecisionPlugin(self.amp_level) raise NotImplementedError("We only support precisions 64, 32 and 16!")
def select_precision_plugin(self) -> PrecisionPlugin: if self.precision == 32: self.amp_type = None return PrecisionPlugin() elif self.precision == 16: if self.on_tpu: return TPUHalfPrecisionPlugin() if self.amp_type == "native": if not _NATIVE_AMP_AVAILABLE: rank_zero_warn( "You have asked for native AMP but your PyTorch version does not support it." " Consider upgrading with `pip install torch>=1.6`." " We will attempt to use NVIDIA Apex for this session." ) if not _APEX_AVAILABLE and self.on_cpu: raise MisconfigurationException( "You have asked for native AMP on CPU, but AMP is only available on GPU." ) self.amp_type = "apex" elif self.on_cpu: raise MisconfigurationException( "You have asked for native AMP on CPU, but AMP is only available on GPU." ) else: log.info("Using native 16bit precision.") if isinstance(self.training_type_plugin, (DDPShardedPlugin, DDPSpawnShardedPlugin)): return ShardedNativeMixedPrecisionPlugin() self.amp_type = AMPType.NATIVE return NativeMixedPrecisionPlugin() if self.amp_type == "apex": if not _APEX_AVAILABLE: rank_zero_warn( "You have asked for Apex AMP but you have not installed it yet." " Install apex first using this guide: https://github.com/NVIDIA/apex#linux" ) else: if isinstance(self.training_type_plugin, (DDPShardedPlugin, DDPSpawnShardedPlugin)): raise MisconfigurationException( "Sharded Plugin is not supported with Apex AMP, " "please using native AMP for 16-bit precision.") log.info("Using APEX 16bit precision.") self.amp_type = AMPType.APEX return ApexMixedPrecisionPlugin(self.amp_level) else: raise NotImplementedError("We only support precisions 32 and 16!")
def select_precision_plugin(self) -> PrecisionPlugin: # set precision type self.amp_type = AMPType.from_str(self.amp_type) # validation for all plugins if self.amp_level is not None and self.amp_type != AMPType.APEX: raise MisconfigurationException( f"You have asked for `amp_level={self.amp_level!r}` but it's only supported with `amp_backend='apex'`." ) if self.use_ipu: if self.precision not in (16, 32): raise MisconfigurationException( f"`Trainer(accelerator='ipu', precision={self.precision!r})` is not supported." ) return IPUPrecisionPlugin(self.precision) if self.use_tpu: if self.precision == 32: return TPUPrecisionPlugin() elif self.precision == 64: raise MisconfigurationException( "`Trainer(accelerator='tpu', precision=64)` is not implemented." " Please, open an issue in `https://github.com/PyTorchLightning/pytorch-lightning/issues`" " requesting this feature.") elif self.precision in (16, "bf16"): if self.precision == 16: # this is not deprecated to ease transition between accelerator environments rank_zero_warn( f"You passed `Trainer(accelerator='tpu', precision=16)` but {self.amp_type.value} AMP" f" is not supported with TPUs. Using `precision='bf16'` instead." ) return TPUBf16PrecisionPlugin() if self._distrib_type == _StrategyType.DEEPSPEED or isinstance( self._training_type_plugin, DeepSpeedStrategy): return DeepSpeedPrecisionPlugin(self.precision, self.amp_type, self.amp_level) if self.precision == 32: return PrecisionPlugin() if self.precision == 64: return DoublePrecisionPlugin() # maybe convert the precision value if self.precision == 16 and self.use_cpu: if self.amp_type == AMPType.APEX: # apex was explicitly passed, not a good idea to silently switch to native AMP raise MisconfigurationException( "You passed `Trainer(accelerator='cpu', precision=16, amp_type='apex')`" " but apex AMP not supported on CPU.") # this automatic switch is to ease transition between accelerator environments rank_zero_warn( "You passed `Trainer(accelerator='cpu', precision=16)` but native AMP is not supported on CPU." " Using `precision='bf16'` instead.") self.precision = "bf16" if self.precision in (16, "bf16"): if self.precision == "bf16" and self.amp_type != AMPType.NATIVE: raise MisconfigurationException( f"You passed `Trainer(amp_type={self.amp_type.value!r}, precision='bf16')` but it's not supported." " Try using `amp_type='native'` instead.") rank_zero_info( f"Using 16bit {self.amp_type.value} Automatic Mixed Precision (AMP)" if self.precision == 16 else "Using bfloat16 Automatic Mixed Precision (AMP)") if self.amp_type == AMPType.NATIVE: device = "cpu" if self.use_cpu else "cuda" if self._is_sharded_training_type: return ShardedNativeMixedPrecisionPlugin( self.precision, device) if self._is_fully_sharded_training_type: return FullyShardedNativeMixedPrecisionPlugin( self.precision, device) return NativeMixedPrecisionPlugin(self.precision, device) if self.amp_type == AMPType.APEX: if self._is_sharded_training_type or self._is_fully_sharded_training_type: raise MisconfigurationException( "Sharded plugins are not supported with apex, please switch to `amp_backend='native'`." ) self.amp_level = self.amp_level or "O2" return ApexMixedPrecisionPlugin(self.amp_level) raise RuntimeError("No precision set")
def configure_optimizers(self): optimizer = torch.optim.Adam(self.parameters(), lr=1e-3) return optimizer dataset = MNIST(os.getcwd(), download=False, transform=transforms.ToTensor()) train_loader = DataLoader(dataset) # init model autoencoder = LitAutoEncoder() # most basic trainer, uses good defaults (auto-tensorboard, checkpoints, logs, and more) parallel_devices = [torch.device(i) for i in range(torch.cuda.device_count())] acc = GPUAccelerator(precision_plugin=NativeMixedPrecisionPlugin(), training_type_plugin=DDPPlugin( parallel_devices=parallel_devices, cluster_environment=LSFEnvironment())) targs = { 'max_epochs': 1, 'num_nodes': 2, 'accumulate_grad_batches': 1, 'gpus': 6, 'accelerator': acc, 'limit_train_batches': 10, 'limit_val_batches': 5, 'log_every_n_steps': 1 }
def test_cpu_amp_precision_throws_error(tmpdir): with pytest.raises( MisconfigurationException, match="To use native AMP on CPU, install PyTorch 1.10 or later.", ): NativeMixedPrecisionPlugin(use_cpu=True)