" | CURRICULUM LEARNING: " + str(args.curriculum_learning) + " | MODEL: " + str(args.type_model)) print() nn_model_ref = NN_Model_Ref(args, writer, device, rng, path_save_model, cipher, creator_data_binary, path_save_model_train) if args.retain_model_gohr_ref: nn_model_ref.train_general(name_input) else: #nn_model_ref.load_nn() try: if args.finetunning: nn_model_ref.load_nn() nn_model_ref.train_from_scractch(name_input + "fine-tune") #nn_model_ref.eval(["val"]) else: nn_model_ref.load_nn() except: print("ERROR") print("NO MODEL AVALAIBLE FOR THIS CONFIG") print("CHANGE ARGUMENT retain_model_gohr_ref") print() sys.exit(1) if args.create_new_data_for_ToT and args.create_new_data_for_classifier: del nn_model_ref.X_train_nn_binaire, nn_model_ref.X_val_nn_binaire, nn_model_ref.Y_train_nn_binaire, nn_model_ref.Y_val_nn_binaire del nn_model_ref.c0l_train_nn, nn_model_ref.c0l_val_nn, nn_model_ref.c0r_train_nn, nn_model_ref.c0r_val_nn del nn_model_ref.c1l_train_nn, nn_model_ref.c1l_val_nn, nn_model_ref.c1r_train_nn, nn_model_ref.c1r_val_nn
net = NN_linear(args, X_train_f.shape[1]).to(device) nn_model_ref.net = net nn_model_ref.X_train_nn_binaire = X_train_f nn_model_ref.X_val_nn_binaire = X_val_f #nn_model_ref.Y_train_nn_binaire = X_train_f #nn_model_ref.Y_val_nn_binaire = X_val_f """ args.load_nn_path = "./results/create_synth_masks_v2/speck/5/ctdata0l^ctdata1l_ctdata0r^ctdata1r^ctdata0l^ctdata1l_ctdata0l^ctdata0r_ctdata1l^ctdata1r/2020_07_21_17_26_59_603174/0.9966710913033485_bestacc.pth" nn_model_ref.net.load_state_dict(torch.load(args.load_nn_path, map_location=device)['state_dict'], strict=False) nn_model_ref.net.to(device) nn_model_ref.net.eval() """ nn_model_ref.train_from_scractch("AE") for global_sparsity in [0, 0.2, 0.4]: print(global_sparsity) flag2 = True acc_retain = [] parameters_to_prune = [] for name, module in nn_model_ref.net.named_modules(): if len(name): if name not in ["layers_batch", "layers_conv"]: flag = True for layer_forbidden in args.layers_NOT_to_prune: if layer_forbidden in name: flag = False if flag: parameters_to_prune.append((module, 'weight'))