def test_ddp_fp16_compress_wrap_sgd_comm_hook(tmpdir): """Test for DDP FP16 compress wrapper for SGD hook.""" model = BoringModel() training_type_plugin = DDPPlugin( ddp_comm_state=powerSGD.PowerSGDState(process_group=None), ddp_comm_hook=powerSGD.powerSGD_hook, ddp_comm_wrapper=default.fp16_compress_wrapper, sync_batchnorm=True, ) trainer = Trainer( max_epochs=1, gpus=2, plugins=[training_type_plugin], default_root_dir=tmpdir, sync_batchnorm=True, fast_dev_run=True, ) trainer.fit(model) trainer_comm_hook = (trainer.accelerator.training_type_plugin._model. get_ddp_logging_data().comm_hook) expected_comm_hook = default.fp16_compress_wrapper( powerSGD.powerSGD_hook).__qualname__ assert trainer_comm_hook == expected_comm_hook assert (trainer.state == TrainerState.FINISHED ), f"Training failed with {trainer.state}"
def test_invalid_powerSGD_state(self): for start_powerSGD_iter, use_error_feedback, warm_start in product( [0, 1], [True, False], [True, False]): if not use_error_feedback and not warm_start: continue with self.assertRaisesRegex( ValueError, "Expect `start_powerSGD_iter` > 1 if `use_error_feedback` or `warm_start` is enabled, " "because PowerSGD can only be applied after the first two iterations in DDP.", ): state = powerSGD.PowerSGDState( process_group=None, matrix_approximation_rank=1, start_powerSGD_iter=start_powerSGD_iter, use_error_feedback=use_error_feedback, warm_start=warm_start, )
def test_ddp_sgd_comm_hook(tmpdir): """Test for DDP FP16 compress hook.""" model = BoringModel() training_type_plugin = DDPStrategy( ddp_comm_state=powerSGD.PowerSGDState(process_group=None), ddp_comm_hook=powerSGD.powerSGD_hook, ) trainer = Trainer( max_epochs=1, gpus=2, strategy=training_type_plugin, default_root_dir=tmpdir, sync_batchnorm=True, fast_dev_run=True, ) trainer.fit(model) trainer_comm_hook = trainer.strategy.model.get_ddp_logging_data().comm_hook expected_comm_hook = powerSGD.powerSGD_hook.__qualname__ assert trainer_comm_hook == expected_comm_hook assert trainer.state.finished, f"Training failed with {trainer.state}"
def test_ddp_sgd_comm_hook(tmpdir): """Test for DDP FP16 compress hook.""" model = BoringModel() strategy = TestDDPStrategy( expected_ddp_comm_hook_name=powerSGD.powerSGD_hook.__qualname__, ddp_comm_state=powerSGD.PowerSGDState(process_group=None), ddp_comm_hook=powerSGD.powerSGD_hook, ) trainer = Trainer( max_epochs=1, accelerator="gpu", devices=2, strategy=strategy, default_root_dir=tmpdir, sync_batchnorm=True, fast_dev_run=True, enable_progress_bar=False, enable_model_summary=False, ) trainer.fit(model) assert trainer.state.finished, f"Training failed with {trainer.state}"
def test_ddp_fp16_compress_wrap_sgd_comm_hook(tmpdir): """Test for DDP FP16 compress wrapper for SGD hook.""" model = BoringModel() strategy = DDPStrategy( ddp_comm_state=powerSGD.PowerSGDState(process_group=None), ddp_comm_hook=powerSGD.powerSGD_hook, ddp_comm_wrapper=default.fp16_compress_wrapper, ) trainer = Trainer( max_epochs=1, accelerator="gpu", devices=2, strategy=strategy, default_root_dir=tmpdir, sync_batchnorm=True, fast_dev_run=True, ) trainer.fit(model) trainer_comm_hook = trainer.strategy.model.get_ddp_logging_data().comm_hook expected_comm_hook = default.fp16_compress_wrapper( powerSGD.powerSGD_hook).__qualname__ assert trainer_comm_hook == expected_comm_hook assert trainer.state.finished, f"Training failed with {trainer.state}"