def build_regnet(cfg): # fmt: off pretrain = cfg.MODEL.BACKBONE.PRETRAIN pretrain_path = cfg.MODEL.BACKBONE.PRETRAIN_PATH last_stride = cfg.MODEL.BACKBONE.LAST_STRIDE #bn_norm = cfg.MODEL.BN_TYPE depth = cfg.MODEL.BACKBONE.STR_DEPTH # fmt: on cfg_files = { '200x': 'modeling/backbone/regnet/regnetx/RegNetX-200MF_dds_8gpu.yaml', '200y': 'modeling/backbone/regnet/regnety/RegNetY-200MF_dds_8gpu.yaml', '400x': 'modeling/backbone/regnet/regnetx/RegNetX-400MF_dds_8gpu.yaml', '400y': 'modeling/backbone/regnet/regnety/RegNetY-400MF_dds_8gpu.yaml', '800x': 'modeling/backbone/regnet/regnetx/RegNetX-800MF_dds_8gpu.yaml', '800y': 'modeling/backbone/regnet/regnety/RegNetY-800MF_dds_8gpu.yaml', '1600x': 'modeling/backbone/regnet/regnetx/RegNetX-1.6GF_dds_8gpu.yaml', '1600y': 'modeling/backbone/regnet/regnety/RegNetY-1.6GF_dds_8gpu.yaml', '3200x': 'modeling/backbone/regnet/regnetx/RegNetX-3.2GF_dds_8gpu.yaml', '3200y': 'modeling/backbone/regnet/regnety/RegNetY-3.2GF_dds_8gpu.yaml', '4000x': 'modeling/backbone/regnet/regnety/RegNetX-4.0GF_dds_8gpu.yaml', '4000y': 'modeling/backbone/regnet/regnety/RegNetY-4.0GF_dds_8gpu.yaml', '6400x': 'modeling/backbone/regnet/regnetx/RegNetX-6.4GF_dds_8gpu.yaml', '6400y': 'modeling/backbone/regnet/regnety/RegNetY-6.4GF_dds_8gpu.yaml', }[depth] regnet_cfg.merge_from_file(cfg_files) model = RegNet(last_stride, cfg) if pretrain: # Load pretrain path if specifically if pretrain_path: try: state_dict = torch.load(pretrain_path, map_location=torch.device('cpu'))['model_state'] logger.info(f"Loading pretrained model from {pretrain_path}") except FileNotFoundError as e: logger.info(f'{pretrain_path} is not found! Please check this path.') raise e except KeyError as e: logger.info("State dict keys error! Please check the state dict.") raise e else: key = depth state_dict = init_pretrained_weights(key) incompatible = model.load_state_dict(state_dict, strict=False) if incompatible.missing_keys: logger.info( get_missing_parameters_message(incompatible.missing_keys) ) if incompatible.unexpected_keys: logger.info( get_unexpected_parameters_message(incompatible.unexpected_keys) ) return model
def init_weights(self, num_layers, pretrain): if pretrain: url = model_urls['resnet{}'.format(num_layers)] pretrained_state_dict = model_zoo.load_url(url) logger.info(f"initial from {url}") incompatible = self.load_state_dict(pretrained_state_dict, strict=False) if incompatible.missing_keys: logger.info( get_missing_parameters_message(incompatible.missing_keys)) if incompatible.unexpected_keys: logger.info( get_unexpected_parameters_message( incompatible.unexpected_keys))
def align_and_update_state_dicts(model_state_dict, ckpt_state_dict, c2_conversion=True): """ Match names between the two state-dict, and update the values of model_state_dict in-place with copies of the matched tensor in ckpt_state_dict. If `c2_conversion==True`, `ckpt_state_dict` is assumed to be a Caffe2 model and will be renamed at first. Strategy: suppose that the models that we will create will have prefixes appended to each of its keys, for example due to an extra level of nesting that the original pre-trained weights from ImageNet won't contain. For example, model.state_dict() might return backbone[0].body.res2.conv1.weight, while the pre-trained model contains res2.conv1.weight. We thus want to match both parameters together. For that, we look for each model weight, look among all loaded keys if there is one that is a suffix of the current weight name, and use it if that's the case. If multiple matches exist, take the one with longest size of the corresponding name. For example, for the same model as before, the pretrained weight file can contain both res2.conv1.weight, as well as conv1.weight. In this case, we want to match backbone[0].body.conv1.weight to conv1.weight, and backbone[0].body.res2.conv1.weight to res2.conv1.weight. """ model_keys = sorted(list(model_state_dict.keys())) if c2_conversion: ckpt_state_dict, original_keys = convert_c2_detectron_names( ckpt_state_dict) # original_keys: the name in the original dict (before renaming) else: original_keys = {x: x for x in ckpt_state_dict.keys()} ckpt_keys = sorted(list(ckpt_state_dict.keys())) def match(a, b): # Matched ckpt_key should be a complete (starts with '.') suffix. # For example, roi_heads.mesh_head.whatever_conv1 does not match conv1, # but matches whatever_conv1 or mesh_head.whatever_conv1. return a == b or a.endswith("." + b) # get a matrix of string matches, where each (i, j) entry correspond to the size of the # ckpt_key string, if it matches match_matrix = [ len(j) if match(i, j) else 0 for i in model_keys for j in ckpt_keys ] match_matrix = torch.as_tensor(match_matrix).view(len(model_keys), len(ckpt_keys)) # use the matched one with longest size in case of multiple matches max_match_size, idxs = match_matrix.max(1) # remove indices that correspond to no-match idxs[max_match_size == 0] = -1 # used for logging max_len_model = max(len(key) for key in model_keys) if model_keys else 1 max_len_ckpt = max(len(key) for key in ckpt_keys) if ckpt_keys else 1 log_str_template = "{: <{}} loaded from {: <{}} of shape {}" logger = logging.getLogger(__name__) # matched_pairs (matched checkpoint key --> matched model key) matched_keys = {} for idx_model, idx_ckpt in enumerate(idxs.tolist()): if idx_ckpt == -1: continue key_model = model_keys[idx_model] key_ckpt = ckpt_keys[idx_ckpt] value_ckpt = ckpt_state_dict[key_ckpt] shape_in_model = model_state_dict[key_model].shape if shape_in_model != value_ckpt.shape: logger.warning( "Shape of {} in checkpoint is {}, while shape of {} in model is {}." .format(key_ckpt, value_ckpt.shape, key_model, shape_in_model)) logger.warning( "{} will not be loaded. Please double check and see if this is desired." .format(key_ckpt)) continue model_state_dict[key_model] = value_ckpt.clone() if key_ckpt in matched_keys: # already added to matched_keys logger.error( "Ambiguity found for {} in checkpoint!" "It matches at least two keys in the model ({} and {}).". format(key_ckpt, key_model, matched_keys[key_ckpt])) raise ValueError( "Cannot match one checkpoint key to multiple keys in the model." ) matched_keys[key_ckpt] = key_model if not global_cfg.MUTE_HEADER: logger.info( log_str_template.format( key_model, max_len_model, original_keys[key_ckpt], max_len_ckpt, tuple(shape_in_model), )) matched_model_keys = matched_keys.values() matched_ckpt_keys = matched_keys.keys() # print warnings about unmatched keys on both side unmatched_model_keys = [ k for k in model_keys if k not in matched_model_keys ] if len(unmatched_model_keys): logger.info(get_missing_parameters_message(unmatched_model_keys)) unmatched_ckpt_keys = [k for k in ckpt_keys if k not in matched_ckpt_keys] if len(unmatched_ckpt_keys): logger.info( get_unexpected_parameters_message(original_keys[x] for x in unmatched_ckpt_keys))