Esempio n. 1
0
def infer():
    model_path = './resnext50_NWPU-RESISC45.pth'
    batch_size = 256
    net = resnext50(num_classes=45).to(device)
    pretrain = torch.load(model_path)
    net.load_state_dict(pretrain)

    test_data = ClassifyData(root='./test_data')
    test_loader = DataLoader(test_data,
                             batch_size=batch_size,
                             shuffle=False,
                             num_workers=16)
    tbar = tqdm(test_loader)
    test_accu = 0
    net.eval()
    with torch.no_grad():
        for i, (img, lab) in enumerate(tbar):
            tbar.set_description("Testing->>")
            img = img.to(device)
            b = img.size(0)
            prob = net.forward(img)
            pred = prob.data.max(1)[1].cpu()
            accu = float(pred.eq(lab.data).sum()) / b
            test_accu = ((test_accu * i) + accu) / (i + 1)
    print("test accuracy=%.6f" % test_accu)
Esempio n. 2
0
def get_network(args):
    """ return given network
    """

    if args.net == 'vgg16':
        from models.vgg import vgg16_bn
        net = vgg16_bn()
    elif args.net == 'vgg13':
        from models.vgg import vgg13_bn
        net = vgg13_bn()
    elif args.net == 'vgg11':
        from models.vgg import vgg11_bn
        net = vgg11_bn()
    elif args.net == 'vgg19':
        from models.vgg import vgg19_bn
        net = vgg19_bn()
    elif args.net == 'densenet121':
        from models.densenet import densenet121
        net = densenet121()
    elif args.net == 'densenet161':
        from models.densenet import densenet161
        net = densenet161()
    elif args.net == 'densenet169':
        from models.densenet import densenet169
        net = densenet169()
    elif args.net == 'densenet201':
        from models.densenet import densenet201
        net = densenet201()
    elif args.net == 'googlenet':
        from models.googlenet import googlenet
        net = googlenet()
    elif args.net == 'inceptionv3':
        from models.inceptionv3 import inceptionv3
        net = inceptionv3()
    elif args.net == 'inceptionv4':
        from models.inceptionv4 import inceptionv4
        net = inceptionv4()
    elif args.net == 'inceptionresnetv2':
        from models.inceptionv4 import inception_resnet_v2
        net = inception_resnet_v2()
    elif args.net == 'xception':
        from models.xception import xception
        net = xception()
    elif args.net == 'resnet18':
        from models.resnet import resnet18
        net = resnet18()
    elif args.net == 'resnet34':
        from models.resnet import resnet34
        net = resnet34()
    elif args.net == 'resnet50':
        from models.resnet import resnet50
        net = resnet50()
    elif args.net == 'resnet101':
        from models.resnet import resnet101
        net = resnet101()
    elif args.net == 'resnet152':
        from models.resnet import resnet152
        net = resnet152()
    elif args.net == 'preactresnet18':
        from models.preactresnet import preactresnet18
        net = preactresnet18()
    elif args.net == 'preactresnet34':
        from models.preactresnet import preactresnet34
        net = preactresnet34()
    elif args.net == 'preactresnet50':
        from models.preactresnet import preactresnet50
        net = preactresnet50()
    elif args.net == 'preactresnet101':
        from models.preactresnet import preactresnet101
        net = preactresnet101()
    elif args.net == 'preactresnet152':
        from models.preactresnet import preactresnet152
        net = preactresnet152()
    elif args.net == 'resnext50':
        from models.resnext import resnext50
        net = resnext50()
    elif args.net == 'resnext101':
        from models.resnext import resnext101
        net = resnext101()
    elif args.net == 'resnext152':
        from models.resnext import resnext152
        net = resnext152()
    elif args.net == 'shufflenet':
        from models.shufflenet import shufflenet
        net = shufflenet()
    elif args.net == 'shufflenetv2':
        from models.shufflenetv2 import shufflenetv2
        net = shufflenetv2()
    elif args.net == 'squeezenet':
        from models.squeezenet import squeezenet
        net = squeezenet()
    elif args.net == 'mobilenet':
        from models.mobilenet import mobilenet
        net = mobilenet()
    elif args.net == 'mobilenetv2':
        from models.mobilenetv2 import mobilenetv2
        net = mobilenetv2()
    elif args.net == 'nasnet':
        from models.nasnet import nasnet
        net = nasnet()
    elif args.net == 'attention56':
        from models.attention import attention56
        net = attention56()
    elif args.net == 'attention92':
        from models.attention import attention92
        net = attention92()
    elif args.net == 'seresnet18':
        from models.senet import seresnet18
        net = seresnet18()
    elif args.net == 'seresnet34':
        from models.senet import seresnet34
        net = seresnet34()
    elif args.net == 'seresnet50':
        from models.senet import seresnet50
        net = seresnet50()
    elif args.net == 'seresnet101':
        from models.senet import seresnet101
        net = seresnet101()
    elif args.net == 'seresnet152':
        from models.senet import seresnet152
        net = seresnet152()
    elif args.net == 'wideresnet':
        from models.wideresidual import wideresnet
        net = wideresnet()
    elif args.net == 'stochasticdepth18':
        from models.stochasticdepth import stochastic_depth_resnet18
        net = stochastic_depth_resnet18()
    elif args.net == 'stochasticdepth34':
        from models.stochasticdepth import stochastic_depth_resnet34
        net = stochastic_depth_resnet34()
    elif args.net == 'stochasticdepth50':
        from models.stochasticdepth import stochastic_depth_resnet50
        net = stochastic_depth_resnet50()
    elif args.net == 'stochasticdepth101':
        from models.stochasticdepth import stochastic_depth_resnet101
        net = stochastic_depth_resnet101()
    elif args.net == 'normal_resnet':
        from models.normal_resnet import resnet18
        net = resnet18()
    elif args.net == 'hyper_resnet':
        from models.hypernet_main import Hypernet_Main
        net = Hypernet_Main(
            encoder="resnet18",
            hypernet_params={'vqvae_dict_size': args.dict_size})
    elif args.net == 'normal_resnet_wo_bn':
        from models.normal_resnet_wo_bn import resnet18
        net = resnet18()
    elif args.net == 'hyper_resnet_wo_bn':
        from models.hypernet_main import Hypernet_Main
        net = Hypernet_Main(
            encoder="resnet18_wobn",
            hypernet_params={'vqvae_dict_size': args.dict_size})
    else:
        print('the network name you have entered is not supported yet')
        sys.exit()

    if args.gpu:  #use_gpu
        net = net.cuda()

    return net
def get_network(args, use_gpu=True):
    """ return given network
    """

    if args.net == 'vgg16':
        from models.vgg import vgg16_bn
        net = vgg16_bn()
    elif args.net == 'vgg13':
        from models.vgg import vgg13_bn
        net = vgg13_bn()
    elif args.net == 'vgg11':
        from models.vgg import vgg11_bn
        net = vgg11_bn()
    elif args.net == 'vgg19':
        from models.vgg import vgg19_bn
        net = vgg19_bn()
    elif args.net == 'densenet121':
        from models.densenet import densenet121
        net = densenet121()
    elif args.net == 'densenet161':
        from models.densenet import densenet161
        net = densenet161()
    elif args.net == 'densenet169':
        from models.densenet import densenet169
        net = densenet169()
    elif args.net == 'densenet201':
        from models.densenet import densenet201
        net = densenet201()
    elif args.net == 'googlenet':
        from models.googlenet import googlenet
        net = googlenet()
    elif args.net == 'inceptionv3':
        from models.inceptionv3 import inceptionv3
        net = inceptionv3()
    elif args.net == 'inceptionv4':
        from models.inceptionv4 import inceptionv4
        net = inceptionv4()
    elif args.net == 'inceptionresnetv2':
        from models.inceptionv4 import inception_resnet_v2
        net = inception_resnet_v2()
    elif args.net == 'xception':
        from models.xception import xception
        net = xception()
    elif args.net == 'resnet18':
        from models.resnet import resnet18
        net = resnet18()
    elif args.net == 'resnet34':
        from models.resnet import resnet34
        net = resnet34()
    elif args.net == 'resnet50':
        from models.resnet import resnet50
        net = resnet50()
    elif args.net == 'resnet101':
        from models.resnet import resnet101
        net = resnet101()
    elif args.net == 'resnet152':
        from models.resnet import resnet152
        net = resnet152()
    elif args.net == 'preactresnet18':
        from models.preactresnet import preactresnet18
        net = preactresnet18()
    elif args.net == 'preactresnet34':
        from models.preactresnet import preactresnet34
        net = preactresnet34()
    elif args.net == 'preactresnet50':
        from models.preactresnet import preactresnet50
        net = preactresnet50()
    elif args.net == 'preactresnet101':
        from models.preactresnet import preactresnet101
        net = preactresnet101()
    elif args.net == 'preactresnet152':
        from models.preactresnet import preactresnet152
        net = preactresnet152()
    elif args.net == 'resnext50':
        from models.resnext import resnext50
        net = resnext50()
    elif args.net == 'resnext101':
        from models.resnext import resnext101
        net = resnext101()
    elif args.net == 'resnext152':
        from models.resnext import resnext152
        net = resnext152()
    elif args.net == 'shufflenet':
        from models.shufflenet import shufflenet
        net = shufflenet()
    elif args.net == 'shufflenetv2':
        from models.shufflenetv2 import shufflenetv2
        net = shufflenetv2()
    elif args.net == 'squeezenet':
        from models.squeezenet import squeezenet
        net = squeezenet()
    elif args.net == 'mobilenet':
        from models.mobilenet import mobilenet
        net = mobilenet()
    elif args.net == 'mobilenetv2':
        from models.mobilenetv2 import mobilenetv2
        net = mobilenetv2()
    elif args.net == 'nasnet':
        from models.nasnet import nasnet
        net = nasnet()
    elif args.net == 'attention56':
        from models.attention import attention56
        net = attention56()
    elif args.net == 'attention92':
        from models.attention import attention92
        net = attention92()
    elif args.net == 'seresnet18':
        from models.senet import seresnet18
        net = seresnet18()
    elif args.net == 'seresnet34':
        from models.senet import seresnet34
        net = seresnet34()
    elif args.net == 'seresnet50':
        from models.senet import seresnet50
        net = seresnet50()
    elif args.net == 'seresnet101':
        from models.senet import seresnet101
        net = seresnet101()
    elif args.net == 'seresnet152':
        from models.senet import seresnet152
        net = seresnet152()

    else:
        print('the network name you have entered is not supported yet')
        sys.exit()

    if use_gpu:
        net = net.cuda()

    return net
Esempio n. 4
0
def generate_model(opt):
    assert opt.model in ['c3d', 'squeezenet', 'mobilenet', 'resnext', 'resnet',
                         'shufflenet', 'mobilenetv2', 'shufflenetv2']


    if opt.model == 'c3d':
        from models.c3d import get_fine_tuning_parameters
        model = c3d.get_model(
            num_classes=opt.n_classes,
            sample_size=opt.sample_size,
            sample_duration=opt.sample_duration)
    elif opt.model == 'squeezenet':
        from models.squeezenet import get_fine_tuning_parameters
        model = squeezenet.get_model(
            version=opt.version,
            num_classes=opt.n_classes,
            sample_size=opt.sample_size,
            sample_duration=opt.sample_duration)
    elif opt.model == 'shufflenet':
        from models.shufflenet import get_fine_tuning_parameters
        model = shufflenet.get_model(
            groups=opt.groups,
            width_mult=opt.width_mult,
            num_classes=opt.n_classes)
    elif opt.model == 'shufflenetv2':
        from models.shufflenetv2 import get_fine_tuning_parameters
        model = shufflenetv2.get_model(
            num_classes=opt.n_classes,
            sample_size=opt.sample_size,
            width_mult=opt.width_mult)
    elif opt.model == 'mobilenet':
        from models.mobilenet import get_fine_tuning_parameters
        model = mobilenet.get_model(
            num_classes=opt.n_classes,
            sample_size=opt.sample_size,
            width_mult=opt.width_mult)
    elif opt.model == 'mobilenetv2':
        from models.mobilenetv2 import get_fine_tuning_parameters
        model = mobilenetv2.get_model(
            num_classes=opt.n_classes,
            sample_size=opt.sample_size,
            width_mult=opt.width_mult)
    elif opt.model == 'resnext':
        assert opt.model_depth in [50, 101, 152]
        from models.resnext import get_fine_tuning_parameters
        if opt.model_depth == 50:
            model = resnext.resnext50(
                num_classes=opt.n_classes,
                shortcut_type=opt.resnet_shortcut,
                cardinality=opt.resnext_cardinality,
                sample_size=opt.sample_size,
                sample_duration=opt.sample_duration)
        elif opt.model_depth == 101:
            model = resnext.resnext101(
                num_classes=opt.n_classes,
                shortcut_type=opt.resnet_shortcut,
                cardinality=opt.resnext_cardinality,
                sample_size=opt.sample_size,
                sample_duration=opt.sample_duration)
        elif opt.model_depth == 152:
            model = resnext.resnext152(
                num_classes=opt.n_classes,
                shortcut_type=opt.resnet_shortcut,
                cardinality=opt.resnext_cardinality,
                sample_size=opt.sample_size,
                sample_duration=opt.sample_duration)
    elif opt.model == 'resnet':
        assert opt.model_depth in [10, 18, 34, 50, 101, 152, 200]
        from models.resnet import get_fine_tuning_parameters
        if opt.model_depth == 10:
            model = resnet.resnet10(
                num_classes=opt.n_classes,
                shortcut_type=opt.resnet_shortcut,
                sample_size=opt.sample_size,
                sample_duration=opt.sample_duration)
        elif opt.model_depth == 18:
            model = resnet.resnet18(
                num_classes=opt.n_classes,
                shortcut_type=opt.resnet_shortcut,
                sample_size=opt.sample_size,
                sample_duration=opt.sample_duration)
        elif opt.model_depth == 34:
            model = resnet.resnet34(
                num_classes=opt.n_classes,
                shortcut_type=opt.resnet_shortcut,
                sample_size=opt.sample_size,
                sample_duration=opt.sample_duration)
        elif opt.model_depth == 50:
            model = resnet.resnet50(
                num_classes=opt.n_classes,
                shortcut_type=opt.resnet_shortcut,
                sample_size=opt.sample_size,
                sample_duration=opt.sample_duration)
        elif opt.model_depth == 101:
            model = resnet.resnet101(
                num_classes=opt.n_classes,
                shortcut_type=opt.resnet_shortcut,
                sample_size=opt.sample_size,
                sample_duration=opt.sample_duration)
        elif opt.model_depth == 152:
            model = resnet.resnet152(
                num_classes=opt.n_classes,
                shortcut_type=opt.resnet_shortcut,
                sample_size=opt.sample_size,
                sample_duration=opt.sample_duration)
        elif opt.model_depth == 200:
            model = resnet.resnet200(
                num_classes=opt.n_classes,
                shortcut_type=opt.resnet_shortcut,
                sample_size=opt.sample_size,
                sample_duration=opt.sample_duration)
    if not opt.no_cuda:
        model = model.to(device)
        model = nn.DataParallel(model, device_ids=None)
        pytorch_total_params = sum(p.numel() for p in model.parameters() if
                               p.requires_grad)
        print("Total number of trainable parameters: ", pytorch_total_params)

        if opt.pretrain_path:
            print('loading pretrained model {}'.format(opt.pretrain_path))
            pretrain = torch.load(opt.pretrain_path, map_location=torch.device('cpu'))
            assert opt.arch == pretrain['arch']
            model.load_state_dict(pretrain['state_dict'])

            if opt.model in  ['mobilenet', 'mobilenetv2', 'shufflenet', 'shufflenetv2']:
                model.module.classifier = nn.Sequential(
                                nn.Dropout(0.9),
                                nn.Linear(model.module.classifier[1].in_features, opt.n_finetune_classes))
                model.module.classifier = model.module.classifier.cuda()
            elif opt.model == 'squeezenet':
                model.module.classifier = nn.Sequential(
                                nn.Dropout(p=0.5),
                                nn.Conv3d(model.module.classifier[1].in_channels, opt.n_finetune_classes, kernel_size=1),
                                nn.ReLU(inplace=True),
                                nn.AvgPool3d((1,4,4), stride=1))
                model.module.classifier = model.module.classifier.to(device)
            else:
                model.module.fc = nn.Linear(model.module.fc.in_features, opt.n_finetune_classes)
                model.module.fc = model.module.fc.cuda()

            parameters = get_fine_tuning_parameters(model, opt.ft_portion)
            return model, parameters
    else:
        if opt.pretrain_path:
            print('loading pretrained model {}'.format(opt.pretrain_path))
            pretrain = torch.load(opt.pretrain_path)
            assert opt.arch == pretrain['arch']
            model.load_state_dict(pretrain['state_dict'])

            if opt.model in  ['mobilenet', 'mobilenetv2', 'shufflenet', 'shufflenetv2']:
                model.module.classifier = nn.Sequential(
                                nn.Dropout(0.9),
                                nn.Linear(model.module.classifier[1].in_features, opt.n_finetune_classes)
                                )
            elif opt.model == 'squeezenet':
                model.module.classifier = nn.Sequential(
                                nn.Dropout(p=0.5),
                                nn.Conv3d(model.module.classifier[1].in_channels, opt.n_finetune_classes, kernel_size=1),
                                nn.ReLU(inplace=True),
                                nn.AvgPool3d((1,4,4), stride=1))
            else:
                model.module.fc = nn.Linear(model.module.fc.in_features, opt.n_finetune_classes)

            parameters = get_fine_tuning_parameters(model, opt.ft_begin_index)
            return model, parameters

    return model, model.parameters()
Esempio n. 5
0
def get_network(args, use_gpu=True):
    """ return given network
    """
    if args.net == 'mobilenet':
        from models.mobilenet import mobilenet
        net = mobilenet(args)
    elif args.net == 'mobilenetv2':
        from models.mobilenetv2 import mobilenetv2
        net = mobilenetv2(args)
    elif args.net == 'vgg13':
        from models.vgg import vgg13_bn
        net = vgg13_bn(args)
    elif args.net == 'vgg11':
        from models.vgg import vgg11_bn
        net = vgg11_bn(args)
    elif args.net == 'vgg19':
        # from models.vgg import vgg19_bn
        # net = vgg19_bn(args)
        from torchvision.models import vgg19_bn
        import torch.nn as nn
        net = vgg19_bn(pretrained=True)
        net.classifier[6] = nn.Linear(4096, args.nc)
    elif args.net == 'densenet121':
        from models.densenet import densenet121
        net = densenet121(args)
    elif args.net == 'densenet161':
        from models.densenet import densenet161
        net = densenet161(args)
    elif args.net == 'densenet169':
        from models.densenet import densenet169
        net = densenet169(args)
    elif args.net == 'densenet201':
        from models.densenet import densenet201
        net = densenet201(args)
    elif args.net == 'googlenet':
        from models.googlenet import googlenet
        net = googlenet(args)
    elif args.net == 'inceptionv3':
        from models.inceptionv3 import inceptionv3
        net = inceptionv3(args)
    elif args.net == 'inceptionv4':
        from models.inceptionv4 import inceptionv4
        net = inceptionv4(args)
    elif args.net == 'inceptionresnetv2':
        from models.inceptionv4 import inception_resnet_v2
        net = inception_resnet_v2(args)
    elif args.net == 'xception':
        from models.xception import xception
        net = xception(args)
    elif args.net == 'resnet18':
        # from models.resnet import resnet18
        # net = resnet18(args)
        from torchvision.models import resnet18
        import torch.nn as nn
        net = resnet18(pretrained=True)
        net.fc = nn.Linear(512, args.nc)
    elif args.net == 'resnet34':
        from models.resnet import resnet34
        net = resnet34(args)
    elif args.net == 'resnet50':
        from models.resnet import resnet50
        net = resnet50(args)
    elif args.net == 'resnet101':
        from models.resnet import resnet101
        net = resnet101(args)
    elif args.net == 'resnet152':
        from models.resnet import resnet152
        net = resnet152(args)
    elif args.net == 'preactresnet18':
        from models.preactresnet import preactresnet18
        net = preactresnet18(args)
    elif args.net == 'preactresnet34':
        from models.preactresnet import preactresnet34
        net = preactresnet34(args)
    elif args.net == 'preactresnet50':
        from models.preactresnet import preactresnet50
        net = preactresnet50(args)
    elif args.net == 'preactresnet101':
        from models.preactresnet import preactresnet101
        net = preactresnet101(args)
    elif args.net == 'preactresnet152':
        from models.preactresnet import preactresnet152
        net = preactresnet152(args)
    elif args.net == 'resnext50':
        from models.resnext import resnext50
        net = resnext50(args)
    elif args.net == 'resnext101':
        from models.resnext import resnext101
        net = resnext101(args)
    elif args.net == 'resnext152':
        from models.resnext import resnext152
        net = resnext152(args)
    elif args.net == 'shufflenet':
        from models.shufflenet import shufflenet
        net = shufflenet(args)
    elif args.net == 'shufflenetv2':
        from models.shufflenetv2 import shufflenetv2
        net = shufflenetv2(args)
    elif args.net == 'squeezenet':
        from models.squeezenet import squeezenet
        net = squeezenet(args)

    elif args.net == 'mobilenet':
        from models.mobilenet import mobilenet
        net = mobilenet(args)
    elif args.net == 'mobilenetv2':
        from models.mobilenetv2 import mobilenetv2
        net = mobilenetv2(args)
    elif args.net == 'mobilenetv3':
        from models.mobilenetv3 import mobileNetv3
        net = mobileNetv3(args)
    elif args.net == 'mobilenetv3_l':
        from models.mobilenetv3 import mobileNetv3
        net = mobileNetv3(args, mode='large')
    elif args.net == 'mobilenetv3_s':
        from models.mobilenetv3 import mobileNetv3
        net = mobileNetv3(args, mode='small')
    elif args.net == 'nasnet':
        from models.nasnet import nasnetalarge
        net = nasnetalarge(args)
    elif args.net == 'attention56':
        from models.attention import attention56
        net = attention56(args)
    elif args.net == 'attention92':
        from models.attention import attention92
        net = attention92(args)
    elif args.net == 'seresnet18':
        from models.senet import seresnet18
        net = seresnet18(args)
    elif args.net == 'seresnet34':
        from models.senet import seresnet34
        net = seresnet34(args)
    elif args.net == 'seresnet50':
        from models.senet import seresnet50
        net = seresnet50(args)
    elif args.net == 'seresnet101':
        from models.senet import seresnet101
        net = seresnet101(args)
    elif args.net == 'seresnet152':
        from models.senet import seresnet152
        net = seresnet152(args)
    elif args.net.lower() == 'sqnxt_23_1x':
        from models.SqueezeNext import SqNxt_23_1x
        net = SqNxt_23_1x(args)
    elif args.net.lower() == 'sqnxt_23_1xv5':
        from models.SqueezeNext import SqNxt_23_1x_v5
        net = SqNxt_23_1x_v5(args)
    elif args.net.lower() == 'sqnxt_23_2x':
        from models.SqueezeNext import SqNxt_23_2x
        net = SqNxt_23_2x(args)
    elif args.net.lower() == 'sqnxt_23_2xv5':
        from models.SqueezeNext import SqNxt_23_2x_v5
        net = SqNxt_23_2x_v5(args)
    elif args.net.lower() == 'mnasnet':
        # from models.MnasNet import mnasnet
        # net = mnasnet(args)
        from models.nasnet_mobile import nasnet_Mobile
        net = nasnet_Mobile(args)
    elif args.net == 'efficientnet_b0':
        from models.efficientnet import efficientnet_b0
        net = efficientnet_b0(args)
    elif args.net == 'efficientnet_b1':
        from models.efficientnet import efficientnet_b1
        net = efficientnet_b1(args)
    elif args.net == 'efficientnet_b2':
        from models.efficientnet import efficientnet_b2
        net = efficientnet_b2(args)
    elif args.net == 'efficientnet_b3':
        from models.efficientnet import efficientnet_b3
        net = efficientnet_b3(args)
    elif args.net == 'efficientnet_b4':
        from models.efficientnet import efficientnet_b4
        net = efficientnet_b4(args)
    elif args.net == 'efficientnet_b5':
        from models.efficientnet import efficientnet_b5
        net = efficientnet_b5(args)
    elif args.net == 'efficientnet_b6':
        from models.efficientnet import efficientnet_b6
        net = efficientnet_b6(args)
    elif args.net == 'efficientnet_b7':
        from models.efficientnet import efficientnet_b7
        net = efficientnet_b7(args)
    elif args.net == 'mlp':
        from models.mlp import MLPClassifier
        net = MLPClassifier(args)
    elif args.net == 'alexnet':
        from torchvision.models import alexnet
        import torch.nn as nn
        net = alexnet(pretrained=True)
        net.classifier[6] = nn.Linear(4096, args.nc)
    elif args.net == 'lambda18':
        from models._lambda import LambdaResnet18
        net = LambdaResnet18(num_classes=args.nc, channels=args.cs)
    elif args.net == 'lambda34':
        from models._lambda import LambdaResnet34
        net = LambdaResnet34(num_classes=args.nc, channels=args.cs)
    elif args.net == 'lambda50':
        from models._lambda import LambdaResnet50
        net = LambdaResnet50(num_classes=args.nc, channels=args.cs)
    elif args.net == 'lambda101':
        from models._lambda import LambdaResnet101
        net = LambdaResnet101(num_classes=args.nc)
    elif args.net == 'lambda152':
        from models._lambda import LambdaResnet152
        net = LambdaResnet152(num_classes=args.nc, channels=args.cs)
    else:
        print('the network name you have entered is not supported yet')
        sys.exit()

    if use_gpu:
        net = net.cuda()

    return net
Esempio n. 6
0
    cls = [
        x for x in os.listdir(data_path)
        if os.path.isdir(os.path.join(data_path, x))
    ]
    num_class = len(cls)
    models = {
        "vgg16": vgg.vgg16_bn(num_class),
        "vgg19": vgg.vgg19_bn(num_class),
        "densenet121": densenet.densenet121(num_class),
        "densenet161": densenet.densenet161(num_class),
        "resnet34": resnet.resnet34(num_class),
        "resnet50": resnet.resnet50(num_class),
        "resnet101": resnet.resnet101(num_class),
        "seresnet34": senet.seresnet34(num_class),
        "seresnet50": senet.seresnet50(num_class),
        "seresnet101": senet.seresnet101(num_class),
        "resnext34": resnext.resnext34(num_class),
        "resnext50": resnext.resnext50(num_class),
        "resnext101": resnext.resnext101(num_class),
        "shufflenet": shufflenet.shufflenet(num_class),
        "xception": xception.xception(num_class)
    }
    for net_name in models.keys():
        writer = SummaryWriter('./runs/%s_%s/' % (ds, net_name))
        model = models[net_name]
        logger.add('./log/%s_%s_{time}.log' % (ds, net_name), level="INFO")
        logger.info("net:%s\t dataset:%s\t num_class:%d" %
                    (net_name, ds, num_class))
        train()
        writer.close()
                                     model_name=args.model)
elif args.model == "efficientnet-b3":
    net = efficientnet.CEfficientNet(num_classes=num_classes,
                                     pretrained=args.pretrain,
                                     model_name=args.model)
elif args.model == "efficientnet-b4":
    net = efficientnet.CEfficientNet(num_classes=num_classes,
                                     pretrained=args.pretrain,
                                     model_name=args.model)
elif args.model == "efficientnet-b5":
    net = efficientnet.CEfficientNet(num_classes=num_classes,
                                     pretrained=args.pretrain,
                                     model_name=args.model)

elif args.model == "resnext50":
    net = resnext.resnext50(num_classes=num_classes)
elif args.model == "googlenet":
    net = googlenet.GoogLeNet(num_classes=num_classes)
elif args.model == "densenet":
    net = densenet.DenseNet_CIFAR(num_classes=num_classes)
else:
    raise "please check model"

# freeze
# count = 0
# for param in net.parameters():
#     count += 1
# for i, param in enumerate(net.parameters()):
#     if i <= count-1 - 10:
#         param.requires_grad = False
Esempio n. 8
0
def get_network(args, use_gpu=True, num_train=0):
    """ return given network
    """
    if args.dataset == 'cifar-10':
        num_classes = 10
    elif args.dataset == 'cifar-100':
        num_classes = 100
    else:
        num_classes = 0

    if args.ignoring:
        if args.net == 'resnet18':
            from models.resnet_ign import resnet18_ign
            criterion = nn.CrossEntropyLoss(reduction='none')
            net = resnet18_ign(criterion, num_classes=num_classes, num_train=num_train,softmax=args.softmax,isalpha=args.isalpha)

    else:
        if args.net == 'vgg16':
            from models.vgg import vgg16_bn
            net = vgg16_bn()
        elif args.net == 'vgg13':
            from models.vgg import vgg13_bn
            net = vgg13_bn()
        elif args.net == 'vgg11':
            from models.vgg import vgg11_bn
            net = vgg11_bn()
        elif args.net == 'vgg19':
            from models.vgg import vgg19_bn
            net = vgg19_bn()
        elif args.net == 'densenet121':
            from models.densenet import densenet121
            net = densenet121()
        elif args.net == 'densenet161':
            from models.densenet import densenet161
            net = densenet161()
        elif args.net == 'densenet169':
            from models.densenet import densenet169
            net = densenet169()
        elif args.net == 'densenet201':
            from models.densenet import densenet201
            net = densenet201()
        elif args.net == 'googlenet':
            from models.googlenet import googlenet
            net = googlenet()
        elif args.net == 'inceptionv3':
            from models.inceptionv3 import inceptionv3
            net = inceptionv3()
        elif args.net == 'inceptionv4':
            from models.inceptionv4 import inceptionv4
            net = inceptionv4()
        elif args.net == 'inceptionresnetv2':
            from models.inceptionv4 import inception_resnet_v2
            net = inception_resnet_v2()
        elif args.net == 'xception':
            from models.xception import xception
            net = xception()
        elif args.net == 'resnet18':
            from models.resnet import resnet18
            net = resnet18(num_classes=num_classes)
        elif args.net == 'resnet34':
            from models.resnet import resnet34
            net = resnet34(num_classes=num_classes)
        elif args.net == 'resnet50':
            from models.resnet import resnet50
            net = resnet50(num_classes=num_classes)
        elif args.net == 'resnet101':
            from models.resnet import resnet101
            net = resnet101(num_classes=num_classes)
        elif args.net == 'resnet152':
            from models.resnet import resnet152
            net = resnet152(num_classes=num_classes)
        elif args.net == 'preactresnet18':
            from models.preactresnet import preactresnet18
            net = preactresnet18()
        elif args.net == 'preactresnet34':
            from models.preactresnet import preactresnet34
            net = preactresnet34()
        elif args.net == 'preactresnet50':
            from models.preactresnet import preactresnet50
            net = preactresnet50()
        elif args.net == 'preactresnet101':
            from models.preactresnet import preactresnet101
            net = preactresnet101()
        elif args.net == 'preactresnet152':
            from models.preactresnet import preactresnet152
            net = preactresnet152()
        elif args.net == 'resnext50':
            from models.resnext import resnext50
            net = resnext50()
        elif args.net == 'resnext101':
            from models.resnext import resnext101
            net = resnext101()
        elif args.net == 'resnext152':
            from models.resnext import resnext152
            net = resnext152()
        elif args.net == 'shufflenet':
            from models.shufflenet import shufflenet
            net = shufflenet()
        elif args.net == 'shufflenetv2':
            from models.shufflenetv2 import shufflenetv2
            net = shufflenetv2()
        elif args.net == 'squeezenet':
            from models.squeezenet import squeezenet
            net = squeezenet()
        elif args.net == 'mobilenet':
            from models.mobilenet import mobilenet
            net = mobilenet()
        elif args.net == 'mobilenetv2':
            from models.mobilenetv2 import mobilenetv2
            net = mobilenetv2()
        elif args.net == 'nasnet':
            from models.nasnet import nasnet
            net = nasnet()
        elif args.net == 'attention56':
            from models.attention import attention56
            net = attention56()
        elif args.net == 'attention92':
            from models.attention import attention92
            net = attention92()
        elif args.net == 'seresnet18':
            from models.senet import seresnet18
            net = seresnet18()
        elif args.net == 'seresnet34':
            from models.senet import seresnet34
            net = seresnet34()
        elif args.net == 'seresnet50':
            from models.senet import seresnet50
            net = seresnet50()
        elif args.net == 'seresnet101':
            from models.senet import seresnet101
            net = seresnet101()
        elif args.net == 'seresnet152':
            from models.senet import seresnet152
            net = seresnet152()

        else:
            print('the network name you have entered is not supported yet')
            sys.exit()

    if use_gpu:
        net = net.cuda()

    return net
Esempio n. 9
0
def get_network(args):
    """ return given network
    """

    if args.model == 'vgg16':
        from models.vgg import vgg16_bn
        model = vgg16_bn()
    elif args.model == 'vgg13':
        from models.vgg import vgg13_bn
        model = vgg13_bn()
    elif args.model == 'vgg11':
        from models.vgg import vgg11_bn
        model = vgg11_bn()
    elif args.model == 'vgg19':
        from models.vgg import vgg19_bn
        model = vgg19_bn()
    elif args.model == 'densenet121':
        from models.densenet import densenet121
        model = densenet121()
    elif args.model == 'densenet161':
        from models.densenet import densenet161
        model = densenet161()
    elif args.model == 'densenet169':
        from models.densenet import densenet169
        model = densenet169()
    elif args.model == 'densenet201':
        from models.densenet import densenet201
        model = densenet201()
    elif args.model == 'googlenet':
        from models.googlenet import googlenet
        model = googlenet()
    elif args.model == 'inceptionv3':
        from models.inceptionv3 import inceptionv3
        model = inceptionv3()
    elif args.model == 'inceptionv4':
        from models.inceptionv4 import inceptionv4
        model = inceptionv4()
    elif args.model == 'inceptionresnetv2':
        from models.inceptionv4 import inception_resnet_v2
        model = inception_resnet_v2()
    elif args.model == 'xception':
        from models.xception import xception
        model = xception()
    elif args.model == 'resnet18':
        from models.resnet import resnet18
        model = resnet18()
    elif args.model == 'resnet34':
        from models.resnet import resnet34
        model = resnet34()
    elif args.model == 'resnet50':
        from models.resnet import resnet50
        model = resnet50()
    elif args.model == 'resnet101':
        from models.resnet import resnet101
        model = resnet101()
    elif args.model == 'resnet152':
        from models.resnet import resnet152
        model = resnet152()
    elif args.model == 'preactresnet18':
        from models.preactresnet import preactresnet18
        model = preactresnet18()
    elif args.model == 'preactresnet34':
        from models.preactresnet import preactresnet34
        model = preactresnet34()
    elif args.model == 'preactresnet50':
        from models.preactresnet import preactresnet50
        model = preactresnet50()
    elif args.model == 'preactresnet101':
        from models.preactresnet import preactresnet101
        model = preactresnet101()
    elif args.model == 'preactresnet152':
        from models.preactresnet import preactresnet152
        model = preactresnet152()
    elif args.model == 'resnext50':
        from models.resnext import resnext50
        model = resnext50()
    elif args.model == 'resnext101':
        from models.resnext import resnext101
        model = resnext101()
    elif args.model == 'resnext152':
        from models.resnext import resnext152
        model = resnext152()
    elif args.model == 'shufflenet':
        from models.shufflenet import shufflenet
        model = shufflenet()
    elif args.model == 'shufflenetv2':
        from models.shufflenetv2 import shufflenetv2
        model = shufflenetv2()
    elif args.model == 'squeezenet':
        from models.squeezenet import squeezenet
        model = squeezenet()
    elif args.model == 'mobilenet':
        from models.mobilenet import mobilenet
        model = mobilenet()
    elif args.model == 'mobilenetv2':
        from models.mobilenetv2 import mobilenetv2
        model = mobilenetv2()
    elif args.model == 'nasnet':
        from models.nasnet import nasnet
        model = nasnet()
    elif args.model == 'attention56':
        from models.attention import attention56
        model = attention56()
    elif args.model == 'attention92':
        from models.attention import attention92
        model = attention92()
    elif args.model == 'seresnet18':
        from models.senet import seresnet18
        model = seresnet18()
    elif args.model == 'seresnet34':
        from models.senet import seresnet34
        model = seresnet34()
    elif args.model == 'seresnet50':
        from models.senet import seresnet50
        model = seresnet50()
    elif args.model == 'seresnet101':
        from models.senet import seresnet101
        model = seresnet101()
    elif args.model == 'seresnet152':
        from models.senet import seresnet152
        model = seresnet152()
    elif args.model == 'wideresnet':
        from models.wideresidual import wideresnet
        model = wideresnet()
    elif args.model == 'stochasticdepth18':
        from models.stochasticdepth import stochastic_depth_resnet18
        model = stochastic_depth_resnet18()
    elif args.model == 'stochasticdepth34':
        from models.stochasticdepth import stochastic_depth_resnet34
        model = stochastic_depth_resnet34()
    elif args.model == 'stochasticdepth50':
        from models.stochasticdepth import stochastic_depth_resnet50
        model = stochastic_depth_resnet50()
    elif args.model == 'stochasticdepth101':
        from models.stochasticdepth import stochastic_depth_resnet101
        model = stochastic_depth_resnet101()

    else:
        print('the network name you have entered is not supported yet')
        sys.exit()

    return model
Esempio n. 10
0
File: utils.py Progetto: nblt/DLDR
def get_model(args):
    if args.datasets == 'ImageNet':
        return models_imagenet.__dict__[args.arch]()

    if args.datasets == 'CIFAR10' or args.datasets == 'MNIST':
        num_class = 10
    elif args.datasets == 'CIFAR100':
        num_class = 100

    if args.datasets == 'CIFAR100':
        if args.arch == 'vgg16':
            from models.vgg import vgg16_bn
            net = vgg16_bn()
        elif args.arch == 'vgg13':
            from models.vgg import vgg13_bn
            net = vgg13_bn()
        elif args.arch == 'vgg11':
            from models.vgg import vgg11_bn
            net = vgg11_bn()
        elif args.arch == 'vgg19':
            from models.vgg import vgg19_bn
            net = vgg19_bn()
        elif args.arch == 'densenet121':
            from models.densenet import densenet121
            net = densenet121()
        elif args.arch == 'densenet161':
            from models.densenet import densenet161
            net = densenet161()
        elif args.arch == 'densenet169':
            from models.densenet import densenet169
            net = densenet169()
        elif args.arch == 'densenet201':
            from models.densenet import densenet201
            net = densenet201()
        elif args.arch == 'googlenet':
            from models.googlenet import googlenet
            net = googlenet()
        elif args.arch == 'inceptionv3':
            from models.inceptionv3 import inceptionv3
            net = inceptionv3()
        elif args.arch == 'inceptionv4':
            from models.inceptionv4 import inceptionv4
            net = inceptionv4()
        elif args.arch == 'inceptionresnetv2':
            from models.inceptionv4 import inception_resnet_v2
            net = inception_resnet_v2()
        elif args.arch == 'xception':
            from models.xception import xception
            net = xception()
        elif args.arch == 'resnet18':
            from models.resnet import resnet18
            net = resnet18()
        elif args.arch == 'resnet34':
            from models.resnet import resnet34
            net = resnet34()
        elif args.arch == 'resnet50':
            from models.resnet import resnet50
            net = resnet50()
        elif args.arch == 'resnet101':
            from models.resnet import resnet101
            net = resnet101()
        elif args.arch == 'resnet152':
            from models.resnet import resnet152
            net = resnet152()
        elif args.arch == 'preactresnet18':
            from models.preactresnet import preactresnet18
            net = preactresnet18()
        elif args.arch == 'preactresnet34':
            from models.preactresnet import preactresnet34
            net = preactresnet34()
        elif args.arch == 'preactresnet50':
            from models.preactresnet import preactresnet50
            net = preactresnet50()
        elif args.arch == 'preactresnet101':
            from models.preactresnet import preactresnet101
            net = preactresnet101()
        elif args.arch == 'preactresnet152':
            from models.preactresnet import preactresnet152
            net = preactresnet152()
        elif args.arch == 'resnext50':
            from models.resnext import resnext50
            net = resnext50()
        elif args.arch == 'resnext101':
            from models.resnext import resnext101
            net = resnext101()
        elif args.arch == 'resnext152':
            from models.resnext import resnext152
            net = resnext152()
        elif args.arch == 'shufflenet':
            from models.shufflenet import shufflenet
            net = shufflenet()
        elif args.arch == 'shufflenetv2':
            from models.shufflenetv2 import shufflenetv2
            net = shufflenetv2()
        elif args.arch == 'squeezenet':
            from models.squeezenet import squeezenet
            net = squeezenet()
        elif args.arch == 'mobilenet':
            from models.mobilenet import mobilenet
            net = mobilenet()
        elif args.arch == 'mobilenetv2':
            from models.mobilenetv2 import mobilenetv2
            net = mobilenetv2()
        elif args.arch == 'nasnet':
            from models.nasnet import nasnet
            net = nasnet()
        elif args.arch == 'attention56':
            from models.attention import attention56
            net = attention56()
        elif args.arch == 'attention92':
            from models.attention import attention92
            net = attention92()
        elif args.arch == 'seresnet18':
            from models.senet import seresnet18
            net = seresnet18()
        elif args.arch == 'seresnet34':
            from models.senet import seresnet34
            net = seresnet34()
        elif args.arch == 'seresnet50':
            from models.senet import seresnet50
            net = seresnet50()
        elif args.arch == 'seresnet101':
            from models.senet import seresnet101
            net = seresnet101()
        elif args.arch == 'seresnet152':
            from models.senet import seresnet152
            net = seresnet152()
        elif args.arch == 'wideresnet':
            from models.wideresidual import wideresnet
            net = wideresnet()
        elif args.arch == 'stochasticdepth18':
            from models.stochasticdepth import stochastic_depth_resnet18
            net = stochastic_depth_resnet18()
        elif args.arch == 'efficientnet':
            from models.efficientnet import efficientnet
            net = efficientnet(1, 1, 100, bn_momentum=0.9)
        elif args.arch == 'stochasticdepth34':
            from models.stochasticdepth import stochastic_depth_resnet34
            net = stochastic_depth_resnet34()
        elif args.arch == 'stochasticdepth50':
            from models.stochasticdepth import stochastic_depth_resnet50
            net = stochastic_depth_resnet50()
        elif args.arch == 'stochasticdepth101':
            from models.stochasticdepth import stochastic_depth_resnet101
            net = stochastic_depth_resnet101()
        else:
            net = resnet.__dict__[args.arch](num_classes=num_class)

        return net
    return resnet.__dict__[args.arch](num_classes=num_class)
Esempio n. 11
0
def get_network(args):
    """ return given network
    """

    if args.net == 'vgg16':
        from models.vgg import vgg16_bn
        net = vgg16_bn()
    elif args.net == 'vgg13':
        from models.vgg import vgg13_bn
        net = vgg13_bn()
    elif args.net == 'vgg11':
        from models.vgg import vgg11_bn
        net = vgg11_bn()
    elif args.net == 'vgg19':
        from models.vgg import vgg19_bn
        net = vgg19_bn()
    # elif args.net == 'efficientnet':
    #     from models.effnetv2 import effnetv2_s
    #     net = effnetv2_s()
    elif args.net == 'densenet121':
        from models.densenet import densenet121
        net = densenet121()
    elif args.net == 'densenet161':
        from models.densenet import densenet161
        net = densenet161()
    elif args.net == 'densenet169':
        from models.densenet import densenet169
        net = densenet169()
    elif args.net == 'densenet201':
        from models.densenet import densenet201
        net = densenet201()
    elif args.net == 'googlenet':
        from models.googlenet import googlenet
        net = googlenet()
    elif args.net == 'inceptionv3':
        from models.inceptionv3 import inceptionv3
        net = inceptionv3()
    elif args.net == 'inceptionv4':
        from models.inceptionv4 import inceptionv4
        net = inceptionv4()
    elif args.net == 'inceptionresnetv2':
        from models.inceptionv4 import inception_resnet_v2
        net = inception_resnet_v2()
    elif args.net == 'xception':
        from models.xception import xception
        net = xception()
    elif args.net == 'resnet18':
        from models.resnet import resnet18
        net = resnet18()
    elif args.net == 'resnet34':
        from models.resnet import resnet34
        net = resnet34()
    elif args.net == 'resnet50':
        from models.resnet import resnet50
        net = resnet50()
    elif args.net == 'resnet101':
        from models.resnet import resnet101
        net = resnet101()
    elif args.net == 'resnet152':
        from models.resnet import resnet152
        net = resnet152()
    elif args.net == 'preactresnet18':
        from models.preactresnet import preactresnet18
        net = preactresnet18()
    elif args.net == 'preactresnet34':
        from models.preactresnet import preactresnet34
        net = preactresnet34()
    elif args.net == 'preactresnet50':
        from models.preactresnet import preactresnet50
        net = preactresnet50()
    elif args.net == 'preactresnet101':
        from models.preactresnet import preactresnet101
        net = preactresnet101()
    elif args.net == 'preactresnet152':
        from models.preactresnet import preactresnet152
        net = preactresnet152()
    elif args.net == 'resnext50':
        from models.resnext import resnext50
        net = resnext50()
    elif args.net == 'resnext101':
        from models.resnext import resnext101
        net = resnext101()
    elif args.net == 'resnext152':
        from models.resnext import resnext152
        net = resnext152()
    elif args.net == 'shufflenet':
        from models.shufflenet import shufflenet
        net = shufflenet()
    elif args.net == 'shufflenetv2':
        from models.shufflenetv2 import shufflenetv2
        net = shufflenetv2()
    elif args.net == 'squeezenet':
        from models.squeezenet import squeezenet
        net = squeezenet()
    elif args.net == 'mobilenet':
        from models.mobilenet import mobilenet
        net = mobilenet()
    elif args.net == 'mobilenetv2':
        from models.mobilenetv2 import mobilenetv2
        net = mobilenetv2()
    elif args.net == 'nasnet':
        from models.nasnet import nasnet
        net = nasnet()
    elif args.net == 'attention56':
        from models.attention import attention56
        net = attention56()
    elif args.net == 'attention92':
        from models.attention import attention92
        net = attention92()
    elif args.net == 'seresnet18':
        from models.senet import seresnet18
        net = seresnet18()
    elif args.net == 'seresnet34':
        from models.senet import seresnet34
        net = seresnet34()
    elif args.net == 'seresnet50':
        from models.senet import seresnet50
        net = seresnet50()
    elif args.net == 'seresnet101':
        from models.senet import seresnet101
        net = seresnet101()
    elif args.net == 'seresnet152':
        from models.senet import seresnet152
        net = seresnet152()
    elif args.net == 'wideresnet':
        from models.wideresidual import wideresnet
        net = wideresnet()
    elif args.net == 'stochasticdepth18':
        from models.stochasticdepth import stochastic_depth_resnet18
        net = stochastic_depth_resnet18()
    elif args.net == 'stochasticdepth34':
        from models.stochasticdepth import stochastic_depth_resnet34
        net = stochastic_depth_resnet34()
    elif args.net == 'stochasticdepth50':
        from models.stochasticdepth import stochastic_depth_resnet50
        net = stochastic_depth_resnet50()
    elif args.net == 'stochasticdepth101':
        from models.stochasticdepth import stochastic_depth_resnet101
        net = stochastic_depth_resnet101()
    elif args.net == 'efficientnetb0':
        from models.efficientnet import efficientnetb0
        net = efficientnetb0()
    elif args.net == 'efficientnetb1':
        from models.efficientnet import efficientnetb1
        net = efficientnetb1()
    elif args.net == 'efficientnetb2':
        from models.efficientnet import efficientnetb2
        net = efficientnetb2()
    elif args.net == 'efficientnetb3':
        from models.efficientnet import efficientnetb3
        net = efficientnetb3()
    elif args.net == 'efficientnetb4':
        from models.efficientnet import efficientnetb4
        net = efficientnetb4()
    elif args.net == 'efficientnetb5':
        from models.efficientnet import efficientnetb5
        net = efficientnetb5()
    elif args.net == 'efficientnetb6':
        from models.efficientnet import efficientnetb6
        net = efficientnetb6()
    elif args.net == 'efficientnetb7':
        from models.efficientnet import efficientnetb7
        net = efficientnetb7()
    elif args.net == 'efficientnetl2':
        from models.efficientnet import efficientnetl2
        net = efficientnetl2()
    elif args.net == 'eff':
        from models.efficientnet_pytorch import EfficientNet
        net = EfficientNet.from_pretrained('efficientnet-b7', num_classes=2)

    else:
        print('the network name you have entered is not supported yet')
        sys.exit()

    if args.gpu:  #use_gpu
        net = net.cuda()
        print("use-gpu")

    return net
Esempio n. 12
0
def get_network(args):
    """ return given network
    """
    if args.task == 'cifar10':
        nclass = 10
    elif args.task == 'cifar100':
        nclass = 100
    #Yang added none bn vggs
    if args.net == 'vgg16':
        from models.vgg import vgg16
        net = vgg16(num_classes=nclass)
    elif args.net == 'vgg13':
        from models.vgg import vgg13
        net = vgg13(num_classes=nclass)
    elif args.net == 'vgg11':
        from models.vgg import vgg11
        net = vgg11(num_classes=nclass)
    elif args.net == 'vgg19':
        from models.vgg import vgg19
        net = vgg19(num_classes=nclass)

    elif args.net == 'vgg16bn':
        from models.vgg import vgg16_bn
        net = vgg16_bn(num_classes=nclass)
    elif args.net == 'vgg13bn':
        from models.vgg import vgg13_bn
        net = vgg13_bn(num_classes=nclass)
    elif args.net == 'vgg11bn':
        from models.vgg import vgg11_bn
        net = vgg11_bn(num_classes=nclass)
    elif args.net == 'vgg19bn':
        from models.vgg import vgg19_bn
        net = vgg19_bn(num_classes=nclass)

    elif args.net == 'densenet121':
        from models.densenet import densenet121
        net = densenet121()
    elif args.net == 'densenet161':
        from models.densenet import densenet161
        net = densenet161()
    elif args.net == 'densenet169':
        from models.densenet import densenet169
        net = densenet169()
    elif args.net == 'densenet201':
        from models.densenet import densenet201
        net = densenet201()
    elif args.net == 'googlenet':
        from models.googlenet import googlenet
        net = googlenet(num_classes=nclass)
    elif args.net == 'inceptionv3':
        from models.inceptionv3 import inceptionv3
        net = inceptionv3()
    elif args.net == 'inceptionv4':
        from models.inceptionv4 import inceptionv4
        net = inceptionv4()
    elif args.net == 'inceptionresnetv2':
        from models.inceptionv4 import inception_resnet_v2
        net = inception_resnet_v2()
    elif args.net == 'xception':
        from models.xception import xception
        net = xception(num_classes=nclass)
    elif args.net == 'scnet':
        from models.sphereconvnet import sphereconvnet
        net = sphereconvnet(num_classes=nclass)
    elif args.net == 'sphereresnet18':
        from models.sphereconvnet import resnet18
        net = resnet18(num_classes=nclass)
    elif args.net == 'sphereresnet32':
        from models.sphereconvnet import sphereresnet32
        net = sphereresnet32(num_classes=nclass)
    elif args.net == 'plainresnet32':
        from models.sphereconvnet import plainresnet32
        net = plainresnet32(num_classes=nclass)
    elif args.net == 'ynet18':
        from models.ynet import resnet18
        net = resnet18(num_classes=nclass)
    elif args.net == 'ynet34':
        from models.ynet import resnet34
        net = resnet34(num_classes=nclass)
    elif args.net == 'ynet50':
        from models.ynet import resnet50
        net = resnet50(num_classes=nclass)
    elif args.net == 'ynet101':
        from models.ynet import resnet101
        net = resnet101(num_classes=nclass)
    elif args.net == 'ynet152':
        from models.ynet import resnet152
        net = resnet152(num_classes=nclass)

    elif args.net == 'resnet18':
        from models.resnet import resnet18
        net = resnet18(num_classes=nclass)
    elif args.net == 'resnet34':
        from models.resnet import resnet34
        net = resnet34(num_classes=nclass)
    elif args.net == 'resnet50':
        from models.resnet import resnet50
        net = resnet50(num_classes=nclass)
    elif args.net == 'resnet101':
        from models.resnet import resnet101
        net = resnet101(num_classes=nclass)
    elif args.net == 'resnet152':
        from models.resnet import resnet152
        net = resnet152(num_classes=nclass)
    elif args.net == 'preactresnet18':
        from models.preactresnet import preactresnet18
        net = preactresnet18(num_classes=nclass)
    elif args.net == 'preactresnet34':
        from models.preactresnet import preactresnet34
        net = preactresnet34(num_classes=nclass)
    elif args.net == 'preactresnet50':
        from models.preactresnet import preactresnet50
        net = preactresnet50(num_classes=nclass)
    elif args.net == 'preactresnet101':
        from models.preactresnet import preactresnet101
        net = preactresnet101(num_classes=nclass)
    elif args.net == 'preactresnet152':
        from models.preactresnet import preactresnet152
        net = preactresnet152(num_classes=nclass)
    elif args.net == 'resnext50':
        from models.resnext import resnext50
        net = resnext50(num_classes=nclass)
    elif args.net == 'resnext101':
        from models.resnext import resnext101
        net = resnext101(num_classes=nclass)
    elif args.net == 'resnext152':
        from models.resnext import resnext152
        net = resnext152(num_classes=nclass)
    elif args.net == 'shufflenet':
        from models.shufflenet import shufflenet
        net = shufflenet()
    elif args.net == 'shufflenetv2':
        from models.shufflenetv2 import shufflenetv2
        net = shufflenetv2()
    elif args.net == 'squeezenet':
        from models.squeezenet import squeezenet
        net = squeezenet()
    elif args.net == 'mobilenet':
        from models.mobilenet import mobilenet
        net = mobilenet(num_classes=nclass)
    elif args.net == 'mobilenetv2':
        from models.mobilenetv2 import mobilenetv2
        net = mobilenetv2(num_classes=nclass)
    elif args.net == 'nasnet':
        from models.nasnet import nasnet
        net = nasnet(num_classes=nclass)
    elif args.net == 'attention56':
        from models.attention import attention56
        net = attention56()
    elif args.net == 'attention92':
        from models.attention import attention92
        net = attention92()
    elif args.net == 'seresnet18':
        from models.senet import seresnet18
        net = seresnet18(num_classes=nclass)
    elif args.net == 'seresnet34':
        from models.senet import seresnet34
        net = seresnet34(num_classes=nclass)
    elif args.net == 'seresnet50':
        from models.senet import seresnet50
        net = seresnet50(num_classes=nclass)
    elif args.net == 'seresnet101':
        from models.senet import seresnet101
        net = seresnet101(num_classes=nclass)
    elif args.net == 'seresnet152':
        from models.senet import seresnet152
        net = seresnet152(num_classes=nclass)

    else:
        print('the network name you have entered is not supported yet')
        sys.exit()

    if args.gpu:  #use_gpu
        net = net.cuda()

    return net
Esempio n. 13
0
def generate_model(opt): 
    assert opt.model in ['c3d', 'squeezenet', 'mobilenet', 'resnext', 'resnet', 'resnetl',
                         'shufflenet', 'mobilenetv2', 'shufflenetv2']

    if opt.model == 'resnetl':
        assert opt.model_depth in [10] # 깊이는 10만 된다!

        from models.resnetl import get_fine_tuning_parameters # 전이학습을 위함.

        if opt.model_depth == 10:
            model = resnetl.resnetl10( 
                num_classes=opt.n_classes,  # 클래스 개수.
                shortcut_type=opt.resnet_shortcut,  # 디폴트 값 : 'B'
                sample_size=opt.sample_size, # 디폴트 값 : 112
                sample_duration=opt.sample_duration) # 디폴트 값 : 16 , 입력 프레임
    elif opt.model == 'resnext':
        assert opt.model_depth in [50, 101, 152]
        from models.resnext import get_fine_tuning_parameters
        if opt.model_depth == 50:
            model = resnext.resnext50(
                num_classes=opt.n_classes,
                shortcut_type=opt.resnet_shortcut,
                cardinality=opt.resnext_cardinality,
                sample_size=opt.sample_size,
                sample_duration=opt.sample_duration)
        elif opt.model_depth == 101:
            model = resnext.resnext101(
                num_classes=opt.n_classes,
                shortcut_type=opt.resnet_shortcut,
                cardinality=opt.resnext_cardinality,
                sample_size=opt.sample_size,
                sample_duration=opt.sample_duration)
        elif opt.model_depth == 152:
            model = resnext.resnext152(
                num_classes=opt.n_classes,
                shortcut_type=opt.resnet_shortcut,
                cardinality=opt.resnext_cardinality,
                sample_size=opt.sample_size,
                sample_duration=opt.sample_duration)       

    if not opt.no_cuda:

        if opt.gpus == '0':
            model = model.cuda()
        else:
            opt.gpus = opt.local_rank
            torch.cuda.set_device(opt.gpus)
            model = model.cuda()

        #model = nn.DataParallel(model, device_ids=None) # 병렬처리를 위함인데, 안쓸 것 같음.
        pytorch_total_params = sum(p.numel() for p in model.parameters() if
                               p.requires_grad) # grad를 하는 파라미터들을 모두 더한다.
        print("Total number of trainable parameters: ", pytorch_total_params) # 파라미터 값 출력
        
        if opt.pretrain_path: # 전이학습.
            print('loading pretrained model {}'.format(opt.pretrain_path))
            pretrain = torch.load(opt.pretrain_path, map_location=torch.device('cpu'))
            # print(opt.arch)
            # print(pretrain['arch'])
            # assert opt.arch == pretrain['arch']
            model = modify_kernels(opt, model, opt.pretrain_modality)
            model.load_state_dict(pretrain['state_dict'])
            

            if opt.model in  ['mobilenet', 'mobilenetv2', 'shufflenet', 'shufflenetv2']:
                model.module.classifier = nn.Sequential(
                                nn.Dropout(0.5),
                                nn.Linear(model.module.classifier[1].in_features, opt.n_finetune_classes))
                model.module.classifier = model.module.classifier.cuda()
            elif opt.model == 'squeezenet':
                model.module.classifier = nn.Sequential(
                                nn.Dropout(p=0.5),
                                nn.Conv3d(model.module.classifier[1].in_channels, opt.n_finetune_classes, kernel_size=1),
                                nn.ReLU(inplace=True),
                                nn.AvgPool3d((1,4,4), stride=1))
                model.module.classifier = model.module.classifier.cuda()
            else:
                model.module.fc = nn.Linear(model.module.fc.in_features, opt.n_finetune_classes)
                model.module.fc = model.module.fc.cuda()

            model = modify_kernels(opt, model, opt.modality)
        else: # 전이학습이  아닐때
            pass
            model = modify_kernels(opt, model, opt.modality)

        parameters = get_fine_tuning_parameters(model, opt.ft_portion) # 전이학습할때만 적용 지금은 그냥 파라미터 그대로 반환됨.
    return model, parameters
Esempio n. 14
0
def generate_model(opt):
    assert opt.model in ['xcresnet', 'resnet', 'resnext', 'i6f_resnet']

    if opt.model == 'xcresnet':
        assert opt.model_depth in [10, 18, 34, 50, 101, 152]
        from models.x_channel_resnet import get_fine_tuning_parameters

        if opt.model_depth == 10:
            model = x_channel_resnet.xcresnet10(num_classes=opt.n_classes,
                                                image_nums=opt.sample_duration)
        elif opt.model_depth == 18:
            model = x_channel_resnet.xcresnet18(num_classes=opt.n_classes,
                                                image_nums=opt.sample_duration)
        elif opt.model_depth == 34:
            model = x_channel_resnet.xcresnet34(num_classes=opt.n_classes,
                                                image_nums=opt.sample_duration)
        elif opt.model_depth == 50:
            model = x_channel_resnet.xcresnet50(num_classes=opt.n_classes,
                                                image_nums=opt.sample_duration)
        elif opt.model_depth == 101:
            model = x_channel_resnet.xcresnet101(
                num_classes=opt.n_classes, image_nums=opt.sample_duration)
        elif opt.model_depth == 152:
            model = x_channel_resnet.xcresnet152(
                num_classes=opt.n_classes, image_nums=opt.sample_duration)
    elif opt.model == 'resnet':
        assert opt.model_depth in [10, 18, 34, 50, 101, 152, 200]

        from models.resnet import get_fine_tuning_parameters

        if opt.model_depth == 10:
            model = resnet.resnet10(num_classes=opt.n_classes,
                                    shortcut_type=opt.resnet_shortcut,
                                    sample_size=opt.sample_size,
                                    sample_duration=opt.sample_duration)
        elif opt.model_depth == 18:
            model = resnet.resnet18(num_classes=opt.n_classes,
                                    shortcut_type=opt.resnet_shortcut,
                                    sample_size=opt.sample_size,
                                    sample_duration=opt.sample_duration)
        elif opt.model_depth == 34:
            model = resnet.resnet34(num_classes=opt.n_classes,
                                    shortcut_type=opt.resnet_shortcut,
                                    sample_size=opt.sample_size,
                                    sample_duration=opt.sample_duration)
        elif opt.model_depth == 50:
            model = resnet.resnet50(num_classes=opt.n_classes,
                                    shortcut_type=opt.resnet_shortcut,
                                    sample_size=opt.sample_size,
                                    sample_duration=opt.sample_duration)
        elif opt.model_depth == 101:
            model = resnet.resnet101(num_classes=opt.n_classes,
                                     shortcut_type=opt.resnet_shortcut,
                                     sample_size=opt.sample_size,
                                     sample_duration=opt.sample_duration)
        elif opt.model_depth == 152:
            model = resnet.resnet152(num_classes=opt.n_classes,
                                     shortcut_type=opt.resnet_shortcut,
                                     sample_size=opt.sample_size,
                                     sample_duration=opt.sample_duration)
        elif opt.model_depth == 200:
            model = resnet.resnet200(num_classes=opt.n_classes,
                                     shortcut_type=opt.resnet_shortcut,
                                     sample_size=opt.sample_size,
                                     sample_duration=opt.sample_duration)
    elif opt.model == 'resnext':
        assert opt.model_depth in [50, 101, 152]

        from models.resnext import get_fine_tuning_parameters

        if opt.model_depth == 50:
            model = resnext.resnext50(num_classes=opt.n_classes,
                                      shortcut_type=opt.resnet_shortcut,
                                      cardinality=opt.resnext_cardinality,
                                      sample_size=opt.sample_size,
                                      sample_duration=opt.sample_duration)
        elif opt.model_depth == 101:
            model = resnext.resnext101(num_classes=opt.n_classes,
                                       shortcut_type=opt.resnet_shortcut,
                                       cardinality=opt.resnext_cardinality,
                                       sample_size=opt.sample_size,
                                       sample_duration=opt.sample_duration)
        elif opt.model_depth == 152:
            model = resnext.resnext152(num_classes=opt.n_classes,
                                       shortcut_type=opt.resnet_shortcut,
                                       cardinality=opt.resnext_cardinality,
                                       sample_size=opt.sample_size,
                                       sample_duration=opt.sample_duration)
    elif opt.model == 'i6f_resnet':
        assert opt.model_depth in [10, 18, 34, 50, 101, 152, 200]

        #from models.resnet import get_fine_tuning_parameters

        if opt.model_depth == 10:
            model = i6f_resnet.i6f_resnet10(
                num_classes=opt.n_classes,
                shortcut_type=opt.resnet_shortcut,
                sample_size=opt.sample_size,
                sample_duration=opt.sample_duration)
        elif opt.model_depth == 18:
            model = i6f_resnet.i6f_resnet18(
                num_classes=opt.n_classes,
                shortcut_type=opt.resnet_shortcut,
                sample_size=opt.sample_size,
                sample_duration=opt.sample_duration)
        elif opt.model_depth == 34:
            model = i6f_resnet.i6f_resnet34(
                num_classes=opt.n_classes,
                shortcut_type=opt.resnet_shortcut,
                sample_size=opt.sample_size,
                sample_duration=opt.sample_duration)
        elif opt.model_depth == 50:
            model = i6f_resnet.i6f_resnet50(
                num_classes=opt.n_classes,
                shortcut_type=opt.resnet_shortcut,
                sample_size=opt.sample_size,
                sample_duration=opt.sample_duration)
        elif opt.model_depth == 101:
            model = i6f_resnet.i6f_resnet101(
                num_classes=opt.n_classes,
                shortcut_type=opt.resnet_shortcut,
                sample_size=opt.sample_size,
                sample_duration=opt.sample_duration)
        elif opt.model_depth == 152:
            model = i6f_resnet.i6f_resnet152(
                num_classes=opt.n_classes,
                shortcut_type=opt.resnet_shortcut,
                sample_size=opt.sample_size,
                sample_duration=opt.sample_duration)
        elif opt.model_depth == 200:
            model = i6f_resnet.i6f_resnet200(
                num_classes=opt.n_classes,
                shortcut_type=opt.resnet_shortcut,
                sample_size=opt.sample_size,
                sample_duration=opt.sample_duration)

    if not opt.no_cuda:
        model = model.cuda()
        #model = nn.DataParallel(model,device_ids=None)

        if opt.pretrain_path:
            print('loading pretrained model {}'.format(opt.pretrain_path))
            pretrain = torch.load(opt.pretrain_path)
            assert opt.arch == pretrain['arch']

            model.load_state_dict(pretrain['state_dict'])

            model.fc = nn.Linear(model.fc.in_features, opt.n_finetune_classes)
            model.fc = model.fc.cuda()

            parameters = get_fine_tuning_parameters(model, opt.ft_begin_index)
            model = nn.DataParallel(model, device_ids=None)
            return model, parameters
    else:
        if opt.pretrain_path:
            print('loading pretrained model {}'.format(opt.pretrain_path))
            pretrain = torch.load(opt.pretrain_path)
            assert opt.arch == pretrain['arch']

            model.load_state_dict(pretrain['state_dict'])

            modele.fc = nn.Linear(model.module.fc.in_features,
                                  opt.n_finetune_classes)

            parameters = get_fine_tuning_parameters(model, opt.ft_begin_index)
            return model, parameters
    model = nn.DataParallel(model, device_ids=None)
    return model, model.parameters()