def train_student_encoded(train_set, val_set, teacher, params): EXP_NO = params.exp_no EXP_ID = params.exp_id STUDENT_TEMP = params.student_temp TEACHER_TEMP = params.teacher_temp ALPHA = params.alpha N_RULES = params.n_rules ROOT = "./models/{}".format(EXP_ID) if not os.path.exists(ROOT): os.mkdir(ROOT) ROOT = "./models/{}/{}".format(EXP_ID, EXP_NO) if not os.path.exists(ROOT): os.mkdir(ROOT) # Save Params with open(ROOT + "/params", "w") as f: json.dump(vars(args), f) STUDENT_MODEL_PATH = ROOT + "/student" train_opts = trn.TrainingOptions() train_opts.optimizer_type = trn.OptimizerType.Adam train_opts.learning_rate = 0.01 train_opts.learning_rate_drop_type = trn.SchedulerType.StepLr train_opts.learning_rate_update_by_step = False # Update at every epoch train_opts.learning_rate_drop_factor = 0.5 # Halve the learning rate train_opts.learning_rate_drop_step_count = params.learn_drop_epochs train_opts.batch_size = 128 train_opts.n_epochs = params.n_epochs train_opts.use_gpu = True train_opts.custom_validation_func = validate_distillation train_opts.save_model = False train_opts.verbose_freq = 100 train_opts.weight_decay = 1e-8 train_opts.shuffle_data = True train_opts.regularization_method = None # Define loss dist_loss = DistillationLoss(STUDENT_TEMP, TEACHER_TEMP, ALPHA) encoder = CifarEncoder(n_dims=params.n_inputs) encoder.load_state_dict(torch.load("Networks/cifar_encoder")) student = StudentEncoder(n_memberships=N_RULES, n_inputs=params.n_inputs, n_outputs=10, learnable_memberships=params.learn_ants, encoder=encoder, fuzzy_type=params.fuzzy_type, use_sigma_scale=params.use_sigma_scale, use_height_scale=params.use_height_scale) # Initialize student print("Initializing Student") train_set.shuffle_data() init_data, init_labels = train_set.get_batch(60000, 0, "cpu") student.initialize(init_data) # student.load_state_dict(torch.load(STUDENT_MODEL_PATH)) print("Done Initializing Student") # student.fuzzy_layer.draw(5) # plt.plot(student.feature_extraction(init_data)[:,1:2], np.zeros(init_data.shape[0]), 'o') # plt.show() device = "cuda:" + args.gpu_no student.to(device) # Define distillation network dist_net = DistillNet(student, teacher) trainer = trn.Trainer(dist_net, train_opts) results = trainer.train(dist_loss, train_set, val_set, is_classification=True) torch.save(student.state_dict(), STUDENT_MODEL_PATH) trn.save_train_info(results, STUDENT_MODEL_PATH + "_train_info") return student
STUDENT_TEMP, TEACHER_TEMP, ALPHA) # TODO: Search for the correct values from the paper # Initialzie student init_data, init_labels = train_set.get_batch(-1, 0, "cpu") student.initialize(init_data, init_labels) student.to("cuda:0") # Define distillation network dist_net = DistillNet(student, teacher) trainer = trn.Trainer(dist_net, train_opts) results = trainer.train(dist_loss, train_set, val_set, is_classification=True) torch.save(student.state_dict(), STUDENT_MODEL_PATH) trn.save_train_info(results, STUDENT_MODEL_PATH + "_train_info") if TEST: student.load_state_dict(torch.load(STUDENT_MODEL_PATH)) teacher.to("cuda:0") test_batch, test_labels = test_set.get_batch(-1, 0, "cuda:0") teacher_preds = teacher.forward(test_batch.float()) teacher_acc = validate_distillation([teacher_preds], test_labels) print("Teacher Acc:{}".format(teacher_acc)) student.to("cuda:0") init_data, init_labels = train_set.get_batch(-1, 0, "cuda:0") student.fit_pca(init_data, init_labels) student_pred = student.forward(test_batch.float()) student_acc = validate_distillation([student_pred], test_labels) print("Student Acc:{}".format(student_acc))