def _get_data(self): """Load data from the file. Does nothing if data was loaded before. """ (data_archive_name, archive_hash), (data_name, data_hash) \ = self._data_file[self._version][self._segment] data_path = os.path.join(self._root, data_name) if not os.path.exists(data_path) or not check_sha1( data_path, data_hash): with tempfile.TemporaryDirectory(dir=self._root) as temp_dir: file_path = download(_get_repo_file_url( 'gluon/dataset/squad', data_archive_name), path=temp_dir, sha1_hash=archive_hash) with zipfile.ZipFile(file_path, 'r') as zf: for member in zf.namelist(): filename = os.path.basename(member) if filename: dest = os.path.join(self._root, filename) temp_dst = dest + str(uuid.uuid4()) with zf.open(member) as source: with open(temp_dst, 'wb') as target: shutil.copyfileobj(source, target) replace_file(temp_dst, dest)
def download(url: str, path: Optional[str] = None, overwrite: Optional[bool] = False, sha1_hash: Optional[str] = None, retries: Optional[int] = 5, verify_ssl: Optional[bool] = True) -> str: """Download a given URL Parameters ---------- url URL to download path Destination path to store downloaded file. By default stores to the current directory with same name as in url. overwrite Whether to overwrite destination file if already exists. sha1_hash Expected sha1 hash in hexadecimal digits. Will ignore existing file when hash is specified but doesn't match. retries The number of times to attempt the download in case of failure or non 200 return codes verify_ssl Verify SSL certificates. Returns ------- fname The file path of the downloaded file. """ is_s3 = url.startswith(S3_PREFIX) if is_s3: boto3 = try_import_boto3() s3 = boto3.resource('s3') components = url[len(S3_PREFIX):].split('/') if len(components) < 2: raise ValueError('Invalid S3 url. Received url={}'.format(url)) s3_bucket_name = components[0] s3_key = '/'.join(components[1:]) if path is None: fname = url.split('/')[-1] # Empty filenames are invalid assert fname, 'Can\'t construct file-name from this URL. ' \ 'Please set the `path` option manually.' else: path = os.path.expanduser(path) if os.path.isdir(path): fname = os.path.join(path, url.split('/')[-1]) else: fname = path assert retries >= 0, "Number of retries should be at least 0, currently it's {}".format( retries) if not verify_ssl: warnings.warn( 'Unverified HTTPS request is being made (verify_ssl=False). ' 'Adding certificate verification is strongly advised.') if overwrite or not os.path.exists(fname) or ( sha1_hash and not sha1sum(fname) == sha1_hash): dirname = os.path.dirname(os.path.abspath(os.path.expanduser(fname))) if not os.path.exists(dirname): os.makedirs(dirname, exist_ok=True) while retries + 1 > 0: # Disable pyling too broad Exception # pylint: disable=W0703 try: print('Downloading {} from {}...'.format(fname, url)) if is_s3: response = s3.meta.client.head_object( Bucket=s3_bucket_name, Key=s3_key) total_size = int(response.get('ContentLength', 0)) random_uuid = str(uuid.uuid4()) tmp_path = '{}.{}'.format(fname, random_uuid) if tqdm is not None: def hook(t_obj): def inner(bytes_amount): t_obj.update(bytes_amount) return inner with tqdm.tqdm(total=total_size, unit='iB', unit_scale=True) as t: s3.meta.client.download_file(s3_bucket_name, s3_key, tmp_path, Callback=hook(t)) else: s3.meta.client.download_file(s3_bucket_name, s3_key, tmp_path) else: r = requests.get(url, stream=True, verify=verify_ssl) if r.status_code != 200: raise RuntimeError( 'Failed downloading url {}'.format(url)) # create uuid for temporary files random_uuid = str(uuid.uuid4()) total_size = int(r.headers.get('content-length', 0)) chunk_size = 1024 if tqdm is not None: t = tqdm.tqdm(total=total_size, unit='iB', unit_scale=True) with open('{}.{}'.format(fname, random_uuid), 'wb') as f: for chunk in r.iter_content(chunk_size=chunk_size): if chunk: # filter out keep-alive new chunks if tqdm is not None: t.update(len(chunk)) f.write(chunk) if tqdm is not None: t.close() # if the target file exists(created by other processes) # and have the same hash with target file # delete the temporary file if not os.path.exists(fname) or ( sha1_hash and not sha1sum(fname) == sha1_hash): # atomic operation in the same file system replace_file('{}.{}'.format(fname, random_uuid), fname) else: try: os.remove('{}.{}'.format(fname, random_uuid)) except OSError: pass finally: warnings.warn( 'File {} exists in file system so the downloaded file is deleted' .format(fname)) if sha1_hash and not sha1sum(fname) == sha1_hash: raise UserWarning( 'File {} is downloaded but the content hash does not match.' ' The repo may be outdated or download may be incomplete. ' 'If the "repo_url" is overridden, consider switching to ' 'the default repo.'.format(fname)) break except Exception as e: retries -= 1 if retries <= 0: raise e print('download failed due to {}, retrying, {} attempt{} left'. format(repr(e), retries, 's' if retries > 1 else '')) return fname