예제 #1
0
def LoadCfg(args, model_load_dir, load_type):
    if model_load_dir:
        if args.model == "resnet":
            assert os.path.isdir(model_load_dir)
            of_weight_path = model_load_dir.rsplit("/",1)[0] + "/weights_profile_path" 
            cfg_temp = []
            cfg = []
            weights_dict = modelWeight.load(of_weight_path)
            for name, profile_dict in weights_dict.items():
                if name.endswith("weight") and "stem" not in name and "shortcut" not in name:
                    shape=profile_dict["shape"]
                    cfg_temp.append(shape[0])
            cfg.append(cfg_temp[0:9])
            cfg.append(cfg_temp[9:21])
            cfg.append(cfg_temp[21:39])
            cfg.append(cfg_temp[39:48])
            cfg.append(cfg_temp[48])
            if load_type == 'train':
                modelWeight.weights_dict = {}
            
        else:            
            assert os.path.isdir(model_load_dir)
            of_weight_path = model_load_dir.rsplit("/",1)[0] + "/weights_profile_path" 
            cfg = []
            weights_dict = modelWeight.load(of_weight_path)
            for name, profile_dict in weights_dict.items():
                if name.endswith("weight"):
                    shape=profile_dict["shape"]
                    cfg.append(shape[0])
#           print(load_type, modelWeight.weights_dict)
            if load_type == 'train':
                modelWeight.weights_dict = {}
    else:
        if args.model == 'vgg':
            # cfg = [64, 64, 128, 128, 256, 256, 256, 512, 512, 512, 512, 512, 512, 4096, 4096, args.num_classes]
            cfg = [64, 64, 128, 128, 256, 256, 256, 512, 512, 512, 512, 512, 512, 512, 128, args.num_classes]
        elif args.model == 'alexnet':
            cfg = [96, 256, 384, 384, 256, 4096, 4096, args.num_classes]
        elif args.model == 'alexnet_simple':
            cfg = [24, 96, 192, 192, 96, 1024, 1024, args.num_classes]
        elif args.model == 'lenet':
            cfg = [6, 16, 120, 84, args.num_classes]
        elif args.model ==  "resnet":
            cfg = [[64, 64, 256, 64, 64, 256, 64, 64, 256],
                   [128, 128, 512, 128, 128, 512, 128, 128, 512, 128, 128, 512],
                   [256, 256, 1024, 256, 256, 1024, 256, 256, 1024, 256, 256, 1024, 256, 256, 1024, 256, 256, 1024],
                   [512, 512, 2048, 512, 512, 2048, 512, 512, 2048], args.num_classes]
        elif args.model == 'dnn_2':
            cfg = [128, args.num_classes]
        elif args.model == 'dnn_4':
            cfg = [4096, 256, 128, args.num_classes]
        else:
            cfg = []
    if load_type == 'train':
        print('Model structure:', cfg)   
    return cfg
def prune():
    #获的剪枝的阈值
    thre = pa.get_pruneThre_fc()
    of_weight_path = args.model_load_dir.rsplit("/",1)[0] + "/weights_profile_path"
    weights_dict = modelWeight.load(of_weight_path)
    
    modelWeight.weights_dict = {}

    removeIndexs = []
    lastRemoveIndexs = []
    beforePrune = 0
    afterPrune = 0
    dictLen = len(weights_dict)
    numDiv = 0
    if args.optimizer == 'adam':
        numDiv = 6
    elif args.optimizer == 'momentum':
        numDiv = 4
    else:
        numDiv = 2
    
    for name, profile_dict in weights_dict.items():
        if name.startswith("dense") and name.endswith("-weight"):
            if name.startswith("dense"+str(int(dictLen/numDiv)-1)) and name.endswith("-weight"):
                lastRemoveIndexs = removeIndexs
                removeIndexs = []
            else:
                a, dtype, shape = name2array(name, weights_dict)
                lastRemoveIndexs = removeIndexs
                #获取对应剪枝方法removeIndexs
                removeIndexs = pa.get_removeIndex_fc(a, shape, thre)
            
                if len(removeIndexs) == len(a):
                    removeIndexs = np.delete(removeIndexs, 0)

            #待剪枝层的名字列表
            name = name.split("_")[0].split("-")[0]
            nameList = []
            nameList = makeNameList(nameList, name)

            #真正剪枝
            i = 0
            for name in nameList:   
                a, dtype, shape = name2array(name, weights_dict)                  
                if "weight" in name:
                    b = np.delete(a, removeIndexs, 0)
                    b = np.delete(b, lastRemoveIndexs, 1) 
                else:
                    b = np.delete(a, removeIndexs)
                if i == 0:
                    beforePrune += a.shape[0]
                    afterPrune += b.shape[0]
                print(name+" pruned: shape from", a.shape, "-->", b.shape)
                if args.model_save_dir:
                    folder = os.path.join(args.model_save_dir, "model", name)
                    _SaveWeightBlob2File(b, folder, 'out')
                modelWeight.add(name, list(dtype_dict.keys())[list(dtype_dict.values()).index(dtype)], b.shape)
                i += 1
    
    print("Pruning done! Number of channel from", beforePrune, "-->", afterPrune)
    print("Real Pruning rate:", 100*(beforePrune-afterPrune)/beforePrune, "%")
    weights_profile_path = os.path.join(args.model_save_dir, "weights_profile_path")
    modelWeight.save(weights_profile_path)
    os.system('cp -r {0}/System-Train-TrainStep-TrainNet {1}/System-Train-TrainStep-TrainNet '.format(args.model_load_dir, os.path.join(args.model_save_dir, "model")))
예제 #3
0
def prune():
    # 获取对应剪枝方法的thre阈值
    if args.prune_method == 'bn':
        thre = pa.get_pruneThre_bn()
    elif args.prune_method == 'conv_avg':
        thre = pa.get_pruneThre_conv_avg()
    elif args.prune_method == 'conv_all':
        thre = pa.get_pruneThre_conv_all()
    elif args.prune_method == 'conv_max':
        thre = pa.get_pruneThre_conv_max()

    of_weight_path = args.model_load_dir.rsplit("/",
                                                1)[0] + "/weights_profile_path"
    weights_dict = modelWeight.load(of_weight_path)

    modelWeight.weights_dict = {}

    fcRemoveIndexs = []
    fcDivideNum = 0
    removeIndexs = []
    lastRemoveIndexs = []
    lastRemoveIndexs_shortcut = []
    beforePrune = 0
    afterPrune = 0
    pruneName = ''

    if "bn" in args.prune_method:
        pruneName = "_bn-gamma"
    elif "conv" in args.prune_method or args.prune_method == "random":
        pruneName = "_weight"

    for name, profile_dict in weights_dict.items():
        if name.startswith("conv") and name.endswith(pruneName) and \
        "stem" not in name and "shortcut" not in name:
            a, dtype, shape = name2array(name, weights_dict)
            lastRemoveIndexs = removeIndexs
            #获取对应剪枝方法removeIndexs
            if args.prune_method == 'bn':
                removeIndexs = pa.get_removeIndex_bn(a, thre)
            elif args.prune_method == "conv_avg":
                removeIndexs = pa.get_removeIndex_conv_avg(a, shape, thre)
            elif args.prune_method == "conv_all":
                removeIndexs = pa.get_removeIndex_conv_all(a, shape, thre)
            elif args.prune_method == "conv_max":
                removeIndexs = pa.get_removeIndex_conv_max(a, shape, thre)
            elif args.prune_method == "random":
                removeIndexs = pa.get_removeIndex_random(shape)
            elif args.prune_method == "conv_similarity":
                removeIndexs = pa.get_removeIndex_conv_similarity(a, shape)
            elif args.prune_method == "bn_similarity":
                removeIndexs = pa.get_removeIndex_bn_similarity(a, shape)
            elif args.prune_method == "conv_threshold":
                removeIndexs = pa.get_removeIndex_conv_threshold(a,
                                                                 shape,
                                                                 threSet=0.06)

            if len(removeIndexs) == len(a):
                removeIndexs = np.delete(removeIndexs, 0)

            if name == "conv47" + pruneName:
                fcRemoveIndexs = removeIndexs
                fcDivideNum = 2048

            #待剪枝层的名字列表
            name = name.split("_")[0].split("-")[0]
            nameList = []
            nameList = makeNameList(pruneName, nameList, name)

            #除了shortcut层的真正剪枝
            for name in nameList:
                a, dtype, shape = name2array(name, weights_dict)
                if name.endswith("weight") or name.endswith("weight-v") or \
                   name.endswith("weight-m") or name.endswith("weight-momentum"):
                    b = np.delete(a, removeIndexs, 0)
                    b = np.delete(b, lastRemoveIndexs, 1)
                    if name.endswith("weight"):
                        beforePrune += a.shape[0]
                        afterPrune += b.shape[0]

                else:
                    b = np.delete(a, removeIndexs)

                print(name + " pruned: shape from", a.shape, "-->", b.shape)
                if args.model_save_dir:
                    folder = os.path.join(args.model_save_dir, "model", name)
                    _SaveWeightBlob2File(b, folder, 'out')
                modelWeight.add(
                    name,
                    list(dtype_dict.keys())[list(
                        dtype_dict.values()).index(dtype)], b.shape)

            #resnet模型剪枝shortcut
            #addName是shortcut层的数字后缀
            addName = ""
            #获取conv层name中的编号数字
            n = int(name.split("_")[0].split("-")[0].replace("conv", ""))
            if (n + 1) % 3 == 0:
                n = int((n + 1) / 3)
                if n <= 3:
                    addName = "0_" + str(n - 1)
                elif n <= 7:
                    addName = "1_" + str(n - 4)
                elif n <= 13:
                    addName = "2_" + str(n - 8)
                elif n <= 16:
                    addName = "3_" + str(n - 14)
                name = "conv_shortcut" + addName
                #shortcut的conv层待剪枝层的名字列表
                #nameList_shortcut是裁剪所有的名字列表
                nameList_shortcut = []
                nameList_shortcut = makeNameList(pruneName, nameList_shortcut,
                                                 name)

                #resnet模型的shortcut真正剪枝
                for name in nameList_shortcut:
                    a, dtype, shape = name2array(name, weights_dict)
                    if name.endswith("weight") or name.endswith("weight-v") or \
                        name.endswith("weight-m") or name.endswith("weight-momentum"):
                        b = np.delete(a, removeIndexs, 0)
                        b = np.delete(b, lastRemoveIndexs_shortcut, 1)
                    else:
                        b = np.delete(a, removeIndexs)
                    print(name + " pruned: shape from", a.shape, "-->",
                          b.shape)
                    if args.model_save_dir:
                        folder = os.path.join(args.model_save_dir, "model",
                                              name)
                        _SaveWeightBlob2File(b, folder, 'out')
                    modelWeight.add(
                        name,
                        list(dtype_dict.keys())[list(
                            dtype_dict.values()).index(dtype)], b.shape)
                lastRemoveIndexs_shortcut = removeIndexs

        #复制stem层
        elif "stem" in name:
            a, dtype, shape = name2array(name, weights_dict)
            b = a
            print(name + " copy")
            if args.model_save_dir:
                folder = os.path.join(args.model_save_dir, "model", name)
                _SaveWeightBlob2File(b, folder, 'out')
            modelWeight.add(
                name,
                list(dtype_dict.keys())[list(
                    dtype_dict.values()).index(dtype)], b.shape)

        #第一个dense0层剪枝
        elif name.startswith("dense"):
            if name in [
                    'dense0-weight', 'dense0-weight-v', 'dense0-weight-m',
                    'dense0-weight-momentum'
            ]:
                fcRemoveIndexsNew = []
                a, dtype, shape = name2array(name, weights_dict)
                num = int(a.shape[1] / fcDivideNum)
                for index in fcRemoveIndexs:
                    fcRemoveIndexsNew += [
                        index + fcDivideNum * i for i in range(num)
                    ]
                b = np.delete(a, fcRemoveIndexsNew, 1)
            else:
                a, dtype, shape = name2array(name, weights_dict)
                b = a
            print(name + " pruned: shape from", a.shape, "-->", b.shape)
            if args.model_save_dir:
                folder = os.path.join(args.model_save_dir, "model", name)
                _SaveWeightBlob2File(b, folder, 'out')
            modelWeight.add(
                name,
                list(dtype_dict.keys())[list(
                    dtype_dict.values()).index(dtype)], b.shape)

    print("Pruning done! Number of channel from", beforePrune, "-->",
          afterPrune)
    print("Real Pruning rate:", 100 * (beforePrune - afterPrune) / beforePrune,
          "%")
    weights_profile_path = os.path.join(args.model_save_dir,
                                        "weights_profile_path")
    modelWeight.save(weights_profile_path)
    os.system(
        'cp -r {0}/System-Train-TrainStep-TrainNet {1}/System-Train-TrainStep-TrainNet '
        .format(args.model_load_dir, os.path.join(args.model_save_dir,
                                                  "model")))