def main(argv): parser = argparse.ArgumentParser( description='Display image readable with tensorflow.') parser.add_argument('filename', type=str, help='Image file to display.') args = parser.parse_args() if ctfi.is_image(args.filename) == False: sys.exit(-1) image = ctfi.load(args.filename, channels=3) image = tf.expand_dims(image, 0) dx, dy = tf.image.image_gradients(image) dxr, dxg, dxb = tf.split(dx, 3, 3) dyr, dyg, dyb = tf.split(dy, 3, 3) strides = [1, 1, 1, 1] padding = "SAME" #reconstructed = tf.nn.conv2d_transpose(dxr + dyr, tf.ones([3,3,1,1], dtype=tf.float32),[1,32,32,1],strides,padding)# + tf.nn.conv2d(dy, tf.ones([3,3,1,3], dtype=tf.float32),strides,padding) #reconstructed = tf.concat([tf.nn.conv2d_transpose(c, tf.ones([1,32,1,1], dtype=tf.float32),[1,32,32,1],strides,padding) for c in tf.split(dx,3,3)],3) #reconstructed += tf.concat([tf.nn.conv2d_transpose(c, tf.ones([32,1,1,1], dtype=tf.float32),[1,32,32,1],strides,padding) for c in tf.split(dy,3,3)],3) fig, ax = plt.subplots(2, 2) ax[0, 0].imshow(image[0].numpy()) ax[0, 1].imshow(dx[0] + dy[0].numpy()) ax[1, 0].imshow(dx[0].numpy()) ax[1, 1].imshow(dy[0].numpy()) plt.show()
def main(argv): filename = os.path.join(git_root, 'data', 'images', 'tile_8_14.jpeg') image = ctfi.load(filename, width=1024, height=1024, channels=3) patches = ctfi.extract_patches(image, 64) image_patch = patches[0, :, :, :] fig, ax = plt.subplots() plt.imshow(image_patch.numpy()) plt.show()
def main(argv): filename = os.path.join(git_root, 'data', 'images', 'tile_8_14.jpeg') if ctfi.is_image(filename): image = ctfi.load(filename, width=1024, height=1024, channels=3) else: image = np.random.rand(1024, 1024, 3) # Using eager execution fig, ax = plt.subplots() plt.imshow(image.numpy()) plt.show()
def main(argv): parser = argparse.ArgumentParser( description='Compute latent code for image patch by model inference.') parser.add_argument('export_dir', type=str, help='Path to saved model to use for inference.') parser.add_argument('filename', type=str, help='Image file or numpy array to run inference on.') parser.add_argument('--output', type=str, help='Where to store the output.') args = parser.parse_args() predict_fn = predictor.from_saved_model(args.export_dir) # Extract patch size and latent space size from the model identifier patch_size = ctfsm.determine_patch_size(args.export_dir) latent_space_size = ctfsm.determine_latent_space_size(args.export_dir) image = None # Check if it is image or numpy array data if ctfi.is_image(args.filename): image = ctfi.load(args.filename).numpy() elif cutil.is_numpy_format(args.filename): image = np.load(args.filename) else: sys.exit(3) # Resize image to match size required by the model image = np.resize(image, [patch_size, patch_size, 3]) batch = np.expand_dims(image, 0) # Make predictions pred = predict_fn({ 'fixed': batch, 'moving': np.random.rand(1, patch_size, patch_size, 3), 'embedding': np.random.rand(1, 1, 1, latent_space_size) }) latent_code = pred['latent_code_fixed'] print(latent_code) if args.output: with open(args.output, 'w') as f: json.dump( { 'filename': args.filename, 'model': args.export_dir, 'latent_code': latent_code.tolist() }, f)
def main(argv): parser = argparse.ArgumentParser(description='Display image readable with tensorflow.') parser.add_argument('filename',type=str,help='Image file to display.') args = parser.parse_args() if ctfi.is_image(args.filename) == False: sys.exit(-1) image = ctfi.load(args.filename, channels=3) fig, ax = plt.subplots() plt.imshow(image.numpy()) plt.show()
def main(argv): parser = argparse.ArgumentParser(description='Display image from dataset') parser.add_argument('dataset', type=str, help='Image file to display.') parser.add_argument( 'key', type=str, help='Key of feature that contains image to be displayed.') parser.add_argument('size', type=int, help='Size of samples in dataset.') parser.add_argument('position', type=int, help='Position of sample to plot in dataset.') args = parser.parse_args() features = [{ 'shape': [args.size, args.size, 3], 'key': args.key, 'dtype': tf.float32 }] decode_op = ctfd.construct_decode_op(features) dataset = tf.data.TFRecordDataset(args.dataset).map(decode_op, num_parallel_calls=8) image = tf.data.experimental.get_single_element( dataset.skip(args.position).take(1))[args.key] plt.imshow(ctfi.rescale(image.numpy(), 0.0, 1.0)) plt.show()
def main(argv): filename = os.path.join(git_root, 'data', 'images', 'tile_8_14.jpeg') image = ctfi.load(filename, width=1024, height=1024, channels=3) image_subsampled = ctfi.subsample(image, 2) ## If not using eager execution #with tf.Session().as_default() as sess: # fig, ax = plt.subplots() # plt.imshow(image_subsampled.eval(session=sess)) # plt.show() # Using eager execution fig, ax = plt.subplots() plt.imshow(image_subsampled.numpy()) plt.show()
def denormalize(image): channels = [ np.expand_dims(image[:, :, channel] * stddev[channel] + mean[channel], -1) for channel in range(3) ] denormalized_image = ctfi.rescale(np.concatenate(channels, 2), 0.0, 1.0) return denormalized_image
def main(argv): parser = argparse.ArgumentParser(description='Compute codes and reconstructions for image.') parser.add_argument('export_dir',type=str,help='Path to saved model.') parser.add_argument('mean', type=str, help='Path to npy file holding mean for normalization.') parser.add_argument('variance', type=str, help='Path to npy file holding variance for normalization.') parser.add_argument('filename', type=str,help='Image file or numpy array to run inference on.') parser.add_argument('image_size', type=int, nargs=2,help='Size of the image, HW.') parser.add_argument('patch_size', type=int, help='Size of image patches.') parser.add_argument('stride', type=int, help='Size of stride.') parser.add_argument('codes_out', type=str,help='Where to store the numpy array of codes.') parser.add_argument('reconstructions_out', type=str,help='Where to store the numpy array of reconstructions.') args = parser.parse_args() mean = np.load(args.mean) variance = np.load(args.variance) stddev = [np.math.sqrt(x) for x in variance] def denormalize(image): channels = [np.expand_dims(image[:,:,channel] * stddev[channel] + mean[channel],-1) for channel in range(3)] denormalized_image = ctfi.rescale(np.concatenate(channels, 2), 0.0, 1.0) return denormalized_image def normalize(image, name=None): channels = [tf.expand_dims((image[:,:,:,channel] - mean[channel]) / stddev[channel],-1) for channel in range(3)] return tf.concat(channels, 3, name=name) latest_checkpoint = tf.train.latest_checkpoint(args.export_dir) saver = tf.train.import_meta_graph(latest_checkpoint + '.meta', import_scope='imported') with tf.Session(graph=tf.get_default_graph()).as_default() as sess: image = ctfi.load(args.filename,height=args.image_size[0], width=args.image_size[1]) patches = normalize(ctfi.extract_patches(image, args.patch_size, strides=[1,args.stride,args.stride,1])) codes = tf.contrib.graph_editor.graph_replace(sess.graph.get_tensor_by_name('imported/code:0') ,{ sess.graph.get_tensor_by_name('imported/patch:0'): patches }) reconstructions = tf.contrib.graph_editor.graph_replace(sess.graph.get_tensor_by_name('imported/logits:0') ,{ sess.graph.get_tensor_by_name('imported/code:0'): codes }) saver.restore(sess, latest_checkpoint) codes_npy = sess.run(codes) reconstructions_npy = np.array(list(map(denormalize,sess.run(reconstructions)))) plt.imshow(denormalize(sess.run(ctfi.stitch_patches(reconstructions,[1,args.stride,args.stride,1], args.image_size)))) plt.show() np.save(args.codes_out,codes_npy) np.save(args.reconstructions_out, reconstructions_npy) print("Done!")
def main(argv): parser = argparse.ArgumentParser( description='Plot latent space traversals for model.') parser.add_argument('export_dir', type=str, help='Path to saved model.') parser.add_argument('filename', type=str, help='Image file or numpy array to run inference on.') args = parser.parse_args() latest_checkpoint = tf.train.latest_checkpoint(args.export_dir) saver = tf.train.import_meta_graph(latest_checkpoint + '.meta', import_scope='imported') image = normalize( tf.expand_dims(ctfi.load(args.filename, width=32, height=32), 0)) plots = 21 fig_traversal, ax_traversal = plt.subplots(18, plots) with tf.Session(graph=tf.get_default_graph()).as_default() as sess: embedding = tf.contrib.graph_editor.graph_replace( sess.graph.get_tensor_by_name('imported/code:0'), {sess.graph.get_tensor_by_name('imported/patch:0'): image}) offsets = tf.expand_dims(tf.lin_space(-11.0, 11.0, plots), -1) shifts = tf.concat( [tf.pad(offsets, [[0, 0], [i, 17 - i]]) for i in range(0, 18)], 0) codes = tf.tile(embedding, [plots * 18, 1]) + shifts shift_vals = sess.run(shifts) reconstructions = tf.contrib.graph_editor.graph_replace( sess.graph.get_tensor_by_name('imported/logits:0'), {sess.graph.get_tensor_by_name('imported/code:0'): codes}) saver.restore(sess, latest_checkpoint) images = list(map(denormalize, sess.run(reconstructions))) for i in range(18 * plots): ax_traversal[int(i / plots), int(i % plots)].imshow(images[i]) plt.show()
def main(argv): with tf.Session(graph=tf.get_default_graph()).as_default() as sess: filename = os.path.join(git_root, 'data', 'images', 'encoder_input.png') image = tf.expand_dims(ctfi.load(filename, width=32, height=32, channels=3), 0, name='image_tensor') angle = tf.convert_to_tensor(np.random.rand(1, 1), dtype=tf.float32, name='angle_tensor') tensors = {'image_tensor': image, 'angle': angle} rotation_layer = ctfm.parse_component(tensors, rotation_layer_conf, tensors) rotated_image = rotation_layer[2](angle) plt.imshow(sess.run(rotated_image)[0]) plt.show()
def _split_patches(features): patches = ctfi.extract_patches(features['image'], args.patch_size) labels = tf.expand_dims(tf.reshape(features['label'], [1]), 0) labels = tf.tile(labels, tf.stack([tf.shape(patches)[0], 1])) return (patches, labels)
def main(argv): parser = argparse.ArgumentParser(description='Compute similarity heatmaps of windows around landmarks.') parser.add_argument('export_dir',type=str,help='Path to saved model.') parser.add_argument('mean', type=str, help='Path to npy file holding mean for normalization.') parser.add_argument('variance', type=str, help='Path to npy file holding variance for normalization.') parser.add_argument('source_filename', type=str,help='Image file from which to extract patch.') parser.add_argument('source_image_size', type=int, nargs=2, help='Size of the input image, HW.') parser.add_argument('source_landmarks', type=str,help='CSV file from which to extract the landmarks for source image.') parser.add_argument('target_filename', type=str,help='Image file for which to create the heatmap.') parser.add_argument('target_image_size', type=int, nargs=2, help='Size of the input image for which to create heatmap, HW.') parser.add_argument('target_landmarks', type=str,help='CSV file from which to extract the landmarks for target image.') parser.add_argument('patch_size', type=int, help='Size of image patch.') parser.add_argument('output', type=str) parser.add_argument('--method', dest='method', type=str, help='Method to use to measure similarity, one of KLD, SKLD, BD, HD, SQHD.') parser.add_argument('--stain_code_size', type=int, dest='stain_code_size', default=0, help='Optional: Size of the stain code to use, which is skipped for similarity estimation') parser.add_argument('--rotate', type=float, dest='angle', default=0, help='Optional: rotation angle to rotate target image') parser.add_argument('--subsampling_factor', type=int, dest='subsampling_factor', default=1, help='Factor to subsample source and target image.') parser.add_argument('--region_size', type=int, default=64) args = parser.parse_args() mean = np.load(args.mean) variance = np.load(args.variance) stddev = [np.math.sqrt(x) for x in variance] def denormalize(image): channels = [np.expand_dims(image[:,:,channel] * stddev[channel] + mean[channel],-1) for channel in range(3)] denormalized_image = ctfi.rescale(np.concatenate(channels, 2), 0.0, 1.0) return denormalized_image def normalize(image, name=None, num_channels=3): channels = [tf.expand_dims((image[:,:,:,channel] - mean[channel]) / stddev[channel],-1) for channel in range(num_channels)] return tf.concat(channels, num_channels) latest_checkpoint = tf.train.latest_checkpoint(args.export_dir) saver = tf.train.import_meta_graph(latest_checkpoint + '.meta', import_scope='imported') config = tf.ConfigProto() config.allow_soft_placement=True #config.log_device_placement=True # Load image and extract patch from it and create distribution. source_image = tf.expand_dims(ctfi.subsample(ctfi.load(args.source_filename,height=args.source_image_size[0], width=args.source_image_size[1]),args.subsampling_factor),0) args.source_image_size = list(map(lambda x: int(x / args.subsampling_factor), args.source_image_size)) #Load image for which to create the heatmap target_image = tf.expand_dims(ctfi.subsample(ctfi.load(args.target_filename,height=args.target_image_size[0], width=args.target_image_size[1]),args.subsampling_factor),0) args.target_image_size = list(map(lambda x: int(x / args.subsampling_factor), args.target_image_size)) source_landmarks = get_landmarks(args.source_landmarks, args.subsampling_factor) target_landmarks = get_landmarks(args.target_landmarks, args.subsampling_factor) region_size = args.region_size region_center = [int(region_size / 2),int(region_size / 2)] num_patches = region_size**2 possible_splits = cutil.get_divisors(num_patches) num_splits = possible_splits.pop(0) while num_patches / num_splits > 512 and len(possible_splits) > 0: num_splits = possible_splits.pop(0) split_size = int(num_patches / num_splits) offset = 64 center_idx = np.prod(region_center) X, Y = np.meshgrid(range(offset, region_size + offset), range(offset, region_size + offset)) coords = np.concatenate([np.expand_dims(Y.flatten(),axis=1),np.expand_dims(X.flatten(),axis=1)],axis=1) coords_placeholder = tf.placeholder(tf.float32, shape=[split_size, 2]) source_landmark_placeholder = tf.placeholder(tf.float32, shape=[1, 2]) target_landmark_placeholder = tf.placeholder(tf.float32, shape=[1, 2]) source_image_region = tf.image.extract_glimpse(source_image,[region_size + 2*offset, region_size+ 2*offset], source_landmark_placeholder, normalized=False, centered=False) target_image_region = tf.image.extract_glimpse(target_image,[region_size + 2*offset, region_size+ 2*offset], target_landmark_placeholder, normalized=False, centered=False) source_patches_placeholder = tf.map_fn(lambda x: get_patch_at(x, source_image, args.patch_size), source_landmark_placeholder, parallel_iterations=8, back_prop=False)[0] target_patches_placeholder = tf.squeeze(tf.map_fn(lambda x: get_patch_at(x, target_image_region, args.patch_size), coords_placeholder, parallel_iterations=8, back_prop=False)) with tf.Session(config=config).as_default() as sess: saver.restore(sess, latest_checkpoint) source_patches_cov, source_patches_mean = tf.contrib.graph_editor.graph_replace([sess.graph.get_tensor_by_name('imported/z_log_sigma_sq/BiasAdd:0'),sess.graph.get_tensor_by_name('imported/z_mean/BiasAdd:0')] ,{ sess.graph.get_tensor_by_name('imported/patch:0'): normalize(source_patches_placeholder) }) source_patches_distribution = tf.contrib.distributions.MultivariateNormalDiag(source_patches_mean[:,args.stain_code_size:], tf.exp(source_patches_cov[:,args.stain_code_size:])) target_patches_cov, target_patches_mean = tf.contrib.graph_editor.graph_replace([sess.graph.get_tensor_by_name('imported/z_log_sigma_sq/BiasAdd:0'),sess.graph.get_tensor_by_name('imported/z_mean/BiasAdd:0')] ,{ sess.graph.get_tensor_by_name('imported/patch:0'): normalize(target_patches_placeholder) }) target_patches_distribution = tf.contrib.distributions.MultivariateNormalDiag(target_patches_mean[:,args.stain_code_size:], tf.exp(target_patches_cov[:,args.stain_code_size:])) similarities_skld = source_patches_distribution.kl_divergence(target_patches_distribution) + target_patches_distribution.kl_divergence(source_patches_distribution) similarities_bd = ctf.bhattacharyya_distance(source_patches_distribution, target_patches_distribution) similarities_sad = tf.reduce_sum(tf.abs(source_patches_placeholder - target_patches_placeholder), axis=[1,2,3]) source_patches_grayscale = tf.image.rgb_to_grayscale(source_patches_placeholder) target_patches_grayscale = tf.image.rgb_to_grayscale(target_patches_placeholder) similarities_nmi = tf.map_fn(lambda x: nmi_tf(tf.squeeze(source_patches_grayscale), tf.squeeze(x), 20), target_patches_grayscale) with open(args.output + "_" + str(region_size) + ".csv",'wt') as outfile: fp = csv.DictWriter(outfile, ["method", "landmark", "min_idx", "min_idx_value", "rank", "landmark_value"]) methods = ["SKLD", "BD", "SAD", "MI"] fp.writeheader() results = [] for k in range(len(source_landmarks)): heatmap_fused = np.ndarray((region_size, region_size, len(methods))) feed_dict={source_landmark_placeholder: [source_landmarks[k,:]], target_landmark_placeholder: [target_landmarks[k,:]] } for i in range(num_splits): start = i * split_size end = start + split_size batch_coords = coords[start:end,:] feed_dict.update({coords_placeholder: batch_coords}) similarity_values = np.array(sess.run([similarities_skld,similarities_bd, similarities_sad, similarities_nmi],feed_dict=feed_dict)).transpose() #heatmap.extend(similarity_values) for idx, val in zip(batch_coords, similarity_values): heatmap_fused[idx[0] - offset, idx[1] - offset] = val for c in range(len(methods)): heatmap = heatmap_fused[:,:,c] if c == 3: min_idx = np.unravel_index(np.argmax(heatmap),heatmap.shape) min_indices = np.array(np.unravel_index(list(reversed(np.argsort(heatmap.flatten()))),heatmap.shape)).transpose().tolist() else: min_idx = np.unravel_index(np.argmin(heatmap),heatmap.shape) min_indices = np.array(np.unravel_index(np.argsort(heatmap.flatten()),heatmap.shape)).transpose().tolist() landmark_value = heatmap[region_center[0], region_center[1]] rank = min_indices.index(region_center) fp.writerow({"method": methods[c],"landmark": k, "min_idx": min_idx, "min_idx_value": heatmap[min_idx[0], min_idx[1]],"rank": rank , "landmark_value": landmark_value}) #matplotlib.image.imsave(args.output + "_" + str(region_size)+ "_"+ methods[c] + "_" + str(k) + ".jpeg", heatmap, cmap='plasma') outfile.flush() print(min_idx, rank) fp.writerows(results) sess.close() return 0
def main(argv): parser = argparse.ArgumentParser( description='Compute codes and reconstructions for image.') parser.add_argument('export_dir', type=str, help='Path to saved model.') parser.add_argument( 'mean', type=str, help='Path to npy file holding mean for normalization.') parser.add_argument( 'variance', type=str, help='Path to npy file holding variance for normalization.') parser.add_argument('source_filename', type=str, help='Image file from which to extract patch.') parser.add_argument('source_image_size', type=int, nargs=2, help='Size of the input image, HW.') parser.add_argument('offsets', type=int, nargs=2, help='Position where to extract the patch.') parser.add_argument('patch_size', type=int, help='Size of image patch.') parser.add_argument('target_filename', type=str, help='Image file for which to create the heatmap.') parser.add_argument( 'target_image_size', type=int, nargs=2, help='Size of the input image for which to create heatmap, HW.') parser.add_argument( 'method', type=str, help= 'Method to use to measure similarity, one of KLD, SKLD, BD, HD, SQHD.') parser.add_argument( '--stain_code_size', type=int, dest='stain_code_size', default=0, help= 'Optional: Size of the stain code to use, which is skipped for similarity estimation' ) parser.add_argument('--rotate', type=float, dest='angle', default=0, help='Optional: rotation angle to rotate target image') parser.add_argument('--subsampling_factor', type=int, dest='subsampling_factor', default=1, help='Factor to subsample source and target image.') args = parser.parse_args() mean = np.load(args.mean) variance = np.load(args.variance) stddev = [np.math.sqrt(x) for x in variance] def denormalize(image): channels = [ np.expand_dims( image[:, :, channel] * stddev[channel] + mean[channel], -1) for channel in range(3) ] denormalized_image = ctfi.rescale(np.concatenate(channels, 2), 0.0, 1.0) return denormalized_image def normalize(image, name=None): channels = [ tf.expand_dims( (image[:, :, :, channel] - mean[channel]) / stddev[channel], -1) for channel in range(3) ] return tf.concat(channels, 3, name=name) latest_checkpoint = tf.train.latest_checkpoint(args.export_dir) saver = tf.train.import_meta_graph(latest_checkpoint + '.meta', import_scope='imported') with tf.Session(graph=tf.get_default_graph()).as_default() as sess: # Load image and extract patch from it and create distribution. source_image = ctfi.subsample( ctfi.load(args.source_filename, height=args.source_image_size[0], width=args.source_image_size[1]), args.subsampling_factor) args.source_image_size = list( map(lambda x: int(x / args.subsampling_factor), args.source_image_size)) patch = normalize( tf.expand_dims( tf.image.crop_to_bounding_box(source_image, args.offsets[0], args.offsets[1], args.patch_size, args.patch_size), 0)) patch_cov, patch_mean = tf.contrib.graph_editor.graph_replace([ sess.graph.get_tensor_by_name('imported/z_log_sigma_sq/BiasAdd:0'), sess.graph.get_tensor_by_name('imported/z_mean/BiasAdd:0') ], {sess.graph.get_tensor_by_name('imported/patch:0'): patch}) #patch_distribution = tf.contrib.distributions.MultivariateNormalTriL(loc=patch_mean[:,args.stain_code_size:], scale_tril=patch_cov[:,args.stain_code_size:,args.stain_code_size:]) patch_descriptor = tf.concat([ patch_mean[:, args.stain_code_size:], tf.layers.flatten(patch_cov[:, args.stain_code_size:]) ], -1) sim_vals = [] structure_code_size = patch_mean.get_shape().as_list( )[1] - args.stain_code_size #Load image for which to create the heatmap target_image = ctfi.subsample( ctfi.load(args.target_filename, height=args.target_image_size[0], width=args.target_image_size[1]), args.subsampling_factor) args.target_image_size = list( map(lambda x: int(x / args.subsampling_factor), args.target_image_size)) target_image = tf.contrib.image.rotate(target_image, np.radians(args.angle)) heatmap_height = args.target_image_size[0] - (args.patch_size - 1) heatmap_width = args.target_image_size[1] - (args.patch_size - 1) # Compute byte size as: width*height*channels*sizeof(float32) patch_size_in_byte = args.patch_size**2 * 3 * 4 max_patches = int(max_patch_buffer_size / patch_size_in_byte) max_num_rows = int(max_patches / heatmap_width) max_chunk_size = int(max_buffer_size_in_byte / patch_size_in_byte) #Iteration over image regions that we can load num_iterations = int(args.target_image_size[0] / max_num_rows) + 1 all_chunks = list() all_similarities = list() chunk_tensors = list() chunk_sizes = np.zeros(num_iterations, dtype=np.int) chunk_sizes.fill(heatmap_width) for i in range(num_iterations): processed_rows = i * max_num_rows rows_to_load = min(max_num_rows + (args.patch_size - 1), args.target_image_size[0] - processed_rows) if rows_to_load < args.patch_size: break # Extract region for which we can compute patches target_image_region = tf.image.crop_to_bounding_box( target_image, processed_rows, 0, rows_to_load, args.target_image_size[1]) # Size = (image_width - patch_size - 1) * (image_height - patch_size - 1) for 'VALID' padding and # image_width * image_height for 'SAME' padding all_image_patches = tf.unstack( normalize( ctfi.extract_patches(target_image_region, args.patch_size, strides=[1, 1, 1, 1], padding='VALID'))) possible_chunk_sizes = get_divisors(len(all_image_patches)) for size in possible_chunk_sizes: if size < max_chunk_size: chunk_sizes[i] = size break # Partition patches into chunks chunked_patches = list( create_chunks(all_image_patches, chunk_sizes[i])) chunked_patches = list(map(tf.stack, chunked_patches)) all_chunks.append(chunked_patches) chunk_tensor = tf.placeholder( tf.float32, shape=[chunk_sizes[i], args.patch_size, args.patch_size, 3], name='chunk_tensor_placeholder') chunk_tensors.append(chunk_tensor) image_patches_cov, image_patches_mean = tf.contrib.graph_editor.graph_replace( [ sess.graph.get_tensor_by_name( 'imported/z_log_sigma_sq/BiasAdd:0'), sess.graph.get_tensor_by_name('imported/z_mean/BiasAdd:0') ], { sess.graph.get_tensor_by_name('imported/patch:0'): chunk_tensor }) image_patches_descriptors = tf.concat([ image_patches_mean[:, args.stain_code_size:], tf.layers.flatten(image_patches_cov[:, args.stain_code_size:]) ], -1) distances = dist_kl(patch_descriptor, image_patches_descriptors, structure_code_size) similarities = tf.squeeze(distances) all_similarities.append(similarities) sess.run(tf.global_variables_initializer()) sess.run(tf.local_variables_initializer()) saver.restore(sess, latest_checkpoint) for i in range(len(all_chunks)): for chunk in all_chunks[i]: #chunk_vals = sess.run(all_similarities[i], feed_dict={chunk_tensors[i]: sess.run(chunk)}) sim_vals.extend( sess.run(all_similarities[i], feed_dict={chunk_tensors[i]: sess.run(chunk)})) print(len(sim_vals)) sim_heatmap = np.reshape(sim_vals, [heatmap_height, heatmap_width]) heatmap_tensor = tf.expand_dims( tf.expand_dims(tf.convert_to_tensor(sim_heatmap), -1), 0) dy, dx = tf.image.image_gradients(heatmap_tensor) sim_vals_normalized = 1.0 - ctfi.rescale(sim_heatmap, 0.0, 1.0) k_min = 20 min_indices = np.unravel_index( np.argsort(sim_vals)[:k_min], sim_heatmap.shape) fig_min, ax_min = plt.subplots(4, 5) for i in range(k_min): target_patch = tf.image.crop_to_bounding_box( target_image, int(min_indices[0][i]), int(min_indices[1][i]), args.patch_size, args.patch_size) ax_min[int(i / 5), int(i % 5)].imshow(sess.run(target_patch)) ax_min[int(i / 5), int(i % 5)].set_title('y:' + str(min_indices[0][i]) + ', x:' + str(min_indices[1][i])) fig, ax = plt.subplots(2, 3) cmap = 'plasma' denormalized_patch = denormalize(sess.run(patch)[0]) max_sim_val = np.max(sim_vals) max_idx = np.unravel_index(np.argmin(sim_heatmap), sim_heatmap.shape) target_image_patch = tf.image.crop_to_bounding_box( target_image, max_idx[0], max_idx[1], args.patch_size, args.patch_size) print(max_idx) print(min_indices) ax[1, 0].imshow(sess.run(source_image)) ax[1, 1].imshow(sess.run(target_image)) ax[0, 0].imshow(denormalized_patch) heatmap_image = ax[0, 2].imshow(sim_heatmap, cmap=cmap) ax[0, 1].imshow(sess.run(target_image_patch)) #dx_image = ax[0,2].imshow(np.squeeze(sess.run(dx)), cmap='bwr') #dy_image = ax[1,2].imshow(np.squeeze(sess.run(dy)), cmap='bwr') gradient_image = ax[1, 2].imshow(np.squeeze(sess.run(dx + dy)), cmap='bwr') fig.colorbar(heatmap_image, ax=ax[0, 2]) #fig.colorbar(dx_image, ax=ax[0,2]) #fig.colorbar(dy_image, ax=ax[1,2]) fig.colorbar(gradient_image, ax=ax[1, 2]) plt.show() sess.close() print("Done!") return 0
def main(argv): parser = argparse.ArgumentParser( description='Plot latent space traversals for model.') parser.add_argument('export_dir', type=str, help='Path to saved model.') parser.add_argument( 'mean', type=str, help='Path to npy file holding mean for normalization.') parser.add_argument( 'variance', type=str, help='Path to npy file holding variance for normalization.') parser.add_argument( 'source_image', type=str, help='Source image file or numpy array to run inference on.') parser.add_argument( 'target_image', type=str, help='Target image file or numpy array to run inference on.') parser.add_argument( 'image_size', type=int, help='Size of the images, has to be expected input size of model.') parser.add_argument('stain_code_size', type=int, help='Size of the stain code.') args = parser.parse_args() mean = np.load(args.mean) variance = np.load(args.variance) stddev = [np.math.sqrt(x) for x in variance] def denormalize(image): channels = [ np.expand_dims( image[:, :, channel] * stddev[channel] + mean[channel], -1) for channel in range(3) ] denormalized_image = ctfi.rescale(np.concatenate(channels, 2), 0.0, 1.0) return denormalized_image def normalize(image, name=None): channels = [ tf.expand_dims( (image[:, :, :, channel] - mean[channel]) / stddev[channel], -1) for channel in range(3) ] return tf.concat(channels, 3, name=name) latest_checkpoint = tf.train.latest_checkpoint(args.export_dir) saver = tf.train.import_meta_graph(latest_checkpoint + '.meta', import_scope='imported') source_image = normalize( tf.expand_dims( ctfi.load(args.source_image, width=args.image_size, height=args.image_size), 0)) target_image = normalize( tf.expand_dims( ctfi.load(args.target_image, width=args.image_size, height=args.image_size), 0)) num_plots = 9 fig, ax = plt.subplots(4, num_plots) weights = np.linspace(0.0, 1.0, num=num_plots) with tf.Session(graph=tf.get_default_graph()).as_default() as sess: embedding_source = tf.contrib.graph_editor.graph_replace( sess.graph.get_tensor_by_name('imported/code:0'), {sess.graph.get_tensor_by_name('imported/patch:0'): source_image}) embedding_target = tf.contrib.graph_editor.graph_replace( sess.graph.get_tensor_by_name('imported/code:0'), {sess.graph.get_tensor_by_name('imported/patch:0'): target_image}) embedding_source_stain = embedding_source[:, :args.stain_code_size] embedding_source_structure = embedding_source[:, args.stain_code_size:] embedding_target_stain = embedding_target[:, :args.stain_code_size] embedding_target_structure = embedding_target[:, args.stain_code_size:] codes_stain = tf.concat([ tf.concat([(1.0 - factor) * embedding_source_stain + factor * embedding_target_stain, embedding_source_structure], -1) for factor in weights ], 0) codes_structure = tf.concat([ tf.concat([ embedding_target_stain, (1.0 - factor) * embedding_source_structure + factor * embedding_target_structure ], -1) for factor in weights ], 0) codes_full = tf.concat( [(1.0 - factor) * embedding_source + factor * embedding_target for factor in weights], 0) reconstructions_stain = tf.contrib.graph_editor.graph_replace( sess.graph.get_tensor_by_name('imported/logits:0'), {sess.graph.get_tensor_by_name('imported/code:0'): codes_stain}) reconstructions_structure = tf.contrib.graph_editor.graph_replace( sess.graph.get_tensor_by_name('imported/logits:0'), { sess.graph.get_tensor_by_name('imported/code:0'): codes_structure }) reconstructions_full = tf.contrib.graph_editor.graph_replace( sess.graph.get_tensor_by_name('imported/logits:0'), {sess.graph.get_tensor_by_name('imported/code:0'): codes_full}) saver.restore(sess, latest_checkpoint) reconstruction_images_full = list( map(denormalize, sess.run(reconstructions_full))) reconstruction_images_stain = list( map(denormalize, sess.run(reconstructions_stain))) reconstruction_images_structure = list( map(denormalize, sess.run(reconstructions_structure))) interpolations = sess.run( tf.concat([(1.0 - factor) * source_image + factor * target_image for factor in weights], 0)) interpolated_images = list(map(denormalize, interpolations)) for i in range(num_plots): ax[0, i].imshow(interpolated_images[i]) ax[1, i].imshow(reconstruction_images_stain[i]) ax[2, i].imshow(reconstruction_images_structure[i]) ax[3, i].imshow(reconstruction_images_full[i]) plt.show()
def _subsampling_op(features): features['patch'] = ctfi.subsample(features['patch'], 2) return features
def main(argv): parser = argparse.ArgumentParser( description='Plot image and its reconstruction.') parser.add_argument('export_dir', type=str, help='Path to saved model.') parser.add_argument( 'mean', type=str, help='Path to npy file holding mean for normalization.') parser.add_argument( 'variance', type=str, help='Path to npy file holding variance for normalization.') parser.add_argument('filename', type=str, help='Image file or numpy array to run inference on.') parser.add_argument('image_size', type=int, nargs=2, help='Size of the image, HW.') parser.add_argument('patch_size', type=int, help='Size of image patches.') args = parser.parse_args() mean = np.load(args.mean) variance = np.load(args.variance) stddev = [np.math.sqrt(x) for x in variance] def denormalize(image): channels = [ np.expand_dims( image[:, :, channel] * stddev[channel] + mean[channel], -1) for channel in range(3) ] denormalized_image = np.concatenate(channels, 2) return ctfi.rescale(denormalized_image, 0.0, 1.0) def normalize(image, name=None): channels = [ tf.expand_dims( (image[:, :, :, channel] - mean[channel]) / stddev[channel], -1) for channel in range(3) ] return tf.concat(channels, 3, name=name) latest_checkpoint = tf.train.latest_checkpoint(args.export_dir) saver = tf.train.import_meta_graph(latest_checkpoint + '.meta', import_scope='imported') with tf.Session(graph=tf.get_default_graph()).as_default() as sess: image = ctfi.load(args.filename, height=args.image_size[0], width=args.image_size[1]) strides = [1, args.patch_size, args.patch_size, 1] patches = normalize( ctfi.extract_patches(image, args.patch_size, strides=strides)) reconstructions = tf.contrib.graph_editor.graph_replace( sess.graph.get_tensor_by_name('imported/logits:0'), {sess.graph.get_tensor_by_name('imported/patch:0'): patches}) reconstructed_image = tf.squeeze( ctfi.stitch_patches(reconstructions, strides, args.image_size)) sess.run(tf.global_variables_initializer()) saver.restore(sess, latest_checkpoint) image_eval = sess.run(image) reconstructed_image_eval = sess.run(reconstructed_image) fig, ax = plt.subplots(1, 2) ax[0].imshow(image_eval) ax[1].imshow(denormalize(reconstructed_image_eval)) plt.show() sess.close() print("Done!")
def main(argv): parser = argparse.ArgumentParser( description='Compute latent code for image patch by model inference.') parser.add_argument('export_dir', type=str, help='Path to saved model to use for inference.') args = parser.parse_args() filename = os.path.join(git_root, 'data', 'images', 'HE_level_1_cropped_512x512.png') image = tf.expand_dims( ctfi.load(filename, width=512, height=512, channels=3), 0) target_filename = os.path.join(git_root, 'data', 'images', 'CD3_level_1_cropped_512x512.png') image_rotated = tf.Variable( tf.expand_dims( ctfi.load(target_filename, width=512, height=512, channels=3), 0)) step = tf.Variable(tf.zeros([], dtype=tf.float32)) X, Y = np.mgrid[0:512:8j, 0:512:8j] positions = np.transpose(np.vstack([X.ravel(), Y.ravel()])) positions = tf.expand_dims( tf.convert_to_tensor(positions, dtype=tf.float32), 0) source_control_point_locations = tf.Variable(positions) dest_control_point_locations = tf.Variable(positions) warped_image = tf.Variable(image_rotated) warped_image, flow = tf.contrib.image.sparse_image_warp( image_rotated, source_control_point_locations, dest_control_point_locations, name='sparse_image_warp', interpolation_order=1, regularization_weight=0.005, #num_boundary_points=1 ) image_patches = normalize( ctfi.extract_patches(image[0], 32, strides=[1, 16, 16, 1])) warped_patches = normalize( ctfi.extract_patches(warped_image[0], 32, strides=[1, 16, 16, 1])) learning_rate = 0.05 latest_checkpoint = tf.train.latest_checkpoint(args.export_dir) saver = tf.train.import_meta_graph(latest_checkpoint + '.meta', import_scope='imported') with tf.Session(graph=tf.get_default_graph()).as_default() as sess: target_cov, target_mean = tf.contrib.graph_editor.graph_replace( [ sess.graph.get_tensor_by_name( 'imported/z_covariance_lower_tri/MatrixBandPart:0'), sess.graph.get_tensor_by_name('imported/z_mean/BiasAdd:0') ], {sess.graph.get_tensor_by_name('imported/patch:0'): image_patches}) moving_cov, moving_mean = tf.contrib.graph_editor.graph_replace([ sess.graph.get_tensor_by_name( 'imported/z_covariance_lower_tri/MatrixBandPart:0'), sess.graph.get_tensor_by_name('imported/z_mean/BiasAdd:0') ], {sess.graph.get_tensor_by_name('imported/patch:0'): warped_patches}) N_target = tf.contrib.distributions.MultivariateNormalTriL( loc=target_mean[:, 6:], scale_tril=target_cov[:, 6:, 6:]) N_mov = tf.contrib.distributions.MultivariateNormalTriL( loc=moving_mean[:, 6:], scale_tril=moving_cov[:, 6:, 6:]) #h_squared = ctf.multivariate_squared_hellinger_distance(N_target, N_mov) #hellinger = tf.sqrt(h_squared) loss = tf.reduce_sum( N_target.kl_divergence(N_mov) + N_mov.kl_divergence(N_target)) scipy_options = {'maxiter': 10000, 'disp': True, 'iprint': 10} scipy_optimizer = tf.contrib.opt.ScipyOptimizerInterface( loss, var_list=[source_control_point_locations], method='SLSQP', options=scipy_options) optimizer = tf.train.GradientDescentOptimizer( learning_rate=learning_rate) compute_gradients_source = optimizer.compute_gradients( loss, var_list=[source_control_point_locations]) apply_gradients_source = optimizer.apply_gradients( compute_gradients_source, global_step=step) sess.run(tf.global_variables_initializer()) sess.run(tf.local_variables_initializer()) saver.restore(sess, latest_checkpoint) fig, ax = plt.subplots(2, 3) ax[0, 0].imshow(ctfi.rescale(image.eval(session=sess)[0], 0.0, 1.0)) ax[0, 0].set_title('image') ax[0, 0].set_autoscale_on(False) #ax[0,0].plot([200],[200],'s',marker='x', ms=10, color='red') ax[0, 1].imshow( ctfi.rescale(image_rotated.eval(session=sess)[0], 0.0, 1.0)) ax[0, 1].set_title('rotated') ax[0, 1].set_autoscale_on(False) plot_warped = ax[0, 2].imshow( ctfi.rescale(warped_image.eval(session=sess)[0], 0.0, 1.0)) ax[0, 2].set_title('warped') ax[0, 2].set_autoscale_on(False) plot_diff_image = ax[1, 0].imshow( ctfi.rescale( tf.abs(image - warped_image).eval(session=sess)[0], 0., 1.)) ax[1, 0].set_title('diff_image') ax[1, 0].set_autoscale_on(False) plot_diff_rotated = ax[1, 1].imshow( ctfi.rescale( tf.abs(image_rotated - warped_image).eval(session=sess)[0], 0., 1.)) ax[1, 1].set_title('diff_rotated') ax[1, 1].set_autoscale_on(False) plot_flow = ax[1, 2].imshow( np.zeros_like(image[0, :, :, :].eval(session=sess))) #flow_mesh_x, flow_mesh_y = np.meshgrid(np.arange(0, 1024 * 10, 10), np.arange(0, 1024 * 10, 10)) #plot_flow = ax[1,2].quiver( # flow_mesh_x, # X # flow_mesh_y, # Y # np.zeros_like(flow_mesh_x), # np.zeros_like(flow_mesh_y), # units='xy',angles='xy', scale_units='xy', scale=10) ax[1, 2].set_title('flow') ax[1, 2].set_autoscale_on(False) dest_points = dest_control_point_locations.eval(session=sess)[0] source_points = source_control_point_locations.eval(session=sess)[0] plot_scatter_source, = ax[0, 1].plot(source_points[:, 0], source_points[:, 1], 's', marker='x', ms=5, color='orange') plot_scatter_dest, = ax[0, 2].plot(dest_points[:, 0], dest_points[:, 1], 's', marker='x', ms=5, color='green') plot_source_grad = ax[0, 1].quiver( source_points[:, 0], # X source_points[:, 1], # Y np.zeros_like(source_points[:, 0]), np.zeros_like(source_points[:, 0]), units='xy', angles='xy', scale_units='xy', scale=1) plot_dest_grad = ax[0, 2].quiver( dest_points[:, 0], # X dest_points[:, 1], # Y np.zeros_like(dest_points[:, 0]), np.zeros_like(dest_points[:, 0]), units='xy', angles='xy', scale_units='xy', scale=1) plt.ion() fig.canvas.draw() fig.canvas.flush_events() plt.show() #gradients = (tf.zeros_like(source_control_point_locations),tf.zeros_like(source_control_point_locations)) iterations = 100000 while step.value().eval(session=sess) < iterations: step_val = int(step.value().eval(session=sess)) #scipy_optimizer.minimize(sess) gradients = sess.run(compute_gradients_source) sess.run(apply_gradients_source) if step_val % 100 == 0 or step_val == iterations - 1: loss_val = loss.eval(session=sess) grad_mean_source = np.mean(gradients[0][0]) grad_mean_dest = 0.0 # np.mean(gradients[1][0]) flow_field = flow.eval(session=sess) x, y = np.split(flow_field, 2, axis=3) flow_image = ctfi.rescale( np.squeeze(np.concatenate([x, y, np.zeros_like(x)], 3)), 0.0, 1.0) diff_warp_rotated = tf.abs(image_rotated - warped_image).eval(session=sess) diff_image_warp = tf.abs(image - warped_image).eval(session=sess) print( "{0:d}\t{1:.4f}\t{2:.4f}\t{3:.4f}\t{4:.4f}\t{5:.4f}\t{6:.4f}" .format(step_val, loss_val, grad_mean_source, grad_mean_dest, np.mean(flow_field), np.sum(diff_warp_rotated), np.sum(diff_image_warp))) plot_warped.set_data( ctfi.rescale(warped_image.eval(session=sess)[0], 0., 1.)) plot_diff_image.set_data( ctfi.rescale(diff_image_warp[0], 0., 1.)) plot_diff_rotated.set_data( ctfi.rescale(diff_warp_rotated[0], 0., 1.)) plot_flow.set_data(flow_image) #plot_flow.set_UVC(x,y, flow_field) dest_points = dest_control_point_locations.eval( session=sess)[0] source_points = np.squeeze(gradients[0][1]) plot_scatter_source.set_data(source_points[:, 0], source_points[:, 1]) plot_scatter_dest.set_data(dest_points[:, 0], dest_points[:, 1]) source_gradients = np.squeeze(gradients[0][0]) #dest_gradients = np.squeeze(gradients_dest[0][0]) plot_source_grad.remove() plot_source_grad = ax[0, 1].quiver( source_points[:, 0], # X source_points[:, 1], # Y source_gradients[:, 0], source_gradients[:, 1], source_gradients, units='xy', angles='xy', scale_units='xy', scale=1) #grid_plot = plot_grid(ax[0,1],source_points[:,0],source_points[:,1]) #plot_dest_grad.remove() #plot_dest_grad = ax[0,2].quiver( # dest_points[:,0], # X # dest_points[:,1], # Y # dest_gradients[:,0], # dest_gradients[:,1], # dest_gradients, # units='xy',angles='xy', scale_units='xy', scale=1) # https://stackoverflow.com/questions/48911643/set-uvc-equivilent-for-a-3d-quiver-plot-in-matplotlib # new_segs = [ [ [x,y,z], [u,v,w] ] for x,y,z,u,v,w in zip(*segs.tolist()) ] # quivers.set_segments(new_segs) #plot_source_grad.set_UVC( # source_gradients[:,0], # source_gradients[:,1], # source_gradients) #plot_dest_grad.set_UVC( # dest_gradients[:,0], # dest_gradients[:,1], # dest_gradients) fig.canvas.draw() fig.canvas.flush_events() plt.show() print("Done!") plt.ioff() plt.show() sys.exit(0)
def main(argv): parser = argparse.ArgumentParser( description='Compute codes and reconstructions for image.') parser.add_argument('export_dir', type=str, help='Path to saved model.') parser.add_argument( 'mean', type=str, help='Path to npy file holding mean for normalization.') parser.add_argument( 'variance', type=str, help='Path to npy file holding variance for normalization.') parser.add_argument('source_filename', type=str, help='Image file from which to extract patch.') parser.add_argument('source_image_size', type=int, nargs=2, help='Size of the input image, HW.') parser.add_argument( 'source_landmarks', type=str, help='CSV file from which to extract the landmarks for source image.') parser.add_argument('target_filename', type=str, help='Image file for which to create the heatmap.') parser.add_argument( 'target_image_size', type=int, nargs=2, help='Size of the input image for which to create heatmap, HW.') parser.add_argument( 'target_landmarks', type=str, help='CSV file from which to extract the landmarks for target image.') parser.add_argument('patch_size', type=int, help='Size of image patch.') parser.add_argument( '--method', dest='method', type=str, help= 'Method to use to measure similarity, one of KLD, SKLD, BD, HD, SQHD.') parser.add_argument( '--stain_code_size', type=int, dest='stain_code_size', default=0, help= 'Optional: Size of the stain code to use, which is skipped for similarity estimation' ) parser.add_argument('--rotate', type=float, dest='angle', default=0, help='Optional: rotation angle to rotate target image') parser.add_argument('--subsampling_factor', type=int, dest='subsampling_factor', default=1, help='Factor to subsample source and target image.') args = parser.parse_args() mean = np.load(args.mean) variance = np.load(args.variance) stddev = [np.math.sqrt(x) for x in variance] def denormalize(image): channels = [ np.expand_dims( image[:, :, channel] * stddev[channel] + mean[channel], -1) for channel in range(3) ] denormalized_image = ctfi.rescale(np.concatenate(channels, 2), 0.0, 1.0) return denormalized_image def normalize(image, name=None, num_channels=3): channels = [ tf.expand_dims( (image[:, :, :, channel] - mean[channel]) / stddev[channel], -1) for channel in range(num_channels) ] return tf.concat(channels, num_channels) latest_checkpoint = tf.train.latest_checkpoint(args.export_dir) saver = tf.train.import_meta_graph(latest_checkpoint + '.meta', import_scope='imported') config = tf.ConfigProto() config.allow_soft_placement = True #config.log_device_placement=True # Load image and extract patch from it and create distribution. source_image = tf.expand_dims( ctfi.subsample( ctfi.load(args.source_filename, height=args.source_image_size[0], width=args.source_image_size[1]), args.subsampling_factor), 0) args.source_image_size = list( map(lambda x: int(x / args.subsampling_factor), args.source_image_size)) #Load image for which to create the heatmap target_image = tf.expand_dims( ctfi.subsample( ctfi.load(args.target_filename, height=args.target_image_size[0], width=args.target_image_size[1]), args.subsampling_factor), 0) args.target_image_size = list( map(lambda x: int(x / args.subsampling_factor), args.target_image_size)) source_landmarks = get_landmarks(args.source_landmarks, args.subsampling_factor) source_patches = tf.squeeze( tf.map_fn(lambda x: get_patch_at(x, source_image, args.patch_size), source_landmarks)) target_landmarks = get_landmarks(args.target_landmarks, args.subsampling_factor) target_patches = tf.squeeze( tf.map_fn(lambda x: get_patch_at(x, target_image, args.patch_size), target_landmarks)) with tf.Session(config=config).as_default() as sess: saver.restore(sess, latest_checkpoint) source_patches_cov, source_patches_mean = tf.contrib.graph_editor.graph_replace( [ sess.graph.get_tensor_by_name( 'imported/z_log_sigma_sq/BiasAdd:0'), sess.graph.get_tensor_by_name('imported/z_mean/BiasAdd:0') ], { sess.graph.get_tensor_by_name('imported/patch:0'): normalize(source_patches) }) source_patches_distribution = tf.contrib.distributions.MultivariateNormalDiag( source_patches_mean[:, args.stain_code_size:], tf.exp(source_patches_cov[:, args.stain_code_size:])) target_patches_cov, target_patches_mean = tf.contrib.graph_editor.graph_replace( [ sess.graph.get_tensor_by_name( 'imported/z_log_sigma_sq/BiasAdd:0'), sess.graph.get_tensor_by_name('imported/z_mean/BiasAdd:0') ], { sess.graph.get_tensor_by_name('imported/patch:0'): normalize(target_patches) }) target_patches_distribution = tf.contrib.distributions.MultivariateNormalDiag( target_patches_mean[:, args.stain_code_size:], tf.exp(target_patches_cov[:, args.stain_code_size:])) #similarities = source_patches_distribution.kl_divergence(target_patches_distribution) + target_patches_distribution.kl_divergence(source_patches_distribution) #similarities = ctf.multivariate_squared_hellinger_distance(source_patches_distribution, target_patches_distribution) similarities = ctf.bhattacharyya_distance(source_patches_distribution, target_patches_distribution) #similarities = tf.reduce_sum(tf.abs(source_patches - target_patches), axis=[1,2,3]) sim_vals = sess.run(similarities) min_idx = np.argmin(sim_vals) max_idx = np.argmax(sim_vals) print(sim_vals) print(min_idx, sim_vals[min_idx]) print(max_idx, sim_vals[max_idx]) fig, ax = plt.subplots(2, 3) ax[0, 0].imshow(sess.run(source_image[0])) ax[0, 1].imshow(sess.run(source_patches)[min_idx]) ax[0, 2].imshow(sess.run(source_patches)[max_idx]) ax[1, 0].imshow(sess.run(target_image[0])) ax[1, 1].imshow(sess.run(target_patches)[min_idx]) ax[1, 2].imshow(sess.run(target_patches)[max_idx]) plt.show() sess.close() return 0
def main(argv): parser = argparse.ArgumentParser( description='Register images using keypoints.') parser.add_argument('export_dir', type=str, help='Path to saved model.') parser.add_argument( 'mean', type=str, help='Path to npy file holding mean for normalization.') parser.add_argument( 'variance', type=str, help='Path to npy file holding variance for normalization.') parser.add_argument('patch_size', type=int, help='Size of image patch.') parser.add_argument('source_filename', type=str, help='Image file from which to extract patch.') parser.add_argument('source_image_size', type=int, nargs=2, help='Size of the input image, HW.') parser.add_argument('target_filename', type=str, help='Image file for which to create the heatmap.') parser.add_argument( 'target_image_size', type=int, nargs=2, help='Size of the input image for which to create heatmap, HW.') parser.add_argument('num_keypoints', type=int, help='Number of keypoints to detect.') parser.add_argument('num_matches', type=int, help='Number of matches to keep.') parser.add_argument( '--stain_code_size', type=int, dest='stain_code_size', default=0, help= 'Optional: Size of the stain code to use, which is skipped for similarity estimation' ) parser.add_argument( '--leaf_size', type=int, dest='leaf_size', default=30, help='Number of elements to keep in leaf nodes of search tree.') parser.add_argument( '--method', type=str, dest='method', default='SKLD', help= 'Method to use to measure similarity, one of KLD, SKLD, BD, HD, SQHD.') parser.add_argument('--num_neighbours', type=int, dest='num_neighbours', default=1, help='k for kNN') parser.add_argument('--subsampling_factor', type=int, dest='subsampling_factor', default=1, help='Factor to subsample source and target image.') args = parser.parse_args() mean = np.load(args.mean) variance = np.load(args.variance) stddev = [np.math.sqrt(x) for x in variance] def denormalize(image): channels = [ np.expand_dims( image[:, :, channel] * stddev[channel] + mean[channel], -1) for channel in range(3) ] denormalized_image = ctfi.rescale(np.concatenate(channels, 2), 0.0, 1.0) return denormalized_image def normalize(image, name=None): channels = [ tf.expand_dims( (image[:, :, :, channel] - mean[channel]) / stddev[channel], -1) for channel in range(3) ] return tf.concat(channels, 3, name=name) def get_patch_at(keypoint, image): return tf.image.extract_glimpse(image, [args.patch_size, args.patch_size], [keypoint], normalized=False, centered=False) latest_checkpoint = tf.train.latest_checkpoint(args.export_dir) saver = tf.train.import_meta_graph(latest_checkpoint + '.meta', import_scope='imported') with tf.Session(graph=tf.get_default_graph()).as_default() as sess: #saver.restore(sess,latest_checkpoint) # Load image and extract patch from it and create distribution. source_image = ctfi.subsample( tf.expand_dims( ctfi.load(args.source_filename, height=args.source_image_size[0], width=args.source_image_size[1]), 0), args.subsampling_factor) im_source = (sess.run(source_image[0]) * 255).astype(np.uint8) target_image = ctfi.subsample( tf.expand_dims( ctfi.load(args.target_filename, height=args.target_image_size[0], width=args.target_image_size[1]), 0), args.subsampling_factor) im_target = (sess.run(target_image[0]) * 255).astype(np.uint8) orb = cv2.ORB_create(20000) source_keypoints, source_descriptors_cv = orb.detectAndCompute( im_source, None) target_keypoints, target_descriptors_cv = orb.detectAndCompute( im_target, None) #for keypoint in source_keypoints: # keypoint.pt = (keypoint.pt[1], keypoint.pt[0]) #for keypoint in target_keypoints: # keypoint.pt = (keypoint.pt[1], keypoint.pt[0]) patch_kp_0 = get_patch_at(source_keypoints[0].pt, source_image) #plt.imshow(sess.run(patch_kp_0)[0]) #plt.show() #source_keypoints.sort(key = lambda x: x.response, reverse=False) #target_keypoints.sort(key = lambda x: x.response, reverse=False) def remove_overlapping(x, keypoints): for p in keypoints: if p != x and x.overlap(x, p) > 0.8: keypoints.remove(p) return keypoints def filter_keypoints(keypoints): i = 0 while i < len(keypoints): end_idx = len(keypoints) - 1 - i p = keypoints[end_idx] keypoints = remove_overlapping(p, keypoints) i += 1 return keypoints #source_keypoints = filter_keypoints(source_keypoints) #target_keypoints = filter_keypoints(target_keypoints) source_keypoints.sort(key=lambda x: x.response, reverse=True) target_keypoints.sort(key=lambda x: x.response, reverse=True) source_keypoints = source_keypoints[:args.num_keypoints] target_keypoints = target_keypoints[:args.num_keypoints] source_descriptors_eval = [] target_descriptors_eval = [] #source_patches = normalize(tf.concat(list(map(lambda x: get_patch_at(x, source_image), source_keypoints)),0)) #target_patches = normalize(tf.concat(list(map(lambda x: get_patch_at(x, target_image), target_keypoints)),0)) patches_placeholder = tf.placeholder( tf.float32, shape=[1000, args.patch_size, args.patch_size, 3]) #source_patches_placeholder = tf.placeholder(tf.float32,shape=[1000, args.patch_size, args.patch_size, 3]) #target_patches_placeholder = tf.placeholder(tf.float32,shape=[1000, args.patch_size, args.patch_size, 3]) tf_cov, tf_mean = tf.contrib.graph_editor.graph_replace([ sess.graph.get_tensor_by_name('imported/z_log_sigma_sq/BiasAdd:0'), sess.graph.get_tensor_by_name('imported/z_mean/BiasAdd:0') ], { sess.graph.get_tensor_by_name('imported/patch:0'): patches_placeholder }) #source_cov, source_mean = tf.contrib.graph_editor.graph_replace([sess.graph.get_tensor_by_name('imported/z_covariance_lower_tri/MatrixBandPart:0'),sess.graph.get_tensor_by_name('imported/z_mean/BiasAdd:0')] ,{ sess.graph.get_tensor_by_name('imported/patch:0'): source_patches_placeholder }) #target_cov, target_mean = tf.contrib.graph_editor.graph_replace([sess.graph.get_tensor_by_name('imported/z_covariance_lower_tri/MatrixBandPart:0'),sess.graph.get_tensor_by_name('imported/z_mean/BiasAdd:0')] ,{ sess.graph.get_tensor_by_name('imported/patch:0'): target_patches_placeholder }) batch, latent_code_size = tf_mean.get_shape().as_list() #batch, latent_code_size = target_mean.get_shape().as_list() structure_code_size = latent_code_size - args.stain_code_size descriptors = tf.concat([ tf_mean[:, args.stain_code_size:], tf.layers.flatten(tf_cov[:, args.stain_code_size:]) ], -1) #source_descriptors = tf.concat([source_mean[:,args.stain_code_size:], tf.layers.flatten(source_cov[:,args.stain_code_size:,args.stain_code_size:])], -1) #target_descriptors = tf.concat([target_mean[:,args.stain_code_size:], tf.layers.flatten(target_cov[:,args.stain_code_size:,args.stain_code_size:])], -1) def multi_kl_div(X, Y): X_mean, X_cov = get_mean_and_cov(X, structure_code_size) Y_mean, Y_cov = get_mean_and_cov(Y, structure_code_size) Y_cov_inv = np.linalg.inv(Y_cov) trace_term = np.matrix.trace(np.matmul(Y_cov_inv, X_cov)) diff_mean = np.expand_dims(Y_mean - X_mean, axis=-1) middle_term = np.matmul(np.transpose(diff_mean), np.matmul(Y_cov_inv, diff_mean)) determinant_term = np.log( np.linalg.det(Y_cov) / np.linalg.det(X_cov)) value = 0.5 * (trace_term + middle_term - structure_code_size + determinant_term) return np.squeeze(value) def multi_kl_div_tf(X, Y): X_mean, X_cov = get_mean_and_cov_tf(X, structure_code_size) Y_mean, Y_cov = get_mean_and_cov_tf(Y, structure_code_size) Y_cov_inv = tf.linalg.inv(Y_cov) trace_term = tf.linalg.trace(tf.matmul(Y_cov_inv, X_cov)) diff_mean = tf.expand_dims(Y_mean - X_mean, axis=-1) middle_term = tf.matmul(diff_mean, tf.matmul(Y_cov_inv, diff_mean), transpose_a=True) determinant_term = tf.log( tf.linalg.det(Y_cov) / tf.linalg.det(X_cov)) value = 0.5 * (trace_term + middle_term - structure_code_size + determinant_term) return tf.squeeze(value) def sym_kl_div(X, Y): return multi_kl_div(X, Y) + multi_kl_div(Y, X) def sym_kl_div_tf(X, Y): return multi_kl_div_tf(X, Y) + multi_kl_div_tf(Y, X) def sqhd(X, Y): return multivariate_squared_hellinger_distance( X, Y, structure_code_size) def bd(X, Y): X_mean, X_cov = get_mean_and_cov(X, structure_code_size) Y_mean, Y_cov = get_mean_and_cov(Y, structure_code_size) return bhattacharyya_distance(X_mean, X_cov, Y_mean, Y_cov) def centroid_distance(X, Y): X_mean, X_cov = get_mean_and_cov(X, structure_code_size) Y_mean, Y_cov = get_mean_and_cov(Y, structure_code_size) return np.linalg.norm(X_mean - Y_mean) coords = tf.placeholder(tf.float32, shape=[1000, 2]) source_patches = tf.map_fn(lambda x: get_patch_at(x, source_image), coords) target_patches = tf.map_fn(lambda x: get_patch_at(x, target_image), coords) # Computation of distance metric descriptor_length = descriptors.get_shape().as_list()[1] def cdist_tf(X, Y): X_mean = X[:, args.stain_code_size:] Y_mean = Y[:, args.stain_code_size:] diff_means_einsum = tf.sqrt( tf.einsum('ij,ij->i', X_mean, X_mean)[:, None] + tf.einsum('ij,ij->i', Y_mean, Y_mean) - 2 * tf.matmul(X_mean, Y_mean, transpose_b=True)) return diff_means_einsum def dist_kl(X, Y): X_mean = X[:, 0:structure_code_size] X_cov = tf.exp(X[:, structure_code_size:]) Y_mean = Y[:, 0:structure_code_size] Y_cov = tf.exp(Y[:, structure_code_size:]) return ctf.fast_symmetric_kl_div(X_mean, X_cov, Y_mean, Y_cov) #tf_dist_op = cdist_tf(tf_src_descs, tf_trgt_descs) sess.run(tf.global_variables_initializer()) sess.run(tf.local_variables_initializer()) saver.restore(sess, latest_checkpoint) for i in range(int(args.num_keypoints / 1000)): start = i * 1000 end = (i + 1) * 1000 #source_patches = sess.run(normalize(tf.concat(list(map(lambda x: get_patch_at(x, source_image), source_keypoints[start:end])),0))) #target_patches = sess.run(normalize(tf.concat(list(map(lambda x: get_patch_at(x, target_image), target_keypoints[start:end])),0))) #source_coords = tf.convert_to_tensor(np.array([list(key_point.pt) for key_point in source_keypoints])) source_coords_np = np.array( [[key_point.pt[1], key_point.pt[0]] for key_point in source_keypoints[start:end]]) target_coords_np = np.array( [[key_point.pt[1], key_point.pt[0]] for key_point in target_keypoints[start:end]]) source_descriptors_eval.extend( sess.run(descriptors, feed_dict={ patches_placeholder: np.squeeze( sess.run(source_patches, feed_dict={coords: source_coords_np})) })) target_descriptors_eval.extend( sess.run(descriptors, feed_dict={ patches_placeholder: np.squeeze( sess.run(target_patches, feed_dict={coords: target_coords_np})) })) #source_descriptors_eval.extend(sess.run(source_descriptors, feed_dict={source_patches_placeholder : source_patches})) #target_descriptors_eval.extend(sess.run(target_descriptors, feed_dict={target_patches_placeholder : target_patches})) #sess.close() #with tf.Session(graph=tf.get_default_graph()).as_default() as sess: tf_src_descs = tf.placeholder( tf.float32, shape=[args.num_keypoints, descriptor_length]) tf_trgt_descs = tf.placeholder( tf.float32, shape=[args.num_keypoints, descriptor_length]) dist_op = dist_kl(tf_src_descs, tf_trgt_descs) #sess.run(tf.global_variables_initializer()) #matches = match_descriptors(source_descriptors, target_descriptors, metric=lambda x,y: sym_kl_div(x,y), cross_check=True) if args.method == 'SKLD': metric = sym_kl_div elif args.method == 'SQHD': metric = sqhd elif args.method == 'BD': metric = bd elif args.method == 'CD': metric = centroid_distance else: metric = sym_kl_div distances = sess.run(dist_op, feed_dict={ tf_src_descs: np.array(source_descriptors_eval), tf_trgt_descs: np.array(target_descriptors_eval) }) indices = np.expand_dims(np.argmin(distances, axis=1), 1) min_distances = [distances[i, indices[i]] for i in range(len(indices))] #knn_source = sklearn.neighbors.NearestNeighbors(n_neighbors=5, radius=1.0, algorithm='ball_tree', leaf_size=args.leaf_size, metric=metric) #knn_source.fit(target_descriptors_eval) #distances, indices = knn_source.kneighbors(source_descriptors_eval, n_neighbors=args.num_neighbours) matches = list(zip(range(len(indices)), indices, min_distances)) # Sort matches by score matches.sort(key=lambda x: np.min(x[2]), reverse=False) matches = matches[:args.num_matches] def create_dmatch(queryIdx, trainIdx, distance): dmatch = cv2.DMatch(queryIdx, trainIdx, 0, distance) return dmatch def create_cv_matches(match): items = [] for i in range(len(match[1])): items.append(cv2.DMatch(match[0], match[1][i], 0, match[2][i])) return items all_cv_matches = [] for match in matches: all_cv_matches.extend(create_cv_matches(match)) sess.close() #cv_matches = list(map(lambda x: create_dmatch(x[0], x[1], x[2]),matches)) # Draw top matches imMatches = cv2.drawMatches(im_source, source_keypoints, im_target, target_keypoints, all_cv_matches, None) fix, ax = plt.subplots(1) ax.imshow(imMatches) plt.show() print("Detected keypoints!") return 0
def main(argv): parser = argparse.ArgumentParser( description= 'Create tfrecords dataset holding patches of images specified by filename in input dataset.' ) parser.add_argument('input_dataset', type=str, help='Path to dataset holding image filenames') parser.add_argument('output_dataset', type=str, help='Path where to store the output dataset') parser.add_argument( 'patch_size', type=int, help='Patch size which to use in the preprocessed dataset') parser.add_argument('num_samples', type=int, help='Size of output dataset') parser.add_argument( 'labels', type=lambda s: [item for item in s.split(',')], help="Comma separated list of labels to find in filenames.") parser.add_argument('--image_size', type=int, dest='image_size', help='Image size for files pointed to by filename') parser.add_argument( '--no_filter', dest='no_filter', action='store_true', default=False, help='Whether to apply total image variation filtering.') parser.add_argument( '--threshold', type=float, dest='threshold', help='Threshold for filtering the samples according to variation.') parser.add_argument('--subsampling_factor', type=int, dest='subsampling_factor', default=1, help='Subsampling factor to use to downsample images.') args = parser.parse_args() labels_table = tf.contrib.lookup.index_table_from_tensor( mapping=args.labels) filename_dataset = tf.data.TFRecordDataset( args.input_dataset, num_parallel_reads=8).map(_decode_example_filename).shuffle(100000) functions = [ tf.Variable(label, name='const_' + label).value for label in args.labels ] false_fn = tf.Variable('None', name='none_label').value def _extract_label(filename): #base_size = tf.size(tf.string_split([filename],"")) #predicates = [tf.equal(base_size, tf.size(tf.string_split([tf.regex_replace(filename, "/"+ label + "/", "")]))) for label in args.labels] match = [ tf.math.reduce_any( tf.strings.regex_full_match( tf.string_split([filename], '/').values, label)) for label in args.labels ] pred_fn_pairs = list(zip(match, functions)) return tf.case(pred_fn_pairs, default=false_fn, exclusive=True) # Load images and extract the label from the filename if args.image_size is not None: images_dataset = filename_dataset.map( lambda feature: { 'image': ctfi.load(feature['filename'], channels=3, width=args.image_size, height=args.image_size), 'label': labels_table.lookup(_extract_label(feature['filename'])) }) else: images_dataset = filename_dataset.map( lambda feature: { 'image': ctfi.load(feature['filename'], channels=3), 'label': labels_table.lookup( _extract_label(feature['filename'])) }) if args.subsampling_factor > 1: images_dataset = images_dataset.map( lambda feature: { 'image': ctfi.subsample(feature['image'], args. subsampling_factor), 'label': feature['label'] }) def _filter_func_label(features): label = features['label'] result = label > -1 return result images_dataset = images_dataset.filter(_filter_func_label).shuffle(100) # Extract image patches #for sample in tfe.Iterator(images_dataset): # print(sample['label']) def _split_patches(features): patches = ctfi.extract_patches(features['image'], args.patch_size) labels = tf.expand_dims(tf.reshape(features['label'], [1]), 0) labels = tf.tile(labels, tf.stack([tf.shape(patches)[0], 1])) return (patches, labels) patches_dataset = images_dataset.map(_split_patches).apply( tf.data.experimental.unbatch()) patches_dataset = patches_dataset.map(lambda patch, label: { 'patch': patch, 'label': label }) if args.threshold is not None: threshold = args.threshold else: threshold = 0.08 num_filtered_patches = tf.Variable(0) filtered_patch_ratio = 10 # Filter function which filters the dataset after total image variation. # See: https://www.tensorflow.org/versions/r1.12/api_docs/python/tf/image/total_variation def add_background_info(sample): variation = tf.image.total_variation(sample['patch']) num_pixels = sample['patch'].get_shape().num_elements() var_per_pixel = (variation / num_pixels) no_background = var_per_pixel > threshold sample['no_background'] = no_background return sample #def true_fn(): # sample.update({'no_background': True}) # return sample #def false_fn(): # def _true_fn_lvl2(): # sample.update({'label':tf.reshape(tf.convert_to_tensor(len(args.labels), dtype=tf.int64), [1]),'no_background': True}) # return sample # def _false_fn_lvl2(): # sample.update({'no_background': False}) # return sample # pred = tf.equal(num_filtered_patches.assign_add(1) % 10, 0) # return tf.cond(pred,true_fn=_true_fn_lvl2,false_fn=_false_fn_lvl2) #return tf.cond(no_background,true_fn=true_fn, false_fn=false_fn) if args.no_filter == True: dataset = patches_dataset else: dataset = patches_dataset.map(add_background_info) filtered_elements_dataset = dataset.filter( lambda sample: tf.logical_not(sample['no_background'])) def change_label(sample): return { 'patch': sample['patch'], 'label': tf.reshape( tf.convert_to_tensor(len(args.labels), dtype=tf.int64), [1]) } filtered_elements_dataset = filtered_elements_dataset.map(change_label) filtered_dataset = dataset.filter(lambda sample: sample[ 'no_background']).map(lambda sample: { 'patch': sample['patch'], 'label': sample['label'] }) dataset = tf.data.experimental.sample_from_datasets( [filtered_dataset, filtered_elements_dataset], weights=[0.95, 0.05]) dataset = dataset.map(lambda sample: (sample['patch'], sample['label'])) dataset = dataset.take(args.num_samples).shuffle(100000) writer = tf.io.TFRecordWriter(args.output_dataset) # Make file readable for all users cutil.publish(args.output_dataset) def _encode_func(sample): patch_np = sample[0].numpy().flatten() label_np = sample[1].numpy() return ctfd.encode({ 'patch': ctf.float_feature(patch_np), 'label': ctf.int64_feature(label_np) }) # Iterate over whole dataset and write serialized examples to file. # See: https://www.tensorflow.org/versions/r1.12/api_docs/python/tf/contrib/eager/Iterator for sample in tfe.Iterator(dataset): example = _encode_func(sample) writer.write(example.SerializeToString()) # Flush and close the writer. writer.flush() writer.close()
def main(argv): parser = argparse.ArgumentParser( description='Compute latent code for image patch by model inference.') parser.add_argument('export_dir', type=str, help='Path to saved model to use for inference.') args = parser.parse_args() width = 512 height = 512 channels = 3 filename_target = os.path.join(git_root, 'data', 'images', 'HE_level_1_cropped_512x512.png') image_target = tf.expand_dims( ctfi.load(filename_target, width=width, height=height, channels=channels), 0) image_target = tf.reshape(image_target, shape=[1, 512, 512, 3]) image_target = tf.contrib.image.rotate(image_target, 0.05 * math.pi) filename_moving = os.path.join(git_root, 'data', 'images', 'HE_level_1_cropped_512x512.png') image_moving = tf.expand_dims( ctfi.load(filename_moving, width=width, height=height, channels=channels), 0) image_moving = tf.reshape(image_moving, shape=[1, 512, 512, 3]) image_moving = tf.contrib.image.rotate(image_moving, -0.05 * math.pi) step = tf.Variable(tf.zeros([], dtype=tf.float32)) X, Y = np.mgrid[0:width:8j, 0:height:8j] positions = np.transpose(np.vstack([X.ravel(), Y.ravel()])) positions = tf.expand_dims( tf.convert_to_tensor(positions, dtype=tf.float32), 0) target_source_control_point_locations = tf.Variable(positions) moving_source_control_point_locations = tf.Variable(positions) dest_control_point_locations = tf.Variable(positions) warped_moving = tf.Variable(image_moving) warped_moving, flow_moving = tf.contrib.image.sparse_image_warp( warped_moving, moving_source_control_point_locations, dest_control_point_locations, name='sparse_image_warp_moving', interpolation_order=1, regularization_weight=0.01, #num_boundary_points=1 ) warped_target = tf.Variable(image_target) warped_target, flow_target = tf.contrib.image.sparse_image_warp( warped_target, target_source_control_point_locations, dest_control_point_locations, name='sparse_image_warp_target', interpolation_order=1, regularization_weight=0.01, #num_boundary_points=1 ) warped_target_patches = normalize( ctfi.extract_patches(warped_target[0], 32, strides=[1, 32, 32, 1])) warped_moving_patches = normalize( ctfi.extract_patches(warped_moving[0], 32, strides=[1, 32, 32, 1])) #warped_target_patches = normalize(tf.image.extract_glimpse(tf.tile(warped_target,[64,1,1,1]),[32,32],target_source_control_point_locations[0], centered=False)) #warped_moving_patches = normalize(tf.image.extract_glimpse(tf.tile(warped_moving,[64,1,1,1]),[32,32],moving_source_control_point_locations[0], centered=False)) #learning_rate = 0.05 # h_squared learning_rate = 0.05 # sym_kl #learning_rate = 0.05 # battacharyya #learning_rate = 1 #hellinger #learning_rate = 0.005 # ssd loss latest_checkpoint = tf.train.latest_checkpoint(args.export_dir) #saver_target = tf.train.import_meta_graph(latest_checkpoint + '.meta', import_scope='target') #saver_moving = tf.train.import_meta_graph(latest_checkpoint + '.meta', import_scope='moving') saver = tf.train.import_meta_graph(latest_checkpoint + '.meta', import_scope='imported') with tf.Session(graph=tf.get_default_graph()).as_default() as sess: #g = tf.Graph() #saved_model = predictor.from_saved_model('/sdb1/logs/examples/models/gae_sampler_v2_0/saved_model/1574232815', graph=sess.graph) #fetch_ops = ['max_pooling2d_4/MaxPool:0','init'] #fetch_ops = ['z:0','init'] #fetch_ops = ['z_mean/BiasAdd:0','z_covariance/MatrixBandPart:0'] #fetch_ops.extend([v.name for v in g.get_collection(tf.GraphKeys.GLOBAL_VARIABLES)]) #warped_target_graph = tf.graph_util.import_graph_def(sess.graph.as_graph_def(), input_map={'patch:0': warped_target_patches}, return_elements=fetch_ops, name='target') #warped_moving_graph = tf.graph_util.import_graph_def(sess.graph.as_graph_def(),input_map={'patch:0': warped_moving_patches}, return_elements=fetch_ops, name='moving') #sess.run(warped_target_graph[2:]) #sess.run(warped_moving_graph[2:]) #target_cov, target_mean = tf.contrib.graph_editor.graph_replace([sess.graph.get_tensor_by_name('target/z_covariance/MatrixBandPart:0'),sess.graph.get_tensor_by_name('target/z_mean/BiasAdd:0')] ,{ sess.graph.get_tensor_by_name('target/patch:0'): warped_target_patches }) #moving_cov, moving_mean = tf.contrib.graph_editor.graph_replace([sess.graph.get_tensor_by_name('moving/z_covariance/MatrixBandPart:0'),sess.graph.get_tensor_by_name('moving/z_mean/BiasAdd:0')] ,{ sess.graph.get_tensor_by_name('moving/patch:0'): warped_moving_patches }) target_cov, target_mean = tf.contrib.graph_editor.graph_replace( [ sess.graph.get_tensor_by_name( 'imported/z_covariance_lower_tri/MatrixBandPart:0'), sess.graph.get_tensor_by_name('imported/z_mean/BiasAdd:0') ], { sess.graph.get_tensor_by_name('imported/patch:0'): warped_target_patches }) moving_cov, moving_mean = tf.contrib.graph_editor.graph_replace( [ sess.graph.get_tensor_by_name( 'imported/z_covariance_lower_tri/MatrixBandPart:0'), sess.graph.get_tensor_by_name('imported/z_mean/BiasAdd:0') ], { sess.graph.get_tensor_by_name('imported/patch:0'): warped_moving_patches }) #target_mean = warped_target_graph[0]#[:,6:] #target_cov = warped_target_graph[1]#[:,6:,6:] stain_code_size = 8 N_target = tf.contrib.distributions.MultivariateNormalTriL( loc=target_mean[:, stain_code_size:], scale_tril=target_cov[:, stain_code_size:, stain_code_size:]) #moving_mean = warped_moving_graph[0]#[:,6:] #moving_cov = warped_moving_graph[1]#[:,6:,6:] N_mov = tf.contrib.distributions.MultivariateNormalTriL( loc=moving_mean[:, stain_code_size:], scale_tril=moving_cov[:, stain_code_size:, stain_code_size:]) sym_kl_div = N_target.kl_divergence(N_mov) + N_mov.kl_divergence( N_target) #h_squared = ctf.multivariate_squared_hellinger_distance(N_target, N_mov) #hellinger = tf.sqrt(h_squared) #batta_dist = ctf.bhattacharyya_distance(N_target, N_mov) #multi_kl_div = ctf.multivariate_kl_div(N_target, N_mov) + ctf.multivariate_kl_div(N_mov, N_target) loss = tf.reduce_sum(sym_kl_div) #loss = tf.reduce_sum(tf.math.squared_difference(warped_target_codes, warped_moving_codes)) #loss = tf.reduce_sum(tf.sqrt(tf.math.squared_difference(image_code, warped_code))) #loss = tf.reduce_sum(tf.math.squared_difference(warped_target, warped_moving)) optimizer = tf.contrib.optimizer_v2.GradientDescentOptimizer( learning_rate=learning_rate) compute_gradients = optimizer.compute_gradients( loss, var_list=[ moving_source_control_point_locations, target_source_control_point_locations ]) apply_gradients = optimizer.apply_gradients(compute_gradients, global_step=step) sess.run(tf.global_variables_initializer()) sess.run(tf.local_variables_initializer()) #saver_target.restore(sess, latest_checkpoint) #saver_moving.restore(sess, latest_checkpoint) saver.restore(sess, latest_checkpoint) fig, ax = plt.subplots(3, 3) ax[0, 0].imshow( ctfi.rescale(image_target.eval(session=sess)[0], 0.0, 1.0)) ax[0, 0].set_title('target') ax[0, 0].set_autoscale_on(False) ax[0, 1].imshow( ctfi.rescale((image_target + image_moving).eval(session=sess)[0], 0.0, 1.0)) ax[0, 1].set_title('overlayed') ax[0, 1].set_autoscale_on(False) ax[0, 2].imshow( ctfi.rescale(image_moving.eval(session=sess)[0], 0.0, 1.0)) ax[0, 2].set_title('moving') ax[0, 2].set_autoscale_on(False) plot_warped_target = ax[1, 0].imshow( ctfi.rescale(warped_target.eval(session=sess)[0], 0.0, 1.0)) ax[1, 0].set_title('warped_target') ax[1, 0].set_autoscale_on(False) plot_overlayed = ax[1, 1].imshow( ctfi.rescale((warped_target + warped_moving).eval(session=sess)[0], 0.0, 1.0)) ax[1, 1].set_title('warped_overlayed') ax[1, 1].set_autoscale_on(False) plot_warped_moving = ax[1, 2].imshow( ctfi.rescale(warped_moving.eval(session=sess)[0], 0.0, 1.0)) ax[1, 2].set_title('warped_moving') ax[1, 2].set_autoscale_on(False) plot_diff_target = ax[2, 0].imshow( ctfi.rescale( tf.abs(image_target - warped_target).eval(session=sess)[0], 0., 1.)) ax[2, 0].set_title('diff_target') ax[2, 0].set_autoscale_on(False) plot_diff_overlayed = ax[2, 1].imshow( ctfi.rescale( tf.abs(warped_target - warped_moving).eval(session=sess)[0], 0., 1.)) ax[2, 1].set_title('diff_overlayed') ax[2, 1].set_autoscale_on(False) plot_diff_moving = ax[2, 2].imshow( ctfi.rescale( tf.abs(image_moving - warped_moving).eval(session=sess)[0], 0., 1.)) ax[2, 2].set_title('diff_moving') ax[2, 2].set_autoscale_on(False) dest_points = dest_control_point_locations.eval(session=sess)[0] moving_source_points = moving_source_control_point_locations.eval( session=sess)[0] target_source_points = target_source_control_point_locations.eval( session=sess)[0] plot_scatter_moving, = ax[1, 2].plot(moving_source_points[:, 0], moving_source_points[:, 1], 's', marker='x', ms=5, color='orange') plot_scatter_target, = ax[1, 0].plot(target_source_points[:, 0], target_source_points[:, 1], 's', marker='x', ms=5, color='orange') plot_moving_grad = ax[1, 2].quiver( moving_source_points[:, 0], # X moving_source_points[:, 1], # Y np.zeros_like(moving_source_points[:, 0]), np.zeros_like(moving_source_points[:, 0]), units='xy', angles='xy', scale_units='xy', scale=1) plot_target_grad = ax[1, 0].quiver( target_source_points[:, 0], # X target_source_points[:, 1], # Y np.zeros_like(target_source_points[:, 0]), np.zeros_like(target_source_points[:, 0]), units='xy', angles='xy', scale_units='xy', scale=1) plt.ion() fig.canvas.draw() fig.canvas.flush_events() plt.show() iterations = 5000 print_iterations = 1 accumulated_gradients = np.zeros_like(sess.run(compute_gradients)) while step.value().eval(session=sess) < iterations: step_val = int(step.value().eval(session=sess)) gradients = sess.run(compute_gradients) sess.run(apply_gradients) accumulated_gradients += gradients if step_val % print_iterations == 0 or step_val == iterations - 1: #moving_cov_val = sess.run(moving_cov) #target_cov_val = sess.run(target_cov) #moving_mean_val = sess.run(moving_mean) #target_mean_val = sess.run(target_mean) loss_val = loss.eval(session=sess) diff_moving = tf.abs(image_moving - warped_moving).eval(session=sess) diff_target = tf.abs(image_target - warped_target).eval(session=sess) diff = tf.abs(warped_target - warped_moving).eval(session=sess) #warped_code_eval = np.mean(warped_moving_codes.eval(session=sess)) print("{0:d}\t{1:.4f}\t{2:.4f}\t{3:.4f}\t{4:.4f}".format( step_val, loss_val, np.sum(diff_moving), np.sum(diff_target), np.sum(diff))) plot_warped_target.set_data( ctfi.rescale(warped_target.eval(session=sess)[0], 0., 1.)) plot_warped_moving.set_data( ctfi.rescale(warped_moving.eval(session=sess)[0], 0., 1.)) plot_overlayed.set_data( ctfi.rescale( (warped_target + warped_moving).eval(session=sess)[0], 0., 1.)) plot_diff_target.set_data(ctfi.rescale(diff_target[0], 0., 1.)) plot_diff_moving.set_data(ctfi.rescale(diff_moving[0], 0., 1.)) plot_diff_overlayed.set_data(ctfi.rescale(diff[0], 0., 1.)) moving_gradients = learning_rate * np.squeeze( accumulated_gradients[0][0]) moving_points = np.squeeze(gradients[0][1]) target_gradients = learning_rate * np.squeeze( accumulated_gradients[1][0]) target_points = np.squeeze(gradients[1][1]) plot_scatter_moving.set_data(moving_points[:, 0], moving_points[:, 1]) plot_scatter_target.set_data(target_points[:, 0], target_points[:, 1]) plot_moving_grad.remove() plot_moving_grad = ax[1, 2].quiver( moving_points[:, 0], # X moving_points[:, 1], # Y moving_gradients[:, 0], moving_gradients[:, 1], moving_gradients, units='xy', angles='xy', scale_units='xy', scale=1) plot_target_grad.remove() plot_target_grad = ax[1, 0].quiver( target_points[:, 0], # X target_points[:, 1], # Y target_gradients[:, 0], target_gradients[:, 1], target_gradients, units='xy', angles='xy', scale_units='xy', scale=1) fig.canvas.draw() fig.canvas.flush_events() plt.show() accumulated_gradients.fill(0) print("Done!") plt.ioff() plt.show() sys.exit(0)
def main(argv): parser = argparse.ArgumentParser( description='Compute codes and reconstructions for image.') parser.add_argument('export_dir', type=str, help='Path to saved model.') parser.add_argument( 'mean', type=str, help='Path to npy file holding mean for normalization.') parser.add_argument( 'variance', type=str, help='Path to npy file holding variance for normalization.') parser.add_argument('source_filename', type=str, help='Image file from which to extract patch.') parser.add_argument('source_image_size', type=int, nargs=2, help='Size of the input image, HW.') parser.add_argument('target_filename', type=str, help='Image file for which to create the heatmap.') parser.add_argument( 'target_image_size', type=int, nargs=2, help='Size of the input image for which to create heatmap, HW.') parser.add_argument('patch_size', type=int, help='Size of image patch.') parser.add_argument( '--method', dest='method', type=str, help= 'Method to use to measure similarity, one of KLD, SKLD, BD, HD, SQHD.') parser.add_argument( '--stain_code_size', type=int, dest='stain_code_size', default=0, help= 'Optional: Size of the stain code to use, which is skipped for similarity estimation' ) parser.add_argument('--rotate', type=float, dest='angle', default=0, help='Optional: rotation angle to rotate target image') parser.add_argument('--subsampling_factor', type=int, dest='subsampling_factor', default=1, help='Factor to subsample source and target image.') args = parser.parse_args() mean = np.load(args.mean) variance = np.load(args.variance) stddev = [np.math.sqrt(x) for x in variance] def denormalize(image): channels = [ np.expand_dims( image[:, :, channel] * stddev[channel] + mean[channel], -1) for channel in range(3) ] denormalized_image = ctfi.rescale(np.concatenate(channels, 2), 0.0, 1.0) return denormalized_image def normalize(image, name=None, num_channels=3): channels = [ tf.expand_dims( (image[:, :, :, channel] - mean[channel]) / stddev[channel], -1) for channel in range(num_channels) ] return tf.concat(channels, num_channels) latest_checkpoint = tf.train.latest_checkpoint(args.export_dir) saver = tf.train.import_meta_graph(latest_checkpoint + '.meta', import_scope='imported') config = tf.ConfigProto() run_options = tf.RunOptions(trace_level=tf.RunOptions.FULL_TRACE) run_options.report_tensor_allocations_upon_oom = True #config.gpu_options.allow_growth = True # Load image and extract patch from it and create distribution. source_image = ctfi.subsample( ctfi.load(args.source_filename, height=args.source_image_size[0], width=args.source_image_size[1]), args.subsampling_factor) args.source_image_size = list( map(lambda x: int(x / args.subsampling_factor), args.source_image_size)) #Load image for which to create the heatmap target_image = ctfi.subsample( ctfi.load(args.target_filename, height=args.target_image_size[0], width=args.target_image_size[1]), args.subsampling_factor) args.target_image_size = list( map(lambda x: int(x / args.subsampling_factor), args.target_image_size)) heatmap_size = list( map(lambda v: max(v[0], v[1]), zip(args.source_image_size, args.target_image_size))) source_image = tf.expand_dims( tf.image.resize_image_with_crop_or_pad(source_image, heatmap_size[0], heatmap_size[1]), 0) target_image = tf.expand_dims( tf.image.resize_image_with_crop_or_pad(target_image, heatmap_size[0], heatmap_size[1]), 0) num_patches = np.prod(heatmap_size, axis=0) possible_splits = cutil.get_divisors(num_patches) num_splits = possible_splits.pop(0) while num_patches / num_splits > 500 and len(possible_splits) > 0: num_splits = possible_splits.pop(0) split_size = int(num_patches / num_splits) X, Y = np.meshgrid(range(heatmap_size[1]), range(heatmap_size[0])) coords = np.concatenate([ np.expand_dims(Y.flatten(), axis=1), np.expand_dims(X.flatten(), axis=1) ], axis=1) #source_patches_placeholder = tf.placeholder(tf.float32, shape=[num_patches / num_splits, args.patch_size, args.patch_size, 3]) #target_patches_placeholder = tf.placeholder(tf.float32, shape=[num_patches / num_splits, args.patch_size, args.patch_size, 3]) #all_source_patches = ctfi.extract_patches(source_image, args.patch_size, strides=[1,1,1,1], padding='SAME') #all_target_patches = ctfi.extract_patches(target_image, args.patch_size, strides=[1,1,1,1], padding='SAME') #source_patches = tf.split(all_source_patches, num_splits) #target_patches = tf.split(all_target_patches, num_splits) #patches = zip(source_patches, target_patches) coords_placeholder = tf.placeholder(tf.float32, shape=[split_size, 2]) source_patches_placeholder = tf.squeeze( tf.map_fn(lambda x: get_patch_at(x, source_image, args.patch_size), coords_placeholder, parallel_iterations=8, back_prop=False)) target_patches_placeholder = tf.squeeze( tf.map_fn(lambda x: get_patch_at(x, target_image, args.patch_size), coords_placeholder, parallel_iterations=8, back_prop=False)) heatmap = np.ndarray(heatmap_size) with tf.Session(graph=tf.get_default_graph(), config=config).as_default() as sess: source_patches_cov, source_patches_mean = tf.contrib.graph_editor.graph_replace( [ sess.graph.get_tensor_by_name( 'imported/z_log_sigma_sq/BiasAdd:0'), sess.graph.get_tensor_by_name('imported/z_mean/BiasAdd:0') ], { sess.graph.get_tensor_by_name('imported/patch:0'): normalize(source_patches_placeholder) }) source_patches_distribution = tf.contrib.distributions.MultivariateNormalDiag( source_patches_mean[:, args.stain_code_size:], tf.exp(source_patches_cov[:, args.stain_code_size:])) target_patches_cov, target_patches_mean = tf.contrib.graph_editor.graph_replace( [ sess.graph.get_tensor_by_name( 'imported/z_log_sigma_sq/BiasAdd:0'), sess.graph.get_tensor_by_name('imported/z_mean/BiasAdd:0') ], { sess.graph.get_tensor_by_name('imported/patch:0'): normalize(target_patches_placeholder) }) target_patches_distribution = tf.contrib.distributions.MultivariateNormalDiag( target_patches_mean[:, args.stain_code_size:], tf.exp(target_patches_cov[:, args.stain_code_size:])) similarity = source_patches_distribution.kl_divergence( target_patches_distribution ) + target_patches_distribution.kl_divergence( source_patches_distribution) #similarity = ctf.bhattacharyya_distance(source_patches_distribution, target_patches_distribution) sess.run(tf.global_variables_initializer()) sess.run(tf.local_variables_initializer()) saver.restore(sess, latest_checkpoint) for i in range(num_splits): start = i * split_size end = start + split_size batch_coords = coords[start:end, :] feed_dict = {coords_placeholder: batch_coords} similarity_values = sess.run(similarity, feed_dict=feed_dict, options=run_options) #heatmap.extend(similarity_values) for idx, val in zip(batch_coords, similarity_values): heatmap[idx[0], idx[1]] = val heatmap_sad = sess.run( tf.reduce_mean(tf.squared_difference(source_image, target_image), axis=3))[0] #sim_heatmap = np.reshape(heatmap, heatmap_size, order='C') sim_heatmap = heatmap fig_images, ax_images = plt.subplots(1, 2) ax_images[0].imshow(sess.run(source_image)[0]) ax_images[1].imshow(sess.run(target_image)[0]) fig_similarities, ax_similarities = plt.subplots(1, 2) heatmap_skld_plot = ax_similarities[0].imshow(sim_heatmap, cmap='plasma') heatmap_sad_plot = ax_similarities[1].imshow(heatmap_sad, cmap='plasma') fig_similarities.colorbar(heatmap_skld_plot, ax=ax_similarities[0]) fig_similarities.colorbar(heatmap_sad_plot, ax=ax_similarities[1]) plt.show() sess.close() return 0