Ejemplo n.º 1
0
def test_shuffle1():
    dataset = wds.DataPipeline(
        wds.SimpleShardList(["testdata/imagenet-000000.tgz"]),
        wds.shuffle(10),
    )
    result = list(iter(dataset))
    assert len(result) == 1
Ejemplo n.º 2
0
def test_trivial_map4():
    dataset = wds.DataPipeline(
        lambda: iter([1, 2, 3, 4]),
        adder4,
    )
    result = list(iter(dataset))
    assert result == [5, 6, 7, 8]
Ejemplo n.º 3
0
def test_shuffle0():
    dataset = wds.DataPipeline(
        lambda: iter([]),
        wds.shuffle(10),
    )
    result = list(iter(dataset))
    assert len(result) == 0
Ejemplo n.º 4
0
def test_resampled():
    dataset = wds.DataPipeline(
        wds.SimpleShardList(list(map(str, range(10)))),
        wds.resampled(27),
    )
    result = list(iter(dataset))
    assert len(result) == 27
Ejemplo n.º 5
0
def test_slice():
    dataset = wds.DataPipeline(
        wds.SimpleShardList(list(map(str, range(200)))),
        wds.slice(29),
    )
    result = list(iter(dataset))
    assert len(result) == 29
Ejemplo n.º 6
0
def test_nonempty2():
    dataset = wds.DataPipeline(
        wds.SimpleShardList(list(map(str, range(10)))),
        lambda src: iter([]),
        wds.non_empty,
    )
    with pytest.raises(ValueError):
        list(iter(dataset))
Ejemplo n.º 7
0
def test_splitting():
    dataset = wds.DataPipeline(
        wds.SimpleShardList(list(map(str, range(10)))),
        wds.split_by_node,
        wds.split_by_worker,
    )
    result = list(iter(dataset))
    assert len(result) == 10
    assert result[0]["url"] == "0"
Ejemplo n.º 8
0
def test_seed():
    dataset = wds.DataPipeline(
        wds.SimpleShardList(list(map(str, range(10)))),
        wds.split_by_node,
        wds.split_by_worker,
    )
    result = list(iter(dataset))
    assert len(result) == 10
    assert result[0]["url"] == "0"
    epoch = 17
    dataset.stage(0).seed = epoch
    result = list(iter(dataset))
    assert len(result) == 10
    assert result[0]["url"] == "7"
Ejemplo n.º 9
0
def test_reader2():
    dataset = wds.DataPipeline(
        wds.SimpleShardList(["testdata/imagenet-000000.tgz"] * 10),
        wds.shuffle(3),
        wds.tarfile_samples,
        wds.shuffle(100),
        wds.decode(autodecode.ImageHandler("rgb")),
        wds.to_tuple("png", "cls"),
    )
    result = list(iter(dataset))
    assert len(result[0]) == 2
    assert isinstance(result[0][0], np.ndarray)
    assert isinstance(result[0][1], int)
    assert len(result) == 470
Ejemplo n.º 10
0
def test_reader1():
    dataset = wds.DataPipeline(
        wds.SimpleShardList("testdata/imagenet-000000.tgz"),
        wds.tarfile_samples,
        wds.decode(autodecode.ImageHandler("rgb")),
    )
    result = list(iter(dataset))
    keys = list(result[0].keys())
    assert "__key__" in keys
    assert "cls" in keys
    assert "png" in keys
    assert isinstance(result[0]["cls"], int)
    assert isinstance(result[0]["png"], np.ndarray)
    assert result[0]["png"].shape == (793, 600, 3)
    assert len(result) == 47
Ejemplo n.º 11
0
def test_composable():
    dataset = wds.DataPipeline(
        wds.SimpleShardList("test-{000000..000099}.tar"), )
    result = list(iter(dataset))
    assert len(result) == 100
Ejemplo n.º 12
0
def test_trivial_map3():
    dataset = wds.DataPipeline(lambda: iter([1, 2, 3, 4]),
                               wds.stage(iterators.map, lambda x: x + 1))
    result = list(iter(dataset))
    assert result == [2, 3, 4, 5]
Ejemplo n.º 13
0
def test_trivial_map():
    dataset = wds.DataPipeline(lambda: iter([1, 2, 3, 4]),
                               filters.map(lambda x: x + 1))
    result = list(iter(dataset))
    assert result == [2, 3, 4, 5]
Ejemplo n.º 14
0
def test_trivial():
    dataset = wds.DataPipeline(lambda: iter([1, 2, 3, 4]))
    result = list(iter(dataset))
    assert result == [1, 2, 3, 4]