def get_and_check_clang_format(): """ Download a platform-appropriate clang-format binary if one doesn't already exist at the expected location and verify that it is the right binary by checking its SHA1 hash against the expected hash. """ # If the host platform is not in PLATFORM_TO_HASH, it is unsupported. if HOST_PLATFORM not in PLATFORM_TO_HASH: print(f"Unsupported platform: {HOST_PLATFORM}") return False if HOST_PLATFORM not in PLATFORM_TO_CF_URL: print(f"Unsupported platform: {HOST_PLATFORM}") return False try: download_url(PLATFORM_TO_CF_URL[HOST_PLATFORM], CLANG_FORMAT_PATH, PLATFORM_TO_HASH[HOST_PLATFORM], hash_type="sha1") except Exception as e: print(f"Download {CLANG_FORMAT_PATH} failed: {e}") print(f"Please remove {CLANG_FORMAT_PATH} and retry.") return False # Make sure the binary is executable. mode = os.stat(CLANG_FORMAT_PATH).st_mode mode |= stat.S_IXUSR os.chmod(CLANG_FORMAT_PATH, mode) print(f"Using clang-format located at {CLANG_FORMAT_PATH}") return True
def verify_metadata( meta_file: Optional[Union[str, Sequence[str]]] = None, filepath: Optional[PathLike] = None, create_dir: Optional[bool] = None, hash_val: Optional[str] = None, hash_type: Optional[str] = None, args_file: Optional[str] = None, **kwargs, ): """ Verify the provided `metadata` file based on the predefined `schema`. `metadata` content must contain the `schema` field for the URL of schema file to download. The schema standard follows: http://json-schema.org/. Args: meta_file: filepath of the metadata file to verify, if `None`, must be provided in `args_file`. if it is a list of file paths, the content of them will be merged. filepath: file path to store the downloaded schema. create_dir: whether to create directories if not existing, default to `True`. hash_val: if not None, define the hash value to verify the downloaded schema file. hash_type: if not None, define the hash type to verify the downloaded schema file. Defaults to "md5". args_file: a JSON or YAML file to provide default values for all the args in this function. so that the command line inputs can be simplified. kwargs: other arguments for `jsonschema.validate()`. for more details: https://python-jsonschema.readthedocs.io/en/stable/validate/#jsonschema.validate. """ _args = _update_args( args=args_file, meta_file=meta_file, filepath=filepath, create_dir=create_dir, hash_val=hash_val, hash_type=hash_type, **kwargs, ) _log_input_summary(tag="verify_metadata", args=_args) filepath_, meta_file_, create_dir_, hash_val_, hash_type_ = _pop_args( _args, "filepath", "meta_file", create_dir=True, hash_val=None, hash_type="md5" ) check_parent_dir(path=filepath_, create_dir=create_dir_) metadata = ConfigParser.load_config_files(files=meta_file_) url = metadata.get("schema") if url is None: raise ValueError("must provide the `schema` field in the metadata for the URL of schema file.") download_url(url=url, filepath=filepath_, hash_val=hash_val_, hash_type=hash_type_, progress=True) schema = ConfigParser.load_config_file(filepath=filepath_) try: # the rest key-values in the _args are for `validate` API validate(instance=metadata, schema=schema, **_args) except ValidationError as e: # pylint: disable=E0712 # as the error message is very long, only extract the key information raise ValueError( re.compile(r".*Failed validating", re.S).findall(str(e))[0] + f" against schema `{url}`." ) from e logger.info("metadata is verified with no error.")
def _download_from_github(repo: str, download_path: Path, filename: str, progress: bool = True): if len(repo.split("/")) != 3: raise ValueError("if source is `github`, repo should be in the form of `repo_owner/repo_name/release_tag`.") repo_owner, repo_name, tag_name = repo.split("/") if ".zip" not in filename: filename += ".zip" url = _get_git_release_url(repo_owner, repo_name, tag_name=tag_name, filename=filename) filepath = download_path / f"{filename}" download_url(url=url, filepath=filepath, hash_val=None, progress=progress) extractall(filepath=filepath, output_dir=download_path, has_base=True)
def _load_state_dict(model: nn.Module, arch: str, progress: bool): """ This function is used to load pretrained models. """ model_url = look_up_option(arch, SE_NET_MODELS, None) if model_url is None: raise ValueError( "only 'senet154', 'se_resnet50', 'se_resnet101', 'se_resnet152', 'se_resnext50_32x4d', " + "and se_resnext101_32x4d are supported to load pretrained weights." ) pattern_conv = re.compile(r"^(layer[1-4]\.\d\.(?:conv)\d\.)(\w*)$") pattern_bn = re.compile(r"^(layer[1-4]\.\d\.)(?:bn)(\d\.)(\w*)$") pattern_se = re.compile(r"^(layer[1-4]\.\d\.)(?:se_module.fc1.)(\w*)$") pattern_se2 = re.compile(r"^(layer[1-4]\.\d\.)(?:se_module.fc2.)(\w*)$") pattern_down_conv = re.compile( r"^(layer[1-4]\.\d\.)(?:downsample.0.)(\w*)$") pattern_down_bn = re.compile(r"^(layer[1-4]\.\d\.)(?:downsample.1.)(\w*)$") if isinstance(model_url, dict): download_url(model_url["url"], filepath=model_url["filename"]) state_dict = torch.load(model_url["filename"], map_location=None) else: state_dict = load_state_dict_from_url(model_url, progress=progress) for key in list(state_dict.keys()): new_key = None if pattern_conv.match(key): new_key = re.sub(pattern_conv, r"\1conv.\2", key) elif pattern_bn.match(key): new_key = re.sub(pattern_bn, r"\1conv\2adn.N.\3", key) elif pattern_se.match(key): state_dict[key] = state_dict[key].squeeze() new_key = re.sub(pattern_se, r"\1se_layer.fc.0.\2", key) elif pattern_se2.match(key): state_dict[key] = state_dict[key].squeeze() new_key = re.sub(pattern_se2, r"\1se_layer.fc.2.\2", key) elif pattern_down_conv.match(key): new_key = re.sub(pattern_down_conv, r"\1project.conv.\2", key) elif pattern_down_bn.match(key): new_key = re.sub(pattern_down_bn, r"\1project.adn.N.\2", key) if new_key: state_dict[new_key] = state_dict[key] del state_dict[key] model_dict = model.state_dict() state_dict = { k: v for k, v in state_dict.items() if (k in model_dict) and (model_dict[k].shape == state_dict[k].shape) } model_dict.update(state_dict) model.load_state_dict(model_dict)
def setUp(self): download_url(FILE_URL, FILE_PATH, "5a3cfd4fd725c50578ddb80b517b759f")
def download_url_or_skip_test(*args, **kwargs): """``download_url`` and skip the tests if any downloading error occurs.""" with skip_if_downloading_fails(): download_url(*args, **kwargs)
def download( name: Optional[str] = None, bundle_dir: Optional[PathLike] = None, source: str = "github", repo: Optional[str] = None, url: Optional[str] = None, progress: bool = True, args_file: Optional[str] = None, ): """ download bundle from the specified source or url. The bundle should be a zip file and it will be extracted after downloading. This function refers to: https://pytorch.org/docs/stable/_modules/torch/hub.html Typical usage examples: .. code-block:: bash # Execute this module as a CLI entry, and download bundle: python -m monai.bundle download --name "bundle_name" --source "github" --repo "repo_owner/repo_name/release_tag" # Execute this module as a CLI entry, and download bundle via URL: python -m monai.bundle download --name "bundle_name" --url <url> # Set default args of `run` in a JSON / YAML file, help to record and simplify the command line. # Other args still can override the default args at runtime. # The content of the JSON / YAML file is a dictionary. For example: # {"name": "spleen", "bundle_dir": "download", "source": ""} # then do the following command for downloading: python -m monai.bundle download --args_file "args.json" --source "github" Args: name: bundle name. If `None` and `url` is `None`, it must be provided in `args_file`. bundle_dir: target directory to store the downloaded data. Default is `bundle` subfolder under`torch.hub get_dir()`. source: place that saved the bundle. If `source` is `github`, the bundle should be within the releases. repo: repo name. If `None` and `url` is `None`, it must be provided in `args_file`. If `source` is `github`, it should be in the form of `repo_owner/repo_name/release_tag`. For example: `Project-MONAI/MONAI-extra-test-data/0.8.1`. url: url to download the data. If not `None`, data will be downloaded directly and `source` will not be checked. If `name` is `None`, filename is determined by `monai.apps.utils._basename(url)`. progress: whether to display a progress bar. args_file: a JSON or YAML file to provide default values for all the args in this function. so that the command line inputs can be simplified. """ _args = _update_args(args=args_file, name=name, bundle_dir=bundle_dir, source=source, repo=repo, url=url, progress=progress) _log_input_summary(tag="download", args=_args) name_, bundle_dir_, source_, repo_, url_, progress_ = _pop_args( _args, name=None, bundle_dir=None, source="github", repo=None, url=None, progress=True) bundle_dir_ = _process_bundle_dir(bundle_dir_) if url_ is not None: if name is not None: filepath = bundle_dir_ / f"{name}.zip" else: filepath = bundle_dir_ / f"{_basename(url_)}" download_url(url=url_, filepath=filepath, hash_val=None, progress=progress_) extractall(filepath=filepath, output_dir=bundle_dir_, has_base=True) elif source_ == "github": if name_ is None or repo_ is None: raise ValueError( f"To download from source: Github, `name` and `repo` must be provided, got {name_} and {repo_}." ) _download_from_github(repo=repo_, download_path=bundle_dir_, filename=name_, progress=progress_) else: raise NotImplementedError( f"Currently only download from provided URL in `url` or Github is implemented, got source: {source_}." )