def test_async_dataloader(tmpdir): ds = CIFAR10(tmpdir) if torch.cuda.device_count() > 0: # Can only run this test with a GPU device = torch.device('cuda', 0) dataloader = AsynchronousLoader(ds, device=device) for b in dataloader: pass dataloader = AsynchronousLoader(DataLoader(ds, batch_size=16), device=device) for b in dataloader: pass
def test_dev_datasets(datadir): ds = CIFAR10(data_dir=datadir) for _ in ds: pass