def test_ddp_post_local_sgd_comm_hook(tmpdir): """Test for DDP post-localSGD hook.""" model = BoringModel() training_type_plugin = DDPPlugin( ddp_comm_state=post_localSGD.PostLocalSGDState( process_group=None, subgroup=None, start_localSGD_iter=8, ), ddp_comm_hook=post_localSGD.post_localSGD_hook, model_averaging_period=4, ) trainer = Trainer( fast_dev_run=True, gpus=2, strategy=training_type_plugin, default_root_dir=tmpdir, sync_batchnorm=True, ) trainer.fit(model) trainer_comm_hook = trainer.accelerator.training_type_plugin._model.get_ddp_logging_data( ).comm_hook expected_comm_hook = post_localSGD.post_localSGD_hook.__qualname__ assert trainer_comm_hook == expected_comm_hook assert trainer.state.finished, f"Training failed with {trainer.state}"
def test_ddp_post_local_sgd_comm_hook(tmpdir): """Test for DDP post-localSGD hook.""" model = BoringModel() strategy = TestDDPStrategy( expected_ddp_comm_hook_name=post_localSGD.post_localSGD_hook. __qualname__, ddp_comm_state=post_localSGD.PostLocalSGDState( process_group=None, subgroup=None, start_localSGD_iter=8, ), ddp_comm_hook=post_localSGD.post_localSGD_hook, model_averaging_period=4, ) trainer = Trainer( fast_dev_run=True, accelerator="gpu", devices=2, strategy=strategy, default_root_dir=tmpdir, sync_batchnorm=True, enable_progress_bar=False, enable_model_summary=False, ) trainer.fit(model) assert trainer.state.finished, f"Training failed with {trainer.state}"
def test_post_local_sgd_model_averaging_value_error(average_parameters_mock, tmpdir): """Test that when using DDP with post-localSGD a ValueError is thrown when the optmizer is ZeroRedundancyOptimizer.""" from torch.distributed.optim import ZeroRedundancyOptimizer class OptimizerModel(BoringModel): def configure_optimizers(self): return ZeroRedundancyOptimizer(params=self.parameters(), optimizer_class=torch.optim.Adam, lr=0.01) model = OptimizerModel() strategy = DDPStrategy( ddp_comm_state=post_localSGD.PostLocalSGDState( process_group=None, subgroup=None, start_localSGD_iter=8, ), ddp_comm_hook=post_localSGD.post_localSGD_hook, model_averaging_period=4, ) trainer = Trainer( fast_dev_run=True, gpus=2, strategy=strategy, default_root_dir=tmpdir, sync_batchnorm=True, ) with pytest.raises(ValueError, match="Currently model averaging cannot work with a distributed optimizer"): trainer.fit(model) average_parameters_mock.assert_not_called()
def test_post_local_sgd_model_averaging(average_parameters_mock, tmpdir): """Test that when using DDP with post-localSGD, model averaging is called.""" model = BoringModel() # test regular ddp does not call model averaging trainer = Trainer( fast_dev_run=True, accelerator="gpu", devices=2, strategy="ddp", default_root_dir=tmpdir, sync_batchnorm=True, enable_progress_bar=False, enable_model_summary=False, ) trainer.fit(model) average_parameters_mock.assert_not_called() # test ddp with post-localSGD does call model averaging ddp_strategy = DDPStrategy( ddp_comm_state=post_localSGD.PostLocalSGDState( process_group=None, subgroup=None, start_localSGD_iter=8, ), ddp_comm_hook=post_localSGD.post_localSGD_hook, model_averaging_period=4, ) trainer = Trainer( fast_dev_run=True, accelerator="gpu", devices=2, strategy=ddp_strategy, default_root_dir=tmpdir, sync_batchnorm=True, ) trainer.fit(model) average_parameters_mock.assert_called()