Example #1
0
def test_load_20ng():
    data_home = get_data_home(data_home=None)
    cache_path = _pkl_filepath(data_home, "20NewsGroup" + ".pkz")
    if os.path.exists(cache_path):
        os.remove(cache_path)

    dataset = Dataset()
    dataset.fetch_dataset("20NewsGroup")
    assert len(dataset.get_corpus()) == 16309
    assert len(dataset.get_labels()) == 16309
    assert os.path.exists(cache_path)

    dataset = Dataset()
    dataset.fetch_dataset("20NewsGroup")
    assert len(dataset.get_corpus()) == 16309
Example #2
0
    def fetch_dataset(self,
                      dataset_name,
                      data_home=None,
                      download_if_missing=True):
        """Load the filenames and data from a dataset.
        Parameters
        ----------
        dataset_name: name of the dataset to download or retrieve
        data_home : optional, default: None
            Specify a download and cache folder for the datasets. If None,
            all data is stored in '~/octis' subfolders.
        download_if_missing : optional, True by default
            If False, raise an IOError if the data is not locally available
            instead of trying to download the data from the source site.
        """

        data_home = get_data_home(data_home=data_home)
        cache_path = _pkl_filepath(data_home, dataset_name + ".pkz")
        dataset_home = join(data_home, dataset_name)
        cache = None
        if exists(cache_path):
            try:
                with open(cache_path, 'rb') as f:
                    compressed_content = f.read()
                uncompressed_content = codecs.decode(compressed_content,
                                                     'zlib_codec')
                cache = pickle.loads(uncompressed_content)
            except Exception as e:
                print(80 * '_')
                print('Cache loading failed')
                print(80 * '_')
                print(e)

        if cache is None:
            if download_if_missing:
                cache = download_dataset(dataset_name,
                                         target_dir=dataset_home,
                                         cache_path=cache_path)
            else:
                raise IOError(dataset_name + ' dataset not found')
        self.is_cached = True
        self.__corpus = [d.split() for d in cache["corpus"]]
        self.__vocabulary = cache["vocabulary"]
        self.__metadata = cache["metadata"]
        self.dataset_path = cache_path
        self.__labels = cache["labels"]