def test_num_dataloader_batches(tmpdir): """Tests that the correct number of batches are allocated.""" # when we have fewer batches in the dataloader we should use those instead of the limit model = BoringModel() trainer = Trainer(limit_val_batches=100, limit_train_batches=100, max_epochs=1, default_root_dir=tmpdir) trainer.fit(model) assert len(model.train_dataloader()) == 64 assert len(model.val_dataloader()) == 64 assert isinstance(trainer.num_val_batches, list) assert trainer.num_val_batches[0] == 64 assert trainer.num_training_batches == 64 # when we have more batches in the dataloader we should limit them model = BoringModel() trainer = Trainer(limit_val_batches=7, limit_train_batches=7, max_epochs=1, default_root_dir=tmpdir) trainer.fit(model) assert len(model.train_dataloader()) == 64 assert len(model.val_dataloader()) == 64 assert isinstance(trainer.num_val_batches, list) assert trainer.num_val_batches[0] == 7 assert trainer.num_training_batches == 7
def test_warning_with_few_workers(_, tmpdir, ckpt_path, stage): """ Test that error is raised if dataloader with only a few workers is used """ model = BoringModel() train_dl = model.train_dataloader() train_dl.num_workers = 0 val_dl = model.val_dataloader() val_dl.num_workers = 0 trainer = Trainer( default_root_dir=tmpdir, max_epochs=1, limit_val_batches=0.1, limit_train_batches=0.2, ) with pytest.warns( UserWarning, match=f'The dataloader, {stage} dataloader{" 0" if stage != "train" else ""}, does not have many workers' ): if stage == 'test': ckpt_path = trainer.checkpoint_callback.best_model_path if ckpt_path == 'specific' else ckpt_path trainer.test(model, test_dataloaders=train_dl, ckpt_path=ckpt_path) else: trainer.fit(model, train_dataloader=train_dl, val_dataloaders=val_dl)
def test_error_raised_with_float_limited_eval_batches(): """Test that an error is raised if there are not enough batches when passed with float value of limit_eval_batches.""" model = BoringModel() dl_size = len(model.val_dataloader()) limit_val_batches = 1 / (dl_size + 2) trainer = Trainer(limit_val_batches=limit_val_batches) trainer._data_connector.attach_data(model) with pytest.raises( MisconfigurationException, match= rf"{limit_val_batches} \* {dl_size} < 1. Please increase the `limit_val_batches`", ): trainer._data_connector._reset_eval_dataloader(RunningStage.VALIDATING, model)