Пример #1
0
def train_student_encoded(train_set, val_set, teacher, params):
    EXP_NO = params.exp_no
    EXP_ID = params.exp_id
    STUDENT_TEMP = params.student_temp
    TEACHER_TEMP = params.teacher_temp
    ALPHA = params.alpha
    N_RULES = params.n_rules

    ROOT = "./models/{}".format(EXP_ID)
    if not os.path.exists(ROOT):
        os.mkdir(ROOT)
    ROOT = "./models/{}/{}".format(EXP_ID, EXP_NO)
    if not os.path.exists(ROOT):
        os.mkdir(ROOT)

    # Save Params
    with open(ROOT + "/params", "w") as f:
        json.dump(vars(args), f)
    STUDENT_MODEL_PATH = ROOT + "/student"

    train_opts = trn.TrainingOptions()
    train_opts.optimizer_type = trn.OptimizerType.Adam
    train_opts.learning_rate = 0.01
    train_opts.learning_rate_drop_type = trn.SchedulerType.StepLr
    train_opts.learning_rate_update_by_step = False  # Update at every epoch
    train_opts.learning_rate_drop_factor = 0.5  # Halve the learning rate
    train_opts.learning_rate_drop_step_count = params.learn_drop_epochs
    train_opts.batch_size = 128
    train_opts.n_epochs = params.n_epochs
    train_opts.use_gpu = True
    train_opts.custom_validation_func = validate_distillation
    train_opts.save_model = False
    train_opts.verbose_freq = 100
    train_opts.weight_decay = 1e-8
    train_opts.shuffle_data = True
    train_opts.regularization_method = None
    # Define loss
    dist_loss = DistillationLoss(STUDENT_TEMP, TEACHER_TEMP, ALPHA)

    encoder = CifarEncoder(n_dims=params.n_inputs)
    encoder.load_state_dict(torch.load("Networks/cifar_encoder"))
    student = StudentEncoder(n_memberships=N_RULES,
                             n_inputs=params.n_inputs,
                             n_outputs=10,
                             learnable_memberships=params.learn_ants,
                             encoder=encoder,
                             fuzzy_type=params.fuzzy_type,
                             use_sigma_scale=params.use_sigma_scale,
                             use_height_scale=params.use_height_scale)
    # Initialize student
    print("Initializing Student")
    train_set.shuffle_data()
    init_data, init_labels = train_set.get_batch(60000, 0, "cpu")
    student.initialize(init_data)
    # student.load_state_dict(torch.load(STUDENT_MODEL_PATH))
    print("Done Initializing Student")
    # student.fuzzy_layer.draw(5)
    # plt.plot(student.feature_extraction(init_data)[:,1:2], np.zeros(init_data.shape[0]), 'o')
    # plt.show()
    device = "cuda:" + args.gpu_no
    student.to(device)
    # Define distillation network
    dist_net = DistillNet(student, teacher)
    trainer = trn.Trainer(dist_net, train_opts)
    results = trainer.train(dist_loss,
                            train_set,
                            val_set,
                            is_classification=True)
    torch.save(student.state_dict(), STUDENT_MODEL_PATH)
    trn.save_train_info(results, STUDENT_MODEL_PATH + "_train_info")

    return student
Пример #2
0
        STUDENT_TEMP, TEACHER_TEMP,
        ALPHA)  # TODO: Search for the correct values from the paper

    # Initialzie student
    init_data, init_labels = train_set.get_batch(-1, 0, "cpu")
    student.initialize(init_data, init_labels)
    student.to("cuda:0")
    # Define distillation network
    dist_net = DistillNet(student, teacher)
    trainer = trn.Trainer(dist_net, train_opts)
    results = trainer.train(dist_loss,
                            train_set,
                            val_set,
                            is_classification=True)
    torch.save(student.state_dict(), STUDENT_MODEL_PATH)
    trn.save_train_info(results, STUDENT_MODEL_PATH + "_train_info")

if TEST:
    student.load_state_dict(torch.load(STUDENT_MODEL_PATH))
    teacher.to("cuda:0")
    test_batch, test_labels = test_set.get_batch(-1, 0, "cuda:0")
    teacher_preds = teacher.forward(test_batch.float())
    teacher_acc = validate_distillation([teacher_preds], test_labels)
    print("Teacher Acc:{}".format(teacher_acc))

    student.to("cuda:0")
    init_data, init_labels = train_set.get_batch(-1, 0, "cuda:0")
    student.fit_pca(init_data, init_labels)
    student_pred = student.forward(test_batch.float())
    student_acc = validate_distillation([student_pred], test_labels)
    print("Student Acc:{}".format(student_acc))