def test_download_repo_of_dir(): md_path = '../mshub_res/assets/mindspore/gpu/1.0/googlenet_v1_cifar10.md' cell = CellInfo("googlenet") cell.update(md_path) git_url = cell.repo_link set_hub_dir('.cache') path = get_hub_dir() ret = _download_repo_from_url(git_url, path) assert ret
def test_load_weights_input_uid(): """test load_weights.""" set_hub_dir('.cache') path = get_hub_dir() repo_link = 'https://gitee.com/mindspore/models/tree/master/official/cv/googlenet' _download_repo_from_url(repo_link, path) net = _get_network_from_cache('GoogleNet', path + '/googlenet', 10) load_weights(net, handle='mindspore/ascend/1.0/googlenet_v1_cifar10', force_reload=True) print(net)
def test_download_ckpt(): """test download ckpt.""" md_path = '../mshub_res/assets/mindspore/gpu/1.0/googlenet_v1_cifar10.md' cell = CellInfo("googlenet") cell.update(md_path) asset_link = cell.asset[cell.asset_id]['asset-link'] asset_sha256 = cell.asset[cell.asset_id]["asset-sha256"] set_hub_dir('.cache') path = get_hub_dir() ckpt_path = _download_file_from_url(asset_link, hash_sha256=asset_sha256, save_path=path) assert os.path.exists(ckpt_path)
def _download_repo_from_url(url, path=get_hub_dir()): """ Download file form url. Args: url (str): A url to download file. path (str): A path to store download file. Returns: bool, return whether success download file. """ _create_path_if_not_exists(path) repo_infos = get_repo_info_from_url(url) arg = dict() arg["bash"] = "bash" arg["git_ssh"] = repo_infos["git_ssh"] arg["path"] = path arg["model_path"] = repo_infos["dst_dir"] arg["branch"] = repo_infos["branch"] is_repo = repo_infos["is_repo"] with TemporaryDirectory() as git_dir: arg["git_dir"] = git_dir # is repo or dir of repo if is_repo: arg["shell_path"] = FULL_SHELL_PATH else: arg["shell_path"] = SPARSE_SHELL_PATH cmd = [ arg["bash"], arg["shell_path"], arg["git_dir"], arg["path"], arg["model_path"], arg["git_ssh"], arg["branch"] ] out = subprocess.check_output(cmd, shell=False) ret = "succeed" in out.decode('utf-8') return ret
def _download_file_from_url(url, hash_sha256=None, save_path=get_hub_dir()): """ download checkpoint weight from giving url. Args: url(string): checkpoint url path. hash_sha256(string): checkpoint file sha256. save_path(string): checkpoint download save path. Returns: string. """ def reporthook(a, b, c): percent = a * b * 100.0 / c percent = 100 if percent > 100 else percent if c > 0: print("\rDownloading...%5.1f%%" % percent, end="") def sha256sum(file_name, hash_sha256): fp = open(file_name, 'rb') content = fp.read() fp.close() m = hashlib.sha256() m.update(content) download_sha256 = m.hexdigest() return download_sha256 == hash_sha256 _create_path_if_not_exists(os.path.realpath(save_path)) ckpt_name = os.path.basename(url.split("/")[-1]) # identify file exist or not file_path = os.path.join(save_path, ckpt_name) if os.path.isfile(file_path): if hash_sha256 and sha256sum(file_path, hash_sha256): print('File already exists!') return file_path print('File already exists, but sha256 checking failed. Will download again') _remove_path_if_exists(file_path) # download the checkpoint file print('Downloading data from url {}'.format(url)) try: opener = urllib.request.build_opener() opener.addheaders = [('User-Agent', 'Mozilla/5.0')] urllib.request.install_opener(opener) urlretrieve(url, file_path, reporthook=reporthook) except HTTPError as e: raise Exception(e.code, e.msg, url) except URLError as e: raise Exception(e.errno, e.reason, url) print('\nDownload finished!') # Check file integrity if hash_sha256: result = sha256sum(file_path, hash_sha256) if not result: raise Exception('INTEGRITY ERROR: File: {} is not integral'.format(file_path)) # Check file size # Get file size and turn the file size to Mb format file_size = os.path.getsize(file_path) print('File size = %.2f Mb' % (file_size / 1024 / 1024)) # Start check if file_size > MAX_FILE_SIZE: os.remove(file_path) raise Exception('SIZE ERROR: Download file is too large,' 'the max size is {}Mb'.format(MAX_FILE_SIZE / 1024 / 1024)) # Check file type suffix = os.path.splitext(file_path)[1] if suffix not in SUFFIX_LIST: os.remove(file_path) raise Exception('SUFFIX ERROR: File: {} with Suffix: {} ' 'can not be recognized'.format(file_path, suffix)) return file_path
def test_download_repo(): git_url = 'https://gitee.com/mindspore/models/tree/master/official/cv/googlenet' set_hub_dir('.cache') path = get_hub_dir() ret = _download_repo_from_url(git_url, path) assert ret