def test_error_raised_with_float_limited_eval_batches(): """Test that an error is raised if there are not enough batches when passed with float value of limit_eval_batches.""" model = BoringModel() dl_size = len(model.val_dataloader()) limit_val_batches = 1 / (dl_size + 2) trainer = Trainer(limit_val_batches=limit_val_batches) trainer._data_connector.attach_data(model) with pytest.raises( MisconfigurationException, match=fr"{limit_val_batches} \* {dl_size} < 1. Please increase the `limit_val_batches`", ): trainer._reset_eval_dataloader(RunningStage.VALIDATING, model)
def test_overfit_batch_limits_eval(stage, mode, overfit_batches): model = ClassificationModel() dm = ClassifDataModule() eval_loader = getattr(dm, f"{mode}_dataloader")() trainer = Trainer(overfit_batches=overfit_batches) model.trainer = trainer trainer._data_connector.attach_datamodule(model, datamodule=dm) loader_num_batches, dataloaders = trainer._reset_eval_dataloader( stage, model=model) if stage == RunningStage.VALIDATING: assert loader_num_batches[0] == 0 else: assert loader_num_batches[0] == len(eval_loader) assert isinstance(dataloaders[0].sampler, SequentialSampler)
def test_eval_limit_batches(stage, mode, limit_batches): limit_eval_batches = f"limit_{mode}_batches" dl_hook = f"{mode}_dataloader" model = BoringModel() eval_loader = getattr(model, dl_hook)() trainer = Trainer(**{limit_eval_batches: limit_batches}) model.trainer = trainer trainer._data_connector.attach_dataloaders(model) loader_num_batches, dataloaders = trainer._reset_eval_dataloader( stage, model=model) expected_batches = int(limit_batches * len(eval_loader)) if isinstance( limit_batches, float) else limit_batches assert loader_num_batches[0] == expected_batches assert len(dataloaders[0]) == len(eval_loader)
def test_overfit_batch_limits(tmpdir): # ------------------------------------------------------ # Make sure shuffle is correct across loaders initially # ------------------------------------------------------ model = EvalModelTemplate() model.train_dataloader() # original train loader which should be replaced in all methods train_loader = model.train_dataloader() # make sure the val and tests are not shuffled assert isinstance(train_loader.sampler, RandomSampler) assert isinstance(model.val_dataloader().sampler, SequentialSampler) assert isinstance(model.test_dataloader().sampler, SequentialSampler) # ------------------------------------------------------ # get the training loader and batch # ------------------------------------------------------ # Create a reference train dataloader without shuffling. train_loader = DataLoader(model.train_dataloader().dataset, shuffle=False) (xa, ya) = next(iter(train_loader)) train_loader = DataLoader(model.train_dataloader().dataset, shuffle=True) full_train_samples = len(train_loader) num_train_samples = int(0.11 * full_train_samples) # ------------------------------------------------------ # set VAL and Test loaders # ------------------------------------------------------ val_loader = DataLoader(model.val_dataloader().dataset, shuffle=False) test_loader = DataLoader(model.test_dataloader().dataset, shuffle=False) # set the model loaders model.train_dataloader = lambda: train_loader model.val_dataloader = lambda: val_loader model.test_dataloader = lambda: test_loader # ------------------------------------------------------ # test train loader applies correct limits # ------------------------------------------------------ trainer = Trainer(overfit_batches=4) model.trainer = trainer trainer._data_connector.attach_dataloaders(model=model) trainer.reset_train_dataloader(model) assert trainer.num_training_batches == 4 # make sure the loaders are the same (xb, yb) = next(iter(trainer.train_dataloader)) assert torch.eq(xa, xb).all() assert torch.eq(ya, yb).all() trainer = Trainer(overfit_batches=0.11) model.trainer = trainer trainer._data_connector.attach_dataloaders(model=model) trainer.reset_train_dataloader(model) # The dataloader should have been overwritten with a Sequential sampler. assert trainer.train_dataloader is not train_loader assert trainer.num_training_batches == num_train_samples # make sure the loaders are the same (xb, yb) = next(iter(trainer.train_dataloader)) assert torch.eq(xa, xb).all() assert torch.eq(ya, yb).all() # ------------------------------------------------------ # run tests for both val and test # ------------------------------------------------------ for split in (RunningStage.VALIDATING, RunningStage.TESTING): # ------------------------------------------------------ # test overfit_batches as percent # ------------------------------------------------------ trainer = Trainer(overfit_batches=0.11) trainer._data_connector.attach_dataloaders(model) loader_num_batches, dataloaders = trainer._reset_eval_dataloader( split, model=model) assert loader_num_batches[0] == num_train_samples # make sure we turned off shuffle for the user assert isinstance(dataloaders[0].sampler, SequentialSampler) # make sure the loaders are the same (xb, yb) = next(iter(dataloaders[0])) assert torch.eq(xa, xb).all() assert torch.eq(ya, yb).all() # ------------------------------------------------------ # test overfit_batches as int # ------------------------------------------------------ trainer = Trainer(overfit_batches=1) trainer._data_connector.attach_dataloaders(model) loader_num_batches, dataloaders = trainer._reset_eval_dataloader( split, model=model) assert loader_num_batches[0] == 1 trainer = Trainer(overfit_batches=5) trainer._data_connector.attach_dataloaders(model) loader_num_batches, dataloaders = trainer._reset_eval_dataloader( split, model=model) assert loader_num_batches[0] == 5 # ------------------------------------------------------ # test limit_xxx_batches as percent AND int # ------------------------------------------------------ if split == RunningStage.VALIDATING: trainer = Trainer(limit_val_batches=0.1) trainer._data_connector.attach_dataloaders(model) loader_num_batches, dataloaders = trainer._reset_eval_dataloader( split, model=model) assert loader_num_batches[0] == int(0.1 * len(val_loader)) trainer = Trainer(limit_val_batches=10) trainer._data_connector.attach_dataloaders(model) loader_num_batches, dataloaders = trainer._reset_eval_dataloader( split, model=model) assert loader_num_batches[0] == 10 else: trainer = Trainer(limit_test_batches=0.1) trainer._data_connector.attach_dataloaders(model) loader_num_batches, dataloaders = trainer._reset_eval_dataloader( split, model=model) assert loader_num_batches[0] == int(0.1 * len(test_loader)) trainer = Trainer(limit_test_batches=10) trainer._data_connector.attach_dataloaders(model) loader_num_batches, dataloaders = trainer._reset_eval_dataloader( split, model=model) assert loader_num_batches[0] == 10
def test_non_sequential_sampler_warning_is_raised_for_eval_dataloader(val_dl): trainer = Trainer() model = BoringModel() trainer._data_connector.attach_data(model, val_dataloaders=val_dl) with pytest.warns(PossibleUserWarning, match="recommended .* turn this off for val/test/predict"): trainer._reset_eval_dataloader(RunningStage.VALIDATING, model)