예제 #1
0
파일: runner.py 프로젝트: youshyee/CEP
    def load_checkpoint(self,
                        filename,
                        map_location='cpu',
                        strict=False,
                        part=None):
        self.logger.info('load checkpoint from %s', filename)

        if not osp.isfile(filename):
            raise IOError(f'{filename} is not a checkpoint file')
        checkpoint = torch.load(filename, map_location=map_location)

        # OrderedDict is a subclass of dict
        if not isinstance(checkpoint, dict):
            raise RuntimeError(
                f'No state_dict found in checkpoint file {filename}')
        # get state_dict from checkpoint
        if 'state_dict' in checkpoint:
            state_dict = checkpoint['state_dict']
        else:
            state_dict = checkpoint
        # strip prefix of state_dict
        if list(state_dict.keys())[0].startswith('module.'):
            state_dict = {
                k[7:]: v
                for k, v in checkpoint['state_dict'].items()
            }
        if part is not None:
            state_dict = {k: v for k, v in state_dict.items() if part in k}
        # load state_dict
        load_state_dict(self.model, state_dict, strict, self.logger)
        return checkpoint
예제 #2
0
 def init_weights(self, pretrained=None):
     if isinstance(pretrained, str):
         logger = get_root_logger()
         #load_checkpoint(self, pretrained, map_location='cpu', strict=False, logger=logger)
         revise_keys = [(r'^module\.', '')]
         checkpoint = _load_checkpoint(pretrained,
                                       map_location='cpu',
                                       logger=logger)
         # OrderedDict is a subclass of dict
         if not isinstance(checkpoint, dict):
             raise RuntimeError(
                 f'No state_dict found in checkpoint file {pretrained}')
         # get state_dict from checkpoint
         if 'state_dict' in checkpoint:
             state_dict = checkpoint['state_dict']
         elif 'model' in checkpoint:  # for our model
             state_dict = checkpoint['model']
         else:
             state_dict = checkpoint
         # strip prefix of state_dict
         for p, r in revise_keys:
             state_dict = {
                 re.sub(p, r, k): v
                 for k, v in state_dict.items()
             }
         # load state_dict
         load_state_dict(self, state_dict, strict=False, logger=logger)
예제 #3
0
    def init_weights(self, pretrained=None):
        """Initiate the parameters either from existing checkpoint or from
        scratch."""
        trunc_normal_(self.pos_embed, std=.02)
        trunc_normal_(self.cls_token, std=.02)

        if pretrained:
            self.pretrained = pretrained
        if isinstance(self.pretrained, str):
            logger = get_root_logger()
            logger.info(f'load model from: {self.pretrained}')

            state_dict = _load_checkpoint(self.pretrained)
            if 'state_dict' in state_dict:
                state_dict = state_dict['state_dict']

            if self.attention_type == 'divided_space_time':
                # modify the key names of norm layers
                old_state_dict_keys = list(state_dict.keys())
                for old_key in old_state_dict_keys:
                    if 'norms' in old_key:
                        new_key = old_key.replace('norms.0',
                                                  'attentions.0.norm')
                        new_key = new_key.replace('norms.1', 'ffns.0.norm')
                        state_dict[new_key] = state_dict.pop(old_key)

                # copy the parameters of space attention to time attention
                old_state_dict_keys = list(state_dict.keys())
                for old_key in old_state_dict_keys:
                    if 'attentions.0' in old_key:
                        new_key = old_key.replace('attentions.0',
                                                  'attentions.1')
                        state_dict[new_key] = state_dict[old_key].clone()

            load_state_dict(self, state_dict, strict=False, logger=logger)
예제 #4
0
 def init_from_pretrained(self, pretrained, logger):
     """ Initialization model weights from pretrained model.
     Args:
         pretrained (str): pretrained model path. (something like *.pth)
         logger (logging.Logger): output logger
     """
     logger.info(f"Loading pretrained backbone from {pretrained}")
     checkpoint = torch.load(pretrained)
     # get state_dict from checkpoint
     if isinstance(checkpoint, OrderedDict):
         state_dict = checkpoint
     elif isinstance(checkpoint, dict) and 'state_dict' in checkpoint:
         state_dict = checkpoint['state_dict']
     else:
         raise RuntimeError(
             f'No state_dict found in checkpoint file {pretrained}')
     # strip prefix of state_dict
     if list(state_dict.keys())[0].startswith('module.'):
         state_dict = {
             k[7:]: v
             for k, v in checkpoint['state_dict'].items()
         }
     # strip prefix of backbone
     if any([s.startswith('backbone.') for s in state_dict.keys()]):
         state_dict = {
             k[9:]: v
             for k, v in checkpoint['state_dict'].items()
             if k.startswith('backbone.')
         }
     load_state_dict(self, state_dict, strict=False, logger=logger)
예제 #5
0
    def init_weights(self, pretrained=None):
        if isinstance(pretrained, str):
            logger = get_root_logger()

            checkpoint = _load_checkpoint(pretrained,
                                          logger=logger,
                                          map_location='cpu')
            logger.warning(f'Load pre-trained model for '
                           f'{self.__class__.__name__} from original repo')
            if 'state_dict' in checkpoint:
                state_dict = checkpoint['state_dict']
            elif 'model' in checkpoint:
                state_dict = checkpoint['model']
            else:
                state_dict = checkpoint

            if self.convert_weights:
                # We need to convert pre-trained weights to match this
                # implementation.
                state_dict = tcformer_convert(state_dict)
            load_state_dict(self, state_dict, strict=False, logger=logger)

        elif pretrained is None:
            for m in self.modules():
                if isinstance(m, nn.Linear):
                    trunc_normal_init(m, std=.02, bias=0.)
                elif isinstance(m, nn.LayerNorm):
                    constant_init(m, 1.0)
                elif isinstance(m, nn.Conv2d):
                    fan_out = m.kernel_size[0] * m.kernel_size[
                        1] * m.out_channels
                    fan_out //= m.groups
                    normal_init(m, 0, math.sqrt(2.0 / fan_out))
        else:
            raise TypeError('pretrained must be a str or None')
예제 #6
0
 def __call__(self, module):
     from mmcv.runner import (_load_checkpoint_with_prefix, load_checkpoint,
                              load_state_dict)
     logger = get_logger('mmcv')
     if self.prefix is None:
         print_log(f'load model from: {self.checkpoint}', logger=logger)
         load_checkpoint(module,
                         self.checkpoint,
                         map_location=self.map_location,
                         strict=False,
                         logger=logger)
     else:
         print_log(f'load {self.prefix} in model from: {self.checkpoint}',
                   logger=logger)
         state_dict = _load_checkpoint_with_prefix(
             self.prefix, self.checkpoint, map_location=self.map_location)
         load_state_dict(module, state_dict, strict=False, logger=logger)
예제 #7
0
    def init_weights(self):
        if (isinstance(self.init_cfg, dict)
                and self.init_cfg.get('type') == 'Pretrained'):
            logger = get_root_logger()
            checkpoint = CheckpointLoader.load_checkpoint(
                self.init_cfg['checkpoint'], logger=logger, map_location='cpu')

            if 'state_dict' in checkpoint:
                state_dict = checkpoint['state_dict']
            else:
                state_dict = checkpoint

            if 'pos_embed' in state_dict.keys():
                if self.pos_embed.shape != state_dict['pos_embed'].shape:
                    logger.info(msg=f'Resize the pos_embed shape from '
                                f'{state_dict["pos_embed"].shape} to '
                                f'{self.pos_embed.shape}')
                    h, w = self.img_size
                    pos_size = int(
                        math.sqrt(state_dict['pos_embed'].shape[1] - 1))
                    state_dict['pos_embed'] = self.resize_pos_embed(
                        state_dict['pos_embed'],
                        (h // self.patch_size, w // self.patch_size),
                        (pos_size, pos_size), self.interpolate_mode)

            load_state_dict(self, state_dict, strict=False, logger=logger)
        elif self.init_cfg is not None:
            super(VisionTransformer, self).init_weights()
        else:
            # We only implement the 'jax_impl' initialization implemented at
            # https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py#L353  # noqa: E501
            trunc_normal_(self.pos_embed, std=.02)
            trunc_normal_(self.cls_token, std=.02)
            for n, m in self.named_modules():
                if isinstance(m, nn.Linear):
                    trunc_normal_(m.weight, std=.02)
                    if m.bias is not None:
                        if 'ffn' in n:
                            nn.init.normal_(m.bias, mean=0., std=1e-6)
                        else:
                            nn.init.constant_(m.bias, 0)
                elif isinstance(m, nn.Conv2d):
                    kaiming_init(m, mode='fan_in', bias=0.)
                elif isinstance(m, (_BatchNorm, nn.GroupNorm, nn.LayerNorm)):
                    constant_init(m, val=1.0, bias=0.)
    def check_and_load_prev_weight(self, curr_scale):
        if curr_scale == 0:
            return
        prev_ch = self.blocks[curr_scale - 1].base_channels
        curr_ch = self.blocks[curr_scale].base_channels

        prev_in_ch = self.blocks[curr_scale - 1].in_channels
        curr_in_ch = self.blocks[curr_scale].in_channels
        if prev_ch == curr_ch and prev_in_ch == curr_in_ch:
            load_state_dict(self.blocks[curr_scale],
                            self.blocks[curr_scale - 1].state_dict(),
                            logger=get_root_logger())
            print_log('Successfully load pretrianed model from last scale.')
        else:
            print_log(
                'Cannot load pretrained model from last scale since'
                f' prev_ch({prev_ch}) != curr_ch({curr_ch})'
                f' or prev_in_ch({prev_in_ch}) != curr_in_ch({curr_in_ch})')
예제 #9
0
 def init_weights(self):
     logger = get_root_logger()
     if self.init_cfg is None:
         logger.warn(f'No pre-trained weights for '
                     f'{self.__class__.__name__}, '
                     f'training start from scratch')
         for m in self.modules():
             if isinstance(m, nn.Linear):
                 trunc_normal_init(m.weight, std=.02)
                 if m.bias is not None:
                     constant_init(m.bias, 0)
             elif isinstance(m, nn.LayerNorm):
                 constant_init(m.bias, 0)
                 constant_init(m.weight, 1.0)
             elif isinstance(m, nn.Conv2d):
                 fan_out = m.kernel_size[0] * m.kernel_size[
                     1] * m.out_channels
                 fan_out //= m.groups
                 normal_init(m.weight, 0, math.sqrt(2.0 / fan_out))
                 if m.bias is not None:
                     constant_init(m.bias, 0)
             elif isinstance(m, AbsolutePositionEmbedding):
                 m.init_weights()
     else:
         assert 'checkpoint' in self.init_cfg, f'Only support ' \
                                               f'specify `Pretrained` in ' \
                                               f'`init_cfg` in ' \
                                               f'{self.__class__.__name__} '
         checkpoint = _load_checkpoint(
             self.init_cfg.checkpoint, logger=logger, map_location='cpu')
         logger.warn(f'Load pre-trained model for '
                     f'{self.__class__.__name__} from original repo')
         if 'state_dict' in checkpoint:
             state_dict = checkpoint['state_dict']
         elif 'model' in checkpoint:
             state_dict = checkpoint['model']
         else:
             state_dict = checkpoint
         if self.convert_weights:
             # Because pvt backbones are not supported by mmcls,
             # so we need to convert pre-trained weights to match this
             # implementation.
             state_dict = pvt_convert(state_dict)
         load_state_dict(self, state_dict, strict=False, logger=logger)
예제 #10
0
    def load_pretrain(self, pretrained):
        if not os.path.exists(pretrained):
            raise FileNotFoundError(f"File '{pretrained}' not exists")
        logger.info(f"Loading pretrained model from: {pretrained}")
        state_dict = torch.load(pretrained, map_location="cpu")
        state_dict = state_dict['state_dict']

        # replace some keys
        key_replaces = [("token_learner", "token_reduction")]
        for key in list(state_dict.keys()):
            for rep in key_replaces:
                if rep[0] in key:
                    state_dict[key.replace(*rep)] = state_dict.pop(key)
        # delete some keys
        for key in ["loss.logit_scale"]:
            if key in state_dict:
                del state_dict[key]

        load_state_dict(self, state_dict, True)
예제 #11
0
    def init_weights(self):
        """Initialize the weights in backbone."""
        logger = get_root_logger()
        if self.init_cfg is None:
            logger.warning(f'No pre-trained weights for '
                           f'{self.__class__.__name__}, '
                           f'training start from scratch')
            for m in self.modules():
                if isinstance(m, nn.Linear):
                    trunc_normal_init(m, std=.02, bias=0.)
                elif isinstance(m, nn.LayerNorm):
                    constant_init(m, 1.0)
            for bly in self.layers:
                bly._init_respostnorm()
        else:
            if self.ape:
                raise NotImplementedError
            assert 'checkpoint' in self.init_cfg, f'Only support ' \
                                                  f'specify `Pretrained` in ' \
                                                  f'`init_cfg` in ' \
                                                  f'{self.__class__.__name__} '
            ckpt = _load_checkpoint(self.init_cfg.checkpoint,
                                    logger=logger,
                                    map_location='cpu')
            if 'state_dict' in ckpt:
                state_dict = ckpt['state_dict']
            elif 'model' in ckpt:
                state_dict = ckpt['model']
            else:
                state_dict = ckpt

            # delete keys for reinitialization
            reinit_keys = ('relative_position_index', 'relative_coords_table')
            for reinit_key in reinit_keys:
                for k in list(state_dict.keys()):
                    if reinit_key in k:
                        del state_dict[k]

            load_state_dict(self, state_dict, strict=False, logger=logger)
    def _load_checkpoint(self, checkpoint, prefix=None, map_location=None):
        from mmcv.runner import (_load_checkpoint,
                                 _load_checkpoint_with_prefix, load_state_dict)
        from mmcv.utils import print_log

        logger = get_root_logger()

        if prefix is None:
            print_log(f'load model from: {checkpoint}', logger=logger)
            checkpoint = _load_checkpoint(checkpoint, map_location, logger)
            # get state_dict from checkpoint
            if 'state_dict' in checkpoint:
                state_dict = checkpoint['state_dict']
            else:
                state_dict = checkpoint
        else:
            print_log(
                f'load {prefix} in model from: {checkpoint}', logger=logger)
            state_dict = _load_checkpoint_with_prefix(prefix, checkpoint,
                                                      map_location)

        if 'pos_embed' in state_dict.keys():
            ckpt_pos_embed_shape = state_dict['pos_embed'].shape
            if self.pos_embed.shape != ckpt_pos_embed_shape:
                print_log(
                    f'Resize the pos_embed shape from {ckpt_pos_embed_shape} '
                    f'to {self.pos_embed.shape}.',
                    logger=logger)

                ckpt_pos_embed_shape = to_2tuple(
                    int(np.sqrt(ckpt_pos_embed_shape[1] - 1)))
                pos_embed_shape = self.patch_embed.patches_resolution

                state_dict['pos_embed'] = self.resize_pos_embed(
                    state_dict['pos_embed'], ckpt_pos_embed_shape,
                    pos_embed_shape, self.interpolate_mode)

        # load state_dict
        load_state_dict(self, state_dict, strict=False, logger=logger)
예제 #13
0
def load_checkpoint(model,
                    filename,
                    map_location=None,
                    strict=False,
                    logger=None):
    """Load checkpoint from a file or URI.
    Args:
        model (Module): Module to load checkpoint.
        filename (str): Accept local filepath, URL, ``torchvision://xxx``,
            ``open-mmlab://xxx``. Please refer to ``docs/model_zoo.md`` for
            details.
        map_location (str): Same as :func:`torch.load`.
        strict (bool): Whether to allow different params for the model and
            checkpoint.
        logger (:mod:`logging.Logger` or None): The logger for error message.
    Returns:
        dict or OrderedDict: The loaded checkpoint.
    """
    checkpoint = torch.load(filename, map_location=lambda storage, loc: storage)
    # OrderedDict is a subclass of dict
    if not isinstance(checkpoint, dict):
        raise RuntimeError(
            f'No state_dict found in checkpoint file {filename}')
    # get state_dict from checkpoint
    if 'state_dict' in checkpoint:
        state_dict = checkpoint['state_dict']
    else:
        state_dict = checkpoint
    # strip prefix of state_dict
    if list(state_dict.keys())[0].startswith('module.'):
        state_dict = {k[7:]: v for k, v in checkpoint['state_dict'].items()}

    # state_dict = {k.split('backbone.')[1]: v for k, v in state_dict.items() if k.startswith('backbone') }
    
    # load state_dict
    load_state_dict(model, state_dict, strict, logger)
    return checkpoint
예제 #14
0
    def init_weights(self):
        logger = get_root_logger()
        if self.init_cfg is None:
            logger.warn(f'No pre-trained weights for '
                        f'{self.__class__.__name__}, '
                        f'training start from scratch')
            if self.use_abs_pos_embed:
                trunc_normal_(self.absolute_pos_embed, std=0.02)
            for m in self.modules():
                if isinstance(m, nn.Linear):
                    trunc_normal_init(m, std=.02, bias=0.)
                elif isinstance(m, nn.LayerNorm):
                    constant_init(m, val=1.0, bias=0.)
        else:
            assert 'checkpoint' in self.init_cfg, f'Only support ' \
                                                  f'specify `Pretrained` in ' \
                                                  f'`init_cfg` in ' \
                                                  f'{self.__class__.__name__} '
            ckpt = CheckpointLoader.load_checkpoint(
                self.init_cfg['checkpoint'], logger=logger, map_location='cpu')
            if 'state_dict' in ckpt:
                _state_dict = ckpt['state_dict']
            elif 'model' in ckpt:
                _state_dict = ckpt['model']
            else:
                _state_dict = ckpt

            state_dict = OrderedDict()
            for k, v in _state_dict.items():
                if k.startswith('backbone.'):
                    state_dict[k[9:]] = v
                else:
                    state_dict[k] = v

            # strip prefix of state_dict
            if list(state_dict.keys())[0].startswith('module.'):
                state_dict = {k[7:]: v for k, v in state_dict.items()}

            # reshape absolute position embedding
            if state_dict.get('absolute_pos_embed') is not None:
                absolute_pos_embed = state_dict['absolute_pos_embed']
                N1, L, C1 = absolute_pos_embed.size()
                N2, C2, H, W = self.absolute_pos_embed.size()
                if N1 != N2 or C1 != C2 or L != H * W:
                    logger.warning('Error in loading absolute_pos_embed, pass')
                else:
                    state_dict['absolute_pos_embed'] = absolute_pos_embed.view(
                        N2, H, W, C2).permute(0, 3, 1, 2).contiguous()

            # interpolate position bias table if needed
            relative_position_bias_table_keys = [
                k for k in state_dict.keys()
                if 'relative_position_bias_table' in k
            ]
            for table_key in relative_position_bias_table_keys:
                table_pretrained = state_dict[table_key]
                table_current = self.state_dict()[table_key]
                L1, nH1 = table_pretrained.size()
                L2, nH2 = table_current.size()
                if nH1 != nH2:
                    logger.warning(f'Error in loading {table_key}, pass')
                elif L1 != L2:
                    S1 = int(L1**0.5)
                    S2 = int(L2**0.5)
                    table_pretrained_resized = F.interpolate(
                        table_pretrained.permute(1, 0).reshape(1, nH1, S1, S1),
                        size=(S2, S2),
                        mode='bicubic')
                    state_dict[table_key] = table_pretrained_resized.view(
                        nH2, L2).permute(1, 0).contiguous()

            # load state_dict
            load_state_dict(self, state_dict, strict=False, logger=logger)
def main():
    """Convert keys in checkpoints for VoteNet.

    There can be some breaking changes during the development of mmdetection3d,
    and this tool is used for upgrading checkpoints trained with old versions
    (before v0.6.0) to the latest one.
    """
    args = parse_args()
    checkpoint = torch.load(args.checkpoint)
    cfg = parse_config(checkpoint['meta']['config'])
    # Build the model and load checkpoint
    model = build_detector(cfg.model,
                           train_cfg=cfg.get('train_cfg'),
                           test_cfg=cfg.get('test_cfg'))
    orig_ckpt = checkpoint['state_dict']
    converted_ckpt = orig_ckpt.copy()

    if cfg['dataset_type'] == 'ScanNetDataset':
        NUM_CLASSES = 18
    elif cfg['dataset_type'] == 'SUNRGBDDataset':
        NUM_CLASSES = 10
    else:
        raise NotImplementedError

    RENAME_PREFIX = {
        'bbox_head.conv_pred.0': 'bbox_head.conv_pred.shared_convs.layer0',
        'bbox_head.conv_pred.1': 'bbox_head.conv_pred.shared_convs.layer1'
    }

    DEL_KEYS = [
        'bbox_head.conv_pred.0.bn.num_batches_tracked',
        'bbox_head.conv_pred.1.bn.num_batches_tracked'
    ]

    EXTRACT_KEYS = {
        'bbox_head.conv_pred.conv_cls.weight':
        ('bbox_head.conv_pred.conv_out.weight', [(0, 2), (-NUM_CLASSES, -1)]),
        'bbox_head.conv_pred.conv_cls.bias':
        ('bbox_head.conv_pred.conv_out.bias', [(0, 2), (-NUM_CLASSES, -1)]),
        'bbox_head.conv_pred.conv_reg.weight':
        ('bbox_head.conv_pred.conv_out.weight', [(2, -NUM_CLASSES)]),
        'bbox_head.conv_pred.conv_reg.bias':
        ('bbox_head.conv_pred.conv_out.bias', [(2, -NUM_CLASSES)])
    }

    # Delete some useless keys
    for key in DEL_KEYS:
        converted_ckpt.pop(key)

    # Rename keys with specific prefix
    RENAME_KEYS = dict()
    for old_key in converted_ckpt.keys():
        for rename_prefix in RENAME_PREFIX.keys():
            if rename_prefix in old_key:
                new_key = old_key.replace(rename_prefix,
                                          RENAME_PREFIX[rename_prefix])
                RENAME_KEYS[new_key] = old_key
    for new_key, old_key in RENAME_KEYS.items():
        converted_ckpt[new_key] = converted_ckpt.pop(old_key)

    # Extract weights and rename the keys
    for new_key, (old_key, indices) in EXTRACT_KEYS.items():
        cur_layers = orig_ckpt[old_key]
        converted_layers = []
        for (start, end) in indices:
            if end != -1:
                converted_layers.append(cur_layers[start:end])
            else:
                converted_layers.append(cur_layers[start:])
        converted_layers = torch.cat(converted_layers, 0)
        converted_ckpt[new_key] = converted_layers
        if old_key in converted_ckpt.keys():
            converted_ckpt.pop(old_key)

    # Check the converted checkpoint by loading to the model
    load_state_dict(model, converted_ckpt, strict=True)
    checkpoint['state_dict'] = converted_ckpt
    torch.save(checkpoint, args.out)
예제 #16
0
def model_from_checkpoint(
    filename: Union[Path, str],
    model_name=None,
    backbone_name=None,
    classes=None,
    is_coco=False,
    img_size=None,
    map_location=None,
    strict=False,
    revise_keys=[
        (r"^module\.", ""),
    ],
    eval_mode=True,
    logger=None,
):
    """load checkpoint through URL scheme path.
    Args:
        filename (str): checkpoint file name with given prefix
        map_location (str, optional): Same as :func:`torch.load`.
            Default: None
    Returns:
        dict or OrderedDict: The loaded checkpoint.
    """

    if isinstance(filename, Path):
        filename = str(filename)

    checkpoint = _load_checkpoint(filename=filename,
                                  map_location=map_location,
                                  logger=logger)

    if is_coco and classes:
        err_msg = "`is_coco` cannot be set to True if `classes` is passed and `not None`. `classes` has priority. `is_coco` will be ignored."
        if logger is not None:
            logger.warning(err_msg)
        else:
            print(err_msg)

    if classes is None:
        if is_coco:
            classes = CLASSES
        else:
            classes = checkpoint["meta"].get("classes", None)

    class_map = None
    if classes:
        class_map = ClassMap(classes)
        num_classes = len(class_map)

    if img_size is None:
        img_size = checkpoint["meta"].get("img_size", None)

    if model_name is None:
        model_name = checkpoint["meta"].get("model_name", None)

    model_type = None
    if model_name:
        model = model_name.split(".")
        # If model_name contains three or more components, the library and model are the second to last and last components, respectively (e.g. models.mmdet.retinanet)
        if len(model) >= 3:
            model_type = getattr(getattr(models, model[-2]), model[-1])
        # If model_name follows the default convention, the library and model are the first and second components, respectively (e.g. mmdet.retinanet)
        else:
            model_type = getattr(getattr(models, model[0]), model[1])

    if backbone_name is None:
        backbone_name = checkpoint["meta"].get("backbone_name", None)
    if model_type and backbone_name:
        backbone = getattr(model_type.backbones, backbone_name)

    extra_args = {}
    if img_size is None:
        img_size = checkpoint["meta"].get("img_size", None)

    models_with_img_size = ("yolov5", "efficientdet")
    # if 'efficientdet' in model_name:
    if (model_name) and (any(m in model_name for m in models_with_img_size)):
        extra_args["img_size"] = img_size

    # Instantiate model
    if model_type and backbone:
        model = model_type.model(backbone=backbone(pretrained=False),
                                 num_classes=num_classes,
                                 **extra_args)
    else:
        model = None

    # OrderedDict is a subclass of dict
    if not isinstance(checkpoint, dict):
        raise RuntimeError(
            f"No state_dict found in checkpoint file {filename}")
    # get state_dict from checkpoint
    if "state_dict" in checkpoint:
        state_dict = checkpoint["state_dict"]
    else:
        state_dict = checkpoint
    # strip prefix of state_dict
    for p, r in revise_keys:
        state_dict = {re.sub(p, r, k): v for k, v in state_dict.items()}

    # load state_dict
    if model:
        load_state_dict(model, state_dict, strict, logger)
    if eval_mode:
        model.eval()

    checkpoint_and_model = {
        "model": model,
        "model_type": model_type,
        "backbone": backbone,
        "class_map": class_map,
        "img_size": img_size,
        "checkpoint": checkpoint,
    }
    return checkpoint_and_model