def test_num_dataloader_batches(tmpdir): """Tests that the correct number of batches are allocated.""" # when we have fewer batches in the dataloader we should use those instead of the limit model = BoringModel() trainer = Trainer(limit_val_batches=100, limit_train_batches=100, max_epochs=1, default_root_dir=tmpdir) trainer.fit(model) assert len(model.train_dataloader()) == 64 assert len(model.val_dataloader()) == 64 assert isinstance(trainer.num_val_batches, list) assert trainer.num_val_batches[0] == 64 assert trainer.num_training_batches == 64 # when we have more batches in the dataloader we should limit them model = BoringModel() trainer = Trainer(limit_val_batches=7, limit_train_batches=7, max_epochs=1, default_root_dir=tmpdir) trainer.fit(model) assert len(model.train_dataloader()) == 64 assert len(model.val_dataloader()) == 64 assert isinstance(trainer.num_val_batches, list) assert trainer.num_val_batches[0] == 7 assert trainer.num_training_batches == 7
def test_warning_with_few_workers(_, tmpdir, ckpt_path, stage): """ Test that error is raised if dataloader with only a few workers is used """ model = BoringModel() train_dl = model.train_dataloader() train_dl.num_workers = 0 val_dl = model.val_dataloader() val_dl.num_workers = 0 trainer = Trainer( default_root_dir=tmpdir, max_epochs=1, limit_val_batches=0.1, limit_train_batches=0.2, ) with pytest.warns( UserWarning, match=f'The dataloader, {stage} dataloader{" 0" if stage != "train" else ""}, does not have many workers' ): if stage == 'test': ckpt_path = trainer.checkpoint_callback.best_model_path if ckpt_path == 'specific' else ckpt_path trainer.test(model, test_dataloaders=train_dl, ckpt_path=ckpt_path) else: trainer.fit(model, train_dataloader=train_dl, val_dataloaders=val_dl)
def test_dataloader_source_request_from_module(): """Test requesting a dataloader from a module works.""" module = BoringModel() module.trainer = Trainer() module.foo = Mock(return_value=module.train_dataloader()) source = _DataLoaderSource(module, "foo") assert source.is_module() module.foo.assert_not_called() assert isinstance(source.dataloader(), DataLoader) module.foo.assert_called_once()
def test_warning_with_iterable_dataset_and_len(tmpdir): """ Tests that a warning message is shown when an IterableDataset defines `__len__`. """ model = BoringModel() original_dataset = model.train_dataloader().dataset class IterableWithoutLen(IterableDataset): def __iter__(self): return iter(original_dataset) class IterableWithLen(IterableWithoutLen): def __len__(self): return len(original_dataset) # with __len__ defined dataloader = DataLoader(IterableWithLen(), batch_size=16) assert has_len(dataloader) assert has_iterable_dataset(dataloader) trainer = Trainer(default_root_dir=tmpdir, max_steps=3) with pytest.warns(UserWarning, match='Your `IterableDataset` has `__len__` defined.'): trainer.validate(model, val_dataloaders=[dataloader]) with pytest.warns(UserWarning, match='Your `IterableDataset` has `__len__` defined.'): trainer.fit(model, train_dataloader=dataloader, val_dataloaders=[dataloader]) with pytest.warns(UserWarning, match='Your `IterableDataset` has `__len__` defined.'): trainer.test(model, test_dataloaders=[dataloader]) with pytest.warns(UserWarning, match='Your `IterableDataset` has `__len__` defined.'): trainer.predict(model, dataloaders=[dataloader]) # without __len__ defined dataloader = DataLoader(IterableWithoutLen(), batch_size=16) assert not has_len(dataloader) assert has_iterable_dataset(dataloader) trainer = Trainer(default_root_dir=tmpdir, max_steps=3) trainer.validate(model, val_dataloaders=dataloader) trainer.fit(model, train_dataloader=dataloader, val_dataloaders=[dataloader]) trainer.test(model, test_dataloaders=dataloader) trainer.predict(model, dataloaders=dataloader)