def select_precision_plugin(self) -> PrecisionPlugin:
        # set precision type
        self.amp_type = AMPType.from_str(self.amp_type)

        if self.use_ipu:
            return IPUPrecisionPlugin(self.precision)

        if self._distrib_type == DistributedType.DEEPSPEED or isinstance(
                self._training_type_plugin, DeepSpeedPlugin):
            return DeepSpeedPrecisionPlugin(self.precision)

        if self.precision == 32:
            return PrecisionPlugin()
        if self.precision == 64:
            return DoublePrecisionPlugin()
        if self.precision in (16, "bf16"):
            if self.use_tpu:
                return TPUHalfPrecisionPlugin()

            if self.amp_type == AMPType.NATIVE:
                if not _NATIVE_AMP_AVAILABLE:
                    msg = (
                        "You have asked for native AMP but your PyTorch version does not support it."
                        " Consider upgrading with `pip install torch>=1.6`.")
                    if _APEX_AVAILABLE:
                        self.amp_type = AMPType.APEX
                        msg += " We will attempt to use NVIDIA Apex for this session."
                        rank_zero_warn(msg)
                    else:
                        raise MisconfigurationException(msg)
                else:
                    log.info(
                        f"Using native {self.precision} bit Automatic Mixed Precision"
                    )
                    if self._is_sharded_training_type:
                        return ShardedNativeMixedPrecisionPlugin(
                            self.precision, use_cpu=self.use_cpu)
                    if self._is_fully_sharded_training_type:
                        return FullyShardedNativeMixedPrecisionPlugin(
                            self.precision, use_cpu=self.use_cpu)
                    return NativeMixedPrecisionPlugin(self.precision,
                                                      use_cpu=self.use_cpu)

            if self.amp_type == AMPType.APEX:
                if not _APEX_AVAILABLE:
                    raise MisconfigurationException(
                        "You have asked for Apex AMP but you have not installed it yet."
                        " Install apex first using this guide: https://github.com/NVIDIA/apex#linux"
                    )
                if self._is_sharded_training_type or self._is_fully_sharded_training_type:
                    raise MisconfigurationException(
                        "Sharded Plugin is not supported with Apex AMP, please using native AMP for 16-bit precision."
                    )
                log.info("Using APEX 16bit precision.")
                return ApexMixedPrecisionPlugin(self.amp_level)

        raise MisconfigurationException(
            f"Precision {self.precision} is invalid. Allowed precision values: {PrecisionType.supported_types()}"
        )
    def select_precision_plugin(self) -> PrecisionPlugin:
        # set precision type
        self.amp_type = AMPType.from_str(self.amp_type)

        if self.use_ipu:
            return IPUPrecisionPlugin(self.precision)

        if self._distrib_type == DistributedType.DEEPSPEED or isinstance(
                self._training_type_plugin, DeepSpeedPlugin):
            return DeepSpeedPrecisionPlugin(self.precision)

        if self.precision == 32:
            return PrecisionPlugin()
        if self.precision == 64:
            return DoublePrecisionPlugin()
        if self.precision == 16:
            if self.use_tpu:
                return TPUHalfPrecisionPlugin()

            if self.amp_type == AMPType.NATIVE:
                if self.use_cpu:
                    raise MisconfigurationException(
                        "You have asked for native AMP on CPU, but AMP is only available on GPU."
                    )
                if not _NATIVE_AMP_AVAILABLE:
                    msg = (
                        "You have asked for native AMP but your PyTorch version does not support it."
                        " Consider upgrading with `pip install torch>=1.6`.")
                    if _APEX_AVAILABLE:
                        self.amp_type = AMPType.APEX
                        msg += " We will attempt to use NVIDIA Apex for this session."
                        rank_zero_warn(msg)
                    else:
                        raise MisconfigurationException(msg)
                else:
                    log.info("Using native 16bit precision.")
                    if self._is_sharded_training_type:
                        return ShardedNativeMixedPrecisionPlugin()
                    if self._is_fully_sharded_training_type:
                        return FullyShardedNativeMixedPrecisionPlugin()
                    return NativeMixedPrecisionPlugin()

            if self.amp_type == AMPType.APEX:
                if not _APEX_AVAILABLE:
                    raise MisconfigurationException(
                        "You have asked for Apex AMP but you have not installed it yet."
                        " Install apex first using this guide: https://github.com/NVIDIA/apex#linux"
                    )
                if self._is_sharded_training_type or self._is_fully_sharded_training_type:
                    raise MisconfigurationException(
                        "Sharded Plugin is not supported with Apex AMP,"
                        " please using native AMP for 16-bit precision.")
                log.info("Using APEX 16bit precision.")
                return ApexMixedPrecisionPlugin(self.amp_level)

        raise NotImplementedError("We only support precisions 64, 32 and 16!")
예제 #3
0
    def select_precision_plugin(self) -> PrecisionPlugin:
        if self.precision == 32:
            self.amp_type = None
            return PrecisionPlugin()

        elif self.precision == 16:
            if self.on_tpu:
                return TPUHalfPrecisionPlugin()

            if self.amp_type == "native":
                if not _NATIVE_AMP_AVAILABLE:
                    rank_zero_warn(
                        "You have asked for native AMP but your PyTorch version does not support it."
                        " Consider upgrading with `pip install torch>=1.6`."
                        " We will attempt to use NVIDIA Apex for this session."
                    )
                    if not _APEX_AVAILABLE and self.on_cpu:
                        raise MisconfigurationException(
                            "You have asked for native AMP on CPU, but AMP is only available on GPU."
                        )
                    self.amp_type = "apex"
                elif self.on_cpu:
                    raise MisconfigurationException(
                        "You have asked for native AMP on CPU, but AMP is only available on GPU."
                    )
                else:
                    log.info("Using native 16bit precision.")
                    if isinstance(self.training_type_plugin,
                                  (DDPShardedPlugin, DDPSpawnShardedPlugin)):
                        return ShardedNativeMixedPrecisionPlugin()
                    self.amp_type = AMPType.NATIVE
                    return NativeMixedPrecisionPlugin()

            if self.amp_type == "apex":
                if not _APEX_AVAILABLE:
                    rank_zero_warn(
                        "You have asked for Apex AMP but you have not installed it yet."
                        " Install apex first using this guide: https://github.com/NVIDIA/apex#linux"
                    )
                else:
                    if isinstance(self.training_type_plugin,
                                  (DDPShardedPlugin, DDPSpawnShardedPlugin)):
                        raise MisconfigurationException(
                            "Sharded Plugin is not supported with Apex AMP, "
                            "please using native AMP for 16-bit precision.")
                    log.info("Using APEX 16bit precision.")
                    self.amp_type = AMPType.APEX
                    return ApexMixedPrecisionPlugin(self.amp_level)
        else:
            raise NotImplementedError("We only support precisions 32 and 16!")