def define_model(is_resnet, is_densenet, is_senet):
    use18 = True # True
    if is_resnet:
        if not use18:
            original_model = resnet.resnet18(pretrained = True)
            Encoder = modules.E_resnet(original_model) 
            model = net.model(Encoder, num_features=512, block_channel = [64, 128, 256, 512])
        else:
            stereoModel = Resnet18Encoder(3)  
            model_dict = stereoModel.state_dict()
            encoder_dict = torch.load('./models/monodepth_resnet18_001.pth',map_location='cpu' )
            new_dict = {}
            for key in encoder_dict:
                if key in model_dict:
                    new_dict[key] = encoder_dict[key]

            stereoModel.load_state_dict(new_dict )
            Encoder = stereoModel
            model = net.model(Encoder, num_features=512, block_channel = [64, 128, 256, 512])      
          
    if is_densenet:
        original_model = densenet.densenet161(pretrained=True)
        Encoder = modules.E_densenet(original_model)
        model = net.model(Encoder, num_features=2208, block_channel = [192, 384, 1056, 2208])
    if is_senet:
        original_model = senet.senet154(pretrained='imagenet')
        Encoder = modules.E_senet(original_model)
        model = net.model(Encoder, num_features=2048, block_channel = [256, 512, 1024, 2048])

    return model
示例#2
0
def define_test_model():
    #archs = {"Resnet", "Densenet", "SEnet", "Custom"}
    is_resnet = args.arch == "Resnet"  #True #False #True
    is_densenet = args.arch == "Densenet"  # #False #True #False # False
    is_senet = args.arch == "SEnet"  # True #False #True #False
    is_custom = args.arch == "Custom"

    if is_resnet:
        #original_model = resnet.resnet18(pretrained = pretrain_logical)
        #Encoder = modules.E_resnet(original_model)
        #model = net.model(Encoder, num_features=2048, block_channel = [256, 512, 1024, 2048])

        stereoModel = Resnet18Encoder(3)
        model_dict = stereoModel.state_dict()
        encoder_dict = torch.load('./models/monodepth_resnet18_001.pth',
                                  map_location='cpu')
        new_dict = {}
        for key in encoder_dict:
            # print(key)
            if key in model_dict:
                new_dict[key] = encoder_dict[key]

        stereoModel.load_state_dict(new_dict)
        Encoder = stereoModel
        model = net.model(Encoder,
                          num_features=512,
                          block_channel=[64, 128, 256, 512])
        print("Loading a model...")
        print("/model_epoch_{}.pth".format(str(args.load_epoch)))
        model = model.cuda().float()
        #print(stereoModel)
        #print(model)

        model_dict = torch.load(
            args.load_dir +
            "/original_model_epoch_{}.pth".format(str(args.load_epoch)))
        new_dict = model_dict
        #new_dict = {}
        #for key in model_dict:
        #	new_dict[key[7:]] = model_dict[key]
        model.load_state_dict(new_dict)

    if is_densenet:
        # TODO: no dot bug
        original_model = densenet.densenet161(pretrained=True)
        Encoder = modules.E_densenet(original_model)
        model = net.model(Encoder,
                          num_features=2208,
                          block_channel=[192, 384, 1056, 2208])

    if is_senet:
        original_model = senet.senet154(pretrained='imagenet')
        Encoder = modules.E_senet(original_model)
        model = net.model(Encoder,
                          num_features=2048,
                          block_channel=[256, 512, 1024, 2048])

    return model
示例#3
0
def define_model(is_resnet, is_densenet, is_senet):
    if is_resnet:
        original_model = resnet.resnet50(pretrained=True)
        Encoder = modules.E_resnet(original_model)
        model = net.model(Encoder, num_features=2048, block_channel=[256, 512, 1024, 2048])
    if is_densenet:
        original_model = densenet.densenet161(pretrained=True)
        Encoder = modules.E_densenet(original_model)
        model = net.model(Encoder, num_features=2208, block_channel=[192, 384, 1056, 2208])
    if is_senet:
        original_model = senet.senet154(pretrained=None)
        Encoder = modules.E_senet(original_model)
        model = net.model(Encoder, num_features=2048, block_channel=[256, 512, 1024, 2048])

    return model
def define_model(encoder='resnet'):
    if encoder is 'resnet':
        original_model = resnet.resnet50(pretrained = True)
        Encoder = modules.E_resnet(original_model) 
        model = net.model(Encoder, num_features=2048, block_channel = [256, 512, 1024, 2048])
    if encoder is 'densenet':
        original_model = densenet.densenet161(pretrained=True)
        Encoder = modules.E_densenet(original_model)
        model = net.model(Encoder, num_features=2208, block_channel = [192, 384, 1056, 2208])
    if encoder is 'senet':
        original_model = senet.senet154(pretrained='imagenet')
        Encoder = modules.E_senet(original_model)
        model = net.model(Encoder, num_features=2048, block_channel = [256, 512, 1024, 2048])

    return model
示例#5
0
    def __init__(self, type, num_classes):
        super().__init__()

        if type == "seresnext50":
            self.senet = se_resnext50_32x4d(pretrained="imagenet")

            # layer0_modules = [
            #     ('conv1', self.senet.layer0.conv1),
            #     ('bn1', self.senet.layer0.bn1),
            #     ('relu1', self.senet.layer0.relu1),
            # ]
            # self.layer0 = nn.Sequential(OrderedDict(layer0_modules))

            self.layer0 = self.senet.layer0
        elif type == "seresnext101":
            self.senet = se_resnext101_32x4d(pretrained="imagenet")
            self.layer0 = self.senet.layer0
        elif type == "seresnet50":
            self.senet = se_resnet50(pretrained="imagenet")
            self.layer0 = self.senet.layer0
        elif type == "seresnet101":
            self.senet = se_resnet101(pretrained="imagenet")
            self.layer0 = self.senet.layer0
        elif type == "seresnet152":
            self.senet = se_resnet152(pretrained="imagenet")
            self.layer0 = self.senet.layer0
        elif type == "senet154":
            self.senet = senet154(pretrained="imagenet")
            self.layer0 = self.senet.layer0
        else:
            raise Exception("Unsupported senet model type: '{}".format(type))

        self.expand_channels = ExpandChannels2d(3)
        self.bn = nn.BatchNorm2d(3)

        self.avg_pool = nn.AdaptiveAvgPool2d(output_size=1)
        self.dropout = nn.Dropout(0.2)
        self.last_linear = nn.Linear(2048, num_classes)
def define_model(is_resnet,
                 is_densenet,
                 is_senet,
                 model='tbdp',
                 parallel=False,
                 semff=False,
                 pcamff=False):
    if is_resnet:
        original_model = resnet.resnet50(pretrained=True)
        Encoder = modules.E_resnet(original_model)
        if model == 'tbdp':
            model = net.TBDPNet(Encoder,
                                num_features=2048,
                                block_channel=[256, 512, 1024, 2048],
                                parallel=parallel,
                                pcamff=pcamff)
        elif model == 'hu':
            model = net.Hu(Encoder,
                           num_features=2048,
                           block_channel=[256, 512, 1024, 2048],
                           semff=semff,
                           pcamff=pcamff)
        else:
            raise NotImplementedError(
                "Select model type in [\'tbdp\', \'hu\']")
    if is_densenet:
        original_model = densenet.densenet161(pretrained=True)
        Encoder = modules.E_densenet(original_model)
        if model == 'tbdp':
            model = net.TBDPNet(Encoder,
                                num_features=2208,
                                block_channel=[192, 384, 1056, 2208],
                                parallel=parallel,
                                pcamff=pcamff)
        elif model == 'hu':
            model = net.Hu(Encoder,
                           num_features=2208,
                           block_channel=[192, 384, 1056, 2208],
                           semff=semff,
                           pcamff=pcamff)
        else:
            raise NotImplementedError(
                "Select model type in [\'tbdp\', \'hu\']")
    if is_senet:
        original_model = senet.senet154(pretrained='imagenet')
        Encoder = modules.E_senet(original_model)
        if model == 'tbdp':
            model = net.TBDPNet(Encoder,
                                num_features=2048,
                                block_channel=[256, 512, 1024, 2048],
                                parallel=parallel,
                                pcamff=pcamff)
        elif model == 'hu':
            model = net.Hu(Encoder,
                           num_features=2048,
                           block_channel=[256, 512, 1024, 2048],
                           semff=semff,
                           pcamff=pcamff)
        else:
            raise NotImplementedError(
                "Select model type in [\'tbdp\', \'hu\']")

    return model
示例#7
0
def load_model(model_name='resnet50',resume='Best',start_epoch=0,cn=3,
               save_dir='saved_models/',width=32,start=8,cls_number=10,avg_number=1,gpus=[0,1,2,3,4,5,6,7],kfold = 1,model_times=0,train=True):
    
    load_dict = None
    #load_dict = True if cn == 3 else None

    if model_name == 'resnet50':
        model = resnet50(num_classes=cls_number,pretrained=load_dict)
    elif model_name == 'resnet101':
        model = resnet101(num_classes=cls_number,pretrained=load_dict)
    elif model_name == 'resnet152':
        model = resnet152(num_classes=cls_number,pretrained=load_dict)
    elif model_name == 'densenet161':
        model = densenet161(num_classes=cls_number,pretrained=load_dict)
    elif model_name == 'xception':
        model = xception(num_classes=cls_number,pretrained=load_dict)
        model.conv1 = nn.Conv2d(cn, 32, kernel_size=(3, 3), stride=(2, 2), bias=False)
    elif model_name == 'inception_v3':
        model = inception_v3(num_classes=cls_number,pretrained=load_dict)
        model.Conv2d_1a_3x3.conv = nn.Conv2d(cn, 32, kernel_size=(3, 3), stride=(2, 2), bias=False)
    elif model_name == 'seinception_v3':
        model = se_inception_v3(num_classes=cls_number)
        model.model.Conv2d_1a_3x3.conv = nn.Conv2d(cn, 32, kernel_size=(3, 3), stride=(2, 2), bias=False)
    elif model_name == 'inception_v4':
        model = inceptionv4(num_classes=cls_number,pretrained=load_dict)
        model.features[0].conv = nn.Conv2d(cn, 32, kernel_size=(3, 3), stride=(2, 2), bias=False)
    elif model_name == 'inceptionresnetv2':
        model = inceptionresnetv2(num_classes=cls_number,pretrained=load_dict)
        model.conv2d_1a.conv = nn.Conv2d(cn, 32, kernel_size=(3, 3), stride=(2, 2), bias=False)
    elif model_name == 'seresnet50':
        model = se_resnet50(num_classes=cls_number,pretrained=load_dict)
    elif model_name == 'seresnet101':
        model = se_resnet101(num_classes=cls_number,pretrained=load_dict)
    elif model_name == 'seresnet152':
        model = se_resnet152(num_classes=cls_number,pretrained=load_dict)
    elif model_name == 'seresnext50':
        model = se_resnext50_32x4d(num_classes=cls_number,pretrained=load_dict)
    elif model_name == 'seresnext101':
        model = se_resnext101_32x4d(num_classes=cls_number,pretrained=load_dict)
    elif model_name == 'resnet50-101':
        model = SimpleNet()
    elif model_name == 'seresnet20':
        model = se_resnet20(num_classes=cls_number)
    elif model_name == 'seresnet32':
        model = se_resnet32(num_classes=cls_number)
    elif model_name == 'seresnet18':
        model = se_resnet18(num_classes=cls_number)
    elif model_name == 'seresnet34':
        model = se_resnet34(num_classes=cls_number)
    elif model_name == 'senet154':
        model = senet154(num_classes=cls_number,pretrained=load_dict)
        model.layer0.conv1 = nn.Conv2d(cn, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
    elif model_name == 'nasnet':
        model = nasnetalarge(num_classes=cls_number,pretrained=load_dict)
        model.conv0.conv = nn.Conv2d(cn, 96, kernel_size=(3, 3), stride=(2, 2), bias=False)
    elif model_name == 'dpn98':
        model = dpn98(num_classes=cls_number,pretrained=load_dict)
    elif model_name == 'dpn107':
        model = dpn107(num_classes=cls_number,pretrained=load_dict)
    elif model_name == 'dpn92':
        model = dpn92(num_classes=cls_number,pretrained=load_dict)
    elif model_name == 'polynet':
        model = polynet(num_classes=cls_number,pretrained=load_dict)
        model.stem.conv1[0].conv = nn.Conv2d(cn, 32, kernel_size=(3, 3), stride=(2, 2), bias=False) 
    elif model_name == 'pnasnet':
        model = pnasnet5large(num_classes=cls_number,pretrained=load_dict)
        model.conv_0.conv = nn.Conv2d(cn, 96, kernel_size=(3, 3), stride=(2, 2), bias=False) 
    
    #print(model)

    if '-' not in model_name and load_dict != True:
      if model_name in ['dpn98',]:
        model.features.conv1_1.conv = nn.Conv2d(cn, 96, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
      elif model_name in ['dpn92',]:
        model.features.conv1_1.conv = nn.Conv2d(cn, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
      elif model_name in ['seresnet20','seresnet32']:
        model.conv1 = nn.Conv2d(cn, 16, kernel_size=3, stride=1, padding=1, bias=False)
      elif model_name in ['seresnet18','seresnet34']:
        model.conv1 = nn.Conv2d(cn, 64, kernel_size=7, stride=2, padding=3, bias=False)
      elif 'seresnext' in model_name:
        model.layer0.conv1 = nn.Conv2d(cn, 64, kernel_size=7, stride=2, padding=3, bias=False)
      elif 'seresnet' in model_name:
        model.layer0.conv1 = nn.Conv2d(cn, 64, kernel_size=7, stride=2, padding=3, bias=False)
      elif 'resnet' in model_name:
        model.conv1 = nn.Conv2d(cn, 64, kernel_size=7, stride=2, padding=3, bias=False)
        #model.fc = torch.nn.Linear(model.fc.in_features,cls_number)
      elif 'densenet' in model_name:
        model.features.conv0 = nn.Conv2d(cn, 96, kernel_size=7, stride=2, padding=3, bias=False)

      model.avgpool = torch.nn.AdaptiveAvgPool2d(output_size=1)
    else:
      pass

    #print(model)

    load_model = False
    #if model_name == 'resnet50':
    if load_dict != True and model_name == 'resnet50' and 0:
       base_model = resnet50(pretrained=True)
       model_dict = model.state_dict()
       new_state_dict = OrderedDict()
       for k, v in base_model.state_dict().items()[1:-2]:
           new_state_dict[k] = v
       model_dict.update(new_state_dict)
       model.load_state_dict(model_dict)
       print 'load imagenet'
       load_model = True
    
    model_ = model_name + '_' + \
                   str(width) + '_' + str(start) + '_' + str(cn)
    if kfold > 1:
       model_prefix = save_dir + str(model_times) + '_' + model_
    else:
       model_prefix = save_dir + model_
    
    if resume == 'Best' and avg_number >= 1:
        weight_path = glob(model_prefix + '*pth')
        cur_index = np.argsort(-np.array([float(cur_p.split('/')[-1].split('[')[-1].split(']')[0]) for cur_p in weight_path]))
        new_state_dict = OrderedDict()
        if len(weight_path) == 0:
            resume = ''
        elif avg_number == 1:
            resume = weight_path[0]
        else:
          for cnt,index in zip(range(avg_number),cur_index[:avg_number]):
            cur_resume = weight_path[index]
            print cur_resume
            model.load_state_dict(torch.load(cur_resume))
            for k, v in model.state_dict().items():
              if cnt == 0:
                new_state_dict[k] = v
              else:
                new_state_dict[k] = new_state_dict[k] + v
              if cnt == avg_number - 1:
                new_state_dict[k] = new_state_dict[k] / float(avg_number)
          model.load_state_dict(new_state_dict)
        if train == False:
          for index in cur_index[avg_number + 2:]:
            cur_resume = weight_path[index]
            print('remove resume %s ' %cur_resume)
            os.remove(cur_resume)
    if resume != '' and avg_number == 1:
        start_epoch = int(resume.split('-')[-3])
        #print('resuming finetune from %s'%resume)
        logging.info('resuming finetune from %s'%resume)
        model.load_state_dict(torch.load(resume))

    print('start-epoch : ',start_epoch)

    cuda_avail = torch.cuda.is_available()
    if cuda_avail:
       print 'cuda_avail: True'
       if len(gpus) > 1:
           model = torch.nn.DataParallel(model,device_ids=gpus).cuda()
       else:
           model = model.cuda()
    return model,start_epoch
示例#8
0
def generate_model(opt):
    assert opt.model in [
        'resnet', 'preresnet', 'wideresnet', 'resnext', 'densenet','senet'
    ]

    if opt.model == 'resnet':
        assert opt.model_depth in [10, 18, 34, 50, 101, 152, 200]

        from models.resnet import get_fine_tuning_parameters

        if opt.model_depth == 10:
            model = resnet.resnet10(
                num_classes=opt.n_classes,
                shortcut_type=opt.resnet_shortcut,
                sample_size=opt.sample_size,
                sample_duration=opt.sample_duration)
        elif opt.model_depth == 18:
            model = resnet.resnet18(
                num_classes=opt.n_classes,
                shortcut_type=opt.resnet_shortcut,
                sample_size=opt.sample_size,
                sample_duration=opt.sample_duration)
        elif opt.model_depth == 34:
            model = resnet.resnet34(
                num_classes=opt.n_classes,
                shortcut_type=opt.resnet_shortcut,
                sample_size=opt.sample_size,
                sample_duration=opt.sample_duration)
        elif opt.model_depth == 50:
            model = resnet.resnet50(
                num_classes=opt.n_classes,
                shortcut_type=opt.resnet_shortcut,
                sample_size=opt.sample_size,
                sample_duration=opt.sample_duration)
        elif opt.model_depth == 101:
            model = resnet.resnet101(
                num_classes=opt.n_classes,
                shortcut_type=opt.resnet_shortcut,
                sample_size=opt.sample_size,
                sample_duration=opt.sample_duration)
        elif opt.model_depth == 152:
            model = resnet.resnet152(
                num_classes=opt.n_classes,
                shortcut_type=opt.resnet_shortcut,
                sample_size=opt.sample_size,
                sample_duration=opt.sample_duration)
        elif opt.model_depth == 200:
            model = resnet.resnet200(
                num_classes=opt.n_classes,
                shortcut_type=opt.resnet_shortcut,
                sample_size=opt.sample_size,
                sample_duration=opt.sample_duration)
    elif opt.model == 'wideresnet':
        assert opt.model_depth in [50]

        from models.wide_resnet import get_fine_tuning_parameters

        if opt.model_depth == 50:
            model = wide_resnet.resnet50(
                num_classes=opt.n_classes,
                shortcut_type=opt.resnet_shortcut,
                k=opt.wide_resnet_k,
                sample_size=opt.sample_size,
                sample_duration=opt.sample_duration)
    elif opt.model == 'resnext':
        assert opt.model_depth in [50, 101, 152]

        from models.resnext import get_fine_tuning_parameters

        if opt.model_depth == 50:
            model = resnext.resnet50(
                num_classes=opt.n_classes,
                shortcut_type=opt.resnet_shortcut,
                cardinality=opt.resnext_cardinality,
                sample_size=opt.sample_size,
                sample_duration=opt.sample_duration)
        elif opt.model_depth == 101:
            model = resnext.resnet101(
                num_classes=opt.n_classes,
                shortcut_type=opt.resnet_shortcut,
                cardinality=opt.resnext_cardinality,
                sample_size=opt.sample_size,
                sample_duration=opt.sample_duration)
        elif opt.model_depth == 152:
            model = resnext.resnet152(
                num_classes=opt.n_classes,
                shortcut_type=opt.resnet_shortcut,
                cardinality=opt.resnext_cardinality,
                sample_size=opt.sample_size,
                sample_duration=opt.sample_duration)
    elif opt.model == 'preresnet':
        assert opt.model_depth in [18, 34, 50, 101, 152, 200]

        from models.pre_act_resnet import get_fine_tuning_parameters

        if opt.model_depth == 18:
            model = pre_act_resnet.resnet18(
                num_classes=opt.n_classes,
                shortcut_type=opt.resnet_shortcut,
                sample_size=opt.sample_size,
                sample_duration=opt.sample_duration)
        elif opt.model_depth == 34:
            model = pre_act_resnet.resnet34(
                num_classes=opt.n_classes,
                shortcut_type=opt.resnet_shortcut,
                sample_size=opt.sample_size,
                sample_duration=opt.sample_duration)
        elif opt.model_depth == 50:
            model = pre_act_resnet.resnet50(
                num_classes=opt.n_classes,
                shortcut_type=opt.resnet_shortcut,
                sample_size=opt.sample_size,
                sample_duration=opt.sample_duration)
        elif opt.model_depth == 101:
            model = pre_act_resnet.resnet101(
                num_classes=opt.n_classes,
                shortcut_type=opt.resnet_shortcut,
                sample_size=opt.sample_size,
                sample_duration=opt.sample_duration)
        elif opt.model_depth == 152:
            model = pre_act_resnet.resnet152(
                num_classes=opt.n_classes,
                shortcut_type=opt.resnet_shortcut,
                sample_size=opt.sample_size,
                sample_duration=opt.sample_duration)
        elif opt.model_depth == 200:
            model = pre_act_resnet.resnet200(
                num_classes=opt.n_classes,
                shortcut_type=opt.resnet_shortcut,
                sample_size=opt.sample_size,
                sample_duration=opt.sample_duration)
    elif opt.model == 'densenet':
        assert opt.model_depth in [121, 169, 201, 264]

        from models.densenet import get_fine_tuning_parameters

        if opt.model_depth == 121:
            model = densenet.densenet121(
                num_classes=opt.n_classes,
                sample_size=opt.sample_size,
                sample_duration=opt.sample_duration)
        elif opt.model_depth == 169:
            model = densenet.densenet169(
                num_classes=opt.n_classes,
                sample_size=opt.sample_size,
                sample_duration=opt.sample_duration)
        elif opt.model_depth == 201:
            model = densenet.densenet201(
                num_classes=opt.n_classes,
                sample_size=opt.sample_size,
                sample_duration=opt.sample_duration)
        elif opt.model_depth == 264:
            model = densenet.densenet264(
                num_classes=opt.n_classes,
                sample_size=opt.sample_size,
                sample_duration=opt.sample_duration)

    elif opt.model == 'senet':
         assert opt.model_depth in [50,101,152,154,5032,10132]
         if opt.model_depth == 50:
             model = senet.se_resnet50(num_classes = opt.n_classes,  pretrained = None)
         elif opt.model_depth == 101:
             model = senet.se_resnet101(num_classes = opt.n_classes, pretrained = None)
         elif opt.model_depth == 152:
             model = senet.se_resnet152(num_classes = opt.n_classes, pretrained = None)
         elif opt.model_depth == 154:
             model = senet.senet154(num_classes = opt.n_classes, pretrained = None)
         elif opt.model_depth == 5032:
             model = senet.resnext50_32x4d(num_classes = opt.n_classes, pretrained = None)
         elif opt.model_depth == 10132:
             model = senet.se_resnext101_32x4d(num_classes = opt.n_classes, pretrained = None)

    if not opt.no_cuda:
        model = model.cuda()
        model = nn.DataParallel(model, device_ids=None)

        if opt.pretrain_path:
            print('loading pretrained model {}'.format(opt.pretrain_path))
            pretrain = torch.load(opt.pretrain_path)
            #assert opt.arch == pretrain['arch']
            #model.load_state_dict(pretrain['state_dict'])
            pretrained_dict = pretrain['state_dict']
            model_dict = model.state_dict()
            pretrained_dict = {k: v for k, v in pretrained_dict.items() if k.find("module.fc") == -1}
            model_dict.update(pretrained_dict)
            model.load_state_dict(model_dict)

            #model.load_state_dict(model_dict,strict=False)

            if opt.model == 'densenet':
                model.module.classifier = nn.Linear(
                    model.module.classifier.in_features, opt.n_finetune_classes)
                model.module.classifier = model.module.classifier.cuda()
            elif opt.model == "senet":
                model.module.last_linear = nn.Linear(model.module.last_linear.in_features, opt.n_finetune_classes)
                model.module.last_linear = model.module.last_linear.cuda()
                return model, model.parameters()
            else:
                model.module.fc = nn.Linear(model.module.fc.in_features,
                                            opt.n_finetune_classes)
                model.module.fc = model.module.fc.cuda()

            parameters = get_fine_tuning_parameters(model, opt.ft_begin_index)
            return model, parameters
    else:
        if opt.pretrain_path:
            print('loading pretrained model {}'.format(opt.pretrain_path))
            pretrain = torch.load(opt.pretrain_path)
            assert opt.arch == pretrain['arch']

            model.load_state_dict(pretrain['state_dict'])

            if opt.model == 'densenet':
                model.classifier = nn.Linear(
                    model.classifier.in_features, opt.n_finetune_classes)
            else:
                model.fc = nn.Linear(model.fc.in_features,
                                            opt.n_finetune_classes)

            parameters = get_fine_tuning_parameters(model, opt.ft_begin_index)
            return model, parameters

    return model, model.parameters()
示例#9
0
    'ResNet50':
    lambda: E_resnet(resnet50(pretrained=True)),
    'ResNet101':
    lambda: E_resnet(resnet101(pretrained=True)),
    'ResNet152':
    lambda: E_resnet(resnet152(pretrained=True)),
    'DenseNet121':
    lambda: E_densenet(densenet121(pretrained=True)),
    'DenseNet161':
    lambda: E_densenet(densenet161(pretrained=True)),
    'DenseNet169':
    lambda: E_densenet(densenet169(pretrained=True)),
    'DenseNet201':
    lambda: E_densenet(densenet201(pretrained=True)),
    'SENet154':
    lambda: E_senet(senet154(pretrained="imagenet")),
    'SE_ResNet50':
    lambda: E_senet(se_resnet50(pretrained="imagenet")),
    'SE_ResNet101':
    lambda: E_senet(se_resnet101(pretrained="imagenet")),
    'SE_ResNet152':
    lambda: E_senet(se_resnet152(pretrained="imagenet")),
    'SE_ResNext50_32x4d':
    lambda: E_senet(se_resnext50_32x4d(pretrained="imagenet")),
    'SE_ResNext101_32x4d':
    lambda: E_senet(se_resnext101_32x4d(pretrained="imagenet"))
}


def get_models(args):
    backbone = args.backbone