def select_precision_plugin(self) -> PrecisionPlugin:
        # set precision type
        self.amp_type = AMPType.from_str(self.amp_type)

        if self.use_ipu:
            return IPUPrecisionPlugin(self.precision)

        if self._distrib_type == DistributedType.DEEPSPEED or isinstance(
                self._training_type_plugin, DeepSpeedPlugin):
            return DeepSpeedPrecisionPlugin(self.precision)

        if self.precision == 32:
            return PrecisionPlugin()
        if self.precision == 64:
            return DoublePrecisionPlugin()
        if self.precision in (16, "bf16"):
            if self.use_tpu:
                return TPUHalfPrecisionPlugin()

            if self.amp_type == AMPType.NATIVE:
                if not _NATIVE_AMP_AVAILABLE:
                    msg = (
                        "You have asked for native AMP but your PyTorch version does not support it."
                        " Consider upgrading with `pip install torch>=1.6`.")
                    if _APEX_AVAILABLE:
                        self.amp_type = AMPType.APEX
                        msg += " We will attempt to use NVIDIA Apex for this session."
                        rank_zero_warn(msg)
                    else:
                        raise MisconfigurationException(msg)
                else:
                    log.info(
                        f"Using native {self.precision} bit Automatic Mixed Precision"
                    )
                    if self._is_sharded_training_type:
                        return ShardedNativeMixedPrecisionPlugin(
                            self.precision, use_cpu=self.use_cpu)
                    if self._is_fully_sharded_training_type:
                        return FullyShardedNativeMixedPrecisionPlugin(
                            self.precision, use_cpu=self.use_cpu)
                    return NativeMixedPrecisionPlugin(self.precision,
                                                      use_cpu=self.use_cpu)

            if self.amp_type == AMPType.APEX:
                if not _APEX_AVAILABLE:
                    raise MisconfigurationException(
                        "You have asked for Apex AMP but you have not installed it yet."
                        " Install apex first using this guide: https://github.com/NVIDIA/apex#linux"
                    )
                if self._is_sharded_training_type or self._is_fully_sharded_training_type:
                    raise MisconfigurationException(
                        "Sharded Plugin is not supported with Apex AMP, please using native AMP for 16-bit precision."
                    )
                log.info("Using APEX 16bit precision.")
                return ApexMixedPrecisionPlugin(self.amp_level)

        raise MisconfigurationException(
            f"Precision {self.precision} is invalid. Allowed precision values: {PrecisionType.supported_types()}"
        )
    def select_precision_plugin(self) -> PrecisionPlugin:
        # set precision type
        self.amp_type = AMPType.from_str(self.amp_type)

        if self.use_ipu:
            return IPUPrecisionPlugin(self.precision)

        if self._distrib_type == DistributedType.DEEPSPEED or isinstance(
                self._training_type_plugin, DeepSpeedPlugin):
            return DeepSpeedPrecisionPlugin(self.precision)

        if self.precision == 32:
            return PrecisionPlugin()
        if self.precision == 64:
            return DoublePrecisionPlugin()
        if self.precision == 16:
            if self.use_tpu:
                return TPUHalfPrecisionPlugin()

            if self.amp_type == AMPType.NATIVE:
                if self.use_cpu:
                    raise MisconfigurationException(
                        "You have asked for native AMP on CPU, but AMP is only available on GPU."
                    )
                if not _NATIVE_AMP_AVAILABLE:
                    msg = (
                        "You have asked for native AMP but your PyTorch version does not support it."
                        " Consider upgrading with `pip install torch>=1.6`.")
                    if _APEX_AVAILABLE:
                        self.amp_type = AMPType.APEX
                        msg += " We will attempt to use NVIDIA Apex for this session."
                        rank_zero_warn(msg)
                    else:
                        raise MisconfigurationException(msg)
                else:
                    log.info("Using native 16bit precision.")
                    if self._is_sharded_training_type:
                        return ShardedNativeMixedPrecisionPlugin()
                    if self._is_fully_sharded_training_type:
                        return FullyShardedNativeMixedPrecisionPlugin()
                    return NativeMixedPrecisionPlugin()

            if self.amp_type == AMPType.APEX:
                if not _APEX_AVAILABLE:
                    raise MisconfigurationException(
                        "You have asked for Apex AMP but you have not installed it yet."
                        " Install apex first using this guide: https://github.com/NVIDIA/apex#linux"
                    )
                if self._is_sharded_training_type or self._is_fully_sharded_training_type:
                    raise MisconfigurationException(
                        "Sharded Plugin is not supported with Apex AMP,"
                        " please using native AMP for 16-bit precision.")
                log.info("Using APEX 16bit precision.")
                return ApexMixedPrecisionPlugin(self.amp_level)

        raise NotImplementedError("We only support precisions 64, 32 and 16!")
Пример #3
0
    def _check_config_and_set_final_flags(
        self,
        strategy: Optional[Union[str, Strategy]],
        accelerator: Optional[Union[str, Accelerator]],
        precision: Union[int, str],
        plugins: Optional[Union[PLUGIN_INPUT, List[PLUGIN_INPUT]]],
        amp_type: str,
        amp_level: Optional[str],
        sync_batchnorm: bool,
    ) -> None:
        """This method checks:

        1. strategy: strategy, accelerator and plugin can all be set to strategies
        2. accelerator: if the value of the accelerator argument is a type of accelerator (instance or string),
            set self._accelerator_flag accordingly. If the value is strategy related (instance or string),
            it gets handled by 1.
        3. precision: The final value of the precision flag may be determined either by the precision argument or
            by a plugin instance.
        4. plugins: a plugin could occur as a value of the strategy argument (handled by 1), or the precision
            argument (handled by 3). We also extract the CheckpointIO and ClusterEnvironment plugins.
        """
        if plugins is not None:
            plugins = [plugins] if not isinstance(plugins, list) else plugins

        if strategy is not None:
            self._strategy_flag = strategy
            if strategy == "ddp_cpu":
                raise MisconfigurationException(
                    "`Trainer(strategy='ddp_cpu')` is not a valid strategy,"
                    " you can use `Trainer(strategy='ddp'|'ddp_spawn', accelerator='cpu')` instead."
                )
            if strategy == "tpu_spawn":
                raise MisconfigurationException(
                    "`Trainer(strategy='tpu_spawn')` is not a valid strategy,"
                    " you can use `Trainer(strategy='ddp_spawn', accelerator='tpu')` instead."
                )
            # handle duplications and conflict
            if isinstance(accelerator, Strategy) and strategy != accelerator:
                raise MisconfigurationException(
                    f"Incompatible values set in `strategy` and `accelerator` arguments."
                    f"Received both strategy={strategy} and accelerator={accelerator}"
                )
            if isinstance(
                    accelerator, str
            ) and accelerator in self._registered_strategies and strategy != accelerator:
                raise MisconfigurationException(
                    f"strategy {strategy} already set through `strategy` flag,"
                    f" but have also passed {accelerator} in through the accelerator flag."
                )
            if plugins:
                for plugin in plugins:
                    if isinstance(plugin, Strategy):
                        raise MisconfigurationException(
                            f"You have passed `Trainer(strategy={strategy})`"
                            f" and you can only specify one strategy, but you have passed {plugin} as a plugin."
                        )
                    if isinstance(
                            plugin,
                            str) and plugin in self._registered_strategies:
                        raise MisconfigurationException(
                            f"You have passed `Trainer(strategy={strategy})`"
                            f" and you can only specify one strategy, but you have passed {plugin} as a plugin."
                        )

        if accelerator is not None:
            if accelerator in self._accelerator_types or accelerator == "auto" or isinstance(
                    accelerator, Accelerator):
                self._accelerator_flag = accelerator
            elif accelerator in self._registered_strategies or isinstance(
                    accelerator, Strategy):
                rank_zero_deprecation(
                    f"Passing `Trainer(accelerator={accelerator!r})` has been deprecated"
                    f" in v1.5 and will be removed in v1.7. Use `Trainer(strategy={accelerator!r})` instead."
                )
                self._strategy_flag = accelerator
            elif accelerator == "ddp_cpu" and not self._strategy_flag:
                self._strategy_flag = accelerator

        if precision is not None:
            if str(precision) not in self._precision_types:
                raise MisconfigurationException(
                    f"Precision {repr(precision)} is invalid. Allowed precision values: {self._precision_types}"
                )
            self._precision_flag = precision

        if plugins:
            plugins_flags_types: Dict[str, int] = Counter()
            for plugin in plugins:
                if isinstance(plugin, Strategy) or isinstance(
                        plugin, str) and plugin in self._registered_strategies:
                    self._strategy_flag = plugin
                    rank_zero_deprecation(
                        f"Passing {plugin} `strategy` to the `plugins` flag in Trainer has been deprecated"
                        f" in v1.5 and will be removed in v1.7. Use `Trainer(strategy={plugin})` instead."
                    )
                    plugins_flags_types[Strategy.__name__] += 1

                elif isinstance(plugin, PrecisionPlugin):
                    self._precision_plugin_flag = plugin
                    plugins_flags_types[PrecisionPlugin.__name__] += 1
                elif isinstance(plugin, CheckpointIO):
                    self.checkpoint_io = plugin
                    plugins_flags_types[CheckpointIO.__name__] += 1
                elif isinstance(plugin, ClusterEnvironment):
                    self._cluster_environment_flag = plugin
                    plugins_flags_types[ClusterEnvironment.__name__] += 1
                elif isinstance(plugin, LayerSync):
                    if sync_batchnorm and not isinstance(
                            plugin, NativeSyncBatchNorm):
                        raise MisconfigurationException(
                            f"You set `Trainer(sync_batchnorm=True)` and provided a `{plugin.__class__.__name__}`"
                            " plugin, but this is not allowed. Choose one or the other."
                        )
                    self._layer_sync = plugin
                    plugins_flags_types[NativeSyncBatchNorm.__name__] += 1
                else:
                    raise MisconfigurationException(
                        f"Found invalid type for plugin {plugin}. Expected one of: PrecisionPlugin, "
                        "CheckpointIO, ClusterEnviroment, LayerSync, or Strategy."
                    )

            duplicated_plugin_key = [
                k for k, v in plugins_flags_types.items() if v > 1
            ]
            if duplicated_plugin_key:
                raise MisconfigurationException(
                    f"Received multiple values for {', '.join(duplicated_plugin_key)} flags in `plugins`."
                    " Expected one value for each type at most.")

        # handle the case when the user passes in a strategy instance which has an accelerator, precision,
        # checkpoint io or cluster env set up
        # TODO: @awaelchli improve the error messages below
        if self._strategy_flag and isinstance(self._strategy_flag, Strategy):
            if self._strategy_flag._accelerator:
                if self._accelerator_flag:
                    raise MisconfigurationException(
                        "accelerator set through both strategy class and accelerator flag, choose one"
                    )
                else:
                    self._accelerator_flag = self._strategy_flag._accelerator
            if self._strategy_flag._precision_plugin:
                # [RFC] handle precision plugin set up conflict?
                if self._precision_plugin_flag:
                    raise MisconfigurationException(
                        "precision set through both strategy class and plugins, choose one"
                    )
                else:
                    self._precision_plugin_flag = self._strategy_flag._precision_plugin
            if self._strategy_flag._checkpoint_io:
                if self.checkpoint_io:
                    raise MisconfigurationException(
                        "checkpoint_io set through both strategy class and plugins, choose one"
                    )
                else:
                    self.checkpoint_io = self._strategy_flag._checkpoint_io
            if getattr(self._strategy_flag, "cluster_environment", None):
                if self._cluster_environment_flag:
                    raise MisconfigurationException(
                        "cluster_environment set through both strategy class and plugins, choose one"
                    )
                else:
                    self._cluster_environment_flag = getattr(
                        self._strategy_flag, "cluster_environment")

            if hasattr(self._strategy_flag, "parallel_devices"):
                if self._strategy_flag.parallel_devices:
                    if self._strategy_flag.parallel_devices[0].type == "cpu":
                        if self._accelerator_flag and self._accelerator_flag not in (
                                "auto", "cpu"):
                            raise MisconfigurationException(
                                f"CPU parallel_devices set through {self._strategy_flag.__class__.__name__} class,"
                                f" but accelerator set to {self._accelerator_flag}, please choose one device type"
                            )
                        self._accelerator_flag = "cpu"
                    if self._strategy_flag.parallel_devices[0].type == "cuda":
                        if self._accelerator_flag and self._accelerator_flag not in (
                                "auto", "gpu"):
                            raise MisconfigurationException(
                                f"GPU parallel_devices set through {self._strategy_flag.__class__.__name__} class,"
                                f" but accelerator set to {self._accelerator_flag}, please choose one device type"
                            )
                        self._accelerator_flag = "gpu"
                    self._parallel_devices = self._strategy_flag.parallel_devices

        amp_type = amp_type if isinstance(amp_type, str) else None
        self._amp_type_flag = AMPType.from_str(amp_type)

        if amp_level is not None and self._amp_type_flag != AMPType.APEX:
            raise MisconfigurationException(
                f"You have asked for `amp_level={amp_level!r}` but it's only supported with `amp_backend='apex'`."
            )
    def select_precision_plugin(self) -> PrecisionPlugin:
        # set precision type
        self.amp_type = AMPType.from_str(self.amp_type)

        # validation for all plugins
        if self.amp_level is not None and self.amp_type != AMPType.APEX:
            raise MisconfigurationException(
                f"You have asked for `amp_level={self.amp_level!r}` but it's only supported with `amp_backend='apex'`."
            )

        if self.use_ipu:
            if self.precision not in (16, 32):
                raise MisconfigurationException(
                    f"`Trainer(accelerator='ipu', precision={self.precision!r})` is not supported."
                )
            return IPUPrecisionPlugin(self.precision)
        if self.use_tpu:
            if self.precision == 32:
                return TPUPrecisionPlugin()
            elif self.precision == 64:
                raise MisconfigurationException(
                    "`Trainer(accelerator='tpu', precision=64)` is not implemented."
                    " Please, open an issue in `https://github.com/PyTorchLightning/pytorch-lightning/issues`"
                    " requesting this feature.")
            elif self.precision in (16, "bf16"):
                if self.precision == 16:
                    # this is not deprecated to ease transition between accelerator environments
                    rank_zero_warn(
                        f"You passed `Trainer(accelerator='tpu', precision=16)` but {self.amp_type.value} AMP"
                        f" is not supported with TPUs. Using `precision='bf16'` instead."
                    )
                return TPUBf16PrecisionPlugin()

        if self._distrib_type == _StrategyType.DEEPSPEED or isinstance(
                self._training_type_plugin, DeepSpeedStrategy):
            return DeepSpeedPrecisionPlugin(self.precision, self.amp_type,
                                            self.amp_level)

        if self.precision == 32:
            return PrecisionPlugin()
        if self.precision == 64:
            return DoublePrecisionPlugin()

        # maybe convert the precision value
        if self.precision == 16 and self.use_cpu:
            if self.amp_type == AMPType.APEX:
                # apex was explicitly passed, not a good idea to silently switch to native AMP
                raise MisconfigurationException(
                    "You passed `Trainer(accelerator='cpu', precision=16, amp_type='apex')`"
                    " but apex AMP not supported on CPU.")
            # this automatic switch is to ease transition between accelerator environments
            rank_zero_warn(
                "You passed `Trainer(accelerator='cpu', precision=16)` but native AMP is not supported on CPU."
                " Using `precision='bf16'` instead.")
            self.precision = "bf16"

        if self.precision in (16, "bf16"):
            if self.precision == "bf16" and self.amp_type != AMPType.NATIVE:
                raise MisconfigurationException(
                    f"You passed `Trainer(amp_type={self.amp_type.value!r}, precision='bf16')` but it's not supported."
                    " Try using `amp_type='native'` instead.")

            rank_zero_info(
                f"Using 16bit {self.amp_type.value} Automatic Mixed Precision (AMP)"
                if self.precision ==
                16 else "Using bfloat16 Automatic Mixed Precision (AMP)")

            if self.amp_type == AMPType.NATIVE:
                device = "cpu" if self.use_cpu else "cuda"

                if self._is_sharded_training_type:
                    return ShardedNativeMixedPrecisionPlugin(
                        self.precision, device)
                if self._is_fully_sharded_training_type:
                    return FullyShardedNativeMixedPrecisionPlugin(
                        self.precision, device)
                return NativeMixedPrecisionPlugin(self.precision, device)

            if self.amp_type == AMPType.APEX:
                if self._is_sharded_training_type or self._is_fully_sharded_training_type:
                    raise MisconfigurationException(
                        "Sharded plugins are not supported with apex, please switch to `amp_backend='native'`."
                    )
                self.amp_level = self.amp_level or "O2"
                return ApexMixedPrecisionPlugin(self.amp_level)

        raise RuntimeError("No precision set")
Пример #5
0
    def _check_config_and_set_final_flags(
        self,
        strategy: Optional[Union[str, Strategy]],
        accelerator: Optional[Union[str, Accelerator]],
        precision: Union[int, str],
        plugins: Optional[Union[PLUGIN_INPUT, List[PLUGIN_INPUT]]],
        amp_type: str,
        amp_level: Optional[str],
        sync_batchnorm: bool,
    ) -> None:
        """This method checks:

        1. strategy: whether the strategy name is valid, and sets the internal flags if it is.
        2. accelerator: if the value of the accelerator argument is a type of accelerator (instance or string),
            set self._accelerator_flag accordingly.
        3. precision: The final value of the precision flag may be determined either by the precision argument or
            by a plugin instance.
        4. plugins: The list of plugins may contain a Precision plugin, CheckpointIO, ClusterEnvironment and others.
            Additionally, other flags such as `precision` or `sync_batchnorm` can populate the list with the
            corresponding plugin instances.
        """
        if plugins is not None:
            plugins = [plugins] if not isinstance(plugins, list) else plugins

        if isinstance(strategy, str):
            strategy = strategy.lower()

        if strategy is not None:
            self._strategy_flag = strategy
            if strategy == "ddp_cpu":
                raise MisconfigurationException(
                    "`Trainer(strategy='ddp_cpu')` is not a valid strategy,"
                    " you can use `Trainer(strategy='ddp'|'ddp_spawn', accelerator='cpu')` instead."
                )
            if strategy == "tpu_spawn":
                raise MisconfigurationException(
                    "`Trainer(strategy='tpu_spawn')` is not a valid strategy,"
                    " you can use `Trainer(strategy='ddp_spawn', accelerator='tpu')` instead."
                )

        if accelerator is not None:
            if accelerator in self._accelerator_types or accelerator == "auto" or isinstance(accelerator, Accelerator):
                self._accelerator_flag = accelerator

        if precision is not None:
            if str(precision) not in self._precision_types:
                raise MisconfigurationException(
                    f"Precision {repr(precision)} is invalid. Allowed precision values: {self._precision_types}"
                )
            self._precision_flag = precision

        if plugins:
            plugins_flags_types: Dict[str, int] = Counter()
            for plugin in plugins:
                if isinstance(plugin, PrecisionPlugin):
                    self._precision_plugin_flag = plugin
                    plugins_flags_types[PrecisionPlugin.__name__] += 1
                elif isinstance(plugin, CheckpointIO):
                    self.checkpoint_io = plugin
                    plugins_flags_types[CheckpointIO.__name__] += 1
                elif isinstance(plugin, ClusterEnvironment):
                    self._cluster_environment_flag = plugin
                    plugins_flags_types[ClusterEnvironment.__name__] += 1
                elif isinstance(plugin, LayerSync):
                    if sync_batchnorm and not isinstance(plugin, NativeSyncBatchNorm):
                        raise MisconfigurationException(
                            f"You set `Trainer(sync_batchnorm=True)` and provided a `{plugin.__class__.__name__}`"
                            " plugin, but this is not allowed. Choose one or the other."
                        )
                    self._layer_sync = plugin
                    plugins_flags_types[NativeSyncBatchNorm.__name__] += 1
                else:
                    raise MisconfigurationException(
                        f"Found invalid type for plugin {plugin}. Expected one of: PrecisionPlugin, "
                        "CheckpointIO, ClusterEnviroment, or LayerSync."
                    )

            duplicated_plugin_key = [k for k, v in plugins_flags_types.items() if v > 1]
            if duplicated_plugin_key:
                raise MisconfigurationException(
                    f"Received multiple values for {', '.join(duplicated_plugin_key)} flags in `plugins`."
                    " Expected one value for each type at most."
                )

        # handle the case when the user passes in a strategy instance which has an accelerator, precision,
        # checkpoint io or cluster env set up
        # TODO: @awaelchli improve the error messages below
        if self._strategy_flag and isinstance(self._strategy_flag, Strategy):
            if self._strategy_flag._accelerator:
                if self._accelerator_flag:
                    raise MisconfigurationException(
                        "accelerator set through both strategy class and accelerator flag, choose one"
                    )
                else:
                    self._accelerator_flag = self._strategy_flag._accelerator
            if self._strategy_flag._precision_plugin:
                # [RFC] handle precision plugin set up conflict?
                if self._precision_plugin_flag:
                    raise MisconfigurationException("precision set through both strategy class and plugins, choose one")
                else:
                    self._precision_plugin_flag = self._strategy_flag._precision_plugin
            if self._strategy_flag._checkpoint_io:
                if self.checkpoint_io:
                    raise MisconfigurationException(
                        "checkpoint_io set through both strategy class and plugins, choose one"
                    )
                else:
                    self.checkpoint_io = self._strategy_flag._checkpoint_io
            if getattr(self._strategy_flag, "cluster_environment", None):
                if self._cluster_environment_flag:
                    raise MisconfigurationException(
                        "cluster_environment set through both strategy class and plugins, choose one"
                    )
                else:
                    self._cluster_environment_flag = getattr(self._strategy_flag, "cluster_environment")

            if hasattr(self._strategy_flag, "parallel_devices"):
                if self._strategy_flag.parallel_devices:
                    if self._strategy_flag.parallel_devices[0].type == "cpu":
                        if self._accelerator_flag and self._accelerator_flag not in ("auto", "cpu"):
                            raise MisconfigurationException(
                                f"CPU parallel_devices set through {self._strategy_flag.__class__.__name__} class,"
                                f" but accelerator set to {self._accelerator_flag}, please choose one device type"
                            )
                        self._accelerator_flag = "cpu"
                    if self._strategy_flag.parallel_devices[0].type == "cuda":
                        if self._accelerator_flag and self._accelerator_flag not in ("auto", "gpu"):
                            raise MisconfigurationException(
                                f"GPU parallel_devices set through {self._strategy_flag.__class__.__name__} class,"
                                f" but accelerator set to {self._accelerator_flag}, please choose one device type"
                            )
                        self._accelerator_flag = "gpu"
                    self._parallel_devices = self._strategy_flag.parallel_devices

        amp_type = amp_type if isinstance(amp_type, str) else None
        self._amp_type_flag = AMPType.from_str(amp_type)

        if amp_level is not None and self._amp_type_flag != AMPType.APEX:
            raise MisconfigurationException(
                f"You have asked for `amp_level={amp_level!r}` but it's only supported with `amp_backend='apex'`."
            )