def test_get_len(): assert get_len(DataLoader(RandomDataset(1, 1))) == 1 value = get_len(DataLoader(RandomIterableDataset(1, 1))) assert isinstance(value, float) assert value == float("inf")
def test_has_len(): assert has_len(DataLoader(RandomDataset(1, 1))) with pytest.raises(ValueError, match="`Dataloader` returned 0 length."): assert has_len(DataLoader(RandomDataset(0, 0))) assert not has_len(DataLoader(RandomIterableDataset(1, 1)))
def test_has_len(): assert has_len(DataLoader(RandomDataset(1, 1))) with pytest.warns(UserWarning, match="`DataLoader` returned 0 length."): assert has_len(DataLoader(RandomDataset(0, 0))) assert not has_len(DataLoader(RandomIterableDataset(1, 1)))
def test_num_stepping_batches_iterable_dataset(): """Test the stepping batches with iterable dataset configured with max steps.""" max_steps = 1000 trainer = Trainer(max_steps=max_steps) model = BoringModel() train_dl = DataLoader(RandomIterableDataset(size=7, count=1e10)) trainer._data_connector.attach_data(model, train_dataloaders=train_dl) trainer.strategy.connect(model) assert trainer.estimated_stepping_batches == max_steps
def test_has_iterable_dataset(): assert has_iterable_dataset(DataLoader(RandomIterableDataset(1, 1))) assert not has_iterable_dataset(DataLoader(RandomDataset(1, 1))) class MockDatasetWithoutIterableDataset(RandomDataset): def __iter__(self): yield 1 return self assert not has_iterable_dataset( DataLoader(MockDatasetWithoutIterableDataset(1, 1)))
def test_dataloader_kwargs_replacement_with_iterable_dataset(mode): """Test that DataLoader kwargs are not replaced when using Iterable Dataset.""" dataset = RandomIterableDataset(7, 100) dataloader = DataLoader(dataset, batch_size=32) dl_kwargs = _get_dataloader_init_kwargs(dataloader, dataloader.sampler, mode=mode) assert dl_kwargs["sampler"] is None assert dl_kwargs["batch_sampler"] is None assert dl_kwargs["batch_size"] is dataloader.batch_size assert dl_kwargs["dataset"] is dataloader.dataset assert dl_kwargs["collate_fn"] is dataloader.collate_fn
@RunIf(rich=True) def test_rich_progress_bar_refresh_rate(): progress_bar = RichProgressBar(refresh_rate_per_second=1) assert progress_bar.is_enabled assert not progress_bar.is_disabled progress_bar = RichProgressBar(refresh_rate_per_second=0) assert not progress_bar.is_enabled assert progress_bar.is_disabled @RunIf(rich=True) @mock.patch( "pytorch_lightning.callbacks.progress.rich_progress.Progress.update") @pytest.mark.parametrize( "dataset", [RandomDataset(32, 64), RandomIterableDataset(32, 64)]) def test_rich_progress_bar(progress_update, tmpdir, dataset): class TestModel(BoringModel): def train_dataloader(self): return DataLoader(dataset=dataset) def val_dataloader(self): return DataLoader(dataset=dataset) def test_dataloader(self): return DataLoader(dataset=dataset) def predict_dataloader(self): return DataLoader(dataset=dataset) model = TestModel()
assert isinstance(trainer.progress_bar_callback, RichProgressBar) @RunIf(rich=True) def test_rich_progress_bar_refresh_rate_enabled(): progress_bar = RichProgressBar(refresh_rate=1) assert progress_bar.is_enabled assert not progress_bar.is_disabled progress_bar = RichProgressBar(refresh_rate=0) assert not progress_bar.is_enabled assert progress_bar.is_disabled @RunIf(rich=True) @mock.patch("pytorch_lightning.callbacks.progress.rich_progress.Progress.update") @pytest.mark.parametrize("dataset", [RandomDataset(32, 64), RandomIterableDataset(32, 64)]) def test_rich_progress_bar(progress_update, tmpdir, dataset): class TestModel(BoringModel): def train_dataloader(self): return DataLoader(dataset=dataset) def val_dataloader(self): return DataLoader(dataset=dataset) def test_dataloader(self): return DataLoader(dataset=dataset) def predict_dataloader(self): return DataLoader(dataset=dataset) model = TestModel()
def train_dataloader(self): train_ds = (RandomIterableDataset(32, count=max_steps + 100) if use_infinite_dataset else RandomDataset(32, length=data_samples_train)) return DataLoader(train_ds)