def get_vgg_model(gpu, percentage_freeze): model = vgg19_bn(True) model.classifier = nn.Sequential( nn.Linear(25088, 4096), nn.ReLU(True), nn.Dropout(), nn.Linear(4096, 4096), nn.ReLU(True), nn.Dropout(), nn.Linear(4096, num_classes), ) num_layers_freeze = 30 params_freezed_count = 0 params_total_count = get_total_trainable_params(model) # for i,param in enumerate(model.parameters()): # percentage_params=params_freezed_count/params_total_count # if percentage_params>percentage_freeze: # param.requires_grad = True # else: # params_freezed_count+=np.prod(param.size()) # param.requires_grad = False summary(model.cuda(), (3, height, width)) return model, "vgg_19_{}_adam".format(gpu)
def get_vgg_model(): model = vgg19_bn(True) model.classifier = nn.Sequential( nn.Linear(512, 4096), nn.ReLU(True), nn.Dropout(), nn.Linear(4096, 4096), nn.ReLU(True), nn.Dropout(), nn.Linear(4096, num_classes), ) return model
def get_model(cfg, pretrained=False, load_param_from_ours=False): if load_param_from_ours: pretrained = False model = None num_classes = cfg.num_classes if cfg.model == 'custom': from models import custom_net if cfg.patch_size == 64: model = custom_net.net_64(num_classes = num_classes) elif cfg.patch_size == 32: model = custom_net.net_32(num_classes = num_classes) else: print('Do not support present patch size %s'%cfg.patch_size) #model = model elif cfg.model == 'googlenet': from models import inception_v3 model = inception_v3.inception_v3(pretrained = pretrained, num_classes = num_classes) elif cfg.model == 'vgg': from models import vgg if cfg.model_info == 19: model = vgg.vgg19_bn(pretrained = pretrained, num_classes = num_classes) elif cfg.model_info == 16: model = vgg.vgg16_bn(pretrained = pretrained, num_classes = num_classes) elif cfg.model == 'resnet': from models import resnet if cfg.model_info == 18: model = resnet.resnet18(pretrained= pretrained, num_classes = num_classes) elif cfg.model_info == 34: model = resnet.resnet34(pretrained= pretrained, num_classes = num_classes) elif cfg.model_info == 50: model = resnet.resnet50(pretrained= pretrained, num_classes = num_classes) elif cfg.model_info == 101: model = resnet.resnet101(pretrained= pretrained, num_classes = num_classes) if model is None: print('not support :' + cfg.model) sys.exit(-1) if load_param_from_ours: print('loading pretrained model from {0}'.format(cfg.init_model_file)) checkpoint = torch.load(cfg.init_model_file) model.load_state_dict(checkpoint['model_param']) model.cuda() print('shift model to parallel!') model = torch.nn.DataParallel(model, device_ids=cfg.gpu_id) return model
def get_network(args,cfg): """ return given network """ # pdb.set_trace() if args.net == 'lenet5': net = LeNet5().cuda() elif args.net == 'alexnet': net = alexnet(pretrained=args.pretrain, num_classes=cfg.PARA.train.num_classes).cuda() elif args.net == 'vgg16': net = vgg16(pretrained=args.pretrain, num_classes=cfg.PARA.train.num_classes).cuda() elif args.net == 'vgg13': net = vgg13(pretrained=args.pretrain,num_classes=cfg.PARA.train.num_classes).cuda() elif args.net == 'vgg11': net = vgg11(pretrained=args.pretrain,num_classes=cfg.PARA.train.num_classes).cuda() elif args.net == 'vgg19': net = vgg19(pretrained=args.pretrain,num_classes=cfg.PARA.train.num_classes).cuda() elif args.net == 'vgg16_bn': net = vgg16_bn(pretrained=args.pretrain,num_classes=cfg.PARA.train.num_classes).cuda() elif args.net == 'vgg13_bn': net = vgg13_bn(pretrained=args.pretrain,num_classes=cfg.PARA.train.num_classes).cuda() elif args.net == 'vgg11_bn': net = vgg11_bn(pretrained=args.pretrain,num_classes=cfg.PARA.train.num_classes).cuda() elif args.net == 'vgg19_bn': net = vgg19_bn(pretrained=args.pretrain,num_classes=cfg.PARA.train.num_classes).cuda() elif args.net =='inceptionv3': net = inception_v3().cuda() # elif args.net == 'inceptionv4': # net = inceptionv4().cuda() # elif args.net == 'inceptionresnetv2': # net = inception_resnet_v2().cuda() elif args.net == 'resnet18': net = resnet18(pretrained=args.pretrain,num_classes=cfg.PARA.train.num_classes).cuda(args.gpuid) elif args.net == 'resnet34': net = resnet34(pretrained=args.pretrain,num_classes=cfg.PARA.train.num_classes).cuda() elif args.net == 'resnet50': net = resnet50(pretrained=args.pretrain,num_classes=cfg.PARA.train.num_classes).cuda(args.gpuid) elif args.net == 'resnet101': net = resnet101(pretrained=args.pretrain,num_classes=cfg.PARA.train.num_classes).cuda() elif args.net == 'resnet152': net = resnet152(pretrained=args.pretrain,num_classes=cfg.PARA.train.num_classes).cuda() elif args.net == 'squeezenet': net = squeezenet1_0().cuda() else: print('the network name you have entered is not supported yet') sys.exit() return net
def get_model(model, dataset, classify=True): """ VGG Models """ if model == 'vgg11': model = vgg.vgg11_bn(dataset=dataset, classify=classify) if model == 'vgg13': model = vgg.vgg13_bn(dataset=dataset, classify=classify) if model == 'vgg16': model = vgg.vgg16_bn(dataset=dataset, classify=classify) if model == 'vgg19': model = vgg.vgg19_bn(dataset=dataset, classify=classify) """ CyVGG Models """ if model == 'cyvgg11': model = cyvgg.cyvgg11_bn(dataset=dataset, classify=classify) if model == 'cyvgg13': model = cyvgg.cyvgg13_bn(dataset=dataset, classify=classify) if model == 'cyvgg16': model = cyvgg.cyvgg16_bn(dataset=dataset, classify=classify) if model == 'cyvgg19': model = cyvgg.cyvgg19_bn(dataset=dataset, classify=classify) """ Resnet Models """ if model == 'resnet20': model = resnet.resnet20(dataset=dataset) if model == 'resnet32': model = resnet.resnet32(dataset=dataset) if model == 'resnet44': model = resnet.resnet44(dataset=dataset) if model == 'resnet56': model = resnet.resnet56(dataset=dataset) """ CyResnet Models """ if model == 'cyresnet20': model = cyresnet.cyresnet20(dataset=dataset) if model == 'cyresnet32': model = cyresnet.cyresnet32(dataset=dataset) if model == 'cyresnet44': model = cyresnet.cyresnet44(dataset=dataset) if model == 'cyresnet56': model = cyresnet.cyresnet56(dataset=dataset) return model
def highest_resolution_test(cfg, file_name, sorted_abnormal_patches, time, patch_out_size, b_map, p_map): """ Input: patch_coordinates; Output: heat map. """ model_hr = vgg.vgg19_bn(pretrained=False, num_classes=num_classes) checkpoint_hr = torch.load(checkpoint_path_hr) model_hr.load_state_dict(checkpoint_hr['model_param']) model_hr.cuda() model_hr = torch.nn.DataParallel(model_hr, device_ids=cfg.gpu_id) model_hr.eval() b_map, p_map = prob_map.generate_prob_map_hr(cfg, file_name, sorted_abnormal_patches, model_hr, time, patch_out_size, b_map, p_map) return b_map, p_map
def get_model(cfg, pretrained=True, load_param_from_folder=False): if load_param_from_folder: pretrained = False model = None num_classes = cfg.num_classes if cfg.model == 'googlenet': from models import inception_v3 model = inception_v3.inception_v3(pretrained = pretrained, num_classes = num_classes) elif cfg.model == 'vgg': from models import vgg if cfg.model_info == 19: model = vgg.vgg19_bn(pretrained = pretrained, num_classes = num_classes) elif cfg.model_info == 16: model = vgg.vgg16_bn(pretrained = pretrained, num_classes = num_classes) elif cfg.model == 'resnet': from models import resnet if cfg.model_info == 18: model = resnet.resnet18(pretrained= pretrained, num_classes = num_classes) elif cfg.model_info == 34: model = resnet.resnet34(pretrained= pretrained, num_classes = num_classes) elif cfg.model_info == 50: model = resnet.resnet50(pretrained= pretrained, num_classes = num_classes) elif cfg.model_info == 101: model = resnet.resnet101(pretrained= pretrained, num_classes = num_classes) if model is None: print('not support :' + cfg.model) sys.exit(-1) if load_param_from_folder: print('loading pretrained model from {0}'.format(cfg.init_model_file)) checkpoint = torch.load(cfg.init_model_file) model.load_state_dict(checkpoint['model_param']) print('shift model to parallel!') model = torch.nn.DataParallel(model, device_ids=cfg.gpu_id) return model
def get_network(args, use_gpu=False): """ return given network """ if args.net == 'vgg16': from models.vgg import vgg16_bn net = vgg16_bn() elif args.net == 'vgg13': from models.vgg import vgg13_bn net = vgg13_bn() elif args.net == 'vgg11': from models.vgg import vgg11_bn net = vgg11_bn() elif args.net == 'vgg19': from models.vgg import vgg19_bn net = vgg19_bn() else: print('the network name you have entered is not supported yet') sys.exit() if use_gpu: net = net.cuda() return net
def get_network(args): """ return given network """ if args.net == 'vgg11': from models.vgg import vgg11_bn net = vgg11_bn() elif args.net == 'vgg19': from models.vgg import vgg19_bn net = vgg19_bn() elif args.net == 'resnet18': from models.resnet import resnet18 net = resnet18() elif args.net == 'resnet50': from models.resnet import resnet50 net = resnet50() else: print('the network name you have entered is not supported yet') sys.exit() if args.gpu: #use_gpu net = net.cuda() return net
def get_network(args, use_gpu=True): """ return given network """ if args.net == 'vgg16': from models.vgg import vgg16_bn net = vgg16_bn() elif args.net == 'vgg13': from models.vgg import vgg13_bn net = vgg13_bn() elif args.net == 'vgg11': from models.vgg import vgg11_bn net = vgg11_bn() elif args.net == 'vgg19': from models.vgg import vgg19_bn net = vgg19_bn() elif args.net == 'densenet121': from models.densenet import densenet121 net = densenet121() elif args.net == 'densenet161': from models.densenet import densenet161 net = densenet161() elif args.net == 'densenet169': from models.densenet import densenet169 net = densenet169() elif args.net == 'densenet201': from models.densenet import densenet201 net = densenet201() elif args.net == 'googlenet': from models.googlenet import googlenet net = googlenet() elif args.net == 'inceptionv3': from models.inceptionv3 import inceptionv3 net = inceptionv3() elif args.net == 'inceptionv4': from models.inceptionv4 import inceptionv4 net = inceptionv4() elif args.net == 'inceptionresnetv2': from models.inceptionv4 import inception_resnet_v2 net = inception_resnet_v2() elif args.net == 'xception': from models.xception import xception net = xception() elif args.net == 'resnet18': from models.resnet import resnet18 net = resnet18() elif args.net == 'resnet34': from models.resnet import resnet34 net = resnet34() elif args.net == 'resnet50': from models.resnet import resnet50 net = resnet50() elif args.net == 'resnet101': from models.resnet import resnet101 net = resnet101() elif args.net == 'resnet152': from models.resnet import resnet152 net = resnet152() elif args.net == 'preactresnet18': from models.preactresnet import preactresnet18 net = preactresnet18() elif args.net == 'preactresnet34': from models.preactresnet import preactresnet34 net = preactresnet34() elif args.net == 'preactresnet50': from models.preactresnet import preactresnet50 net = preactresnet50() elif args.net == 'preactresnet101': from models.preactresnet import preactresnet101 net = preactresnet101() elif args.net == 'preactresnet152': from models.preactresnet import preactresnet152 net = preactresnet152() elif args.net == 'resnext50': from models.resnext import resnext50 net = resnext50() elif args.net == 'resnext101': from models.resnext import resnext101 net = resnext101() elif args.net == 'resnext152': from models.resnext import resnext152 net = resnext152() elif args.net == 'shufflenet': from models.shufflenet import shufflenet net = shufflenet() elif args.net == 'shufflenetv2': from models.shufflenetv2 import shufflenetv2 net = shufflenetv2() elif args.net == 'squeezenet': from models.squeezenet import squeezenet net = squeezenet() elif args.net == 'mobilenet': from models.mobilenet import mobilenet net = mobilenet() elif args.net == 'mobilenetv2': from models.mobilenetv2 import mobilenetv2 net = mobilenetv2() elif args.net == 'nasnet': from models.nasnet import nasnet net = nasnet() elif args.net == 'attention56': from models.attention import attention56 net = attention56() elif args.net == 'attention92': from models.attention import attention92 net = attention92() elif args.net == 'seresnet18': from models.senet import seresnet18 net = seresnet18() elif args.net == 'seresnet34': from models.senet import seresnet34 net = seresnet34() elif args.net == 'seresnet50': from models.senet import seresnet50 net = seresnet50() elif args.net == 'seresnet101': from models.senet import seresnet101 net = seresnet101() elif args.net == 'seresnet152': from models.senet import seresnet152 net = seresnet152() else: print('the network name you have entered is not supported yet') sys.exit() if use_gpu: net = net.cuda() return net
def get_network(args): """ return given network """ if args.net == 'vgg16': from models.vgg import vgg16_bn net = vgg16_bn() elif args.net == 'vgg13': from models.vgg import vgg13_bn net = vgg13_bn() elif args.net == 'vgg11': from models.vgg import vgg11_bn net = vgg11_bn() elif args.net == 'vgg19': from models.vgg import vgg19_bn net = vgg19_bn() elif args.net == 'densenet121': from models.densenet import densenet121 net = densenet121() elif args.net == 'densenet161': from models.densenet import densenet161 net = densenet161() elif args.net == 'densenet169': from models.densenet import densenet169 net = densenet169() elif args.net == 'densenet201': from models.densenet import densenet201 net = densenet201() elif args.net == 'googlenet': from models.googlenet import googlenet net = googlenet() elif args.net == 'inceptionv3': from models.inceptionv3 import inceptionv3 net = inceptionv3() elif args.net == 'inceptionv4': from models.inceptionv4 import inceptionv4 net = inceptionv4() elif args.net == 'inceptionresnetv2': from models.inceptionv4 import inception_resnet_v2 net = inception_resnet_v2() elif args.net == 'xception': from models.xception import xception net = xception() elif args.net == 'resnet18': from models.resnet import resnet18 net = resnet18() elif args.net == 'resnet34': from models.resnet import resnet34 net = resnet34() elif args.net == 'resnet50': from models.resnet import resnet50 net = resnet50() elif args.net == 'resnet101': from models.resnet import resnet101 net = resnet101() elif args.net == 'resnet152': from models.resnet import resnet152 net = resnet152() elif args.net == 'preactresnet18': from models.preactresnet import preactresnet18 net = preactresnet18() elif args.net == 'preactresnet34': from models.preactresnet import preactresnet34 net = preactresnet34() elif args.net == 'preactresnet50': from models.preactresnet import preactresnet50 net = preactresnet50() elif args.net == 'preactresnet101': from models.preactresnet import preactresnet101 net = preactresnet101() elif args.net == 'preactresnet152': from models.preactresnet import preactresnet152 net = preactresnet152() elif args.net == 'resnext50': from models.resnext import resnext50 net = resnext50() elif args.net == 'resnext101': from models.resnext import resnext101 net = resnext101() elif args.net == 'resnext152': from models.resnext import resnext152 net = resnext152() elif args.net == 'shufflenet': from models.shufflenet import shufflenet net = shufflenet() elif args.net == 'shufflenetv2': from models.shufflenetv2 import shufflenetv2 net = shufflenetv2() elif args.net == 'squeezenet': from models.squeezenet import squeezenet net = squeezenet() elif args.net == 'mobilenet': from models.mobilenet import mobilenet net = mobilenet() elif args.net == 'mobilenetv2': from models.mobilenetv2 import mobilenetv2 net = mobilenetv2() elif args.net == 'nasnet': from models.nasnet import nasnet net = nasnet() elif args.net == 'attention56': from models.attention import attention56 net = attention56() elif args.net == 'attention92': from models.attention import attention92 net = attention92() elif args.net == 'seresnet18': from models.senet import seresnet18 net = seresnet18() elif args.net == 'seresnet34': from models.senet import seresnet34 net = seresnet34() elif args.net == 'seresnet50': from models.senet import seresnet50 net = seresnet50() elif args.net == 'seresnet101': from models.senet import seresnet101 net = seresnet101() elif args.net == 'seresnet152': from models.senet import seresnet152 net = seresnet152() elif args.net == 'wideresnet': from models.wideresidual import wideresnet net = wideresnet() elif args.net == 'stochasticdepth18': from models.stochasticdepth import stochastic_depth_resnet18 net = stochastic_depth_resnet18() elif args.net == 'stochasticdepth34': from models.stochasticdepth import stochastic_depth_resnet34 net = stochastic_depth_resnet34() elif args.net == 'stochasticdepth50': from models.stochasticdepth import stochastic_depth_resnet50 net = stochastic_depth_resnet50() elif args.net == 'stochasticdepth101': from models.stochasticdepth import stochastic_depth_resnet101 net = stochastic_depth_resnet101() elif args.net == 'normal_resnet': from models.normal_resnet import resnet18 net = resnet18() elif args.net == 'hyper_resnet': from models.hypernet_main import Hypernet_Main net = Hypernet_Main( encoder="resnet18", hypernet_params={'vqvae_dict_size': args.dict_size}) elif args.net == 'normal_resnet_wo_bn': from models.normal_resnet_wo_bn import resnet18 net = resnet18() elif args.net == 'hyper_resnet_wo_bn': from models.hypernet_main import Hypernet_Main net = Hypernet_Main( encoder="resnet18_wobn", hypernet_params={'vqvae_dict_size': args.dict_size}) else: print('the network name you have entered is not supported yet') sys.exit() if args.gpu: #use_gpu net = net.cuda() return net
from conv_cf import HCF import torch import numpy as np import cv2 from models.vgg import vgg19_bn import torchvision.transforms as T device = torch.device("cuda:1" if torch.cuda.is_available() else "cpu") model = vgg19_bn(pretrained=True, progress=True) model = model.to(device) conv3 = torch.zeros((1, 256, 16, 16), device=device) conv4 = torch.zeros((1, 512, 8, 8), device=device) conv5 = torch.zeros((1, 512, 4, 4), device=device) def get_conv(model, extracted_roi): with torch.no_grad(): global conv3, conv4, conv5 # can be remove for i in range(53): extracted_roi = model.features[i](extracted_roi) if i == 26: conv3 = extracted_roi if i == 39: conv4 = extracted_roi conv5 = extracted_roi return conv3, conv4, conv5 def get_border_roi(x1, y1, x2, y2, frame):
def get_network(args): """ return given network """ if args.net == 'vgg16': from models.vgg import vgg16_bn net = vgg16_bn() elif args.net == 'vgg13': from models.vgg import vgg13_bn net = vgg13_bn() elif args.net == 'vgg11': from models.vgg import vgg11_bn net = vgg11_bn() elif args.net == 'vgg19': from models.vgg import vgg19_bn net = vgg19_bn() # elif args.net == 'efficientnet': # from models.effnetv2 import effnetv2_s # net = effnetv2_s() elif args.net == 'densenet121': from models.densenet import densenet121 net = densenet121() elif args.net == 'densenet161': from models.densenet import densenet161 net = densenet161() elif args.net == 'densenet169': from models.densenet import densenet169 net = densenet169() elif args.net == 'densenet201': from models.densenet import densenet201 net = densenet201() elif args.net == 'googlenet': from models.googlenet import googlenet net = googlenet() elif args.net == 'inceptionv3': from models.inceptionv3 import inceptionv3 net = inceptionv3() elif args.net == 'inceptionv4': from models.inceptionv4 import inceptionv4 net = inceptionv4() elif args.net == 'inceptionresnetv2': from models.inceptionv4 import inception_resnet_v2 net = inception_resnet_v2() elif args.net == 'xception': from models.xception import xception net = xception() elif args.net == 'resnet18': from models.resnet import resnet18 net = resnet18() elif args.net == 'resnet34': from models.resnet import resnet34 net = resnet34() elif args.net == 'resnet50': from models.resnet import resnet50 net = resnet50() elif args.net == 'resnet101': from models.resnet import resnet101 net = resnet101() elif args.net == 'resnet152': from models.resnet import resnet152 net = resnet152() elif args.net == 'preactresnet18': from models.preactresnet import preactresnet18 net = preactresnet18() elif args.net == 'preactresnet34': from models.preactresnet import preactresnet34 net = preactresnet34() elif args.net == 'preactresnet50': from models.preactresnet import preactresnet50 net = preactresnet50() elif args.net == 'preactresnet101': from models.preactresnet import preactresnet101 net = preactresnet101() elif args.net == 'preactresnet152': from models.preactresnet import preactresnet152 net = preactresnet152() elif args.net == 'resnext50': from models.resnext import resnext50 net = resnext50() elif args.net == 'resnext101': from models.resnext import resnext101 net = resnext101() elif args.net == 'resnext152': from models.resnext import resnext152 net = resnext152() elif args.net == 'shufflenet': from models.shufflenet import shufflenet net = shufflenet() elif args.net == 'shufflenetv2': from models.shufflenetv2 import shufflenetv2 net = shufflenetv2() elif args.net == 'squeezenet': from models.squeezenet import squeezenet net = squeezenet() elif args.net == 'mobilenet': from models.mobilenet import mobilenet net = mobilenet() elif args.net == 'mobilenetv2': from models.mobilenetv2 import mobilenetv2 net = mobilenetv2() elif args.net == 'nasnet': from models.nasnet import nasnet net = nasnet() elif args.net == 'attention56': from models.attention import attention56 net = attention56() elif args.net == 'attention92': from models.attention import attention92 net = attention92() elif args.net == 'seresnet18': from models.senet import seresnet18 net = seresnet18() elif args.net == 'seresnet34': from models.senet import seresnet34 net = seresnet34() elif args.net == 'seresnet50': from models.senet import seresnet50 net = seresnet50() elif args.net == 'seresnet101': from models.senet import seresnet101 net = seresnet101() elif args.net == 'seresnet152': from models.senet import seresnet152 net = seresnet152() elif args.net == 'wideresnet': from models.wideresidual import wideresnet net = wideresnet() elif args.net == 'stochasticdepth18': from models.stochasticdepth import stochastic_depth_resnet18 net = stochastic_depth_resnet18() elif args.net == 'stochasticdepth34': from models.stochasticdepth import stochastic_depth_resnet34 net = stochastic_depth_resnet34() elif args.net == 'stochasticdepth50': from models.stochasticdepth import stochastic_depth_resnet50 net = stochastic_depth_resnet50() elif args.net == 'stochasticdepth101': from models.stochasticdepth import stochastic_depth_resnet101 net = stochastic_depth_resnet101() elif args.net == 'efficientnetb0': from models.efficientnet import efficientnetb0 net = efficientnetb0() elif args.net == 'efficientnetb1': from models.efficientnet import efficientnetb1 net = efficientnetb1() elif args.net == 'efficientnetb2': from models.efficientnet import efficientnetb2 net = efficientnetb2() elif args.net == 'efficientnetb3': from models.efficientnet import efficientnetb3 net = efficientnetb3() elif args.net == 'efficientnetb4': from models.efficientnet import efficientnetb4 net = efficientnetb4() elif args.net == 'efficientnetb5': from models.efficientnet import efficientnetb5 net = efficientnetb5() elif args.net == 'efficientnetb6': from models.efficientnet import efficientnetb6 net = efficientnetb6() elif args.net == 'efficientnetb7': from models.efficientnet import efficientnetb7 net = efficientnetb7() elif args.net == 'efficientnetl2': from models.efficientnet import efficientnetl2 net = efficientnetl2() elif args.net == 'eff': from models.efficientnet_pytorch import EfficientNet net = EfficientNet.from_pretrained('efficientnet-b7', num_classes=2) else: print('the network name you have entered is not supported yet') sys.exit() if args.gpu: #use_gpu net = net.cuda() print("use-gpu") return net
def test_model(modname='alexnet', pm_ch='both', bs=16): # hyperparameters batch_size = bs # device configuration device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') # determine number of input channels nch = 2 if pm_ch != 'both': nch = 1 # restore model model = None if modname == 'alexnet': model = alexnet(num_classes=3, in_ch=nch).to(device) elif modname == 'densenet': model = DenseNet(num_classes=3, in_ch=nch).to(device) elif modname == 'inception': model = inception_v3(num_classes=3, in_ch=nch).to(device) elif modname == 'resnet': model = resnet18(num_classes=3, in_ch=nch).to(device) elif modname == 'squeezenet': model = squeezenet1_1(num_classes=3, in_ch=nch).to(device) elif modname == 'vgg': model = vgg19_bn(in_ch=nch, num_classes=3).to(device) else: print('Model {} not defined.'.format(modname)) return # retrieve trained model # load path load_path = '../../../data/two_views/saved_models/{}/{}'.format( modname, pm_ch) model_pathname = os.path.join(load_path, 'model.ckpt') if not os.path.exists(model_pathname): print('Trained model file {} does not exist. Abort.'.format( model_pathname)) return model.load_state_dict(torch.load(model_pathname)) # load test dataset test_dataset = PixelMapDataset('test_file_list.txt', pm_ch) test_loader = torch.utils.data.DataLoader(dataset=test_dataset, batch_size=batch_size, shuffle=False) # test the model model.eval( ) # eval mode (batchnorm uses moving mean/variance instead of mini-batch mean/variance) with torch.no_grad(): correct = 0 total = 0 correct_cc_or_bkg = 0 ws_total = 0 ws_correct = 0 for view1, view2, labels in test_loader: view1 = view1.float().to(device) if modname == 'inception': view1 = nn.ZeroPad2d((0, 192, 102, 101))(view1) else: view1 = nn.ZeroPad2d((0, 117, 64, 64))(view1) labels = labels.to(device) outputs = model(view1) _, predicted = torch.max(outputs.data, 1) total += labels.size(0) correct += (predicted == labels).sum().item() for i in range(len(predicted)): if (predicted[i] < 2 and labels[i] < 2) or (predicted[i] == 2 and labels[i] == 2): correct_cc_or_bkg += 1 if labels[i] < 2: ws_total += 1 if (predicted[i] == labels[i]): ws_correct += 1 print('Model Performance:') print('Model:', modname) print('Channel:', pm_ch) print( '3-class Test Accuracy of the model on the test images: {}/{}, {:.2f} %' .format(correct, total, 100 * correct / total)) print( '2-class Test Accuracy of the model on the test images: {}/{}, {:.2f} %' .format(correct_cc_or_bkg, total, 100 * correct_cc_or_bkg / total)) print( 'Wrong-sign Test Accuracy of the model on the test images: {}/{}, {:.2f} %' .format(ws_correct, ws_total, 100 * ws_correct / ws_total))
def main(): global args, best_err1 args = parser.parse_args() # TensorBoard configure if args.tensorboard: configure('%s_checkpoints/%s'%(args.dataset, args.expname)) # CUDA os.environ['CUDA_DEVICE_ORDER'] = "PCI_BUS_ID" os.environ['CUDA_VISIBLE_DEVICES'] = str(args.gpu_ids) if torch.cuda.is_available(): cudnn.benchmark = True # https://discuss.pytorch.org/t/what-does-torch-backends-cudnn-benchmark-do/5936 kwargs = {'num_workers': 2, 'pin_memory': True} else: kwargs = {'num_workers': 2} # Data loading code if args.dataset == 'cifar10': normalize = transforms.Normalize(mean=[0.4914, 0.4822, 0.4465], std=[0.2023, 0.1994, 0.2010]) elif args.dataset == 'cifar100': normalize = transforms.Normalize(mean=[0.5071, 0.4865, 0.4409], std=[0.2634, 0.2528, 0.2719]) elif args.dataset == 'cub': normalize = transforms.Normalize(mean=[0.4862, 0.4973, 0.4293], std=[0.2230, 0.2185, 0.2472]) elif args.dataset == 'webvision': normalize = transforms.Normalize(mean=[0.49274242, 0.46481857, 0.41779366], std=[0.26831809, 0.26145372, 0.27042758]) else: raise Exception('Unknown dataset: {}'.format(args.dataset)) # Transforms if args.augment: train_transform = transforms.Compose([ transforms.RandomResizedCrop(args.train_image_size), transforms.RandomHorizontalFlip(), transforms.ToTensor(), normalize, ]) else: train_transform = transforms.Compose([ transforms.RandomResizedCrop(args.train_image_size), transforms.ToTensor(), normalize, ]) val_transform = transforms.Compose([ transforms.Resize(args.test_image_size), transforms.CenterCrop(args.test_crop_image_size), transforms.ToTensor(), normalize ]) # Datasets num_classes = 10 # default 10 classes if args.dataset == 'cifar10': train_dataset = datasets.CIFAR10('./data/', train=True, download=True, transform=train_transform) val_dataset = datasets.CIFAR10('./data/', train=False, download=True, transform=val_transform) num_classes = 10 elif args.dataset == 'cifar100': train_dataset = datasets.CIFAR100('./data/', train=True, download=True, transform=train_transform) val_dataset = datasets.CIFAR100('./data/', train=False, download=True, transform=val_transform) num_classes = 100 elif args.dataset == 'cub': train_dataset = datasets.ImageFolder('/media/ouc/30bd7817-d3a1-4e83-b7d9-5c0e373ae434/DuAngAng/datasets/CUB-200-2011/train/', transform=train_transform) val_dataset = datasets.ImageFolder('/media/ouc/30bd7817-d3a1-4e83-b7d9-5c0e373ae434/DuAngAng/datasets/CUB-200-2011/test/', transform=val_transform) num_classes = 200 elif args.dataset == 'webvision': train_dataset = datasets.ImageFolder('/media/ouc/30bd7817-d3a1-4e83-b7d9-5c0e373ae434/LiuJing/WebVision/info/train', transform=train_transform) val_dataset = datasets.ImageFolder('/media/ouc/30bd7817-d3a1-4e83-b7d9-5c0e373ae434/LiuJing/WebVision/info/val', transform=val_transform) num_classes = 1000 else: raise Exception('Unknown dataset: {}'.format(args.dataset)) # Data Loader train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=args.batchsize, shuffle=True, **kwargs) val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=1, shuffle=False, **kwargs) # Create model if args.model == 'AlexNet': model = alexnet(pretrained=False, num_classes=num_classes) elif args.model == 'VGG': use_batch_normalization = True # default use Batch Normalization if use_batch_normalization: if args.depth == 11: model = vgg11_bn(pretrained=False, num_classes=num_classes) elif args.depth == 13: model = vgg13_bn(pretrained=False, num_classes=num_classes) elif args.depth == 16: model = vgg16_bn(pretrained=False, num_classes=num_classes) elif args.depth == 19: model = vgg19_bn(pretrained=False, num_classes=num_classes) else: raise Exception('Unsupport VGG detph: {}, optional depths: 11, 13, 16 or 19'.format(args.depth)) else: if args.depth == 11: model = vgg11(pretrained=False, num_classes=num_classes) elif args.depth == 13: model = vgg13(pretrained=False, num_classes=num_classes) elif args.depth == 16: model = vgg16(pretrained=False, num_classes=num_classes) elif args.depth == 19: model = vgg19(pretrained=False, num_classes=num_classes) else: raise Exception('Unsupport VGG detph: {}, optional depths: 11, 13, 16 or 19'.format(args.depth)) elif args.model == 'Inception': model = inception_v3(pretrained=False, num_classes=num_classes) elif args.model == 'ResNet': if args.depth == 18: model = resnet18(pretrained=False, num_classes=num_classes) elif args.depth == 34: model = resnet34(pretrained=False, num_classes=num_classes) elif args.depth == 50: model = resnet50(pretrained=False, num_classes=num_classes) elif args.depth == 101: model = resnet101(pretrained=False, num_classes=num_classes) elif args.depth == 152: model = resnet152(pretrained=False, num_classes=num_classes) else: raise Exception('Unsupport ResNet detph: {}, optional depths: 18, 34, 50, 101 or 152'.format(args.depth)) elif args.model == 'MPN-COV-ResNet': if args.depth == 18: model = mpn_cov_resnet18(pretrained=False, num_classes=num_classes) elif args.depth == 34: model = mpn_cov_resnet34(pretrained=False, num_classes=num_classes) elif args.depth == 50: model = mpn_cov_resnet50(pretrained=False, num_classes=num_classes) elif args.depth == 101: model = mpn_cov_resnet101(pretrained=False, num_classes=num_classes) elif args.depth == 152: model = mpn_cov_resnet152(pretrained=False, num_classes=num_classes) else: raise Exception('Unsupport MPN-COV-ResNet detph: {}, optional depths: 18, 34, 50, 101 or 152'.format(args.depth)) else: raise Exception('Unsupport model'.format(args.model)) # Get the number of model parameters print('Number of model parameters: {}'.format(sum([p.data.nelement() for p in model.parameters()]))) if torch.cuda.is_available(): model = model.cuda() # Optionally resume from a checkpoint if args.resume: if os.path.isfile(args.resume): print("==> loading checkpoint '{}'".format(args.resume)) checkpoint = torch.load(args.resume) args.start_epoch = checkpoint['epoch'] best_err1 = checkpoint['best_err1'] model.load_state_dict(checkpoint['state_dict']) print("==> loaded checkpoint '{}' (epoch {})" .format(args.resume, checkpoint['epoch'])) else: print("==> no checkpoint found at '{}'".format(args.resume)) print(model) # Define loss function (criterion) and optimizer criterion = nn.CrossEntropyLoss() if torch.cuda.is_available(): criterion = criterion.cuda() optimizer = torch.optim.SGD(model.parameters(), args.lr, momentum=args.momentum, weight_decay=args.weight_decay, nesterov=True) for epoch in range(args.start_epoch, args.epochs): adjust_learning_rate(optimizer, epoch) # Train for one epoch train(train_loader, model, criterion, optimizer, epoch) # Evaluate on validation set err1 = validate(val_loader, model, criterion, epoch) # Remember best err1 and save checkpoint is_best = (err1 <= best_err1) best_err1 = min(err1, best_err1) print("Current best accuracy (error):", best_err1) save_checkpoint({ 'epoch': epoch+1, 'state_dict': model.state_dict(), 'best_err1': best_err1, }, is_best) print("Best accuracy (error):", best_err1)
def load_trained_model(model_name, train_set, device=torch.device('cpu')): """ Loads a pre-trained model from a state dict. Assumes that your models are saved in 'bayesian-calibration/models' and that your state dicts are saved in 'bayesian-calibration/models/checkpoints' Args: model_name: str ; train_set: str ; device: str; cpu by default Returns: A trained PyTorch model in eval mode. """ print('\nLoading pre-trained model') print('----| Model: {} Train set: {}'.format(model_name, train_set)) train_set = train_set.lower() if train_set.startswith('cifar'): # Load local cifar-trained models num_classes = {'cifar100': 100, 'cifar10': 10, 'cifar10imba': 10} train_set = train_set.lower().strip() model_name = model_name.lower().strip() # Load the saved state dict path_str = 'models/checkpoints/{}_{}.tar'.format(model_name, train_set) checkpoint_path = pathlib.Path(path_str).resolve() checkpoint = torch.load(checkpoint_path) state_dict = checkpoint['state_dict'] if model_name == 'resnet-110': from models.resnet import resnet state_dict = _strip_parallel_model(state_dict) model = resnet(num_classes=num_classes[train_set], depth=110, block_name='BasicBlock') elif model_name == 'alexnet': from models.alexnet import alexnet state_dict = _strip_parallel_model(state_dict) model = alexnet(num_classes=num_classes[train_set]) elif model_name == 'vgg19-bn': from models.vgg import vgg19_bn state_dict = _strip_parallel_model(state_dict) model = vgg19_bn(num_classes=num_classes[train_set]) elif model_name == 'wrn-28-10': from models.wrn import wrn state_dict = _strip_parallel_model(state_dict) model = wrn(num_classes=num_classes[train_set], depth=28, widen_factor=10, dropRate=0.3) else: raise NotImplementedError model.load_state_dict(state_dict) elif train_set == 'imagenet': # Thin wrapper to load PyTorch pretrained imagenet models import torchvision.models as models model = getattr(models, model_name)(pretrained=True) else: raise NotImplementedError model.eval() return model.to(device)
import torchvision.transforms as transforms def save_image_tensor2pillow(input_tensor: torch.Tensor, filename): assert (len(input_tensor.shape) == 3) input_tensor = input_tensor.clone().detach() input_tensor = input_tensor.to(torch.device('cpu')) input_tensor = input_tensor.squeeze() input_tensor = input_tensor.mul_(255).add_(0.5).clamp_(0, 255).permute(1, 2, 0).type(torch.uint8).numpy() im = Image.fromarray(input_tensor) im.save(filename) torch.manual_seed(42) use_cuda = torch.cuda.is_available() device = torch.device("cuda" if use_cuda else "cpu") model = vgg19_bn() model.load_state_dict( torch.load("checkpoint/vgg_baseline.pth")) model.to(device) model.eval() cifar100_test_loader = get_test_dataloader( settings.CIFAR100_TRAIN_MEAN, settings.CIFAR100_TRAIN_STD, #settings.CIFAR100_PATH, num_workers=2, batch_size=16, shuffle=True ) adversary = GradientSignAttack(
def get_model(args, model_path=None): """ :param args: super arguments :param model_path: if not None, load already trained model parameters. :return: model """ if args.scratch: # train model from scratch pretrained = False model_dir = None print("=> Loading model '{}' from scratch...".format(args.model)) else: # train model with pretrained model pretrained = True model_dir = os.path.join(args.root_path, args.pretrained_models_path) print("=> Loading pretrained model '{}'...".format(args.model)) if args.model.startswith('resnet'): if args.model == 'resnet18': model = resnet18(pretrained=pretrained, model_dir=model_dir) elif args.model == 'resnet34': model = resnet34(pretrained=pretrained, model_dir=model_dir) elif args.model == 'resnet50': model = resnet50(pretrained=pretrained, model_dir=model_dir) elif args.model == 'resnet101': model = resnet101(pretrained=pretrained, model_dir=model_dir) elif args.model == 'resnet152': model = resnet152(pretrained=pretrained, model_dir=model_dir) model.fc = nn.Linear(model.fc.in_features, args.num_classes) elif args.model.startswith('vgg'): if args.model == 'vgg11': model = vgg11(pretrained=pretrained, model_dir=model_dir) elif args.model == 'vgg11_bn': model = vgg11_bn(pretrained=pretrained, model_dir=model_dir) elif args.model == 'vgg13': model = vgg13(pretrained=pretrained, model_dir=model_dir) elif args.model == 'vgg13_bn': model = vgg13_bn(pretrained=pretrained, model_dir=model_dir) elif args.model == 'vgg16': model = vgg16(pretrained=pretrained, model_dir=model_dir) elif args.model == 'vgg16_bn': model = vgg16_bn(pretrained=pretrained, model_dir=model_dir) elif args.model == 'vgg19': model = vgg19(pretrained=pretrained, model_dir=model_dir) elif args.model == 'vgg19_bn': model = vgg19_bn(pretrained=pretrained, model_dir=model_dir) model.classifier[6] = nn.Linear(model.classifier[6].in_features, args.num_classes) elif args.model == 'alexnet': model = alexnet(pretrained=pretrained, model_dir=model_dir) model.classifier[6] = nn.Linear(model.classifier[6].in_features, args.num_classes) # Load already trained model parameters and go on training if model_path is not None: checkpoint = torch.load(model_path) model.load_state_dict(checkpoint['model']) return model
def get_network(args): """ return given network """ if args.model == 'vgg16': from models.vgg import vgg16_bn model = vgg16_bn() elif args.model == 'vgg13': from models.vgg import vgg13_bn model = vgg13_bn() elif args.model == 'vgg11': from models.vgg import vgg11_bn model = vgg11_bn() elif args.model == 'vgg19': from models.vgg import vgg19_bn model = vgg19_bn() elif args.model == 'densenet121': from models.densenet import densenet121 model = densenet121() elif args.model == 'densenet161': from models.densenet import densenet161 model = densenet161() elif args.model == 'densenet169': from models.densenet import densenet169 model = densenet169() elif args.model == 'densenet201': from models.densenet import densenet201 model = densenet201() elif args.model == 'googlenet': from models.googlenet import googlenet model = googlenet() elif args.model == 'inceptionv3': from models.inceptionv3 import inceptionv3 model = inceptionv3() elif args.model == 'inceptionv4': from models.inceptionv4 import inceptionv4 model = inceptionv4() elif args.model == 'inceptionresnetv2': from models.inceptionv4 import inception_resnet_v2 model = inception_resnet_v2() elif args.model == 'xception': from models.xception import xception model = xception() elif args.model == 'resnet18': from models.resnet import resnet18 model = resnet18() elif args.model == 'resnet34': from models.resnet import resnet34 model = resnet34() elif args.model == 'resnet50': from models.resnet import resnet50 model = resnet50() elif args.model == 'resnet101': from models.resnet import resnet101 model = resnet101() elif args.model == 'resnet152': from models.resnet import resnet152 model = resnet152() elif args.model == 'preactresnet18': from models.preactresnet import preactresnet18 model = preactresnet18() elif args.model == 'preactresnet34': from models.preactresnet import preactresnet34 model = preactresnet34() elif args.model == 'preactresnet50': from models.preactresnet import preactresnet50 model = preactresnet50() elif args.model == 'preactresnet101': from models.preactresnet import preactresnet101 model = preactresnet101() elif args.model == 'preactresnet152': from models.preactresnet import preactresnet152 model = preactresnet152() elif args.model == 'resnext50': from models.resnext import resnext50 model = resnext50() elif args.model == 'resnext101': from models.resnext import resnext101 model = resnext101() elif args.model == 'resnext152': from models.resnext import resnext152 model = resnext152() elif args.model == 'shufflenet': from models.shufflenet import shufflenet model = shufflenet() elif args.model == 'shufflenetv2': from models.shufflenetv2 import shufflenetv2 model = shufflenetv2() elif args.model == 'squeezenet': from models.squeezenet import squeezenet model = squeezenet() elif args.model == 'mobilenet': from models.mobilenet import mobilenet model = mobilenet() elif args.model == 'mobilenetv2': from models.mobilenetv2 import mobilenetv2 model = mobilenetv2() elif args.model == 'nasnet': from models.nasnet import nasnet model = nasnet() elif args.model == 'attention56': from models.attention import attention56 model = attention56() elif args.model == 'attention92': from models.attention import attention92 model = attention92() elif args.model == 'seresnet18': from models.senet import seresnet18 model = seresnet18() elif args.model == 'seresnet34': from models.senet import seresnet34 model = seresnet34() elif args.model == 'seresnet50': from models.senet import seresnet50 model = seresnet50() elif args.model == 'seresnet101': from models.senet import seresnet101 model = seresnet101() elif args.model == 'seresnet152': from models.senet import seresnet152 model = seresnet152() elif args.model == 'wideresnet': from models.wideresidual import wideresnet model = wideresnet() elif args.model == 'stochasticdepth18': from models.stochasticdepth import stochastic_depth_resnet18 model = stochastic_depth_resnet18() elif args.model == 'stochasticdepth34': from models.stochasticdepth import stochastic_depth_resnet34 model = stochastic_depth_resnet34() elif args.model == 'stochasticdepth50': from models.stochasticdepth import stochastic_depth_resnet50 model = stochastic_depth_resnet50() elif args.model == 'stochasticdepth101': from models.stochasticdepth import stochastic_depth_resnet101 model = stochastic_depth_resnet101() else: print('the network name you have entered is not supported yet') sys.exit() return model
def get_network(args, use_gpu=True, num_train=0): """ return given network """ if args.dataset == 'cifar-10': num_classes = 10 elif args.dataset == 'cifar-100': num_classes = 100 else: num_classes = 0 if args.ignoring: if args.net == 'resnet18': from models.resnet_ign import resnet18_ign criterion = nn.CrossEntropyLoss(reduction='none') net = resnet18_ign(criterion, num_classes=num_classes, num_train=num_train,softmax=args.softmax,isalpha=args.isalpha) else: if args.net == 'vgg16': from models.vgg import vgg16_bn net = vgg16_bn() elif args.net == 'vgg13': from models.vgg import vgg13_bn net = vgg13_bn() elif args.net == 'vgg11': from models.vgg import vgg11_bn net = vgg11_bn() elif args.net == 'vgg19': from models.vgg import vgg19_bn net = vgg19_bn() elif args.net == 'densenet121': from models.densenet import densenet121 net = densenet121() elif args.net == 'densenet161': from models.densenet import densenet161 net = densenet161() elif args.net == 'densenet169': from models.densenet import densenet169 net = densenet169() elif args.net == 'densenet201': from models.densenet import densenet201 net = densenet201() elif args.net == 'googlenet': from models.googlenet import googlenet net = googlenet() elif args.net == 'inceptionv3': from models.inceptionv3 import inceptionv3 net = inceptionv3() elif args.net == 'inceptionv4': from models.inceptionv4 import inceptionv4 net = inceptionv4() elif args.net == 'inceptionresnetv2': from models.inceptionv4 import inception_resnet_v2 net = inception_resnet_v2() elif args.net == 'xception': from models.xception import xception net = xception() elif args.net == 'resnet18': from models.resnet import resnet18 net = resnet18(num_classes=num_classes) elif args.net == 'resnet34': from models.resnet import resnet34 net = resnet34(num_classes=num_classes) elif args.net == 'resnet50': from models.resnet import resnet50 net = resnet50(num_classes=num_classes) elif args.net == 'resnet101': from models.resnet import resnet101 net = resnet101(num_classes=num_classes) elif args.net == 'resnet152': from models.resnet import resnet152 net = resnet152(num_classes=num_classes) elif args.net == 'preactresnet18': from models.preactresnet import preactresnet18 net = preactresnet18() elif args.net == 'preactresnet34': from models.preactresnet import preactresnet34 net = preactresnet34() elif args.net == 'preactresnet50': from models.preactresnet import preactresnet50 net = preactresnet50() elif args.net == 'preactresnet101': from models.preactresnet import preactresnet101 net = preactresnet101() elif args.net == 'preactresnet152': from models.preactresnet import preactresnet152 net = preactresnet152() elif args.net == 'resnext50': from models.resnext import resnext50 net = resnext50() elif args.net == 'resnext101': from models.resnext import resnext101 net = resnext101() elif args.net == 'resnext152': from models.resnext import resnext152 net = resnext152() elif args.net == 'shufflenet': from models.shufflenet import shufflenet net = shufflenet() elif args.net == 'shufflenetv2': from models.shufflenetv2 import shufflenetv2 net = shufflenetv2() elif args.net == 'squeezenet': from models.squeezenet import squeezenet net = squeezenet() elif args.net == 'mobilenet': from models.mobilenet import mobilenet net = mobilenet() elif args.net == 'mobilenetv2': from models.mobilenetv2 import mobilenetv2 net = mobilenetv2() elif args.net == 'nasnet': from models.nasnet import nasnet net = nasnet() elif args.net == 'attention56': from models.attention import attention56 net = attention56() elif args.net == 'attention92': from models.attention import attention92 net = attention92() elif args.net == 'seresnet18': from models.senet import seresnet18 net = seresnet18() elif args.net == 'seresnet34': from models.senet import seresnet34 net = seresnet34() elif args.net == 'seresnet50': from models.senet import seresnet50 net = seresnet50() elif args.net == 'seresnet101': from models.senet import seresnet101 net = seresnet101() elif args.net == 'seresnet152': from models.senet import seresnet152 net = seresnet152() else: print('the network name you have entered is not supported yet') sys.exit() if use_gpu: net = net.cuda() return net
def get_model(args): if args.datasets == 'ImageNet': return models_imagenet.__dict__[args.arch]() if args.datasets == 'CIFAR10' or args.datasets == 'MNIST': num_class = 10 elif args.datasets == 'CIFAR100': num_class = 100 if args.datasets == 'CIFAR100': if args.arch == 'vgg16': from models.vgg import vgg16_bn net = vgg16_bn() elif args.arch == 'vgg13': from models.vgg import vgg13_bn net = vgg13_bn() elif args.arch == 'vgg11': from models.vgg import vgg11_bn net = vgg11_bn() elif args.arch == 'vgg19': from models.vgg import vgg19_bn net = vgg19_bn() elif args.arch == 'densenet121': from models.densenet import densenet121 net = densenet121() elif args.arch == 'densenet161': from models.densenet import densenet161 net = densenet161() elif args.arch == 'densenet169': from models.densenet import densenet169 net = densenet169() elif args.arch == 'densenet201': from models.densenet import densenet201 net = densenet201() elif args.arch == 'googlenet': from models.googlenet import googlenet net = googlenet() elif args.arch == 'inceptionv3': from models.inceptionv3 import inceptionv3 net = inceptionv3() elif args.arch == 'inceptionv4': from models.inceptionv4 import inceptionv4 net = inceptionv4() elif args.arch == 'inceptionresnetv2': from models.inceptionv4 import inception_resnet_v2 net = inception_resnet_v2() elif args.arch == 'xception': from models.xception import xception net = xception() elif args.arch == 'resnet18': from models.resnet import resnet18 net = resnet18() elif args.arch == 'resnet34': from models.resnet import resnet34 net = resnet34() elif args.arch == 'resnet50': from models.resnet import resnet50 net = resnet50() elif args.arch == 'resnet101': from models.resnet import resnet101 net = resnet101() elif args.arch == 'resnet152': from models.resnet import resnet152 net = resnet152() elif args.arch == 'preactresnet18': from models.preactresnet import preactresnet18 net = preactresnet18() elif args.arch == 'preactresnet34': from models.preactresnet import preactresnet34 net = preactresnet34() elif args.arch == 'preactresnet50': from models.preactresnet import preactresnet50 net = preactresnet50() elif args.arch == 'preactresnet101': from models.preactresnet import preactresnet101 net = preactresnet101() elif args.arch == 'preactresnet152': from models.preactresnet import preactresnet152 net = preactresnet152() elif args.arch == 'resnext50': from models.resnext import resnext50 net = resnext50() elif args.arch == 'resnext101': from models.resnext import resnext101 net = resnext101() elif args.arch == 'resnext152': from models.resnext import resnext152 net = resnext152() elif args.arch == 'shufflenet': from models.shufflenet import shufflenet net = shufflenet() elif args.arch == 'shufflenetv2': from models.shufflenetv2 import shufflenetv2 net = shufflenetv2() elif args.arch == 'squeezenet': from models.squeezenet import squeezenet net = squeezenet() elif args.arch == 'mobilenet': from models.mobilenet import mobilenet net = mobilenet() elif args.arch == 'mobilenetv2': from models.mobilenetv2 import mobilenetv2 net = mobilenetv2() elif args.arch == 'nasnet': from models.nasnet import nasnet net = nasnet() elif args.arch == 'attention56': from models.attention import attention56 net = attention56() elif args.arch == 'attention92': from models.attention import attention92 net = attention92() elif args.arch == 'seresnet18': from models.senet import seresnet18 net = seresnet18() elif args.arch == 'seresnet34': from models.senet import seresnet34 net = seresnet34() elif args.arch == 'seresnet50': from models.senet import seresnet50 net = seresnet50() elif args.arch == 'seresnet101': from models.senet import seresnet101 net = seresnet101() elif args.arch == 'seresnet152': from models.senet import seresnet152 net = seresnet152() elif args.arch == 'wideresnet': from models.wideresidual import wideresnet net = wideresnet() elif args.arch == 'stochasticdepth18': from models.stochasticdepth import stochastic_depth_resnet18 net = stochastic_depth_resnet18() elif args.arch == 'efficientnet': from models.efficientnet import efficientnet net = efficientnet(1, 1, 100, bn_momentum=0.9) elif args.arch == 'stochasticdepth34': from models.stochasticdepth import stochastic_depth_resnet34 net = stochastic_depth_resnet34() elif args.arch == 'stochasticdepth50': from models.stochasticdepth import stochastic_depth_resnet50 net = stochastic_depth_resnet50() elif args.arch == 'stochasticdepth101': from models.stochasticdepth import stochastic_depth_resnet101 net = stochastic_depth_resnet101() else: net = resnet.__dict__[args.arch](num_classes=num_class) return net return resnet.__dict__[args.arch](num_classes=num_class)
cnt += 1 if cnt == 20: logger.info("early stop") break for ds in dataset: data_path = os.path.join(args.root, ds) cls = [ x for x in os.listdir(data_path) if os.path.isdir(os.path.join(data_path, x)) ] num_class = len(cls) models = { "vgg16": vgg.vgg16_bn(num_class), "vgg19": vgg.vgg19_bn(num_class), "densenet121": densenet.densenet121(num_class), "densenet161": densenet.densenet161(num_class), "resnet34": resnet.resnet34(num_class), "resnet50": resnet.resnet50(num_class), "resnet101": resnet.resnet101(num_class), "seresnet34": senet.seresnet34(num_class), "seresnet50": senet.seresnet50(num_class), "seresnet101": senet.seresnet101(num_class), "resnext34": resnext.resnext34(num_class), "resnext50": resnext.resnext50(num_class), "resnext101": resnext.resnext101(num_class), "shufflenet": shufflenet.shufflenet(num_class), "xception": xception.xception(num_class) } for net_name in models.keys():
def get_network(args): """ return given network """ if args.task == 'cifar10': nclass = 10 elif args.task == 'cifar100': nclass = 100 #Yang added none bn vggs if args.net == 'vgg16': from models.vgg import vgg16 net = vgg16(num_classes=nclass) elif args.net == 'vgg13': from models.vgg import vgg13 net = vgg13(num_classes=nclass) elif args.net == 'vgg11': from models.vgg import vgg11 net = vgg11(num_classes=nclass) elif args.net == 'vgg19': from models.vgg import vgg19 net = vgg19(num_classes=nclass) elif args.net == 'vgg16bn': from models.vgg import vgg16_bn net = vgg16_bn(num_classes=nclass) elif args.net == 'vgg13bn': from models.vgg import vgg13_bn net = vgg13_bn(num_classes=nclass) elif args.net == 'vgg11bn': from models.vgg import vgg11_bn net = vgg11_bn(num_classes=nclass) elif args.net == 'vgg19bn': from models.vgg import vgg19_bn net = vgg19_bn(num_classes=nclass) elif args.net == 'densenet121': from models.densenet import densenet121 net = densenet121() elif args.net == 'densenet161': from models.densenet import densenet161 net = densenet161() elif args.net == 'densenet169': from models.densenet import densenet169 net = densenet169() elif args.net == 'densenet201': from models.densenet import densenet201 net = densenet201() elif args.net == 'googlenet': from models.googlenet import googlenet net = googlenet(num_classes=nclass) elif args.net == 'inceptionv3': from models.inceptionv3 import inceptionv3 net = inceptionv3() elif args.net == 'inceptionv4': from models.inceptionv4 import inceptionv4 net = inceptionv4() elif args.net == 'inceptionresnetv2': from models.inceptionv4 import inception_resnet_v2 net = inception_resnet_v2() elif args.net == 'xception': from models.xception import xception net = xception(num_classes=nclass) elif args.net == 'scnet': from models.sphereconvnet import sphereconvnet net = sphereconvnet(num_classes=nclass) elif args.net == 'sphereresnet18': from models.sphereconvnet import resnet18 net = resnet18(num_classes=nclass) elif args.net == 'sphereresnet32': from models.sphereconvnet import sphereresnet32 net = sphereresnet32(num_classes=nclass) elif args.net == 'plainresnet32': from models.sphereconvnet import plainresnet32 net = plainresnet32(num_classes=nclass) elif args.net == 'ynet18': from models.ynet import resnet18 net = resnet18(num_classes=nclass) elif args.net == 'ynet34': from models.ynet import resnet34 net = resnet34(num_classes=nclass) elif args.net == 'ynet50': from models.ynet import resnet50 net = resnet50(num_classes=nclass) elif args.net == 'ynet101': from models.ynet import resnet101 net = resnet101(num_classes=nclass) elif args.net == 'ynet152': from models.ynet import resnet152 net = resnet152(num_classes=nclass) elif args.net == 'resnet18': from models.resnet import resnet18 net = resnet18(num_classes=nclass) elif args.net == 'resnet34': from models.resnet import resnet34 net = resnet34(num_classes=nclass) elif args.net == 'resnet50': from models.resnet import resnet50 net = resnet50(num_classes=nclass) elif args.net == 'resnet101': from models.resnet import resnet101 net = resnet101(num_classes=nclass) elif args.net == 'resnet152': from models.resnet import resnet152 net = resnet152(num_classes=nclass) elif args.net == 'preactresnet18': from models.preactresnet import preactresnet18 net = preactresnet18(num_classes=nclass) elif args.net == 'preactresnet34': from models.preactresnet import preactresnet34 net = preactresnet34(num_classes=nclass) elif args.net == 'preactresnet50': from models.preactresnet import preactresnet50 net = preactresnet50(num_classes=nclass) elif args.net == 'preactresnet101': from models.preactresnet import preactresnet101 net = preactresnet101(num_classes=nclass) elif args.net == 'preactresnet152': from models.preactresnet import preactresnet152 net = preactresnet152(num_classes=nclass) elif args.net == 'resnext50': from models.resnext import resnext50 net = resnext50(num_classes=nclass) elif args.net == 'resnext101': from models.resnext import resnext101 net = resnext101(num_classes=nclass) elif args.net == 'resnext152': from models.resnext import resnext152 net = resnext152(num_classes=nclass) elif args.net == 'shufflenet': from models.shufflenet import shufflenet net = shufflenet() elif args.net == 'shufflenetv2': from models.shufflenetv2 import shufflenetv2 net = shufflenetv2() elif args.net == 'squeezenet': from models.squeezenet import squeezenet net = squeezenet() elif args.net == 'mobilenet': from models.mobilenet import mobilenet net = mobilenet(num_classes=nclass) elif args.net == 'mobilenetv2': from models.mobilenetv2 import mobilenetv2 net = mobilenetv2(num_classes=nclass) elif args.net == 'nasnet': from models.nasnet import nasnet net = nasnet(num_classes=nclass) elif args.net == 'attention56': from models.attention import attention56 net = attention56() elif args.net == 'attention92': from models.attention import attention92 net = attention92() elif args.net == 'seresnet18': from models.senet import seresnet18 net = seresnet18(num_classes=nclass) elif args.net == 'seresnet34': from models.senet import seresnet34 net = seresnet34(num_classes=nclass) elif args.net == 'seresnet50': from models.senet import seresnet50 net = seresnet50(num_classes=nclass) elif args.net == 'seresnet101': from models.senet import seresnet101 net = seresnet101(num_classes=nclass) elif args.net == 'seresnet152': from models.senet import seresnet152 net = seresnet152(num_classes=nclass) else: print('the network name you have entered is not supported yet') sys.exit() if args.gpu: #use_gpu net = net.cuda() return net
def get_network(args): """ return given network """ if args.net == 'vgg11': from models.vgg import vgg11_bn net = vgg11_bn() elif args.net == 'vgg13': from models.vgg import vgg13_bn net = vgg13_bn() elif args.net == 'vgg16': from models.vgg import vgg16_bn net = vgg16_bn() elif args.net == 'vgg19': from models.vgg import vgg19_bn net = vgg19_bn() elif args.net == 'googlenet': from models.googLeNet import GoogLeNet net = GoogLeNet() elif args.net == 'inceptionv3': from models.inceptionv3 import Inceptionv3 net = Inceptionv3() elif args.net == 'resnet18': from models.resnet import resnet18 net = resnet18() elif args.net == 'resnet34': from models.resnet import resnet34 net = resnet34() elif args.net == 'resnet50': from models.resnet import resnet50 net = resnet50() elif args.net == 'resnet101': from models.resnet import resnet101 net = resnet101() elif args.net == 'resnet152': from models.resnet import resnet152 net = resnet152() elif args.net == 'wrn': from models.wideresnet import wideresnet net = wideresnet() elif args.net == 'densenet121': from models.densenet import densenet121 net = densenet121() elif args.net == 'densenet161': from models.densenet import densenet161 net = densenet161() elif args.net == 'densenet169': from models.densenet import densenet169 net = densenet169() elif args.net == 'densenet201': from models.densenet import densenet201 net = densenet201() else: print('the network name you have entered is not supported yet') sys.exit() if args.gpu: print("use gpu") net = net.cuda() return net
def get_network(key, num_cls=2, use_gpu=False): """ return given network """ if key == 'vgg16': from models.vgg import vgg16_bn net = vgg16_bn(num_cls) elif key == 'vgg13': from models.vgg import vgg13_bn net = vgg13_bn(num_cls) elif key == 'vgg11': from models.vgg import vgg11_bn net = vgg11_bn(num_cls) elif key == 'vgg19': from models.vgg import vgg19_bn net = vgg19_bn(num_cls) elif key == 'resnext': print('we will continue') elif key == 'efficientNetb0': from models.torchefficient import make_model net = make_model(key, num_cls) elif key == 'efficientNetb1': from models.torchefficient import make_model net = make_model(key, num_cls) elif key == 'efficientNetb2': from models.torchefficient import make_model net = make_model(key, num_cls) elif key == 'efficientNetb3': from models.torchefficient import make_model net = make_model(key, num_cls) elif key == 'efficientNetb4': from models.torchefficient import make_model net = make_model(key, num_cls) elif key == 'efficientNetb5': from models.torchefficient import make_model net = make_model(key, num_cls) elif key == 'efficientNetb6': from models.torchefficient import make_model net = make_model(key, num_cls) elif key == 'efficientNetb7': from models.torchefficient import make_model net = make_model(key, num_cls) elif key == 'resnext50_32x8d': from models.resnext import make_model net = make_model(key) elif key == 'resnext101_32x8d': from models.resnext import make_model net = make_model(key) elif key == 'resnet50': from models.resnet import make_model net = make_model(key) elif key == 'resnet18': from models.resnet import make_model net = make_model(key) elif key == 'resnet34': from models.resnet import make_model net = make_model(key) elif key == 'resnet101': from models.resnet import make_model net = make_model(key) else: print('the network name you have entered is not supported yet') sys.exit() if use_gpu: net = net.cuda() return net
def train_model(modname='alexnet', pm_ch='both', bs=16): """ Args: modname (string): Name of the model. Has to be one of the values: 'alexnet', batch 64 'densenet' 'inception' 'resnet', batch 16 'squeezenet', batch 16 'vgg' pm_ch (string): pixelmap channel -- 'time', 'charge', 'both', default to both """ # device configuration device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') # hyper parameters max_epochs = 10 learning_rate = 0.001 # determine number of input channels nch = 2 if pm_ch != 'both': nch = 1 ds = PixelMapDataset('training_file_list.txt', pm_ch) # try out the data loader utility dl = torch.utils.data.DataLoader(dataset=ds, batch_size=bs, shuffle=True) # define model model = None if modname == 'alexnet': model = alexnet(num_classes=3, in_ch=nch).to(device) elif modname == 'densenet': model = DenseNet(num_classes=3, in_ch=nch).to(device) elif modname == 'inception': model = inception_v3(num_classes=3, in_ch=nch).to(device) elif modname == 'resnet': model = resnet18(num_classes=3, in_ch=nch).to(device) elif modname == 'squeezenet': model = squeezenet1_1(num_classes=3, in_ch=nch).to(device) elif modname == 'vgg': model = vgg19_bn(in_ch=nch, num_classes=3).to(device) else: print('Model {} not defined.'.format(modname)) return # loss and optimizer criterion = nn.CrossEntropyLoss() optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate) # training process total_step = len(dl) for epoch in range(max_epochs): for i, (view1, view2, local_labels) in enumerate(dl): view1 = view1.float().to(device) if modname == 'inception': view1 = nn.ZeroPad2d((0, 192, 102, 101))(view1) else: view1 = nn.ZeroPad2d((0, 117, 64, 64))(view1) local_labels = local_labels.to(device) # forward pass outputs = model(view1) loss = criterion(outputs, local_labels) # backward and optimize optimizer.zero_grad() loss.backward() optimizer.step() if (i + 1) % bs == 0: print('Epoch [{}/{}], Step [{}/{}], Loss: {:.4f}'.format( epoch + 1, max_epochs, i + 1, total_step, loss.item())) # save the model checkpoint save_path = '../../../data/two_views/saved_models/{}/{}'.format( modname, pm_ch) os.makedirs(save_path, exist_ok=True) torch.save(model.state_dict(), os.path.join(save_path, 'model.ckpt'))
def load_model(model_name, training_type, configs): """ Loads model. """ dataset = configs.dataset.lower() # set the input channels and num_classes if dataset == "mnist" or dataset == "fashionmnist": configs.num_classes = 10 configs.input_channels = 1 elif dataset == "cifar-100": configs.num_classes = 100 configs.input_channels = 3 elif dataset == "imagenet": configs.num_classes = 1000 configs.input_channels = 3 else: configs.num_classes = 10 configs.input_channels = 3 # pick model if model_name == "Resnet18": # load weights if training_type == "pretrained": print("Loading pretrained Resnet18") model = torchvision.models.resnet18(pretrained=True) model.fc.Linear = nn.Linear(model.fc.in_features, configs.num_classes) elif training_type == "untrained": print("Loading untrained Resnet18") model = ResNet18(num_classes=configs.num_classes, input_channels=configs.input_channels) elif model_name == "Resnet50": # load weights if training_type == "pretrained": print(f"Loading pretrained {model_name}") model = torchvision.models.resnet50(pretrained=True) model.fc.Linear = nn.Linear(model.fc.in_features, configs.num_classes) elif training_type == "untrained": print(f"Loading untrained {model_name}") model = ResNet50(num_classes=configs.num_classes, input_channels=configs.input_channels) elif model_name == "Resnet101": if training_type == 'pretrained': print(f"Loading pretrained {model_name}") model = torchvision.models.resnet101(pretrained=True) elif training_type == "untrained": print(f"Loading untrained {model_name}") model = torchvision.models.resnet101() elif model_name == "VGG19": # load weights if training_type == "pretrained": print(f"Loading pretrained {model_name}") model = torchvision.models.vgg19(pretrained=True) model.classifier = nn.Sequential( nn.Linear(512 * 7 * 7, 512), nn.ReLU(True), nn.Dropout(), nn.Linear(512, 512), nn.ReLU(True), nn.Dropout(), nn.Linear(512, 512), nn.ReLU(True), nn.Linear(512, 10), ) elif training_type == "untrained": print(f"Loading untrained {model_name}") model = vgg19_bn(in_channels=configs.input_channels, num_classes=configs.num_classes) elif "efficientnet" in model_name: if training_type == 'pretrained': print(f"Loading pretrained {model_name}") model = load_efficientnet(model_name, configs.num_classes, configs.input_channels, True) elif training_type == "untrained": print(f"Loading untrained {model_name}") model = load_efficientnet(model_name, configs.num_classes, configs.input_channels, False) else: print("Please provide a model") # push model to cuda if torch.cuda.device_count() > 1: print(f"Number of GPUs available are {torch.cuda.device_count()}") model = nn.DataParallel(model) print("\nModel moved to Data Parallel") model.cuda() return model