Esempio n. 1
0
def get_student_model(opt):
    student = None
    student_key = None
    if opt.student_arch == 'alexnet':
        student = alexnet()
        student.fc = nn.Sequential()
        student_key = alexnet()
        student_key.fc = nn.Sequential()

    elif opt.student_arch == 'mobilenet':
        student = mobilenet()
        student.fc = nn.Sequential()
        student_key = mobilenet()
        student_key.fc = nn.Sequential()

    elif opt.student_arch == 'resnet18':
        student = resnet18()
        student.fc = nn.Sequential()
        student_key = resnet18()
        student_key.fc = nn.Sequential()

    elif opt.student_arch == 'resnet50':
        student = resnet50(fc_dim=8192)
        student_key = resnet50(fc_dim=8192)

    return student, student_key
Esempio n. 2
0
 def testCreationNoClasses(self):
     spec = copy.deepcopy(mobilenet_v2.V2_DEF)
     net, ep = mobilenet.mobilenet(tf.placeholder(tf.float32,
                                                  (10, 224, 224, 16)),
                                   conv_defs=spec,
                                   num_classes=None)
     self.assertIs(net, ep['global_pool'])
Esempio n. 3
0
 def testWithSplits(self):
   spec = copy.deepcopy(mobilenet_v2.V2_DEF)
   spec['overrides'] = {
       (ops.expanded_conv,): dict(split_expansion=2),
   }
   _, _ = mobilenet.mobilenet(
       tf.placeholder(tf.float32, (10, 224, 224, 16)), conv_defs=spec)
   num_convs = len(find_ops('Conv2D'))
   # All but 3 op has 3 conv operatore, the remainign 3 have one
   # and there is one unaccounted.
   self.assertEqual(num_convs, len(spec['spec']) * 3 - 5)
Esempio n. 4
0
  def testCreation(self):
    spec = dict(mobilenet_v2.V2_DEF)
    _, ep = mobilenet.mobilenet(
        tf.placeholder(tf.float32, (10, 224, 224, 16)), conv_defs=spec)
    num_convs = len(find_ops('Conv2D'))

    # This is mostly a sanity test. No deep reason for these particular
    # constants.
    #
    # All but first 2 and last one have  two convolutions, and there is one
    # extra conv that is not in the spec. (logits)
    self.assertEqual(num_convs, len(spec['spec']) * 2 - 2)
    # Check that depthwise are exposed.
    for i in range(2, 17):
      self.assertIn('layer_%d/depthwise_output' % i, ep)
Esempio n. 5
0
def getModel(model):

    if model == 'resnet18':
        from models.resnet import resnet
        net = resnet(groups=[1, 1, 1, 1],
                     depth=18,
                     width=[64, 128, 256, 512],
                     dataset="cifar100").cuda()
    elif model == 'mobilenet':
        from models.mobilenet import mobilenet
        net = mobilenet(num_classes=100, cifar=True).cuda()
    else:
        print('check the name of the model')
        sys.exit()

    return net
Esempio n. 6
0
def get_model(args):

    model = None
    if args.arch == 'alexnet':
        model = alexnet()
        model.fc = nn.Sequential()
        model = torch.nn.DataParallel(model).cuda()
        checkpoint = torch.load(args.weights)
        if 'model' in checkpoint:
            sd = checkpoint['model']
        else:
            sd = checkpoint['state_dict']
        sd = {k.replace('module.', ''): v for k, v in sd.items()}
        sd = {k: v for k, v in sd.items() if 'fc' not in k}
        sd = {k: v for k, v in sd.items() if 'encoder_k' not in k}
        sd = {k.replace('encoder_q.', ''): v for k, v in sd.items()}
        sd = {('module.' + k): v for k, v in sd.items()}
        msg = model.load_state_dict(sd, strict=False)
        print(model)
        print(msg)

    elif args.arch == 'pt_alexnet':
        model = models.alexnet(num_classes=16000)
        checkpoint = torch.load(args.weights)
        sd = checkpoint['state_dict']
        sd = {k.replace('module.', ''): v for k, v in sd.items()}
        msg = model.load_state_dict(sd, strict=True)
        classif = list(model.classifier.children())[:5]
        model.classifier = nn.Sequential(*classif)
        model = torch.nn.DataParallel(model).cuda()
        print(model)
        print(msg)

    elif args.arch == 'resnet18':
        model = resnet18()
        model.fc = nn.Sequential()
        model = torch.nn.DataParallel(model).cuda()
        checkpoint = torch.load(args.weights)
        if 'model' in checkpoint:
            sd = checkpoint['model']
        else:
            sd = checkpoint['state_dict']
        sd = {k.replace('module.', ''): v for k, v in sd.items()}
        sd = {k: v for k, v in sd.items() if 'fc' not in k}
        sd = {k: v for k, v in sd.items() if 'encoder_k' not in k}
        sd = {k.replace('encoder_q.', ''): v for k, v in sd.items()}
        sd = {('module.' + k): v for k, v in sd.items()}
        msg = model.load_state_dict(sd, strict=False)
        print(model)
        print(msg)

    elif args.arch == 'one_resnet50':
        model = resnet50()
        model.fc = nn.Sequential()
        checkpoint = torch.load(args.weights)
        if 'model' in checkpoint:
            sd = checkpoint['model']
        else:
            sd = checkpoint['state_dict']
        sd = {k.replace('module.', ''): v for k, v in sd.items()}
        sd = {k: v for k, v in sd.items() if 'projection' not in k}
        sd = {k: v for k, v in sd.items() if 'prediction' not in k}
        sd = {k: v for k, v in sd.items() if 'pred_' not in k}
        sd = {k: v for k, v in sd.items() if 'encoder_two' not in k}
        sd = {k.replace('encoder_one.', ''): v for k, v in sd.items()}
        sd = {k.replace('backbone.', ''): v for k, v in sd.items()}
        model.load_state_dict(sd, strict=True)
        model = torch.nn.DataParallel(model).cuda()

    elif args.arch == 'two_resnet50':
        model = resnet50()
        model.fc = nn.Sequential()
        checkpoint = torch.load(args.weights)
        if 'model' in checkpoint:
            sd = checkpoint['model']
        else:
            sd = checkpoint['state_dict']
        sd = {k.replace('module.', ''): v for k, v in sd.items()}
        sd = {k: v for k, v in sd.items() if 'projection' not in k}
        sd = {k: v for k, v in sd.items() if 'prediction' not in k}
        sd = {k: v for k, v in sd.items() if 'pred_' not in k}
        sd = {k: v for k, v in sd.items() if 'encoder_one' not in k}
        sd = {k.replace('encoder_two.', ''): v for k, v in sd.items()}
        sd = {k.replace('backbone.', ''): v for k, v in sd.items()}
        model.load_state_dict(sd, strict=True)
        model = torch.nn.DataParallel(model).cuda()

    elif args.arch == 'l_resnet18':
        model = l_resnet18()
        model.fc = nn.Sequential()
        model = torch.nn.DataParallel(model).cuda()
        checkpoint = torch.load(args.weights)
        if 'model' in checkpoint:
            sd = checkpoint['model']
        else:
            sd = checkpoint['state_dict']
        sd = {k.replace('module.', ''): v for k, v in sd.items()}
        sd = {k: v for k, v in sd.items() if 'fc' not in k}
        sd = {k: v for k, v in sd.items() if 'encoder_k' not in k}
        sd = {k.replace('encoder_q.', ''): v for k, v in sd.items()}
        sd = {('module.' + k): v for k, v in sd.items()}
        msg = model.load_state_dict(sd, strict=False)
        print(model)
        print(msg)

    elif 'teacher_' in args.arch:
        if 'resnet18' in args.arch:
            model = resnet18()
        elif 'resnet50' in args.arch:
            model = resnet50()
        model.fc = nn.Sequential()
        model = torch.nn.DataParallel(model).cuda()
        checkpoint = torch.load(args.weights)
        if 'model' in checkpoint:
            sd = checkpoint['model']
        else:
            sd = checkpoint['state_dict']

        sd = {k.replace('module.', ''): v for k, v in sd.items()}
        sd = {k: v for k, v in sd.items() if 'fc' not in k}
        sd = {k: v for k, v in sd.items() if 'predict_q' not in k}
        sd = {k: v for k, v in sd.items() if 'queue' not in k}
        new_sd = {}
        for key in sd.keys():
            if 'encoder_k' in key and 'running_' not in key:
                new_sd['module.' + key.replace('encoder_k.', '')] = sd[key]
            if 'encoder_q' in key and 'running_' in key:
                new_sd['module.' + key.replace('encoder_q.', '')] = sd[key]
        msg = model.load_state_dict(new_sd, strict=True)
        print(model)
        print(msg)

    elif args.arch == 'mobilenet':
        model = mobilenet()
        model.fc = nn.Sequential()
        model = torch.nn.DataParallel(model).cuda()
        checkpoint = torch.load(args.weights)
        msg = model.load_state_dict(checkpoint['model'], strict=False)
        print(model)
        print(msg)

    elif args.arch == 'resnet50':
        model = resnet50()
        model.fc = nn.Sequential()
        model = torch.nn.DataParallel(model).cuda()
        checkpoint = torch.load(args.weights)
        if 'model' in checkpoint:
            sd = checkpoint['model']
        else:
            sd = checkpoint['state_dict']
        sd = {k.replace('module.', ''): v for k, v in sd.items()}
        sd = {k: v for k, v in sd.items() if 'fc' not in k}
        sd = {k: v for k, v in sd.items() if 'encoder_k' not in k}
        sd = {k.replace('encoder_q.', ''): v for k, v in sd.items()}
        sd = {('module.' + k): v for k, v in sd.items()}
        msg = model.load_state_dict(sd, strict=False)
        print(model)
        print(msg)

    elif args.arch == 'byol_resnet50':
        model = byol_resnet50()
        model.fc = nn.Sequential()
        checkpoint = torch.load(args.weights)
        if 'model' in checkpoint:
            sd = checkpoint['model']
        else:
            sd = checkpoint['state_dict']
        sd = {k.replace('module.', ''): v for k, v in sd.items()}
        sd = {k: v for k, v in sd.items() if 'fc' not in k}
        sd = {k: v for k, v in sd.items() if 'encoder_k' not in k}
        sd = {k: v for k, v in sd.items() if 'predict_q' not in k}
        sd = {k: v for k, v in sd.items() if 'queue' not in k}
        sd = {k.replace('encoder_q.', ''): v for k, v in sd.items()}
        msg = model.load_state_dict(sd, strict=True)
        print(model)
        print(msg)
        model = torch.nn.DataParallel(model).cuda()

    elif args.arch == 'moco_alexnet':
        model = alexnet()
        model.fc = nn.Sequential()
        model = nn.Sequential(OrderedDict([('encoder_q', model)]))
        model = model.cuda()
        checkpoint = torch.load(args.weights)
        model.load_state_dict(checkpoint['state_dict'], strict=False)

    elif args.arch == 'moco_resnet18':
        model = resnet18().cuda()
        model = nn.Sequential(OrderedDict([('encoder_q', model)]))
        model = torch.nn.DataParallel(model).cuda()
        checkpoint = torch.load(args.weights)
        msg = model.load_state_dict(checkpoint['state_dict'], strict=False)
        print(msg)
        # model.module.encoder_q.fc = nn.Sequential()

    elif args.arch == 'moco_mobilenet':
        model = mobilenet()
        model.fc = nn.Sequential()
        model = nn.Sequential(OrderedDict([('encoder_q', model)]))
        model = torch.nn.DataParallel(model).cuda()
        checkpoint = torch.load(args.weights)
        model.load_state_dict(checkpoint['state_dict'], strict=False)

    elif args.arch == 'moco_resnet50':
        model = resnet50().cuda()
        model = nn.Sequential(OrderedDict([('encoder_q', model)]))
        model = torch.nn.DataParallel(model).cuda()
        checkpoint = torch.load(args.weights)
        model.load_state_dict(checkpoint['state_dict'], strict=False)
        model.module.encoder_q.fc = nn.Sequential()

    elif args.arch == 'resnet50w5':
        model = resnet50w5()
        model.l2norm = None
        load_weights(model, args.weights)
        model = torch.nn.DataParallel(model).cuda()

    elif args.arch == 'swav_resnet50':
        model = swav_resnet50()
        model.l2norm = None
        load_weights(model, args.weights)
        model = torch.nn.DataParallel(model).cuda()

    elif args.arch == 'sup_alexnet':
        # model = models.alexnet(pretrained=True)
        # modules = list(model.children())[:-1]
        # classifier_modules = list(model.classifier.children())[:-1]
        # modules.append(Flatten())
        # modules.append(nn.Sequential(*classifier_modules))
        # model = nn.Sequential(*modules)
        # model = model.cuda()
        ####### modified #######
        model = models.alexnet(pretrained=False)
        model.classifier = nn.Sequential()
        modules = list(model.children())
        modules.append(nn.Flatten())
        model = nn.Sequential(*modules)
        model = model.cuda()

    elif args.arch == 'sup_resnet18':
        model = models.resnet18(pretrained=True)
        model.fc = nn.Sequential()
        model = torch.nn.DataParallel(model).cuda()

    elif args.arch == 'sup_mobilenet':
        model = models.mobilenet_v2(pretrained=True)
        model.classifier = nn.Sequential()
        model = torch.nn.DataParallel(model).cuda()

    elif args.arch == 'sup_resnet50':
        model = models.resnet50(pretrained=True)
        model.fc = nn.Sequential()
        model = torch.nn.DataParallel(model).cuda()

    for param in model.parameters():
        param.requires_grad = False

    return model
Esempio n. 7
0
                                             download=False,
                                             transform=transform_test)

    trainloader = DataLoader(traindata,
                             batch_size=args.b,
                             shuffle=True,
                             num_workers=2)
    testloader = DataLoader(testdata,
                            batch_size=args.b,
                            shuffle=True,
                            num_workers=2)

    # define net
    if args.net == 'mobilenet':
        from models.mobilenet import mobilenet
        net = mobilenet(1, 100).cuda()
    elif args.net == 'mobilenetv2':
        from models.mobilenetv2 import mobilenetv2
        net = mobilenetv2(1, 100).cuda()
    elif args.net == 'shufflenet':
        from models.shufflenet import shufflenet
        net = shufflenet([4, 8, 4], 3, 1, 100).cuda()
    elif args.net == 'shufflenetv2':
        from models.shufflenetv2 import shufflenetv2
        net = shufflenetv2(100, 1).cuda()
    elif args.net == 'efficientnetb0':
        from models.efficientnet import efficientnet
        print("loading net")
        net = efficientnet(1, 1, 100, bn_momentum=0.9).cuda()
        print("loading finish")
    else:
Esempio n. 8
0
def get_model(args):

    model = None
    if args.arch == 'alexnet':
        model = alexnet()
        model.fc = nn.Sequential()
        model = torch.nn.DataParallel(model).cuda()
        checkpoint = torch.load(args.weights)
        msg = model.load_state_dict(checkpoint['model'], strict=False)
        print(msg)

    elif args.arch == 'pt_alexnet':
        model = models.alexnet(num_classes=16000)
        checkpoint = torch.load(args.weights)
        sd = checkpoint['state_dict']
        sd = {k.replace('module.', ''): v for k, v in sd.items()}
        msg = model.load_state_dict(sd, strict=True)
        classif = list(model.classifier.children())[:5]
        model.classifier = nn.Sequential(*classif)
        model = torch.nn.DataParallel(model).cuda()
        print(model)
        print(msg)

    elif args.arch == 'resnet18':
        model = resnet18()
        model.fc = nn.Sequential()
        model = torch.nn.DataParallel(model).cuda()
        checkpoint = torch.load(args.weights)
        model.load_state_dict(checkpoint['model'], strict=False)

    elif args.arch == 'mobilenet':
        model = mobilenet()
        model.fc = nn.Sequential()
        model = torch.nn.DataParallel(model).cuda()
        checkpoint = torch.load(args.weights)
        model.load_state_dict(checkpoint['model'], strict=False)

    elif args.arch == 'resnet50':
        model = resnet50()
        model.fc = nn.Sequential()
        model = torch.nn.DataParallel(model).cuda()
        checkpoint = torch.load(args.weights)
        model.load_state_dict(checkpoint['model'], strict=False)

    elif args.arch == 'moco_alexnet':
        model = alexnet()
        model.fc = nn.Sequential()
        model = nn.Sequential(OrderedDict([('encoder_q', model)]))
        model = model.cuda()
        checkpoint = torch.load(args.weights)
        model.load_state_dict(checkpoint['state_dict'], strict=False)

    elif args.arch == 'moco_resnet18':
        model = resnet18().cuda()
        model = nn.Sequential(OrderedDict([('encoder_q', model)]))
        model = torch.nn.DataParallel(model).cuda()
        checkpoint = torch.load(args.weights)
        model.load_state_dict(checkpoint['state_dict'], strict=False)
        model.module.encoder_q.fc = nn.Sequential()

    elif args.arch == 'moco_mobilenet':
        model = mobilenet()
        model.fc = nn.Sequential()
        model = nn.Sequential(OrderedDict([('encoder_q', model)]))
        model = torch.nn.DataParallel(model).cuda()
        checkpoint = torch.load(args.weights)
        model.load_state_dict(checkpoint['state_dict'], strict=False)

    elif args.arch == 'moco_resnet50':
        model = resnet50().cuda()
        model = nn.Sequential(OrderedDict([('encoder_q', model)]))
        model = torch.nn.DataParallel(model).cuda()
        checkpoint = torch.load(args.weights)
        model.load_state_dict(checkpoint['state_dict'], strict=False)
        model.module.encoder_q.fc = nn.Sequential()

    elif args.arch == 'resnet50w5':
        model = resnet50w5()
        model.l2norm = None
        load_weights(model, args.weights)
        model = torch.nn.DataParallel(model).cuda()

    elif args.arch == 'sup_alexnet':
        model = models.alexnet(pretrained=True)
        modules = list(model.children())[:-1]
        classifier_modules = list(model.classifier.children())[:-1]
        modules.append(Flatten())
        modules.append(nn.Sequential(*classifier_modules))
        model = nn.Sequential(*modules)
        model = model.cuda()

    elif args.arch == 'sup_resnet18':
        model = models.resnet18(pretrained=True)
        model.fc = nn.Sequential()
        model = torch.nn.DataParallel(model).cuda()

    elif args.arch == 'sup_mobilenet':
        model = models.mobilenet_v2(pretrained=True)
        model.classifier = nn.Sequential()
        model = torch.nn.DataParallel(model).cuda()

    elif args.arch == 'sup_resnet50':
        model = models.resnet50(pretrained=True)
        model.fc = nn.Sequential()
        model = torch.nn.DataParallel(model).cuda()

    for param in model.parameters():
        param.requires_grad = False

    return model
def get_network(args, use_gpu=True):
    """ return given network
    """

    if args.net == 'vgg16':
        from models.vgg import vgg16_bn
        net = vgg16_bn()
    elif args.net == 'vgg13':
        from models.vgg import vgg13_bn
        net = vgg13_bn()
    elif args.net == 'vgg11':
        from models.vgg import vgg11_bn
        net = vgg11_bn()
    elif args.net == 'vgg19':
        from models.vgg import vgg19_bn
        net = vgg19_bn()
    elif args.net == 'densenet121':
        from models.densenet import densenet121
        net = densenet121()
    elif args.net == 'densenet161':
        from models.densenet import densenet161
        net = densenet161()
    elif args.net == 'densenet169':
        from models.densenet import densenet169
        net = densenet169()
    elif args.net == 'densenet201':
        from models.densenet import densenet201
        net = densenet201()
    elif args.net == 'googlenet':
        from models.googlenet import googlenet
        net = googlenet()
    elif args.net == 'inceptionv3':
        from models.inceptionv3 import inceptionv3
        net = inceptionv3()
    elif args.net == 'inceptionv4':
        from models.inceptionv4 import inceptionv4
        net = inceptionv4()
    elif args.net == 'inceptionresnetv2':
        from models.inceptionv4 import inception_resnet_v2
        net = inception_resnet_v2()
    elif args.net == 'xception':
        from models.xception import xception
        net = xception()
    elif args.net == 'resnet18':
        from models.resnet import resnet18
        net = resnet18()
    elif args.net == 'resnet34':
        from models.resnet import resnet34
        net = resnet34()
    elif args.net == 'resnet50':
        from models.resnet import resnet50
        net = resnet50()
    elif args.net == 'resnet101':
        from models.resnet import resnet101
        net = resnet101()
    elif args.net == 'resnet152':
        from models.resnet import resnet152
        net = resnet152()
    elif args.net == 'preactresnet18':
        from models.preactresnet import preactresnet18
        net = preactresnet18()
    elif args.net == 'preactresnet34':
        from models.preactresnet import preactresnet34
        net = preactresnet34()
    elif args.net == 'preactresnet50':
        from models.preactresnet import preactresnet50
        net = preactresnet50()
    elif args.net == 'preactresnet101':
        from models.preactresnet import preactresnet101
        net = preactresnet101()
    elif args.net == 'preactresnet152':
        from models.preactresnet import preactresnet152
        net = preactresnet152()
    elif args.net == 'resnext50':
        from models.resnext import resnext50
        net = resnext50()
    elif args.net == 'resnext101':
        from models.resnext import resnext101
        net = resnext101()
    elif args.net == 'resnext152':
        from models.resnext import resnext152
        net = resnext152()
    elif args.net == 'shufflenet':
        from models.shufflenet import shufflenet
        net = shufflenet()
    elif args.net == 'shufflenetv2':
        from models.shufflenetv2 import shufflenetv2
        net = shufflenetv2()
    elif args.net == 'squeezenet':
        from models.squeezenet import squeezenet
        net = squeezenet()
    elif args.net == 'mobilenet':
        from models.mobilenet import mobilenet
        net = mobilenet()
    elif args.net == 'mobilenetv2':
        from models.mobilenetv2 import mobilenetv2
        net = mobilenetv2()
    elif args.net == 'nasnet':
        from models.nasnet import nasnet
        net = nasnet()
    elif args.net == 'attention56':
        from models.attention import attention56
        net = attention56()
    elif args.net == 'attention92':
        from models.attention import attention92
        net = attention92()
    elif args.net == 'seresnet18':
        from models.senet import seresnet18
        net = seresnet18()
    elif args.net == 'seresnet34':
        from models.senet import seresnet34
        net = seresnet34()
    elif args.net == 'seresnet50':
        from models.senet import seresnet50
        net = seresnet50()
    elif args.net == 'seresnet101':
        from models.senet import seresnet101
        net = seresnet101()
    elif args.net == 'seresnet152':
        from models.senet import seresnet152
        net = seresnet152()

    else:
        print('the network name you have entered is not supported yet')
        sys.exit()

    if use_gpu:
        net = net.cuda()

    return net
Esempio n. 10
0
def get_network(args, use_gpu=True):
    """ return given network
    """

    if args.net == 'vgg16':
        net = torchvision.models.vgg16_bn(pretrained=args.bool_pretrained)
        net.classifier[6] = nn.Linear(4096, args.num_classes, bias=True)
    elif args.net == 'vgg13':
        net = torchvision.models.vgg13_bn(pretrained=args.bool_pretrained)
        net.classifier[6] = nn.Linear(4096, args.num_classes, bias=True)
    elif args.net == 'vgg11':
        net = torchvision.models.vgg11_bn(pretrained=args.bool_pretrained)
        net.classifier[6] = nn.Linear(4096, args.num_classes, bias=True)
    elif args.net == 'vgg19':
        net = torchvision.models.vgg19_bn(pretrained=args.bool_pretrained)
        net.classifier[6] = nn.Linear(4096, args.num_classes, bias=True)

    ####effcientnet
    elif args.net == 'efficientnet-b5':
        from efficientnet_pytorch import EfficientNet
        if args.bool_pretrained == True:
            net = EfficientNet.from_pretrained('efficientnet-b5')
        else:
            net = EfficientNet.from_name('efficientnet-b5')
        net._fc = nn.Linear(2048, args.num_classes, bias=True)

    elif args.net == 'densenet121':
        from models.densenet import densenet121
        net = densenet121()
    elif args.net == 'densenet161':
        net = torchvision.models.densenet161(pretrained=args.bool_pretrained)
        in_features = net.classifier.in_features
        net.classifier = nn.Linear(in_features, args.num_classes, bias=True)
    elif args.net == 'densenet169':
        net = torchvision.models.densenet169(pretrained=args.bool_pretrained)
        in_features = net.classifier.in_features
        net.classifier = nn.Linear(in_features, args.num_classes, bias=True)
    elif args.net == 'densenet201':
        net = torchvision.models.densenet201(pretrained=args.bool_pretrained)
        in_features = net.classifier.in_features
        net.classifier = nn.Linear(in_features, args.num_classes, bias=True)

    elif args.net == 'googlenet':
        from models.googlenet import googlenet
        net = googlenet()
    elif args.net == 'inceptionv3':
        from models.inceptionv3 import inceptionv3
        net = inceptionv3()
    elif args.net == 'inceptionv4':
        from models.inceptionv4 import inceptionv4
        net = inceptionv4()
    elif args.net == 'inceptionresnetv2':
        from models.inceptionv4 import inception_resnet_v2
        net = inception_resnet_v2()
    elif args.net == 'xception':
        from models.xception import xception
        net = xception()

    ################## ResNet ########################################################
    elif args.net == 'resnet18':
        from models.resnet import resnet
        net = resnet(args.num_classes, 2, args.pretrained, args.net)
    elif args.net == 'resnet34':
        from models.resnet import resnet
        net = resnet(args.num_classes, 2, args.pretrained, args.net)
    elif args.net == 'resnet50':
        from models.resnet import resnet
        net = resnet(args.num_classes, 2, args.pretrained, args.net)
    elif args.net == 'resnet101':
        from models.resnet import resnet
        net = resnet(args.num_classes, 2, args.pretrained, args.net)
    elif args.net == 'resnet152':
        from models.resnet import resnet
        net = resnet(args.num_classes, 2, args.pretrained, args.net)

    ##################################################################################
    elif args.net == 'preactresnet18':
        from models.preactresnet import preactresnet18
        net = preactresnet18()
    elif args.net == 'preactresnet34':
        from models.preactresnet import preactresnet34
        net = preactresnet34()
    elif args.net == 'preactresnet50':
        from models.preactresnet import preactresnet50
        net = preactresnet50()
    elif args.net == 'preactresnet101':
        from models.preactresnet import preactresnet101
        net = preactresnet101()
    elif args.net == 'preactresnet152':
        from models.preactresnet import preactresnet152
        net = preactresnet152()

    ##################################################################
    elif args.net == 'se_resnext50':
        from models.resnext import se_resnext
        net = se_resnext(args.num_classes, 2, args.pretrained, args.net)
    elif args.net == 'se_resnext101':
        from models.resnext import se_resnext
        net = se_resnext(args.num_classes, 2, args.pretrained, args.net)

    #################################################################
    elif args.net == 'shufflenet':
        from models.shufflenet import shufflenet
        net = shufflenet()
    elif args.net == 'shufflenetv2':
        from models.shufflenetv2 import shufflenetv2
        net = shufflenetv2()
    elif args.net == 'squeezenet':
        from models.squeezenet import squeezenet
        net = squeezenet()
    elif args.net == 'mobilenet':
        from models.mobilenet import mobilenet
        net = mobilenet()
    elif args.net == 'mobilenetv2':
        from models.mobilenetv2 import mobilenetv2
        net = mobilenetv2()
    elif args.net == 'nasnet':
        from models.nasnet import nasnet
        net = nasnet()
    elif args.net == 'attention56':
        from models.attention import attention56
        net = attention56()
    elif args.net == 'attention92':
        from models.attention import attention92
        net = attention92()

    #########################################################
    elif args.net == 'se_resnet50':
        from models.senet import seresnet
        net = seresnet(args.num_classes, 2, args.pretrained, args.net)
    elif args.net == 'se_resnet101':
        from models.senet import seresnet
        net = seresnet(args.num_classes, 2, args.pretrained, args.net)
    elif args.net == 'se_resnet152':
        from models.senet import seresnet
        net = seresnet(args.num_classes, 2, args.pretrained, args.net)

    else:
        print('the network name you have entered is not supported yet')
        sys.exit()

    if use_gpu:
        net = net.cuda()

    return net
Esempio n. 11
0
File: utils.py Progetto: nblt/DLDR
def get_model(args):
    if args.datasets == 'ImageNet':
        return models_imagenet.__dict__[args.arch]()

    if args.datasets == 'CIFAR10' or args.datasets == 'MNIST':
        num_class = 10
    elif args.datasets == 'CIFAR100':
        num_class = 100

    if args.datasets == 'CIFAR100':
        if args.arch == 'vgg16':
            from models.vgg import vgg16_bn
            net = vgg16_bn()
        elif args.arch == 'vgg13':
            from models.vgg import vgg13_bn
            net = vgg13_bn()
        elif args.arch == 'vgg11':
            from models.vgg import vgg11_bn
            net = vgg11_bn()
        elif args.arch == 'vgg19':
            from models.vgg import vgg19_bn
            net = vgg19_bn()
        elif args.arch == 'densenet121':
            from models.densenet import densenet121
            net = densenet121()
        elif args.arch == 'densenet161':
            from models.densenet import densenet161
            net = densenet161()
        elif args.arch == 'densenet169':
            from models.densenet import densenet169
            net = densenet169()
        elif args.arch == 'densenet201':
            from models.densenet import densenet201
            net = densenet201()
        elif args.arch == 'googlenet':
            from models.googlenet import googlenet
            net = googlenet()
        elif args.arch == 'inceptionv3':
            from models.inceptionv3 import inceptionv3
            net = inceptionv3()
        elif args.arch == 'inceptionv4':
            from models.inceptionv4 import inceptionv4
            net = inceptionv4()
        elif args.arch == 'inceptionresnetv2':
            from models.inceptionv4 import inception_resnet_v2
            net = inception_resnet_v2()
        elif args.arch == 'xception':
            from models.xception import xception
            net = xception()
        elif args.arch == 'resnet18':
            from models.resnet import resnet18
            net = resnet18()
        elif args.arch == 'resnet34':
            from models.resnet import resnet34
            net = resnet34()
        elif args.arch == 'resnet50':
            from models.resnet import resnet50
            net = resnet50()
        elif args.arch == 'resnet101':
            from models.resnet import resnet101
            net = resnet101()
        elif args.arch == 'resnet152':
            from models.resnet import resnet152
            net = resnet152()
        elif args.arch == 'preactresnet18':
            from models.preactresnet import preactresnet18
            net = preactresnet18()
        elif args.arch == 'preactresnet34':
            from models.preactresnet import preactresnet34
            net = preactresnet34()
        elif args.arch == 'preactresnet50':
            from models.preactresnet import preactresnet50
            net = preactresnet50()
        elif args.arch == 'preactresnet101':
            from models.preactresnet import preactresnet101
            net = preactresnet101()
        elif args.arch == 'preactresnet152':
            from models.preactresnet import preactresnet152
            net = preactresnet152()
        elif args.arch == 'resnext50':
            from models.resnext import resnext50
            net = resnext50()
        elif args.arch == 'resnext101':
            from models.resnext import resnext101
            net = resnext101()
        elif args.arch == 'resnext152':
            from models.resnext import resnext152
            net = resnext152()
        elif args.arch == 'shufflenet':
            from models.shufflenet import shufflenet
            net = shufflenet()
        elif args.arch == 'shufflenetv2':
            from models.shufflenetv2 import shufflenetv2
            net = shufflenetv2()
        elif args.arch == 'squeezenet':
            from models.squeezenet import squeezenet
            net = squeezenet()
        elif args.arch == 'mobilenet':
            from models.mobilenet import mobilenet
            net = mobilenet()
        elif args.arch == 'mobilenetv2':
            from models.mobilenetv2 import mobilenetv2
            net = mobilenetv2()
        elif args.arch == 'nasnet':
            from models.nasnet import nasnet
            net = nasnet()
        elif args.arch == 'attention56':
            from models.attention import attention56
            net = attention56()
        elif args.arch == 'attention92':
            from models.attention import attention92
            net = attention92()
        elif args.arch == 'seresnet18':
            from models.senet import seresnet18
            net = seresnet18()
        elif args.arch == 'seresnet34':
            from models.senet import seresnet34
            net = seresnet34()
        elif args.arch == 'seresnet50':
            from models.senet import seresnet50
            net = seresnet50()
        elif args.arch == 'seresnet101':
            from models.senet import seresnet101
            net = seresnet101()
        elif args.arch == 'seresnet152':
            from models.senet import seresnet152
            net = seresnet152()
        elif args.arch == 'wideresnet':
            from models.wideresidual import wideresnet
            net = wideresnet()
        elif args.arch == 'stochasticdepth18':
            from models.stochasticdepth import stochastic_depth_resnet18
            net = stochastic_depth_resnet18()
        elif args.arch == 'efficientnet':
            from models.efficientnet import efficientnet
            net = efficientnet(1, 1, 100, bn_momentum=0.9)
        elif args.arch == 'stochasticdepth34':
            from models.stochasticdepth import stochastic_depth_resnet34
            net = stochastic_depth_resnet34()
        elif args.arch == 'stochasticdepth50':
            from models.stochasticdepth import stochastic_depth_resnet50
            net = stochastic_depth_resnet50()
        elif args.arch == 'stochasticdepth101':
            from models.stochasticdepth import stochastic_depth_resnet101
            net = stochastic_depth_resnet101()
        else:
            net = resnet.__dict__[args.arch](num_classes=num_class)

        return net
    return resnet.__dict__[args.arch](num_classes=num_class)
def get_model(args):

    if args.arch == 'alexnet':
        model = alexnet()
        model.fc = nn.Sequential()
        model = torch.nn.DataParallel(model).cuda()
        checkpoint = torch.load(args.weights)
        model.load_state_dict(checkpoint['model'], strict=False)

    elif args.arch == 'resnet18':
        model = resnet18()
        model.fc = nn.Sequential()
        model = torch.nn.DataParallel(model).cuda()
        checkpoint = torch.load(args.weights)
        model.load_state_dict(checkpoint['model'], strict=False)

    elif args.arch == 'mobilenet':
        model = mobilenet()
        model.fc = nn.Sequential()
        model = torch.nn.DataParallel(model).cuda()
        checkpoint = torch.load(args.weights)
        model.load_state_dict(checkpoint['model'], strict=False)

    elif args.arch == 'resnet50':
        model = resnet50()
        model.fc = nn.Sequential()
        model = torch.nn.DataParallel(model).cuda()
        checkpoint = torch.load(args.weights)
        model.load_state_dict(checkpoint['model'], strict=False)

    elif args.arch == 'moco_alexnet':
        model = alexnet()
        model.fc = nn.Sequential()
        model = nn.Sequential(OrderedDict([('encoder_q', model)]))
        model = torch.nn.DataParallel(model).cuda()
        checkpoint = torch.load(args.weights)
        model.load_state_dict(checkpoint['state_dict'], strict=False)

    elif args.arch == 'moco_resnet18':
        model = resnet18().cuda()
        model = nn.Sequential(OrderedDict([('encoder_q', model)]))
        model = torch.nn.DataParallel(model).cuda()
        checkpoint = torch.load(args.weights)
        model.load_state_dict(checkpoint['state_dict'], strict=False)
        model.module.encoder_q.fc = nn.Sequential()

    elif args.arch == 'moco_mobilenet':
        model = mobilenet()
        model.fc = nn.Sequential()
        model = nn.Sequential(OrderedDict([('encoder_q', model)]))
        model = torch.nn.DataParallel(model).cuda()
        checkpoint = torch.load(args.weights)
        model.load_state_dict(checkpoint['state_dict'], strict=False)

    elif args.arch == 'moco_resnet50':
        model = resnet50().cuda()
        model = nn.Sequential(OrderedDict([('encoder_q', model)]))
        model = torch.nn.DataParallel(model).cuda()
        checkpoint = torch.load(args.weights)
        model.load_state_dict(checkpoint['state_dict'], strict=False)
        model.module.encoder_q.fc = nn.Sequential()

    elif args.arch == 'sup_alexnet':
        model = models.alexnet(pretrained=True)
        modules = list(model.children())[:-1]
        classifier_modules = list(model.classifier.children())[:-1]
        modules.append(Flatten())
        modules.append(nn.Sequential(*classifier_modules))
        model = nn.Sequential(*modules)
        model = torch.nn.DataParallel(model).cuda()

    elif args.arch == 'sup_resnet18':
        model = models.resnet18(pretrained=True)
        model.fc = nn.Sequential()
        model = torch.nn.DataParallel(model).cuda()

    elif args.arch == 'sup_mobilenet':
        model = models.mobilenet_v2(pretrained=True)
        model.classifier = nn.Sequential()
        model = torch.nn.DataParallel(model).cuda()

    elif args.arch == 'sup_resnet50':
        model = models.resnet50(pretrained=True)
        model.fc = nn.Sequential()
        model = torch.nn.DataParallel(model).cuda()

    for param in model.parameters():
        param.requires_grad = False

    return model
Esempio n. 13
0
def get_network(args):
    """ return given network
    """

    if args.net == 'vgg16':
        from models.vgg import vgg16_bn
        net = vgg16_bn()
    elif args.net == 'vgg13':
        from models.vgg import vgg13_bn
        net = vgg13_bn()
    elif args.net == 'vgg11':
        from models.vgg import vgg11_bn
        net = vgg11_bn()
    elif args.net == 'vgg19':
        from models.vgg import vgg19_bn
        net = vgg19_bn()
    # elif args.net == 'efficientnet':
    #     from models.effnetv2 import effnetv2_s
    #     net = effnetv2_s()
    elif args.net == 'densenet121':
        from models.densenet import densenet121
        net = densenet121()
    elif args.net == 'densenet161':
        from models.densenet import densenet161
        net = densenet161()
    elif args.net == 'densenet169':
        from models.densenet import densenet169
        net = densenet169()
    elif args.net == 'densenet201':
        from models.densenet import densenet201
        net = densenet201()
    elif args.net == 'googlenet':
        from models.googlenet import googlenet
        net = googlenet()
    elif args.net == 'inceptionv3':
        from models.inceptionv3 import inceptionv3
        net = inceptionv3()
    elif args.net == 'inceptionv4':
        from models.inceptionv4 import inceptionv4
        net = inceptionv4()
    elif args.net == 'inceptionresnetv2':
        from models.inceptionv4 import inception_resnet_v2
        net = inception_resnet_v2()
    elif args.net == 'xception':
        from models.xception import xception
        net = xception()
    elif args.net == 'resnet18':
        from models.resnet import resnet18
        net = resnet18()
    elif args.net == 'resnet34':
        from models.resnet import resnet34
        net = resnet34()
    elif args.net == 'resnet50':
        from models.resnet import resnet50
        net = resnet50()
    elif args.net == 'resnet101':
        from models.resnet import resnet101
        net = resnet101()
    elif args.net == 'resnet152':
        from models.resnet import resnet152
        net = resnet152()
    elif args.net == 'preactresnet18':
        from models.preactresnet import preactresnet18
        net = preactresnet18()
    elif args.net == 'preactresnet34':
        from models.preactresnet import preactresnet34
        net = preactresnet34()
    elif args.net == 'preactresnet50':
        from models.preactresnet import preactresnet50
        net = preactresnet50()
    elif args.net == 'preactresnet101':
        from models.preactresnet import preactresnet101
        net = preactresnet101()
    elif args.net == 'preactresnet152':
        from models.preactresnet import preactresnet152
        net = preactresnet152()
    elif args.net == 'resnext50':
        from models.resnext import resnext50
        net = resnext50()
    elif args.net == 'resnext101':
        from models.resnext import resnext101
        net = resnext101()
    elif args.net == 'resnext152':
        from models.resnext import resnext152
        net = resnext152()
    elif args.net == 'shufflenet':
        from models.shufflenet import shufflenet
        net = shufflenet()
    elif args.net == 'shufflenetv2':
        from models.shufflenetv2 import shufflenetv2
        net = shufflenetv2()
    elif args.net == 'squeezenet':
        from models.squeezenet import squeezenet
        net = squeezenet()
    elif args.net == 'mobilenet':
        from models.mobilenet import mobilenet
        net = mobilenet()
    elif args.net == 'mobilenetv2':
        from models.mobilenetv2 import mobilenetv2
        net = mobilenetv2()
    elif args.net == 'nasnet':
        from models.nasnet import nasnet
        net = nasnet()
    elif args.net == 'attention56':
        from models.attention import attention56
        net = attention56()
    elif args.net == 'attention92':
        from models.attention import attention92
        net = attention92()
    elif args.net == 'seresnet18':
        from models.senet import seresnet18
        net = seresnet18()
    elif args.net == 'seresnet34':
        from models.senet import seresnet34
        net = seresnet34()
    elif args.net == 'seresnet50':
        from models.senet import seresnet50
        net = seresnet50()
    elif args.net == 'seresnet101':
        from models.senet import seresnet101
        net = seresnet101()
    elif args.net == 'seresnet152':
        from models.senet import seresnet152
        net = seresnet152()
    elif args.net == 'wideresnet':
        from models.wideresidual import wideresnet
        net = wideresnet()
    elif args.net == 'stochasticdepth18':
        from models.stochasticdepth import stochastic_depth_resnet18
        net = stochastic_depth_resnet18()
    elif args.net == 'stochasticdepth34':
        from models.stochasticdepth import stochastic_depth_resnet34
        net = stochastic_depth_resnet34()
    elif args.net == 'stochasticdepth50':
        from models.stochasticdepth import stochastic_depth_resnet50
        net = stochastic_depth_resnet50()
    elif args.net == 'stochasticdepth101':
        from models.stochasticdepth import stochastic_depth_resnet101
        net = stochastic_depth_resnet101()
    elif args.net == 'efficientnetb0':
        from models.efficientnet import efficientnetb0
        net = efficientnetb0()
    elif args.net == 'efficientnetb1':
        from models.efficientnet import efficientnetb1
        net = efficientnetb1()
    elif args.net == 'efficientnetb2':
        from models.efficientnet import efficientnetb2
        net = efficientnetb2()
    elif args.net == 'efficientnetb3':
        from models.efficientnet import efficientnetb3
        net = efficientnetb3()
    elif args.net == 'efficientnetb4':
        from models.efficientnet import efficientnetb4
        net = efficientnetb4()
    elif args.net == 'efficientnetb5':
        from models.efficientnet import efficientnetb5
        net = efficientnetb5()
    elif args.net == 'efficientnetb6':
        from models.efficientnet import efficientnetb6
        net = efficientnetb6()
    elif args.net == 'efficientnetb7':
        from models.efficientnet import efficientnetb7
        net = efficientnetb7()
    elif args.net == 'efficientnetl2':
        from models.efficientnet import efficientnetl2
        net = efficientnetl2()
    elif args.net == 'eff':
        from models.efficientnet_pytorch import EfficientNet
        net = EfficientNet.from_pretrained('efficientnet-b7', num_classes=2)

    else:
        print('the network name you have entered is not supported yet')
        sys.exit()

    if args.gpu:  #use_gpu
        net = net.cuda()
        print("use-gpu")

    return net
Esempio n. 14
0
def get_network(args):
    """ return given network
    """
    if args.task == 'cifar10':
        nclass = 10
    elif args.task == 'cifar100':
        nclass = 100
    #Yang added none bn vggs
    if args.net == 'vgg16':
        from models.vgg import vgg16
        net = vgg16(num_classes=nclass)
    elif args.net == 'vgg13':
        from models.vgg import vgg13
        net = vgg13(num_classes=nclass)
    elif args.net == 'vgg11':
        from models.vgg import vgg11
        net = vgg11(num_classes=nclass)
    elif args.net == 'vgg19':
        from models.vgg import vgg19
        net = vgg19(num_classes=nclass)

    elif args.net == 'vgg16bn':
        from models.vgg import vgg16_bn
        net = vgg16_bn(num_classes=nclass)
    elif args.net == 'vgg13bn':
        from models.vgg import vgg13_bn
        net = vgg13_bn(num_classes=nclass)
    elif args.net == 'vgg11bn':
        from models.vgg import vgg11_bn
        net = vgg11_bn(num_classes=nclass)
    elif args.net == 'vgg19bn':
        from models.vgg import vgg19_bn
        net = vgg19_bn(num_classes=nclass)

    elif args.net == 'densenet121':
        from models.densenet import densenet121
        net = densenet121()
    elif args.net == 'densenet161':
        from models.densenet import densenet161
        net = densenet161()
    elif args.net == 'densenet169':
        from models.densenet import densenet169
        net = densenet169()
    elif args.net == 'densenet201':
        from models.densenet import densenet201
        net = densenet201()
    elif args.net == 'googlenet':
        from models.googlenet import googlenet
        net = googlenet(num_classes=nclass)
    elif args.net == 'inceptionv3':
        from models.inceptionv3 import inceptionv3
        net = inceptionv3()
    elif args.net == 'inceptionv4':
        from models.inceptionv4 import inceptionv4
        net = inceptionv4()
    elif args.net == 'inceptionresnetv2':
        from models.inceptionv4 import inception_resnet_v2
        net = inception_resnet_v2()
    elif args.net == 'xception':
        from models.xception import xception
        net = xception(num_classes=nclass)
    elif args.net == 'scnet':
        from models.sphereconvnet import sphereconvnet
        net = sphereconvnet(num_classes=nclass)
    elif args.net == 'sphereresnet18':
        from models.sphereconvnet import resnet18
        net = resnet18(num_classes=nclass)
    elif args.net == 'sphereresnet32':
        from models.sphereconvnet import sphereresnet32
        net = sphereresnet32(num_classes=nclass)
    elif args.net == 'plainresnet32':
        from models.sphereconvnet import plainresnet32
        net = plainresnet32(num_classes=nclass)
    elif args.net == 'ynet18':
        from models.ynet import resnet18
        net = resnet18(num_classes=nclass)
    elif args.net == 'ynet34':
        from models.ynet import resnet34
        net = resnet34(num_classes=nclass)
    elif args.net == 'ynet50':
        from models.ynet import resnet50
        net = resnet50(num_classes=nclass)
    elif args.net == 'ynet101':
        from models.ynet import resnet101
        net = resnet101(num_classes=nclass)
    elif args.net == 'ynet152':
        from models.ynet import resnet152
        net = resnet152(num_classes=nclass)

    elif args.net == 'resnet18':
        from models.resnet import resnet18
        net = resnet18(num_classes=nclass)
    elif args.net == 'resnet34':
        from models.resnet import resnet34
        net = resnet34(num_classes=nclass)
    elif args.net == 'resnet50':
        from models.resnet import resnet50
        net = resnet50(num_classes=nclass)
    elif args.net == 'resnet101':
        from models.resnet import resnet101
        net = resnet101(num_classes=nclass)
    elif args.net == 'resnet152':
        from models.resnet import resnet152
        net = resnet152(num_classes=nclass)
    elif args.net == 'preactresnet18':
        from models.preactresnet import preactresnet18
        net = preactresnet18(num_classes=nclass)
    elif args.net == 'preactresnet34':
        from models.preactresnet import preactresnet34
        net = preactresnet34(num_classes=nclass)
    elif args.net == 'preactresnet50':
        from models.preactresnet import preactresnet50
        net = preactresnet50(num_classes=nclass)
    elif args.net == 'preactresnet101':
        from models.preactresnet import preactresnet101
        net = preactresnet101(num_classes=nclass)
    elif args.net == 'preactresnet152':
        from models.preactresnet import preactresnet152
        net = preactresnet152(num_classes=nclass)
    elif args.net == 'resnext50':
        from models.resnext import resnext50
        net = resnext50(num_classes=nclass)
    elif args.net == 'resnext101':
        from models.resnext import resnext101
        net = resnext101(num_classes=nclass)
    elif args.net == 'resnext152':
        from models.resnext import resnext152
        net = resnext152(num_classes=nclass)
    elif args.net == 'shufflenet':
        from models.shufflenet import shufflenet
        net = shufflenet()
    elif args.net == 'shufflenetv2':
        from models.shufflenetv2 import shufflenetv2
        net = shufflenetv2()
    elif args.net == 'squeezenet':
        from models.squeezenet import squeezenet
        net = squeezenet()
    elif args.net == 'mobilenet':
        from models.mobilenet import mobilenet
        net = mobilenet(num_classes=nclass)
    elif args.net == 'mobilenetv2':
        from models.mobilenetv2 import mobilenetv2
        net = mobilenetv2(num_classes=nclass)
    elif args.net == 'nasnet':
        from models.nasnet import nasnet
        net = nasnet(num_classes=nclass)
    elif args.net == 'attention56':
        from models.attention import attention56
        net = attention56()
    elif args.net == 'attention92':
        from models.attention import attention92
        net = attention92()
    elif args.net == 'seresnet18':
        from models.senet import seresnet18
        net = seresnet18(num_classes=nclass)
    elif args.net == 'seresnet34':
        from models.senet import seresnet34
        net = seresnet34(num_classes=nclass)
    elif args.net == 'seresnet50':
        from models.senet import seresnet50
        net = seresnet50(num_classes=nclass)
    elif args.net == 'seresnet101':
        from models.senet import seresnet101
        net = seresnet101(num_classes=nclass)
    elif args.net == 'seresnet152':
        from models.senet import seresnet152
        net = seresnet152(num_classes=nclass)

    else:
        print('the network name you have entered is not supported yet')
        sys.exit()

    if args.gpu:  #use_gpu
        net = net.cuda()

    return net
Esempio n. 15
0
def get_network(args, use_gpu=True):
    """ return given network
    """
    if args.net == 'mobilenet':
        from models.mobilenet import mobilenet
        net = mobilenet(args)
    elif args.net == 'mobilenetv2':
        from models.mobilenetv2 import mobilenetv2
        net = mobilenetv2(args)
    elif args.net == 'vgg13':
        from models.vgg import vgg13_bn
        net = vgg13_bn(args)
    elif args.net == 'vgg11':
        from models.vgg import vgg11_bn
        net = vgg11_bn(args)
    elif args.net == 'vgg19':
        # from models.vgg import vgg19_bn
        # net = vgg19_bn(args)
        from torchvision.models import vgg19_bn
        import torch.nn as nn
        net = vgg19_bn(pretrained=True)
        net.classifier[6] = nn.Linear(4096, args.nc)
    elif args.net == 'densenet121':
        from models.densenet import densenet121
        net = densenet121(args)
    elif args.net == 'densenet161':
        from models.densenet import densenet161
        net = densenet161(args)
    elif args.net == 'densenet169':
        from models.densenet import densenet169
        net = densenet169(args)
    elif args.net == 'densenet201':
        from models.densenet import densenet201
        net = densenet201(args)
    elif args.net == 'googlenet':
        from models.googlenet import googlenet
        net = googlenet(args)
    elif args.net == 'inceptionv3':
        from models.inceptionv3 import inceptionv3
        net = inceptionv3(args)
    elif args.net == 'inceptionv4':
        from models.inceptionv4 import inceptionv4
        net = inceptionv4(args)
    elif args.net == 'inceptionresnetv2':
        from models.inceptionv4 import inception_resnet_v2
        net = inception_resnet_v2(args)
    elif args.net == 'xception':
        from models.xception import xception
        net = xception(args)
    elif args.net == 'resnet18':
        # from models.resnet import resnet18
        # net = resnet18(args)
        from torchvision.models import resnet18
        import torch.nn as nn
        net = resnet18(pretrained=True)
        net.fc = nn.Linear(512, args.nc)
    elif args.net == 'resnet34':
        from models.resnet import resnet34
        net = resnet34(args)
    elif args.net == 'resnet50':
        from models.resnet import resnet50
        net = resnet50(args)
    elif args.net == 'resnet101':
        from models.resnet import resnet101
        net = resnet101(args)
    elif args.net == 'resnet152':
        from models.resnet import resnet152
        net = resnet152(args)
    elif args.net == 'preactresnet18':
        from models.preactresnet import preactresnet18
        net = preactresnet18(args)
    elif args.net == 'preactresnet34':
        from models.preactresnet import preactresnet34
        net = preactresnet34(args)
    elif args.net == 'preactresnet50':
        from models.preactresnet import preactresnet50
        net = preactresnet50(args)
    elif args.net == 'preactresnet101':
        from models.preactresnet import preactresnet101
        net = preactresnet101(args)
    elif args.net == 'preactresnet152':
        from models.preactresnet import preactresnet152
        net = preactresnet152(args)
    elif args.net == 'resnext50':
        from models.resnext import resnext50
        net = resnext50(args)
    elif args.net == 'resnext101':
        from models.resnext import resnext101
        net = resnext101(args)
    elif args.net == 'resnext152':
        from models.resnext import resnext152
        net = resnext152(args)
    elif args.net == 'shufflenet':
        from models.shufflenet import shufflenet
        net = shufflenet(args)
    elif args.net == 'shufflenetv2':
        from models.shufflenetv2 import shufflenetv2
        net = shufflenetv2(args)
    elif args.net == 'squeezenet':
        from models.squeezenet import squeezenet
        net = squeezenet(args)

    elif args.net == 'mobilenet':
        from models.mobilenet import mobilenet
        net = mobilenet(args)
    elif args.net == 'mobilenetv2':
        from models.mobilenetv2 import mobilenetv2
        net = mobilenetv2(args)
    elif args.net == 'mobilenetv3':
        from models.mobilenetv3 import mobileNetv3
        net = mobileNetv3(args)
    elif args.net == 'mobilenetv3_l':
        from models.mobilenetv3 import mobileNetv3
        net = mobileNetv3(args, mode='large')
    elif args.net == 'mobilenetv3_s':
        from models.mobilenetv3 import mobileNetv3
        net = mobileNetv3(args, mode='small')
    elif args.net == 'nasnet':
        from models.nasnet import nasnetalarge
        net = nasnetalarge(args)
    elif args.net == 'attention56':
        from models.attention import attention56
        net = attention56(args)
    elif args.net == 'attention92':
        from models.attention import attention92
        net = attention92(args)
    elif args.net == 'seresnet18':
        from models.senet import seresnet18
        net = seresnet18(args)
    elif args.net == 'seresnet34':
        from models.senet import seresnet34
        net = seresnet34(args)
    elif args.net == 'seresnet50':
        from models.senet import seresnet50
        net = seresnet50(args)
    elif args.net == 'seresnet101':
        from models.senet import seresnet101
        net = seresnet101(args)
    elif args.net == 'seresnet152':
        from models.senet import seresnet152
        net = seresnet152(args)
    elif args.net.lower() == 'sqnxt_23_1x':
        from models.SqueezeNext import SqNxt_23_1x
        net = SqNxt_23_1x(args)
    elif args.net.lower() == 'sqnxt_23_1xv5':
        from models.SqueezeNext import SqNxt_23_1x_v5
        net = SqNxt_23_1x_v5(args)
    elif args.net.lower() == 'sqnxt_23_2x':
        from models.SqueezeNext import SqNxt_23_2x
        net = SqNxt_23_2x(args)
    elif args.net.lower() == 'sqnxt_23_2xv5':
        from models.SqueezeNext import SqNxt_23_2x_v5
        net = SqNxt_23_2x_v5(args)
    elif args.net.lower() == 'mnasnet':
        # from models.MnasNet import mnasnet
        # net = mnasnet(args)
        from models.nasnet_mobile import nasnet_Mobile
        net = nasnet_Mobile(args)
    elif args.net == 'efficientnet_b0':
        from models.efficientnet import efficientnet_b0
        net = efficientnet_b0(args)
    elif args.net == 'efficientnet_b1':
        from models.efficientnet import efficientnet_b1
        net = efficientnet_b1(args)
    elif args.net == 'efficientnet_b2':
        from models.efficientnet import efficientnet_b2
        net = efficientnet_b2(args)
    elif args.net == 'efficientnet_b3':
        from models.efficientnet import efficientnet_b3
        net = efficientnet_b3(args)
    elif args.net == 'efficientnet_b4':
        from models.efficientnet import efficientnet_b4
        net = efficientnet_b4(args)
    elif args.net == 'efficientnet_b5':
        from models.efficientnet import efficientnet_b5
        net = efficientnet_b5(args)
    elif args.net == 'efficientnet_b6':
        from models.efficientnet import efficientnet_b6
        net = efficientnet_b6(args)
    elif args.net == 'efficientnet_b7':
        from models.efficientnet import efficientnet_b7
        net = efficientnet_b7(args)
    elif args.net == 'mlp':
        from models.mlp import MLPClassifier
        net = MLPClassifier(args)
    elif args.net == 'alexnet':
        from torchvision.models import alexnet
        import torch.nn as nn
        net = alexnet(pretrained=True)
        net.classifier[6] = nn.Linear(4096, args.nc)
    elif args.net == 'lambda18':
        from models._lambda import LambdaResnet18
        net = LambdaResnet18(num_classes=args.nc, channels=args.cs)
    elif args.net == 'lambda34':
        from models._lambda import LambdaResnet34
        net = LambdaResnet34(num_classes=args.nc, channels=args.cs)
    elif args.net == 'lambda50':
        from models._lambda import LambdaResnet50
        net = LambdaResnet50(num_classes=args.nc, channels=args.cs)
    elif args.net == 'lambda101':
        from models._lambda import LambdaResnet101
        net = LambdaResnet101(num_classes=args.nc)
    elif args.net == 'lambda152':
        from models._lambda import LambdaResnet152
        net = LambdaResnet152(num_classes=args.nc, channels=args.cs)
    else:
        print('the network name you have entered is not supported yet')
        sys.exit()

    if use_gpu:
        net = net.cuda()

    return net
Esempio n. 16
0
def get_network(args):
    """ return given network
    """

    if args.net == 'vgg16':
        from models.vgg import vgg16_bn
        net = vgg16_bn()
    elif args.net == 'vgg13':
        from models.vgg import vgg13_bn
        net = vgg13_bn()
    elif args.net == 'vgg11':
        from models.vgg import vgg11_bn
        net = vgg11_bn()
    elif args.net == 'vgg19':
        from models.vgg import vgg19_bn
        net = vgg19_bn()
    elif args.net == 'densenet121':
        from models.densenet import densenet121
        net = densenet121()
    elif args.net == 'densenet161':
        from models.densenet import densenet161
        net = densenet161()
    elif args.net == 'densenet169':
        from models.densenet import densenet169
        net = densenet169()
    elif args.net == 'densenet201':
        from models.densenet import densenet201
        net = densenet201()
    elif args.net == 'googlenet':
        from models.googlenet import googlenet
        net = googlenet()
    elif args.net == 'inceptionv3':
        from models.inceptionv3 import inceptionv3
        net = inceptionv3()
    elif args.net == 'inceptionv4':
        from models.inceptionv4 import inceptionv4
        net = inceptionv4()
    elif args.net == 'inceptionresnetv2':
        from models.inceptionv4 import inception_resnet_v2
        net = inception_resnet_v2()
    elif args.net == 'xception':
        from models.xception import xception
        net = xception()
    elif args.net == 'resnet18':
        from models.resnet import resnet18
        net = resnet18()
    elif args.net == 'resnet34':
        from models.resnet import resnet34
        net = resnet34()
    elif args.net == 'resnet50':
        from models.resnet import resnet50
        net = resnet50()
    elif args.net == 'resnet101':
        from models.resnet import resnet101
        net = resnet101()
    elif args.net == 'resnet152':
        from models.resnet import resnet152
        net = resnet152()
    elif args.net == 'preactresnet18':
        from models.preactresnet import preactresnet18
        net = preactresnet18()
    elif args.net == 'preactresnet34':
        from models.preactresnet import preactresnet34
        net = preactresnet34()
    elif args.net == 'preactresnet50':
        from models.preactresnet import preactresnet50
        net = preactresnet50()
    elif args.net == 'preactresnet101':
        from models.preactresnet import preactresnet101
        net = preactresnet101()
    elif args.net == 'preactresnet152':
        from models.preactresnet import preactresnet152
        net = preactresnet152()
    elif args.net == 'resnext50':
        from models.resnext import resnext50
        net = resnext50()
    elif args.net == 'resnext101':
        from models.resnext import resnext101
        net = resnext101()
    elif args.net == 'resnext152':
        from models.resnext import resnext152
        net = resnext152()
    elif args.net == 'shufflenet':
        from models.shufflenet import shufflenet
        net = shufflenet()
    elif args.net == 'shufflenetv2':
        from models.shufflenetv2 import shufflenetv2
        net = shufflenetv2()
    elif args.net == 'squeezenet':
        from models.squeezenet import squeezenet
        net = squeezenet()
    elif args.net == 'mobilenet':
        from models.mobilenet import mobilenet
        net = mobilenet()
    elif args.net == 'mobilenetv2':
        from models.mobilenetv2 import mobilenetv2
        net = mobilenetv2()
    elif args.net == 'nasnet':
        from models.nasnet import nasnet
        net = nasnet()
    elif args.net == 'attention56':
        from models.attention import attention56
        net = attention56()
    elif args.net == 'attention92':
        from models.attention import attention92
        net = attention92()
    elif args.net == 'seresnet18':
        from models.senet import seresnet18
        net = seresnet18()
    elif args.net == 'seresnet34':
        from models.senet import seresnet34
        net = seresnet34()
    elif args.net == 'seresnet50':
        from models.senet import seresnet50
        net = seresnet50()
    elif args.net == 'seresnet101':
        from models.senet import seresnet101
        net = seresnet101()
    elif args.net == 'seresnet152':
        from models.senet import seresnet152
        net = seresnet152()
    elif args.net == 'wideresnet':
        from models.wideresidual import wideresnet
        net = wideresnet()
    elif args.net == 'stochasticdepth18':
        from models.stochasticdepth import stochastic_depth_resnet18
        net = stochastic_depth_resnet18()
    elif args.net == 'stochasticdepth34':
        from models.stochasticdepth import stochastic_depth_resnet34
        net = stochastic_depth_resnet34()
    elif args.net == 'stochasticdepth50':
        from models.stochasticdepth import stochastic_depth_resnet50
        net = stochastic_depth_resnet50()
    elif args.net == 'stochasticdepth101':
        from models.stochasticdepth import stochastic_depth_resnet101
        net = stochastic_depth_resnet101()
    elif args.net == 'normal_resnet':
        from models.normal_resnet import resnet18
        net = resnet18()
    elif args.net == 'hyper_resnet':
        from models.hypernet_main import Hypernet_Main
        net = Hypernet_Main(
            encoder="resnet18",
            hypernet_params={'vqvae_dict_size': args.dict_size})
    elif args.net == 'normal_resnet_wo_bn':
        from models.normal_resnet_wo_bn import resnet18
        net = resnet18()
    elif args.net == 'hyper_resnet_wo_bn':
        from models.hypernet_main import Hypernet_Main
        net = Hypernet_Main(
            encoder="resnet18_wobn",
            hypernet_params={'vqvae_dict_size': args.dict_size})
    else:
        print('the network name you have entered is not supported yet')
        sys.exit()

    if args.gpu:  #use_gpu
        net = net.cuda()

    return net
Esempio n. 17
0
def get_network(args, use_gpu=True, num_train=0):
    """ return given network
    """
    if args.dataset == 'cifar-10':
        num_classes = 10
    elif args.dataset == 'cifar-100':
        num_classes = 100
    else:
        num_classes = 0

    if args.ignoring:
        if args.net == 'resnet18':
            from models.resnet_ign import resnet18_ign
            criterion = nn.CrossEntropyLoss(reduction='none')
            net = resnet18_ign(criterion, num_classes=num_classes, num_train=num_train,softmax=args.softmax,isalpha=args.isalpha)

    else:
        if args.net == 'vgg16':
            from models.vgg import vgg16_bn
            net = vgg16_bn()
        elif args.net == 'vgg13':
            from models.vgg import vgg13_bn
            net = vgg13_bn()
        elif args.net == 'vgg11':
            from models.vgg import vgg11_bn
            net = vgg11_bn()
        elif args.net == 'vgg19':
            from models.vgg import vgg19_bn
            net = vgg19_bn()
        elif args.net == 'densenet121':
            from models.densenet import densenet121
            net = densenet121()
        elif args.net == 'densenet161':
            from models.densenet import densenet161
            net = densenet161()
        elif args.net == 'densenet169':
            from models.densenet import densenet169
            net = densenet169()
        elif args.net == 'densenet201':
            from models.densenet import densenet201
            net = densenet201()
        elif args.net == 'googlenet':
            from models.googlenet import googlenet
            net = googlenet()
        elif args.net == 'inceptionv3':
            from models.inceptionv3 import inceptionv3
            net = inceptionv3()
        elif args.net == 'inceptionv4':
            from models.inceptionv4 import inceptionv4
            net = inceptionv4()
        elif args.net == 'inceptionresnetv2':
            from models.inceptionv4 import inception_resnet_v2
            net = inception_resnet_v2()
        elif args.net == 'xception':
            from models.xception import xception
            net = xception()
        elif args.net == 'resnet18':
            from models.resnet import resnet18
            net = resnet18(num_classes=num_classes)
        elif args.net == 'resnet34':
            from models.resnet import resnet34
            net = resnet34(num_classes=num_classes)
        elif args.net == 'resnet50':
            from models.resnet import resnet50
            net = resnet50(num_classes=num_classes)
        elif args.net == 'resnet101':
            from models.resnet import resnet101
            net = resnet101(num_classes=num_classes)
        elif args.net == 'resnet152':
            from models.resnet import resnet152
            net = resnet152(num_classes=num_classes)
        elif args.net == 'preactresnet18':
            from models.preactresnet import preactresnet18
            net = preactresnet18()
        elif args.net == 'preactresnet34':
            from models.preactresnet import preactresnet34
            net = preactresnet34()
        elif args.net == 'preactresnet50':
            from models.preactresnet import preactresnet50
            net = preactresnet50()
        elif args.net == 'preactresnet101':
            from models.preactresnet import preactresnet101
            net = preactresnet101()
        elif args.net == 'preactresnet152':
            from models.preactresnet import preactresnet152
            net = preactresnet152()
        elif args.net == 'resnext50':
            from models.resnext import resnext50
            net = resnext50()
        elif args.net == 'resnext101':
            from models.resnext import resnext101
            net = resnext101()
        elif args.net == 'resnext152':
            from models.resnext import resnext152
            net = resnext152()
        elif args.net == 'shufflenet':
            from models.shufflenet import shufflenet
            net = shufflenet()
        elif args.net == 'shufflenetv2':
            from models.shufflenetv2 import shufflenetv2
            net = shufflenetv2()
        elif args.net == 'squeezenet':
            from models.squeezenet import squeezenet
            net = squeezenet()
        elif args.net == 'mobilenet':
            from models.mobilenet import mobilenet
            net = mobilenet()
        elif args.net == 'mobilenetv2':
            from models.mobilenetv2 import mobilenetv2
            net = mobilenetv2()
        elif args.net == 'nasnet':
            from models.nasnet import nasnet
            net = nasnet()
        elif args.net == 'attention56':
            from models.attention import attention56
            net = attention56()
        elif args.net == 'attention92':
            from models.attention import attention92
            net = attention92()
        elif args.net == 'seresnet18':
            from models.senet import seresnet18
            net = seresnet18()
        elif args.net == 'seresnet34':
            from models.senet import seresnet34
            net = seresnet34()
        elif args.net == 'seresnet50':
            from models.senet import seresnet50
            net = seresnet50()
        elif args.net == 'seresnet101':
            from models.senet import seresnet101
            net = seresnet101()
        elif args.net == 'seresnet152':
            from models.senet import seresnet152
            net = seresnet152()

        else:
            print('the network name you have entered is not supported yet')
            sys.exit()

    if use_gpu:
        net = net.cuda()

    return net
Esempio n. 18
0
def mobilenet(input_tensor,
              num_classes=1001,
              depth_multiplier=1.0,
              scope='MobilenetV2',
              conv_defs=None,
              finegrain_classification_mode=False,
              min_depth=None,
              divisible_by=None,
              **kwargs):
    """Creates mobilenet V2 network.

  Inference mode is created by default. To create training use training_scope
  below.

  with tf.contrib.slim.arg_scope(mobilenet_v2.training_scope()):
     logits, endpoints = mobilenet_v2.mobilenet(input_tensor)

  Args:
    input_tensor: The input tensor
    num_classes: number of classes
    depth_multiplier: The multiplier applied to scale number of
    channels in each layer. Note: this is called depth multiplier in the
    paper but the name is kept for consistency with slim's model builder.
    scope: Scope of the operator
    conv_defs: Allows to override default conv def.
    finegrain_classification_mode: When set to True, the model
    will keep the last layer large even for small multipliers. Following
    https://arxiv.org/abs/1801.04381
    suggests that it improves performance for ImageNet-type of problems.
      *Note* ignored if final_endpoint makes the builder exit earlier.
    min_depth: If provided, will ensure that all layers will have that
    many channels after application of depth multiplier.
    divisible_by: If provided will ensure that all layers # channels
    will be divisible by this number.
    **kwargs: passed directly to mobilenet.mobilenet:
      prediction_fn- what prediction function to use.
      reuse-: whether to reuse variables (if reuse set to true, scope
      must be given).
  Returns:
    logits/endpoints pair

  Raises:
    ValueError: On invalid arguments
  """
    if conv_defs is None:
        conv_defs = V2_DEF
    if 'multiplier' in kwargs:
        raise ValueError(
            'mobilenetv2 doesn\'t support generic '
            'multiplier parameter use "depth_multiplier" instead.')
    if finegrain_classification_mode:
        conv_defs = copy.deepcopy(conv_defs)
        if depth_multiplier < 1:
            conv_defs['spec'][-1].params['num_outputs'] /= depth_multiplier

    depth_args = {}
    # NB: do not set depth_args unless they are provided to avoid overriding
    # whatever default depth_multiplier might have thanks to arg_scope.
    if min_depth is not None:
        depth_args['min_depth'] = min_depth
    if divisible_by is not None:
        depth_args['divisible_by'] = divisible_by

    with slim.arg_scope((lib.depth_multiplier, ), **depth_args):
        return lib.mobilenet(input_tensor,
                             num_classes=num_classes,
                             conv_defs=conv_defs,
                             scope=scope,
                             multiplier=depth_multiplier,
                             **kwargs)
Esempio n. 19
0
def get_network(args):
    """ return given network
    """

    if args.model == 'vgg16':
        from models.vgg import vgg16_bn
        model = vgg16_bn()
    elif args.model == 'vgg13':
        from models.vgg import vgg13_bn
        model = vgg13_bn()
    elif args.model == 'vgg11':
        from models.vgg import vgg11_bn
        model = vgg11_bn()
    elif args.model == 'vgg19':
        from models.vgg import vgg19_bn
        model = vgg19_bn()
    elif args.model == 'densenet121':
        from models.densenet import densenet121
        model = densenet121()
    elif args.model == 'densenet161':
        from models.densenet import densenet161
        model = densenet161()
    elif args.model == 'densenet169':
        from models.densenet import densenet169
        model = densenet169()
    elif args.model == 'densenet201':
        from models.densenet import densenet201
        model = densenet201()
    elif args.model == 'googlenet':
        from models.googlenet import googlenet
        model = googlenet()
    elif args.model == 'inceptionv3':
        from models.inceptionv3 import inceptionv3
        model = inceptionv3()
    elif args.model == 'inceptionv4':
        from models.inceptionv4 import inceptionv4
        model = inceptionv4()
    elif args.model == 'inceptionresnetv2':
        from models.inceptionv4 import inception_resnet_v2
        model = inception_resnet_v2()
    elif args.model == 'xception':
        from models.xception import xception
        model = xception()
    elif args.model == 'resnet18':
        from models.resnet import resnet18
        model = resnet18()
    elif args.model == 'resnet34':
        from models.resnet import resnet34
        model = resnet34()
    elif args.model == 'resnet50':
        from models.resnet import resnet50
        model = resnet50()
    elif args.model == 'resnet101':
        from models.resnet import resnet101
        model = resnet101()
    elif args.model == 'resnet152':
        from models.resnet import resnet152
        model = resnet152()
    elif args.model == 'preactresnet18':
        from models.preactresnet import preactresnet18
        model = preactresnet18()
    elif args.model == 'preactresnet34':
        from models.preactresnet import preactresnet34
        model = preactresnet34()
    elif args.model == 'preactresnet50':
        from models.preactresnet import preactresnet50
        model = preactresnet50()
    elif args.model == 'preactresnet101':
        from models.preactresnet import preactresnet101
        model = preactresnet101()
    elif args.model == 'preactresnet152':
        from models.preactresnet import preactresnet152
        model = preactresnet152()
    elif args.model == 'resnext50':
        from models.resnext import resnext50
        model = resnext50()
    elif args.model == 'resnext101':
        from models.resnext import resnext101
        model = resnext101()
    elif args.model == 'resnext152':
        from models.resnext import resnext152
        model = resnext152()
    elif args.model == 'shufflenet':
        from models.shufflenet import shufflenet
        model = shufflenet()
    elif args.model == 'shufflenetv2':
        from models.shufflenetv2 import shufflenetv2
        model = shufflenetv2()
    elif args.model == 'squeezenet':
        from models.squeezenet import squeezenet
        model = squeezenet()
    elif args.model == 'mobilenet':
        from models.mobilenet import mobilenet
        model = mobilenet()
    elif args.model == 'mobilenetv2':
        from models.mobilenetv2 import mobilenetv2
        model = mobilenetv2()
    elif args.model == 'nasnet':
        from models.nasnet import nasnet
        model = nasnet()
    elif args.model == 'attention56':
        from models.attention import attention56
        model = attention56()
    elif args.model == 'attention92':
        from models.attention import attention92
        model = attention92()
    elif args.model == 'seresnet18':
        from models.senet import seresnet18
        model = seresnet18()
    elif args.model == 'seresnet34':
        from models.senet import seresnet34
        model = seresnet34()
    elif args.model == 'seresnet50':
        from models.senet import seresnet50
        model = seresnet50()
    elif args.model == 'seresnet101':
        from models.senet import seresnet101
        model = seresnet101()
    elif args.model == 'seresnet152':
        from models.senet import seresnet152
        model = seresnet152()
    elif args.model == 'wideresnet':
        from models.wideresidual import wideresnet
        model = wideresnet()
    elif args.model == 'stochasticdepth18':
        from models.stochasticdepth import stochastic_depth_resnet18
        model = stochastic_depth_resnet18()
    elif args.model == 'stochasticdepth34':
        from models.stochasticdepth import stochastic_depth_resnet34
        model = stochastic_depth_resnet34()
    elif args.model == 'stochasticdepth50':
        from models.stochasticdepth import stochastic_depth_resnet50
        model = stochastic_depth_resnet50()
    elif args.model == 'stochasticdepth101':
        from models.stochasticdepth import stochastic_depth_resnet101
        model = stochastic_depth_resnet101()

    else:
        print('the network name you have entered is not supported yet')
        sys.exit()

    return model