def load_models(model_folder, gpu_id=0): """ load segmentation model from folder :param model_folder: the folder containing the segmentation model :param gpu_id: the gpu device id to run the segmentation model :return: a dictionary containing the model and inference parameters """ assert os.path.isdir(model_folder), 'Model folder does not exist: {}'.format(model_folder) # load inference config file infer_cfg = load_config(os.path.join(model_folder, 'infer_config.py')) models = edict() models.infer_cfg = infer_cfg # load coarse model if it is enabled if models.infer_cfg.general.single_scale == 'coarse': coarse_model_folder = os.path.join(model_folder, models.infer_cfg.coarse.model_name) coarse_model = load_single_model(coarse_model_folder, gpu_id) models.coarse_model = coarse_model models.fine_model = None elif models.infer_cfg.general.single_scale == 'fine': fine_model_folder = os.path.join(model_folder, models.infer_cfg.fine.model_name) fine_model = load_single_model(fine_model_folder, gpu_id) models.fine_model = fine_model models.coarse_model = None elif models.infer_cfg.general.single_scale == 'DISABLE': coarse_model_folder = os.path.join(model_folder, models.infer_cfg.coarse.model_name) coarse_model = load_single_model(coarse_model_folder, gpu_id) models.coarse_model = coarse_model fine_model_folder = os.path.join(model_folder, models.infer_cfg.fine.model_name) fine_model = load_single_model(fine_model_folder, gpu_id) models.fine_model = fine_model else: raise ValueError('Unsupported single scale type!') return models
def load_seg_model(model_folder, gpu_id=0): """ load segmentation model from folder :param model_folder: the folder containing the segmentation model :param gpu_id: the gpu device id to run the segmentation model :return: a dictionary containing the model and inference parameters """ assert os.path.isdir(model_folder), 'Model folder does not exist: {}'.format(model_folder) model = edict() # load inference config file latest_checkpoint_dir = get_checkpoint_folder(os.path.join(model_folder, 'checkpoints'), -1) infer_cfg = load_config(os.path.join(model_folder, 'config_infer.py')) model.infer_cfg = infer_cfg if len(gpu_id) >= 0: os.environ['CUDA_VISIBLE_DEVICES'] = '{},{}'.format(int(gpu_id[0]),int(gpu_id[1])) # load model state chk_file = os.path.join(latest_checkpoint_dir, 'params.pth') state = torch.load(chk_file) # load network module net_module = importlib.import_module('segmentation3d.network.' + state['net']) net = net_module.SegmentationNet(state['in_channels'], state['out_channels']) net = nn.parallel.DataParallel(net) net.load_state_dict(state['state_dict']) net.eval() if len(gpu_id) >= 0: net = net.cuda() del os.environ['CUDA_VISIBLE_DEVICES'] model.net = net model.spacing = state['spacing'] model.max_stride = state['max_stride'] model.interpolation = state['interpolation'] return model
def train(config_file): """ Medical image segmentation training engine :param config_file: the input configuration file :return: None """ assert os.path.isfile(config_file), 'Config not found: {}'.format( config_file) # load config file cfg = load_config(config_file) # clean the existing folder if training from scratch if cfg.general.resume_epoch < 0 and os.path.isdir(cfg.general.save_dir): shutil.rmtree(cfg.general.save_dir) # enable logging log_file = os.path.join(cfg.general.save_dir, 'train_log.txt') logger = setup_logger(log_file, 'seg3d') # control randomness during training np.random.seed(cfg.general.seed) torch.manual_seed(cfg.general.seed) if cfg.general.num_gpus > 0: torch.cuda.manual_seed(cfg.general.seed) # dataset dataset = SegmentationDataset( imlist_file=cfg.general.imseg_list, num_classes=cfg.dataset.num_classes, spacing=cfg.dataset.spacing, crop_size=cfg.dataset.crop_size, default_values=cfg.dataset.default_values, sampling_method=cfg.dataset.sampling_method, random_translation=cfg.dataset.random_translation, interpolation=cfg.dataset.interpolation, crop_normalizers=cfg.dataset.crop_normalizers) sampler = EpochConcateSampler(dataset, cfg.train.epochs) #print('total index for training',len(sampler)) data_loader = DataLoader(dataset, sampler=sampler, batch_size=cfg.train.batchsize, num_workers=cfg.train.num_threads, pin_memory=True) net_module = importlib.import_module('segmentation3d.network.' + cfg.net.name) net = net_module.SegmentationNet(dataset.num_modality(), cfg.dataset.num_classes) #1,2 max_stride = net.max_stride() #return 16 net_module.parameters_kaiming_init(net) #initial weights if cfg.general.num_gpus > 0: net = nn.parallel.DataParallel(net, device_ids=list( range(cfg.general.num_gpus))) net = net.cuda() assert np.all( np.array(cfg.dataset.crop_size) % max_stride == 0 ), 'crop size not divisible by max stride' #adjust crop size for down conv # training optimizer opt = optim.Adam(net.parameters(), lr=cfg.train.lr, betas=cfg.train.betas) # 1e-4 and (0.9,0.999) # load checkpoint if resume epoch > 0 for keep training if cfg.general.resume_epoch >= 0: last_save_epoch, batch_start = load_checkpoint( cfg.general.resume_epoch, net, opt, cfg.general.save_dir) else: last_save_epoch, batch_start = 0, 0 batch_idx = batch_start data_iter = iter(data_loader) if cfg.loss.name == 'Focal': # reuse focal loss if exists loss_func = FocalLoss(class_num=cfg.dataset.num_classes, alpha=cfg.loss.obj_weight, gamma=cfg.loss.focal_gamma) elif cfg.loss.name == 'Dice': loss_func = MultiDiceLoss(weights=cfg.loss.obj_weight, num_class=cfg.dataset.num_classes, use_gpu=cfg.general.num_gpus > 0) else: raise ValueError('Unknown loss function') writer = SummaryWriter(os.path.join(cfg.general.save_dir, 'tensorboard')) # loop over batches for i in range(len(data_loader)): #epoches begin_t = time.time() crops, masks, frames, filenames = data_iter.next() print('training ', filenames) #print('crops',crops.shape) #print('masks',masks.shape) if cfg.general.num_gpus > 0: crops, masks = crops.cuda(), masks.cuda() # clear previous gradients opt.zero_grad() # network forward and backward outputs = net(crops) #print('outputs',outputs.shape) train_loss = loss_func(outputs, masks) #each class has a loss ang get average train_loss.backward() # update weights opt.step() # save training crops for visualization if cfg.debug.save_inputs: batch_size = crops.size(0) save_intermediate_results( list(range(batch_size)), crops, masks, None, frames, filenames, os.path.join(cfg.general.save_dir, 'batch_{}'.format(i))) epoch_idx = batch_idx * cfg.train.batchsize // len(dataset) batch_idx += 1 batch_duration = time.time() - begin_t sample_duration = batch_duration * 1.0 / cfg.train.batchsize if (batch_idx + 1) % 1 == 0: begin = 2 end = 31 r = 3 image_num = int((end - begin - 1) / r) '''show result''' image = crops[0, 0:1, :, begin:end:r, :].permute(2, 0, 1, 3) grid_image = make_grid(image, image_num, normalize=True) writer.add_image('/train/image', grid_image, batch_idx) '''show label''' label = masks[0, 0:1, :, begin:end:r, :].permute(2, 0, 1, 3) grid_label = make_grid(label, image_num) writer.add_image('/train/label', grid_label, batch_idx) '''show pred''' pred1 = outputs[0, 0:1, :, begin:end:r, :].permute(2, 0, 1, 3) #1:2 or 0:1 pred1[pred1 <= 0.8] = 0 pred1[pred1 > 0.8] = 1 grid_pred1 = make_grid(pred1, image_num) writer.add_image('/train/pred0:1', grid_pred1, batch_idx) '''show pred''' pred2 = outputs[0, 1:2, :, begin:end:r, :].permute(2, 0, 1, 3) #1:2 or 0:1 pred2[pred2 > 0.8] = 1 pred2[pred2 <= 0.8] = 0 grid_pred2 = make_grid(pred2, image_num) writer.add_image('/train/pred1:2', grid_pred2, batch_idx) ''' ''' grid_image = grid_image.cpu().detach().numpy().transpose((1, 2, 0)) grid_label = grid_label.cpu().detach().numpy().transpose((1, 2, 0)) grid_pred1 = grid_pred1.cpu().detach().numpy().transpose((1, 2, 0)) grid_pred2 = grid_pred2.cpu().detach().numpy().transpose((1, 2, 0)) #fig = plt.figure() #ax = fig.add_subplot(411) #ax.imshow(grid_image[:,:,0],'gray')#,vmin=0,vmax=1.) #ax = fig.add_subplot(412) #cs = ax.imshow(grid_label[:,:,0],'gray',vmin=0,vmax=1.) #ax=fig.add_subplot(413) #cs = ax.imshow(grid_pred1[:,:,0],'gray',vmin=0,vmax=1.) #ax=fig.add_subplot(414) #cs = ax.imshow(grid_pred2[:,:,0],'gray',vmin=0,vmax=1.) #fig.colorbar(cs, ax=ax,shrink = 0.9) #writer.add_figure('/train/pred_result',fig,batch_idx) #fig.clear() pred1 = label2rgb(grid_pred1[:, :, 0], grid_image[:, :, 0], bg_label=0) pred2 = label2rgb(grid_pred2[:, :, 0], grid_image[:, :, 0], bg_label=0) gt = label2rgb(grid_label[:, :, 0], grid_image[:, :, 0], bg_label=0) fig = plt.figure() ax = fig.add_subplot(311) ax.imshow(gt) ax.set_title('label on image') ax = fig.add_subplot(312) ax.imshow(pred1) ax.set_title('pred0:1 on image') ax = fig.add_subplot(313) ax.imshow(pred2) ax.set_title('pred1:2 on image') fig.tight_layout() writer.add_figure('/train/results', fig, batch_idx) fig.clear() # print training loss per batch msg = 'epoch: {}, batch: {}, train_loss: {:.4f}, time: {:.4f} s/vol' msg = msg.format(epoch_idx, batch_idx, train_loss.item(), sample_duration) logger.info(msg) # save checkpoint if epoch_idx != 0 and (epoch_idx % cfg.train.save_epochs == 0): if last_save_epoch != epoch_idx: save_checkpoint(net, opt, epoch_idx, batch_idx, cfg, config_file, max_stride, dataset.num_modality()) last_save_epoch = epoch_idx '''evaluate testing set''' writer.add_scalar('Train/Loss', train_loss.item(), batch_idx) writer.close()
def test(config_file): '''Medical image segmentation testing engine :param config_file: the input confituration file :return: NONE ''' assert os.path.isfile(config_file), 'Config not found: {}'.format(config_file) total_metric = 0.0 metric_dict = OrderedDict() metric_dict['name'] = [] metric_dict['dice'] = [] metric_dict['jaccard'] = [] cfg = load_config(config_file) #print('cfg',cfg) #log_file = os.path.join(cfg.general.save_dir,'test_log.txt') #logger = setup_logger(log_file,'seg3d_test') np.random.seed(cfg.general.seed) torch.manual_seed(cfg.general.seed) if cfg.general.num_gpus > 0: torch.cuda.manual_seed(cfg.general.seed) #dataset = SegmentationTestDataset(cfg.test.imseg_list) dataset = SegmentationDataset( imlist_file=cfg.test.imseg_list, num_classes=cfg.dataset.num_classes, spacing=cfg.dataset.spacing, crop_size=cfg.dataset.crop_size, default_values=cfg.dataset.default_values, sampling_method=cfg.dataset.sampling_method, random_translation=cfg.dataset.random_translation, interpolation=cfg.dataset.interpolation, crop_normalizers=cfg.dataset.crop_normalizers ) testloader = DataLoader(dataset,batch_size=cfg.test.batch_size, shuffle=False, num_workers=cfg.test.num_threads,pin_memory=True) print('dataset length',len(testloader)) net_model = importlib.import_module('segmentation3d.network.'+cfg.net.name) net = net_model.SegmentationNet(dataset.num_modality(),cfg.dataset.num_classes) net = nn.parallel.DataParallel(net,device_ids=list(range(cfg.general.num_gpus))) net = net.cuda() epoch_idx = cfg.test.test_epoch save_dir = cfg.general.save_dir state = load_testmodel(epoch_idx,net,save_dir) net.load_state_dict(state['state_dict']) net.eval() for ii, (image,label,fram,name) in enumerate(testloader): name = name[0][6:-4] print('testing patient',name) print('image',image.shape) #print('label',label.shape) prediction,score_map = test_single_case(net,image[0][0],label[0][0],stridex,stridey,stridez,patch_size,num_classes=cfg.dataset.num_classes) image = image[0][0].numpy() label = label[0][0].numpy() print('prediction',prediction.shape) print('label',label.shape) if np.sum(prediction) == 0: single_metric = (0,0) else: single_metric = calculate_metric_percase(prediction,label) print('single_metric',single_metric) metric_dict['name'].append(name) metric_dict['dice'].append(single_metric[0]) metric_dict['jaccard'].append(single_metric[1]) total_metric += np.asarray(single_metric) if cfg.test.save == True: test_save_path_temp = os.path.join(cfg.general.save_dir+cfg.test.save_filename+'/',name) if not os.path.exists(test_save_path_temp): os.makedirs(test_save_path_temp) nib.save(nib.Nifti1Image(prediction.astype(np.float32),np.eye(4)),test_save_path_temp+'/'+'pred.nii.gz') nib.save(nib.Nifti1Image(image.astype(np.float32),np.eye(4)),test_save_path_temp+'/'+'img.nii.gz') nib.save(nib.Nifti1Image(label.astype(np.float32),np.eye(4)),test_save_path_temp+'/'+'gt.nii.gz') avg_metric = total_metric / len(testloader) metric_csv = pd.DataFrame(metric_dict) metric_csv.to_csv(cfg.general.save_dir+cfg.test.save_filename+'/metric.csv',index=False) print('average metric is {}'.format(avg_metric)) f = open(cfg.general.save_dir+cfg.test.save_filename+'/metric.csv','a+') f.write('%s,%.4f,%.4f\n'%('average',avg_metric[0],avg_metric[1])) f.close()
def train(train_config_file): """ Medical image segmentation training engine :param train_config_file: the input configuration file :return: None """ assert os.path.isfile(train_config_file), 'Config not found: {}'.format(train_config_file) # load config file train_cfg = load_config(train_config_file) # clean the existing folder if training from scratch model_folder = os.path.join(train_cfg.general.save_dir, train_cfg.general.model_scale) if os.path.isdir(model_folder): if train_cfg.general.resume_epoch < 0: shutil.rmtree(model_folder) os.makedirs(model_folder) else: os.makedirs(model_folder) # copy training and inference config files to the model folder shutil.copy(train_config_file, os.path.join(model_folder, 'train_config.py')) infer_config_file = os.path.join(os.path.join(os.path.dirname(os.path.dirname(__file__)), 'config', 'infer_config.py')) shutil.copy(infer_config_file, os.path.join(train_cfg.general.save_dir, 'infer_config.py')) # enable logging log_file = os.path.join(model_folder, 'train_log.txt') logger = setup_logger(log_file, 'seg3d') # control randomness during training np.random.seed(train_cfg.general.seed) torch.manual_seed(train_cfg.general.seed) if train_cfg.general.num_gpus > 0: torch.cuda.manual_seed(train_cfg.general.seed) # dataset dataset = SegmentationDataset( imlist_file=train_cfg.general.imseg_list, num_classes=train_cfg.dataset.num_classes, spacing=train_cfg.dataset.spacing, crop_size=train_cfg.dataset.crop_size, sampling_method=train_cfg.dataset.sampling_method, random_translation=train_cfg.dataset.random_translation, random_scale=train_cfg.dataset.random_scale, interpolation=train_cfg.dataset.interpolation, crop_normalizers=train_cfg.dataset.crop_normalizers) sampler = EpochConcateSampler(dataset, train_cfg.train.epochs) data_loader = DataLoader(dataset, sampler=sampler, batch_size=train_cfg.train.batchsize, num_workers=train_cfg.train.num_threads, pin_memory=True) net_module = importlib.import_module('segmentation3d.network.' + train_cfg.net.name) net = net_module.SegmentationNet(dataset.num_modality(), train_cfg.dataset.num_classes) max_stride = net.max_stride() net_module.parameters_kaiming_init(net) if train_cfg.general.num_gpus > 0: net = nn.parallel.DataParallel(net, device_ids=list(range(train_cfg.general.num_gpus))) net = net.cuda() assert np.all(np.array(train_cfg.dataset.crop_size) % max_stride == 0), 'crop size not divisible by max stride' # training optimizer opt = optim.Adam(net.parameters(), lr=train_cfg.train.lr, betas=train_cfg.train.betas) # load checkpoint if resume epoch > 0 if train_cfg.general.resume_epoch >= 0: last_save_epoch, batch_start = load_checkpoint(train_cfg.general.resume_epoch, net, opt, model_folder) else: last_save_epoch, batch_start = 0, 0 if train_cfg.loss.name == 'Focal': # reuse focal loss if exists loss_func = FocalLoss(class_num=train_cfg.dataset.num_classes, alpha=train_cfg.loss.obj_weight, gamma=train_cfg.loss.focal_gamma, use_gpu=train_cfg.general.num_gpus > 0) elif train_cfg.loss.name == 'Dice': loss_func = MultiDiceLoss(weights=train_cfg.loss.obj_weight, num_class=train_cfg.dataset.num_classes, use_gpu=train_cfg.general.num_gpus > 0) elif train_cfg.loss.name == 'CE': loss_func = CrossEntropyLoss() else: raise ValueError('Unknown loss function') writer = SummaryWriter(os.path.join(model_folder, 'tensorboard')) batch_idx = batch_start data_iter = iter(data_loader) # loop over batches for i in range(len(data_loader)): begin_t = time.time() crops, masks, frames, filenames = data_iter.next() if train_cfg.general.num_gpus > 0: crops, masks = crops.cuda(), masks.cuda() # clear previous gradients opt.zero_grad() # network forward and backward outputs = net(crops) train_loss = loss_func(outputs, masks) train_loss.backward() # update weights opt.step() # save training crops for visualization if train_cfg.debug.save_inputs: batch_size = crops.size(0) save_intermediate_results(list(range(batch_size)), crops, masks, outputs, frames, filenames, os.path.join(model_folder, 'batch_{}'.format(i))) epoch_idx = batch_idx * train_cfg.train.batchsize // len(dataset) batch_idx += 1 batch_duration = time.time() - begin_t sample_duration = batch_duration * 1.0 / train_cfg.train.batchsize # print training loss per batch msg = 'epoch: {}, batch: {}, train_loss: {:.4f}, time: {:.4f} s/vol' msg = msg.format(epoch_idx, batch_idx, train_loss.item(), sample_duration) logger.info(msg) # save checkpoint if epoch_idx != 0 and (epoch_idx % train_cfg.train.save_epochs == 0): if last_save_epoch != epoch_idx: save_checkpoint(net, opt, epoch_idx, batch_idx, train_cfg, max_stride, dataset.num_modality()) last_save_epoch = epoch_idx writer.add_scalar('Train/Loss', train_loss.item(), batch_idx) writer.close()
def test(config_file): '''Medical image segmentation testing engine :param config_file: the input confituration file :return: NONE ''' assert os.path.isfile(config_file), 'Config not found: {}'.format( config_file) total_metric = 0.0 metric_dict = OrderedDict() metric_dict['name'] = [] metric_dict['dice'] = [] metric_dict['jaccard'] = [] cfg = load_config(config_file) if cfg.general.num_gpus > 0: os.environ['CUDA_VISIBLE_DEVICES'] = cfg.general.gpu np.random.seed(cfg.general.seed) torch.manual_seed(cfg.general.seed) if cfg.general.num_gpus > 0: torch.cuda.manual_seed(cfg.general.seed) dataset = ABUS(base_dir=cfg.test.imseg_list, transform=transforms.Compose([ToTensor()])) testloader = DataLoader(dataset, batch_size=cfg.test.batch_size, shuffle=False, num_workers=cfg.test.num_threads, pin_memory=True) print('dataset length', len(testloader)) net_model = importlib.import_module('segmentation3d.network.' + cfg.net.name) net = net_model.SegmentationNet(1, cfg.dataset.num_classes) net = nn.parallel.DataParallel(net, device_ids=list(range( cfg.general.num_gpus))) net = net.cuda() epoch_idx = cfg.test.test_epoch save_dir = cfg.test.model_dir state = load_testmodel(epoch_idx, net, save_dir) net.load_state_dict(state['state_dict']) net.eval() for ii, sample in enumerate(testloader): name = sample['name'] print('testing patient', name) crops, masks = sample['image'], sample['label'] crops, masks = crops.cuda(), masks.cuda() outputs = net(crops) #print('outputs',outputs.shape) #outputs_softmax = F.softmax(outputs,dim=1) output_numpy = outputs.cpu().data.numpy()[0, 1, :, :] #0for bh 1 for map #print('output_numpy',output_numpy.shape) #print('output_numpy',output_numpy.max()) pred = output_numpy pred[pred > 0.5] = 1 pred[pred != 1] = 0 #print('pred',pred.shape) #print('pred',pred.max()) masks = masks.cpu().detach().numpy()[0, :, :, :] crops = crops.cpu().detach().numpy()[0, 0, :, :, :] #print('crops',crops.shape) #print('masks',masks.shape) #print('np.sum(pred)',np.sum(pred)) if np.sum(pred) == 0: single_metric = (0, 0) else: single_metric = calculate_metric_percase(pred, masks) print('single_metric', single_metric) metric_dict['name'].append(name) metric_dict['dice'].append(single_metric[0]) metric_dict['jaccard'].append(single_metric[1]) total_metric += np.asarray(single_metric) if cfg.test.save == True: test_save_path_temp = os.path.join( cfg.test.model_dir + cfg.test.save_filename + '/', name[0]) if not os.path.exists(test_save_path_temp): os.makedirs(test_save_path_temp) print('test_save_path_temp', test_save_path_temp) nib.save(nib.Nifti1Image(pred.astype(np.float32), np.eye(4)), test_save_path_temp + '/' + 'pred.nii.gz') nib.save(nib.Nifti1Image(crops.astype(np.float32), np.eye(4)), test_save_path_temp + '/' + 'img.nii.gz') nib.save(nib.Nifti1Image(masks.astype(np.float32), np.eye(4)), test_save_path_temp + '/' + 'gt.nii.gz') avg_metric = total_metric / len(testloader) metric_csv = pd.DataFrame(metric_dict) metric_csv.to_csv(cfg.test.model_dir + cfg.test.save_filename + '/metric.csv', index=False) print('average metric is {}'.format(avg_metric)) f = open(cfg.test.model_dir + cfg.test.save_filename + '/metric.csv', 'a+') f.write('%s,%.4f,%.4f\n' % ('average', avg_metric[0], avg_metric[1])) f.close()
def load_seg_model(model_folder, gpu_id=0): """ load segmentation model from folder :param model_folder: the folder containing the segmentation model :param gpu_id: the gpu device id to run the segmentation model :return: a dictionary containing the model and inference parameters """ assert os.path.isdir( model_folder), 'Model folder does not exist: {}'.format(model_folder) model = edict() # load inference config file latest_checkpoint_dir = get_checkpoint_folder( os.path.join(model_folder, 'checkpoints'), -1) infer_cfg = load_config( os.path.join(latest_checkpoint_dir, 'infer_config.py')) model.infer_cfg = infer_cfg # load model state chk_file = os.path.join(latest_checkpoint_dir, 'params.pth') if gpu_id >= 0: os.environ['CUDA_VISIBLE_DEVICES'] = '{}'.format(int(gpu_id)) # load network module state = torch.load(chk_file) net_module = importlib.import_module('segmentation3d.network.' + state['net']) net = net_module.SegmentationNet(state['in_channels'], state['out_channels'], state['dropout']) net = nn.parallel.DataParallel(net) net.load_state_dict(state['state_dict']) net.eval() net = net.cuda() del os.environ['CUDA_VISIBLE_DEVICES'] else: state = torch.load(chk_file, map_location='cpu') net_module = importlib.import_module('segmentation3d.network.' + state['net']) net = net_module.SegmentationNet(state['in_channels'], state['out_channels'], state['dropout']) net = nn.parallel.DataParallel(net) net.load_state_dict(state['state_dict']) net.eval() model.net = net model.spacing, model.max_stride, model.interpolation = state[ 'spacing'], state['max_stride'], state['interpolation'] model.in_channels, model.out_channels = state['in_channels'], state[ 'out_channels'] model.crop_normalizers = [] for crop_normalizer in state['crop_normalizers']: if crop_normalizer['type'] == 0: mean, stddev, clip = crop_normalizer['mean'], crop_normalizer[ 'stddev'], crop_normalizer['clip'] model.crop_normalizers.append(FixedNormalizer(mean, stddev, clip)) elif crop_normalizer['type'] == 1: min_p, max_p, clip = crop_normalizer['min_p'], crop_normalizer[ 'max_p'], crop_normalizer['clip'] model.crop_normalizers.append( AdaptiveNormalizer(min_p, max_p, clip)) else: raise ValueError('Unsupported normalization type.') return model