Ejemplo n.º 1
0
def test_precision_supported_types():
    assert PrecisionType.supported_types() == [
        "16", "32", "64", "bf16", "mixed"
    ]
    assert PrecisionType.supported_type(16)
    assert PrecisionType.supported_type("16")
    assert not PrecisionType.supported_type(1)
    assert not PrecisionType.supported_type("invalid")
    def __init__(
        self,
        num_processes,
        devices,
        tpu_cores,
        ipus,
        accelerator,
        strategy: Optional[Union[str, Strategy]],
        gpus,
        gpu_ids,
        num_nodes,
        sync_batchnorm,
        benchmark,
        replace_sampler_ddp,
        deterministic: bool,
        precision,
        amp_type,
        amp_level,
        plugins,
    ):
        # initialization
        self._device_type = _AcceleratorType.CPU
        self._distrib_type = None
        self._accelerator_type = None

        self.strategy = strategy.lower() if isinstance(strategy,
                                                       str) else strategy
        # TODO: Rename this to something else once all the distributed flags are moved to strategy
        self.distributed_backend = accelerator

        self._init_deterministic(deterministic)

        self.num_processes = num_processes
        self.devices = devices
        # `gpus` is the input passed to the Trainer, whereas `gpu_ids` is a list of parsed gpu ids.
        self.gpus = gpus
        self.parallel_device_ids = gpu_ids
        self.tpu_cores = tpu_cores
        self.ipus = ipus
        self.num_nodes = num_nodes
        self.sync_batchnorm = sync_batchnorm
        self.benchmark = benchmark
        self.replace_sampler_ddp = replace_sampler_ddp
        if not PrecisionType.supported_type(precision):
            raise MisconfigurationException(
                f"Precision {repr(precision)} is invalid. Allowed precision values: {PrecisionType.supported_types()}"
            )
        self.precision = precision
        self.amp_type = amp_type.lower() if isinstance(amp_type, str) else None
        self.amp_level = amp_level

        self._precision_plugin: Optional[PrecisionPlugin] = None
        self._training_type_plugin: Optional[Strategy] = None
        self._cluster_environment: Optional[ClusterEnvironment] = None
        self._checkpoint_io: Optional[CheckpointIO] = None

        plugins = plugins if plugins is not None else []

        if isinstance(plugins, str):
            plugins = [plugins]

        if not isinstance(plugins, Sequence):
            plugins = [plugins]

        self.plugins = plugins

        self._handle_accelerator_and_strategy()

        self._validate_accelerator_and_devices()

        self._warn_if_devices_flag_ignored()

        self.select_accelerator_type()

        if self.strategy is not None:
            self._set_training_type_plugin()
        else:
            self.set_distributed_mode()

        self.handle_given_plugins()
        self._set_distrib_type_if_training_type_plugin_passed()

        self._cluster_environment = self.select_cluster_environment()

        self.update_device_type_if_ipu_plugin()
        self.update_device_type_if_training_type_plugin_passed()

        self._validate_accelerator_type()
        self._set_devices_if_none()

        self.training_type_plugin = self.final_training_type_plugin()
        self.accelerator = self.training_type_plugin.accelerator
        self._check_plugin_compatibility()

        # benchmarking
        # TODO: should this be moved to GPU accelerator?
        torch.backends.cudnn.benchmark = self.benchmark

        self.replace_sampler_ddp = replace_sampler_ddp