Exemple #1
0
 def func_encode(sample):
     x, y = sample
     features = {
         'val': ctf.float_feature([x]),
         'label': ctf.float_feature([y])
     }
     return tf.train.Example(features=tf.train.Features(feature=features))
 def _encode_func(sample):
     patch_np = sample[0].numpy().flatten()
     label_np = sample[1].numpy()
     return ctfd.encode({
         'patch': ctf.float_feature(patch_np),
         'label': ctf.int64_feature(label_np)
     })
Exemple #3
0
        def dist_kl(X, Y):
            X_mean = X[:, 0:structure_code_size]
            X_cov = tf.exp(X[:, structure_code_size:])
            Y_mean = Y[:, 0:structure_code_size]
            Y_cov = tf.exp(Y[:, structure_code_size:])

            return ctf.fast_symmetric_kl_div(X_mean, X_cov, Y_mean, Y_cov)
def main(argv):
    parser = argparse.ArgumentParser(description='Compute similarity heatmaps of windows around landmarks.')
    parser.add_argument('export_dir',type=str,help='Path to saved model.')
    parser.add_argument('mean', type=str, help='Path to npy file holding mean for normalization.')
    parser.add_argument('variance', type=str, help='Path to npy file holding variance for normalization.')
    parser.add_argument('source_filename', type=str,help='Image file from which to extract patch.')
    parser.add_argument('source_image_size', type=int, nargs=2, help='Size of the input image, HW.')
    parser.add_argument('source_landmarks', type=str,help='CSV file from which to extract the landmarks for source image.')
    parser.add_argument('target_filename', type=str,help='Image file for which to create the heatmap.')
    parser.add_argument('target_image_size', type=int, nargs=2, help='Size of the input image for which to create heatmap, HW.')
    parser.add_argument('target_landmarks', type=str,help='CSV file from which to extract the landmarks for target image.')
    parser.add_argument('patch_size', type=int, help='Size of image patch.')
    parser.add_argument('output', type=str)
    parser.add_argument('--method', dest='method', type=str, help='Method to use to measure similarity, one of KLD, SKLD, BD, HD, SQHD.')
    parser.add_argument('--stain_code_size', type=int, dest='stain_code_size', default=0,
        help='Optional: Size of the stain code to use, which is skipped for similarity estimation')
    parser.add_argument('--rotate', type=float, dest='angle', default=0,
        help='Optional: rotation angle to rotate target image')
    parser.add_argument('--subsampling_factor', type=int, dest='subsampling_factor', default=1, help='Factor to subsample source and target image.')
    parser.add_argument('--region_size', type=int, default=64)
    args = parser.parse_args()

    mean = np.load(args.mean)
    variance = np.load(args.variance)
    stddev = [np.math.sqrt(x) for x in variance]

    def denormalize(image):
        channels = [np.expand_dims(image[:,:,channel] * stddev[channel] + mean[channel],-1) for channel in range(3)]
        denormalized_image = ctfi.rescale(np.concatenate(channels, 2), 0.0, 1.0)
        return denormalized_image

    def normalize(image, name=None, num_channels=3):
        channels = [tf.expand_dims((image[:,:,:,channel] - mean[channel]) / stddev[channel],-1) for channel in range(num_channels)]
        return tf.concat(channels, num_channels)

    latest_checkpoint = tf.train.latest_checkpoint(args.export_dir)   
    saver = tf.train.import_meta_graph(latest_checkpoint + '.meta', import_scope='imported')

    config = tf.ConfigProto()
    config.allow_soft_placement=True
    #config.log_device_placement=True

    # Load image and extract patch from it and create distribution.
    source_image = tf.expand_dims(ctfi.subsample(ctfi.load(args.source_filename,height=args.source_image_size[0], width=args.source_image_size[1]),args.subsampling_factor),0)
    args.source_image_size = list(map(lambda x: int(x / args.subsampling_factor), args.source_image_size))

    #Load image for which to create the heatmap
    target_image = tf.expand_dims(ctfi.subsample(ctfi.load(args.target_filename,height=args.target_image_size[0], width=args.target_image_size[1]),args.subsampling_factor),0)
    args.target_image_size = list(map(lambda x: int(x / args.subsampling_factor), args.target_image_size))

    source_landmarks = get_landmarks(args.source_landmarks, args.subsampling_factor)
    target_landmarks = get_landmarks(args.target_landmarks, args.subsampling_factor)

    region_size = args.region_size
    region_center = [int(region_size / 2),int(region_size / 2)]
    num_patches = region_size**2

    possible_splits = cutil.get_divisors(num_patches)
    num_splits = possible_splits.pop(0)

    while num_patches / num_splits > 512 and len(possible_splits) > 0:
        num_splits = possible_splits.pop(0)

    split_size = int(num_patches / num_splits)

    offset = 64
    center_idx = np.prod(region_center)

    X, Y = np.meshgrid(range(offset, region_size + offset), range(offset, region_size + offset))
    coords = np.concatenate([np.expand_dims(Y.flatten(),axis=1),np.expand_dims(X.flatten(),axis=1)],axis=1)

    coords_placeholder = tf.placeholder(tf.float32, shape=[split_size, 2])

    source_landmark_placeholder = tf.placeholder(tf.float32, shape=[1, 2])
    target_landmark_placeholder = tf.placeholder(tf.float32, shape=[1, 2])

    source_image_region = tf.image.extract_glimpse(source_image,[region_size + 2*offset, region_size+ 2*offset], source_landmark_placeholder, normalized=False, centered=False)
    target_image_region = tf.image.extract_glimpse(target_image,[region_size + 2*offset, region_size+ 2*offset], target_landmark_placeholder, normalized=False, centered=False)

    source_patches_placeholder = tf.map_fn(lambda x: get_patch_at(x, source_image, args.patch_size), source_landmark_placeholder, parallel_iterations=8, back_prop=False)[0]
    target_patches_placeholder = tf.squeeze(tf.map_fn(lambda x: get_patch_at(x, target_image_region, args.patch_size), coords_placeholder, parallel_iterations=8, back_prop=False))


    with tf.Session(config=config).as_default() as sess:
        saver.restore(sess, latest_checkpoint)

        source_patches_cov, source_patches_mean = tf.contrib.graph_editor.graph_replace([sess.graph.get_tensor_by_name('imported/z_log_sigma_sq/BiasAdd:0'),sess.graph.get_tensor_by_name('imported/z_mean/BiasAdd:0')] ,{ sess.graph.get_tensor_by_name('imported/patch:0'): normalize(source_patches_placeholder) })
        source_patches_distribution = tf.contrib.distributions.MultivariateNormalDiag(source_patches_mean[:,args.stain_code_size:], tf.exp(source_patches_cov[:,args.stain_code_size:]))
        
        target_patches_cov, target_patches_mean = tf.contrib.graph_editor.graph_replace([sess.graph.get_tensor_by_name('imported/z_log_sigma_sq/BiasAdd:0'),sess.graph.get_tensor_by_name('imported/z_mean/BiasAdd:0')] ,{ sess.graph.get_tensor_by_name('imported/patch:0'): normalize(target_patches_placeholder) })
        target_patches_distribution = tf.contrib.distributions.MultivariateNormalDiag(target_patches_mean[:,args.stain_code_size:], tf.exp(target_patches_cov[:,args.stain_code_size:]))

        similarities_skld = source_patches_distribution.kl_divergence(target_patches_distribution) + target_patches_distribution.kl_divergence(source_patches_distribution)
        similarities_bd = ctf.bhattacharyya_distance(source_patches_distribution, target_patches_distribution)
        similarities_sad = tf.reduce_sum(tf.abs(source_patches_placeholder - target_patches_placeholder), axis=[1,2,3])

        source_patches_grayscale = tf.image.rgb_to_grayscale(source_patches_placeholder)
        target_patches_grayscale = tf.image.rgb_to_grayscale(target_patches_placeholder)

        similarities_nmi = tf.map_fn(lambda x: nmi_tf(tf.squeeze(source_patches_grayscale), tf.squeeze(x), 20), target_patches_grayscale)

        with open(args.output + "_" + str(region_size) + ".csv",'wt') as outfile:
            fp = csv.DictWriter(outfile, ["method", "landmark", "min_idx", "min_idx_value", "rank", "landmark_value"])
            methods = ["SKLD", "BD", "SAD", "MI"]
            fp.writeheader()
            
            results = []

            for k in range(len(source_landmarks)):

                heatmap_fused = np.ndarray((region_size, region_size, len(methods)))
                feed_dict={source_landmark_placeholder: [source_landmarks[k,:]], target_landmark_placeholder: [target_landmarks[k,:]] }
                
                for i in range(num_splits):
                    start = i * split_size
                    end = start + split_size
                    batch_coords = coords[start:end,:]

                    feed_dict.update({coords_placeholder: batch_coords})

                    similarity_values = np.array(sess.run([similarities_skld,similarities_bd, similarities_sad, similarities_nmi],feed_dict=feed_dict)).transpose()
                    #heatmap.extend(similarity_values)
                    for idx, val in zip(batch_coords, similarity_values):
                        heatmap_fused[idx[0] - offset, idx[1] - offset] = val

                for c in range(len(methods)):
                    heatmap = heatmap_fused[:,:,c]
                    if c == 3:
                        min_idx = np.unravel_index(np.argmax(heatmap),heatmap.shape)
                        min_indices = np.array(np.unravel_index(list(reversed(np.argsort(heatmap.flatten()))),heatmap.shape)).transpose().tolist()
                    else:
                        min_idx = np.unravel_index(np.argmin(heatmap),heatmap.shape)
                        min_indices = np.array(np.unravel_index(np.argsort(heatmap.flatten()),heatmap.shape)).transpose().tolist()

                    landmark_value = heatmap[region_center[0], region_center[1]]
                    rank = min_indices.index(region_center)

                    fp.writerow({"method": methods[c],"landmark": k, "min_idx": min_idx, "min_idx_value": heatmap[min_idx[0], min_idx[1]],"rank": rank , "landmark_value": landmark_value})
                    #matplotlib.image.imsave(args.output + "_" + str(region_size)+ "_"+ methods[c] + "_" + str(k) + ".jpeg", heatmap, cmap='plasma')
                outfile.flush()

                print(min_idx, rank)
        
        
            fp.writerows(results)


        sess.close()
        
    return 0
 def func_encode(sample):
     feature = { 'filename': ctf.string_feature(sample) }
     return tf.train.Example(features=tf.train.Features(feature=feature))
def main(argv):
    parser = argparse.ArgumentParser(
        description='Compute codes and reconstructions for image.')
    parser.add_argument('export_dir', type=str, help='Path to saved model.')
    parser.add_argument(
        'mean',
        type=str,
        help='Path to npy file holding mean for normalization.')
    parser.add_argument(
        'variance',
        type=str,
        help='Path to npy file holding variance for normalization.')
    parser.add_argument('source_filename',
                        type=str,
                        help='Image file from which to extract patch.')
    parser.add_argument('source_image_size',
                        type=int,
                        nargs=2,
                        help='Size of the input image, HW.')
    parser.add_argument('offsets',
                        type=int,
                        nargs=2,
                        help='Position where to extract the patch.')
    parser.add_argument('patch_size', type=int, help='Size of image patch.')
    parser.add_argument('target_filename',
                        type=str,
                        help='Image file for which to create the heatmap.')
    parser.add_argument(
        'target_image_size',
        type=int,
        nargs=2,
        help='Size of the input image for which to create heatmap, HW.')
    parser.add_argument(
        'method',
        type=str,
        help=
        'Method to use to measure similarity, one of KLD, SKLD, BD, HD, SQHD.')
    parser.add_argument(
        '--stain_code_size',
        type=int,
        dest='stain_code_size',
        default=0,
        help=
        'Optional: Size of the stain code to use, which is skipped for similarity estimation'
    )
    parser.add_argument('--rotate',
                        type=float,
                        dest='angle',
                        default=0,
                        help='Optional: rotation angle to rotate target image')
    parser.add_argument('--subsampling_factor',
                        type=int,
                        dest='subsampling_factor',
                        default=1,
                        help='Factor to subsample source and target image.')
    args = parser.parse_args()

    mean = np.load(args.mean)
    variance = np.load(args.variance)
    stddev = [np.math.sqrt(x) for x in variance]

    def denormalize(image):
        channels = [
            np.expand_dims(
                image[:, :, channel] * stddev[channel] + mean[channel], -1)
            for channel in range(3)
        ]
        denormalized_image = ctfi.rescale(np.concatenate(channels, 2), 0.0,
                                          1.0)
        return denormalized_image

    def normalize(image, name=None):
        channels = [
            tf.expand_dims(
                (image[:, :, :, channel] - mean[channel]) / stddev[channel],
                -1) for channel in range(3)
        ]
        return tf.concat(channels, 3, name=name)

    latest_checkpoint = tf.train.latest_checkpoint(args.export_dir)
    saver = tf.train.import_meta_graph(latest_checkpoint + '.meta',
                                       import_scope='imported')

    with tf.Session(graph=tf.get_default_graph()).as_default() as sess:

        # Load image and extract patch from it and create distribution.
        source_image = ctfi.subsample(
            ctfi.load(args.source_filename,
                      height=args.source_image_size[0],
                      width=args.source_image_size[1]),
            args.subsampling_factor)
        args.source_image_size = list(
            map(lambda x: int(x / args.subsampling_factor),
                args.source_image_size))
        patch = normalize(
            tf.expand_dims(
                tf.image.crop_to_bounding_box(source_image, args.offsets[0],
                                              args.offsets[1], args.patch_size,
                                              args.patch_size), 0))
        #patch_cov, patch_mean = tf.contrib.graph_editor.graph_replace([sess.graph.get_tensor_by_name('imported/z_covariance_lower_tri/MatrixBandPart:0'),sess.graph.get_tensor_by_name('imported/z_mean/BiasAdd:0')] ,{ sess.graph.get_tensor_by_name('imported/patch:0'): patch })
        #patch_distribution = tf.contrib.distributions.MultivariateNormalTriL(loc=patch_mean[:,args.stain_code_size:], scale_tril=patch_cov[:,args.stain_code_size:,args.stain_code_size:])

        patch_cov, patch_mean = tf.contrib.graph_editor.graph_replace([
            sess.graph.get_tensor_by_name('imported/z_log_sigma_sq/BiasAdd:0'),
            sess.graph.get_tensor_by_name('imported/z_mean/BiasAdd:0')
        ], {sess.graph.get_tensor_by_name('imported/patch:0'):
            patch})
        patch_distribution = tf.contrib.distributions.MultivariateNormalDiag(
            patch_mean[:, args.stain_code_size:],
            tf.sqrt(tf.exp(patch_cov[:, args.stain_code_size:])))

        sim_vals = []

        #Load image for which to create the heatmap
        target_image = ctfi.subsample(
            ctfi.load(args.target_filename,
                      height=args.target_image_size[0],
                      width=args.target_image_size[1]),
            args.subsampling_factor)
        target_image = tf.contrib.image.rotate(target_image,
                                               np.radians(args.angle))
        args.target_image_size = list(
            map(lambda x: int(x / args.subsampling_factor),
                args.target_image_size))

        heatmap_height = args.target_image_size[0] - (args.patch_size - 1)
        heatmap_width = args.target_image_size[1] - (args.patch_size - 1)

        # Compute byte size as: width*height*channels*sizeof(float32)
        patch_size_in_byte = args.patch_size**2 * 3 * 4
        max_patches = int(max_patch_buffer_size / patch_size_in_byte)
        max_num_rows = int(max_patches / heatmap_width)
        max_chunk_size = int(max_buffer_size_in_byte / patch_size_in_byte)

        #Iteration over image regions that we can load
        num_iterations = int(args.target_image_size[0] / max_num_rows) + 1

        all_chunks = list()
        all_similarities = list()
        chunk_tensors = list()

        chunk_sizes = np.zeros(num_iterations, dtype=np.int)
        chunk_sizes.fill(heatmap_width)
        for i in range(num_iterations):
            processed_rows = i * max_num_rows
            rows_to_load = min(max_num_rows + (args.patch_size - 1),
                               args.target_image_size[0] - processed_rows)
            if rows_to_load < args.patch_size:
                break

            # Extract region for which we can compute patches
            target_image_region = tf.image.crop_to_bounding_box(
                target_image, processed_rows, 0, rows_to_load,
                args.target_image_size[1])

            # Size = (image_width - patch_size - 1) * (image_height - patch_size - 1) for 'VALID' padding and
            # image_width * image_height for 'SAME' padding
            all_image_patches = tf.unstack(
                normalize(
                    ctfi.extract_patches(target_image_region,
                                         args.patch_size,
                                         strides=[1, 1, 1, 1],
                                         padding='VALID')))

            possible_chunk_sizes = get_divisors(len(all_image_patches))

            for size in possible_chunk_sizes:
                if size < max_chunk_size:
                    chunk_sizes[i] = size
                    break

            # Partition patches into chunks
            chunked_patches = list(
                create_chunks(all_image_patches, chunk_sizes[i]))
            chunked_patches = list(map(tf.stack, chunked_patches))
            all_chunks.append(chunked_patches)

            #last_chunk = chunked_patches.pop()
            #last_chunk_size = last_chunk.get_shape().as_list()[0]
            #padding_size = chunk_size - last_chunk_size
            #last_chunk_padded = tf.concat([last_chunk, tf.ones([padding_size, args.patch_size, args.patch_size, 3],dtype=tf.float32)],0)

            chunk_tensor = tf.placeholder(
                tf.float32,
                shape=[chunk_sizes[i], args.patch_size, args.patch_size, 3],
                name='chunk_tensor_placeholder')
            chunk_tensors.append(chunk_tensor)

            #image_patches_cov, image_patches_mean = tf.contrib.graph_editor.graph_replace([sess.graph.get_tensor_by_name('imported/z_covariance_lower_tri/MatrixBandPart:0'),sess.graph.get_tensor_by_name('imported/z_mean/BiasAdd:0')] ,{ sess.graph.get_tensor_by_name('imported/patch:0'): chunk_tensor })
            #image_patches_distributions = tf.contrib.distributions.MultivariateNormalTriL(loc=image_patches_mean[:,args.stain_code_size:], scale_tril=image_patches_cov[:,args.stain_code_size:,args.stain_code_size:])

            image_patches_cov, image_patches_mean = tf.contrib.graph_editor.graph_replace(
                [
                    sess.graph.get_tensor_by_name(
                        'imported/z_log_sigma_sq/BiasAdd:0'),
                    sess.graph.get_tensor_by_name('imported/z_mean/BiasAdd:0')
                ], {
                    sess.graph.get_tensor_by_name('imported/patch:0'):
                    chunk_tensor
                })
            image_patches_distributions = tf.contrib.distributions.MultivariateNormalDiag(
                image_patches_mean[:, args.stain_code_size:],
                tf.sqrt(tf.exp(image_patches_cov[:, args.stain_code_size:])))

            if args.method == 'SKLD':
                similarities = patch_distribution.kl_divergence(
                    image_patches_distributions
                ) + image_patches_distributions.kl_divergence(
                    patch_distribution)
            elif args.method == 'BD':
                similarities = ctf.bhattacharyya_distance(
                    patch_distribution, image_patches_distributions)
            elif args.method == 'SQHD':
                similarities = ctf.multivariate_squared_hellinger_distance(
                    patch_distribution, image_patches_distributions)
            elif args.method == 'HD':
                similarities = tf.sqrt(
                    ctf.multivariate_squared_hellinger_distance(
                        patch_distribution, image_patches_distributions))
            else:
                similarities = patch_distribution.kl_divergence(
                    image_patches_distributions)

            all_similarities.append(similarities)

        sess.run(tf.global_variables_initializer())
        sess.run(tf.local_variables_initializer())
        saver.restore(sess, latest_checkpoint)

        for i in range(len(all_chunks)):
            for chunk in all_chunks[i]:
                #chunk_vals = sess.run(all_similarities[i], feed_dict={chunk_tensors[i]: sess.run(chunk)})
                sim_vals.extend(
                    sess.run(all_similarities[i],
                             feed_dict={chunk_tensors[i]: sess.run(chunk)}))

        print(len(sim_vals))
        sim_heatmap = np.reshape(sim_vals, [heatmap_height, heatmap_width])
        heatmap_tensor = tf.expand_dims(
            tf.expand_dims(tf.convert_to_tensor(sim_heatmap), -1), 0)
        dy, dx = tf.image.image_gradients(heatmap_tensor)
        sim_vals_normalized = 1.0 - ctfi.rescale(sim_heatmap, 0.0, 1.0)

        k_min = 20
        min_indices = np.unravel_index(
            np.argsort(sim_vals)[:k_min], sim_heatmap.shape)
        fig_min, ax_min = plt.subplots(4, 5)

        for i in range(k_min):
            target_patch = tf.image.crop_to_bounding_box(
                target_image, min_indices[0][i], min_indices[1][i],
                args.patch_size, args.patch_size)
            ax_min[int(i / 5), int(i % 5)].imshow(sess.run(target_patch))
            ax_min[int(i / 5),
                   int(i % 5)].set_title('y:' + str(min_indices[0][i]) +
                                         ', x:' + str(min_indices[1][i]))

        fig, ax = plt.subplots(2, 3)
        cmap = 'plasma'

        denormalized_patch = denormalize(sess.run(patch)[0])
        max_sim_val = np.max(sim_vals)
        max_idx = np.unravel_index(np.argmin(sim_heatmap), sim_heatmap.shape)

        target_image_patch = tf.image.crop_to_bounding_box(
            target_image, max_idx[0], max_idx[1], args.patch_size,
            args.patch_size)
        print(max_idx)

        print(min_indices)
        ax[1, 0].imshow(sess.run(source_image))
        ax[1, 1].imshow(sess.run(target_image))
        ax[0, 0].imshow(denormalized_patch)
        heatmap_image = ax[0, 2].imshow(sim_heatmap, cmap=cmap)
        ax[0, 1].imshow(sess.run(target_image_patch))
        #dx_image = ax[0,2].imshow(np.squeeze(sess.run(dx)), cmap='bwr')
        #dy_image = ax[1,2].imshow(np.squeeze(sess.run(dy)), cmap='bwr')
        gradient_image = ax[1, 2].imshow(np.squeeze(sess.run(dx + dy)),
                                         cmap='bwr')

        fig.colorbar(heatmap_image, ax=ax[0, 2])
        #fig.colorbar(dx_image, ax=ax[0,2])
        #fig.colorbar(dy_image, ax=ax[1,2])
        fig.colorbar(gradient_image, ax=ax[1, 2])

        plt.show()
        sess.close()
    print("Done!")
def main(argv):
    parser = argparse.ArgumentParser(
        description='Compute codes and reconstructions for image.')
    parser.add_argument('export_dir', type=str, help='Path to saved model.')
    parser.add_argument(
        'mean',
        type=str,
        help='Path to npy file holding mean for normalization.')
    parser.add_argument(
        'variance',
        type=str,
        help='Path to npy file holding variance for normalization.')
    parser.add_argument('source_filename',
                        type=str,
                        help='Image file from which to extract patch.')
    parser.add_argument('source_image_size',
                        type=int,
                        nargs=2,
                        help='Size of the input image, HW.')
    parser.add_argument(
        'source_landmarks',
        type=str,
        help='CSV file from which to extract the landmarks for source image.')
    parser.add_argument('target_filename',
                        type=str,
                        help='Image file for which to create the heatmap.')
    parser.add_argument(
        'target_image_size',
        type=int,
        nargs=2,
        help='Size of the input image for which to create heatmap, HW.')
    parser.add_argument(
        'target_landmarks',
        type=str,
        help='CSV file from which to extract the landmarks for target image.')
    parser.add_argument('patch_size', type=int, help='Size of image patch.')
    parser.add_argument(
        '--method',
        dest='method',
        type=str,
        help=
        'Method to use to measure similarity, one of KLD, SKLD, BD, HD, SQHD.')
    parser.add_argument(
        '--stain_code_size',
        type=int,
        dest='stain_code_size',
        default=0,
        help=
        'Optional: Size of the stain code to use, which is skipped for similarity estimation'
    )
    parser.add_argument('--rotate',
                        type=float,
                        dest='angle',
                        default=0,
                        help='Optional: rotation angle to rotate target image')
    parser.add_argument('--subsampling_factor',
                        type=int,
                        dest='subsampling_factor',
                        default=1,
                        help='Factor to subsample source and target image.')
    args = parser.parse_args()

    mean = np.load(args.mean)
    variance = np.load(args.variance)
    stddev = [np.math.sqrt(x) for x in variance]

    def denormalize(image):
        channels = [
            np.expand_dims(
                image[:, :, channel] * stddev[channel] + mean[channel], -1)
            for channel in range(3)
        ]
        denormalized_image = ctfi.rescale(np.concatenate(channels, 2), 0.0,
                                          1.0)
        return denormalized_image

    def normalize(image, name=None, num_channels=3):
        channels = [
            tf.expand_dims(
                (image[:, :, :, channel] - mean[channel]) / stddev[channel],
                -1) for channel in range(num_channels)
        ]
        return tf.concat(channels, num_channels)

    latest_checkpoint = tf.train.latest_checkpoint(args.export_dir)
    saver = tf.train.import_meta_graph(latest_checkpoint + '.meta',
                                       import_scope='imported')

    config = tf.ConfigProto()
    config.allow_soft_placement = True
    #config.log_device_placement=True

    # Load image and extract patch from it and create distribution.
    source_image = tf.expand_dims(
        ctfi.subsample(
            ctfi.load(args.source_filename,
                      height=args.source_image_size[0],
                      width=args.source_image_size[1]),
            args.subsampling_factor), 0)
    args.source_image_size = list(
        map(lambda x: int(x / args.subsampling_factor),
            args.source_image_size))

    #Load image for which to create the heatmap
    target_image = tf.expand_dims(
        ctfi.subsample(
            ctfi.load(args.target_filename,
                      height=args.target_image_size[0],
                      width=args.target_image_size[1]),
            args.subsampling_factor), 0)
    args.target_image_size = list(
        map(lambda x: int(x / args.subsampling_factor),
            args.target_image_size))

    source_landmarks = get_landmarks(args.source_landmarks,
                                     args.subsampling_factor)
    source_patches = tf.squeeze(
        tf.map_fn(lambda x: get_patch_at(x, source_image, args.patch_size),
                  source_landmarks))

    target_landmarks = get_landmarks(args.target_landmarks,
                                     args.subsampling_factor)
    target_patches = tf.squeeze(
        tf.map_fn(lambda x: get_patch_at(x, target_image, args.patch_size),
                  target_landmarks))

    with tf.Session(config=config).as_default() as sess:
        saver.restore(sess, latest_checkpoint)

        source_patches_cov, source_patches_mean = tf.contrib.graph_editor.graph_replace(
            [
                sess.graph.get_tensor_by_name(
                    'imported/z_log_sigma_sq/BiasAdd:0'),
                sess.graph.get_tensor_by_name('imported/z_mean/BiasAdd:0')
            ], {
                sess.graph.get_tensor_by_name('imported/patch:0'):
                normalize(source_patches)
            })
        source_patches_distribution = tf.contrib.distributions.MultivariateNormalDiag(
            source_patches_mean[:, args.stain_code_size:],
            tf.exp(source_patches_cov[:, args.stain_code_size:]))

        target_patches_cov, target_patches_mean = tf.contrib.graph_editor.graph_replace(
            [
                sess.graph.get_tensor_by_name(
                    'imported/z_log_sigma_sq/BiasAdd:0'),
                sess.graph.get_tensor_by_name('imported/z_mean/BiasAdd:0')
            ], {
                sess.graph.get_tensor_by_name('imported/patch:0'):
                normalize(target_patches)
            })
        target_patches_distribution = tf.contrib.distributions.MultivariateNormalDiag(
            target_patches_mean[:, args.stain_code_size:],
            tf.exp(target_patches_cov[:, args.stain_code_size:]))

        #similarities = source_patches_distribution.kl_divergence(target_patches_distribution) + target_patches_distribution.kl_divergence(source_patches_distribution)
        #similarities = ctf.multivariate_squared_hellinger_distance(source_patches_distribution, target_patches_distribution)
        similarities = ctf.bhattacharyya_distance(source_patches_distribution,
                                                  target_patches_distribution)
        #similarities = tf.reduce_sum(tf.abs(source_patches - target_patches), axis=[1,2,3])

        sim_vals = sess.run(similarities)
        min_idx = np.argmin(sim_vals)
        max_idx = np.argmax(sim_vals)
        print(sim_vals)
        print(min_idx, sim_vals[min_idx])
        print(max_idx, sim_vals[max_idx])

        fig, ax = plt.subplots(2, 3)
        ax[0, 0].imshow(sess.run(source_image[0]))
        ax[0, 1].imshow(sess.run(source_patches)[min_idx])
        ax[0, 2].imshow(sess.run(source_patches)[max_idx])
        ax[1, 0].imshow(sess.run(target_image[0]))
        ax[1, 1].imshow(sess.run(target_patches)[min_idx])
        ax[1, 2].imshow(sess.run(target_patches)[max_idx])
        plt.show()

        sess.close()

    return 0