def test_return_empty():
    global test_batch_size
    num_samples = 1000
    tar_file_path = os.path.join(get_dali_extra_path(),
                                 "db/webdataset/MNIST/missing.tar")
    index_file = generate_temp_index_file(tar_file_path)

    extract_dir = generate_temp_extract(tar_file_path)
    equivalent_files = glob(extract_dir.name + "/*")
    equivalent_files = sorted(
        equivalent_files,
        key=(lambda s: int(s[s.rfind("/") + 1:s.rfind(".")])))  # noqa: 203

    compare_pipelines(
        webdataset_raw_pipeline(
            tar_file_path,
            index_file.name,
            ["jpg", "txt"],
            batch_size=test_batch_size,
            device_id=0,
            num_threads=1,
            missing_component_behavior="empty",
        ),
        file_reader_pipeline(equivalent_files, ["jpg", []],
                             batch_size=test_batch_size,
                             device_id=0,
                             num_threads=1),
        test_batch_size,
        math.ceil(num_samples / test_batch_size),
    )
Exemplo n.º 2
0
def general_corner_case(test_batch_size=base.test_batch_size,
                        dtypes=None,
                        missing_component_behavior="",
                        **kwargs):
    num_samples = 1000
    tar_file_path = os.path.join(get_dali_extra_path(),
                                 "db/webdataset/MNIST/devel-0.tar")
    index_file = base.generate_temp_index_file(tar_file_path)

    extract_dir = base.generate_temp_extract(tar_file_path)
    equivalent_files = sorted(
        glob(extract_dir.name + "/*"),
        key=lambda s: int(s[s.rfind("/") + 1:s.rfind(".")]))

    compare_pipelines(
        base.webdataset_raw_pipeline(
            tar_file_path,
            index_file.name, ["jpg", "cls"],
            missing_component_behavior=missing_component_behavior,
            dtypes=dtypes,
            batch_size=test_batch_size,
            device_id=0,
            num_threads=1,
            **kwargs),
        base.file_reader_pipeline(equivalent_files, ["jpg", "cls"],
                                  batch_size=test_batch_size,
                                  device_id=0,
                                  num_threads=1,
                                  **kwargs),
        test_batch_size,
        math.ceil(num_samples / test_batch_size),
    )
def test_index_generation():
    global test_batch_size
    num_samples = 3000
    tar_file_paths = [
        os.path.join(get_dali_extra_path(), "db/webdataset/MNIST/devel-0.tar"),
        os.path.join(get_dali_extra_path(), "db/webdataset/MNIST/devel-1.tar"),
        os.path.join(get_dali_extra_path(), "db/webdataset/MNIST/devel-2.tar"),
    ]

    extract_dirs = [
        generate_temp_extract(tar_file_path)
        for tar_file_path in tar_file_paths
    ]
    equivalent_files = sum(
        list(
            sorted(glob(extract_dir.name + "/*"),
                   key=lambda s: int(s[s.rfind("/") + 1:s.rfind(".")])
                   )  # noqa: 203
            for extract_dir in extract_dirs),
        [],
    )

    num_shards = 100
    for shard_id in range(num_shards):
        compare_pipelines(
            webdataset_raw_pipeline(
                tar_file_paths,
                [],
                ["jpg", "cls"],
                missing_component_behavior="error",
                num_shards=num_shards,
                shard_id=shard_id,
                batch_size=test_batch_size,
                device_id=0,
                num_threads=1,
            ),
            file_reader_pipeline(
                equivalent_files,
                ["jpg", "cls"],
                num_shards=num_shards,
                shard_id=shard_id,
                batch_size=test_batch_size,
                device_id=0,
                num_threads=1,
            ),
            test_batch_size,
            math.ceil(num_samples / num_shards / test_batch_size) * 2,
        )
def test_skip_sample():
    global test_batch_size
    num_samples = 500
    tar_file_path = os.path.join(get_dali_extra_path(),
                                 "db/webdataset/MNIST/missing.tar")
    index_file = generate_temp_index_file(tar_file_path)

    extract_dir = generate_temp_extract(tar_file_path)
    equivalent_files = list(
        filter(
            lambda s: int(s[s.rfind("/") + 1:s.rfind(".")]) <
            2500,  # noqa: 203
            sorted(glob(extract_dir.name + "/*"),
                   key=lambda s: int(s[s.rfind("/") + 1:s.rfind(".")])
                   ),  # noqa: 203
        ))

    compare_pipelines(
        webdataset_raw_pipeline(
            tar_file_path,
            index_file.name,
            ["jpg", "cls"],
            missing_component_behavior="skip",
            batch_size=test_batch_size,
            device_id=0,
            num_threads=1,
        ),
        file_reader_pipeline(equivalent_files, ["jpg", "cls"],
                             batch_size=test_batch_size,
                             device_id=0,
                             num_threads=1),
        test_batch_size,
        math.ceil(num_samples / test_batch_size),
    )
    wds_pipeline = webdataset_raw_pipeline(
        tar_file_path,
        index_file.name,
        ["jpg", "cls"],
        missing_component_behavior="skip",
        batch_size=test_batch_size,
        device_id=0,
        num_threads=1,
    )
    wds_pipeline.build()
    assert_equal(list(wds_pipeline.epoch_size().values())[0], num_samples)
def test_wds_sharding():
    global test_batch_size
    num_samples = 3000
    tar_file_paths = [
        os.path.join(get_dali_extra_path(), "db/webdataset/MNIST/devel-0.tar"),
        os.path.join(get_dali_extra_path(), "db/webdataset/MNIST/devel-1.tar"),
        os.path.join(get_dali_extra_path(), "db/webdataset/MNIST/devel-2.tar"),
    ]
    index_files = [
        generate_temp_index_file(tar_file_path)
        for tar_file_path in tar_file_paths
    ]

    extract_dirs = [
        generate_temp_extract(tar_file_path)
        for tar_file_path in tar_file_paths
    ]
    equivalent_files = sum(
        list(
            sorted(glob(extract_dir.name + "/*"),
                   key=lambda s: int(s[s.rfind("/") + 1:s.rfind(".")])
                   )  # noqa: 203
            for extract_dir in extract_dirs),
        [],
    )

    compare_pipelines(
        webdataset_raw_pipeline(
            tar_file_paths,
            [index_file.name for index_file in index_files],
            ["jpg", "cls"],
            batch_size=test_batch_size,
            device_id=0,
            num_threads=1,
        ),
        file_reader_pipeline(
            equivalent_files,
            ["jpg", "cls"],
            batch_size=test_batch_size,
            device_id=0,
            num_threads=1,
        ),
        test_batch_size,
        math.ceil(num_samples / test_batch_size),
    )
Exemplo n.º 6
0
def test_wide_sample():
    test_batch_size = 1
    num_samples = 1
    tar_file_path = os.path.join(get_dali_extra_path(),
                                 "db/webdataset/sample-tar/wide.tar")
    index_file = base.generate_temp_index_file(tar_file_path)

    extract_dir = base.generate_temp_extract(tar_file_path)
    equivalent_files = list(sorted(glob(extract_dir.name + "/*")))

    num_components = 1000
    compare_pipelines(
        base.webdataset_raw_pipeline(
            tar_file_path,
            index_file.name,
            [str(x) for x in range(num_components)],
            batch_size=test_batch_size,
            device_id=0,
            num_threads=1,
        ),
        base.file_reader_pipeline(
            equivalent_files,
            [str(x) for x in range(num_components)],
            batch_size=test_batch_size,
            device_id=0,
            num_threads=1,
        ),
        test_batch_size,
        math.ceil(num_samples / test_batch_size) * 10,
    )
    wds_pipeline = base.webdataset_raw_pipeline(
        tar_file_path,
        index_file.name,
        ["txt"],
        batch_size=test_batch_size,
        device_id=0,
        num_threads=1,
    )
    wds_pipeline.build()
    assert_equal(list(wds_pipeline.epoch_size().values())[0], num_samples)
def test_sharding():
    global test_batch_size
    num_samples = 1000
    tar_file_path = os.path.join(get_dali_extra_path(),
                                 "db/webdataset/MNIST/devel-0.tar")
    index_file = generate_temp_index_file(tar_file_path)

    extract_dir = generate_temp_extract(tar_file_path)
    equivalent_files = sorted(
        glob(extract_dir.name + "/*"),
        key=lambda s: int(s[s.rfind("/") + 1:s.rfind(".")]))  # noqa: 203

    num_shards = 100
    for shard_id in range(num_shards):
        compare_pipelines(
            webdataset_raw_pipeline(
                tar_file_path,
                index_file.name,
                ["jpg", "cls"],
                num_shards=num_shards,
                shard_id=shard_id,
                batch_size=test_batch_size,
                device_id=0,
                num_threads=1,
            ),
            file_reader_pipeline(
                equivalent_files,
                ["jpg", "cls"],
                num_shards=num_shards,
                shard_id=shard_id,
                batch_size=test_batch_size,
                device_id=0,
                num_threads=1,
            ),
            test_batch_size,
            math.ceil(num_samples / num_shards / test_batch_size) * 2,
        )