Ejemplo n.º 1
0
def prune_and_retrain(thresh):
    load_model(False)
    if prune_bool:
        ############################3
        # READ THE RANKS
        print("\nPruning the model\n")
        print("architecture for pruning: ", args.arch)
        if method == 'switch':
            epochs_num = 1
            num_samps_for_switch = args.switch_samps
            ranks_method = args.ranks_method
            #path =
            ########################################################
            # if ranks_method == 'shapley':
            #     combinationss = []
            #     shapley_file = open(
            #         "/home/user/Dropbox/Current_research/python_tests/results_shapley/combinations/94.34/zeroing_0.2val/shapley.txt")
            #     for line in shapley_file:
            #         line = line.strip()[1:-2]
            #         nums = line.split(",")
            #         nums_int = [int(i) for i in nums]
            #         combinationss.append(nums_int)
            #######################################################
            if ranks_method == 'integral':
                print(ranks_method)
                if args.switch_trainranks:
                    print("\nTraining switches\n")
                    ranks = script_vgg("switch_" + ranks_method, epochs_num,
                                       num_samps_for_switch)
                    combinationss = ranks['combinationss']
                else:
                    print("\nLoading switches\n")
                    ranks_path = path_main + "/methods/switches/VGG/integral/switch_data_cifar_integral_samps_%i_epochs_%i.npy" % (
                        args.switch_samps, args.switch_epochs)
                    combinationss = list(
                        np.load(ranks_path,
                                allow_pickle=True).item()['combinationss'])
            #######################################################
            elif ranks_method == 'point':
                print(ranks_method)
                if args.switch_trainranks:
                    print("Training switches\n")
                    ranks = script_vgg("switch_" + ranks_method, epochs_num)
                    combinationss = ranks['combinationss']
                else:
                    print("Loading switches")
                    ranks_path = path_main + '/methods/switches/VGG/point/switch_data_cifar_point_epochs_%i.npy' % (
                        args.switch_epochs)
                    combinationss = list(
                        np.load(ranks_path,
                                allow_pickle=True).item()['combinationss'])
                # these numbers from the beginning will be cut off, meaning the worse will be cut off
            for i in range(len(combinationss)):
                combinationss[i] = torch.LongTensor(
                    combinationss[i][thresh[i]:].copy())

        #################################################################
        elif method == 'l1' or method == 'l2':
            magnitude_rank.setup()
            combinationss = magnitude_rank.get_ranks(method, net)
            # the numbers from the beginning will be cut off, meaning the worse will be cut off
            for i in range(len(combinationss)):
                combinationss[i] = torch.LongTensor(
                    combinationss[i][:thresh[i]].copy())
            print(combinationss[1])

        ###############
        elif method == 'shapley':
            try:
                combinationss = shapley_rank.shapley_rank(
                    testval, net, "VGG",
                    os.path.split(model2load)[1], args.dataset, args.load_file,
                    args.k_num, args.shap_method, args.shap_sample_num)
            except KeyboardInterrupt:
                print('Interrupted')
                shapley_rank.file_check()
                try:
                    sys.exit(0)
                except SystemExit:
                    os._exit(
                        0
                    )  ####################################################################
        elif method == 'fisher':
            # in the process of finetuning we accumulate the gradient information that w eadd for each batch. We use this gradient info for constructing a ranking.
            net.module.reset_fisher()
            finetune()
            combinationss = []
            for i in range(14):
                fisher_rank = torch.argsort(net.module.running_fisher[i],
                                            descending=True)
                combinationss.append(fisher_rank.detach().cpu())
            # these numbers from the beginning will be cut off, meaning the worse will be cut off
            for i in range(len(combinationss)):
                combinationss[i] = torch.LongTensor(
                    combinationss[i][:thresh[i]])
            print(combinationss[1])
        # PRINT THE PRUNED ARCHITECTURE
        remaining = []
        for i in range(len(combinationss)):
            print(cfg['VGG15'][i], len(combinationss[i]))
            remaining.append(int(cfg['VGG15'][i]) - len(combinationss[i]))
        print(remaining)

        #############
        # PRUNE

        def zero_params():
            it = 0
            for name, param in net.state_dict().items():
                # print(name, param.shape)
                if "module.c" in name and "weight" in name:
                    it += 1
                    param.data[combinationss[it - 1]] = 0
                    # print(param.data)
                if "module.c" in name and "bias" in name:
                    param.data[combinationss[it - 1]] = 0
                    # print(param.data)
                if ("bn" in name) and ("weight" in name):
                    param.data[combinationss[it - 1]] = 0
                if ("bn" in name) and ("bias" in name):
                    param.data[combinationss[it - 1]] = 0
                if ("bn" in name) and ("running_mean" in name):
                    param.data[combinationss[it - 1]] = 0
                if ("bn" in name) and ("running_var" in name):
                    param.data[combinationss[it - 1]] = 0

        zero_params()

        print("After pruning")
        test(-1)

        ######################
        # GRAD
        print("Gradients for retraining")

        # def gradi1(module):
        #     module[combinationss[0]] = 0
        #     # print(module[21])
        def gradi_new(combs_num):
            def hook(module):
                module[combinationss[combs_num]] = 0

            return hook

        net.module.c1.weight.register_hook(gradi_new(0))
        net.module.c1.bias.register_hook(gradi_new(0))
        net.module.bn1.weight.register_hook(gradi_new(0))
        net.module.bn1.bias.register_hook(gradi_new(0))
        net.module.c2.weight.register_hook(gradi_new(1))
        net.module.c2.bias.register_hook(gradi_new(1))
        net.module.bn2.weight.register_hook(gradi_new(1))
        net.module.bn2.bias.register_hook(gradi_new(1))
        net.module.c3.weight.register_hook(gradi_new(2))
        net.module.c3.bias.register_hook(gradi_new(2))
        net.module.bn3.weight.register_hook(gradi_new(2))
        net.module.bn3.bias.register_hook(gradi_new(2))
        net.module.c4.weight.register_hook(gradi_new(3))
        net.module.c4.bias.register_hook(gradi_new(3))
        net.module.bn4.weight.register_hook(gradi_new(3))
        net.module.bn4.bias.register_hook(gradi_new(3))
        h1 = net.module.c5.weight.register_hook(gradi_new(4))
        h1 = net.module.c5.bias.register_hook(gradi_new(4))
        h12 = net.module.bn5.weight.register_hook(gradi_new(4))
        h13 = net.module.bn5.bias.register_hook(gradi_new(4))
        h1 = net.module.c6.weight.register_hook(gradi_new(5))
        h1 = net.module.c6.bias.register_hook(gradi_new(5))
        h12 = net.module.bn6.weight.register_hook(gradi_new(5))
        h13 = net.module.bn6.bias.register_hook(gradi_new(5))
        h1 = net.module.c7.weight.register_hook(gradi_new(6))
        h1 = net.module.c7.bias.register_hook(gradi_new(6))
        h12 = net.module.bn7.weight.register_hook(gradi_new(6))
        h13 = net.module.bn7.bias.register_hook(gradi_new(6))
        h1 = net.module.c8.weight.register_hook(gradi_new(7))
        h1 = net.module.c8.bias.register_hook(gradi_new(7))
        h12 = net.module.bn8.weight.register_hook(gradi_new(7))
        h13 = net.module.bn8.bias.register_hook(gradi_new(7))
        h1 = net.module.c9.weight.register_hook(gradi_new(8))
        h1 = net.module.c9.bias.register_hook(gradi_new(8))
        h12 = net.module.bn9.weight.register_hook(gradi_new(8))
        h13 = net.module.bn9.bias.register_hook(gradi_new(8))
        h1 = net.module.c10.weight.register_hook(gradi_new(9))
        h1 = net.module.c10.bias.register_hook(gradi_new(9))
        h12 = net.module.bn10.weight.register_hook(gradi_new(9))
        h13 = net.module.bn10.bias.register_hook(gradi_new(9))
        h1 = net.module.c11.weight.register_hook(gradi_new(10))
        h1 = net.module.c11.bias.register_hook(gradi_new(10))
        h12 = net.module.bn11.weight.register_hook(gradi_new(10))
        h13 = net.module.bn11.bias.register_hook(gradi_new(10))
        h1 = net.module.c12.weight.register_hook(gradi_new(11))
        h1 = net.module.c12.bias.register_hook(gradi_new(11))
        h12 = net.module.bn12.weight.register_hook(gradi_new(11))
        h13 = net.module.bn12.bias.register_hook(gradi_new(11))
        h1 = net.module.c13.weight.register_hook(gradi_new(12))
        h1 = net.module.c13.bias.register_hook(gradi_new(12))
        h12 = net.module.bn13.weight.register_hook(gradi_new(12))
        h13 = net.module.bn13.bias.register_hook(gradi_new(12))
        h1 = net.module.l1.weight.register_hook(gradi_new(13))
        h1 = net.module.l1.bias.register_hook(gradi_new(13))
        h12 = net.module.l1.weight.register_hook(gradi_new(13))
        h13 = net.module.l1.bias.register_hook(gradi_new(13))

    #######################################################
    # RETRAIN
    if retrain_bool:
        print("\nRetraining\n")
        net.train()
        stop = 0
        epoch = 0
        best_accuracy = 0
        early_stopping = 100
        optimizer = optim.SGD(net.parameters(),
                              lr=0.0001,
                              momentum=0.9,
                              weight_decay=5e-4)
        while (stop < early_stopping):
            epoch = epoch + 1
            for i, data in enumerate(trainloader):
                inputs, labels = data
                inputs, labels = inputs.to(device), labels.to(device)
                optimizer.zero_grad()
                outputs = net(inputs)
                loss = criterion(outputs, labels)
                loss.backward()
                #net.module.c2.weight.grad  #  just check the gradient
                optimizer.step()
                # net.c1.weight.data[1] = 0  # instead of hook
                # net.c1.bias.data[1] = 0  # instead of hook

                if args.prune:
                    zero_params()

            print(loss.item())
            accuracy = test(-1)
            # print(net.module.c2.weight.data)
            print("Epoch " + str(epoch) + " ended.")
            if (accuracy <= best_accuracy):
                stop = stop + 1
            else:
                if accuracy > 90.5:
                    # compares accuracy and best_accuracy by itself again
                    best_accuracy = save_checkpoint(epoch, accuracy,
                                                    best_accuracy, remaining)
                print("Best updated")
                stop = 0
        print(loss.item())
        accuracy = test(-1)
def get_ranks(method, path_checkpoint):
    print(f"Ranking method {method}")

    if method == 'random':
        combinationss = [
            np.random.permutation(nodesNum1),
            np.random.permutation(nodesNum2),
            np.random.permutation(nodesFc1),
            np.random.permutation(nodesFc2)
        ]

    elif method == 'fisher':
        finetune()
        combinationss = []
        for i in range(4):
            fisher_rank = np.argsort(
                net.running_fisher[i].detach().cpu().numpy())[::-1]
            combinationss.append(fisher_rank)

    elif method == 'shapley':
        load_file = args.load_file

        try:
            combinationss = shapley_rank.shapley_rank(
                evaluate, net, "Lenet",
                os.path.split(path_checkpoint)[1], dataset, load_file,
                args.k_num, args.shap_method, args.shap_sample_num,
                args.adding, args.layer)
        except KeyboardInterrupt:
            print('Interrupted')
            shapley_rank.file_check("combin")
            try:
                sys.exit(0)
            except SystemExit:
                os._exit(0)

    elif method == "switch_integral":

        #train or load
        getranks_method = args.switch_comb
        switch_data = {}
        switch_data['combinationss'] = []
        switch_data['switches'] = []
        num_samps_for_switch = args.switch_samps
        print("integral evaluation")
        epochs_num = 3
        file_path = os.path.join(
            path_main,
            'results_switch/results/switch_data_%s_9927_integral_samps_%s_epochs_%i.npy'
            % (dataset, str(num_samps_for_switch), epochs_num))
        if getranks_method == 'train':
            for layer in ["c1", "c3", "c5", "f6"]:
                best_accuracy, epoch, best_model, S = run_experiment_integral(
                    epochs_num, layer, 10, 20, 100, 25, num_samps_for_switch,
                    path)
                print(
                    "Rank for switches from most important/largest to smallest after %s "
                    % str(epochs_num))
                print(S)
                print("max: %.4f, min: %.4f" % (torch.max(S), torch.min(S)))
                ranks_sorted = np.argsort(S.cpu().detach().numpy())[::-1]
                print(",".join(map(str, ranks_sorted)))
                switch_data['combinationss'].append(ranks_sorted)
                switch_data['switches'].append(S.cpu().detach().numpy())
            print('*' * 30)
            print(switch_data['combinationss'])
            combinationss = switch_data['combinationss']
            np.save(file_path, switch_data)
        elif getranks_method == 'load':
            combinationss = list(
                np.load(file_path, allow_pickle=True).item()['combinationss'])

    elif method == "switch_point":
        getranks_method = args.switch_comb
        switch_data = {}
        switch_data['combinationss'] = []
        switch_data['switches'] = []
        epochs_num = 1
        path_switches = "../methods/switches/Lenet"
        if getranks_method == 'train':
            for layer in ["c1", "c3", "c5", "f6"]:
                print(f"\nLayer: {layer}")
                best_accuracy, epoch, best_model, S = run_experiment_pointest(
                    epochs_num, layer, 10, 20, 100, 25, path_checkpoint, args)
                print(
                    "Rank for switches from most important/largest to smallest after %s "
                    % str(epochs_num))
                print(S)
                print("max: %.4f, min: %.4f" % (torch.max(S), torch.min(S)))
                ranks_sorted = np.argsort(S.cpu().detach().numpy())[::-1]
                print(",".join(map(str, ranks_sorted)))
                switch_data['combinationss'].append(ranks_sorted)
                switch_data['switches'].append(S.cpu().detach().numpy())
            print(switch_data['combinationss'])
            combinationss = switch_data['combinationss']
            # save switches
            if not os.path.exists(path_switches):
                os.makedirs(path_switches)
            file_path = os.path.join(
                path_switches,
                f"switches_{dataset}_{epochs_num}_{path_checkpoint[-5:]}.npy")
            np.save(file_path, switch_data)
        elif getranks_method == 'load':
            switches_files = os.listdir(path_switches)
            for file in switches_files:
                if (file[-9:-4] == path_checkpoint[-5:]):
                    path_switches_file = os.path.join(path_switches, file)
                    combinationss = list(
                        np.load(path_switches_file,
                                allow_pickle=True).item()['combinationss'])

    elif method == "switch_point_multiple":
        file_path = os.path.join(
            path_main, 'results_switch/results/combinations_multiple_9032.npy')
        combinationss = list(np.load(file_path, allow_pickle=True))

    else:
        combinationss = magnitude_rank.get_ranks(method, net)

    return combinationss