コード例 #1
0
def exec_vgg16(exec_name, pruning_params=None, exec_params=None, dataset_params=None, debug_params=None):
    print("*** ", exec_name)

    if exec_params.best_result_save_path is not None and os.path.isfile(exec_params.best_result_save_path):
        model = torch.load(exec_params.best_result_save_path)
    else:
        model = models.vgg16(pretrained=True)

    model.cuda()

    history, test_score = common_training_code(model,
                                               pruned_save_path="../saved/{}/Pruned.pth".format(exec_name),
                                               pruned_best_result_save_path="../saved/{}/pruned_best.pth".format(exec_name),
                                               sample_run=torch.zeros([1, 3, 224, 224]),
                                               pruning_params=pruning_params,
                                               exec_params=exec_params,
                                               dataset_params=dataset_params,
                                               debug_params=debug_params)
    return history, test_score
コード例 #2
0
def exec_resnet50(exec_name, pruning_params=None, exec_params=None, dataset_params=None, out_count=1000, debug_params=None):
    print("*** ", exec_name)

    if exec_params.best_result_save_path is not None and os.path.isfile(exec_params.best_result_save_path):
        model = torch.load(exec_params.best_result_save_path)
    else:
        model = models.resnet50(pretrained=True)
        num_ftrs = model.fc.in_features
        model.fc = nn.Linear(num_ftrs, 10)

    model.cuda()

    history, test_score = common_training_code(model,
                                               pruned_save_path="../saved/{}/Pruned.pth".format(exec_name),
                                               pruned_best_result_save_path="../saved/{}/pruned_best.pth".format(exec_name),
                                               sample_run=torch.zeros([1, 3, 224, 224]),
                                               pruning_params=pruning_params,
                                               exec_params=exec_params,
                                               dataset_params=dataset_params,
                                               debug_params=debug_params)
    return history, test_score
コード例 #3
0
def exec_dense_net(exec_name, pruning_params=None, exec_params=None, dataset_params=None, debug_params=None):
    print("*** ", exec_name)

    if exec_params.best_result_save_path is not None and os.path.isfile(exec_params.best_result_save_path):
        model = torch.load(exec_params.best_result_save_path)
    else:
        model = models.densenet121(pretrained=True)

    model.cuda()

    if exec_params is not None:
        exec_params.force_forward_view = True
        exec_params.ignore_last_conv = True

    history, test_score = common_training_code(model,
                                               pruned_save_path="../saved/{}/Pruned.pth".format(exec_name),
                                               pruned_best_result_save_path="../saved/{}/pruned_best.pth".format(exec_name),
                                               sample_run=torch.zeros([1, 3, 224, 224]),
                                               pruning_params=pruning_params,
                                               exec_params=exec_params,
                                               dataset_params=dataset_params,
                                               debug_params=debug_params)
    return history, test_score