Exemple #1
0
def test_training_type_plugins_registry_with_new_plugin():
    class TestPlugin:
        def __init__(self, param1, param2):
            self.param1 = param1
            self.param2 = param2

    plugin_name = "test_plugin"
    plugin_description = "Test Plugin"

    TrainingTypePluginsRegistry.register(plugin_name,
                                         TestPlugin,
                                         description=plugin_description,
                                         param1="abc",
                                         param2=123)

    assert plugin_name in TrainingTypePluginsRegistry
    assert TrainingTypePluginsRegistry[plugin_name][
        "description"] == plugin_description
    assert TrainingTypePluginsRegistry[plugin_name]["init_params"] == {
        "param1": "abc",
        "param2": 123
    }
    assert isinstance(TrainingTypePluginsRegistry.get(plugin_name), TestPlugin)

    TrainingTypePluginsRegistry.remove(plugin_name)
    assert plugin_name not in TrainingTypePluginsRegistry
def test_custom_registered_training_plugin_to_strategy():
    class CustomCheckpointIO(CheckpointIO):
        def save_checkpoint(self, checkpoint, path):
            pass

        def load_checkpoint(self, path):
            pass

        def remove_checkpoint(self, path):
            pass

    custom_checkpoint_io = CustomCheckpointIO()

    # Register the DDP Plugin with your custom CheckpointIO plugin
    TrainingTypePluginsRegistry.register(
        "ddp_custom_checkpoint_io",
        DDPPlugin,
        description="DDP Plugin with custom checkpoint io plugin",
        checkpoint_io=custom_checkpoint_io,
    )
    trainer = Trainer(strategy="ddp_custom_checkpoint_io",
                      accelerator="cpu",
                      devices=2)

    assert isinstance(trainer.training_type_plugin, DDPPlugin)
    assert trainer.training_type_plugin.checkpoint_io == custom_checkpoint_io
 def _set_training_type_plugin(self) -> None:
     if isinstance(self.strategy,
                   str) and self.strategy in TrainingTypePluginsRegistry:
         self._training_type_plugin = TrainingTypePluginsRegistry.get(
             self.strategy)
     if isinstance(self.strategy, str):
         self.set_distributed_mode(self.strategy)
     elif isinstance(self.strategy, Strategy):
         self._training_type_plugin = self.strategy
    def handle_given_plugins(self) -> None:

        training_type = None
        precision = None
        cluster_environment = None

        for plug in self.plugins:
            if isinstance(plug, str) and plug in TrainingTypePluginsRegistry:
                if training_type is None:
                    training_type = TrainingTypePluginsRegistry.get(plug)
                else:
                    raise MisconfigurationException(
                        'You can only specify one precision and one training type plugin.'
                        ' Found more than 1 training type plugin:'
                        f' {TrainingTypePluginsRegistry[plug]["plugin"]} registered to {plug}'
                    )
            if isinstance(plug, str):
                # Reset the distributed type as the user has overridden training type
                # via the plugins argument
                self._distrib_type = None
                self.set_distributed_mode(plug)

            elif isinstance(plug, TrainingTypePlugin):
                if training_type is None:
                    training_type = plug

                else:
                    raise MisconfigurationException(
                        'You can only specify one precision and one training type plugin.'
                        f' Found more than 1 training type plugin: {type(plug).__name__}'
                    )
            elif isinstance(plug, PrecisionPlugin):
                if precision is None:
                    precision = plug
                else:
                    raise MisconfigurationException(
                        'You can only specify one precision and one training type plugin.'
                        f' Found more than 1 precision plugin: {type(plug).__name__}'
                    )

            elif isinstance(plug, ClusterEnvironment):
                if cluster_environment is None:
                    cluster_environment = plug
                else:
                    raise MisconfigurationException(
                        'You can only specify one cluster environment. Found more than 1 cluster environment plugin'
                    )
            else:
                raise MisconfigurationException(
                    f'Found invalid type for plugin {plug}. Expected a precision or training type plugin.'
                )

        self._training_type_plugin = training_type
        self._precision_plugin = precision
        self._cluster_environment = cluster_environment or self.select_cluster_environment(
        )
    def handle_given_plugins(self) -> None:

        for plug in self.plugins:
            if self.strategy is not None and self._is_plugin_training_type(
                    plug):
                raise MisconfigurationException(
                    f"You have passed `Trainer(strategy={self.strategy!r})`"
                    f" and you can only specify one training type plugin, but you have passed {plug} as a plugin."
                )
            if self._is_plugin_training_type(plug):
                rank_zero_deprecation(
                    f"Passing {plug} `strategy` to the `plugins` flag in Trainer has been deprecated"
                    f" in v1.5 and will be removed in v1.7. Use `Trainer(strategy={plug})` instead."
                )

        training_type = self._training_type_plugin or None
        checkpoint = None
        precision = None
        cluster_environment = None

        for plug in self.plugins:
            if isinstance(plug, str) and plug in TrainingTypePluginsRegistry:
                if training_type is None:
                    training_type = TrainingTypePluginsRegistry.get(plug)
                else:
                    raise MisconfigurationException(
                        "You can only specify one precision and one training type plugin."
                        " Found more than 1 training type plugin:"
                        f' {TrainingTypePluginsRegistry[plug]["plugin"]} registered to {plug}'
                    )
            if isinstance(plug, str):
                # Reset the distributed type as the user has overridden training type
                # via the plugins argument
                self._distrib_type = None
                self.set_distributed_mode(plug)

            elif isinstance(plug, Strategy):
                if training_type is None:
                    training_type = plug

                else:
                    raise MisconfigurationException(
                        "You can only specify one training type plugin."
                        f" Available: {type(training_type).__name__}, given: {type(plug).__name__}"
                    )
            elif isinstance(plug, PrecisionPlugin):
                if precision is None:
                    precision = plug
                else:
                    raise MisconfigurationException(
                        "You can only specify one precision plugin."
                        f" Available: {type(precision).__name__}, given: {type(plug).__name__}"
                    )
            elif isinstance(plug, CheckpointIO):
                if checkpoint is None:
                    checkpoint = plug
                else:
                    raise MisconfigurationException(
                        "You can only specify one checkpoint plugin."
                        f" Available: {type(checkpoint).__name__}, given: {type(plug).__name__}"
                    )
            elif isinstance(plug, ClusterEnvironment):
                if cluster_environment is None:
                    cluster_environment = plug
                else:
                    raise MisconfigurationException(
                        "You can only specify one cluster environment. Found more than 1 cluster environment plugin"
                    )
            else:
                raise MisconfigurationException(
                    f"Found invalid type for plugin {plug}. Expected a precision or training type plugin."
                )

        self._training_type_plugin = training_type
        self._precision_plugin = precision
        self._checkpoint_io = checkpoint
        self._cluster_environment = cluster_environment