예제 #1
0
def test_slice():
    dataset = wds.DataPipeline(
        wds.SimpleShardList(list(map(str, range(200)))),
        wds.slice(29),
    )
    result = list(iter(dataset))
    assert len(result) == 29
예제 #2
0
def test_shuffle1():
    dataset = wds.DataPipeline(
        wds.SimpleShardList(["testdata/imagenet-000000.tgz"]),
        wds.shuffle(10),
    )
    result = list(iter(dataset))
    assert len(result) == 1
예제 #3
0
def test_resampled():
    dataset = wds.DataPipeline(
        wds.SimpleShardList(list(map(str, range(10)))),
        wds.resampled(27),
    )
    result = list(iter(dataset))
    assert len(result) == 27
예제 #4
0
def test_nonempty2():
    dataset = wds.DataPipeline(
        wds.SimpleShardList(list(map(str, range(10)))),
        lambda src: iter([]),
        wds.non_empty,
    )
    with pytest.raises(ValueError):
        list(iter(dataset))
예제 #5
0
def test_splitting():
    dataset = wds.DataPipeline(
        wds.SimpleShardList(list(map(str, range(10)))),
        wds.split_by_node,
        wds.split_by_worker,
    )
    result = list(iter(dataset))
    assert len(result) == 10
    assert result[0]["url"] == "0"
예제 #6
0
def test_seed():
    dataset = wds.DataPipeline(
        wds.SimpleShardList(list(map(str, range(10)))),
        wds.split_by_node,
        wds.split_by_worker,
    )
    result = list(iter(dataset))
    assert len(result) == 10
    assert result[0]["url"] == "0"
    epoch = 17
    dataset.stage(0).seed = epoch
    result = list(iter(dataset))
    assert len(result) == 10
    assert result[0]["url"] == "7"
예제 #7
0
def test_reader2():
    dataset = wds.DataPipeline(
        wds.SimpleShardList(["testdata/imagenet-000000.tgz"] * 10),
        wds.shuffle(3),
        wds.tarfile_samples,
        wds.shuffle(100),
        wds.decode(autodecode.ImageHandler("rgb")),
        wds.to_tuple("png", "cls"),
    )
    result = list(iter(dataset))
    assert len(result[0]) == 2
    assert isinstance(result[0][0], np.ndarray)
    assert isinstance(result[0][1], int)
    assert len(result) == 470
예제 #8
0
def test_reader1():
    dataset = wds.DataPipeline(
        wds.SimpleShardList("testdata/imagenet-000000.tgz"),
        wds.tarfile_samples,
        wds.decode(autodecode.ImageHandler("rgb")),
    )
    result = list(iter(dataset))
    keys = list(result[0].keys())
    assert "__key__" in keys
    assert "cls" in keys
    assert "png" in keys
    assert isinstance(result[0]["cls"], int)
    assert isinstance(result[0]["png"], np.ndarray)
    assert result[0]["png"].shape == (793, 600, 3)
    assert len(result) == 47
예제 #9
0
def test_composable():
    dataset = wds.DataPipeline(
        wds.SimpleShardList("test-{000000..000099}.tar"), )
    result = list(iter(dataset))
    assert len(result) == 100