def init_weights(self, pretrained=None): """Initialize model weights.""" if isinstance(pretrained, str): logger = get_root_logger() state_dict = get_state_dict(pretrained) load_state_dict( self.top, state_dict['top'], strict=False, logger=logger) for i in range(self.num_stages): load_state_dict( self.multi_stage_mspn[i].downsample, state_dict['bottlenecks'], strict=False, logger=logger) for m in self.multi_stage_mspn.modules(): if isinstance(m, nn.Conv2d): kaiming_init(m) elif isinstance(m, nn.BatchNorm2d): constant_init(m, 1) elif isinstance(m, nn.Linear): normal_init(m, std=0.01) for m in self.top.modules(): if isinstance(m, nn.Conv2d): kaiming_init(m)
def get_net(): config = get_efficientdet_config('tf_efficientdet_d6') net = EfficientDet(config, pretrained_backbone=False) config.num_classes = 1 config.image_size = 1024 net.class_net = HeadNet(config, num_outputs=config.num_classes, norm_kwargs=dict(eps=.001, momentum=.01)) checkpoint = torch.load('pretrained/efficientdet_d6-51cb0132.pth') load_state_dict(net, checkpoint) return DetBenchTrain(net, config)
def get_net(): config = get_efficientdet_config('tf_efficientdet_d6') net = EfficientDet(config, pretrained_backbone=False) config.num_classes = 1 config.image_size = 1024 net.class_net = HeadNet(config, num_outputs=config.num_classes, norm_kwargs=dict(eps=.001, momentum=.01)) checkpoint = torch.load( 'effdet6-baseline-1024-4x8-sa-fold0/best-checkpoint-052epoch.bin') load_state_dict(net, checkpoint['model_state_dict']) return DetBenchTrain(net, config)
def load_checkpoint(model, filename, map_location=None, strict=False, logger=None): """Load checkpoint from a file or URI. Args: model (Module): Module to load checkpoint. filename (str): Accept local filepath, URL, ``torchvision://xxx``, ``open-mmlab://xxx``. Please refer to ``docs/model_zoo.md`` for details. map_location (str): Same as :func:`torch.load`. strict (bool): Whether to allow different params for the model and checkpoint. logger (:mod:`logging.Logger` or None): The logger for error message. Returns: dict or OrderedDict: The loaded checkpoint. """ checkpoint = _load_checkpoint(filename, map_location) # OrderedDict is a subclass of dict if not isinstance(checkpoint, dict): raise RuntimeError( f'No state_dict found in checkpoint file {filename}') # get state_dict from checkpoint if 'state_dict' in checkpoint: state_dict_tmp = checkpoint['state_dict'] else: state_dict_tmp = checkpoint state_dict = OrderedDict() # strip prefix of state_dict for k, v in state_dict_tmp.items(): if k.startswith('module.backbone.'): state_dict[k[16:]] = v elif k.startswith('module.'): state_dict[k[7:]] = v elif k.startswith('backbone.'): state_dict[k[9:]] = v else: state_dict[k] = v # load state_dict load_state_dict(model, state_dict, strict, logger) return checkpoint
def load_pretrained(model, url, filter_fn=None, strict=False, logger=None): if not url: print( "=> Warning: Pretrained model URL is empty, using random initialization." ) return state_dict = load_state_dict_from_url(url, progress=False, map_location='cpu') input_conv = 'conv_stem' classifier = 'classifier' in_chans = getattr(model, input_conv).weight.shape[1] #num_classes = getattr(model, classifier).weight.shape[0] input_conv_weight = input_conv + '.weight' pretrained_in_chans = state_dict[input_conv_weight].shape[1] if in_chans != pretrained_in_chans: if in_chans == 1: print( '=> Converting pretrained input conv {} from {} to 1 channel'. format(input_conv_weight, pretrained_in_chans)) conv1_weight = state_dict[input_conv_weight] state_dict[input_conv_weight] = conv1_weight.sum(dim=1, keepdim=True) else: print( '=> Discarding pretrained input conv {} since input channel count != {}' .format(input_conv_weight, pretrained_in_chans)) del state_dict[input_conv_weight] strict = False # classifier_weight = classifier + '.weight' # pretrained_num_classes = state_dict[classifier_weight].shape[0] # if num_classes != pretrained_num_classes: # print('=> Discarding pretrained classifier since num_classes != {}'.format(pretrained_num_classes)) # del state_dict[classifier_weight] # del state_dict[classifier + '.bias'] # strict = False if filter_fn is not None: state_dict = filter_fn(state_dict) load_state_dict(model, state_dict, strict=strict, logger=get_root_logger())
def init_weights(self, pretrained=None): """Initialize model weights.""" if isinstance(pretrained, str): logger = get_root_logger() state_dict_tmp = get_state_dict(pretrained) state_dict = OrderedDict() state_dict['top'] = OrderedDict() state_dict['bottlenecks'] = OrderedDict() for k, v in state_dict_tmp.items(): if k.startswith('layer'): if 'downsample.0' in k: state_dict['bottlenecks'][k.replace( 'downsample.0', 'downsample.conv')] = v elif 'downsample.1' in k: state_dict['bottlenecks'][k.replace( 'downsample.1', 'downsample.bn')] = v else: state_dict['bottlenecks'][k] = v elif k.startswith('conv1'): state_dict['top'][k.replace('conv1', 'top.0.conv')] = v elif k.startswith('bn1'): state_dict['top'][k.replace('bn1', 'top.0.bn')] = v load_state_dict(self.top, state_dict['top'], strict=False, logger=logger) for i in range(self.num_stages): load_state_dict(self.multi_stage_mspn[i].downsample, state_dict['bottlenecks'], strict=False, logger=logger) else: for m in self.multi_stage_mspn.modules(): if isinstance(m, nn.Conv2d): kaiming_init(m) elif isinstance(m, nn.BatchNorm2d): constant_init(m, 1) elif isinstance(m, nn.Linear): normal_init(m, std=0.01) for m in self.top.modules(): if isinstance(m, nn.Conv2d): kaiming_init(m)
def load_checkpoint(model, filename, map_location=None, strict=False, logger=None): """Load checkpoint from a file or URI. Args: model (Module): Module to load checkpoint. filename (str): Either a filepath or URL or modelzoo://xxxxxxx. map_location (str): Same as :func:`torch.load`. strict (bool): Whether to allow different params for the model and checkpoint. logger (:mod:`logging.Logger` or None): The logger for error message. Returns: dict or OrderedDict: The loaded checkpoint. """ # load checkpoint from modelzoo or file or url if filename.startswith('modelzoo://'): import torchvision model_urls = dict() for _, name, ispkg in pkgutil.walk_packages( torchvision.models.__path__): if not ispkg: _zoo = import_module('torchvision.models.{}'.format(name)) if hasattr(_zoo, 'model_urls'): _urls = getattr(_zoo, 'model_urls') model_urls.update(_urls) model_name = filename[11:] checkpoint = load_url_dist(model_urls[model_name]) elif filename.startswith('open-mmlab://'): model_name = filename[13:] checkpoint = load_url_dist(open_mmlab_model_urls[model_name]) elif filename.startswith(('http://', 'https://')): checkpoint = load_url_dist(filename) else: if not osp.isfile(filename): raise IOError('{} is not a checkpoint file'.format(filename)) checkpoint = torch.load(filename, map_location=map_location) # get state_dict from checkpoint if isinstance(checkpoint, OrderedDict): state_dict = checkpoint elif isinstance(checkpoint, dict) and 'state_dict' in checkpoint: state_dict = checkpoint['state_dict'] else: raise RuntimeError( 'No state_dict found in checkpoint file {}'.format(filename)) # strip prefix of state_dict if list(state_dict.keys())[0].startswith('module.'): state_dict = {k[7:]: v for k, v in checkpoint['state_dict'].items()} # load state_dict if True: # TODO(xiao): init stage3 with stage2 state_dict_ = state_dict.copy() for k, v in state_dict.items(): if "extra_heads.0" in k: kk = k.replace("extra_heads.0", "extra_heads.1") state_dict_[kk] = v.clone() state_dict = state_dict_.copy() if hasattr(model, 'module'): load_state_dict(model.module, state_dict, strict, logger) else: load_state_dict(model, state_dict, strict, logger) return checkpoint
def init_weights(self, pretrained=None): if isinstance(pretrained, list): logger = logging.getLogger() filename = pretrained[0] if filename.startswith('modelzoo://'): warnings.warn( 'The URL scheme of "modelzoo://" is deprecated, please ' 'use "torchvision://" instead') model_urls = get_torchvision_models() model_name = filename[11:] checkpoint = load_url_dist(model_urls[model_name]) elif filename.startswith('torchvision://'): model_urls = get_torchvision_models() model_name = filename[14:] checkpoint = load_url_dist(model_urls[model_name]) elif filename.startswith('open-mmlab://'): model_name = filename[13:] checkpoint = load_url_dist(open_mmlab_model_urls[model_name]) elif filename.startswith(('http://', 'https://')): checkpoint = load_url_dist(filename) else: if not osp.isfile(filename): raise IOError( '{} is not a checkpoint file'.format(filename)) checkpoint = torch.load(filename, map_location=map_location) #print(checkpoint.keys) #state_dict = {'deep_'+k : v for k,v in checkpoint.items() if 'layer' in k} #print(state_dict.keys()) # load state_dict # get state_dict from checkpoint if isinstance(checkpoint, OrderedDict): state_dict = checkpoint elif isinstance(checkpoint, dict) and 'state_dict' in checkpoint: state_dict = checkpoint['state_dict'] else: raise RuntimeError( 'No state_dict found in checkpoint file {}'.format( filename)) # strip prefix of state_dict if list(state_dict.keys())[0].startswith('module.'): state_dict = { k[7:]: v for k, v in checkpoint['state_dict'].items() } if hasattr(self, 'module'): load_state_dict(self.module, state_dict, False, logger) else: load_state_dict(self, state_dict, False, logger) filename = pretrained[1] if filename.startswith('modelzoo://'): warnings.warn( 'The URL scheme of "modelzoo://" is deprecated, please ' 'use "torchvision://" instead') model_urls = get_torchvision_models() model_name = filename[11:] checkpoint = load_url_dist(model_urls[model_name]) elif filename.startswith('torchvision://'): model_urls = get_torchvision_models() model_name = filename[14:] checkpoint = load_url_dist(model_urls[model_name]) elif filename.startswith('open-mmlab://'): model_name = filename[13:] checkpoint = load_url_dist(open_mmlab_model_urls[model_name]) elif filename.startswith(('http://', 'https://')): checkpoint = load_url_dist(filename) else: if not osp.isfile(filename): raise IOError( '{} is not a checkpoint file'.format(filename)) checkpoint = torch.load(filename, map_location=map_location) if isinstance(checkpoint, OrderedDict): state_dict = checkpoint elif isinstance(checkpoint, dict) and 'state_dict' in checkpoint: state_dict = checkpoint['state_dict'] else: raise RuntimeError( 'No state_dict found in checkpoint file {}'.format( filename)) # strip prefix of state_dict if list(state_dict.keys())[0].startswith('module.'): state_dict = { 'shallow_' + k[7:]: v for k, v in checkpoint['state_dict'].items() if 'layer' in k } else: state_dict = { 'shallow_' + k: v for k, v in checkpoint.items() if 'layer' in k } if hasattr(self, 'module'): load_state_dict(self.module, state_dict, False, logger) else: load_state_dict(self, state_dict, False, logger) elif pretrained is None: for m in self.modules(): if isinstance(m, nn.Conv2d): kaiming_init(m) elif isinstance(m, (_BatchNorm, nn.GroupNorm)): constant_init(m, 1) if self.dcn is not None: for m in self.modules(): if isinstance(m, Bottleneck) and hasattr( m, 'conv2_offset'): constant_init(m.conv2_offset, 0) if self.zero_init_residual: for m in self.modules(): if isinstance(m, Bottleneck): constant_init(m.norm3, 0) elif isinstance(m, BasicBlock): constant_init(m.norm2, 0) else: raise TypeError('pretrained must be a str or None')