def test_dataset_rename_handler(self): ds = wds.Dataset(local_data).rename(image="png;jpg", cls="cls") count_samples_tuple(ds) with self.assertRaises(ValueError): ds = wds.Dataset(local_data).rename(image="missing", cls="cls") count_samples_tuple(ds)
def test_float_np_vs_torch(): ds = wds.Dataset(local_data).decode("rgb").to_tuple("png;jpg", "cls") image, cls = next(iter(ds)) ds = wds.Dataset(local_data).decode("torchrgb").to_tuple("png;jpg", "cls") image2, cls2 = next(iter(ds)) assert (image == image2.permute(1, 2, 0).numpy()).all(), (image.shape, image2.shape) assert cls == cls2
def test_dataset_map_handler(self): def f(x): assert isinstance(x, dict) return x def g(x): raise ValueError() ds = wds.Dataset(local_data).map(f) count_samples_tuple(ds) with self.assertRaises(ValueError): ds = wds.Dataset(local_data).map(g) count_samples_tuple(ds)
def test_dataset_map_dict_handler(self): ds = wds.Dataset(local_data).map_dict(png=identity, cls=identity) count_samples_tuple(ds) with self.assertRaises(KeyError): ds = wds.Dataset(local_data).map_dict(png=identity, cls2=identity) count_samples_tuple(ds) def g(x): raise ValueError() with self.assertRaises(ValueError): ds = wds.Dataset(local_data).map_dict(png=g, cls=identity) count_samples_tuple(ds)
def test_rgb8_np_vs_torch(): import warnings warnings.filterwarnings("error") ds = wds.Dataset(local_data).decode("rgb8").to_tuple("png;jpg", "cls") image, cls = next(iter(ds)) assert isinstance(image, np.ndarray), type(image) assert isinstance(cls, int), type(cls) ds = wds.Dataset(local_data).decode("torchrgb8").to_tuple("png;jpg", "cls") image2, cls2 = next(iter(ds)) assert isinstance(image2, torch.Tensor), type(image2) assert isinstance(cls, int), type(cls) assert (image == image2.permute(1, 2, 0).numpy()).all, (image.shape, image2.shape) assert cls == cls2
def test_dataset_shuffle_decode_rename_extract(): ds = (wds.Dataset(local_data).shuffle(5).decode("rgb").rename( image="png;jpg", cls="cls").to_tuple("image", "cls")) assert count_samples_tuple(ds) == 47 image, cls = next(iter(ds)) assert isinstance(image, np.ndarray), image assert isinstance(cls, int), type(cls)
def test_dataset_eof(): import tarfile with pytest.raises(tarfile.ReadError): ds = wds.Dataset(f"pipe:dd if={local_data} bs=1024 count=10").shuffle( 5) assert count_samples(ds) == 47
def test_rgb8(): ds = wds.Dataset(local_data).decode("rgb8").to_tuple("png;jpg", "cls") assert count_samples_tuple(ds) == 47 image, cls = next(iter(ds)) assert isinstance(image, np.ndarray), type(image) assert image.dtype == np.uint8, image.dtype assert isinstance(cls, int), type(cls)
def test_writer_pipe(tmpdir): with writer.TarWriter(f"pipe:cat > {tmpdir}/writer3.tar") as sink: sink.write(dict(__key__="a", txt="hello", cls="3")) os.system(f"ls -l {tmpdir}") ds = wds.Dataset(f"pipe:cat {tmpdir}/writer3.tar") for sample in ds: assert set(sample.keys()) == set("__key__ txt cls".split()) break
def test_decoder(): def mydecoder(key, sample): return len(sample) ds = (wds.Dataset(remote_loc + remote_shard).decode(mydecoder).to_tuple( "jpg;png", "json")) for sample in ds: assert isinstance(sample[0], int) break
def test_handlers(): def mydecoder(data): return PIL.Image.open(io.BytesIO(data)).resize((128, 128)) ds = (wds.Dataset(remote_loc + remote_shard).decode( ("jpg", mydecoder)).to_tuple("jpg;png", "json")) for sample in ds: assert isinstance(sample[0], PIL.Image.Image) break
def test_tenbin_dec(): ds = wds.Dataset("webdataset_testdata/tendata.tar").decode().to_tuple( "ten") assert count_samples_tuple(ds) == 100 for sample in ds: xs, ys = sample[0] assert xs.dtype == np.float64 assert ys.dtype == np.float64 assert xs.shape == (28, 28) assert ys.shape == (28, 28)
def test_writer(tmpdir): with writer.TarWriter(f"{tmpdir}/writer.tar") as sink: sink.write(dict(__key__="a", txt="hello", cls="3")) os.system(f"ls -l {tmpdir}") ftype = os.popen(f"file {tmpdir}/writer.tar").read() assert "compress" not in ftype, ftype ds = wds.Dataset(f"{tmpdir}/writer.tar") for sample in ds: assert set(sample.keys()) == set("__key__ txt cls".split()) break
def test_dataset_decode_nohandler(): count = [0] def faulty_decoder(key, data): if count[0] % 2 == 0: raise ValueError("nothing") else: return data count[0] += 1 with pytest.raises(ValueError): ds = wds.Dataset(local_data).decode(faulty_decoder) count_samples_tuple(ds)
def test_opener(): def opener(url): print(url, file=sys.stderr) cmd = "curl -s '{}{}'".format(remote_loc, remote_pattern.format(url)) print(cmd, file=sys.stderr) return subprocess.Popen(cmd, bufsize=1000000, shell=True, stdout=subprocess.PIPE).stdout ds = (wds.Dataset("{000000..000099}", open_fn=opener).shuffle(100).to_tuple("jpg;png", "json")) assert count_samples_tuple(ds, n=10) == 10
def test_writer3(tmpdir): with writer.TarWriter(f"{tmpdir}/writer3.tar") as sink: sink.write(dict(__key__="a", pth=["abc"], pyd=dict(x=0))) os.system(f"ls -l {tmpdir}") os.system(f"tar tvf {tmpdir}/writer3.tar") ftype = os.popen(f"file {tmpdir}/writer3.tar").read() assert "compress" not in ftype, ftype ds = wds.Dataset(f"{tmpdir}/writer3.tar").decode() for sample in ds: assert set(sample.keys()) == set("__key__ pth pyd".split()) assert isinstance(sample["pyd"], dict) assert sample["pyd"] == dict(x=0) assert isinstance(sample["pth"], list) assert sample["pth"] == ["abc"]
def test_unbatched(): from torchvision import transforms normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) preproc = transforms.Compose([ transforms.RandomResizedCrop(224), transforms.RandomHorizontalFlip(), transforms.ToTensor(), normalize, ]) ds = (wds.Dataset(remote_loc + remote_shards).decode("pil").to_tuple( "jpg;png", "json").map_tuple(preproc, identity).batched(7).unbatched()) for sample in ds: assert isinstance(sample[0], torch.Tensor), type(sample[0]) assert tuple(sample[0].size()) == (3, 224, 224), sample[0].size() assert isinstance(sample[1], list), type(sample[1]) break pickle.dumps(ds)
def test_dataset_decode_handler(): count = [0] good = [0] def faulty_decoder(key, data): if "png" not in key: return data count[0] += 1 if count[0] % 2 == 0: raise ValueError("nothing") else: good[0] += 1 return data ds = wds.Dataset(local_data).decode(faulty_decoder, handler=wds.ignore_and_continue) result = count_samples_tuple(ds) assert count[0] == 47 assert good[0] == 24 assert result == 24
def test_writer4(tmpdir): with writer.TarWriter(f"{tmpdir}/writer4.tar") as sink: sink.write( dict(__key__="a", ten=np.zeros((3, 3)), tb=[np.ones(1), np.ones(2)])) os.system(f"ls -l {tmpdir}") os.system(f"tar tvf {tmpdir}/writer4.tar") ftype = os.popen(f"file {tmpdir}/writer4.tar").read() assert "compress" not in ftype, ftype ds = wds.Dataset(f"{tmpdir}/writer4.tar").decode() for sample in ds: assert set(sample.keys()) == set("__key__ tb ten".split()) assert isinstance(sample["ten"], list) assert isinstance(sample["ten"][0], np.ndarray) assert sample["ten"][0].shape == (3, 3) assert isinstance(sample["tb"], list) assert len(sample["tb"]) == 2 assert len(sample["tb"][0]) == 1 assert len(sample["tb"][1]) == 2 assert sample["tb"][0][0] == 1.0
def test_multi(self): for k in [1, 4, 17]: urls = [f"pipe:cat {local_data} # {i}" for i in range(k)] ds = wds.Dataset(urls).decode().shuffle(5).to_tuple("png;jpg cls") mds = multi.MultiDataset(ds, workers=4) assert count_samples_tuple(mds) == 47 * k
def test_dataloader(): import torch ds = wds.Dataset(remote_loc + remote_shards) dl = torch.utils.data.DataLoader(ds, num_workers=4) assert count_samples_tuple(dl, n=100) == 100
def test_shard_syntax(): ds = (wds.Dataset(remote_loc + remote_shards).decode().to_tuple( "jpg;png", "json").shuffle(0)) assert count_samples_tuple(ds, n=10) == 10
def test_pipe(): ds = (wds.Dataset(f"pipe:curl -s '{remote_loc}{remote_shards}'").shuffle( 100).to_tuple("jpg;png", "json")) assert count_samples_tuple(ds, n=10) == 10
def test_dataset_nogrouping(): ds = wds.Dataset(local_data, initial_pipeline=[]) assert count_samples_tuple(ds) == 188
def test_dataset_missing_totuple_raises(): with pytest.raises(ValueError): ds = wds.Dataset(local_data).to_tuple("foo", "bar") count_samples_tuple(ds)
def test_dataset_eof_handler(): ds = wds.Dataset(f"pipe:dd if={local_data} bs=1024 count=10", handler=wds.ignore_and_stop) assert count_samples(ds) < 47
def test_dataset(): ds = wds.Dataset(local_data) assert count_samples_tuple(ds) == 47
def test_dataset_pipe_cat(): ds = wds.Dataset(f"pipe:cat {local_data}").shuffle(5).to_tuple( "png;jpg cls") assert count_samples_tuple(ds) == 47
def test_dataset_shuffle_extract(): ds = wds.Dataset(local_data).shuffle(5).to_tuple("png;jpg cls") assert count_samples_tuple(ds) == 47
def test_dataset_missing_rename_raises(self): with self.assertRaises(ValueError): ds = wds.Dataset(local_data).rename(x="foo", y="bar") count_samples_tuple(ds)