def define_model(is_resnet, is_densenet, is_senet): use18 = True # True if is_resnet: if not use18: original_model = resnet.resnet18(pretrained = True) Encoder = modules.E_resnet(original_model) model = net.model(Encoder, num_features=512, block_channel = [64, 128, 256, 512]) else: stereoModel = Resnet18Encoder(3) model_dict = stereoModel.state_dict() encoder_dict = torch.load('./models/monodepth_resnet18_001.pth',map_location='cpu' ) new_dict = {} for key in encoder_dict: if key in model_dict: new_dict[key] = encoder_dict[key] stereoModel.load_state_dict(new_dict ) Encoder = stereoModel model = net.model(Encoder, num_features=512, block_channel = [64, 128, 256, 512]) if is_densenet: original_model = densenet.densenet161(pretrained=True) Encoder = modules.E_densenet(original_model) model = net.model(Encoder, num_features=2208, block_channel = [192, 384, 1056, 2208]) if is_senet: original_model = senet.senet154(pretrained='imagenet') Encoder = modules.E_senet(original_model) model = net.model(Encoder, num_features=2048, block_channel = [256, 512, 1024, 2048]) return model
def define_test_model(): #archs = {"Resnet", "Densenet", "SEnet", "Custom"} is_resnet = args.arch == "Resnet" #True #False #True is_densenet = args.arch == "Densenet" # #False #True #False # False is_senet = args.arch == "SEnet" # True #False #True #False is_custom = args.arch == "Custom" if is_resnet: #original_model = resnet.resnet18(pretrained = pretrain_logical) #Encoder = modules.E_resnet(original_model) #model = net.model(Encoder, num_features=2048, block_channel = [256, 512, 1024, 2048]) stereoModel = Resnet18Encoder(3) model_dict = stereoModel.state_dict() encoder_dict = torch.load('./models/monodepth_resnet18_001.pth', map_location='cpu') new_dict = {} for key in encoder_dict: # print(key) if key in model_dict: new_dict[key] = encoder_dict[key] stereoModel.load_state_dict(new_dict) Encoder = stereoModel model = net.model(Encoder, num_features=512, block_channel=[64, 128, 256, 512]) print("Loading a model...") print("/model_epoch_{}.pth".format(str(args.load_epoch))) model = model.cuda().float() #print(stereoModel) #print(model) model_dict = torch.load( args.load_dir + "/original_model_epoch_{}.pth".format(str(args.load_epoch))) new_dict = model_dict #new_dict = {} #for key in model_dict: # new_dict[key[7:]] = model_dict[key] model.load_state_dict(new_dict) if is_densenet: # TODO: no dot bug original_model = densenet.densenet161(pretrained=True) Encoder = modules.E_densenet(original_model) model = net.model(Encoder, num_features=2208, block_channel=[192, 384, 1056, 2208]) if is_senet: original_model = senet.senet154(pretrained='imagenet') Encoder = modules.E_senet(original_model) model = net.model(Encoder, num_features=2048, block_channel=[256, 512, 1024, 2048]) return model
def define_model(is_resnet, is_densenet, is_senet): if is_resnet: original_model = resnet.resnet50(pretrained=True) Encoder = modules.E_resnet(original_model) model = net.model(Encoder, num_features=2048, block_channel=[256, 512, 1024, 2048]) if is_densenet: original_model = densenet.densenet161(pretrained=True) Encoder = modules.E_densenet(original_model) model = net.model(Encoder, num_features=2208, block_channel=[192, 384, 1056, 2208]) if is_senet: original_model = senet.senet154(pretrained=None) Encoder = modules.E_senet(original_model) model = net.model(Encoder, num_features=2048, block_channel=[256, 512, 1024, 2048]) return model
def define_model(encoder='resnet'): if encoder is 'resnet': original_model = resnet.resnet50(pretrained = True) Encoder = modules.E_resnet(original_model) model = net.model(Encoder, num_features=2048, block_channel = [256, 512, 1024, 2048]) if encoder is 'densenet': original_model = densenet.densenet161(pretrained=True) Encoder = modules.E_densenet(original_model) model = net.model(Encoder, num_features=2208, block_channel = [192, 384, 1056, 2208]) if encoder is 'senet': original_model = senet.senet154(pretrained='imagenet') Encoder = modules.E_senet(original_model) model = net.model(Encoder, num_features=2048, block_channel = [256, 512, 1024, 2048]) return model
def __init__(self, type, num_classes): super().__init__() if type == "seresnext50": self.senet = se_resnext50_32x4d(pretrained="imagenet") # layer0_modules = [ # ('conv1', self.senet.layer0.conv1), # ('bn1', self.senet.layer0.bn1), # ('relu1', self.senet.layer0.relu1), # ] # self.layer0 = nn.Sequential(OrderedDict(layer0_modules)) self.layer0 = self.senet.layer0 elif type == "seresnext101": self.senet = se_resnext101_32x4d(pretrained="imagenet") self.layer0 = self.senet.layer0 elif type == "seresnet50": self.senet = se_resnet50(pretrained="imagenet") self.layer0 = self.senet.layer0 elif type == "seresnet101": self.senet = se_resnet101(pretrained="imagenet") self.layer0 = self.senet.layer0 elif type == "seresnet152": self.senet = se_resnet152(pretrained="imagenet") self.layer0 = self.senet.layer0 elif type == "senet154": self.senet = senet154(pretrained="imagenet") self.layer0 = self.senet.layer0 else: raise Exception("Unsupported senet model type: '{}".format(type)) self.expand_channels = ExpandChannels2d(3) self.bn = nn.BatchNorm2d(3) self.avg_pool = nn.AdaptiveAvgPool2d(output_size=1) self.dropout = nn.Dropout(0.2) self.last_linear = nn.Linear(2048, num_classes)
def define_model(is_resnet, is_densenet, is_senet, model='tbdp', parallel=False, semff=False, pcamff=False): if is_resnet: original_model = resnet.resnet50(pretrained=True) Encoder = modules.E_resnet(original_model) if model == 'tbdp': model = net.TBDPNet(Encoder, num_features=2048, block_channel=[256, 512, 1024, 2048], parallel=parallel, pcamff=pcamff) elif model == 'hu': model = net.Hu(Encoder, num_features=2048, block_channel=[256, 512, 1024, 2048], semff=semff, pcamff=pcamff) else: raise NotImplementedError( "Select model type in [\'tbdp\', \'hu\']") if is_densenet: original_model = densenet.densenet161(pretrained=True) Encoder = modules.E_densenet(original_model) if model == 'tbdp': model = net.TBDPNet(Encoder, num_features=2208, block_channel=[192, 384, 1056, 2208], parallel=parallel, pcamff=pcamff) elif model == 'hu': model = net.Hu(Encoder, num_features=2208, block_channel=[192, 384, 1056, 2208], semff=semff, pcamff=pcamff) else: raise NotImplementedError( "Select model type in [\'tbdp\', \'hu\']") if is_senet: original_model = senet.senet154(pretrained='imagenet') Encoder = modules.E_senet(original_model) if model == 'tbdp': model = net.TBDPNet(Encoder, num_features=2048, block_channel=[256, 512, 1024, 2048], parallel=parallel, pcamff=pcamff) elif model == 'hu': model = net.Hu(Encoder, num_features=2048, block_channel=[256, 512, 1024, 2048], semff=semff, pcamff=pcamff) else: raise NotImplementedError( "Select model type in [\'tbdp\', \'hu\']") return model
def load_model(model_name='resnet50',resume='Best',start_epoch=0,cn=3, save_dir='saved_models/',width=32,start=8,cls_number=10,avg_number=1,gpus=[0,1,2,3,4,5,6,7],kfold = 1,model_times=0,train=True): load_dict = None #load_dict = True if cn == 3 else None if model_name == 'resnet50': model = resnet50(num_classes=cls_number,pretrained=load_dict) elif model_name == 'resnet101': model = resnet101(num_classes=cls_number,pretrained=load_dict) elif model_name == 'resnet152': model = resnet152(num_classes=cls_number,pretrained=load_dict) elif model_name == 'densenet161': model = densenet161(num_classes=cls_number,pretrained=load_dict) elif model_name == 'xception': model = xception(num_classes=cls_number,pretrained=load_dict) model.conv1 = nn.Conv2d(cn, 32, kernel_size=(3, 3), stride=(2, 2), bias=False) elif model_name == 'inception_v3': model = inception_v3(num_classes=cls_number,pretrained=load_dict) model.Conv2d_1a_3x3.conv = nn.Conv2d(cn, 32, kernel_size=(3, 3), stride=(2, 2), bias=False) elif model_name == 'seinception_v3': model = se_inception_v3(num_classes=cls_number) model.model.Conv2d_1a_3x3.conv = nn.Conv2d(cn, 32, kernel_size=(3, 3), stride=(2, 2), bias=False) elif model_name == 'inception_v4': model = inceptionv4(num_classes=cls_number,pretrained=load_dict) model.features[0].conv = nn.Conv2d(cn, 32, kernel_size=(3, 3), stride=(2, 2), bias=False) elif model_name == 'inceptionresnetv2': model = inceptionresnetv2(num_classes=cls_number,pretrained=load_dict) model.conv2d_1a.conv = nn.Conv2d(cn, 32, kernel_size=(3, 3), stride=(2, 2), bias=False) elif model_name == 'seresnet50': model = se_resnet50(num_classes=cls_number,pretrained=load_dict) elif model_name == 'seresnet101': model = se_resnet101(num_classes=cls_number,pretrained=load_dict) elif model_name == 'seresnet152': model = se_resnet152(num_classes=cls_number,pretrained=load_dict) elif model_name == 'seresnext50': model = se_resnext50_32x4d(num_classes=cls_number,pretrained=load_dict) elif model_name == 'seresnext101': model = se_resnext101_32x4d(num_classes=cls_number,pretrained=load_dict) elif model_name == 'resnet50-101': model = SimpleNet() elif model_name == 'seresnet20': model = se_resnet20(num_classes=cls_number) elif model_name == 'seresnet32': model = se_resnet32(num_classes=cls_number) elif model_name == 'seresnet18': model = se_resnet18(num_classes=cls_number) elif model_name == 'seresnet34': model = se_resnet34(num_classes=cls_number) elif model_name == 'senet154': model = senet154(num_classes=cls_number,pretrained=load_dict) model.layer0.conv1 = nn.Conv2d(cn, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False) elif model_name == 'nasnet': model = nasnetalarge(num_classes=cls_number,pretrained=load_dict) model.conv0.conv = nn.Conv2d(cn, 96, kernel_size=(3, 3), stride=(2, 2), bias=False) elif model_name == 'dpn98': model = dpn98(num_classes=cls_number,pretrained=load_dict) elif model_name == 'dpn107': model = dpn107(num_classes=cls_number,pretrained=load_dict) elif model_name == 'dpn92': model = dpn92(num_classes=cls_number,pretrained=load_dict) elif model_name == 'polynet': model = polynet(num_classes=cls_number,pretrained=load_dict) model.stem.conv1[0].conv = nn.Conv2d(cn, 32, kernel_size=(3, 3), stride=(2, 2), bias=False) elif model_name == 'pnasnet': model = pnasnet5large(num_classes=cls_number,pretrained=load_dict) model.conv_0.conv = nn.Conv2d(cn, 96, kernel_size=(3, 3), stride=(2, 2), bias=False) #print(model) if '-' not in model_name and load_dict != True: if model_name in ['dpn98',]: model.features.conv1_1.conv = nn.Conv2d(cn, 96, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False) elif model_name in ['dpn92',]: model.features.conv1_1.conv = nn.Conv2d(cn, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False) elif model_name in ['seresnet20','seresnet32']: model.conv1 = nn.Conv2d(cn, 16, kernel_size=3, stride=1, padding=1, bias=False) elif model_name in ['seresnet18','seresnet34']: model.conv1 = nn.Conv2d(cn, 64, kernel_size=7, stride=2, padding=3, bias=False) elif 'seresnext' in model_name: model.layer0.conv1 = nn.Conv2d(cn, 64, kernel_size=7, stride=2, padding=3, bias=False) elif 'seresnet' in model_name: model.layer0.conv1 = nn.Conv2d(cn, 64, kernel_size=7, stride=2, padding=3, bias=False) elif 'resnet' in model_name: model.conv1 = nn.Conv2d(cn, 64, kernel_size=7, stride=2, padding=3, bias=False) #model.fc = torch.nn.Linear(model.fc.in_features,cls_number) elif 'densenet' in model_name: model.features.conv0 = nn.Conv2d(cn, 96, kernel_size=7, stride=2, padding=3, bias=False) model.avgpool = torch.nn.AdaptiveAvgPool2d(output_size=1) else: pass #print(model) load_model = False #if model_name == 'resnet50': if load_dict != True and model_name == 'resnet50' and 0: base_model = resnet50(pretrained=True) model_dict = model.state_dict() new_state_dict = OrderedDict() for k, v in base_model.state_dict().items()[1:-2]: new_state_dict[k] = v model_dict.update(new_state_dict) model.load_state_dict(model_dict) print 'load imagenet' load_model = True model_ = model_name + '_' + \ str(width) + '_' + str(start) + '_' + str(cn) if kfold > 1: model_prefix = save_dir + str(model_times) + '_' + model_ else: model_prefix = save_dir + model_ if resume == 'Best' and avg_number >= 1: weight_path = glob(model_prefix + '*pth') cur_index = np.argsort(-np.array([float(cur_p.split('/')[-1].split('[')[-1].split(']')[0]) for cur_p in weight_path])) new_state_dict = OrderedDict() if len(weight_path) == 0: resume = '' elif avg_number == 1: resume = weight_path[0] else: for cnt,index in zip(range(avg_number),cur_index[:avg_number]): cur_resume = weight_path[index] print cur_resume model.load_state_dict(torch.load(cur_resume)) for k, v in model.state_dict().items(): if cnt == 0: new_state_dict[k] = v else: new_state_dict[k] = new_state_dict[k] + v if cnt == avg_number - 1: new_state_dict[k] = new_state_dict[k] / float(avg_number) model.load_state_dict(new_state_dict) if train == False: for index in cur_index[avg_number + 2:]: cur_resume = weight_path[index] print('remove resume %s ' %cur_resume) os.remove(cur_resume) if resume != '' and avg_number == 1: start_epoch = int(resume.split('-')[-3]) #print('resuming finetune from %s'%resume) logging.info('resuming finetune from %s'%resume) model.load_state_dict(torch.load(resume)) print('start-epoch : ',start_epoch) cuda_avail = torch.cuda.is_available() if cuda_avail: print 'cuda_avail: True' if len(gpus) > 1: model = torch.nn.DataParallel(model,device_ids=gpus).cuda() else: model = model.cuda() return model,start_epoch
def generate_model(opt): assert opt.model in [ 'resnet', 'preresnet', 'wideresnet', 'resnext', 'densenet','senet' ] if opt.model == 'resnet': assert opt.model_depth in [10, 18, 34, 50, 101, 152, 200] from models.resnet import get_fine_tuning_parameters if opt.model_depth == 10: model = resnet.resnet10( num_classes=opt.n_classes, shortcut_type=opt.resnet_shortcut, sample_size=opt.sample_size, sample_duration=opt.sample_duration) elif opt.model_depth == 18: model = resnet.resnet18( num_classes=opt.n_classes, shortcut_type=opt.resnet_shortcut, sample_size=opt.sample_size, sample_duration=opt.sample_duration) elif opt.model_depth == 34: model = resnet.resnet34( num_classes=opt.n_classes, shortcut_type=opt.resnet_shortcut, sample_size=opt.sample_size, sample_duration=opt.sample_duration) elif opt.model_depth == 50: model = resnet.resnet50( num_classes=opt.n_classes, shortcut_type=opt.resnet_shortcut, sample_size=opt.sample_size, sample_duration=opt.sample_duration) elif opt.model_depth == 101: model = resnet.resnet101( num_classes=opt.n_classes, shortcut_type=opt.resnet_shortcut, sample_size=opt.sample_size, sample_duration=opt.sample_duration) elif opt.model_depth == 152: model = resnet.resnet152( num_classes=opt.n_classes, shortcut_type=opt.resnet_shortcut, sample_size=opt.sample_size, sample_duration=opt.sample_duration) elif opt.model_depth == 200: model = resnet.resnet200( num_classes=opt.n_classes, shortcut_type=opt.resnet_shortcut, sample_size=opt.sample_size, sample_duration=opt.sample_duration) elif opt.model == 'wideresnet': assert opt.model_depth in [50] from models.wide_resnet import get_fine_tuning_parameters if opt.model_depth == 50: model = wide_resnet.resnet50( num_classes=opt.n_classes, shortcut_type=opt.resnet_shortcut, k=opt.wide_resnet_k, sample_size=opt.sample_size, sample_duration=opt.sample_duration) elif opt.model == 'resnext': assert opt.model_depth in [50, 101, 152] from models.resnext import get_fine_tuning_parameters if opt.model_depth == 50: model = resnext.resnet50( num_classes=opt.n_classes, shortcut_type=opt.resnet_shortcut, cardinality=opt.resnext_cardinality, sample_size=opt.sample_size, sample_duration=opt.sample_duration) elif opt.model_depth == 101: model = resnext.resnet101( num_classes=opt.n_classes, shortcut_type=opt.resnet_shortcut, cardinality=opt.resnext_cardinality, sample_size=opt.sample_size, sample_duration=opt.sample_duration) elif opt.model_depth == 152: model = resnext.resnet152( num_classes=opt.n_classes, shortcut_type=opt.resnet_shortcut, cardinality=opt.resnext_cardinality, sample_size=opt.sample_size, sample_duration=opt.sample_duration) elif opt.model == 'preresnet': assert opt.model_depth in [18, 34, 50, 101, 152, 200] from models.pre_act_resnet import get_fine_tuning_parameters if opt.model_depth == 18: model = pre_act_resnet.resnet18( num_classes=opt.n_classes, shortcut_type=opt.resnet_shortcut, sample_size=opt.sample_size, sample_duration=opt.sample_duration) elif opt.model_depth == 34: model = pre_act_resnet.resnet34( num_classes=opt.n_classes, shortcut_type=opt.resnet_shortcut, sample_size=opt.sample_size, sample_duration=opt.sample_duration) elif opt.model_depth == 50: model = pre_act_resnet.resnet50( num_classes=opt.n_classes, shortcut_type=opt.resnet_shortcut, sample_size=opt.sample_size, sample_duration=opt.sample_duration) elif opt.model_depth == 101: model = pre_act_resnet.resnet101( num_classes=opt.n_classes, shortcut_type=opt.resnet_shortcut, sample_size=opt.sample_size, sample_duration=opt.sample_duration) elif opt.model_depth == 152: model = pre_act_resnet.resnet152( num_classes=opt.n_classes, shortcut_type=opt.resnet_shortcut, sample_size=opt.sample_size, sample_duration=opt.sample_duration) elif opt.model_depth == 200: model = pre_act_resnet.resnet200( num_classes=opt.n_classes, shortcut_type=opt.resnet_shortcut, sample_size=opt.sample_size, sample_duration=opt.sample_duration) elif opt.model == 'densenet': assert opt.model_depth in [121, 169, 201, 264] from models.densenet import get_fine_tuning_parameters if opt.model_depth == 121: model = densenet.densenet121( num_classes=opt.n_classes, sample_size=opt.sample_size, sample_duration=opt.sample_duration) elif opt.model_depth == 169: model = densenet.densenet169( num_classes=opt.n_classes, sample_size=opt.sample_size, sample_duration=opt.sample_duration) elif opt.model_depth == 201: model = densenet.densenet201( num_classes=opt.n_classes, sample_size=opt.sample_size, sample_duration=opt.sample_duration) elif opt.model_depth == 264: model = densenet.densenet264( num_classes=opt.n_classes, sample_size=opt.sample_size, sample_duration=opt.sample_duration) elif opt.model == 'senet': assert opt.model_depth in [50,101,152,154,5032,10132] if opt.model_depth == 50: model = senet.se_resnet50(num_classes = opt.n_classes, pretrained = None) elif opt.model_depth == 101: model = senet.se_resnet101(num_classes = opt.n_classes, pretrained = None) elif opt.model_depth == 152: model = senet.se_resnet152(num_classes = opt.n_classes, pretrained = None) elif opt.model_depth == 154: model = senet.senet154(num_classes = opt.n_classes, pretrained = None) elif opt.model_depth == 5032: model = senet.resnext50_32x4d(num_classes = opt.n_classes, pretrained = None) elif opt.model_depth == 10132: model = senet.se_resnext101_32x4d(num_classes = opt.n_classes, pretrained = None) if not opt.no_cuda: model = model.cuda() model = nn.DataParallel(model, device_ids=None) if opt.pretrain_path: print('loading pretrained model {}'.format(opt.pretrain_path)) pretrain = torch.load(opt.pretrain_path) #assert opt.arch == pretrain['arch'] #model.load_state_dict(pretrain['state_dict']) pretrained_dict = pretrain['state_dict'] model_dict = model.state_dict() pretrained_dict = {k: v for k, v in pretrained_dict.items() if k.find("module.fc") == -1} model_dict.update(pretrained_dict) model.load_state_dict(model_dict) #model.load_state_dict(model_dict,strict=False) if opt.model == 'densenet': model.module.classifier = nn.Linear( model.module.classifier.in_features, opt.n_finetune_classes) model.module.classifier = model.module.classifier.cuda() elif opt.model == "senet": model.module.last_linear = nn.Linear(model.module.last_linear.in_features, opt.n_finetune_classes) model.module.last_linear = model.module.last_linear.cuda() return model, model.parameters() else: model.module.fc = nn.Linear(model.module.fc.in_features, opt.n_finetune_classes) model.module.fc = model.module.fc.cuda() parameters = get_fine_tuning_parameters(model, opt.ft_begin_index) return model, parameters else: if opt.pretrain_path: print('loading pretrained model {}'.format(opt.pretrain_path)) pretrain = torch.load(opt.pretrain_path) assert opt.arch == pretrain['arch'] model.load_state_dict(pretrain['state_dict']) if opt.model == 'densenet': model.classifier = nn.Linear( model.classifier.in_features, opt.n_finetune_classes) else: model.fc = nn.Linear(model.fc.in_features, opt.n_finetune_classes) parameters = get_fine_tuning_parameters(model, opt.ft_begin_index) return model, parameters return model, model.parameters()
'ResNet50': lambda: E_resnet(resnet50(pretrained=True)), 'ResNet101': lambda: E_resnet(resnet101(pretrained=True)), 'ResNet152': lambda: E_resnet(resnet152(pretrained=True)), 'DenseNet121': lambda: E_densenet(densenet121(pretrained=True)), 'DenseNet161': lambda: E_densenet(densenet161(pretrained=True)), 'DenseNet169': lambda: E_densenet(densenet169(pretrained=True)), 'DenseNet201': lambda: E_densenet(densenet201(pretrained=True)), 'SENet154': lambda: E_senet(senet154(pretrained="imagenet")), 'SE_ResNet50': lambda: E_senet(se_resnet50(pretrained="imagenet")), 'SE_ResNet101': lambda: E_senet(se_resnet101(pretrained="imagenet")), 'SE_ResNet152': lambda: E_senet(se_resnet152(pretrained="imagenet")), 'SE_ResNext50_32x4d': lambda: E_senet(se_resnext50_32x4d(pretrained="imagenet")), 'SE_ResNext101_32x4d': lambda: E_senet(se_resnext101_32x4d(pretrained="imagenet")) } def get_models(args): backbone = args.backbone