def load_model(models_config): model_list = [] for model_config in models_config['lists']: modelname = model_config['name'] print(modelname) if modelname == 'DenseNet121': model = densenet.DenseNet121() elif modelname == 'spDenseNet121': sp = model_config['sp'] model = densenet.SparseDenseNet121(sparse_func='vol', sparsities=[sp,sp,sp,sp]) elif modelname == 'ResNet18': model = resnet.ResNet18() elif modelname == 'ResNet34': model = resnet.ResNet34() elif modelname == 'ResNet50': model = resnet.ResNet50() elif modelname == 'ResNet101': model = resnet.ResNet101() elif modelname == 'ResNet152': model = resnet.ResNet152() elif modelname == 'spResNet18': sp = model_config['sp'] model = resnet.SparseResNet18(relu=False, sparsities=[sp,sp,sp,sp], sparse_func='vol') elif modelname == 'spResNet34': sp = model_config['sp'] model = resnet.SparseResNet34(relu=False, sparsities=[sp,sp,sp,sp], sparse_func='vol') elif modelname == 'spResNet50': sp = model_config['sp'] model = resnet.SparseResNet50(relu=False, sparsities=[sp,sp,sp,sp], sparse_func='vol') elif modelname == 'spResNet101': sp = model_config['sp'] model = resnet.SparseResNet101(relu=False, sparsities=[sp,sp,sp,sp], sparse_func='vol') elif modelname == 'spResNet152': sp = model_config['sp'] model = resnet.SparseResNet152(relu=False, sparsities=[sp,sp,sp,sp], sparse_func='vol') elif modelname == 'spWideResNet': sp = model_config['sp'] model = wideresnet.SparseWideResNet(depth=model_config['depth'], num_classes=10, widen_factor=model_config['width'], sp=sp, sp_func='vol', bias=False) elif modelname == 'WideResNet': model = wideresnet.WideResNet(depth=model_config['depth'], widen_factor=model_config['width'], num_classes=10) else: raise NotImplementedError model.load_state_dict(torch.load(model_config['path'])) modelfile = model_config['path'].split('/')[-1][:-4] model_list.append((modelfile, model)) return model_list
def load_model(model_config): modelname = model_config['name'] if modelname == 'DenseNet121': model = densenet.DenseNet121() elif modelname == 'spDenseNet121': sp = model_config['sp'] model = densenet.SparseDenseNet121(sparse_func='vol', sparsities=[sp,sp,sp,sp]) elif modelname == 'ResNet18': model = resnet.ResNet18() elif modelname == 'ResNet34': model = resnet.ResNet34() elif modelname == 'ResNet50': model = resnet.ResNet50() elif modelname == 'ResNet101': model = resnet.ResNet101() elif modelname == 'ResNet152': model = resnet.ResNet152() elif modelname == 'spResNet18': sp = model_config['sp'] model = resnet.SparseResNet18(relu=False, sparsities=[sp,sp,sp,sp], sparse_func='vol') elif modelname == 'spResNet34': sp = model_config['sp'] model = resnet.SparseResNet34(relu=False, sparsities=[sp,sp,sp,sp], sparse_func='vol') elif modelname == 'spResNet50': sp = model_config['sp'] model = resnet.SparseResNet50(relu=False, sparsities=[sp,sp,sp,sp], sparse_func='vol') elif modelname == 'spResNet101': sp = model_config['sp'] model = resnet.SparseResNet101(relu=False, sparsities=[sp,sp,sp,sp], sparse_func='vol') elif modelname == 'spResNet152': sp = model_config['sp'] model = resnet.SparseResNet152(relu=False, sparsities=[sp,sp,sp,sp], sparse_func='vol') elif modelname == 'spWideResNet': sp = model_config['sp'] model = wideresnet.SparseWideResNet(depth=model_config['depth'], num_classes=10, widen_factor=model_config['width'], sp=sp, sp_func='vol', bias=False) elif modelname == 'WideResNet': model = wideresnet.WideResNet(depth=model_config['depth'], widen_factor=model_config['width'], num_classes=10) else: raise NotImplementedError if 'path' in model_config: model.load_state_dict(torch.load(model_config['path'])) return model
name = name + '_wide' if args.is_kWTA: name = name + '_kWTA' name = name + '_' + args.which_AT name = name + '_iter' + str(args.iters) if args.is_Wide: if args.is_kWTA: load_net = wideresnet.SparseWideResNet(depth=34, num_classes=10, widen_factor=10, sp=0.1, sp_func='vol').to(device) else: load_net = wideresnet.WideResNet(depth=34, num_classes=10, widen_factor=10).to(device) else: if args.is_kWTA: load_net = resnet.SparseResNet18(sparsities=[0.1, 0.1, 0.1, 0.1], sparse_func='vol').to(device) else: load_net = resnet.ResNet18().to(device) if len(args.gpu) > 2: import torch.backends.cudnn as cudnn if device == 'cuda': load_net = torch.nn.DataParallel(load_net) cudnn.benchmark = True load_net.load_state_dict(