def test_ddp_fp16_compress_wrap_sgd_comm_hook(tmpdir):
    """Test for DDP FP16 compress wrapper for SGD hook."""
    model = BoringModel()
    training_type_plugin = DDPPlugin(
        ddp_comm_state=powerSGD.PowerSGDState(process_group=None),
        ddp_comm_hook=powerSGD.powerSGD_hook,
        ddp_comm_wrapper=default.fp16_compress_wrapper,
        sync_batchnorm=True,
    )
    trainer = Trainer(
        max_epochs=1,
        gpus=2,
        plugins=[training_type_plugin],
        default_root_dir=tmpdir,
        sync_batchnorm=True,
        fast_dev_run=True,
    )
    trainer.fit(model)
    trainer_comm_hook = (trainer.accelerator.training_type_plugin._model.
                         get_ddp_logging_data().comm_hook)
    expected_comm_hook = default.fp16_compress_wrapper(
        powerSGD.powerSGD_hook).__qualname__
    assert trainer_comm_hook == expected_comm_hook
    assert (trainer.state == TrainerState.FINISHED
            ), f"Training failed with {trainer.state}"
Beispiel #2
0
 def test_invalid_powerSGD_state(self):
     for start_powerSGD_iter, use_error_feedback, warm_start in product(
         [0, 1], [True, False], [True, False]):
         if not use_error_feedback and not warm_start:
             continue
         with self.assertRaisesRegex(
                 ValueError,
                 "Expect `start_powerSGD_iter` > 1 if `use_error_feedback` or `warm_start` is enabled, "
                 "because PowerSGD can only be applied after the first two iterations in DDP.",
         ):
             state = powerSGD.PowerSGDState(
                 process_group=None,
                 matrix_approximation_rank=1,
                 start_powerSGD_iter=start_powerSGD_iter,
                 use_error_feedback=use_error_feedback,
                 warm_start=warm_start,
             )
Beispiel #3
0
def test_ddp_sgd_comm_hook(tmpdir):
    """Test for DDP FP16 compress hook."""
    model = BoringModel()
    training_type_plugin = DDPStrategy(
        ddp_comm_state=powerSGD.PowerSGDState(process_group=None),
        ddp_comm_hook=powerSGD.powerSGD_hook,
    )
    trainer = Trainer(
        max_epochs=1,
        gpus=2,
        strategy=training_type_plugin,
        default_root_dir=tmpdir,
        sync_batchnorm=True,
        fast_dev_run=True,
    )
    trainer.fit(model)
    trainer_comm_hook = trainer.strategy.model.get_ddp_logging_data().comm_hook
    expected_comm_hook = powerSGD.powerSGD_hook.__qualname__
    assert trainer_comm_hook == expected_comm_hook
    assert trainer.state.finished, f"Training failed with {trainer.state}"
Beispiel #4
0
def test_ddp_sgd_comm_hook(tmpdir):
    """Test for DDP FP16 compress hook."""
    model = BoringModel()
    strategy = TestDDPStrategy(
        expected_ddp_comm_hook_name=powerSGD.powerSGD_hook.__qualname__,
        ddp_comm_state=powerSGD.PowerSGDState(process_group=None),
        ddp_comm_hook=powerSGD.powerSGD_hook,
    )
    trainer = Trainer(
        max_epochs=1,
        accelerator="gpu",
        devices=2,
        strategy=strategy,
        default_root_dir=tmpdir,
        sync_batchnorm=True,
        fast_dev_run=True,
        enable_progress_bar=False,
        enable_model_summary=False,
    )
    trainer.fit(model)
    assert trainer.state.finished, f"Training failed with {trainer.state}"
Beispiel #5
0
def test_ddp_fp16_compress_wrap_sgd_comm_hook(tmpdir):
    """Test for DDP FP16 compress wrapper for SGD hook."""
    model = BoringModel()
    strategy = DDPStrategy(
        ddp_comm_state=powerSGD.PowerSGDState(process_group=None),
        ddp_comm_hook=powerSGD.powerSGD_hook,
        ddp_comm_wrapper=default.fp16_compress_wrapper,
    )
    trainer = Trainer(
        max_epochs=1,
        accelerator="gpu",
        devices=2,
        strategy=strategy,
        default_root_dir=tmpdir,
        sync_batchnorm=True,
        fast_dev_run=True,
    )
    trainer.fit(model)
    trainer_comm_hook = trainer.strategy.model.get_ddp_logging_data().comm_hook
    expected_comm_hook = default.fp16_compress_wrapper(
        powerSGD.powerSGD_hook).__qualname__
    assert trainer_comm_hook == expected_comm_hook
    assert trainer.state.finished, f"Training failed with {trainer.state}"