def validate(val_loader, model, logger, epoch, current_step, val_dataset_opt): print('---------- validation -------------') val_start_time = time.time() model.eval() # Change to eval mode. It is important for BN layers. val_results = OrderedDict() avg_psnr = 0.0 idx = 0 for val_data in val_loader: idx += 1 img_path = val_data['LR_path'][0] img_name = os.path.splitext(os.path.basename(img_path))[0] img_dir = os.path.join(opt['path']['val_images'], img_name) util.mkdir(img_dir) model.feed_data(val_data, volatile=True) model.val() visuals = model.get_current_visuals() sr_img = util.tensor2img_np(visuals['SR']) # uint8 gt_img = util.tensor2img_np(visuals['HR']) # uint8 # # modcrop # gt_img = util.modcrop(gt_img, val_dataset_opt['scale']) h_min = min(sr_img.shape[0], gt_img.shape[0]) w_min = min(sr_img.shape[1], gt_img.shape[1]) sr_img = sr_img[0:h_min, 0:w_min, :] gt_img = gt_img[0:h_min, 0:w_min, :] crop_size = val_dataset_opt['scale'] + 2 cropped_sr_img = sr_img[crop_size:-crop_size, crop_size:-crop_size, :] cropped_gt_img = gt_img[crop_size:-crop_size, crop_size:-crop_size, :] # Save SR images for reference save_img_path = os.path.join(img_dir, '%s_%s.png' % (img_name, current_step)) util.save_img_np(sr_img.squeeze(), save_img_path) # TODO need to modify # metric_mode = val_dataset_opt['metric_mode'] # if metric_mode == 'y': # cropped_sr_img = util.rgb2ycbcr(cropped_sr_img, only_y=True) # cropped_gt_img = util.rgb2ycbcr(cropped_gt_img, only_y=True) avg_psnr += util.psnr(cropped_sr_img, cropped_gt_img) avg_psnr = avg_psnr / idx val_elapsed = time.time() - val_start_time # Save to log print_rlt = OrderedDict() print_rlt['model'] = opt['model'] print_rlt['epoch'] = epoch print_rlt['iters'] = current_step print_rlt['time'] = val_elapsed print_rlt['psnr'] = avg_psnr logger.print_format_results('val', print_rlt) model.train() # change back to train mode. print('-----------------------------------')
idx = 0 for path in glob.glob(test_img_folder + '/*'): idx += 1 basename = os.path.basename(path) base = os.path.splitext(basename)[0] print(idx, base) # read image img = cv2.imread(path, cv2.IMREAD_UNCHANGED) img = modcrop(img, 8) img = img * 1.0 / 255 if img.ndim == 2: img = np.expand_dims(img, axis=2) img = torch.from_numpy(np.transpose(img[:, :, [2, 1, 0]], (2, 0, 1))).float() # matlab imresize img_LR = imresize(img, 1 / 4, antialiasing=True) img_LR = img_LR.unsqueeze(0) img_LR = img_LR.cuda() # read seg seg = torch.load(os.path.join(test_prob_path, base + '_bic.pth')) seg = seg.unsqueeze(0) # change probability # seg.fill_(0) # seg[:,5].fill_(1) seg = seg.cuda() output = model((img_LR, seg)).data output = util.tensor2img_np(output.squeeze()) util.save_img_np(output, os.path.join(save_result_path, base + '_rlt.png'))
test_results['psnr_y'] = [] test_results['ssim_y'] = [] for data in test_loader: need_HR = True if test_loader.dataset.opt['dataroot_HR'] is None: need_HR = False model.feed_data(data, volatile=True, need_HR=need_HR) img_path = data['LR_path'][0] img_name = os.path.splitext(os.path.basename(img_path))[0] model.test() # test visuals = model.get_current_visuals(need_HR=need_HR) sr_img = util.tensor2img_np(visuals['SR']) # uint8 if need_HR: # load GT image and calculate psnr gt_img = util.tensor2img_np(visuals['HR']) h_min = min(sr_img.shape[0], gt_img.shape[0]) w_min = min(sr_img.shape[1], gt_img.shape[1]) sr_img = sr_img[0:h_min, 0:w_min, :] gt_img = gt_img[0:h_min, 0:w_min, :] scale = test_loader.dataset.opt['scale'] crop_border = scale + 2 cropped_sr_img = sr_img[crop_border:-crop_border, crop_border:-crop_border, :] cropped_gt_img = gt_img[crop_border:-crop_border, crop_border:-crop_border, :] psnr = util.psnr(cropped_sr_img, cropped_gt_img) ssim = util.ssim(cropped_sr_img, cropped_gt_img, multichannel=True)
def main(): # options parser = argparse.ArgumentParser() parser.add_argument('-opt', type=str, required=True, help='Path to option JSON file.') opt = option.parse(parser.parse_args().opt, is_train=True) util.mkdir_and_rename(opt['path']['experiments_root']) # rename old experiments if exists util.mkdirs((path for key, path in opt['path'].items() if not key == 'experiments_root' and \ not key == 'pretrain_model_G' and not key == 'pretrain_model_D')) option.save(opt) opt = option.dict_to_nonedict(opt) # Convert to NoneDict, which return None for missing key. # print to file and std_out simultaneously sys.stdout = PrintLogger(opt['path']['log']) # random seed seed = opt['train']['manual_seed'] if seed is None: seed = random.randint(1, 10000) print("Random Seed: ", seed) random.seed(seed) torch.manual_seed(seed) # create train and val dataloader for phase, dataset_opt in opt['datasets'].items(): if phase == 'train': train_set = create_dataset(dataset_opt) train_size = int(math.ceil(len(train_set) / dataset_opt['batch_size'])) print('Number of train images: {:,d}, iters: {:,d}'.format(len(train_set), train_size)) total_iters = int(opt['train']['niter']) total_epoches = int(math.ceil(total_iters / train_size)) print('Total epoches needed: {:d} for iters {:,d}'.format(total_epoches, total_iters)) train_loader = create_dataloader(train_set, dataset_opt) elif phase == 'val': val_dataset_opt = dataset_opt val_set = create_dataset(dataset_opt) val_loader = create_dataloader(val_set, dataset_opt) print('Number of val images in [%s]: %d' % (dataset_opt['name'], len(val_set))) else: raise NotImplementedError("Phase [%s] is not recognized." % phase) assert train_loader is not None # Create model model = create_model(opt) # create logger logger = Logger(opt) current_step = 0 start_time = time.time() print('---------- Start training -------------') for epoch in range(total_epoches): for i, train_data in enumerate(train_loader): current_step += 1 if current_step > total_iters: break # training model.feed_data(train_data) model.optimize_parameters(current_step) time_elapsed = time.time() - start_time start_time = time.time() # log if current_step % opt['logger']['print_freq'] == 0: logs = model.get_current_log() print_rlt = OrderedDict() print_rlt['model'] = opt['model'] print_rlt['epoch'] = epoch print_rlt['iters'] = current_step print_rlt['time'] = time_elapsed for k, v in logs.items(): print_rlt[k] = v print_rlt['lr'] = model.get_current_learning_rate() logger.print_format_results('train', print_rlt) # save models if current_step % opt['logger']['save_checkpoint_freq'] == 0: print('Saving the model at the end of iter %d' % (current_step)) model.save(current_step) # validation if current_step % opt['train']['val_freq'] == 0: print('---------- validation -------------') start_time = time.time() avg_psnr = 0.0 idx = 0 for val_data in val_loader: idx += 1 img_name = os.path.splitext(os.path.basename(val_data['LR_path'][0]))[0] img_dir = os.path.join(opt['path']['val_images'], img_name) util.mkdir(img_dir) model.feed_data(val_data) model.test() visuals = model.get_current_visuals() sr_img = util.tensor2img_np(visuals['SR']) # uint8 gt_img = util.tensor2img_np(visuals['HR']) # uint8 # Save SR images for reference save_img_path = os.path.join(img_dir, '%s_%s.png' % (img_name, current_step)) util.save_img_np(sr_img.squeeze(), save_img_path) # calculate PSNR crop_size = opt['scale'] + 2 cropped_sr_img = sr_img[crop_size:-crop_size, crop_size:-crop_size, :] cropped_gt_img = gt_img[crop_size:-crop_size, crop_size:-crop_size, :] avg_psnr += util.psnr(cropped_sr_img, cropped_gt_img) avg_psnr = avg_psnr / idx time_elapsed = time.time() - start_time # Save to log print_rlt = OrderedDict() print_rlt['model'] = opt['model'] print_rlt['epoch'] = epoch print_rlt['iters'] = current_step print_rlt['time'] = time_elapsed print_rlt['psnr'] = avg_psnr logger.print_format_results('val', print_rlt) print('-----------------------------------') # update learning rate model.update_learning_rate() print('Saving the final model.') model.save('latest') print('End of Training \t Time taken: %d sec' % (time.time() - start_time))