コード例 #1
0
    def _check_and_init_precision(self) -> PrecisionPlugin:
        self._validate_precision_choice()
        if isinstance(self._precision_plugin_flag, PrecisionPlugin):
            return self._precision_plugin_flag

        if isinstance(self.accelerator, IPUAccelerator):
            return IPUPrecisionPlugin(self._precision_flag)  # type: ignore
        if isinstance(self.accelerator, HPUAccelerator):
            return HPUPrecisionPlugin(self._precision_flag)  # type: ignore
        if isinstance(self.accelerator, TPUAccelerator):
            if self._precision_flag == 32:
                return TPUPrecisionPlugin()
            elif self._precision_flag in (16, "bf16"):
                if self._precision_flag == 16:
                    rank_zero_warn(
                        "You passed `Trainer(accelerator='tpu', precision=16)` but AMP"
                        " is not supported with TPUs. Using `precision='bf16'` instead."
                    )
                return TPUBf16PrecisionPlugin()
        if isinstance(self.strategy, DeepSpeedStrategy):
            return DeepSpeedPrecisionPlugin(
                self._precision_flag,
                self._amp_type_flag,
                self._amp_level_flag  # type: ignore
            )

        if self._precision_flag == 32:
            return PrecisionPlugin()
        if self._precision_flag == 64:
            return DoublePrecisionPlugin()

        if self._precision_flag == 16 and self._accelerator_flag == "cpu":
            rank_zero_warn(
                "You passed `Trainer(accelerator='cpu', precision=16)` but native AMP is not supported on CPU."
                " Using `precision='bf16'` instead.")
            self._precision_flag = "bf16"

        if self._precision_flag in (16, "bf16"):
            rank_zero_info(
                f"Using 16bit {self._amp_type_flag.value} Automatic Mixed Precision (AMP)"  # type: ignore
                if self._precision_flag ==
                16 else "Using bfloat16 Automatic Mixed Precision (AMP)")

            if self._amp_type_flag == AMPType.NATIVE:
                device = "cpu" if self._accelerator_flag == "cpu" else "cuda"

                if isinstance(self.strategy,
                              (DDPShardedStrategy, DDPSpawnShardedStrategy)):
                    return ShardedNativeMixedPrecisionPlugin(
                        self._precision_flag, device)
                if isinstance(self.strategy, DDPFullyShardedStrategy):
                    return FullyShardedNativeMixedPrecisionPlugin(
                        self._precision_flag, device)
                return NativeMixedPrecisionPlugin(self._precision_flag, device)

            if self._amp_type_flag == AMPType.APEX:
                self._amp_level_flag = self._amp_level_flag or "O2"
                return ApexMixedPrecisionPlugin(self._amp_level_flag)

        raise RuntimeError("No precision set")
コード例 #2
0
    def select_precision_plugin(self) -> PrecisionPlugin:
        # set precision type
        self.amp_type = AMPType.from_str(self.amp_type)

        if self.use_ipu:
            return IPUPrecisionPlugin(self.precision)

        if self._distrib_type == DistributedType.DEEPSPEED or isinstance(
                self._training_type_plugin, DeepSpeedPlugin):
            return DeepSpeedPrecisionPlugin(self.precision)

        if self.precision == 32:
            return PrecisionPlugin()
        if self.precision == 64:
            return DoublePrecisionPlugin()
        if self.precision in (16, "bf16"):
            if self.use_tpu:
                return TPUHalfPrecisionPlugin()

            if self.amp_type == AMPType.NATIVE:
                if not _NATIVE_AMP_AVAILABLE:
                    msg = (
                        "You have asked for native AMP but your PyTorch version does not support it."
                        " Consider upgrading with `pip install torch>=1.6`.")
                    if _APEX_AVAILABLE:
                        self.amp_type = AMPType.APEX
                        msg += " We will attempt to use NVIDIA Apex for this session."
                        rank_zero_warn(msg)
                    else:
                        raise MisconfigurationException(msg)
                else:
                    log.info(
                        f"Using native {self.precision} bit Automatic Mixed Precision"
                    )
                    if self._is_sharded_training_type:
                        return ShardedNativeMixedPrecisionPlugin(
                            self.precision, use_cpu=self.use_cpu)
                    if self._is_fully_sharded_training_type:
                        return FullyShardedNativeMixedPrecisionPlugin(
                            self.precision, use_cpu=self.use_cpu)
                    return NativeMixedPrecisionPlugin(self.precision,
                                                      use_cpu=self.use_cpu)

            if self.amp_type == AMPType.APEX:
                if not _APEX_AVAILABLE:
                    raise MisconfigurationException(
                        "You have asked for Apex AMP but you have not installed it yet."
                        " Install apex first using this guide: https://github.com/NVIDIA/apex#linux"
                    )
                if self._is_sharded_training_type or self._is_fully_sharded_training_type:
                    raise MisconfigurationException(
                        "Sharded Plugin is not supported with Apex AMP, please using native AMP for 16-bit precision."
                    )
                log.info("Using APEX 16bit precision.")
                return ApexMixedPrecisionPlugin(self.amp_level)

        raise MisconfigurationException(
            f"Precision {self.precision} is invalid. Allowed precision values: {PrecisionType.supported_types()}"
        )
コード例 #3
0
def test_tpu_invalid_raises():
    strategy = TPUSpawnStrategy(accelerator=TPUAccelerator(), precision_plugin=PrecisionPlugin())
    with pytest.raises(ValueError, match="TPUAccelerator` can only be used with a `TPUPrecisionPlugin"):
        Trainer(strategy=strategy, devices=8)

    strategy = DDPStrategy(accelerator=TPUAccelerator(), precision_plugin=TPUPrecisionPlugin())
    with pytest.raises(ValueError, match="TPUAccelerator` can only be used with a `SingleTPUStrategy`"):
        Trainer(strategy=strategy, devices=8)
コード例 #4
0
    def select_precision_plugin(self) -> PrecisionPlugin:
        # set precision type
        self.amp_type = AMPType.from_str(self.amp_type)

        if self.use_ipu:
            return IPUPrecisionPlugin(self.precision)

        if self._distrib_type == DistributedType.DEEPSPEED or isinstance(
                self._training_type_plugin, DeepSpeedPlugin):
            return DeepSpeedPrecisionPlugin(self.precision)

        if self.precision == 32:
            return PrecisionPlugin()
        if self.precision == 64:
            return DoublePrecisionPlugin()
        if self.precision == 16:
            if self.use_tpu:
                return TPUHalfPrecisionPlugin()

            if self.amp_type == AMPType.NATIVE:
                if self.use_cpu:
                    raise MisconfigurationException(
                        "You have asked for native AMP on CPU, but AMP is only available on GPU."
                    )
                if not _NATIVE_AMP_AVAILABLE:
                    msg = (
                        "You have asked for native AMP but your PyTorch version does not support it."
                        " Consider upgrading with `pip install torch>=1.6`.")
                    if _APEX_AVAILABLE:
                        self.amp_type = AMPType.APEX
                        msg += " We will attempt to use NVIDIA Apex for this session."
                        rank_zero_warn(msg)
                    else:
                        raise MisconfigurationException(msg)
                else:
                    log.info("Using native 16bit precision.")
                    if self._is_sharded_training_type:
                        return ShardedNativeMixedPrecisionPlugin()
                    if self._is_fully_sharded_training_type:
                        return FullyShardedNativeMixedPrecisionPlugin()
                    return NativeMixedPrecisionPlugin()

            if self.amp_type == AMPType.APEX:
                if not _APEX_AVAILABLE:
                    raise MisconfigurationException(
                        "You have asked for Apex AMP but you have not installed it yet."
                        " Install apex first using this guide: https://github.com/NVIDIA/apex#linux"
                    )
                if self._is_sharded_training_type or self._is_fully_sharded_training_type:
                    raise MisconfigurationException(
                        "Sharded Plugin is not supported with Apex AMP,"
                        " please using native AMP for 16-bit precision.")
                log.info("Using APEX 16bit precision.")
                return ApexMixedPrecisionPlugin(self.amp_level)

        raise NotImplementedError("We only support precisions 64, 32 and 16!")
コード例 #5
0
    def select_precision_plugin(self) -> PrecisionPlugin:
        if self.precision == 32:
            self.amp_type = None
            return PrecisionPlugin()

        elif self.precision == 16:
            if self.on_tpu:
                return TPUHalfPrecisionPlugin()

            if self.amp_type == "native":
                if not _NATIVE_AMP_AVAILABLE:
                    rank_zero_warn(
                        "You have asked for native AMP but your PyTorch version does not support it."
                        " Consider upgrading with `pip install torch>=1.6`."
                        " We will attempt to use NVIDIA Apex for this session."
                    )
                    if not _APEX_AVAILABLE and self.on_cpu:
                        raise MisconfigurationException(
                            "You have asked for native AMP on CPU, but AMP is only available on GPU."
                        )
                    self.amp_type = "apex"
                elif self.on_cpu:
                    raise MisconfigurationException(
                        "You have asked for native AMP on CPU, but AMP is only available on GPU."
                    )
                else:
                    log.info("Using native 16bit precision.")
                    if isinstance(self.training_type_plugin,
                                  (DDPShardedPlugin, DDPSpawnShardedPlugin)):
                        return ShardedNativeMixedPrecisionPlugin()
                    self.amp_type = AMPType.NATIVE
                    return NativeMixedPrecisionPlugin()

            if self.amp_type == "apex":
                if not _APEX_AVAILABLE:
                    rank_zero_warn(
                        "You have asked for Apex AMP but you have not installed it yet."
                        " Install apex first using this guide: https://github.com/NVIDIA/apex#linux"
                    )
                else:
                    if isinstance(self.training_type_plugin,
                                  (DDPShardedPlugin, DDPSpawnShardedPlugin)):
                        raise MisconfigurationException(
                            "Sharded Plugin is not supported with Apex AMP, "
                            "please using native AMP for 16-bit precision.")
                    log.info("Using APEX 16bit precision.")
                    self.amp_type = AMPType.APEX
                    return ApexMixedPrecisionPlugin(self.amp_level)
        else:
            raise NotImplementedError("We only support precisions 32 and 16!")
コード例 #6
0
    def select_precision_plugin(self) -> PrecisionPlugin:
        # set precision type
        self.amp_type = AMPType.from_str(self.amp_type)

        # validation for all plugins
        if self.amp_level is not None and self.amp_type != AMPType.APEX:
            raise MisconfigurationException(
                f"You have asked for `amp_level={self.amp_level!r}` but it's only supported with `amp_backend='apex'`."
            )

        if self.use_ipu:
            if self.precision not in (16, 32):
                raise MisconfigurationException(
                    f"`Trainer(accelerator='ipu', precision={self.precision!r})` is not supported."
                )
            return IPUPrecisionPlugin(self.precision)
        if self.use_tpu:
            if self.precision == 32:
                return TPUPrecisionPlugin()
            elif self.precision == 64:
                raise MisconfigurationException(
                    "`Trainer(accelerator='tpu', precision=64)` is not implemented."
                    " Please, open an issue in `https://github.com/PyTorchLightning/pytorch-lightning/issues`"
                    " requesting this feature.")
            elif self.precision in (16, "bf16"):
                if self.precision == 16:
                    # this is not deprecated to ease transition between accelerator environments
                    rank_zero_warn(
                        f"You passed `Trainer(accelerator='tpu', precision=16)` but {self.amp_type.value} AMP"
                        f" is not supported with TPUs. Using `precision='bf16'` instead."
                    )
                return TPUBf16PrecisionPlugin()

        if self._distrib_type == _StrategyType.DEEPSPEED or isinstance(
                self._training_type_plugin, DeepSpeedStrategy):
            return DeepSpeedPrecisionPlugin(self.precision, self.amp_type,
                                            self.amp_level)

        if self.precision == 32:
            return PrecisionPlugin()
        if self.precision == 64:
            return DoublePrecisionPlugin()

        # maybe convert the precision value
        if self.precision == 16 and self.use_cpu:
            if self.amp_type == AMPType.APEX:
                # apex was explicitly passed, not a good idea to silently switch to native AMP
                raise MisconfigurationException(
                    "You passed `Trainer(accelerator='cpu', precision=16, amp_type='apex')`"
                    " but apex AMP not supported on CPU.")
            # this automatic switch is to ease transition between accelerator environments
            rank_zero_warn(
                "You passed `Trainer(accelerator='cpu', precision=16)` but native AMP is not supported on CPU."
                " Using `precision='bf16'` instead.")
            self.precision = "bf16"

        if self.precision in (16, "bf16"):
            if self.precision == "bf16" and self.amp_type != AMPType.NATIVE:
                raise MisconfigurationException(
                    f"You passed `Trainer(amp_type={self.amp_type.value!r}, precision='bf16')` but it's not supported."
                    " Try using `amp_type='native'` instead.")

            rank_zero_info(
                f"Using 16bit {self.amp_type.value} Automatic Mixed Precision (AMP)"
                if self.precision ==
                16 else "Using bfloat16 Automatic Mixed Precision (AMP)")

            if self.amp_type == AMPType.NATIVE:
                device = "cpu" if self.use_cpu else "cuda"

                if self._is_sharded_training_type:
                    return ShardedNativeMixedPrecisionPlugin(
                        self.precision, device)
                if self._is_fully_sharded_training_type:
                    return FullyShardedNativeMixedPrecisionPlugin(
                        self.precision, device)
                return NativeMixedPrecisionPlugin(self.precision, device)

            if self.amp_type == AMPType.APEX:
                if self._is_sharded_training_type or self._is_fully_sharded_training_type:
                    raise MisconfigurationException(
                        "Sharded plugins are not supported with apex, please switch to `amp_backend='native'`."
                    )
                self.amp_level = self.amp_level or "O2"
                return ApexMixedPrecisionPlugin(self.amp_level)

        raise RuntimeError("No precision set")
コード例 #7
0
            super().__init__(**kwargs)
            # Set to None so it will be overwritten by the accelerator connector.
            self._layer_sync = None

    strategy = CustomParallelStrategy()
    assert strategy._layer_sync is None
    Trainer(strategy=strategy, sync_batchnorm=True)
    assert isinstance(strategy._layer_sync, NativeSyncBatchNorm)


@pytest.mark.parametrize(
    ["plugins", "expected"],
    [
        ([LightningEnvironment(), SLURMEnvironment()], "ClusterEnvironment"),
        ([TorchCheckpointIO(), TorchCheckpointIO()], "CheckpointIO"),
        (
            [
                PrecisionPlugin(),
                DoublePrecisionPlugin(),
                LightningEnvironment(),
                SLURMEnvironment()
            ],
            "PrecisionPlugin, ClusterEnvironment",
        ),
    ],
)
def test_plugin_only_one_instance_for_one_type(plugins, expected):
    with pytest.raises(MisconfigurationException,
                       match=f"Received multiple values for {expected}"):
        Trainer(plugins=plugins)
コード例 #8
0
@RunIf(skip_windows=True)
def test_sync_batchnorm_set_in_custom_strategy(tmpdir):
    """Tests if layer_sync is automatically set for custom strategy."""

    class CustomParallelStrategy(DDPStrategy):
        def __init__(self, **kwargs):
            super().__init__(**kwargs)
            # Set to None so it will be overwritten by the accelerator connector.
            self._layer_sync = None

    strategy = CustomParallelStrategy()
    assert strategy._layer_sync is None
    Trainer(strategy=strategy, sync_batchnorm=True)
    assert isinstance(strategy._layer_sync, NativeSyncBatchNorm)


@pytest.mark.parametrize(
    ["plugins", "expected"],
    [
        ([LightningEnvironment(), SLURMEnvironment()], "ClusterEnvironment"),
        ([TorchCheckpointIO(), TorchCheckpointIO()], "CheckpointIO"),
        (
            [PrecisionPlugin(), DoublePrecisionPlugin(), LightningEnvironment(), SLURMEnvironment()],
            "PrecisionPlugin, ClusterEnvironment",
        ),
    ],
)
def test_plugin_only_one_instance_for_one_type(plugins, expected):
    with pytest.raises(MisconfigurationException, match=f"Received multiple values for {expected}"):
        Trainer(plugins=plugins)