def get_batch(self, batch_size): x_imgs, y_imgs = self.data.get_batch(batch_size) h, _, channel = np.shape(x_imgs[0]) if h > self.wavelet_image_size: x_dwt_imgs = utils.get_dwt_images(x_imgs, self.wavelet_image_size, self.wavelet) else: x_dwt_imgs = x_imgs y_dwt_imgs = utils.get_dwt_images(y_imgs, self.wavelet_image_size, self.wavelet) return x_dwt_imgs, y_dwt_imgs
def get_test_set(self, batch_size): x_imgs, y_imgs = self.data.get_test_set(batch_size) if x_imgs!=None and y_imgs!=None: h, _, channel = np.shape(x_imgs[0]) if h > self.wavelet_image_size : x_dwt_imgs = utils.get_dwt_images(x_imgs, self.wavelet_image_size, self.wavelet) else: x_dwt_imgs = x_imgs y_dwt_imgs = utils.get_dwt_images(y_imgs, self.wavelet_image_size, self.wavelet) return x_dwt_imgs, y_dwt_imgs else: return None,None
def get_test_set(self): x_imgs, y_imgs = self.data.get_test_set() x_dwt_imgs = utils.get_dwt_images(x_imgs) y_dwt_imgs = utils.get_dwt_images(y_imgs) return x_dwt_imgs, y_dwt_imgs
def get_batch(self, batch_size, i): x_imgs, y_imgs = self.data.get_batch(batch_size, i) x_dwt_imgs = utils.get_dwt_images(x_imgs) y_dwt_imgs = utils.get_dwt_images(y_imgs) return x_dwt_imgs, y_dwt_imgs
def main(_): if os.path.exists(FLAGS.outdir): shutil.rmtree(FLAGS.outdir) os.mkdir(FLAGS.outdir) img_files = sorted(os.listdir(FLAGS.datadir)) lr_imgs, hr_imgs, lr_pos, hr_pos = utils.get_image_set( img_files, input_dir=FLAGS.datadir, ground_truth_dir=FLAGS.groundtruth, hr_image_size=0, scale=FLAGS.scale, postfix_len=FLAGS.postfixlen) hr_norm_imgs = utils.normalize_color(hr_imgs) network = WaveletSR(FLAGS.layers, FLAGS.featuresize, FLAGS.scale, FLAGS.waveletimgsize, FLAGS.hrimgsize, channels=3) network.buildModel() network.resume(FLAGS.reusedir, global_step=FLAGS.step) level = FLAGS.hrimgsize // (FLAGS.scale * FLAGS.waveletimgsize) fo = open(FLAGS.outdir + '/psnr.csv', 'w') fo.writelines("seq, file, L1, PSNR\n") mean_list = [] for i in range(len(img_files)): size, _, _ = np.shape(lr_imgs[i]) size_hr, _, _ = np.shape(hr_imgs[i]) target_imgs = utils.get_dwt_images( [hr_norm_imgs[i]], img_size=1 + (size_hr // (FLAGS.scale * math.pow(2, level - 1))), wavelet=FLAGS.wavelet) input_imgs = utils.get_dwt_images( [lr_imgs[i]], img_size=1 + (size // level), wavelet=FLAGS.wavelet) if level > 1 else [lr_imgs[i]] #output = predict(i, input_imgs, network, target_imgs, fo) output = ensem_predict(input_imgs, network) output_img = make_same_shape(hr_imgs[i], output[0]) print( '%dth composed image, loss = %.6f, min = %.6f, max = %.6f, mean = %.6f, var = %.6f\n' % (i, np.mean(np.abs(hr_imgs[i] / 255.0, output_img.astype( np.float32))), np.min(output_img), np.max(output_img), np.mean(output_img), math.sqrt(np.var(output_img)))) output_img = np.clip(output_img, 0, 1) output_img = output_img * 255 + 0.5 mean = utils.psnr_np(hr_imgs[i], output_img, scale=FLAGS.scale) #fo.writelines("%s, %.6f\n"%(img_files[i], mean)) fo.writelines( "%d, %s, %.6f, %.6f\n" % (i, img_files[i], np.mean(np.abs(hr_imgs[i] - output_img)), mean)) mean_list.append(mean) tl.vis.save_image(output_img, FLAGS.outdir + '/' + img_files[i]) fo.writelines("%d, Average,0, %.6f\n" % (-1, np.mean(mean_list))) fo.close() return
def get_test_set(self,batch_size): x_imgs,y_imgs = self.data.get_test_set(batch_size) x_dwt_imgs = utils.get_dwt_images(x_imgs) y_dwt_imgs = utils.get_dwt_images(y_imgs) return x_dwt_imgs,y_dwt_imgs