示例#1
0
def main():
    # get supported models
    model_names = sorted(name for name in models.__dict__
                         if name.islower() and not name.startswith("__")
                         and callable(models.__dict__[name]))

    # get arguments
    parser = argparse.ArgumentParser()
    parser.add_argument('-a', '--arch', metavar='ARCH', default='resnet50', choices=model_names,
                        help='model architecture: ' + ' | '.join(model_names) + ' (default: resnet50)')
    parser.add_argument('--save', default=None, type=str, help='pruned checkpoint')
    args = parser.parse_args()

    # if required, download pre trained model
    models.__dict__[args.arch](pretrained=True)

    # get pre trained checkpoint
    checkpoint_path = os.path.join(get_dir(), 'checkpoints')
    files = glob(os.path.join(checkpoint_path, '{}-*.pth').format(args.arch))
    assert len(files) == 1
    checkpoint_file = files[0]

    # prune and save checkpoint
    prune(checkpoint=checkpoint_file, save=args.save, sd_key=None, bs=8, topk=4)

    # add expected fields to checkpoint
    sd = torch.load(args.save)
    checkpoint = {'state_dict': sd, 'epoch': 0, 'best_prec1': 0}
    torch.save(checkpoint, args.save)
示例#2
0
def load_state_dict_from_url(
    url: str,
    model_dir: Optional[str] = None,
    map_location: Optional[Union[torch.device, str]] = None,
    file_name: Optional[str] = None,
    **kwargs: Any,
) -> OrderedDictType[str, torch.Tensor]:
    # This is just for compatibility with torch==1.6.0 until
    # https://github.com/pytorch/pytorch/issues/42596 is resolved
    if model_dir is None:
        model_dir = path.join(hub.get_dir(), "checkpoints")
    if file_name is None:
        file_name = path.basename(url)

    try:
        return cast(
            OrderedDictType[str, torch.Tensor],
            hub.load_state_dict_from_url(url,
                                         model_dir=model_dir,
                                         file_name=file_name,
                                         **kwargs),
        )
    except RuntimeError as error:
        if str(error) != "Only one file(not dir) is allowed in the zipfile":
            raise error

        cached_file = path.join(model_dir, file_name)
        return cast(
            OrderedDictType[str, torch.Tensor],
            torch.load(cached_file, map_location=map_location),
        )
示例#3
0
def load_file_from_url(url,
                       model_dir=None,
                       progress=True,
                       check_hash=False,
                       file_name=None):
    if model_dir is None:
        hub_dir = get_dir()
        model_dir = os.path.join(hub_dir, 'checkpoints')

    try:
        os.makedirs(model_dir)
    except OSError as e:
        if e.errno == errno.EEXIST:
            # Directory already exists, ignore.
            pass
        else:
            # Unexpected OSError, re-raise.
            raise

    parts = urlparse(url)
    filename = os.path.basename(parts.path)
    if file_name is not None:
        filename = file_name
    cached_file = os.path.join(model_dir, filename)
    if not os.path.exists(cached_file):
        sys.stderr.write('Downloading: "{}" to {}\n'.format(url, cached_file))
        hash_prefix = None
        if check_hash:
            r = HASH_REGEX.search(filename)  # r is Optional[Match[str]]
            hash_prefix = r.group(1) if r else None
        download_url_to_file(url, cached_file, hash_prefix, progress=progress)

    return cached_file
示例#4
0
def load_file_from_url(url, model_dir=None, progress=True, file_name=None):
    """Load file form http url, will download models if necessary.

    Ref:https://github.com/1adrianb/face-alignment/blob/master/face_alignment/utils.py

    Args:
        url (str): URL to be downloaded.
        model_dir (str): The path to save the downloaded model. Should be a full path. If None, use pytorch hub_dir.
            Default: None.
        progress (bool): Whether to show the download progress. Default: True.
        file_name (str): The downloaded file name. If None, use the file name in the url. Default: None.

    Returns:
        str: The path to the downloaded file.
    """
    if model_dir is None:  # use the pytorch hub_dir
        hub_dir = get_dir()
        model_dir = os.path.join(hub_dir, 'checkpoints')

    os.makedirs(model_dir, exist_ok=True)

    parts = urlparse(url)
    filename = os.path.basename(parts.path)
    if file_name is not None:
        filename = file_name
    cached_file = os.path.abspath(os.path.join(model_dir, filename))
    if not os.path.exists(cached_file):
        print(f'Downloading: "{url}" to {cached_file}\n')
        download_url_to_file(url, cached_file, hash_prefix=None, progress=progress)
    return cached_file
def download_url(url, download_dir=None, filename=None):
    parts = urlparse(url)
    if download_dir is None:
        hub_dir = get_dir()
        download_dir = os.path.join(hub_dir, 'checkpoints')
    if filename is None:
        filename = os.path.basename(parts.path)
    cached_file = os.path.join(download_dir, filename)
    if not os.path.exists(cached_file):
        sys.stderr.write('Downloading: "{}" to {}\n'.format(url, cached_file))
        download_url_to_file(url, cached_file)
    return cached_file
示例#6
0
def get_cache_dir(child_dir=''):
    """
    Returns the location of the directory where models.bak are cached (and creates it if necessary).
    """
    # Issue warning to move data if old env is set
    if os.getenv('TORCH_MODEL_ZOO'):
        _logger.warning('TORCH_MODEL_ZOO is deprecated, please use env TORCH_HOME instead')

    hub_dir = get_dir()
    child_dir = () if not child_dir else (child_dir,)
    model_dir = os.path.join(hub_dir, 'checkpoints', *child_dir)
    os.makedirs(model_dir, exist_ok=True)
    return model_dir
示例#7
0
def load_custom_pretrained(model, cfg=None, load_fn=None, progress=False, check_hash=False):
    r"""Loads a custom (read non .pth) weight file
    Downloads checkpoint file into cache-dir like torch.hub based loaders, but calls
    a passed in custom load fun, or the `load_pretrained` model member fn.
    If the object is already present in `model_dir`, it's deserialized and returned.
    The default value of `model_dir` is ``<hub_dir>/checkpoints`` where
    `hub_dir` is the directory returned by :func:`~torch.hub.get_dir`.
    Args:
        model: The instantiated model to load weights into
        cfg (dict): Default pretrained model cfg
        load_fn: An external stand alone fn that loads weights into provided model, otherwise a fn named
            'laod_pretrained' on the model will be called if it exists
        progress (bool, optional): whether or not to display a progress bar to stderr. Default: False
        check_hash(bool, optional): If True, the filename part of the URL should follow the naming convention
            ``filename-<sha256>.ext`` where ``<sha256>`` is the first eight or more
            digits of the SHA256 hash of the contents of the file. The hash is used to
            ensure unique names and to verify the contents of the file. Default: False
    """
    cfg = cfg or getattr(model, 'default_cfg')
    if cfg is None or not cfg.get('url', None):
        _logger.warning("No pretrained weights exist for this model. Using random initialization.")
        return
    url = cfg['url']

    # Issue warning to move data if old env is set
    if os.getenv('TORCH_MODEL_ZOO'):
        _logger.warning('TORCH_MODEL_ZOO is deprecated, please use env TORCH_HOME instead')

    hub_dir = get_dir()
    model_dir = os.path.join(hub_dir, 'checkpoints')

    os.makedirs(model_dir, exist_ok=True)

    parts = urlparse(url)
    filename = os.path.basename(parts.path)
    cached_file = os.path.join(model_dir, filename)
    if not os.path.exists(cached_file):
        _logger.info('Downloading: "{}" to {}\n'.format(url, cached_file))
        hash_prefix = None
        if check_hash:
            r = HASH_REGEX.search(filename)  # r is Optional[Match[str]]
            hash_prefix = r.group(1) if r else None
        download_url_to_file(url, cached_file, hash_prefix, progress=progress)

    if load_fn is not None:
        load_fn(model, cached_file)
    elif hasattr(model, 'load_pretrained'):
        model.load_pretrained(cached_file)
    else:
        _logger.warning("Valid function to load pretrained weights is not available, using random initialization.")
示例#8
0
文件: hub.py 项目: bernardomig/ark
def load_state_dict_from_url(url, model_dir=None, map_location=None, progress=True, check_hash=False, file_name=None):
    r"""A modified version of `torch.hub.load_state_dict_from_url`, which handles the new 
    serialization protocol diferently. 
    See <https://github.com/pytorch/pytorch/issues/43106> for more information.
    """
    import os
    import sys
    import warnings
    import errno
    import torch
    from urllib.parse import urlparse
    from torch.hub import get_dir, download_url_to_file, HASH_REGEX

    # Issue warning to move data if old env is set
    if os.getenv('TORCH_MODEL_ZOO'):
        warnings.warn('TORCH_MODEL_ZOO is deprecated, please use env TORCH_HOME instead')

    if model_dir is None:
        hub_dir = get_dir()
        model_dir = os.path.join(hub_dir, 'checkpoints')

    try:
        os.makedirs(model_dir)
    except OSError as e:
        if e.errno == errno.EEXIST:
            # Directory already exists, ignore.
            pass
        else:
            # Unexpected OSError, re-raise.
            raise

    parts = urlparse(url)
    filename = os.path.basename(parts.path)
    if file_name is not None:
        filename = file_name
    cached_file = os.path.join(model_dir, filename)
    if not os.path.exists(cached_file):
        sys.stderr.write('Downloading: "{}" to {}\n'.format(url, cached_file))
        hash_prefix = None
        if check_hash:
            r = HASH_REGEX.search(filename)  # r is Optional[Match[str]]
            hash_prefix = r.group(1) if r else None
        download_url_to_file(url, cached_file, hash_prefix, progress=progress)

    return torch.load(cached_file, map_location=map_location)
示例#9
0
def _resnet(arch, block, layers, pretrained, progress, **kwargs):
    model = ResNet(block, layers, **kwargs)
    if pretrained and arch != 'resnet101v2':
        state_dict = load_state_dict_from_url(model_urls[arch],
                                              progress=progress)
        model.load_state_dict(state_dict)
    elif pretrained and arch == 'resnet101v2':
        import gdown
        from torch.hub import get_dir
        import os
        hub_dir = get_dir()
        model_dir = os.path.join(hub_dir, 'checkpoints')
        os.makedirs(model_dir, exist_ok=True)
        output = os.path.join(model_dir, 'resnet101v2.pth')
        if not os.path.exists(output):
            gdown.download(model_urls[arch], output, quiet=False)
        state_dict = torch.load(output)
        model.load_state_dict(state_dict, strict=False)
    return model
示例#10
0
    def load_onnx_from_zoo(self, model_name, sensor, version="onnx"):
        model_url = self._get_zoo_model_url(model_name, sensor.name, version)
        hub_dir = os.path.join(hub.get_dir(), "checkpoints")

        if not os.path.isfile(hub_dir):
            os.makedirs(hub_dir)

        # from https://github.com/pytorch/pytorch/blob/master/torch/hub.py
        parts = urlparse(model_url)
        filename = os.path.basename(parts.path)
        cached_file = os.path.join(hub_dir, filename)
        if not os.path.exists(cached_file):
            sys.stderr.write('Downloading: "{}" to {}\n'.format(
                model_url, cached_file))
            hub.download_url_to_file(model_url,
                                     cached_file,
                                     None,
                                     progress=True)

        return self.load_onnx_session(cached_file)
示例#11
0
 def download_model_from_zoo(self,
                             model_name,
                             sensor,
                             dst=None,
                             save_local=False,
                             version="pth"):
     model_url = self._get_zoo_model_url(model_name, sensor.zoo_name(),
                                         version)
     model_dst = (dst if save_local else os.path.join(
         hub.get_dir(), "checkpoints", dst))
     try:
         os.makedirs(model_dst)
     except OSError as e:
         if e.errno != errno.EEXIST:
             raise
     cached_file = os.path.join(
         model_dst, f"{model_name}_{sensor.zoo_name()}.{version}")
     if not os.path.exists(cached_file):
         sys.stderr.write(f"Downloading: {model_url} to {cached_file}\n")
         hub.download_url_to_file(model_url, cached_file)
     return cached_file
示例#12
0
def load_file_from_url(url, model_dir=None, progress=True, file_name=None):
    """Ref:https://github.com/1adrianb/face-alignment/blob/master/face_alignment/utils.py
    """
    if model_dir is None:
        hub_dir = get_dir()
        model_dir = os.path.join(hub_dir, 'checkpoints')

    os.makedirs(os.path.join(ROOT_DIR, model_dir), exist_ok=True)

    parts = urlparse(url)
    filename = os.path.basename(parts.path)
    if file_name is not None:
        filename = file_name
    cached_file = os.path.abspath(os.path.join(ROOT_DIR, model_dir, filename))
    if not os.path.exists(cached_file):
        print(f'Downloading: "{url}" to {cached_file}\n')
        download_url_to_file(url,
                             cached_file,
                             hash_prefix=None,
                             progress=progress)
    return cached_file