Beispiel #1
0
def get_state_dict(filename, map_location='cpu'):
    """Get state_dict from a file or URI.

    Args:
        filename (str): Accept local filepath, URL, ``torchvision://xxx``,
            ``open-mmlab://xxx``.
        map_location (str): Same as :func:`torch.load`.

    Returns:
        OrderedDict: The state_dict.
    """
    checkpoint = _load_checkpoint(filename, map_location)
    # OrderedDict is a subclass of dict
    if not isinstance(checkpoint, dict):
        raise RuntimeError(
            f'No state_dict found in checkpoint file {filename}')
    # get state_dict from checkpoint
    if 'state_dict' in checkpoint:
        state_dict_tmp = checkpoint['state_dict']
    else:
        state_dict_tmp = checkpoint

    state_dict = OrderedDict()
    # strip prefix of state_dict
    for k, v in state_dict_tmp.items():
        if k.startswith('module.backbone.'):
            state_dict[k[16:]] = v
        elif k.startswith('module.'):
            state_dict[k[7:]] = v
        elif k.startswith('backbone.'):
            state_dict[k[9:]] = v
        else:
            state_dict[k] = v

    return state_dict
Beispiel #2
0
def load_checkpoint(model,
                    filename,
                    map_location=None,
                    strict=False,
                    logger=None):
    """Load checkpoint from a file or URI.

    Args:
        model (Module): Module to load checkpoint.
        filename (str): Accept local filepath, URL, ``torchvision://xxx``,
            ``open-mmlab://xxx``. Please refer to ``docs/model_zoo.md`` for
            details.
        map_location (str): Same as :func:`torch.load`.
        strict (bool): Whether to allow different params for the model and
            checkpoint.
        logger (:mod:`logging.Logger` or None): The logger for error message.

    Returns:
        dict or OrderedDict: The loaded checkpoint.
    """
    checkpoint = _load_checkpoint(filename, map_location)
    # OrderedDict is a subclass of dict
    if not isinstance(checkpoint, dict):
        raise RuntimeError(
            f'No state_dict found in checkpoint file {filename}')
    # get state_dict from checkpoint
    if 'state_dict' in checkpoint:
        state_dict_tmp = checkpoint['state_dict']
    else:
        state_dict_tmp = checkpoint

    state_dict = OrderedDict()
    # strip prefix of state_dict
    for k, v in state_dict_tmp.items():
        if k.startswith('module.backbone.'):
            state_dict[k[16:]] = v
        elif k.startswith('module.'):
            state_dict[k[7:]] = v
        elif k.startswith('backbone.'):
            state_dict[k[9:]] = v
        else:
            state_dict[k] = v
    # load state_dict
    load_state_dict(model, state_dict, strict, logger)
    return checkpoint
Beispiel #3
0
def test_load_external_url():
    # test modelzoo://
    url = _load_checkpoint('modelzoo://resnet50')
    assert url == 'url:https://download.pytorch.org/models/resnet50-19c8e357' \
                  '.pth'

    # test torchvision://
    url = _load_checkpoint('torchvision://resnet50')
    assert url == 'url:https://download.pytorch.org/models/resnet50-19c8e357' \
                  '.pth'

    # test open-mmlab:// with default MMCV_HOME
    os.environ.pop(ENV_MMCV_HOME, None)
    os.environ.pop(ENV_XDG_CACHE_HOME, None)
    url = _load_checkpoint('open-mmlab://train')
    assert url == 'url:https://localhost/train.pth'

    # test open-mmlab:// with deprecated model name
    os.environ.pop(ENV_MMCV_HOME, None)
    os.environ.pop(ENV_XDG_CACHE_HOME, None)
    with pytest.warns(Warning,
                      match='open-mmlab://train_old is deprecated in favor of '
                      'open-mmlab://train'):
        url = _load_checkpoint('open-mmlab://train_old')
        assert url == 'url:https://localhost/train.pth'

    # test openmmlab:// with deprecated model name
    os.environ.pop(ENV_MMCV_HOME, None)
    os.environ.pop(ENV_XDG_CACHE_HOME, None)
    with pytest.warns(Warning,
                      match='openmmlab://train_old is deprecated in favor of '
                      'openmmlab://train'):
        url = _load_checkpoint('openmmlab://train_old')
        assert url == 'url:https://localhost/train.pth'

    # test open-mmlab:// with user-defined MMCV_HOME
    os.environ.pop(ENV_MMCV_HOME, None)
    mmcv_home = osp.join(osp.dirname(__file__), 'data/model_zoo/mmcv_home')
    os.environ[ENV_MMCV_HOME] = mmcv_home
    url = _load_checkpoint('open-mmlab://train')
    assert url == 'url:https://localhost/train.pth'
    with pytest.raises(IOError, match='train.pth is not a checkpoint ' 'file'):
        _load_checkpoint('open-mmlab://train_empty')
    url = _load_checkpoint('open-mmlab://test')
    assert url == f'local:{osp.join(_get_mmcv_home(), "test.pth")}'
    url = _load_checkpoint('open-mmlab://val')
    assert url == f'local:{osp.join(_get_mmcv_home(), "val.pth")}'

    # test http:// https://
    url = _load_checkpoint('http://localhost/train.pth')
    assert url == 'url:http://localhost/train.pth'

    # test local file
    with pytest.raises(IOError, match='train.pth is not a checkpoint ' 'file'):
        _load_checkpoint('train.pth')
    url = _load_checkpoint(osp.join(_get_mmcv_home(), 'test.pth'))
    assert url == f'local:{osp.join(_get_mmcv_home(), "test.pth")}'
def test_load_external_url():
    # test modelzoo://
    url = _load_checkpoint('modelzoo://resnet50')
    if TORCH_VERSION < '1.9.0':
        assert url == ('url:https://download.pytorch.org/models/resnet50-19c8e'
                       '357.pth')
    else:
        # filename of checkpoint is renamed in torch1.9.0
        assert url == ('url:https://download.pytorch.org/models/resnet50-0676b'
                       'a61.pth')

    # test torchvision://
    url = _load_checkpoint('torchvision://resnet50')
    if TORCH_VERSION < '1.9.0':
        assert url == ('url:https://download.pytorch.org/models/resnet50-19c8e'
                       '357.pth')
    else:
        # filename of checkpoint is renamed in torch1.9.0
        assert url == ('url:https://download.pytorch.org/models/resnet50-0676b'
                       'a61.pth')

    # test open-mmlab:// with default MMCV_HOME
    os.environ.pop(ENV_MMCV_HOME, None)
    os.environ.pop(ENV_XDG_CACHE_HOME, None)
    url = _load_checkpoint('open-mmlab://train')
    assert url == 'url:https://localhost/train.pth'

    # test open-mmlab:// with deprecated model name
    os.environ.pop(ENV_MMCV_HOME, None)
    os.environ.pop(ENV_XDG_CACHE_HOME, None)
    with pytest.warns(Warning,
                      match='open-mmlab://train_old is deprecated in favor of '
                      'open-mmlab://train'):
        url = _load_checkpoint('open-mmlab://train_old')
        assert url == 'url:https://localhost/train.pth'

    # test openmmlab:// with deprecated model name
    os.environ.pop(ENV_MMCV_HOME, None)
    os.environ.pop(ENV_XDG_CACHE_HOME, None)
    with pytest.warns(Warning,
                      match='openmmlab://train_old is deprecated in favor of '
                      'openmmlab://train'):
        url = _load_checkpoint('openmmlab://train_old')
        assert url == 'url:https://localhost/train.pth'

    # test open-mmlab:// with user-defined MMCV_HOME
    os.environ.pop(ENV_MMCV_HOME, None)
    mmcv_home = osp.join(osp.dirname(__file__), 'data/model_zoo/mmcv_home')
    os.environ[ENV_MMCV_HOME] = mmcv_home
    url = _load_checkpoint('open-mmlab://train')
    assert url == 'url:https://localhost/train.pth'
    with pytest.raises(FileNotFoundError, match='train.pth can not be found.'):
        _load_checkpoint('open-mmlab://train_empty')
    url = _load_checkpoint('open-mmlab://test')
    assert url == f'local:{osp.join(_get_mmcv_home(), "test.pth")}'
    url = _load_checkpoint('open-mmlab://val')
    assert url == f'local:{osp.join(_get_mmcv_home(), "val.pth")}'

    # test http:// https://
    url = _load_checkpoint('http://localhost/train.pth')
    assert url == 'url:http://localhost/train.pth'

    # test local file
    with pytest.raises(FileNotFoundError, match='train.pth can not be found.'):
        _load_checkpoint('train.pth')
    url = _load_checkpoint(osp.join(_get_mmcv_home(), 'test.pth'))
    assert url == f'local:{osp.join(_get_mmcv_home(), "test.pth")}'
Beispiel #5
0
def load_checkpoint(model,
                    filename,
                    map_location='cpu',
                    strict=False,
                    logger=None,
                    force_matching=False,
                    show_converted=False,
                    ignore_prefixes=None,
                    ignore_suffixes=None):
    # load checkpoint
    checkpoint = _load_checkpoint(filename, map_location)
    if not isinstance(checkpoint, dict):
        raise RuntimeError(
            f'No state_dict found in checkpoint file {filename}')

    # get state_dict from checkpoint
    if 'state_dict' in checkpoint:
        state_dict = checkpoint['state_dict']
    else:
        state_dict = checkpoint

    # strip prefix of state_dict
    if list(state_dict.keys())[0].startswith('module.'):
        state_dict = {k[7:]: v for k, v in checkpoint['state_dict'].items()}

    # extract model
    model = model.module if hasattr(model, 'module') else model

    # load model classes
    assert hasattr(model, 'CLASSES')
    assert isinstance(model.CLASSES, dict)
    model_all_classes = model.CLASSES

    # build class mapping between model.classes and checkpoint.classes
    if 'meta' in checkpoint and 'CLASSES' in checkpoint['meta']:
        checkpoint_all_classes = checkpoint['meta']['CLASSES']

        assert set(model_all_classes.keys()).issubset(checkpoint_all_classes.keys()),\
            f'The model set of datasets is not a subset of checkpoint datasets: ' \
            f'{model_all_classes.keys()} vs {checkpoint_all_classes.keys()}'

        class_maps = dict()
        for dataset_id in model_all_classes.keys():
            model_dataset_classes = model_all_classes[dataset_id]
            checkpoint_dataset_classes = checkpoint_all_classes[dataset_id]
            assert set(model_dataset_classes.values()).issubset(checkpoint_dataset_classes.values()), \
                f'The model set of classes is not a subset of checkpoint classes'

            checkpoint_inv_class_map = {
                v: k
                for k, v in checkpoint_dataset_classes.items()
            }
            class_maps[dataset_id] = {
                k: checkpoint_inv_class_map[v]
                for k, v in model_dataset_classes.items()
            }
    else:
        class_maps = model_all_classes

    # load weights
    load_state_dict(model, state_dict, class_maps, strict, logger,
                    force_matching, show_converted, ignore_prefixes,
                    ignore_suffixes)

    return checkpoint