def test_strategy_registry_with_new_strategy(): class TestStrategy: distributed_backend = "test_strategy" def __init__(self, param1, param2): self.param1 = param1 self.param2 = param2 strategy_name = "test_strategy" strategy_description = "Test Strategy" StrategyRegistry.register(strategy_name, TestStrategy, description=strategy_description, param1="abc", param2=123) assert strategy_name in StrategyRegistry assert StrategyRegistry[strategy_name][ "description"] == strategy_description assert StrategyRegistry[strategy_name]["init_params"] == { "param1": "abc", "param2": 123 } assert StrategyRegistry[strategy_name][ "distributed_backend"] == "test_strategy" assert isinstance(StrategyRegistry.get(strategy_name), TestStrategy) StrategyRegistry.remove(strategy_name) assert strategy_name not in StrategyRegistry
def test_custom_registered_strategy_to_strategy_flag(): class CustomCheckpointIO(CheckpointIO): def save_checkpoint(self, checkpoint, path): pass def load_checkpoint(self, path): pass def remove_checkpoint(self, path): pass custom_checkpoint_io = CustomCheckpointIO() # Register the DDP Strategy with your custom CheckpointIO plugin StrategyRegistry.register( "ddp_custom_checkpoint_io", DDPStrategy, description="DDP Strategy with custom checkpoint io plugin", checkpoint_io=custom_checkpoint_io, ) trainer = Trainer(strategy="ddp_custom_checkpoint_io", accelerator="cpu", devices=2) assert isinstance(trainer.strategy, DDPStrategy) assert trainer.strategy.checkpoint_io == custom_checkpoint_io
def _set_strategy(self) -> None: if isinstance(self._strategy_flag, str) and self._strategy_flag in StrategyRegistry: self._strategy = StrategyRegistry.get(self._strategy_flag) if isinstance(self._strategy_flag, str): self.set_distributed_mode(self._strategy_flag) elif isinstance(self._strategy_flag, Strategy): self._strategy = self._strategy_flag
def _init_strategy(self) -> None: """Instantiate the Strategy given depending on the setting of ``_strategy_flag``.""" if isinstance(self._strategy_flag, HorovodStrategy) or self._strategy_flag == "horovod": # handle horovod has to happen before initialize strategy because HorovodStrategy needs hvd.init() first. # TODO lazy initialized and setup horovod strategy `global_rank` self._handle_horovod() if isinstance(self._strategy_flag, str): self.strategy = StrategyRegistry.get(self._strategy_flag) elif isinstance(self._strategy_flag, Strategy): self.strategy = self._strategy_flag else: raise RuntimeError(f"{self.strategy} is not valid type: {self.strategy}")
def _init_strategy(self) -> None: """Instantiate the Strategy given depending on the setting of ``_strategy_flag``.""" if isinstance(self._strategy_flag, HorovodStrategy) or self._strategy_flag == "horovod": # handle horovod has to happen before initialize strategy because HorovodStrategy needs hvd.init() first. # TODO lazy initialized and setup horovod strategy `global_rank` self._handle_horovod() if isinstance(self._strategy_flag, str): if self._strategy_flag == "ddp2": # TODO: remove this error in v1.8 raise ValueError( "The DDP2 strategy is no longer supported. For single-node use, we recommend `strategy='ddp'` or" " `strategy='dp'` as a replacement. If you need DDP2, you will need `torch < 1.9`," " `pytorch-lightning < 1.5`, and set it as `accelerator='ddp2'`." ) self.strategy = StrategyRegistry.get(self._strategy_flag) elif isinstance(self._strategy_flag, Strategy): self.strategy = self._strategy_flag else: raise RuntimeError( f"{self.strategy} is not valid type: {self.strategy}")
def __init__( self, devices: Optional[Union[List[int], str, int]] = None, num_nodes: int = 1, accelerator: Optional[Union[str, Accelerator]] = None, strategy: Optional[Union[str, Strategy]] = None, plugins: Optional[Union[PLUGIN_INPUT, List[PLUGIN_INPUT]]] = None, precision: Union[int, str] = 32, amp_type: str = "native", amp_level: Optional[str] = None, sync_batchnorm: bool = False, benchmark: Optional[bool] = None, replace_sampler_ddp: bool = True, deterministic: bool = False, num_processes: Optional[int] = None, # deprecated tpu_cores: Optional[Union[List[int], int]] = None, # deprecated ipus: Optional[int] = None, # deprecated gpus: Optional[Union[List[int], str, int]] = None, # deprecated gpu_ids: Optional[List[int]] = None, # TODO can be removed ) -> None: """The AcceleratorConnector parses several Trainer arguments and instantiates the Strategy including other components such as the Accelerator and Precision plugins. A. accelerator flag could be: 1. strategy class (deprecated in 1.5 will be removed in 1.7) 2. strategy str (deprecated in 1.5 will be removed in 1.7) 3. accelerator class 4. accelerator str 5. accelerator auto B. strategy flag could be : 1. strategy class 2. strategy str registered with StrategyRegistry 3. strategy str in _strategy_type enum which listed in each strategy as backend (registed these too, and _strategy_type could be deprecated) C. plugins flag could be: 1. List of str, which could contain: i. strategy str ii. precision str (Not supported in the old accelerator_connector version) iii. checkpoint_io str (Not supported in the old accelerator_connector version) iv. cluster_environment str (Not supported in the old accelerator_connector version) 2. List of class, which could contains: i. strategy class (deprecated in 1.5 will be removed in 1.7) ii. precision class (should be removed, and precision flag should allow user pass classes) iii. checkpoint_io class iv. cluster_environment class priorities which to take when: A. Class > str B. Strategy > Accelerator/precision/plugins C. TODO When multiple flag set to the same thing """ if benchmark and deterministic: rank_zero_warn( "You passed `deterministic=True` and `benchmark=True`. Note that PyTorch ignores" " torch.backends.cudnn.deterministic=True when torch.backends.cudnn.benchmark=True.", ) self.benchmark = not deterministic if benchmark is None else benchmark # TODO: move to gpu accelerator torch.backends.cudnn.benchmark = self.benchmark self.replace_sampler_ddp = replace_sampler_ddp self._init_deterministic(deterministic) # 1. Parsing flags # Get registered strategies, built-in accelerators and precision plugins self._registered_strategies = StrategyRegistry.available_strategies() self._accelerator_types = AcceleratorRegistry.available_accelerators() self._precision_types = ("16", "32", "64", "bf16", "mixed") # Raise an exception if there are conflicts between flags # Set each valid flag to `self._x_flag` after validation # Example: If accelerator is set to a strategy type, set `self._strategy_flag = accelerator`. # For devices: Assign gpus, ipus, etc. to the accelerator flag and devices flag self._strategy_flag: Optional[Union[Strategy, str]] = None self._accelerator_flag: Optional[Union[Accelerator, str]] = None self._precision_flag: Optional[Union[int, str]] = None self._precision_plugin_flag: Optional[PrecisionPlugin] = None self._cluster_environment_flag: Optional[Union[ClusterEnvironment, str]] = None self._parallel_devices: List[Union[int, torch.device]] = [] self._layer_sync: Optional[LayerSync] = NativeSyncBatchNorm( ) if sync_batchnorm else None self.checkpoint_io: Optional[CheckpointIO] = None self._amp_type_flag: Optional[LightningEnum] = None self._amp_level_flag: Optional[str] = amp_level self._check_config_and_set_final_flags( strategy=strategy, accelerator=accelerator, precision=precision, plugins=plugins, amp_type=amp_type, amp_level=amp_level, sync_batchnorm=sync_batchnorm, ) self._check_device_config_and_set_final_flags( devices=devices, num_nodes=num_nodes, num_processes=num_processes, gpus=gpus, ipus=ipus, tpu_cores=tpu_cores) # 2. Instantiate Accelerator # handle `auto` and `None` self._set_accelerator_if_ipu_strategy_is_passed() if self._accelerator_flag == "auto" or self._accelerator_flag is None: self._accelerator_flag = self._choose_accelerator() self._set_parallel_devices_and_init_accelerator() # 3. Instantiate ClusterEnvironment self.cluster_environment: ClusterEnvironment = self._choose_and_init_cluster_environment( ) # 4. Instantiate Strategy - Part 1 if self._strategy_flag is None: self._strategy_flag = self._choose_strategy() # In specific cases, ignore user selection and fall back to a different strategy self._check_strategy_and_fallback() self._init_strategy() # 5. Instantiate Precision Plugin self.precision_plugin = self._check_and_init_precision() # 6. Instantiate Strategy - Part 2 self._lazy_init_strategy()
def handle_given_plugins(self) -> None: for plug in self.plugins: if self._strategy_flag is not None and self._is_plugin_training_type(plug): raise MisconfigurationException( f"You have passed `Trainer(strategy={self._strategy_flag!r})`" f" and you can only specify one training type plugin, but you have passed {plug} as a plugin." ) if self._is_plugin_training_type(plug): rank_zero_deprecation( f"Passing {plug} `strategy` to the `plugins` flag in Trainer has been deprecated" f" in v1.5 and will be removed in v1.7. Use `Trainer(strategy={plug})` instead." ) strategy = self._strategy or None checkpoint = None precision = None cluster_environment = None for plug in self.plugins: if isinstance(plug, str) and plug in StrategyRegistry: if strategy is None: strategy = StrategyRegistry.get(plug) else: raise MisconfigurationException( "You can only specify one precision and one training type plugin." " Found more than 1 training type plugin:" f' {StrategyRegistry[plug]["strategy"]} registered to {plug}' ) if isinstance(plug, str): # Reset the distributed type as the user has overridden training type # via the plugins argument self._strategy_type = None self.set_distributed_mode(plug) elif isinstance(plug, Strategy): if strategy is None: strategy = plug else: raise MisconfigurationException( "You can only specify one training type plugin." f" Available: {type(strategy).__name__}, given: {type(plug).__name__}" ) elif isinstance(plug, PrecisionPlugin): if precision is None: precision = plug else: raise MisconfigurationException( "You can only specify one precision plugin." f" Available: {type(precision).__name__}, given: {type(plug).__name__}" ) elif isinstance(plug, CheckpointIO): if checkpoint is None: checkpoint = plug else: raise MisconfigurationException( "You can only specify one checkpoint plugin." f" Available: {type(checkpoint).__name__}, given: {type(plug).__name__}" ) elif isinstance(plug, ClusterEnvironment): if cluster_environment is None: cluster_environment = plug else: raise MisconfigurationException( "You can only specify one cluster environment. Found more than 1 cluster environment plugin" ) else: raise MisconfigurationException( f"Found invalid type for plugin {plug}. Expected a precision or training type plugin." ) self._strategy = strategy self._precision_plugin = precision self._checkpoint_io = checkpoint self._cluster_environment = cluster_environment