def test_strategy_registry_with_new_strategy():
    class TestStrategy:

        distributed_backend = "test_strategy"

        def __init__(self, param1, param2):
            self.param1 = param1
            self.param2 = param2

    strategy_name = "test_strategy"
    strategy_description = "Test Strategy"

    StrategyRegistry.register(strategy_name,
                              TestStrategy,
                              description=strategy_description,
                              param1="abc",
                              param2=123)

    assert strategy_name in StrategyRegistry
    assert StrategyRegistry[strategy_name][
        "description"] == strategy_description
    assert StrategyRegistry[strategy_name]["init_params"] == {
        "param1": "abc",
        "param2": 123
    }
    assert StrategyRegistry[strategy_name][
        "distributed_backend"] == "test_strategy"
    assert isinstance(StrategyRegistry.get(strategy_name), TestStrategy)

    StrategyRegistry.remove(strategy_name)
    assert strategy_name not in StrategyRegistry
def test_custom_registered_strategy_to_strategy_flag():
    class CustomCheckpointIO(CheckpointIO):
        def save_checkpoint(self, checkpoint, path):
            pass

        def load_checkpoint(self, path):
            pass

        def remove_checkpoint(self, path):
            pass

    custom_checkpoint_io = CustomCheckpointIO()

    # Register the DDP Strategy with your custom CheckpointIO plugin
    StrategyRegistry.register(
        "ddp_custom_checkpoint_io",
        DDPStrategy,
        description="DDP Strategy with custom checkpoint io plugin",
        checkpoint_io=custom_checkpoint_io,
    )
    trainer = Trainer(strategy="ddp_custom_checkpoint_io",
                      accelerator="cpu",
                      devices=2)

    assert isinstance(trainer.strategy, DDPStrategy)
    assert trainer.strategy.checkpoint_io == custom_checkpoint_io