Exemple #1
0
def test_ddp_sampler_error():
    """
    Make sure DDP + AMP work
    :return:
    """
    if not can_run_gpu_test():
        return

    os.environ['MASTER_PORT'] = str(np.random.randint(12000, 19000, 1)[0])

    hparams = get_hparams()
    model = LightningTestModel(hparams, force_remove_distributed_sampler=True)

    exp = get_exp(True)
    exp.save()

    trainer = Trainer(experiment=exp,
                      show_progress_bar=False,
                      max_nb_epochs=1,
                      gpus=[0, 1],
                      distributed_backend='ddp',
                      use_amp=True)

    with pytest.warns(UserWarning):
        trainer.get_dataloaders(model)

    clear_save_dir()
def test_ddp_sampler_error():
    """
    Make sure DDP + AMP work
    :return:
    """
    if not can_run_gpu_test():
        return

    reset_seed()
    set_random_master_port()

    hparams = get_hparams()
    model = LightningTestModel(hparams, force_remove_distributed_sampler=True)

    logger = get_test_tube_logger(True)

    trainer = Trainer(logger=logger,
                      show_progress_bar=False,
                      max_nb_epochs=1,
                      gpus=[0, 1],
                      distributed_backend='ddp',
                      use_amp=True)

    with pytest.warns(UserWarning):
        trainer.get_dataloaders(model)

    clear_save_dir()
Exemple #3
0
def test_ddp_sampler_error():
    """
    Make sure DDP + AMP work
    :return:
    """
    if not torch.cuda.is_available():
        warnings.warn('test_amp_gpu_ddp cannot run. Rerun on a GPU node to run this test')
        return
    if not torch.cuda.device_count() > 1:
        warnings.warn('test_amp_gpu_ddp cannot run. Rerun on a node with 2+ GPUs to run this test')
        return

    os.environ['MASTER_PORT'] = str(np.random.randint(12000, 19000, 1)[0])

    hparams = get_hparams()
    model = LightningTestModel(hparams, force_remove_distributed_sampler=True)

    exp = get_exp(True)
    exp.save()

    trainer = Trainer(
        experiment=exp,
        progress_bar=False,
        max_nb_epochs=1,
        gpus=[0, 1],
        distributed_backend='ddp',
        use_amp=True
    )

    with pytest.warns(UserWarning):
        trainer.get_dataloaders(model)

    clear_save_dir()
Exemple #4
0
def test_ddp_sampler_error(tmpdir):
    """Make sure DDP + AMP work."""
    if not tutils.can_run_gpu_test():
        return

    tutils.reset_seed()
    tutils.set_random_master_port()

    hparams = tutils.get_hparams()
    model = LightningTestModel(hparams, force_remove_distributed_sampler=True)

    logger = tutils.get_test_tube_logger(tmpdir, True)

    trainer = Trainer(logger=logger,
                      show_progress_bar=False,
                      max_epochs=1,
                      gpus=[0, 1],
                      distributed_backend='ddp',
                      precision=16)

    with pytest.warns(UserWarning):
        trainer.get_dataloaders(model)