def load_checkpoint(self, filename, map_location='cpu', strict=False, part=None): self.logger.info('load checkpoint from %s', filename) if not osp.isfile(filename): raise IOError(f'{filename} is not a checkpoint file') checkpoint = torch.load(filename, map_location=map_location) # OrderedDict is a subclass of dict if not isinstance(checkpoint, dict): raise RuntimeError( f'No state_dict found in checkpoint file {filename}') # get state_dict from checkpoint if 'state_dict' in checkpoint: state_dict = checkpoint['state_dict'] else: state_dict = checkpoint # strip prefix of state_dict if list(state_dict.keys())[0].startswith('module.'): state_dict = { k[7:]: v for k, v in checkpoint['state_dict'].items() } if part is not None: state_dict = {k: v for k, v in state_dict.items() if part in k} # load state_dict load_state_dict(self.model, state_dict, strict, self.logger) return checkpoint
def init_weights(self, pretrained=None): if isinstance(pretrained, str): logger = get_root_logger() #load_checkpoint(self, pretrained, map_location='cpu', strict=False, logger=logger) revise_keys = [(r'^module\.', '')] checkpoint = _load_checkpoint(pretrained, map_location='cpu', logger=logger) # OrderedDict is a subclass of dict if not isinstance(checkpoint, dict): raise RuntimeError( f'No state_dict found in checkpoint file {pretrained}') # get state_dict from checkpoint if 'state_dict' in checkpoint: state_dict = checkpoint['state_dict'] elif 'model' in checkpoint: # for our model state_dict = checkpoint['model'] else: state_dict = checkpoint # strip prefix of state_dict for p, r in revise_keys: state_dict = { re.sub(p, r, k): v for k, v in state_dict.items() } # load state_dict load_state_dict(self, state_dict, strict=False, logger=logger)
def init_weights(self, pretrained=None): """Initiate the parameters either from existing checkpoint or from scratch.""" trunc_normal_(self.pos_embed, std=.02) trunc_normal_(self.cls_token, std=.02) if pretrained: self.pretrained = pretrained if isinstance(self.pretrained, str): logger = get_root_logger() logger.info(f'load model from: {self.pretrained}') state_dict = _load_checkpoint(self.pretrained) if 'state_dict' in state_dict: state_dict = state_dict['state_dict'] if self.attention_type == 'divided_space_time': # modify the key names of norm layers old_state_dict_keys = list(state_dict.keys()) for old_key in old_state_dict_keys: if 'norms' in old_key: new_key = old_key.replace('norms.0', 'attentions.0.norm') new_key = new_key.replace('norms.1', 'ffns.0.norm') state_dict[new_key] = state_dict.pop(old_key) # copy the parameters of space attention to time attention old_state_dict_keys = list(state_dict.keys()) for old_key in old_state_dict_keys: if 'attentions.0' in old_key: new_key = old_key.replace('attentions.0', 'attentions.1') state_dict[new_key] = state_dict[old_key].clone() load_state_dict(self, state_dict, strict=False, logger=logger)
def init_from_pretrained(self, pretrained, logger): """ Initialization model weights from pretrained model. Args: pretrained (str): pretrained model path. (something like *.pth) logger (logging.Logger): output logger """ logger.info(f"Loading pretrained backbone from {pretrained}") checkpoint = torch.load(pretrained) # get state_dict from checkpoint if isinstance(checkpoint, OrderedDict): state_dict = checkpoint elif isinstance(checkpoint, dict) and 'state_dict' in checkpoint: state_dict = checkpoint['state_dict'] else: raise RuntimeError( f'No state_dict found in checkpoint file {pretrained}') # strip prefix of state_dict if list(state_dict.keys())[0].startswith('module.'): state_dict = { k[7:]: v for k, v in checkpoint['state_dict'].items() } # strip prefix of backbone if any([s.startswith('backbone.') for s in state_dict.keys()]): state_dict = { k[9:]: v for k, v in checkpoint['state_dict'].items() if k.startswith('backbone.') } load_state_dict(self, state_dict, strict=False, logger=logger)
def init_weights(self, pretrained=None): if isinstance(pretrained, str): logger = get_root_logger() checkpoint = _load_checkpoint(pretrained, logger=logger, map_location='cpu') logger.warning(f'Load pre-trained model for ' f'{self.__class__.__name__} from original repo') if 'state_dict' in checkpoint: state_dict = checkpoint['state_dict'] elif 'model' in checkpoint: state_dict = checkpoint['model'] else: state_dict = checkpoint if self.convert_weights: # We need to convert pre-trained weights to match this # implementation. state_dict = tcformer_convert(state_dict) load_state_dict(self, state_dict, strict=False, logger=logger) elif pretrained is None: for m in self.modules(): if isinstance(m, nn.Linear): trunc_normal_init(m, std=.02, bias=0.) elif isinstance(m, nn.LayerNorm): constant_init(m, 1.0) elif isinstance(m, nn.Conv2d): fan_out = m.kernel_size[0] * m.kernel_size[ 1] * m.out_channels fan_out //= m.groups normal_init(m, 0, math.sqrt(2.0 / fan_out)) else: raise TypeError('pretrained must be a str or None')
def __call__(self, module): from mmcv.runner import (_load_checkpoint_with_prefix, load_checkpoint, load_state_dict) logger = get_logger('mmcv') if self.prefix is None: print_log(f'load model from: {self.checkpoint}', logger=logger) load_checkpoint(module, self.checkpoint, map_location=self.map_location, strict=False, logger=logger) else: print_log(f'load {self.prefix} in model from: {self.checkpoint}', logger=logger) state_dict = _load_checkpoint_with_prefix( self.prefix, self.checkpoint, map_location=self.map_location) load_state_dict(module, state_dict, strict=False, logger=logger)
def init_weights(self): if (isinstance(self.init_cfg, dict) and self.init_cfg.get('type') == 'Pretrained'): logger = get_root_logger() checkpoint = CheckpointLoader.load_checkpoint( self.init_cfg['checkpoint'], logger=logger, map_location='cpu') if 'state_dict' in checkpoint: state_dict = checkpoint['state_dict'] else: state_dict = checkpoint if 'pos_embed' in state_dict.keys(): if self.pos_embed.shape != state_dict['pos_embed'].shape: logger.info(msg=f'Resize the pos_embed shape from ' f'{state_dict["pos_embed"].shape} to ' f'{self.pos_embed.shape}') h, w = self.img_size pos_size = int( math.sqrt(state_dict['pos_embed'].shape[1] - 1)) state_dict['pos_embed'] = self.resize_pos_embed( state_dict['pos_embed'], (h // self.patch_size, w // self.patch_size), (pos_size, pos_size), self.interpolate_mode) load_state_dict(self, state_dict, strict=False, logger=logger) elif self.init_cfg is not None: super(VisionTransformer, self).init_weights() else: # We only implement the 'jax_impl' initialization implemented at # https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py#L353 # noqa: E501 trunc_normal_(self.pos_embed, std=.02) trunc_normal_(self.cls_token, std=.02) for n, m in self.named_modules(): if isinstance(m, nn.Linear): trunc_normal_(m.weight, std=.02) if m.bias is not None: if 'ffn' in n: nn.init.normal_(m.bias, mean=0., std=1e-6) else: nn.init.constant_(m.bias, 0) elif isinstance(m, nn.Conv2d): kaiming_init(m, mode='fan_in', bias=0.) elif isinstance(m, (_BatchNorm, nn.GroupNorm, nn.LayerNorm)): constant_init(m, val=1.0, bias=0.)
def check_and_load_prev_weight(self, curr_scale): if curr_scale == 0: return prev_ch = self.blocks[curr_scale - 1].base_channels curr_ch = self.blocks[curr_scale].base_channels prev_in_ch = self.blocks[curr_scale - 1].in_channels curr_in_ch = self.blocks[curr_scale].in_channels if prev_ch == curr_ch and prev_in_ch == curr_in_ch: load_state_dict(self.blocks[curr_scale], self.blocks[curr_scale - 1].state_dict(), logger=get_root_logger()) print_log('Successfully load pretrianed model from last scale.') else: print_log( 'Cannot load pretrained model from last scale since' f' prev_ch({prev_ch}) != curr_ch({curr_ch})' f' or prev_in_ch({prev_in_ch}) != curr_in_ch({curr_in_ch})')
def init_weights(self): logger = get_root_logger() if self.init_cfg is None: logger.warn(f'No pre-trained weights for ' f'{self.__class__.__name__}, ' f'training start from scratch') for m in self.modules(): if isinstance(m, nn.Linear): trunc_normal_init(m.weight, std=.02) if m.bias is not None: constant_init(m.bias, 0) elif isinstance(m, nn.LayerNorm): constant_init(m.bias, 0) constant_init(m.weight, 1.0) elif isinstance(m, nn.Conv2d): fan_out = m.kernel_size[0] * m.kernel_size[ 1] * m.out_channels fan_out //= m.groups normal_init(m.weight, 0, math.sqrt(2.0 / fan_out)) if m.bias is not None: constant_init(m.bias, 0) elif isinstance(m, AbsolutePositionEmbedding): m.init_weights() else: assert 'checkpoint' in self.init_cfg, f'Only support ' \ f'specify `Pretrained` in ' \ f'`init_cfg` in ' \ f'{self.__class__.__name__} ' checkpoint = _load_checkpoint( self.init_cfg.checkpoint, logger=logger, map_location='cpu') logger.warn(f'Load pre-trained model for ' f'{self.__class__.__name__} from original repo') if 'state_dict' in checkpoint: state_dict = checkpoint['state_dict'] elif 'model' in checkpoint: state_dict = checkpoint['model'] else: state_dict = checkpoint if self.convert_weights: # Because pvt backbones are not supported by mmcls, # so we need to convert pre-trained weights to match this # implementation. state_dict = pvt_convert(state_dict) load_state_dict(self, state_dict, strict=False, logger=logger)
def load_pretrain(self, pretrained): if not os.path.exists(pretrained): raise FileNotFoundError(f"File '{pretrained}' not exists") logger.info(f"Loading pretrained model from: {pretrained}") state_dict = torch.load(pretrained, map_location="cpu") state_dict = state_dict['state_dict'] # replace some keys key_replaces = [("token_learner", "token_reduction")] for key in list(state_dict.keys()): for rep in key_replaces: if rep[0] in key: state_dict[key.replace(*rep)] = state_dict.pop(key) # delete some keys for key in ["loss.logit_scale"]: if key in state_dict: del state_dict[key] load_state_dict(self, state_dict, True)
def init_weights(self): """Initialize the weights in backbone.""" logger = get_root_logger() if self.init_cfg is None: logger.warning(f'No pre-trained weights for ' f'{self.__class__.__name__}, ' f'training start from scratch') for m in self.modules(): if isinstance(m, nn.Linear): trunc_normal_init(m, std=.02, bias=0.) elif isinstance(m, nn.LayerNorm): constant_init(m, 1.0) for bly in self.layers: bly._init_respostnorm() else: if self.ape: raise NotImplementedError assert 'checkpoint' in self.init_cfg, f'Only support ' \ f'specify `Pretrained` in ' \ f'`init_cfg` in ' \ f'{self.__class__.__name__} ' ckpt = _load_checkpoint(self.init_cfg.checkpoint, logger=logger, map_location='cpu') if 'state_dict' in ckpt: state_dict = ckpt['state_dict'] elif 'model' in ckpt: state_dict = ckpt['model'] else: state_dict = ckpt # delete keys for reinitialization reinit_keys = ('relative_position_index', 'relative_coords_table') for reinit_key in reinit_keys: for k in list(state_dict.keys()): if reinit_key in k: del state_dict[k] load_state_dict(self, state_dict, strict=False, logger=logger)
def _load_checkpoint(self, checkpoint, prefix=None, map_location=None): from mmcv.runner import (_load_checkpoint, _load_checkpoint_with_prefix, load_state_dict) from mmcv.utils import print_log logger = get_root_logger() if prefix is None: print_log(f'load model from: {checkpoint}', logger=logger) checkpoint = _load_checkpoint(checkpoint, map_location, logger) # get state_dict from checkpoint if 'state_dict' in checkpoint: state_dict = checkpoint['state_dict'] else: state_dict = checkpoint else: print_log( f'load {prefix} in model from: {checkpoint}', logger=logger) state_dict = _load_checkpoint_with_prefix(prefix, checkpoint, map_location) if 'pos_embed' in state_dict.keys(): ckpt_pos_embed_shape = state_dict['pos_embed'].shape if self.pos_embed.shape != ckpt_pos_embed_shape: print_log( f'Resize the pos_embed shape from {ckpt_pos_embed_shape} ' f'to {self.pos_embed.shape}.', logger=logger) ckpt_pos_embed_shape = to_2tuple( int(np.sqrt(ckpt_pos_embed_shape[1] - 1))) pos_embed_shape = self.patch_embed.patches_resolution state_dict['pos_embed'] = self.resize_pos_embed( state_dict['pos_embed'], ckpt_pos_embed_shape, pos_embed_shape, self.interpolate_mode) # load state_dict load_state_dict(self, state_dict, strict=False, logger=logger)
def load_checkpoint(model, filename, map_location=None, strict=False, logger=None): """Load checkpoint from a file or URI. Args: model (Module): Module to load checkpoint. filename (str): Accept local filepath, URL, ``torchvision://xxx``, ``open-mmlab://xxx``. Please refer to ``docs/model_zoo.md`` for details. map_location (str): Same as :func:`torch.load`. strict (bool): Whether to allow different params for the model and checkpoint. logger (:mod:`logging.Logger` or None): The logger for error message. Returns: dict or OrderedDict: The loaded checkpoint. """ checkpoint = torch.load(filename, map_location=lambda storage, loc: storage) # OrderedDict is a subclass of dict if not isinstance(checkpoint, dict): raise RuntimeError( f'No state_dict found in checkpoint file {filename}') # get state_dict from checkpoint if 'state_dict' in checkpoint: state_dict = checkpoint['state_dict'] else: state_dict = checkpoint # strip prefix of state_dict if list(state_dict.keys())[0].startswith('module.'): state_dict = {k[7:]: v for k, v in checkpoint['state_dict'].items()} # state_dict = {k.split('backbone.')[1]: v for k, v in state_dict.items() if k.startswith('backbone') } # load state_dict load_state_dict(model, state_dict, strict, logger) return checkpoint
def init_weights(self): logger = get_root_logger() if self.init_cfg is None: logger.warn(f'No pre-trained weights for ' f'{self.__class__.__name__}, ' f'training start from scratch') if self.use_abs_pos_embed: trunc_normal_(self.absolute_pos_embed, std=0.02) for m in self.modules(): if isinstance(m, nn.Linear): trunc_normal_init(m, std=.02, bias=0.) elif isinstance(m, nn.LayerNorm): constant_init(m, val=1.0, bias=0.) else: assert 'checkpoint' in self.init_cfg, f'Only support ' \ f'specify `Pretrained` in ' \ f'`init_cfg` in ' \ f'{self.__class__.__name__} ' ckpt = CheckpointLoader.load_checkpoint( self.init_cfg['checkpoint'], logger=logger, map_location='cpu') if 'state_dict' in ckpt: _state_dict = ckpt['state_dict'] elif 'model' in ckpt: _state_dict = ckpt['model'] else: _state_dict = ckpt state_dict = OrderedDict() for k, v in _state_dict.items(): if k.startswith('backbone.'): state_dict[k[9:]] = v else: state_dict[k] = v # strip prefix of state_dict if list(state_dict.keys())[0].startswith('module.'): state_dict = {k[7:]: v for k, v in state_dict.items()} # reshape absolute position embedding if state_dict.get('absolute_pos_embed') is not None: absolute_pos_embed = state_dict['absolute_pos_embed'] N1, L, C1 = absolute_pos_embed.size() N2, C2, H, W = self.absolute_pos_embed.size() if N1 != N2 or C1 != C2 or L != H * W: logger.warning('Error in loading absolute_pos_embed, pass') else: state_dict['absolute_pos_embed'] = absolute_pos_embed.view( N2, H, W, C2).permute(0, 3, 1, 2).contiguous() # interpolate position bias table if needed relative_position_bias_table_keys = [ k for k in state_dict.keys() if 'relative_position_bias_table' in k ] for table_key in relative_position_bias_table_keys: table_pretrained = state_dict[table_key] table_current = self.state_dict()[table_key] L1, nH1 = table_pretrained.size() L2, nH2 = table_current.size() if nH1 != nH2: logger.warning(f'Error in loading {table_key}, pass') elif L1 != L2: S1 = int(L1**0.5) S2 = int(L2**0.5) table_pretrained_resized = F.interpolate( table_pretrained.permute(1, 0).reshape(1, nH1, S1, S1), size=(S2, S2), mode='bicubic') state_dict[table_key] = table_pretrained_resized.view( nH2, L2).permute(1, 0).contiguous() # load state_dict load_state_dict(self, state_dict, strict=False, logger=logger)
def main(): """Convert keys in checkpoints for VoteNet. There can be some breaking changes during the development of mmdetection3d, and this tool is used for upgrading checkpoints trained with old versions (before v0.6.0) to the latest one. """ args = parse_args() checkpoint = torch.load(args.checkpoint) cfg = parse_config(checkpoint['meta']['config']) # Build the model and load checkpoint model = build_detector(cfg.model, train_cfg=cfg.get('train_cfg'), test_cfg=cfg.get('test_cfg')) orig_ckpt = checkpoint['state_dict'] converted_ckpt = orig_ckpt.copy() if cfg['dataset_type'] == 'ScanNetDataset': NUM_CLASSES = 18 elif cfg['dataset_type'] == 'SUNRGBDDataset': NUM_CLASSES = 10 else: raise NotImplementedError RENAME_PREFIX = { 'bbox_head.conv_pred.0': 'bbox_head.conv_pred.shared_convs.layer0', 'bbox_head.conv_pred.1': 'bbox_head.conv_pred.shared_convs.layer1' } DEL_KEYS = [ 'bbox_head.conv_pred.0.bn.num_batches_tracked', 'bbox_head.conv_pred.1.bn.num_batches_tracked' ] EXTRACT_KEYS = { 'bbox_head.conv_pred.conv_cls.weight': ('bbox_head.conv_pred.conv_out.weight', [(0, 2), (-NUM_CLASSES, -1)]), 'bbox_head.conv_pred.conv_cls.bias': ('bbox_head.conv_pred.conv_out.bias', [(0, 2), (-NUM_CLASSES, -1)]), 'bbox_head.conv_pred.conv_reg.weight': ('bbox_head.conv_pred.conv_out.weight', [(2, -NUM_CLASSES)]), 'bbox_head.conv_pred.conv_reg.bias': ('bbox_head.conv_pred.conv_out.bias', [(2, -NUM_CLASSES)]) } # Delete some useless keys for key in DEL_KEYS: converted_ckpt.pop(key) # Rename keys with specific prefix RENAME_KEYS = dict() for old_key in converted_ckpt.keys(): for rename_prefix in RENAME_PREFIX.keys(): if rename_prefix in old_key: new_key = old_key.replace(rename_prefix, RENAME_PREFIX[rename_prefix]) RENAME_KEYS[new_key] = old_key for new_key, old_key in RENAME_KEYS.items(): converted_ckpt[new_key] = converted_ckpt.pop(old_key) # Extract weights and rename the keys for new_key, (old_key, indices) in EXTRACT_KEYS.items(): cur_layers = orig_ckpt[old_key] converted_layers = [] for (start, end) in indices: if end != -1: converted_layers.append(cur_layers[start:end]) else: converted_layers.append(cur_layers[start:]) converted_layers = torch.cat(converted_layers, 0) converted_ckpt[new_key] = converted_layers if old_key in converted_ckpt.keys(): converted_ckpt.pop(old_key) # Check the converted checkpoint by loading to the model load_state_dict(model, converted_ckpt, strict=True) checkpoint['state_dict'] = converted_ckpt torch.save(checkpoint, args.out)
def model_from_checkpoint( filename: Union[Path, str], model_name=None, backbone_name=None, classes=None, is_coco=False, img_size=None, map_location=None, strict=False, revise_keys=[ (r"^module\.", ""), ], eval_mode=True, logger=None, ): """load checkpoint through URL scheme path. Args: filename (str): checkpoint file name with given prefix map_location (str, optional): Same as :func:`torch.load`. Default: None Returns: dict or OrderedDict: The loaded checkpoint. """ if isinstance(filename, Path): filename = str(filename) checkpoint = _load_checkpoint(filename=filename, map_location=map_location, logger=logger) if is_coco and classes: err_msg = "`is_coco` cannot be set to True if `classes` is passed and `not None`. `classes` has priority. `is_coco` will be ignored." if logger is not None: logger.warning(err_msg) else: print(err_msg) if classes is None: if is_coco: classes = CLASSES else: classes = checkpoint["meta"].get("classes", None) class_map = None if classes: class_map = ClassMap(classes) num_classes = len(class_map) if img_size is None: img_size = checkpoint["meta"].get("img_size", None) if model_name is None: model_name = checkpoint["meta"].get("model_name", None) model_type = None if model_name: model = model_name.split(".") # If model_name contains three or more components, the library and model are the second to last and last components, respectively (e.g. models.mmdet.retinanet) if len(model) >= 3: model_type = getattr(getattr(models, model[-2]), model[-1]) # If model_name follows the default convention, the library and model are the first and second components, respectively (e.g. mmdet.retinanet) else: model_type = getattr(getattr(models, model[0]), model[1]) if backbone_name is None: backbone_name = checkpoint["meta"].get("backbone_name", None) if model_type and backbone_name: backbone = getattr(model_type.backbones, backbone_name) extra_args = {} if img_size is None: img_size = checkpoint["meta"].get("img_size", None) models_with_img_size = ("yolov5", "efficientdet") # if 'efficientdet' in model_name: if (model_name) and (any(m in model_name for m in models_with_img_size)): extra_args["img_size"] = img_size # Instantiate model if model_type and backbone: model = model_type.model(backbone=backbone(pretrained=False), num_classes=num_classes, **extra_args) else: model = None # OrderedDict is a subclass of dict if not isinstance(checkpoint, dict): raise RuntimeError( f"No state_dict found in checkpoint file {filename}") # get state_dict from checkpoint if "state_dict" in checkpoint: state_dict = checkpoint["state_dict"] else: state_dict = checkpoint # strip prefix of state_dict for p, r in revise_keys: state_dict = {re.sub(p, r, k): v for k, v in state_dict.items()} # load state_dict if model: load_state_dict(model, state_dict, strict, logger) if eval_mode: model.eval() checkpoint_and_model = { "model": model, "model_type": model_type, "backbone": backbone, "class_map": class_map, "img_size": img_size, "checkpoint": checkpoint, } return checkpoint_and_model