def predict(i, input_imgs, network, target_imgs, fo): out0 = network.predict(input_imgs) print( '%dth output image, loss = %.6f, min = %.6f, max = %.6f, mean = %.6f, var = %.6f\n' % (i, np.mean(np.abs(target_imgs[0] - out0[0])), np.min(out0[0]), np.max(out0[0]), np.mean(out0[0]), math.sqrt(np.var(out0[0])))) output = utils.compose_dwt_images(out0, FLAGS.wavelet) return output
def ensem_predict(input_imgs, network): # ensembling outs_list = [] for _, flip_axis in enumerate([0, 1, 2, -1]): for _, rotate_rg in enumerate([0, 90]): en_imgs = utils.enhance_imgs(input_imgs, rotate_rg, flip_axis) outs = network.predict(en_imgs) composed_img = utils.compose_dwt_images(outs, FLAGS.wavelet) anti_outs = utils.anti_enhance_imgs(composed_img, rotate_rg, flip_axis) outs_list.append(anti_outs[0]) output = np.mean(outs_list, axis=0) return [output]
def predict(input_imgs, network, step): out_hfreq, out_lowfreq = network.predict(input_imgs) out = concat(out_hfreq, out_lowfreq, step) output = utils.compose_dwt_images([out], FLAGS.wavelet) return output
def predict(input_imgs, network): out0 = network.predict(input_imgs) output = utils.compose_dwt_images(out0, FLAGS.wavelet) return output