Ejemplo n.º 1
0
def get_and_check_clang_format():
    """
    Download a platform-appropriate clang-format binary if one doesn't already exist at the expected location and verify
    that it is the right binary by checking its SHA1 hash against the expected hash.
    """
    # If the host platform is not in PLATFORM_TO_HASH, it is unsupported.
    if HOST_PLATFORM not in PLATFORM_TO_HASH:
        print(f"Unsupported platform: {HOST_PLATFORM}")
        return False
    if HOST_PLATFORM not in PLATFORM_TO_CF_URL:
        print(f"Unsupported platform: {HOST_PLATFORM}")
        return False

    try:
        download_url(PLATFORM_TO_CF_URL[HOST_PLATFORM],
                     CLANG_FORMAT_PATH,
                     PLATFORM_TO_HASH[HOST_PLATFORM],
                     hash_type="sha1")
    except Exception as e:
        print(f"Download {CLANG_FORMAT_PATH} failed: {e}")
        print(f"Please remove {CLANG_FORMAT_PATH} and retry.")
        return False

    # Make sure the binary is executable.
    mode = os.stat(CLANG_FORMAT_PATH).st_mode
    mode |= stat.S_IXUSR
    os.chmod(CLANG_FORMAT_PATH, mode)
    print(f"Using clang-format located at {CLANG_FORMAT_PATH}")

    return True
Ejemplo n.º 2
0
def verify_metadata(
    meta_file: Optional[Union[str, Sequence[str]]] = None,
    filepath: Optional[PathLike] = None,
    create_dir: Optional[bool] = None,
    hash_val: Optional[str] = None,
    hash_type: Optional[str] = None,
    args_file: Optional[str] = None,
    **kwargs,
):
    """
    Verify the provided `metadata` file based on the predefined `schema`.
    `metadata` content must contain the `schema` field for the URL of schema file to download.
    The schema standard follows: http://json-schema.org/.

    Args:
        meta_file: filepath of the metadata file to verify, if `None`, must be provided in `args_file`.
            if it is a list of file paths, the content of them will be merged.
        filepath: file path to store the downloaded schema.
        create_dir: whether to create directories if not existing, default to `True`.
        hash_val: if not None, define the hash value to verify the downloaded schema file.
        hash_type: if not None, define the hash type to verify the downloaded schema file. Defaults to "md5".
        args_file: a JSON or YAML file to provide default values for all the args in this function.
            so that the command line inputs can be simplified.
        kwargs: other arguments for `jsonschema.validate()`. for more details:
            https://python-jsonschema.readthedocs.io/en/stable/validate/#jsonschema.validate.

    """

    _args = _update_args(
        args=args_file,
        meta_file=meta_file,
        filepath=filepath,
        create_dir=create_dir,
        hash_val=hash_val,
        hash_type=hash_type,
        **kwargs,
    )
    _log_input_summary(tag="verify_metadata", args=_args)
    filepath_, meta_file_, create_dir_, hash_val_, hash_type_ = _pop_args(
        _args, "filepath", "meta_file", create_dir=True, hash_val=None, hash_type="md5"
    )

    check_parent_dir(path=filepath_, create_dir=create_dir_)
    metadata = ConfigParser.load_config_files(files=meta_file_)
    url = metadata.get("schema")
    if url is None:
        raise ValueError("must provide the `schema` field in the metadata for the URL of schema file.")
    download_url(url=url, filepath=filepath_, hash_val=hash_val_, hash_type=hash_type_, progress=True)
    schema = ConfigParser.load_config_file(filepath=filepath_)

    try:
        # the rest key-values in the _args are for `validate` API
        validate(instance=metadata, schema=schema, **_args)
    except ValidationError as e:  # pylint: disable=E0712
        # as the error message is very long, only extract the key information
        raise ValueError(
            re.compile(r".*Failed validating", re.S).findall(str(e))[0] + f" against schema `{url}`."
        ) from e
    logger.info("metadata is verified with no error.")
Ejemplo n.º 3
0
def _download_from_github(repo: str, download_path: Path, filename: str, progress: bool = True):
    if len(repo.split("/")) != 3:
        raise ValueError("if source is `github`, repo should be in the form of `repo_owner/repo_name/release_tag`.")
    repo_owner, repo_name, tag_name = repo.split("/")
    if ".zip" not in filename:
        filename += ".zip"
    url = _get_git_release_url(repo_owner, repo_name, tag_name=tag_name, filename=filename)
    filepath = download_path / f"{filename}"
    download_url(url=url, filepath=filepath, hash_val=None, progress=progress)
    extractall(filepath=filepath, output_dir=download_path, has_base=True)
Ejemplo n.º 4
0
def _load_state_dict(model: nn.Module, arch: str, progress: bool):
    """
    This function is used to load pretrained models.
    """
    model_url = look_up_option(arch, SE_NET_MODELS, None)
    if model_url is None:
        raise ValueError(
            "only 'senet154', 'se_resnet50', 'se_resnet101',  'se_resnet152', 'se_resnext50_32x4d', "
            +
            "and se_resnext101_32x4d are supported to load pretrained weights."
        )

    pattern_conv = re.compile(r"^(layer[1-4]\.\d\.(?:conv)\d\.)(\w*)$")
    pattern_bn = re.compile(r"^(layer[1-4]\.\d\.)(?:bn)(\d\.)(\w*)$")
    pattern_se = re.compile(r"^(layer[1-4]\.\d\.)(?:se_module.fc1.)(\w*)$")
    pattern_se2 = re.compile(r"^(layer[1-4]\.\d\.)(?:se_module.fc2.)(\w*)$")
    pattern_down_conv = re.compile(
        r"^(layer[1-4]\.\d\.)(?:downsample.0.)(\w*)$")
    pattern_down_bn = re.compile(r"^(layer[1-4]\.\d\.)(?:downsample.1.)(\w*)$")

    if isinstance(model_url, dict):
        download_url(model_url["url"], filepath=model_url["filename"])
        state_dict = torch.load(model_url["filename"], map_location=None)
    else:
        state_dict = load_state_dict_from_url(model_url, progress=progress)
    for key in list(state_dict.keys()):
        new_key = None
        if pattern_conv.match(key):
            new_key = re.sub(pattern_conv, r"\1conv.\2", key)
        elif pattern_bn.match(key):
            new_key = re.sub(pattern_bn, r"\1conv\2adn.N.\3", key)
        elif pattern_se.match(key):
            state_dict[key] = state_dict[key].squeeze()
            new_key = re.sub(pattern_se, r"\1se_layer.fc.0.\2", key)
        elif pattern_se2.match(key):
            state_dict[key] = state_dict[key].squeeze()
            new_key = re.sub(pattern_se2, r"\1se_layer.fc.2.\2", key)
        elif pattern_down_conv.match(key):
            new_key = re.sub(pattern_down_conv, r"\1project.conv.\2", key)
        elif pattern_down_bn.match(key):
            new_key = re.sub(pattern_down_bn, r"\1project.adn.N.\2", key)
        if new_key:
            state_dict[new_key] = state_dict[key]
            del state_dict[key]

    model_dict = model.state_dict()
    state_dict = {
        k: v
        for k, v in state_dict.items()
        if (k in model_dict) and (model_dict[k].shape == state_dict[k].shape)
    }
    model_dict.update(state_dict)
    model.load_state_dict(model_dict)
Ejemplo n.º 5
0
 def setUp(self):
     download_url(FILE_URL, FILE_PATH, "5a3cfd4fd725c50578ddb80b517b759f")
Ejemplo n.º 6
0
def download_url_or_skip_test(*args, **kwargs):
    """``download_url`` and skip the tests if any downloading error occurs."""
    with skip_if_downloading_fails():
        download_url(*args, **kwargs)
Ejemplo n.º 7
0
def download(
    name: Optional[str] = None,
    bundle_dir: Optional[PathLike] = None,
    source: str = "github",
    repo: Optional[str] = None,
    url: Optional[str] = None,
    progress: bool = True,
    args_file: Optional[str] = None,
):
    """
    download bundle from the specified source or url. The bundle should be a zip file and it
    will be extracted after downloading.
    This function refers to:
    https://pytorch.org/docs/stable/_modules/torch/hub.html

    Typical usage examples:

    .. code-block:: bash

        # Execute this module as a CLI entry, and download bundle:
        python -m monai.bundle download --name "bundle_name" --source "github" --repo "repo_owner/repo_name/release_tag"

        # Execute this module as a CLI entry, and download bundle via URL:
        python -m monai.bundle download --name "bundle_name" --url <url>

        # Set default args of `run` in a JSON / YAML file, help to record and simplify the command line.
        # Other args still can override the default args at runtime.
        # The content of the JSON / YAML file is a dictionary. For example:
        # {"name": "spleen", "bundle_dir": "download", "source": ""}
        # then do the following command for downloading:
        python -m monai.bundle download --args_file "args.json" --source "github"

    Args:
        name: bundle name. If `None` and `url` is `None`, it must be provided in `args_file`.
        bundle_dir: target directory to store the downloaded data.
            Default is `bundle` subfolder under`torch.hub get_dir()`.
        source: place that saved the bundle.
            If `source` is `github`, the bundle should be within the releases.
        repo: repo name. If `None` and `url` is `None`, it must be provided in `args_file`.
            If `source` is `github`, it should be in the form of `repo_owner/repo_name/release_tag`.
            For example: `Project-MONAI/MONAI-extra-test-data/0.8.1`.
        url: url to download the data. If not `None`, data will be downloaded directly
            and `source` will not be checked.
            If `name` is `None`, filename is determined by `monai.apps.utils._basename(url)`.
        progress: whether to display a progress bar.
        args_file: a JSON or YAML file to provide default values for all the args in this function.
            so that the command line inputs can be simplified.

    """
    _args = _update_args(args=args_file,
                         name=name,
                         bundle_dir=bundle_dir,
                         source=source,
                         repo=repo,
                         url=url,
                         progress=progress)

    _log_input_summary(tag="download", args=_args)
    name_, bundle_dir_, source_, repo_, url_, progress_ = _pop_args(
        _args,
        name=None,
        bundle_dir=None,
        source="github",
        repo=None,
        url=None,
        progress=True)

    bundle_dir_ = _process_bundle_dir(bundle_dir_)

    if url_ is not None:
        if name is not None:
            filepath = bundle_dir_ / f"{name}.zip"
        else:
            filepath = bundle_dir_ / f"{_basename(url_)}"
        download_url(url=url_,
                     filepath=filepath,
                     hash_val=None,
                     progress=progress_)
        extractall(filepath=filepath, output_dir=bundle_dir_, has_base=True)
    elif source_ == "github":
        if name_ is None or repo_ is None:
            raise ValueError(
                f"To download from source: Github, `name` and `repo` must be provided, got {name_} and {repo_}."
            )
        _download_from_github(repo=repo_,
                              download_path=bundle_dir_,
                              filename=name_,
                              progress=progress_)
    else:
        raise NotImplementedError(
            f"Currently only download from provided URL in `url` or Github is implemented, got source: {source_}."
        )