예제 #1
0
    def initialize(self,
                   pretrain_dict: Dict[str, nn.Parameter],
                   verbosity: int = 2,
                   **kwargs):
        weight_dict = dict()
        matched_dict, missed_keys = self.find_match_miss(pretrain_dict)
        weight_dict.update(matched_dict)

        if len(missed_keys) > 0:
            print("Initial seg head from cls head")
            strip_pretrain_dict = strip_named_params(pretrain_dict)
            for model_param_name in missed_keys:
                if 'fc8' in model_param_name:
                    # Don't init fc_8
                    continue
                if re.search(r'fc\d_\d', model_param_name) is None:
                    continue
                # Fuzzy Match
                strip_model_param_name = strip_param_name(model_param_name)
                fuzzy_strip_model_param_name = re.sub(
                    r'(fc\d_)\d', lambda match: match.group(1) + '0',
                    strip_model_param_name)
                weight_dict[model_param_name] = strip_pretrain_dict[
                    fuzzy_strip_model_param_name]

        self.update_model_state_dict(weight_dict, verbosity)
예제 #2
0
    def initialize(self,
                   pretrain_dict: Dict[str, nn.Parameter],
                   verbosity: int = 1,
                   **kwargs):
        strip_pretrain_dict = strip_named_params(pretrain_dict)
        weight_dict = {}

        for model_param_name in self.state_dict().keys():
            strip_model_param_name = strip_param_name(model_param_name)
            if strip_model_param_name in strip_pretrain_dict:
                weight_dict[model_param_name] = strip_pretrain_dict[
                    strip_model_param_name]

        update_model_state_dict(self, weight_dict, verbosity)
예제 #3
0
    def initialize(self, pretrain_dict: Dict[str, nn.Parameter], verbosity: int=1, **kwargs):
        strip_pretrain_dict = strip_named_params(pretrain_dict)
        weight_dict = {}

        for model_param_name in self.state_dict().keys():
            strip_model_param_name = strip_param_name(model_param_name)
            if strip_model_param_name in strip_pretrain_dict:
                weight_dict[model_param_name] = strip_pretrain_dict[strip_model_param_name]
            elif 'fc8' in strip_model_param_name:
                continue
            else:
                # Using param from cls branch with same layer_ID
                print(f"Init {strip_model_param_name} from cls head")
                pattern = re.sub(r'fc\d_\d', lambda match: match.group(0)[:-1] + '0', strip_model_param_name)
                weight = list(named_param_match_pattern(pattern,
                                                        strip_pretrain_dict,
                                                        full_match=True,
                                                        multi_match=False).values())[0]
                weight_dict[model_param_name] = weight
        update_model_state_dict(self, weight_dict, verbosity)
예제 #4
0
    def initialize_from_deeplab_v1(self,
                                   pretrain_dict: Dict[str, nn.Parameter],
                                   verbosity: int = 1):
        update_dict = {}
        matched_dict, missed_keys = self.find_match_miss(pretrain_dict)
        update_dict.update(matched_dict)

        if len(missed_keys) > 0:
            print("Init cls head from deeplab_v1")
            strip_pretrain_dict = strip_named_params(pretrain_dict)
            for model_param_name in missed_keys:
                strip_model_param_name = strip_param_name(model_param_name)
                if 'fc8' in strip_model_param_name:
                    continue
                else:
                    param_name_in_pretrain = re.sub(
                        r'(fc\d)_\d', lambda match: match.group(1),
                        strip_model_param_name)
                    update_dict[model_param_name] = strip_pretrain_dict[
                        param_name_in_pretrain]

        self.update_model_state_dict(update_dict, verbosity)
예제 #5
0
    def initialize_from_deeplab_v2(self,
                                   pretrain_dict: Dict[str, nn.Parameter],
                                   verbosity: int = 1):
        strip_pretrain_dict = strip_named_params(pretrain_dict)
        weight_dict = {}

        for model_param_name in self.state_dict().keys():
            strip_model_param_name = strip_param_name(model_param_name)
            if strip_model_param_name in strip_pretrain_dict:
                # Exact Match
                weight_dict[model_param_name] = strip_pretrain_dict[
                    strip_model_param_name]
            else:
                if 'fc8' in model_param_name:
                    # Don't init fc_8
                    continue
                # Fuzzy Match
                fuzzy_strip_model_param_name = re.sub(
                    r'fc\d_\d', lambda match: match.group(0)[:-1] + '1',
                    strip_model_param_name)
                weight_dict[model_param_name] = strip_pretrain_dict[
                    fuzzy_strip_model_param_name]

        update_model_state_dict(self, weight_dict, verbosity)