def merge_bn_in(self, bn, affine_only): if affine_only and not bn.affine: raise Exception( "Affine-only merging requires BN to have affine scaling enabled." ) else: mul_factor, add_factor = mul_add_from_bn( bn_mean=bn.running_mean, bn_var=bn.running_var, bn_eps=bn.eps, bn_weight=bn.weight.data.clone(), bn_bias=bn.bias.data.clone(), affine_only=affine_only) self.weight.data *= mul_factor.view( self.per_output_channel_broadcastable_shape) if self.bias is not None: self.bias.data += add_factor else: self.bias = Parameter(add_factor)
def fuse_bn(self, state_dict, prefix): index = 0 flag = False keys_to_check = [] keys_to_delete = [] for k in state_dict.keys(): if k.startswith(prefix): keys_to_check.append(k) if k.split('.')[-1] == 'running_mean': flag = True if flag: for name in keys_to_check: prefix_long = name.split('.')[:-1] if name.split('.')[-1] == "running_mean": # print("Found") bn_prefix = '.'.join(prefix_long) module_number = int(prefix_long[-1]) # print(bn_prefix) conv_name = prefix_long[:-1] + [str(module_number-1)] + ['conv'] conv_name = '.'.join(conv_name) # print(conv_name) conv_mod = self.conv_module_to_merge[index] index = index + 1 bn_weight_key = '.'.join([bn_prefix, 'weight']) bn_bias_key = '.'.join([bn_prefix, 'bias']) bn_running_mean_key = '.'.join([bn_prefix, 'running_mean']) bn_running_var_key = '.'.join([bn_prefix, 'running_var']) bn_num_batches_traked_key = '.'.join([bn_prefix, 'num_batches_tracked']) keys_to_delete = keys_to_delete + [bn_bias_key] keys_to_delete = keys_to_delete + [bn_weight_key] keys_to_delete = keys_to_delete + [bn_running_mean_key] keys_to_delete = keys_to_delete + [bn_running_var_key] keys_to_delete = keys_to_delete + [bn_num_batches_traked_key] mul_factor, add_factor = mul_add_from_bn( bn_mean=state_dict[bn_running_mean_key], bn_var=state_dict[bn_running_var_key], bn_eps=1e-03, bn_weight=state_dict[bn_weight_key], bn_bias=state_dict[bn_bias_key], affine_only=False) if isinstance(conv_mod, MaskedConv1d): conv_mod = conv_mod.conv mul_shape = conv_mod.per_output_channel_broadcastable_shape conv_weight_key = conv_name + '.weight' conv_bias_key = conv_name + '.bias' result = state_dict[conv_weight_key] * mul_factor.view(mul_shape) state_dict[conv_weight_key] = result if conv_mod.bias is not None and conv_bias_key in state_dict: state_dict[conv_bias_key] += add_factor elif conv_mod.bias is not None and not conv_bias_key in state_dict: state_dict[conv_bias_key] = add_factor else: if torch.cuda.is_available(): add_factor = add_factor.cuda() conv_mod.bias = nn.Parameter(add_factor) # add it to the dict any to avoid missing key error state_dict[conv_bias_key] = add_factor # Get rid of statistics after using them else: state_dict[name] = state_dict[name] for k in list(state_dict.keys()): if k in keys_to_delete: del state_dict[k] assert len(self.conv_module_to_merge) == index