def run_network(downscaled_image, checkpoint): # network for original image config = tf.ConfigProto( device_count={'GPU': 1} ) graph_or = tf.Graph() sess_or = tf.Session(graph=graph_or, config=config) with graph_or.as_default(): input_or = tf.placeholder(tf.float32, (1, downscaled_image.shape[1], downscaled_image.shape[2], params.num_channels), name='input_or') _, output_or = params.network_architecture(input_or) saver = tf.train.Saver() saver.restore(sess_or, checkpoint) graph_tr = tf.Graph() sess_tr = tf.Session(graph=graph_tr, config=config) with graph_tr.as_default(): input_tr = tf.placeholder(tf.float32, (1, downscaled_image.shape[2], downscaled_image.shape[1], params.num_channels), name='input_tr') _, output_tr = params.network_architecture(input_tr, reuse=False) saver = tf.train.Saver() saver.restore(sess_tr, checkpoint) num_images = downscaled_image.shape[0] cnn_output = [] for image in downscaled_image: out_images = [] # original 0 res = trim_image(sess_or.run(output_or, {input_or: [image]})[0]) out_images.append(res) # flip 0 res = trim_image(sess_or.run(output_or, {input_or: [flip_image(image)]})[0]) out_images.append(reverse_flip_image(res)) # original 180 rot180_image = rotate_image_180(image) res = trim_image(sess_or.run(output_or, {input_or: [rot180_image]})[0]) out_images.append(reverse_rotate_image_180(res)) # flip 180 res = trim_image(sess_or.run(output_or, {input_or: [flip_image(rot180_image)]})[0]) out_images.append(reverse_rotate_image_180(reverse_flip_image(res))) if use_mean: cnn_output.append(np.round(np.mean(np.array(out_images), axis=0))) else: cnn_output.append(np.round(np.median(np.array(out_images), axis=0))) cnn_output = np.array(cnn_output) return cnn_output
def upscale(downscaled_image, checkpoint): scale_factor = params.scale # cnn resize input = tf.placeholder(tf.float32, (1, downscaled_image.shape[1], downscaled_image.shape[2], params.num_channels), name='input') _, output = params.network_architecture(input) with tf.Session(config=config) as sess: sess.run(tf.global_variables_initializer()) saver = tf.train.Saver() print('restoring from ' + checkpoint) saver.restore(sess, checkpoint) # step 1 - apply cnn on each resized image, maybe as a batch cnn_output = [] for image in downscaled_image: cnn_output.append(sess.run(output, feed_dict={input: [image]})[0]) cnn_output = np.array(cnn_output) cnn_output = np.round(cnn_output) cnn_output[cnn_output > 255] = 255 return cnn_output
'./data/validation', './data/test', SHOW_IMAGES=False) # training batch_size = 128 input = tf.placeholder(tf.float32, (batch_size, data_reader.dim_patch_in_rows, data_reader.dim_patch_in_cols, params.num_channels), name='input') target = tf.placeholder(tf.float32, (batch_size, data_reader.dim_patch_gt_rows, data_reader.dim_patch_gt_cols, params.num_channels), name='target') output_PS, output = params.network_architecture(input) print('output shape is ', output.shape, target.shape) if params.LOSS == params.L1_LOSS: loss = tf.reduce_mean( tf.reduce_mean(tf.abs(output - target)) + tf.reduce_mean(tf.abs(output_PS - target))) if params.LOSS == params.L2_LOSS: loss = tf.reduce_mean( tf.reduce_mean(tf.square(output - target)) + tf.reduce_mean(tf.square(output_PS - target))) # alpha = 0.7 # loss = alpha * tf.reduce_mean(tf.reduce_mean(tf.abs(output - target)) + tf.reduce_mean(tf.abs(output_PS - target))) + (1 - alpha) * tf.reduce_mean(tf.reduce_mean(tf.square(output - target)) + tf.reduce_mean(tf.square(output_PS - target))) global_step = tf.Variable(0, trainable=False)
def upscale(downscaled_image, checkpoint): # network for original image config = tf.ConfigProto( device_count={'GPU': 1} ) sess_or = tf.Session(config=config) with sess_or: input_or = tf.placeholder(tf.float32, (1, downscaled_image.shape[1], downscaled_image.shape[2], params.num_channels), name='input_or') _, output_or = params.network_architecture(input_or) sess_tr = tf.Session(config=config) with sess_tr: input_tr = tf.placeholder(tf.float32, (1, downscaled_image.shape[2], downscaled_image.shape[1], params.num_channels), name='input_tr') _, output_tr = params.network_architecture(input_tr, reuse=True) saver = tf.train.Saver() print('restoring from ' + checkpoint) saver.restore(sess_or, checkpoint) saver.restore(sess_tr, checkpoint) num_images = downscaled_image.shape[0] cnn_output = [] for image in downscaled_image: out_images = [] # original 0 res = trim_image(sess_or.run(output_or, {input_or: [image]})[0]) out_images.append(res) # flip 0 res = trim_image(sess_or.run(output_or, {input_or: [flip_image(image)]})[0]) out_images.append(reverse_flip_image(res)) # original 90 rot90_image = rotate_image_90(image) res = trim_image(sess_tr.run(output_tr, {input_tr: [rot90_image]})[0]) out_images.append(reverse_rotate_image_90(res)) # flip 90 res = trim_image(sess_tr.run(output_tr, {input_tr: [flip_image(rot90_image)]})[0]) out_images.append(reverse_rotate_image_90(reverse_flip_image(res))) # original 180 rot180_image = rotate_image_180(image) res = trim_image(sess_or.run(output_or, {input_or: [rot180_image]})[0]) out_images.append(reverse_rotate_image_180(res)) # flip 180 res = trim_image(sess_or.run(output_or, {input_or: [flip_image(rot180_image)]})[0]) out_images.append(reverse_rotate_image_180(reverse_flip_image(res))) # original 270 rot270_image = rotate_image_270(image) res = trim_image(sess_tr.run(output_tr, {input_tr: [rot270_image]})[0]) out_images.append(reverse_rotate_image_270(res)) # flip 270 res = trim_image(sess_tr.run(output_tr, {input_tr: [flip_image(rot270_image)]})[0]) out_images.append(reverse_rotate_image_270(reverse_flip_image(res))) # pdb.set_trace() if use_mean: cnn_output.append(np.round(np.mean(np.array(out_images), axis=0))) else: cnn_output.append(np.round(np.median(np.array(out_images), axis=0))) cnn_output = np.array(cnn_output) return cnn_output