def LoadCfg(args, model_load_dir, load_type): if model_load_dir: if args.model == "resnet": assert os.path.isdir(model_load_dir) of_weight_path = model_load_dir.rsplit("/",1)[0] + "/weights_profile_path" cfg_temp = [] cfg = [] weights_dict = modelWeight.load(of_weight_path) for name, profile_dict in weights_dict.items(): if name.endswith("weight") and "stem" not in name and "shortcut" not in name: shape=profile_dict["shape"] cfg_temp.append(shape[0]) cfg.append(cfg_temp[0:9]) cfg.append(cfg_temp[9:21]) cfg.append(cfg_temp[21:39]) cfg.append(cfg_temp[39:48]) cfg.append(cfg_temp[48]) if load_type == 'train': modelWeight.weights_dict = {} else: assert os.path.isdir(model_load_dir) of_weight_path = model_load_dir.rsplit("/",1)[0] + "/weights_profile_path" cfg = [] weights_dict = modelWeight.load(of_weight_path) for name, profile_dict in weights_dict.items(): if name.endswith("weight"): shape=profile_dict["shape"] cfg.append(shape[0]) # print(load_type, modelWeight.weights_dict) if load_type == 'train': modelWeight.weights_dict = {} else: if args.model == 'vgg': # cfg = [64, 64, 128, 128, 256, 256, 256, 512, 512, 512, 512, 512, 512, 4096, 4096, args.num_classes] cfg = [64, 64, 128, 128, 256, 256, 256, 512, 512, 512, 512, 512, 512, 512, 128, args.num_classes] elif args.model == 'alexnet': cfg = [96, 256, 384, 384, 256, 4096, 4096, args.num_classes] elif args.model == 'alexnet_simple': cfg = [24, 96, 192, 192, 96, 1024, 1024, args.num_classes] elif args.model == 'lenet': cfg = [6, 16, 120, 84, args.num_classes] elif args.model == "resnet": cfg = [[64, 64, 256, 64, 64, 256, 64, 64, 256], [128, 128, 512, 128, 128, 512, 128, 128, 512, 128, 128, 512], [256, 256, 1024, 256, 256, 1024, 256, 256, 1024, 256, 256, 1024, 256, 256, 1024, 256, 256, 1024], [512, 512, 2048, 512, 512, 2048, 512, 512, 2048], args.num_classes] elif args.model == 'dnn_2': cfg = [128, args.num_classes] elif args.model == 'dnn_4': cfg = [4096, 256, 128, args.num_classes] else: cfg = [] if load_type == 'train': print('Model structure:', cfg) return cfg
def prune(): #获的剪枝的阈值 thre = pa.get_pruneThre_fc() of_weight_path = args.model_load_dir.rsplit("/",1)[0] + "/weights_profile_path" weights_dict = modelWeight.load(of_weight_path) modelWeight.weights_dict = {} removeIndexs = [] lastRemoveIndexs = [] beforePrune = 0 afterPrune = 0 dictLen = len(weights_dict) numDiv = 0 if args.optimizer == 'adam': numDiv = 6 elif args.optimizer == 'momentum': numDiv = 4 else: numDiv = 2 for name, profile_dict in weights_dict.items(): if name.startswith("dense") and name.endswith("-weight"): if name.startswith("dense"+str(int(dictLen/numDiv)-1)) and name.endswith("-weight"): lastRemoveIndexs = removeIndexs removeIndexs = [] else: a, dtype, shape = name2array(name, weights_dict) lastRemoveIndexs = removeIndexs #获取对应剪枝方法removeIndexs removeIndexs = pa.get_removeIndex_fc(a, shape, thre) if len(removeIndexs) == len(a): removeIndexs = np.delete(removeIndexs, 0) #待剪枝层的名字列表 name = name.split("_")[0].split("-")[0] nameList = [] nameList = makeNameList(nameList, name) #真正剪枝 i = 0 for name in nameList: a, dtype, shape = name2array(name, weights_dict) if "weight" in name: b = np.delete(a, removeIndexs, 0) b = np.delete(b, lastRemoveIndexs, 1) else: b = np.delete(a, removeIndexs) if i == 0: beforePrune += a.shape[0] afterPrune += b.shape[0] print(name+" pruned: shape from", a.shape, "-->", b.shape) if args.model_save_dir: folder = os.path.join(args.model_save_dir, "model", name) _SaveWeightBlob2File(b, folder, 'out') modelWeight.add(name, list(dtype_dict.keys())[list(dtype_dict.values()).index(dtype)], b.shape) i += 1 print("Pruning done! Number of channel from", beforePrune, "-->", afterPrune) print("Real Pruning rate:", 100*(beforePrune-afterPrune)/beforePrune, "%") weights_profile_path = os.path.join(args.model_save_dir, "weights_profile_path") modelWeight.save(weights_profile_path) os.system('cp -r {0}/System-Train-TrainStep-TrainNet {1}/System-Train-TrainStep-TrainNet '.format(args.model_load_dir, os.path.join(args.model_save_dir, "model")))
def prune(): # 获取对应剪枝方法的thre阈值 if args.prune_method == 'bn': thre = pa.get_pruneThre_bn() elif args.prune_method == 'conv_avg': thre = pa.get_pruneThre_conv_avg() elif args.prune_method == 'conv_all': thre = pa.get_pruneThre_conv_all() elif args.prune_method == 'conv_max': thre = pa.get_pruneThre_conv_max() of_weight_path = args.model_load_dir.rsplit("/", 1)[0] + "/weights_profile_path" weights_dict = modelWeight.load(of_weight_path) modelWeight.weights_dict = {} fcRemoveIndexs = [] fcDivideNum = 0 removeIndexs = [] lastRemoveIndexs = [] lastRemoveIndexs_shortcut = [] beforePrune = 0 afterPrune = 0 pruneName = '' if "bn" in args.prune_method: pruneName = "_bn-gamma" elif "conv" in args.prune_method or args.prune_method == "random": pruneName = "_weight" for name, profile_dict in weights_dict.items(): if name.startswith("conv") and name.endswith(pruneName) and \ "stem" not in name and "shortcut" not in name: a, dtype, shape = name2array(name, weights_dict) lastRemoveIndexs = removeIndexs #获取对应剪枝方法removeIndexs if args.prune_method == 'bn': removeIndexs = pa.get_removeIndex_bn(a, thre) elif args.prune_method == "conv_avg": removeIndexs = pa.get_removeIndex_conv_avg(a, shape, thre) elif args.prune_method == "conv_all": removeIndexs = pa.get_removeIndex_conv_all(a, shape, thre) elif args.prune_method == "conv_max": removeIndexs = pa.get_removeIndex_conv_max(a, shape, thre) elif args.prune_method == "random": removeIndexs = pa.get_removeIndex_random(shape) elif args.prune_method == "conv_similarity": removeIndexs = pa.get_removeIndex_conv_similarity(a, shape) elif args.prune_method == "bn_similarity": removeIndexs = pa.get_removeIndex_bn_similarity(a, shape) elif args.prune_method == "conv_threshold": removeIndexs = pa.get_removeIndex_conv_threshold(a, shape, threSet=0.06) if len(removeIndexs) == len(a): removeIndexs = np.delete(removeIndexs, 0) if name == "conv47" + pruneName: fcRemoveIndexs = removeIndexs fcDivideNum = 2048 #待剪枝层的名字列表 name = name.split("_")[0].split("-")[0] nameList = [] nameList = makeNameList(pruneName, nameList, name) #除了shortcut层的真正剪枝 for name in nameList: a, dtype, shape = name2array(name, weights_dict) if name.endswith("weight") or name.endswith("weight-v") or \ name.endswith("weight-m") or name.endswith("weight-momentum"): b = np.delete(a, removeIndexs, 0) b = np.delete(b, lastRemoveIndexs, 1) if name.endswith("weight"): beforePrune += a.shape[0] afterPrune += b.shape[0] else: b = np.delete(a, removeIndexs) print(name + " pruned: shape from", a.shape, "-->", b.shape) if args.model_save_dir: folder = os.path.join(args.model_save_dir, "model", name) _SaveWeightBlob2File(b, folder, 'out') modelWeight.add( name, list(dtype_dict.keys())[list( dtype_dict.values()).index(dtype)], b.shape) #resnet模型剪枝shortcut #addName是shortcut层的数字后缀 addName = "" #获取conv层name中的编号数字 n = int(name.split("_")[0].split("-")[0].replace("conv", "")) if (n + 1) % 3 == 0: n = int((n + 1) / 3) if n <= 3: addName = "0_" + str(n - 1) elif n <= 7: addName = "1_" + str(n - 4) elif n <= 13: addName = "2_" + str(n - 8) elif n <= 16: addName = "3_" + str(n - 14) name = "conv_shortcut" + addName #shortcut的conv层待剪枝层的名字列表 #nameList_shortcut是裁剪所有的名字列表 nameList_shortcut = [] nameList_shortcut = makeNameList(pruneName, nameList_shortcut, name) #resnet模型的shortcut真正剪枝 for name in nameList_shortcut: a, dtype, shape = name2array(name, weights_dict) if name.endswith("weight") or name.endswith("weight-v") or \ name.endswith("weight-m") or name.endswith("weight-momentum"): b = np.delete(a, removeIndexs, 0) b = np.delete(b, lastRemoveIndexs_shortcut, 1) else: b = np.delete(a, removeIndexs) print(name + " pruned: shape from", a.shape, "-->", b.shape) if args.model_save_dir: folder = os.path.join(args.model_save_dir, "model", name) _SaveWeightBlob2File(b, folder, 'out') modelWeight.add( name, list(dtype_dict.keys())[list( dtype_dict.values()).index(dtype)], b.shape) lastRemoveIndexs_shortcut = removeIndexs #复制stem层 elif "stem" in name: a, dtype, shape = name2array(name, weights_dict) b = a print(name + " copy") if args.model_save_dir: folder = os.path.join(args.model_save_dir, "model", name) _SaveWeightBlob2File(b, folder, 'out') modelWeight.add( name, list(dtype_dict.keys())[list( dtype_dict.values()).index(dtype)], b.shape) #第一个dense0层剪枝 elif name.startswith("dense"): if name in [ 'dense0-weight', 'dense0-weight-v', 'dense0-weight-m', 'dense0-weight-momentum' ]: fcRemoveIndexsNew = [] a, dtype, shape = name2array(name, weights_dict) num = int(a.shape[1] / fcDivideNum) for index in fcRemoveIndexs: fcRemoveIndexsNew += [ index + fcDivideNum * i for i in range(num) ] b = np.delete(a, fcRemoveIndexsNew, 1) else: a, dtype, shape = name2array(name, weights_dict) b = a print(name + " pruned: shape from", a.shape, "-->", b.shape) if args.model_save_dir: folder = os.path.join(args.model_save_dir, "model", name) _SaveWeightBlob2File(b, folder, 'out') modelWeight.add( name, list(dtype_dict.keys())[list( dtype_dict.values()).index(dtype)], b.shape) print("Pruning done! Number of channel from", beforePrune, "-->", afterPrune) print("Real Pruning rate:", 100 * (beforePrune - afterPrune) / beforePrune, "%") weights_profile_path = os.path.join(args.model_save_dir, "weights_profile_path") modelWeight.save(weights_profile_path) os.system( 'cp -r {0}/System-Train-TrainStep-TrainNet {1}/System-Train-TrainStep-TrainNet ' .format(args.model_load_dir, os.path.join(args.model_save_dir, "model")))