def get_vgg_model(gpu, percentage_freeze):
    model = vgg19_bn(True)
    model.classifier = nn.Sequential(
        nn.Linear(25088, 4096),
        nn.ReLU(True),
        nn.Dropout(),
        nn.Linear(4096, 4096),
        nn.ReLU(True),
        nn.Dropout(),
        nn.Linear(4096, num_classes),
    )
    num_layers_freeze = 30

    params_freezed_count = 0
    params_total_count = get_total_trainable_params(model)
    # for i,param in enumerate(model.parameters()):
    #     percentage_params=params_freezed_count/params_total_count
    #     if percentage_params>percentage_freeze:
    #         param.requires_grad = True
    #     else:
    #         params_freezed_count+=np.prod(param.size())
    #         param.requires_grad = False

    summary(model.cuda(), (3, height, width))
    return model, "vgg_19_{}_adam".format(gpu)
def get_vgg_model():
    model = vgg19_bn(True)
    model.classifier = nn.Sequential(
        nn.Linear(512, 4096),
        nn.ReLU(True),
        nn.Dropout(),
        nn.Linear(4096, 4096),
        nn.ReLU(True),
        nn.Dropout(),
        nn.Linear(4096, num_classes),
    )
    return model
示例#3
0
def get_model(cfg, pretrained=False, load_param_from_ours=False):

    if load_param_from_ours:
        pretrained = False

    model = None
    num_classes = cfg.num_classes
    if cfg.model == 'custom':
        from models import custom_net
        if cfg.patch_size == 64:
            model = custom_net.net_64(num_classes = num_classes)
        elif cfg.patch_size == 32:
            model = custom_net.net_32(num_classes = num_classes)
        else:
            print('Do not support present patch size %s'%cfg.patch_size)
        #model = model
    elif cfg.model == 'googlenet':
        from models import inception_v3
        model = inception_v3.inception_v3(pretrained = pretrained, num_classes = num_classes)
    elif cfg.model == 'vgg':
        from models import vgg
        if cfg.model_info == 19:
            model = vgg.vgg19_bn(pretrained = pretrained, num_classes = num_classes)
        elif cfg.model_info == 16:
            model = vgg.vgg16_bn(pretrained = pretrained, num_classes = num_classes)
    elif cfg.model == 'resnet':
        from models import resnet
        if cfg.model_info == 18:
            model = resnet.resnet18(pretrained= pretrained, num_classes = num_classes)
        elif cfg.model_info == 34:
            model = resnet.resnet34(pretrained= pretrained, num_classes = num_classes)
        elif cfg.model_info == 50:
            model = resnet.resnet50(pretrained= pretrained, num_classes = num_classes)
        elif cfg.model_info == 101:
            model = resnet.resnet101(pretrained= pretrained, num_classes = num_classes)
    if model is None:
        print('not support :' + cfg.model)
        sys.exit(-1)

    if load_param_from_ours:
        print('loading pretrained model from {0}'.format(cfg.init_model_file))
        checkpoint = torch.load(cfg.init_model_file)
        model.load_state_dict(checkpoint['model_param'])

    model.cuda()
    print('shift model to parallel!')
    model = torch.nn.DataParallel(model, device_ids=cfg.gpu_id)
    return model
示例#4
0
def get_network(args,cfg):
    """ return given network
    """
    # pdb.set_trace()
    if args.net == 'lenet5':
        net = LeNet5().cuda()
    elif args.net == 'alexnet':
        net = alexnet(pretrained=args.pretrain, num_classes=cfg.PARA.train.num_classes).cuda()
    elif args.net == 'vgg16':
        net = vgg16(pretrained=args.pretrain, num_classes=cfg.PARA.train.num_classes).cuda()
    elif args.net == 'vgg13':
        net = vgg13(pretrained=args.pretrain,num_classes=cfg.PARA.train.num_classes).cuda()
    elif args.net == 'vgg11':
        net = vgg11(pretrained=args.pretrain,num_classes=cfg.PARA.train.num_classes).cuda()
    elif args.net == 'vgg19':
        net = vgg19(pretrained=args.pretrain,num_classes=cfg.PARA.train.num_classes).cuda()
    elif args.net == 'vgg16_bn':
        net = vgg16_bn(pretrained=args.pretrain,num_classes=cfg.PARA.train.num_classes).cuda()
    elif args.net == 'vgg13_bn':
        net = vgg13_bn(pretrained=args.pretrain,num_classes=cfg.PARA.train.num_classes).cuda()
    elif args.net == 'vgg11_bn':
        net = vgg11_bn(pretrained=args.pretrain,num_classes=cfg.PARA.train.num_classes).cuda()
    elif args.net == 'vgg19_bn':
        net = vgg19_bn(pretrained=args.pretrain,num_classes=cfg.PARA.train.num_classes).cuda()
    elif args.net =='inceptionv3':
        net = inception_v3().cuda()
    # elif args.net == 'inceptionv4':
    #     net = inceptionv4().cuda()
    # elif args.net == 'inceptionresnetv2':
    #     net = inception_resnet_v2().cuda()
    elif args.net == 'resnet18':
        net = resnet18(pretrained=args.pretrain,num_classes=cfg.PARA.train.num_classes).cuda(args.gpuid)
    elif args.net == 'resnet34':
        net = resnet34(pretrained=args.pretrain,num_classes=cfg.PARA.train.num_classes).cuda()
    elif args.net == 'resnet50':
        net = resnet50(pretrained=args.pretrain,num_classes=cfg.PARA.train.num_classes).cuda(args.gpuid)
    elif args.net == 'resnet101':
        net = resnet101(pretrained=args.pretrain,num_classes=cfg.PARA.train.num_classes).cuda()
    elif args.net == 'resnet152':
        net = resnet152(pretrained=args.pretrain,num_classes=cfg.PARA.train.num_classes).cuda()
    elif args.net == 'squeezenet':
        net = squeezenet1_0().cuda()
    else:
        print('the network name you have entered is not supported yet')
        sys.exit()

    return net
示例#5
0
def get_model(model, dataset, classify=True):
    """
    VGG Models
    """
    if model == 'vgg11':
        model = vgg.vgg11_bn(dataset=dataset, classify=classify)
    if model == 'vgg13':
        model = vgg.vgg13_bn(dataset=dataset, classify=classify)
    if model == 'vgg16':
        model = vgg.vgg16_bn(dataset=dataset, classify=classify)
    if model == 'vgg19':
        model = vgg.vgg19_bn(dataset=dataset, classify=classify)
    """
    CyVGG Models
    """
    if model == 'cyvgg11':
        model = cyvgg.cyvgg11_bn(dataset=dataset, classify=classify)
    if model == 'cyvgg13':
        model = cyvgg.cyvgg13_bn(dataset=dataset, classify=classify)
    if model == 'cyvgg16':
        model = cyvgg.cyvgg16_bn(dataset=dataset, classify=classify)
    if model == 'cyvgg19':
        model = cyvgg.cyvgg19_bn(dataset=dataset, classify=classify)
    """
    Resnet Models   
    """
    if model == 'resnet20':
        model = resnet.resnet20(dataset=dataset)
    if model == 'resnet32':
        model = resnet.resnet32(dataset=dataset)
    if model == 'resnet44':
        model = resnet.resnet44(dataset=dataset)
    if model == 'resnet56':
        model = resnet.resnet56(dataset=dataset)
    """
    CyResnet Models
    """
    if model == 'cyresnet20':
        model = cyresnet.cyresnet20(dataset=dataset)
    if model == 'cyresnet32':
        model = cyresnet.cyresnet32(dataset=dataset)
    if model == 'cyresnet44':
        model = cyresnet.cyresnet44(dataset=dataset)
    if model == 'cyresnet56':
        model = cyresnet.cyresnet56(dataset=dataset)

    return model
def highest_resolution_test(cfg, file_name, sorted_abnormal_patches, time,
                            patch_out_size, b_map, p_map):
    """
    Input: patch_coordinates;
    Output: heat map.
    
    """
    model_hr = vgg.vgg19_bn(pretrained=False, num_classes=num_classes)
    checkpoint_hr = torch.load(checkpoint_path_hr)
    model_hr.load_state_dict(checkpoint_hr['model_param'])
    model_hr.cuda()
    model_hr = torch.nn.DataParallel(model_hr, device_ids=cfg.gpu_id)
    model_hr.eval()
    b_map, p_map = prob_map.generate_prob_map_hr(cfg, file_name,
                                                 sorted_abnormal_patches,
                                                 model_hr, time,
                                                 patch_out_size, b_map, p_map)
    return b_map, p_map
示例#7
0
def get_model(cfg, pretrained=True, load_param_from_folder=False):

    if load_param_from_folder:
        pretrained = False

    model = None
    num_classes = cfg.num_classes
    if cfg.model == 'googlenet':
        from models import inception_v3
        model = inception_v3.inception_v3(pretrained = pretrained, num_classes = num_classes)
    elif cfg.model == 'vgg':
        from models import vgg
        if cfg.model_info == 19:
            model = vgg.vgg19_bn(pretrained = pretrained, num_classes = num_classes)
        elif cfg.model_info == 16:
            model = vgg.vgg16_bn(pretrained = pretrained, num_classes = num_classes)
    elif cfg.model == 'resnet':
        from models import resnet
        if cfg.model_info == 18:
            model = resnet.resnet18(pretrained= pretrained, num_classes = num_classes)
        elif cfg.model_info == 34:
            model = resnet.resnet34(pretrained= pretrained, num_classes = num_classes)
        elif cfg.model_info == 50:
            model = resnet.resnet50(pretrained= pretrained, num_classes = num_classes)
        elif cfg.model_info == 101:
            model = resnet.resnet101(pretrained= pretrained, num_classes = num_classes)
    if model is None:
        print('not support :' + cfg.model)
        sys.exit(-1)

    if load_param_from_folder:
        print('loading pretrained model from {0}'.format(cfg.init_model_file))
        checkpoint = torch.load(cfg.init_model_file)
        model.load_state_dict(checkpoint['model_param'])

    print('shift model to parallel!')
    model = torch.nn.DataParallel(model, device_ids=cfg.gpu_id)
    return model
示例#8
0
def get_network(args, use_gpu=False):
    """ return given network
    """

    if args.net == 'vgg16':
        from models.vgg import vgg16_bn
        net = vgg16_bn()
    elif args.net == 'vgg13':
        from models.vgg import vgg13_bn
        net = vgg13_bn()
    elif args.net == 'vgg11':
        from models.vgg import vgg11_bn
        net = vgg11_bn()
    elif args.net == 'vgg19':
        from models.vgg import vgg19_bn
        net = vgg19_bn()
    else:
        print('the network name you have entered is not supported yet')
        sys.exit()

    if use_gpu:
        net = net.cuda()

    return net
示例#9
0
def get_network(args):
    """ return given network
    """

    if args.net == 'vgg11':
        from models.vgg import vgg11_bn
        net = vgg11_bn()
    elif args.net == 'vgg19':
        from models.vgg import vgg19_bn
        net = vgg19_bn()
    elif args.net == 'resnet18':
        from models.resnet import resnet18
        net = resnet18()
    elif args.net == 'resnet50':
        from models.resnet import resnet50
        net = resnet50()
    else:
        print('the network name you have entered is not supported yet')
        sys.exit()

    if args.gpu:  #use_gpu
        net = net.cuda()

    return net
def get_network(args, use_gpu=True):
    """ return given network
    """

    if args.net == 'vgg16':
        from models.vgg import vgg16_bn
        net = vgg16_bn()
    elif args.net == 'vgg13':
        from models.vgg import vgg13_bn
        net = vgg13_bn()
    elif args.net == 'vgg11':
        from models.vgg import vgg11_bn
        net = vgg11_bn()
    elif args.net == 'vgg19':
        from models.vgg import vgg19_bn
        net = vgg19_bn()
    elif args.net == 'densenet121':
        from models.densenet import densenet121
        net = densenet121()
    elif args.net == 'densenet161':
        from models.densenet import densenet161
        net = densenet161()
    elif args.net == 'densenet169':
        from models.densenet import densenet169
        net = densenet169()
    elif args.net == 'densenet201':
        from models.densenet import densenet201
        net = densenet201()
    elif args.net == 'googlenet':
        from models.googlenet import googlenet
        net = googlenet()
    elif args.net == 'inceptionv3':
        from models.inceptionv3 import inceptionv3
        net = inceptionv3()
    elif args.net == 'inceptionv4':
        from models.inceptionv4 import inceptionv4
        net = inceptionv4()
    elif args.net == 'inceptionresnetv2':
        from models.inceptionv4 import inception_resnet_v2
        net = inception_resnet_v2()
    elif args.net == 'xception':
        from models.xception import xception
        net = xception()
    elif args.net == 'resnet18':
        from models.resnet import resnet18
        net = resnet18()
    elif args.net == 'resnet34':
        from models.resnet import resnet34
        net = resnet34()
    elif args.net == 'resnet50':
        from models.resnet import resnet50
        net = resnet50()
    elif args.net == 'resnet101':
        from models.resnet import resnet101
        net = resnet101()
    elif args.net == 'resnet152':
        from models.resnet import resnet152
        net = resnet152()
    elif args.net == 'preactresnet18':
        from models.preactresnet import preactresnet18
        net = preactresnet18()
    elif args.net == 'preactresnet34':
        from models.preactresnet import preactresnet34
        net = preactresnet34()
    elif args.net == 'preactresnet50':
        from models.preactresnet import preactresnet50
        net = preactresnet50()
    elif args.net == 'preactresnet101':
        from models.preactresnet import preactresnet101
        net = preactresnet101()
    elif args.net == 'preactresnet152':
        from models.preactresnet import preactresnet152
        net = preactresnet152()
    elif args.net == 'resnext50':
        from models.resnext import resnext50
        net = resnext50()
    elif args.net == 'resnext101':
        from models.resnext import resnext101
        net = resnext101()
    elif args.net == 'resnext152':
        from models.resnext import resnext152
        net = resnext152()
    elif args.net == 'shufflenet':
        from models.shufflenet import shufflenet
        net = shufflenet()
    elif args.net == 'shufflenetv2':
        from models.shufflenetv2 import shufflenetv2
        net = shufflenetv2()
    elif args.net == 'squeezenet':
        from models.squeezenet import squeezenet
        net = squeezenet()
    elif args.net == 'mobilenet':
        from models.mobilenet import mobilenet
        net = mobilenet()
    elif args.net == 'mobilenetv2':
        from models.mobilenetv2 import mobilenetv2
        net = mobilenetv2()
    elif args.net == 'nasnet':
        from models.nasnet import nasnet
        net = nasnet()
    elif args.net == 'attention56':
        from models.attention import attention56
        net = attention56()
    elif args.net == 'attention92':
        from models.attention import attention92
        net = attention92()
    elif args.net == 'seresnet18':
        from models.senet import seresnet18
        net = seresnet18()
    elif args.net == 'seresnet34':
        from models.senet import seresnet34
        net = seresnet34()
    elif args.net == 'seresnet50':
        from models.senet import seresnet50
        net = seresnet50()
    elif args.net == 'seresnet101':
        from models.senet import seresnet101
        net = seresnet101()
    elif args.net == 'seresnet152':
        from models.senet import seresnet152
        net = seresnet152()

    else:
        print('the network name you have entered is not supported yet')
        sys.exit()

    if use_gpu:
        net = net.cuda()

    return net
示例#11
0
def get_network(args):
    """ return given network
    """

    if args.net == 'vgg16':
        from models.vgg import vgg16_bn
        net = vgg16_bn()
    elif args.net == 'vgg13':
        from models.vgg import vgg13_bn
        net = vgg13_bn()
    elif args.net == 'vgg11':
        from models.vgg import vgg11_bn
        net = vgg11_bn()
    elif args.net == 'vgg19':
        from models.vgg import vgg19_bn
        net = vgg19_bn()
    elif args.net == 'densenet121':
        from models.densenet import densenet121
        net = densenet121()
    elif args.net == 'densenet161':
        from models.densenet import densenet161
        net = densenet161()
    elif args.net == 'densenet169':
        from models.densenet import densenet169
        net = densenet169()
    elif args.net == 'densenet201':
        from models.densenet import densenet201
        net = densenet201()
    elif args.net == 'googlenet':
        from models.googlenet import googlenet
        net = googlenet()
    elif args.net == 'inceptionv3':
        from models.inceptionv3 import inceptionv3
        net = inceptionv3()
    elif args.net == 'inceptionv4':
        from models.inceptionv4 import inceptionv4
        net = inceptionv4()
    elif args.net == 'inceptionresnetv2':
        from models.inceptionv4 import inception_resnet_v2
        net = inception_resnet_v2()
    elif args.net == 'xception':
        from models.xception import xception
        net = xception()
    elif args.net == 'resnet18':
        from models.resnet import resnet18
        net = resnet18()
    elif args.net == 'resnet34':
        from models.resnet import resnet34
        net = resnet34()
    elif args.net == 'resnet50':
        from models.resnet import resnet50
        net = resnet50()
    elif args.net == 'resnet101':
        from models.resnet import resnet101
        net = resnet101()
    elif args.net == 'resnet152':
        from models.resnet import resnet152
        net = resnet152()
    elif args.net == 'preactresnet18':
        from models.preactresnet import preactresnet18
        net = preactresnet18()
    elif args.net == 'preactresnet34':
        from models.preactresnet import preactresnet34
        net = preactresnet34()
    elif args.net == 'preactresnet50':
        from models.preactresnet import preactresnet50
        net = preactresnet50()
    elif args.net == 'preactresnet101':
        from models.preactresnet import preactresnet101
        net = preactresnet101()
    elif args.net == 'preactresnet152':
        from models.preactresnet import preactresnet152
        net = preactresnet152()
    elif args.net == 'resnext50':
        from models.resnext import resnext50
        net = resnext50()
    elif args.net == 'resnext101':
        from models.resnext import resnext101
        net = resnext101()
    elif args.net == 'resnext152':
        from models.resnext import resnext152
        net = resnext152()
    elif args.net == 'shufflenet':
        from models.shufflenet import shufflenet
        net = shufflenet()
    elif args.net == 'shufflenetv2':
        from models.shufflenetv2 import shufflenetv2
        net = shufflenetv2()
    elif args.net == 'squeezenet':
        from models.squeezenet import squeezenet
        net = squeezenet()
    elif args.net == 'mobilenet':
        from models.mobilenet import mobilenet
        net = mobilenet()
    elif args.net == 'mobilenetv2':
        from models.mobilenetv2 import mobilenetv2
        net = mobilenetv2()
    elif args.net == 'nasnet':
        from models.nasnet import nasnet
        net = nasnet()
    elif args.net == 'attention56':
        from models.attention import attention56
        net = attention56()
    elif args.net == 'attention92':
        from models.attention import attention92
        net = attention92()
    elif args.net == 'seresnet18':
        from models.senet import seresnet18
        net = seresnet18()
    elif args.net == 'seresnet34':
        from models.senet import seresnet34
        net = seresnet34()
    elif args.net == 'seresnet50':
        from models.senet import seresnet50
        net = seresnet50()
    elif args.net == 'seresnet101':
        from models.senet import seresnet101
        net = seresnet101()
    elif args.net == 'seresnet152':
        from models.senet import seresnet152
        net = seresnet152()
    elif args.net == 'wideresnet':
        from models.wideresidual import wideresnet
        net = wideresnet()
    elif args.net == 'stochasticdepth18':
        from models.stochasticdepth import stochastic_depth_resnet18
        net = stochastic_depth_resnet18()
    elif args.net == 'stochasticdepth34':
        from models.stochasticdepth import stochastic_depth_resnet34
        net = stochastic_depth_resnet34()
    elif args.net == 'stochasticdepth50':
        from models.stochasticdepth import stochastic_depth_resnet50
        net = stochastic_depth_resnet50()
    elif args.net == 'stochasticdepth101':
        from models.stochasticdepth import stochastic_depth_resnet101
        net = stochastic_depth_resnet101()
    elif args.net == 'normal_resnet':
        from models.normal_resnet import resnet18
        net = resnet18()
    elif args.net == 'hyper_resnet':
        from models.hypernet_main import Hypernet_Main
        net = Hypernet_Main(
            encoder="resnet18",
            hypernet_params={'vqvae_dict_size': args.dict_size})
    elif args.net == 'normal_resnet_wo_bn':
        from models.normal_resnet_wo_bn import resnet18
        net = resnet18()
    elif args.net == 'hyper_resnet_wo_bn':
        from models.hypernet_main import Hypernet_Main
        net = Hypernet_Main(
            encoder="resnet18_wobn",
            hypernet_params={'vqvae_dict_size': args.dict_size})
    else:
        print('the network name you have entered is not supported yet')
        sys.exit()

    if args.gpu:  #use_gpu
        net = net.cuda()

    return net
示例#12
0
from conv_cf import HCF
import torch
import numpy as np
import cv2
from models.vgg import vgg19_bn
import torchvision.transforms as T

device = torch.device("cuda:1" if torch.cuda.is_available() else "cpu")

model = vgg19_bn(pretrained=True, progress=True)
model = model.to(device)
conv3 = torch.zeros((1, 256, 16, 16), device=device)
conv4 = torch.zeros((1, 512, 8, 8), device=device)
conv5 = torch.zeros((1, 512, 4, 4), device=device)


def get_conv(model, extracted_roi):
    with torch.no_grad():
        global conv3, conv4, conv5  # can be remove
        for i in range(53):
            extracted_roi = model.features[i](extracted_roi)
            if i == 26:
                conv3 = extracted_roi
            if i == 39:
                conv4 = extracted_roi
        conv5 = extracted_roi

    return conv3, conv4, conv5


def get_border_roi(x1, y1, x2, y2, frame):
示例#13
0
def get_network(args):
    """ return given network
    """

    if args.net == 'vgg16':
        from models.vgg import vgg16_bn
        net = vgg16_bn()
    elif args.net == 'vgg13':
        from models.vgg import vgg13_bn
        net = vgg13_bn()
    elif args.net == 'vgg11':
        from models.vgg import vgg11_bn
        net = vgg11_bn()
    elif args.net == 'vgg19':
        from models.vgg import vgg19_bn
        net = vgg19_bn()
    # elif args.net == 'efficientnet':
    #     from models.effnetv2 import effnetv2_s
    #     net = effnetv2_s()
    elif args.net == 'densenet121':
        from models.densenet import densenet121
        net = densenet121()
    elif args.net == 'densenet161':
        from models.densenet import densenet161
        net = densenet161()
    elif args.net == 'densenet169':
        from models.densenet import densenet169
        net = densenet169()
    elif args.net == 'densenet201':
        from models.densenet import densenet201
        net = densenet201()
    elif args.net == 'googlenet':
        from models.googlenet import googlenet
        net = googlenet()
    elif args.net == 'inceptionv3':
        from models.inceptionv3 import inceptionv3
        net = inceptionv3()
    elif args.net == 'inceptionv4':
        from models.inceptionv4 import inceptionv4
        net = inceptionv4()
    elif args.net == 'inceptionresnetv2':
        from models.inceptionv4 import inception_resnet_v2
        net = inception_resnet_v2()
    elif args.net == 'xception':
        from models.xception import xception
        net = xception()
    elif args.net == 'resnet18':
        from models.resnet import resnet18
        net = resnet18()
    elif args.net == 'resnet34':
        from models.resnet import resnet34
        net = resnet34()
    elif args.net == 'resnet50':
        from models.resnet import resnet50
        net = resnet50()
    elif args.net == 'resnet101':
        from models.resnet import resnet101
        net = resnet101()
    elif args.net == 'resnet152':
        from models.resnet import resnet152
        net = resnet152()
    elif args.net == 'preactresnet18':
        from models.preactresnet import preactresnet18
        net = preactresnet18()
    elif args.net == 'preactresnet34':
        from models.preactresnet import preactresnet34
        net = preactresnet34()
    elif args.net == 'preactresnet50':
        from models.preactresnet import preactresnet50
        net = preactresnet50()
    elif args.net == 'preactresnet101':
        from models.preactresnet import preactresnet101
        net = preactresnet101()
    elif args.net == 'preactresnet152':
        from models.preactresnet import preactresnet152
        net = preactresnet152()
    elif args.net == 'resnext50':
        from models.resnext import resnext50
        net = resnext50()
    elif args.net == 'resnext101':
        from models.resnext import resnext101
        net = resnext101()
    elif args.net == 'resnext152':
        from models.resnext import resnext152
        net = resnext152()
    elif args.net == 'shufflenet':
        from models.shufflenet import shufflenet
        net = shufflenet()
    elif args.net == 'shufflenetv2':
        from models.shufflenetv2 import shufflenetv2
        net = shufflenetv2()
    elif args.net == 'squeezenet':
        from models.squeezenet import squeezenet
        net = squeezenet()
    elif args.net == 'mobilenet':
        from models.mobilenet import mobilenet
        net = mobilenet()
    elif args.net == 'mobilenetv2':
        from models.mobilenetv2 import mobilenetv2
        net = mobilenetv2()
    elif args.net == 'nasnet':
        from models.nasnet import nasnet
        net = nasnet()
    elif args.net == 'attention56':
        from models.attention import attention56
        net = attention56()
    elif args.net == 'attention92':
        from models.attention import attention92
        net = attention92()
    elif args.net == 'seresnet18':
        from models.senet import seresnet18
        net = seresnet18()
    elif args.net == 'seresnet34':
        from models.senet import seresnet34
        net = seresnet34()
    elif args.net == 'seresnet50':
        from models.senet import seresnet50
        net = seresnet50()
    elif args.net == 'seresnet101':
        from models.senet import seresnet101
        net = seresnet101()
    elif args.net == 'seresnet152':
        from models.senet import seresnet152
        net = seresnet152()
    elif args.net == 'wideresnet':
        from models.wideresidual import wideresnet
        net = wideresnet()
    elif args.net == 'stochasticdepth18':
        from models.stochasticdepth import stochastic_depth_resnet18
        net = stochastic_depth_resnet18()
    elif args.net == 'stochasticdepth34':
        from models.stochasticdepth import stochastic_depth_resnet34
        net = stochastic_depth_resnet34()
    elif args.net == 'stochasticdepth50':
        from models.stochasticdepth import stochastic_depth_resnet50
        net = stochastic_depth_resnet50()
    elif args.net == 'stochasticdepth101':
        from models.stochasticdepth import stochastic_depth_resnet101
        net = stochastic_depth_resnet101()
    elif args.net == 'efficientnetb0':
        from models.efficientnet import efficientnetb0
        net = efficientnetb0()
    elif args.net == 'efficientnetb1':
        from models.efficientnet import efficientnetb1
        net = efficientnetb1()
    elif args.net == 'efficientnetb2':
        from models.efficientnet import efficientnetb2
        net = efficientnetb2()
    elif args.net == 'efficientnetb3':
        from models.efficientnet import efficientnetb3
        net = efficientnetb3()
    elif args.net == 'efficientnetb4':
        from models.efficientnet import efficientnetb4
        net = efficientnetb4()
    elif args.net == 'efficientnetb5':
        from models.efficientnet import efficientnetb5
        net = efficientnetb5()
    elif args.net == 'efficientnetb6':
        from models.efficientnet import efficientnetb6
        net = efficientnetb6()
    elif args.net == 'efficientnetb7':
        from models.efficientnet import efficientnetb7
        net = efficientnetb7()
    elif args.net == 'efficientnetl2':
        from models.efficientnet import efficientnetl2
        net = efficientnetl2()
    elif args.net == 'eff':
        from models.efficientnet_pytorch import EfficientNet
        net = EfficientNet.from_pretrained('efficientnet-b7', num_classes=2)

    else:
        print('the network name you have entered is not supported yet')
        sys.exit()

    if args.gpu:  #use_gpu
        net = net.cuda()
        print("use-gpu")

    return net
示例#14
0
def test_model(modname='alexnet', pm_ch='both', bs=16):
    # hyperparameters
    batch_size = bs

    # device configuration
    device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

    # determine number of input channels
    nch = 2
    if pm_ch != 'both':
        nch = 1

    # restore model
    model = None
    if modname == 'alexnet':
        model = alexnet(num_classes=3, in_ch=nch).to(device)
    elif modname == 'densenet':
        model = DenseNet(num_classes=3, in_ch=nch).to(device)
    elif modname == 'inception':
        model = inception_v3(num_classes=3, in_ch=nch).to(device)
    elif modname == 'resnet':
        model = resnet18(num_classes=3, in_ch=nch).to(device)
    elif modname == 'squeezenet':
        model = squeezenet1_1(num_classes=3, in_ch=nch).to(device)
    elif modname == 'vgg':
        model = vgg19_bn(in_ch=nch, num_classes=3).to(device)
    else:
        print('Model {} not defined.'.format(modname))
        return

    # retrieve trained model
    # load path
    load_path = '../../../data/two_views/saved_models/{}/{}'.format(
        modname, pm_ch)
    model_pathname = os.path.join(load_path, 'model.ckpt')
    if not os.path.exists(model_pathname):
        print('Trained model file {} does not exist. Abort.'.format(
            model_pathname))
        return
    model.load_state_dict(torch.load(model_pathname))

    # load test dataset
    test_dataset = PixelMapDataset('test_file_list.txt', pm_ch)

    test_loader = torch.utils.data.DataLoader(dataset=test_dataset,
                                              batch_size=batch_size,
                                              shuffle=False)

    # test the model
    model.eval(
    )  # eval mode (batchnorm uses moving mean/variance instead of mini-batch mean/variance)
    with torch.no_grad():
        correct = 0
        total = 0
        correct_cc_or_bkg = 0
        ws_total = 0
        ws_correct = 0
        for view1, view2, labels in test_loader:
            view1 = view1.float().to(device)
            if modname == 'inception':
                view1 = nn.ZeroPad2d((0, 192, 102, 101))(view1)
            else:
                view1 = nn.ZeroPad2d((0, 117, 64, 64))(view1)
            labels = labels.to(device)
            outputs = model(view1)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
            for i in range(len(predicted)):
                if (predicted[i] < 2
                        and labels[i] < 2) or (predicted[i] == 2
                                               and labels[i] == 2):
                    correct_cc_or_bkg += 1
                if labels[i] < 2:
                    ws_total += 1
                    if (predicted[i] == labels[i]):
                        ws_correct += 1
        print('Model Performance:')
        print('Model:', modname)
        print('Channel:', pm_ch)
        print(
            '3-class Test Accuracy of the model on the test images: {}/{}, {:.2f} %'
            .format(correct, total, 100 * correct / total))
        print(
            '2-class Test Accuracy of the model on the test images: {}/{}, {:.2f} %'
            .format(correct_cc_or_bkg, total, 100 * correct_cc_or_bkg / total))
        print(
            'Wrong-sign Test Accuracy of the model on the test images: {}/{}, {:.2f} %'
            .format(ws_correct, ws_total, 100 * ws_correct / ws_total))
示例#15
0
def main():
    global args, best_err1
    args = parser.parse_args()

    # TensorBoard configure
    if args.tensorboard:
        configure('%s_checkpoints/%s'%(args.dataset, args.expname))

    # CUDA
    os.environ['CUDA_DEVICE_ORDER'] = "PCI_BUS_ID"
    os.environ['CUDA_VISIBLE_DEVICES'] = str(args.gpu_ids)
    if torch.cuda.is_available():
        cudnn.benchmark = True  # https://discuss.pytorch.org/t/what-does-torch-backends-cudnn-benchmark-do/5936
        kwargs = {'num_workers': 2, 'pin_memory': True}
    else:
        kwargs = {'num_workers': 2}

    # Data loading code
    if args.dataset == 'cifar10':
        normalize = transforms.Normalize(mean=[0.4914, 0.4822, 0.4465],
                                         std=[0.2023, 0.1994, 0.2010])
    elif args.dataset == 'cifar100':
        normalize = transforms.Normalize(mean=[0.5071, 0.4865, 0.4409],
                                         std=[0.2634, 0.2528, 0.2719])
    elif args.dataset == 'cub':
        normalize = transforms.Normalize(mean=[0.4862, 0.4973, 0.4293],
                                         std=[0.2230, 0.2185, 0.2472])
    elif args.dataset == 'webvision':
        normalize = transforms.Normalize(mean=[0.49274242, 0.46481857, 0.41779366],
                                         std=[0.26831809, 0.26145372, 0.27042758])
    else:
        raise Exception('Unknown dataset: {}'.format(args.dataset))

    # Transforms
    if args.augment:
        train_transform = transforms.Compose([
            transforms.RandomResizedCrop(args.train_image_size),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            normalize,
        ])
    else:
        train_transform = transforms.Compose([
            transforms.RandomResizedCrop(args.train_image_size),
            transforms.ToTensor(),
            normalize,
        ])
    val_transform = transforms.Compose([
        transforms.Resize(args.test_image_size),
        transforms.CenterCrop(args.test_crop_image_size),
        transforms.ToTensor(),
        normalize
    ])

    # Datasets
    num_classes = 10    # default 10 classes
    if args.dataset == 'cifar10':
        train_dataset = datasets.CIFAR10('./data/', train=True, download=True, transform=train_transform)
        val_dataset = datasets.CIFAR10('./data/', train=False, download=True, transform=val_transform)
        num_classes = 10
    elif args.dataset == 'cifar100':
        train_dataset = datasets.CIFAR100('./data/', train=True, download=True, transform=train_transform)
        val_dataset = datasets.CIFAR100('./data/', train=False, download=True, transform=val_transform)
        num_classes = 100
    elif args.dataset == 'cub':
        train_dataset = datasets.ImageFolder('/media/ouc/30bd7817-d3a1-4e83-b7d9-5c0e373ae434/DuAngAng/datasets/CUB-200-2011/train/',
                                             transform=train_transform)
        val_dataset = datasets.ImageFolder('/media/ouc/30bd7817-d3a1-4e83-b7d9-5c0e373ae434/DuAngAng/datasets/CUB-200-2011/test/',
                                           transform=val_transform)
        num_classes = 200
    elif args.dataset == 'webvision':
        train_dataset = datasets.ImageFolder('/media/ouc/30bd7817-d3a1-4e83-b7d9-5c0e373ae434/LiuJing/WebVision/info/train',
                                             transform=train_transform)
        val_dataset = datasets.ImageFolder('/media/ouc/30bd7817-d3a1-4e83-b7d9-5c0e373ae434/LiuJing/WebVision/info/val',
                                           transform=val_transform)
        num_classes = 1000
    else:
        raise Exception('Unknown dataset: {}'.format(args.dataset))

    # Data Loader
    train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=args.batchsize, shuffle=True, **kwargs)
    val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=1, shuffle=False, **kwargs)

    # Create model
    if args.model == 'AlexNet':
        model = alexnet(pretrained=False, num_classes=num_classes)
    elif args.model == 'VGG':
        use_batch_normalization = True  # default use Batch Normalization
        if use_batch_normalization:
            if args.depth == 11:
                model = vgg11_bn(pretrained=False, num_classes=num_classes)
            elif args.depth == 13:
                model = vgg13_bn(pretrained=False, num_classes=num_classes)
            elif args.depth == 16:
                model = vgg16_bn(pretrained=False, num_classes=num_classes)
            elif args.depth == 19:
                model = vgg19_bn(pretrained=False, num_classes=num_classes)
            else:
                raise Exception('Unsupport VGG detph: {}, optional depths: 11, 13, 16 or 19'.format(args.depth))
        else:
            if args.depth == 11:
                model = vgg11(pretrained=False, num_classes=num_classes)
            elif args.depth == 13:
                model = vgg13(pretrained=False, num_classes=num_classes)
            elif args.depth == 16:
                model = vgg16(pretrained=False, num_classes=num_classes)
            elif args.depth == 19:
                model = vgg19(pretrained=False, num_classes=num_classes)
            else:
                raise Exception('Unsupport VGG detph: {}, optional depths: 11, 13, 16 or 19'.format(args.depth))
    elif args.model == 'Inception':
        model = inception_v3(pretrained=False, num_classes=num_classes)
    elif args.model == 'ResNet':
        if args.depth == 18:
            model = resnet18(pretrained=False, num_classes=num_classes)
        elif args.depth == 34:
            model = resnet34(pretrained=False, num_classes=num_classes)
        elif args.depth == 50:
            model = resnet50(pretrained=False, num_classes=num_classes)
        elif args.depth == 101:
            model = resnet101(pretrained=False, num_classes=num_classes)
        elif args.depth == 152:
            model = resnet152(pretrained=False, num_classes=num_classes)
        else:
            raise Exception('Unsupport ResNet detph: {}, optional depths: 18, 34, 50, 101 or 152'.format(args.depth))
    elif args.model == 'MPN-COV-ResNet':
        if args.depth == 18:
            model = mpn_cov_resnet18(pretrained=False, num_classes=num_classes)
        elif args.depth == 34:
            model = mpn_cov_resnet34(pretrained=False, num_classes=num_classes)
        elif args.depth == 50:
            model = mpn_cov_resnet50(pretrained=False, num_classes=num_classes)
        elif args.depth == 101:
            model = mpn_cov_resnet101(pretrained=False, num_classes=num_classes)
        elif args.depth == 152:
            model = mpn_cov_resnet152(pretrained=False, num_classes=num_classes)
        else:
            raise Exception('Unsupport MPN-COV-ResNet detph: {}, optional depths: 18, 34, 50, 101 or 152'.format(args.depth))
    else:
        raise Exception('Unsupport model'.format(args.model))

    # Get the number of model parameters
    print('Number of model parameters: {}'.format(sum([p.data.nelement() for p in model.parameters()])))

    if torch.cuda.is_available():
        model = model.cuda()

    # Optionally resume from a checkpoint
    if args.resume:
        if os.path.isfile(args.resume):
            print("==> loading checkpoint '{}'".format(args.resume))
            checkpoint = torch.load(args.resume)
            args.start_epoch = checkpoint['epoch']
            best_err1 = checkpoint['best_err1']
            model.load_state_dict(checkpoint['state_dict'])
            print("==> loaded checkpoint '{}' (epoch {})"
                  .format(args.resume, checkpoint['epoch']))
        else:
            print("==> no checkpoint found at '{}'".format(args.resume))

    print(model)

    # Define loss function (criterion) and optimizer
    criterion = nn.CrossEntropyLoss()
    if torch.cuda.is_available():
        criterion = criterion.cuda()
    optimizer = torch.optim.SGD(model.parameters(), args.lr,
                                momentum=args.momentum,
                                weight_decay=args.weight_decay,
                                nesterov=True)
    for epoch in range(args.start_epoch, args.epochs):
        adjust_learning_rate(optimizer, epoch)

        # Train for one epoch
        train(train_loader, model, criterion, optimizer, epoch)

        # Evaluate on validation set
        err1 = validate(val_loader, model, criterion, epoch)

        # Remember best err1 and save checkpoint
        is_best = (err1 <= best_err1)
        best_err1 = min(err1, best_err1)
        print("Current best accuracy (error):", best_err1)
        save_checkpoint({
            'epoch': epoch+1,
            'state_dict': model.state_dict(),
            'best_err1': best_err1,
        }, is_best)

    print("Best accuracy (error):", best_err1)
def load_trained_model(model_name, train_set, device=torch.device('cpu')):
    """ Loads a pre-trained model from a state dict.

    Assumes that your models are saved in 'bayesian-calibration/models'
        and that your state dicts are saved in 'bayesian-calibration/models/checkpoints'

    Args:
        model_name: str ;
        train_set: str ;
        device: str; cpu by default
    Returns:
        A trained PyTorch model in eval mode.
    """
    print('\nLoading pre-trained model')
    print('----| Model: {}  Train set: {}'.format(model_name, train_set))

    train_set = train_set.lower()

    if train_set.startswith('cifar'):
        # Load local cifar-trained models
        num_classes = {'cifar100': 100,
                       'cifar10': 10,
                       'cifar10imba': 10}

        train_set = train_set.lower().strip()
        model_name = model_name.lower().strip()

        # Load the saved state dict
        path_str = 'models/checkpoints/{}_{}.tar'.format(model_name, train_set)
        checkpoint_path = pathlib.Path(path_str).resolve()
        checkpoint = torch.load(checkpoint_path)
        state_dict = checkpoint['state_dict']

        if model_name == 'resnet-110':
            from models.resnet import resnet
            state_dict = _strip_parallel_model(state_dict)
            model = resnet(num_classes=num_classes[train_set], depth=110, block_name='BasicBlock')
        elif model_name == 'alexnet':
            from models.alexnet import alexnet
            state_dict = _strip_parallel_model(state_dict)
            model = alexnet(num_classes=num_classes[train_set])
        elif model_name == 'vgg19-bn':
            from models.vgg import vgg19_bn
            state_dict = _strip_parallel_model(state_dict)
            model = vgg19_bn(num_classes=num_classes[train_set])
        elif model_name == 'wrn-28-10':
            from models.wrn import wrn
            state_dict = _strip_parallel_model(state_dict)
            model = wrn(num_classes=num_classes[train_set],
                        depth=28,
                        widen_factor=10,
                        dropRate=0.3)
        else:
            raise NotImplementedError

        model.load_state_dict(state_dict)
    elif train_set == 'imagenet':
        # Thin wrapper to load PyTorch pretrained imagenet models
        import torchvision.models as models
        model = getattr(models, model_name)(pretrained=True)
    else:
        raise NotImplementedError

    model.eval()
    return model.to(device)
import torchvision.transforms as transforms

def save_image_tensor2pillow(input_tensor: torch.Tensor, filename):
    assert (len(input_tensor.shape) == 3)
    input_tensor = input_tensor.clone().detach()
    input_tensor = input_tensor.to(torch.device('cpu'))
    input_tensor = input_tensor.squeeze()
    input_tensor = input_tensor.mul_(255).add_(0.5).clamp_(0, 255).permute(1, 2, 0).type(torch.uint8).numpy()
    im = Image.fromarray(input_tensor)
    im.save(filename)

torch.manual_seed(42)
use_cuda = torch.cuda.is_available()
device = torch.device("cuda" if use_cuda else "cpu")

model = vgg19_bn()
model.load_state_dict(
    torch.load("checkpoint/vgg_baseline.pth"))
model.to(device)
model.eval()

cifar100_test_loader = get_test_dataloader(
    settings.CIFAR100_TRAIN_MEAN,
    settings.CIFAR100_TRAIN_STD,
    #settings.CIFAR100_PATH,
    num_workers=2,
    batch_size=16,
    shuffle=True
)

adversary = GradientSignAttack(
示例#18
0
def get_model(args, model_path=None):
    """

    :param args: super arguments
    :param model_path: if not None, load already trained model parameters.
    :return: model
    """
    if args.scratch:  # train model from scratch
        pretrained = False
        model_dir = None
        print("=> Loading model '{}' from scratch...".format(args.model))
    else:  # train model with pretrained model
        pretrained = True
        model_dir = os.path.join(args.root_path, args.pretrained_models_path)
        print("=> Loading pretrained model '{}'...".format(args.model))

    if args.model.startswith('resnet'):

        if args.model == 'resnet18':
            model = resnet18(pretrained=pretrained, model_dir=model_dir)
        elif args.model == 'resnet34':
            model = resnet34(pretrained=pretrained, model_dir=model_dir)
        elif args.model == 'resnet50':
            model = resnet50(pretrained=pretrained, model_dir=model_dir)
        elif args.model == 'resnet101':
            model = resnet101(pretrained=pretrained, model_dir=model_dir)
        elif args.model == 'resnet152':
            model = resnet152(pretrained=pretrained, model_dir=model_dir)

        model.fc = nn.Linear(model.fc.in_features, args.num_classes)

    elif args.model.startswith('vgg'):
        if args.model == 'vgg11':
            model = vgg11(pretrained=pretrained, model_dir=model_dir)
        elif args.model == 'vgg11_bn':
            model = vgg11_bn(pretrained=pretrained, model_dir=model_dir)
        elif args.model == 'vgg13':
            model = vgg13(pretrained=pretrained, model_dir=model_dir)
        elif args.model == 'vgg13_bn':
            model = vgg13_bn(pretrained=pretrained, model_dir=model_dir)
        elif args.model == 'vgg16':
            model = vgg16(pretrained=pretrained, model_dir=model_dir)
        elif args.model == 'vgg16_bn':
            model = vgg16_bn(pretrained=pretrained, model_dir=model_dir)
        elif args.model == 'vgg19':
            model = vgg19(pretrained=pretrained, model_dir=model_dir)
        elif args.model == 'vgg19_bn':
            model = vgg19_bn(pretrained=pretrained, model_dir=model_dir)

        model.classifier[6] = nn.Linear(model.classifier[6].in_features, args.num_classes)

    elif args.model == 'alexnet':
        model = alexnet(pretrained=pretrained, model_dir=model_dir)
        model.classifier[6] = nn.Linear(model.classifier[6].in_features, args.num_classes)

    # Load already trained model parameters and go on training
    if model_path is not None:
        checkpoint = torch.load(model_path)
        model.load_state_dict(checkpoint['model'])

    return model
示例#19
0
def get_network(args):
    """ return given network
    """

    if args.model == 'vgg16':
        from models.vgg import vgg16_bn
        model = vgg16_bn()
    elif args.model == 'vgg13':
        from models.vgg import vgg13_bn
        model = vgg13_bn()
    elif args.model == 'vgg11':
        from models.vgg import vgg11_bn
        model = vgg11_bn()
    elif args.model == 'vgg19':
        from models.vgg import vgg19_bn
        model = vgg19_bn()
    elif args.model == 'densenet121':
        from models.densenet import densenet121
        model = densenet121()
    elif args.model == 'densenet161':
        from models.densenet import densenet161
        model = densenet161()
    elif args.model == 'densenet169':
        from models.densenet import densenet169
        model = densenet169()
    elif args.model == 'densenet201':
        from models.densenet import densenet201
        model = densenet201()
    elif args.model == 'googlenet':
        from models.googlenet import googlenet
        model = googlenet()
    elif args.model == 'inceptionv3':
        from models.inceptionv3 import inceptionv3
        model = inceptionv3()
    elif args.model == 'inceptionv4':
        from models.inceptionv4 import inceptionv4
        model = inceptionv4()
    elif args.model == 'inceptionresnetv2':
        from models.inceptionv4 import inception_resnet_v2
        model = inception_resnet_v2()
    elif args.model == 'xception':
        from models.xception import xception
        model = xception()
    elif args.model == 'resnet18':
        from models.resnet import resnet18
        model = resnet18()
    elif args.model == 'resnet34':
        from models.resnet import resnet34
        model = resnet34()
    elif args.model == 'resnet50':
        from models.resnet import resnet50
        model = resnet50()
    elif args.model == 'resnet101':
        from models.resnet import resnet101
        model = resnet101()
    elif args.model == 'resnet152':
        from models.resnet import resnet152
        model = resnet152()
    elif args.model == 'preactresnet18':
        from models.preactresnet import preactresnet18
        model = preactresnet18()
    elif args.model == 'preactresnet34':
        from models.preactresnet import preactresnet34
        model = preactresnet34()
    elif args.model == 'preactresnet50':
        from models.preactresnet import preactresnet50
        model = preactresnet50()
    elif args.model == 'preactresnet101':
        from models.preactresnet import preactresnet101
        model = preactresnet101()
    elif args.model == 'preactresnet152':
        from models.preactresnet import preactresnet152
        model = preactresnet152()
    elif args.model == 'resnext50':
        from models.resnext import resnext50
        model = resnext50()
    elif args.model == 'resnext101':
        from models.resnext import resnext101
        model = resnext101()
    elif args.model == 'resnext152':
        from models.resnext import resnext152
        model = resnext152()
    elif args.model == 'shufflenet':
        from models.shufflenet import shufflenet
        model = shufflenet()
    elif args.model == 'shufflenetv2':
        from models.shufflenetv2 import shufflenetv2
        model = shufflenetv2()
    elif args.model == 'squeezenet':
        from models.squeezenet import squeezenet
        model = squeezenet()
    elif args.model == 'mobilenet':
        from models.mobilenet import mobilenet
        model = mobilenet()
    elif args.model == 'mobilenetv2':
        from models.mobilenetv2 import mobilenetv2
        model = mobilenetv2()
    elif args.model == 'nasnet':
        from models.nasnet import nasnet
        model = nasnet()
    elif args.model == 'attention56':
        from models.attention import attention56
        model = attention56()
    elif args.model == 'attention92':
        from models.attention import attention92
        model = attention92()
    elif args.model == 'seresnet18':
        from models.senet import seresnet18
        model = seresnet18()
    elif args.model == 'seresnet34':
        from models.senet import seresnet34
        model = seresnet34()
    elif args.model == 'seresnet50':
        from models.senet import seresnet50
        model = seresnet50()
    elif args.model == 'seresnet101':
        from models.senet import seresnet101
        model = seresnet101()
    elif args.model == 'seresnet152':
        from models.senet import seresnet152
        model = seresnet152()
    elif args.model == 'wideresnet':
        from models.wideresidual import wideresnet
        model = wideresnet()
    elif args.model == 'stochasticdepth18':
        from models.stochasticdepth import stochastic_depth_resnet18
        model = stochastic_depth_resnet18()
    elif args.model == 'stochasticdepth34':
        from models.stochasticdepth import stochastic_depth_resnet34
        model = stochastic_depth_resnet34()
    elif args.model == 'stochasticdepth50':
        from models.stochasticdepth import stochastic_depth_resnet50
        model = stochastic_depth_resnet50()
    elif args.model == 'stochasticdepth101':
        from models.stochasticdepth import stochastic_depth_resnet101
        model = stochastic_depth_resnet101()

    else:
        print('the network name you have entered is not supported yet')
        sys.exit()

    return model
示例#20
0
def get_network(args, use_gpu=True, num_train=0):
    """ return given network
    """
    if args.dataset == 'cifar-10':
        num_classes = 10
    elif args.dataset == 'cifar-100':
        num_classes = 100
    else:
        num_classes = 0

    if args.ignoring:
        if args.net == 'resnet18':
            from models.resnet_ign import resnet18_ign
            criterion = nn.CrossEntropyLoss(reduction='none')
            net = resnet18_ign(criterion, num_classes=num_classes, num_train=num_train,softmax=args.softmax,isalpha=args.isalpha)

    else:
        if args.net == 'vgg16':
            from models.vgg import vgg16_bn
            net = vgg16_bn()
        elif args.net == 'vgg13':
            from models.vgg import vgg13_bn
            net = vgg13_bn()
        elif args.net == 'vgg11':
            from models.vgg import vgg11_bn
            net = vgg11_bn()
        elif args.net == 'vgg19':
            from models.vgg import vgg19_bn
            net = vgg19_bn()
        elif args.net == 'densenet121':
            from models.densenet import densenet121
            net = densenet121()
        elif args.net == 'densenet161':
            from models.densenet import densenet161
            net = densenet161()
        elif args.net == 'densenet169':
            from models.densenet import densenet169
            net = densenet169()
        elif args.net == 'densenet201':
            from models.densenet import densenet201
            net = densenet201()
        elif args.net == 'googlenet':
            from models.googlenet import googlenet
            net = googlenet()
        elif args.net == 'inceptionv3':
            from models.inceptionv3 import inceptionv3
            net = inceptionv3()
        elif args.net == 'inceptionv4':
            from models.inceptionv4 import inceptionv4
            net = inceptionv4()
        elif args.net == 'inceptionresnetv2':
            from models.inceptionv4 import inception_resnet_v2
            net = inception_resnet_v2()
        elif args.net == 'xception':
            from models.xception import xception
            net = xception()
        elif args.net == 'resnet18':
            from models.resnet import resnet18
            net = resnet18(num_classes=num_classes)
        elif args.net == 'resnet34':
            from models.resnet import resnet34
            net = resnet34(num_classes=num_classes)
        elif args.net == 'resnet50':
            from models.resnet import resnet50
            net = resnet50(num_classes=num_classes)
        elif args.net == 'resnet101':
            from models.resnet import resnet101
            net = resnet101(num_classes=num_classes)
        elif args.net == 'resnet152':
            from models.resnet import resnet152
            net = resnet152(num_classes=num_classes)
        elif args.net == 'preactresnet18':
            from models.preactresnet import preactresnet18
            net = preactresnet18()
        elif args.net == 'preactresnet34':
            from models.preactresnet import preactresnet34
            net = preactresnet34()
        elif args.net == 'preactresnet50':
            from models.preactresnet import preactresnet50
            net = preactresnet50()
        elif args.net == 'preactresnet101':
            from models.preactresnet import preactresnet101
            net = preactresnet101()
        elif args.net == 'preactresnet152':
            from models.preactresnet import preactresnet152
            net = preactresnet152()
        elif args.net == 'resnext50':
            from models.resnext import resnext50
            net = resnext50()
        elif args.net == 'resnext101':
            from models.resnext import resnext101
            net = resnext101()
        elif args.net == 'resnext152':
            from models.resnext import resnext152
            net = resnext152()
        elif args.net == 'shufflenet':
            from models.shufflenet import shufflenet
            net = shufflenet()
        elif args.net == 'shufflenetv2':
            from models.shufflenetv2 import shufflenetv2
            net = shufflenetv2()
        elif args.net == 'squeezenet':
            from models.squeezenet import squeezenet
            net = squeezenet()
        elif args.net == 'mobilenet':
            from models.mobilenet import mobilenet
            net = mobilenet()
        elif args.net == 'mobilenetv2':
            from models.mobilenetv2 import mobilenetv2
            net = mobilenetv2()
        elif args.net == 'nasnet':
            from models.nasnet import nasnet
            net = nasnet()
        elif args.net == 'attention56':
            from models.attention import attention56
            net = attention56()
        elif args.net == 'attention92':
            from models.attention import attention92
            net = attention92()
        elif args.net == 'seresnet18':
            from models.senet import seresnet18
            net = seresnet18()
        elif args.net == 'seresnet34':
            from models.senet import seresnet34
            net = seresnet34()
        elif args.net == 'seresnet50':
            from models.senet import seresnet50
            net = seresnet50()
        elif args.net == 'seresnet101':
            from models.senet import seresnet101
            net = seresnet101()
        elif args.net == 'seresnet152':
            from models.senet import seresnet152
            net = seresnet152()

        else:
            print('the network name you have entered is not supported yet')
            sys.exit()

    if use_gpu:
        net = net.cuda()

    return net
示例#21
0
文件: utils.py 项目: nblt/DLDR
def get_model(args):
    if args.datasets == 'ImageNet':
        return models_imagenet.__dict__[args.arch]()

    if args.datasets == 'CIFAR10' or args.datasets == 'MNIST':
        num_class = 10
    elif args.datasets == 'CIFAR100':
        num_class = 100

    if args.datasets == 'CIFAR100':
        if args.arch == 'vgg16':
            from models.vgg import vgg16_bn
            net = vgg16_bn()
        elif args.arch == 'vgg13':
            from models.vgg import vgg13_bn
            net = vgg13_bn()
        elif args.arch == 'vgg11':
            from models.vgg import vgg11_bn
            net = vgg11_bn()
        elif args.arch == 'vgg19':
            from models.vgg import vgg19_bn
            net = vgg19_bn()
        elif args.arch == 'densenet121':
            from models.densenet import densenet121
            net = densenet121()
        elif args.arch == 'densenet161':
            from models.densenet import densenet161
            net = densenet161()
        elif args.arch == 'densenet169':
            from models.densenet import densenet169
            net = densenet169()
        elif args.arch == 'densenet201':
            from models.densenet import densenet201
            net = densenet201()
        elif args.arch == 'googlenet':
            from models.googlenet import googlenet
            net = googlenet()
        elif args.arch == 'inceptionv3':
            from models.inceptionv3 import inceptionv3
            net = inceptionv3()
        elif args.arch == 'inceptionv4':
            from models.inceptionv4 import inceptionv4
            net = inceptionv4()
        elif args.arch == 'inceptionresnetv2':
            from models.inceptionv4 import inception_resnet_v2
            net = inception_resnet_v2()
        elif args.arch == 'xception':
            from models.xception import xception
            net = xception()
        elif args.arch == 'resnet18':
            from models.resnet import resnet18
            net = resnet18()
        elif args.arch == 'resnet34':
            from models.resnet import resnet34
            net = resnet34()
        elif args.arch == 'resnet50':
            from models.resnet import resnet50
            net = resnet50()
        elif args.arch == 'resnet101':
            from models.resnet import resnet101
            net = resnet101()
        elif args.arch == 'resnet152':
            from models.resnet import resnet152
            net = resnet152()
        elif args.arch == 'preactresnet18':
            from models.preactresnet import preactresnet18
            net = preactresnet18()
        elif args.arch == 'preactresnet34':
            from models.preactresnet import preactresnet34
            net = preactresnet34()
        elif args.arch == 'preactresnet50':
            from models.preactresnet import preactresnet50
            net = preactresnet50()
        elif args.arch == 'preactresnet101':
            from models.preactresnet import preactresnet101
            net = preactresnet101()
        elif args.arch == 'preactresnet152':
            from models.preactresnet import preactresnet152
            net = preactresnet152()
        elif args.arch == 'resnext50':
            from models.resnext import resnext50
            net = resnext50()
        elif args.arch == 'resnext101':
            from models.resnext import resnext101
            net = resnext101()
        elif args.arch == 'resnext152':
            from models.resnext import resnext152
            net = resnext152()
        elif args.arch == 'shufflenet':
            from models.shufflenet import shufflenet
            net = shufflenet()
        elif args.arch == 'shufflenetv2':
            from models.shufflenetv2 import shufflenetv2
            net = shufflenetv2()
        elif args.arch == 'squeezenet':
            from models.squeezenet import squeezenet
            net = squeezenet()
        elif args.arch == 'mobilenet':
            from models.mobilenet import mobilenet
            net = mobilenet()
        elif args.arch == 'mobilenetv2':
            from models.mobilenetv2 import mobilenetv2
            net = mobilenetv2()
        elif args.arch == 'nasnet':
            from models.nasnet import nasnet
            net = nasnet()
        elif args.arch == 'attention56':
            from models.attention import attention56
            net = attention56()
        elif args.arch == 'attention92':
            from models.attention import attention92
            net = attention92()
        elif args.arch == 'seresnet18':
            from models.senet import seresnet18
            net = seresnet18()
        elif args.arch == 'seresnet34':
            from models.senet import seresnet34
            net = seresnet34()
        elif args.arch == 'seresnet50':
            from models.senet import seresnet50
            net = seresnet50()
        elif args.arch == 'seresnet101':
            from models.senet import seresnet101
            net = seresnet101()
        elif args.arch == 'seresnet152':
            from models.senet import seresnet152
            net = seresnet152()
        elif args.arch == 'wideresnet':
            from models.wideresidual import wideresnet
            net = wideresnet()
        elif args.arch == 'stochasticdepth18':
            from models.stochasticdepth import stochastic_depth_resnet18
            net = stochastic_depth_resnet18()
        elif args.arch == 'efficientnet':
            from models.efficientnet import efficientnet
            net = efficientnet(1, 1, 100, bn_momentum=0.9)
        elif args.arch == 'stochasticdepth34':
            from models.stochasticdepth import stochastic_depth_resnet34
            net = stochastic_depth_resnet34()
        elif args.arch == 'stochasticdepth50':
            from models.stochasticdepth import stochastic_depth_resnet50
            net = stochastic_depth_resnet50()
        elif args.arch == 'stochasticdepth101':
            from models.stochasticdepth import stochastic_depth_resnet101
            net = stochastic_depth_resnet101()
        else:
            net = resnet.__dict__[args.arch](num_classes=num_class)

        return net
    return resnet.__dict__[args.arch](num_classes=num_class)
示例#22
0
            cnt += 1
            if cnt == 20:
                logger.info("early stop")
                break


for ds in dataset:
    data_path = os.path.join(args.root, ds)
    cls = [
        x for x in os.listdir(data_path)
        if os.path.isdir(os.path.join(data_path, x))
    ]
    num_class = len(cls)
    models = {
        "vgg16": vgg.vgg16_bn(num_class),
        "vgg19": vgg.vgg19_bn(num_class),
        "densenet121": densenet.densenet121(num_class),
        "densenet161": densenet.densenet161(num_class),
        "resnet34": resnet.resnet34(num_class),
        "resnet50": resnet.resnet50(num_class),
        "resnet101": resnet.resnet101(num_class),
        "seresnet34": senet.seresnet34(num_class),
        "seresnet50": senet.seresnet50(num_class),
        "seresnet101": senet.seresnet101(num_class),
        "resnext34": resnext.resnext34(num_class),
        "resnext50": resnext.resnext50(num_class),
        "resnext101": resnext.resnext101(num_class),
        "shufflenet": shufflenet.shufflenet(num_class),
        "xception": xception.xception(num_class)
    }
    for net_name in models.keys():
示例#23
0
def get_network(args):
    """ return given network
    """
    if args.task == 'cifar10':
        nclass = 10
    elif args.task == 'cifar100':
        nclass = 100
    #Yang added none bn vggs
    if args.net == 'vgg16':
        from models.vgg import vgg16
        net = vgg16(num_classes=nclass)
    elif args.net == 'vgg13':
        from models.vgg import vgg13
        net = vgg13(num_classes=nclass)
    elif args.net == 'vgg11':
        from models.vgg import vgg11
        net = vgg11(num_classes=nclass)
    elif args.net == 'vgg19':
        from models.vgg import vgg19
        net = vgg19(num_classes=nclass)

    elif args.net == 'vgg16bn':
        from models.vgg import vgg16_bn
        net = vgg16_bn(num_classes=nclass)
    elif args.net == 'vgg13bn':
        from models.vgg import vgg13_bn
        net = vgg13_bn(num_classes=nclass)
    elif args.net == 'vgg11bn':
        from models.vgg import vgg11_bn
        net = vgg11_bn(num_classes=nclass)
    elif args.net == 'vgg19bn':
        from models.vgg import vgg19_bn
        net = vgg19_bn(num_classes=nclass)

    elif args.net == 'densenet121':
        from models.densenet import densenet121
        net = densenet121()
    elif args.net == 'densenet161':
        from models.densenet import densenet161
        net = densenet161()
    elif args.net == 'densenet169':
        from models.densenet import densenet169
        net = densenet169()
    elif args.net == 'densenet201':
        from models.densenet import densenet201
        net = densenet201()
    elif args.net == 'googlenet':
        from models.googlenet import googlenet
        net = googlenet(num_classes=nclass)
    elif args.net == 'inceptionv3':
        from models.inceptionv3 import inceptionv3
        net = inceptionv3()
    elif args.net == 'inceptionv4':
        from models.inceptionv4 import inceptionv4
        net = inceptionv4()
    elif args.net == 'inceptionresnetv2':
        from models.inceptionv4 import inception_resnet_v2
        net = inception_resnet_v2()
    elif args.net == 'xception':
        from models.xception import xception
        net = xception(num_classes=nclass)
    elif args.net == 'scnet':
        from models.sphereconvnet import sphereconvnet
        net = sphereconvnet(num_classes=nclass)
    elif args.net == 'sphereresnet18':
        from models.sphereconvnet import resnet18
        net = resnet18(num_classes=nclass)
    elif args.net == 'sphereresnet32':
        from models.sphereconvnet import sphereresnet32
        net = sphereresnet32(num_classes=nclass)
    elif args.net == 'plainresnet32':
        from models.sphereconvnet import plainresnet32
        net = plainresnet32(num_classes=nclass)
    elif args.net == 'ynet18':
        from models.ynet import resnet18
        net = resnet18(num_classes=nclass)
    elif args.net == 'ynet34':
        from models.ynet import resnet34
        net = resnet34(num_classes=nclass)
    elif args.net == 'ynet50':
        from models.ynet import resnet50
        net = resnet50(num_classes=nclass)
    elif args.net == 'ynet101':
        from models.ynet import resnet101
        net = resnet101(num_classes=nclass)
    elif args.net == 'ynet152':
        from models.ynet import resnet152
        net = resnet152(num_classes=nclass)

    elif args.net == 'resnet18':
        from models.resnet import resnet18
        net = resnet18(num_classes=nclass)
    elif args.net == 'resnet34':
        from models.resnet import resnet34
        net = resnet34(num_classes=nclass)
    elif args.net == 'resnet50':
        from models.resnet import resnet50
        net = resnet50(num_classes=nclass)
    elif args.net == 'resnet101':
        from models.resnet import resnet101
        net = resnet101(num_classes=nclass)
    elif args.net == 'resnet152':
        from models.resnet import resnet152
        net = resnet152(num_classes=nclass)
    elif args.net == 'preactresnet18':
        from models.preactresnet import preactresnet18
        net = preactresnet18(num_classes=nclass)
    elif args.net == 'preactresnet34':
        from models.preactresnet import preactresnet34
        net = preactresnet34(num_classes=nclass)
    elif args.net == 'preactresnet50':
        from models.preactresnet import preactresnet50
        net = preactresnet50(num_classes=nclass)
    elif args.net == 'preactresnet101':
        from models.preactresnet import preactresnet101
        net = preactresnet101(num_classes=nclass)
    elif args.net == 'preactresnet152':
        from models.preactresnet import preactresnet152
        net = preactresnet152(num_classes=nclass)
    elif args.net == 'resnext50':
        from models.resnext import resnext50
        net = resnext50(num_classes=nclass)
    elif args.net == 'resnext101':
        from models.resnext import resnext101
        net = resnext101(num_classes=nclass)
    elif args.net == 'resnext152':
        from models.resnext import resnext152
        net = resnext152(num_classes=nclass)
    elif args.net == 'shufflenet':
        from models.shufflenet import shufflenet
        net = shufflenet()
    elif args.net == 'shufflenetv2':
        from models.shufflenetv2 import shufflenetv2
        net = shufflenetv2()
    elif args.net == 'squeezenet':
        from models.squeezenet import squeezenet
        net = squeezenet()
    elif args.net == 'mobilenet':
        from models.mobilenet import mobilenet
        net = mobilenet(num_classes=nclass)
    elif args.net == 'mobilenetv2':
        from models.mobilenetv2 import mobilenetv2
        net = mobilenetv2(num_classes=nclass)
    elif args.net == 'nasnet':
        from models.nasnet import nasnet
        net = nasnet(num_classes=nclass)
    elif args.net == 'attention56':
        from models.attention import attention56
        net = attention56()
    elif args.net == 'attention92':
        from models.attention import attention92
        net = attention92()
    elif args.net == 'seresnet18':
        from models.senet import seresnet18
        net = seresnet18(num_classes=nclass)
    elif args.net == 'seresnet34':
        from models.senet import seresnet34
        net = seresnet34(num_classes=nclass)
    elif args.net == 'seresnet50':
        from models.senet import seresnet50
        net = seresnet50(num_classes=nclass)
    elif args.net == 'seresnet101':
        from models.senet import seresnet101
        net = seresnet101(num_classes=nclass)
    elif args.net == 'seresnet152':
        from models.senet import seresnet152
        net = seresnet152(num_classes=nclass)

    else:
        print('the network name you have entered is not supported yet')
        sys.exit()

    if args.gpu:  #use_gpu
        net = net.cuda()

    return net
示例#24
0
def get_network(args):
    """ return given network
    """

    if args.net == 'vgg11':
        from models.vgg import vgg11_bn
        net = vgg11_bn()
    elif args.net == 'vgg13':
        from models.vgg import vgg13_bn
        net = vgg13_bn()
    elif args.net == 'vgg16':
        from models.vgg import vgg16_bn
        net = vgg16_bn()
    elif args.net == 'vgg19':
        from models.vgg import vgg19_bn
        net = vgg19_bn()
    elif args.net == 'googlenet':
        from models.googLeNet import GoogLeNet
        net = GoogLeNet()
    elif args.net == 'inceptionv3':
        from models.inceptionv3 import Inceptionv3
        net = Inceptionv3()
    elif args.net == 'resnet18':
        from models.resnet import resnet18
        net = resnet18()
    elif args.net == 'resnet34':
        from models.resnet import resnet34
        net = resnet34()
    elif args.net == 'resnet50':
        from models.resnet import resnet50
        net = resnet50()
    elif args.net == 'resnet101':
        from models.resnet import resnet101
        net = resnet101()
    elif args.net == 'resnet152':
        from models.resnet import resnet152
        net = resnet152()
    elif args.net == 'wrn':
        from models.wideresnet import wideresnet
        net = wideresnet()
    elif args.net == 'densenet121':
        from models.densenet import densenet121
        net = densenet121()
    elif args.net == 'densenet161':
        from models.densenet import densenet161
        net = densenet161()
    elif args.net == 'densenet169':
        from models.densenet import densenet169
        net = densenet169()
    elif args.net == 'densenet201':
        from models.densenet import densenet201
        net = densenet201()

    else:
        print('the network name you have entered is not supported yet')
        sys.exit()

    if args.gpu:
        print("use gpu")
        net = net.cuda()

    return net
示例#25
0
def get_network(key, num_cls=2, use_gpu=False):
    """ return given network
    """

    if key == 'vgg16':
        from models.vgg import vgg16_bn
        net = vgg16_bn(num_cls)
    elif key == 'vgg13':
        from models.vgg import vgg13_bn
        net = vgg13_bn(num_cls)
    elif key == 'vgg11':
        from models.vgg import vgg11_bn
        net = vgg11_bn(num_cls)
    elif key == 'vgg19':
        from models.vgg import vgg19_bn
        net = vgg19_bn(num_cls)
    elif key == 'resnext':
        print('we will continue')
    elif key == 'efficientNetb0':
        from models.torchefficient import make_model
        net = make_model(key, num_cls)
    elif key == 'efficientNetb1':
        from models.torchefficient import make_model
        net = make_model(key, num_cls)
    elif key == 'efficientNetb2':
        from models.torchefficient import make_model
        net = make_model(key, num_cls)
    elif key == 'efficientNetb3':
        from models.torchefficient import make_model
        net = make_model(key, num_cls)
    elif key == 'efficientNetb4':
        from models.torchefficient import make_model
        net = make_model(key, num_cls)
    elif key == 'efficientNetb5':
        from models.torchefficient import make_model
        net = make_model(key, num_cls)
    elif key == 'efficientNetb6':
        from models.torchefficient import make_model
        net = make_model(key, num_cls)
    elif key == 'efficientNetb7':
        from models.torchefficient import make_model
        net = make_model(key, num_cls)
    elif key == 'resnext50_32x8d':
        from models.resnext import make_model
        net = make_model(key)
    elif key == 'resnext101_32x8d':
        from models.resnext import make_model
        net = make_model(key)
    elif key == 'resnet50':
        from models.resnet import make_model
        net = make_model(key)
    elif key == 'resnet18':
        from models.resnet import make_model
        net = make_model(key)
    elif key == 'resnet34':
        from models.resnet import make_model
        net = make_model(key)
    elif key == 'resnet101':
        from models.resnet import make_model
        net = make_model(key)
    else:
        print('the network name you have entered is not supported yet')
        sys.exit()
    if use_gpu:
        net = net.cuda()
    return net
示例#26
0
def train_model(modname='alexnet', pm_ch='both', bs=16):
    """
    Args:
        modname (string): Name of the model. Has to be one of the values:
            'alexnet', batch 64
            'densenet'
            'inception'
            'resnet', batch 16
            'squeezenet', batch 16
            'vgg'
        pm_ch (string): pixelmap channel -- 'time', 'charge', 'both', default to both
    """
    # device configuration
    device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

    # hyper parameters
    max_epochs = 10
    learning_rate = 0.001

    # determine number of input channels
    nch = 2
    if pm_ch != 'both':
        nch = 1

    ds = PixelMapDataset('training_file_list.txt', pm_ch)
    # try out the data loader utility
    dl = torch.utils.data.DataLoader(dataset=ds, batch_size=bs, shuffle=True)

    # define model
    model = None
    if modname == 'alexnet':
        model = alexnet(num_classes=3, in_ch=nch).to(device)
    elif modname == 'densenet':
        model = DenseNet(num_classes=3, in_ch=nch).to(device)
    elif modname == 'inception':
        model = inception_v3(num_classes=3, in_ch=nch).to(device)
    elif modname == 'resnet':
        model = resnet18(num_classes=3, in_ch=nch).to(device)
    elif modname == 'squeezenet':
        model = squeezenet1_1(num_classes=3, in_ch=nch).to(device)
    elif modname == 'vgg':
        model = vgg19_bn(in_ch=nch, num_classes=3).to(device)
    else:
        print('Model {} not defined.'.format(modname))
        return

    # loss and optimizer
    criterion = nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

    # training process
    total_step = len(dl)
    for epoch in range(max_epochs):
        for i, (view1, view2, local_labels) in enumerate(dl):
            view1 = view1.float().to(device)
            if modname == 'inception':
                view1 = nn.ZeroPad2d((0, 192, 102, 101))(view1)
            else:
                view1 = nn.ZeroPad2d((0, 117, 64, 64))(view1)
            local_labels = local_labels.to(device)

            # forward pass
            outputs = model(view1)
            loss = criterion(outputs, local_labels)

            # backward and optimize
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            if (i + 1) % bs == 0:
                print('Epoch [{}/{}], Step [{}/{}], Loss: {:.4f}'.format(
                    epoch + 1, max_epochs, i + 1, total_step, loss.item()))

    # save the model checkpoint
    save_path = '../../../data/two_views/saved_models/{}/{}'.format(
        modname, pm_ch)
    os.makedirs(save_path, exist_ok=True)
    torch.save(model.state_dict(), os.path.join(save_path, 'model.ckpt'))
def load_model(model_name, training_type, configs):
    """
    Loads model.
    """

    dataset = configs.dataset.lower()
    # set the input channels and num_classes
    if dataset == "mnist" or dataset == "fashionmnist":
        configs.num_classes = 10
        configs.input_channels = 1
    elif dataset == "cifar-100":
        configs.num_classes = 100
        configs.input_channels = 3
    elif dataset == "imagenet":
        configs.num_classes = 1000
        configs.input_channels = 3
    else:
        configs.num_classes = 10
        configs.input_channels = 3

    # pick model
    if model_name == "Resnet18":
        # load weights
        if training_type == "pretrained":
            print("Loading pretrained Resnet18")
            model = torchvision.models.resnet18(pretrained=True)
            model.fc.Linear = nn.Linear(model.fc.in_features,
                                        configs.num_classes)

        elif training_type == "untrained":
            print("Loading untrained Resnet18")
            model = ResNet18(num_classes=configs.num_classes,
                             input_channels=configs.input_channels)
    elif model_name == "Resnet50":
        # load weights
        if training_type == "pretrained":
            print(f"Loading pretrained {model_name}")
            model = torchvision.models.resnet50(pretrained=True)
            model.fc.Linear = nn.Linear(model.fc.in_features,
                                        configs.num_classes)

        elif training_type == "untrained":
            print(f"Loading untrained {model_name}")
            model = ResNet50(num_classes=configs.num_classes,
                             input_channels=configs.input_channels)

    elif model_name == "Resnet101":
        if training_type == 'pretrained':
            print(f"Loading pretrained {model_name}")

            model = torchvision.models.resnet101(pretrained=True)
        elif training_type == "untrained":
            print(f"Loading untrained {model_name}")

            model = torchvision.models.resnet101()

    elif model_name == "VGG19":
        # load weights
        if training_type == "pretrained":

            print(f"Loading pretrained {model_name}")

            model = torchvision.models.vgg19(pretrained=True)
            model.classifier = nn.Sequential(
                nn.Linear(512 * 7 * 7, 512),
                nn.ReLU(True),
                nn.Dropout(),
                nn.Linear(512, 512),
                nn.ReLU(True),
                nn.Dropout(),
                nn.Linear(512, 512),
                nn.ReLU(True),
                nn.Linear(512, 10),
            )
        elif training_type == "untrained":

            print(f"Loading untrained {model_name}")

            model = vgg19_bn(in_channels=configs.input_channels,
                             num_classes=configs.num_classes)

    elif "efficientnet" in model_name:

        if training_type == 'pretrained':
            print(f"Loading pretrained {model_name}")

            model = load_efficientnet(model_name, configs.num_classes,
                                      configs.input_channels, True)

        elif training_type == "untrained":
            print(f"Loading untrained {model_name}")

            model = load_efficientnet(model_name, configs.num_classes,
                                      configs.input_channels, False)
    else:
        print("Please provide a model")

    # push model to cuda
    if torch.cuda.device_count() > 1:
        print(f"Number of GPUs available are {torch.cuda.device_count()}")
        model = nn.DataParallel(model)
        print("\nModel moved to Data Parallel")
    model.cuda()

    return model