Exemplo n.º 1
0
def load_model(model_config):
    modelname = model_config['name']

    if modelname == 'DenseNet121':
        model = densenet.DenseNet121()
    elif modelname == 'spDenseNet121':
        sp = model_config['sp']
        model = densenet.SparseDenseNet121(sparse_func='vol',
                                           sparsities=[sp, sp, sp, sp])
    elif modelname == 'ResNet18':
        model = resnet.ResNet18()
    elif modelname == 'ResNet34':
        model = resnet.ResNet34()
    elif modelname == 'ResNet50':
        model = resnet.ResNet50()
    elif modelname == 'ResNet101':
        model = resnet.ResNet101()
    elif modelname == 'ResNet152':
        model = resnet.ResNet152()

    elif modelname == 'spResNet18':
        sp = model_config['sp']
        model = resnet.SparseResNet18(relu=False,
                                      sparsities=[sp, sp, sp, sp],
                                      sparse_func='vol')
    elif modelname == 'spResNet34':
        sp = model_config['sp']
        model = resnet.SparseResNet34(relu=False,
                                      sparsities=[sp, sp, sp, sp],
                                      sparse_func='vol')
    elif modelname == 'spResNet50':
        sp = model_config['sp']
        model = resnet.SparseResNet50(relu=False,
                                      sparsities=[sp, sp, sp, sp],
                                      sparse_func='vol')
    elif modelname == 'spResNet101':
        sp = model_config['sp']
        model = resnet.SparseResNet101(relu=False,
                                       sparsities=[sp, sp, sp, sp],
                                       sparse_func='vol')
    elif modelname == 'spResNet152':
        sp = model_config['sp']
        model = resnet.SparseResNet152(relu=False,
                                       sparsities=[sp, sp, sp, sp],
                                       sparse_func='vol')

    elif modelname == 'spWideResNet':
        sp = model_config['sp']
        model = wideresnet.SparseWideResNet(depth=model_config['depth'],
                                            num_classes=10,
                                            widen_factor=model_config['width'],
                                            sp=sp,
                                            sp_func='vol',
                                            bias=False)

    else:
        raise NotImplementedError

    return model
Exemplo n.º 2
0
def load_model(models_config):
    model_list = []
    for model_config in models_config['lists']:
        modelname = model_config['name']
        print(modelname)
        if modelname == 'DenseNet121':
            model = densenet.DenseNet121()
        elif modelname == 'spDenseNet121':
            sp = model_config['sp']
            model = densenet.SparseDenseNet121(sparse_func='vol',
            sparsities=[sp,sp,sp,sp])
        elif modelname == 'ResNet18':
            model = resnet.ResNet18()
        elif modelname == 'ResNet34':
            model = resnet.ResNet34()
        elif modelname == 'ResNet50':
            model = resnet.ResNet50()
        elif modelname == 'ResNet101':
            model = resnet.ResNet101()
        elif modelname == 'ResNet152':
            model = resnet.ResNet152()


        elif modelname == 'spResNet18':
            sp = model_config['sp']
            model = resnet.SparseResNet18(relu=False, sparsities=[sp,sp,sp,sp], sparse_func='vol')
        elif modelname == 'spResNet34':
            sp = model_config['sp']
            model = resnet.SparseResNet34(relu=False, sparsities=[sp,sp,sp,sp], sparse_func='vol')
        elif modelname == 'spResNet50':
            sp = model_config['sp']
            model = resnet.SparseResNet50(relu=False, sparsities=[sp,sp,sp,sp], sparse_func='vol')
        elif modelname == 'spResNet101':
            sp = model_config['sp']
            model = resnet.SparseResNet101(relu=False, sparsities=[sp,sp,sp,sp], sparse_func='vol') 
        elif modelname == 'spResNet152':
            sp = model_config['sp']
            model = resnet.SparseResNet152(relu=False, sparsities=[sp,sp,sp,sp], sparse_func='vol')                
        
        elif modelname == 'spWideResNet':
            sp = model_config['sp']
            model = wideresnet.SparseWideResNet(depth=model_config['depth'], num_classes=10,
            widen_factor=model_config['width'], sp=sp, sp_func='vol', bias=False)
        
        elif modelname == 'WideResNet':
            model = wideresnet.WideResNet(depth=model_config['depth'],
            widen_factor=model_config['width'], num_classes=10)
        else:
            raise NotImplementedError

        model.load_state_dict(torch.load(model_config['path']))
        modelfile = model_config['path'].split('/')[-1][:-4]

        model_list.append((modelfile, model))
    return model_list
    batch_size=args.batch_size // 2,
    is_preprocessing=args.make_inputs_with_same_data)

name = ''
if args.is_Wide:
    name = name + '_wide'
if args.is_kWTA:
    name = name + '_kWTA'
name = name + '_' + args.which_AT
name = name + '_iter' + str(args.iters)

if args.is_Wide:
    if args.is_kWTA:
        load_net = wideresnet.SparseWideResNet(depth=34,
                                               num_classes=10,
                                               widen_factor=10,
                                               sp=0.1,
                                               sp_func='vol').to(device)
    else:
        load_net = wideresnet.WideResNet(depth=34,
                                         num_classes=10,
                                         widen_factor=10).to(device)
else:
    if args.is_kWTA:
        load_net = resnet.SparseResNet18(sparsities=[0.1, 0.1, 0.1, 0.1],
                                         sparse_func='vol').to(device)
    else:
        load_net = resnet.ResNet18().to(device)

if len(args.gpu) > 2:
    import torch.backends.cudnn as cudnn