def test_multimode(): import torch urls = [local_data] * 8 nsamples = 47 * 8 shardlist = wds.PytorchShardList(urls, verbose=True, epoch_shuffle=True, shuffle=True) os.environ["WDS_EPOCH"] = "7" ds = wds.WebDataset(shardlist) dl = torch.utils.data.DataLoader(ds, num_workers=4) count = count_samples_tuple(dl) assert count == nsamples, count del os.environ["WDS_EPOCH"] shardlist = wds.PytorchShardList(urls, verbose=True, split_by_worker=False) ds = wds.WebDataset(shardlist) dl = torch.utils.data.DataLoader(ds, num_workers=4) count = count_samples_tuple(dl) assert count == 4 * nsamples, count shardlist = shardlists.ResampledShards(urls) ds = wds.WebDataset(shardlist).slice(170) dl = torch.utils.data.DataLoader(ds, num_workers=4) count = count_samples_tuple(dl) assert count == 170 * 4, count
def test_dataset_rename_keep(): ds = wds.WebDataset(local_data).rename(image="png", keep=False) sample = next(iter(ds)) assert getkeys(sample) == set(["image"]), getkeys(sample) ds = wds.WebDataset(local_data).rename(image="png") sample = next(iter(ds)) assert getkeys(sample) == set( "cls image wnid xml".split()), getkeys(sample)
def test_dataset_rename_handler(): ds = wds.WebDataset(local_data).rename(image="png;jpg", cls="cls") count_samples_tuple(ds) with pytest.raises(ValueError): ds = wds.WebDataset(local_data).rename(image="missing", cls="cls") count_samples_tuple(ds)
def test_dataset_rsample(): ds = wds.WebDataset(local_data).rsample(1.0) assert count_samples_tuple(ds) == 47 ds = wds.WebDataset(local_data).rsample(0.5) result = [count_samples_tuple(ds) for _ in range(300)] assert np.mean(result) >= 0.3 * 47 and np.mean( result) <= 0.7 * 47, np.mean(result)
def test_float_np_vs_torch(): ds = wds.WebDataset(local_data).decode("rgb").to_tuple("png;jpg", "cls") image, cls = next(iter(ds)) ds = wds.WebDataset(local_data).decode("torchrgb").to_tuple( "png;jpg", "cls") image2, cls2 = next(iter(ds)) assert (image == image2.permute(1, 2, 0).numpy()).all(), (image.shape, image2.shape) assert cls == cls2
def test_float_np_vs_torch(): ds = wds.WebDataset(local_data, extensions="png;jpg cls") image, cls = next(iter(ds)) ds = wds.WebDataset(local_data, extensions="png;jpg cls", decoder="torchrgb") image2, cls2 = next(iter(ds)) assert (image == image2.permute(1, 2, 0).numpy()).all(), (image.shape, image2.shape) assert cls == cls2
def test_dataset_map_handler(): def f(x): assert isinstance(x, dict) return x def g(x): raise ValueError() ds = wds.WebDataset(local_data).map(f) count_samples_tuple(ds) with pytest.raises(ValueError): ds = wds.WebDataset(local_data).map(g) count_samples_tuple(ds)
def test_rgb8_np_vs_torch(): ds = wds.WebDataset(local_data, extensions="png;jpg cls", decoder="rgb8") image, cls = next(iter(ds)) assert isinstance(image, np.ndarray), type(image) assert isinstance(cls, int), type(cls) ds = wds.WebDataset(local_data, extensions="png;jpg cls", decoder="torchrgb8") image2, cls2 = next(iter(ds)) assert isinstance(image2, torch.Tensor), type(image2) assert isinstance(cls, int), type(cls) assert (image == image2.permute(1, 2, 0).numpy()).all, (image.shape, image2.shape) assert cls == cls2
def test_only1(): ds = wds.WebDataset(local_data).decode(only="cls").to_tuple( "jpg;png", "cls") assert count_samples_tuple(ds) == 47 image, cls = next(iter(ds)) assert isinstance(image, bytes) assert isinstance(cls, int) ds = wds.WebDataset(local_data).decode("l", only=["jpg", "png" ]).to_tuple("jpg;png", "cls") assert count_samples_tuple(ds) == 47 image, cls = next(iter(ds)) assert isinstance(image, np.ndarray) assert isinstance(cls, bytes)
def test_dataset_map_dict_handler(): ds = wds.WebDataset(local_data).map_dict(png=identity, cls=identity) count_samples_tuple(ds) with pytest.raises(KeyError): ds = wds.WebDataset(local_data).map_dict(png=identity, cls2=identity) count_samples_tuple(ds) def g(x): raise ValueError() with pytest.raises(ValueError): ds = wds.WebDataset(local_data).map_dict(png=g, cls=identity) count_samples_tuple(ds)
def test_rgb8_np_vs_torch(): import warnings warnings.filterwarnings("error") ds = wds.WebDataset(local_data).decode("rgb8").to_tuple("png;jpg", "cls") image, cls = next(iter(ds)) assert isinstance(image, np.ndarray), type(image) assert isinstance(cls, int), type(cls) ds = wds.WebDataset(local_data).decode("torchrgb8").to_tuple( "png;jpg", "cls") image2, cls2 = next(iter(ds)) assert isinstance(image2, torch.Tensor), type(image2) assert isinstance(cls, int), type(cls) assert (image == image2.permute(1, 2, 0).numpy()).all, (image.shape, image2.shape) assert cls == cls2
def test_dataset_eof(): import tarfile with pytest.raises(tarfile.ReadError): ds = wds.WebDataset( f"pipe:dd if={local_data} bs=1024 count=10").shuffle(5) assert count_samples(ds) == 47
def test_pipe(): ds = wds.WebDataset( f"pipe:curl -s '{remote_loc}" + "imagenet_train-{0000..0147}.tgz'", extensions="jpg;png cls", shuffle=100, ) assert count_samples(ds, n=10) == 10
def test_shard_syntax(): ds = wds.WebDataset( remote_loc + "imagenet_train-{0000..0147}.tgz", extensions="jpg;png cls", shuffle=0, ) assert count_samples(ds, n=10) == 10
def test_dataloader(): import torch ds = wds.WebDataset(remote_loc + "imagenet_train-{0000..0147}.tgz", decoder=None) dl = torch.utils.data.DataLoader(ds, num_workers=4) assert count_samples(dl, n=100) == 100
def test_rgb8(): ds = wds.WebDataset(local_data).decode("rgb8").to_tuple("png;jpg", "cls") assert count_samples_tuple(ds) == 47 image, cls = next(iter(ds)) assert isinstance(image, np.ndarray), type(image) assert image.dtype == np.uint8, image.dtype assert isinstance(cls, int), type(cls)
def test_dataset_shuffle_decode_rename_extract(): ds = (wds.WebDataset(local_data).shuffle(5).decode("rgb").rename( image="png;jpg", cls="cls").to_tuple("image", "cls")) assert count_samples_tuple(ds) == 47 image, cls = next(iter(ds)) assert isinstance(image, np.ndarray), image assert isinstance(cls, int), type(cls)
def test_length(): ds = wds.WebDataset(local_data) with pytest.raises(TypeError): len(ds) dsl = ds.with_length(1793) assert len(dsl) == 1793 dsl2 = dsl.repeat(17).with_length(19) assert len(dsl2) == 19
def test_writer_pipe(tmpdir): with writer.TarWriter(f"pipe:cat > {tmpdir}/writer3.tar") as sink: sink.write(dict(__key__="a", txt="hello", cls="3")) os.system(f"ls -l {tmpdir}") ds = wds.WebDataset(f"pipe:cat {tmpdir}/writer3.tar") for sample in ds: assert set(sample.keys()) == set("__key__ txt cls".split()) break
def test_log_keys(tmp_path): tmp_path = str(tmp_path) fname = tmp_path + "/test.ds.yml" ds = wds.WebDataset(local_data).log_keys(fname) result = [x for x in ds] assert len(result) == 47 with open(fname) as stream: lines = stream.readlines() assert len(lines) == 47
def test_tenbin_dec(): ds = wds.WebDataset("testdata/tendata.tar", extensions="ten") assert count_samples(ds) == 100 for sample in ds: xs, ys = sample[0] assert xs.dtype == np.float64 assert ys.dtype == np.float64 assert xs.shape == (28, 28) assert ys.shape == (28, 28)
def test_decoder(): def mydecoder(key, sample): return len(sample) ds = wds.WebDataset(remote_loc + remote_shard).decode(mydecoder).to_tuple( "jpg;png", "json") for sample in ds: assert isinstance(sample[0], int) break
def test_tenbin_dec(): ds = wds.WebDataset("testdata/tendata.tar").decode().to_tuple("ten") assert count_samples_tuple(ds) == 100 for sample in ds: xs, ys = sample[0] assert xs.dtype == np.float64 assert ys.dtype == np.float64 assert xs.shape == (28, 28) assert ys.shape == (28, 28)
def test_with_epoch(): ds = wds.WebDataset(local_data) for _ in range(10): assert count_samples_tuple(ds) == 47 be = ds.with_epoch(193) for _ in range(10): assert count_samples_tuple(be) == 193 be = ds.with_epoch(2) for _ in range(10): assert count_samples_tuple(be) == 2
def test_dataset_mock(): obj = ("hello", "world") ds = wds.WebDataset(local_data).shuffle(5).to_tuple("png;jpg cls").test( mock_sample=obj) assert count_samples_tuple(ds) == 47 ds.mock_length = 99 ds.mock = True assert count_samples_tuple(ds) == 99 sample = next(iter(ds)) assert sample == obj
def test_writer(tmpdir): with writer.TarWriter(f"{tmpdir}/writer.tar") as sink: sink.write(dict(__key__="a", txt="hello", cls="3")) os.system(f"ls -l {tmpdir}") ftype = os.popen(f"file {tmpdir}/writer.tar").read() assert "compress" not in ftype, ftype ds = wds.WebDataset(f"{tmpdir}/writer.tar") for sample in ds: assert set(sample.keys()) == set("__key__ txt cls".split()) break
def test_associate(): with open("testdata/imagenet-extra.json") as stream: extra_data = simplejson.load(stream) def associate(key): return dict(MY_EXTRA_DATA=extra_data[key]) ds = wds.WebDataset(local_data, associate=associate) for sample in ds: assert "MY_EXTRA_DATA" in sample.keys() break
def test_decoder(): def mydecoder(sample): return {k: len(v) for k, v in sample.items()} ds = wds.WebDataset( remote_loc + "imagenet_train-0050.tgz", extensions="jpg;png cls", decoder=mydecoder, ) for sample in ds: assert isinstance(sample[0], int) break
def test_handlers(): def mydecoder(data): return PIL.Image.open(io.BytesIO(data)).resize((128, 128)) ds = (wds.WebDataset(remote_loc + remote_shard).decode( autodecode.handle_extension("jpg", mydecoder), autodecode.handle_extension("png", mydecoder), ).to_tuple("jpg;png", "json")) for sample in ds: assert isinstance(sample[0], PIL.Image.Image) break
def test_dataset_decode_nohandler(): count = [0] def faulty_decoder(key, data): if count[0] % 2 == 0: raise ValueError("nothing") else: return data count[0] += 1 with pytest.raises(ValueError): ds = wds.WebDataset(local_data).decode(faulty_decoder) count_samples_tuple(ds)