def test_ddp_post_local_sgd_comm_hook(tmpdir):
    """Test for DDP post-localSGD hook."""
    model = BoringModel()

    training_type_plugin = DDPPlugin(
        ddp_comm_state=post_localSGD.PostLocalSGDState(
            process_group=None,
            subgroup=None,
            start_localSGD_iter=8,
        ),
        ddp_comm_hook=post_localSGD.post_localSGD_hook,
        model_averaging_period=4,
    )
    trainer = Trainer(
        fast_dev_run=True,
        gpus=2,
        strategy=training_type_plugin,
        default_root_dir=tmpdir,
        sync_batchnorm=True,
    )
    trainer.fit(model)
    trainer_comm_hook = trainer.accelerator.training_type_plugin._model.get_ddp_logging_data(
    ).comm_hook
    expected_comm_hook = post_localSGD.post_localSGD_hook.__qualname__
    assert trainer_comm_hook == expected_comm_hook
    assert trainer.state.finished, f"Training failed with {trainer.state}"
示例#2
0
def test_ddp_post_local_sgd_comm_hook(tmpdir):
    """Test for DDP post-localSGD hook."""
    model = BoringModel()
    strategy = TestDDPStrategy(
        expected_ddp_comm_hook_name=post_localSGD.post_localSGD_hook.
        __qualname__,
        ddp_comm_state=post_localSGD.PostLocalSGDState(
            process_group=None,
            subgroup=None,
            start_localSGD_iter=8,
        ),
        ddp_comm_hook=post_localSGD.post_localSGD_hook,
        model_averaging_period=4,
    )
    trainer = Trainer(
        fast_dev_run=True,
        accelerator="gpu",
        devices=2,
        strategy=strategy,
        default_root_dir=tmpdir,
        sync_batchnorm=True,
        enable_progress_bar=False,
        enable_model_summary=False,
    )
    trainer.fit(model)
    assert trainer.state.finished, f"Training failed with {trainer.state}"
示例#3
0
def test_post_local_sgd_model_averaging_value_error(average_parameters_mock, tmpdir):
    """Test that when using DDP with post-localSGD a ValueError is thrown when the optmizer is
    ZeroRedundancyOptimizer."""
    from torch.distributed.optim import ZeroRedundancyOptimizer

    class OptimizerModel(BoringModel):
        def configure_optimizers(self):
            return ZeroRedundancyOptimizer(params=self.parameters(), optimizer_class=torch.optim.Adam, lr=0.01)

    model = OptimizerModel()
    strategy = DDPStrategy(
        ddp_comm_state=post_localSGD.PostLocalSGDState(
            process_group=None,
            subgroup=None,
            start_localSGD_iter=8,
        ),
        ddp_comm_hook=post_localSGD.post_localSGD_hook,
        model_averaging_period=4,
    )

    trainer = Trainer(
        fast_dev_run=True,
        gpus=2,
        strategy=strategy,
        default_root_dir=tmpdir,
        sync_batchnorm=True,
    )

    with pytest.raises(ValueError, match="Currently model averaging cannot work with a distributed optimizer"):
        trainer.fit(model)

    average_parameters_mock.assert_not_called()
示例#4
0
def test_post_local_sgd_model_averaging(average_parameters_mock, tmpdir):
    """Test that when using DDP with post-localSGD, model averaging is called."""
    model = BoringModel()

    # test regular ddp does not call model averaging
    trainer = Trainer(
        fast_dev_run=True,
        accelerator="gpu",
        devices=2,
        strategy="ddp",
        default_root_dir=tmpdir,
        sync_batchnorm=True,
        enable_progress_bar=False,
        enable_model_summary=False,
    )

    trainer.fit(model)
    average_parameters_mock.assert_not_called()

    # test ddp with post-localSGD does call model averaging
    ddp_strategy = DDPStrategy(
        ddp_comm_state=post_localSGD.PostLocalSGDState(
            process_group=None,
            subgroup=None,
            start_localSGD_iter=8,
        ),
        ddp_comm_hook=post_localSGD.post_localSGD_hook,
        model_averaging_period=4,
    )

    trainer = Trainer(
        fast_dev_run=True,
        accelerator="gpu",
        devices=2,
        strategy=ddp_strategy,
        default_root_dir=tmpdir,
        sync_batchnorm=True,
    )

    trainer.fit(model)
    average_parameters_mock.assert_called()