Exemplo n.º 1
0
def test_download_repo_of_dir():
    md_path = '../mshub_res/assets/mindspore/gpu/1.0/googlenet_v1_cifar10.md'
    cell = CellInfo("googlenet")
    cell.update(md_path)

    git_url = cell.repo_link
    set_hub_dir('.cache')
    path = get_hub_dir()
    ret = _download_repo_from_url(git_url, path)
    assert ret
Exemplo n.º 2
0
def test_load_weights_input_uid():
    """test load_weights."""
    set_hub_dir('.cache')
    path = get_hub_dir()

    repo_link = 'https://gitee.com/mindspore/models/tree/master/official/cv/googlenet'
    _download_repo_from_url(repo_link, path)

    net = _get_network_from_cache('GoogleNet', path + '/googlenet', 10)
    load_weights(net,
                 handle='mindspore/ascend/1.0/googlenet_v1_cifar10',
                 force_reload=True)
    print(net)
Exemplo n.º 3
0
def test_download_ckpt():
    """test download ckpt."""
    md_path = '../mshub_res/assets/mindspore/gpu/1.0/googlenet_v1_cifar10.md'
    cell = CellInfo("googlenet")
    cell.update(md_path)

    asset_link = cell.asset[cell.asset_id]['asset-link']
    asset_sha256 = cell.asset[cell.asset_id]["asset-sha256"]
    set_hub_dir('.cache')
    path = get_hub_dir()

    ckpt_path = _download_file_from_url(asset_link,
                                        hash_sha256=asset_sha256,
                                        save_path=path)
    assert os.path.exists(ckpt_path)
Exemplo n.º 4
0
def _download_repo_from_url(url, path=get_hub_dir()):
    """
    Download file form url.

    Args:
        url (str): A url to download file.
        path (str): A path to store download file.

    Returns:
        bool, return whether success download file.
    """
    _create_path_if_not_exists(path)

    repo_infos = get_repo_info_from_url(url)

    arg = dict()
    arg["bash"] = "bash"
    arg["git_ssh"] = repo_infos["git_ssh"]
    arg["path"] = path
    arg["model_path"] = repo_infos["dst_dir"]
    arg["branch"] = repo_infos["branch"]
    is_repo = repo_infos["is_repo"]

    with TemporaryDirectory() as git_dir:
        arg["git_dir"] = git_dir

        # is repo or dir of repo
        if is_repo:
            arg["shell_path"] = FULL_SHELL_PATH
        else:
            arg["shell_path"] = SPARSE_SHELL_PATH

        cmd = [
            arg["bash"], arg["shell_path"], arg["git_dir"], arg["path"],
            arg["model_path"], arg["git_ssh"], arg["branch"]
        ]
        out = subprocess.check_output(cmd, shell=False)
        ret = "succeed" in out.decode('utf-8')
        return ret
Exemplo n.º 5
0
def _download_file_from_url(url, hash_sha256=None, save_path=get_hub_dir()):
    """
    download checkpoint weight from giving url.

    Args:
       url(string): checkpoint url path.
       hash_sha256(string): checkpoint file sha256.
       save_path(string): checkpoint download save path.

    Returns:
       string.
    """

    def reporthook(a, b, c):
        percent = a * b * 100.0 / c
        percent = 100 if percent > 100 else percent
        if c > 0:
            print("\rDownloading...%5.1f%%" % percent, end="")

    def sha256sum(file_name, hash_sha256):
        fp = open(file_name, 'rb')
        content = fp.read()
        fp.close()
        m = hashlib.sha256()
        m.update(content)
        download_sha256 = m.hexdigest()
        return download_sha256 == hash_sha256

    _create_path_if_not_exists(os.path.realpath(save_path))
    ckpt_name = os.path.basename(url.split("/")[-1])
    # identify file exist or not
    file_path = os.path.join(save_path, ckpt_name)
    if os.path.isfile(file_path):
        if hash_sha256 and sha256sum(file_path, hash_sha256):
            print('File already exists!')
            return file_path
        print('File already exists, but sha256 checking failed. Will download again')

    _remove_path_if_exists(file_path)

    # download the checkpoint file
    print('Downloading data from url {}'.format(url))
    try:
        opener = urllib.request.build_opener()
        opener.addheaders = [('User-Agent', 'Mozilla/5.0')]
        urllib.request.install_opener(opener)
        urlretrieve(url, file_path, reporthook=reporthook)
    except HTTPError as e:
        raise Exception(e.code, e.msg, url)
    except URLError as e:
        raise Exception(e.errno, e.reason, url)
    print('\nDownload finished!')

    # Check file integrity
    if hash_sha256:
        result = sha256sum(file_path, hash_sha256)
        if not result:
            raise Exception('INTEGRITY ERROR: File: {} is not integral'.format(file_path))

    # Check file size
    # Get file size and turn the file size to Mb format
    file_size = os.path.getsize(file_path)
    print('File size = %.2f Mb' % (file_size / 1024 / 1024))
    # Start check
    if file_size > MAX_FILE_SIZE:
        os.remove(file_path)
        raise Exception('SIZE ERROR: Download file is too large,'
                        'the max size is {}Mb'.format(MAX_FILE_SIZE / 1024 / 1024))

    # Check file type
    suffix = os.path.splitext(file_path)[1]
    if suffix not in SUFFIX_LIST:
        os.remove(file_path)
        raise Exception('SUFFIX ERROR: File: {} with Suffix: {} '
                        'can not be recognized'.format(file_path, suffix))
    return file_path
Exemplo n.º 6
0
def test_download_repo():
    git_url = 'https://gitee.com/mindspore/models/tree/master/official/cv/googlenet'
    set_hub_dir('.cache')
    path = get_hub_dir()
    ret = _download_repo_from_url(git_url, path)
    assert ret