Example #1
0
 def __init__(self,
              net='xception',
              feature_layer='b3',
              num_classes=2,
              dropout_rate=0.5,
              pretrained=False):
     super().__init__()
     self.num_classes = num_classes
     if 'xception' in net:
         self.net = xception(num_classes, escape=feature_layer)
     elif net.split('-')[0] == 'efficientnet':
         self.net = EfficientNet.from_pretrained(net,
                                                 advprop=True,
                                                 num_classes=num_classes,
                                                 escape=feature_layer)
     self.feature_layer = feature_layer
     with torch.no_grad():
         layers = self.net(torch.zeros(1, 3, 100, 100))
     num_features = layers[self.feature_layer].shape[1]
     if pretrained:
         a = torch.load(pretrained, map_location='cpu')
         keys = {
             i: a['state_dict'][i]
             for i in a.keys() if i.startswith('net')
         }
         if not keys:
             keys = a['state_dict']
         load_state(self.net, keys)
     self.pooling = nn.AdaptiveAvgPool2d(1)
     self.texture_enhance = Texture_Enhance_v2(num_features, 1)
     self.num_features = self.texture_enhance.output_features
     self.fc = nn.Linear(self.num_features, self.num_classes)
     self.dropout = nn.Dropout(dropout_rate)
Example #2
0
 def __init__(self):
     super(Myxception, self).__init__()
     from models.xception import xception
     model = net=xception(pretrained = True)
     self.resnet_lay=nn.Sequential(*list(model.children())[:-1])
     self.conv1_lay = nn.Conv2d(2048, 512, kernel_size = (1,1),stride=(1,1))
     self.relu1_lay = nn.ReLU(inplace = True)
     self.drop_lay = nn.Dropout2d(0.5)
     self.global_average = nn.AdaptiveAvgPool2d((1,1))
     self.fc_Linear_lay2 = nn.Linear(512,2)
Example #3
0
 def __init__(self, net='xception',feature_layer='b3',attention_layer='final',num_classes=2, M=8,mid_dims=256,\
 dropout_rate=0.5,drop_final_rate=0.5, pretrained=False,alpha=0.05,size=(380,380),margin=1,inner_margin=[0.01,0.02]):
     super(MAT, self).__init__()
     self.num_classes = num_classes
     self.M = M
     if 'xception' in net:
         self.net = xception(num_classes)
     elif net.split('-')[0] == 'efficientnet':
         self.net = EfficientNet.from_pretrained(net,
                                                 advprop=True,
                                                 num_classes=num_classes)
     self.feature_layer = feature_layer
     self.attention_layer = attention_layer
     with torch.no_grad():
         layers = self.net(torch.zeros(1, 3, size[0], size[1]))
     num_features = layers[self.feature_layer].shape[1]
     self.mid_dims = mid_dims
     if pretrained:
         a = torch.load(pretrained, map_location='cpu')
         keys = {
             i: a['state_dict'][i]
             for i in a.keys() if i.startswith('net')
         }
         if not keys:
             keys = a['state_dict']
         self.net.load_state_dict(keys, strict=False)
     self.attentions = AttentionMap(layers[self.attention_layer].shape[1],
                                    self.M)
     self.atp = AttentionPooling()
     self.texture_enhance = Texture_Enhance_v2(num_features, M)
     self.num_features = self.texture_enhance.output_features
     self.num_features_d = self.texture_enhance.output_features_d
     self.projection_local = nn.Sequential(
         nn.Linear(M * self.num_features, mid_dims), nn.Hardswish(),
         nn.Linear(mid_dims, mid_dims))
     self.project_final = nn.Linear(layers['final'].shape[1], mid_dims)
     self.ensemble_classifier_fc = nn.Sequential(
         nn.Linear(mid_dims * 2, mid_dims), nn.Hardswish(),
         nn.Linear(mid_dims, num_classes))
     self.auxiliary_loss = Auxiliary_Loss_v2(M, self.num_features_d,
                                             num_classes, alpha, margin,
                                             inner_margin)
     self.dropout = nn.Dropout2d(dropout_rate, inplace=True)
     self.dropout_final = nn.Dropout(drop_final_rate, inplace=True)
Example #4
0
def select_model(model_name: str,
                 pretrain: bool,
                 n_class: int,
                 onehot: int,
                 onehot2=0):
    if model_name == 'resnet50':
        model = ResNet50(onehot=onehot, onehot2=onehot2)
    elif model_name == "resnext":
        model = resnext50_32x4d(onehot=onehot, onehot2=onehot2)
    elif model_name == "resnet101":
        model = resnet101(onehot=onehot, onehot2=onehot2)
    elif model_name == "resnext101":
        model = resnext101_32x8d(onehot=onehot, onehot2=onehot2)
    elif model_name == "resnext101_32x16d":
        model = resnext101_32x16d(onehot=onehot, onehot2=onehot2)
    elif model_name == 'densenet':
        model = DenseNet121(onehot=onehot, onehot2=onehot2)
    elif model_name == 'densenet201':
        model = densenet201(onehot=onehot, onehot2=onehot2)
    elif model_name == "nest200":
        model = resnest200(onehot=onehot, onehot2=onehot2)
    elif model_name == "nest101":
        model = resnest101(onehot=onehot, onehot2=onehot2)
    elif model_name == "efficient_b2":
        model = EfficientNet_B2(onehot=onehot, onehot2=onehot2)
    elif model_name == "efficient_b5":
        model = EfficientNet_B5(onehot=onehot, onehot2=onehot2)
    elif model_name == "efficient_b6":
        model = EfficientNet_B6(onehot=onehot, onehot2=onehot2)
    elif model_name == "xception":
        model = xception(onehot=onehot, onehot2=onehot2)
    else:
        raise NotImplementedError(
            'Please select in [resnext101, resnext101_32x16d, nest200, densenet201]'
        )
    return model
Example #5
0
def get_network(args, use_gpu=True):
    """ return given network
    """
    if args.net == 'mobilenet':
        from models.mobilenet import mobilenet
        net = mobilenet(args)
    elif args.net == 'mobilenetv2':
        from models.mobilenetv2 import mobilenetv2
        net = mobilenetv2(args)
    elif args.net == 'vgg13':
        from models.vgg import vgg13_bn
        net = vgg13_bn(args)
    elif args.net == 'vgg11':
        from models.vgg import vgg11_bn
        net = vgg11_bn(args)
    elif args.net == 'vgg19':
        # from models.vgg import vgg19_bn
        # net = vgg19_bn(args)
        from torchvision.models import vgg19_bn
        import torch.nn as nn
        net = vgg19_bn(pretrained=True)
        net.classifier[6] = nn.Linear(4096, args.nc)
    elif args.net == 'densenet121':
        from models.densenet import densenet121
        net = densenet121(args)
    elif args.net == 'densenet161':
        from models.densenet import densenet161
        net = densenet161(args)
    elif args.net == 'densenet169':
        from models.densenet import densenet169
        net = densenet169(args)
    elif args.net == 'densenet201':
        from models.densenet import densenet201
        net = densenet201(args)
    elif args.net == 'googlenet':
        from models.googlenet import googlenet
        net = googlenet(args)
    elif args.net == 'inceptionv3':
        from models.inceptionv3 import inceptionv3
        net = inceptionv3(args)
    elif args.net == 'inceptionv4':
        from models.inceptionv4 import inceptionv4
        net = inceptionv4(args)
    elif args.net == 'inceptionresnetv2':
        from models.inceptionv4 import inception_resnet_v2
        net = inception_resnet_v2(args)
    elif args.net == 'xception':
        from models.xception import xception
        net = xception(args)
    elif args.net == 'resnet18':
        # from models.resnet import resnet18
        # net = resnet18(args)
        from torchvision.models import resnet18
        import torch.nn as nn
        net = resnet18(pretrained=True)
        net.fc = nn.Linear(512, args.nc)
    elif args.net == 'resnet34':
        from models.resnet import resnet34
        net = resnet34(args)
    elif args.net == 'resnet50':
        from models.resnet import resnet50
        net = resnet50(args)
    elif args.net == 'resnet101':
        from models.resnet import resnet101
        net = resnet101(args)
    elif args.net == 'resnet152':
        from models.resnet import resnet152
        net = resnet152(args)
    elif args.net == 'preactresnet18':
        from models.preactresnet import preactresnet18
        net = preactresnet18(args)
    elif args.net == 'preactresnet34':
        from models.preactresnet import preactresnet34
        net = preactresnet34(args)
    elif args.net == 'preactresnet50':
        from models.preactresnet import preactresnet50
        net = preactresnet50(args)
    elif args.net == 'preactresnet101':
        from models.preactresnet import preactresnet101
        net = preactresnet101(args)
    elif args.net == 'preactresnet152':
        from models.preactresnet import preactresnet152
        net = preactresnet152(args)
    elif args.net == 'resnext50':
        from models.resnext import resnext50
        net = resnext50(args)
    elif args.net == 'resnext101':
        from models.resnext import resnext101
        net = resnext101(args)
    elif args.net == 'resnext152':
        from models.resnext import resnext152
        net = resnext152(args)
    elif args.net == 'shufflenet':
        from models.shufflenet import shufflenet
        net = shufflenet(args)
    elif args.net == 'shufflenetv2':
        from models.shufflenetv2 import shufflenetv2
        net = shufflenetv2(args)
    elif args.net == 'squeezenet':
        from models.squeezenet import squeezenet
        net = squeezenet(args)

    elif args.net == 'mobilenet':
        from models.mobilenet import mobilenet
        net = mobilenet(args)
    elif args.net == 'mobilenetv2':
        from models.mobilenetv2 import mobilenetv2
        net = mobilenetv2(args)
    elif args.net == 'mobilenetv3':
        from models.mobilenetv3 import mobileNetv3
        net = mobileNetv3(args)
    elif args.net == 'mobilenetv3_l':
        from models.mobilenetv3 import mobileNetv3
        net = mobileNetv3(args, mode='large')
    elif args.net == 'mobilenetv3_s':
        from models.mobilenetv3 import mobileNetv3
        net = mobileNetv3(args, mode='small')
    elif args.net == 'nasnet':
        from models.nasnet import nasnetalarge
        net = nasnetalarge(args)
    elif args.net == 'attention56':
        from models.attention import attention56
        net = attention56(args)
    elif args.net == 'attention92':
        from models.attention import attention92
        net = attention92(args)
    elif args.net == 'seresnet18':
        from models.senet import seresnet18
        net = seresnet18(args)
    elif args.net == 'seresnet34':
        from models.senet import seresnet34
        net = seresnet34(args)
    elif args.net == 'seresnet50':
        from models.senet import seresnet50
        net = seresnet50(args)
    elif args.net == 'seresnet101':
        from models.senet import seresnet101
        net = seresnet101(args)
    elif args.net == 'seresnet152':
        from models.senet import seresnet152
        net = seresnet152(args)
    elif args.net.lower() == 'sqnxt_23_1x':
        from models.SqueezeNext import SqNxt_23_1x
        net = SqNxt_23_1x(args)
    elif args.net.lower() == 'sqnxt_23_1xv5':
        from models.SqueezeNext import SqNxt_23_1x_v5
        net = SqNxt_23_1x_v5(args)
    elif args.net.lower() == 'sqnxt_23_2x':
        from models.SqueezeNext import SqNxt_23_2x
        net = SqNxt_23_2x(args)
    elif args.net.lower() == 'sqnxt_23_2xv5':
        from models.SqueezeNext import SqNxt_23_2x_v5
        net = SqNxt_23_2x_v5(args)
    elif args.net.lower() == 'mnasnet':
        # from models.MnasNet import mnasnet
        # net = mnasnet(args)
        from models.nasnet_mobile import nasnet_Mobile
        net = nasnet_Mobile(args)
    elif args.net == 'efficientnet_b0':
        from models.efficientnet import efficientnet_b0
        net = efficientnet_b0(args)
    elif args.net == 'efficientnet_b1':
        from models.efficientnet import efficientnet_b1
        net = efficientnet_b1(args)
    elif args.net == 'efficientnet_b2':
        from models.efficientnet import efficientnet_b2
        net = efficientnet_b2(args)
    elif args.net == 'efficientnet_b3':
        from models.efficientnet import efficientnet_b3
        net = efficientnet_b3(args)
    elif args.net == 'efficientnet_b4':
        from models.efficientnet import efficientnet_b4
        net = efficientnet_b4(args)
    elif args.net == 'efficientnet_b5':
        from models.efficientnet import efficientnet_b5
        net = efficientnet_b5(args)
    elif args.net == 'efficientnet_b6':
        from models.efficientnet import efficientnet_b6
        net = efficientnet_b6(args)
    elif args.net == 'efficientnet_b7':
        from models.efficientnet import efficientnet_b7
        net = efficientnet_b7(args)
    elif args.net == 'mlp':
        from models.mlp import MLPClassifier
        net = MLPClassifier(args)
    elif args.net == 'alexnet':
        from torchvision.models import alexnet
        import torch.nn as nn
        net = alexnet(pretrained=True)
        net.classifier[6] = nn.Linear(4096, args.nc)
    elif args.net == 'lambda18':
        from models._lambda import LambdaResnet18
        net = LambdaResnet18(num_classes=args.nc, channels=args.cs)
    elif args.net == 'lambda34':
        from models._lambda import LambdaResnet34
        net = LambdaResnet34(num_classes=args.nc, channels=args.cs)
    elif args.net == 'lambda50':
        from models._lambda import LambdaResnet50
        net = LambdaResnet50(num_classes=args.nc, channels=args.cs)
    elif args.net == 'lambda101':
        from models._lambda import LambdaResnet101
        net = LambdaResnet101(num_classes=args.nc)
    elif args.net == 'lambda152':
        from models._lambda import LambdaResnet152
        net = LambdaResnet152(num_classes=args.nc, channels=args.cs)
    else:
        print('the network name you have entered is not supported yet')
        sys.exit()

    if use_gpu:
        net = net.cuda()

    return net
Example #6
0
def get_network(args):
    """ return given network
    """
    if args.task == 'cifar10':
        nclass = 10
    elif args.task == 'cifar100':
        nclass = 100
    #Yang added none bn vggs
    if args.net == 'vgg16':
        from models.vgg import vgg16
        net = vgg16(num_classes=nclass)
    elif args.net == 'vgg13':
        from models.vgg import vgg13
        net = vgg13(num_classes=nclass)
    elif args.net == 'vgg11':
        from models.vgg import vgg11
        net = vgg11(num_classes=nclass)
    elif args.net == 'vgg19':
        from models.vgg import vgg19
        net = vgg19(num_classes=nclass)

    elif args.net == 'vgg16bn':
        from models.vgg import vgg16_bn
        net = vgg16_bn(num_classes=nclass)
    elif args.net == 'vgg13bn':
        from models.vgg import vgg13_bn
        net = vgg13_bn(num_classes=nclass)
    elif args.net == 'vgg11bn':
        from models.vgg import vgg11_bn
        net = vgg11_bn(num_classes=nclass)
    elif args.net == 'vgg19bn':
        from models.vgg import vgg19_bn
        net = vgg19_bn(num_classes=nclass)

    elif args.net == 'densenet121':
        from models.densenet import densenet121
        net = densenet121()
    elif args.net == 'densenet161':
        from models.densenet import densenet161
        net = densenet161()
    elif args.net == 'densenet169':
        from models.densenet import densenet169
        net = densenet169()
    elif args.net == 'densenet201':
        from models.densenet import densenet201
        net = densenet201()
    elif args.net == 'googlenet':
        from models.googlenet import googlenet
        net = googlenet(num_classes=nclass)
    elif args.net == 'inceptionv3':
        from models.inceptionv3 import inceptionv3
        net = inceptionv3()
    elif args.net == 'inceptionv4':
        from models.inceptionv4 import inceptionv4
        net = inceptionv4()
    elif args.net == 'inceptionresnetv2':
        from models.inceptionv4 import inception_resnet_v2
        net = inception_resnet_v2()
    elif args.net == 'xception':
        from models.xception import xception
        net = xception(num_classes=nclass)
    elif args.net == 'scnet':
        from models.sphereconvnet import sphereconvnet
        net = sphereconvnet(num_classes=nclass)
    elif args.net == 'sphereresnet18':
        from models.sphereconvnet import resnet18
        net = resnet18(num_classes=nclass)
    elif args.net == 'sphereresnet32':
        from models.sphereconvnet import sphereresnet32
        net = sphereresnet32(num_classes=nclass)
    elif args.net == 'plainresnet32':
        from models.sphereconvnet import plainresnet32
        net = plainresnet32(num_classes=nclass)
    elif args.net == 'ynet18':
        from models.ynet import resnet18
        net = resnet18(num_classes=nclass)
    elif args.net == 'ynet34':
        from models.ynet import resnet34
        net = resnet34(num_classes=nclass)
    elif args.net == 'ynet50':
        from models.ynet import resnet50
        net = resnet50(num_classes=nclass)
    elif args.net == 'ynet101':
        from models.ynet import resnet101
        net = resnet101(num_classes=nclass)
    elif args.net == 'ynet152':
        from models.ynet import resnet152
        net = resnet152(num_classes=nclass)

    elif args.net == 'resnet18':
        from models.resnet import resnet18
        net = resnet18(num_classes=nclass)
    elif args.net == 'resnet34':
        from models.resnet import resnet34
        net = resnet34(num_classes=nclass)
    elif args.net == 'resnet50':
        from models.resnet import resnet50
        net = resnet50(num_classes=nclass)
    elif args.net == 'resnet101':
        from models.resnet import resnet101
        net = resnet101(num_classes=nclass)
    elif args.net == 'resnet152':
        from models.resnet import resnet152
        net = resnet152(num_classes=nclass)
    elif args.net == 'preactresnet18':
        from models.preactresnet import preactresnet18
        net = preactresnet18(num_classes=nclass)
    elif args.net == 'preactresnet34':
        from models.preactresnet import preactresnet34
        net = preactresnet34(num_classes=nclass)
    elif args.net == 'preactresnet50':
        from models.preactresnet import preactresnet50
        net = preactresnet50(num_classes=nclass)
    elif args.net == 'preactresnet101':
        from models.preactresnet import preactresnet101
        net = preactresnet101(num_classes=nclass)
    elif args.net == 'preactresnet152':
        from models.preactresnet import preactresnet152
        net = preactresnet152(num_classes=nclass)
    elif args.net == 'resnext50':
        from models.resnext import resnext50
        net = resnext50(num_classes=nclass)
    elif args.net == 'resnext101':
        from models.resnext import resnext101
        net = resnext101(num_classes=nclass)
    elif args.net == 'resnext152':
        from models.resnext import resnext152
        net = resnext152(num_classes=nclass)
    elif args.net == 'shufflenet':
        from models.shufflenet import shufflenet
        net = shufflenet()
    elif args.net == 'shufflenetv2':
        from models.shufflenetv2 import shufflenetv2
        net = shufflenetv2()
    elif args.net == 'squeezenet':
        from models.squeezenet import squeezenet
        net = squeezenet()
    elif args.net == 'mobilenet':
        from models.mobilenet import mobilenet
        net = mobilenet(num_classes=nclass)
    elif args.net == 'mobilenetv2':
        from models.mobilenetv2 import mobilenetv2
        net = mobilenetv2(num_classes=nclass)
    elif args.net == 'nasnet':
        from models.nasnet import nasnet
        net = nasnet(num_classes=nclass)
    elif args.net == 'attention56':
        from models.attention import attention56
        net = attention56()
    elif args.net == 'attention92':
        from models.attention import attention92
        net = attention92()
    elif args.net == 'seresnet18':
        from models.senet import seresnet18
        net = seresnet18(num_classes=nclass)
    elif args.net == 'seresnet34':
        from models.senet import seresnet34
        net = seresnet34(num_classes=nclass)
    elif args.net == 'seresnet50':
        from models.senet import seresnet50
        net = seresnet50(num_classes=nclass)
    elif args.net == 'seresnet101':
        from models.senet import seresnet101
        net = seresnet101(num_classes=nclass)
    elif args.net == 'seresnet152':
        from models.senet import seresnet152
        net = seresnet152(num_classes=nclass)

    else:
        print('the network name you have entered is not supported yet')
        sys.exit()

    if args.gpu:  #use_gpu
        net = net.cuda()

    return net
Example #7
0
def load_model(model_name='resnet50',resume='Best',start_epoch=0,cn=3,
               save_dir='saved_models/',width=32,start=8,cls_number=10,avg_number=1,gpus=[0,1,2,3,4,5,6,7],kfold = 1,model_times=0,train=True):
    
    load_dict = None
    #load_dict = True if cn == 3 else None

    if model_name == 'resnet50':
        model = resnet50(num_classes=cls_number,pretrained=load_dict)
    elif model_name == 'resnet101':
        model = resnet101(num_classes=cls_number,pretrained=load_dict)
    elif model_name == 'resnet152':
        model = resnet152(num_classes=cls_number,pretrained=load_dict)
    elif model_name == 'densenet161':
        model = densenet161(num_classes=cls_number,pretrained=load_dict)
    elif model_name == 'xception':
        model = xception(num_classes=cls_number,pretrained=load_dict)
        model.conv1 = nn.Conv2d(cn, 32, kernel_size=(3, 3), stride=(2, 2), bias=False)
    elif model_name == 'inception_v3':
        model = inception_v3(num_classes=cls_number,pretrained=load_dict)
        model.Conv2d_1a_3x3.conv = nn.Conv2d(cn, 32, kernel_size=(3, 3), stride=(2, 2), bias=False)
    elif model_name == 'seinception_v3':
        model = se_inception_v3(num_classes=cls_number)
        model.model.Conv2d_1a_3x3.conv = nn.Conv2d(cn, 32, kernel_size=(3, 3), stride=(2, 2), bias=False)
    elif model_name == 'inception_v4':
        model = inceptionv4(num_classes=cls_number,pretrained=load_dict)
        model.features[0].conv = nn.Conv2d(cn, 32, kernel_size=(3, 3), stride=(2, 2), bias=False)
    elif model_name == 'inceptionresnetv2':
        model = inceptionresnetv2(num_classes=cls_number,pretrained=load_dict)
        model.conv2d_1a.conv = nn.Conv2d(cn, 32, kernel_size=(3, 3), stride=(2, 2), bias=False)
    elif model_name == 'seresnet50':
        model = se_resnet50(num_classes=cls_number,pretrained=load_dict)
    elif model_name == 'seresnet101':
        model = se_resnet101(num_classes=cls_number,pretrained=load_dict)
    elif model_name == 'seresnet152':
        model = se_resnet152(num_classes=cls_number,pretrained=load_dict)
    elif model_name == 'seresnext50':
        model = se_resnext50_32x4d(num_classes=cls_number,pretrained=load_dict)
    elif model_name == 'seresnext101':
        model = se_resnext101_32x4d(num_classes=cls_number,pretrained=load_dict)
    elif model_name == 'resnet50-101':
        model = SimpleNet()
    elif model_name == 'seresnet20':
        model = se_resnet20(num_classes=cls_number)
    elif model_name == 'seresnet32':
        model = se_resnet32(num_classes=cls_number)
    elif model_name == 'seresnet18':
        model = se_resnet18(num_classes=cls_number)
    elif model_name == 'seresnet34':
        model = se_resnet34(num_classes=cls_number)
    elif model_name == 'senet154':
        model = senet154(num_classes=cls_number,pretrained=load_dict)
        model.layer0.conv1 = nn.Conv2d(cn, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
    elif model_name == 'nasnet':
        model = nasnetalarge(num_classes=cls_number,pretrained=load_dict)
        model.conv0.conv = nn.Conv2d(cn, 96, kernel_size=(3, 3), stride=(2, 2), bias=False)
    elif model_name == 'dpn98':
        model = dpn98(num_classes=cls_number,pretrained=load_dict)
    elif model_name == 'dpn107':
        model = dpn107(num_classes=cls_number,pretrained=load_dict)
    elif model_name == 'dpn92':
        model = dpn92(num_classes=cls_number,pretrained=load_dict)
    elif model_name == 'polynet':
        model = polynet(num_classes=cls_number,pretrained=load_dict)
        model.stem.conv1[0].conv = nn.Conv2d(cn, 32, kernel_size=(3, 3), stride=(2, 2), bias=False) 
    elif model_name == 'pnasnet':
        model = pnasnet5large(num_classes=cls_number,pretrained=load_dict)
        model.conv_0.conv = nn.Conv2d(cn, 96, kernel_size=(3, 3), stride=(2, 2), bias=False) 
    
    #print(model)

    if '-' not in model_name and load_dict != True:
      if model_name in ['dpn98',]:
        model.features.conv1_1.conv = nn.Conv2d(cn, 96, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
      elif model_name in ['dpn92',]:
        model.features.conv1_1.conv = nn.Conv2d(cn, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
      elif model_name in ['seresnet20','seresnet32']:
        model.conv1 = nn.Conv2d(cn, 16, kernel_size=3, stride=1, padding=1, bias=False)
      elif model_name in ['seresnet18','seresnet34']:
        model.conv1 = nn.Conv2d(cn, 64, kernel_size=7, stride=2, padding=3, bias=False)
      elif 'seresnext' in model_name:
        model.layer0.conv1 = nn.Conv2d(cn, 64, kernel_size=7, stride=2, padding=3, bias=False)
      elif 'seresnet' in model_name:
        model.layer0.conv1 = nn.Conv2d(cn, 64, kernel_size=7, stride=2, padding=3, bias=False)
      elif 'resnet' in model_name:
        model.conv1 = nn.Conv2d(cn, 64, kernel_size=7, stride=2, padding=3, bias=False)
        #model.fc = torch.nn.Linear(model.fc.in_features,cls_number)
      elif 'densenet' in model_name:
        model.features.conv0 = nn.Conv2d(cn, 96, kernel_size=7, stride=2, padding=3, bias=False)

      model.avgpool = torch.nn.AdaptiveAvgPool2d(output_size=1)
    else:
      pass

    #print(model)

    load_model = False
    #if model_name == 'resnet50':
    if load_dict != True and model_name == 'resnet50' and 0:
       base_model = resnet50(pretrained=True)
       model_dict = model.state_dict()
       new_state_dict = OrderedDict()
       for k, v in base_model.state_dict().items()[1:-2]:
           new_state_dict[k] = v
       model_dict.update(new_state_dict)
       model.load_state_dict(model_dict)
       print 'load imagenet'
       load_model = True
    
    model_ = model_name + '_' + \
                   str(width) + '_' + str(start) + '_' + str(cn)
    if kfold > 1:
       model_prefix = save_dir + str(model_times) + '_' + model_
    else:
       model_prefix = save_dir + model_
    
    if resume == 'Best' and avg_number >= 1:
        weight_path = glob(model_prefix + '*pth')
        cur_index = np.argsort(-np.array([float(cur_p.split('/')[-1].split('[')[-1].split(']')[0]) for cur_p in weight_path]))
        new_state_dict = OrderedDict()
        if len(weight_path) == 0:
            resume = ''
        elif avg_number == 1:
            resume = weight_path[0]
        else:
          for cnt,index in zip(range(avg_number),cur_index[:avg_number]):
            cur_resume = weight_path[index]
            print cur_resume
            model.load_state_dict(torch.load(cur_resume))
            for k, v in model.state_dict().items():
              if cnt == 0:
                new_state_dict[k] = v
              else:
                new_state_dict[k] = new_state_dict[k] + v
              if cnt == avg_number - 1:
                new_state_dict[k] = new_state_dict[k] / float(avg_number)
          model.load_state_dict(new_state_dict)
        if train == False:
          for index in cur_index[avg_number + 2:]:
            cur_resume = weight_path[index]
            print('remove resume %s ' %cur_resume)
            os.remove(cur_resume)
    if resume != '' and avg_number == 1:
        start_epoch = int(resume.split('-')[-3])
        #print('resuming finetune from %s'%resume)
        logging.info('resuming finetune from %s'%resume)
        model.load_state_dict(torch.load(resume))

    print('start-epoch : ',start_epoch)

    cuda_avail = torch.cuda.is_available()
    if cuda_avail:
       print 'cuda_avail: True'
       if len(gpus) > 1:
           model = torch.nn.DataParallel(model,device_ids=gpus).cuda()
       else:
           model = model.cuda()
    return model,start_epoch
Example #8
0
def get_network(args):
    """ return given network
    """

    if args.model == 'vgg16':
        from models.vgg import vgg16_bn
        model = vgg16_bn()
    elif args.model == 'vgg13':
        from models.vgg import vgg13_bn
        model = vgg13_bn()
    elif args.model == 'vgg11':
        from models.vgg import vgg11_bn
        model = vgg11_bn()
    elif args.model == 'vgg19':
        from models.vgg import vgg19_bn
        model = vgg19_bn()
    elif args.model == 'densenet121':
        from models.densenet import densenet121
        model = densenet121()
    elif args.model == 'densenet161':
        from models.densenet import densenet161
        model = densenet161()
    elif args.model == 'densenet169':
        from models.densenet import densenet169
        model = densenet169()
    elif args.model == 'densenet201':
        from models.densenet import densenet201
        model = densenet201()
    elif args.model == 'googlenet':
        from models.googlenet import googlenet
        model = googlenet()
    elif args.model == 'inceptionv3':
        from models.inceptionv3 import inceptionv3
        model = inceptionv3()
    elif args.model == 'inceptionv4':
        from models.inceptionv4 import inceptionv4
        model = inceptionv4()
    elif args.model == 'inceptionresnetv2':
        from models.inceptionv4 import inception_resnet_v2
        model = inception_resnet_v2()
    elif args.model == 'xception':
        from models.xception import xception
        model = xception()
    elif args.model == 'resnet18':
        from models.resnet import resnet18
        model = resnet18()
    elif args.model == 'resnet34':
        from models.resnet import resnet34
        model = resnet34()
    elif args.model == 'resnet50':
        from models.resnet import resnet50
        model = resnet50()
    elif args.model == 'resnet101':
        from models.resnet import resnet101
        model = resnet101()
    elif args.model == 'resnet152':
        from models.resnet import resnet152
        model = resnet152()
    elif args.model == 'preactresnet18':
        from models.preactresnet import preactresnet18
        model = preactresnet18()
    elif args.model == 'preactresnet34':
        from models.preactresnet import preactresnet34
        model = preactresnet34()
    elif args.model == 'preactresnet50':
        from models.preactresnet import preactresnet50
        model = preactresnet50()
    elif args.model == 'preactresnet101':
        from models.preactresnet import preactresnet101
        model = preactresnet101()
    elif args.model == 'preactresnet152':
        from models.preactresnet import preactresnet152
        model = preactresnet152()
    elif args.model == 'resnext50':
        from models.resnext import resnext50
        model = resnext50()
    elif args.model == 'resnext101':
        from models.resnext import resnext101
        model = resnext101()
    elif args.model == 'resnext152':
        from models.resnext import resnext152
        model = resnext152()
    elif args.model == 'shufflenet':
        from models.shufflenet import shufflenet
        model = shufflenet()
    elif args.model == 'shufflenetv2':
        from models.shufflenetv2 import shufflenetv2
        model = shufflenetv2()
    elif args.model == 'squeezenet':
        from models.squeezenet import squeezenet
        model = squeezenet()
    elif args.model == 'mobilenet':
        from models.mobilenet import mobilenet
        model = mobilenet()
    elif args.model == 'mobilenetv2':
        from models.mobilenetv2 import mobilenetv2
        model = mobilenetv2()
    elif args.model == 'nasnet':
        from models.nasnet import nasnet
        model = nasnet()
    elif args.model == 'attention56':
        from models.attention import attention56
        model = attention56()
    elif args.model == 'attention92':
        from models.attention import attention92
        model = attention92()
    elif args.model == 'seresnet18':
        from models.senet import seresnet18
        model = seresnet18()
    elif args.model == 'seresnet34':
        from models.senet import seresnet34
        model = seresnet34()
    elif args.model == 'seresnet50':
        from models.senet import seresnet50
        model = seresnet50()
    elif args.model == 'seresnet101':
        from models.senet import seresnet101
        model = seresnet101()
    elif args.model == 'seresnet152':
        from models.senet import seresnet152
        model = seresnet152()
    elif args.model == 'wideresnet':
        from models.wideresidual import wideresnet
        model = wideresnet()
    elif args.model == 'stochasticdepth18':
        from models.stochasticdepth import stochastic_depth_resnet18
        model = stochastic_depth_resnet18()
    elif args.model == 'stochasticdepth34':
        from models.stochasticdepth import stochastic_depth_resnet34
        model = stochastic_depth_resnet34()
    elif args.model == 'stochasticdepth50':
        from models.stochasticdepth import stochastic_depth_resnet50
        model = stochastic_depth_resnet50()
    elif args.model == 'stochasticdepth101':
        from models.stochasticdepth import stochastic_depth_resnet101
        model = stochastic_depth_resnet101()

    else:
        print('the network name you have entered is not supported yet')
        sys.exit()

    return model
Example #9
0
    cls = [
        x for x in os.listdir(data_path)
        if os.path.isdir(os.path.join(data_path, x))
    ]
    num_class = len(cls)
    models = {
        "vgg16": vgg.vgg16_bn(num_class),
        "vgg19": vgg.vgg19_bn(num_class),
        "densenet121": densenet.densenet121(num_class),
        "densenet161": densenet.densenet161(num_class),
        "resnet34": resnet.resnet34(num_class),
        "resnet50": resnet.resnet50(num_class),
        "resnet101": resnet.resnet101(num_class),
        "seresnet34": senet.seresnet34(num_class),
        "seresnet50": senet.seresnet50(num_class),
        "seresnet101": senet.seresnet101(num_class),
        "resnext34": resnext.resnext34(num_class),
        "resnext50": resnext.resnext50(num_class),
        "resnext101": resnext.resnext101(num_class),
        "shufflenet": shufflenet.shufflenet(num_class),
        "xception": xception.xception(num_class)
    }
    for net_name in models.keys():
        writer = SummaryWriter('./runs/%s_%s/' % (ds, net_name))
        model = models[net_name]
        logger.add('./log/%s_%s_{time}.log' % (ds, net_name), level="INFO")
        logger.info("net:%s\t dataset:%s\t num_class:%d" %
                    (net_name, ds, num_class))
        train()
        writer.close()
Example #10
0
def get_network(args, use_gpu=True):
    """ return given network
    """

    if args.net == 'vgg16':
        net = torchvision.models.vgg16_bn(pretrained=args.bool_pretrained)
        net.classifier[6] = nn.Linear(4096, args.num_classes, bias=True)
    elif args.net == 'vgg13':
        net = torchvision.models.vgg13_bn(pretrained=args.bool_pretrained)
        net.classifier[6] = nn.Linear(4096, args.num_classes, bias=True)
    elif args.net == 'vgg11':
        net = torchvision.models.vgg11_bn(pretrained=args.bool_pretrained)
        net.classifier[6] = nn.Linear(4096, args.num_classes, bias=True)
    elif args.net == 'vgg19':
        net = torchvision.models.vgg19_bn(pretrained=args.bool_pretrained)
        net.classifier[6] = nn.Linear(4096, args.num_classes, bias=True)

    ####effcientnet
    elif args.net == 'efficientnet-b5':
        from efficientnet_pytorch import EfficientNet
        if args.bool_pretrained == True:
            net = EfficientNet.from_pretrained('efficientnet-b5')
        else:
            net = EfficientNet.from_name('efficientnet-b5')
        net._fc = nn.Linear(2048, args.num_classes, bias=True)

    elif args.net == 'densenet121':
        from models.densenet import densenet121
        net = densenet121()
    elif args.net == 'densenet161':
        net = torchvision.models.densenet161(pretrained=args.bool_pretrained)
        in_features = net.classifier.in_features
        net.classifier = nn.Linear(in_features, args.num_classes, bias=True)
    elif args.net == 'densenet169':
        net = torchvision.models.densenet169(pretrained=args.bool_pretrained)
        in_features = net.classifier.in_features
        net.classifier = nn.Linear(in_features, args.num_classes, bias=True)
    elif args.net == 'densenet201':
        net = torchvision.models.densenet201(pretrained=args.bool_pretrained)
        in_features = net.classifier.in_features
        net.classifier = nn.Linear(in_features, args.num_classes, bias=True)

    elif args.net == 'googlenet':
        from models.googlenet import googlenet
        net = googlenet()
    elif args.net == 'inceptionv3':
        from models.inceptionv3 import inceptionv3
        net = inceptionv3()
    elif args.net == 'inceptionv4':
        from models.inceptionv4 import inceptionv4
        net = inceptionv4()
    elif args.net == 'inceptionresnetv2':
        from models.inceptionv4 import inception_resnet_v2
        net = inception_resnet_v2()
    elif args.net == 'xception':
        from models.xception import xception
        net = xception()

    ################## ResNet ########################################################
    elif args.net == 'resnet18':
        from models.resnet import resnet
        net = resnet(args.num_classes, 2, args.pretrained, args.net)
    elif args.net == 'resnet34':
        from models.resnet import resnet
        net = resnet(args.num_classes, 2, args.pretrained, args.net)
    elif args.net == 'resnet50':
        from models.resnet import resnet
        net = resnet(args.num_classes, 2, args.pretrained, args.net)
    elif args.net == 'resnet101':
        from models.resnet import resnet
        net = resnet(args.num_classes, 2, args.pretrained, args.net)
    elif args.net == 'resnet152':
        from models.resnet import resnet
        net = resnet(args.num_classes, 2, args.pretrained, args.net)

    ##################################################################################
    elif args.net == 'preactresnet18':
        from models.preactresnet import preactresnet18
        net = preactresnet18()
    elif args.net == 'preactresnet34':
        from models.preactresnet import preactresnet34
        net = preactresnet34()
    elif args.net == 'preactresnet50':
        from models.preactresnet import preactresnet50
        net = preactresnet50()
    elif args.net == 'preactresnet101':
        from models.preactresnet import preactresnet101
        net = preactresnet101()
    elif args.net == 'preactresnet152':
        from models.preactresnet import preactresnet152
        net = preactresnet152()

    ##################################################################
    elif args.net == 'se_resnext50':
        from models.resnext import se_resnext
        net = se_resnext(args.num_classes, 2, args.pretrained, args.net)
    elif args.net == 'se_resnext101':
        from models.resnext import se_resnext
        net = se_resnext(args.num_classes, 2, args.pretrained, args.net)

    #################################################################
    elif args.net == 'shufflenet':
        from models.shufflenet import shufflenet
        net = shufflenet()
    elif args.net == 'shufflenetv2':
        from models.shufflenetv2 import shufflenetv2
        net = shufflenetv2()
    elif args.net == 'squeezenet':
        from models.squeezenet import squeezenet
        net = squeezenet()
    elif args.net == 'mobilenet':
        from models.mobilenet import mobilenet
        net = mobilenet()
    elif args.net == 'mobilenetv2':
        from models.mobilenetv2 import mobilenetv2
        net = mobilenetv2()
    elif args.net == 'nasnet':
        from models.nasnet import nasnet
        net = nasnet()
    elif args.net == 'attention56':
        from models.attention import attention56
        net = attention56()
    elif args.net == 'attention92':
        from models.attention import attention92
        net = attention92()

    #########################################################
    elif args.net == 'se_resnet50':
        from models.senet import seresnet
        net = seresnet(args.num_classes, 2, args.pretrained, args.net)
    elif args.net == 'se_resnet101':
        from models.senet import seresnet
        net = seresnet(args.num_classes, 2, args.pretrained, args.net)
    elif args.net == 'se_resnet152':
        from models.senet import seresnet
        net = seresnet(args.num_classes, 2, args.pretrained, args.net)

    else:
        print('the network name you have entered is not supported yet')
        sys.exit()

    if use_gpu:
        net = net.cuda()

    return net
Example #11
0
def get_network(args, use_gpu=True, num_train=0):
    """ return given network
    """
    if args.dataset == 'cifar-10':
        num_classes = 10
    elif args.dataset == 'cifar-100':
        num_classes = 100
    else:
        num_classes = 0

    if args.ignoring:
        if args.net == 'resnet18':
            from models.resnet_ign import resnet18_ign
            criterion = nn.CrossEntropyLoss(reduction='none')
            net = resnet18_ign(criterion, num_classes=num_classes, num_train=num_train,softmax=args.softmax,isalpha=args.isalpha)

    else:
        if args.net == 'vgg16':
            from models.vgg import vgg16_bn
            net = vgg16_bn()
        elif args.net == 'vgg13':
            from models.vgg import vgg13_bn
            net = vgg13_bn()
        elif args.net == 'vgg11':
            from models.vgg import vgg11_bn
            net = vgg11_bn()
        elif args.net == 'vgg19':
            from models.vgg import vgg19_bn
            net = vgg19_bn()
        elif args.net == 'densenet121':
            from models.densenet import densenet121
            net = densenet121()
        elif args.net == 'densenet161':
            from models.densenet import densenet161
            net = densenet161()
        elif args.net == 'densenet169':
            from models.densenet import densenet169
            net = densenet169()
        elif args.net == 'densenet201':
            from models.densenet import densenet201
            net = densenet201()
        elif args.net == 'googlenet':
            from models.googlenet import googlenet
            net = googlenet()
        elif args.net == 'inceptionv3':
            from models.inceptionv3 import inceptionv3
            net = inceptionv3()
        elif args.net == 'inceptionv4':
            from models.inceptionv4 import inceptionv4
            net = inceptionv4()
        elif args.net == 'inceptionresnetv2':
            from models.inceptionv4 import inception_resnet_v2
            net = inception_resnet_v2()
        elif args.net == 'xception':
            from models.xception import xception
            net = xception()
        elif args.net == 'resnet18':
            from models.resnet import resnet18
            net = resnet18(num_classes=num_classes)
        elif args.net == 'resnet34':
            from models.resnet import resnet34
            net = resnet34(num_classes=num_classes)
        elif args.net == 'resnet50':
            from models.resnet import resnet50
            net = resnet50(num_classes=num_classes)
        elif args.net == 'resnet101':
            from models.resnet import resnet101
            net = resnet101(num_classes=num_classes)
        elif args.net == 'resnet152':
            from models.resnet import resnet152
            net = resnet152(num_classes=num_classes)
        elif args.net == 'preactresnet18':
            from models.preactresnet import preactresnet18
            net = preactresnet18()
        elif args.net == 'preactresnet34':
            from models.preactresnet import preactresnet34
            net = preactresnet34()
        elif args.net == 'preactresnet50':
            from models.preactresnet import preactresnet50
            net = preactresnet50()
        elif args.net == 'preactresnet101':
            from models.preactresnet import preactresnet101
            net = preactresnet101()
        elif args.net == 'preactresnet152':
            from models.preactresnet import preactresnet152
            net = preactresnet152()
        elif args.net == 'resnext50':
            from models.resnext import resnext50
            net = resnext50()
        elif args.net == 'resnext101':
            from models.resnext import resnext101
            net = resnext101()
        elif args.net == 'resnext152':
            from models.resnext import resnext152
            net = resnext152()
        elif args.net == 'shufflenet':
            from models.shufflenet import shufflenet
            net = shufflenet()
        elif args.net == 'shufflenetv2':
            from models.shufflenetv2 import shufflenetv2
            net = shufflenetv2()
        elif args.net == 'squeezenet':
            from models.squeezenet import squeezenet
            net = squeezenet()
        elif args.net == 'mobilenet':
            from models.mobilenet import mobilenet
            net = mobilenet()
        elif args.net == 'mobilenetv2':
            from models.mobilenetv2 import mobilenetv2
            net = mobilenetv2()
        elif args.net == 'nasnet':
            from models.nasnet import nasnet
            net = nasnet()
        elif args.net == 'attention56':
            from models.attention import attention56
            net = attention56()
        elif args.net == 'attention92':
            from models.attention import attention92
            net = attention92()
        elif args.net == 'seresnet18':
            from models.senet import seresnet18
            net = seresnet18()
        elif args.net == 'seresnet34':
            from models.senet import seresnet34
            net = seresnet34()
        elif args.net == 'seresnet50':
            from models.senet import seresnet50
            net = seresnet50()
        elif args.net == 'seresnet101':
            from models.senet import seresnet101
            net = seresnet101()
        elif args.net == 'seresnet152':
            from models.senet import seresnet152
            net = seresnet152()

        else:
            print('the network name you have entered is not supported yet')
            sys.exit()

    if use_gpu:
        net = net.cuda()

    return net
Example #12
0
File: utils.py Project: nblt/DLDR
def get_model(args):
    if args.datasets == 'ImageNet':
        return models_imagenet.__dict__[args.arch]()

    if args.datasets == 'CIFAR10' or args.datasets == 'MNIST':
        num_class = 10
    elif args.datasets == 'CIFAR100':
        num_class = 100

    if args.datasets == 'CIFAR100':
        if args.arch == 'vgg16':
            from models.vgg import vgg16_bn
            net = vgg16_bn()
        elif args.arch == 'vgg13':
            from models.vgg import vgg13_bn
            net = vgg13_bn()
        elif args.arch == 'vgg11':
            from models.vgg import vgg11_bn
            net = vgg11_bn()
        elif args.arch == 'vgg19':
            from models.vgg import vgg19_bn
            net = vgg19_bn()
        elif args.arch == 'densenet121':
            from models.densenet import densenet121
            net = densenet121()
        elif args.arch == 'densenet161':
            from models.densenet import densenet161
            net = densenet161()
        elif args.arch == 'densenet169':
            from models.densenet import densenet169
            net = densenet169()
        elif args.arch == 'densenet201':
            from models.densenet import densenet201
            net = densenet201()
        elif args.arch == 'googlenet':
            from models.googlenet import googlenet
            net = googlenet()
        elif args.arch == 'inceptionv3':
            from models.inceptionv3 import inceptionv3
            net = inceptionv3()
        elif args.arch == 'inceptionv4':
            from models.inceptionv4 import inceptionv4
            net = inceptionv4()
        elif args.arch == 'inceptionresnetv2':
            from models.inceptionv4 import inception_resnet_v2
            net = inception_resnet_v2()
        elif args.arch == 'xception':
            from models.xception import xception
            net = xception()
        elif args.arch == 'resnet18':
            from models.resnet import resnet18
            net = resnet18()
        elif args.arch == 'resnet34':
            from models.resnet import resnet34
            net = resnet34()
        elif args.arch == 'resnet50':
            from models.resnet import resnet50
            net = resnet50()
        elif args.arch == 'resnet101':
            from models.resnet import resnet101
            net = resnet101()
        elif args.arch == 'resnet152':
            from models.resnet import resnet152
            net = resnet152()
        elif args.arch == 'preactresnet18':
            from models.preactresnet import preactresnet18
            net = preactresnet18()
        elif args.arch == 'preactresnet34':
            from models.preactresnet import preactresnet34
            net = preactresnet34()
        elif args.arch == 'preactresnet50':
            from models.preactresnet import preactresnet50
            net = preactresnet50()
        elif args.arch == 'preactresnet101':
            from models.preactresnet import preactresnet101
            net = preactresnet101()
        elif args.arch == 'preactresnet152':
            from models.preactresnet import preactresnet152
            net = preactresnet152()
        elif args.arch == 'resnext50':
            from models.resnext import resnext50
            net = resnext50()
        elif args.arch == 'resnext101':
            from models.resnext import resnext101
            net = resnext101()
        elif args.arch == 'resnext152':
            from models.resnext import resnext152
            net = resnext152()
        elif args.arch == 'shufflenet':
            from models.shufflenet import shufflenet
            net = shufflenet()
        elif args.arch == 'shufflenetv2':
            from models.shufflenetv2 import shufflenetv2
            net = shufflenetv2()
        elif args.arch == 'squeezenet':
            from models.squeezenet import squeezenet
            net = squeezenet()
        elif args.arch == 'mobilenet':
            from models.mobilenet import mobilenet
            net = mobilenet()
        elif args.arch == 'mobilenetv2':
            from models.mobilenetv2 import mobilenetv2
            net = mobilenetv2()
        elif args.arch == 'nasnet':
            from models.nasnet import nasnet
            net = nasnet()
        elif args.arch == 'attention56':
            from models.attention import attention56
            net = attention56()
        elif args.arch == 'attention92':
            from models.attention import attention92
            net = attention92()
        elif args.arch == 'seresnet18':
            from models.senet import seresnet18
            net = seresnet18()
        elif args.arch == 'seresnet34':
            from models.senet import seresnet34
            net = seresnet34()
        elif args.arch == 'seresnet50':
            from models.senet import seresnet50
            net = seresnet50()
        elif args.arch == 'seresnet101':
            from models.senet import seresnet101
            net = seresnet101()
        elif args.arch == 'seresnet152':
            from models.senet import seresnet152
            net = seresnet152()
        elif args.arch == 'wideresnet':
            from models.wideresidual import wideresnet
            net = wideresnet()
        elif args.arch == 'stochasticdepth18':
            from models.stochasticdepth import stochastic_depth_resnet18
            net = stochastic_depth_resnet18()
        elif args.arch == 'efficientnet':
            from models.efficientnet import efficientnet
            net = efficientnet(1, 1, 100, bn_momentum=0.9)
        elif args.arch == 'stochasticdepth34':
            from models.stochasticdepth import stochastic_depth_resnet34
            net = stochastic_depth_resnet34()
        elif args.arch == 'stochasticdepth50':
            from models.stochasticdepth import stochastic_depth_resnet50
            net = stochastic_depth_resnet50()
        elif args.arch == 'stochasticdepth101':
            from models.stochasticdepth import stochastic_depth_resnet101
            net = stochastic_depth_resnet101()
        else:
            net = resnet.__dict__[args.arch](num_classes=num_class)

        return net
    return resnet.__dict__[args.arch](num_classes=num_class)
Example #13
0
def get_network(args):
    """ return given network
    """

    if args.net == 'vgg16':
        from models.vgg import vgg16_bn
        net = vgg16_bn()
    elif args.net == 'vgg13':
        from models.vgg import vgg13_bn
        net = vgg13_bn()
    elif args.net == 'vgg11':
        from models.vgg import vgg11_bn
        net = vgg11_bn()
    elif args.net == 'vgg19':
        from models.vgg import vgg19_bn
        net = vgg19_bn()
    # elif args.net == 'efficientnet':
    #     from models.effnetv2 import effnetv2_s
    #     net = effnetv2_s()
    elif args.net == 'densenet121':
        from models.densenet import densenet121
        net = densenet121()
    elif args.net == 'densenet161':
        from models.densenet import densenet161
        net = densenet161()
    elif args.net == 'densenet169':
        from models.densenet import densenet169
        net = densenet169()
    elif args.net == 'densenet201':
        from models.densenet import densenet201
        net = densenet201()
    elif args.net == 'googlenet':
        from models.googlenet import googlenet
        net = googlenet()
    elif args.net == 'inceptionv3':
        from models.inceptionv3 import inceptionv3
        net = inceptionv3()
    elif args.net == 'inceptionv4':
        from models.inceptionv4 import inceptionv4
        net = inceptionv4()
    elif args.net == 'inceptionresnetv2':
        from models.inceptionv4 import inception_resnet_v2
        net = inception_resnet_v2()
    elif args.net == 'xception':
        from models.xception import xception
        net = xception()
    elif args.net == 'resnet18':
        from models.resnet import resnet18
        net = resnet18()
    elif args.net == 'resnet34':
        from models.resnet import resnet34
        net = resnet34()
    elif args.net == 'resnet50':
        from models.resnet import resnet50
        net = resnet50()
    elif args.net == 'resnet101':
        from models.resnet import resnet101
        net = resnet101()
    elif args.net == 'resnet152':
        from models.resnet import resnet152
        net = resnet152()
    elif args.net == 'preactresnet18':
        from models.preactresnet import preactresnet18
        net = preactresnet18()
    elif args.net == 'preactresnet34':
        from models.preactresnet import preactresnet34
        net = preactresnet34()
    elif args.net == 'preactresnet50':
        from models.preactresnet import preactresnet50
        net = preactresnet50()
    elif args.net == 'preactresnet101':
        from models.preactresnet import preactresnet101
        net = preactresnet101()
    elif args.net == 'preactresnet152':
        from models.preactresnet import preactresnet152
        net = preactresnet152()
    elif args.net == 'resnext50':
        from models.resnext import resnext50
        net = resnext50()
    elif args.net == 'resnext101':
        from models.resnext import resnext101
        net = resnext101()
    elif args.net == 'resnext152':
        from models.resnext import resnext152
        net = resnext152()
    elif args.net == 'shufflenet':
        from models.shufflenet import shufflenet
        net = shufflenet()
    elif args.net == 'shufflenetv2':
        from models.shufflenetv2 import shufflenetv2
        net = shufflenetv2()
    elif args.net == 'squeezenet':
        from models.squeezenet import squeezenet
        net = squeezenet()
    elif args.net == 'mobilenet':
        from models.mobilenet import mobilenet
        net = mobilenet()
    elif args.net == 'mobilenetv2':
        from models.mobilenetv2 import mobilenetv2
        net = mobilenetv2()
    elif args.net == 'nasnet':
        from models.nasnet import nasnet
        net = nasnet()
    elif args.net == 'attention56':
        from models.attention import attention56
        net = attention56()
    elif args.net == 'attention92':
        from models.attention import attention92
        net = attention92()
    elif args.net == 'seresnet18':
        from models.senet import seresnet18
        net = seresnet18()
    elif args.net == 'seresnet34':
        from models.senet import seresnet34
        net = seresnet34()
    elif args.net == 'seresnet50':
        from models.senet import seresnet50
        net = seresnet50()
    elif args.net == 'seresnet101':
        from models.senet import seresnet101
        net = seresnet101()
    elif args.net == 'seresnet152':
        from models.senet import seresnet152
        net = seresnet152()
    elif args.net == 'wideresnet':
        from models.wideresidual import wideresnet
        net = wideresnet()
    elif args.net == 'stochasticdepth18':
        from models.stochasticdepth import stochastic_depth_resnet18
        net = stochastic_depth_resnet18()
    elif args.net == 'stochasticdepth34':
        from models.stochasticdepth import stochastic_depth_resnet34
        net = stochastic_depth_resnet34()
    elif args.net == 'stochasticdepth50':
        from models.stochasticdepth import stochastic_depth_resnet50
        net = stochastic_depth_resnet50()
    elif args.net == 'stochasticdepth101':
        from models.stochasticdepth import stochastic_depth_resnet101
        net = stochastic_depth_resnet101()
    elif args.net == 'efficientnetb0':
        from models.efficientnet import efficientnetb0
        net = efficientnetb0()
    elif args.net == 'efficientnetb1':
        from models.efficientnet import efficientnetb1
        net = efficientnetb1()
    elif args.net == 'efficientnetb2':
        from models.efficientnet import efficientnetb2
        net = efficientnetb2()
    elif args.net == 'efficientnetb3':
        from models.efficientnet import efficientnetb3
        net = efficientnetb3()
    elif args.net == 'efficientnetb4':
        from models.efficientnet import efficientnetb4
        net = efficientnetb4()
    elif args.net == 'efficientnetb5':
        from models.efficientnet import efficientnetb5
        net = efficientnetb5()
    elif args.net == 'efficientnetb6':
        from models.efficientnet import efficientnetb6
        net = efficientnetb6()
    elif args.net == 'efficientnetb7':
        from models.efficientnet import efficientnetb7
        net = efficientnetb7()
    elif args.net == 'efficientnetl2':
        from models.efficientnet import efficientnetl2
        net = efficientnetl2()
    elif args.net == 'eff':
        from models.efficientnet_pytorch import EfficientNet
        net = EfficientNet.from_pretrained('efficientnet-b7', num_classes=2)

    else:
        print('the network name you have entered is not supported yet')
        sys.exit()

    if args.gpu:  #use_gpu
        net = net.cuda()
        print("use-gpu")

    return net
def initialize_model(model_name,
                     num_classes,
                     feature_extract,
                     use_pretrained=True):
    # Initialize these variables which will be set in this if statement. Each of these
    #   variables is model specific.
    model_ft = None
    input_size = 0

    # Ignore ssl certification (prevent error for some users)
    ssl._create_default_https_context = ssl._create_unverified_context

    if model_name == "resnet":
        """ Resnet18
        """
        model_ft = models.resnet18(pretrained=use_pretrained)
        set_parameter_requires_grad(model_ft, feature_extract)
        num_ftrs = model_ft.fc.in_features
        model_ft.fc = nn.Linear(num_ftrs, num_classes)
        input_size = 224

    elif model_name == "resnet152":
        model_ft = models.resnet152(pretrained=use_pretrained)
        set_parameter_requires_grad(model_ft, feature_extract)
        num_ftrs = model_ft.fc.in_features
        model_ft.fc = nn.Linear(num_ftrs, num_classes)
        input_size = 224

    elif model_name == "alexnet":
        """ Alexnet
        """
        model_ft = models.alexnet(pretrained=use_pretrained)
        set_parameter_requires_grad(model_ft, feature_extract)
        num_ftrs = model_ft.classifier[6].in_features
        model_ft.classifier[6] = nn.Linear(num_ftrs, num_classes)
        input_size = 224

    elif model_name == "vgg":
        """ VGG11_bn
        """
        model_ft = models.vgg11_bn(pretrained=use_pretrained)
        set_parameter_requires_grad(model_ft, feature_extract)
        num_ftrs = model_ft.classifier[6].in_features
        model_ft.classifier[6] = nn.Linear(num_ftrs, num_classes)
        input_size = 224

    elif model_name == "squeezenet":
        """ Squeezenet
        """
        model_ft = models.squeezenet1_0(pretrained=use_pretrained)
        set_parameter_requires_grad(model_ft, feature_extract)
        model_ft.classifier[1] = nn.Conv2d(512,
                                           num_classes,
                                           kernel_size=(1, 1),
                                           stride=(1, 1))
        model_ft.num_classes = num_classes
        input_size = 224

    elif model_name == "densenet":
        """ Densenet
        """
        model_ft = models.densenet121(pretrained=use_pretrained)
        set_parameter_requires_grad(model_ft, feature_extract)
        num_ftrs = model_ft.classifier.in_features
        model_ft.classifier = nn.Linear(num_ftrs, num_classes)
        input_size = 224

    elif model_name == "inception":
        """ Inception v3
        Be careful, expects (299,299) sized images and has auxiliary output
        """
        model_ft = models.inception_v3(pretrained=use_pretrained)
        set_parameter_requires_grad(model_ft, feature_extract)
        # Handle the auxilary net
        num_ftrs = model_ft.AuxLogits.fc.in_features
        model_ft.AuxLogits.fc = nn.Linear(num_ftrs, num_classes)
        # Handle the primary net
        num_ftrs = model_ft.fc.in_features
        model_ft.fc = nn.Linear(num_ftrs, num_classes)
        input_size = 299
    elif model_name == "xception":
        """ Xception
        Be careful, expects (299,299) sized images and has auxiliary output
        """
        model_ft = xception.xception(pretrained=use_pretrained)
        set_parameter_requires_grad(model_ft, feature_extract)
        num_ftrs = model_ft.fc.in_features
        model_ft.fc = nn.Linear(num_ftrs, num_classes)
        input_size = 299

    elif model_name == "fleming_v1":
        """ Fleming Model
        Custom model created by team Fleming
        """
        model_ft = fleming.FlemingModel_v1(num_classes=196)
        input_size = 224

    elif model_name == "fleming_v2":
        """ Fleming Model
        Custom model created by team Fleming
        """
        model_ft = fleming.FlemingModel_v2(num_classes=196)
        input_size = 224

    elif model_name == "fleming_v3":
        """ Fleming Model
        Custom model created by team Fleming
        """
        model_ft = fleming.FlemingModel_v3(num_classes=196)
        input_size = 224

    else:
        print("Invalid model name, exiting...")
        exit()

    return model_ft, input_size
Example #15
0
def get_network(args):
    """ return given network
    """

    if args.net == 'vgg16':
        from models.vgg import vgg16_bn
        net = vgg16_bn()
    elif args.net == 'vgg13':
        from models.vgg import vgg13_bn
        net = vgg13_bn()
    elif args.net == 'vgg11':
        from models.vgg import vgg11_bn
        net = vgg11_bn()
    elif args.net == 'vgg19':
        from models.vgg import vgg19_bn
        net = vgg19_bn()
    elif args.net == 'densenet121':
        from models.densenet import densenet121
        net = densenet121()
    elif args.net == 'densenet161':
        from models.densenet import densenet161
        net = densenet161()
    elif args.net == 'densenet169':
        from models.densenet import densenet169
        net = densenet169()
    elif args.net == 'densenet201':
        from models.densenet import densenet201
        net = densenet201()
    elif args.net == 'googlenet':
        from models.googlenet import googlenet
        net = googlenet()
    elif args.net == 'inceptionv3':
        from models.inceptionv3 import inceptionv3
        net = inceptionv3()
    elif args.net == 'inceptionv4':
        from models.inceptionv4 import inceptionv4
        net = inceptionv4()
    elif args.net == 'inceptionresnetv2':
        from models.inceptionv4 import inception_resnet_v2
        net = inception_resnet_v2()
    elif args.net == 'xception':
        from models.xception import xception
        net = xception()
    elif args.net == 'resnet18':
        from models.resnet import resnet18
        net = resnet18()
    elif args.net == 'resnet34':
        from models.resnet import resnet34
        net = resnet34()
    elif args.net == 'resnet50':
        from models.resnet import resnet50
        net = resnet50()
    elif args.net == 'resnet101':
        from models.resnet import resnet101
        net = resnet101()
    elif args.net == 'resnet152':
        from models.resnet import resnet152
        net = resnet152()
    elif args.net == 'preactresnet18':
        from models.preactresnet import preactresnet18
        net = preactresnet18()
    elif args.net == 'preactresnet34':
        from models.preactresnet import preactresnet34
        net = preactresnet34()
    elif args.net == 'preactresnet50':
        from models.preactresnet import preactresnet50
        net = preactresnet50()
    elif args.net == 'preactresnet101':
        from models.preactresnet import preactresnet101
        net = preactresnet101()
    elif args.net == 'preactresnet152':
        from models.preactresnet import preactresnet152
        net = preactresnet152()
    elif args.net == 'resnext50':
        from models.resnext import resnext50
        net = resnext50()
    elif args.net == 'resnext101':
        from models.resnext import resnext101
        net = resnext101()
    elif args.net == 'resnext152':
        from models.resnext import resnext152
        net = resnext152()
    elif args.net == 'shufflenet':
        from models.shufflenet import shufflenet
        net = shufflenet()
    elif args.net == 'shufflenetv2':
        from models.shufflenetv2 import shufflenetv2
        net = shufflenetv2()
    elif args.net == 'squeezenet':
        from models.squeezenet import squeezenet
        net = squeezenet()
    elif args.net == 'mobilenet':
        from models.mobilenet import mobilenet
        net = mobilenet()
    elif args.net == 'mobilenetv2':
        from models.mobilenetv2 import mobilenetv2
        net = mobilenetv2()
    elif args.net == 'nasnet':
        from models.nasnet import nasnet
        net = nasnet()
    elif args.net == 'attention56':
        from models.attention import attention56
        net = attention56()
    elif args.net == 'attention92':
        from models.attention import attention92
        net = attention92()
    elif args.net == 'seresnet18':
        from models.senet import seresnet18
        net = seresnet18()
    elif args.net == 'seresnet34':
        from models.senet import seresnet34
        net = seresnet34()
    elif args.net == 'seresnet50':
        from models.senet import seresnet50
        net = seresnet50()
    elif args.net == 'seresnet101':
        from models.senet import seresnet101
        net = seresnet101()
    elif args.net == 'seresnet152':
        from models.senet import seresnet152
        net = seresnet152()
    elif args.net == 'wideresnet':
        from models.wideresidual import wideresnet
        net = wideresnet()
    elif args.net == 'stochasticdepth18':
        from models.stochasticdepth import stochastic_depth_resnet18
        net = stochastic_depth_resnet18()
    elif args.net == 'stochasticdepth34':
        from models.stochasticdepth import stochastic_depth_resnet34
        net = stochastic_depth_resnet34()
    elif args.net == 'stochasticdepth50':
        from models.stochasticdepth import stochastic_depth_resnet50
        net = stochastic_depth_resnet50()
    elif args.net == 'stochasticdepth101':
        from models.stochasticdepth import stochastic_depth_resnet101
        net = stochastic_depth_resnet101()
    elif args.net == 'normal_resnet':
        from models.normal_resnet import resnet18
        net = resnet18()
    elif args.net == 'hyper_resnet':
        from models.hypernet_main import Hypernet_Main
        net = Hypernet_Main(
            encoder="resnet18",
            hypernet_params={'vqvae_dict_size': args.dict_size})
    elif args.net == 'normal_resnet_wo_bn':
        from models.normal_resnet_wo_bn import resnet18
        net = resnet18()
    elif args.net == 'hyper_resnet_wo_bn':
        from models.hypernet_main import Hypernet_Main
        net = Hypernet_Main(
            encoder="resnet18_wobn",
            hypernet_params={'vqvae_dict_size': args.dict_size})
    else:
        print('the network name you have entered is not supported yet')
        sys.exit()

    if args.gpu:  #use_gpu
        net = net.cuda()

    return net
def build_backbone(backbone='resnet-50',
                   layers=50,
                   output_stride=16,
                   norm_layer=None):
    # if norm_layer is None:
    #     norm_layer = nn.BatchNorm2d
    # elif norm_layer is 'gn':
    #     norm_layer = GroupNorm
    # elif norm_layer is 'frn':
    #     norm_layer = FilterResponseNorm2d
    if backbone is 'resnet':
        if layers == 50:
            model = resnet.resnet50(norm_layer=norm_layer)
            return model
        elif layers == 101:
            model = resnet.resnet101(norm_layer=norm_layer)
            return model
        elif layers == 152:
            model = resnet.resnet152(norm_layer=norm_layer)
            return model
        elif layers == 200:
            model = resnet.resnet200(norm_layer=norm_layer)
            return model

    elif backbone is 'resgroup':
        if layers == 50:
            model = resgroup.resgroup50(norm_layer=norm_layer)
            return model
        elif layers == 101:
            model = resgroup.resgroup101(norm_layer=norm_layer)
            return model
        elif layers == 152:
            model = resgroup.resgroup152(norm_layer=norm_layer)
            return model

    elif backbone is 'iresnet':
        if layers == 50:
            model = iresnet.iresnet50(norm_layer=norm_layer)
            return model
        elif layers == 101:
            model = iresnet.iresnet101(norm_layer=norm_layer)
            return model
        elif layers == 152:
            model = iresnet.iresnet152(norm_layer=norm_layer)
            return model
        elif layers == 200:
            model = iresnet.iresnet200(norm_layer=norm_layer)
            return model
        elif layers == 302:
            model = iresnet.iresnet302(norm_layer=norm_layer)
            return model
        elif layers == 404:
            model = iresnet.iresnet404(norm_layer=norm_layer)
            return model
        elif layers == 1001:
            model = iresnet.iresnet1001(norm_layer=norm_layer)
            return model

    elif backbone is 'iresgroup':
        if layers == 50:
            model = iresgroup.iresgroup50(norm_layer=norm_layer)
            return model
        elif layers == 101:
            model = iresgroup.iresgroup101(norm_layer=norm_layer)
            return model
        elif layers == 152:
            model = iresgroup.iresgroup152(norm_layer=norm_layer)
            return model

    elif backbone is 'xception':
        model = xception.xception(output_stride=output_stride,
                                  norm_layer=norm_layer)
        return model
def get_network(args, use_gpu=True):
    """ return given network
    """

    if args.net == 'vgg16':
        from models.vgg import vgg16_bn
        net = vgg16_bn()
    elif args.net == 'vgg13':
        from models.vgg import vgg13_bn
        net = vgg13_bn()
    elif args.net == 'vgg11':
        from models.vgg import vgg11_bn
        net = vgg11_bn()
    elif args.net == 'vgg19':
        from models.vgg import vgg19_bn
        net = vgg19_bn()
    elif args.net == 'densenet121':
        from models.densenet import densenet121
        net = densenet121()
    elif args.net == 'densenet161':
        from models.densenet import densenet161
        net = densenet161()
    elif args.net == 'densenet169':
        from models.densenet import densenet169
        net = densenet169()
    elif args.net == 'densenet201':
        from models.densenet import densenet201
        net = densenet201()
    elif args.net == 'googlenet':
        from models.googlenet import googlenet
        net = googlenet()
    elif args.net == 'inceptionv3':
        from models.inceptionv3 import inceptionv3
        net = inceptionv3()
    elif args.net == 'inceptionv4':
        from models.inceptionv4 import inceptionv4
        net = inceptionv4()
    elif args.net == 'inceptionresnetv2':
        from models.inceptionv4 import inception_resnet_v2
        net = inception_resnet_v2()
    elif args.net == 'xception':
        from models.xception import xception
        net = xception()
    elif args.net == 'resnet18':
        from models.resnet import resnet18
        net = resnet18()
    elif args.net == 'resnet34':
        from models.resnet import resnet34
        net = resnet34()
    elif args.net == 'resnet50':
        from models.resnet import resnet50
        net = resnet50()
    elif args.net == 'resnet101':
        from models.resnet import resnet101
        net = resnet101()
    elif args.net == 'resnet152':
        from models.resnet import resnet152
        net = resnet152()
    elif args.net == 'preactresnet18':
        from models.preactresnet import preactresnet18
        net = preactresnet18()
    elif args.net == 'preactresnet34':
        from models.preactresnet import preactresnet34
        net = preactresnet34()
    elif args.net == 'preactresnet50':
        from models.preactresnet import preactresnet50
        net = preactresnet50()
    elif args.net == 'preactresnet101':
        from models.preactresnet import preactresnet101
        net = preactresnet101()
    elif args.net == 'preactresnet152':
        from models.preactresnet import preactresnet152
        net = preactresnet152()
    elif args.net == 'resnext50':
        from models.resnext import resnext50
        net = resnext50()
    elif args.net == 'resnext101':
        from models.resnext import resnext101
        net = resnext101()
    elif args.net == 'resnext152':
        from models.resnext import resnext152
        net = resnext152()
    elif args.net == 'shufflenet':
        from models.shufflenet import shufflenet
        net = shufflenet()
    elif args.net == 'shufflenetv2':
        from models.shufflenetv2 import shufflenetv2
        net = shufflenetv2()
    elif args.net == 'squeezenet':
        from models.squeezenet import squeezenet
        net = squeezenet()
    elif args.net == 'mobilenet':
        from models.mobilenet import mobilenet
        net = mobilenet()
    elif args.net == 'mobilenetv2':
        from models.mobilenetv2 import mobilenetv2
        net = mobilenetv2()
    elif args.net == 'nasnet':
        from models.nasnet import nasnet
        net = nasnet()
    elif args.net == 'attention56':
        from models.attention import attention56
        net = attention56()
    elif args.net == 'attention92':
        from models.attention import attention92
        net = attention92()
    elif args.net == 'seresnet18':
        from models.senet import seresnet18
        net = seresnet18()
    elif args.net == 'seresnet34':
        from models.senet import seresnet34
        net = seresnet34()
    elif args.net == 'seresnet50':
        from models.senet import seresnet50
        net = seresnet50()
    elif args.net == 'seresnet101':
        from models.senet import seresnet101
        net = seresnet101()
    elif args.net == 'seresnet152':
        from models.senet import seresnet152
        net = seresnet152()

    else:
        print('the network name you have entered is not supported yet')
        sys.exit()

    if use_gpu:
        net = net.cuda()

    return net
Example #18
0
    def __init__(self,
                 args,
                 mode='soft',
                 weight=None,
                 eta=0.3,
                 min_child_weight=1,
                 max_depth=6,
                 gamma=0):
        '''
        XGB parameters:
            https://www.analyticsvidhya.com/blog/2016/03/complete-guide-parameter-tuning-xgboost-with-codes-python/
            eta = Makes the model more robust by shrinking the weights on each step
            min_child_weight = Used to control over-fitting. Higher values prevent a model from learning relations which might be highly specific to the particular sample selected for a tree.
            max_depth  = Used to control over-fitting as higher depth will allow model to learn relations very specific to a particular sample.
            gamma = A node is split only when the resulting split gives a positive reduction in the loss function. Gamma specifies the minimum loss reduction required to make a split.
        '''
        super(Ensemble_Model, self).__init__()
        self.name = "Ensemble_model"
        self.models = []
        self.session = []
        self.transform = []
        idx = 0
        if args.densenet:
            args.densenet = args.densenet.split(' ')
            for i in range(len(args.densenet) // 5):
                densenet = densenet201(pretrained=False)
                # densenet = DenseNet121(pretrained=False)
                if int(args.densenet[1 + i * 5]):
                    densenet = Binary_Model(
                        densenet,
                        cat_embed=int(args.densenet[2 + i * 5]),
                        embed_dim=int(args.densenet[3 + i * 5]))
                elif int(args.densenet[2 + i * 5]):
                    densenet = Trainable_Embedding(
                        densenet, embed_dim=int(args.densenet[3 + i * 5]))
                self.models.append(densenet)
                self.session.append(args.densenet[0 + i * 5])
                self.transform.append(int(args.densenet[4 + i * 5]))

        if args.nest200:
            args.nest200 = args.nest200.split(' ')
            for i in range(len(args.nest200) // 5):
                nest = resnest200(pretrained=False)

                if int(args.nest200[1 + i * 5]):
                    nest = Binary_Model(nest,
                                        cat_embed=int(args.nest200[2 + i * 5]),
                                        embed_dim=int(args.nest200[3 + i * 5]))
                elif int(args.nest200[2 + i * 5]):
                    nest = Trainable_Embedding(nest,
                                               embed_dim=int(
                                                   args.nest200[3 + i * 5]))

                self.models.append(nest)
                self.session.append(args.nest200[0 + i * 5])
                self.transform.append(int(args.nest200[4 + i * 5]))

        if args.resnext:
            args.resnext = args.resnext.split(' ')
            for i in range(len(args.resnext) // 5):

                resnet = resnext50_32x4d(pretrained=False)

                if int(args.resnext[1 + i * 5]):
                    resnet = Binary_Model(
                        resnet,
                        cat_embed=int(args.resnext[2 + i * 5]),
                        embed_dim=int(args.resnext[3 + i * 5]))
                elif int(args.resnext[2 + i * 5]):
                    resnet = Trainable_Embedding(resnet,
                                                 embed_dim=int(
                                                     args.resnext[3 + i * 5]))

                self.models.append(resnet)
                self.session.append(args.resnext[0 + i * 5])
                self.transform.append(int(args.resnext[4 + i * 5]))

        if args.resnext101:
            args.resnext101 = args.resnext101.split(' ')
            for i in range(len(args.resnext101) // 5):
                resnet101 = resnext101_32x8d(pretrained=False)

                if int(args.resnext101[1 + i * 5]):
                    resnet101 = Binary_Model(
                        resnet,
                        cat_embed=int(args.resnext101[2 + i * 5]),
                        embed_dim=int(args.resnext101[3 + i * 5]))
                elif int(args.resnext101[2 + i * 5]):
                    resnet101 = Trainable_Embedding(
                        resnet101, embed_dim=int(args.resnext101[3 + i * 5]))

                self.models.append(resnet101)
                self.session.append(args.resnext101[0 + i * 5])
                self.transform.append(int(args.resnext101[4 + i * 5]))

        if args.resnext101_32x16d:
            args.resnext101_32x16d = args.resnext101_32x16d.split(' ')
            for i in range(len(args.resnext101_32x16d) // 5):
                resnet101_32x16d = resnext101_32x16d(pretrained=False)

                if int(args.resnext101_32x16d[1 + i * 5]):
                    resnet101_32x16d = Binary_Model(
                        resnet,
                        cat_embed=int(args.resnext101_32x16d[2 + i * 5]),
                        embed_dim=int(args.resnext101_32x16d[3 + i * 5]))
                elif int(args.resnext101_32x16d[2 + i * 5]):
                    resnet101_32x16d = Trainable_Embedding(
                        resnet101_32x16d,
                        embed_dim=int(args.resnext101_32x16d[3 + i * 5]))

                self.models.append(resnet101_32x16d)
                self.session.append(args.resnext101_32x16d[0 + i * 5])
                self.transform.append(int(args.resnext101_32x16d[4 + i * 5]))

        if args.efficient_b2:
            args.efficient_b2 = args.efficient_b2.split(' ')
            for i in range(len(args.efficient_b2) // 5):
                effi = EfficientNet_B2(pretrained=False)

                if int(args.efficient_b2[1 + i * 5]):
                    effi = Binary_Model(
                        resnet,
                        cat_embed=int(args.efficient_b2[2 + i * 5]),
                        embed_dim=int(args.efficient_b2[3 + i * 5]))
                elif int(args.efficient_b2[2 + i * 5]):
                    effi = Trainable_Embedding(
                        effi, embed_dim=int(args.efficient_b2[3 + i * 5]))

                self.models.append(effi)
                self.session.append(args.efficient_b2[0 + i * 5])
                self.transform.append(int(args.efficient_b2[4 + i * 5]))

        if args.efficient_b5:
            args.efficient_b5 = args.efficient_b5.split(' ')
            for i in range(len(args.efficient_b5) // 5):
                effi = EfficientNet_B5(pretrained=False)

                if int(args.efficient_b5[1 + i * 5]):
                    effi = Binary_Model(
                        resnet,
                        cat_embed=int(args.efficient_b5[2 + i * 5]),
                        embed_dim=int(args.efficient_b5[3 + i * 5]))
                elif int(args.efficient_b5[2 + i * 5]):
                    effi = Trainable_Embedding(
                        effi, embed_dim=int(args.efficient_b5[3 + i * 5]))

                self.models.append(effi)
                self.session.append(args.efficient_b5[0 + i * 5])
                self.transform.append(int(args.efficient_b5[4 + i * 5]))

        if args.efficient_b6:
            args.efficient_b6 = args.efficient_b6.split(' ')
            for i in range(len(args.efficient_b6) // 5):
                effi = EfficientNet_B6(pretrained=False)

                if int(args.efficient_b6[1 + i * 5]):
                    effi = Binary_Model(
                        resnet,
                        cat_embed=int(args.efficient_b6[2 + i * 5]),
                        embed_dim=int(args.efficient_b6[3 + i * 5]))
                elif int(args.efficient_b6[2 + i * 5]):
                    effi = Trainable_Embedding(
                        effi, embed_dim=int(args.efficient_b6[3 + i * 5]))

                self.models.append(effi)
                self.session.append(args.efficient_b6[0 + i * 5])
                self.transform.append(int(args.efficient_b6[4 + i * 5]))

        if args.xception:
            args.xception = args.xception.split(' ')
            for i in range(len(args.xception) // 5):
                xcep = xception(pretrained=False)

                if int(args.xception[1 + i * 5]):
                    xcep = Binary_Model(
                        resnet,
                        cat_embed=int(args.xception[2 + i * 5]),
                        embed_dim=int(args.xception[3 + i * 5]))
                elif int(args.xception[2 + i * 5]):
                    xcep = Trainable_Embedding(xcep,
                                               embed_dim=int(
                                                   args.xception[3 + i * 5]))

                self.models.append(xcep)
                self.session.append(args.xception[0 + i * 5])
                self.transform.append(int(args.xception[4 + i * 5]))

        self.num_model = len(self.models)
        self.mode = mode
        self.weight = weight
        print("Transforms !", self.transform)
        if weight is not None:
            print("Weight :", self.weight)
            print([int(w / sum(self.weight) * 100) for w in self.weight])

        self.w = nn.Parameter(torch.tensor([1 / self.num_model] *
                                           self.num_model).cuda(),
                              requires_grad=True)

        if mode == 'stacked':
            self.stacked_fc = nn.Linear(5 * self.num_model, 5)

        elif mode == 'xgb':
            self.xgb_classifier = xgb.XGBClassifier(
                objective="multi:softprob",
                learning_rate=eta,
                min_child_weight=min_child_weight,
                max_depth=max_depth,
                gamma=gamma,
                random_state=42)

        self.load_finetuned()