예제 #1
0
파일: test_kd.py 프로젝트: NeelayS/KD_Lib
def test_DML():

    student_params = [4, 4, 4, 4, 4]
    student_model_1 = ResNet50(student_params, 1, 10)
    student_model_2 = ResNet18(student_params, 1, 10)

    student_cohort = (student_model_1, student_model_2)

    s_optimizer_1 = optim.SGD(student_model_1.parameters(), 0.01)
    s_optimizer_2 = optim.SGD(student_model_2.parameters(), 0.01)

    student_optimizers = (s_optimizer_1, s_optimizer_2)

    distiller = DML(
        student_cohort,
        train_loader,
        test_loader,
        student_optimizers,
        log=True,
        logdir=".",
    )

    distiller.train_students(epochs=1,
                             plot_losses=True,
                             save_model=True,
                             save_model_path="./student.pt")
    distiller.evaluate()
    distiller.get_parameters()

    del student_model_1, student_model_2, distiller, s_optimizer_1, s_optimizer_2
예제 #2
0
파일: test_kd.py 프로젝트: NeelayS/KD_Lib
def test_NoisyTeacher():
    teacher_params = [4, 4, 8, 4, 4]
    student_params = [4, 4, 4, 4, 4]
    teacher_model = ResNet50(teacher_params, 1, 10)
    student_model = ResNet18(student_params, 1, 10)

    t_optimizer = optim.SGD(teacher_model.parameters(), 0.01)
    s_optimizer = optim.SGD(student_model.parameters(), 0.01)

    experiment = NoisyTeacher(
        teacher_model,
        student_model,
        train_loader,
        test_loader,
        t_optimizer,
        s_optimizer,
        alpha=0.4,
        noise_variance=0.2,
        device="cpu",
    )

    experiment.train_teacher(epochs=1, plot_losses=False, save_model=False)
    experiment.train_student(epochs=1, plot_losses=False, save_model=False)
    experiment.evaluate(teacher=False)
    experiment.get_parameters()

    del teacher_model, student_model, experiment, t_optimizer, s_optimizer
예제 #3
0
파일: test_kd.py 프로젝트: NeelayS/KD_Lib
def test_RCO():

    teacher_params = [4, 4, 8, 4, 4]
    student_params = [4, 4, 4, 4, 4]
    teacher_model = ResNet50(teacher_params, 1, 10)
    student_model = ResNet18(student_params, 1, 10)

    t_optimizer = optim.SGD(teacher_model.parameters(), 0.01)
    s_optimizer = optim.SGD(student_model.parameters(), 0.01)

    distiller = RCO(
        teacher_model,
        student_model,
        train_loader,
        test_loader,
        t_optimizer,
        s_optimizer,
    )

    distiller.train_teacher(epochs=1, plot_losses=False, save_model=False)
    distiller.train_student(epochs=1, plot_losses=False, save_model=False)
    distiller.evaluate()
    distiller.get_parameters()

    del teacher_model, student_model, distiller, t_optimizer, s_optimizer
예제 #4
0
def test_resnet():
    params = [4, 4, 8, 8, 16]
    ResNet18(params)
    ResNet34(params)
    ResNet50(params)
    ResNet101(params)
    ResNet152(params)
예제 #5
0
def test_dynamic_quantization():

    model_params = [4, 4, 8, 4, 4]
    model = ResNet18(model_params, 1, 10, True)
    quantizer = Dynamic_Quantizer(model, test_loader, {torch.nn.Linear})
    quantized_model = quantizer.quantize()
    quantizer.get_model_sizes()
    quantizer.get_performance_statistics()

    del model, quantizer, quantized_model
예제 #6
0
def test_lottery_tickets():

    teacher_params = [4, 4, 8, 4, 4]
    teacher_model = ResNet18(teacher_params, 1, 10)
    pruner = LotteryTicketsPruner(teacher_model, train_loader, test_loader)
    pruner.prune(num_iterations=2,
                 train_epochs=1,
                 save_models=True,
                 prune_percent=50)

    del teacher_model, pruner
예제 #7
0
def test_weight_threshold_pruning():

    teacher_params = [4, 4, 8, 4, 4]
    teacher_model = ResNet18(teacher_params, 1, 10)
    pruner = WeightThresholdPruner(teacher_model, train_loader, test_loader)
    pruner.prune(num_iterations=2,
                 train_epochs=1,
                 save_models=True,
                 threshold=0.1)
    pruner.evaluate(model_path="pruned_model_iteration_0.pt")
    pruner.get_pruning_statistics(model_path="pruned_model_iteration_0.pt",
                                  verbose=True)

    del teacher_model, pruner
예제 #8
0
def test_DML():

    student_params = [4, 4, 4, 4, 4]
    student_model_1 = ResNet50(student_params, 1, 10)
    student_model_2 = ResNet18(student_params, 1, 10)

    student_cohort = (student_model_1, student_model_2)

    s_optimizer_1 = optim.SGD(student_model_1.parameters(), 0.01)
    s_optimizer_2 = optim.SGD(student_model_2.parameters(), 0.01)

    student_optimizers = (s_optimizer_1, s_optimizer_2)

    distiller = DML(student_cohort, train_loader, test_loader, student_optimizers)

    distiller.train_students(epochs=1, plot_losses=False, save_model=False)
    distiller.evaluate()
    distiller.get_parameters()
예제 #9
0
def test_mean_teacher():
    teacher_params = [4, 4, 8, 4, 4]
    student_params = [4, 4, 4, 4, 4]
    teacher_model = ResNet50(teacher_params, 1, 10, mean=True)
    student_model = ResNet18(student_params, 1, 10, mean=True)

    t_optimizer = optim.SGD(teacher_model.parameters(), 0.01)
    s_optimizer = optim.SGD(student_model.parameters(), 0.01)

    mt = MeanTeacher(
        teacher_model,
        student_model,
        train_loader,
        test_loader,
        t_optimizer,
        s_optimizer,
    )

    mt.train_teacher(epochs=1, plot_losses=False, save_model=False)
    mt.train_student(epochs=1, plot_losses=False, save_model=False)
    mt.evaluate()
    mt.get_parameters()
예제 #10
0
def test_attention():
    teacher_params = [4, 4, 8, 4, 4]
    student_params = [4, 4, 4, 4, 4]
    teacher_model = ResNet50(teacher_params, 1, 10, True)
    student_model = ResNet18(student_params, 1, 10, True)

    t_optimizer = optim.SGD(teacher_model.parameters(), 0.01)
    s_optimizer = optim.SGD(student_model.parameters(), 0.01)

    att = Attention(
        teacher_model,
        student_model,
        train_loader,
        test_loader,
        t_optimizer,
        s_optimizer,
    )

    att.train_teacher(epochs=1, plot_losses=False, save_model=False)
    att.train_student(epochs=1, plot_losses=False, save_model=False)
    att.evaluate(teacher=False)
    att.get_parameters()
예제 #11
0
def test_resnet():

    params = [4, 4, 8, 8, 16]
    model = ResNet18(params)