Esempio n. 1
0
def test_get_data_dir(mock_os_path_dirname):
    """Tests function ``get_data_dir()``."""
    from dirty_cat.datasets.utils import get_data_dir

    expected_return_value_default = Path("/user/directory/data")

    mock_os_path_dirname.return_value = "/user/directory/"
    assert get_data_dir() == expected_return_value_default

    expected_return_value_custom = expected_return_value_default / "tests"
    assert get_data_dir("tests") == expected_return_value_custom
def test_fetch_dataset():
    from dirty_cat.datasets import fetching
    utils.MockResponse.set_with_content_length(True)
    datadir = os.path.join(datasets_utils.get_data_dir(),
                           'testdata')
    try:
        urlinfo = fetching.DatasetInfo(name='testdata',
                                       urlinfos=(fetching.UrlInfo(
                                           url='http://foo/data',
                                           filenames=('data',),
                                           uncompress=False,
                                           encoding='utf-8'),),
                                       main_file='data',
                                       source='http://foo/')
        fetching.fetch_dataset(urlinfo, show_progress=False)
        assert os.path.exists(os.path.join(datadir, 'data'))
        shutil.rmtree(os.path.join(datadir))

        # test with zipped data
        utils.MockResponse.set_to_zipfile(True)
        utils.MockResponse.set_with_content_length(False)
        urlinfo = fetching.DatasetInfo(name='testdata',
                                       urlinfos=(fetching.UrlInfo(
                                           url='http://foo/data.zip',
                                           filenames=('unzipped_data.txt',),
                                           uncompress=True,
                                           encoding='utf-8'),),
                                       main_file='unzipped_data.txt',
                                       source='http://foo/')
        fetching.fetch_dataset(urlinfo, show_progress=False)
        assert os.path.exists(os.path.join(datadir, 'unzipped_data.txt'))
    finally:
        if os.path.exists(datadir):
            shutil.rmtree(datadir)
def test_convert_file_to_utf8(monkeypatch):
    from dirty_cat.datasets import fetching
    datadir = os.path.join(datasets_utils.get_data_dir(),
                           'testdata')
    try:
        # Specify some content encoded in latin-1, and make sure the final file
        # contains the same content, but in utf-8. Here, '\xe9' in latin-1 is
        # '\xc3\xa9' in utf-8

        with monkeypatch.context() as m:
            m.setattr(utils.MockResponse, "zipresult", False)
            m.setattr(utils.MockResponse, "_file_contents", b'\xe9')

            dataset_info = fetching.DatasetInfo(
                name='testdata',
                urlinfos=(fetching.UrlInfo(
                    url='http://foo/data',
                    filenames=('data',),
                    uncompress=False,
                    encoding='latin-1'),),
                main_file='data',
                source='http://foo/')

            info = fetching.fetch_dataset(dataset_info, show_progress=False)

            with open(info['path'], 'rb') as f:
                content = f.read()

            assert content == b'\xc3\xa9'
            os.unlink(info['path'])

            m.setattr(utils.MockResponse, "zipresult", True)

            dataset_info_with_zipfile = fetching.DatasetInfo(
                name='testdata',
                urlinfos=(fetching.UrlInfo(
                    url='http://foo/data.zip',
                    filenames=('unzipped_data.txt',),
                    uncompress=True,
                    encoding='latin-1'),),
                main_file='unzipped_data.txt',
                source='http://foo/')

            info_unzipped = fetching.fetch_dataset(
                dataset_info_with_zipfile, show_progress=False)

            with open(info_unzipped['path'], 'rb') as f:
                content_unzipped = f.read()

            assert content_unzipped == b'\xc3\xa9'

    finally:
        if os.path.exists(datadir):
            shutil.rmtree(datadir)
def test_fetch_file_overwrite():
    utils.MockResponse.set_with_content_length(True)
    utils.MockResponse.set_to_zipfile(False)
    test_dir = datasets_utils.get_data_dir(name='test')
    from dirty_cat.datasets import fetching
    try:
        # test that filename is a md5 hash of the url if
        # the url ends with /
        fil = fetching._fetch_file(url='http://foo/', data_dir=test_dir,
                                   overwrite=True, uncompress=False,
                                   show_progress=False)
        assert os.path.basename(fil) == datasets_utils.md5_hash('/')
        os.remove(fil)

        # overwrite non-exiting file.
        fil = fetching._fetch_file(url='http://foo/testdata', data_dir=test_dir,
                                   overwrite=True, uncompress=False,
                                   show_progress=False)

        # check if data_dir is actually used
        assert os.path.dirname(fil) == test_dir
        assert os.path.exists(fil)
        with open(fil, 'r') as fp:
            assert fp.read() == ' '

        # Modify content
        with open(fil, 'w') as fp:
            fp.write('some content')

        # Don't overwrite existing file.
        fil = fetching._fetch_file(url='http://foo/testdata', data_dir=test_dir,
                                   overwrite=False, uncompress=False,
                                   show_progress=False)
        assert os.path.exists(fil)
        with open(fil, 'r') as fp:
            assert fp.read() == 'some content'

        # Overwrite existing file.
        # Overwrite existing file.
        fil = fetching._fetch_file(url='http://foo/testdata', data_dir=test_dir,
                                   overwrite=True, uncompress=False,
                                   show_progress=False)
        assert os.path.exists(fil)
        with open(fil, 'r') as fp:
            assert fp.read() == ' '

        # modify content,
        # change filename,add it in argument, and set overwrite to false
        with open(fil, 'w') as fp:
            fp.write('some content')
        newf = 'moved_file'
        os.rename(fil, os.path.join(test_dir, newf))
        fetching._fetch_file(url='http://foo/testdata',
                             filenames=(newf,),
                             data_dir=test_dir,
                             overwrite=False, uncompress=False,
                             show_progress=False)
        assert (
            not os.path.exists(fil))  # it has been removed and should not have
        with open(os.path.join(test_dir, newf), 'r') as fp:
            assert fp.read() == 'some content'
        # been downloaded again
        os.remove(os.path.join(test_dir, newf))

        # # create a zipfile with a file inside, remove the file, and
        #
        # zipd = os.path.join('testzip.zip')
        # with contextlib.closing(zipfile.ZipFile())
        #     fetching._fetch_file(url='http://foo/', filenames=('test_filename',),
        #                          data_dir=test_dir,
        #                          overwrite=False, uncompress=False,
        #                          show_progress=False)

        # add wrong md5 sum file and catch ValueError
        try:
            fil = fetching._fetch_file(url='http://foo/testdata',
                                       data_dir=test_dir,
                                       overwrite=True, uncompress=False,
                                       show_progress=False, md5sum='1')
            raise ValueError  # if no error raised in the previous line,
            #  it is bad:
            # a wrong md5 sum should raise an error. So forcing the except chunk
            # to happen anyway
        except Exception as e:
            assert isinstance(e, fetching.FileChangedError)
        utils.MockResponse.set_with_content_length(False)
        # write content if no content size
        fil = fetching._fetch_file(url='http://foo/testdata', data_dir=test_dir,
                                   overwrite=True, uncompress=False,
                                   show_progress=False)
        assert os.path.exists(fil)
        os.remove(fil)
    finally:
        shutil.rmtree(test_dir)
Esempio n. 5
0
def get_test_data_dir() -> Path:
    from dirty_cat.datasets.utils import get_data_dir
    return get_data_dir("tests")
Esempio n. 6
0
def fetch_openml_dataset(
        dataset_id: int,
        data_directory: Path = get_data_dir(),
) -> Dict[str, Any]:
    """
    Gets a dataset from OpenML (https://www.openml.org),
    or from the disk if already downloaded.

    Parameters
    ----------
    dataset_id: int
        The ID of the dataset to fetch.
    data_directory: Path
        Optional. A directory to save the data to.
        By default, the dirty_cat data directory.

    Returns
    -------
    Dict[str, Any]
        A dictionary containing:
          - ``description``: str
              The description of the dataset,
              as gathered from OpenML.
          - ``source``: str
              The dataset's URL from OpenML.
          - ``path``: pathlib.Path
              The local path leading to the dataset,
              saved as a CSV file.

    """
    # Make path absolute
    data_directory = data_directory.resolve()

    # Construct the path to the gzip file containing the details on a dataset.
    details_gz_path = data_directory / DETAILS_DIRECTORY / f'{dataset_id}.gz'
    features_gz_path = data_directory / FEATURES_DIRECTORY / f'{dataset_id}.gz'

    if not details_gz_path.is_file() or not features_gz_path.is_file():
        # If the details file or the features file don't exist,
        # download the dataset.
        warnings.warn(
            f"Could not find the dataset {dataset_id} locally. "
            "Downloading it from OpenML; this might take a while... "
            "If it is interrupted, some files might be invalid/incomplete: "
            "if on the following run, the fetching raises errors, you can try "
            f"fixing this issue by deleting the directory {data_directory!r}.")
        _download_and_write_openml_dataset(dataset_id=dataset_id,
                                           data_directory=data_directory)
    details = _get_details(details_gz_path)

    # The file ID is required because the data file is named after this ID,
    # and not after the dataset's.
    file_id = details.file_id
    csv_path = data_directory / f'{details.name}.csv'

    data_gz_path = data_directory / DATA_DIRECTORY / f'{file_id}.gz'

    if not data_gz_path.is_file():
        # This is a double-check.
        # If the data file does not exist, download the dataset.
        _download_and_write_openml_dataset(dataset_id=dataset_id,
                                           data_directory=data_directory)

    if not csv_path.is_file():
        # If the CSV file does not exist, use the dataset
        # downloaded by ``fetch_openml()`` to construct it.
        features = _get_features(features_gz_path)
        _export_gz_data_to_csv(data_gz_path, csv_path, features)

    url = openml_url.format(ID=dataset_id)

    return {
        "description": details.description,
        "source": url,
        "path": csv_path.resolve()
    }