Пример #1
0
def test_ddp_post_local_sgd_comm_hook(tmpdir):
    """Test for DDP post-localSGD hook."""
    model = BoringModel()

    training_type_plugin = DDPStrategy(
        ddp_comm_state=post_localSGD.PostLocalSGDState(
            process_group=None,
            subgroup=None,
            start_localSGD_iter=8,
        ),
        ddp_comm_hook=post_localSGD.post_localSGD_hook,
        model_averaging_period=4,
    )
    trainer = Trainer(
        fast_dev_run=True,
        gpus=2,
        strategy=training_type_plugin,
        default_root_dir=tmpdir,
        sync_batchnorm=True,
    )
    trainer.fit(model)
    trainer_comm_hook = trainer.strategy.model.get_ddp_logging_data().comm_hook
    expected_comm_hook = post_localSGD.post_localSGD_hook.__qualname__
    assert trainer_comm_hook == expected_comm_hook
    assert trainer.state.finished, f"Training failed with {trainer.state}"
def test_ddp_configure_ddp():
    """Tests with ddp strategy."""
    model = BoringModel()
    ddp_strategy = DDPStrategy()
    trainer = Trainer(
        max_epochs=1,
        strategy=ddp_strategy,
    )
    # test wrap the model if fitting
    trainer.state.fn = TrainerFn.FITTING
    trainer.strategy.connect(model)
    trainer.lightning_module.trainer = trainer
    trainer.strategy.setup_environment()
    assert isinstance(trainer.model, LightningModule)
    trainer.strategy.setup(trainer)
    # in DDPStrategy configure_ddp(), model wrapped by DistributedDataParallel
    assert isinstance(trainer.model, DistributedDataParallel)

    trainer = Trainer(
        max_epochs=1,
        strategy=ddp_strategy,
    )
    # test do not wrap the model if trainerFN is not fitting
    trainer.state.fn = TrainerFn.VALIDATING
    trainer.strategy.connect(model)
    trainer.lightning_module.trainer = trainer
    trainer.strategy.setup_environment()
    trainer.strategy.setup(trainer)
    # in DDPStrategy configure_ddp(), model are still LightningModule
    assert isinstance(trainer.model, LightningModule)
Пример #3
0
def test_tpu_invalid_raises():
    training_type_plugin = TPUSpawnStrategy(accelerator=TPUAccelerator(),
                                            precision_plugin=Mock())
    with pytest.raises(
            ValueError,
            match="TPUAccelerator` can only be used with a `TPUPrecisionPlugin"
    ):
        Trainer(strategy=training_type_plugin)

    training_type_plugin = DDPStrategy(accelerator=TPUAccelerator(),
                                       precision_plugin=TPUPrecisionPlugin())
    with pytest.raises(
            ValueError,
            match="TPUAccelerator` can only be used with a `SingleTPUStrategy`"
    ):
        Trainer(strategy=training_type_plugin)
Пример #4
0
def test_ddp_fp16_compress_comm_hook(tmpdir):
    """Test for DDP FP16 compress hook."""
    model = BoringModel()
    training_type_plugin = DDPStrategy(ddp_comm_hook=default.fp16_compress_hook)
    trainer = Trainer(
        max_epochs=1,
        gpus=2,
        strategy=training_type_plugin,
        default_root_dir=tmpdir,
        sync_batchnorm=True,
        fast_dev_run=True,
    )
    trainer.fit(model)
    trainer_comm_hook = trainer.strategy.model.get_ddp_logging_data().comm_hook
    expected_comm_hook = default.fp16_compress_hook.__qualname__
    assert trainer_comm_hook == expected_comm_hook
    assert trainer.state.finished, f"Training failed with {trainer.state}"
Пример #5
0
def test_tpu_invalid_raises_set_precision_with_strategy():
    accelerator = TPUAccelerator()
    training_type_plugin = TPUSpawnStrategy(accelerator=accelerator,
                                            precision_plugin=object())
    with pytest.raises(
            ValueError,
            match=
            "`TPUAccelerator` can only be used with a `TPUPrecisionPlugin`"):
        Trainer(strategy=training_type_plugin)

    accelerator = TPUAccelerator()
    training_type_plugin = DDPStrategy(accelerator=accelerator,
                                       precision_plugin=TPUPrecisionPlugin())
    with pytest.raises(
            ValueError,
            match=
            "The `TPUAccelerator` can only be used with a `SingleTPUStrategy` or `TPUSpawnStrategy"
    ):
        Trainer(strategy=training_type_plugin)
Пример #6
0
def test_ddp_fp16_compress_wrap_sgd_comm_hook(tmpdir):
    """Test for DDP FP16 compress wrapper for SGD hook."""
    model = BoringModel()
    training_type_plugin = DDPStrategy(
        ddp_comm_state=powerSGD.PowerSGDState(process_group=None),
        ddp_comm_hook=powerSGD.powerSGD_hook,
        ddp_comm_wrapper=default.fp16_compress_wrapper,
    )
    trainer = Trainer(
        max_epochs=1,
        gpus=2,
        strategy=training_type_plugin,
        default_root_dir=tmpdir,
        sync_batchnorm=True,
        fast_dev_run=True,
    )
    trainer.fit(model)
    trainer_comm_hook = trainer.strategy.model.get_ddp_logging_data().comm_hook
    expected_comm_hook = default.fp16_compress_wrapper(powerSGD.powerSGD_hook).__qualname__
    assert trainer_comm_hook == expected_comm_hook
    assert trainer.state.finished, f"Training failed with {trainer.state}"
@RunIf(min_gpus=2)
@mock.patch.dict(
    os.environ,
    {
        "CUDA_VISIBLE_DEVICES": "0,1",
        "SLURM_NTASKS": "2",
        "SLURM_JOB_NAME": "SOME_NAME",
        "SLURM_NODEID": "0",
        "SLURM_PROCID": "1",
        "SLURM_LOCALID": "1",
    },
)
@mock.patch("pytorch_lightning.plugins.DDPStrategy.setup_distributed",
            autospec=True)
@pytest.mark.parametrize("strategy", ["ddp", DDPStrategy()])
def test_strategy_choice_ddp_slurm(setup_distributed_mock, strategy):
    trainer = Trainer(fast_dev_run=True, strategy=strategy, gpus=2)
    assert trainer._accelerator_connector._is_slurm_managing_tasks()
    assert isinstance(trainer.accelerator, GPUAccelerator)
    assert isinstance(trainer.strategy, DDPStrategy)
    assert isinstance(trainer.strategy.cluster_environment, SLURMEnvironment)
    assert trainer.strategy.cluster_environment.local_rank() == 1
    assert trainer.strategy.local_rank == 1


@mock.patch.dict(
    os.environ,
    {
        "CUDA_VISIBLE_DEVICES": "0,1",
        "SLURM_NTASKS": "2",