def __init__(self, args): self.args = args[0] self.ckp = args[1] super(Hinge, self).__init__(self.args) # traning or loading phase of network compression if self.args.model.lower().find( 'hinge') >= 0 and not self.args.test_only: self.load(strict=True) if self.args.data_train.find('CIFAR') >= 0: self.input_dim = (3, 32, 32) elif self.args.data_train.find('Tiny_ImageNet') >= 0: self.input_dim = (3, 64, 64) else: self.input_dim = (3, 224, 224) self.flops, self.params = get_model_complexity_info( self, self.input_dim, print_per_layer_stat=False) modify_network(self) if self.args.model.lower().find( 'hinge' ) >= 0 and not self.args.test_only and self.args.layer_balancing: # need to calculate the layer-balancing regularizer during both training and loading phase. self.layer_reg = self.calc_regularization() for l, reg in enumerate(self.layer_reg): print('DenseBlock {:<2}: {:<2.4f}'.format(l + 1, reg))
def calc_model_complexity(model): model = model.get_model() model.flops_compress, model.params_compress = get_model_complexity_info(model, model.input_dim, print_per_layer_stat=False) print('FLOPs ratio {:.2f} = {:.4f} [G] / {:.4f} [G]; Parameter ratio {:.2f} = {:.2f} [k] / {:.2f} [k].\n\n' .format(model.flops_compress / model.flops * 100, model.flops_compress / 10. ** 9, model.flops / 10. ** 9, model.params_compress / model.params * 100, model.params_compress / 10. ** 3, model.params / 10. ** 3))
def __init__(self, args): self.args = args[0] self.ckp = args[1] super(Hinge, self).__init__(self.args) # traning phase of network compression if self.args.model.lower().find('hinge') >= 0 and not self.args.test_only and not self.args.load: # for (k1, v1), (k2, v2) in zip(self.state_dict().items(), torch.load(self.args.pretrain).items()): # print(k1, v1.shape) # print(k2, v2.shape) # embed() self.load(self.args, strict=False) # self.num_params = dutil.param_count(self, self.args.ignore_linear) if self.args.data_train.find('CIFAR') >= 0: self.input_dim = (3, 32, 32) elif self.args.data_train.find('Tiny_ImageNet') >= 0: self.input_dim = (3, 64, 64) else: self.input_dim = (3, 224, 224) # self.s = 224 if self.args.data_train == 'ImageNet' else 32 self.flops, self.params = get_model_complexity_info(self, self.input_dim, print_per_layer_stat=False) # if self.args.model.lower().find('resnet') >= 0: self.register_buffer('running_grad_ratio', None) modify_network(self)
def __init__(self, args, ckp, converging): self.args = args self.ckp = ckp super(Prune, self).__init__(self.args) # traning or loading for searching if not self.args.test_only and not converging: self.load(self.args, strict=True) if self.args.data_train.find('CIFAR') >= 0: self.input_dim = (3, 32, 32) elif self.args.data_train.find('Tiny_ImageNet') >= 0: self.input_dim = (3, 64, 64) else: self.input_dim = (3, 224, 224) self.flops, self.params = get_model_complexity_info( self, self.input_dim, print_per_layer_stat=False)
def __init__(self, args): self.args = args[0] self.ckp = args[1] super(Hinge, self).__init__(self.args) # traning phase of network compression if self.args.model.lower().find( 'hinge' ) >= 0 and not self.args.test_only and not self.args.load: self.load(self.args, strict=True) if self.args.data_train.find('CIFAR') >= 0: self.input_dim = (3, 32, 32) elif self.args.data_train.find('Tiny_ImageNet') >= 0: self.input_dim = (3, 64, 64) else: self.input_dim = (3, 224, 224) self.flops, self.params = get_model_complexity_info( self, self.input_dim, print_per_layer_stat=False) modify_network(self)
def __init__(self, args, ckp, converging): self.args = args self.ckp = ckp super(Hinge, self).__init__(self.args) # traning or loading for searching if not self.args.test_only and not converging: self.load(self.args, strict=False) # self.num_params = dutil.param_count(self, self.args.ignore_linear) if self.args.data_train.find('CIFAR') >= 0: self.input_dim = (3, 32, 32) elif self.args.data_train.find('Tiny_ImageNet') >= 0: self.input_dim = (3, 64, 64) else: self.input_dim = (3, 224, 224) # self.s = 224 if self.args.data_train == 'ImageNet' else 32 self.flops, self.params = get_model_complexity_info(self, self.input_dim, print_per_layer_stat=False) # if self.args.model.lower().find('resnet') >= 0: self.register_buffer('running_grad_ratio', None) modify_network(self)
def __init__(self, args, ckp, converging): self.args = args self.ckp = ckp super(Hinge, self).__init__(self.args) # traning or loading for searching if not self.args.test_only and not converging: self.load(self.args.pretrain, strict=True) if self.args.layer_balancing: # need to calculate the layer-balancing regularizer during both training and loading phase. self.layer_reg = self.calc_regularization() for l, reg in enumerate(self.layer_reg): print('Block {:<2}: regularization {:<2.4f}'.format( l + 1, reg)) if self.args.data_train.find('CIFAR') >= 0: self.input_dim = (3, 32, 32) elif self.args.data_train.find('Tiny_ImageNet') >= 0: self.input_dim = (3, 64, 64) else: self.input_dim = (3, 224, 224) self.flops, self.params = get_model_complexity_info( self, self.input_dim, print_per_layer_stat=False)
def algorithm(self, t, P, ratio): # 输入:预训练包含滤波器集合F的网络 # 每组滤波器由几个滤波器组成 :P # 输出:剪枝后的包含滤波器集合F'的网络 # index = [] current_ratio_list = [] G_MWG = {} # {key : layer_num#channel_num#P, value : importance} # 计算每组filter的重要性 modules = [m for m in self.modules() if isinstance(m, BasicBlock)][0:11] for layer, module_cur in enumerate(modules): # 不剪最后一层 conv11 = module_cur[0] ws1 = conv11.weight.shape projection1 = conv11.weight.data.view( ws1[0], -1).t() # reshape (input * k * k,output) Fl = torch.norm(projection1, p=2, dim=0) # shape : output # FR eg: tensor([0, 2, 3, 1]) FR = Fl.sort()[1] # 每个filter在当前层的重要性排序 按二范数升序排列 越小越容易被剪枝 filter_num = FR.shape[0] # print("FR : ") # print(FR) cluster_num = math.ceil(filter_num / P) factors = np.zeros(cluster_num) for i in range(filter_num): # 当前层第i名 filter_indice = FR[i] # 对应的index索引 factors[int(filter_indice / P)] = factors[int( filter_indice / P)] + (Fl[filter_indice] * i) / P for cluster in range(cluster_num): key = str(layer) + '#' + str(cluster) + '#' + str(P) G_MWG[key] = factors[cluster] # print("factors : ") # print(factors) current_ratio = 1.0 self.flops_compress = self.flops while current_ratio > ratio: # t.train() t.test() # 以上是常规操作 ###################################################### (layer_num, group_num, key) = findMinGroup(G_MWG) del G_MWG[key] # 删除该元素 print("prun layer :%d, group : %d" % (layer_num, group_num)) cur_layer = modules[layer_num] conv11 = cur_layer[0] ws1 = conv11.weight.shape weight1 = conv11.weight.data.view(ws1[0], -1).t() pindex1 = torch.ones(weight1.shape[1]).to(weight1.device) pindex1[group_num * P:group_num * P + P] = 0 pindex1 = torch.nonzero(pindex1).squeeze(dim=1) self.compress_one_layer(layer_num, -1, pindex1=pindex1) # calc_model_complexity(self) self.flops_compress, self.params_compress = get_model_complexity_info( self, self.input_dim, print_per_layer_stat=False) print( 'FLOPs ratio {:.2f} = {:.4f} [G] / {:.4f} [G]; Parameter ratio {:.2f} = {:.2f} [k] / {:.2f} [k].' .format(self.flops_compress / self.flops * 100, self.flops_compress / 10.**9, self.flops / 10.**9, self.params_compress / self.params * 100, self.params_compress / 10.**3, self.params / 10.**3)) current_ratio = self.flops_compress / self.flops t.model.get_model().current_ratio_list.append( "{:.4f}".format(current_ratio)) print("current_ratio_list : ") print(t.model.get_model().current_ratio_list) t.model.get_model().parameter_ratio_list.append("{:.4f}".format( self.params_compress / self.params)) print("parameter ratio list : ") print(t.model.get_model().parameter_ratio_list) save(t.model.get_model().current_ratio_list, t.model.get_model().timer_test_list, t.model.get_model().sum_list, t.model.get_model().top1_err_list, t.model.get_model().parameter_ratio_list)