def test_strategy_registry_with_new_strategy():
    class TestStrategy:

        distributed_backend = "test_strategy"

        def __init__(self, param1, param2):
            self.param1 = param1
            self.param2 = param2

    strategy_name = "test_strategy"
    strategy_description = "Test Strategy"

    StrategyRegistry.register(strategy_name,
                              TestStrategy,
                              description=strategy_description,
                              param1="abc",
                              param2=123)

    assert strategy_name in StrategyRegistry
    assert StrategyRegistry[strategy_name][
        "description"] == strategy_description
    assert StrategyRegistry[strategy_name]["init_params"] == {
        "param1": "abc",
        "param2": 123
    }
    assert StrategyRegistry[strategy_name][
        "distributed_backend"] == "test_strategy"
    assert isinstance(StrategyRegistry.get(strategy_name), TestStrategy)

    StrategyRegistry.remove(strategy_name)
    assert strategy_name not in StrategyRegistry
def test_custom_registered_strategy_to_strategy_flag():
    class CustomCheckpointIO(CheckpointIO):
        def save_checkpoint(self, checkpoint, path):
            pass

        def load_checkpoint(self, path):
            pass

        def remove_checkpoint(self, path):
            pass

    custom_checkpoint_io = CustomCheckpointIO()

    # Register the DDP Strategy with your custom CheckpointIO plugin
    StrategyRegistry.register(
        "ddp_custom_checkpoint_io",
        DDPStrategy,
        description="DDP Strategy with custom checkpoint io plugin",
        checkpoint_io=custom_checkpoint_io,
    )
    trainer = Trainer(strategy="ddp_custom_checkpoint_io",
                      accelerator="cpu",
                      devices=2)

    assert isinstance(trainer.strategy, DDPStrategy)
    assert trainer.strategy.checkpoint_io == custom_checkpoint_io
 def _set_strategy(self) -> None:
     if isinstance(self._strategy_flag, str) and self._strategy_flag in StrategyRegistry:
         self._strategy = StrategyRegistry.get(self._strategy_flag)
     if isinstance(self._strategy_flag, str):
         self.set_distributed_mode(self._strategy_flag)
     elif isinstance(self._strategy_flag, Strategy):
         self._strategy = self._strategy_flag
Ejemplo n.º 4
0
 def _init_strategy(self) -> None:
     """Instantiate the Strategy given depending on the setting of ``_strategy_flag``."""
     if isinstance(self._strategy_flag, HorovodStrategy) or self._strategy_flag == "horovod":
         # handle horovod has to happen before initialize strategy because HorovodStrategy needs hvd.init() first.
         # TODO lazy initialized and setup horovod strategy `global_rank`
         self._handle_horovod()
     if isinstance(self._strategy_flag, str):
         self.strategy = StrategyRegistry.get(self._strategy_flag)
     elif isinstance(self._strategy_flag, Strategy):
         self.strategy = self._strategy_flag
     else:
         raise RuntimeError(f"{self.strategy} is not valid type: {self.strategy}")
 def _init_strategy(self) -> None:
     """Instantiate the Strategy given depending on the setting of ``_strategy_flag``."""
     if isinstance(self._strategy_flag,
                   HorovodStrategy) or self._strategy_flag == "horovod":
         # handle horovod has to happen before initialize strategy because HorovodStrategy needs hvd.init() first.
         # TODO lazy initialized and setup horovod strategy `global_rank`
         self._handle_horovod()
     if isinstance(self._strategy_flag, str):
         if self._strategy_flag == "ddp2":
             # TODO: remove this error in v1.8
             raise ValueError(
                 "The DDP2 strategy is no longer supported. For single-node use, we recommend `strategy='ddp'` or"
                 " `strategy='dp'` as a replacement. If you need DDP2, you will need `torch < 1.9`,"
                 " `pytorch-lightning < 1.5`, and set it as `accelerator='ddp2'`."
             )
         self.strategy = StrategyRegistry.get(self._strategy_flag)
     elif isinstance(self._strategy_flag, Strategy):
         self.strategy = self._strategy_flag
     else:
         raise RuntimeError(
             f"{self.strategy} is not valid type: {self.strategy}")
Ejemplo n.º 6
0
    def __init__(
            self,
            devices: Optional[Union[List[int], str, int]] = None,
            num_nodes: int = 1,
            accelerator: Optional[Union[str, Accelerator]] = None,
            strategy: Optional[Union[str, Strategy]] = None,
            plugins: Optional[Union[PLUGIN_INPUT, List[PLUGIN_INPUT]]] = None,
            precision: Union[int, str] = 32,
            amp_type: str = "native",
            amp_level: Optional[str] = None,
            sync_batchnorm: bool = False,
            benchmark: Optional[bool] = None,
            replace_sampler_ddp: bool = True,
            deterministic: bool = False,
            num_processes: Optional[int] = None,  # deprecated
            tpu_cores: Optional[Union[List[int], int]] = None,  # deprecated
            ipus: Optional[int] = None,  # deprecated
            gpus: Optional[Union[List[int], str, int]] = None,  # deprecated
            gpu_ids: Optional[List[int]] = None,  # TODO can be removed
    ) -> None:
        """The AcceleratorConnector parses several Trainer arguments and instantiates the Strategy including other
        components such as the Accelerator and Precision plugins.

            A. accelerator flag could be:
                1. strategy class (deprecated in 1.5 will be removed in 1.7)
                2. strategy str (deprecated in 1.5 will be removed in 1.7)
                3. accelerator class
                4. accelerator str
                5. accelerator auto

            B. strategy flag could be :
                1. strategy class
                2. strategy str registered with StrategyRegistry
                3. strategy str in _strategy_type enum which listed in each strategy as
                   backend (registed these too, and _strategy_type could be deprecated)

            C. plugins flag could be:
                1. List of str, which could contain:
                    i. strategy str
                    ii. precision str (Not supported in the old accelerator_connector version)
                    iii. checkpoint_io str (Not supported in the old accelerator_connector version)
                    iv. cluster_environment str (Not supported in the old accelerator_connector version)
                2. List of class, which could contains:
                    i. strategy class (deprecated in 1.5 will be removed in 1.7)
                    ii. precision class (should be removed, and precision flag should allow user pass classes)
                    iii. checkpoint_io class
                    iv. cluster_environment class


        priorities which to take when:
            A. Class > str
            B. Strategy > Accelerator/precision/plugins
            C. TODO When multiple flag set to the same thing
        """
        if benchmark and deterministic:
            rank_zero_warn(
                "You passed `deterministic=True` and `benchmark=True`. Note that PyTorch ignores"
                " torch.backends.cudnn.deterministic=True when torch.backends.cudnn.benchmark=True.",
            )
        self.benchmark = not deterministic if benchmark is None else benchmark
        # TODO: move to gpu accelerator
        torch.backends.cudnn.benchmark = self.benchmark
        self.replace_sampler_ddp = replace_sampler_ddp
        self._init_deterministic(deterministic)

        # 1. Parsing flags
        # Get registered strategies, built-in accelerators and precision plugins
        self._registered_strategies = StrategyRegistry.available_strategies()
        self._accelerator_types = AcceleratorRegistry.available_accelerators()
        self._precision_types = ("16", "32", "64", "bf16", "mixed")

        # Raise an exception if there are conflicts between flags
        # Set each valid flag to `self._x_flag` after validation
        # Example: If accelerator is set to a strategy type, set `self._strategy_flag = accelerator`.
        # For devices: Assign gpus, ipus, etc. to the accelerator flag and devices flag
        self._strategy_flag: Optional[Union[Strategy, str]] = None
        self._accelerator_flag: Optional[Union[Accelerator, str]] = None
        self._precision_flag: Optional[Union[int, str]] = None
        self._precision_plugin_flag: Optional[PrecisionPlugin] = None
        self._cluster_environment_flag: Optional[Union[ClusterEnvironment,
                                                       str]] = None
        self._parallel_devices: List[Union[int, torch.device]] = []
        self._layer_sync: Optional[LayerSync] = NativeSyncBatchNorm(
        ) if sync_batchnorm else None
        self.checkpoint_io: Optional[CheckpointIO] = None
        self._amp_type_flag: Optional[LightningEnum] = None
        self._amp_level_flag: Optional[str] = amp_level

        self._check_config_and_set_final_flags(
            strategy=strategy,
            accelerator=accelerator,
            precision=precision,
            plugins=plugins,
            amp_type=amp_type,
            amp_level=amp_level,
            sync_batchnorm=sync_batchnorm,
        )
        self._check_device_config_and_set_final_flags(
            devices=devices,
            num_nodes=num_nodes,
            num_processes=num_processes,
            gpus=gpus,
            ipus=ipus,
            tpu_cores=tpu_cores)
        # 2. Instantiate Accelerator
        # handle `auto` and `None`
        self._set_accelerator_if_ipu_strategy_is_passed()
        if self._accelerator_flag == "auto" or self._accelerator_flag is None:
            self._accelerator_flag = self._choose_accelerator()
        self._set_parallel_devices_and_init_accelerator()

        # 3. Instantiate ClusterEnvironment
        self.cluster_environment: ClusterEnvironment = self._choose_and_init_cluster_environment(
        )

        # 4. Instantiate Strategy - Part 1
        if self._strategy_flag is None:
            self._strategy_flag = self._choose_strategy()
        # In specific cases, ignore user selection and fall back to a different strategy
        self._check_strategy_and_fallback()
        self._init_strategy()

        # 5. Instantiate Precision Plugin
        self.precision_plugin = self._check_and_init_precision()

        # 6. Instantiate Strategy - Part 2
        self._lazy_init_strategy()
    def handle_given_plugins(self) -> None:

        for plug in self.plugins:
            if self._strategy_flag is not None and self._is_plugin_training_type(plug):
                raise MisconfigurationException(
                    f"You have passed `Trainer(strategy={self._strategy_flag!r})`"
                    f" and you can only specify one training type plugin, but you have passed {plug} as a plugin."
                )
            if self._is_plugin_training_type(plug):
                rank_zero_deprecation(
                    f"Passing {plug} `strategy` to the `plugins` flag in Trainer has been deprecated"
                    f" in v1.5 and will be removed in v1.7. Use `Trainer(strategy={plug})` instead."
                )

        strategy = self._strategy or None
        checkpoint = None
        precision = None
        cluster_environment = None

        for plug in self.plugins:
            if isinstance(plug, str) and plug in StrategyRegistry:
                if strategy is None:
                    strategy = StrategyRegistry.get(plug)
                else:
                    raise MisconfigurationException(
                        "You can only specify one precision and one training type plugin."
                        " Found more than 1 training type plugin:"
                        f' {StrategyRegistry[plug]["strategy"]} registered to {plug}'
                    )
            if isinstance(plug, str):
                # Reset the distributed type as the user has overridden training type
                # via the plugins argument
                self._strategy_type = None
                self.set_distributed_mode(plug)

            elif isinstance(plug, Strategy):
                if strategy is None:
                    strategy = plug

                else:
                    raise MisconfigurationException(
                        "You can only specify one training type plugin."
                        f" Available: {type(strategy).__name__}, given: {type(plug).__name__}"
                    )
            elif isinstance(plug, PrecisionPlugin):
                if precision is None:
                    precision = plug
                else:
                    raise MisconfigurationException(
                        "You can only specify one precision plugin."
                        f" Available: {type(precision).__name__}, given: {type(plug).__name__}"
                    )
            elif isinstance(plug, CheckpointIO):
                if checkpoint is None:
                    checkpoint = plug
                else:
                    raise MisconfigurationException(
                        "You can only specify one checkpoint plugin."
                        f" Available: {type(checkpoint).__name__}, given: {type(plug).__name__}"
                    )
            elif isinstance(plug, ClusterEnvironment):
                if cluster_environment is None:
                    cluster_environment = plug
                else:
                    raise MisconfigurationException(
                        "You can only specify one cluster environment. Found more than 1 cluster environment plugin"
                    )
            else:
                raise MisconfigurationException(
                    f"Found invalid type for plugin {plug}. Expected a precision or training type plugin."
                )

        self._strategy = strategy
        self._precision_plugin = precision
        self._checkpoint_io = checkpoint
        self._cluster_environment = cluster_environment