def test_auto_scale_batch_size_set_model_attribute(tmpdir, use_hparams):
    """ Test that new batch size gets written to the correct hyperparameter attribute. """
    tutils.reset_seed()

    hparams = EvalModelTemplate.get_default_hparams()
    before_batch_size = hparams.get('batch_size')

    class HparamsEvalModelTemplate(EvalModelTemplate):

        def dataloader(self, *args, **kwargs):
            # artificially set batch_size so we can get a dataloader
            # remove it immediately after, because we want only self.hparams.batch_size
            setattr(self, "batch_size", before_batch_size)
            dataloader = super().dataloader(*args, **kwargs)
            del self.batch_size
            return dataloader

    datamodule_model = MNISTDataModule(data_dir=tmpdir, batch_size=111)  # this datamodule should get ignored!
    datamodule_fit = MNISTDataModule(data_dir=tmpdir, batch_size=before_batch_size)

    model_class = HparamsEvalModelTemplate if use_hparams else EvalModelTemplate
    model = model_class(**hparams)
    model.datamodule = datamodule_model  # unused when another module gets passed to .tune() / .fit()

    trainer = Trainer(default_root_dir=tmpdir, max_epochs=1, auto_scale_batch_size=True)
    trainer.tune(model, datamodule_fit)
    after_batch_size = model.hparams.batch_size if use_hparams else model.batch_size
    assert trainer.datamodule == datamodule_fit
    assert before_batch_size != after_batch_size
    assert after_batch_size <= len(trainer.train_dataloader.dataset)
    assert datamodule_fit.batch_size == after_batch_size
    # should be left unchanged, since it was not passed to .tune()
    assert datamodule_model.batch_size == 111
Пример #2
0
def test_sync_batchnorm_ddp(tmpdir):
    seed_everything(234)
    set_random_master_port()

    # define datamodule and dataloader
    dm = MNISTDataModule()
    dm.prepare_data()
    dm.setup(stage=None)

    train_dataloader = dm.train_dataloader()
    model = SyncBNModule()

    bn_outputs = []

    # shuffle is false by default
    for batch_idx, batch in enumerate(train_dataloader):
        x, _ = batch

        _, out_bn = model.forward(x, batch_idx)
        bn_outputs.append(out_bn)

        # get 3 steps
        if batch_idx == 2:
            break

    bn_outputs = [x.cuda() for x in bn_outputs]

    # reset datamodule
    # batch-size = 16 because 2 GPUs in DDP
    dm = MNISTDataModule(batch_size=16, dist_sampler=True)
    dm.prepare_data()
    dm.setup(stage=None)

    model = SyncBNModule(gpu_count=2, bn_targets=bn_outputs)

    trainer = Trainer(
        gpus=2,
        num_nodes=1,
        distributed_backend='ddp_spawn',
        max_epochs=1,
        max_steps=3,
        sync_batchnorm=True,
        num_sanity_val_steps=0,
        replace_sampler_ddp=False,
    )

    result = trainer.fit(model, dm)
    assert result == 1, "Sync batchnorm failing with DDP"