def __init__(self, pathModel, nnArchitecture, nnClassCount, transCrop): #---- Initialize the network if nnArchitecture == 'DENSE-NET-121': model = densenet121(False).cuda() elif nnArchitecture == 'DENSE-NET-169': model = densenet169(False).cuda() elif nnArchitecture == 'DENSE-NET-201': model = densenet201(False).cuda() model = torch.nn.DataParallel(model).cuda() modelCheckpoint = torch.load(pathModel) model.load_state_dict(modelCheckpoint['best_model_wts'], strict=False) self.model = model.module.features self.model.eval() #---- Initialize the weights self.weights = list(self.model.parameters())[-2] #---- Initialize the image transform - resize + normalize normalize = transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) transformList = [] transformList.append(transforms.Resize(transCrop)) transformList.append(transforms.ToTensor()) transformList.append(normalize) self.transformSequence = transforms.Compose(transformList)
def get_net(name): if name == 'densenet121': net = densenet121() elif name == 'densenet161': net = densenet161() elif name == 'densenet169': net = densenet169() elif name == 'googlenet': net = googlenet() elif name == 'inception_v3': net = inception_v3() elif name == 'mobilenet_v2': net = mobilenet_v2() elif name == 'resnet18': net = resnet18() elif name == 'resnet34': net = resnet34() elif name == 'resnet50': net = resnet50() elif name == 'resnet_orig': net = resnet_orig() elif name == 'vgg11_bn': net = vgg11_bn() elif name == 'vgg13_bn': net = vgg13_bn() elif name == 'vgg16_bn': net = vgg16_bn() elif name == 'vgg19_bn': net = vgg19_bn() else: print(f'{name} not a valid model name') sys.exit(0) return net.to(device)
def load_model(): model_path = os.path.join( os.path.dirname( os.path.abspath(__file__)), 'models/model18.pth' ) # print(model_path) model = densenet169() if torch.cuda.is_available(): model = model.cuda() model.load_state_dict( torch.load(model_path, map_location=lambda storage, loc: storage)['weights'] ) return model
def __init__(self, densenet_path, resnet_path, vgg_path): super(Ensemble, self).__init__() self.densenet = densenet169(pretrained=True, droprate=0) self.densenet.load_state_dict(torch.load(densenet_path)) self.resnet = resnet101() num_ftrs = self.resnet.fc.in_features self.resnet.fc = nn.Linear(num_ftrs, 1) self.resnet.load_state_dict(torch.load(resnet_path)) self.vgg = vgg16_bn() self.vgg.classifier[6] = nn.Linear(4096, 1) self.vgg.load_state_dict(torch.load(vgg_path))
print('Wt1 valid:', Wt1['valid']) class Loss(torch.nn.modules.Module): def __init__(self, Wt1, Wt0): super(Loss, self).__init__() self.Wt1 = Wt1 self.Wt0 = Wt0 def forward(self, inputs, targets, phase): loss = -(self.Wt1[phase] * targets * inputs.log() + self.Wt0[phase] * (1 - targets) * (1 - inputs).log()) return loss model = densenet169(pretrained=True) model = model.cuda() criterion = Loss(Wt1, Wt0) optimizer = torch.optim.Adam(model.parameters(), lr=0.0001) scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', patience=1, verbose=True) # #### Train model model = train_model(model, criterion, optimizer, dataloaders, scheduler,
extClassifier.load_state_dict(best_classifier_wts) model_serial = str(datetime.now().timestamp()) # torch.save(model.state_dict(), # os.path.join(r'C:\Users\wzuo\Developer\ML for APT\models', model_serial + '.model')) # torch.save(patchModel.state_dict(), # os.path.join(r'C:\Users\wzuo\Developer\ML for APT\models', model_serial + '.patchModel')) torch.save(extClassifier.state_dict(), os.path.join(model_path, model_serial + '.clsmodel')) with open(os.path.join(model_path, model_serial + '.json'), 'w') as fp: json.dump(param_dict, fp) return extClassifier model_global = MD.densenet201(pretrained=True) model_local = MD.densenet169(pretrained=True) param_dict['model_global'] = 'densenet201' param_dict['model_local'] = 'densenet169' # todo change architecture from here #patch_model = MD.densenet121(pretrained=True) #patch_model = MD.SimpleNet() #param_dict['patch_base'] = 'densenet121' for param in model_global.parameters(): param.requires_grad = False for param in model_local.parameters(): param.requires_grad = False # Parameters of newly constructed modules have requires_grad=True by default
mask = mask - np.min(mask) mask = mask / np.max(mask) mask = cv2.resize(mask, (img.shape[1], img.shape[0])) heatmap = cv2.applyColorMap(np.uint8(255 * mask), cv2.COLORMAP_JET) heatmap = np.float32(heatmap) / 255 cam = heatmap + np.float32(img) cam = cam / np.max(cam) cv2.imwrite(save_dir, np.uint8(255 * cam)) if __name__ == "__main__": img_dir = "image2.png" input_shapes = (3, 320, 320) model = densenet169(input_shapes=input_shapes, num_classes=1) model.load_state_dict(torch.load('model.pth', map_location='cpu')) model.cuda() model.eval() _transform = transforms.Compose([ transforms.Resize((320, 320)), transforms.ToTensor(), transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) ]) img = default_loader(img_dir) img = _transform(img).unsqueeze(0).cuda() cls_score, cam, grad_cams = model(img) print(cls_score)
def main(): torch.manual_seed(23) # Band_num = 2 # Tag_id = 4 data_l = data_loader_(batch_size=64,proportion=0.85, shuffle=True, data_add=2, onehot=False, data_size=224, nb_classes=100) print data_l.train_length print data_l.test_length # print 'loading....' # trX = np.load('bddog/trX.npy') # trY = np.load('bddog/trY.npy') # print 'load train data' # trX = torch.from_numpy(trX).float() # trY = torch.from_numpy(trY).long() # teX = np.load('bddog/teX.npy').astype(np.float) # teY = np.load('bddog/teY.npy') # print 'load test data' # teX[:, 0, ...] -= MEAN_VALUE[0] # teX[:, 1, ...] -= MEAN_VALUE[1] # teX[:, 2, ...] -= MEAN_VALUE[2] # teX = torch.from_numpy(teX).float() # teY = torch.from_numpy(teY).long() # print 'numpy data to tensor' # n_examples = len(trX) # n_classes = 100 # model = torch.load('models/resnet_model_pretrained_adam_2_2_SGD_1.pkl') model = densenet169(pretrained=True) print '===============================' print model # for param in model.parameters(): # param.requires_grad = False # model.classifier[-1] = nn.Linear(4096, 100) # n = model.classifier[-1].weight.size(1) # model.classifier[-1].weight.data.normal_(0, 0.01) # model.classifier[-1].bias.data.zero_() # VGG16 classifier层 # model.classifier = nn.Sequential( # nn.Linear(512 * 7 * 7, 4096), # nn.ReLU(inplace=True), # nn.Dropout(), # nn.Linear(4096, 4096), # nn.ReLU(inplace=True), # nn.Dropout(), # nn.Linear(4096, 100), # ) # count = 0 # print '===============================' # for module in model.modules(): # print '**** %d' % count # print(module) # count+=1 # print '===============================' # count= 0 # model.classifier[6] = nn.Linear(4096, 100) # for m in model.classifier: # if count == 6: # m = nn.Linear(4096, 100) # if isinstance(m, nn.Conv2d): # n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels # m.weight.data.normal_(0, math.sqrt(2. / n)) # if m.bias is not None: # m.bias.data.zero_() # elif isinstance(m, nn.BatchNorm2d): # m.weight.data.fill_(1) # m.bias.data.zero_() # elif isinstance(m, nn.Linear): # n = m.weight.size(1) # m.weight.data.normal_(0, 0.01) # m.bias.data.zero_() # count+=1 # try: # print model.classifier[0] # except Exception as e: # print e # print '===============================' # for module in model.modules()[-7:]: # print '****' # print(module) # resnet50 FC层 # model.group1 = nn.Sequential( # OrderedDict([ # ('fc', nn.Linear(2048, 100)) # ]) # ) model.classifier = nn.Linear(2208, 100) # ignored_params = list(map(id, model.group2.parameters())) # base_params = filter(lambda p: id(p) not in ignored_params, # model.parameters()) # print '===============================' # print model model = model.cuda() loss = torch.nn.CrossEntropyLoss(size_average=True) loss = loss.cuda() # 对局部优化 # optimizer = optim.SGD(model.group2.parameters(), lr=(1e-03), momentum=0.9,weight_decay=0.001) # optimizer = optim.Adam([{'params':model.layer4[2].parameters()}, # {'params':model.group2.parameters()} # ],lr=(1e-04),eps=1e-08, betas=(0.9, 0.999), weight_decay=0.0005) # optimizer_a = optim.Adam([{'params':model.group2.parameters()} # ],lr=(1e-04)) # optimizer = optim.Adam(model.group1.parameters(),lr=(1e-04)) # optimizer.lr = (1e-04) # print optimizer.lr # print optimizer.momentum # for param_group in optimizer.param_groups: # print param_group['lr'] # 全局优化 optimizer = optim.SGD(model.parameters(), lr=(0.001), momentum=0.9, weight_decay=0.0005) batch_size = data_l.batch_szie data_aug_num = data_l.data_add mini_batch_size = batch_size / data_aug_num epochs = 1000 print '1' for e in range(epochs): cost = 0.0 train_acc = 0.0 if e == 4: for param_group in optimizer.param_groups: param_group['lr'] = param_group['lr'] * 0.3 if e == 8: for param_group in optimizer.param_groups: param_group['lr'] = param_group['lr'] * 0.3 num_batches_train = data_l.train_length / mini_batch_size print num_batches_train train_acc= 0.0 cost = 0.0 k =1 for k in range(num_batches_train+1): batch_train_data_X, batch_train_data_Y = data_l.get_train_data() batch_train_data_X = batch_train_data_X.transpose(0, 3, 1, 2) # batch_train_data_X[:, 0, ...] -= MEAN_VALUE[0] # batch_train_data_X[:, 1, ...] -= MEAN_VALUE[1] # batch_train_data_X[:, 2, ...] -= MEAN_VALUE[2] # print batch_train_data_X.shape # print batch_train_data_Y.shape # batch_train_data_X = preprocess_input(batch_train_data_X) torch_batch_train_data_X = torch.from_numpy(batch_train_data_X).float() torch_batch_train_data_Y = torch.from_numpy(batch_train_data_Y).long() cost_temp, acc_temp = train(model, loss, optimizer, torch_batch_train_data_X, torch_batch_train_data_Y) train_acc += acc_temp cost += cost_temp if (k + 1) % 10 == 0: print 'now step train loss is : %f' % (cost_temp) print 'now step train acc is : %f' % (acc_temp) if (k + 1) % 20 == 0: print 'all average train loss is : %f' % (cost / (k + 1)) print 'all average train acc is : %f' % (train_acc / (k + 1)) # if (k + 1) % 100 == 0: # model.training = False # acc = 0.0 # num_batches_test = data_l.test_length / batch_size # for j in range(num_batches_test+1): # teX, teY = data_l.get_test_data() # teX = teX.transpose(0, 3, 1, 2) # # teX[:, 0, ...] -= MEAN_VALUE[0] # # teX[:, 1, ...] -= MEAN_VALUE[1] # # teX[:, 2, ...] -= MEAN_VALUE[2] # teX = preprocess_input(teX) # teX = torch.from_numpy(teX).float() # # teY = torch.from_numpy(teY).long() # predY = predict(model, teX) # # print predY.dtype # # print teY[start:end] # acc += 1. * np.mean(predY == teY) # # print ('Epoch %d ,Step %d, acc = %.2f%%'%(e,k,100.*np.mean(predY==teY[start:end]))) # model.training = True # print 'Epoch %d ,Step %d, all test acc is : %f' % (e, k, acc / num_batches_test) # torch.save(model, 'models/inception_model_pretrained_%s_%s_%s_1.pkl' % ('SGD', str(e), str(k))) # model.training = False acc = 0.0 num_batches_test = data_l.test_length / batch_size for j in range(num_batches_test+1): teX, teY = data_l.get_test_data() teX = teX.transpose(0, 3, 1, 2) # teX[:, 0, ...] -= 0.5 # teX[:, 1, ...] -= 0.5 # teX[:, 2, ...] -= 0.5 # teX = preprocess_input(teX) teX = torch.from_numpy(teX).float() # teY = torch.from_numpy(teY).long() predY = predict(model, teX) # print predY.dtype # print teY[start:end] acc += 1. * np.mean(predY == teY) # print ('Epoch %d ,Step %d, acc = %.2f%%'%(e,k,100.*np.mean(predY==teY[start:end]))) # model.training = True print 'Epoch %d ,Step %d, all test acc is : %f' % (e, k, acc / num_batches_test) torch.save(model, 'models/densenet161_model_pretrained_%s_%s_%s_4.pkl' % ('SGD', str(e), str(k))) print 'train over'
data_cat = ['train', 'valid'] # data categories dataloaders = get_dataloaders(study_data, batch_size) dataset_sizes = {x: len(study_data[x]) for x in data_cat} # tai = total abnormal images, tni = total normal images tai = {x: get_count(study_data[x], 'positive') for x in data_cat} tni = {x: get_count(study_data[x], 'negative') for x in data_cat} # Find the weights of abnormal images and normal images Wt1 = {x: (tni[x] / (tni[x] + tai[x])) for x in data_cat} Wt0 = {x: (tai[x] / (tni[x] + tai[x])) for x in data_cat} # For training & testing individual models if model_type != 'ensemble': if model_type == "dense": model = densenet169(pretrained=True, droprate=droprate) elif model_type == "vgg": model = vgg16_bn(pretrained=True) model.classifier[6] = nn.Linear(4096, 1) elif model_type == "shufflenet": model = shufflenet_v2_x1_0() num_ftrs = model.fc.in_features model.fc = nn.Linear(num_ftrs, 1) else: model = resnet101() num_ftrs = model.fc.in_features model.fc = nn.Linear(num_ftrs, 1)
torch_dataset_train = data.TensorDataset(train_data, train_label) torch_dataset_val = data.TensorDataset(val_data, val_label) train_loader = data.DataLoader(dataset=torch_dataset_train, batch_size=BATCH_SIZE, shuffle=True, drop_last=True) val_loader = data.DataLoader(dataset=torch_dataset_val, batch_size=BATCH_SIZE, shuffle=False, drop_last=True) test_loader = data.DataLoader(test_data, batch_size=test_num, shuffle=False, drop_last=True) model = densenet169() model = ACSConverter(model) model = model.cuda() criterion = torch.nn.CrossEntropyLoss() optimizer = torch.optim.Adam(model.parameters(), lr=0.0001) # train and evaluate for epoch in range(NUM_EPOCHS): print(epoch) train_loss = 0 train_acc = 0 val_loss = 0 val_acc = 0 for step, (batch_x, batch_y) in enumerate(train_loader):
def generate_model(opt): assert opt.model in [ 'resnet', 'preresnet', 'wideresnet', 'resnext', 'densenet' ] if opt.model == 'resnet': assert opt.model_depth in [10, 18, 34, 50, 101, 152, 200] from resnet import get_fine_tuning_parameters if opt.model_depth == 10: model = resnet.resnet10(num_classes=opt.n_classes, shortcut_type=opt.resnet_shortcut, sample_size=opt.sample_size, sample_duration=opt.sample_duration) elif opt.model_depth == 18: model = resnet.resnet18(num_classes=opt.n_classes, shortcut_type=opt.resnet_shortcut, sample_size=opt.sample_size, sample_duration=opt.sample_duration) elif opt.model_depth == 34: model = resnet.resnet34(num_classes=opt.n_classes, shortcut_type=opt.resnet_shortcut, sample_size=opt.sample_size, sample_duration=opt.sample_duration) elif opt.model_depth == 50: model = resnet.resnet50(num_classes=opt.n_classes, shortcut_type=opt.resnet_shortcut, sample_size=opt.sample_size, sample_duration=opt.sample_duration) elif opt.model_depth == 101: model = resnet.resnet101(num_classes=opt.n_classes, shortcut_type=opt.resnet_shortcut, sample_size=opt.sample_size, sample_duration=opt.sample_duration) elif opt.model_depth == 152: model = resnet.resnet152(num_classes=opt.n_classes, shortcut_type=opt.resnet_shortcut, sample_size=opt.sample_size, sample_duration=opt.sample_duration) elif opt.model_depth == 200: model = resnet.resnet200(num_classes=opt.n_classes, shortcut_type=opt.resnet_shortcut, sample_size=opt.sample_size, sample_duration=opt.sample_duration) elif opt.model == 'wideresnet': assert opt.model_depth in [50] from models.wide_resnet import get_fine_tuning_parameters if opt.model_depth == 50: model = wide_resnet.resnet50(num_classes=opt.n_classes, shortcut_type=opt.resnet_shortcut, k=opt.wide_resnet_k, sample_size=opt.sample_size, sample_duration=opt.sample_duration) elif opt.model == 'resnext': assert opt.model_depth in [50, 101, 152] from models.resnext import get_fine_tuning_parameters if opt.model_depth == 50: model = resnext.resnet50(num_classes=opt.n_classes, shortcut_type=opt.resnet_shortcut, cardinality=opt.resnext_cardinality, sample_size=opt.sample_size, sample_duration=opt.sample_duration) elif opt.model_depth == 101: model = resnext.resnet101(num_classes=opt.n_classes, shortcut_type=opt.resnet_shortcut, cardinality=opt.resnext_cardinality, sample_size=opt.sample_size, sample_duration=opt.sample_duration) elif opt.model_depth == 152: model = resnext.resnet152(num_classes=opt.n_classes, shortcut_type=opt.resnet_shortcut, cardinality=opt.resnext_cardinality, sample_size=opt.sample_size, sample_duration=opt.sample_duration) elif opt.model == 'preresnet': assert opt.model_depth in [18, 34, 50, 101, 152, 200] from models.pre_act_resnet import get_fine_tuning_parameters if opt.model_depth == 18: model = pre_act_resnet.resnet18( num_classes=opt.n_classes, shortcut_type=opt.resnet_shortcut, sample_size=opt.sample_size, sample_duration=opt.sample_duration) elif opt.model_depth == 34: model = pre_act_resnet.resnet34( num_classes=opt.n_classes, shortcut_type=opt.resnet_shortcut, sample_size=opt.sample_size, sample_duration=opt.sample_duration) elif opt.model_depth == 50: model = pre_act_resnet.resnet50( num_classes=opt.n_classes, shortcut_type=opt.resnet_shortcut, sample_size=opt.sample_size, sample_duration=opt.sample_duration) elif opt.model_depth == 101: model = pre_act_resnet.resnet101( num_classes=opt.n_classes, shortcut_type=opt.resnet_shortcut, sample_size=opt.sample_size, sample_duration=opt.sample_duration) elif opt.model_depth == 152: model = pre_act_resnet.resnet152( num_classes=opt.n_classes, shortcut_type=opt.resnet_shortcut, sample_size=opt.sample_size, sample_duration=opt.sample_duration) elif opt.model_depth == 200: model = pre_act_resnet.resnet200( num_classes=opt.n_classes, shortcut_type=opt.resnet_shortcut, sample_size=opt.sample_size, sample_duration=opt.sample_duration) elif opt.model == 'densenet': assert opt.model_depth in [121, 169, 201, 264] from models.densenet import get_fine_tuning_parameters if opt.model_depth == 121: model = densenet.densenet121(num_classes=opt.n_classes, sample_size=opt.sample_size, sample_duration=opt.sample_duration) elif opt.model_depth == 169: model = densenet.densenet169(num_classes=opt.n_classes, sample_size=opt.sample_size, sample_duration=opt.sample_duration) elif opt.model_depth == 201: model = densenet.densenet201(num_classes=opt.n_classes, sample_size=opt.sample_size, sample_duration=opt.sample_duration) elif opt.model_depth == 264: model = densenet.densenet264(num_classes=opt.n_classes, sample_size=opt.sample_size, sample_duration=opt.sample_duration) if not opt.no_cuda: model = model.cuda() model = nn.DataParallel(model, device_ids=None) if opt.pretrain_path: print('loading pretrained model {}'.format(opt.pretrain_path)) pretrain = torch.load(opt.pretrain_path) assert opt.arch == pretrain['arch'] model.load_state_dict(pretrain['state_dict']) if opt.model == 'densenet': model.module.classifier = nn.Linear( model.module.classifier.in_features, opt.n_finetune_classes) model.module.classifier = model.module.classifier.cuda() else: model.module.fc = nn.Linear(model.module.fc.in_features, opt.n_finetune_classes) model.module.fc = model.module.fc.cuda() parameters = get_fine_tuning_parameters(model, opt.ft_begin_index) return model, parameters else: if opt.pretrain_path: print('loading pretrained model {}'.format(opt.pretrain_path)) pretrain = torch.load(opt.pretrain_path) assert opt.arch == pretrain['arch'] model.load_state_dict(pretrain['state_dict']) if opt.model == 'densenet': model.classifier = nn.Linear(model.classifier.in_features, opt.n_finetune_classes) else: model.fc = nn.Linear(model.fc.in_features, opt.n_finetune_classes) parameters = get_fine_tuning_parameters(model, opt.ft_begin_index) return model, parameters return model, model.parameters()
def get_model(args): network = args.network if network == 'vgg11': model = vgg.vgg11(num_classes=args.class_num) elif network == 'vgg13': model = vgg.vgg13(num_classes=args.class_num) elif network == 'vgg16': model = vgg.vgg16(num_classes=args.class_num) elif network == 'vgg19': model = vgg.vgg19(num_classes=args.class_num) elif network == 'vgg11_bn': model = vgg.vgg11_bn(num_classes=args.class_num) elif network == 'vgg13_bn': model = vgg.vgg13_bn(num_classes=args.class_num) elif network == 'vgg16_bn': model = vgg.vgg16_bn(num_classes=args.class_num) elif network == 'vgg19_bn': model = vgg.vgg19_bn(num_classes=args.class_num) elif network == 'resnet18': model = models.resnet18(num_classes=args.class_num) model.conv1 = torch.nn.Conv2d(in_channels=1, out_channels=model.conv1.out_channels, kernel_size=model.conv1.kernel_size, stride=model.conv1.stride, padding=model.conv1.padding, bias=model.conv1.bias) elif network == 'resnet34': model = models.resnet34(num_classes=args.class_num) model.conv1 = torch.nn.Conv2d(in_channels=1, out_channels=model.conv1.out_channels, kernel_size=model.conv1.kernel_size, stride=model.conv1.stride, padding=model.conv1.padding, bias=model.conv1.bias) elif network == 'resnet50': model = models.resnet50(num_classes=args.class_num) model.conv1 = torch.nn.Conv2d(in_channels=1, out_channels=model.conv1.out_channels, kernel_size=model.conv1.kernel_size, stride=model.conv1.stride, padding=model.conv1.padding, bias=model.conv1.bias) elif network == 'resnet101': model = models.resnet101(num_classes=args.class_num) model.conv1 = torch.nn.Conv2d(in_channels=1, out_channels=model.conv1.out_channels, kernel_size=model.conv1.kernel_size, stride=model.conv1.stride, padding=model.conv1.padding, bias=model.conv1.bias) elif network == 'resnet152': model = models.resnet152(num_classes=args.class_num) model.conv1 = torch.nn.Conv2d(in_channels=1, out_channels=model.conv1.out_channels, kernel_size=model.conv1.kernel_size, stride=model.conv1.stride, padding=model.conv1.padding, bias=model.conv1.bias) elif network == 'densenet121': model = densenet.densenet121(num_classes=args.class_num) elif network == 'densenet169': model = densenet.densenet169(num_classes=args.class_num) elif network == 'densenet161': model = densenet.densenet161(num_classes=args.class_num) elif network == 'densenet201': model = densenet.densenet201(num_classes=args.class_num) return model
def generate_model(opt): assert opt.mode in ['score', 'feature'] if opt.mode == 'score': last_fc = True elif opt.mode == 'feature': last_fc = False assert opt.model_name in [ 'resnet', 'preresnet', 'wideresnet', 'resnext', 'densenet' ] if opt.model_name == 'resnet': assert opt.model_depth in [10, 18, 34, 50, 101, 152, 200] if opt.model_depth == 10: model = resnet.resnet10(num_classes=opt.n_classes, shortcut_type=opt.resnet_shortcut, sample_size=opt.sample_size, sample_duration=opt.sample_duration, last_fc=last_fc) elif opt.model_depth == 18: model = resnet.resnet18(num_classes=opt.n_classes, shortcut_type=opt.resnet_shortcut, sample_size=opt.sample_size, sample_duration=opt.sample_duration, last_fc=last_fc) elif opt.model_depth == 34: model = resnet.resnet34(num_classes=opt.n_classes, shortcut_type=opt.resnet_shortcut, sample_size=opt.sample_size, sample_duration=opt.sample_duration, last_fc=last_fc) elif opt.model_depth == 50: model = resnet.resnet50(num_classes=opt.n_classes, shortcut_type=opt.resnet_shortcut, sample_size=opt.sample_size, sample_duration=opt.sample_duration, last_fc=last_fc) elif opt.model_depth == 101: model = resnet.resnet101(num_classes=opt.n_classes, shortcut_type=opt.resnet_shortcut, sample_size=opt.sample_size, sample_duration=opt.sample_duration, last_fc=last_fc) elif opt.model_depth == 152: model = resnet.resnet152(num_classes=opt.n_classes, shortcut_type=opt.resnet_shortcut, sample_size=opt.sample_size, sample_duration=opt.sample_duration, last_fc=last_fc) elif opt.model_depth == 200: model = resnet.resnet200(num_classes=opt.n_classes, shortcut_type=opt.resnet_shortcut, sample_size=opt.sample_size, sample_duration=opt.sample_duration, last_fc=last_fc) elif opt.model_name == 'wideresnet': assert opt.model_depth in [50] if opt.model_depth == 50: model = wide_resnet.resnet50(num_classes=opt.n_classes, shortcut_type=opt.resnet_shortcut, k=opt.wide_resnet_k, sample_size=opt.sample_size, sample_duration=opt.sample_duration, last_fc=last_fc) elif opt.model_name == 'resnext': assert opt.model_depth in [50, 101, 152] if opt.model_depth == 50: model = resnext.resnet50(num_classes=opt.n_classes, shortcut_type=opt.resnet_shortcut, cardinality=opt.resnext_cardinality, sample_size=opt.sample_size, sample_duration=opt.sample_duration, last_fc=last_fc) elif opt.model_depth == 101: model = resnext.resnet101(num_classes=opt.n_classes, shortcut_type=opt.resnet_shortcut, cardinality=opt.resnext_cardinality, sample_size=opt.sample_size, sample_duration=opt.sample_duration, last_fc=last_fc) elif opt.model_depth == 152: model = resnext.resnet152(num_classes=opt.n_classes, shortcut_type=opt.resnet_shortcut, cardinality=opt.resnext_cardinality, sample_size=opt.sample_size, sample_duration=opt.sample_duration, last_fc=last_fc) elif opt.model_name == 'preresnet': assert opt.model_depth in [18, 34, 50, 101, 152, 200] if opt.model_depth == 18: model = pre_act_resnet.resnet18( num_classes=opt.n_classes, shortcut_type=opt.resnet_shortcut, sample_size=opt.sample_size, sample_duration=opt.sample_duration, last_fc=last_fc) elif opt.model_depth == 34: model = pre_act_resnet.resnet34( num_classes=opt.n_classes, shortcut_type=opt.resnet_shortcut, sample_size=opt.sample_size, sample_duration=opt.sample_duration, last_fc=last_fc) elif opt.model_depth == 50: model = pre_act_resnet.resnet50( num_classes=opt.n_classes, shortcut_type=opt.resnet_shortcut, sample_size=opt.sample_size, sample_duration=opt.sample_duration, last_fc=last_fc) elif opt.model_depth == 101: model = pre_act_resnet.resnet101( num_classes=opt.n_classes, shortcut_type=opt.resnet_shortcut, sample_size=opt.sample_size, sample_duration=opt.sample_duration, last_fc=last_fc) elif opt.model_depth == 152: model = pre_act_resnet.resnet152( num_classes=opt.n_classes, shortcut_type=opt.resnet_shortcut, sample_size=opt.sample_size, sample_duration=opt.sample_duration, last_fc=last_fc) elif opt.model_depth == 200: model = pre_act_resnet.resnet200( num_classes=opt.n_classes, shortcut_type=opt.resnet_shortcut, sample_size=opt.sample_size, sample_duration=opt.sample_duration, last_fc=last_fc) elif opt.model_name == 'densenet': assert opt.model_depth in [121, 169, 201, 264] if opt.model_depth == 121: model = densenet.densenet121(num_classes=opt.n_classes, sample_size=opt.sample_size, sample_duration=opt.sample_duration, last_fc=last_fc) elif opt.model_depth == 169: model = densenet.densenet169(num_classes=opt.n_classes, sample_size=opt.sample_size, sample_duration=opt.sample_duration, last_fc=last_fc) elif opt.model_depth == 201: model = densenet.densenet201(num_classes=opt.n_classes, sample_size=opt.sample_size, sample_duration=opt.sample_duration, last_fc=last_fc) elif opt.model_depth == 264: model = densenet.densenet264(num_classes=opt.n_classes, sample_size=opt.sample_size, sample_duration=opt.sample_duration, last_fc=last_fc) if not opt.no_cuda: model = model.cuda() model = nn.DataParallel(model, device_ids=None) return model
def inference(self, inputs, is_training=True, reuse=False): """ 网络前向传播计算,输出logits张量,keep_prob为drop out参数,预测时置为1 """ if self.model == "vgg16": if self.scope == None: self.scope = 'vgg_16' logits = vgg16.vgg_16(inputs=inputs, num_classes=self.n_classes, is_training=is_training, reuse=reuse, dropout_keep_prob=self.dropprob, scope=self.scope, weight_decay=self.l2_rate, use_batch_norm=self.use_bn, batch_norm_decay=self.bn_decay, batch_norm_epsilon=self.bn_epsilon, batch_norm_scale=self.bn_scale) elif self.model == "res50": if self.scope == None: self.scope = 'resnet_v1_50' logits = resnet.resnet_50(inputs=inputs, num_classes=self.n_classes, is_training=is_training, reuse=reuse, use_se_module=False, scope=self.scope, weight_decay=self.l2_rate, use_batch_norm=self.use_bn, batch_norm_decay=self.bn_decay, batch_norm_epsilon=self.bn_epsilon, batch_norm_scale=self.bn_scale) elif self.model == "res50_senet": if self.scope == None: self.scope = 'resnet_v1_50' logits = resnet.resnet_50(inputs=inputs, num_classes=self.n_classes, is_training=is_training, reuse=reuse, use_se_module=True, scope=self.scope, weight_decay=self.l2_rate, use_batch_norm=self.use_bn, batch_norm_decay=self.bn_decay, batch_norm_epsilon=self.bn_epsilon, batch_norm_scale=self.bn_scale) elif self.model == "densenet": if self.scope == None: self.scope = 'densenet169' logits = densenet.densenet169(inputs=inputs, num_classes=self.n_classes, is_training=is_training, reuse=reuse, dropout_keep_prob=self.dropprob, scope=self.scope, weight_decay=self.l2_rate, use_batch_norm=self.use_bn, batch_norm_decay=self.bn_decay, batch_norm_epsilon=self.bn_epsilon, batch_norm_scale=self.bn_scale) else: raise ValueError("Unknown cost function: " % cost_name) return tf.squeeze(logits)