def test_pax_format():
    global test_batch_size
    num_samples = 1000
    tar_file_path = os.path.join(get_dali_extra_path(),
                                 "db/webdataset/MNIST/devel-0.tar")
    pax_tar_file_path = os.path.join(get_dali_extra_path(),
                                     "db/webdataset/pax/devel-0.tar")
    index_file = generate_temp_index_file(tar_file_path)

    num_shards = 100
    for shard_id in range(num_shards):
        compare_pipelines(
            webdataset_raw_pipeline(
                tar_file_path,
                index_file.name,
                ["jpg", "cls"],
                num_shards=num_shards,
                shard_id=shard_id,
                batch_size=test_batch_size,
                device_id=0,
                num_threads=1,
            ),
            webdataset_raw_pipeline(
                pax_tar_file_path,
                None,
                ext=["jpg", "cls"],
                num_shards=num_shards,
                shard_id=shard_id,
                batch_size=test_batch_size,
                device_id=0,
                num_threads=1,
            ),
            test_batch_size,
            math.ceil(num_samples / num_shards / test_batch_size) * 2,
        )
示例#2
0
def general_corner_case(test_batch_size=base.test_batch_size,
                        dtypes=None,
                        missing_component_behavior="",
                        **kwargs):
    num_samples = 1000
    tar_file_path = os.path.join(get_dali_extra_path(),
                                 "db/webdataset/MNIST/devel-0.tar")
    index_file = base.generate_temp_index_file(tar_file_path)

    extract_dir = base.generate_temp_extract(tar_file_path)
    equivalent_files = sorted(
        glob(extract_dir.name + "/*"),
        key=lambda s: int(s[s.rfind("/") + 1:s.rfind(".")]))

    compare_pipelines(
        base.webdataset_raw_pipeline(
            tar_file_path,
            index_file.name, ["jpg", "cls"],
            missing_component_behavior=missing_component_behavior,
            dtypes=dtypes,
            batch_size=test_batch_size,
            device_id=0,
            num_threads=1,
            **kwargs),
        base.file_reader_pipeline(equivalent_files, ["jpg", "cls"],
                                  batch_size=test_batch_size,
                                  device_id=0,
                                  num_threads=1,
                                  **kwargs),
        test_batch_size,
        math.ceil(num_samples / test_batch_size),
    )
def test_return_empty():
    global test_batch_size
    num_samples = 1000
    tar_file_path = os.path.join(get_dali_extra_path(),
                                 "db/webdataset/MNIST/missing.tar")
    index_file = generate_temp_index_file(tar_file_path)

    extract_dir = generate_temp_extract(tar_file_path)
    equivalent_files = glob(extract_dir.name + "/*")
    equivalent_files = sorted(
        equivalent_files,
        key=(lambda s: int(s[s.rfind("/") + 1:s.rfind(".")])))  # noqa: 203

    compare_pipelines(
        webdataset_raw_pipeline(
            tar_file_path,
            index_file.name,
            ["jpg", "txt"],
            batch_size=test_batch_size,
            device_id=0,
            num_threads=1,
            missing_component_behavior="empty",
        ),
        file_reader_pipeline(equivalent_files, ["jpg", []],
                             batch_size=test_batch_size,
                             device_id=0,
                             num_threads=1),
        test_batch_size,
        math.ceil(num_samples / test_batch_size),
    )
def test_dtypes():
    global test_batch_size
    num_samples = 100
    tar_file_path = os.path.join(get_dali_extra_path(),
                                 "db/webdataset/sample-tar/dtypes.tar")
    index_file = generate_temp_index_file(tar_file_path)

    wds_pipeline = webdataset_raw_pipeline(
        tar_file_path,
        index_file.name,
        ["float16", "int32", "float64"],
        dtypes=[dali.types.FLOAT16, dali.types.INT32, dali.types.FLOAT64],
        batch_size=test_batch_size,
        device_id=0,
        num_threads=1,
    )
    wds_pipeline.build()
    for sample_idx in range(num_samples):
        if sample_idx % test_batch_size == 0:
            f16, i32, f64 = wds_pipeline.run()
        assert (f16.as_array()[sample_idx %
                               test_batch_size] == [float(sample_idx)] *
                10).all()
        assert (i32.as_array()[sample_idx %
                               test_batch_size] == [int(sample_idx)] *
                10).all()
        assert (f64.as_array()[sample_idx %
                               test_batch_size] == [float(sample_idx)] *
                10).all()
def test_skip_sample():
    global test_batch_size
    num_samples = 500
    tar_file_path = os.path.join(get_dali_extra_path(),
                                 "db/webdataset/MNIST/missing.tar")
    index_file = generate_temp_index_file(tar_file_path)

    extract_dir = generate_temp_extract(tar_file_path)
    equivalent_files = list(
        filter(
            lambda s: int(s[s.rfind("/") + 1:s.rfind(".")]) <
            2500,  # noqa: 203
            sorted(glob(extract_dir.name + "/*"),
                   key=lambda s: int(s[s.rfind("/") + 1:s.rfind(".")])
                   ),  # noqa: 203
        ))

    compare_pipelines(
        webdataset_raw_pipeline(
            tar_file_path,
            index_file.name,
            ["jpg", "cls"],
            missing_component_behavior="skip",
            batch_size=test_batch_size,
            device_id=0,
            num_threads=1,
        ),
        file_reader_pipeline(equivalent_files, ["jpg", "cls"],
                             batch_size=test_batch_size,
                             device_id=0,
                             num_threads=1),
        test_batch_size,
        math.ceil(num_samples / test_batch_size),
    )
    wds_pipeline = webdataset_raw_pipeline(
        tar_file_path,
        index_file.name,
        ["jpg", "cls"],
        missing_component_behavior="skip",
        batch_size=test_batch_size,
        device_id=0,
        num_threads=1,
    )
    wds_pipeline.build()
    assert_equal(list(wds_pipeline.epoch_size().values())[0], num_samples)
示例#6
0
def test_wide_sample():
    test_batch_size = 1
    num_samples = 1
    tar_file_path = os.path.join(get_dali_extra_path(),
                                 "db/webdataset/sample-tar/wide.tar")
    index_file = base.generate_temp_index_file(tar_file_path)

    extract_dir = base.generate_temp_extract(tar_file_path)
    equivalent_files = list(sorted(glob(extract_dir.name + "/*")))

    num_components = 1000
    compare_pipelines(
        base.webdataset_raw_pipeline(
            tar_file_path,
            index_file.name,
            [str(x) for x in range(num_components)],
            batch_size=test_batch_size,
            device_id=0,
            num_threads=1,
        ),
        base.file_reader_pipeline(
            equivalent_files,
            [str(x) for x in range(num_components)],
            batch_size=test_batch_size,
            device_id=0,
            num_threads=1,
        ),
        test_batch_size,
        math.ceil(num_samples / test_batch_size) * 10,
    )
    wds_pipeline = base.webdataset_raw_pipeline(
        tar_file_path,
        index_file.name,
        ["txt"],
        batch_size=test_batch_size,
        device_id=0,
        num_threads=1,
    )
    wds_pipeline.build()
    assert_equal(list(wds_pipeline.epoch_size().values())[0], num_samples)
def test_index_generation():
    global test_batch_size
    num_samples = 3000
    tar_file_paths = [
        os.path.join(get_dali_extra_path(), "db/webdataset/MNIST/devel-0.tar"),
        os.path.join(get_dali_extra_path(), "db/webdataset/MNIST/devel-1.tar"),
        os.path.join(get_dali_extra_path(), "db/webdataset/MNIST/devel-2.tar"),
    ]

    extract_dirs = [
        generate_temp_extract(tar_file_path)
        for tar_file_path in tar_file_paths
    ]
    equivalent_files = sum(
        list(
            sorted(glob(extract_dir.name + "/*"),
                   key=lambda s: int(s[s.rfind("/") + 1:s.rfind(".")])
                   )  # noqa: 203
            for extract_dir in extract_dirs),
        [],
    )

    num_shards = 100
    for shard_id in range(num_shards):
        compare_pipelines(
            webdataset_raw_pipeline(
                tar_file_paths,
                [],
                ["jpg", "cls"],
                missing_component_behavior="error",
                num_shards=num_shards,
                shard_id=shard_id,
                batch_size=test_batch_size,
                device_id=0,
                num_threads=1,
            ),
            file_reader_pipeline(
                equivalent_files,
                ["jpg", "cls"],
                num_shards=num_shards,
                shard_id=shard_id,
                batch_size=test_batch_size,
                device_id=0,
                num_threads=1,
            ),
            test_batch_size,
            math.ceil(num_samples / num_shards / test_batch_size) * 2,
        )
示例#8
0
def general_index_error(index_file_contents,
                        tar_file_path="db/webdataset/MNIST/devel-0.tar",
                        ext="jpg"):
    index_file = tempfile.NamedTemporaryFile()
    index_file.write(index_file_contents)
    index_file.flush()
    webdataset_pipeline = base.webdataset_raw_pipeline(
        os.path.join(get_dali_extra_path(), tar_file_path),
        index_file.name,
        ext,
        batch_size=1,
        device_id=0,
        num_threads=1,
    )
    webdataset_pipeline.build()
    webdataset_pipeline.run()
    webdataset_pipeline.run()
示例#9
0
 def paths_index_paths_error():
     webdataset_pipeline = base.webdataset_raw_pipeline(
         [
             os.path.join(get_dali_extra_path(),
                          "db/webdataset/MNIST/devel-0.tar"),
             os.path.join(get_dali_extra_path(),
                          "db/webdataset/MNIST/devel-1.tar"),
             os.path.join(get_dali_extra_path(),
                          "db/webdataset/MNIST/devel-2.tar"),
         ],
         ["test.idx"],
         ["jpg", "cls"],
         batch_size=1,
         device_id=0,
         num_threads=1,
     )
     webdataset_pipeline.build()
def test_raise_error_on_missing():
    global test_batch_size
    tar_file_path = os.path.join(get_dali_extra_path(),
                                 "db/webdataset/MNIST/missing.tar")
    index_file = generate_temp_index_file(tar_file_path)
    wds_pipeline = webdataset_raw_pipeline(
        tar_file_path,
        index_file.name,
        ["jpg", "cls"],
        missing_component_behavior="error",
        batch_size=test_batch_size,
        device_id=0,
        num_threads=1,
    )
    assert_raises(RuntimeError,
                  wds_pipeline.build,
                  glob="Underful sample detected")
def test_wds_sharding():
    global test_batch_size
    num_samples = 3000
    tar_file_paths = [
        os.path.join(get_dali_extra_path(), "db/webdataset/MNIST/devel-0.tar"),
        os.path.join(get_dali_extra_path(), "db/webdataset/MNIST/devel-1.tar"),
        os.path.join(get_dali_extra_path(), "db/webdataset/MNIST/devel-2.tar"),
    ]
    index_files = [
        generate_temp_index_file(tar_file_path)
        for tar_file_path in tar_file_paths
    ]

    extract_dirs = [
        generate_temp_extract(tar_file_path)
        for tar_file_path in tar_file_paths
    ]
    equivalent_files = sum(
        list(
            sorted(glob(extract_dir.name + "/*"),
                   key=lambda s: int(s[s.rfind("/") + 1:s.rfind(".")])
                   )  # noqa: 203
            for extract_dir in extract_dirs),
        [],
    )

    compare_pipelines(
        webdataset_raw_pipeline(
            tar_file_paths,
            [index_file.name for index_file in index_files],
            ["jpg", "cls"],
            batch_size=test_batch_size,
            device_id=0,
            num_threads=1,
        ),
        file_reader_pipeline(
            equivalent_files,
            ["jpg", "cls"],
            batch_size=test_batch_size,
            device_id=0,
            num_threads=1,
        ),
        test_batch_size,
        math.ceil(num_samples / test_batch_size),
    )
def test_sharding():
    global test_batch_size
    num_samples = 1000
    tar_file_path = os.path.join(get_dali_extra_path(),
                                 "db/webdataset/MNIST/devel-0.tar")
    index_file = generate_temp_index_file(tar_file_path)

    extract_dir = generate_temp_extract(tar_file_path)
    equivalent_files = sorted(
        glob(extract_dir.name + "/*"),
        key=lambda s: int(s[s.rfind("/") + 1:s.rfind(".")]))  # noqa: 203

    num_shards = 100
    for shard_id in range(num_shards):
        compare_pipelines(
            webdataset_raw_pipeline(
                tar_file_path,
                index_file.name,
                ["jpg", "cls"],
                num_shards=num_shards,
                shard_id=shard_id,
                batch_size=test_batch_size,
                device_id=0,
                num_threads=1,
            ),
            file_reader_pipeline(
                equivalent_files,
                ["jpg", "cls"],
                num_shards=num_shards,
                shard_id=shard_id,
                batch_size=test_batch_size,
                device_id=0,
                num_threads=1,
            ),
            test_batch_size,
            math.ceil(num_samples / num_shards / test_batch_size) * 2,
        )