def test_default_mmcv_home(): os.environ.pop(ENV_MMCV_HOME, None) os.environ.pop(ENV_XDG_CACHE_HOME, None) assert _get_mmcv_home() == os.path.expanduser( os.path.join(DEFAULT_CACHE_DIR, 'mmcv')) model_urls = get_external_models() assert model_urls == mmcv.load( osp.join(mmcv.__path__[0], 'model_zoo/open_mmlab.json'))
def test_get_external_models(): os.environ.pop(ENV_MMCV_HOME, None) mmcv_home = osp.join(osp.dirname(__file__), 'data/model_zoo/mmcv_home/') os.environ[ENV_MMCV_HOME] = mmcv_home ext_urls = get_external_models() assert ext_urls == { 'train': 'https://localhost/train.pth', 'test': 'test.pth', 'val': 'val.pth', 'train_empty': 'train.pth' }
def _load_checkpoint(filename, map_location=None): """Load checkpoint from somewhere (modelzoo, file, url). Args: filename (str): Accept local filepath, URL, ``torchvision://xxx``, ``open-mmlab://xxx``. Please refer to ``docs/model_zoo.md`` for details. map_location (str | None): Same as :func:`torch.load`. Default: None. Returns: dict | OrderedDict: The loaded checkpoint. It can be either an OrderedDict storing model weights or a dict containing other information, which depends on the checkpoint. """ if filename.startswith('modelzoo://'): warnings.warn('The URL scheme of "modelzoo://" is deprecated, please ' 'use "torchvision://" instead') model_urls = get_torchvision_models() model_name = filename[11:] checkpoint = load_url_dist(model_urls[model_name]) elif filename.startswith('torchvision://'): model_urls = get_torchvision_models() model_name = filename[14:] checkpoint = load_url_dist(model_urls[model_name]) elif filename.startswith('open-mmlab://'): model_urls = get_external_models() model_name = filename[13:] deprecated_urls = get_deprecated_model_names() if model_name in deprecated_urls: warnings.warn(f'open-mmlab://{model_name} is deprecated in favor ' f'of open-mmlab://{deprecated_urls[model_name]}') model_name = deprecated_urls[model_name] model_url = model_urls[model_name] # check if is url if model_url.startswith(('http://', 'https://')): checkpoint = load_url_dist(model_url, map_location=map_location) else: filename = osp.join(_get_mmcv_home(), model_url) if not osp.isfile(filename): raise IOError(f'{filename} is not a checkpoint file') checkpoint = torch.load(filename, map_location=map_location) elif filename.startswith(('http://', 'https://')): checkpoint = load_url_dist(filename, map_location=map_location) else: if not osp.isfile(filename): raise IOError(f'{filename} is not a checkpoint file') checkpoint = torch.load(filename, map_location=map_location) return checkpoint