def cal_dpsnr_dssim(raw_frame, cmp_frame, enhanced_t): dpsnr = utils.cal_psnr(torch.squeeze(enhanced_t, 0).detach().cpu().numpy(), raw_frame, data_range=1.0) -\ utils.cal_psnr(cmp_frame, raw_frame, data_range=1.0) #dpsnr = compare_psnr(torch.squeeze(enhanced_t).detach().cpu().numpy(), np.squeeze(raw_frame), data_range=1) -\ # compare_psnr(np.squeeze(cmp_frame), np.squeeze(raw_frame), data_range=1) #dssim = compare_ssim(torch.squeeze(enhanced_t).detach().cpu().numpy(), np.squeeze(raw_frame), data_range=1) -\ # compare_ssim(np.squeeze(cmp_frame), np.squeeze(raw_frame), data_range=1) return dpsnr #, dssim
def test(self, sess, test_files, ckpt_dir, save_dir): """ Test MAP denoising Parameters ---------- sess - Tensorflow session test_files - list of filenames of images to test ckpt_dir - checkpoint directory containing the pretrained model save_dir - directory into which the noisy and estimate images will be saved Returns ------- """ # init variables tf.initialize_all_variables().run() assert len(test_files) != 0, 'No testing data!' load_model_status, _ = self.load(sess, ckpt_dir) assert load_model_status == True, '[!] Load weights FAILED...' print(" [*] Load weights SUCCESS...") psnr_sum = 0 print("[*] " + 'noise variance: ' + str(self.stddev**2) + " start testing...") for idx in range(len(test_files)): if self.channels == 1: clean_image = load_images(test_files[idx]).astype( np.float32) / 255.0 else: clean_image = load_images_rgb(test_files[idx]).astype( np.float32) / 255.0 output_clean_image, noisy_image = sess.run([self.Dv, self.v], feed_dict={ self.v_ph: clean_image, self.is_training_ph: False }) groundtruth = np.clip(255 * clean_image, 0, 255).astype('uint8') noisyimage = np.clip(255 * noisy_image, 0, 255).astype('uint8') outputimage = np.clip(255 * output_clean_image, 0, 255).astype('uint8') # calculate PSNR psnr = cal_psnr(groundtruth, outputimage) print("img%d PSNR: %.2f" % (idx, psnr)) psnr_sum += psnr save_images(path.join(save_dir, 'noisy%d.png' % idx), noisyimage) save_images(path.join(save_dir, 'denoised%d.png' % idx), outputimage) avg_psnr = psnr_sum / len(test_files) print("--- Average PSNR %.2f ---" % avg_psnr)
def validation(img, name, save_imgs=False, save_dir=None): kernel_generation_net.eval() downsampler_net.eval() upscale_net.eval() kernels, offsets_h, offsets_v = kernel_generation_net(img) downscaled_img = downsampler_net(img, kernels, offsets_h, offsets_v, OFFSET_UNIT) downscaled_img = torch.clamp(downscaled_img, 0, 1) downscaled_img = torch.round(downscaled_img * 255) reconstructed_img = upscale_net(downscaled_img / 255.0) img = img * 255 img = img.data.cpu().numpy().transpose(0, 2, 3, 1) img = np.uint8(img) reconstructed_img = torch.clamp(reconstructed_img, 0, 1) * 255 reconstructed_img = reconstructed_img.data.cpu().numpy().transpose( 0, 2, 3, 1) reconstructed_img = np.uint8(reconstructed_img) downscaled_img = downscaled_img.data.cpu().numpy().transpose(0, 2, 3, 1) downscaled_img = np.uint8(downscaled_img) orig_img = img[0, ...].squeeze() downscaled_img = downscaled_img[0, ...].squeeze() recon_img = reconstructed_img[0, ...].squeeze() if save_imgs and save_dir: img = Image.fromarray(orig_img) img.save(os.path.join(save_dir, name + '_orig.png')) img = Image.fromarray(downscaled_img) img.save(os.path.join(save_dir, name + '_down.png')) img = Image.fromarray(recon_img) img.save(os.path.join(save_dir, name + '_recon.png')) psnr = utils.cal_psnr(orig_img[SCALE:-SCALE, SCALE:-SCALE, ...], recon_img[SCALE:-SCALE, SCALE:-SCALE, ...], benchmark=BENCHMARK) orig_img_y = rgb2ycbcr(orig_img)[:, :, 0] recon_img_y = rgb2ycbcr(recon_img)[:, :, 0] orig_img_y = orig_img_y[SCALE:-SCALE, SCALE:-SCALE, ...] recon_img_y = recon_img_y[SCALE:-SCALE, SCALE:-SCALE, ...] ssim = utils.calc_ssim(recon_img_y, orig_img_y) return psnr, ssim
def evaluate(self, sess, iter_num, test_data, sample_dir, summary_writer): """ Evaluate denoising Parameters ---------- sess - Tensorfow session iter_num - Iteration number test_data - list of array of different size, 4-D, pixel value range is 0-255 sample_dir - evalutation dataset folder name (found in ./data) summary_writer - Tensorflow SummaryWriter Returns ------- """ # assert test_data value range is 0-255 print("[*] Evaluating...") psnr_sum = 0 for idx in range(len(test_data)): clean_image = test_data[idx].astype(np.float32) / 255.0 output_clean_image, noisy_image, psnr_summary = sess.run( [self.Dv, self.v, self.summary_psnr], feed_dict={ self.v_ph: clean_image, self.is_training_ph: False }) summary_writer.add_summary(psnr_summary, iter_num) groundtruth = np.clip(test_data[idx], 0, 255).astype('uint8') noisyimage = np.clip(255 * noisy_image, 0, 255).astype('uint8') outputimage = np.clip(255 * output_clean_image, 0, 255).astype('uint8') # calculate PSNR psnr = cal_psnr(groundtruth, outputimage) # print("img%d PSNR: %.2f" % (idx + 1, psnr)) psnr_sum += psnr save_images( path.join(sample_dir, 'test%d_%d.png' % (idx + 1, iter_num)), groundtruth, noisyimage, outputimage) avg_psnr = psnr_sum / len(test_data) print("--- Test ---- Average PSNR %.2f ---" % avg_psnr)
def evaluate(self, iter_num, evaln_data, evalc_data, sample_dir, summary_merged, summary_writer): """ -i- evaln_data : list, of 4D array of different size. Each array is a noisy image for evaluation, value range 0-255. -i- evalc_data : list, of 4D array of different size. Each array is a clean image for evaluation, value range 0-255. """ # assert eval_data value range is 0-255 print("[*] Evaluating...") psnr_sum = 0 for idx in range(len(evaln_data)): noisy_image = evaln_data[idx].astype(np.float32) / 255.0 clean_image = evalc_data[idx].astype(np.float32) / 255.0 output_image, psnr_summary = self.sess.run( [self.Y, summary_merged], feed_dict={ self.X: noisy_image, self.Y_: clean_image, self.is_training: False }) summary_writer.add_summary(psnr_summary, iter_num) groundtruth = np.clip(evalc_data[idx], 0, 255).astype('uint8') noisy_img = np.clip(evaln_data[idx], 0, 255).astype('uint8') output_img = np.clip(255 * output_image, 0, 255).astype('uint8') # calculate PSNR psnr = cal_psnr(groundtruth, output_img) print("img%d PSNR: %.2f" % (idx + 1, psnr)) psnr_sum += psnr filename = 'test%d_%d.png' % (idx + 1, iter_num) filename = os.path.join(sample_dir, filename) save_images(filename, groundtruth, noisy_img, output_img) avg_psnr = psnr_sum / len(evaln_data) print("--- Test ---- Average PSNR %.2f ---" % avg_psnr)
def experiment( data, label, n_sample=200, n_test=40, n_imgrow=300, n_imgcol=300, shuffle_button=3,\ in_button=3, window_len=7): if (shuffle_button == 1): # 1.shuffle the whole dataset order = nr.permutation(n_sample) print("[*] shuffle the whole dataset") elif (shuffle_button == 2): # 2.do not shuffle order = range(n_sample) print("[*] do not shuffle") elif (shuffle_button == 3): # 3. shuffle the training and validation only order = np.concatenate((nr.permutation(n_sample - n_test), range(n_sample - n_test, n_sample)), axis=0) print("[*] shuffle the training and validation only") else: print("[*] shuffle button not confirmed") shuffledata = data[order, :, :] shufflelabel = label[order, :, :] # split input data and test data in_data = shuffledata[0:(n_sample - n_test)] in_label = shufflelabel[0:(n_sample - n_test)] t_data = shuffledata[(n_sample - n_test):n_sample, :] t_label = shufflelabel[(n_sample - n_test):n_sample, :] train_data = np.zeros([len(in_data), n_imgrow, n_imgcol, in_button]) train_label = np.zeros([len(in_label), n_imgrow, n_imgcol, in_button]) t0 = time.time() for i in range(len(in_data)): # generate input data if (in_button == 1): train_data[i] = in_data[i].reshape( [n_imgrow, n_imgcol, 1]) # using the noisy image as input train_label[i] = (in_data[i] - in_label[i]).reshape( [n_imgrow, n_imgcol, 1]) # using the clean image as output (label) elif (in_button == 2): in_hardthr = in_data[i] train_data_ch1 = mean_filter( in_hardthr, kernelsize=window_len) # channel 1 is for filtered train_data_ch2 = in_hardthr - train_data_ch1 # channel 2 is for residue train_data[i] = np.stack([train_data_ch1, train_data_ch2], axis=2) lab_hardthr = in_data[i] - in_label[i] # channel 1 is for original train_label_ch1 = mean_filter( lab_hardthr, kernelsize=window_len) # channel 1 is for filtered train_label_ch2 = lab_hardthr - train_label_ch1 # channel 2 is for residue train_label[i] = np.stack([train_label_ch1, train_label_ch2], axis=2) elif (in_button == 3): in_hardthr = in_data[i] train_data_ch1 = mean_filter( mean_filter(in_hardthr), kernelsize=window_len) # channel 1 is for filtered twice train_data_ch2 = mean_filter( in_hardthr, kernelsize=window_len ) - train_data_ch1 # residue 1 (once - twice filtered) train_data_ch3 = in_hardthr - mean_filter( in_hardthr, kernelsize=window_len) # residue 2 (original - once filtered) train_data[i] = np.stack( [train_data_ch1, train_data_ch2, train_data_ch3], axis=2) lab_hardthr = in_data[i] - in_label[i] train_label_ch1 = mean_filter( mean_filter(lab_hardthr), kernelsize=window_len) # channel 1 is for filtered twice train_label_ch2 = mean_filter( lab_hardthr, kernelsize=window_len ) - train_label_ch1 # residue 1 (once - twice filtered) train_label_ch3 = lab_hardthr - mean_filter( lab_hardthr, kernelsize=window_len) # residue 2 (original - once filtered) train_label[i] = np.stack( [train_label_ch1, train_label_ch2, train_label_ch3], axis=2) print("[*] train data ready") t1 = time.time() print("Total time running: %s seconds" % (str(t1 - t0))) t0 = time.time() test_data = np.zeros([len(t_data), n_imgrow, n_imgcol, in_button]) test_label = np.zeros([len(t_label), n_imgrow, n_imgcol, in_button]) for i in range(len(t_data)): # generate input data if (in_button == 1): test_data[i] = t_data[i].reshape( [n_imgrow, n_imgcol, 1]) # using the noisy image as input test_label[i] = (t_data[i] - t_label[i]).reshape( [n_imgrow, n_imgcol, 1]) # using the clean image as output (label) elif (in_button == 2): t_hardthr = t_data[i] test_data_ch1 = mean_filter( t_hardthr, kernelsize=window_len) # channel 1 is for filtered test_data_ch2 = t_hardthr - test_data_ch1 # channel 2 is for residue test_data[i] = np.stack([test_data_ch1, test_data_ch2], axis=2) tlab_hardthr = t_data[i] - t_label[i] # channel 1 is for original test_label_ch1 = mean_filter( tlab_hardthr, kernelsize=window_len) # channel 1 is for filtered test_label_ch2 = tlab_hardthr - test_label_ch1 # channel 2 is for residue test_label[i] = np.stack([test_label_ch1, test_label_ch2], axis=2) elif (in_button == 3): t_hardthr = t_data[i] test_data_ch1 = mean_filter( mean_filter(t_hardthr), kernelsize=window_len) # channel 1 is for filtered twice test_data_ch2 = mean_filter( t_hardthr, kernelsize=window_len ) - test_data_ch1 # residue 1 (once - twice filtered) test_data_ch3 = t_hardthr - mean_filter( t_hardthr, kernelsize=window_len) # residue 2 (original - once filtered) test_data[i] = np.stack( [test_data_ch1, test_data_ch2, test_data_ch3], axis=2) tlab_hardthr = t_data[i] - t_label[i] test_label_ch1 = mean_filter( mean_filter(tlab_hardthr), kernelsize=window_len) # channel 1 is for filtered twice test_label_ch2 = mean_filter( tlab_hardthr, kernelsize=window_len ) - test_label_ch1 # residue 1 (once - twice filtered) test_label_ch3 = tlab_hardthr - mean_filter( tlab_hardthr, kernelsize=window_len) # residue 2 (original - once filtered) test_label[i] = np.stack( [test_label_ch1, test_label_ch2, test_label_ch3], axis=2) print("[*] test data ready") t1 = time.time() print("Total time running: %s seconds" % (str(t1 - t0))) CNNclass = FRCNN_model(image_size=[n_imgrow, n_imgcol], in_channel=in_button) model = CNNclass.build_model() model, hist = CNNclass.train_model(model, train_data, train_label) denoised = CNNclass.test_model(model, test_data) output = open('log.txt', 'w+') output.write(hist.history['loss']) output.close # calculate the PSNR of this experiment ori_psnr = np.zeros([n_test, 1]) dnd_psnr = np.zeros([n_test, 1]) # if os.path.exists('./tobedown'): # os.removedirs('./tobedown') for i in range(n_test): noisy_img = t_data[i] denoised_img = (denoised[i, :, :, :].sum(axis=2)).reshape( [n_imgrow, n_imgcol]) real_img = (t_data[i] - t_label[i]).reshape([n_imgrow, n_imgcol]) ori_psnr[i] = cal_psnr(real_img, noisy_img) dnd_psnr[i] = cal_psnr(real_img, denoised_img) ''' print("the {0:d}th test image : ".format(i)) print("---> original PSNR is {0:.4f} dB".format(cal_psnr(real_img,noisy_img))) print("---> denoised PSNR is {0:.4f} dB".format(cal_psnr(real_img,denoised_img))) print("the different frequency PSNR of {0:d}th test image : ".format(i)) if (in_button == 1): print("---- no frequency segmentation ----") elif (in_button == 2): print("----> original PSNR of smooth is {0:.4f} dB".format(cal_psnr(test_data[i,:,:,0],test_label[i,:,:,0]))) print("----> original PSNR of residue is {0:.4f} dB".format(cal_psnr(test_data[i,:,:,1],test_label[i,:,:,1]))) print("----> denoised PSNR of smooth is {0:.4f} dB".format(cal_psnr(denoised[i,:,:,0],test_label[i,:,:,0]))) print("----> denoised PSNR of residue is {0:.4f} dB".format(cal_psnr(denoised[i,:,:,1],test_label[i,:,:,1]))) elif (in_button == 3): print("----> original PSNR of smooth is {0:.4f} dB".format(cal_psnr(test_data[i,:,:,0],test_label[i,:,:,0]))) print("----> original PSNR of residue1 is {0:.4f} dB".format(cal_psnr(test_data[i,:,:,1],test_label[i,:,:,1]))) print("----> original PSNR of residue2 is {0:.4f} dB".format(cal_psnr(test_data[i,:,:,2],test_label[i,:,:,2]))) print("----> denoised PSNR of smooth is {0:.4f} dB".format(cal_psnr(denoised[i,:,:,0],test_label[i,:,:,0]))) print("----> denoised PSNR of residue1 is {0:.4f} dB".format(cal_psnr(denoised[i,:,:,1],test_label[i,:,:,1]))) print("----> denoised PSNR of residue2 is {0:.4f} dB".format(cal_psnr(denoised[i,:,:,2],test_label[i,:,:,2]))) ''' print("---> original PSNR is %.4f dB" % np.mean(ori_psnr)) print("---> denoised PSNR is %.4f dB" % np.mean(dnd_psnr)) output = open('log.txt', 'w+') output.write("---> original PSNR is %.4f dB\n" % np.mean(ori_psnr)) output.write("---> denoised PSNR is %.4f dB\n" % np.mean(dnd_psnr)) output.close # save experiment results savingpath = './tobedown/tobedown_in' + str(in_button) + '_winlen' + str( window_len) if not os.path.exists(savingpath): os.makedirs(savingpath) postfix = str(in_button) + '_' + str(window_len) sio.savemat(os.path.join(savingpath,'denoised'+ postfix+ '.mat'), \ {'denoised'+ postfix: denoised.sum(axis=3).reshape([n_test,n_imgrow,n_imgcol])}) sio.savemat(os.path.join(savingpath,'denoised_ch1'+ postfix+ '.mat'), \ {'denoised1'+ postfix: denoised[:,:,:,0].reshape([n_test,n_imgrow,n_imgcol])}) if (in_button > 1): sio.savemat(os.path.join(savingpath,'denoised_ch2'+ postfix+ '.mat'), \ {'denoised2'+ postfix: denoised[:,:,:,1].reshape([n_test,n_imgrow,n_imgcol])}) if (in_button > 2): sio.savemat(os.path.join(savingpath,'denoised_ch3'+ postfix+ '.mat'), \ {'denoised3'+ postfix: denoised[:,:,:,2].reshape([n_test,n_imgrow,n_imgcol])}) sio.savemat(os.path.join(savingpath,'noisy'+ postfix+ '.mat'), \ {'noisy'+ postfix: t_data}) sio.savemat(os.path.join(savingpath,'real'+ postfix+ '.mat'), \ {'real'+ postfix: (t_data-t_label).reshape([n_test,n_imgrow,n_imgcol])}) sio.savemat(os.path.join(savingpath,'ori_psnr'+ postfix+ '.mat'), \ {'ori_psnr'+ postfix: ori_psnr}) sio.savemat(os.path.join(savingpath,'dnd_psnr'+ postfix+ '.mat'), \ {'dnd_psnr'+ postfix: dnd_psnr}) return ori_psnr, dnd_psnr
def val_loop(stsr, val_loader, val_dataset, epoch): ### validation avg_PSNR_TS = 0 avg_PSNR_ST = 0 avg_PSNR_MERGE = 0 avg_PSNR_RESIDUAL = 0 avg_PSNR_HR = 0 avg_PSNR_LR = 0 avg_SSIM_TS = 0 avg_SSIM_ST = 0 avg_SSIM_MERGE = 0 avg_SSIM_RESIDUAL = 0 avg_SSIM_HR = 0 avg_SSIM_LR = 0 stsr.eval() n = 0 with torch.no_grad(): # for vid, val_data in enumerate(tqdm(val_loader)): for vid, val_data in enumerate(val_loader): """ TEST CODE """ if args.forward_MsMt: HR = val_data['HR'].to(device) LR = torch.stack([nn_down(HR[:, 0]), nn_down(HR[:, 1]), nn_down(HR[:, 2])], dim=1) LR = LR.clamp(0, 1).detach() GT = HR[:, 1] I_L_2, I_H_1, I_H_3, I_TS_2, I_ST_2, I_F_2, mask_1, mask_2, I_R_basic, I_R_2 = stsr(LR[:, 0], LR[:, 2]) else: ST = val_data['ST'].to(device) TS = val_data['TS'].to(device) GT = val_data['GT'].to(device) I_L_2, I_H_1, I_H_3, I_TS_2, I_ST_2, I_F_2, mask_1, mask_2, I_R_basic, I_R_2 = stsr(ST, TS) B, C, H, W = GT.size() for b_id in range(B): avg_PSNR_TS += utils.cal_psnr(I_TS_2[b_id], HR[b_id, 1]).item() avg_PSNR_ST += utils.cal_psnr(I_ST_2[b_id], HR[b_id, 1]).item() avg_PSNR_MERGE += utils.cal_psnr(I_F_2[b_id], HR[b_id, 1]).item() avg_PSNR_RESIDUAL += utils.cal_psnr(I_R_2[b_id], HR[b_id, 1]).item() avg_PSNR_HR += utils.cal_psnr(I_H_1[b_id], HR[b_id, 0]).item()+utils.cal_psnr(I_H_3[b_id], HR[b_id, 2]).item() avg_PSNR_LR += utils.cal_psnr(I_L_2[b_id], LR[b_id, 1]).item() avg_SSIM_TS += utils.cal_ssim(I_TS_2[b_id], HR[b_id, 1]) avg_SSIM_ST += utils.cal_ssim(I_ST_2[b_id], HR[b_id, 1]) avg_SSIM_MERGE += utils.cal_ssim(I_F_2[b_id], HR[b_id, 1]) avg_SSIM_RESIDUAL += utils.cal_ssim(I_R_2[b_id], HR[b_id, 1]) avg_SSIM_HR += utils.cal_ssim(I_H_1[b_id], HR[b_id, 0])+utils.cal_ssim(I_H_3[b_id], HR[b_id, 2]) avg_SSIM_LR += utils.cal_ssim(I_L_2[b_id], LR[b_id, 1]) f = open(os.path.join(save_dir, 'vimeo_record.txt'), 'w') print('PSNR_HR: {}'.format(avg_PSNR_HR/len(val_dataset)/2), file=f) print('PSNR_LR: {}'.format(avg_PSNR_LR/len(val_dataset)), file=f) print('PSNR_TS: {}'.format(avg_PSNR_TS/len(val_dataset)), file=f) print('PSNR_ST: {}'.format(avg_PSNR_ST/len(val_dataset)), file=f) print('PSNR_MERGE: {}'.format(avg_PSNR_MERGE/len(val_dataset)), file=f) print('PSNR_REFINE: {}'.format(avg_PSNR_RESIDUAL/len(val_dataset)), file=f) print('SSIM_HR: {}'.format(avg_SSIM_HR/len(val_dataset)/2), file=f) print('SSIM_LR: {}'.format(avg_SSIM_LR/len(val_dataset)), file=f) print('SSIM_TS: {}'.format(avg_SSIM_TS/len(val_dataset)), file=f) print('SSIM_ST: {}'.format(avg_SSIM_ST/len(val_dataset)), file=f) print('SSIM_MERGE: {}'.format(avg_SSIM_MERGE/len(val_dataset)), file=f) print('SSIM_REFINE: {}'.format(avg_SSIM_RESIDUAL/len(val_dataset)), file=f) f.close()
avg_err, avg_psnr = 0, 0 acc_rec = 0 acc_f_diff = 0 start_time = time.time() for z, data in enumerate(tqdm(trainloader)): ori_v = torch.autograd.Variable(data['ori'], requires_grad=False).cuda() de_v = torch.autograd.Variable(data['de'], requires_grad=False).cuda() residual = ori_v - de_v reconstruction, features = featExNets(residual) if epoch == 0: ori_psnr += utils.cal_psnr(ori_v.cpu().data.numpy(), de_v.cpu().data.numpy(), data_range=1.0).item() # epoch 0 to 4 we use real residual patterns to train upSamplingNets and refineNets # epoch 5 to - we use approximated residual patterns to train upSamplingNets and refineNets if epoch >= 5: c = 1 # weight for loss pick = [] pre_kmmodel = utils.load_obj(opt.logging_root + '/kmeans/kmmodel_%d' % (epoch - 1)) pre_centerPatch = utils.load_obj(opt.logging_root + '/kmeans/centerPatch_%d' % (epoch - 1))
def test(model): with torch.no_grad(): raw_path = os.path.join(dir_test, "RAISE_raw_" + suffix_data_path) dpsnr_sum_5QP = 0.0 #dssim_sum_5QP = 0.0 for QPorQF in order_QPorQF: # test order, not the output order cmp_path = os.path.join( dir_test, "RAISE_" + tab + str(QPorQF) + "_" + suffix_data_path) dpsnr_ave = 0.0 #dssim_ave = 0.0 time_total = 0.0 nfs_test_final = nfs_test_used for ite_frame in range(nfs_test_used): raw_frame = utils.y_import( raw_path, height_frame, width_frame, nfs=1, startfrm=ite_frame).astype( np.float32)[:, start_height:start_height + height_test, start_width:start_width + width_test] / 255 cmp_frame = utils.y_import( cmp_path, height_frame, width_frame, nfs=1, startfrm=ite_frame).astype( np.float32)[:, start_height:start_height + height_test, start_width:start_width + width_test] / 255 if isplane(raw_frame ): # plain frame => no need to enhance => invalid nfs_test_final -= 1 continue cmp_t, raw_t = torch.from_numpy(cmp_frame).to( dev), torch.from_numpy(raw_frame).to( dev) # turn them to tensors and move to GPU cmp_t = cmp_t.view( 1, 1, height_test, width_test ) # batch_size * height * width => batch_size * channel * height * width start_time = time.time() enh_1 = model(cmp_t, 1) # enhanced img from the shallowest output enh_2 = model(cmp_t, 2) enh_3 = model(cmp_t, 3) enh_4 = model(cmp_t, 4) enh_5 = model(cmp_t, 5) # enhanced img from the deepest output if QPorQF == order_QPorQF[0]: enhanced_cmp_t = enh_1 elif QPorQF == order_QPorQF[1]: enhanced_cmp_t = enh_2 elif QPorQF == order_QPorQF[2]: enhanced_cmp_t = enh_3 elif QPorQF == order_QPorQF[3]: enhanced_cmp_t = enh_4 elif QPorQF == order_QPorQF[4]: enhanced_cmp_t = enh_5 time_total += time.time() - start_time if opt_output: # save frame as png func_output(ite_frame, QPorQF, cmp_t, out=0) func_output(ite_frame, QPorQF, enh_1, out=1) func_output(ite_frame, QPorQF, enh_2, out=2) func_output(ite_frame, QPorQF, enh_3, out=3) func_output(ite_frame, QPorQF, enh_4, out=4) func_output(ite_frame, QPorQF, enh_5, out=5) # cal dpsnr and dssim #dpsnr, dssim = cal_dpsnr_dssim(raw_frame, cmp_frame, enhanced_cmp_t) dpsnr = cal_dpsnr_dssim(raw_frame, cmp_frame, enhanced_cmp_t) #print("\rframe %4d|%4d - dpsnr %.3f - dssim %3d (x1e-4) - %s %2d " % (ite_frame + 1, nfs_test_used, dpsnr, dssim * 1e4, tab, QPorQF), end="", flush=True) print("\rframe %4d|%4d - dpsnr %.3f - %s %2d " % (ite_frame + 1, nfs_test_used, dpsnr, tab, QPorQF), end="", flush=True) dpsnr_ave += dpsnr #dssim_ave += dssim # cal dpsnr and dssim for all outputs #dp1, ds1 = cal_dpsnr_dssim(raw_frame, cmp_frame, enh_1) #dp2, ds2 = cal_dpsnr_dssim(raw_frame, cmp_frame, enh_2) #dp3, ds3 = cal_dpsnr_dssim(raw_frame, cmp_frame, enh_3) #dp4, ds4 = cal_dpsnr_dssim(raw_frame, cmp_frame, enh_4) #dp5, ds5 = cal_dpsnr_dssim(raw_frame, cmp_frame, enh_5) dp1 = cal_dpsnr_dssim(raw_frame, cmp_frame, enh_1) dp2 = cal_dpsnr_dssim(raw_frame, cmp_frame, enh_2) dp3 = cal_dpsnr_dssim(raw_frame, cmp_frame, enh_3) dp4 = cal_dpsnr_dssim(raw_frame, cmp_frame, enh_4) dp5 = cal_dpsnr_dssim(raw_frame, cmp_frame, enh_5) fp_each.write("frame %d - %s %2d - ori psnr: %.3f - dpsnr from o1 to o5: %.3f, %.3f, %.3f, %.3f, %.3f\n" %\ (ite_frame, tab, QPorQF, utils.cal_psnr(cmp_frame, raw_frame, data_range=1.0), dp1, dp2, dp3, dp4, dp5)) #fp_each.write("frame %d - %s %2d - ori ssim: %.3f - dssim from o1 to o5: %.3f, %.3f, %.3f, %.3f, %.3f\n" %\ # (ite_frame, tab, QPorQF, compare_ssim(np.squeeze(cmp_frame), np.squeeze(raw_frame), data_range=1), ds1, ds2, ds3, ds4, ds5)) fp_each.flush() dpsnr_ave = dpsnr_ave / nfs_test_final #dssim_ave = dssim_ave / nfs_test_final fps = nfs_test_final / time_total #print("\r=== dpsnr: %.3f - dssim %3d (x1e-4) - %s %2d - fps %.1f === " % (dpsnr_ave, dssim * 1e4, tab, QPorQF, fps), flush=True) print( "\r=== dpsnr: {:.3f} - {:s} {:2d} - fps {:.1f} (no early-exit) ===" .format(dpsnr_ave, tab, QPorQF, fps) + 10 * " ", flush=True) #fp_ave.write("=== dpsnr: %.3f - dssim %3d (x1e-4) - %s %2d - fps %.1f ===\n" % (dpsnr_ave, dssim * 1e4, tab, QPorQF, fps)) fp_ave.write( "=== dpsnr: %.3f - %s %2d - fps %.1f (no early-exit) ===\n" % (dpsnr_ave, tab, QPorQF, fps)) fp_ave.flush() dpsnr_sum_5QP += dpsnr_ave #dssim_sum_5QP += dssim_ave #print("=== dpsnr: %.3f - dssim: % 3d (x1e-4) ===" % (dpsnr_sum_5QP / 5, dssim_sum_5QP / 5 * 1e4), flush=True) #fp_ave.write("=== dpsnr: %.3f - dssim: % 3d (x1e-4) ===\n" % (dpsnr_sum_5QP / 5, dssim_sum_5QP / 5 * 1e4)) print("=== dpsnr: %.3f ===" % (dpsnr_sum_5QP / 5, ), flush=True) fp_ave.write("=== dpsnr: %.3f ===\n" % (dpsnr_sum_5QP / 5)) fp_ave.flush()
'/kmeans/centerPatch_%d' % i) ori_psnr, ori_ssim = 0, 0 avg_err, avg_psnr, avg_ssim = 0, 0, 0 avg_f_diff = 0 start_time = time.time() for _, data in enumerate(testloader): ori_v = torch.autograd.Variable(data['ori'], requires_grad=False).cuda() de_v = torch.autograd.Variable(data['de'], requires_grad=False).cuda() residual = ori_v - de_v ori_psnr += utils.cal_psnr(ori_v.cpu().data.numpy(), de_v.cpu().data.numpy(), data_range=1.0).item() / len(testset) ori_ssim += utils.cal_ssim( ori_v.squeeze().cpu().data.numpy().transpose(1, 2, 0), de_v.squeeze().cpu().data.numpy().transpose(1, 2, 0), data_range=1.0, multichannel=True).item() / len(testset) _, features = featExNets(residual) pick = [] patchResFeat = features.squeeze().permute( 1, 2, 0).contiguous().view( -1, features.size()[1]).cpu().detach().data.numpy() prediction = pre_kmmodel.predict(patchResFeat.astype(np.float64))
def val_loop(stsr, val_loader, val_dataset, epoch): ### validation avg_PSNR_TS = 0 avg_PSNR_ST = 0 avg_PSNR_MERGE = 0 avg_PSNR_RESIDUAL = 0 avg_PSNR_HR = 0 avg_PSNR_LR = 0 stsr.eval() with torch.no_grad(): for vid, val_data in enumerate(tqdm(val_loader)): """ TEST CODE """ if args.train_MsMt: HR = val_data['HR'].to(device) LR = torch.stack( [nn_down(HR[:, 0]), nn_down(HR[:, 1]), nn_down(HR[:, 2])], dim=1) LR = LR.clamp(0, 1).detach() GT = HR[:, 1] I_L_2, I_H_1, I_H_3, I_TS_2, I_ST_2, I_F_2, mask_1, mask_2, I_R_basic, I_R_2 = stsr( LR[:, 0], LR[:, 2]) else: ST = val_data['ST'].to(device) TS = val_data['TS'].to(device) GT = val_data['GT'].to(device) I_L_2, I_H_1, I_H_3, I_TS_2, I_ST_2, I_F_2, mask_1, mask_2, I_R_basic, I_R_2 = stsr( ST, TS) B, C, H, W = GT.size() for b_id in range(B): avg_PSNR_MERGE += utils.cal_psnr(I_R_basic[b_id], GT[b_id]).item() avg_PSNR_RESIDUAL += utils.cal_psnr(I_R_2[b_id], GT[b_id]).item() avg_PSNR_TS += utils.cal_psnr(I_TS_2[b_id], GT[b_id]).item() avg_PSNR_ST += utils.cal_psnr(I_ST_2[b_id], GT[b_id]).item() if args.train_MsMt: avg_PSNR_HR += utils.cal_psnr( I_H_1[b_id], HR[b_id, 0]).item() + utils.cal_psnr( I_H_3[b_id], HR[b_id, 2]).item() avg_PSNR_LR += utils.cal_psnr(I_L_2[b_id], LR[b_id, 1]).item() log = { 'PSNR_TS': avg_PSNR_TS / len(val_dataset), 'PSNR_ST': avg_PSNR_ST / len(val_dataset), 'PSNR_MERGE': avg_PSNR_MERGE / len(val_dataset), 'PSNR_RESIDUAL': avg_PSNR_RESIDUAL / len(val_dataset) } if args.train_MsMt: log['PSNR_HR'] = avg_PSNR_HR / len(val_dataset) / 2. log['PSNR_LR'] = avg_PSNR_LR / len(val_dataset) print(log) return avg_PSNR_RESIDUAL / len(val_dataset)
def train(self): print("Begin training...") whole_time = 0 init_global_step = self.model.global_step.eval(self.sess) for _ in range(init_global_step, self.args.num_iter): start_time = time.time() # For Track 2, 3, and 4: multi-scale => multi-degradation if self.args.degrade: self.data.scale_list = self.data.degra_list # randomly select scale in scale_list idx_scale = np.random.choice(len(self.data.scale_list)) scale = self.data.scale_list[idx_scale] # get batch data and scale train_in_imgs, train_tar_imgs = self.data.get_batch(batch_size=self.args.num_batch, idx_scale=idx_scale) # train the network feed_dict = {self.model.input: train_in_imgs, self.model.target: train_tar_imgs, self.model.flag_scale: scale} _, loss, lr, output, global_step = self.sess.run([self.model.train_op, self.model.loss, self.model.learning_rate, self.model.output, self.model.global_step], \ feed_dict=feed_dict) # check the duration of each iteration end_time = time.time() duration = end_time - start_time whole_time += duration mean_duration = whole_time / (global_step - init_global_step) ############################################## print loss and duratin of training ################################################ if global_step % self.args.print_freq == 0: print('Loss: %d, Duration: %d / %d (%.3f sec/batch)' % (loss, global_step, self.args.num_iter, mean_duration)) ############################################## log the loss, PSNR, and lr of training ############################################## if global_step % self.args.log_freq == 0: # calculate PSNR psnr = 0 for (out_img, tar_img) in zip(output, train_tar_imgs): psnr += cal_psnr(out_img, tar_img, scale) / self.args.num_batch # write summary summaries_dict = {} summaries_dict['loss'] = loss summaries_dict['PSNR'] = psnr summaries_dict['lr'] = lr # summaries_dict['input'] = np.array(train_in_imgs)[:3] # summaries_dict['output'] = np.clip(np.round(output), 0.0, 255.0)[:3] # summaries_dict['target'] = np.array(train_tar_imgs)[:3] self.logger.write(summaries_dict, global_step, is_train=True, idx_scale=idx_scale) ######################################################## save the trained model ######################################################## if global_step % self.args.save_freq == 0: # save the trained model self.model.save(self.sess) ############################################## log the PSNR of validation ############################################## if global_step % self.args.valid_freq == 0: # validation for all scale used for idx_scale, scale in enumerate(self.data.scale_list): if self.args.degrade: scale = 4 valid_in_imgs = self.data.dataset[idx_scale][-self.args.num_valid:] if self.args.is_degrade: valid_tar_imgs = self.data.dataset[idx_scale+len(self.data.scale_list)][-self.args.num_valid:] else: valid_tar_imgs = self.data.dataset[-1][-self.args.num_valid:] # inference validation images & calculate PSNR psnr = 0 for (in_img, tar_img) in zip(valid_in_imgs, valid_tar_imgs): tar_img = mod_crop(tar_img, scale) out_img = chop_forward(in_img, self.sess, self.model, scale=scale, shave=10) psnr += cal_psnr(out_img, tar_img, scale) / self.args.num_valid # write summary summaries_dict = {} summaries_dict['PSNR'] = psnr self.logger.write(summaries_dict, global_step, is_train=False, idx_scale=idx_scale) print("Training is done!")