Example #1
0
def test_error_raised_with_float_limited_eval_batches():
    """Test that an error is raised if there are not enough batches when passed with float value of
    limit_eval_batches."""
    model = BoringModel()
    dl_size = len(model.val_dataloader())
    limit_val_batches = 1 / (dl_size + 2)
    trainer = Trainer(limit_val_batches=limit_val_batches)
    trainer._data_connector.attach_data(model)
    with pytest.raises(
        MisconfigurationException,
        match=fr"{limit_val_batches} \* {dl_size} < 1. Please increase the `limit_val_batches`",
    ):
        trainer._reset_eval_dataloader(RunningStage.VALIDATING, model)
def test_overfit_batch_limits_eval(stage, mode, overfit_batches):
    model = ClassificationModel()
    dm = ClassifDataModule()
    eval_loader = getattr(dm, f"{mode}_dataloader")()
    trainer = Trainer(overfit_batches=overfit_batches)
    model.trainer = trainer
    trainer._data_connector.attach_datamodule(model, datamodule=dm)

    loader_num_batches, dataloaders = trainer._reset_eval_dataloader(
        stage, model=model)
    if stage == RunningStage.VALIDATING:
        assert loader_num_batches[0] == 0
    else:
        assert loader_num_batches[0] == len(eval_loader)
        assert isinstance(dataloaders[0].sampler, SequentialSampler)
def test_eval_limit_batches(stage, mode, limit_batches):
    limit_eval_batches = f"limit_{mode}_batches"
    dl_hook = f"{mode}_dataloader"
    model = BoringModel()
    eval_loader = getattr(model, dl_hook)()

    trainer = Trainer(**{limit_eval_batches: limit_batches})
    model.trainer = trainer
    trainer._data_connector.attach_dataloaders(model)
    loader_num_batches, dataloaders = trainer._reset_eval_dataloader(
        stage, model=model)
    expected_batches = int(limit_batches * len(eval_loader)) if isinstance(
        limit_batches, float) else limit_batches
    assert loader_num_batches[0] == expected_batches
    assert len(dataloaders[0]) == len(eval_loader)
Example #4
0
def test_overfit_batch_limits(tmpdir):
    # ------------------------------------------------------
    # Make sure shuffle is correct across loaders initially
    # ------------------------------------------------------
    model = EvalModelTemplate()
    model.train_dataloader()

    # original train loader which should be replaced in all methods
    train_loader = model.train_dataloader()

    # make sure the val and tests are not shuffled
    assert isinstance(train_loader.sampler, RandomSampler)
    assert isinstance(model.val_dataloader().sampler, SequentialSampler)
    assert isinstance(model.test_dataloader().sampler, SequentialSampler)

    # ------------------------------------------------------
    # get the training loader and batch
    # ------------------------------------------------------
    # Create a reference train dataloader without shuffling.
    train_loader = DataLoader(model.train_dataloader().dataset, shuffle=False)
    (xa, ya) = next(iter(train_loader))
    train_loader = DataLoader(model.train_dataloader().dataset, shuffle=True)
    full_train_samples = len(train_loader)
    num_train_samples = int(0.11 * full_train_samples)

    # ------------------------------------------------------
    # set VAL and Test loaders
    # ------------------------------------------------------
    val_loader = DataLoader(model.val_dataloader().dataset, shuffle=False)
    test_loader = DataLoader(model.test_dataloader().dataset, shuffle=False)

    # set the model loaders
    model.train_dataloader = lambda: train_loader
    model.val_dataloader = lambda: val_loader
    model.test_dataloader = lambda: test_loader

    # ------------------------------------------------------
    # test train loader applies correct limits
    # ------------------------------------------------------
    trainer = Trainer(overfit_batches=4)
    model.trainer = trainer
    trainer._data_connector.attach_dataloaders(model=model)
    trainer.reset_train_dataloader(model)
    assert trainer.num_training_batches == 4

    # make sure the loaders are the same
    (xb, yb) = next(iter(trainer.train_dataloader))
    assert torch.eq(xa, xb).all()
    assert torch.eq(ya, yb).all()

    trainer = Trainer(overfit_batches=0.11)
    model.trainer = trainer
    trainer._data_connector.attach_dataloaders(model=model)
    trainer.reset_train_dataloader(model)
    # The dataloader should have been overwritten with a Sequential sampler.
    assert trainer.train_dataloader is not train_loader
    assert trainer.num_training_batches == num_train_samples

    # make sure the loaders are the same
    (xb, yb) = next(iter(trainer.train_dataloader))
    assert torch.eq(xa, xb).all()
    assert torch.eq(ya, yb).all()

    # ------------------------------------------------------
    # run tests for both val and test
    # ------------------------------------------------------
    for split in (RunningStage.VALIDATING, RunningStage.TESTING):

        # ------------------------------------------------------
        # test overfit_batches as percent
        # ------------------------------------------------------
        trainer = Trainer(overfit_batches=0.11)
        trainer._data_connector.attach_dataloaders(model)
        loader_num_batches, dataloaders = trainer._reset_eval_dataloader(
            split, model=model)
        assert loader_num_batches[0] == num_train_samples

        # make sure we turned off shuffle for the user
        assert isinstance(dataloaders[0].sampler, SequentialSampler)

        # make sure the loaders are the same
        (xb, yb) = next(iter(dataloaders[0]))
        assert torch.eq(xa, xb).all()
        assert torch.eq(ya, yb).all()

        # ------------------------------------------------------
        # test overfit_batches as int
        # ------------------------------------------------------
        trainer = Trainer(overfit_batches=1)
        trainer._data_connector.attach_dataloaders(model)
        loader_num_batches, dataloaders = trainer._reset_eval_dataloader(
            split, model=model)
        assert loader_num_batches[0] == 1
        trainer = Trainer(overfit_batches=5)
        trainer._data_connector.attach_dataloaders(model)
        loader_num_batches, dataloaders = trainer._reset_eval_dataloader(
            split, model=model)
        assert loader_num_batches[0] == 5

        # ------------------------------------------------------
        # test limit_xxx_batches as percent AND int
        # ------------------------------------------------------
        if split == RunningStage.VALIDATING:
            trainer = Trainer(limit_val_batches=0.1)
            trainer._data_connector.attach_dataloaders(model)
            loader_num_batches, dataloaders = trainer._reset_eval_dataloader(
                split, model=model)
            assert loader_num_batches[0] == int(0.1 * len(val_loader))

            trainer = Trainer(limit_val_batches=10)
            trainer._data_connector.attach_dataloaders(model)
            loader_num_batches, dataloaders = trainer._reset_eval_dataloader(
                split, model=model)
            assert loader_num_batches[0] == 10
        else:
            trainer = Trainer(limit_test_batches=0.1)
            trainer._data_connector.attach_dataloaders(model)
            loader_num_batches, dataloaders = trainer._reset_eval_dataloader(
                split, model=model)
            assert loader_num_batches[0] == int(0.1 * len(test_loader))

            trainer = Trainer(limit_test_batches=10)
            trainer._data_connector.attach_dataloaders(model)
            loader_num_batches, dataloaders = trainer._reset_eval_dataloader(
                split, model=model)
            assert loader_num_batches[0] == 10
Example #5
0
def test_non_sequential_sampler_warning_is_raised_for_eval_dataloader(val_dl):
    trainer = Trainer()
    model = BoringModel()
    trainer._data_connector.attach_data(model, val_dataloaders=val_dl)
    with pytest.warns(PossibleUserWarning, match="recommended .* turn this off for val/test/predict"):
        trainer._reset_eval_dataloader(RunningStage.VALIDATING, model)