コード例 #1
0
def test_gradient_accumulation_scheduling(tmpdir, schedule, expected):
    """
    Test grad accumulation by the freq of optimizer updates
    """

    # test incorrect configs
    with pytest.raises(IndexError):
        assert Trainer(accumulate_grad_batches={-1: 3, 1: 4, 4: 6})
    with pytest.raises(IndexError):
        assert Trainer(accumulate_grad_batches={-2: 3})

    with pytest.raises(TypeError):
        assert Trainer(accumulate_grad_batches={})
    with pytest.raises(TypeError):
        assert Trainer(accumulate_grad_batches=[[2, 3], [4, 6]])
    with pytest.raises(TypeError):
        assert Trainer(accumulate_grad_batches={1: 2, 3.: 4})
    with pytest.raises(TypeError):
        assert Trainer(accumulate_grad_batches={1: 2.5, 3: 5})

    model = EvalModelTemplate()

    trainer = Trainer(
        accumulate_grad_batches=schedule,
        limit_train_batches=
        0.7,  # not to be divisible by accumulate_grad_batches on purpose
        limit_val_batches=0.8,
        max_epochs=4,
        default_root_dir=tmpdir,
    )

    # test optimizer call freq matches scheduler
    def _optimizer_step(epoch,
                        batch_idx,
                        optimizer,
                        optimizer_idx,
                        second_order_closure=None,
                        on_tpu=False,
                        using_native_amp=False,
                        using_lbfgs=False):
        # only test the first 12 batches in epoch
        if batch_idx < 12:
            if epoch == 0:
                # reset counter when starting epoch
                if batch_idx == expected[0] - 1:
                    model.prev_called_batch_idx = expected[0] - 1

                    # use this opportunity to test once
                    assert trainer.accumulate_grad_batches == expected[0]

                # separate check for last batch with accumulate 1 step
                if expected[0] == 1 and (batch_idx +
                                         1) == trainer.num_training_batches:
                    assert batch_idx == model.prev_called_batch_idx
                elif (batch_idx + 1) == trainer.num_training_batches:
                    # prev_called_batch_idx - schedule + modulus remainder
                    assert batch_idx == (model.prev_called_batch_idx -
                                         expected[0] +
                                         (batch_idx + 1) % expected[0])
                else:
                    assert batch_idx == model.prev_called_batch_idx
                    model.prev_called_batch_idx += expected[0]

            elif 1 <= epoch <= 2:
                # reset counter when starting epoch
                if batch_idx == expected[1] - 1:
                    model.prev_called_batch_idx = expected[1] - 1

                    # use this opportunity to test once
                    assert trainer.accumulate_grad_batches == expected[1]

                if trainer.num_training_batches == batch_idx + 1:
                    # prev_called_batch_idx - schedule + modulus remainder
                    assert batch_idx == (model.prev_called_batch_idx -
                                         expected[1] +
                                         (batch_idx + 1) % expected[1])
                else:
                    assert batch_idx == model.prev_called_batch_idx
                    model.prev_called_batch_idx += expected[1]

            else:
                if batch_idx == expected[2] - 1:
                    model.prev_called_batch_idx = expected[2] - 1

                    # use this opportunity to test once
                    assert trainer.accumulate_grad_batches == expected[2]

                if (batch_idx + 1) == trainer.num_training_batches:
                    # prev_called_batch_idx - schedule + modulus remainder
                    assert batch_idx == (model.prev_called_batch_idx -
                                         expected[2] +
                                         (batch_idx + 1) % expected[2])
                else:
                    assert batch_idx == model.prev_called_batch_idx
                    model.prev_called_batch_idx += expected[2]

        optimizer.step()

        # clear gradients
        optimizer.zero_grad()

    # for the test
    model.optimizer_step = _optimizer_step
    model.prev_called_batch_idx = 0

    trainer.fit(model)
コード例 #2
0
def test_optimizer_return_options():

    trainer = Trainer()
    model = EvalModelTemplate()

    # single optimizer
    opt_a = torch.optim.Adam(model.parameters(), lr=0.002)
    opt_b = torch.optim.SGD(model.parameters(), lr=0.002)
    scheduler_a = torch.optim.lr_scheduler.StepLR(opt_a, 10)
    scheduler_b = torch.optim.lr_scheduler.StepLR(opt_b, 10)

    # single optimizer
    model.configure_optimizers = lambda: opt_a
    optim, lr_sched, freq = trainer.init_optimizers(model)
    assert len(optim) == 1 and len(lr_sched) == 0 and len(freq) == 0

    # opt tuple
    model.configure_optimizers = lambda: (opt_a, opt_b)
    optim, lr_sched, freq = trainer.init_optimizers(model)
    assert len(optim) == 2 and optim[0] == opt_a and optim[1] == opt_b
    assert len(lr_sched) == 0 and len(freq) == 0

    # opt list
    model.configure_optimizers = lambda: [opt_a, opt_b]
    optim, lr_sched, freq = trainer.init_optimizers(model)
    assert len(optim) == 2 and optim[0] == opt_a and optim[1] == opt_b
    assert len(lr_sched) == 0 and len(freq) == 0

    # opt tuple of 2 lists
    model.configure_optimizers = lambda: ([opt_a], [scheduler_a])
    optim, lr_sched, freq = trainer.init_optimizers(model)
    assert len(optim) == 1 and len(lr_sched) == 1 and len(freq) == 0
    assert optim[0] == opt_a
    assert lr_sched[0] == dict(scheduler=scheduler_a,
                               interval='epoch',
                               frequency=1,
                               reduce_on_plateau=False)

    # opt single dictionary
    model.configure_optimizers = lambda: {
        "optimizer": opt_a,
        "lr_scheduler": scheduler_a
    }
    optim, lr_sched, freq = trainer.init_optimizers(model)
    assert len(optim) == 1 and len(lr_sched) == 1 and len(freq) == 0
    assert optim[0] == opt_a
    assert lr_sched[0] == dict(scheduler=scheduler_a,
                               interval='epoch',
                               frequency=1,
                               reduce_on_plateau=False)

    # opt multiple dictionaries with frequencies
    model.configure_optimizers = lambda: (
        {
            "optimizer": opt_a,
            "lr_scheduler": scheduler_a,
            "frequency": 1
        },
        {
            "optimizer": opt_b,
            "lr_scheduler": scheduler_b,
            "frequency": 5
        },
    )
    optim, lr_sched, freq = trainer.init_optimizers(model)
    assert len(optim) == 2 and len(lr_sched) == 2 and len(freq) == 2
    assert optim[0] == opt_a
    assert lr_sched[0] == dict(scheduler=scheduler_a,
                               interval='epoch',
                               frequency=1,
                               reduce_on_plateau=False)
    assert freq == [1, 5]
コード例 #3
0
def test_overfit_batch_limits(tmpdir):
    # ------------------------------------------------------
    # Make sure shuffle is correct across loaders initially
    # ------------------------------------------------------
    model = EvalModelTemplate()
    model.train_dataloader()

    # original train loader which should be replaced in all methods
    train_loader = model.train_dataloader()

    # make sure the val and tests are not shuffled
    assert isinstance(train_loader.sampler, RandomSampler)
    assert isinstance(model.val_dataloader().sampler, SequentialSampler)
    assert isinstance(model.test_dataloader().sampler, SequentialSampler)

    # ------------------------------------------------------
    # get the training loader and batch
    # ------------------------------------------------------
    # Create a reference train dataloader without shuffling.
    train_loader = DataLoader(model.train_dataloader().dataset, shuffle=False)
    (xa, ya) = next(iter(train_loader))
    train_loader = DataLoader(model.train_dataloader().dataset, shuffle=True)
    full_train_samples = len(train_loader)
    num_train_samples = int(0.11 * full_train_samples)

    # ------------------------------------------------------
    # set VAL and Test loaders
    # ------------------------------------------------------
    val_loader = DataLoader(model.val_dataloader().dataset, shuffle=False)
    test_loader = DataLoader(model.test_dataloader().dataset, shuffle=False)

    # set the model loaders
    model.train_dataloader = lambda: train_loader
    model.val_dataloader = lambda: val_loader
    model.test_dataloader = lambda: test_loader

    # ------------------------------------------------------
    # test train loader applies correct limits
    # ------------------------------------------------------
    trainer = Trainer(overfit_batches=4)
    trainer.reset_train_dataloader(model)
    assert trainer.num_training_batches == 4

    # make sure the loaders are the same
    (xb, yb) = next(iter(trainer.train_dataloader))
    assert torch.eq(xa, xb).all()
    assert torch.eq(ya, yb).all()

    trainer = Trainer(overfit_batches=0.11)
    trainer.reset_train_dataloader(model)
    # The dataloader should have been overwritten with a Sequential sampler.
    assert trainer.train_dataloader is not train_loader
    assert trainer.num_training_batches == num_train_samples

    # make sure the loaders are the same
    (xb, yb) = next(iter(trainer.train_dataloader))
    assert torch.eq(xa, xb).all()
    assert torch.eq(ya, yb).all()

    # ------------------------------------------------------
    # run tests for both val and test
    # ------------------------------------------------------
    for split in ['val', 'test']:

        # ------------------------------------------------------
        # test overfit_batches as percent
        # ------------------------------------------------------
        loader_num_batches, dataloaders = Trainer(
            overfit_batches=0.11)._reset_eval_dataloader(model, split)
        assert loader_num_batches[0] == num_train_samples

        # make sure we turned off shuffle for the user
        assert isinstance(dataloaders[0].sampler, SequentialSampler)

        # make sure the loaders are the same
        (xb, yb) = next(iter(dataloaders[0]))
        assert torch.eq(xa, xb).all()
        assert torch.eq(ya, yb).all()

        # ------------------------------------------------------
        # test overfit_batches as int
        # ------------------------------------------------------
        loader_num_batches, dataloaders = Trainer(
            overfit_batches=1)._reset_eval_dataloader(model, split)
        assert loader_num_batches[0] == 1
        loader_num_batches, dataloaders = Trainer(
            overfit_batches=5)._reset_eval_dataloader(model, split)
        assert loader_num_batches[0] == 5

        # ------------------------------------------------------
        # test limit_xxx_batches as percent AND int
        # ------------------------------------------------------
        if split == 'val':
            loader_num_batches, dataloaders = Trainer(
                limit_val_batches=0.1)._reset_eval_dataloader(model, split)
            assert loader_num_batches[0] == int(0.1 * len(val_loader))

            loader_num_batches, dataloaders = Trainer(
                limit_val_batches=10)._reset_eval_dataloader(model, split)
            assert loader_num_batches[0] == 10
        else:
            loader_num_batches, dataloaders = Trainer(
                limit_test_batches=0.1)._reset_eval_dataloader(model, split)
            assert loader_num_batches[0] == int(0.1 * len(test_loader))

            loader_num_batches, dataloaders = Trainer(
                limit_test_batches=10)._reset_eval_dataloader(model, split)
            assert loader_num_batches[0] == 10
コード例 #4
0
def test_model_freeze_unfreeze():

    model = EvalModelTemplate()

    model.freeze()
    model.unfreeze()
コード例 #5
0
def test_gradient_accumulation_scheduling(tmpdir):
    """
    Test grad accumulation by the freq of optimizer updates
    """

    # test incorrect configs
    with pytest.raises(IndexError):
        assert Trainer(accumulate_grad_batches={0: 3, 1: 4, 4: 6})
        assert Trainer(accumulate_grad_batches={-2: 3})

    with pytest.raises(TypeError):
        assert Trainer(accumulate_grad_batches={})
        assert Trainer(accumulate_grad_batches=[[2, 3], [4, 6]])
        assert Trainer(accumulate_grad_batches={1: 2, 3.: 4})
        assert Trainer(accumulate_grad_batches={1: 2.5, 3: 5})

    # test optimizer call freq matches scheduler
    def _optimizer_step(self,
                        epoch,
                        batch_idx,
                        optimizer,
                        optimizer_idx,
                        second_order_closure=None):
        # only test the first 12 batches in epoch
        if batch_idx < 12:
            if epoch == 0:
                # reset counter when starting epoch
                if batch_idx == 0:
                    self.prev_called_batch_idx = 0

                    # use this opportunity to test once
                    assert self.trainer.accumulate_grad_batches == 1

                assert batch_idx == self.prev_called_batch_idx
                self.prev_called_batch_idx += 1

            elif 1 <= epoch <= 2:
                # reset counter when starting epoch
                if batch_idx == 1:
                    self.prev_called_batch_idx = 1

                    # use this opportunity to test once
                    assert self.trainer.accumulate_grad_batches == 2

                assert batch_idx == self.prev_called_batch_idx
                self.prev_called_batch_idx += 2

            else:
                if batch_idx == 3:
                    self.prev_called_batch_idx = 3

                    # use this opportunity to test once
                    assert self.trainer.accumulate_grad_batches == 4

                assert batch_idx == self.prev_called_batch_idx
                self.prev_called_batch_idx += 3

        optimizer.step()

        # clear gradients
        optimizer.zero_grad()

    model = EvalModelTemplate()
    schedule = {1: 2, 3: 4}

    trainer = Trainer(accumulate_grad_batches=schedule,
                      train_percent_check=0.1,
                      val_percent_check=0.1,
                      max_epochs=2,
                      default_root_dir=tmpdir)

    # for the test
    trainer.optimizer_step = _optimizer_step
    model.prev_called_batch_idx = 0

    trainer.fit(model)
コード例 #6
0
def test_model_pickle(tmpdir):
    model = EvalModelTemplate()
    pickle.dumps(model)
    cloudpickle.dumps(model)
コード例 #7
0
def test_trainer_callback_system(tmpdir):
    """Test the callback system."""

    hparams = tutils.get_default_hparams()
    model = EvalModelTemplate(hparams)

    def _check_args(trainer, pl_module):
        assert isinstance(trainer, Trainer)
        assert isinstance(pl_module, LightningModule)

    class TestCallback(Callback):
        def __init__(self):
            super().__init__()
            self.on_init_start_called = False
            self.on_init_end_called = False
            self.on_sanity_check_start_called = False
            self.on_sanity_check_end_called = False
            self.on_epoch_start_called = False
            self.on_epoch_end_called = False
            self.on_batch_start_called = False
            self.on_batch_end_called = False
            self.on_validation_batch_start_called = False
            self.on_validation_batch_end_called = False
            self.on_test_batch_start_called = False
            self.on_test_batch_end_called = False
            self.on_train_start_called = False
            self.on_train_end_called = False
            self.on_validation_start_called = False
            self.on_validation_end_called = False
            self.on_test_start_called = False
            self.on_test_end_called = False

        def on_init_start(self, trainer):
            assert isinstance(trainer, Trainer)
            self.on_init_start_called = True

        def on_init_end(self, trainer):
            assert isinstance(trainer, Trainer)
            self.on_init_end_called = True

        def on_sanity_check_start(self, trainer, pl_module):
            _check_args(trainer, pl_module)
            self.on_sanity_check_start_called = True

        def on_sanity_check_end(self, trainer, pl_module):
            _check_args(trainer, pl_module)
            self.on_sanity_check_end_called = True

        def on_epoch_start(self, trainer, pl_module):
            _check_args(trainer, pl_module)
            self.on_epoch_start_called = True

        def on_epoch_end(self, trainer, pl_module):
            _check_args(trainer, pl_module)
            self.on_epoch_end_called = True

        def on_batch_start(self, trainer, pl_module):
            _check_args(trainer, pl_module)
            self.on_batch_start_called = True

        def on_batch_end(self, trainer, pl_module):
            _check_args(trainer, pl_module)
            self.on_batch_end_called = True

        def on_validation_batch_start(self, trainer, pl_module):
            _check_args(trainer, pl_module)
            self.on_validation_batch_start_called = True

        def on_validation_batch_end(self, trainer, pl_module):
            _check_args(trainer, pl_module)
            self.on_validation_batch_end_called = True

        def on_test_batch_start(self, trainer, pl_module):
            _check_args(trainer, pl_module)
            self.on_test_batch_start_called = True

        def on_test_batch_end(self, trainer, pl_module):
            _check_args(trainer, pl_module)
            self.on_test_batch_end_called = True

        def on_train_start(self, trainer, pl_module):
            _check_args(trainer, pl_module)
            self.on_train_start_called = True

        def on_train_end(self, trainer, pl_module):
            _check_args(trainer, pl_module)
            self.on_train_end_called = True

        def on_validation_start(self, trainer, pl_module):
            _check_args(trainer, pl_module)
            self.on_validation_start_called = True

        def on_validation_end(self, trainer, pl_module):
            _check_args(trainer, pl_module)
            self.on_validation_end_called = True

        def on_test_start(self, trainer, pl_module):
            _check_args(trainer, pl_module)
            self.on_test_start_called = True

        def on_test_end(self, trainer, pl_module):
            _check_args(trainer, pl_module)
            self.on_test_end_called = True

    test_callback = TestCallback()

    trainer_options = dict(
        callbacks=[test_callback],
        max_epochs=1,
        val_percent_check=0.1,
        train_percent_check=0.2,
        progress_bar_refresh_rate=0,
    )

    assert not test_callback.on_init_start_called
    assert not test_callback.on_init_end_called
    assert not test_callback.on_sanity_check_start_called
    assert not test_callback.on_sanity_check_end_called
    assert not test_callback.on_epoch_start_called
    assert not test_callback.on_epoch_start_called
    assert not test_callback.on_batch_start_called
    assert not test_callback.on_batch_end_called
    assert not test_callback.on_validation_batch_start_called
    assert not test_callback.on_validation_batch_end_called
    assert not test_callback.on_test_batch_start_called
    assert not test_callback.on_test_batch_end_called
    assert not test_callback.on_train_start_called
    assert not test_callback.on_train_end_called
    assert not test_callback.on_validation_start_called
    assert not test_callback.on_validation_end_called
    assert not test_callback.on_test_start_called
    assert not test_callback.on_test_end_called

    # fit model
    trainer = Trainer(**trainer_options)

    assert trainer.callbacks[0] == test_callback
    assert test_callback.on_init_start_called
    assert test_callback.on_init_end_called
    assert not test_callback.on_sanity_check_start_called
    assert not test_callback.on_sanity_check_end_called
    assert not test_callback.on_epoch_start_called
    assert not test_callback.on_epoch_start_called
    assert not test_callback.on_batch_start_called
    assert not test_callback.on_batch_end_called
    assert not test_callback.on_validation_batch_start_called
    assert not test_callback.on_validation_batch_end_called
    assert not test_callback.on_test_batch_start_called
    assert not test_callback.on_test_batch_end_called
    assert not test_callback.on_train_start_called
    assert not test_callback.on_train_end_called
    assert not test_callback.on_validation_start_called
    assert not test_callback.on_validation_end_called
    assert not test_callback.on_test_start_called
    assert not test_callback.on_test_end_called

    trainer.fit(model)

    assert test_callback.on_init_start_called
    assert test_callback.on_init_end_called
    assert test_callback.on_sanity_check_start_called
    assert test_callback.on_sanity_check_end_called
    assert test_callback.on_epoch_start_called
    assert test_callback.on_epoch_start_called
    assert test_callback.on_batch_start_called
    assert test_callback.on_batch_end_called
    assert test_callback.on_validation_batch_start_called
    assert test_callback.on_validation_batch_end_called
    assert test_callback.on_train_start_called
    assert test_callback.on_train_end_called
    assert test_callback.on_validation_start_called
    assert test_callback.on_validation_end_called
    assert not test_callback.on_test_batch_start_called
    assert not test_callback.on_test_batch_end_called
    assert not test_callback.on_test_start_called
    assert not test_callback.on_test_end_called

    test_callback = TestCallback()
    trainer_options.update(callbacks=[test_callback])
    trainer = Trainer(**trainer_options)
    trainer.test(model)

    assert test_callback.on_test_batch_start_called
    assert test_callback.on_test_batch_end_called
    assert test_callback.on_test_start_called
    assert test_callback.on_test_end_called
    assert not test_callback.on_validation_start_called
    assert not test_callback.on_validation_end_called
    assert not test_callback.on_validation_batch_end_called
    assert not test_callback.on_validation_batch_start_called
コード例 #8
0
def test_dp_resume(tmpdir):
    """Make sure DP continues training correctly."""
    hparams = EvalModelTemplate.get_default_hparams()
    model = EvalModelTemplate(**hparams)

    trainer_options = dict(
        max_epochs=1,
        gpus=2,
        distributed_backend='dp',
        default_root_dir=tmpdir,
    )

    # get logger
    logger = tutils.get_default_logger(tmpdir)

    # exp file to get weights
    # logger file to get weights
    checkpoint = tutils.init_checkpoint_callback(logger)

    # add these to the trainer options
    trainer_options['logger'] = logger
    trainer_options['checkpoint_callback'] = checkpoint

    # fit model
    trainer = Trainer(**trainer_options)
    trainer.is_slurm_managing_tasks = True
    result = trainer.fit(model)

    # track epoch before saving. Increment since we finished the current epoch, don't want to rerun
    real_global_epoch = trainer.current_epoch + 1

    # correct result and ok accuracy
    assert result == 1, 'amp + dp model failed to complete'

    # ---------------------------
    # HPC LOAD/SAVE
    # ---------------------------
    # save
    trainer.hpc_save(tmpdir, logger)

    # init new trainer
    new_logger = tutils.get_default_logger(tmpdir, version=logger.version)
    trainer_options['logger'] = new_logger
    trainer_options['checkpoint_callback'] = ModelCheckpoint(tmpdir)
    trainer_options['limit_train_batches'] = 0.5
    trainer_options['limit_val_batches'] = 0.2
    trainer_options['max_epochs'] = 1
    new_trainer = Trainer(**trainer_options)

    # set the epoch start hook so we can predict before the model does the full training
    def assert_good_acc():
        assert new_trainer.current_epoch == real_global_epoch and new_trainer.current_epoch > 0

        # if model and state loaded correctly, predictions will be good even though we
        # haven't trained with the new loaded model
        dp_model = new_trainer.model
        dp_model.eval()

        dataloader = trainer.train_dataloader
        tpipes.run_prediction(dataloader, dp_model, dp=True)

    # new model
    model = EvalModelTemplate(**hparams)
    model.on_train_start = assert_good_acc

    # fit new model which should load hpc weights
    new_trainer.fit(model)

    # test freeze on gpu
    model.freeze()
    model.unfreeze()
コード例 #9
0
def test_trainer_callback_system(tmpdir):
    """Test the callback system."""

    hparams = EvalModelTemplate.get_default_hparams()
    model = EvalModelTemplate(**hparams)

    def _check_args(trainer, pl_module):
        assert isinstance(trainer, Trainer)
        assert isinstance(pl_module, LightningModule)

    class TestCallback(Callback):
        def __init__(self):
            super().__init__()
            self.setup_called = False
            self.teardown_called = False
            self.on_init_start_called = False
            self.on_init_end_called = False
            self.on_fit_start_called = False
            self.on_fit_end_called = False
            self.on_sanity_check_start_called = False
            self.on_sanity_check_end_called = False
            self.on_epoch_start_called = False
            self.on_epoch_end_called = False
            self.on_batch_start_called = False
            self.on_batch_end_called = False
            self.on_train_batch_start_called = False
            self.on_train_batch_end_called = False
            self.on_validation_batch_start_called = False
            self.on_validation_batch_end_called = False
            self.on_test_batch_start_called = False
            self.on_test_batch_end_called = False
            self.on_train_start_called = False
            self.on_train_end_called = False
            self.on_pretrain_routine_start_called = False
            self.on_pretrain_routine_end_called = False
            self.on_validation_start_called = False
            self.on_validation_end_called = False
            self.on_test_start_called = False
            self.on_test_end_called = False

        def setup(self, trainer, pl_module, stage: str):
            assert isinstance(trainer, Trainer)
            self.setup_called = True

        def teardown(self, trainer, pl_module, step: str):
            assert isinstance(trainer, Trainer)
            self.teardown_called = True

        def on_init_start(self, trainer):
            assert isinstance(trainer, Trainer)
            self.on_init_start_called = True

        def on_init_end(self, trainer):
            assert isinstance(trainer, Trainer)
            self.on_init_end_called = True

        def on_fit_start(self, trainer, pl_module):
            assert isinstance(trainer, Trainer)
            self.on_fit_start_called = True

        def on_fit_end(self, trainer, pl_module):
            assert isinstance(trainer, Trainer)
            self.on_fit_end_called = True

        def on_sanity_check_start(self, trainer, pl_module):
            _check_args(trainer, pl_module)
            self.on_sanity_check_start_called = True

        def on_sanity_check_end(self, trainer, pl_module):
            _check_args(trainer, pl_module)
            self.on_sanity_check_end_called = True

        def on_epoch_start(self, trainer, pl_module):
            _check_args(trainer, pl_module)
            self.on_epoch_start_called = True

        def on_epoch_end(self, trainer, pl_module):
            _check_args(trainer, pl_module)
            self.on_epoch_end_called = True

        def on_batch_start(self, trainer, pl_module):
            _check_args(trainer, pl_module)
            self.on_batch_start_called = True

        def on_batch_end(self, trainer, pl_module):
            _check_args(trainer, pl_module)
            self.on_batch_end_called = True

        def on_train_batch_start(self, trainer, pl_module, batch, batch_idx,
                                 dataloader_idx):
            _check_args(trainer, pl_module)
            self.on_train_batch_start_called = True

        def on_train_batch_end(self, trainer, pl_module, batch, batch_idx,
                               dataloader_idx):
            _check_args(trainer, pl_module)
            self.on_train_batch_end_called = True

        def on_validation_batch_start(self, trainer, pl_module, batch,
                                      batch_idx, dataloader_idx):
            _check_args(trainer, pl_module)
            self.on_validation_batch_start_called = True

        def on_validation_batch_end(self, trainer, pl_module, batch, batch_idx,
                                    dataloader_idx):
            _check_args(trainer, pl_module)
            self.on_validation_batch_end_called = True

        def on_test_batch_start(self, trainer, pl_module, batch, batch_idx,
                                dataloader_idx):
            _check_args(trainer, pl_module)
            self.on_test_batch_start_called = True

        def on_test_batch_end(self, trainer, pl_module, batch, batch_idx,
                              dataloader_idx):
            _check_args(trainer, pl_module)
            self.on_test_batch_end_called = True

        def on_train_start(self, trainer, pl_module):
            _check_args(trainer, pl_module)
            self.on_train_start_called = True

        def on_train_end(self, trainer, pl_module):
            _check_args(trainer, pl_module)
            self.on_train_end_called = True

        def on_pretrain_routine_start(self, trainer, pl_module):
            _check_args(trainer, pl_module)
            self.on_pretrain_routine_start_called = True

        def on_pretrain_routine_end(self, trainer, pl_module):
            _check_args(trainer, pl_module)
            self.on_pretrain_routine_end_called = True

        def on_validation_start(self, trainer, pl_module):
            _check_args(trainer, pl_module)
            self.on_validation_start_called = True

        def on_validation_end(self, trainer, pl_module):
            _check_args(trainer, pl_module)
            self.on_validation_end_called = True

        def on_test_start(self, trainer, pl_module):
            _check_args(trainer, pl_module)
            self.on_test_start_called = True

        def on_test_end(self, trainer, pl_module):
            _check_args(trainer, pl_module)
            self.on_test_end_called = True

    test_callback = TestCallback()

    trainer_options = dict(
        default_root_dir=tmpdir,
        callbacks=[test_callback],
        max_epochs=1,
        limit_val_batches=0.1,
        limit_train_batches=0.2,
        progress_bar_refresh_rate=0,
    )

    assert not test_callback.setup_called
    assert not test_callback.teardown_called
    assert not test_callback.on_init_start_called
    assert not test_callback.on_init_end_called
    assert not test_callback.on_fit_start_called
    assert not test_callback.on_fit_end_called
    assert not test_callback.on_sanity_check_start_called
    assert not test_callback.on_sanity_check_end_called
    assert not test_callback.on_epoch_start_called
    assert not test_callback.on_epoch_start_called
    assert not test_callback.on_batch_start_called
    assert not test_callback.on_batch_end_called
    assert not test_callback.on_train_batch_start_called
    assert not test_callback.on_train_batch_end_called
    assert not test_callback.on_validation_batch_start_called
    assert not test_callback.on_validation_batch_end_called
    assert not test_callback.on_test_batch_start_called
    assert not test_callback.on_test_batch_end_called
    assert not test_callback.on_train_start_called
    assert not test_callback.on_train_end_called
    assert not test_callback.on_pretrain_routine_start_called
    assert not test_callback.on_pretrain_routine_end_called
    assert not test_callback.on_validation_start_called
    assert not test_callback.on_validation_end_called
    assert not test_callback.on_test_start_called
    assert not test_callback.on_test_end_called

    # fit model
    trainer = Trainer(**trainer_options)

    assert trainer.callbacks[0] == test_callback
    assert test_callback.on_init_start_called
    assert test_callback.on_init_end_called
    assert not test_callback.setup_called
    assert not test_callback.teardown_called
    assert not test_callback.on_fit_start_called
    assert not test_callback.on_fit_end_called
    assert not test_callback.on_sanity_check_start_called
    assert not test_callback.on_sanity_check_end_called
    assert not test_callback.on_epoch_start_called
    assert not test_callback.on_epoch_start_called
    assert not test_callback.on_batch_start_called
    assert not test_callback.on_batch_end_called
    assert not test_callback.on_train_batch_start_called
    assert not test_callback.on_train_batch_end_called
    assert not test_callback.on_validation_batch_start_called
    assert not test_callback.on_validation_batch_end_called
    assert not test_callback.on_test_batch_start_called
    assert not test_callback.on_test_batch_end_called
    assert not test_callback.on_train_start_called
    assert not test_callback.on_train_end_called
    assert not test_callback.on_pretrain_routine_start_called
    assert not test_callback.on_pretrain_routine_end_called
    assert not test_callback.on_validation_start_called
    assert not test_callback.on_validation_end_called
    assert not test_callback.on_test_start_called
    assert not test_callback.on_test_end_called

    trainer.fit(model)

    assert test_callback.setup_called
    assert test_callback.teardown_called
    assert test_callback.on_init_start_called
    assert test_callback.on_init_end_called
    assert test_callback.on_fit_start_called
    assert test_callback.on_fit_end_called
    assert test_callback.on_sanity_check_start_called
    assert test_callback.on_sanity_check_end_called
    assert test_callback.on_epoch_start_called
    assert test_callback.on_epoch_start_called
    assert test_callback.on_batch_start_called
    assert test_callback.on_batch_end_called
    assert test_callback.on_train_batch_start_called
    assert test_callback.on_train_batch_end_called
    assert test_callback.on_validation_batch_start_called
    assert test_callback.on_validation_batch_end_called
    assert test_callback.on_train_start_called
    assert test_callback.on_train_end_called
    assert test_callback.on_pretrain_routine_start_called
    assert test_callback.on_pretrain_routine_end_called
    assert test_callback.on_validation_start_called
    assert test_callback.on_validation_end_called
    assert not test_callback.on_test_batch_start_called
    assert not test_callback.on_test_batch_end_called
    assert not test_callback.on_test_start_called
    assert not test_callback.on_test_end_called

    # reset setup teardown callback
    test_callback.teardown_called = False
    test_callback.setup_called = False

    test_callback = TestCallback()
    trainer_options.update(callbacks=[test_callback])
    trainer = Trainer(**trainer_options)
    trainer.test(model)

    assert test_callback.setup_called
    assert test_callback.teardown_called
    assert test_callback.on_test_batch_start_called
    assert test_callback.on_test_batch_end_called
    assert test_callback.on_test_start_called
    assert test_callback.on_test_end_called
    assert not test_callback.on_validation_start_called
    assert not test_callback.on_validation_end_called
    assert not test_callback.on_validation_batch_end_called
    assert not test_callback.on_validation_batch_start_called
コード例 #10
0
def test_dataloaders_with_limit_num_batches(tmpdir, limit_train_batches,
                                            limit_val_batches,
                                            limit_test_batches):
    """Verify num_batches for train, val & test dataloaders passed with batch limit as number"""
    os.environ['PL_DEV_DEBUG'] = '1'

    model = EvalModelTemplate()
    model.val_dataloader = model.val_dataloader__multiple_mixed_length
    model.test_dataloader = model.test_dataloader__multiple_mixed_length
    model.validation_step = model.validation_step__multiple_dataloaders
    model.validation_epoch_end = model.validation_epoch_end__multiple_dataloaders
    model.test_step = model.test_step__multiple_dataloaders
    model.test_epoch_end = model.test_epoch_end__multiple_dataloaders

    # train, multiple val and multiple test passed with percent_check
    trainer = Trainer(
        default_root_dir=tmpdir,
        max_epochs=1,
        limit_train_batches=limit_train_batches,
        limit_val_batches=limit_val_batches,
        limit_test_batches=limit_test_batches,
    )
    trainer.fit(model)

    # -------------------------------------------
    # MAKE SURE THE TRAINER SET THE CORRECT VALUES
    # -------------------------------------------
    assert trainer.num_training_batches == limit_train_batches
    assert trainer.num_val_batches == [limit_val_batches] * len(
        trainer.val_dataloaders)
    trainer.test(ckpt_path=None)

    # when the limit is greater than the number of test batches it should be the num in loaders
    test_dataloader_lengths = [len(x) for x in model.test_dataloader()]
    if limit_test_batches > 1e10:
        assert trainer.num_test_batches == test_dataloader_lengths
    else:
        assert trainer.num_test_batches == [limit_test_batches] * len(
            trainer.test_dataloaders)

    # -------------------------------------------
    # make sure we actually saw the expected num of batches
    # -------------------------------------------
    num_val_dataloaders = len(model.val_dataloader())
    num_test_dataloaders = len(model.test_dataloader())
    if limit_train_batches > 0:

        # make sure val batches are as expected
        assert len(trainer.dev_debugger.num_seen_val_check_batches
                   ) == num_val_dataloaders
        for dataloader_idx, num_batches in trainer.dev_debugger.num_seen_val_check_batches.items(
        ):
            assert num_batches == limit_val_batches

        # make sure test batches are as expected
        assert len(trainer.dev_debugger.num_seen_test_check_batches
                   ) == num_test_dataloaders
        for dataloader_idx, num_batches in trainer.dev_debugger.num_seen_test_check_batches.items(
        ):
            if limit_test_batches > 1e10:
                assert num_batches == test_dataloader_lengths[dataloader_idx]
            else:
                assert num_batches == limit_test_batches
コード例 #11
0
def test_model_pickle(tmpdir):
    import pickle

    model = EvalModelTemplate()
    pickle.dumps(model)
コード例 #12
0
def test_cpu_slurm_save_load(tmpdir):
    """Verify model save/load/checkpoint on CPU."""
    hparams = tutils.get_default_hparams()
    model = EvalModelTemplate(hparams)

    # logger file to get meta
    logger = tutils.get_default_logger(tmpdir)
    version = logger.version

    # fit model
    trainer = Trainer(
        max_epochs=1,
        logger=logger,
        checkpoint_callback=ModelCheckpoint(tmpdir)
    )
    result = trainer.fit(model)
    real_global_step = trainer.global_step

    # traning complete
    assert result == 1, 'cpu model failed to complete'

    # predict with trained model before saving
    # make a prediction
    dataloaders = model.test_dataloader()
    if not isinstance(dataloaders, list):
        dataloaders = [dataloaders]

    for dataloader in dataloaders:
        for batch in dataloader:
            break

    x, y = batch
    x = x.view(x.size(0), -1)

    model.eval()
    pred_before_saving = model(x)

    # test HPC saving
    # simulate snapshot on slurm
    saved_filepath = trainer.hpc_save(tmpdir, logger)
    assert os.path.exists(saved_filepath)

    # new logger file to get meta
    logger = tutils.get_default_logger(tmpdir, version=version)

    trainer = Trainer(
        max_epochs=1,
        logger=logger,
        checkpoint_callback=ModelCheckpoint(tmpdir),
    )
    model = EvalModelTemplate(hparams)

    # set the epoch start hook so we can predict before the model does the full training
    def assert_pred_same():
        assert trainer.global_step == real_global_step and trainer.global_step > 0

        # predict with loaded model to make sure answers are the same
        trainer.model.eval()
        new_pred = trainer.model(x)
        assert torch.all(torch.eq(pred_before_saving, new_pred)).item() == 1

    model.on_epoch_start = assert_pred_same

    # by calling fit again, we trigger training, loading weights from the cluster
    # and our hook to predict using current model before any more weight updates
    trainer.fit(model)