def _set_strategy(self) -> None:
     if isinstance(self._strategy_flag, str) and self._strategy_flag in StrategyRegistry:
         self._strategy = StrategyRegistry.get(self._strategy_flag)
     if isinstance(self._strategy_flag, str):
         self.set_distributed_mode(self._strategy_flag)
     elif isinstance(self._strategy_flag, Strategy):
         self._strategy = self._strategy_flag
def test_strategy_registry_with_new_strategy():
    class TestStrategy:

        distributed_backend = "test_strategy"

        def __init__(self, param1, param2):
            self.param1 = param1
            self.param2 = param2

    strategy_name = "test_strategy"
    strategy_description = "Test Strategy"

    StrategyRegistry.register(strategy_name,
                              TestStrategy,
                              description=strategy_description,
                              param1="abc",
                              param2=123)

    assert strategy_name in StrategyRegistry
    assert StrategyRegistry[strategy_name][
        "description"] == strategy_description
    assert StrategyRegistry[strategy_name]["init_params"] == {
        "param1": "abc",
        "param2": 123
    }
    assert StrategyRegistry[strategy_name][
        "distributed_backend"] == "test_strategy"
    assert isinstance(StrategyRegistry.get(strategy_name), TestStrategy)

    StrategyRegistry.remove(strategy_name)
    assert strategy_name not in StrategyRegistry
Ejemplo n.º 3
0
 def _init_strategy(self) -> None:
     """Instantiate the Strategy given depending on the setting of ``_strategy_flag``."""
     if isinstance(self._strategy_flag, HorovodStrategy) or self._strategy_flag == "horovod":
         # handle horovod has to happen before initialize strategy because HorovodStrategy needs hvd.init() first.
         # TODO lazy initialized and setup horovod strategy `global_rank`
         self._handle_horovod()
     if isinstance(self._strategy_flag, str):
         self.strategy = StrategyRegistry.get(self._strategy_flag)
     elif isinstance(self._strategy_flag, Strategy):
         self.strategy = self._strategy_flag
     else:
         raise RuntimeError(f"{self.strategy} is not valid type: {self.strategy}")
 def _init_strategy(self) -> None:
     """Instantiate the Strategy given depending on the setting of ``_strategy_flag``."""
     if isinstance(self._strategy_flag,
                   HorovodStrategy) or self._strategy_flag == "horovod":
         # handle horovod has to happen before initialize strategy because HorovodStrategy needs hvd.init() first.
         # TODO lazy initialized and setup horovod strategy `global_rank`
         self._handle_horovod()
     if isinstance(self._strategy_flag, str):
         if self._strategy_flag == "ddp2":
             # TODO: remove this error in v1.8
             raise ValueError(
                 "The DDP2 strategy is no longer supported. For single-node use, we recommend `strategy='ddp'` or"
                 " `strategy='dp'` as a replacement. If you need DDP2, you will need `torch < 1.9`,"
                 " `pytorch-lightning < 1.5`, and set it as `accelerator='ddp2'`."
             )
         self.strategy = StrategyRegistry.get(self._strategy_flag)
     elif isinstance(self._strategy_flag, Strategy):
         self.strategy = self._strategy_flag
     else:
         raise RuntimeError(
             f"{self.strategy} is not valid type: {self.strategy}")
    def handle_given_plugins(self) -> None:

        for plug in self.plugins:
            if self._strategy_flag is not None and self._is_plugin_training_type(plug):
                raise MisconfigurationException(
                    f"You have passed `Trainer(strategy={self._strategy_flag!r})`"
                    f" and you can only specify one training type plugin, but you have passed {plug} as a plugin."
                )
            if self._is_plugin_training_type(plug):
                rank_zero_deprecation(
                    f"Passing {plug} `strategy` to the `plugins` flag in Trainer has been deprecated"
                    f" in v1.5 and will be removed in v1.7. Use `Trainer(strategy={plug})` instead."
                )

        strategy = self._strategy or None
        checkpoint = None
        precision = None
        cluster_environment = None

        for plug in self.plugins:
            if isinstance(plug, str) and plug in StrategyRegistry:
                if strategy is None:
                    strategy = StrategyRegistry.get(plug)
                else:
                    raise MisconfigurationException(
                        "You can only specify one precision and one training type plugin."
                        " Found more than 1 training type plugin:"
                        f' {StrategyRegistry[plug]["strategy"]} registered to {plug}'
                    )
            if isinstance(plug, str):
                # Reset the distributed type as the user has overridden training type
                # via the plugins argument
                self._strategy_type = None
                self.set_distributed_mode(plug)

            elif isinstance(plug, Strategy):
                if strategy is None:
                    strategy = plug

                else:
                    raise MisconfigurationException(
                        "You can only specify one training type plugin."
                        f" Available: {type(strategy).__name__}, given: {type(plug).__name__}"
                    )
            elif isinstance(plug, PrecisionPlugin):
                if precision is None:
                    precision = plug
                else:
                    raise MisconfigurationException(
                        "You can only specify one precision plugin."
                        f" Available: {type(precision).__name__}, given: {type(plug).__name__}"
                    )
            elif isinstance(plug, CheckpointIO):
                if checkpoint is None:
                    checkpoint = plug
                else:
                    raise MisconfigurationException(
                        "You can only specify one checkpoint plugin."
                        f" Available: {type(checkpoint).__name__}, given: {type(plug).__name__}"
                    )
            elif isinstance(plug, ClusterEnvironment):
                if cluster_environment is None:
                    cluster_environment = plug
                else:
                    raise MisconfigurationException(
                        "You can only specify one cluster environment. Found more than 1 cluster environment plugin"
                    )
            else:
                raise MisconfigurationException(
                    f"Found invalid type for plugin {plug}. Expected a precision or training type plugin."
                )

        self._strategy = strategy
        self._precision_plugin = precision
        self._checkpoint_io = checkpoint
        self._cluster_environment = cluster_environment