def initialize(self, pretrain_dict: Dict[str, nn.Parameter], verbosity: int = 2, **kwargs): weight_dict = dict() matched_dict, missed_keys = self.find_match_miss(pretrain_dict) weight_dict.update(matched_dict) if len(missed_keys) > 0: print("Initial seg head from cls head") strip_pretrain_dict = strip_named_params(pretrain_dict) for model_param_name in missed_keys: if 'fc8' in model_param_name: # Don't init fc_8 continue if re.search(r'fc\d_\d', model_param_name) is None: continue # Fuzzy Match strip_model_param_name = strip_param_name(model_param_name) fuzzy_strip_model_param_name = re.sub( r'(fc\d_)\d', lambda match: match.group(1) + '0', strip_model_param_name) weight_dict[model_param_name] = strip_pretrain_dict[ fuzzy_strip_model_param_name] self.update_model_state_dict(weight_dict, verbosity)
def initialize(self, pretrain_dict: Dict[str, nn.Parameter], verbosity: int = 1, **kwargs): strip_pretrain_dict = strip_named_params(pretrain_dict) weight_dict = {} for model_param_name in self.state_dict().keys(): strip_model_param_name = strip_param_name(model_param_name) if strip_model_param_name in strip_pretrain_dict: weight_dict[model_param_name] = strip_pretrain_dict[ strip_model_param_name] update_model_state_dict(self, weight_dict, verbosity)
def initialize(self, pretrain_dict: Dict[str, nn.Parameter], verbosity: int=1, **kwargs): strip_pretrain_dict = strip_named_params(pretrain_dict) weight_dict = {} for model_param_name in self.state_dict().keys(): strip_model_param_name = strip_param_name(model_param_name) if strip_model_param_name in strip_pretrain_dict: weight_dict[model_param_name] = strip_pretrain_dict[strip_model_param_name] elif 'fc8' in strip_model_param_name: continue else: # Using param from cls branch with same layer_ID print(f"Init {strip_model_param_name} from cls head") pattern = re.sub(r'fc\d_\d', lambda match: match.group(0)[:-1] + '0', strip_model_param_name) weight = list(named_param_match_pattern(pattern, strip_pretrain_dict, full_match=True, multi_match=False).values())[0] weight_dict[model_param_name] = weight update_model_state_dict(self, weight_dict, verbosity)
def initialize_from_deeplab_v1(self, pretrain_dict: Dict[str, nn.Parameter], verbosity: int = 1): update_dict = {} matched_dict, missed_keys = self.find_match_miss(pretrain_dict) update_dict.update(matched_dict) if len(missed_keys) > 0: print("Init cls head from deeplab_v1") strip_pretrain_dict = strip_named_params(pretrain_dict) for model_param_name in missed_keys: strip_model_param_name = strip_param_name(model_param_name) if 'fc8' in strip_model_param_name: continue else: param_name_in_pretrain = re.sub( r'(fc\d)_\d', lambda match: match.group(1), strip_model_param_name) update_dict[model_param_name] = strip_pretrain_dict[ param_name_in_pretrain] self.update_model_state_dict(update_dict, verbosity)
def initialize_from_deeplab_v2(self, pretrain_dict: Dict[str, nn.Parameter], verbosity: int = 1): strip_pretrain_dict = strip_named_params(pretrain_dict) weight_dict = {} for model_param_name in self.state_dict().keys(): strip_model_param_name = strip_param_name(model_param_name) if strip_model_param_name in strip_pretrain_dict: # Exact Match weight_dict[model_param_name] = strip_pretrain_dict[ strip_model_param_name] else: if 'fc8' in model_param_name: # Don't init fc_8 continue # Fuzzy Match fuzzy_strip_model_param_name = re.sub( r'fc\d_\d', lambda match: match.group(0)[:-1] + '1', strip_model_param_name) weight_dict[model_param_name] = strip_pretrain_dict[ fuzzy_strip_model_param_name] update_model_state_dict(self, weight_dict, verbosity)