def test_cached_path_local(text_file): # absolute path text_file = str(Path(text_file).resolve()) assert cached_path(text_file) == text_file # relative path text_file = str(Path(__file__).resolve().relative_to(Path(os.getcwd()))) assert cached_path(text_file) == text_file
def test_cached_path_missing_local(tmp_path): # absolute path missing_file = str(tmp_path.resolve() / "__missing_file__.txt") with pytest.raises(FileNotFoundError): cached_path(missing_file) # relative path missing_file = "./__missing_file__.txt" with pytest.raises(FileNotFoundError): cached_path(missing_file)
def test_load_dataset_distributed(self): num_workers = 5 with tempfile.TemporaryDirectory() as tmp_dir: data_name = "csv" data_base_path = os.path.join("datasets", data_name, "dummy", "0.0.0", "dummy_data.zip") local_path = cached_path(data_base_path, cache_dir=tmp_dir, extract_compressed_file=True, force_extract=True) datafiles = { "train": os.path.join(local_path, "dummy_data/train.csv"), "dev": os.path.join(local_path, "dummy_data/dev.csv"), "test": os.path.join(local_path, "dummy_data/test.csv"), } args = data_name, tmp_dir, datafiles with Pool(processes=num_workers ) as pool: # start num_workers processes result = pool.apply_async(distributed_load_dataset, (args, )) dataset = result.get(timeout=20) del result, dataset datasets = pool.map(distributed_load_dataset, [args] * num_workers) for _ in range(len(datasets)): dataset = datasets.pop() del dataset
def test_cached_path_extract(xz_file, tmp_path, text_file): filename = xz_file cache_dir = tmp_path / "cache" download_config = DownloadConfig(cache_dir=cache_dir, extract_compressed_file=True) extracted_filename = cached_path(filename, download_config=download_config) with open(extracted_filename) as f: extracted_file_content = f.read() with open(text_file) as f: expected_file_content = f.read() assert extracted_file_content == expected_file_content
def test_cached_path_extract(compression_format, gz_file, xz_file, zstd_path, tmp_path, text_file): input_paths = {"gzip": gz_file, "xz": xz_file, "zstd": zstd_path} input_path = str(input_paths[compression_format]) cache_dir = tmp_path / "cache" download_config = DownloadConfig(cache_dir=cache_dir, extract_compressed_file=True) extracted_path = cached_path(input_path, download_config=download_config) with open(extracted_path) as f: extracted_file_content = f.read() with open(text_file) as f: expected_file_content = f.read() assert extracted_file_content == expected_file_content
def test_extracted_datasets_path(default_extracted, default_cache_dir, xz_file, tmp_path, monkeypatch): custom_cache_dir = "custom_cache" custom_extracted_dir = "custom_extracted_dir" custom_extracted_path = tmp_path / "custom_extracted_path" if default_extracted: expected = ("downloads" if default_cache_dir else custom_cache_dir, "extracted") else: monkeypatch.setattr("datasets.config.EXTRACTED_DATASETS_DIR", custom_extracted_dir) monkeypatch.setattr("datasets.config.EXTRACTED_DATASETS_PATH", str(custom_extracted_path)) expected = custom_extracted_path.parts[-2:] if default_cache_dir else (custom_cache_dir, custom_extracted_dir) filename = xz_file download_config = ( DownloadConfig(extract_compressed_file=True) if default_cache_dir else DownloadConfig(cache_dir=tmp_path / custom_cache_dir, extract_compressed_file=True) ) extracted_file_path = cached_path(filename, download_config=download_config) assert Path(extracted_file_path).parent.parts[-2:] == expected
def test_dataset_info_available(self, dataset, config_name): with TemporaryDirectory() as tmp_dir: dataset_module = dataset_module_factory(os.path.join( "datasets", dataset), cache_dir=tmp_dir, local_files_only=True) builder_cls = import_main_class(dataset_module.module_path, dataset=True) builder_instance: DatasetBuilder = builder_cls( cache_dir=tmp_dir, name=config_name, hash=dataset_module.hash, ) dataset_info_url = os.path.join( HF_GCP_BASE_URL, builder_instance._relative_data_dir(with_hash=False), config.DATASET_INFO_FILENAME).replace(os.sep, "/") datset_info_path = cached_path(dataset_info_url, cache_dir=tmp_dir) self.assertTrue(os.path.exists(datset_info_path))
def test_cached_path_offline(): with pytest.raises(OfflineModeIsEnabled): cached_path("https://huggingface.co")