def init_weights(self, pretrained=None): if isinstance(pretrained, str): logger = get_root_logger() #load_checkpoint(self, pretrained, map_location='cpu', strict=False, logger=logger) revise_keys = [(r'^module\.', '')] checkpoint = _load_checkpoint(pretrained, map_location='cpu', logger=logger) # OrderedDict is a subclass of dict if not isinstance(checkpoint, dict): raise RuntimeError( f'No state_dict found in checkpoint file {pretrained}') # get state_dict from checkpoint if 'state_dict' in checkpoint: state_dict = checkpoint['state_dict'] elif 'model' in checkpoint: # for our model state_dict = checkpoint['model'] else: state_dict = checkpoint # strip prefix of state_dict for p, r in revise_keys: state_dict = { re.sub(p, r, k): v for k, v in state_dict.items() } # load state_dict load_state_dict(self, state_dict, strict=False, logger=logger)
def init_weights(self): if self.pretrained is None: for m in self.modules(): if isinstance(m, nn.Linear): trunc_normal_init(m.weight, std=.02) if m.bias is not None: constant_init(m.bias, 0) elif isinstance(m, nn.LayerNorm): constant_init(m.bias, 0) constant_init(m.weight, 1.0) elif isinstance(m, nn.Conv2d): fan_out = m.kernel_size[0] * m.kernel_size[ 1] * m.out_channels fan_out //= m.groups normal_init(m.weight, 0, math.sqrt(2.0 / fan_out)) if m.bias is not None: constant_init(m.bias, 0) elif isinstance(self.pretrained, str): logger = get_root_logger() checkpoint = _load_checkpoint(self.pretrained, logger=logger, map_location='cpu') if 'state_dict' in checkpoint: state_dict = checkpoint['state_dict'] else: state_dict = checkpoint self.load_state_dict(state_dict, False)
def _load_torchvision_checkpoint(self, logger=None): """Initiate the parameters from torchvision pretrained checkpoint.""" state_dict_torchvision = _load_checkpoint(self.pretrained) if 'state_dict' in state_dict_torchvision: state_dict_torchvision = state_dict_torchvision['state_dict'] loaded_param_names = [] for name, module in self.named_modules(): if isinstance(module, ConvModule): # we use a ConvModule to wrap conv+bn+relu layers, thus the # name mapping is needed if 'downsample' in name: # layer{X}.{Y}.downsample.conv->layer{X}.{Y}.downsample.0 original_conv_name = name + '.0' # layer{X}.{Y}.downsample.bn->layer{X}.{Y}.downsample.1 original_bn_name = name + '.1' else: # layer{X}.{Y}.conv{n}.conv->layer{X}.{Y}.conv{n} original_conv_name = name # layer{X}.{Y}.conv{n}.bn->layer{X}.{Y}.bn{n} original_bn_name = name.replace('conv', 'bn') self._load_conv_params(module.conv, state_dict_torchvision, original_conv_name, loaded_param_names) self._load_bn_params(module.bn, state_dict_torchvision, original_bn_name, loaded_param_names) # check if any parameters in the 2d checkpoint are not loaded remaining_names = set( state_dict_torchvision.keys()) - set(loaded_param_names) if remaining_names: logger.info( f'These parameters in pretrained checkpoint are not loaded' f': {remaining_names}')
def init_weights(self, pretrained=None): logger = get_root_logger() if self.init_cfg is None and pretrained is None: logger.warn(f'No pre-trained weights for ' f'{self.__class__.__name__}, ' f'training start from scratch') pass else: assert 'checkpoint' in self.init_cfg, f'Only support ' \ f'specify `Pretrained` in ' \ f'`init_cfg` in ' \ f'{self.__class__.__name__} ' if self.init_cfg is not None: ckpt_path = self.init_cfg['checkpoint'] elif pretrained is not None: ckpt_path = pretrained ckpt = _load_checkpoint(ckpt_path, logger=logger, map_location='cpu') if 'state_dict' in ckpt: _state_dict = ckpt['state_dict'] elif 'model' in ckpt: _state_dict = ckpt['model'] else: _state_dict = ckpt state_dict = _state_dict missing_keys, unexpected_keys = \ self.load_state_dict(state_dict, False)
def init_weights(self): """Initialize the weights in backbone.""" logger = get_root_logger() if self.init_cfg is None: logger.warning(f'No pre-trained weights for ' f'{self.__class__.__name__}, ' f'training start from scratch') self.apply(self._init_weights) else: assert 'checkpoint' in self.init_cfg, f'Only support ' \ f'specify `Pretrained` in ' \ f'`init_cfg` in ' \ f'{self.__class__.__name__} ' ckpt = _load_checkpoint(self.init_cfg.checkpoint, logger=logger, map_location='cpu') if 'state_dict' in ckpt: state_dict = ckpt['state_dict'] elif 'model' in ckpt: state_dict = ckpt['model'] else: state_dict = ckpt missing_keys, unexpected_keys = \ self.load_state_dict(state_dict, False) logger.warning(f'missing_keys: {missing_keys}') logger.warning(f'unexpected_keys: {unexpected_keys}')
def init_weights(self, pretrained=None): """Initiate the parameters either from existing checkpoint or from scratch.""" trunc_normal_(self.pos_embed, std=.02) trunc_normal_(self.cls_token, std=.02) if pretrained: self.pretrained = pretrained if isinstance(self.pretrained, str): logger = get_root_logger() logger.info(f'load model from: {self.pretrained}') state_dict = _load_checkpoint(self.pretrained) if 'state_dict' in state_dict: state_dict = state_dict['state_dict'] if self.attention_type == 'divided_space_time': # modify the key names of norm layers old_state_dict_keys = list(state_dict.keys()) for old_key in old_state_dict_keys: if 'norms' in old_key: new_key = old_key.replace('norms.0', 'attentions.0.norm') new_key = new_key.replace('norms.1', 'ffns.0.norm') state_dict[new_key] = state_dict.pop(old_key) # copy the parameters of space attention to time attention old_state_dict_keys = list(state_dict.keys()) for old_key in old_state_dict_keys: if 'attentions.0' in old_key: new_key = old_key.replace('attentions.0', 'attentions.1') state_dict[new_key] = state_dict[old_key].clone() load_state_dict(self, state_dict, strict=False, logger=logger)
def _load_file(self, filename): if filename.endswith(".pkl"): with PathManager.open(filename, "rb") as f: data = pickle.load(f, encoding="latin1") if "model" in data and "__author__" in data: # file is in Detectron2 model zoo format self.logger.info("Reading a file from '{}'".format(data["__author__"])) return data else: # assume file is from Caffe2 / Detectron1 model zoo if "blobs" in data: # Detection models have "blobs", but ImageNet models don't data = data["blobs"] data = {k: v for k, v in data.items() if not k.endswith("_momentum")} if "weight_order" in data: del data["weight_order"] return {"model": data, "__author__": "Caffe2", "matching_heuristics": True} if filename.startswith("torchvision://") or filename.startswith(("http://", "https://")): loaded = _load_checkpoint(filename) # load torchvision pretrained model using mmcv else: loaded = super()._load_file(filename) # load native pth checkpoint if "model" not in loaded: loaded = {"model": loaded} basename = os.path.basename(filename).lower() if "lpf" in basename or "dla" in basename: loaded["matching_heuristics"] = True return loaded
def init_weights(self, pretrained=None): if isinstance(pretrained, str): logger = get_root_logger() checkpoint = _load_checkpoint(pretrained, logger=logger, map_location='cpu') logger.warning(f'Load pre-trained model for ' f'{self.__class__.__name__} from original repo') if 'state_dict' in checkpoint: state_dict = checkpoint['state_dict'] elif 'model' in checkpoint: state_dict = checkpoint['model'] else: state_dict = checkpoint if self.convert_weights: # We need to convert pre-trained weights to match this # implementation. state_dict = tcformer_convert(state_dict) load_state_dict(self, state_dict, strict=False, logger=logger) elif pretrained is None: for m in self.modules(): if isinstance(m, nn.Linear): trunc_normal_init(m, std=.02, bias=0.) elif isinstance(m, nn.LayerNorm): constant_init(m, 1.0) elif isinstance(m, nn.Conv2d): fan_out = m.kernel_size[0] * m.kernel_size[ 1] * m.out_channels fan_out //= m.groups normal_init(m, 0, math.sqrt(2.0 / fan_out)) else: raise TypeError('pretrained must be a str or None')
def _check_backbone(config, print_cfg=True): """Check out backbone whether successfully load pretrained model, by using `backbone.init_cfg`. First, using `mmcv._load_checkpoint` to load the checkpoint without loading models. Then, using `build_detector` to build models, and using `model.init_weights()` to initialize the parameters. Finally, assert weights and bias of each layer loaded from pretrained checkpoint are equal to the weights and bias of original checkpoint. For the convenience of comparison, we sum up weights and bias of each loaded layer separately. Args: config (str): Config file path. print_cfg (bool): Whether print logger and return the result. Returns: results (str or None): If backbone successfully load pretrained checkpoint, return None; else, return config file path. """ if print_cfg: print('-' * 15 + 'loading ', config) cfg = Config.fromfile(config) init_cfg = None try: init_cfg = cfg.model.backbone.init_cfg init_flag = True except AttributeError: init_flag = False if init_cfg is None or init_cfg.get('type') != 'Pretrained': init_flag = False if init_flag: checkpoint = _load_checkpoint(init_cfg.checkpoint) if 'state_dict' in checkpoint: state_dict = checkpoint['state_dict'] else: state_dict = checkpoint model = build_detector(cfg.model, train_cfg=cfg.get('train_cfg'), test_cfg=cfg.get('test_cfg')) model.init_weights() checkpoint_layers = state_dict.keys() for name, value in model.backbone.state_dict().items(): if name in checkpoint_layers: assert value.equal(state_dict[name]) if print_cfg: print('-' * 10 + 'Successfully load checkpoint' + '-' * 10 + '\n', ) return None else: if print_cfg: print(config + '\n' + '-' * 10 + 'config file do not have init_cfg' + '-' * 10 + '\n') return config
def inflate_weights(self, logger): """Inflate the resnet2d parameters to resnet3d pathway. The differences between resnet3d and resnet2d mainly lie in an extra axis of conv kernel. To utilize the pretrained parameters in 2d model, the weight of conv2d models should be inflated to fit in the shapes of the 3d counterpart. For pathway the ``lateral_connection`` part should not be inflated from 2d weights. Args: logger (logging.Logger): The logger used to print debugging infomation. """ state_dict_r2d = _load_checkpoint(self.pretrained) if 'state_dict' in state_dict_r2d: state_dict_r2d = state_dict_r2d['state_dict'] inflated_param_names = [] for name, module in self.named_modules(): if 'lateral' in name: continue if isinstance(module, ConvModule): # we use a ConvModule to wrap conv+bn+relu layers, thus the # name mapping is needed if 'downsample' in name: # layer{X}.{Y}.downsample.conv->layer{X}.{Y}.downsample.0 original_conv_name = name + '.0' # layer{X}.{Y}.downsample.bn->layer{X}.{Y}.downsample.1 original_bn_name = name + '.1' else: # layer{X}.{Y}.conv{n}.conv->layer{X}.{Y}.conv{n} original_conv_name = name # layer{X}.{Y}.conv{n}.bn->layer{X}.{Y}.bn{n} original_bn_name = name.replace('conv', 'bn') if original_conv_name + '.weight' not in state_dict_r2d: logger.warning(f'Module not exist in the state_dict_r2d' f': {original_conv_name}') else: self._inflate_conv_params(module.conv, state_dict_r2d, original_conv_name, inflated_param_names) if original_bn_name + '.weight' not in state_dict_r2d: logger.warning(f'Module not exist in the state_dict_r2d' f': {original_bn_name}') else: self._inflate_bn_params(module.bn, state_dict_r2d, original_bn_name, inflated_param_names) # check if any parameters in the 2d checkpoint are not loaded remaining_names = set( state_dict_r2d.keys()) - set(inflated_param_names) if remaining_names: logger.info(f'These parameters in the 2d checkpoint are not loaded' f': {remaining_names}')
def init_weights(self): if isinstance(self.pretrained, str): logger = get_root_logger() checkpoint = _load_checkpoint(self.pretrained, logger=logger, map_location='cpu') if 'state_dict' in checkpoint: state_dict = checkpoint['state_dict'] elif 'model' in checkpoint: state_dict = checkpoint['model'] else: state_dict = checkpoint if self.pretrain_style == 'timm': # Because the refactor of vit is blocked by mmcls, # so we firstly use timm pretrain weights to train # downstream model. state_dict = vit_convert(state_dict) if 'pos_embed' in state_dict.keys(): if self.pos_embed.shape != state_dict['pos_embed'].shape: logger.info(msg=f'Resize the pos_embed shape from ' f'{state_dict["pos_embed"].shape} to ' f'{self.pos_embed.shape}') h, w = self.img_size pos_size = int( math.sqrt(state_dict['pos_embed'].shape[1] - 1)) state_dict['pos_embed'] = self.resize_pos_embed( state_dict['pos_embed'], (h // self.patch_size, w // self.patch_size), (pos_size, pos_size), self.interpolate_mode) self.load_state_dict(state_dict, False) elif self.pretrained is None: super(VisionTransformer, self).init_weights() # We only implement the 'jax_impl' initialization implemented at # https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py#L353 # noqa: E501 trunc_normal_init(self.pos_embed, std=.02) trunc_normal_init(self.cls_token, std=.02) for n, m in self.named_modules(): if isinstance(m, nn.Linear): trunc_normal_init(m.weight, std=.02) if m.bias is not None: if 'ffn' in n: normal_init(m.bias, std=1e-6) else: constant_init(m.bias, 0) elif isinstance(m, nn.Conv2d): kaiming_init(m.weight, mode='fan_in') if m.bias is not None: constant_init(m.bias, 0) elif isinstance(m, (_BatchNorm, nn.GroupNorm, nn.LayerNorm)): constant_init(m.bias, 0) constant_init(m.weight, 1.0)
def init_weights(self, pretrained=None): if isinstance(pretrained, str): logger = get_root_logger() checkpoint = _load_checkpoint(pretrained, logger=logger) if 'state_dict' in checkpoint: state_dict = checkpoint['state_dict'] elif 'model' in checkpoint: state_dict = checkpoint['model'] else: state_dict = checkpoint if 'pos_embed' in state_dict.keys(): if self.pos_embed.shape != state_dict['pos_embed'].shape: logger.info(msg=f'Resize the pos_embed shape from \ {state_dict["pos_embed"].shape} to {self.pos_embed.shape}') h, w = self.img_size pos_size = int( math.sqrt(state_dict['pos_embed'].shape[1] - 1)) state_dict['pos_embed'] = self.resize_pos_embed( state_dict['pos_embed'], (h, w), (pos_size, pos_size), self.patch_size, self.interpolate_mode) self.load_state_dict(state_dict, False) elif pretrained is None: # We only implement the 'jax_impl' initialization implemented at # https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py#L353 # noqa: E501 trunc_normal_(self.pos_embed, std=.02) trunc_normal_(self.cls_token, std=.02) for n, m in self.named_modules(): if isinstance(m, Linear): trunc_normal_(m.weight, std=.02) if m.bias is not None: if 'mlp' in n: normal_init(m.bias, std=1e-6) else: constant_init(m.bias, 0) elif isinstance(m, Conv2d): kaiming_init(m.weight, mode='fan_in') if m.bias is not None: constant_init(m.bias, 0) elif isinstance(m, (_BatchNorm, nn.GroupNorm, nn.LayerNorm)): constant_init(m.bias, 0) constant_init(m.weight, 1.0) else: raise TypeError('pretrained must be a str or None')
def init_weights(self): if (isinstance(self.init_cfg, dict) and self.init_cfg.get('type') == 'Pretrained'): logger = get_root_logger() checkpoint = _load_checkpoint( self.init_cfg['checkpoint'], logger=logger, map_location='cpu') if 'state_dict' in checkpoint: state_dict = checkpoint['state_dict'] else: state_dict = checkpoint if 'pos_embed' in state_dict.keys(): if self.pos_embed.shape != state_dict['pos_embed'].shape: logger.info(msg=f'Resize the pos_embed shape from ' f'{state_dict["pos_embed"].shape} to ' f'{self.pos_embed.shape}') h, w = self.img_size pos_size = int( math.sqrt(state_dict['pos_embed'].shape[1] - 1)) state_dict['pos_embed'] = self.resize_pos_embed( state_dict['pos_embed'], (h // self.patch_size, w // self.patch_size), (pos_size, pos_size), self.interpolate_mode) self.load_state_dict(state_dict, False) elif self.init_cfg is not None: super(VisionTransformer, self).init_weights() else: # We only implement the 'jax_impl' initialization implemented at # https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py#L353 # noqa: E501 trunc_normal_(self.pos_embed, std=.02) trunc_normal_(self.cls_token, std=.02) for n, m in self.named_modules(): if isinstance(m, nn.Linear): trunc_normal_(m.weight, std=.02) if m.bias is not None: if 'ffn' in n: nn.init.normal_(m.bias, mean=0., std=1e-6) else: nn.init.constant_(m.bias, 0) elif isinstance(m, nn.Conv2d): kaiming_init(m, mode='fan_in', bias=0.) elif isinstance(m, (_BatchNorm, nn.GroupNorm, nn.LayerNorm)): constant_init(m, val=1.0, bias=0.)
def init_weights(self): logger = get_root_logger() if self.init_cfg is None: logger.warn(f'No pre-trained weights for ' f'{self.__class__.__name__}, ' f'training start from scratch') for m in self.modules(): if isinstance(m, nn.Linear): trunc_normal_init(m.weight, std=.02) if m.bias is not None: constant_init(m.bias, 0) elif isinstance(m, nn.LayerNorm): constant_init(m.bias, 0) constant_init(m.weight, 1.0) elif isinstance(m, nn.Conv2d): fan_out = m.kernel_size[0] * m.kernel_size[ 1] * m.out_channels fan_out //= m.groups normal_init(m.weight, 0, math.sqrt(2.0 / fan_out)) if m.bias is not None: constant_init(m.bias, 0) elif isinstance(m, AbsolutePositionEmbedding): m.init_weights() else: assert 'checkpoint' in self.init_cfg, f'Only support ' \ f'specify `Pretrained` in ' \ f'`init_cfg` in ' \ f'{self.__class__.__name__} ' checkpoint = _load_checkpoint( self.init_cfg.checkpoint, logger=logger, map_location='cpu') logger.warn(f'Load pre-trained model for ' f'{self.__class__.__name__} from original repo') if 'state_dict' in checkpoint: state_dict = checkpoint['state_dict'] elif 'model' in checkpoint: state_dict = checkpoint['model'] else: state_dict = checkpoint if self.convert_weights: # Because pvt backbones are not supported by mmcls, # so we need to convert pre-trained weights to match this # implementation. state_dict = pvt_convert(state_dict) load_state_dict(self, state_dict, strict=False, logger=logger)
def init_weights(self): def _init_weights(m): if isinstance(m, nn.Linear): trunc_normal_(m.weight, std=.02) if isinstance(m, nn.Linear) and m.bias is not None: nn.init.constant_(m.bias, 0) elif isinstance(m, nn.LayerNorm): nn.init.constant_(m.bias, 0) nn.init.constant_(m.weight, 1.0) self.apply(_init_weights) self.fix_init_weight() if (isinstance(self.init_cfg, dict) and self.init_cfg.get('type') == 'Pretrained'): logger = get_root_logger() checkpoint = _load_checkpoint( self.init_cfg['checkpoint'], logger=logger, map_location='cpu') state_dict = self.resize_rel_pos_embed(checkpoint) state_dict = self.resize_abs_pos_embed(state_dict) self.load_state_dict(state_dict, False) elif self.init_cfg is not None: super(MAE, self).init_weights() else: # We only implement the 'jax_impl' initialization implemented at # https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py#L353 # noqa: E501 # Copyright 2019 Ross Wightman # Licensed under the Apache License, Version 2.0 (the "License") trunc_normal_(self.cls_token, std=.02) for n, m in self.named_modules(): if isinstance(m, nn.Linear): trunc_normal_(m.weight, std=.02) if m.bias is not None: if 'ffn' in n: nn.init.normal_(m.bias, mean=0., std=1e-6) else: nn.init.constant_(m.bias, 0) elif isinstance(m, nn.Conv2d): kaiming_init(m, mode='fan_in', bias=0.) elif isinstance(m, (_BatchNorm, nn.GroupNorm, nn.LayerNorm)): constant_init(m, val=1.0, bias=0.)
def _load_checkpoint(self, checkpoint, prefix=None, map_location=None): from mmcv.runner import (_load_checkpoint, _load_checkpoint_with_prefix, load_state_dict) from mmcv.utils import print_log logger = get_root_logger() if prefix is None: print_log(f'load model from: {checkpoint}', logger=logger) checkpoint = _load_checkpoint(checkpoint, map_location, logger) # get state_dict from checkpoint if 'state_dict' in checkpoint: state_dict = checkpoint['state_dict'] else: state_dict = checkpoint else: print_log( f'load {prefix} in model from: {checkpoint}', logger=logger) state_dict = _load_checkpoint_with_prefix(prefix, checkpoint, map_location) if 'pos_embed' in state_dict.keys(): ckpt_pos_embed_shape = state_dict['pos_embed'].shape if self.pos_embed.shape != ckpt_pos_embed_shape: print_log( f'Resize the pos_embed shape from {ckpt_pos_embed_shape} ' f'to {self.pos_embed.shape}.', logger=logger) ckpt_pos_embed_shape = to_2tuple( int(np.sqrt(ckpt_pos_embed_shape[1] - 1))) pos_embed_shape = self.patch_embed.patches_resolution state_dict['pos_embed'] = self.resize_pos_embed( state_dict['pos_embed'], ckpt_pos_embed_shape, pos_embed_shape, self.interpolate_mode) # load state_dict load_state_dict(self, state_dict, strict=False, logger=logger)
def init_weights(self): """Initialize the weights in backbone.""" logger = get_root_logger() if self.init_cfg is None: logger.warning(f'No pre-trained weights for ' f'{self.__class__.__name__}, ' f'training start from scratch') for m in self.modules(): if isinstance(m, nn.Linear): trunc_normal_init(m, std=.02, bias=0.) elif isinstance(m, nn.LayerNorm): constant_init(m, 1.0) for bly in self.layers: bly._init_respostnorm() else: if self.ape: raise NotImplementedError assert 'checkpoint' in self.init_cfg, f'Only support ' \ f'specify `Pretrained` in ' \ f'`init_cfg` in ' \ f'{self.__class__.__name__} ' ckpt = _load_checkpoint(self.init_cfg.checkpoint, logger=logger, map_location='cpu') if 'state_dict' in ckpt: state_dict = ckpt['state_dict'] elif 'model' in ckpt: state_dict = ckpt['model'] else: state_dict = ckpt # delete keys for reinitialization reinit_keys = ('relative_position_index', 'relative_coords_table') for reinit_key in reinit_keys: for k in list(state_dict.keys()): if reinit_key in k: del state_dict[k] load_state_dict(self, state_dict, strict=False, logger=logger)
def init_weights(self): if self.pretrained is None: for m in self.modules(): if isinstance(m, nn.Linear): trunc_normal_init(m.weight, std=.02) if m.bias is not None: constant_init(m.bias, 0) elif isinstance(m, nn.LayerNorm): constant_init(m.bias, 0) constant_init(m.weight, 1.0) elif isinstance(m, nn.Conv2d): fan_out = m.kernel_size[0] * m.kernel_size[ 1] * m.out_channels fan_out //= m.groups normal_init(m.weight, 0, math.sqrt(2.0 / fan_out)) if m.bias is not None: constant_init(m.bias, 0) elif isinstance(self.pretrained, str): logger = get_root_logger() checkpoint = _load_checkpoint(self.pretrained, logger=logger, map_location='cpu') if 'state_dict' in checkpoint: state_dict = checkpoint['state_dict'] elif 'model' in checkpoint: state_dict = checkpoint['model'] else: state_dict = checkpoint if self.pretrain_style == 'official': # Because segformer backbone is not support by mmcls, # so we need to convert pretrain weights to match this # implementation. state_dict = mit_convert(state_dict) self.load_state_dict(state_dict, False)
def test_resnet3d_backbone(): """Test resnet3d backbone.""" with pytest.raises(AssertionError): # In ResNet3d: 1 <= num_stages <= 4 ResNet3d(34, None, num_stages=0) with pytest.raises(AssertionError): # In ResNet3d: 1 <= num_stages <= 4 ResNet3d(34, None, num_stages=5) with pytest.raises(AssertionError): # In ResNet3d: 1 <= num_stages <= 4 ResNet3d(50, None, num_stages=0) with pytest.raises(AssertionError): # In ResNet3d: 1 <= num_stages <= 4 ResNet3d(50, None, num_stages=5) with pytest.raises(AssertionError): # len(spatial_strides) == len(temporal_strides) # == len(dilations) == num_stages ResNet3d(50, None, spatial_strides=(1, ), temporal_strides=(1, 1), dilations=(1, 1, 1), num_stages=4) with pytest.raises(AssertionError): # len(spatial_strides) == len(temporal_strides) # == len(dilations) == num_stages ResNet3d(34, None, spatial_strides=(1, ), temporal_strides=(1, 1), dilations=(1, 1, 1), num_stages=4) with pytest.raises(TypeError): # pretrain must be str or None. resnet3d_34 = ResNet3d(34, ['resnet', 'bninception']) resnet3d_34.init_weights() with pytest.raises(TypeError): # pretrain must be str or None. resnet3d_50 = ResNet3d(50, ['resnet', 'bninception']) resnet3d_50.init_weights() # resnet3d with depth 34, no pretrained, norm_eval True resnet3d_34 = ResNet3d(34, None, pretrained2d=False, norm_eval=True) resnet3d_34.init_weights() resnet3d_34.train() assert check_norm_state(resnet3d_34.modules(), False) # resnet3d with depth 50, no pretrained, norm_eval True resnet3d_50 = ResNet3d(50, None, pretrained2d=False, norm_eval=True) resnet3d_50.init_weights() resnet3d_50.train() assert check_norm_state(resnet3d_50.modules(), False) # resnet3d with depth 50, pretrained2d, norm_eval True resnet3d_50_pretrain = ResNet3d(50, 'torchvision://resnet50', norm_eval=True) resnet3d_50_pretrain.init_weights() resnet3d_50_pretrain.train() assert check_norm_state(resnet3d_50_pretrain.modules(), False) from mmcv.runner import _load_checkpoint chkp_2d = _load_checkpoint('torchvision://resnet50') for name, module in resnet3d_50_pretrain.named_modules(): if len(name.split('.')) == 4: # layer.block.module.submodule prefix = name.split('.')[:2] module_type = name.split('.')[2] submodule_type = name.split('.')[3] if module_type == 'downsample': name2d = name.replace('conv', '0').replace('bn', '1') else: layer_id = name.split('.')[2][-1] name2d = prefix[0] + '.' + prefix[1] + '.' + \ submodule_type + layer_id if isinstance(module, nn.Conv3d): conv2d_weight = chkp_2d[name2d + '.weight'] conv3d_weight = getattr(module, 'weight').data assert torch.equal( conv3d_weight, conv2d_weight.data.unsqueeze(2).expand_as(conv3d_weight) / conv3d_weight.shape[2]) if getattr(module, 'bias') is not None: conv2d_bias = chkp_2d[name2d + '.bias'] conv3d_bias = getattr(module, 'bias').data assert torch.equal(conv2d_bias, conv3d_bias) elif isinstance(module, nn.BatchNorm3d): for pname in ['weight', 'bias', 'running_mean', 'running_var']: param_2d = chkp_2d[name2d + '.' + pname] param_3d = getattr(module, pname).data assert torch.equal(param_2d, param_3d) conv3d = resnet3d_50_pretrain.conv1.conv assert torch.equal( conv3d.weight, chkp_2d['conv1.weight'].unsqueeze(2).expand_as(conv3d.weight) / conv3d.weight.shape[2]) conv3d = resnet3d_50_pretrain.layer3[2].conv2.conv assert torch.equal( conv3d.weight, chkp_2d['layer3.2.conv2.weight'].unsqueeze(2).expand_as( conv3d.weight) / conv3d.weight.shape[2]) # resnet3d with depth 34, no pretrained, norm_eval False resnet3d_34_no_bn_eval = ResNet3d(34, None, pretrained2d=False, norm_eval=False) resnet3d_34_no_bn_eval.init_weights() resnet3d_34_no_bn_eval.train() assert check_norm_state(resnet3d_34_no_bn_eval.modules(), True) # resnet3d with depth 50, no pretrained, norm_eval False resnet3d_50_no_bn_eval = ResNet3d(50, None, pretrained2d=False, norm_eval=False) resnet3d_50_no_bn_eval.init_weights() resnet3d_50_no_bn_eval.train() assert check_norm_state(resnet3d_50_no_bn_eval.modules(), True) # resnet3d with depth 34, no pretrained, frozen_stages, norm_eval False frozen_stages = 1 resnet3d_34_frozen = ResNet3d(34, None, pretrained2d=False, frozen_stages=frozen_stages) resnet3d_34_frozen.init_weights() resnet3d_34_frozen.train() assert resnet3d_34_frozen.conv1.bn.training is False for param in resnet3d_34_frozen.conv1.parameters(): assert param.requires_grad is False for i in range(1, frozen_stages + 1): layer = getattr(resnet3d_34_frozen, f'layer{i}') for mod in layer.modules(): if isinstance(mod, _BatchNorm): assert mod.training is False for param in layer.parameters(): assert param.requires_grad is False # test zero_init_residual for m in resnet3d_34_frozen.modules(): if hasattr(m, 'conv2'): assert torch.equal(m.conv2.bn.weight, torch.zeros_like(m.conv2.bn.weight)) assert torch.equal(m.conv2.bn.bias, torch.zeros_like(m.conv2.bn.bias)) # resnet3d with depth 50, no pretrained, frozen_stages, norm_eval False frozen_stages = 1 resnet3d_50_frozen = ResNet3d(50, None, pretrained2d=False, frozen_stages=frozen_stages) resnet3d_50_frozen.init_weights() resnet3d_50_frozen.train() assert resnet3d_50_frozen.conv1.bn.training is False for param in resnet3d_50_frozen.conv1.parameters(): assert param.requires_grad is False for i in range(1, frozen_stages + 1): layer = getattr(resnet3d_50_frozen, f'layer{i}') for mod in layer.modules(): if isinstance(mod, _BatchNorm): assert mod.training is False for param in layer.parameters(): assert param.requires_grad is False # test zero_init_residual for m in resnet3d_50_frozen.modules(): if hasattr(m, 'conv3'): assert torch.equal(m.conv3.bn.weight, torch.zeros_like(m.conv3.bn.weight)) assert torch.equal(m.conv3.bn.bias, torch.zeros_like(m.conv3.bn.bias)) # resnet3d frozen with depth 34 inference input_shape = (1, 3, 6, 64, 64) imgs = _demo_inputs(input_shape) # parrots 3dconv is only implemented on gpu if torch.__version__ == 'parrots': if torch.cuda.is_available(): resnet3d_34_frozen = resnet3d_34_frozen.cuda() imgs_gpu = imgs.cuda() feat = resnet3d_34_frozen(imgs_gpu) assert feat.shape == torch.Size([1, 512, 1, 2, 2]) else: feat = resnet3d_34_frozen(imgs) assert feat.shape == torch.Size([1, 512, 1, 2, 2]) # resnet3d with depth 50 inference input_shape = (1, 3, 6, 64, 64) imgs = _demo_inputs(input_shape) # parrots 3dconv is only implemented on gpu if torch.__version__ == 'parrots': if torch.cuda.is_available(): resnet3d_50_frozen = resnet3d_50_frozen.cuda() imgs_gpu = imgs.cuda() feat = resnet3d_50_frozen(imgs_gpu) assert feat.shape == torch.Size([1, 2048, 1, 2, 2]) else: feat = resnet3d_50_frozen(imgs) assert feat.shape == torch.Size([1, 2048, 1, 2, 2]) # resnet3d with depth 50 in caffe style inference resnet3d_50_caffe = ResNet3d(50, None, pretrained2d=False, style='caffe') resnet3d_50_caffe.init_weights() resnet3d_50_caffe.train() # parrots 3dconv is only implemented on gpu if torch.__version__ == 'parrots': if torch.cuda.is_available(): resnet3d_50_caffe = resnet3d_50_caffe.cuda() imgs_gpu = imgs.cuda() feat = resnet3d_50_caffe(imgs_gpu) assert feat.shape == torch.Size([1, 2048, 1, 2, 2]) else: feat = resnet3d_50_caffe(imgs) assert feat.shape == torch.Size([1, 2048, 1, 2, 2]) # resnet3d with depth 34 in caffe style inference resnet3d_34_caffe = ResNet3d(34, None, pretrained2d=False, style='caffe') resnet3d_34_caffe.init_weights() resnet3d_34_caffe.train() # parrots 3dconv is only implemented on gpu if torch.__version__ == 'parrots': if torch.cuda.is_available(): resnet3d_34_caffe = resnet3d_34_caffe.cuda() imgs_gpu = imgs.cuda() feat = resnet3d_34_caffe(imgs_gpu) assert feat.shape == torch.Size([1, 512, 1, 2, 2]) else: feat = resnet3d_34_caffe(imgs) assert feat.shape == torch.Size([1, 512, 1, 2, 2]) # resnet3d with depth with 3x3x3 inflate_style inference resnet3d_50_1x1x1 = ResNet3d(50, None, pretrained2d=False, inflate_style='3x3x3') resnet3d_50_1x1x1.init_weights() resnet3d_50_1x1x1.train() # parrots 3dconv is only implemented on gpu if torch.__version__ == 'parrots': if torch.cuda.is_available(): resnet3d_50_1x1x1 = resnet3d_50_1x1x1.cuda() imgs_gpu = imgs.cuda() feat = resnet3d_50_1x1x1(imgs_gpu) assert feat.shape == torch.Size([1, 2048, 1, 2, 2]) else: feat = resnet3d_50_1x1x1(imgs) assert feat.shape == torch.Size([1, 2048, 1, 2, 2]) resnet3d_34_1x1x1 = ResNet3d(34, None, pretrained2d=False, inflate_style='3x3x3') resnet3d_34_1x1x1.init_weights() resnet3d_34_1x1x1.train() # parrots 3dconv is only implemented on gpu if torch.__version__ == 'parrots': if torch.cuda.is_available(): resnet3d_34_1x1x1 = resnet3d_34_1x1x1.cuda() imgs_gpu = imgs.cuda() feat = resnet3d_34_1x1x1(imgs_gpu) assert feat.shape == torch.Size([1, 512, 1, 2, 2]) else: feat = resnet3d_34_1x1x1(imgs) assert feat.shape == torch.Size([1, 512, 1, 2, 2])
def model_from_checkpoint( filename: Union[Path, str], model_name=None, backbone_name=None, classes=None, is_coco=False, img_size=None, map_location=None, strict=False, revise_keys=[ (r"^module\.", ""), ], eval_mode=True, logger=None, ): """load checkpoint through URL scheme path. Args: filename (str): checkpoint file name with given prefix map_location (str, optional): Same as :func:`torch.load`. Default: None Returns: dict or OrderedDict: The loaded checkpoint. """ if isinstance(filename, Path): filename = str(filename) checkpoint = _load_checkpoint(filename=filename, map_location=map_location, logger=logger) if is_coco and classes: err_msg = "`is_coco` cannot be set to True if `classes` is passed and `not None`. `classes` has priority. `is_coco` will be ignored." if logger is not None: logger.warning(err_msg) else: print(err_msg) if classes is None: if is_coco: classes = CLASSES else: classes = checkpoint["meta"].get("classes", None) class_map = None if classes: class_map = ClassMap(classes) num_classes = len(class_map) if img_size is None: img_size = checkpoint["meta"].get("img_size", None) if model_name is None: model_name = checkpoint["meta"].get("model_name", None) model_type = None if model_name: model = model_name.split(".") # If model_name contains three or more components, the library and model are the second to last and last components, respectively (e.g. models.mmdet.retinanet) if len(model) >= 3: model_type = getattr(getattr(models, model[-2]), model[-1]) # If model_name follows the default convention, the library and model are the first and second components, respectively (e.g. mmdet.retinanet) else: model_type = getattr(getattr(models, model[0]), model[1]) if backbone_name is None: backbone_name = checkpoint["meta"].get("backbone_name", None) if model_type and backbone_name: backbone = getattr(model_type.backbones, backbone_name) extra_args = {} if img_size is None: img_size = checkpoint["meta"].get("img_size", None) models_with_img_size = ("yolov5", "efficientdet") # if 'efficientdet' in model_name: if (model_name) and (any(m in model_name for m in models_with_img_size)): extra_args["img_size"] = img_size # Instantiate model if model_type and backbone: model = model_type.model(backbone=backbone(pretrained=False), num_classes=num_classes, **extra_args) else: model = None # OrderedDict is a subclass of dict if not isinstance(checkpoint, dict): raise RuntimeError( f"No state_dict found in checkpoint file {filename}") # get state_dict from checkpoint if "state_dict" in checkpoint: state_dict = checkpoint["state_dict"] else: state_dict = checkpoint # strip prefix of state_dict for p, r in revise_keys: state_dict = {re.sub(p, r, k): v for k, v in state_dict.items()} # load state_dict if model: load_state_dict(model, state_dict, strict, logger) if eval_mode: model.eval() checkpoint_and_model = { "model": model, "model_type": model_type, "backbone": backbone, "class_map": class_map, "img_size": img_size, "checkpoint": checkpoint, } return checkpoint_and_model
def test_checkpoint_loader(): from mmcv.runner import _load_checkpoint, save_checkpoint, CheckpointLoader import tempfile import os checkpoint_path = os.path.join(tempfile.gettempdir(), 'checkpoint.pth') model = Model() save_checkpoint(model, checkpoint_path) checkpoint = _load_checkpoint(checkpoint_path) assert 'meta' in checkpoint and 'CLASSES' not in checkpoint['meta'] # remove the temp file os.remove(checkpoint_path) filenames = [ 'http://xx.xx/xx.pth', 'https://xx.xx/xx.pth', 'modelzoo://xx.xx/xx.pth', 'torchvision://xx.xx/xx.pth', 'open-mmlab://xx.xx/xx.pth', 'openmmlab://xx.xx/xx.pth', 'mmcls://xx.xx/xx.pth', 'pavi://xx.xx/xx.pth', 's3://xx.xx/xx.pth', 'ss3://xx.xx/xx.pth', ' s3://xx.xx/xx.pth' ] fn_names = [ 'load_from_http', 'load_from_http', 'load_from_torchvision', 'load_from_torchvision', 'load_from_openmmlab', 'load_from_openmmlab', 'load_from_mmcls', 'load_from_pavi', 'load_from_ceph', 'load_from_local', 'load_from_local' ] for filename, fn_name in zip(filenames, fn_names): loader = CheckpointLoader._get_checkpoint_loader(filename) assert loader.__name__ == fn_name @CheckpointLoader.register_scheme(prefixes='ftp://') def load_from_ftp(filename, map_location): return dict(filename=filename) # test register_loader filename = 'ftp://xx.xx/xx.pth' loader = CheckpointLoader._get_checkpoint_loader(filename) assert loader.__name__ == 'load_from_ftp' def load_from_ftp1(filename, map_location): return dict(filename=filename) # test duplicate registered error with pytest.raises(KeyError): CheckpointLoader.register_scheme('ftp://', load_from_ftp1) # test force param CheckpointLoader.register_scheme('ftp://', load_from_ftp1, force=True) checkpoint = CheckpointLoader.load_checkpoint(filename) assert checkpoint['filename'] == filename # test print function name loader = CheckpointLoader._get_checkpoint_loader(filename) assert loader.__name__ == 'load_from_ftp1' # test sort @CheckpointLoader.register_scheme(prefixes='a/b') def load_from_ab(filename, map_location): return dict(filename=filename) @CheckpointLoader.register_scheme(prefixes='a/b/c') def load_from_abc(filename, map_location): return dict(filename=filename) filename = 'a/b/c/d' loader = CheckpointLoader._get_checkpoint_loader(filename) assert loader.__name__ == 'load_from_abc'
def init_weights(self): if self.pretrained is None: super().init_weights() if self.use_abs_pos_embed: trunc_normal_init(self.absolute_pos_embed, std=0.02) for m in self.modules(): if isinstance(m, Linear): trunc_normal_init(m.weight, std=.02) if m.bias is not None: constant_init(m.bias, 0) elif isinstance(m, LayerNorm): constant_init(m.bias, 0) constant_init(m.weight, 1.0) elif isinstance(self.pretrained, str): logger = get_root_logger() ckpt = _load_checkpoint(self.pretrained, logger=logger, map_location='cpu') if 'state_dict' in ckpt: state_dict = ckpt['state_dict'] elif 'model' in ckpt: state_dict = ckpt['model'] else: state_dict = ckpt if self.pretrain_style == 'official': state_dict = swin_convert(state_dict) # strip prefix of state_dict if list(state_dict.keys())[0].startswith('module.'): state_dict = {k[7:]: v for k, v in state_dict.items()} # reshape absolute position embedding if state_dict.get('absolute_pos_embed') is not None: absolute_pos_embed = state_dict['absolute_pos_embed'] N1, L, C1 = absolute_pos_embed.size() N2, C2, H, W = self.absolute_pos_embed.size() if N1 != N2 or C1 != C2 or L != H * W: logger.warning('Error in loading absolute_pos_embed, pass') else: state_dict['absolute_pos_embed'] = absolute_pos_embed.view( N2, H, W, C2).permute(0, 3, 1, 2).contiguous() # interpolate position bias table if needed relative_position_bias_table_keys = [ k for k in state_dict.keys() if 'relative_position_bias_table' in k ] for table_key in relative_position_bias_table_keys: table_pretrained = state_dict[table_key] table_current = self.state_dict()[table_key] L1, nH1 = table_pretrained.size() L2, nH2 = table_current.size() if nH1 != nH2: logger.warning(f'Error in loading {table_key}, pass') else: if L1 != L2: S1 = int(L1**0.5) S2 = int(L2**0.5) table_pretrained_resized = F.interpolate( table_pretrained.permute(1, 0).reshape( 1, nH1, S1, S1), size=(S2, S2), mode='bicubic') state_dict[table_key] = table_pretrained_resized.view( nH2, L2).permute(1, 0).contiguous() # load state_dict self.load_state_dict(state_dict, False)
def inflate_weights(model, state_dict_2d, logger=None): if isinstance(state_dict_2d, str): state_dict_2d = _load_checkpoint(state_dict_2d, map_location='cpu') if 'state_dict' in state_dict_2d: state_dict_2d = state_dict_2d['state_dict'] assert isinstance(state_dict_2d, dict) first_element = list(state_dict_2d.keys())[0] if first_element.startswith('module.'): state_dict_2d = {k[7:]: v for k, v in state_dict_2d.items()} elif first_element.startswith('backbone.'): state_dict_2d = {k[9:]: v for k, v in state_dict_2d.items()} missing_keys = [] shape_mismatch_pairs = [] shape_inflated_pairs = [] copied_pairs = [] for name, module in model.named_modules(): trg_name = name + '.weight' if isinstance(module, nn.Conv3d) and trg_name in state_dict_2d: old_weight = state_dict_2d[name + '.weight'].data old_weight_shape = old_weight.shape assert len(old_weight_shape) in [2, 4] if len(old_weight.shape) == 2: old_weight = old_weight.unsqueeze(2).unsqueeze(3) old_weight = old_weight.unsqueeze(2) if not _is_compatible(old_weight.shape, module.weight.data.shape): shape_mismatch_pairs.append( [name, list(old_weight_shape), list(module.weight.size())]) continue new_weight = old_weight.expand_as( module.weight) / module.weight.data.shape[2] module.weight.data.copy_(new_weight) shape_inflated_pairs.append( [trg_name, list(old_weight_shape), list(module.weight.size())]) if hasattr(module, 'bias') and module.bias is not None: trg_name = name + '.bias' new_bias = state_dict_2d[trg_name].data module.bias.data.copy_(new_bias) copied_pairs.append([trg_name, list(new_bias.size())]) elif isinstance(module, nn.BatchNorm3d) and trg_name in state_dict_2d: for attr_name in ['weight', 'bias', 'running_mean', 'running_var']: trg_name = name + '.' + attr_name old_attr = state_dict_2d[trg_name].data new_attr = getattr(module, attr_name) new_attr.data.copy_(old_attr) copied_pairs.append([trg_name, list(new_attr.size())]) elif isinstance(module, (nn.Conv3d, nn.BatchNorm3d)): missing_keys.append(name) if len(missing_keys) > 0: msg = 'Missing keys in source state_dict: {}\n'.format( ', '.join(missing_keys)) if logger is not None: logger.warning(msg) if shape_mismatch_pairs: header = ['key', '2d shape', '3d shape'] table_data = [header] + shape_mismatch_pairs table = AsciiTable(table_data) if logger is not None: logger.warning('These keys have mismatched shape:\n' + table.table) if copied_pairs: header = ['key', 'shape'] table_data = [header] + copied_pairs table = AsciiTable(table_data) if logger is not None: logger.info('These keys have been copied:\n' + table.table) if shape_inflated_pairs: header = ['key', '2d shape', '3d shape'] table_data = [header] + shape_inflated_pairs table = AsciiTable(table_data) if logger is not None: logger.info('These keys have been shape inflated:\n' + table.table)
def init_weights(self): logger = get_root_logger() if self.init_cfg is None: logger.warn(f'No pre-trained weights for ' f'{self.__class__.__name__}, ' f'training start from scratch') if self.use_abs_pos_embed: trunc_normal_(self.absolute_pos_embed, std=0.02) for m in self.modules(): if isinstance(m, nn.Linear): trunc_normal_init(m, std=.02, bias=0.) elif isinstance(m, nn.LayerNorm): constant_init(m, 1.0) else: assert 'checkpoint' in self.init_cfg, f'Only support ' \ f'specify `Pretrained` in ' \ f'`init_cfg` in ' \ f'{self.__class__.__name__} ' ckpt = _load_checkpoint(self.init_cfg.checkpoint, logger=logger, map_location='cpu') if 'state_dict' in ckpt: _state_dict = ckpt['state_dict'] elif 'model' in ckpt: _state_dict = ckpt['model'] else: _state_dict = ckpt if self.convert_weights: # supported loading weight from original repo, _state_dict = swin_converter(_state_dict) state_dict = OrderedDict() for k, v in _state_dict.items(): if k.startswith('backbone.'): state_dict[k[9:]] = v # strip prefix of state_dict if list(state_dict.keys())[0].startswith('module.'): state_dict = {k[7:]: v for k, v in state_dict.items()} # reshape absolute position embedding if state_dict.get('absolute_pos_embed') is not None: absolute_pos_embed = state_dict['absolute_pos_embed'] N1, L, C1 = absolute_pos_embed.size() N2, C2, H, W = self.absolute_pos_embed.size() if N1 != N2 or C1 != C2 or L != H * W: logger.warning('Error in loading absolute_pos_embed, pass') else: state_dict['absolute_pos_embed'] = absolute_pos_embed.view( N2, H, W, C2).permute(0, 3, 1, 2).contiguous() # interpolate position bias table if needed relative_position_bias_table_keys = [ k for k in state_dict.keys() if 'relative_position_bias_table' in k ] for table_key in relative_position_bias_table_keys: table_pretrained = state_dict[table_key] table_current = self.state_dict()[table_key] L1, nH1 = table_pretrained.size() L2, nH2 = table_current.size() if nH1 != nH2: logger.warning(f'Error in loading {table_key}, pass') elif L1 != L2: S1 = int(L1**0.5) S2 = int(L2**0.5) table_pretrained_resized = F.interpolate( table_pretrained.permute(1, 0).reshape(1, nH1, S1, S1), size=(S2, S2), mode='bicubic') state_dict[table_key] = table_pretrained_resized.view( nH2, L2).permute(1, 0).contiguous() # load state_dict self.load_state_dict(state_dict, False)