示例#1
0
 def __call__(self, module):
     from mmcv.runner import (_load_checkpoint_with_prefix, load_checkpoint,
                              load_state_dict)
     logger = get_logger('mmcv')
     if self.prefix is None:
         print_log(f'load model from: {self.checkpoint}', logger=logger)
         load_checkpoint(module,
                         self.checkpoint,
                         map_location=self.map_location,
                         strict=False,
                         logger=logger)
     else:
         print_log(f'load {self.prefix} in model from: {self.checkpoint}',
                   logger=logger)
         state_dict = _load_checkpoint_with_prefix(
             self.prefix, self.checkpoint, map_location=self.map_location)
         load_state_dict(module, state_dict, strict=False, logger=logger)
    def _load_checkpoint(self, checkpoint, prefix=None, map_location=None):
        from mmcv.runner import (_load_checkpoint,
                                 _load_checkpoint_with_prefix, load_state_dict)
        from mmcv.utils import print_log

        logger = get_root_logger()

        if prefix is None:
            print_log(f'load model from: {checkpoint}', logger=logger)
            checkpoint = _load_checkpoint(checkpoint, map_location, logger)
            # get state_dict from checkpoint
            if 'state_dict' in checkpoint:
                state_dict = checkpoint['state_dict']
            else:
                state_dict = checkpoint
        else:
            print_log(
                f'load {prefix} in model from: {checkpoint}', logger=logger)
            state_dict = _load_checkpoint_with_prefix(prefix, checkpoint,
                                                      map_location)

        if 'pos_embed' in state_dict.keys():
            ckpt_pos_embed_shape = state_dict['pos_embed'].shape
            if self.pos_embed.shape != ckpt_pos_embed_shape:
                print_log(
                    f'Resize the pos_embed shape from {ckpt_pos_embed_shape} '
                    f'to {self.pos_embed.shape}.',
                    logger=logger)

                ckpt_pos_embed_shape = to_2tuple(
                    int(np.sqrt(ckpt_pos_embed_shape[1] - 1)))
                pos_embed_shape = self.patch_embed.patches_resolution

                state_dict['pos_embed'] = self.resize_pos_embed(
                    state_dict['pos_embed'], ckpt_pos_embed_shape,
                    pos_embed_shape, self.interpolate_mode)

        # load state_dict
        load_state_dict(self, state_dict, strict=False, logger=logger)