示例#1
0
def load_model(model_config):
    modelname = model_config['name']

    if modelname == 'DenseNet121':
        model = densenet.DenseNet121()
    elif modelname == 'spDenseNet121':
        sp = model_config['sp']
        model = densenet.SparseDenseNet121(sparse_func='vol',
                                           sparsities=[sp, sp, sp, sp])
    elif modelname == 'ResNet18':
        model = resnet.ResNet18()
    elif modelname == 'ResNet34':
        model = resnet.ResNet34()
    elif modelname == 'ResNet50':
        model = resnet.ResNet50()
    elif modelname == 'ResNet101':
        model = resnet.ResNet101()
    elif modelname == 'ResNet152':
        model = resnet.ResNet152()

    elif modelname == 'spResNet18':
        sp = model_config['sp']
        model = resnet.SparseResNet18(relu=False,
                                      sparsities=[sp, sp, sp, sp],
                                      sparse_func='vol')
    elif modelname == 'spResNet34':
        sp = model_config['sp']
        model = resnet.SparseResNet34(relu=False,
                                      sparsities=[sp, sp, sp, sp],
                                      sparse_func='vol')
    elif modelname == 'spResNet50':
        sp = model_config['sp']
        model = resnet.SparseResNet50(relu=False,
                                      sparsities=[sp, sp, sp, sp],
                                      sparse_func='vol')
    elif modelname == 'spResNet101':
        sp = model_config['sp']
        model = resnet.SparseResNet101(relu=False,
                                       sparsities=[sp, sp, sp, sp],
                                       sparse_func='vol')
    elif modelname == 'spResNet152':
        sp = model_config['sp']
        model = resnet.SparseResNet152(relu=False,
                                       sparsities=[sp, sp, sp, sp],
                                       sparse_func='vol')

    elif modelname == 'spWideResNet':
        sp = model_config['sp']
        model = wideresnet.SparseWideResNet(depth=model_config['depth'],
                                            num_classes=10,
                                            widen_factor=model_config['width'],
                                            sp=sp,
                                            sp_func='vol',
                                            bias=False)

    else:
        raise NotImplementedError

    return model
示例#2
0
def load_model(models_config):
    model_list = []
    for model_config in models_config['lists']:
        modelname = model_config['name']
        print(modelname)
        if modelname == 'DenseNet121':
            model = densenet.DenseNet121()
        elif modelname == 'spDenseNet121':
            sp = model_config['sp']
            model = densenet.SparseDenseNet121(sparse_func='vol',
            sparsities=[sp,sp,sp,sp])
        elif modelname == 'ResNet18':
            model = resnet.ResNet18()
        elif modelname == 'ResNet34':
            model = resnet.ResNet34()
        elif modelname == 'ResNet50':
            model = resnet.ResNet50()
        elif modelname == 'ResNet101':
            model = resnet.ResNet101()
        elif modelname == 'ResNet152':
            model = resnet.ResNet152()


        elif modelname == 'spResNet18':
            sp = model_config['sp']
            model = resnet.SparseResNet18(relu=False, sparsities=[sp,sp,sp,sp], sparse_func='vol')
        elif modelname == 'spResNet34':
            sp = model_config['sp']
            model = resnet.SparseResNet34(relu=False, sparsities=[sp,sp,sp,sp], sparse_func='vol')
        elif modelname == 'spResNet50':
            sp = model_config['sp']
            model = resnet.SparseResNet50(relu=False, sparsities=[sp,sp,sp,sp], sparse_func='vol')
        elif modelname == 'spResNet101':
            sp = model_config['sp']
            model = resnet.SparseResNet101(relu=False, sparsities=[sp,sp,sp,sp], sparse_func='vol') 
        elif modelname == 'spResNet152':
            sp = model_config['sp']
            model = resnet.SparseResNet152(relu=False, sparsities=[sp,sp,sp,sp], sparse_func='vol')                
        
        elif modelname == 'spWideResNet':
            sp = model_config['sp']
            model = wideresnet.SparseWideResNet(depth=model_config['depth'], num_classes=10,
            widen_factor=model_config['width'], sp=sp, sp_func='vol', bias=False)
        
        elif modelname == 'WideResNet':
            model = wideresnet.WideResNet(depth=model_config['depth'],
            widen_factor=model_config['width'], num_classes=10)
        else:
            raise NotImplementedError

        model.load_state_dict(torch.load(model_config['path']))
        modelfile = model_config['path'].split('/')[-1][:-4]

        model_list.append((modelfile, model))
    return model_list
示例#3
0
                                               num_iter=20,
                                               use_tqdm=True)
print("MIM:", adv_err)
test_loader = DataLoader(cifar_test, batch_size=1, shuffle=True)
adv_err, adv_loss = training.epoch_adversarial(test_loader,
                                               model,
                                               attack=attack.deepfool,
                                               device=device,
                                               num_iter=20,
                                               use_tqdm=True,
                                               epsilon=eps,
                                               n_test=1000)
print("Deepfool:", adv_err)

device = torch.device('cuda:0')
model = resnet.SparseResNet18(sparsities=[0.1, 0.1, 0.1, 0.1],
                              sparse_func='vol').to(device)
print("model loading --------k=0.1")
model.load_state_dict(torch.load('models/spresnet18_0.1_cifar_80epochs.pth'))

test_loader = DataLoader(cifar_test, batch_size=400, shuffle=True)
model.eval()
test_err, test_loss = training.epoch(test_loader,
                                     model,
                                     device=device,
                                     use_tqdm=True)
print("test", test_err)

adv_err, adv_loss = training.epoch_adversarial(
    test_loader,
    model,
    attack=attack.pgd_linf_untargeted,
name = name + '_iter' + str(args.iters)

if args.is_Wide:
    if args.is_kWTA:
        load_net = wideresnet.SparseWideResNet(depth=34,
                                               num_classes=10,
                                               widen_factor=10,
                                               sp=0.1,
                                               sp_func='vol').to(device)
    else:
        load_net = wideresnet.WideResNet(depth=34,
                                         num_classes=10,
                                         widen_factor=10).to(device)
else:
    if args.is_kWTA:
        load_net = resnet.SparseResNet18(sparsities=[0.1, 0.1, 0.1, 0.1],
                                         sparse_func='vol').to(device)
    else:
        load_net = resnet.ResNet18().to(device)

if len(args.gpu) > 2:
    import torch.backends.cudnn as cudnn
    if device == 'cuda':
        load_net = torch.nn.DataParallel(load_net)
        cudnn.benchmark = True

load_net.load_state_dict(
    torch.load('models/resnet18_cifar' + name + '.pth', map_location=device))

if len(args.gpu) > 2:
    load_net = list(load_net.children())[0]