def main(argv): parser = argparse.ArgumentParser( description='Compute codes and reconstructions for image.') parser.add_argument('export_dir', type=str, help='Path to saved model.') parser.add_argument( 'mean', type=str, help='Path to npy file holding mean for normalization.') parser.add_argument( 'variance', type=str, help='Path to npy file holding variance for normalization.') parser.add_argument('source_filename', type=str, help='Image file from which to extract patch.') parser.add_argument('source_image_size', type=int, nargs=2, help='Size of the input image, HW.') parser.add_argument('offsets', type=int, nargs=2, help='Position where to extract the patch.') parser.add_argument('patch_size', type=int, help='Size of image patch.') parser.add_argument('target_filename', type=str, help='Image file for which to create the heatmap.') parser.add_argument( 'target_image_size', type=int, nargs=2, help='Size of the input image for which to create heatmap, HW.') parser.add_argument( 'method', type=str, help= 'Method to use to measure similarity, one of KLD, SKLD, BD, HD, SQHD.') parser.add_argument( '--stain_code_size', type=int, dest='stain_code_size', default=0, help= 'Optional: Size of the stain code to use, which is skipped for similarity estimation' ) parser.add_argument('--rotate', type=float, dest='angle', default=0, help='Optional: rotation angle to rotate target image') parser.add_argument('--subsampling_factor', type=int, dest='subsampling_factor', default=1, help='Factor to subsample source and target image.') args = parser.parse_args() mean = np.load(args.mean) variance = np.load(args.variance) stddev = [np.math.sqrt(x) for x in variance] def denormalize(image): channels = [ np.expand_dims( image[:, :, channel] * stddev[channel] + mean[channel], -1) for channel in range(3) ] denormalized_image = ctfi.rescale(np.concatenate(channels, 2), 0.0, 1.0) return denormalized_image def normalize(image, name=None): channels = [ tf.expand_dims( (image[:, :, :, channel] - mean[channel]) / stddev[channel], -1) for channel in range(3) ] return tf.concat(channels, 3, name=name) latest_checkpoint = tf.train.latest_checkpoint(args.export_dir) saver = tf.train.import_meta_graph(latest_checkpoint + '.meta', import_scope='imported') with tf.Session(graph=tf.get_default_graph()).as_default() as sess: # Load image and extract patch from it and create distribution. source_image = ctfi.subsample( ctfi.load(args.source_filename, height=args.source_image_size[0], width=args.source_image_size[1]), args.subsampling_factor) args.source_image_size = list( map(lambda x: int(x / args.subsampling_factor), args.source_image_size)) patch = normalize( tf.expand_dims( tf.image.crop_to_bounding_box(source_image, args.offsets[0], args.offsets[1], args.patch_size, args.patch_size), 0)) #patch_cov, patch_mean = tf.contrib.graph_editor.graph_replace([sess.graph.get_tensor_by_name('imported/z_covariance_lower_tri/MatrixBandPart:0'),sess.graph.get_tensor_by_name('imported/z_mean/BiasAdd:0')] ,{ sess.graph.get_tensor_by_name('imported/patch:0'): patch }) #patch_distribution = tf.contrib.distributions.MultivariateNormalTriL(loc=patch_mean[:,args.stain_code_size:], scale_tril=patch_cov[:,args.stain_code_size:,args.stain_code_size:]) patch_cov, patch_mean = tf.contrib.graph_editor.graph_replace([ sess.graph.get_tensor_by_name('imported/z_log_sigma_sq/BiasAdd:0'), sess.graph.get_tensor_by_name('imported/z_mean/BiasAdd:0') ], {sess.graph.get_tensor_by_name('imported/patch:0'): patch}) patch_distribution = tf.contrib.distributions.MultivariateNormalDiag( patch_mean[:, args.stain_code_size:], tf.sqrt(tf.exp(patch_cov[:, args.stain_code_size:]))) sim_vals = [] #Load image for which to create the heatmap target_image = ctfi.subsample( ctfi.load(args.target_filename, height=args.target_image_size[0], width=args.target_image_size[1]), args.subsampling_factor) target_image = tf.contrib.image.rotate(target_image, np.radians(args.angle)) args.target_image_size = list( map(lambda x: int(x / args.subsampling_factor), args.target_image_size)) heatmap_height = args.target_image_size[0] - (args.patch_size - 1) heatmap_width = args.target_image_size[1] - (args.patch_size - 1) # Compute byte size as: width*height*channels*sizeof(float32) patch_size_in_byte = args.patch_size**2 * 3 * 4 max_patches = int(max_patch_buffer_size / patch_size_in_byte) max_num_rows = int(max_patches / heatmap_width) max_chunk_size = int(max_buffer_size_in_byte / patch_size_in_byte) #Iteration over image regions that we can load num_iterations = int(args.target_image_size[0] / max_num_rows) + 1 all_chunks = list() all_similarities = list() chunk_tensors = list() chunk_sizes = np.zeros(num_iterations, dtype=np.int) chunk_sizes.fill(heatmap_width) for i in range(num_iterations): processed_rows = i * max_num_rows rows_to_load = min(max_num_rows + (args.patch_size - 1), args.target_image_size[0] - processed_rows) if rows_to_load < args.patch_size: break # Extract region for which we can compute patches target_image_region = tf.image.crop_to_bounding_box( target_image, processed_rows, 0, rows_to_load, args.target_image_size[1]) # Size = (image_width - patch_size - 1) * (image_height - patch_size - 1) for 'VALID' padding and # image_width * image_height for 'SAME' padding all_image_patches = tf.unstack( normalize( ctfi.extract_patches(target_image_region, args.patch_size, strides=[1, 1, 1, 1], padding='VALID'))) possible_chunk_sizes = get_divisors(len(all_image_patches)) for size in possible_chunk_sizes: if size < max_chunk_size: chunk_sizes[i] = size break # Partition patches into chunks chunked_patches = list( create_chunks(all_image_patches, chunk_sizes[i])) chunked_patches = list(map(tf.stack, chunked_patches)) all_chunks.append(chunked_patches) #last_chunk = chunked_patches.pop() #last_chunk_size = last_chunk.get_shape().as_list()[0] #padding_size = chunk_size - last_chunk_size #last_chunk_padded = tf.concat([last_chunk, tf.ones([padding_size, args.patch_size, args.patch_size, 3],dtype=tf.float32)],0) chunk_tensor = tf.placeholder( tf.float32, shape=[chunk_sizes[i], args.patch_size, args.patch_size, 3], name='chunk_tensor_placeholder') chunk_tensors.append(chunk_tensor) #image_patches_cov, image_patches_mean = tf.contrib.graph_editor.graph_replace([sess.graph.get_tensor_by_name('imported/z_covariance_lower_tri/MatrixBandPart:0'),sess.graph.get_tensor_by_name('imported/z_mean/BiasAdd:0')] ,{ sess.graph.get_tensor_by_name('imported/patch:0'): chunk_tensor }) #image_patches_distributions = tf.contrib.distributions.MultivariateNormalTriL(loc=image_patches_mean[:,args.stain_code_size:], scale_tril=image_patches_cov[:,args.stain_code_size:,args.stain_code_size:]) image_patches_cov, image_patches_mean = tf.contrib.graph_editor.graph_replace( [ sess.graph.get_tensor_by_name( 'imported/z_log_sigma_sq/BiasAdd:0'), sess.graph.get_tensor_by_name('imported/z_mean/BiasAdd:0') ], { sess.graph.get_tensor_by_name('imported/patch:0'): chunk_tensor }) image_patches_distributions = tf.contrib.distributions.MultivariateNormalDiag( image_patches_mean[:, args.stain_code_size:], tf.sqrt(tf.exp(image_patches_cov[:, args.stain_code_size:]))) if args.method == 'SKLD': similarities = patch_distribution.kl_divergence( image_patches_distributions ) + image_patches_distributions.kl_divergence( patch_distribution) elif args.method == 'BD': similarities = ctf.bhattacharyya_distance( patch_distribution, image_patches_distributions) elif args.method == 'SQHD': similarities = ctf.multivariate_squared_hellinger_distance( patch_distribution, image_patches_distributions) elif args.method == 'HD': similarities = tf.sqrt( ctf.multivariate_squared_hellinger_distance( patch_distribution, image_patches_distributions)) else: similarities = patch_distribution.kl_divergence( image_patches_distributions) all_similarities.append(similarities) sess.run(tf.global_variables_initializer()) sess.run(tf.local_variables_initializer()) saver.restore(sess, latest_checkpoint) for i in range(len(all_chunks)): for chunk in all_chunks[i]: #chunk_vals = sess.run(all_similarities[i], feed_dict={chunk_tensors[i]: sess.run(chunk)}) sim_vals.extend( sess.run(all_similarities[i], feed_dict={chunk_tensors[i]: sess.run(chunk)})) print(len(sim_vals)) sim_heatmap = np.reshape(sim_vals, [heatmap_height, heatmap_width]) heatmap_tensor = tf.expand_dims( tf.expand_dims(tf.convert_to_tensor(sim_heatmap), -1), 0) dy, dx = tf.image.image_gradients(heatmap_tensor) sim_vals_normalized = 1.0 - ctfi.rescale(sim_heatmap, 0.0, 1.0) k_min = 20 min_indices = np.unravel_index( np.argsort(sim_vals)[:k_min], sim_heatmap.shape) fig_min, ax_min = plt.subplots(4, 5) for i in range(k_min): target_patch = tf.image.crop_to_bounding_box( target_image, min_indices[0][i], min_indices[1][i], args.patch_size, args.patch_size) ax_min[int(i / 5), int(i % 5)].imshow(sess.run(target_patch)) ax_min[int(i / 5), int(i % 5)].set_title('y:' + str(min_indices[0][i]) + ', x:' + str(min_indices[1][i])) fig, ax = plt.subplots(2, 3) cmap = 'plasma' denormalized_patch = denormalize(sess.run(patch)[0]) max_sim_val = np.max(sim_vals) max_idx = np.unravel_index(np.argmin(sim_heatmap), sim_heatmap.shape) target_image_patch = tf.image.crop_to_bounding_box( target_image, max_idx[0], max_idx[1], args.patch_size, args.patch_size) print(max_idx) print(min_indices) ax[1, 0].imshow(sess.run(source_image)) ax[1, 1].imshow(sess.run(target_image)) ax[0, 0].imshow(denormalized_patch) heatmap_image = ax[0, 2].imshow(sim_heatmap, cmap=cmap) ax[0, 1].imshow(sess.run(target_image_patch)) #dx_image = ax[0,2].imshow(np.squeeze(sess.run(dx)), cmap='bwr') #dy_image = ax[1,2].imshow(np.squeeze(sess.run(dy)), cmap='bwr') gradient_image = ax[1, 2].imshow(np.squeeze(sess.run(dx + dy)), cmap='bwr') fig.colorbar(heatmap_image, ax=ax[0, 2]) #fig.colorbar(dx_image, ax=ax[0,2]) #fig.colorbar(dy_image, ax=ax[1,2]) fig.colorbar(gradient_image, ax=ax[1, 2]) plt.show() sess.close() print("Done!")
def main(argv): parser = argparse.ArgumentParser(description='Compute similarity heatmaps of windows around landmarks.') parser.add_argument('export_dir',type=str,help='Path to saved model.') parser.add_argument('mean', type=str, help='Path to npy file holding mean for normalization.') parser.add_argument('variance', type=str, help='Path to npy file holding variance for normalization.') parser.add_argument('source_filename', type=str,help='Image file from which to extract patch.') parser.add_argument('source_image_size', type=int, nargs=2, help='Size of the input image, HW.') parser.add_argument('source_landmarks', type=str,help='CSV file from which to extract the landmarks for source image.') parser.add_argument('target_filename', type=str,help='Image file for which to create the heatmap.') parser.add_argument('target_image_size', type=int, nargs=2, help='Size of the input image for which to create heatmap, HW.') parser.add_argument('target_landmarks', type=str,help='CSV file from which to extract the landmarks for target image.') parser.add_argument('patch_size', type=int, help='Size of image patch.') parser.add_argument('output', type=str) parser.add_argument('--method', dest='method', type=str, help='Method to use to measure similarity, one of KLD, SKLD, BD, HD, SQHD.') parser.add_argument('--stain_code_size', type=int, dest='stain_code_size', default=0, help='Optional: Size of the stain code to use, which is skipped for similarity estimation') parser.add_argument('--rotate', type=float, dest='angle', default=0, help='Optional: rotation angle to rotate target image') parser.add_argument('--subsampling_factor', type=int, dest='subsampling_factor', default=1, help='Factor to subsample source and target image.') parser.add_argument('--region_size', type=int, default=64) args = parser.parse_args() mean = np.load(args.mean) variance = np.load(args.variance) stddev = [np.math.sqrt(x) for x in variance] def denormalize(image): channels = [np.expand_dims(image[:,:,channel] * stddev[channel] + mean[channel],-1) for channel in range(3)] denormalized_image = ctfi.rescale(np.concatenate(channels, 2), 0.0, 1.0) return denormalized_image def normalize(image, name=None, num_channels=3): channels = [tf.expand_dims((image[:,:,:,channel] - mean[channel]) / stddev[channel],-1) for channel in range(num_channels)] return tf.concat(channels, num_channels) latest_checkpoint = tf.train.latest_checkpoint(args.export_dir) saver = tf.train.import_meta_graph(latest_checkpoint + '.meta', import_scope='imported') config = tf.ConfigProto() config.allow_soft_placement=True #config.log_device_placement=True # Load image and extract patch from it and create distribution. source_image = tf.expand_dims(ctfi.subsample(ctfi.load(args.source_filename,height=args.source_image_size[0], width=args.source_image_size[1]),args.subsampling_factor),0) args.source_image_size = list(map(lambda x: int(x / args.subsampling_factor), args.source_image_size)) #Load image for which to create the heatmap target_image = tf.expand_dims(ctfi.subsample(ctfi.load(args.target_filename,height=args.target_image_size[0], width=args.target_image_size[1]),args.subsampling_factor),0) args.target_image_size = list(map(lambda x: int(x / args.subsampling_factor), args.target_image_size)) source_landmarks = get_landmarks(args.source_landmarks, args.subsampling_factor) target_landmarks = get_landmarks(args.target_landmarks, args.subsampling_factor) region_size = args.region_size region_center = [int(region_size / 2),int(region_size / 2)] num_patches = region_size**2 possible_splits = cutil.get_divisors(num_patches) num_splits = possible_splits.pop(0) while num_patches / num_splits > 512 and len(possible_splits) > 0: num_splits = possible_splits.pop(0) split_size = int(num_patches / num_splits) offset = 64 center_idx = np.prod(region_center) X, Y = np.meshgrid(range(offset, region_size + offset), range(offset, region_size + offset)) coords = np.concatenate([np.expand_dims(Y.flatten(),axis=1),np.expand_dims(X.flatten(),axis=1)],axis=1) coords_placeholder = tf.placeholder(tf.float32, shape=[split_size, 2]) source_landmark_placeholder = tf.placeholder(tf.float32, shape=[1, 2]) target_landmark_placeholder = tf.placeholder(tf.float32, shape=[1, 2]) source_image_region = tf.image.extract_glimpse(source_image,[region_size + 2*offset, region_size+ 2*offset], source_landmark_placeholder, normalized=False, centered=False) target_image_region = tf.image.extract_glimpse(target_image,[region_size + 2*offset, region_size+ 2*offset], target_landmark_placeholder, normalized=False, centered=False) source_patches_placeholder = tf.map_fn(lambda x: get_patch_at(x, source_image, args.patch_size), source_landmark_placeholder, parallel_iterations=8, back_prop=False)[0] target_patches_placeholder = tf.squeeze(tf.map_fn(lambda x: get_patch_at(x, target_image_region, args.patch_size), coords_placeholder, parallel_iterations=8, back_prop=False)) with tf.Session(config=config).as_default() as sess: saver.restore(sess, latest_checkpoint) source_patches_cov, source_patches_mean = tf.contrib.graph_editor.graph_replace([sess.graph.get_tensor_by_name('imported/z_log_sigma_sq/BiasAdd:0'),sess.graph.get_tensor_by_name('imported/z_mean/BiasAdd:0')] ,{ sess.graph.get_tensor_by_name('imported/patch:0'): normalize(source_patches_placeholder) }) source_patches_distribution = tf.contrib.distributions.MultivariateNormalDiag(source_patches_mean[:,args.stain_code_size:], tf.exp(source_patches_cov[:,args.stain_code_size:])) target_patches_cov, target_patches_mean = tf.contrib.graph_editor.graph_replace([sess.graph.get_tensor_by_name('imported/z_log_sigma_sq/BiasAdd:0'),sess.graph.get_tensor_by_name('imported/z_mean/BiasAdd:0')] ,{ sess.graph.get_tensor_by_name('imported/patch:0'): normalize(target_patches_placeholder) }) target_patches_distribution = tf.contrib.distributions.MultivariateNormalDiag(target_patches_mean[:,args.stain_code_size:], tf.exp(target_patches_cov[:,args.stain_code_size:])) similarities_skld = source_patches_distribution.kl_divergence(target_patches_distribution) + target_patches_distribution.kl_divergence(source_patches_distribution) similarities_bd = ctf.bhattacharyya_distance(source_patches_distribution, target_patches_distribution) similarities_sad = tf.reduce_sum(tf.abs(source_patches_placeholder - target_patches_placeholder), axis=[1,2,3]) source_patches_grayscale = tf.image.rgb_to_grayscale(source_patches_placeholder) target_patches_grayscale = tf.image.rgb_to_grayscale(target_patches_placeholder) similarities_nmi = tf.map_fn(lambda x: nmi_tf(tf.squeeze(source_patches_grayscale), tf.squeeze(x), 20), target_patches_grayscale) with open(args.output + "_" + str(region_size) + ".csv",'wt') as outfile: fp = csv.DictWriter(outfile, ["method", "landmark", "min_idx", "min_idx_value", "rank", "landmark_value"]) methods = ["SKLD", "BD", "SAD", "MI"] fp.writeheader() results = [] for k in range(len(source_landmarks)): heatmap_fused = np.ndarray((region_size, region_size, len(methods))) feed_dict={source_landmark_placeholder: [source_landmarks[k,:]], target_landmark_placeholder: [target_landmarks[k,:]] } for i in range(num_splits): start = i * split_size end = start + split_size batch_coords = coords[start:end,:] feed_dict.update({coords_placeholder: batch_coords}) similarity_values = np.array(sess.run([similarities_skld,similarities_bd, similarities_sad, similarities_nmi],feed_dict=feed_dict)).transpose() #heatmap.extend(similarity_values) for idx, val in zip(batch_coords, similarity_values): heatmap_fused[idx[0] - offset, idx[1] - offset] = val for c in range(len(methods)): heatmap = heatmap_fused[:,:,c] if c == 3: min_idx = np.unravel_index(np.argmax(heatmap),heatmap.shape) min_indices = np.array(np.unravel_index(list(reversed(np.argsort(heatmap.flatten()))),heatmap.shape)).transpose().tolist() else: min_idx = np.unravel_index(np.argmin(heatmap),heatmap.shape) min_indices = np.array(np.unravel_index(np.argsort(heatmap.flatten()),heatmap.shape)).transpose().tolist() landmark_value = heatmap[region_center[0], region_center[1]] rank = min_indices.index(region_center) fp.writerow({"method": methods[c],"landmark": k, "min_idx": min_idx, "min_idx_value": heatmap[min_idx[0], min_idx[1]],"rank": rank , "landmark_value": landmark_value}) #matplotlib.image.imsave(args.output + "_" + str(region_size)+ "_"+ methods[c] + "_" + str(k) + ".jpeg", heatmap, cmap='plasma') outfile.flush() print(min_idx, rank) fp.writerows(results) sess.close() return 0
def main(argv): parser = argparse.ArgumentParser( description='Compute codes and reconstructions for image.') parser.add_argument('export_dir', type=str, help='Path to saved model.') parser.add_argument( 'mean', type=str, help='Path to npy file holding mean for normalization.') parser.add_argument( 'variance', type=str, help='Path to npy file holding variance for normalization.') parser.add_argument('source_filename', type=str, help='Image file from which to extract patch.') parser.add_argument('source_image_size', type=int, nargs=2, help='Size of the input image, HW.') parser.add_argument( 'source_landmarks', type=str, help='CSV file from which to extract the landmarks for source image.') parser.add_argument('target_filename', type=str, help='Image file for which to create the heatmap.') parser.add_argument( 'target_image_size', type=int, nargs=2, help='Size of the input image for which to create heatmap, HW.') parser.add_argument( 'target_landmarks', type=str, help='CSV file from which to extract the landmarks for target image.') parser.add_argument('patch_size', type=int, help='Size of image patch.') parser.add_argument( '--method', dest='method', type=str, help= 'Method to use to measure similarity, one of KLD, SKLD, BD, HD, SQHD.') parser.add_argument( '--stain_code_size', type=int, dest='stain_code_size', default=0, help= 'Optional: Size of the stain code to use, which is skipped for similarity estimation' ) parser.add_argument('--rotate', type=float, dest='angle', default=0, help='Optional: rotation angle to rotate target image') parser.add_argument('--subsampling_factor', type=int, dest='subsampling_factor', default=1, help='Factor to subsample source and target image.') args = parser.parse_args() mean = np.load(args.mean) variance = np.load(args.variance) stddev = [np.math.sqrt(x) for x in variance] def denormalize(image): channels = [ np.expand_dims( image[:, :, channel] * stddev[channel] + mean[channel], -1) for channel in range(3) ] denormalized_image = ctfi.rescale(np.concatenate(channels, 2), 0.0, 1.0) return denormalized_image def normalize(image, name=None, num_channels=3): channels = [ tf.expand_dims( (image[:, :, :, channel] - mean[channel]) / stddev[channel], -1) for channel in range(num_channels) ] return tf.concat(channels, num_channels) latest_checkpoint = tf.train.latest_checkpoint(args.export_dir) saver = tf.train.import_meta_graph(latest_checkpoint + '.meta', import_scope='imported') config = tf.ConfigProto() config.allow_soft_placement = True #config.log_device_placement=True # Load image and extract patch from it and create distribution. source_image = tf.expand_dims( ctfi.subsample( ctfi.load(args.source_filename, height=args.source_image_size[0], width=args.source_image_size[1]), args.subsampling_factor), 0) args.source_image_size = list( map(lambda x: int(x / args.subsampling_factor), args.source_image_size)) #Load image for which to create the heatmap target_image = tf.expand_dims( ctfi.subsample( ctfi.load(args.target_filename, height=args.target_image_size[0], width=args.target_image_size[1]), args.subsampling_factor), 0) args.target_image_size = list( map(lambda x: int(x / args.subsampling_factor), args.target_image_size)) source_landmarks = get_landmarks(args.source_landmarks, args.subsampling_factor) source_patches = tf.squeeze( tf.map_fn(lambda x: get_patch_at(x, source_image, args.patch_size), source_landmarks)) target_landmarks = get_landmarks(args.target_landmarks, args.subsampling_factor) target_patches = tf.squeeze( tf.map_fn(lambda x: get_patch_at(x, target_image, args.patch_size), target_landmarks)) with tf.Session(config=config).as_default() as sess: saver.restore(sess, latest_checkpoint) source_patches_cov, source_patches_mean = tf.contrib.graph_editor.graph_replace( [ sess.graph.get_tensor_by_name( 'imported/z_log_sigma_sq/BiasAdd:0'), sess.graph.get_tensor_by_name('imported/z_mean/BiasAdd:0') ], { sess.graph.get_tensor_by_name('imported/patch:0'): normalize(source_patches) }) source_patches_distribution = tf.contrib.distributions.MultivariateNormalDiag( source_patches_mean[:, args.stain_code_size:], tf.exp(source_patches_cov[:, args.stain_code_size:])) target_patches_cov, target_patches_mean = tf.contrib.graph_editor.graph_replace( [ sess.graph.get_tensor_by_name( 'imported/z_log_sigma_sq/BiasAdd:0'), sess.graph.get_tensor_by_name('imported/z_mean/BiasAdd:0') ], { sess.graph.get_tensor_by_name('imported/patch:0'): normalize(target_patches) }) target_patches_distribution = tf.contrib.distributions.MultivariateNormalDiag( target_patches_mean[:, args.stain_code_size:], tf.exp(target_patches_cov[:, args.stain_code_size:])) #similarities = source_patches_distribution.kl_divergence(target_patches_distribution) + target_patches_distribution.kl_divergence(source_patches_distribution) #similarities = ctf.multivariate_squared_hellinger_distance(source_patches_distribution, target_patches_distribution) similarities = ctf.bhattacharyya_distance(source_patches_distribution, target_patches_distribution) #similarities = tf.reduce_sum(tf.abs(source_patches - target_patches), axis=[1,2,3]) sim_vals = sess.run(similarities) min_idx = np.argmin(sim_vals) max_idx = np.argmax(sim_vals) print(sim_vals) print(min_idx, sim_vals[min_idx]) print(max_idx, sim_vals[max_idx]) fig, ax = plt.subplots(2, 3) ax[0, 0].imshow(sess.run(source_image[0])) ax[0, 1].imshow(sess.run(source_patches)[min_idx]) ax[0, 2].imshow(sess.run(source_patches)[max_idx]) ax[1, 0].imshow(sess.run(target_image[0])) ax[1, 1].imshow(sess.run(target_patches)[min_idx]) ax[1, 2].imshow(sess.run(target_patches)[max_idx]) plt.show() sess.close() return 0