def test_training_type_plugins_registry_with_new_plugin(): class TestPlugin: def __init__(self, param1, param2): self.param1 = param1 self.param2 = param2 plugin_name = "test_plugin" plugin_description = "Test Plugin" TrainingTypePluginsRegistry.register(plugin_name, TestPlugin, description=plugin_description, param1="abc", param2=123) assert plugin_name in TrainingTypePluginsRegistry assert TrainingTypePluginsRegistry[plugin_name][ "description"] == plugin_description assert TrainingTypePluginsRegistry[plugin_name]["init_params"] == { "param1": "abc", "param2": 123 } assert isinstance(TrainingTypePluginsRegistry.get(plugin_name), TestPlugin) TrainingTypePluginsRegistry.remove(plugin_name) assert plugin_name not in TrainingTypePluginsRegistry
def test_custom_registered_training_plugin_to_strategy(): class CustomCheckpointIO(CheckpointIO): def save_checkpoint(self, checkpoint, path): pass def load_checkpoint(self, path): pass def remove_checkpoint(self, path): pass custom_checkpoint_io = CustomCheckpointIO() # Register the DDP Plugin with your custom CheckpointIO plugin TrainingTypePluginsRegistry.register( "ddp_custom_checkpoint_io", DDPPlugin, description="DDP Plugin with custom checkpoint io plugin", checkpoint_io=custom_checkpoint_io, ) trainer = Trainer(strategy="ddp_custom_checkpoint_io", accelerator="cpu", devices=2) assert isinstance(trainer.training_type_plugin, DDPPlugin) assert trainer.training_type_plugin.checkpoint_io == custom_checkpoint_io
def _set_training_type_plugin(self) -> None: if isinstance(self.strategy, str) and self.strategy in TrainingTypePluginsRegistry: self._training_type_plugin = TrainingTypePluginsRegistry.get( self.strategy) if isinstance(self.strategy, str): self.set_distributed_mode(self.strategy) elif isinstance(self.strategy, Strategy): self._training_type_plugin = self.strategy
def handle_given_plugins(self) -> None: training_type = None precision = None cluster_environment = None for plug in self.plugins: if isinstance(plug, str) and plug in TrainingTypePluginsRegistry: if training_type is None: training_type = TrainingTypePluginsRegistry.get(plug) else: raise MisconfigurationException( 'You can only specify one precision and one training type plugin.' ' Found more than 1 training type plugin:' f' {TrainingTypePluginsRegistry[plug]["plugin"]} registered to {plug}' ) if isinstance(plug, str): # Reset the distributed type as the user has overridden training type # via the plugins argument self._distrib_type = None self.set_distributed_mode(plug) elif isinstance(plug, TrainingTypePlugin): if training_type is None: training_type = plug else: raise MisconfigurationException( 'You can only specify one precision and one training type plugin.' f' Found more than 1 training type plugin: {type(plug).__name__}' ) elif isinstance(plug, PrecisionPlugin): if precision is None: precision = plug else: raise MisconfigurationException( 'You can only specify one precision and one training type plugin.' f' Found more than 1 precision plugin: {type(plug).__name__}' ) elif isinstance(plug, ClusterEnvironment): if cluster_environment is None: cluster_environment = plug else: raise MisconfigurationException( 'You can only specify one cluster environment. Found more than 1 cluster environment plugin' ) else: raise MisconfigurationException( f'Found invalid type for plugin {plug}. Expected a precision or training type plugin.' ) self._training_type_plugin = training_type self._precision_plugin = precision self._cluster_environment = cluster_environment or self.select_cluster_environment( )
def handle_given_plugins(self) -> None: for plug in self.plugins: if self.strategy is not None and self._is_plugin_training_type( plug): raise MisconfigurationException( f"You have passed `Trainer(strategy={self.strategy!r})`" f" and you can only specify one training type plugin, but you have passed {plug} as a plugin." ) if self._is_plugin_training_type(plug): rank_zero_deprecation( f"Passing {plug} `strategy` to the `plugins` flag in Trainer has been deprecated" f" in v1.5 and will be removed in v1.7. Use `Trainer(strategy={plug})` instead." ) training_type = self._training_type_plugin or None checkpoint = None precision = None cluster_environment = None for plug in self.plugins: if isinstance(plug, str) and plug in TrainingTypePluginsRegistry: if training_type is None: training_type = TrainingTypePluginsRegistry.get(plug) else: raise MisconfigurationException( "You can only specify one precision and one training type plugin." " Found more than 1 training type plugin:" f' {TrainingTypePluginsRegistry[plug]["plugin"]} registered to {plug}' ) if isinstance(plug, str): # Reset the distributed type as the user has overridden training type # via the plugins argument self._distrib_type = None self.set_distributed_mode(plug) elif isinstance(plug, Strategy): if training_type is None: training_type = plug else: raise MisconfigurationException( "You can only specify one training type plugin." f" Available: {type(training_type).__name__}, given: {type(plug).__name__}" ) elif isinstance(plug, PrecisionPlugin): if precision is None: precision = plug else: raise MisconfigurationException( "You can only specify one precision plugin." f" Available: {type(precision).__name__}, given: {type(plug).__name__}" ) elif isinstance(plug, CheckpointIO): if checkpoint is None: checkpoint = plug else: raise MisconfigurationException( "You can only specify one checkpoint plugin." f" Available: {type(checkpoint).__name__}, given: {type(plug).__name__}" ) elif isinstance(plug, ClusterEnvironment): if cluster_environment is None: cluster_environment = plug else: raise MisconfigurationException( "You can only specify one cluster environment. Found more than 1 cluster environment plugin" ) else: raise MisconfigurationException( f"Found invalid type for plugin {plug}. Expected a precision or training type plugin." ) self._training_type_plugin = training_type self._precision_plugin = precision self._checkpoint_io = checkpoint self._cluster_environment = cluster_environment