Exemplo n.º 1
0
    def loading_model(self):

        print('Loading %s model' % (self.model_type))
        pretrained = None
        if self.pretrained:
            pretrained = 'imagenet'

        if self.model_type == 'inceptionv4':
            self.model = inceptionv4(num_classes=1000, pretrained=pretrained)
            if self.pretrained:
                num_ftrs = self.model.last_linear.in_features
                self.model.last_linear = nn.Linear(num_ftrs, self.num_classes)
                # free all layers:
                for i, param in self.model.named_parameters():
                    param.requires_grad = False
                # unfreeze last layers:
                ct = []
                for name, child in self.model.features.named_children():
                    if "4" in ct:
                        for param in child.parameters():
                            param.requires_grad = True
                    ct.append(name)

            else:
                num_ftrs = self.model.last_linear.in_features
                self.model.last_linear = nn.Linear(num_ftrs, self.num_classes)
        else:
            print('no model')
            exit()

        # Check gpu and run parallel
        if check_gpu() > 0:
            self.model = torch.nn.DataParallel(self.model).cuda()
            # self.model.cuda()

        # define loss function (criterion) and optimizer
        self.criterion = nn.CrossEntropyLoss()
        if check_gpu() > 0:
            self.criterion = nn.CrossEntropyLoss().cuda()

        params = list(
            filter(lambda p: p.requires_grad, self.model.parameters()))
        self.optimizer = optim.SGD(params=params,
                                   lr=self.lr,
                                   momentum=self.momentum,
                                   weight_decay=self.weight_decay)

        file = os.path.join(self.data_folder, 'model_best.pth.tar')
        if os.path.isfile(file):
            print("=> loading checkpoint '{}'".format('model_best.pth.tar'))

            checkpoint = torch.load(file)
            self.start_epoch = checkpoint['epoch']
            self.best_prec1 = checkpoint['best_prec1']
            self.model.load_state_dict(checkpoint['state_dict'])
            self.optimizer.load_state_dict(checkpoint['optimizer'])
            print("=> loaded best model")
        else:
            print("=> no model best found at ")
            exit()

        cudnn.benchmark = True
def get_network(args, use_gpu=True):
    """ return given network
    """

    if args.net == 'vgg16':
        from models.vgg import vgg16_bn
        net = vgg16_bn()
    elif args.net == 'vgg13':
        from models.vgg import vgg13_bn
        net = vgg13_bn()
    elif args.net == 'vgg11':
        from models.vgg import vgg11_bn
        net = vgg11_bn()
    elif args.net == 'vgg19':
        from models.vgg import vgg19_bn
        net = vgg19_bn()
    elif args.net == 'densenet121':
        from models.densenet import densenet121
        net = densenet121()
    elif args.net == 'densenet161':
        from models.densenet import densenet161
        net = densenet161()
    elif args.net == 'densenet169':
        from models.densenet import densenet169
        net = densenet169()
    elif args.net == 'densenet201':
        from models.densenet import densenet201
        net = densenet201()
    elif args.net == 'googlenet':
        from models.googlenet import googlenet
        net = googlenet()
    elif args.net == 'inceptionv3':
        from models.inceptionv3 import inceptionv3
        net = inceptionv3()
    elif args.net == 'inceptionv4':
        from models.inceptionv4 import inceptionv4
        net = inceptionv4()
    elif args.net == 'inceptionresnetv2':
        from models.inceptionv4 import inception_resnet_v2
        net = inception_resnet_v2()
    elif args.net == 'xception':
        from models.xception import xception
        net = xception()
    elif args.net == 'resnet18':
        from models.resnet import resnet18
        net = resnet18()
    elif args.net == 'resnet34':
        from models.resnet import resnet34
        net = resnet34()
    elif args.net == 'resnet50':
        from models.resnet import resnet50
        net = resnet50()
    elif args.net == 'resnet101':
        from models.resnet import resnet101
        net = resnet101()
    elif args.net == 'resnet152':
        from models.resnet import resnet152
        net = resnet152()
    elif args.net == 'preactresnet18':
        from models.preactresnet import preactresnet18
        net = preactresnet18()
    elif args.net == 'preactresnet34':
        from models.preactresnet import preactresnet34
        net = preactresnet34()
    elif args.net == 'preactresnet50':
        from models.preactresnet import preactresnet50
        net = preactresnet50()
    elif args.net == 'preactresnet101':
        from models.preactresnet import preactresnet101
        net = preactresnet101()
    elif args.net == 'preactresnet152':
        from models.preactresnet import preactresnet152
        net = preactresnet152()
    elif args.net == 'resnext50':
        from models.resnext import resnext50
        net = resnext50()
    elif args.net == 'resnext101':
        from models.resnext import resnext101
        net = resnext101()
    elif args.net == 'resnext152':
        from models.resnext import resnext152
        net = resnext152()
    elif args.net == 'shufflenet':
        from models.shufflenet import shufflenet
        net = shufflenet()
    elif args.net == 'shufflenetv2':
        from models.shufflenetv2 import shufflenetv2
        net = shufflenetv2()
    elif args.net == 'squeezenet':
        from models.squeezenet import squeezenet
        net = squeezenet()
    elif args.net == 'mobilenet':
        from models.mobilenet import mobilenet
        net = mobilenet()
    elif args.net == 'mobilenetv2':
        from models.mobilenetv2 import mobilenetv2
        net = mobilenetv2()
    elif args.net == 'nasnet':
        from models.nasnet import nasnet
        net = nasnet()
    elif args.net == 'attention56':
        from models.attention import attention56
        net = attention56()
    elif args.net == 'attention92':
        from models.attention import attention92
        net = attention92()
    elif args.net == 'seresnet18':
        from models.senet import seresnet18
        net = seresnet18()
    elif args.net == 'seresnet34':
        from models.senet import seresnet34
        net = seresnet34()
    elif args.net == 'seresnet50':
        from models.senet import seresnet50
        net = seresnet50()
    elif args.net == 'seresnet101':
        from models.senet import seresnet101
        net = seresnet101()
    elif args.net == 'seresnet152':
        from models.senet import seresnet152
        net = seresnet152()

    else:
        print('the network name you have entered is not supported yet')
        sys.exit()

    if use_gpu:
        net = net.cuda()

    return net
Exemplo n.º 3
0
def get_network(args, use_gpu=True):
    """ return given network
    """
    if args.net == 'mobilenet':
        from models.mobilenet import mobilenet
        net = mobilenet(args)
    elif args.net == 'mobilenetv2':
        from models.mobilenetv2 import mobilenetv2
        net = mobilenetv2(args)
    elif args.net == 'vgg13':
        from models.vgg import vgg13_bn
        net = vgg13_bn(args)
    elif args.net == 'vgg11':
        from models.vgg import vgg11_bn
        net = vgg11_bn(args)
    elif args.net == 'vgg19':
        # from models.vgg import vgg19_bn
        # net = vgg19_bn(args)
        from torchvision.models import vgg19_bn
        import torch.nn as nn
        net = vgg19_bn(pretrained=True)
        net.classifier[6] = nn.Linear(4096, args.nc)
    elif args.net == 'densenet121':
        from models.densenet import densenet121
        net = densenet121(args)
    elif args.net == 'densenet161':
        from models.densenet import densenet161
        net = densenet161(args)
    elif args.net == 'densenet169':
        from models.densenet import densenet169
        net = densenet169(args)
    elif args.net == 'densenet201':
        from models.densenet import densenet201
        net = densenet201(args)
    elif args.net == 'googlenet':
        from models.googlenet import googlenet
        net = googlenet(args)
    elif args.net == 'inceptionv3':
        from models.inceptionv3 import inceptionv3
        net = inceptionv3(args)
    elif args.net == 'inceptionv4':
        from models.inceptionv4 import inceptionv4
        net = inceptionv4(args)
    elif args.net == 'inceptionresnetv2':
        from models.inceptionv4 import inception_resnet_v2
        net = inception_resnet_v2(args)
    elif args.net == 'xception':
        from models.xception import xception
        net = xception(args)
    elif args.net == 'resnet18':
        # from models.resnet import resnet18
        # net = resnet18(args)
        from torchvision.models import resnet18
        import torch.nn as nn
        net = resnet18(pretrained=True)
        net.fc = nn.Linear(512, args.nc)
    elif args.net == 'resnet34':
        from models.resnet import resnet34
        net = resnet34(args)
    elif args.net == 'resnet50':
        from models.resnet import resnet50
        net = resnet50(args)
    elif args.net == 'resnet101':
        from models.resnet import resnet101
        net = resnet101(args)
    elif args.net == 'resnet152':
        from models.resnet import resnet152
        net = resnet152(args)
    elif args.net == 'preactresnet18':
        from models.preactresnet import preactresnet18
        net = preactresnet18(args)
    elif args.net == 'preactresnet34':
        from models.preactresnet import preactresnet34
        net = preactresnet34(args)
    elif args.net == 'preactresnet50':
        from models.preactresnet import preactresnet50
        net = preactresnet50(args)
    elif args.net == 'preactresnet101':
        from models.preactresnet import preactresnet101
        net = preactresnet101(args)
    elif args.net == 'preactresnet152':
        from models.preactresnet import preactresnet152
        net = preactresnet152(args)
    elif args.net == 'resnext50':
        from models.resnext import resnext50
        net = resnext50(args)
    elif args.net == 'resnext101':
        from models.resnext import resnext101
        net = resnext101(args)
    elif args.net == 'resnext152':
        from models.resnext import resnext152
        net = resnext152(args)
    elif args.net == 'shufflenet':
        from models.shufflenet import shufflenet
        net = shufflenet(args)
    elif args.net == 'shufflenetv2':
        from models.shufflenetv2 import shufflenetv2
        net = shufflenetv2(args)
    elif args.net == 'squeezenet':
        from models.squeezenet import squeezenet
        net = squeezenet(args)

    elif args.net == 'mobilenet':
        from models.mobilenet import mobilenet
        net = mobilenet(args)
    elif args.net == 'mobilenetv2':
        from models.mobilenetv2 import mobilenetv2
        net = mobilenetv2(args)
    elif args.net == 'mobilenetv3':
        from models.mobilenetv3 import mobileNetv3
        net = mobileNetv3(args)
    elif args.net == 'mobilenetv3_l':
        from models.mobilenetv3 import mobileNetv3
        net = mobileNetv3(args, mode='large')
    elif args.net == 'mobilenetv3_s':
        from models.mobilenetv3 import mobileNetv3
        net = mobileNetv3(args, mode='small')
    elif args.net == 'nasnet':
        from models.nasnet import nasnetalarge
        net = nasnetalarge(args)
    elif args.net == 'attention56':
        from models.attention import attention56
        net = attention56(args)
    elif args.net == 'attention92':
        from models.attention import attention92
        net = attention92(args)
    elif args.net == 'seresnet18':
        from models.senet import seresnet18
        net = seresnet18(args)
    elif args.net == 'seresnet34':
        from models.senet import seresnet34
        net = seresnet34(args)
    elif args.net == 'seresnet50':
        from models.senet import seresnet50
        net = seresnet50(args)
    elif args.net == 'seresnet101':
        from models.senet import seresnet101
        net = seresnet101(args)
    elif args.net == 'seresnet152':
        from models.senet import seresnet152
        net = seresnet152(args)
    elif args.net.lower() == 'sqnxt_23_1x':
        from models.SqueezeNext import SqNxt_23_1x
        net = SqNxt_23_1x(args)
    elif args.net.lower() == 'sqnxt_23_1xv5':
        from models.SqueezeNext import SqNxt_23_1x_v5
        net = SqNxt_23_1x_v5(args)
    elif args.net.lower() == 'sqnxt_23_2x':
        from models.SqueezeNext import SqNxt_23_2x
        net = SqNxt_23_2x(args)
    elif args.net.lower() == 'sqnxt_23_2xv5':
        from models.SqueezeNext import SqNxt_23_2x_v5
        net = SqNxt_23_2x_v5(args)
    elif args.net.lower() == 'mnasnet':
        # from models.MnasNet import mnasnet
        # net = mnasnet(args)
        from models.nasnet_mobile import nasnet_Mobile
        net = nasnet_Mobile(args)
    elif args.net == 'efficientnet_b0':
        from models.efficientnet import efficientnet_b0
        net = efficientnet_b0(args)
    elif args.net == 'efficientnet_b1':
        from models.efficientnet import efficientnet_b1
        net = efficientnet_b1(args)
    elif args.net == 'efficientnet_b2':
        from models.efficientnet import efficientnet_b2
        net = efficientnet_b2(args)
    elif args.net == 'efficientnet_b3':
        from models.efficientnet import efficientnet_b3
        net = efficientnet_b3(args)
    elif args.net == 'efficientnet_b4':
        from models.efficientnet import efficientnet_b4
        net = efficientnet_b4(args)
    elif args.net == 'efficientnet_b5':
        from models.efficientnet import efficientnet_b5
        net = efficientnet_b5(args)
    elif args.net == 'efficientnet_b6':
        from models.efficientnet import efficientnet_b6
        net = efficientnet_b6(args)
    elif args.net == 'efficientnet_b7':
        from models.efficientnet import efficientnet_b7
        net = efficientnet_b7(args)
    elif args.net == 'mlp':
        from models.mlp import MLPClassifier
        net = MLPClassifier(args)
    elif args.net == 'alexnet':
        from torchvision.models import alexnet
        import torch.nn as nn
        net = alexnet(pretrained=True)
        net.classifier[6] = nn.Linear(4096, args.nc)
    elif args.net == 'lambda18':
        from models._lambda import LambdaResnet18
        net = LambdaResnet18(num_classes=args.nc, channels=args.cs)
    elif args.net == 'lambda34':
        from models._lambda import LambdaResnet34
        net = LambdaResnet34(num_classes=args.nc, channels=args.cs)
    elif args.net == 'lambda50':
        from models._lambda import LambdaResnet50
        net = LambdaResnet50(num_classes=args.nc, channels=args.cs)
    elif args.net == 'lambda101':
        from models._lambda import LambdaResnet101
        net = LambdaResnet101(num_classes=args.nc)
    elif args.net == 'lambda152':
        from models._lambda import LambdaResnet152
        net = LambdaResnet152(num_classes=args.nc, channels=args.cs)
    else:
        print('the network name you have entered is not supported yet')
        sys.exit()

    if use_gpu:
        net = net.cuda()

    return net
Exemplo n.º 4
0
def get_network(args):
    """ return given network
    """

    if args.net == 'vgg16':
        from models.vgg import vgg16_bn
        net = vgg16_bn()
    elif args.net == 'vgg13':
        from models.vgg import vgg13_bn
        net = vgg13_bn()
    elif args.net == 'vgg11':
        from models.vgg import vgg11_bn
        net = vgg11_bn()
    elif args.net == 'vgg19':
        from models.vgg import vgg19_bn
        net = vgg19_bn()
    elif args.net == 'densenet121':
        from models.densenet import densenet121
        net = densenet121()
    elif args.net == 'densenet161':
        from models.densenet import densenet161
        net = densenet161()
    elif args.net == 'densenet169':
        from models.densenet import densenet169
        net = densenet169()
    elif args.net == 'densenet201':
        from models.densenet import densenet201
        net = densenet201()
    elif args.net == 'googlenet':
        from models.googlenet import googlenet
        net = googlenet()
    elif args.net == 'inceptionv3':
        from models.inceptionv3 import inceptionv3
        net = inceptionv3()
    elif args.net == 'inceptionv4':
        from models.inceptionv4 import inceptionv4
        net = inceptionv4()
    elif args.net == 'inceptionresnetv2':
        from models.inceptionv4 import inception_resnet_v2
        net = inception_resnet_v2()
    elif args.net == 'xception':
        from models.xception import xception
        net = xception()
    elif args.net == 'resnet18':
        from models.resnet import resnet18
        net = resnet18()
    elif args.net == 'resnet34':
        from models.resnet import resnet34
        net = resnet34()
    elif args.net == 'resnet50':
        from models.resnet import resnet50
        net = resnet50()
    elif args.net == 'resnet101':
        from models.resnet import resnet101
        net = resnet101()
    elif args.net == 'resnet152':
        from models.resnet import resnet152
        net = resnet152()
    elif args.net == 'preactresnet18':
        from models.preactresnet import preactresnet18
        net = preactresnet18()
    elif args.net == 'preactresnet34':
        from models.preactresnet import preactresnet34
        net = preactresnet34()
    elif args.net == 'preactresnet50':
        from models.preactresnet import preactresnet50
        net = preactresnet50()
    elif args.net == 'preactresnet101':
        from models.preactresnet import preactresnet101
        net = preactresnet101()
    elif args.net == 'preactresnet152':
        from models.preactresnet import preactresnet152
        net = preactresnet152()
    elif args.net == 'resnext50':
        from models.resnext import resnext50
        net = resnext50()
    elif args.net == 'resnext101':
        from models.resnext import resnext101
        net = resnext101()
    elif args.net == 'resnext152':
        from models.resnext import resnext152
        net = resnext152()
    elif args.net == 'shufflenet':
        from models.shufflenet import shufflenet
        net = shufflenet()
    elif args.net == 'shufflenetv2':
        from models.shufflenetv2 import shufflenetv2
        net = shufflenetv2()
    elif args.net == 'squeezenet':
        from models.squeezenet import squeezenet
        net = squeezenet()
    elif args.net == 'mobilenet':
        from models.mobilenet import mobilenet
        net = mobilenet()
    elif args.net == 'mobilenetv2':
        from models.mobilenetv2 import mobilenetv2
        net = mobilenetv2()
    elif args.net == 'nasnet':
        from models.nasnet import nasnet
        net = nasnet()
    elif args.net == 'attention56':
        from models.attention import attention56
        net = attention56()
    elif args.net == 'attention92':
        from models.attention import attention92
        net = attention92()
    elif args.net == 'seresnet18':
        from models.senet import seresnet18
        net = seresnet18()
    elif args.net == 'seresnet34':
        from models.senet import seresnet34
        net = seresnet34()
    elif args.net == 'seresnet50':
        from models.senet import seresnet50
        net = seresnet50()
    elif args.net == 'seresnet101':
        from models.senet import seresnet101
        net = seresnet101()
    elif args.net == 'seresnet152':
        from models.senet import seresnet152
        net = seresnet152()
    elif args.net == 'wideresnet':
        from models.wideresidual import wideresnet
        net = wideresnet()
    elif args.net == 'stochasticdepth18':
        from models.stochasticdepth import stochastic_depth_resnet18
        net = stochastic_depth_resnet18()
    elif args.net == 'stochasticdepth34':
        from models.stochasticdepth import stochastic_depth_resnet34
        net = stochastic_depth_resnet34()
    elif args.net == 'stochasticdepth50':
        from models.stochasticdepth import stochastic_depth_resnet50
        net = stochastic_depth_resnet50()
    elif args.net == 'stochasticdepth101':
        from models.stochasticdepth import stochastic_depth_resnet101
        net = stochastic_depth_resnet101()
    elif args.net == 'normal_resnet':
        from models.normal_resnet import resnet18
        net = resnet18()
    elif args.net == 'hyper_resnet':
        from models.hypernet_main import Hypernet_Main
        net = Hypernet_Main(
            encoder="resnet18",
            hypernet_params={'vqvae_dict_size': args.dict_size})
    elif args.net == 'normal_resnet_wo_bn':
        from models.normal_resnet_wo_bn import resnet18
        net = resnet18()
    elif args.net == 'hyper_resnet_wo_bn':
        from models.hypernet_main import Hypernet_Main
        net = Hypernet_Main(
            encoder="resnet18_wobn",
            hypernet_params={'vqvae_dict_size': args.dict_size})
    else:
        print('the network name you have entered is not supported yet')
        sys.exit()

    if args.gpu:  #use_gpu
        net = net.cuda()

    return net
Exemplo n.º 5
0
def get_network(args):
    """ return given network
    """

    if args.model == 'vgg16':
        from models.vgg import vgg16_bn
        model = vgg16_bn()
    elif args.model == 'vgg13':
        from models.vgg import vgg13_bn
        model = vgg13_bn()
    elif args.model == 'vgg11':
        from models.vgg import vgg11_bn
        model = vgg11_bn()
    elif args.model == 'vgg19':
        from models.vgg import vgg19_bn
        model = vgg19_bn()
    elif args.model == 'densenet121':
        from models.densenet import densenet121
        model = densenet121()
    elif args.model == 'densenet161':
        from models.densenet import densenet161
        model = densenet161()
    elif args.model == 'densenet169':
        from models.densenet import densenet169
        model = densenet169()
    elif args.model == 'densenet201':
        from models.densenet import densenet201
        model = densenet201()
    elif args.model == 'googlenet':
        from models.googlenet import googlenet
        model = googlenet()
    elif args.model == 'inceptionv3':
        from models.inceptionv3 import inceptionv3
        model = inceptionv3()
    elif args.model == 'inceptionv4':
        from models.inceptionv4 import inceptionv4
        model = inceptionv4()
    elif args.model == 'inceptionresnetv2':
        from models.inceptionv4 import inception_resnet_v2
        model = inception_resnet_v2()
    elif args.model == 'xception':
        from models.xception import xception
        model = xception()
    elif args.model == 'resnet18':
        from models.resnet import resnet18
        model = resnet18()
    elif args.model == 'resnet34':
        from models.resnet import resnet34
        model = resnet34()
    elif args.model == 'resnet50':
        from models.resnet import resnet50
        model = resnet50()
    elif args.model == 'resnet101':
        from models.resnet import resnet101
        model = resnet101()
    elif args.model == 'resnet152':
        from models.resnet import resnet152
        model = resnet152()
    elif args.model == 'preactresnet18':
        from models.preactresnet import preactresnet18
        model = preactresnet18()
    elif args.model == 'preactresnet34':
        from models.preactresnet import preactresnet34
        model = preactresnet34()
    elif args.model == 'preactresnet50':
        from models.preactresnet import preactresnet50
        model = preactresnet50()
    elif args.model == 'preactresnet101':
        from models.preactresnet import preactresnet101
        model = preactresnet101()
    elif args.model == 'preactresnet152':
        from models.preactresnet import preactresnet152
        model = preactresnet152()
    elif args.model == 'resnext50':
        from models.resnext import resnext50
        model = resnext50()
    elif args.model == 'resnext101':
        from models.resnext import resnext101
        model = resnext101()
    elif args.model == 'resnext152':
        from models.resnext import resnext152
        model = resnext152()
    elif args.model == 'shufflenet':
        from models.shufflenet import shufflenet
        model = shufflenet()
    elif args.model == 'shufflenetv2':
        from models.shufflenetv2 import shufflenetv2
        model = shufflenetv2()
    elif args.model == 'squeezenet':
        from models.squeezenet import squeezenet
        model = squeezenet()
    elif args.model == 'mobilenet':
        from models.mobilenet import mobilenet
        model = mobilenet()
    elif args.model == 'mobilenetv2':
        from models.mobilenetv2 import mobilenetv2
        model = mobilenetv2()
    elif args.model == 'nasnet':
        from models.nasnet import nasnet
        model = nasnet()
    elif args.model == 'attention56':
        from models.attention import attention56
        model = attention56()
    elif args.model == 'attention92':
        from models.attention import attention92
        model = attention92()
    elif args.model == 'seresnet18':
        from models.senet import seresnet18
        model = seresnet18()
    elif args.model == 'seresnet34':
        from models.senet import seresnet34
        model = seresnet34()
    elif args.model == 'seresnet50':
        from models.senet import seresnet50
        model = seresnet50()
    elif args.model == 'seresnet101':
        from models.senet import seresnet101
        model = seresnet101()
    elif args.model == 'seresnet152':
        from models.senet import seresnet152
        model = seresnet152()
    elif args.model == 'wideresnet':
        from models.wideresidual import wideresnet
        model = wideresnet()
    elif args.model == 'stochasticdepth18':
        from models.stochasticdepth import stochastic_depth_resnet18
        model = stochastic_depth_resnet18()
    elif args.model == 'stochasticdepth34':
        from models.stochasticdepth import stochastic_depth_resnet34
        model = stochastic_depth_resnet34()
    elif args.model == 'stochasticdepth50':
        from models.stochasticdepth import stochastic_depth_resnet50
        model = stochastic_depth_resnet50()
    elif args.model == 'stochasticdepth101':
        from models.stochasticdepth import stochastic_depth_resnet101
        model = stochastic_depth_resnet101()

    else:
        print('the network name you have entered is not supported yet')
        sys.exit()

    return model
Exemplo n.º 6
0
def inception_4(pre):
    return children(inceptionv4(pretrained=pre))[0]
Exemplo n.º 7
0
def get_network(args, use_gpu=True):
    """ return given network
    """

    if args.net == 'vgg16':
        net = torchvision.models.vgg16_bn(pretrained=args.bool_pretrained)
        net.classifier[6] = nn.Linear(4096, args.num_classes, bias=True)
    elif args.net == 'vgg13':
        net = torchvision.models.vgg13_bn(pretrained=args.bool_pretrained)
        net.classifier[6] = nn.Linear(4096, args.num_classes, bias=True)
    elif args.net == 'vgg11':
        net = torchvision.models.vgg11_bn(pretrained=args.bool_pretrained)
        net.classifier[6] = nn.Linear(4096, args.num_classes, bias=True)
    elif args.net == 'vgg19':
        net = torchvision.models.vgg19_bn(pretrained=args.bool_pretrained)
        net.classifier[6] = nn.Linear(4096, args.num_classes, bias=True)

    ####effcientnet
    elif args.net == 'efficientnet-b5':
        from efficientnet_pytorch import EfficientNet
        if args.bool_pretrained == True:
            net = EfficientNet.from_pretrained('efficientnet-b5')
        else:
            net = EfficientNet.from_name('efficientnet-b5')
        net._fc = nn.Linear(2048, args.num_classes, bias=True)

    elif args.net == 'densenet121':
        from models.densenet import densenet121
        net = densenet121()
    elif args.net == 'densenet161':
        net = torchvision.models.densenet161(pretrained=args.bool_pretrained)
        in_features = net.classifier.in_features
        net.classifier = nn.Linear(in_features, args.num_classes, bias=True)
    elif args.net == 'densenet169':
        net = torchvision.models.densenet169(pretrained=args.bool_pretrained)
        in_features = net.classifier.in_features
        net.classifier = nn.Linear(in_features, args.num_classes, bias=True)
    elif args.net == 'densenet201':
        net = torchvision.models.densenet201(pretrained=args.bool_pretrained)
        in_features = net.classifier.in_features
        net.classifier = nn.Linear(in_features, args.num_classes, bias=True)

    elif args.net == 'googlenet':
        from models.googlenet import googlenet
        net = googlenet()
    elif args.net == 'inceptionv3':
        from models.inceptionv3 import inceptionv3
        net = inceptionv3()
    elif args.net == 'inceptionv4':
        from models.inceptionv4 import inceptionv4
        net = inceptionv4()
    elif args.net == 'inceptionresnetv2':
        from models.inceptionv4 import inception_resnet_v2
        net = inception_resnet_v2()
    elif args.net == 'xception':
        from models.xception import xception
        net = xception()

    ################## ResNet ########################################################
    elif args.net == 'resnet18':
        from models.resnet import resnet
        net = resnet(args.num_classes, 2, args.pretrained, args.net)
    elif args.net == 'resnet34':
        from models.resnet import resnet
        net = resnet(args.num_classes, 2, args.pretrained, args.net)
    elif args.net == 'resnet50':
        from models.resnet import resnet
        net = resnet(args.num_classes, 2, args.pretrained, args.net)
    elif args.net == 'resnet101':
        from models.resnet import resnet
        net = resnet(args.num_classes, 2, args.pretrained, args.net)
    elif args.net == 'resnet152':
        from models.resnet import resnet
        net = resnet(args.num_classes, 2, args.pretrained, args.net)

    ##################################################################################
    elif args.net == 'preactresnet18':
        from models.preactresnet import preactresnet18
        net = preactresnet18()
    elif args.net == 'preactresnet34':
        from models.preactresnet import preactresnet34
        net = preactresnet34()
    elif args.net == 'preactresnet50':
        from models.preactresnet import preactresnet50
        net = preactresnet50()
    elif args.net == 'preactresnet101':
        from models.preactresnet import preactresnet101
        net = preactresnet101()
    elif args.net == 'preactresnet152':
        from models.preactresnet import preactresnet152
        net = preactresnet152()

    ##################################################################
    elif args.net == 'se_resnext50':
        from models.resnext import se_resnext
        net = se_resnext(args.num_classes, 2, args.pretrained, args.net)
    elif args.net == 'se_resnext101':
        from models.resnext import se_resnext
        net = se_resnext(args.num_classes, 2, args.pretrained, args.net)

    #################################################################
    elif args.net == 'shufflenet':
        from models.shufflenet import shufflenet
        net = shufflenet()
    elif args.net == 'shufflenetv2':
        from models.shufflenetv2 import shufflenetv2
        net = shufflenetv2()
    elif args.net == 'squeezenet':
        from models.squeezenet import squeezenet
        net = squeezenet()
    elif args.net == 'mobilenet':
        from models.mobilenet import mobilenet
        net = mobilenet()
    elif args.net == 'mobilenetv2':
        from models.mobilenetv2 import mobilenetv2
        net = mobilenetv2()
    elif args.net == 'nasnet':
        from models.nasnet import nasnet
        net = nasnet()
    elif args.net == 'attention56':
        from models.attention import attention56
        net = attention56()
    elif args.net == 'attention92':
        from models.attention import attention92
        net = attention92()

    #########################################################
    elif args.net == 'se_resnet50':
        from models.senet import seresnet
        net = seresnet(args.num_classes, 2, args.pretrained, args.net)
    elif args.net == 'se_resnet101':
        from models.senet import seresnet
        net = seresnet(args.num_classes, 2, args.pretrained, args.net)
    elif args.net == 'se_resnet152':
        from models.senet import seresnet
        net = seresnet(args.num_classes, 2, args.pretrained, args.net)

    else:
        print('the network name you have entered is not supported yet')
        sys.exit()

    if use_gpu:
        net = net.cuda()

    return net
Exemplo n.º 8
0
def get_network(args, use_gpu=True, num_train=0):
    """ return given network
    """
    if args.dataset == 'cifar-10':
        num_classes = 10
    elif args.dataset == 'cifar-100':
        num_classes = 100
    else:
        num_classes = 0

    if args.ignoring:
        if args.net == 'resnet18':
            from models.resnet_ign import resnet18_ign
            criterion = nn.CrossEntropyLoss(reduction='none')
            net = resnet18_ign(criterion, num_classes=num_classes, num_train=num_train,softmax=args.softmax,isalpha=args.isalpha)

    else:
        if args.net == 'vgg16':
            from models.vgg import vgg16_bn
            net = vgg16_bn()
        elif args.net == 'vgg13':
            from models.vgg import vgg13_bn
            net = vgg13_bn()
        elif args.net == 'vgg11':
            from models.vgg import vgg11_bn
            net = vgg11_bn()
        elif args.net == 'vgg19':
            from models.vgg import vgg19_bn
            net = vgg19_bn()
        elif args.net == 'densenet121':
            from models.densenet import densenet121
            net = densenet121()
        elif args.net == 'densenet161':
            from models.densenet import densenet161
            net = densenet161()
        elif args.net == 'densenet169':
            from models.densenet import densenet169
            net = densenet169()
        elif args.net == 'densenet201':
            from models.densenet import densenet201
            net = densenet201()
        elif args.net == 'googlenet':
            from models.googlenet import googlenet
            net = googlenet()
        elif args.net == 'inceptionv3':
            from models.inceptionv3 import inceptionv3
            net = inceptionv3()
        elif args.net == 'inceptionv4':
            from models.inceptionv4 import inceptionv4
            net = inceptionv4()
        elif args.net == 'inceptionresnetv2':
            from models.inceptionv4 import inception_resnet_v2
            net = inception_resnet_v2()
        elif args.net == 'xception':
            from models.xception import xception
            net = xception()
        elif args.net == 'resnet18':
            from models.resnet import resnet18
            net = resnet18(num_classes=num_classes)
        elif args.net == 'resnet34':
            from models.resnet import resnet34
            net = resnet34(num_classes=num_classes)
        elif args.net == 'resnet50':
            from models.resnet import resnet50
            net = resnet50(num_classes=num_classes)
        elif args.net == 'resnet101':
            from models.resnet import resnet101
            net = resnet101(num_classes=num_classes)
        elif args.net == 'resnet152':
            from models.resnet import resnet152
            net = resnet152(num_classes=num_classes)
        elif args.net == 'preactresnet18':
            from models.preactresnet import preactresnet18
            net = preactresnet18()
        elif args.net == 'preactresnet34':
            from models.preactresnet import preactresnet34
            net = preactresnet34()
        elif args.net == 'preactresnet50':
            from models.preactresnet import preactresnet50
            net = preactresnet50()
        elif args.net == 'preactresnet101':
            from models.preactresnet import preactresnet101
            net = preactresnet101()
        elif args.net == 'preactresnet152':
            from models.preactresnet import preactresnet152
            net = preactresnet152()
        elif args.net == 'resnext50':
            from models.resnext import resnext50
            net = resnext50()
        elif args.net == 'resnext101':
            from models.resnext import resnext101
            net = resnext101()
        elif args.net == 'resnext152':
            from models.resnext import resnext152
            net = resnext152()
        elif args.net == 'shufflenet':
            from models.shufflenet import shufflenet
            net = shufflenet()
        elif args.net == 'shufflenetv2':
            from models.shufflenetv2 import shufflenetv2
            net = shufflenetv2()
        elif args.net == 'squeezenet':
            from models.squeezenet import squeezenet
            net = squeezenet()
        elif args.net == 'mobilenet':
            from models.mobilenet import mobilenet
            net = mobilenet()
        elif args.net == 'mobilenetv2':
            from models.mobilenetv2 import mobilenetv2
            net = mobilenetv2()
        elif args.net == 'nasnet':
            from models.nasnet import nasnet
            net = nasnet()
        elif args.net == 'attention56':
            from models.attention import attention56
            net = attention56()
        elif args.net == 'attention92':
            from models.attention import attention92
            net = attention92()
        elif args.net == 'seresnet18':
            from models.senet import seresnet18
            net = seresnet18()
        elif args.net == 'seresnet34':
            from models.senet import seresnet34
            net = seresnet34()
        elif args.net == 'seresnet50':
            from models.senet import seresnet50
            net = seresnet50()
        elif args.net == 'seresnet101':
            from models.senet import seresnet101
            net = seresnet101()
        elif args.net == 'seresnet152':
            from models.senet import seresnet152
            net = seresnet152()

        else:
            print('the network name you have entered is not supported yet')
            sys.exit()

    if use_gpu:
        net = net.cuda()

    return net
Exemplo n.º 9
0
Arquivo: utils.py Projeto: nblt/DLDR
def get_model(args):
    if args.datasets == 'ImageNet':
        return models_imagenet.__dict__[args.arch]()

    if args.datasets == 'CIFAR10' or args.datasets == 'MNIST':
        num_class = 10
    elif args.datasets == 'CIFAR100':
        num_class = 100

    if args.datasets == 'CIFAR100':
        if args.arch == 'vgg16':
            from models.vgg import vgg16_bn
            net = vgg16_bn()
        elif args.arch == 'vgg13':
            from models.vgg import vgg13_bn
            net = vgg13_bn()
        elif args.arch == 'vgg11':
            from models.vgg import vgg11_bn
            net = vgg11_bn()
        elif args.arch == 'vgg19':
            from models.vgg import vgg19_bn
            net = vgg19_bn()
        elif args.arch == 'densenet121':
            from models.densenet import densenet121
            net = densenet121()
        elif args.arch == 'densenet161':
            from models.densenet import densenet161
            net = densenet161()
        elif args.arch == 'densenet169':
            from models.densenet import densenet169
            net = densenet169()
        elif args.arch == 'densenet201':
            from models.densenet import densenet201
            net = densenet201()
        elif args.arch == 'googlenet':
            from models.googlenet import googlenet
            net = googlenet()
        elif args.arch == 'inceptionv3':
            from models.inceptionv3 import inceptionv3
            net = inceptionv3()
        elif args.arch == 'inceptionv4':
            from models.inceptionv4 import inceptionv4
            net = inceptionv4()
        elif args.arch == 'inceptionresnetv2':
            from models.inceptionv4 import inception_resnet_v2
            net = inception_resnet_v2()
        elif args.arch == 'xception':
            from models.xception import xception
            net = xception()
        elif args.arch == 'resnet18':
            from models.resnet import resnet18
            net = resnet18()
        elif args.arch == 'resnet34':
            from models.resnet import resnet34
            net = resnet34()
        elif args.arch == 'resnet50':
            from models.resnet import resnet50
            net = resnet50()
        elif args.arch == 'resnet101':
            from models.resnet import resnet101
            net = resnet101()
        elif args.arch == 'resnet152':
            from models.resnet import resnet152
            net = resnet152()
        elif args.arch == 'preactresnet18':
            from models.preactresnet import preactresnet18
            net = preactresnet18()
        elif args.arch == 'preactresnet34':
            from models.preactresnet import preactresnet34
            net = preactresnet34()
        elif args.arch == 'preactresnet50':
            from models.preactresnet import preactresnet50
            net = preactresnet50()
        elif args.arch == 'preactresnet101':
            from models.preactresnet import preactresnet101
            net = preactresnet101()
        elif args.arch == 'preactresnet152':
            from models.preactresnet import preactresnet152
            net = preactresnet152()
        elif args.arch == 'resnext50':
            from models.resnext import resnext50
            net = resnext50()
        elif args.arch == 'resnext101':
            from models.resnext import resnext101
            net = resnext101()
        elif args.arch == 'resnext152':
            from models.resnext import resnext152
            net = resnext152()
        elif args.arch == 'shufflenet':
            from models.shufflenet import shufflenet
            net = shufflenet()
        elif args.arch == 'shufflenetv2':
            from models.shufflenetv2 import shufflenetv2
            net = shufflenetv2()
        elif args.arch == 'squeezenet':
            from models.squeezenet import squeezenet
            net = squeezenet()
        elif args.arch == 'mobilenet':
            from models.mobilenet import mobilenet
            net = mobilenet()
        elif args.arch == 'mobilenetv2':
            from models.mobilenetv2 import mobilenetv2
            net = mobilenetv2()
        elif args.arch == 'nasnet':
            from models.nasnet import nasnet
            net = nasnet()
        elif args.arch == 'attention56':
            from models.attention import attention56
            net = attention56()
        elif args.arch == 'attention92':
            from models.attention import attention92
            net = attention92()
        elif args.arch == 'seresnet18':
            from models.senet import seresnet18
            net = seresnet18()
        elif args.arch == 'seresnet34':
            from models.senet import seresnet34
            net = seresnet34()
        elif args.arch == 'seresnet50':
            from models.senet import seresnet50
            net = seresnet50()
        elif args.arch == 'seresnet101':
            from models.senet import seresnet101
            net = seresnet101()
        elif args.arch == 'seresnet152':
            from models.senet import seresnet152
            net = seresnet152()
        elif args.arch == 'wideresnet':
            from models.wideresidual import wideresnet
            net = wideresnet()
        elif args.arch == 'stochasticdepth18':
            from models.stochasticdepth import stochastic_depth_resnet18
            net = stochastic_depth_resnet18()
        elif args.arch == 'efficientnet':
            from models.efficientnet import efficientnet
            net = efficientnet(1, 1, 100, bn_momentum=0.9)
        elif args.arch == 'stochasticdepth34':
            from models.stochasticdepth import stochastic_depth_resnet34
            net = stochastic_depth_resnet34()
        elif args.arch == 'stochasticdepth50':
            from models.stochasticdepth import stochastic_depth_resnet50
            net = stochastic_depth_resnet50()
        elif args.arch == 'stochasticdepth101':
            from models.stochasticdepth import stochastic_depth_resnet101
            net = stochastic_depth_resnet101()
        else:
            net = resnet.__dict__[args.arch](num_classes=num_class)

        return net
    return resnet.__dict__[args.arch](num_classes=num_class)
Exemplo n.º 10
0
def get_network(args):
    """ return given network
    """

    if args.net == 'vgg16':
        from models.vgg import vgg16_bn
        net = vgg16_bn()
    elif args.net == 'vgg13':
        from models.vgg import vgg13_bn
        net = vgg13_bn()
    elif args.net == 'vgg11':
        from models.vgg import vgg11_bn
        net = vgg11_bn()
    elif args.net == 'vgg19':
        from models.vgg import vgg19_bn
        net = vgg19_bn()
    # elif args.net == 'efficientnet':
    #     from models.effnetv2 import effnetv2_s
    #     net = effnetv2_s()
    elif args.net == 'densenet121':
        from models.densenet import densenet121
        net = densenet121()
    elif args.net == 'densenet161':
        from models.densenet import densenet161
        net = densenet161()
    elif args.net == 'densenet169':
        from models.densenet import densenet169
        net = densenet169()
    elif args.net == 'densenet201':
        from models.densenet import densenet201
        net = densenet201()
    elif args.net == 'googlenet':
        from models.googlenet import googlenet
        net = googlenet()
    elif args.net == 'inceptionv3':
        from models.inceptionv3 import inceptionv3
        net = inceptionv3()
    elif args.net == 'inceptionv4':
        from models.inceptionv4 import inceptionv4
        net = inceptionv4()
    elif args.net == 'inceptionresnetv2':
        from models.inceptionv4 import inception_resnet_v2
        net = inception_resnet_v2()
    elif args.net == 'xception':
        from models.xception import xception
        net = xception()
    elif args.net == 'resnet18':
        from models.resnet import resnet18
        net = resnet18()
    elif args.net == 'resnet34':
        from models.resnet import resnet34
        net = resnet34()
    elif args.net == 'resnet50':
        from models.resnet import resnet50
        net = resnet50()
    elif args.net == 'resnet101':
        from models.resnet import resnet101
        net = resnet101()
    elif args.net == 'resnet152':
        from models.resnet import resnet152
        net = resnet152()
    elif args.net == 'preactresnet18':
        from models.preactresnet import preactresnet18
        net = preactresnet18()
    elif args.net == 'preactresnet34':
        from models.preactresnet import preactresnet34
        net = preactresnet34()
    elif args.net == 'preactresnet50':
        from models.preactresnet import preactresnet50
        net = preactresnet50()
    elif args.net == 'preactresnet101':
        from models.preactresnet import preactresnet101
        net = preactresnet101()
    elif args.net == 'preactresnet152':
        from models.preactresnet import preactresnet152
        net = preactresnet152()
    elif args.net == 'resnext50':
        from models.resnext import resnext50
        net = resnext50()
    elif args.net == 'resnext101':
        from models.resnext import resnext101
        net = resnext101()
    elif args.net == 'resnext152':
        from models.resnext import resnext152
        net = resnext152()
    elif args.net == 'shufflenet':
        from models.shufflenet import shufflenet
        net = shufflenet()
    elif args.net == 'shufflenetv2':
        from models.shufflenetv2 import shufflenetv2
        net = shufflenetv2()
    elif args.net == 'squeezenet':
        from models.squeezenet import squeezenet
        net = squeezenet()
    elif args.net == 'mobilenet':
        from models.mobilenet import mobilenet
        net = mobilenet()
    elif args.net == 'mobilenetv2':
        from models.mobilenetv2 import mobilenetv2
        net = mobilenetv2()
    elif args.net == 'nasnet':
        from models.nasnet import nasnet
        net = nasnet()
    elif args.net == 'attention56':
        from models.attention import attention56
        net = attention56()
    elif args.net == 'attention92':
        from models.attention import attention92
        net = attention92()
    elif args.net == 'seresnet18':
        from models.senet import seresnet18
        net = seresnet18()
    elif args.net == 'seresnet34':
        from models.senet import seresnet34
        net = seresnet34()
    elif args.net == 'seresnet50':
        from models.senet import seresnet50
        net = seresnet50()
    elif args.net == 'seresnet101':
        from models.senet import seresnet101
        net = seresnet101()
    elif args.net == 'seresnet152':
        from models.senet import seresnet152
        net = seresnet152()
    elif args.net == 'wideresnet':
        from models.wideresidual import wideresnet
        net = wideresnet()
    elif args.net == 'stochasticdepth18':
        from models.stochasticdepth import stochastic_depth_resnet18
        net = stochastic_depth_resnet18()
    elif args.net == 'stochasticdepth34':
        from models.stochasticdepth import stochastic_depth_resnet34
        net = stochastic_depth_resnet34()
    elif args.net == 'stochasticdepth50':
        from models.stochasticdepth import stochastic_depth_resnet50
        net = stochastic_depth_resnet50()
    elif args.net == 'stochasticdepth101':
        from models.stochasticdepth import stochastic_depth_resnet101
        net = stochastic_depth_resnet101()
    elif args.net == 'efficientnetb0':
        from models.efficientnet import efficientnetb0
        net = efficientnetb0()
    elif args.net == 'efficientnetb1':
        from models.efficientnet import efficientnetb1
        net = efficientnetb1()
    elif args.net == 'efficientnetb2':
        from models.efficientnet import efficientnetb2
        net = efficientnetb2()
    elif args.net == 'efficientnetb3':
        from models.efficientnet import efficientnetb3
        net = efficientnetb3()
    elif args.net == 'efficientnetb4':
        from models.efficientnet import efficientnetb4
        net = efficientnetb4()
    elif args.net == 'efficientnetb5':
        from models.efficientnet import efficientnetb5
        net = efficientnetb5()
    elif args.net == 'efficientnetb6':
        from models.efficientnet import efficientnetb6
        net = efficientnetb6()
    elif args.net == 'efficientnetb7':
        from models.efficientnet import efficientnetb7
        net = efficientnetb7()
    elif args.net == 'efficientnetl2':
        from models.efficientnet import efficientnetl2
        net = efficientnetl2()
    elif args.net == 'eff':
        from models.efficientnet_pytorch import EfficientNet
        net = EfficientNet.from_pretrained('efficientnet-b7', num_classes=2)

    else:
        print('the network name you have entered is not supported yet')
        sys.exit()

    if args.gpu:  #use_gpu
        net = net.cuda()
        print("use-gpu")

    return net
Exemplo n.º 11
0
def get_network(args):
    """ return given network
    """
    if args.task == 'cifar10':
        nclass = 10
    elif args.task == 'cifar100':
        nclass = 100
    #Yang added none bn vggs
    if args.net == 'vgg16':
        from models.vgg import vgg16
        net = vgg16(num_classes=nclass)
    elif args.net == 'vgg13':
        from models.vgg import vgg13
        net = vgg13(num_classes=nclass)
    elif args.net == 'vgg11':
        from models.vgg import vgg11
        net = vgg11(num_classes=nclass)
    elif args.net == 'vgg19':
        from models.vgg import vgg19
        net = vgg19(num_classes=nclass)

    elif args.net == 'vgg16bn':
        from models.vgg import vgg16_bn
        net = vgg16_bn(num_classes=nclass)
    elif args.net == 'vgg13bn':
        from models.vgg import vgg13_bn
        net = vgg13_bn(num_classes=nclass)
    elif args.net == 'vgg11bn':
        from models.vgg import vgg11_bn
        net = vgg11_bn(num_classes=nclass)
    elif args.net == 'vgg19bn':
        from models.vgg import vgg19_bn
        net = vgg19_bn(num_classes=nclass)

    elif args.net == 'densenet121':
        from models.densenet import densenet121
        net = densenet121()
    elif args.net == 'densenet161':
        from models.densenet import densenet161
        net = densenet161()
    elif args.net == 'densenet169':
        from models.densenet import densenet169
        net = densenet169()
    elif args.net == 'densenet201':
        from models.densenet import densenet201
        net = densenet201()
    elif args.net == 'googlenet':
        from models.googlenet import googlenet
        net = googlenet(num_classes=nclass)
    elif args.net == 'inceptionv3':
        from models.inceptionv3 import inceptionv3
        net = inceptionv3()
    elif args.net == 'inceptionv4':
        from models.inceptionv4 import inceptionv4
        net = inceptionv4()
    elif args.net == 'inceptionresnetv2':
        from models.inceptionv4 import inception_resnet_v2
        net = inception_resnet_v2()
    elif args.net == 'xception':
        from models.xception import xception
        net = xception(num_classes=nclass)
    elif args.net == 'scnet':
        from models.sphereconvnet import sphereconvnet
        net = sphereconvnet(num_classes=nclass)
    elif args.net == 'sphereresnet18':
        from models.sphereconvnet import resnet18
        net = resnet18(num_classes=nclass)
    elif args.net == 'sphereresnet32':
        from models.sphereconvnet import sphereresnet32
        net = sphereresnet32(num_classes=nclass)
    elif args.net == 'plainresnet32':
        from models.sphereconvnet import plainresnet32
        net = plainresnet32(num_classes=nclass)
    elif args.net == 'ynet18':
        from models.ynet import resnet18
        net = resnet18(num_classes=nclass)
    elif args.net == 'ynet34':
        from models.ynet import resnet34
        net = resnet34(num_classes=nclass)
    elif args.net == 'ynet50':
        from models.ynet import resnet50
        net = resnet50(num_classes=nclass)
    elif args.net == 'ynet101':
        from models.ynet import resnet101
        net = resnet101(num_classes=nclass)
    elif args.net == 'ynet152':
        from models.ynet import resnet152
        net = resnet152(num_classes=nclass)

    elif args.net == 'resnet18':
        from models.resnet import resnet18
        net = resnet18(num_classes=nclass)
    elif args.net == 'resnet34':
        from models.resnet import resnet34
        net = resnet34(num_classes=nclass)
    elif args.net == 'resnet50':
        from models.resnet import resnet50
        net = resnet50(num_classes=nclass)
    elif args.net == 'resnet101':
        from models.resnet import resnet101
        net = resnet101(num_classes=nclass)
    elif args.net == 'resnet152':
        from models.resnet import resnet152
        net = resnet152(num_classes=nclass)
    elif args.net == 'preactresnet18':
        from models.preactresnet import preactresnet18
        net = preactresnet18(num_classes=nclass)
    elif args.net == 'preactresnet34':
        from models.preactresnet import preactresnet34
        net = preactresnet34(num_classes=nclass)
    elif args.net == 'preactresnet50':
        from models.preactresnet import preactresnet50
        net = preactresnet50(num_classes=nclass)
    elif args.net == 'preactresnet101':
        from models.preactresnet import preactresnet101
        net = preactresnet101(num_classes=nclass)
    elif args.net == 'preactresnet152':
        from models.preactresnet import preactresnet152
        net = preactresnet152(num_classes=nclass)
    elif args.net == 'resnext50':
        from models.resnext import resnext50
        net = resnext50(num_classes=nclass)
    elif args.net == 'resnext101':
        from models.resnext import resnext101
        net = resnext101(num_classes=nclass)
    elif args.net == 'resnext152':
        from models.resnext import resnext152
        net = resnext152(num_classes=nclass)
    elif args.net == 'shufflenet':
        from models.shufflenet import shufflenet
        net = shufflenet()
    elif args.net == 'shufflenetv2':
        from models.shufflenetv2 import shufflenetv2
        net = shufflenetv2()
    elif args.net == 'squeezenet':
        from models.squeezenet import squeezenet
        net = squeezenet()
    elif args.net == 'mobilenet':
        from models.mobilenet import mobilenet
        net = mobilenet(num_classes=nclass)
    elif args.net == 'mobilenetv2':
        from models.mobilenetv2 import mobilenetv2
        net = mobilenetv2(num_classes=nclass)
    elif args.net == 'nasnet':
        from models.nasnet import nasnet
        net = nasnet(num_classes=nclass)
    elif args.net == 'attention56':
        from models.attention import attention56
        net = attention56()
    elif args.net == 'attention92':
        from models.attention import attention92
        net = attention92()
    elif args.net == 'seresnet18':
        from models.senet import seresnet18
        net = seresnet18(num_classes=nclass)
    elif args.net == 'seresnet34':
        from models.senet import seresnet34
        net = seresnet34(num_classes=nclass)
    elif args.net == 'seresnet50':
        from models.senet import seresnet50
        net = seresnet50(num_classes=nclass)
    elif args.net == 'seresnet101':
        from models.senet import seresnet101
        net = seresnet101(num_classes=nclass)
    elif args.net == 'seresnet152':
        from models.senet import seresnet152
        net = seresnet152(num_classes=nclass)

    else:
        print('the network name you have entered is not supported yet')
        sys.exit()

    if args.gpu:  #use_gpu
        net = net.cuda()

    return net
Exemplo n.º 12
0
    def loading_model(self):

        print('Loading %s model' % (self.model_type))
        pretrained = None
        if self.pretrained:
            pretrained = 'imagenet'

        if self.model_type == 'inceptionv4':
            self.model = inceptionv4(num_classes=1000, pretrained=pretrained)
            if self.pretrained:
                num_ftrs = self.model.last_linear.in_features
                self.model.last_linear = nn.Linear(num_ftrs, self.num_classes)
                #free all layers:
                for i, param in self.model.named_parameters():
                    param.requires_grad = False
                #unfreeze last layers:
                ct = []
                for name, child in self.model.features.named_children():
                    if "4" in ct:
                        for param in child.parameters():
                            param.requires_grad = True
                    ct.append(name)

            else:
                num_ftrs = self.model.last_linear.in_features
                self.model.last_linear = nn.Linear(num_ftrs, self.num_classes)

        elif self.model_type == 'iresetv2':
            self.model = inceptionresnetv2(num_classes=self.num_classes,
                                           pretrained=pretrained)
        else:
            print('no model')
            exit()

        cudnn.benchmark = True
        # Check gpu and run parallel
        if check_gpu() > 0:
            self.model = torch.nn.DataParallel(self.model).cuda()
            # self.model.cuda()
        # define loss function (criterion) and optimizer
        self.criterion = nn.CrossEntropyLoss()
        if check_gpu() > 0:
            self.criterion = nn.CrossEntropyLoss().cuda()

        params = self.model.parameters()
        if self.pretrained:
            params = list(
                filter(lambda p: p.requires_grad, self.model.parameters()))

        self.optimizer = optim.SGD(params=params,
                                   lr=self.lr,
                                   momentum=self.momentum,
                                   weight_decay=self.weight_decay)
        # self.scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer=self.optimizer, mode='min', factor=0.1,
        #                                                       patience=10, verbose=True, min_lr=0)
        self.scheduler = optim.lr_scheduler.StepLR(optimizer=self.optimizer,
                                                   step_size=10,
                                                   gamma=0.1)
        # self.optimizer = optim.Adam(params, lr=self.lr)
        # optionally resume from a checkpoint
        if self.resume:
            if os.path.isfile(self.resume):
                print("=> loading checkpoint '{}'".format(self.resume))
                checkpoint = torch.load(self.resume)
                self.start_epoch = checkpoint['epoch']
                self.best_prec1 = checkpoint['best_prec1']
                self.model.load_state_dict(checkpoint['state_dict'])
                self.optimizer.load_state_dict(checkpoint['optimizer'])
                print("=> loaded checkpoint (epoch {})".format(
                    checkpoint['epoch']))
            else:
                print("=> no checkpoint found at '{}'".format(self.resume))

        if self.evaluate:
            file_model_best = os.path.join(self.data_folder,
                                           'model_best.pth.tar')
            if os.path.isfile(file_model_best):
                print(
                    "=> loading checkpoint '{}'".format('model_best.pth.tar'))
                checkpoint = torch.load(file_model_best)
                self.start_epoch = checkpoint['epoch']
                self.best_prec1 = checkpoint['best_prec1']
                self.model.load_state_dict(checkpoint['state_dict'])
                self.optimizer.load_state_dict(checkpoint['optimizer'])
                print("=> loaded checkpoint (epoch {})".format(
                    checkpoint['epoch']))
            else:
                print("=> no checkpoint found at '{}'".format(self.resume))

        cudnn.benchmark = True
Exemplo n.º 13
0
from models.inceptionv4 import inceptionv4
from iceberg import *
from torch.optim import SGD

lr = 0.0005
mom = 0.9
wd = 1e-4
bsize = 128

model = inceptionv4(1000)
model = torch.nn.DataParallel(model).cuda()
optim = SGD(model.parameters(), lr, momentum=mom, weight_decay=wd)
iceberg = Iceberg(
    './data/train.json',
    './data/inceptionv4_lr5e-4_mom9e-1_wd1e-4_bs{}_pretrained_model.pth'.
    format(bsize), model, optim, 10000, bsize)
iceberg.run(
    transforms.Compose([
        transforms.Resize(320),
        transforms.RandomCrop(299),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor()
    ]))