def test_pax_format(): global test_batch_size num_samples = 1000 tar_file_path = os.path.join(get_dali_extra_path(), "db/webdataset/MNIST/devel-0.tar") pax_tar_file_path = os.path.join(get_dali_extra_path(), "db/webdataset/pax/devel-0.tar") index_file = generate_temp_index_file(tar_file_path) num_shards = 100 for shard_id in range(num_shards): compare_pipelines( webdataset_raw_pipeline( tar_file_path, index_file.name, ["jpg", "cls"], num_shards=num_shards, shard_id=shard_id, batch_size=test_batch_size, device_id=0, num_threads=1, ), webdataset_raw_pipeline( pax_tar_file_path, None, ext=["jpg", "cls"], num_shards=num_shards, shard_id=shard_id, batch_size=test_batch_size, device_id=0, num_threads=1, ), test_batch_size, math.ceil(num_samples / num_shards / test_batch_size) * 2, )
def general_corner_case(test_batch_size=base.test_batch_size, dtypes=None, missing_component_behavior="", **kwargs): num_samples = 1000 tar_file_path = os.path.join(get_dali_extra_path(), "db/webdataset/MNIST/devel-0.tar") index_file = base.generate_temp_index_file(tar_file_path) extract_dir = base.generate_temp_extract(tar_file_path) equivalent_files = sorted( glob(extract_dir.name + "/*"), key=lambda s: int(s[s.rfind("/") + 1:s.rfind(".")])) compare_pipelines( base.webdataset_raw_pipeline( tar_file_path, index_file.name, ["jpg", "cls"], missing_component_behavior=missing_component_behavior, dtypes=dtypes, batch_size=test_batch_size, device_id=0, num_threads=1, **kwargs), base.file_reader_pipeline(equivalent_files, ["jpg", "cls"], batch_size=test_batch_size, device_id=0, num_threads=1, **kwargs), test_batch_size, math.ceil(num_samples / test_batch_size), )
def test_return_empty(): global test_batch_size num_samples = 1000 tar_file_path = os.path.join(get_dali_extra_path(), "db/webdataset/MNIST/missing.tar") index_file = generate_temp_index_file(tar_file_path) extract_dir = generate_temp_extract(tar_file_path) equivalent_files = glob(extract_dir.name + "/*") equivalent_files = sorted( equivalent_files, key=(lambda s: int(s[s.rfind("/") + 1:s.rfind(".")]))) # noqa: 203 compare_pipelines( webdataset_raw_pipeline( tar_file_path, index_file.name, ["jpg", "txt"], batch_size=test_batch_size, device_id=0, num_threads=1, missing_component_behavior="empty", ), file_reader_pipeline(equivalent_files, ["jpg", []], batch_size=test_batch_size, device_id=0, num_threads=1), test_batch_size, math.ceil(num_samples / test_batch_size), )
def test_dtypes(): global test_batch_size num_samples = 100 tar_file_path = os.path.join(get_dali_extra_path(), "db/webdataset/sample-tar/dtypes.tar") index_file = generate_temp_index_file(tar_file_path) wds_pipeline = webdataset_raw_pipeline( tar_file_path, index_file.name, ["float16", "int32", "float64"], dtypes=[dali.types.FLOAT16, dali.types.INT32, dali.types.FLOAT64], batch_size=test_batch_size, device_id=0, num_threads=1, ) wds_pipeline.build() for sample_idx in range(num_samples): if sample_idx % test_batch_size == 0: f16, i32, f64 = wds_pipeline.run() assert (f16.as_array()[sample_idx % test_batch_size] == [float(sample_idx)] * 10).all() assert (i32.as_array()[sample_idx % test_batch_size] == [int(sample_idx)] * 10).all() assert (f64.as_array()[sample_idx % test_batch_size] == [float(sample_idx)] * 10).all()
def test_skip_sample(): global test_batch_size num_samples = 500 tar_file_path = os.path.join(get_dali_extra_path(), "db/webdataset/MNIST/missing.tar") index_file = generate_temp_index_file(tar_file_path) extract_dir = generate_temp_extract(tar_file_path) equivalent_files = list( filter( lambda s: int(s[s.rfind("/") + 1:s.rfind(".")]) < 2500, # noqa: 203 sorted(glob(extract_dir.name + "/*"), key=lambda s: int(s[s.rfind("/") + 1:s.rfind(".")]) ), # noqa: 203 )) compare_pipelines( webdataset_raw_pipeline( tar_file_path, index_file.name, ["jpg", "cls"], missing_component_behavior="skip", batch_size=test_batch_size, device_id=0, num_threads=1, ), file_reader_pipeline(equivalent_files, ["jpg", "cls"], batch_size=test_batch_size, device_id=0, num_threads=1), test_batch_size, math.ceil(num_samples / test_batch_size), ) wds_pipeline = webdataset_raw_pipeline( tar_file_path, index_file.name, ["jpg", "cls"], missing_component_behavior="skip", batch_size=test_batch_size, device_id=0, num_threads=1, ) wds_pipeline.build() assert_equal(list(wds_pipeline.epoch_size().values())[0], num_samples)
def test_wide_sample(): test_batch_size = 1 num_samples = 1 tar_file_path = os.path.join(get_dali_extra_path(), "db/webdataset/sample-tar/wide.tar") index_file = base.generate_temp_index_file(tar_file_path) extract_dir = base.generate_temp_extract(tar_file_path) equivalent_files = list(sorted(glob(extract_dir.name + "/*"))) num_components = 1000 compare_pipelines( base.webdataset_raw_pipeline( tar_file_path, index_file.name, [str(x) for x in range(num_components)], batch_size=test_batch_size, device_id=0, num_threads=1, ), base.file_reader_pipeline( equivalent_files, [str(x) for x in range(num_components)], batch_size=test_batch_size, device_id=0, num_threads=1, ), test_batch_size, math.ceil(num_samples / test_batch_size) * 10, ) wds_pipeline = base.webdataset_raw_pipeline( tar_file_path, index_file.name, ["txt"], batch_size=test_batch_size, device_id=0, num_threads=1, ) wds_pipeline.build() assert_equal(list(wds_pipeline.epoch_size().values())[0], num_samples)
def test_index_generation(): global test_batch_size num_samples = 3000 tar_file_paths = [ os.path.join(get_dali_extra_path(), "db/webdataset/MNIST/devel-0.tar"), os.path.join(get_dali_extra_path(), "db/webdataset/MNIST/devel-1.tar"), os.path.join(get_dali_extra_path(), "db/webdataset/MNIST/devel-2.tar"), ] extract_dirs = [ generate_temp_extract(tar_file_path) for tar_file_path in tar_file_paths ] equivalent_files = sum( list( sorted(glob(extract_dir.name + "/*"), key=lambda s: int(s[s.rfind("/") + 1:s.rfind(".")]) ) # noqa: 203 for extract_dir in extract_dirs), [], ) num_shards = 100 for shard_id in range(num_shards): compare_pipelines( webdataset_raw_pipeline( tar_file_paths, [], ["jpg", "cls"], missing_component_behavior="error", num_shards=num_shards, shard_id=shard_id, batch_size=test_batch_size, device_id=0, num_threads=1, ), file_reader_pipeline( equivalent_files, ["jpg", "cls"], num_shards=num_shards, shard_id=shard_id, batch_size=test_batch_size, device_id=0, num_threads=1, ), test_batch_size, math.ceil(num_samples / num_shards / test_batch_size) * 2, )
def general_index_error(index_file_contents, tar_file_path="db/webdataset/MNIST/devel-0.tar", ext="jpg"): index_file = tempfile.NamedTemporaryFile() index_file.write(index_file_contents) index_file.flush() webdataset_pipeline = base.webdataset_raw_pipeline( os.path.join(get_dali_extra_path(), tar_file_path), index_file.name, ext, batch_size=1, device_id=0, num_threads=1, ) webdataset_pipeline.build() webdataset_pipeline.run() webdataset_pipeline.run()
def paths_index_paths_error(): webdataset_pipeline = base.webdataset_raw_pipeline( [ os.path.join(get_dali_extra_path(), "db/webdataset/MNIST/devel-0.tar"), os.path.join(get_dali_extra_path(), "db/webdataset/MNIST/devel-1.tar"), os.path.join(get_dali_extra_path(), "db/webdataset/MNIST/devel-2.tar"), ], ["test.idx"], ["jpg", "cls"], batch_size=1, device_id=0, num_threads=1, ) webdataset_pipeline.build()
def test_raise_error_on_missing(): global test_batch_size tar_file_path = os.path.join(get_dali_extra_path(), "db/webdataset/MNIST/missing.tar") index_file = generate_temp_index_file(tar_file_path) wds_pipeline = webdataset_raw_pipeline( tar_file_path, index_file.name, ["jpg", "cls"], missing_component_behavior="error", batch_size=test_batch_size, device_id=0, num_threads=1, ) assert_raises(RuntimeError, wds_pipeline.build, glob="Underful sample detected")
def test_wds_sharding(): global test_batch_size num_samples = 3000 tar_file_paths = [ os.path.join(get_dali_extra_path(), "db/webdataset/MNIST/devel-0.tar"), os.path.join(get_dali_extra_path(), "db/webdataset/MNIST/devel-1.tar"), os.path.join(get_dali_extra_path(), "db/webdataset/MNIST/devel-2.tar"), ] index_files = [ generate_temp_index_file(tar_file_path) for tar_file_path in tar_file_paths ] extract_dirs = [ generate_temp_extract(tar_file_path) for tar_file_path in tar_file_paths ] equivalent_files = sum( list( sorted(glob(extract_dir.name + "/*"), key=lambda s: int(s[s.rfind("/") + 1:s.rfind(".")]) ) # noqa: 203 for extract_dir in extract_dirs), [], ) compare_pipelines( webdataset_raw_pipeline( tar_file_paths, [index_file.name for index_file in index_files], ["jpg", "cls"], batch_size=test_batch_size, device_id=0, num_threads=1, ), file_reader_pipeline( equivalent_files, ["jpg", "cls"], batch_size=test_batch_size, device_id=0, num_threads=1, ), test_batch_size, math.ceil(num_samples / test_batch_size), )
def test_sharding(): global test_batch_size num_samples = 1000 tar_file_path = os.path.join(get_dali_extra_path(), "db/webdataset/MNIST/devel-0.tar") index_file = generate_temp_index_file(tar_file_path) extract_dir = generate_temp_extract(tar_file_path) equivalent_files = sorted( glob(extract_dir.name + "/*"), key=lambda s: int(s[s.rfind("/") + 1:s.rfind(".")])) # noqa: 203 num_shards = 100 for shard_id in range(num_shards): compare_pipelines( webdataset_raw_pipeline( tar_file_path, index_file.name, ["jpg", "cls"], num_shards=num_shards, shard_id=shard_id, batch_size=test_batch_size, device_id=0, num_threads=1, ), file_reader_pipeline( equivalent_files, ["jpg", "cls"], num_shards=num_shards, shard_id=shard_id, batch_size=test_batch_size, device_id=0, num_threads=1, ), test_batch_size, math.ceil(num_samples / num_shards / test_batch_size) * 2, )