Beispiel #1
0
def test_cached_path_local(text_file):
    # absolute path
    text_file = str(Path(text_file).resolve())
    assert cached_path(text_file) == text_file
    # relative path
    text_file = str(Path(__file__).resolve().relative_to(Path(os.getcwd())))
    assert cached_path(text_file) == text_file
Beispiel #2
0
def test_cached_path_missing_local(tmp_path):
    # absolute path
    missing_file = str(tmp_path.resolve() / "__missing_file__.txt")
    with pytest.raises(FileNotFoundError):
        cached_path(missing_file)
    # relative path
    missing_file = "./__missing_file__.txt"
    with pytest.raises(FileNotFoundError):
        cached_path(missing_file)
 def test_load_dataset_distributed(self):
     num_workers = 5
     with tempfile.TemporaryDirectory() as tmp_dir:
         data_name = "csv"
         data_base_path = os.path.join("datasets", data_name, "dummy",
                                       "0.0.0", "dummy_data.zip")
         local_path = cached_path(data_base_path,
                                  cache_dir=tmp_dir,
                                  extract_compressed_file=True,
                                  force_extract=True)
         datafiles = {
             "train": os.path.join(local_path, "dummy_data/train.csv"),
             "dev": os.path.join(local_path, "dummy_data/dev.csv"),
             "test": os.path.join(local_path, "dummy_data/test.csv"),
         }
         args = data_name, tmp_dir, datafiles
         with Pool(processes=num_workers
                   ) as pool:  # start num_workers processes
             result = pool.apply_async(distributed_load_dataset, (args, ))
             dataset = result.get(timeout=20)
             del result, dataset
             datasets = pool.map(distributed_load_dataset,
                                 [args] * num_workers)
             for _ in range(len(datasets)):
                 dataset = datasets.pop()
                 del dataset
Beispiel #4
0
def test_cached_path_extract(xz_file, tmp_path, text_file):
    filename = xz_file
    cache_dir = tmp_path / "cache"
    download_config = DownloadConfig(cache_dir=cache_dir, extract_compressed_file=True)
    extracted_filename = cached_path(filename, download_config=download_config)
    with open(extracted_filename) as f:
        extracted_file_content = f.read()
    with open(text_file) as f:
        expected_file_content = f.read()
    assert extracted_file_content == expected_file_content
Beispiel #5
0
def test_cached_path_extract(compression_format, gz_file, xz_file, zstd_path, tmp_path, text_file):
    input_paths = {"gzip": gz_file, "xz": xz_file, "zstd": zstd_path}
    input_path = str(input_paths[compression_format])
    cache_dir = tmp_path / "cache"
    download_config = DownloadConfig(cache_dir=cache_dir, extract_compressed_file=True)
    extracted_path = cached_path(input_path, download_config=download_config)
    with open(extracted_path) as f:
        extracted_file_content = f.read()
    with open(text_file) as f:
        expected_file_content = f.read()
    assert extracted_file_content == expected_file_content
Beispiel #6
0
def test_extracted_datasets_path(default_extracted, default_cache_dir, xz_file, tmp_path, monkeypatch):
    custom_cache_dir = "custom_cache"
    custom_extracted_dir = "custom_extracted_dir"
    custom_extracted_path = tmp_path / "custom_extracted_path"
    if default_extracted:
        expected = ("downloads" if default_cache_dir else custom_cache_dir, "extracted")
    else:
        monkeypatch.setattr("datasets.config.EXTRACTED_DATASETS_DIR", custom_extracted_dir)
        monkeypatch.setattr("datasets.config.EXTRACTED_DATASETS_PATH", str(custom_extracted_path))
        expected = custom_extracted_path.parts[-2:] if default_cache_dir else (custom_cache_dir, custom_extracted_dir)

    filename = xz_file
    download_config = (
        DownloadConfig(extract_compressed_file=True)
        if default_cache_dir
        else DownloadConfig(cache_dir=tmp_path / custom_cache_dir, extract_compressed_file=True)
    )
    extracted_file_path = cached_path(filename, download_config=download_config)
    assert Path(extracted_file_path).parent.parts[-2:] == expected
Beispiel #7
0
    def test_dataset_info_available(self, dataset, config_name):

        with TemporaryDirectory() as tmp_dir:
            dataset_module = dataset_module_factory(os.path.join(
                "datasets", dataset),
                                                    cache_dir=tmp_dir,
                                                    local_files_only=True)

            builder_cls = import_main_class(dataset_module.module_path,
                                            dataset=True)

            builder_instance: DatasetBuilder = builder_cls(
                cache_dir=tmp_dir,
                name=config_name,
                hash=dataset_module.hash,
            )

            dataset_info_url = os.path.join(
                HF_GCP_BASE_URL,
                builder_instance._relative_data_dir(with_hash=False),
                config.DATASET_INFO_FILENAME).replace(os.sep, "/")
            datset_info_path = cached_path(dataset_info_url, cache_dir=tmp_dir)
            self.assertTrue(os.path.exists(datset_info_path))
Beispiel #8
0
def test_cached_path_offline():
    with pytest.raises(OfflineModeIsEnabled):
        cached_path("https://huggingface.co")