Example #1
0
def test_overfit_batch_limits_train(overfit_batches):
    model = ClassificationModel()
    dm = ClassifDataModule()

    # original train loader which should be replaced in all methods
    train_loader = dm.train_dataloader()
    assert isinstance(train_loader.sampler, RandomSampler)

    # Create a reference train dataloader without shuffling.
    train_loader = DataLoader(dm.train_dataloader().dataset, shuffle=False)
    (xa, ya) = next(iter(train_loader))
    train_loader = DataLoader(dm.train_dataloader().dataset, shuffle=True)
    full_train_samples = len(train_loader)

    # set the model loaders
    model.train_dataloader = lambda: train_loader

    # test train loader applies correct limits
    trainer = Trainer(overfit_batches=overfit_batches)
    model.trainer = trainer
    trainer._data_connector.attach_dataloaders(model=model)
    trainer.reset_train_dataloader(model)
    expected_batches = (int(overfit_batches *
                            full_train_samples) if isinstance(
                                overfit_batches, float) else overfit_batches)
    assert trainer.num_training_batches == expected_batches

    # make sure the loaders are the same
    (xb, yb) = next(iter(trainer.train_dataloader))
    assert torch.eq(xa, xb).all()
    assert torch.eq(ya, yb).all()
def test_overfit_batch_limits_eval(stage, mode, overfit_batches):
    model = ClassificationModel()
    dm = ClassifDataModule()
    eval_loader = getattr(dm, f"{mode}_dataloader")()
    trainer = Trainer(overfit_batches=overfit_batches)
    model.trainer = trainer
    trainer._data_connector.attach_datamodule(model, datamodule=dm)

    loader_num_batches, dataloaders = trainer._data_connector._reset_eval_dataloader(stage, model=model)
    if stage == RunningStage.VALIDATING:
        assert loader_num_batches[0] == 0
    else:
        assert loader_num_batches[0] == len(eval_loader)
        assert isinstance(dataloaders[0].sampler, SequentialSampler)