def test_rgb8_np_vs_torch(): ds = wds.Dataset(local_data).decode("rgb8").to_tuple("png;jpg", "cls") image, cls = next(iter(ds)) assert isinstance(image, np.ndarray), type(image) assert isinstance(cls, int), type(cls) ds = wds.Dataset(local_data).decode("torchrgb8").to_tuple("png;jpg", "cls") image2, cls2 = next(iter(ds)) assert isinstance(image2, torch.Tensor), type(image2) assert isinstance(cls, int), type(cls) assert (image == image2.permute(1, 2, 0).numpy()).all, (image.shape, image2.shape) assert cls == cls2
def test_dataset_shuffle_decode_rename_extract(): ds = (wds.Dataset(local_data).shuffle(5).decode("rgb").rename( image="png;jpg", cls="cls").to_tuple("image", "cls")) assert count_samples_tuple(ds) == 47 image, cls = next(iter(ds)) assert isinstance(image, np.ndarray), image assert isinstance(cls, int), type(cls)
def test_decoder(): def mydecoder(sample): return {k: len(v) for k, v in sample.items()} ds = (wds.Dataset(remote_loc + "imagenet_train-0050.tgz").decode(mydecoder).to_tuple( "jpg;png", "cls")) for sample in ds: assert isinstance(sample[0], int) break
def test_opener(): def opener(url): print(url, file=sys.stderr) cmd = "curl -s '{}imagenet_train-{}.tgz'".format(remote_loc, url) return subprocess.Popen(cmd, bufsize=1000000, shell=True, stdout=subprocess.PIPE).stdout ds = wds.Dataset("{0000..0147}", opener=opener).shuffle(100).to_tuple("jpg;png cls") assert count_samples_tuple(ds, n=10) == 10
def test_handlers(): handlers = dict(autodecode.default_handlers["rgb"]) def decode_jpg_and_resize(data): return PIL.Image.open(io.BytesIO(data)).resize((128, 128)) handlers["jpg"] = decode_jpg_and_resize ds = (wds.Dataset(remote_loc + "imagenet_train-0050.tgz").decode(handlers).to_tuple( "jpg;png", "cls")) for sample in ds: assert isinstance(sample[0], PIL.Image.Image) break
def test_dataset_decode_handler(): count = [0] good = [0] def faulty_decoder(key, data): if "png" not in key: return data count[0] += 1 if count[0] % 2 == 0: raise ValueError("nothing") else: good[0] += 1 return data ds = wds.Dataset(local_data).decode(faulty_decoder, handler=wds.ignore_and_continue) result = count_samples_tuple(ds) assert count[0] == 47 assert good[0] == 24 assert result == 24
def test_torchvision(): import torch from torchvision import transforms normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) preproc = transforms.Compose([ transforms.RandomResizedCrop(224), transforms.RandomHorizontalFlip(), transforms.ToTensor(), normalize, ]) ds = ( wds.Dataset(remote_loc + "imagenet_train-{0000..0147}.tgz").decode("pil").to_tuple( "jpg;png", "cls").map_tuple(preproc, lambda x: x - 1)) for sample in ds: assert isinstance(sample[0], torch.Tensor) assert tuple(sample[0].size()) == (3, 224, 224) assert isinstance(sample[1], int) break