コード例 #1
0
    def test_dataset_rename_handler(self):
        ds = wds.Dataset(local_data).rename(image="png;jpg", cls="cls")
        count_samples_tuple(ds)

        with self.assertRaises(ValueError):
            ds = wds.Dataset(local_data).rename(image="missing", cls="cls")
            count_samples_tuple(ds)
コード例 #2
0
def test_float_np_vs_torch():
    ds = wds.Dataset(local_data).decode("rgb").to_tuple("png;jpg", "cls")
    image, cls = next(iter(ds))
    ds = wds.Dataset(local_data).decode("torchrgb").to_tuple("png;jpg", "cls")
    image2, cls2 = next(iter(ds))
    assert (image == image2.permute(1, 2, 0).numpy()).all(), (image.shape,
                                                              image2.shape)
    assert cls == cls2
コード例 #3
0
    def test_dataset_map_handler(self):
        def f(x):
            assert isinstance(x, dict)
            return x

        def g(x):
            raise ValueError()

        ds = wds.Dataset(local_data).map(f)
        count_samples_tuple(ds)

        with self.assertRaises(ValueError):
            ds = wds.Dataset(local_data).map(g)
            count_samples_tuple(ds)
コード例 #4
0
    def test_dataset_map_dict_handler(self):
        ds = wds.Dataset(local_data).map_dict(png=identity, cls=identity)
        count_samples_tuple(ds)

        with self.assertRaises(KeyError):
            ds = wds.Dataset(local_data).map_dict(png=identity, cls2=identity)
            count_samples_tuple(ds)

        def g(x):
            raise ValueError()

        with self.assertRaises(ValueError):
            ds = wds.Dataset(local_data).map_dict(png=g, cls=identity)
            count_samples_tuple(ds)
コード例 #5
0
def test_rgb8_np_vs_torch():
    import warnings

    warnings.filterwarnings("error")
    ds = wds.Dataset(local_data).decode("rgb8").to_tuple("png;jpg", "cls")
    image, cls = next(iter(ds))
    assert isinstance(image, np.ndarray), type(image)
    assert isinstance(cls, int), type(cls)
    ds = wds.Dataset(local_data).decode("torchrgb8").to_tuple("png;jpg", "cls")
    image2, cls2 = next(iter(ds))
    assert isinstance(image2, torch.Tensor), type(image2)
    assert isinstance(cls, int), type(cls)
    assert (image == image2.permute(1, 2, 0).numpy()).all, (image.shape,
                                                            image2.shape)
    assert cls == cls2
コード例 #6
0
def test_dataset_shuffle_decode_rename_extract():
    ds = (wds.Dataset(local_data).shuffle(5).decode("rgb").rename(
        image="png;jpg", cls="cls").to_tuple("image", "cls"))
    assert count_samples_tuple(ds) == 47
    image, cls = next(iter(ds))
    assert isinstance(image, np.ndarray), image
    assert isinstance(cls, int), type(cls)
コード例 #7
0
def test_dataset_eof():
    import tarfile

    with pytest.raises(tarfile.ReadError):
        ds = wds.Dataset(f"pipe:dd if={local_data} bs=1024 count=10").shuffle(
            5)
        assert count_samples(ds) == 47
コード例 #8
0
def test_rgb8():
    ds = wds.Dataset(local_data).decode("rgb8").to_tuple("png;jpg", "cls")
    assert count_samples_tuple(ds) == 47
    image, cls = next(iter(ds))
    assert isinstance(image, np.ndarray), type(image)
    assert image.dtype == np.uint8, image.dtype
    assert isinstance(cls, int), type(cls)
コード例 #9
0
def test_writer_pipe(tmpdir):
    with writer.TarWriter(f"pipe:cat > {tmpdir}/writer3.tar") as sink:
        sink.write(dict(__key__="a", txt="hello", cls="3"))
    os.system(f"ls -l {tmpdir}")
    ds = wds.Dataset(f"pipe:cat {tmpdir}/writer3.tar")
    for sample in ds:
        assert set(sample.keys()) == set("__key__ txt cls".split())
        break
コード例 #10
0
def test_decoder():
    def mydecoder(key, sample):
        return len(sample)

    ds = (wds.Dataset(remote_loc + remote_shard).decode(mydecoder).to_tuple(
        "jpg;png", "json"))
    for sample in ds:
        assert isinstance(sample[0], int)
        break
コード例 #11
0
def test_handlers():
    def mydecoder(data):
        return PIL.Image.open(io.BytesIO(data)).resize((128, 128))

    ds = (wds.Dataset(remote_loc + remote_shard).decode(
        ("jpg", mydecoder)).to_tuple("jpg;png", "json"))

    for sample in ds:
        assert isinstance(sample[0], PIL.Image.Image)
        break
コード例 #12
0
def test_tenbin_dec():
    ds = wds.Dataset("webdataset_testdata/tendata.tar").decode().to_tuple(
        "ten")
    assert count_samples_tuple(ds) == 100
    for sample in ds:
        xs, ys = sample[0]
        assert xs.dtype == np.float64
        assert ys.dtype == np.float64
        assert xs.shape == (28, 28)
        assert ys.shape == (28, 28)
コード例 #13
0
def test_writer(tmpdir):
    with writer.TarWriter(f"{tmpdir}/writer.tar") as sink:
        sink.write(dict(__key__="a", txt="hello", cls="3"))
    os.system(f"ls -l {tmpdir}")
    ftype = os.popen(f"file {tmpdir}/writer.tar").read()
    assert "compress" not in ftype, ftype

    ds = wds.Dataset(f"{tmpdir}/writer.tar")
    for sample in ds:
        assert set(sample.keys()) == set("__key__ txt cls".split())
        break
コード例 #14
0
def test_dataset_decode_nohandler():
    count = [0]

    def faulty_decoder(key, data):
        if count[0] % 2 == 0:
            raise ValueError("nothing")
        else:
            return data
        count[0] += 1

    with pytest.raises(ValueError):
        ds = wds.Dataset(local_data).decode(faulty_decoder)
        count_samples_tuple(ds)
コード例 #15
0
def test_opener():
    def opener(url):
        print(url, file=sys.stderr)
        cmd = "curl -s '{}{}'".format(remote_loc, remote_pattern.format(url))
        print(cmd, file=sys.stderr)
        return subprocess.Popen(cmd,
                                bufsize=1000000,
                                shell=True,
                                stdout=subprocess.PIPE).stdout

    ds = (wds.Dataset("{000000..000099}",
                      open_fn=opener).shuffle(100).to_tuple("jpg;png", "json"))
    assert count_samples_tuple(ds, n=10) == 10
コード例 #16
0
def test_writer3(tmpdir):
    with writer.TarWriter(f"{tmpdir}/writer3.tar") as sink:
        sink.write(dict(__key__="a", pth=["abc"], pyd=dict(x=0)))
    os.system(f"ls -l {tmpdir}")
    os.system(f"tar tvf {tmpdir}/writer3.tar")
    ftype = os.popen(f"file {tmpdir}/writer3.tar").read()
    assert "compress" not in ftype, ftype

    ds = wds.Dataset(f"{tmpdir}/writer3.tar").decode()
    for sample in ds:
        assert set(sample.keys()) == set("__key__ pth pyd".split())
        assert isinstance(sample["pyd"], dict)
        assert sample["pyd"] == dict(x=0)
        assert isinstance(sample["pth"], list)
        assert sample["pth"] == ["abc"]
コード例 #17
0
def test_unbatched():
    from torchvision import transforms

    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225])
    preproc = transforms.Compose([
        transforms.RandomResizedCrop(224),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        normalize,
    ])
    ds = (wds.Dataset(remote_loc + remote_shards).decode("pil").to_tuple(
        "jpg;png", "json").map_tuple(preproc, identity).batched(7).unbatched())
    for sample in ds:
        assert isinstance(sample[0], torch.Tensor), type(sample[0])
        assert tuple(sample[0].size()) == (3, 224, 224), sample[0].size()
        assert isinstance(sample[1], list), type(sample[1])
        break
    pickle.dumps(ds)
コード例 #18
0
def test_dataset_decode_handler():
    count = [0]
    good = [0]

    def faulty_decoder(key, data):
        if "png" not in key:
            return data
        count[0] += 1
        if count[0] % 2 == 0:
            raise ValueError("nothing")
        else:
            good[0] += 1
            return data

    ds = wds.Dataset(local_data).decode(faulty_decoder,
                                        handler=wds.ignore_and_continue)
    result = count_samples_tuple(ds)
    assert count[0] == 47
    assert good[0] == 24
    assert result == 24
コード例 #19
0
def test_writer4(tmpdir):
    with writer.TarWriter(f"{tmpdir}/writer4.tar") as sink:
        sink.write(
            dict(__key__="a",
                 ten=np.zeros((3, 3)),
                 tb=[np.ones(1), np.ones(2)]))
    os.system(f"ls -l {tmpdir}")
    os.system(f"tar tvf {tmpdir}/writer4.tar")
    ftype = os.popen(f"file {tmpdir}/writer4.tar").read()
    assert "compress" not in ftype, ftype

    ds = wds.Dataset(f"{tmpdir}/writer4.tar").decode()
    for sample in ds:
        assert set(sample.keys()) == set("__key__ tb ten".split())
        assert isinstance(sample["ten"], list)
        assert isinstance(sample["ten"][0], np.ndarray)
        assert sample["ten"][0].shape == (3, 3)
        assert isinstance(sample["tb"], list)
        assert len(sample["tb"]) == 2
        assert len(sample["tb"][0]) == 1
        assert len(sample["tb"][1]) == 2
        assert sample["tb"][0][0] == 1.0
コード例 #20
0
 def test_multi(self):
     for k in [1, 4, 17]:
         urls = [f"pipe:cat {local_data} # {i}" for i in range(k)]
         ds = wds.Dataset(urls).decode().shuffle(5).to_tuple("png;jpg cls")
         mds = multi.MultiDataset(ds, workers=4)
         assert count_samples_tuple(mds) == 47 * k
コード例 #21
0
def test_dataloader():
    import torch

    ds = wds.Dataset(remote_loc + remote_shards)
    dl = torch.utils.data.DataLoader(ds, num_workers=4)
    assert count_samples_tuple(dl, n=100) == 100
コード例 #22
0
def test_shard_syntax():
    ds = (wds.Dataset(remote_loc + remote_shards).decode().to_tuple(
        "jpg;png", "json").shuffle(0))
    assert count_samples_tuple(ds, n=10) == 10
コード例 #23
0
def test_pipe():
    ds = (wds.Dataset(f"pipe:curl -s '{remote_loc}{remote_shards}'").shuffle(
        100).to_tuple("jpg;png", "json"))
    assert count_samples_tuple(ds, n=10) == 10
コード例 #24
0
def test_dataset_nogrouping():
    ds = wds.Dataset(local_data, initial_pipeline=[])
    assert count_samples_tuple(ds) == 188
コード例 #25
0
def test_dataset_missing_totuple_raises():
    with pytest.raises(ValueError):
        ds = wds.Dataset(local_data).to_tuple("foo", "bar")
        count_samples_tuple(ds)
コード例 #26
0
def test_dataset_eof_handler():
    ds = wds.Dataset(f"pipe:dd if={local_data} bs=1024 count=10",
                     handler=wds.ignore_and_stop)
    assert count_samples(ds) < 47
コード例 #27
0
def test_dataset():
    ds = wds.Dataset(local_data)
    assert count_samples_tuple(ds) == 47
コード例 #28
0
def test_dataset_pipe_cat():
    ds = wds.Dataset(f"pipe:cat {local_data}").shuffle(5).to_tuple(
        "png;jpg cls")
    assert count_samples_tuple(ds) == 47
コード例 #29
0
def test_dataset_shuffle_extract():
    ds = wds.Dataset(local_data).shuffle(5).to_tuple("png;jpg cls")
    assert count_samples_tuple(ds) == 47
コード例 #30
0
 def test_dataset_missing_rename_raises(self):
     with self.assertRaises(ValueError):
         ds = wds.Dataset(local_data).rename(x="foo", y="bar")
         count_samples_tuple(ds)