Esempio n. 1
0
def load_custom_pretrained(model, cfg=None, load_fn=None, progress=False, check_hash=False):
    r"""Loads a custom (read non .pth) weight file
    Downloads checkpoint file into cache-dir like torch.hub based loaders, but calls
    a passed in custom load fun, or the `load_pretrained` model member fn.
    If the object is already present in `model_dir`, it's deserialized and returned.
    The default value of `model_dir` is ``<hub_dir>/checkpoints`` where
    `hub_dir` is the directory returned by :func:`~torch.hub.get_dir`.
    Args:
        model: The instantiated model to load weights into
        cfg (dict): Default pretrained model cfg
        load_fn: An external stand alone fn that loads weights into provided model, otherwise a fn named
            'laod_pretrained' on the model will be called if it exists
        progress (bool, optional): whether or not to display a progress bar to stderr. Default: False
        check_hash(bool, optional): If True, the filename part of the URL should follow the naming convention
            ``filename-<sha256>.ext`` where ``<sha256>`` is the first eight or more
            digits of the SHA256 hash of the contents of the file. The hash is used to
            ensure unique names and to verify the contents of the file. Default: False
    """
    cfg = cfg or getattr(model, 'default_cfg')
    if cfg is None or not cfg.get('url', None):
        _logger.warning("No pretrained weights exist for this model. Using random initialization.")
        return
    url = cfg['url']

    # Issue warning to move data if old env is set
    if os.getenv('TORCH_MODEL_ZOO'):
        _logger.warning('TORCH_MODEL_ZOO is deprecated, please use env TORCH_HOME instead')

    hub_dir = get_dir()
    model_dir = os.path.join(hub_dir, 'checkpoints')

    os.makedirs(model_dir, exist_ok=True)

    parts = urlparse(url)
    filename = os.path.basename(parts.path)
    cached_file = os.path.join(model_dir, filename)
    if not os.path.exists(cached_file):
        _logger.info('Downloading: "{}" to {}\n'.format(url, cached_file))
        hash_prefix = None
        if check_hash:
            r = HASH_REGEX.search(filename)  # r is Optional[Match[str]]
            hash_prefix = r.group(1) if r else None
        download_url_to_file(url, cached_file, hash_prefix, progress=progress)

    if load_fn is not None:
        load_fn(model, cached_file)
    elif hasattr(model, 'load_pretrained'):
        model.load_pretrained(cached_file)
    else:
        _logger.warning("Valid function to load pretrained weights is not available, using random initialization.")
Esempio n. 2
0
def get_cached_file_path(url: str,
                         save_dir: Optional[str] = None,
                         progress: bool = True,
                         check_hash: bool = False,
                         file_name: Optional[str] = None) -> str:
    r"""Loads the Torch serialized object at the given URL.

    If downloaded file is a zip file, it will be automatically decompressed

    If the object is already present in `model_dir`, it's deserialized and
    returned.
    The default value of ``model_dir`` is ``<hub_dir>/checkpoints`` where
    ``hub_dir`` is the directory returned by :func:`~torch.hub.get_dir`.

    Args:
        url (str): URL of the object to download
        save_dir (str, optional): directory in which to save the object
        progress (bool): whether or not to display a progress bar
            to stderr. Default: ``True``
        check_hash(bool): If True, the filename part of the URL
            should follow the naming convention ``filename-<sha256>.ext``
            where ``<sha256>`` is the first eight or more digits of the
            SHA256 hash of the contents of the file. The hash is used to
            ensure unique names and to verify the contents of the file.
            Default: ``False``
        file_name (str, optional): name for the downloaded file. Filename
            from ``url`` will be used if not set. Default: ``None``.

    Returns:
        str: The path to the cached file.
    """
    if save_dir is None:
        save_dir = os.path.join('webcam_resources')

    mkdir_or_exist(save_dir)

    parts = urlparse(url)
    filename = os.path.basename(parts.path)
    if file_name is not None:
        filename = file_name
    cached_file = os.path.join(save_dir, filename)
    if not os.path.exists(cached_file):
        sys.stderr.write('Downloading: "{}" to {}\n'.format(url, cached_file))
        hash_prefix = None
        if check_hash:
            r = HASH_REGEX.search(filename)  # r is Optional[Match[str]]
            hash_prefix = r.group(1) if r else None
        download_url_to_file(url, cached_file, hash_prefix, progress=progress)
    return cached_file
Esempio n. 3
0
def cache_url(url, model_dir=None, progress=True):
    r"""Loads the Torch serialized object at the given URL.
    If the object is already present in `model_dir`, it's deserialized and
    returned. The filename part of the URL should follow the naming convention
    ``filename-<sha256>.ext`` where ``<sha256>`` is the first eight or more
    digits of the SHA256 hash of the contents of the file. The hash is used to
    ensure unique names and to verify the contents of the file.
    The default value of `model_dir` is ``$TORCH_HOME/models`` where
    ``$TORCH_HOME`` defaults to ``~/.torch``. The default directory can be
    overridden with the ``$TORCH_MODEL_ZOO`` environment variable.
    Args:
        url (string): URL of the object to download
        model_dir (string, optional): directory in which to save the object
        progress (bool, optional): whether or not to display a progress bar to stderr
    Example:
        >>> cached_file = sampling_free.utils.model_zoo.cache_url('https://s3.amazonaws.com/pytorch/models/resnet18-5c106cde.pth')
    """
    if model_dir is None:
        torch_home = os.path.expanduser(os.getenv('TORCH_HOME', '~/.torch'))
        model_dir = os.getenv('TORCH_MODEL_ZOO',
                              os.path.join(torch_home, 'models'))
    if not os.path.exists(model_dir):
        os.makedirs(model_dir)
    parts = urlparse(url)
    if parts.fragment != "":
        filename = parts.fragment
    else:
        filename = os.path.basename(parts.path)
    if filename == "model_final.pkl":
        # workaround as pre-trained Caffe2 models from Detectron have all the same filename
        # so make the full path the filename by replacing / with _
        filename = parts.path.replace("/", "_")
    cached_file = os.path.join(model_dir, filename)
    if not os.path.exists(cached_file) and is_main_process():
        sys.stderr.write('Downloading: "{}" to {}\n'.format(url, cached_file))
        hash_prefix = HASH_REGEX.search(filename)
        if hash_prefix is not None:
            hash_prefix = hash_prefix.group(1)
            # workaround: Caffe2 models don't have a hash, but follow the R-50 convention,
            # which matches the hash PyTorch uses. So we skip the hash matching
            # if the hash_prefix is less than 6 characters
            if len(hash_prefix) < 6:
                hash_prefix = None
        _download_url_to_file(url, cached_file, hash_prefix, progress=progress)
    synchronize()
    return cached_file
Esempio n. 4
0
def custom_cache_url(url: str, model_dir: Path = None, progress: bool = True) -> Path:
    r"""Loads the Torch serialized object at the given URL.
    If the object is already present in `model_dir`, it's deserialized and
    returned. The filename part of the URL should follow the naming convention
    ``filename-<sha256>.ext`` where ``<sha256>`` is the first eight or more
    digits of the SHA256 hash of the contents of the file. The hash is used to
    ensure unique names and to verify the contents of the file.
    The default value of `model_dir` is ``$TORCH_HOME/models`` where
    ``$TORCH_HOME`` defaults to ``~/.torch``. The default directory can be
    overridden with the ``$TORCH_MODEL_ZOO`` environment variable.
    Args:
    url (string): URL of the object to download
    model_dir (string, optional): directory in which to save the object
    progress (bool, optional): whether or not to display a progress bar to stderr
    Example:
    >>> cached_file = maskrcnn_benchmark.utils.model_zoo.custom_cache_url(
    'https://s3.amazonaws.com/pytorch/models/resnet18-5c106cde.pth')
    """
    if model_dir is None:
        model_dir = os.getenv(
            "TORCH_MODEL_ZOO",
            Path(os.path.expanduser(os.getenv("TORCH_HOME", "~/.torch"))) / "models",
        )
    if not model_dir.exists():
        model_dir.mkdir(parents=True)
    parts = urlparse(url)
    filename = os.path.basename(parts.path)
    if filename == "model_final.pkl":
        # workaround as pre-trained Caffe2 models from Detectron have all the same filename
        # so make the full path the filename by replacing / with _
        filename = parts.path.replace("/", "_")
    cached_file = model_dir / filename
    if not cached_file.exists() and is_main_process():
        sys.stderr.write(f'Downloading: "{url}" to {cached_file}\n')
        hash_prefix = HASH_REGEX.search(filename)
        if hash_prefix is not None:
            hash_prefix = hash_prefix.group(1)
            # workaround: Caffe2 models don't have a hash, but follow the R-50 convention,
            # which matches the hash PyTorch uses. So we skip the hash matching
            # if the hash_prefix is less than 6 characters
            if len(hash_prefix) < 6:
                hash_prefix = None
        download_url_to_file(url, str(cached_file), hash_prefix, progress=progress)
    synchronise_torch_barrier()
    return cached_file
Esempio n. 5
0
def load_state_dict_from_url(url, model_dir=None, map_location=None, progress=True, check_hash=False, file_name=None):
    r"""A modified version of `torch.hub.load_state_dict_from_url`, which handles the new 
    serialization protocol diferently. 
    See <https://github.com/pytorch/pytorch/issues/43106> for more information.
    """
    import os
    import sys
    import warnings
    import errno
    import torch
    from urllib.parse import urlparse
    from torch.hub import get_dir, download_url_to_file, HASH_REGEX

    # Issue warning to move data if old env is set
    if os.getenv('TORCH_MODEL_ZOO'):
        warnings.warn('TORCH_MODEL_ZOO is deprecated, please use env TORCH_HOME instead')

    if model_dir is None:
        hub_dir = get_dir()
        model_dir = os.path.join(hub_dir, 'checkpoints')

    try:
        os.makedirs(model_dir)
    except OSError as e:
        if e.errno == errno.EEXIST:
            # Directory already exists, ignore.
            pass
        else:
            # Unexpected OSError, re-raise.
            raise

    parts = urlparse(url)
    filename = os.path.basename(parts.path)
    if file_name is not None:
        filename = file_name
    cached_file = os.path.join(model_dir, filename)
    if not os.path.exists(cached_file):
        sys.stderr.write('Downloading: "{}" to {}\n'.format(url, cached_file))
        hash_prefix = None
        if check_hash:
            r = HASH_REGEX.search(filename)  # r is Optional[Match[str]]
            hash_prefix = r.group(1) if r else None
        download_url_to_file(url, cached_file, hash_prefix, progress=progress)

    return torch.load(cached_file, map_location=map_location)
def download_from_url(url):
    torch_home = os.path.expanduser(os.getenv("TORCH_HOME", "~/.cache"))
    model_dir = os.getenv("TORCH_MODEL_ZOO",
                          os.path.join(torch_home, "torch", "checkpoints"))

    if not os.path.exists(model_dir):
        os.makedirs(model_dir)
    parts = urlparse(url)
    filename = os.path.basename(parts.path)

    cached_file = os.path.join(model_dir, filename)
    if not os.path.exists(cached_file) and is_main_process():
        sys.stderr.write('Downloading: "{}" to {}\n'.format(url, cached_file))
        hash_prefix = HASH_REGEX.search(filename)
        if hash_prefix is not None:
            hash_prefix = hash_prefix.group(1)
            if len(hash_prefix) < 6:
                hash_prefix = None
        download_url_to_file(url, cached_file, hash_prefix, progress=True)
    synchronize()
    return cached_file
Esempio n. 7
0
    def load_url(url,
                 model_dir=None,
                 map_location=None,
                 progress=True,
                 check_hash=False,
                 file_name=None):
        r"""Loads the Torch serialized object at the given URL.

        If downloaded file is a zip file, it will be automatically decompressed

        If the object is already present in `model_dir`, it's deserialized and
        returned.
        The default value of ``model_dir`` is ``<hub_dir>/checkpoints`` where
        ``hub_dir`` is the directory returned by :func:`~torch.hub.get_dir`.

        Args:
            url (str): URL of the object to download
            model_dir (str, optional): directory in which to save the object
            map_location (optional): a function or a dict specifying how to
                remap storage locations (see torch.load)
            progress (bool, optional): whether or not to display a progress bar
                to stderr. Default: True
            check_hash(bool, optional): If True, the filename part of the URL
                should follow the naming convention ``filename-<sha256>.ext``
                where ``<sha256>`` is the first eight or more digits of the
                SHA256 hash of the contents of the file. The hash is used to
                ensure unique names and to verify the contents of the file.
                Default: False
            file_name (str, optional): name for the downloaded file. Filename
                from ``url`` will be used if not set. Default: None.

        Example:
            >>> url = ('https://s3.amazonaws.com/pytorch/models/resnet18-5c106'
            ...        'cde.pth')
            >>> state_dict = torch.hub.load_state_dict_from_url(url)
        """
        # Issue warning to move data if old env is set
        if os.getenv('TORCH_MODEL_ZOO'):
            warnings.warn(
                'TORCH_MODEL_ZOO is deprecated, please use env '
                'TORCH_HOME instead', DeprecationWarning)

        if model_dir is None:
            torch_home = _get_torch_home()
            model_dir = os.path.join(torch_home, 'checkpoints')

        mkdir_or_exist(model_dir)

        parts = urlparse(url)
        filename = os.path.basename(parts.path)
        if file_name is not None:
            filename = file_name
        cached_file = os.path.join(model_dir, filename)
        if not os.path.exists(cached_file):
            sys.stderr.write('Downloading: "{}" to {}\n'.format(
                url, cached_file))
            hash_prefix = None
            if check_hash:
                r = HASH_REGEX.search(filename)  # r is Optional[Match[str]]
                hash_prefix = r.group(1) if r else None
            download_url_to_file(
                url, cached_file, hash_prefix, progress=progress)

        if _is_legacy_zip_format(cached_file):
            return _legacy_zip_load(cached_file, model_dir, map_location)

        try:
            return torch.load(cached_file, map_location=map_location)
        except RuntimeError as error:
            if digit_version(TORCH_VERSION) < digit_version('1.5.0'):
                warnings.warn(
                    f'If the error is the same as "{cached_file} is a zip '
                    'archive (did you mean to use torch.jit.load()?)", you can'
                    ' upgrade your torch to 1.5.0 or higher (current torch '
                    f'version is {TORCH_VERSION}). The error was raised '
                    ' because the checkpoint was saved in torch>=1.6.0 but '
                    'loaded in torch<1.5.')
            raise error
Esempio n. 8
0
def load_npz_from_url(url,
                      model_dir=None,
                      map_location=None,
                      progress=True,
                      check_hash=False):
    import os
    import sys
    import errno
    import warnings
    import numpy as np
    from urllib.parse import urlparse
    from torch.hub import _get_torch_home, download_url_to_file, HASH_REGEX
    r"""Loads the Torch serialized object at the given URL.

    If downloaded file is a zip file, it will be automatically
    decompressed.

    If the object is already present in `model_dir`, it's deserialized and
    returned.
    The default value of `model_dir` is ``$TORCH_HOME/checkpoints`` where
    environment variable ``$TORCH_HOME`` defaults to ``$XDG_CACHE_HOME/torch``.
    ``$XDG_CACHE_HOME`` follows the X Design Group specification of the Linux
    filesytem layout, with a default value ``~/.cache`` if not set.

    Args:
        url (string): URL of the object to download
        model_dir (string, optional): directory in which to save the object
        map_location (optional): a function or a dict specifying how to remap storage locations (see torch.load)
        progress (bool, optional): whether or not to display a progress bar to stderr.
            Default: True
        check_hash(bool, optional): If True, the filename part of the URL should follow the naming convention
            ``filename-<sha256>.ext`` where ``<sha256>`` is the first eight or more
            digits of the SHA256 hash of the contents of the file. The hash is used to
            ensure unique names and to verify the contents of the file.
            Default: False

    Example:
        >>> state_dict = torch.hub.load_state_dict_from_url('https://s3.amazonaws.com/pytorch/models/resnet18-5c106cde.pth')

    """
    # Issue warning to move data if old env is set
    if os.getenv('TORCH_MODEL_ZOO'):
        warnings.warn(
            'TORCH_MODEL_ZOO is deprecated, please use env TORCH_HOME instead')

    if model_dir is None:
        torch_home = _get_torch_home()
        model_dir = os.path.join(torch_home, 'checkpoints')

    try:
        os.makedirs(model_dir)
    except OSError as e:
        if e.errno == errno.EEXIST:
            # Directory already exists, ignore.
            pass
        else:
            # Unexpected OSError, re-raise.
            raise

    parts = urlparse(url)
    filename = os.path.basename(parts.path)
    cached_file = os.path.join(model_dir, filename)
    if not os.path.exists(cached_file):
        sys.stderr.write('Downloading: "{}" to {}\n'.format(url, cached_file))
        hash_prefix = HASH_REGEX.search(filename).group(
            1) if check_hash else None
        download_url_to_file(url, cached_file, hash_prefix, progress=progress)

    return np.load(cached_file)