Beispiel #1
0
def test_sync_batchnorm_ddp(tmpdir):
    seed_everything(234)
    set_random_main_port()

    # define datamodule and dataloader
    dm = MNISTDataModule()
    dm.prepare_data()
    dm.setup(stage=None)

    train_dataloader = dm.train_dataloader()
    model = SyncBNModule()

    bn_outputs = []

    # shuffle is false by default
    for batch_idx, batch in enumerate(train_dataloader):
        x, _ = batch

        _, out_bn = model.forward(x, batch_idx)
        bn_outputs.append(out_bn)

        # get 3 steps
        if batch_idx == 2:
            break

    bn_outputs = [x.cuda() for x in bn_outputs]

    # reset datamodule
    # batch-size = 16 because 2 GPUs in DDP
    dm = MNISTDataModule(batch_size=16, dist_sampler=True)
    dm.prepare_data()
    dm.setup(stage=None)

    model = SyncBNModule(gpu_count=2, bn_targets=bn_outputs)
    ddp = DDPSpawnStrategy(
        parallel_devices=[torch.device("cuda", 0),
                          torch.device("cuda", 1)],
        num_nodes=1,
        sync_batchnorm=True,
        cluster_environment=LightningEnvironment(),
        find_unused_parameters=True,
    )

    trainer = Trainer(
        default_root_dir=tmpdir,
        gpus=2,
        num_nodes=1,
        strategy=ddp,
        max_epochs=1,
        max_steps=3,
        sync_batchnorm=True,
        num_sanity_val_steps=0,
        replace_sampler_ddp=False,
    )

    trainer.fit(model, dm)
    # the strategy is responsible for tearing down the batchnorm wrappers
    assert not isinstance(model.bn_layer,
                          torch.nn.modules.batchnorm.SyncBatchNorm)
    assert isinstance(model.bn_layer, torch.nn.modules.batchnorm._BatchNorm)
Beispiel #2
0
def test_sync_batchnorm_ddp(tmpdir):
    seed_everything(234)
    set_random_master_port()

    # define datamodule and dataloader
    dm = MNISTDataModule()
    dm.prepare_data()
    dm.setup(stage=None)

    train_dataloader = dm.train_dataloader()
    model = SyncBNModule()

    bn_outputs = []

    # shuffle is false by default
    for batch_idx, batch in enumerate(train_dataloader):
        x, _ = batch

        _, out_bn = model.forward(x, batch_idx)
        bn_outputs.append(out_bn)

        # get 3 steps
        if batch_idx == 2:
            break

    bn_outputs = [x.cuda() for x in bn_outputs]

    # reset datamodule
    # batch-size = 16 because 2 GPUs in DDP
    dm = MNISTDataModule(batch_size=16, dist_sampler=True)
    dm.prepare_data()
    dm.setup(stage=None)

    model = SyncBNModule(gpu_count=2, bn_targets=bn_outputs)
    ddp = DDPSpawnPlugin(
        parallel_devices=[torch.device("cuda", 0),
                          torch.device("cuda", 1)],
        num_nodes=1,
        sync_batchnorm=True,
        cluster_environment=LightningEnvironment(),
        find_unused_parameters=True,
    )

    trainer = Trainer(
        default_root_dir=tmpdir,
        gpus=2,
        num_nodes=1,
        accelerator="ddp_spawn",
        max_epochs=1,
        max_steps=3,
        sync_batchnorm=True,
        num_sanity_val_steps=0,
        replace_sampler_ddp=False,
        plugins=[ddp],
    )

    trainer.fit(model, dm)
    assert trainer.state.finished, "Sync batchnorm failing with DDP"
def test_sync_batchnorm_ddp(tmpdir):
    seed_everything(234)
    set_random_master_port()

    # define datamodule and dataloader
    dm = MNISTDataModule()
    dm.prepare_data()
    dm.setup(stage=None)

    train_dataloader = dm.train_dataloader()
    model = SyncBNModule()

    bn_outputs = []

    # shuffle is false by default
    for batch_idx, batch in enumerate(train_dataloader):
        x, _ = batch

        _, out_bn = model.forward(x, batch_idx)
        bn_outputs.append(out_bn)

        # get 3 steps
        if batch_idx == 2:
            break

    bn_outputs = [x.cuda() for x in bn_outputs]

    # reset datamodule
    # batch-size = 16 because 2 GPUs in DDP
    dm = MNISTDataModule(batch_size=16, dist_sampler=True)
    dm.prepare_data()
    dm.setup(stage=None)

    model = SyncBNModule(gpu_count=2, bn_targets=bn_outputs)

    trainer = Trainer(gpus=2,
                      num_nodes=1,
                      accelerator='ddp_spawn',
                      max_epochs=1,
                      max_steps=3,
                      sync_batchnorm=True,
                      num_sanity_val_steps=0,
                      replace_sampler_ddp=False,
                      plugins=[DDPPlugin(find_unused_parameters=True)])

    trainer.fit(model, dm)
    assert trainer.state == TrainerState.FINISHED, "Sync batchnorm failing with DDP"