def validate(val_loader, model, logger, epoch, current_step, val_dataset_opt): print('---------- validation -------------') val_start_time = time.time() model.eval() # Change to eval mode. It is important for BN layers. val_results = OrderedDict() avg_psnr = 0.0 idx = 0 for val_data in val_loader: idx += 1 img_path = val_data['LR_path'][0] img_name = os.path.splitext(os.path.basename(img_path))[0] img_dir = os.path.join(opt['path']['val_images'], img_name) util.mkdir(img_dir) model.feed_data(val_data, volatile=True) model.val() visuals = model.get_current_visuals() sr_img = util.tensor2img_np(visuals['SR']) # uint8 gt_img = util.tensor2img_np(visuals['HR']) # uint8 # # modcrop # gt_img = util.modcrop(gt_img, val_dataset_opt['scale']) h_min = min(sr_img.shape[0], gt_img.shape[0]) w_min = min(sr_img.shape[1], gt_img.shape[1]) sr_img = sr_img[0:h_min, 0:w_min, :] gt_img = gt_img[0:h_min, 0:w_min, :] crop_size = val_dataset_opt['scale'] + 2 cropped_sr_img = sr_img[crop_size:-crop_size, crop_size:-crop_size, :] cropped_gt_img = gt_img[crop_size:-crop_size, crop_size:-crop_size, :] # Save SR images for reference save_img_path = os.path.join(img_dir, '%s_%s.png' % (img_name, current_step)) util.save_img_np(sr_img.squeeze(), save_img_path) # TODO need to modify # metric_mode = val_dataset_opt['metric_mode'] # if metric_mode == 'y': # cropped_sr_img = util.rgb2ycbcr(cropped_sr_img, only_y=True) # cropped_gt_img = util.rgb2ycbcr(cropped_gt_img, only_y=True) avg_psnr += util.psnr(cropped_sr_img, cropped_gt_img) avg_psnr = avg_psnr / idx val_elapsed = time.time() - val_start_time # Save to log print_rlt = OrderedDict() print_rlt['model'] = opt['model'] print_rlt['epoch'] = epoch print_rlt['iters'] = current_step print_rlt['time'] = val_elapsed print_rlt['psnr'] = avg_psnr logger.print_format_results('val', print_rlt) model.train() # change back to train mode. print('-----------------------------------')
def validate(val_loader, opt, model, current_step, epoch, logger): print('---------- validation -------------') start_time = time.time() avg_psnr = 0.0 avg_lpips = 0.0 idx = 0 for val_data in val_loader: idx += 1 img_name = os.path.splitext(os.path.basename( val_data['LR_path'][0]))[0] img_dir = os.path.join(opt['path']['val_images'], img_name) util.mkdir(img_dir) tensor_type = torch.zeros if opt['train']['zero_code'] else torch.randn code = model.gen_code(val_data['LR'].shape[0], val_data['LR'].shape[2], val_data['LR'].shape[3], tensor_type=tensor_type) model.feed_data(val_data, code=code) model.test() visuals = model.get_current_visuals() sr_img = util.tensor2img(visuals['HR_pred']) # uint8 gt_img = util.tensor2img(visuals['HR']) # uint8 # Save generated images for reference save_img_path = os.path.join( img_dir, '{:s}_{:s}_{:d}.png'.format(opt['name'], img_name, current_step)) util.save_img(sr_img, save_img_path) # calculate PSNR sr_img = sr_img gt_img = gt_img avg_psnr += util.psnr(sr_img, gt_img) avg_lpips += torch.sum(model.get_loss(level=-1)) if current_step == 0: print('Saving the model at the end of iter {:d}.'.format(current_step)) model.save(current_step) avg_psnr = avg_psnr / idx avg_lpips = avg_lpips / idx time_elapsed = time.time() - start_time # Save to log print_rlt = OrderedDict() print_rlt['model'] = opt['model'] print_rlt['epoch'] = epoch print_rlt['iters'] = current_step print_rlt['time'] = time_elapsed print_rlt['psnr'] = avg_psnr print_rlt['lpips'] = avg_lpips logger.print_format_results('val', print_rlt) print('-----------------------------------')
def validate(val_loader, opt, model, current_step, epoch, logger): print('---------- validation -------------') start_time = time.time() avg_psnr = 0.0 avg_lips = 0.0 idx = 0 for val_data in val_loader: idx += 1 img_name = os.path.splitext(os.path.basename( val_data['LR_path'][0]))[0] img_dir = os.path.join(opt['path']['val_images'], img_name) util.mkdir(img_dir) code_val_0 = torch.randn(val_data['LR'].shape[0], int(opt['network_G']['in_code_nc']), val_data['LR'].shape[2], val_data['LR'].shape[3]) model.feed_data(val_data, code=[code_val_0]) model.test() visuals = model.get_current_visuals() sr_img = util.tensor2img(visuals['SR']) # uint8 gt_img = util.tensor2img(visuals['HR']) # uint8 # Save SR images for reference arch_name = opt['name'].split("_")[2] run_index = opt['name'].split("_")[3] save_img_path = os.path.join( img_dir, 'HyperRIM_{:s}_{:s}_{:s}_x2_{:d}.png'.format( arch_name, run_index, img_name, current_step)) util.save_img(sr_img, save_img_path) # calculate PSNR avg_psnr += util.psnr(sr_img, gt_img) avg_lips += torch.sum(model.get_loss(level=-1)) if current_step == 0: print('Saving the model at the end of iter {:d}.'.format(current_step)) model.save(current_step) avg_psnr = avg_psnr / idx avg_lips = avg_lips / idx time_elapsed = time.time() - start_time # Save to log print_rlt = OrderedDict() print_rlt['model'] = opt['model'] print_rlt['epoch'] = epoch print_rlt['iters'] = current_step print_rlt['time'] = time_elapsed print_rlt['psnr'] = avg_psnr print_rlt['lpips'] = avg_lips logger.print_format_results('val', print_rlt) print('-----------------------------------')
def get_sim(root_path, ori_path, n_samples): img_list = sorted(glob.glob(ori_path)) total_psnr = 0. total_ssim = 0. count = 0 for i, v in enumerate(img_list): img_name = v.split("/")[-1].split(".")[0] img0_np = load_image(v) for j in range(n_samples): img1_np = load_image(root_path + img_name + "_" + str(j) + ".png") total_psnr += util.psnr(img0_np, img1_np) total_ssim += util.ssim(img0_np, img1_np, multichannel=True) count += 1 return total_psnr / count, total_ssim / count
def main(): # options parser = argparse.ArgumentParser() parser.add_argument('-opt', type=str, required=True, help='Path to option JSON file.') opt = option.parse(parser.parse_args().opt, is_train=True) util.mkdir_and_rename( opt['path']['experiments_root']) # rename old experiments if exists util.mkdirs((path for key, path in opt['path'].items() if not key == 'experiments_root' and \ not key == 'pretrain_model_G' and not key == 'pretrain_model_D')) option.save(opt) opt = option.dict_to_nonedict( opt) # Convert to NoneDict, which return None for missing key. # print to file and std_out simultaneously sys.stdout = PrintLogger(opt['path']['log']) # random seed seed = opt['train']['manual_seed'] if seed is None: seed = random.randint(1, 10000) print("Random Seed: ", seed) random.seed(seed) torch.manual_seed(seed) # create train and val dataloader for phase, dataset_opt in opt['datasets'].items(): if phase == 'train': train_set = create_dataset(dataset_opt) train_size = int( math.ceil(len(train_set) / dataset_opt['batch_size'])) print('Number of train images: {:,d}, iters: {:,d}'.format( len(train_set), train_size)) total_iters = int(opt['train']['niter']) total_epoches = int(math.ceil(total_iters / train_size)) print('Total epoches needed: {:d} for iters {:,d}'.format( total_epoches, total_iters)) train_loader = create_dataloader(train_set, dataset_opt) elif phase == 'val': val_dataset_opt = dataset_opt val_set = create_dataset(dataset_opt) val_loader = create_dataloader(val_set, dataset_opt) print('Number of val images in [{:s}]: {:d}'.format( dataset_opt['name'], len(val_set))) else: raise NotImplementedError( 'Phase [{:s}] is not recognized.'.format(phase)) assert train_loader is not None # Create model model = create_model(opt) # create logger logger = Logger(opt) current_step = 0 start_time = time.time() print('---------- Start training -------------') for epoch in range(total_epoches): for i, train_data in enumerate(train_loader): current_step += 1 if current_step > total_iters: break # training model.feed_data(train_data) model.optimize_parameters(current_step) time_elapsed = time.time() - start_time start_time = time.time() # log if current_step % opt['logger']['print_freq'] == 0: logs = model.get_current_log() print_rlt = OrderedDict() print_rlt['model'] = opt['model'] print_rlt['epoch'] = epoch print_rlt['iters'] = current_step print_rlt['time'] = time_elapsed for k, v in logs.items(): print_rlt[k] = v print_rlt['lr'] = model.get_current_learning_rate() logger.print_format_results('train', print_rlt) # save models if current_step % opt['logger']['save_checkpoint_freq'] == 0: print('Saving the model at the end of iter {:d}.'.format( current_step)) model.save(current_step) # validation if current_step % opt['train']['val_freq'] == 0: print('---------- validation -------------') start_time = time.time() avg_psnr = 0.0 idx = 0 for val_data in val_loader: idx += 1 img_name = os.path.splitext( os.path.basename(val_data['LR_path'][0]))[0] img_dir = os.path.join(opt['path']['val_images'], img_name) util.mkdir(img_dir) model.feed_data(val_data) model.test() visuals = model.get_current_visuals() sr_img = util.tensor2img(visuals['SR']) # uint8 gt_img = util.tensor2img(visuals['HR']) # uint8 # Save SR images for reference save_img_path = os.path.join(img_dir, '{:s}_{:d}.png'.format(\ img_name, current_step)) util.save_img(sr_img, save_img_path) # calculate PSNR if opt['crop_scale'] is not None: crop_size = opt['crop_scale'] else: crop_size = opt['scale'] if crop_size <= 0: cropped_sr_img = sr_img.copy() cropped_gt_img = gt_img.copy() else: if len(gt_img.shape) < 3: cropped_sr_img = sr_img[crop_size:-crop_size, crop_size:-crop_size] cropped_gt_img = gt_img[crop_size:-crop_size, crop_size:-crop_size] else: cropped_sr_img = sr_img[crop_size:-crop_size, crop_size:-crop_size, :] cropped_gt_img = gt_img[crop_size:-crop_size, crop_size:-crop_size, :] #avg_psnr += util.psnr(cropped_sr_img, cropped_gt_img) cropped_sr_img_y = bgr2ycbcr(cropped_sr_img, only_y=True) cropped_gt_img_y = bgr2ycbcr(cropped_gt_img, only_y=True) avg_psnr += util.psnr( cropped_sr_img_y, cropped_gt_img_y) ##########only y channel avg_psnr = avg_psnr / idx time_elapsed = time.time() - start_time # Save to log print_rlt = OrderedDict() print_rlt['model'] = opt['model'] print_rlt['epoch'] = epoch print_rlt['iters'] = current_step print_rlt['time'] = time_elapsed print_rlt['psnr'] = avg_psnr logger.print_format_results('val', print_rlt) print('-----------------------------------') # update learning rate model.update_learning_rate() print('Saving the final model.') model.save('latest') print('End of training.')
img_name = os.path.splitext(os.path.basename(img_path))[0] model.test() # test visuals = model.get_current_visuals(need_HR=need_HR) sr_img = util.tensor2img(visuals['SR']) # uint8 if need_HR: # load GT image and calculate psnr gt_img = util.tensor2img(visuals['HR']) crop_border = test_loader.dataset.opt['scale'] cropped_sr_img = sr_img[crop_border:-crop_border, crop_border:-crop_border, :] cropped_gt_img = gt_img[crop_border:-crop_border, crop_border:-crop_border, :] psnr = util.psnr(cropped_sr_img, cropped_gt_img) ssim = util.ssim(cropped_sr_img, cropped_gt_img, multichannel=True) test_results['psnr'].append(psnr) test_results['ssim'].append(ssim) if gt_img.shape[2] == 3: # RGB image cropped_sr_img_y = rgb2ycbcr(cropped_sr_img, only_y=True) cropped_gt_img_y = rgb2ycbcr(cropped_gt_img, only_y=True) psnr_y = util.psnr(cropped_sr_img_y, cropped_gt_img_y) ssim_y = util.ssim(cropped_sr_img_y, cropped_gt_img_y, multichannel=False) test_results['psnr_y'].append(psnr_y) test_results['ssim_y'].append(ssim_y) print('{:20s} - PSNR: {:.4f} dB; SSIM: {:.4f}; PSNR_Y: {:.4f} dB; SSIM_Y: {:.4f}.'\ .format(img_name, psnr, ssim, psnr_y, ssim_y)) else:
def validate(val_loader, opt, model, current_step, epoch, logger): print('---------- validation -------------') start_time = time.time() avg_psnr = 0.0 idx = 0 for val_data in val_loader: idx += 1 img_name = os.path.splitext(os.path.basename( val_data['LR_path'][0]))[0] img_dir = os.path.join(opt['path']['val_images'], img_name) util.mkdir(img_dir) if 'zero_code' in opt['train'] and opt['train']['zero_code']: code_val_0 = torch.zeros(val_data['LR'].shape[0], int(opt['network_G']['in_code_nc']), val_data['LR'].shape[2] * 2, val_data['LR'].shape[3] * 2) code_val_1 = torch.zeros(val_data['LR'].shape[0], int(opt['network_G']['in_code_nc']), val_data['LR'].shape[2] * 4, val_data['LR'].shape[3] * 4) elif 'rand_code' in opt['train'] and opt['train']['rand_code']: code_val_0 = torch.rand(val_data['LR'].shape[0], int(opt['network_G']['in_code_nc']), val_data['LR'].shape[2] * 2, val_data['LR'].shape[3] * 2) code_val_1 = torch.rand(val_data['LR'].shape[0], int(opt['network_G']['in_code_nc']), val_data['LR'].shape[2] * 4, val_data['LR'].shape[3] * 4) else: code_val_0 = torch.randn(val_data['LR'].shape[0], int(opt['network_G']['in_code_nc']), val_data['LR'].shape[2] * 2, val_data['LR'].shape[3] * 2) code_val_1 = torch.randn(val_data['LR'].shape[0], int(opt['network_G']['in_code_nc']), val_data['LR'].shape[2] * 4, val_data['LR'].shape[3] * 4) model.feed_data(val_data, code=[code_val_0, code_val_1]) model.test() visuals = model.get_current_visuals() sr_img = util.tensor2img(visuals['SR']) # uint8 gt_img = util.tensor2img(visuals['HR']) # uint8 # Save SR images for reference save_img_path = os.path.join( img_dir, 'caffe_{:s}_x4_{:d}.png'.format(img_name, current_step)) util.save_img(sr_img, save_img_path) # calculate PSNR crop_size = opt['scale'] cropped_sr_img = sr_img[crop_size:-crop_size, crop_size:-crop_size, :] cropped_gt_img = gt_img[crop_size:-crop_size, crop_size:-crop_size, :] avg_psnr += util.psnr(cropped_sr_img, cropped_gt_img) if current_step == 0: print('Saving the model at the end of iter {:d}.'.format(current_step)) model.save(current_step) avg_psnr = avg_psnr / idx time_elapsed = time.time() - start_time # Save to log print_rlt = OrderedDict() print_rlt['model'] = opt['model'] print_rlt['epoch'] = epoch print_rlt['iters'] = current_step print_rlt['time'] = time_elapsed print_rlt['psnr'] = avg_psnr logger.print_format_results('val', print_rlt) print('-----------------------------------')
# For generating multiple samples of the same input image for run_index in range(multiple): code = model.gen_code(data['LR'].shape[0], data['LR'].shape[2], data['LR'].shape[3]) model.feed_data(data, code=code, need_HR=need_HR) model.test() img_path = data['LR_path'][0] img_name = os.path.splitext(os.path.basename(img_path))[0] visuals = model.get_current_visuals(need_HR=need_HR) sr_img = util.tensor2img(visuals['HR_pred']) # uint8 if need_HR: # load target image and calculate metric scores gt_img = util.tensor2img(visuals['HR']) psnr = util.psnr(sr_img, gt_img) ssim = util.ssim(sr_img, gt_img, multichannel=True) lpips = torch.sum(model.get_loss(level=-1)) test_results['psnr'].append(psnr) test_results['ssim'].append(ssim) test_results['lpips'].append(lpips) print('{:20s} - LPIPS: {:.4f}; PSNR: {:.4f} dB; SSIM: {:.4f}.'. format(img_name, lpips, psnr, ssim)) else: print(img_name) save_img_path = os.path.join(dataset_dir, img_name + '_%d.png' % run_index) util.save_img(sr_img, save_img_path) if need_HR: # metrics
def main(): # Settings parser = argparse.ArgumentParser() parser.add_argument('-opt', type=str, required=True, help='Path to option JSON file.') opt = option.parse(parser.parse_args().opt) #load settings and initialize settings util.mkdir_and_rename(opt['path']['experiments_root']) # rename old experiments if exists util.mkdirs((path for key, path in opt['path'].items() if not key == 'experiments_root' and \ not key == 'saved_model')) option.save(opt) opt = option.dict_to_nonedict(opt) # Convert to NoneDict, which return None for missing key. # Redirect all writes to the "txt" file sys.stdout = PrintLogger(opt['path']['log']) # create train and val dataloader for phase, dataset_opt in opt['datasets'].items(): if phase == 'train': train_set = create_dataset(dataset_opt) train_size = int(math.ceil(len(train_set) / dataset_opt['batch_size'])) print('Number of train images: {:,d}, iters: {:,d}'.format(len(train_set), train_size)) total_iters = int(opt['train']['niter']) total_epoches = int(math.ceil(total_iters / train_size)) print('Total epoches needed: {:d} for iters {:,d}'.format(total_epoches, total_iters)) train_loader = create_dataloader(train_set, dataset_opt) elif phase == 'val': val_set = create_dataset(dataset_opt) val_loader = create_dataloader(val_set, dataset_opt) print('Number of val images in [{:s}]: {:d}'.format(dataset_opt['name'], len(val_set))) else: raise NotImplementedError('Phase [{:s}] is not recognized.'.format(phase)) assert train_loader is not None # Create model model = create_model(opt) # Create logger logger = Logger(opt) current_step = 0 start_time = time.time() print('---------- Start training -------------') for epoch in range(total_epoches): for i, train_data in enumerate(train_loader): current_step += 1 if current_step > total_iters: break # training model.feed_data(train_data) model.optimize_parameters(current_step) time_elapsed = time.time() - start_time start_time = time.time() # log if current_step % opt['logger']['print_freq'] == 0: logs = model.get_current_log() print_rlt = OrderedDict() print_rlt['model'] = opt['model'] print_rlt['epoch'] = epoch print_rlt['iters'] = current_step print_rlt['time'] = time_elapsed for k, v in logs.items(): print_rlt[k] = v print_rlt['lr'] = model.get_current_learning_rate() logger.print_format_results('train', print_rlt) # save models if current_step % opt['logger']['save_checkpoint_freq'] == 0: print('Saving the model at the end of iter {:d}.'.format(current_step)) model.save(current_step) # validation if current_step % opt['train']['val_freq'] == 0: print('---------- validation -------------') start_time = time.time() avg_psnr = 0.0 avg_ssim =0.0 idx = 0 for val_data in val_loader: idx += 1 img_name = os.path.splitext(os.path.basename(val_data['GT_path'][0]))[0] img_dir = os.path.join(opt['path']['val_images'], img_name) util.mkdir(img_dir) model.feed_data(val_data) model.test() visuals = model.get_current_visuals() out_img = util.tensor2img(visuals['Output']) gt_img = util.tensor2img(visuals['ground_truth']) # uint8 # Save output images for reference save_img_path = os.path.join(img_dir, '{:s}_{:d}.png'.format(\ img_name, current_step)) util.save_img(out_img, save_img_path) # calculate PSNR if len(gt_img.shape) == 2: gt_img = np.expand_dims(gt_img, axis=2) out_img = np.expand_dims(out_img, axis=2) crop_border = opt['scale'] cropped_out_img = out_img[crop_border:-crop_border, crop_border:-crop_border, :] cropped_gt_img = gt_img[crop_border:-crop_border, crop_border:-crop_border, :] if gt_img.shape[2] == 3: # RGB image cropped_out_img_y = bgr2ycbcr(cropped_out_img, only_y=True) cropped_gt_img_y = bgr2ycbcr(cropped_gt_img, only_y=True) avg_psnr += util.psnr(cropped_out_img_y, cropped_gt_img_y) avg_ssim += util.ssim(cropped_out_img_y, cropped_gt_img_y, multichannel=False) else: avg_psnr += util.psnr(cropped_out_img, cropped_gt_img) avg_ssim += util.ssim(cropped_out_img, cropped_gt_img, multichannel=True) avg_psnr = avg_psnr / idx avg_ssim = avg_ssim / idx time_elapsed = time.time() - start_time # Save to log print_rlt = OrderedDict() print_rlt['model'] = opt['model'] print_rlt['epoch'] = epoch print_rlt['iters'] = current_step print_rlt['time'] = time_elapsed print_rlt['psnr'] = avg_psnr print_rlt['ssim'] = avg_ssim logger.print_format_results('val', print_rlt) print('-----------------------------------') # update learning rate model.update_learning_rate() print('Saving the final model.') model.save('latest') print('End of training.')
def main(): # options parser = argparse.ArgumentParser() parser.add_argument('-opt', type=str, required=True, help='Path to option JSON file.') opt = option.parse(parser.parse_args().opt, is_train=True) util.mkdir_and_rename( opt['path']['experiments_root']) # rename old experiments if exists util.mkdirs((path for key, path in opt['path'].items() if not key == 'experiments_root' and \ not key == 'pretrain_model_G' and not key == 'pretrain_model_D')) option.save(opt) opt = option.dict_to_nonedict( opt) # Convert to NoneDict, which return None for missing key. # print to file and std_out simultaneously sys.stdout = PrintLogger(opt['path']['log']) # random seed seed = opt['train']['manual_seed'] if seed is None: seed = random.randint(1, 10000) print("Random Seed: ", seed) random.seed(seed) torch.manual_seed(seed) # create train and val dataloader for phase, dataset_opt in opt['datasets'].items(): if phase == 'train': train_set = create_dataset(dataset_opt) train_size = int( math.ceil(len(train_set) / dataset_opt['batch_size'])) print('Number of train images: {:,d}, iters: {:,d}'.format( len(train_set), train_size)) total_iters = int(opt['train']['niter']) total_epoches = int(math.ceil(total_iters / train_size)) print('Total epoches needed: {:d} for iters {:,d}'.format( total_epoches, total_iters)) train_loader = create_dataloader(train_set, dataset_opt) batch_size_per_month = dataset_opt['batch_size'] batch_size_per_day = int( opt['datasets']['train']['batch_size_per_day']) num_month = int(opt['train']['num_month']) num_day = int(opt['train']['num_day']) use_dci = false if 'use_dci' not in opt['train'] else opt['train'][ 'use_dci'] elif phase == 'val': val_dataset_opt = dataset_opt val_set = create_dataset(dataset_opt) val_loader = create_dataloader(val_set, dataset_opt) print('Number of val images in [{:s}]: {:d}'.format( dataset_opt['name'], len(val_set))) else: raise NotImplementedError( 'Phase [{:s}] is not recognized.'.format(phase)) assert train_loader is not None # Create model model = create_model(opt) # create logger logger = Logger(opt) current_step = 0 start_time = time.time() print('---------- Start training -------------') for epoch in range(num_month): for i, train_data in enumerate(train_loader): # get the code if use_dci: cur_month_code = get_code_for_data(model, train_data, opt) else: cur_month_code = get_code(model, train_data, opt) for j in range(num_day): current_step += 1 if current_step > total_iters: break # get the sliced data cur_day_batch_start_idx = ( j * batch_size_per_day) % batch_size_per_month cur_day_batch_end_idx = cur_day_batch_start_idx + batch_size_per_day if cur_day_batch_end_idx > batch_size_per_month: cur_day_batch_idx = np.hstack( (np.arange(cur_day_batch_start_idx, batch_size_per_month), np.arange(cur_day_batch_end_idx - batch_size_per_month))) else: cur_day_batch_idx = slice(cur_day_batch_start_idx, cur_day_batch_end_idx) cur_day_train_data = { 'LR': train_data['LR'][cur_day_batch_idx], 'HR': train_data['HR'][cur_day_batch_idx] } code = cur_month_code[cur_day_batch_idx] # training model.feed_data(cur_day_train_data, code=code) model.optimize_parameters(current_step) time_elapsed = time.time() - start_time start_time = time.time() # log if current_step % opt['logger']['print_freq'] == 0: logs = model.get_current_log() print_rlt = OrderedDict() print_rlt['model'] = opt['model'] print_rlt['epoch'] = epoch print_rlt['iters'] = current_step print_rlt['time'] = time_elapsed for k, v in logs.items(): print_rlt[k] = v print_rlt['lr'] = model.get_current_learning_rate() logger.print_format_results('train', print_rlt) # save models if current_step % opt['logger']['save_checkpoint_freq'] == 0: print('Saving the model at the end of iter {:d}.'.format( current_step)) model.save(current_step) # validation if current_step % opt['train']['val_freq'] == 0: print('---------- validation -------------') start_time = time.time() avg_psnr = 0.0 idx = 0 for val_data in val_loader: idx += 1 img_name = os.path.splitext( os.path.basename(val_data['LR_path'][0]))[0] img_dir = os.path.join(opt['path']['val_images'], img_name) util.mkdir(img_dir) if 'zero_code' in opt['train'] and opt['train'][ 'zero_code']: code_val = torch.zeros( val_data['LR'].shape[0], int(opt['network_G']['in_code_nc']), val_data['LR'].shape[2], val_data['LR'].shape[3]) elif 'rand_code' in opt['train'] and opt['train'][ 'rand_code']: code_val = torch.rand( val_data['LR'].shape[0], int(opt['network_G']['in_code_nc']), val_data['LR'].shape[2], val_data['LR'].shape[3]) else: code_val = torch.randn( val_data['LR'].shape[0], int(opt['network_G']['in_code_nc']), val_data['LR'].shape[2], val_data['LR'].shape[3]) model.feed_data(val_data, code=code_val) model.test() visuals = model.get_current_visuals() sr_img = util.tensor2img(visuals['SR']) # uint8 gt_img = util.tensor2img(visuals['HR']) # uint8 # Save SR images for reference run_index = opt['name'].split("_")[2] save_img_path = os.path.join(img_dir, 'srim_{:s}_{:s}_{:d}.png'.format( \ run_index, img_name, current_step)) util.save_img(sr_img, save_img_path) # calculate PSNR crop_size = opt['scale'] cropped_sr_img = sr_img[crop_size:-crop_size, crop_size:-crop_size, :] cropped_gt_img = gt_img[crop_size:-crop_size, crop_size:-crop_size, :] avg_psnr += util.psnr(cropped_sr_img, cropped_gt_img) avg_psnr = avg_psnr / idx time_elapsed = time.time() - start_time # Save to log print_rlt = OrderedDict() print_rlt['model'] = opt['model'] print_rlt['epoch'] = epoch print_rlt['iters'] = current_step print_rlt['time'] = time_elapsed print_rlt['psnr'] = avg_psnr logger.print_format_results('val', print_rlt) print('-----------------------------------') # update learning rate model.update_learning_rate() print('Saving the final model.') model.save('latest') print('End of training.')