예제 #1
0
 def merge_bn_in(self, bn, affine_only):
     if affine_only and not bn.affine:
         raise Exception(
             "Affine-only merging requires BN to have affine scaling enabled."
         )
     else:
         mul_factor, add_factor = mul_add_from_bn(
             bn_mean=bn.running_mean,
             bn_var=bn.running_var,
             bn_eps=bn.eps,
             bn_weight=bn.weight.data.clone(),
             bn_bias=bn.bias.data.clone(),
             affine_only=affine_only)
         self.weight.data *= mul_factor.view(
             self.per_output_channel_broadcastable_shape)
         if self.bias is not None:
             self.bias.data += add_factor
         else:
             self.bias = Parameter(add_factor)
예제 #2
0
    def fuse_bn(self, state_dict,
            prefix):
        index = 0
        flag = False
        keys_to_check = []
        keys_to_delete = []
        for k in state_dict.keys():
            if k.startswith(prefix):
                keys_to_check.append(k)
                if k.split('.')[-1] == 'running_mean':
                    flag = True


        if flag:
            for name in keys_to_check:
                prefix_long = name.split('.')[:-1]

                if name.split('.')[-1] == "running_mean":
                    # print("Found")
                    bn_prefix = '.'.join(prefix_long)
                    module_number = int(prefix_long[-1])
                    # print(bn_prefix)
                    conv_name = prefix_long[:-1] + [str(module_number-1)] + ['conv']
                    conv_name = '.'.join(conv_name)
                    # print(conv_name)
                    conv_mod = self.conv_module_to_merge[index]
                    index = index + 1
                    bn_weight_key = '.'.join([bn_prefix, 'weight'])
                    bn_bias_key = '.'.join([bn_prefix, 'bias'])
                    bn_running_mean_key = '.'.join([bn_prefix, 'running_mean'])
                    bn_running_var_key = '.'.join([bn_prefix, 'running_var'])
                    bn_num_batches_traked_key = '.'.join([bn_prefix, 'num_batches_tracked'])
                    keys_to_delete = keys_to_delete + [bn_bias_key]
                    keys_to_delete = keys_to_delete + [bn_weight_key]
                    keys_to_delete = keys_to_delete + [bn_running_mean_key]
                    keys_to_delete = keys_to_delete + [bn_running_var_key]
                    keys_to_delete = keys_to_delete + [bn_num_batches_traked_key]
                    mul_factor, add_factor = mul_add_from_bn(
                        bn_mean=state_dict[bn_running_mean_key],
                        bn_var=state_dict[bn_running_var_key],
                        bn_eps=1e-03,
                        bn_weight=state_dict[bn_weight_key],
                        bn_bias=state_dict[bn_bias_key],
                        affine_only=False)
                    if isinstance(conv_mod, MaskedConv1d):
                        conv_mod = conv_mod.conv
                    mul_shape = conv_mod.per_output_channel_broadcastable_shape
                    conv_weight_key = conv_name + '.weight'
                    conv_bias_key = conv_name + '.bias'
                    result = state_dict[conv_weight_key] * mul_factor.view(mul_shape)

                    state_dict[conv_weight_key] = result

                    if conv_mod.bias is not None and conv_bias_key in state_dict:
                        state_dict[conv_bias_key] += add_factor
                    elif conv_mod.bias is not None and not conv_bias_key in state_dict:
                        state_dict[conv_bias_key] = add_factor
                    else:
                        if torch.cuda.is_available():
                            add_factor = add_factor.cuda()
                        conv_mod.bias = nn.Parameter(add_factor)
                        # add it to the dict any to avoid missing key error
                        state_dict[conv_bias_key] = add_factor

                    # Get rid of statistics after using them
                else:
                    state_dict[name] = state_dict[name]
        for k in list(state_dict.keys()):
            if k in keys_to_delete:
                del state_dict[k]
        assert len(self.conv_module_to_merge) == index