def __init__(self, config_file, im_size, config_spec=None, cropping="random", cache_dir=None, use_cache=False, dataset_name="default_dataset_name"): """ Dataset for generating degraded images on the fly. Args: pipeline_configs: dictionary of boolean flags controlling how pipelines are created. pipeline_param_ranges: dictionary of ranges of params. patch_dir: directory to load linear patches. config_file: path to data config file im_size: tuple of (w, h) config_spec: path to data config spec file cropping: cropping mode ["random", "center"] """ super().__init__() if config_spec is None: config_spec = _configspec_path() config = read_config(config_file, config_spec) self.config_file = config_file # directory to load linear patches patch_dir = config['dataset_dir'] # dictionary of boolean flags controlling how pipelines are created # (see data_configspec for detail). pipeline_configs = config['pipeline_configs'] # dictionary of ranges of params (see data_configspec for detail). pipeline_param_ranges = config['pipeline_param_ranges'] file_list = glob.glob(os.path.join(patch_dir, 'images/target/*.npy')) file_list = [os.path.basename(f) for f in file_list] file_list = [os.path.splitext(f)[0] for f in file_list] self.file_list = sorted(file_list, key=lambda x: zlib.adler32(x.encode('utf-8'))) self.pipeline_param_ranges = pipeline_param_ranges self.pipeline_configs = pipeline_configs print('Data Pipeline Configs: ', self.pipeline_configs) print('Data Pipeline Param Ranges: ', self.pipeline_param_ranges) self.data_root = patch_dir self.im_size = im_size self.cropping = cropping self.use_cache = use_cache self.cache_dir = cache_dir sz = "{}x{}".format(self.im_size[0], self.im_size[1]) \ if self.im_size is not None else "None" self.dataset_name = "_".join([dataset_name, sz])
def __init__(self, config_file, config_spec=None, img_format='.bmp', degamma=True, color=True, blind=False, train=True): super(TrainDataSet, self).__init__() if config_spec is None: config_spec = self._configspec_path() config = read_config(config_file, config_spec) self.dataset_config = config['dataset_configs'] self.dataset_dir = self.dataset_config['dataset_dir'] self.images = list( filter(lambda x: True if img_format in x else False, os.listdir(self.dataset_dir))) self.burst_size = self.dataset_config['burst_length'] self.patch_size = self.dataset_config['patch_size'] self.upscale = self.dataset_config['down_sample'] self.big_jitter = self.dataset_config['big_jitter'] self.small_jitter = self.dataset_config['small_jitter'] # 对应下采样之前图像的最大偏移量 self.jitter_upscale = self.big_jitter * self.upscale # 对应下采样之前的图像的patch尺寸 self.size_upscale = self.patch_size * self.upscale + 2 * self.jitter_upscale # 产生大jitter和小jitter之间的delta 在下采样之前的尺度上 self.delta_upscale = (self.big_jitter - self.small_jitter) * self.upscale # 对应到原图的patch的尺寸 self.patch_size_upscale = self.patch_size * self.upscale # 去伽马效应 self.degamma = degamma # 是否用彩色图像进行处理 self.color = color # 是否盲估计 盲估计即估计的噪声方差不会作为网络的输入 self.blind = blind self.train = train self.vertical_flip = Random_Vertical_Flip(p=0.5) self.horizontal_flip = Random_Horizontal_Flip(p=0.5)
def eval(config, args): train_config = config['training'] arch_config = config['architecture'] use_cache = train_config['use_cache'] print('Eval Process......') checkpoint_dir = train_config['checkpoint_dir'] if not os.path.exists(checkpoint_dir) or len( os.listdir(checkpoint_dir)) == 0: print('There is no any checkpoint file in path:{}'.format( checkpoint_dir)) # the path for saving eval images eval_dir = train_config['eval_dir'] if not os.path.exists(eval_dir): os.mkdir(eval_dir) files = os.listdir(eval_dir) for f in files: os.remove(os.path.join(eval_dir, f)) # dataset and dataloader data_set = TrainDataSet(train_config['dataset_configs'], img_format='.bmp', degamma=True, color=False, blind=arch_config['blind_est'], train=False) data_loader = DataLoader(data_set, batch_size=1, shuffle=False, num_workers=args.num_workers) dataset_config = read_config(train_config['dataset_configs'], _configspec_path())['dataset_configs'] # model here model = KPN(color=False, burst_length=dataset_config['burst_length'], blind_est=arch_config['blind_est'], kernel_size=list(map(int, arch_config['kernel_size'].split())), sep_conv=arch_config['sep_conv'], channel_att=arch_config['channel_att'], spatial_att=arch_config['spatial_att'], upMode=arch_config['upMode'], core_bias=arch_config['core_bias']) if args.cuda: model = model.cuda() if args.mGPU: model = nn.DataParallel(model) # load trained model ckpt = load_checkpoint(checkpoint_dir, args.checkpoint) model.load_state_dict(ckpt['state_dict']) print('The model has been loaded from epoch {}, n_iter {}.'.format( ckpt['epoch'], ckpt['global_iter'])) # switch the eval mode model.eval() # data_loader = iter(data_loader) burst_length = dataset_config['burst_length'] data_length = burst_length if arch_config['blind_est'] else burst_length + 1 patch_size = dataset_config['patch_size'] trans = transforms.ToPILImage() with torch.no_grad(): psnr = 0.0 ssim = 0.0 for i, (burst_noise, gt, white_level) in enumerate(data_loader): if i < 100: # data = next(data_loader) if args.cuda: burst_noise = burst_noise.cuda() gt = gt.cuda() white_level = white_level.cuda() pred_i, pred = model(burst_noise, burst_noise[:, 0:burst_length, ...], white_level) pred_i = sRGBGamma(pred_i) pred = sRGBGamma(pred) gt = sRGBGamma(gt) burst_noise = sRGBGamma(burst_noise / white_level) psnr_t = calculate_psnr(pred.unsqueeze(1), gt.unsqueeze(1)) ssim_t = calculate_ssim(pred.unsqueeze(1), gt.unsqueeze(1)) psnr_noisy = calculate_psnr( burst_noise[:, 0, ...].unsqueeze(1), gt.unsqueeze(1)) psnr += psnr_t ssim += ssim_t pred = torch.clamp(pred, 0.0, 1.0) if args.cuda: pred = pred.cpu() gt = gt.cpu() burst_noise = burst_noise.cpu() trans(burst_noise[0, 0, ...].squeeze()).save(os.path.join( eval_dir, '{}_noisy_{:.2f}dB.png'.format(i, psnr_noisy)), quality=100) trans(pred.squeeze()).save(os.path.join( eval_dir, '{}_pred_{:.2f}dB.png'.format(i, psnr_t)), quality=100) trans(gt.squeeze()).save(os.path.join(eval_dir, '{}_gt.png'.format(i)), quality=100) print('{}-th image is OK, with PSNR: {:.2f}dB, SSIM: {:.4f}'. format(i, psnr_t, ssim_t)) else: break print('All images are OK, average PSNR: {:.2f}dB, SSIM: {:.4f}'.format( psnr / 100, ssim / 100))
def train(config, num_workers, num_threads, cuda, restart_train, mGPU): # torch.set_num_threads(num_threads) train_config = config['training'] arch_config = config['architecture'] batch_size = train_config['batch_size'] lr = train_config['learning_rate'] weight_decay = train_config['weight_decay'] decay_step = train_config['decay_steps'] lr_decay = train_config['lr_decay'] n_epoch = train_config['num_epochs'] use_cache = train_config['use_cache'] print('Configs:', config) # checkpoint path checkpoint_dir = train_config['checkpoint_dir'] if not os.path.exists(checkpoint_dir): os.makedirs(checkpoint_dir) # logs path logs_dir = train_config['logs_dir'] if not os.path.exists(logs_dir): os.makedirs(logs_dir) shutil.rmtree(logs_dir) log_writer = SummaryWriter(logs_dir) # dataset and dataloader data_set = TrainDataSet(train_config['dataset_configs'], img_format='.bmp', degamma=True, color=False, blind=arch_config['blind_est']) data_loader = DataLoader(data_set, batch_size=batch_size, shuffle=True, num_workers=num_workers) dataset_config = read_config(train_config['dataset_configs'], _configspec_path())['dataset_configs'] # model here model = KPN(color=False, burst_length=dataset_config['burst_length'], blind_est=arch_config['blind_est'], kernel_size=list(map(int, arch_config['kernel_size'].split())), sep_conv=arch_config['sep_conv'], channel_att=arch_config['channel_att'], spatial_att=arch_config['spatial_att'], upMode=arch_config['upMode'], core_bias=arch_config['core_bias']) if cuda: model = model.cuda() if mGPU: model = nn.DataParallel(model) model.train() # loss function here loss_func = LossFunc(coeff_basic=1.0, coeff_anneal=1.0, gradient_L1=True, alpha=arch_config['alpha'], beta=arch_config['beta']) # Optimizer here if train_config['optimizer'] == 'adam': optimizer = optim.Adam(model.parameters(), lr=lr) elif train_config['optimizer'] == 'sgd': optimizer = optim.SGD(model.parameters(), lr=lr, momentum=0.9, weight_decay=weight_decay) else: raise ValueError( "Optimizer must be 'sgd' or 'adam', but received {}.".format( train_config['optimizer'])) optimizer.zero_grad() # learning rate scheduler here scheduler = lr_scheduler.StepLR(optimizer, step_size=10, gamma=lr_decay) average_loss = MovingAverage(train_config['save_freq']) if not restart_train: try: checkpoint = load_checkpoint(checkpoint_dir, 'best') start_epoch = checkpoint['epoch'] global_step = checkpoint['global_iter'] best_loss = checkpoint['best_loss'] model.load_state_dict(checkpoint['state_dict']) optimizer.load_state_dict(checkpoint['optimizer']) scheduler.load_state_dict(checkpoint['lr_scheduler']) print('=> loaded checkpoint (epoch {}, global_step {})'.format( start_epoch, global_step)) except: start_epoch = 0 global_step = 0 best_loss = np.inf print('=> no checkpoint file to be loaded.') else: start_epoch = 0 global_step = 0 best_loss = np.inf if os.path.exists(checkpoint_dir): pass # files = os.listdir(checkpoint_dir) # for f in files: # os.remove(os.path.join(checkpoint_dir, f)) else: os.mkdir(checkpoint_dir) print('=> training') burst_length = dataset_config['burst_length'] data_length = burst_length if arch_config['blind_est'] else burst_length + 1 patch_size = dataset_config['patch_size'] for epoch in range(start_epoch, n_epoch): epoch_start_time = time.time() # decay the learning rate lr_cur = [param['lr'] for param in optimizer.param_groups] if lr_cur[0] > 5e-6: scheduler.step() else: for param in optimizer.param_groups: param['lr'] = 5e-6 print( '=' * 20, 'lr={}'.format([param['lr'] for param in optimizer.param_groups]), '=' * 20) t1 = time.time() for step, (burst_noise, gt, white_level) in enumerate(data_loader): if cuda: burst_noise = burst_noise.cuda() gt = gt.cuda() # print('white_level', white_level, white_level.size()) # pred_i, pred = model(burst_noise, burst_noise[:, 0:burst_length, ...], white_level) # loss_basic, loss_anneal = loss_func(sRGBGamma(pred_i), sRGBGamma(pred), sRGBGamma(gt), global_step) loss = loss_basic + loss_anneal # backward optimizer.zero_grad() loss.backward() optimizer.step() # update the average loss average_loss.update(loss) # calculate PSNR psnr = calculate_psnr(pred.unsqueeze(1), gt.unsqueeze(1)) ssim = calculate_ssim(pred.unsqueeze(1), gt.unsqueeze(1)) # add scalars to tensorboardX log_writer.add_scalar('loss_basic', loss_basic, global_step) log_writer.add_scalar('loss_anneal', loss_anneal, global_step) log_writer.add_scalar('loss_total', loss, global_step) log_writer.add_scalar('psnr', psnr, global_step) log_writer.add_scalar('ssim', ssim, global_step) # print print( '{:-4d}\t| epoch {:2d}\t| step {:4d}\t| loss_basic: {:.4f}\t| loss_anneal: {:.4f}\t|' ' loss: {:.4f}\t| PSNR: {:.2f}dB\t| SSIM: {:.4f}\t| time:{:.2f} seconds.' .format(global_step, epoch, step, loss_basic, loss_anneal, loss, psnr, ssim, time.time() - t1)) t1 = time.time() # global_step global_step += 1 if global_step % train_config['save_freq'] == 0: if average_loss.get_value() < best_loss: is_best = True best_loss = average_loss.get_value() else: is_best = False save_dict = { 'epoch': epoch, 'global_iter': global_step, 'state_dict': model.state_dict(), 'best_loss': best_loss, 'optimizer': optimizer.state_dict(), 'lr_scheduler': scheduler.state_dict() } save_checkpoint(save_dict, is_best, checkpoint_dir, global_step, max_keep=train_config['ckpt_to_keep']) print('Epoch {} is finished, time elapsed {:.2f} seconds.'.format( epoch, time.time() - epoch_start_time))
parser.add_argument('--num_threads', '-nt', default=8, type=int, help='number of threads in data loader') parser.add_argument('--cuda', '-c', action='store_true', help='whether to train on the GPU') parser.add_argument('--mGPU', '-m', action='store_true', help='whether to train on multiple GPUs') parser.add_argument('--eval', action='store_true', help='whether to work on the evaluation mode') parser.add_argument('--checkpoint', '-ckpt', dest='checkpoint', type=str, default='best', help='the checkpoint to eval') args = parser.parse_args() # config = read_config(args.config_file, args.config_spec) if args.eval: eval(config, args) else: train(config, args.num_workers, args.num_threads, args.cuda, args.restart, args.mGPU)
def __init__(self, config_file, config_spec=None, blind=False, cropping="random", cache_dir=None, use_cache=False, dataset_name="synthetic"): """ Dataset for generating degraded images on the fly. Args: pipeline_configs: dictionary of boolean flags controlling how pipelines are created. pipeline_param_ranges: dictionary of ranges of params. patch_dir: directory to load linear patches. config_file: path to data config file im_size: tuple of (w, h) config_spec: path to data config spec file cropping: cropping mode ["random", "center"] """ super().__init__() if config_spec is None: config_spec = _configspec_path() config = read_config(config_file, config_spec) # self.config_file = config_file # dictionary of dataset configs self.dataset_configs = config['dataset_configs'] # directory to load linear patches patch_dir = self.dataset_configs['dataset_dir'] # dictionary of boolean flags controlling how pipelines are created # (see data_configspec for detail). self.pipeline_configs = config['pipeline_configs'] # dictionary of ranges of params (see data_configspec for detail). self.pipeline_param_ranges = config['pipeline_param_ranges'] file_list = glob.glob(os.path.join(patch_dir, '*.pth')) file_list = [os.path.basename(f) for f in file_list] file_list = [os.path.splitext(f)[0] for f in file_list] self.file_list = sorted(file_list, key=lambda x: zlib.adler32(x.encode('utf-8'))) # print(self.file_list) # self.pipeline_param_ranges = pipeline_param_ranges # self.pipeline_configs = pipeline_configs # print('Data Pipeline Configs: ', self.pipeline_configs) # print('Data Pipeline Param Ranges: ', self.pipeline_param_ranges) # some variables about the setting of dataset self.data_root = patch_dir self.im_size = self.dataset_configs['patch_size'] # the size after down-sample extra_for_bayer = 2 # extra size used for the random choice for bayer pattern self.big_jitter = self.dataset_configs['big_jitter'] self.small_jitter = self.dataset_configs['small_jitter'] self.down_sample = self.dataset_configs['down_sample'] # image size corresponding to original image (include big jitter) self.im_size_upscale = (self.im_size + 2 * self.big_jitter + extra_for_bayer) * self.down_sample # from big jitter image to real image with extra pixels to random choose the bayer pattern self.big_restore_upscale = self.big_jitter * self.down_sample # the shift pixels of small jitter within upscale self.small_restore_upscale = self.small_jitter * self.down_sample # from big jitter images to small jitter images self.big2small_upscale = (self.big_jitter - self.small_jitter) * self.down_sample # self.im_size_extra = (self.im_size + extra_for_bayer) * self.down_sample # blind estimate? self.blind = blind # others self.cropping = cropping self.use_cache = use_cache self.cache_dir = cache_dir sz = "{}x{}".format(self.im_size, self.im_size) \ if self.im_size is not None else "None" self.dataset_name = "_".join([dataset_name, sz]) # add the codes by Bin Zhang self.burst_length = self.dataset_configs['burst_length']
def train(): log_writer = SummaryWriter('./logs') parser = argparse.ArgumentParser() parser.add_argument('--restart', '-r', action='store_true') args = parser.parse_args() config = read_config('kpn_specs/att_kpn_config.conf', 'kpn_specs/configspec.conf') train_config = config['training'] data_set = TrainDataSet( train_config['dataset_configs'], img_format='.bmp', degamma=True, color=True, blind=False ) data_loader = DataLoader( dataset=data_set, batch_size=32, shuffle=True, num_workers=4 ) loss_fn = nn.L1Loss() model = Network(True).cuda() model.train() optimizer = optim.Adam(model.parameters(), lr=5e-5) if not args.restart: model.load_state_dict(load_checkpoint('./noise_models', best_or_latest='best')) global_iter = 0 min_loss = np.inf loss_ave = MovingAverage(200) import os if not os.path.exists('./noise_models'): os.mkdir('./noise_models') for epoch in range(100): for step, (data, A, B) in enumerate(data_loader): feed = data[:, 0, ...].cuda() gt = data[:, -1, ...].cuda() # print(data.size()) pred = model(feed) loss = loss_fn(pred, gt) global_iter += 1 optimizer.zero_grad() loss.backward() optimizer.step() log_writer.add_scalar('loss', loss, global_iter) loss_ave.update(loss) if global_iter % 200 == 0: loss_t = loss_ave.get_value() min_loss = min(min_loss, loss_t) is_best = min_loss == loss_t save_checkpoint( model.state_dict(), is_best=is_best, checkpoint_dir='./noise_models', n_iter=global_iter ) print('{: 6d}, epoch {: 3d}, iter {: 4d}, loss {:.4f}'.format(global_iter, epoch, step, loss))