forked from Lam1360/YOLOv3-model-pruning
-
Notifications
You must be signed in to change notification settings - Fork 0
/
test_prune.py
212 lines (165 loc) · 8.46 KB
/
test_prune.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
from models import *
from utils.utils import *
import torch
import numpy as np
from copy import deepcopy
from test import evaluate
from terminaltables import AsciiTable #AsciiTable是最简单的表。它使用+,|和-字符来构建边框。
import time
from utils.prune_utils import * #剪枝相关的实现都在这里
class opt():
model_def = "config/yolov3-hand.cfg"
data_config = "config/oxfordhand.data" # 存储类别,训练验证集路径,类别对应名字等
model = 'checkpoints/yolov3_ckpt.pth' # 稀疏训练之后的模型
#%%
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = Darknet(opt.model_def).to(device)
model.load_state_dict(torch.load(opt.model)) #加载模型
# 解析config文件
data_config = parse_data_config(opt.data_config)
valid_path = data_config["valid"] #获取验证集路径
class_names = load_classes(data_config["names"]) #加载类别对应名字
eval_model = lambda model:evaluate(model, path=valid_path, iou_thres=0.5, conf_thres=0.01,
nms_thres=0.5, img_size=model.img_size, batch_size=8)
obtain_num_parameters = lambda model:sum([param.nelement() for param in model.parameters()])
origin_model_metric = eval_model(model) #稀疏化训练的模型的评价指标(还没有剪枝)
origin_nparameters = obtain_num_parameters(model) #稀疏化训练的模型的参数
# 返回CBL组件的id,单独的Conv层的id,以及需要被剪枝的层的id
CBL_idx, Conv_idx, prune_idx= parse_module_defs(model.module_defs)
# 获取所有CBL组件的BN层的权重,即Gamma参数,我们会根据这个参数来剪枝
bn_weights = gather_bn_weights(model.module_list, prune_idx)
# 按照Gamma参数的大小进行排序
sorted_bn = torch.sort(bn_weights)[0]
# 避免剪掉所有channel的最高阈值(每个BN层的gamma的最大值列表,其中的最小值即为阈值上限)
# 设定这个阈值后,即使剪得最多的层也会至少保留一个通道,不会出错
highest_thre = []
for idx in prune_idx:
highest_thre.append(model.module_list[idx][1].weight.data.abs().max().item())
highest_thre = min(highest_thre)
# 找到highest_thre对应的下标对应的百分比,这是剪枝的极限
percent_limit = (sorted_bn==highest_thre).nonzero().item()/len(bn_weights)
print(f'Threshold should be less than {highest_thre:.4f}.') # 剪枝gamma阈值是自动计算,不用设定
print(f'The corresponding prune ratio is {percent_limit:.3f}.') # 剪枝比例需要设定,这里给出提示不能超过该比例
#%%
#开始剪枝
def prune_and_eval(model, sorted_bn, percent=.0):
# 请看https://blog.csdn.net/sodalife/article/details/89461030的解释
model_copy = deepcopy(model)
thre_index = int(len(sorted_bn) * percent)
# 需要剪枝的权重阈值,即<thre那么这个通道就剪枝掉,因为这个通道不那么重要了
thre = sorted_bn[thre_index]
print(f'Channels with Gamma value less than {thre:.4f} are pruned!')
remain_num = 0
for idx in prune_idx:
bn_module = model_copy.module_list[idx][1]
# 返回是否需要剪枝的通道状态mask
mask = obtain_bn_mask(bn_module, thre)
# 记录保留的通道数目
remain_num += int(mask.sum()) # mask中1为保留
# BN层的权重(gamma)乘以这个mask,就相当于剪枝了
bn_module.weight.data.mul_(mask)
# 计算剪枝后的模型的mAP
mAP = eval_model(model_copy)[2].mean()
print(f'Number of channels has been reduced from {len(sorted_bn)} to {remain_num}')
print(f'Prune ratio: {1-remain_num/len(sorted_bn):.3f}')
print(f'mAP of the pruned model is {mAP:.4f}')
# 返回需要剪枝的权重阈值
return thre
# 表示剪枝掉85%的参数
percent = 0.85
# 求需要剪枝的权重阈值,将BN权重置0后评估模型mAP
threshold = prune_and_eval(model, sorted_bn, percent)
#%%
# 获取每一个BN层通道状态
def obtain_filters_mask(model, thre, CBL_idx, prune_idx):
pruned = 0
total = 0 # 所有BN层通道数之和
num_filters = [] # 每CBL层剪枝后保留的通道数
filters_mask = [] # 每CBL层剪枝mask
for idx in CBL_idx:
bn_module = model.module_list[idx][1]
# 如果idx是在剪枝下标的列表中,就执行剪枝
if idx in prune_idx:
mask = obtain_bn_mask(bn_module, thre).cpu().numpy()
# 保留的通道数
remain = int(mask.sum())
# 剪掉的通道数
pruned = pruned + mask.shape[0] - remain
if remain == 0:
print("Channels would be all pruned!")
raise Exception
print(f'layer index: {idx:>3d} \t total channel: {mask.shape[0]:>4d} \t '
f'remaining channel: {remain:>4d}')
else:
# 不用剪枝就全部保留
mask = np.ones(bn_module.weight.data.shape)
remain = mask.shape[0]
total += mask.shape[0]
num_filters.append(remain)
filters_mask.append(mask.copy())
prune_ratio = pruned / total
print(f'Prune channels: {pruned}\tPrune ratio: {prune_ratio:.3f}')
# 输出每CBL层保留的通道数列表和每CBL层的通道mask
return num_filters, filters_mask
# 调用上面的函数
num_filters, filters_mask = obtain_filters_mask(model, threshold, CBL_idx, prune_idx)
#%%
#映射成一个字典,idx->mask, 层索引:剪枝通道mask
CBLidx2mask = {idx: mask for idx, mask in zip(CBL_idx, filters_mask)}
# 获得剪枝后的模型
pruned_model = prune_model_keep_size(model, prune_idx, CBL_idx, CBLidx2mask)
# 对剪枝后的模型进行评价
eval_model(pruned_model)
#%%
# 拷贝一份原始模型的参数
compact_module_defs = deepcopy(model.module_defs)
# 遍历需要剪枝的CBL模块,将通道数设置为剪枝后的通道数
for idx, num in zip(CBL_idx, num_filters):
assert compact_module_defs[idx]['type'] == 'convolutional'
compact_module_defs[idx]['filters'] = str(num)
#%%
# compact_model是剪枝之后的网络的真实结构(注意上面的剪枝网络只是把那些需要剪枝的卷积层/BN层/激活层通道的权重置0了,并没有保存剪枝后的网络)
compact_model = Darknet([model.hyperparams.copy()] + compact_module_defs).to(device)
# 计算参数量,MFLOPs
compact_nparameters = obtain_num_parameters(compact_model)
# 为剪枝后的真实网络结构重新复制权重参数
init_weights_from_loose_model(compact_model, pruned_model, CBL_idx, Conv_idx, CBLidx2mask)
#%%
# 随机初始化一个输入
random_input = torch.rand((1, 3, model.img_size, model.img_size)).to(device)
# 获取模型的推理时间
def obtain_avg_forward_time(input, model, repeat=200):
model.eval()
start = time.time()
with torch.no_grad():
for i in range(repeat):
output = model(input)
avg_infer_time = (time.time() - start) / repeat
return avg_infer_time, output
# 分别获取原始模型和剪枝后的模型的推理时间和输出
pruned_forward_time, pruned_output = obtain_avg_forward_time(random_input, pruned_model)
compact_forward_time, compact_output = obtain_avg_forward_time(random_input, compact_model)
# 计算原始模型推理结果和剪枝后的模型的推理结果,如果差距比较大说明哪里错了
diff = (pruned_output-compact_output).abs().gt(0.001).sum().item()
if diff > 0:
print('Something wrong with the pruned model!')
#%%
# 在测试集上测试剪枝后的模型, 并统计模型的参数数量
compact_model_metric = eval_model(compact_model)
#%%
# 比较剪枝前后参数数量的变化、指标性能的变化
metric_table = [
["Metric", "Before", "After"],
["mAP", f'{origin_model_metric[2].mean():.6f}', f'{compact_model_metric[2].mean():.6f}'],
["Parameters", f"{origin_nparameters}", f"{compact_nparameters}"],
["Inference", f'{pruned_forward_time:.4f}', f'{compact_forward_time:.4f}']
]
print(AsciiTable(metric_table).table)
#%%
# 生成剪枝后的cfg文件并保存模型
pruned_cfg_name = opt.model_def.replace('/', f'/prune_{percent}_')
pruned_cfg_file = write_cfg(pruned_cfg_name, [model.hyperparams.copy()] + compact_module_defs)
print(f'Config file has been saved: {pruned_cfg_file}')
compact_model_name = opt.model.replace('/', f'/prune_{percent}_')
torch.save(compact_model.state_dict(), compact_model_name)
print(f'Compact model has been saved: {compact_model_name}')