def __call__(self, module): from mmcv.runner import (_load_checkpoint_with_prefix, load_checkpoint, load_state_dict) logger = get_logger('mmcv') if self.prefix is None: print_log(f'load model from: {self.checkpoint}', logger=logger) load_checkpoint(module, self.checkpoint, map_location=self.map_location, strict=False, logger=logger) else: print_log(f'load {self.prefix} in model from: {self.checkpoint}', logger=logger) state_dict = _load_checkpoint_with_prefix( self.prefix, self.checkpoint, map_location=self.map_location) load_state_dict(module, state_dict, strict=False, logger=logger)
def _load_checkpoint(self, checkpoint, prefix=None, map_location=None): from mmcv.runner import (_load_checkpoint, _load_checkpoint_with_prefix, load_state_dict) from mmcv.utils import print_log logger = get_root_logger() if prefix is None: print_log(f'load model from: {checkpoint}', logger=logger) checkpoint = _load_checkpoint(checkpoint, map_location, logger) # get state_dict from checkpoint if 'state_dict' in checkpoint: state_dict = checkpoint['state_dict'] else: state_dict = checkpoint else: print_log( f'load {prefix} in model from: {checkpoint}', logger=logger) state_dict = _load_checkpoint_with_prefix(prefix, checkpoint, map_location) if 'pos_embed' in state_dict.keys(): ckpt_pos_embed_shape = state_dict['pos_embed'].shape if self.pos_embed.shape != ckpt_pos_embed_shape: print_log( f'Resize the pos_embed shape from {ckpt_pos_embed_shape} ' f'to {self.pos_embed.shape}.', logger=logger) ckpt_pos_embed_shape = to_2tuple( int(np.sqrt(ckpt_pos_embed_shape[1] - 1))) pos_embed_shape = self.patch_embed.patches_resolution state_dict['pos_embed'] = self.resize_pos_embed( state_dict['pos_embed'], ckpt_pos_embed_shape, pos_embed_shape, self.interpolate_mode) # load state_dict load_state_dict(self, state_dict, strict=False, logger=logger)