コード例 #1
0
def test_rgb8_np_vs_torch():
    ds = wds.Dataset(local_data).decode("rgb8").to_tuple("png;jpg", "cls")
    image, cls = next(iter(ds))
    assert isinstance(image, np.ndarray), type(image)
    assert isinstance(cls, int), type(cls)
    ds = wds.Dataset(local_data).decode("torchrgb8").to_tuple("png;jpg", "cls")
    image2, cls2 = next(iter(ds))
    assert isinstance(image2, torch.Tensor), type(image2)
    assert isinstance(cls, int), type(cls)
    assert (image == image2.permute(1, 2, 0).numpy()).all, (image.shape,
                                                            image2.shape)
    assert cls == cls2
コード例 #2
0
def test_dataset_shuffle_decode_rename_extract():
    ds = (wds.Dataset(local_data).shuffle(5).decode("rgb").rename(
        image="png;jpg", cls="cls").to_tuple("image", "cls"))
    assert count_samples_tuple(ds) == 47
    image, cls = next(iter(ds))
    assert isinstance(image, np.ndarray), image
    assert isinstance(cls, int), type(cls)
コード例 #3
0
def test_decoder():
    def mydecoder(sample):
        return {k: len(v) for k, v in sample.items()}

    ds = (wds.Dataset(remote_loc +
                      "imagenet_train-0050.tgz").decode(mydecoder).to_tuple(
                          "jpg;png", "cls"))
    for sample in ds:
        assert isinstance(sample[0], int)
        break
コード例 #4
0
def test_opener():
    def opener(url):
        print(url, file=sys.stderr)
        cmd = "curl -s '{}imagenet_train-{}.tgz'".format(remote_loc, url)
        return subprocess.Popen(cmd,
                                bufsize=1000000,
                                shell=True,
                                stdout=subprocess.PIPE).stdout

    ds = wds.Dataset("{0000..0147}",
                     opener=opener).shuffle(100).to_tuple("jpg;png cls")
    assert count_samples_tuple(ds, n=10) == 10
コード例 #5
0
def test_handlers():
    handlers = dict(autodecode.default_handlers["rgb"])

    def decode_jpg_and_resize(data):
        return PIL.Image.open(io.BytesIO(data)).resize((128, 128))

    handlers["jpg"] = decode_jpg_and_resize

    ds = (wds.Dataset(remote_loc +
                      "imagenet_train-0050.tgz").decode(handlers).to_tuple(
                          "jpg;png", "cls"))

    for sample in ds:
        assert isinstance(sample[0], PIL.Image.Image)
        break
コード例 #6
0
def test_dataset_decode_handler():
    count = [0]
    good = [0]

    def faulty_decoder(key, data):
        if "png" not in key:
            return data
        count[0] += 1
        if count[0] % 2 == 0:
            raise ValueError("nothing")
        else:
            good[0] += 1
            return data

    ds = wds.Dataset(local_data).decode(faulty_decoder,
                                        handler=wds.ignore_and_continue)
    result = count_samples_tuple(ds)
    assert count[0] == 47
    assert good[0] == 24
    assert result == 24
コード例 #7
0
def test_torchvision():
    import torch
    from torchvision import transforms

    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225])
    preproc = transforms.Compose([
        transforms.RandomResizedCrop(224),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        normalize,
    ])
    ds = (
        wds.Dataset(remote_loc +
                    "imagenet_train-{0000..0147}.tgz").decode("pil").to_tuple(
                        "jpg;png", "cls").map_tuple(preproc, lambda x: x - 1))
    for sample in ds:
        assert isinstance(sample[0], torch.Tensor)
        assert tuple(sample[0].size()) == (3, 224, 224)
        assert isinstance(sample[1], int)
        break