def evaluate(self, net, iteration, noise_func): avg_psnr = 0.0 for idx in range(len(self.images)): orig_img = self.images[idx] w = orig_img.shape[2] h = orig_img.shape[1] noisy_img = noise_func(orig_img) pred255 = util.infer_image(net, noisy_img) orig255 = util.clip_to_uint8(orig_img) assert (pred255.shape[2] == w and pred255.shape[1] == h) sqerr = np.square( orig255.astype(np.float32) - pred255.astype(np.float32)) s = np.sum(sqerr) cur_psnr = 10.0 * np.log10((255 * 255) / (s / (w * h * 3))) avg_psnr += cur_psnr util.save_image(self.submit_config, pred255, "img_{0}_val_{1}_pred.png".format(iteration, idx)) if iteration == 0: util.save_image( self.submit_config, orig_img, "img_{0}_val_{1}_orig.png".format(iteration, idx)) util.save_image( self.submit_config, noisy_img, "img_{0}_val_{1}_noisy.png".format(iteration, idx)) avg_psnr /= len(self.images) print('Average PSNR: %.2f' % autosummary('PSNR_avg_psnr', avg_psnr))
def evaluate(self, net, iteration, noise_func): avg_psnr = 0.0 for idx in range(len(self.images)): orig_img = self.images[idx] w = orig_img.shape[2] h = orig_img.shape[1] noisy_img = noise_func(orig_img) # infer_image runs the numpy array through the network. pred255 = util.infer_image(net, noisy_img) orig255 = util.clip_to_uint8(orig_img) assert (pred255.shape[2] == w and pred255.shape[1] == h) sqerr = np.square( orig255.astype(np.float32) - pred255.astype(np.float32)) s = np.sum(sqerr) cur_psnr = 10.0 * np.log10((255 * 255) / (s / (w * h * 3))) avg_psnr += cur_psnr # Saves the prediction of all images of this iteration in the results directory. #util.save_image(self.submit_config, pred255, "img_{0}_val_{1}_pred.png".format(iteration, idx)) #if iteration == 0: #util.save_image(self.submit_config, orig_img, "img_{0}_val_{1}_orig.png".format(iteration, idx)) #util.save_image(self.submit_config, noisy_img, "img_{0}_val_{1}_noisy.png".format(iteration, idx)) avg_psnr /= len(self.images) print('Average PSNR: %.2f' % autosummary('PSNR_avg_psnr', avg_psnr))
def infer_image(network_snapshot: str, image: str, out_image: str): tfutil.init_tf(config.tf_config) net = util.load_snapshot(network_snapshot) im = PIL.Image.open(image).convert('RGB') arr = np.array(im, dtype=np.float32) reshaped = arr.transpose([2, 0, 1]) / 255.0 - 0.5 pred255 = util.infer_image(net, reshaped) t = pred255.transpose([1, 2, 0]) # [RGB, H, W] -> [H, W, RGB] PIL.Image.fromarray(t, 'RGB').save(os.path.join(out_image)) print('Inferred image saved in', out_image)
net = util.load_snapshot(args.network_dir + "/network_169000.pickle") reader = tf.TFRecordReader() feats = {'shape': tf.FixedLenFeature([3], tf.int64), 'data1': tf.FixedLenFeature([], tf.string), 'data2': tf.FixedLenFeature([], tf.string)} def _parse_image_function(example_proto): return tf.parse_single_example(example_proto, feats) raw_image_dataset = tf.data.TFRecordDataset(args.tf_train) dataset = raw_image_dataset.map(_parse_image_function) dat = dataset.make_one_shot_iterator().get_next() print(dat) assert(False) try: errs = [] while True: target_img = tf.reshape(tf.decode_raw(dat["data2"], tf.uint8), dat["shape"]).eval() pred_img = util.infer_image(net, target_img) target_img = util.clip_to_uint8(np.mean(target_img, axis=0)) pred_img = util.clip_to_uint8(np.mean(pred_img, axis=0)) errs.append(sum(np.sqrt((target_img - pred_img)**2).flatten())) print(errs) except tf.errors.OutOfRangeError: pass