Exemplo n.º 1
0
def iterative_test_conv(settings, network_type=2, filename=""):
    percents, iterations = generate_percentages(
        [1.0, 1.0, 1.0], 0.02, settings['pruning_percentages'])
    histories = np.zeros(iterations + 1)
    es_epochs = np.zeros(iterations + 1)

    for trial in range(TRIALS):
        if network_type == 2:
            og_network = CONV2_NETWORK(settings)
        elif network_type == 4:
            og_network = CONV4_NETWORK(settings)
        elif network_type == 6:
            og_network = CONV6_NETWORK(settings)

        #Save initial weights of the original matrix
        init_weights = og_network.get_weights()

        #Train original Network
        mask = prune(og_network, 1.0, 1.0, 1.0)

        _, epoch = og_network.fit_batch(x_train, y_train, mask, init_weights,
                                        x_test, y_test)
        es_epochs[0] += epoch

        #Evaluate original network and save results
        _, test_acc = og_network.evaluate_model(x_test, y_test)
        histories[0] += test_acc

        #Prune the network for x amount of iterations, evaulate each iteration and save results
        for i in range(0, iterations):
            print("Conv %: " + str(percents[i][0]) + ", Dense %: " +
                  str(percents[i][1]) + ", Output %: " + str(percents[i][2]))
            mask = prune(og_network, percents[i][0], percents[i][1],
                         percents[i][2])
            #for w in masked_weights:
            #    print(np.count_nonzero(w==0)/np.size(w))

            if network_type == 2:
                pruned_network = CONV2_NETWORK(settings)
            elif network_type == 4:
                pruned_network = CONV4_NETWORK(settings)
            elif network_type == 6:
                pruned_network = CONV6_NETWORK(settings)

            _, epoch = pruned_network.fit_batch(x_train, y_train, mask,
                                                init_weights, x_test, y_test)
            _, test_acc = pruned_network.evaluate_model(x_test, y_test)
            histories[i + 1] += test_acc
            es_epochs[i + 1] += epoch

            og_network = pruned_network

        filename += "_trial-" + str(trial + 1)

        print(histories)
        print(es_epochs)
        np.savez(filename + ".npz", histories=histories, es_epochs=es_epochs)
    return histories, es_epochs
def prune_validation(data):
    '''
    Performs complete evaluation of the decision tree algorithm with pruning

    Here are the steps:
    1. Split data into TEST and TRAINING+VALIDATION (x10 times)
        2. Split TRAINING+VALIDATION into TRAINING and VALIDATION (x9 times)
            3. For each TRAINING and VALIDATION:
                a) Train a tree using TRAINING
                b) Prune a tree using VALIDATION
                c) Test each pruned tree using TEST (9 trees x 10 test datasets = 90 measures)

    :param data: full dataset (clean_dataset OR noisy_dataset)
    :return: all_90_measures: list of [measures1, measures2, ..., measures90]
                                    where measure1 = [classification_rate, ...]
    '''

    # Shuffle and divide data
    divided_data = divide_data(data, 10)

    avg_errors = []
    all_90_measures = []

    for i in range(10):
        # Split TEST and TRAINING+VALIDATION (x10 times)
        test_data = divided_data[i]
        errors_on_this_test = []

        # Split TRAINING+VALIDATION ---> TRAINING and VALIDATION (x9 times)
        for j in range(1, 10):

            validation_data = divided_data[(i + j) % 10]
            training_data = np.concatenate([
                a for a in divided_data if not (a == test_data).all()
                and not (a == validation_data).all()
            ])

            # Train a tree
            tree = create_tree(training_data)
            decision_tree_learning(tree)

            # Prune tree on VALIDATION
            pruned_tree = prune(tree, validation_data)

            # Calculate error of pruned tree on TEST
            errors_on_this_test.append(1 - evaluate(test_data, pruned_tree)[0])

            # Evaluate pruned on TEST
            measures = evaluate(test_data, pruned_tree)

            # Collect all measures of the pruned tree
            all_90_measures.append(measures)

        # Collect error stats
        avg_err_on_this_test = sum(errors_on_this_test) / len(
            errors_on_this_test)
        avg_errors.append(avg_err_on_this_test)
        total_error = sum(avg_errors) / len(avg_errors)

    return all_90_measures
Exemplo n.º 3
0
def iterative_pruning_experiment():
    """
    - For multiple trials:
        - Create the original network
        - Train original network, save results
        - For multiple iterations:
            - Create mask with the pruning precentages generated by generate_percentages function in tools.py
            - Prune the network with the pruning percentages
            - Train the pruned network, save results
            - Set the original network to be the pruned network
    """
    trials = SETTINGS['trials']
    percents = [0.0, 1.0, 1.0, 1.0]
    percentages, iterations = generate_percentages(percents,
                                                   SETTINGS['lower_bound'])
    histories = np.zeros((iterations + 1, SETTINGS['n_epochs']))
    #es_epochs = np.zeros((trials, iterations+1))
    for k in range(0, trials):
        print("TRIAL " + str(k + 1) + "/" + str(trials))
        og_network = FC_NETWORK(use_earlyStopping=SETTINGS['use_es'])
        init_weights = og_network.weights_init
        mask = prune(og_network, percents)
        acc_history, _ = og_network.fit_batch(x_train, y_train, mask,
                                              init_weights, SETTINGS, x_test,
                                              y_test)
        histories[iterations] += np.asarray(acc_history)
        #es_epochs[k, iterations] = epoch
        for i in range(0, iterations):
            S = percentages[i]

            print("Prune iteration: " + str(i + 1) + "/" + str(iterations) +
                  ", S: " + str(S))
            print("Creating the pruned network")
            mask = prune(og_network, S)
            pruned_network = FC_NETWORK(use_earlyStopping=SETTINGS['use_es'])
            acc_history, _ = pruned_network.fit_batch(x_train, y_train, mask,
                                                      init_weights, SETTINGS,
                                                      x_test, y_test)
            histories[i] += np.asarray(acc_history)
            #es_epochs[k, i] = epoch
            og_network = pruned_network

    histories = histories / trials
    #es_epochs = float(es_epochs/trials)
    np.savez("data/iterpr_lenet_20perc.npz", histories=histories)
Exemplo n.º 4
0
def main():

    options = opts()
    
    src = options.source
    ad = options.dest
    
    d = getLockDir(ad)
    unf = checkUnfinished(ad)

    newDir = backupDir(ad)
    if os.path.exists(newDir):
        os.removedirs(d)
        raise ValueError("path "+newDir+" exists - wait 1 minute between backups")
    
    
    all = getAllBackups(ad)
    latest = latestBackup(all)

    if latest:
        if options.verbose:
            print "Making hard-link copy of previous backup.."
        os.system("cp -lR "+latest+" "+unf)

    if options.verbose:
        vstring="v"
    else:
        vstring=""

    fstring = ""
    if options.filter:
        for fil in options.filter:
            fstring = fstring + ' --filter="%s" ' % fil
        
    os.system('rsync -az'+vstring + fstring + ' --force --relative --hard-links --delete "'+src+'" "'+unf+'"')
    if options.verbose:
        print "now renaming ",unf," to ",newDir
    os.renames(unf, newDir)
    os.removedirs(d)

    if options.prune:
        pruning.prune(options,ad)
Exemplo n.º 5
0
def one_shot_pruning_experiment():
    """
    - For multiple trials:
        - Train original network and evaluate results
        - Find mask depending on weights of original network 
        - Apply mask to a new network and disable corresponding weights
        - Train the new pruned network and evaluate results
    """
    percentages, iterations = generate_percentages([0.0, 1.0, 1.0, 1.0],
                                                   SETTINGS['lower_bound'])
    #print(percentages)
    #print(iterations)
    trials = SETTINGS['trials']
    og_networks = list()
    tot_acc = np.zeros(iterations + 1)
    tot_loss = np.zeros(iterations + 1)
    tot_epoch = np.zeros(iterations + 1)

    # Training and evaluating the unpruned network over multiple trials
    print("Training and Evaluating OG networks")
    percents = [0.0, 1.0, 1.0, 1.0]

    for i in range(0, trials):
        print("TRIAL " + str(i + 1) + "/" + str(trials))
        og_networks.append(FC_NETWORK(use_earlyStopping=SETTINGS['use_es']))
        mask = prune(og_networks[i], percents)
        _, epoch = og_networks[i].fit_batch(x_train, y_train, mask,
                                            og_networks[i].weights_init,
                                            SETTINGS, x_test, y_test)
        test_loss, test_acc = og_networks[i].evaluate_model(x_test, y_test)
        tot_acc[0] += test_acc
        tot_loss[0] += test_loss
        tot_epoch[0] += epoch
        print(epoch)
    tot_acc[0] = float(tot_acc[0] / trials)
    tot_loss[0] = float(tot_loss[0] / trials)
    tot_epoch[0] = float(tot_epoch[0] / trials)
    # Training and evaluating pruned networks of different pruning rates over multiple trials
    for j in range(1, iterations + 1):
        print("Training and Evaluating pruned networks, iteration: " + str(j) +
              "/" + str(iterations))
        print("Percentage: " + str(percentages[j - 1]))
        for og in og_networks:
            mask = prune(og, percentages[j - 1])
            pruned_network = FC_NETWORK(use_earlyStopping=SETTINGS['use_es'])
            _, epoch = pruned_network.fit_batch(x_train, y_train, mask,
                                                og.weights_init, SETTINGS,
                                                x_test, y_test)
            print(epoch)
            test_loss, test_acc = pruned_network.evaluate_model(x_test, y_test)
            tot_acc[j] += test_acc
            tot_loss[j] += test_loss
            tot_epoch[j] += epoch

        tot_acc[j] = float(tot_acc[j] / trials)
        tot_loss[j] = float(tot_loss[j] / trials)
        tot_epoch[j] = float(tot_epoch[j] / trials)
        print(tot_epoch)

    print(tot_acc)
    print(tot_loss)
    print(tot_epoch)
Exemplo n.º 6
0
file_to_test = args.test_data

Making_the_tree.min_leaf_size = int(args.min_leaf_size)

if args.absolute is True:
    Making_the_tree.error_function = Making_the_tree.abs_error_of_data
    pruning.error_function = pruning.abs_error_of_data
else:
    Making_the_tree.error_function = Making_the_tree.mean_squared_error_of_data
    pruning.error_function = pruning.mean_squared_error_of_data

# Reading the data
train_data = pd.read_csv(file_to_train_from)
test_data = pd.read_csv(file_to_test)
output = test_data.columns.values[-1]

# Separating data for training ,and pruning
a, b = train_data.shape

tree_making_data = train_data.loc[:int(2 * a / 3), :]
pruning_data = train_data.loc[int(2 * a / 3):, :]

tree = make_tree(train_data)
pruned_tree = prune(tree, pruning_data)

time_mean_square_error_df = 'min_leaf_size delta_t mean_squared_error'.split()

out_df = prediction(test_data, tree)[["index", "prediction"]]

out_df.columns = ["Id", "output"]
out_df.to_csv('unpruned_mean_squared_wine.csv', index=False)
Exemplo n.º 7
0
def crossValidate(data_set):
    #80 because size gives total data points not no. of rows
    empty_confusion_matrix = np.zeros((4, 4))
    empty_metrics = (empty_confusion_matrix, 0.0, 0.0, 0.0, 0.0)
    unpruned_results = [[empty_metrics for a in range(9)] for b in range(10)]
    pruned_results = [[empty_metrics for a in range(9)] for b in range(10)]

    split_size = int(data_set.size / 80)

    startTime = time.time()
    #split out the testing data
    for i in range(10):

        split_set = np.split(data_set, [i * split_size, (i + 1) * split_size])
        test_set = split_set[1]
        set_without_test = np.concatenate((split_set[0], split_set[2]), axis=0)

        #split out the validation
        for j in range(9):

            split_training_set = np.split(
                set_without_test, [j * split_size, (j + 1) * split_size])
            validation_set = split_training_set[1]
            training_set = np.concatenate(
                (split_training_set[0], split_training_set[2]), axis=0)

            tree = getTree(training_set, 0)
            #print("Depth of tree:")
            #print(tree[1])
            tree = tree[0]
            unpruned_results[i][j] = evaluate(test_set, tree)
            pruned_tree = prune(tree, validation_set)
            pruned_results[i][j] = evaluate(test_set, pruned_tree)

            #stuff for printing nicely
            percent = (float(i * 9) + float(j + 1)) / 0.9
            timeElapsed = time.time() - startTime
            timeLeft = timeElapsed / percent * (100 - percent)
            print("\r\t",
                  round(percent, 2),
                  "%\t Time elapsed: ",
                  int(timeElapsed / 3600),
                  ":",
                  int((timeElapsed / 60) % 60),
                  ":",
                  int(timeElapsed % 60),
                  "\t Time left: ",
                  int(timeLeft / 3600),
                  ":",
                  int((timeLeft / 60) % 60),
                  ":",
                  int(timeLeft % 60),
                  end="      ",
                  sep="")

    average_unpruned_results = average_metrics(unpruned_results)
    average_pruned_results = average_metrics(pruned_results)

    print("Done:")
    print()
    print("\nResults before pruning:\n")
    print_metrics(average_unpruned_results)
    print("\nResults after pruning:\n")
    print_metrics(average_pruned_results)
Exemplo n.º 8
0
def plot_tree(tree):

    width = 2000

    max_nodes_in_layer = max([len(layer) for layer in tree.node_list])

    for layer in tree.node_list:
    # # Prune tree on validation data
        for i, node in enumerate(layer):
            if node.children != None:
                node.children[0].coord[0] = node.coord[0]  - width
                node.children[0].coord[1] = node.coord[1] - 20
                node.children[1].coord[0] = node.coord[0]  + width
                node.children[1].coord[1] = node.coord[1] - 20
            else:
                labels = [sample[-1] for sample in node.dataset]
                node.label = max(set(labels), key=labels.count)


        width = width * 0.5

    for layer in tree.node_list:
        for node in layer:
            node_x = node.coord[0]
            node_y = node.coord[1]
            if node.children != None:
                wifi, value = node.split_attribute[1][2:]
                plt.text(node_x, node_y, f"Wifi {wifi} <= {value}?", size=10,
               ha="center", va="center",
               bbox=dict(boxstyle="round",
                         ec=(0.2, 0.5, 0.5),
                         fc=(0.2, 0.8, 0.8),
                         )
               )
                for child in node.children:
                    xt = [node_x,  child.coord[0]]
                    yt = [node_y, child.coord[1]]
                    plt.plot(xt, yt)
            else:
                label = node.label
                plt.text(node_x, node_y, f"{label}", size=10,
                         ha="center", va="center",
                         bbox=dict(boxstyle="round",
                                 ec=(1, 0.5, 0.5),
                                 fc=(1, 0.8, 0.8),))

    plt.show()


    # Prune tree on validation data
    pruned_tree = prune(tree, validation_data)

    for layer in tree.node_list:
        for i, node in enumerate(layer):
            if node.children != None:
                node.children[0].coord[0] = node.coord[0]  - width  #len(tree.node_list)/(tree.             node_list.index(layer) +1)
                node.children[0].coord[1] = node.coord[1] - 1 # -1 depth
                node.children[1].coord[0] = node.coord[0]  + width  #len(tree.node_list)/(tree.             node_list.index(layer)+1)
                node.children[1].coord[1] = node.coord[1] - 1 # -1 depth

            else:
                labels = [sample[-1] for sample in node.dataset]
                node.label = max(set(labels), key=labels.count)
        width = width * 0.5


    for layer in tree.node_list:
        for node in layer:
            node_x = node.coord[0]
            node_y = node.coord[1]
            if node.children != None:
                wifi, value = node.split_attribute[1][2:]
                plt.text(node_x, node_y, f"Wifi {wifi} <= {value}?", size=10,
              ha="center", va="center",
              bbox=dict(boxstyle="round",
                        ec=(0.2, 0.5, 0.5),
                        fc=(0.2, 0.8, 0.8),
                        )
              )

                for child in node.children:
                    xt = [node_x,  child.coord[0]]
                    yt = [node_y, child.coord[1]]
                    plt.plot(xt, yt)

            else:
                label = node.label
                plt.text(node_x, node_y, f"{label}", size=10,
                    ha="center", va="center",
                    bbox=dict(boxstyle="round",
                            ec=(1, 0.5, 0.5),
                            fc=(1, 0.8, 0.8),))

    plt.show()
Exemplo n.º 9
0
print("\tCreating tree from clean dataset...")
cleanTree, cleanDepth = build_decision_tree(cleanTreeSet, 0)
print("\tDepth of cleanTree:\t", cleanDepth)

print("\t------------------------------------")
print("\tCreating tree from noisy dataset...")
noisyTree, noisyDepth = build_decision_tree(noisyTreeSet, 0)
print("\tDepth of noisyTree:\t", noisyDepth)

# If requested, print visualized trees to file
# (original and pruned version)
if fileout != "":
    visualClean = visualize(cleanTree)
    visualNoisy = visualize(noisyTree)

    prunedClean = prune(cleanTree, cleanTest)
    prunedNoisy = prune(noisyTree, noisyTest)

    visualPrunedClean = visualize(prunedClean)
    visualPrunedNoisy = visualize(prunedNoisy)
    try:
        f = open(fileout, "w+")
        f.write("**** CLEAN TREE -> NOT pruned ****\n")
        f.write(visualClean)
        f.write("\n\n\n")
        f.write("**** CLEAN TREE -> pruned ****\n")
        f.write(visualPrunedClean)
        f.write("\n\n\n")
        f.write("**** NOISY TREE -> NOT pruned) ****\n")
        f.write(visualNoisy)
        f.write("\n\n\n")
Exemplo n.º 10
0
def main():
    parser = argparse.ArgumentParser(
        description='prune face recognition model')

    # 剪枝
    parser.add_argument('--lr',
                        default=0.01,
                        type=float,
                        help='retrain学习率, 一般为训练时的1/10')
    parser.add_argument('--weight_decay',
                        default=4e-5,
                        type=float,
                        help='学习率衰减')
    parser.add_argument('--save_pruned_model_root',
                        default='work_space/models/pruned_model/',
                        help='剪枝模型定义和文件保存文件夹')
    parser.add_argument('--momentum', default=0.9, type=float)

    parser.add_argument('--epoch', default=30, type=int, help='剪枝后重训练多少个epoch')
    parser.add_argument('--head_path', default=None, help='训练头')
    parser.add_argument('--device', default='cuda:0')

    parser.add_argument('--print_freq',
                        type=int,
                        default=1,
                        help='每隔多少次打印准确度信息')
    parser.add_argument('--save_model_pt',
                        default=False,
                        action='store_true',
                        help='是否保存pt文件')

    parser.add_argument('--batch_size', type=int, default=256)
    parser.add_argument('--embedding_size', type=int, default=512)
    parser.add_argument('--pruned_save_model_path',
                        default='work_space/pruned_model',
                        help='剪枝后模型保存路径')
    parser.add_argument('--sensitivity_csv_path',
                        default='work_space/sensitivity_data',
                        help='剪枝敏感度分析后的csv文件保存路径')

    # 每次运行需要确定以下参数
    parser.add_argument(
        '--mode',
        default=None,
        choices=['prune', 'quantization', 'test', 'sa', 'finetune'],
        help='prune表示仅仅剪枝,quantization表示量化'
        'sa表示sensitivity analysis,'
        'finetune表示剪枝并finetune')

    parser.add_argument('--best_model_path',
                        default=None,
                        help='已经训练好的最好的模型文件路径,准备用来剪枝')
    parser.add_argument('--test_root_path', default=None, help='测试集root路径')
    parser.add_argument('--img_list_label_path',
                        default=None,
                        help='测试集pair list路径')
    parser.add_argument('--model',
                        default=None,
                        choices=[
                            'mobilefacenet', 'resnet34', 'mobilefacenet_y2',
                            'resnet50', 'resnet100', 'mobilefacenet_lzc',
                            'mobilenetv3', 'resnet34_lzc', 'resnet50_imagenet',
                            'mobilefacenet_y2_ljt', 'shufflefacenet_v2_ljt',
                            'resnet_50_ljt', 'resnet_100_ljt'
                        ],
                        help='对哪个模型剪枝')

    parser.add_argument('--is_save',
                        default=False,
                        action='store_true',
                        help='是否保存模型文件')

    parser.add_argument('--from_data_parallel',
                        action='store_true',
                        default=False,
                        help='模型是否来自多卡训练')

    parser.add_argument(
        '--data_source',
        choices=['lfw', 'company', 'company_zkx'],
        default='None',
        help=
        '测试时使用哪个测试集, company->zy的resnet50和resnet100, company_zkx->zkx的mobilefacenet_y2'
    )

    parser.add_argument('--fpgm',
                        action='store_true',
                        default=False,
                        help='是否使用几何中位数剪枝')
    parser.add_argument('--hrank',
                        action='store_true',
                        default=False,
                        help='是否使用HRank剪枝')
    parser.add_argument('--rank_path',
                        default='./work_space/rank_conv/',
                        help='HRank配置文件')

    parser.add_argument('--yaml_path',
                        default='yaml_file/auto_yaml.yaml',
                        help='剪枝配置文件')

    parser.add_argument('--cal_flops_and_forward',
                        default=False,
                        action='store_true',
                        help='是否测试flops和前向时间')

    parser.add_argument('--test_batch_size', type=int, default=256)

    # 下面是量化时需要确定的参数
    parser.add_argument('--quantize-mode',
                        type=str,
                        choices=['symmetric', 'asymmetric-signed', 'unsigned'],
                        default='symmetric',
                        help='量化模式,将权重值映射到对称,有符号非对称和无符号非对称区间')

    parser.add_argument('--fp16',
                        action='store_true',
                        default=False,
                        help='采用半精度量化,设置了此模式,上面的模式都会失效')

    parser.add_argument('--input_size', type=int, default=112, help='输出图片大小')

    parser.add_argument('--quantized_save_model_path',
                        default='work_space/quantized_model',
                        help='量化后模型保存路径')

    # finetune时所需参数
    parser.add_argument('--pruned_checkpoint',
                        type=str,
                        default=None,
                        help='剪枝后的模型文件路径')
    parser.add_argument('--train_data_path',
                        type=str,
                        default=None,
                        help='finetune所需训练集的路径')
    parser.add_argument('--milestones',
                        type=str,
                        default='12,15,18',
                        help='规定在第几个epoch学习率下降')
    parser.add_argument('--train_batch_size',
                        type=int,
                        default=64,
                        help='训练batch size')
    parser.add_argument('--pin_memory', type=bool, default=True)
    parser.add_argument('--num_workers', type=int, default=4)
    parser.add_argument('--work_path',
                        type=str,
                        default='work_space/finetune',
                        help='训练过程产生的文件存放目录')
    parser.add_argument('--finetune_pruned_model',
                        action='store_true',
                        default=False,
                        help='finetune 剪枝后的模型')

    args = parser.parse_args()

    if args.mode == 'prune':
        prune(args)

    elif args.mode == 'sa':
        sensitivity_analysis(args)

    elif args.mode == 'quantization':
        quantization(args)

    elif args.mode == 'finetune':
        args.work_path = os.path.join(args.work_path, get_time())
        os.mkdir(args.work_path)

        args.log_path = os.path.join(args.work_path, 'log')
        args.save_path = os.path.join(args.work_path, 'save')
        args.model_path = os.path.join(args.work_path, 'model')

        os.mkdir(args.log_path)
        os.mkdir(args.save_path)
        os.mkdir(args.model_path)

        args.log_path = os.path.join(args.log_path, get_time())
        args.milestones = list(map(int, args.milestones.split(',')))

        learner = face_learner(args)
        learner.train(args)
Exemplo n.º 11
0
def main():
    parser = argparse.ArgumentParser(
        description='prune face recognition model')

    # 剪枝
    parser.add_argument('--lr',
                        default=0.01,
                        type=float,
                        help='retrain学习率, 一般为训练时的1/10')
    parser.add_argument('--weight_decay',
                        default=4e-5,
                        type=float,
                        help='学习率衰减')
    parser.add_argument('--save_pruned_model_root',
                        default='work_space/models/pruned_model/',
                        help='剪枝模型定义和文件保存文件夹')
    parser.add_argument('--momentum', default=0.9, type=float)

    parser.add_argument('--epoch', default=30, type=int, help='剪枝后重训练多少个epoch')
    parser.add_argument(
        '--head_path',
        default='work_space/pre_head/head/10k/head_2019-05-17-05-02_accuracy_0'
        '.0_step_111910_None.pth',
        help='训练头')
    parser.add_argument('--device', default='cuda:0')

    parser.add_argument('--print_freq',
                        type=int,
                        default=1,
                        help='每隔多少次打印准确度信息')
    parser.add_argument('--save_model_pt',
                        default=False,
                        action='store_true',
                        help='是否保存pt文件')

    parser.add_argument('--batch_size', type=int, default=100)
    parser.add_argument('--embedding_size', type=int, default=512)
    parser.add_argument('--pruned_save_model_path',
                        default='work_space/pruned_model',
                        help='剪枝后模型保存路径')
    parser.add_argument('--sensitivity_csv_path',
                        default='work_space/sensitivity_data',
                        help='剪枝敏感度分析后的csv文件保存路径')

    # 每次运行需要确定以下参数
    parser.add_argument(
        '--mode',
        default=None,
        choices=['prune', 'quantization', 'test', 'sa', 'finetune'],
        help='prune表示仅仅剪枝,quantization表示量化'
        'sa表示sensitivity analysis,'
        'finetune表示剪枝并finetune')

    parser.add_argument('--best_model_path',
                        default=None,
                        help='已经训练好的最好的模型文件路径,准备用来剪枝')
    parser.add_argument('--test_root_path', default=None, help='测试集root路径')
    parser.add_argument('--img_list_label_path',
                        default=None,
                        help='测试集pair list路径')
    parser.add_argument('--model',
                        default=None,
                        choices=[
                            'mobilefacenet', 'resnet34', 'mobilefacenet_y2',
                            'resnet50', 'resnet100', 'mobilefacenet_lzc',
                            'mobilenetv3', 'resnet34_lzc'
                        ],
                        help='对哪个模型剪枝')

    parser.add_argument('--is_save',
                        default=False,
                        action='store_true',
                        help='是否保存模型文件')

    parser.add_argument('--from_data_parallel',
                        action='store_true',
                        default=False,
                        help='模型是否来自多卡训练')

    parser.add_argument(
        '--data_source',
        choices=['lfw', 'company', 'company_zkx'],
        default='None',
        help=
        '测试时使用哪个测试集, company->zy的resnet50和resnet100, company_zkx->zkx的mobilefacenet_y2'
    )

    parser.add_argument('--fpgm',
                        action='store_true',
                        default=False,
                        help='是否使用几何中位数剪枝')
    parser.add_argument('--hrank',
                        action='store_true',
                        default=False,
                        help='是否使用HRank剪枝')
    parser.add_argument('--rank_path',
                        default='./work_space/rank_conv/',
                        help='HRank配置文件')

    parser.add_argument('--yaml_path',
                        default='yaml_file/auto_yaml.yaml',
                        help='剪枝配置文件')

    parser.add_argument('--cal_flops_and_forward',
                        default=False,
                        action='store_true',
                        help='是否测试flops和前向时间')

    parser.add_argument('--test_batch_size', type=int, default=256)

    # 下面是量化时需要确定的参数
    parser.add_argument('--quantize-mode',
                        type=str,
                        choices=['symmetric', 'asymmetric-signed', 'unsigned'],
                        default='symmetric',
                        help='量化模式,将权重值映射到对称,有符号非对称和无符号非对称区间')

    parser.add_argument('--fp16',
                        action='store_true',
                        default=False,
                        help='采用半精度量化,设置了此模式,上面的模式都会失效')

    parser.add_argument('--input_size', type=int, default=112, help='输出图片大小')

    parser.add_argument('--quantized_save_model_path',
                        default='work_space/quantized_model',
                        help='量化后模型保存路径')

    args = parser.parse_args()

    if args.mode == 'prune':
        prune(args)

    elif args.mode == 'sa':
        sensitivity_analysis(args)

    elif args.mode == 'quantization':
        quantization(args)

    elif args.mode == 'test':
        pass