def test_data_utils(): """Tests get_file from a url, plus extraction and validation. """ dirname = 'data_utils' with open('test.txt', 'w') as text_file: text_file.write('Float like a butterfly, sting like a bee.') with tarfile.open('test.tar.gz', 'w:gz') as tar_file: tar_file.add('test.txt') with zipfile.ZipFile('test.zip', 'w') as zip_file: zip_file.write('test.txt') origin = urljoin('file://', pathname2url(os.path.abspath('test.tar.gz'))) path = get_file(dirname, origin, untar=True) filepath = path + '.tar.gz' hashval_sha256 = _hash_file(filepath) hashval_md5 = _hash_file(filepath, algorithm='md5') path = get_file(dirname, origin, md5_hash=hashval_md5, untar=True) path = get_file(filepath, origin, file_hash=hashval_sha256, extract=True) assert os.path.exists(filepath) assert validate_file(filepath, hashval_sha256) assert validate_file(filepath, hashval_md5) os.remove(filepath) os.remove('test.tar.gz') origin = urljoin('file://', pathname2url(os.path.abspath('test.zip'))) hashval_sha256 = _hash_file('test.zip') hashval_md5 = _hash_file('test.zip', algorithm='md5') path = get_file(dirname, origin, md5_hash=hashval_md5, extract=True) path = get_file(dirname, origin, file_hash=hashval_sha256, extract=True) assert os.path.exists(path) assert validate_file(path, hashval_sha256) assert validate_file(path, hashval_md5) os.remove(path) os.remove('test.txt') os.remove('test.zip')
def test_data_utils(in_tmpdir): """Tests get_file from a url, plus extraction and validation. """ dirname = 'data_utils' with open('test.txt', 'w') as text_file: text_file.write('Float like a butterfly, sting like a bee.') with tarfile.open('test.tar.gz', 'w:gz') as tar_file: tar_file.add('test.txt') with zipfile.ZipFile('test.zip', 'w') as zip_file: zip_file.write('test.txt') origin = urljoin('file://', pathname2url(os.path.abspath('test.tar.gz'))) path = get_file(dirname, origin, untar=True) filepath = path + '.tar.gz' hashval_sha256 = _hash_file(filepath) hashval_md5 = _hash_file(filepath, algorithm='md5') path = get_file(dirname, origin, md5_hash=hashval_md5, untar=True) path = get_file(filepath, origin, file_hash=hashval_sha256, extract=True) assert os.path.exists(filepath) assert validate_file(filepath, hashval_sha256) assert validate_file(filepath, hashval_md5) os.remove(filepath) os.remove('test.tar.gz') origin = urljoin('file://', pathname2url(os.path.abspath('test.zip'))) hashval_sha256 = _hash_file('test.zip') hashval_md5 = _hash_file('test.zip', algorithm='md5') path = get_file(dirname, origin, md5_hash=hashval_md5, extract=True) path = get_file(dirname, origin, file_hash=hashval_sha256, extract=True) assert os.path.exists(path) assert validate_file(path, hashval_sha256) assert validate_file(path, hashval_md5) os.remove(path) os.remove('test.txt') os.remove('test.zip')
def test_data_utils(in_tmpdir): """Tests get_file from a url, plus extraction and validation. """ dirname = 'data_utils' with open('test.txt', 'w') as text_file: text_file.write('Float like a butterfly, sting like a bee.') with tarfile.open('test.tar.gz', 'w:gz') as tar_file: tar_file.add('test.txt') with zipfile.ZipFile('test.zip', 'w') as zip_file: zip_file.write('test.txt') origin = urljoin('file://', pathname2url(os.path.abspath('test.tar.gz'))) path = get_file(dirname, origin, untar=True) filepath = path + '.tar.gz' data_keras_home = os.path.dirname( os.path.dirname(os.path.abspath(filepath))) assert data_keras_home == os.path.dirname(load_backend._config_path) os.remove(filepath) _keras_home = os.path.join(os.path.abspath('.'), '.keras') if not os.path.exists(_keras_home): os.makedirs(_keras_home) os.environ['KERAS_HOME'] = _keras_home reload_module(load_backend) path = get_file(dirname, origin, untar=True) filepath = path + '.tar.gz' data_keras_home = os.path.dirname( os.path.dirname(os.path.abspath(filepath))) assert data_keras_home == os.path.dirname(load_backend._config_path) os.environ.pop('KERAS_HOME') shutil.rmtree(_keras_home) reload_module(load_backend) path = get_file(dirname, origin, untar=True) filepath = path + '.tar.gz' hashval_sha256 = _hash_file(filepath) hashval_md5 = _hash_file(filepath, algorithm='md5') path = get_file(dirname, origin, md5_hash=hashval_md5, untar=True) path = get_file(filepath, origin, file_hash=hashval_sha256, extract=True) assert os.path.exists(filepath) assert validate_file(filepath, hashval_sha256) assert validate_file(filepath, hashval_md5) os.remove(filepath) os.remove('test.tar.gz') origin = urljoin('file://', pathname2url(os.path.abspath('test.zip'))) hashval_sha256 = _hash_file('test.zip') hashval_md5 = _hash_file('test.zip', algorithm='md5') path = get_file(dirname, origin, md5_hash=hashval_md5, extract=True) path = get_file(dirname, origin, file_hash=hashval_sha256, extract=True) assert os.path.exists(path) assert validate_file(path, hashval_sha256) assert validate_file(path, hashval_md5) os.remove(path) os.remove(os.path.join(os.path.dirname(path), 'test.txt')) os.remove('test.txt') os.remove('test.zip')
def get_file(fname, origin, save_path, untar=False, md5_hash=None, cache_subdir='datasets', tar_folder_name=None): """Downloads a file from a URL if it not already in the cache. Passing the MD5 hash will verify the file after download as well as if it is already present in the cache. Usually it downloads the file to save_path/cache_dubdir/fname Arguments --------- fname: name of the file origin: original URL of the file save_path: path to create cache_subdir. untar: boolean, whether the file should be decompressed md5_hash: MD5 hash of the file for verification cache_subdir: directory being used as the cache tar_folder_name: string, if inside of abc.tar.gz is not abc but def, pass def here. Returns ------- Path to the downloaded file """ datadir_base = save_path if not os.access(datadir_base, os.W_OK): datadir_base = os.path.expanduser(os.path.join('~', '.kapre')) print( 'Given path {} is not accessible. Trying to use~/.kapre instead..') if not os.access(datadir_base, os.W_OK): print('~/.kapre is not accessible, using /tmp/kapre instead.') datadir_base = os.path.join('/tmp', '.kapre') datadir = os.path.join(datadir_base, cache_subdir) if not os.path.exists(datadir): os.makedirs(datadir) if untar: assert fname.endswith('.tar.gz'), fname fpath = os.path.join(datadir, fname) if tar_folder_name: untar_fpath = os.path.join(datadir, tar_folder_name) else: untar_fpath = fpath.rstrip('.tar.gz') else: fpath = os.path.join(datadir, fname) download = False if os.path.exists(fpath): # File found; verify integrity if a hash was provided. if md5_hash is not None: if not validate_file(fpath, md5_hash): print( 'A local file was found, just checked md5 hash, but it might be ' 'incomplete or outdated') download = True else: download = True if download: print('Downloading data from', origin) progbar = None def dl_progress(count, block_size, total_size, progbar=None): if progbar is None: progbar = Progbar(total_size) else: progbar.update(count * block_size) error_msg = 'URL fetch failure on {}: {} -- {}' try: try: urlretrieve(origin, fpath, functools.partial(dl_progress, progbar=progbar)) except URLError as e: raise Exception(error_msg.format(origin, e.errno, e.reason)) except HTTPError as e: raise Exception(error_msg.format(origin, e.code, e.msg)) except (Exception, KeyboardInterrupt) as e: if os.path.exists(fpath): os.remove(fpath) raise progbar = None if untar: if not os.path.exists(untar_fpath): print('Untaring file...') tfile = tarfile.open(fpath, 'r:gz') try: tfile.extractall(path=datadir) except (Exception, KeyboardInterrupt) as e: if os.path.exists(untar_fpath): if os.path.isfile(untar_fpath): os.remove(untar_fpath) else: shutil.rmtree(untar_fpath) raise tfile.close() # return untar_fpath return datadir
def test_data_utils(in_tmpdir): """Tests get_file from a url, plus extraction and validation. """ dirname = 'data_utils' with open('test.txt', 'w') as text_file: text_file.write('Float like a butterfly, sting like a bee.') with tarfile.open('test.tar.gz', 'w:gz') as tar_file: tar_file.add('test.txt') with zipfile.ZipFile('test.zip', 'w') as zip_file: zip_file.write('test.txt') origin = urljoin('file://', pathname2url(os.path.abspath('test.tar.gz'))) path = get_file(dirname, origin, untar=True) filepath = path + '.tar.gz' data_keras_home = os.path.dirname(os.path.dirname(os.path.abspath(filepath))) assert data_keras_home == os.path.dirname(K._config_path) os.remove(filepath) _keras_home = os.path.join(os.path.abspath('.'), '.keras') if not os.path.exists(_keras_home): os.makedirs(_keras_home) os.environ['KERAS_HOME'] = _keras_home reload_module(K) path = get_file(dirname, origin, untar=True) filepath = path + '.tar.gz' data_keras_home = os.path.dirname(os.path.dirname(os.path.abspath(filepath))) assert data_keras_home == os.path.dirname(K._config_path) os.environ.pop('KERAS_HOME') shutil.rmtree(_keras_home) reload_module(K) path = get_file(dirname, origin, untar=True) filepath = path + '.tar.gz' hashval_sha256 = _hash_file(filepath) hashval_md5 = _hash_file(filepath, algorithm='md5') path = get_file(dirname, origin, md5_hash=hashval_md5, untar=True) path = get_file(filepath, origin, file_hash=hashval_sha256, extract=True) assert os.path.exists(filepath) assert validate_file(filepath, hashval_sha256) assert validate_file(filepath, hashval_md5) os.remove(filepath) os.remove('test.tar.gz') origin = urljoin('file://', pathname2url(os.path.abspath('test.zip'))) hashval_sha256 = _hash_file('test.zip') hashval_md5 = _hash_file('test.zip', algorithm='md5') path = get_file(dirname, origin, md5_hash=hashval_md5, extract=True) path = get_file(dirname, origin, file_hash=hashval_sha256, extract=True) assert os.path.exists(path) assert validate_file(path, hashval_sha256) assert validate_file(path, hashval_md5) os.remove(path) os.remove(os.path.join(os.path.dirname(path), 'test.txt')) os.remove('test.txt') os.remove('test.zip')