def load_custom_pretrained(model, cfg=None, load_fn=None, progress=False, check_hash=False): r"""Loads a custom (read non .pth) weight file Downloads checkpoint file into cache-dir like torch.hub based loaders, but calls a passed in custom load fun, or the `load_pretrained` model member fn. If the object is already present in `model_dir`, it's deserialized and returned. The default value of `model_dir` is ``<hub_dir>/checkpoints`` where `hub_dir` is the directory returned by :func:`~torch.hub.get_dir`. Args: model: The instantiated model to load weights into cfg (dict): Default pretrained model cfg load_fn: An external stand alone fn that loads weights into provided model, otherwise a fn named 'laod_pretrained' on the model will be called if it exists progress (bool, optional): whether or not to display a progress bar to stderr. Default: False check_hash(bool, optional): If True, the filename part of the URL should follow the naming convention ``filename-<sha256>.ext`` where ``<sha256>`` is the first eight or more digits of the SHA256 hash of the contents of the file. The hash is used to ensure unique names and to verify the contents of the file. Default: False """ cfg = cfg or getattr(model, 'default_cfg') if cfg is None or not cfg.get('url', None): _logger.warning("No pretrained weights exist for this model. Using random initialization.") return url = cfg['url'] # Issue warning to move data if old env is set if os.getenv('TORCH_MODEL_ZOO'): _logger.warning('TORCH_MODEL_ZOO is deprecated, please use env TORCH_HOME instead') hub_dir = get_dir() model_dir = os.path.join(hub_dir, 'checkpoints') os.makedirs(model_dir, exist_ok=True) parts = urlparse(url) filename = os.path.basename(parts.path) cached_file = os.path.join(model_dir, filename) if not os.path.exists(cached_file): _logger.info('Downloading: "{}" to {}\n'.format(url, cached_file)) hash_prefix = None if check_hash: r = HASH_REGEX.search(filename) # r is Optional[Match[str]] hash_prefix = r.group(1) if r else None download_url_to_file(url, cached_file, hash_prefix, progress=progress) if load_fn is not None: load_fn(model, cached_file) elif hasattr(model, 'load_pretrained'): model.load_pretrained(cached_file) else: _logger.warning("Valid function to load pretrained weights is not available, using random initialization.")
def get_cached_file_path(url: str, save_dir: Optional[str] = None, progress: bool = True, check_hash: bool = False, file_name: Optional[str] = None) -> str: r"""Loads the Torch serialized object at the given URL. If downloaded file is a zip file, it will be automatically decompressed If the object is already present in `model_dir`, it's deserialized and returned. The default value of ``model_dir`` is ``<hub_dir>/checkpoints`` where ``hub_dir`` is the directory returned by :func:`~torch.hub.get_dir`. Args: url (str): URL of the object to download save_dir (str, optional): directory in which to save the object progress (bool): whether or not to display a progress bar to stderr. Default: ``True`` check_hash(bool): If True, the filename part of the URL should follow the naming convention ``filename-<sha256>.ext`` where ``<sha256>`` is the first eight or more digits of the SHA256 hash of the contents of the file. The hash is used to ensure unique names and to verify the contents of the file. Default: ``False`` file_name (str, optional): name for the downloaded file. Filename from ``url`` will be used if not set. Default: ``None``. Returns: str: The path to the cached file. """ if save_dir is None: save_dir = os.path.join('webcam_resources') mkdir_or_exist(save_dir) parts = urlparse(url) filename = os.path.basename(parts.path) if file_name is not None: filename = file_name cached_file = os.path.join(save_dir, filename) if not os.path.exists(cached_file): sys.stderr.write('Downloading: "{}" to {}\n'.format(url, cached_file)) hash_prefix = None if check_hash: r = HASH_REGEX.search(filename) # r is Optional[Match[str]] hash_prefix = r.group(1) if r else None download_url_to_file(url, cached_file, hash_prefix, progress=progress) return cached_file
def cache_url(url, model_dir=None, progress=True): r"""Loads the Torch serialized object at the given URL. If the object is already present in `model_dir`, it's deserialized and returned. The filename part of the URL should follow the naming convention ``filename-<sha256>.ext`` where ``<sha256>`` is the first eight or more digits of the SHA256 hash of the contents of the file. The hash is used to ensure unique names and to verify the contents of the file. The default value of `model_dir` is ``$TORCH_HOME/models`` where ``$TORCH_HOME`` defaults to ``~/.torch``. The default directory can be overridden with the ``$TORCH_MODEL_ZOO`` environment variable. Args: url (string): URL of the object to download model_dir (string, optional): directory in which to save the object progress (bool, optional): whether or not to display a progress bar to stderr Example: >>> cached_file = sampling_free.utils.model_zoo.cache_url('https://s3.amazonaws.com/pytorch/models/resnet18-5c106cde.pth') """ if model_dir is None: torch_home = os.path.expanduser(os.getenv('TORCH_HOME', '~/.torch')) model_dir = os.getenv('TORCH_MODEL_ZOO', os.path.join(torch_home, 'models')) if not os.path.exists(model_dir): os.makedirs(model_dir) parts = urlparse(url) if parts.fragment != "": filename = parts.fragment else: filename = os.path.basename(parts.path) if filename == "model_final.pkl": # workaround as pre-trained Caffe2 models from Detectron have all the same filename # so make the full path the filename by replacing / with _ filename = parts.path.replace("/", "_") cached_file = os.path.join(model_dir, filename) if not os.path.exists(cached_file) and is_main_process(): sys.stderr.write('Downloading: "{}" to {}\n'.format(url, cached_file)) hash_prefix = HASH_REGEX.search(filename) if hash_prefix is not None: hash_prefix = hash_prefix.group(1) # workaround: Caffe2 models don't have a hash, but follow the R-50 convention, # which matches the hash PyTorch uses. So we skip the hash matching # if the hash_prefix is less than 6 characters if len(hash_prefix) < 6: hash_prefix = None _download_url_to_file(url, cached_file, hash_prefix, progress=progress) synchronize() return cached_file
def custom_cache_url(url: str, model_dir: Path = None, progress: bool = True) -> Path: r"""Loads the Torch serialized object at the given URL. If the object is already present in `model_dir`, it's deserialized and returned. The filename part of the URL should follow the naming convention ``filename-<sha256>.ext`` where ``<sha256>`` is the first eight or more digits of the SHA256 hash of the contents of the file. The hash is used to ensure unique names and to verify the contents of the file. The default value of `model_dir` is ``$TORCH_HOME/models`` where ``$TORCH_HOME`` defaults to ``~/.torch``. The default directory can be overridden with the ``$TORCH_MODEL_ZOO`` environment variable. Args: url (string): URL of the object to download model_dir (string, optional): directory in which to save the object progress (bool, optional): whether or not to display a progress bar to stderr Example: >>> cached_file = maskrcnn_benchmark.utils.model_zoo.custom_cache_url( 'https://s3.amazonaws.com/pytorch/models/resnet18-5c106cde.pth') """ if model_dir is None: model_dir = os.getenv( "TORCH_MODEL_ZOO", Path(os.path.expanduser(os.getenv("TORCH_HOME", "~/.torch"))) / "models", ) if not model_dir.exists(): model_dir.mkdir(parents=True) parts = urlparse(url) filename = os.path.basename(parts.path) if filename == "model_final.pkl": # workaround as pre-trained Caffe2 models from Detectron have all the same filename # so make the full path the filename by replacing / with _ filename = parts.path.replace("/", "_") cached_file = model_dir / filename if not cached_file.exists() and is_main_process(): sys.stderr.write(f'Downloading: "{url}" to {cached_file}\n') hash_prefix = HASH_REGEX.search(filename) if hash_prefix is not None: hash_prefix = hash_prefix.group(1) # workaround: Caffe2 models don't have a hash, but follow the R-50 convention, # which matches the hash PyTorch uses. So we skip the hash matching # if the hash_prefix is less than 6 characters if len(hash_prefix) < 6: hash_prefix = None download_url_to_file(url, str(cached_file), hash_prefix, progress=progress) synchronise_torch_barrier() return cached_file
def load_state_dict_from_url(url, model_dir=None, map_location=None, progress=True, check_hash=False, file_name=None): r"""A modified version of `torch.hub.load_state_dict_from_url`, which handles the new serialization protocol diferently. See <https://github.com/pytorch/pytorch/issues/43106> for more information. """ import os import sys import warnings import errno import torch from urllib.parse import urlparse from torch.hub import get_dir, download_url_to_file, HASH_REGEX # Issue warning to move data if old env is set if os.getenv('TORCH_MODEL_ZOO'): warnings.warn('TORCH_MODEL_ZOO is deprecated, please use env TORCH_HOME instead') if model_dir is None: hub_dir = get_dir() model_dir = os.path.join(hub_dir, 'checkpoints') try: os.makedirs(model_dir) except OSError as e: if e.errno == errno.EEXIST: # Directory already exists, ignore. pass else: # Unexpected OSError, re-raise. raise parts = urlparse(url) filename = os.path.basename(parts.path) if file_name is not None: filename = file_name cached_file = os.path.join(model_dir, filename) if not os.path.exists(cached_file): sys.stderr.write('Downloading: "{}" to {}\n'.format(url, cached_file)) hash_prefix = None if check_hash: r = HASH_REGEX.search(filename) # r is Optional[Match[str]] hash_prefix = r.group(1) if r else None download_url_to_file(url, cached_file, hash_prefix, progress=progress) return torch.load(cached_file, map_location=map_location)
def download_from_url(url): torch_home = os.path.expanduser(os.getenv("TORCH_HOME", "~/.cache")) model_dir = os.getenv("TORCH_MODEL_ZOO", os.path.join(torch_home, "torch", "checkpoints")) if not os.path.exists(model_dir): os.makedirs(model_dir) parts = urlparse(url) filename = os.path.basename(parts.path) cached_file = os.path.join(model_dir, filename) if not os.path.exists(cached_file) and is_main_process(): sys.stderr.write('Downloading: "{}" to {}\n'.format(url, cached_file)) hash_prefix = HASH_REGEX.search(filename) if hash_prefix is not None: hash_prefix = hash_prefix.group(1) if len(hash_prefix) < 6: hash_prefix = None download_url_to_file(url, cached_file, hash_prefix, progress=True) synchronize() return cached_file
def load_url(url, model_dir=None, map_location=None, progress=True, check_hash=False, file_name=None): r"""Loads the Torch serialized object at the given URL. If downloaded file is a zip file, it will be automatically decompressed If the object is already present in `model_dir`, it's deserialized and returned. The default value of ``model_dir`` is ``<hub_dir>/checkpoints`` where ``hub_dir`` is the directory returned by :func:`~torch.hub.get_dir`. Args: url (str): URL of the object to download model_dir (str, optional): directory in which to save the object map_location (optional): a function or a dict specifying how to remap storage locations (see torch.load) progress (bool, optional): whether or not to display a progress bar to stderr. Default: True check_hash(bool, optional): If True, the filename part of the URL should follow the naming convention ``filename-<sha256>.ext`` where ``<sha256>`` is the first eight or more digits of the SHA256 hash of the contents of the file. The hash is used to ensure unique names and to verify the contents of the file. Default: False file_name (str, optional): name for the downloaded file. Filename from ``url`` will be used if not set. Default: None. Example: >>> url = ('https://s3.amazonaws.com/pytorch/models/resnet18-5c106' ... 'cde.pth') >>> state_dict = torch.hub.load_state_dict_from_url(url) """ # Issue warning to move data if old env is set if os.getenv('TORCH_MODEL_ZOO'): warnings.warn( 'TORCH_MODEL_ZOO is deprecated, please use env ' 'TORCH_HOME instead', DeprecationWarning) if model_dir is None: torch_home = _get_torch_home() model_dir = os.path.join(torch_home, 'checkpoints') mkdir_or_exist(model_dir) parts = urlparse(url) filename = os.path.basename(parts.path) if file_name is not None: filename = file_name cached_file = os.path.join(model_dir, filename) if not os.path.exists(cached_file): sys.stderr.write('Downloading: "{}" to {}\n'.format( url, cached_file)) hash_prefix = None if check_hash: r = HASH_REGEX.search(filename) # r is Optional[Match[str]] hash_prefix = r.group(1) if r else None download_url_to_file( url, cached_file, hash_prefix, progress=progress) if _is_legacy_zip_format(cached_file): return _legacy_zip_load(cached_file, model_dir, map_location) try: return torch.load(cached_file, map_location=map_location) except RuntimeError as error: if digit_version(TORCH_VERSION) < digit_version('1.5.0'): warnings.warn( f'If the error is the same as "{cached_file} is a zip ' 'archive (did you mean to use torch.jit.load()?)", you can' ' upgrade your torch to 1.5.0 or higher (current torch ' f'version is {TORCH_VERSION}). The error was raised ' ' because the checkpoint was saved in torch>=1.6.0 but ' 'loaded in torch<1.5.') raise error
def load_npz_from_url(url, model_dir=None, map_location=None, progress=True, check_hash=False): import os import sys import errno import warnings import numpy as np from urllib.parse import urlparse from torch.hub import _get_torch_home, download_url_to_file, HASH_REGEX r"""Loads the Torch serialized object at the given URL. If downloaded file is a zip file, it will be automatically decompressed. If the object is already present in `model_dir`, it's deserialized and returned. The default value of `model_dir` is ``$TORCH_HOME/checkpoints`` where environment variable ``$TORCH_HOME`` defaults to ``$XDG_CACHE_HOME/torch``. ``$XDG_CACHE_HOME`` follows the X Design Group specification of the Linux filesytem layout, with a default value ``~/.cache`` if not set. Args: url (string): URL of the object to download model_dir (string, optional): directory in which to save the object map_location (optional): a function or a dict specifying how to remap storage locations (see torch.load) progress (bool, optional): whether or not to display a progress bar to stderr. Default: True check_hash(bool, optional): If True, the filename part of the URL should follow the naming convention ``filename-<sha256>.ext`` where ``<sha256>`` is the first eight or more digits of the SHA256 hash of the contents of the file. The hash is used to ensure unique names and to verify the contents of the file. Default: False Example: >>> state_dict = torch.hub.load_state_dict_from_url('https://s3.amazonaws.com/pytorch/models/resnet18-5c106cde.pth') """ # Issue warning to move data if old env is set if os.getenv('TORCH_MODEL_ZOO'): warnings.warn( 'TORCH_MODEL_ZOO is deprecated, please use env TORCH_HOME instead') if model_dir is None: torch_home = _get_torch_home() model_dir = os.path.join(torch_home, 'checkpoints') try: os.makedirs(model_dir) except OSError as e: if e.errno == errno.EEXIST: # Directory already exists, ignore. pass else: # Unexpected OSError, re-raise. raise parts = urlparse(url) filename = os.path.basename(parts.path) cached_file = os.path.join(model_dir, filename) if not os.path.exists(cached_file): sys.stderr.write('Downloading: "{}" to {}\n'.format(url, cached_file)) hash_prefix = HASH_REGEX.search(filename).group( 1) if check_hash else None download_url_to_file(url, cached_file, hash_prefix, progress=progress) return np.load(cached_file)