def denormalize(image): channels = [ np.expand_dims(image[:, :, channel] * stddev[channel] + mean[channel], -1) for channel in range(3) ] denormalized_image = ctfi.rescale(np.concatenate(channels, 2), 0.0, 1.0) return denormalized_image
def main(argv): parser = argparse.ArgumentParser(description='Display image from dataset') parser.add_argument('dataset', type=str, help='Image file to display.') parser.add_argument( 'key', type=str, help='Key of feature that contains image to be displayed.') parser.add_argument('size', type=int, help='Size of samples in dataset.') parser.add_argument('position', type=int, help='Position of sample to plot in dataset.') args = parser.parse_args() features = [{ 'shape': [args.size, args.size, 3], 'key': args.key, 'dtype': tf.float32 }] decode_op = ctfd.construct_decode_op(features) dataset = tf.data.TFRecordDataset(args.dataset).map(decode_op, num_parallel_calls=8) image = tf.data.experimental.get_single_element( dataset.skip(args.position).take(1))[args.key] plt.imshow(ctfi.rescale(image.numpy(), 0.0, 1.0)) plt.show()
def main(argv): parser = argparse.ArgumentParser( description='Compute codes and reconstructions for image.') parser.add_argument('export_dir', type=str, help='Path to saved model.') parser.add_argument( 'mean', type=str, help='Path to npy file holding mean for normalization.') parser.add_argument( 'variance', type=str, help='Path to npy file holding variance for normalization.') parser.add_argument('source_filename', type=str, help='Image file from which to extract patch.') parser.add_argument('source_image_size', type=int, nargs=2, help='Size of the input image, HW.') parser.add_argument('offsets', type=int, nargs=2, help='Position where to extract the patch.') parser.add_argument('patch_size', type=int, help='Size of image patch.') parser.add_argument('target_filename', type=str, help='Image file for which to create the heatmap.') parser.add_argument( 'target_image_size', type=int, nargs=2, help='Size of the input image for which to create heatmap, HW.') parser.add_argument( 'method', type=str, help= 'Method to use to measure similarity, one of KLD, SKLD, BD, HD, SQHD.') parser.add_argument( '--stain_code_size', type=int, dest='stain_code_size', default=0, help= 'Optional: Size of the stain code to use, which is skipped for similarity estimation' ) parser.add_argument('--rotate', type=float, dest='angle', default=0, help='Optional: rotation angle to rotate target image') parser.add_argument('--subsampling_factor', type=int, dest='subsampling_factor', default=1, help='Factor to subsample source and target image.') args = parser.parse_args() mean = np.load(args.mean) variance = np.load(args.variance) stddev = [np.math.sqrt(x) for x in variance] def denormalize(image): channels = [ np.expand_dims( image[:, :, channel] * stddev[channel] + mean[channel], -1) for channel in range(3) ] denormalized_image = ctfi.rescale(np.concatenate(channels, 2), 0.0, 1.0) return denormalized_image def normalize(image, name=None): channels = [ tf.expand_dims( (image[:, :, :, channel] - mean[channel]) / stddev[channel], -1) for channel in range(3) ] return tf.concat(channels, 3, name=name) latest_checkpoint = tf.train.latest_checkpoint(args.export_dir) saver = tf.train.import_meta_graph(latest_checkpoint + '.meta', import_scope='imported') with tf.Session(graph=tf.get_default_graph()).as_default() as sess: # Load image and extract patch from it and create distribution. source_image = ctfi.subsample( ctfi.load(args.source_filename, height=args.source_image_size[0], width=args.source_image_size[1]), args.subsampling_factor) args.source_image_size = list( map(lambda x: int(x / args.subsampling_factor), args.source_image_size)) patch = normalize( tf.expand_dims( tf.image.crop_to_bounding_box(source_image, args.offsets[0], args.offsets[1], args.patch_size, args.patch_size), 0)) patch_cov, patch_mean = tf.contrib.graph_editor.graph_replace([ sess.graph.get_tensor_by_name('imported/z_log_sigma_sq/BiasAdd:0'), sess.graph.get_tensor_by_name('imported/z_mean/BiasAdd:0') ], {sess.graph.get_tensor_by_name('imported/patch:0'): patch}) #patch_distribution = tf.contrib.distributions.MultivariateNormalTriL(loc=patch_mean[:,args.stain_code_size:], scale_tril=patch_cov[:,args.stain_code_size:,args.stain_code_size:]) patch_descriptor = tf.concat([ patch_mean[:, args.stain_code_size:], tf.layers.flatten(patch_cov[:, args.stain_code_size:]) ], -1) sim_vals = [] structure_code_size = patch_mean.get_shape().as_list( )[1] - args.stain_code_size #Load image for which to create the heatmap target_image = ctfi.subsample( ctfi.load(args.target_filename, height=args.target_image_size[0], width=args.target_image_size[1]), args.subsampling_factor) args.target_image_size = list( map(lambda x: int(x / args.subsampling_factor), args.target_image_size)) target_image = tf.contrib.image.rotate(target_image, np.radians(args.angle)) heatmap_height = args.target_image_size[0] - (args.patch_size - 1) heatmap_width = args.target_image_size[1] - (args.patch_size - 1) # Compute byte size as: width*height*channels*sizeof(float32) patch_size_in_byte = args.patch_size**2 * 3 * 4 max_patches = int(max_patch_buffer_size / patch_size_in_byte) max_num_rows = int(max_patches / heatmap_width) max_chunk_size = int(max_buffer_size_in_byte / patch_size_in_byte) #Iteration over image regions that we can load num_iterations = int(args.target_image_size[0] / max_num_rows) + 1 all_chunks = list() all_similarities = list() chunk_tensors = list() chunk_sizes = np.zeros(num_iterations, dtype=np.int) chunk_sizes.fill(heatmap_width) for i in range(num_iterations): processed_rows = i * max_num_rows rows_to_load = min(max_num_rows + (args.patch_size - 1), args.target_image_size[0] - processed_rows) if rows_to_load < args.patch_size: break # Extract region for which we can compute patches target_image_region = tf.image.crop_to_bounding_box( target_image, processed_rows, 0, rows_to_load, args.target_image_size[1]) # Size = (image_width - patch_size - 1) * (image_height - patch_size - 1) for 'VALID' padding and # image_width * image_height for 'SAME' padding all_image_patches = tf.unstack( normalize( ctfi.extract_patches(target_image_region, args.patch_size, strides=[1, 1, 1, 1], padding='VALID'))) possible_chunk_sizes = get_divisors(len(all_image_patches)) for size in possible_chunk_sizes: if size < max_chunk_size: chunk_sizes[i] = size break # Partition patches into chunks chunked_patches = list( create_chunks(all_image_patches, chunk_sizes[i])) chunked_patches = list(map(tf.stack, chunked_patches)) all_chunks.append(chunked_patches) chunk_tensor = tf.placeholder( tf.float32, shape=[chunk_sizes[i], args.patch_size, args.patch_size, 3], name='chunk_tensor_placeholder') chunk_tensors.append(chunk_tensor) image_patches_cov, image_patches_mean = tf.contrib.graph_editor.graph_replace( [ sess.graph.get_tensor_by_name( 'imported/z_log_sigma_sq/BiasAdd:0'), sess.graph.get_tensor_by_name('imported/z_mean/BiasAdd:0') ], { sess.graph.get_tensor_by_name('imported/patch:0'): chunk_tensor }) image_patches_descriptors = tf.concat([ image_patches_mean[:, args.stain_code_size:], tf.layers.flatten(image_patches_cov[:, args.stain_code_size:]) ], -1) distances = dist_kl(patch_descriptor, image_patches_descriptors, structure_code_size) similarities = tf.squeeze(distances) all_similarities.append(similarities) sess.run(tf.global_variables_initializer()) sess.run(tf.local_variables_initializer()) saver.restore(sess, latest_checkpoint) for i in range(len(all_chunks)): for chunk in all_chunks[i]: #chunk_vals = sess.run(all_similarities[i], feed_dict={chunk_tensors[i]: sess.run(chunk)}) sim_vals.extend( sess.run(all_similarities[i], feed_dict={chunk_tensors[i]: sess.run(chunk)})) print(len(sim_vals)) sim_heatmap = np.reshape(sim_vals, [heatmap_height, heatmap_width]) heatmap_tensor = tf.expand_dims( tf.expand_dims(tf.convert_to_tensor(sim_heatmap), -1), 0) dy, dx = tf.image.image_gradients(heatmap_tensor) sim_vals_normalized = 1.0 - ctfi.rescale(sim_heatmap, 0.0, 1.0) k_min = 20 min_indices = np.unravel_index( np.argsort(sim_vals)[:k_min], sim_heatmap.shape) fig_min, ax_min = plt.subplots(4, 5) for i in range(k_min): target_patch = tf.image.crop_to_bounding_box( target_image, int(min_indices[0][i]), int(min_indices[1][i]), args.patch_size, args.patch_size) ax_min[int(i / 5), int(i % 5)].imshow(sess.run(target_patch)) ax_min[int(i / 5), int(i % 5)].set_title('y:' + str(min_indices[0][i]) + ', x:' + str(min_indices[1][i])) fig, ax = plt.subplots(2, 3) cmap = 'plasma' denormalized_patch = denormalize(sess.run(patch)[0]) max_sim_val = np.max(sim_vals) max_idx = np.unravel_index(np.argmin(sim_heatmap), sim_heatmap.shape) target_image_patch = tf.image.crop_to_bounding_box( target_image, max_idx[0], max_idx[1], args.patch_size, args.patch_size) print(max_idx) print(min_indices) ax[1, 0].imshow(sess.run(source_image)) ax[1, 1].imshow(sess.run(target_image)) ax[0, 0].imshow(denormalized_patch) heatmap_image = ax[0, 2].imshow(sim_heatmap, cmap=cmap) ax[0, 1].imshow(sess.run(target_image_patch)) #dx_image = ax[0,2].imshow(np.squeeze(sess.run(dx)), cmap='bwr') #dy_image = ax[1,2].imshow(np.squeeze(sess.run(dy)), cmap='bwr') gradient_image = ax[1, 2].imshow(np.squeeze(sess.run(dx + dy)), cmap='bwr') fig.colorbar(heatmap_image, ax=ax[0, 2]) #fig.colorbar(dx_image, ax=ax[0,2]) #fig.colorbar(dy_image, ax=ax[1,2]) fig.colorbar(gradient_image, ax=ax[1, 2]) plt.show() sess.close() print("Done!") return 0
def main(argv): parser = argparse.ArgumentParser( description='Compute latent code for image patch by model inference.') parser.add_argument('export_dir', type=str, help='Path to saved model to use for inference.') args = parser.parse_args() width = 512 height = 512 channels = 3 filename_target = os.path.join(git_root, 'data', 'images', 'HE_level_1_cropped_512x512.png') image_target = tf.expand_dims( ctfi.load(filename_target, width=width, height=height, channels=channels), 0) image_target = tf.reshape(image_target, shape=[1, 512, 512, 3]) image_target = tf.contrib.image.rotate(image_target, 0.05 * math.pi) filename_moving = os.path.join(git_root, 'data', 'images', 'HE_level_1_cropped_512x512.png') image_moving = tf.expand_dims( ctfi.load(filename_moving, width=width, height=height, channels=channels), 0) image_moving = tf.reshape(image_moving, shape=[1, 512, 512, 3]) image_moving = tf.contrib.image.rotate(image_moving, -0.05 * math.pi) step = tf.Variable(tf.zeros([], dtype=tf.float32)) X, Y = np.mgrid[0:width:8j, 0:height:8j] positions = np.transpose(np.vstack([X.ravel(), Y.ravel()])) positions = tf.expand_dims( tf.convert_to_tensor(positions, dtype=tf.float32), 0) target_source_control_point_locations = tf.Variable(positions) moving_source_control_point_locations = tf.Variable(positions) dest_control_point_locations = tf.Variable(positions) warped_moving = tf.Variable(image_moving) warped_moving, flow_moving = tf.contrib.image.sparse_image_warp( warped_moving, moving_source_control_point_locations, dest_control_point_locations, name='sparse_image_warp_moving', interpolation_order=1, regularization_weight=0.01, #num_boundary_points=1 ) warped_target = tf.Variable(image_target) warped_target, flow_target = tf.contrib.image.sparse_image_warp( warped_target, target_source_control_point_locations, dest_control_point_locations, name='sparse_image_warp_target', interpolation_order=1, regularization_weight=0.01, #num_boundary_points=1 ) warped_target_patches = normalize( ctfi.extract_patches(warped_target[0], 32, strides=[1, 32, 32, 1])) warped_moving_patches = normalize( ctfi.extract_patches(warped_moving[0], 32, strides=[1, 32, 32, 1])) #warped_target_patches = normalize(tf.image.extract_glimpse(tf.tile(warped_target,[64,1,1,1]),[32,32],target_source_control_point_locations[0], centered=False)) #warped_moving_patches = normalize(tf.image.extract_glimpse(tf.tile(warped_moving,[64,1,1,1]),[32,32],moving_source_control_point_locations[0], centered=False)) #learning_rate = 0.05 # h_squared learning_rate = 0.05 # sym_kl #learning_rate = 0.05 # battacharyya #learning_rate = 1 #hellinger #learning_rate = 0.005 # ssd loss latest_checkpoint = tf.train.latest_checkpoint(args.export_dir) #saver_target = tf.train.import_meta_graph(latest_checkpoint + '.meta', import_scope='target') #saver_moving = tf.train.import_meta_graph(latest_checkpoint + '.meta', import_scope='moving') saver = tf.train.import_meta_graph(latest_checkpoint + '.meta', import_scope='imported') with tf.Session(graph=tf.get_default_graph()).as_default() as sess: #g = tf.Graph() #saved_model = predictor.from_saved_model('/sdb1/logs/examples/models/gae_sampler_v2_0/saved_model/1574232815', graph=sess.graph) #fetch_ops = ['max_pooling2d_4/MaxPool:0','init'] #fetch_ops = ['z:0','init'] #fetch_ops = ['z_mean/BiasAdd:0','z_covariance/MatrixBandPart:0'] #fetch_ops.extend([v.name for v in g.get_collection(tf.GraphKeys.GLOBAL_VARIABLES)]) #warped_target_graph = tf.graph_util.import_graph_def(sess.graph.as_graph_def(), input_map={'patch:0': warped_target_patches}, return_elements=fetch_ops, name='target') #warped_moving_graph = tf.graph_util.import_graph_def(sess.graph.as_graph_def(),input_map={'patch:0': warped_moving_patches}, return_elements=fetch_ops, name='moving') #sess.run(warped_target_graph[2:]) #sess.run(warped_moving_graph[2:]) #target_cov, target_mean = tf.contrib.graph_editor.graph_replace([sess.graph.get_tensor_by_name('target/z_covariance/MatrixBandPart:0'),sess.graph.get_tensor_by_name('target/z_mean/BiasAdd:0')] ,{ sess.graph.get_tensor_by_name('target/patch:0'): warped_target_patches }) #moving_cov, moving_mean = tf.contrib.graph_editor.graph_replace([sess.graph.get_tensor_by_name('moving/z_covariance/MatrixBandPart:0'),sess.graph.get_tensor_by_name('moving/z_mean/BiasAdd:0')] ,{ sess.graph.get_tensor_by_name('moving/patch:0'): warped_moving_patches }) target_cov, target_mean = tf.contrib.graph_editor.graph_replace( [ sess.graph.get_tensor_by_name( 'imported/z_covariance_lower_tri/MatrixBandPart:0'), sess.graph.get_tensor_by_name('imported/z_mean/BiasAdd:0') ], { sess.graph.get_tensor_by_name('imported/patch:0'): warped_target_patches }) moving_cov, moving_mean = tf.contrib.graph_editor.graph_replace( [ sess.graph.get_tensor_by_name( 'imported/z_covariance_lower_tri/MatrixBandPart:0'), sess.graph.get_tensor_by_name('imported/z_mean/BiasAdd:0') ], { sess.graph.get_tensor_by_name('imported/patch:0'): warped_moving_patches }) #target_mean = warped_target_graph[0]#[:,6:] #target_cov = warped_target_graph[1]#[:,6:,6:] stain_code_size = 8 N_target = tf.contrib.distributions.MultivariateNormalTriL( loc=target_mean[:, stain_code_size:], scale_tril=target_cov[:, stain_code_size:, stain_code_size:]) #moving_mean = warped_moving_graph[0]#[:,6:] #moving_cov = warped_moving_graph[1]#[:,6:,6:] N_mov = tf.contrib.distributions.MultivariateNormalTriL( loc=moving_mean[:, stain_code_size:], scale_tril=moving_cov[:, stain_code_size:, stain_code_size:]) sym_kl_div = N_target.kl_divergence(N_mov) + N_mov.kl_divergence( N_target) #h_squared = ctf.multivariate_squared_hellinger_distance(N_target, N_mov) #hellinger = tf.sqrt(h_squared) #batta_dist = ctf.bhattacharyya_distance(N_target, N_mov) #multi_kl_div = ctf.multivariate_kl_div(N_target, N_mov) + ctf.multivariate_kl_div(N_mov, N_target) loss = tf.reduce_sum(sym_kl_div) #loss = tf.reduce_sum(tf.math.squared_difference(warped_target_codes, warped_moving_codes)) #loss = tf.reduce_sum(tf.sqrt(tf.math.squared_difference(image_code, warped_code))) #loss = tf.reduce_sum(tf.math.squared_difference(warped_target, warped_moving)) optimizer = tf.contrib.optimizer_v2.GradientDescentOptimizer( learning_rate=learning_rate) compute_gradients = optimizer.compute_gradients( loss, var_list=[ moving_source_control_point_locations, target_source_control_point_locations ]) apply_gradients = optimizer.apply_gradients(compute_gradients, global_step=step) sess.run(tf.global_variables_initializer()) sess.run(tf.local_variables_initializer()) #saver_target.restore(sess, latest_checkpoint) #saver_moving.restore(sess, latest_checkpoint) saver.restore(sess, latest_checkpoint) fig, ax = plt.subplots(3, 3) ax[0, 0].imshow( ctfi.rescale(image_target.eval(session=sess)[0], 0.0, 1.0)) ax[0, 0].set_title('target') ax[0, 0].set_autoscale_on(False) ax[0, 1].imshow( ctfi.rescale((image_target + image_moving).eval(session=sess)[0], 0.0, 1.0)) ax[0, 1].set_title('overlayed') ax[0, 1].set_autoscale_on(False) ax[0, 2].imshow( ctfi.rescale(image_moving.eval(session=sess)[0], 0.0, 1.0)) ax[0, 2].set_title('moving') ax[0, 2].set_autoscale_on(False) plot_warped_target = ax[1, 0].imshow( ctfi.rescale(warped_target.eval(session=sess)[0], 0.0, 1.0)) ax[1, 0].set_title('warped_target') ax[1, 0].set_autoscale_on(False) plot_overlayed = ax[1, 1].imshow( ctfi.rescale((warped_target + warped_moving).eval(session=sess)[0], 0.0, 1.0)) ax[1, 1].set_title('warped_overlayed') ax[1, 1].set_autoscale_on(False) plot_warped_moving = ax[1, 2].imshow( ctfi.rescale(warped_moving.eval(session=sess)[0], 0.0, 1.0)) ax[1, 2].set_title('warped_moving') ax[1, 2].set_autoscale_on(False) plot_diff_target = ax[2, 0].imshow( ctfi.rescale( tf.abs(image_target - warped_target).eval(session=sess)[0], 0., 1.)) ax[2, 0].set_title('diff_target') ax[2, 0].set_autoscale_on(False) plot_diff_overlayed = ax[2, 1].imshow( ctfi.rescale( tf.abs(warped_target - warped_moving).eval(session=sess)[0], 0., 1.)) ax[2, 1].set_title('diff_overlayed') ax[2, 1].set_autoscale_on(False) plot_diff_moving = ax[2, 2].imshow( ctfi.rescale( tf.abs(image_moving - warped_moving).eval(session=sess)[0], 0., 1.)) ax[2, 2].set_title('diff_moving') ax[2, 2].set_autoscale_on(False) dest_points = dest_control_point_locations.eval(session=sess)[0] moving_source_points = moving_source_control_point_locations.eval( session=sess)[0] target_source_points = target_source_control_point_locations.eval( session=sess)[0] plot_scatter_moving, = ax[1, 2].plot(moving_source_points[:, 0], moving_source_points[:, 1], 's', marker='x', ms=5, color='orange') plot_scatter_target, = ax[1, 0].plot(target_source_points[:, 0], target_source_points[:, 1], 's', marker='x', ms=5, color='orange') plot_moving_grad = ax[1, 2].quiver( moving_source_points[:, 0], # X moving_source_points[:, 1], # Y np.zeros_like(moving_source_points[:, 0]), np.zeros_like(moving_source_points[:, 0]), units='xy', angles='xy', scale_units='xy', scale=1) plot_target_grad = ax[1, 0].quiver( target_source_points[:, 0], # X target_source_points[:, 1], # Y np.zeros_like(target_source_points[:, 0]), np.zeros_like(target_source_points[:, 0]), units='xy', angles='xy', scale_units='xy', scale=1) plt.ion() fig.canvas.draw() fig.canvas.flush_events() plt.show() iterations = 5000 print_iterations = 1 accumulated_gradients = np.zeros_like(sess.run(compute_gradients)) while step.value().eval(session=sess) < iterations: step_val = int(step.value().eval(session=sess)) gradients = sess.run(compute_gradients) sess.run(apply_gradients) accumulated_gradients += gradients if step_val % print_iterations == 0 or step_val == iterations - 1: #moving_cov_val = sess.run(moving_cov) #target_cov_val = sess.run(target_cov) #moving_mean_val = sess.run(moving_mean) #target_mean_val = sess.run(target_mean) loss_val = loss.eval(session=sess) diff_moving = tf.abs(image_moving - warped_moving).eval(session=sess) diff_target = tf.abs(image_target - warped_target).eval(session=sess) diff = tf.abs(warped_target - warped_moving).eval(session=sess) #warped_code_eval = np.mean(warped_moving_codes.eval(session=sess)) print("{0:d}\t{1:.4f}\t{2:.4f}\t{3:.4f}\t{4:.4f}".format( step_val, loss_val, np.sum(diff_moving), np.sum(diff_target), np.sum(diff))) plot_warped_target.set_data( ctfi.rescale(warped_target.eval(session=sess)[0], 0., 1.)) plot_warped_moving.set_data( ctfi.rescale(warped_moving.eval(session=sess)[0], 0., 1.)) plot_overlayed.set_data( ctfi.rescale( (warped_target + warped_moving).eval(session=sess)[0], 0., 1.)) plot_diff_target.set_data(ctfi.rescale(diff_target[0], 0., 1.)) plot_diff_moving.set_data(ctfi.rescale(diff_moving[0], 0., 1.)) plot_diff_overlayed.set_data(ctfi.rescale(diff[0], 0., 1.)) moving_gradients = learning_rate * np.squeeze( accumulated_gradients[0][0]) moving_points = np.squeeze(gradients[0][1]) target_gradients = learning_rate * np.squeeze( accumulated_gradients[1][0]) target_points = np.squeeze(gradients[1][1]) plot_scatter_moving.set_data(moving_points[:, 0], moving_points[:, 1]) plot_scatter_target.set_data(target_points[:, 0], target_points[:, 1]) plot_moving_grad.remove() plot_moving_grad = ax[1, 2].quiver( moving_points[:, 0], # X moving_points[:, 1], # Y moving_gradients[:, 0], moving_gradients[:, 1], moving_gradients, units='xy', angles='xy', scale_units='xy', scale=1) plot_target_grad.remove() plot_target_grad = ax[1, 0].quiver( target_points[:, 0], # X target_points[:, 1], # Y target_gradients[:, 0], target_gradients[:, 1], target_gradients, units='xy', angles='xy', scale_units='xy', scale=1) fig.canvas.draw() fig.canvas.flush_events() plt.show() accumulated_gradients.fill(0) print("Done!") plt.ioff() plt.show() sys.exit(0)
def main(argv): parser = argparse.ArgumentParser( description='Compute latent code for image patch by model inference.') parser.add_argument('export_dir', type=str, help='Path to saved model to use for inference.') args = parser.parse_args() filename = os.path.join(git_root, 'data', 'images', 'HE_level_1_cropped_512x512.png') image = tf.expand_dims( ctfi.load(filename, width=512, height=512, channels=3), 0) target_filename = os.path.join(git_root, 'data', 'images', 'CD3_level_1_cropped_512x512.png') image_rotated = tf.Variable( tf.expand_dims( ctfi.load(target_filename, width=512, height=512, channels=3), 0)) step = tf.Variable(tf.zeros([], dtype=tf.float32)) X, Y = np.mgrid[0:512:8j, 0:512:8j] positions = np.transpose(np.vstack([X.ravel(), Y.ravel()])) positions = tf.expand_dims( tf.convert_to_tensor(positions, dtype=tf.float32), 0) source_control_point_locations = tf.Variable(positions) dest_control_point_locations = tf.Variable(positions) warped_image = tf.Variable(image_rotated) warped_image, flow = tf.contrib.image.sparse_image_warp( image_rotated, source_control_point_locations, dest_control_point_locations, name='sparse_image_warp', interpolation_order=1, regularization_weight=0.005, #num_boundary_points=1 ) image_patches = normalize( ctfi.extract_patches(image[0], 32, strides=[1, 16, 16, 1])) warped_patches = normalize( ctfi.extract_patches(warped_image[0], 32, strides=[1, 16, 16, 1])) learning_rate = 0.05 latest_checkpoint = tf.train.latest_checkpoint(args.export_dir) saver = tf.train.import_meta_graph(latest_checkpoint + '.meta', import_scope='imported') with tf.Session(graph=tf.get_default_graph()).as_default() as sess: target_cov, target_mean = tf.contrib.graph_editor.graph_replace( [ sess.graph.get_tensor_by_name( 'imported/z_covariance_lower_tri/MatrixBandPart:0'), sess.graph.get_tensor_by_name('imported/z_mean/BiasAdd:0') ], {sess.graph.get_tensor_by_name('imported/patch:0'): image_patches}) moving_cov, moving_mean = tf.contrib.graph_editor.graph_replace([ sess.graph.get_tensor_by_name( 'imported/z_covariance_lower_tri/MatrixBandPart:0'), sess.graph.get_tensor_by_name('imported/z_mean/BiasAdd:0') ], {sess.graph.get_tensor_by_name('imported/patch:0'): warped_patches}) N_target = tf.contrib.distributions.MultivariateNormalTriL( loc=target_mean[:, 6:], scale_tril=target_cov[:, 6:, 6:]) N_mov = tf.contrib.distributions.MultivariateNormalTriL( loc=moving_mean[:, 6:], scale_tril=moving_cov[:, 6:, 6:]) #h_squared = ctf.multivariate_squared_hellinger_distance(N_target, N_mov) #hellinger = tf.sqrt(h_squared) loss = tf.reduce_sum( N_target.kl_divergence(N_mov) + N_mov.kl_divergence(N_target)) scipy_options = {'maxiter': 10000, 'disp': True, 'iprint': 10} scipy_optimizer = tf.contrib.opt.ScipyOptimizerInterface( loss, var_list=[source_control_point_locations], method='SLSQP', options=scipy_options) optimizer = tf.train.GradientDescentOptimizer( learning_rate=learning_rate) compute_gradients_source = optimizer.compute_gradients( loss, var_list=[source_control_point_locations]) apply_gradients_source = optimizer.apply_gradients( compute_gradients_source, global_step=step) sess.run(tf.global_variables_initializer()) sess.run(tf.local_variables_initializer()) saver.restore(sess, latest_checkpoint) fig, ax = plt.subplots(2, 3) ax[0, 0].imshow(ctfi.rescale(image.eval(session=sess)[0], 0.0, 1.0)) ax[0, 0].set_title('image') ax[0, 0].set_autoscale_on(False) #ax[0,0].plot([200],[200],'s',marker='x', ms=10, color='red') ax[0, 1].imshow( ctfi.rescale(image_rotated.eval(session=sess)[0], 0.0, 1.0)) ax[0, 1].set_title('rotated') ax[0, 1].set_autoscale_on(False) plot_warped = ax[0, 2].imshow( ctfi.rescale(warped_image.eval(session=sess)[0], 0.0, 1.0)) ax[0, 2].set_title('warped') ax[0, 2].set_autoscale_on(False) plot_diff_image = ax[1, 0].imshow( ctfi.rescale( tf.abs(image - warped_image).eval(session=sess)[0], 0., 1.)) ax[1, 0].set_title('diff_image') ax[1, 0].set_autoscale_on(False) plot_diff_rotated = ax[1, 1].imshow( ctfi.rescale( tf.abs(image_rotated - warped_image).eval(session=sess)[0], 0., 1.)) ax[1, 1].set_title('diff_rotated') ax[1, 1].set_autoscale_on(False) plot_flow = ax[1, 2].imshow( np.zeros_like(image[0, :, :, :].eval(session=sess))) #flow_mesh_x, flow_mesh_y = np.meshgrid(np.arange(0, 1024 * 10, 10), np.arange(0, 1024 * 10, 10)) #plot_flow = ax[1,2].quiver( # flow_mesh_x, # X # flow_mesh_y, # Y # np.zeros_like(flow_mesh_x), # np.zeros_like(flow_mesh_y), # units='xy',angles='xy', scale_units='xy', scale=10) ax[1, 2].set_title('flow') ax[1, 2].set_autoscale_on(False) dest_points = dest_control_point_locations.eval(session=sess)[0] source_points = source_control_point_locations.eval(session=sess)[0] plot_scatter_source, = ax[0, 1].plot(source_points[:, 0], source_points[:, 1], 's', marker='x', ms=5, color='orange') plot_scatter_dest, = ax[0, 2].plot(dest_points[:, 0], dest_points[:, 1], 's', marker='x', ms=5, color='green') plot_source_grad = ax[0, 1].quiver( source_points[:, 0], # X source_points[:, 1], # Y np.zeros_like(source_points[:, 0]), np.zeros_like(source_points[:, 0]), units='xy', angles='xy', scale_units='xy', scale=1) plot_dest_grad = ax[0, 2].quiver( dest_points[:, 0], # X dest_points[:, 1], # Y np.zeros_like(dest_points[:, 0]), np.zeros_like(dest_points[:, 0]), units='xy', angles='xy', scale_units='xy', scale=1) plt.ion() fig.canvas.draw() fig.canvas.flush_events() plt.show() #gradients = (tf.zeros_like(source_control_point_locations),tf.zeros_like(source_control_point_locations)) iterations = 100000 while step.value().eval(session=sess) < iterations: step_val = int(step.value().eval(session=sess)) #scipy_optimizer.minimize(sess) gradients = sess.run(compute_gradients_source) sess.run(apply_gradients_source) if step_val % 100 == 0 or step_val == iterations - 1: loss_val = loss.eval(session=sess) grad_mean_source = np.mean(gradients[0][0]) grad_mean_dest = 0.0 # np.mean(gradients[1][0]) flow_field = flow.eval(session=sess) x, y = np.split(flow_field, 2, axis=3) flow_image = ctfi.rescale( np.squeeze(np.concatenate([x, y, np.zeros_like(x)], 3)), 0.0, 1.0) diff_warp_rotated = tf.abs(image_rotated - warped_image).eval(session=sess) diff_image_warp = tf.abs(image - warped_image).eval(session=sess) print( "{0:d}\t{1:.4f}\t{2:.4f}\t{3:.4f}\t{4:.4f}\t{5:.4f}\t{6:.4f}" .format(step_val, loss_val, grad_mean_source, grad_mean_dest, np.mean(flow_field), np.sum(diff_warp_rotated), np.sum(diff_image_warp))) plot_warped.set_data( ctfi.rescale(warped_image.eval(session=sess)[0], 0., 1.)) plot_diff_image.set_data( ctfi.rescale(diff_image_warp[0], 0., 1.)) plot_diff_rotated.set_data( ctfi.rescale(diff_warp_rotated[0], 0., 1.)) plot_flow.set_data(flow_image) #plot_flow.set_UVC(x,y, flow_field) dest_points = dest_control_point_locations.eval( session=sess)[0] source_points = np.squeeze(gradients[0][1]) plot_scatter_source.set_data(source_points[:, 0], source_points[:, 1]) plot_scatter_dest.set_data(dest_points[:, 0], dest_points[:, 1]) source_gradients = np.squeeze(gradients[0][0]) #dest_gradients = np.squeeze(gradients_dest[0][0]) plot_source_grad.remove() plot_source_grad = ax[0, 1].quiver( source_points[:, 0], # X source_points[:, 1], # Y source_gradients[:, 0], source_gradients[:, 1], source_gradients, units='xy', angles='xy', scale_units='xy', scale=1) #grid_plot = plot_grid(ax[0,1],source_points[:,0],source_points[:,1]) #plot_dest_grad.remove() #plot_dest_grad = ax[0,2].quiver( # dest_points[:,0], # X # dest_points[:,1], # Y # dest_gradients[:,0], # dest_gradients[:,1], # dest_gradients, # units='xy',angles='xy', scale_units='xy', scale=1) # https://stackoverflow.com/questions/48911643/set-uvc-equivilent-for-a-3d-quiver-plot-in-matplotlib # new_segs = [ [ [x,y,z], [u,v,w] ] for x,y,z,u,v,w in zip(*segs.tolist()) ] # quivers.set_segments(new_segs) #plot_source_grad.set_UVC( # source_gradients[:,0], # source_gradients[:,1], # source_gradients) #plot_dest_grad.set_UVC( # dest_gradients[:,0], # dest_gradients[:,1], # dest_gradients) fig.canvas.draw() fig.canvas.flush_events() plt.show() print("Done!") plt.ioff() plt.show() sys.exit(0)