def test_slice(): dataset = wds.DataPipeline( wds.SimpleShardList(list(map(str, range(200)))), wds.slice(29), ) result = list(iter(dataset)) assert len(result) == 29
def test_shuffle1(): dataset = wds.DataPipeline( wds.SimpleShardList(["testdata/imagenet-000000.tgz"]), wds.shuffle(10), ) result = list(iter(dataset)) assert len(result) == 1
def test_resampled(): dataset = wds.DataPipeline( wds.SimpleShardList(list(map(str, range(10)))), wds.resampled(27), ) result = list(iter(dataset)) assert len(result) == 27
def test_nonempty2(): dataset = wds.DataPipeline( wds.SimpleShardList(list(map(str, range(10)))), lambda src: iter([]), wds.non_empty, ) with pytest.raises(ValueError): list(iter(dataset))
def test_splitting(): dataset = wds.DataPipeline( wds.SimpleShardList(list(map(str, range(10)))), wds.split_by_node, wds.split_by_worker, ) result = list(iter(dataset)) assert len(result) == 10 assert result[0]["url"] == "0"
def test_seed(): dataset = wds.DataPipeline( wds.SimpleShardList(list(map(str, range(10)))), wds.split_by_node, wds.split_by_worker, ) result = list(iter(dataset)) assert len(result) == 10 assert result[0]["url"] == "0" epoch = 17 dataset.stage(0).seed = epoch result = list(iter(dataset)) assert len(result) == 10 assert result[0]["url"] == "7"
def test_reader2(): dataset = wds.DataPipeline( wds.SimpleShardList(["testdata/imagenet-000000.tgz"] * 10), wds.shuffle(3), wds.tarfile_samples, wds.shuffle(100), wds.decode(autodecode.ImageHandler("rgb")), wds.to_tuple("png", "cls"), ) result = list(iter(dataset)) assert len(result[0]) == 2 assert isinstance(result[0][0], np.ndarray) assert isinstance(result[0][1], int) assert len(result) == 470
def test_reader1(): dataset = wds.DataPipeline( wds.SimpleShardList("testdata/imagenet-000000.tgz"), wds.tarfile_samples, wds.decode(autodecode.ImageHandler("rgb")), ) result = list(iter(dataset)) keys = list(result[0].keys()) assert "__key__" in keys assert "cls" in keys assert "png" in keys assert isinstance(result[0]["cls"], int) assert isinstance(result[0]["png"], np.ndarray) assert result[0]["png"].shape == (793, 600, 3) assert len(result) == 47
def test_composable(): dataset = wds.DataPipeline( wds.SimpleShardList("test-{000000..000099}.tar"), ) result = list(iter(dataset)) assert len(result) == 100