示例#1
0
              " | CURRICULUM LEARNING: " + str(args.curriculum_learning) +
              " | MODEL: " + str(args.type_model))
        print()

        nn_model_ref = NN_Model_Ref(args, writer, device, rng, path_save_model,
                                    cipher, creator_data_binary,
                                    path_save_model_train)

        if args.retain_model_gohr_ref:
            nn_model_ref.train_general(name_input)
        else:
            #nn_model_ref.load_nn()
            try:
                if args.finetunning:
                    nn_model_ref.load_nn()
                    nn_model_ref.train_from_scractch(name_input + "fine-tune")
                #nn_model_ref.eval(["val"])
                else:
                    nn_model_ref.load_nn()
            except:
                print("ERROR")
                print("NO MODEL AVALAIBLE FOR THIS CONFIG")
                print("CHANGE ARGUMENT retain_model_gohr_ref")
                print()
                sys.exit(1)

        if args.create_new_data_for_ToT and args.create_new_data_for_classifier:
            del nn_model_ref.X_train_nn_binaire, nn_model_ref.X_val_nn_binaire, nn_model_ref.Y_train_nn_binaire, nn_model_ref.Y_val_nn_binaire
            del nn_model_ref.c0l_train_nn, nn_model_ref.c0l_val_nn, nn_model_ref.c0r_train_nn, nn_model_ref.c0r_val_nn
            del nn_model_ref.c1l_train_nn, nn_model_ref.c1l_val_nn, nn_model_ref.c1r_train_nn, nn_model_ref.c1r_val_nn
net = NN_linear(args, X_train_f.shape[1]).to(device)

nn_model_ref.net = net
nn_model_ref.X_train_nn_binaire = X_train_f
nn_model_ref.X_val_nn_binaire = X_val_f
#nn_model_ref.Y_train_nn_binaire = X_train_f
#nn_model_ref.Y_val_nn_binaire = X_val_f
"""
args.load_nn_path = "./results/create_synth_masks_v2/speck/5/ctdata0l^ctdata1l_ctdata0r^ctdata1r^ctdata0l^ctdata1l_ctdata0l^ctdata0r_ctdata1l^ctdata1r/2020_07_21_17_26_59_603174/0.9966710913033485_bestacc.pth"
nn_model_ref.net.load_state_dict(torch.load(args.load_nn_path,
                map_location=device)['state_dict'], strict=False)
nn_model_ref.net.to(device)
nn_model_ref.net.eval()
"""
nn_model_ref.train_from_scractch("AE")

for global_sparsity in [0, 0.2, 0.4]:
    print(global_sparsity)
    flag2 = True
    acc_retain = []
    parameters_to_prune = []
    for name, module in nn_model_ref.net.named_modules():
        if len(name):
            if name not in ["layers_batch", "layers_conv"]:
                flag = True
                for layer_forbidden in args.layers_NOT_to_prune:
                    if layer_forbidden in name:
                        flag = False
                if flag:
                    parameters_to_prune.append((module, 'weight'))