def test_get_data_dir(mock_os_path_dirname): """Tests function ``get_data_dir()``.""" from dirty_cat.datasets.utils import get_data_dir expected_return_value_default = Path("/user/directory/data") mock_os_path_dirname.return_value = "/user/directory/" assert get_data_dir() == expected_return_value_default expected_return_value_custom = expected_return_value_default / "tests" assert get_data_dir("tests") == expected_return_value_custom
def test_fetch_dataset(): from dirty_cat.datasets import fetching utils.MockResponse.set_with_content_length(True) datadir = os.path.join(datasets_utils.get_data_dir(), 'testdata') try: urlinfo = fetching.DatasetInfo(name='testdata', urlinfos=(fetching.UrlInfo( url='http://foo/data', filenames=('data',), uncompress=False, encoding='utf-8'),), main_file='data', source='http://foo/') fetching.fetch_dataset(urlinfo, show_progress=False) assert os.path.exists(os.path.join(datadir, 'data')) shutil.rmtree(os.path.join(datadir)) # test with zipped data utils.MockResponse.set_to_zipfile(True) utils.MockResponse.set_with_content_length(False) urlinfo = fetching.DatasetInfo(name='testdata', urlinfos=(fetching.UrlInfo( url='http://foo/data.zip', filenames=('unzipped_data.txt',), uncompress=True, encoding='utf-8'),), main_file='unzipped_data.txt', source='http://foo/') fetching.fetch_dataset(urlinfo, show_progress=False) assert os.path.exists(os.path.join(datadir, 'unzipped_data.txt')) finally: if os.path.exists(datadir): shutil.rmtree(datadir)
def test_convert_file_to_utf8(monkeypatch): from dirty_cat.datasets import fetching datadir = os.path.join(datasets_utils.get_data_dir(), 'testdata') try: # Specify some content encoded in latin-1, and make sure the final file # contains the same content, but in utf-8. Here, '\xe9' in latin-1 is # '\xc3\xa9' in utf-8 with monkeypatch.context() as m: m.setattr(utils.MockResponse, "zipresult", False) m.setattr(utils.MockResponse, "_file_contents", b'\xe9') dataset_info = fetching.DatasetInfo( name='testdata', urlinfos=(fetching.UrlInfo( url='http://foo/data', filenames=('data',), uncompress=False, encoding='latin-1'),), main_file='data', source='http://foo/') info = fetching.fetch_dataset(dataset_info, show_progress=False) with open(info['path'], 'rb') as f: content = f.read() assert content == b'\xc3\xa9' os.unlink(info['path']) m.setattr(utils.MockResponse, "zipresult", True) dataset_info_with_zipfile = fetching.DatasetInfo( name='testdata', urlinfos=(fetching.UrlInfo( url='http://foo/data.zip', filenames=('unzipped_data.txt',), uncompress=True, encoding='latin-1'),), main_file='unzipped_data.txt', source='http://foo/') info_unzipped = fetching.fetch_dataset( dataset_info_with_zipfile, show_progress=False) with open(info_unzipped['path'], 'rb') as f: content_unzipped = f.read() assert content_unzipped == b'\xc3\xa9' finally: if os.path.exists(datadir): shutil.rmtree(datadir)
def test_fetch_file_overwrite(): utils.MockResponse.set_with_content_length(True) utils.MockResponse.set_to_zipfile(False) test_dir = datasets_utils.get_data_dir(name='test') from dirty_cat.datasets import fetching try: # test that filename is a md5 hash of the url if # the url ends with / fil = fetching._fetch_file(url='http://foo/', data_dir=test_dir, overwrite=True, uncompress=False, show_progress=False) assert os.path.basename(fil) == datasets_utils.md5_hash('/') os.remove(fil) # overwrite non-exiting file. fil = fetching._fetch_file(url='http://foo/testdata', data_dir=test_dir, overwrite=True, uncompress=False, show_progress=False) # check if data_dir is actually used assert os.path.dirname(fil) == test_dir assert os.path.exists(fil) with open(fil, 'r') as fp: assert fp.read() == ' ' # Modify content with open(fil, 'w') as fp: fp.write('some content') # Don't overwrite existing file. fil = fetching._fetch_file(url='http://foo/testdata', data_dir=test_dir, overwrite=False, uncompress=False, show_progress=False) assert os.path.exists(fil) with open(fil, 'r') as fp: assert fp.read() == 'some content' # Overwrite existing file. # Overwrite existing file. fil = fetching._fetch_file(url='http://foo/testdata', data_dir=test_dir, overwrite=True, uncompress=False, show_progress=False) assert os.path.exists(fil) with open(fil, 'r') as fp: assert fp.read() == ' ' # modify content, # change filename,add it in argument, and set overwrite to false with open(fil, 'w') as fp: fp.write('some content') newf = 'moved_file' os.rename(fil, os.path.join(test_dir, newf)) fetching._fetch_file(url='http://foo/testdata', filenames=(newf,), data_dir=test_dir, overwrite=False, uncompress=False, show_progress=False) assert ( not os.path.exists(fil)) # it has been removed and should not have with open(os.path.join(test_dir, newf), 'r') as fp: assert fp.read() == 'some content' # been downloaded again os.remove(os.path.join(test_dir, newf)) # # create a zipfile with a file inside, remove the file, and # # zipd = os.path.join('testzip.zip') # with contextlib.closing(zipfile.ZipFile()) # fetching._fetch_file(url='http://foo/', filenames=('test_filename',), # data_dir=test_dir, # overwrite=False, uncompress=False, # show_progress=False) # add wrong md5 sum file and catch ValueError try: fil = fetching._fetch_file(url='http://foo/testdata', data_dir=test_dir, overwrite=True, uncompress=False, show_progress=False, md5sum='1') raise ValueError # if no error raised in the previous line, # it is bad: # a wrong md5 sum should raise an error. So forcing the except chunk # to happen anyway except Exception as e: assert isinstance(e, fetching.FileChangedError) utils.MockResponse.set_with_content_length(False) # write content if no content size fil = fetching._fetch_file(url='http://foo/testdata', data_dir=test_dir, overwrite=True, uncompress=False, show_progress=False) assert os.path.exists(fil) os.remove(fil) finally: shutil.rmtree(test_dir)
def get_test_data_dir() -> Path: from dirty_cat.datasets.utils import get_data_dir return get_data_dir("tests")
def fetch_openml_dataset( dataset_id: int, data_directory: Path = get_data_dir(), ) -> Dict[str, Any]: """ Gets a dataset from OpenML (https://www.openml.org), or from the disk if already downloaded. Parameters ---------- dataset_id: int The ID of the dataset to fetch. data_directory: Path Optional. A directory to save the data to. By default, the dirty_cat data directory. Returns ------- Dict[str, Any] A dictionary containing: - ``description``: str The description of the dataset, as gathered from OpenML. - ``source``: str The dataset's URL from OpenML. - ``path``: pathlib.Path The local path leading to the dataset, saved as a CSV file. """ # Make path absolute data_directory = data_directory.resolve() # Construct the path to the gzip file containing the details on a dataset. details_gz_path = data_directory / DETAILS_DIRECTORY / f'{dataset_id}.gz' features_gz_path = data_directory / FEATURES_DIRECTORY / f'{dataset_id}.gz' if not details_gz_path.is_file() or not features_gz_path.is_file(): # If the details file or the features file don't exist, # download the dataset. warnings.warn( f"Could not find the dataset {dataset_id} locally. " "Downloading it from OpenML; this might take a while... " "If it is interrupted, some files might be invalid/incomplete: " "if on the following run, the fetching raises errors, you can try " f"fixing this issue by deleting the directory {data_directory!r}.") _download_and_write_openml_dataset(dataset_id=dataset_id, data_directory=data_directory) details = _get_details(details_gz_path) # The file ID is required because the data file is named after this ID, # and not after the dataset's. file_id = details.file_id csv_path = data_directory / f'{details.name}.csv' data_gz_path = data_directory / DATA_DIRECTORY / f'{file_id}.gz' if not data_gz_path.is_file(): # This is a double-check. # If the data file does not exist, download the dataset. _download_and_write_openml_dataset(dataset_id=dataset_id, data_directory=data_directory) if not csv_path.is_file(): # If the CSV file does not exist, use the dataset # downloaded by ``fetch_openml()`` to construct it. features = _get_features(features_gz_path) _export_gz_data_to_csv(data_gz_path, csv_path, features) url = openml_url.format(ID=dataset_id) return { "description": details.description, "source": url, "path": csv_path.resolve() }