def main(argv):
    parser = argparse.ArgumentParser(description='Compute similarity heatmaps of windows around landmarks.')
    parser.add_argument('export_dir',type=str,help='Path to saved model.')
    parser.add_argument('mean', type=str, help='Path to npy file holding mean for normalization.')
    parser.add_argument('variance', type=str, help='Path to npy file holding variance for normalization.')
    parser.add_argument('source_filename', type=str,help='Image file from which to extract patch.')
    parser.add_argument('source_image_size', type=int, nargs=2, help='Size of the input image, HW.')
    parser.add_argument('source_landmarks', type=str,help='CSV file from which to extract the landmarks for source image.')
    parser.add_argument('target_filename', type=str,help='Image file for which to create the heatmap.')
    parser.add_argument('target_image_size', type=int, nargs=2, help='Size of the input image for which to create heatmap, HW.')
    parser.add_argument('target_landmarks', type=str,help='CSV file from which to extract the landmarks for target image.')
    parser.add_argument('patch_size', type=int, help='Size of image patch.')
    parser.add_argument('output', type=str)
    parser.add_argument('--method', dest='method', type=str, help='Method to use to measure similarity, one of KLD, SKLD, BD, HD, SQHD.')
    parser.add_argument('--stain_code_size', type=int, dest='stain_code_size', default=0,
        help='Optional: Size of the stain code to use, which is skipped for similarity estimation')
    parser.add_argument('--rotate', type=float, dest='angle', default=0,
        help='Optional: rotation angle to rotate target image')
    parser.add_argument('--subsampling_factor', type=int, dest='subsampling_factor', default=1, help='Factor to subsample source and target image.')
    parser.add_argument('--region_size', type=int, default=64)
    args = parser.parse_args()

    mean = np.load(args.mean)
    variance = np.load(args.variance)
    stddev = [np.math.sqrt(x) for x in variance]

    def denormalize(image):
        channels = [np.expand_dims(image[:,:,channel] * stddev[channel] + mean[channel],-1) for channel in range(3)]
        denormalized_image = ctfi.rescale(np.concatenate(channels, 2), 0.0, 1.0)
        return denormalized_image

    def normalize(image, name=None, num_channels=3):
        channels = [tf.expand_dims((image[:,:,:,channel] - mean[channel]) / stddev[channel],-1) for channel in range(num_channels)]
        return tf.concat(channels, num_channels)

    latest_checkpoint = tf.train.latest_checkpoint(args.export_dir)   
    saver = tf.train.import_meta_graph(latest_checkpoint + '.meta', import_scope='imported')

    config = tf.ConfigProto()
    config.allow_soft_placement=True
    #config.log_device_placement=True

    # Load image and extract patch from it and create distribution.
    source_image = tf.expand_dims(ctfi.subsample(ctfi.load(args.source_filename,height=args.source_image_size[0], width=args.source_image_size[1]),args.subsampling_factor),0)
    args.source_image_size = list(map(lambda x: int(x / args.subsampling_factor), args.source_image_size))

    #Load image for which to create the heatmap
    target_image = tf.expand_dims(ctfi.subsample(ctfi.load(args.target_filename,height=args.target_image_size[0], width=args.target_image_size[1]),args.subsampling_factor),0)
    args.target_image_size = list(map(lambda x: int(x / args.subsampling_factor), args.target_image_size))

    source_landmarks = get_landmarks(args.source_landmarks, args.subsampling_factor)
    target_landmarks = get_landmarks(args.target_landmarks, args.subsampling_factor)

    region_size = args.region_size
    region_center = [int(region_size / 2),int(region_size / 2)]
    num_patches = region_size**2

    possible_splits = cutil.get_divisors(num_patches)
    num_splits = possible_splits.pop(0)

    while num_patches / num_splits > 512 and len(possible_splits) > 0:
        num_splits = possible_splits.pop(0)

    split_size = int(num_patches / num_splits)

    offset = 64
    center_idx = np.prod(region_center)

    X, Y = np.meshgrid(range(offset, region_size + offset), range(offset, region_size + offset))
    coords = np.concatenate([np.expand_dims(Y.flatten(),axis=1),np.expand_dims(X.flatten(),axis=1)],axis=1)

    coords_placeholder = tf.placeholder(tf.float32, shape=[split_size, 2])

    source_landmark_placeholder = tf.placeholder(tf.float32, shape=[1, 2])
    target_landmark_placeholder = tf.placeholder(tf.float32, shape=[1, 2])

    source_image_region = tf.image.extract_glimpse(source_image,[region_size + 2*offset, region_size+ 2*offset], source_landmark_placeholder, normalized=False, centered=False)
    target_image_region = tf.image.extract_glimpse(target_image,[region_size + 2*offset, region_size+ 2*offset], target_landmark_placeholder, normalized=False, centered=False)

    source_patches_placeholder = tf.map_fn(lambda x: get_patch_at(x, source_image, args.patch_size), source_landmark_placeholder, parallel_iterations=8, back_prop=False)[0]
    target_patches_placeholder = tf.squeeze(tf.map_fn(lambda x: get_patch_at(x, target_image_region, args.patch_size), coords_placeholder, parallel_iterations=8, back_prop=False))


    with tf.Session(config=config).as_default() as sess:
        saver.restore(sess, latest_checkpoint)

        source_patches_cov, source_patches_mean = tf.contrib.graph_editor.graph_replace([sess.graph.get_tensor_by_name('imported/z_log_sigma_sq/BiasAdd:0'),sess.graph.get_tensor_by_name('imported/z_mean/BiasAdd:0')] ,{ sess.graph.get_tensor_by_name('imported/patch:0'): normalize(source_patches_placeholder) })
        source_patches_distribution = tf.contrib.distributions.MultivariateNormalDiag(source_patches_mean[:,args.stain_code_size:], tf.exp(source_patches_cov[:,args.stain_code_size:]))
        
        target_patches_cov, target_patches_mean = tf.contrib.graph_editor.graph_replace([sess.graph.get_tensor_by_name('imported/z_log_sigma_sq/BiasAdd:0'),sess.graph.get_tensor_by_name('imported/z_mean/BiasAdd:0')] ,{ sess.graph.get_tensor_by_name('imported/patch:0'): normalize(target_patches_placeholder) })
        target_patches_distribution = tf.contrib.distributions.MultivariateNormalDiag(target_patches_mean[:,args.stain_code_size:], tf.exp(target_patches_cov[:,args.stain_code_size:]))

        similarities_skld = source_patches_distribution.kl_divergence(target_patches_distribution) + target_patches_distribution.kl_divergence(source_patches_distribution)
        similarities_bd = ctf.bhattacharyya_distance(source_patches_distribution, target_patches_distribution)
        similarities_sad = tf.reduce_sum(tf.abs(source_patches_placeholder - target_patches_placeholder), axis=[1,2,3])

        source_patches_grayscale = tf.image.rgb_to_grayscale(source_patches_placeholder)
        target_patches_grayscale = tf.image.rgb_to_grayscale(target_patches_placeholder)

        similarities_nmi = tf.map_fn(lambda x: nmi_tf(tf.squeeze(source_patches_grayscale), tf.squeeze(x), 20), target_patches_grayscale)

        with open(args.output + "_" + str(region_size) + ".csv",'wt') as outfile:
            fp = csv.DictWriter(outfile, ["method", "landmark", "min_idx", "min_idx_value", "rank", "landmark_value"])
            methods = ["SKLD", "BD", "SAD", "MI"]
            fp.writeheader()
            
            results = []

            for k in range(len(source_landmarks)):

                heatmap_fused = np.ndarray((region_size, region_size, len(methods)))
                feed_dict={source_landmark_placeholder: [source_landmarks[k,:]], target_landmark_placeholder: [target_landmarks[k,:]] }
                
                for i in range(num_splits):
                    start = i * split_size
                    end = start + split_size
                    batch_coords = coords[start:end,:]

                    feed_dict.update({coords_placeholder: batch_coords})

                    similarity_values = np.array(sess.run([similarities_skld,similarities_bd, similarities_sad, similarities_nmi],feed_dict=feed_dict)).transpose()
                    #heatmap.extend(similarity_values)
                    for idx, val in zip(batch_coords, similarity_values):
                        heatmap_fused[idx[0] - offset, idx[1] - offset] = val

                for c in range(len(methods)):
                    heatmap = heatmap_fused[:,:,c]
                    if c == 3:
                        min_idx = np.unravel_index(np.argmax(heatmap),heatmap.shape)
                        min_indices = np.array(np.unravel_index(list(reversed(np.argsort(heatmap.flatten()))),heatmap.shape)).transpose().tolist()
                    else:
                        min_idx = np.unravel_index(np.argmin(heatmap),heatmap.shape)
                        min_indices = np.array(np.unravel_index(np.argsort(heatmap.flatten()),heatmap.shape)).transpose().tolist()

                    landmark_value = heatmap[region_center[0], region_center[1]]
                    rank = min_indices.index(region_center)

                    fp.writerow({"method": methods[c],"landmark": k, "min_idx": min_idx, "min_idx_value": heatmap[min_idx[0], min_idx[1]],"rank": rank , "landmark_value": landmark_value})
                    #matplotlib.image.imsave(args.output + "_" + str(region_size)+ "_"+ methods[c] + "_" + str(k) + ".jpeg", heatmap, cmap='plasma')
                outfile.flush()

                print(min_idx, rank)
        
        
            fp.writerows(results)


        sess.close()
        
    return 0
Beispiel #2
0
def main(argv):
    parser = argparse.ArgumentParser(
        description='Compute codes and reconstructions for image.')
    parser.add_argument('export_dir', type=str, help='Path to saved model.')
    parser.add_argument(
        'mean',
        type=str,
        help='Path to npy file holding mean for normalization.')
    parser.add_argument(
        'variance',
        type=str,
        help='Path to npy file holding variance for normalization.')
    parser.add_argument('source_filename',
                        type=str,
                        help='Image file from which to extract patch.')
    parser.add_argument('source_image_size',
                        type=int,
                        nargs=2,
                        help='Size of the input image, HW.')
    parser.add_argument('target_filename',
                        type=str,
                        help='Image file for which to create the heatmap.')
    parser.add_argument(
        'target_image_size',
        type=int,
        nargs=2,
        help='Size of the input image for which to create heatmap, HW.')
    parser.add_argument('patch_size', type=int, help='Size of image patch.')
    parser.add_argument(
        '--method',
        dest='method',
        type=str,
        help=
        'Method to use to measure similarity, one of KLD, SKLD, BD, HD, SQHD.')
    parser.add_argument(
        '--stain_code_size',
        type=int,
        dest='stain_code_size',
        default=0,
        help=
        'Optional: Size of the stain code to use, which is skipped for similarity estimation'
    )
    parser.add_argument('--rotate',
                        type=float,
                        dest='angle',
                        default=0,
                        help='Optional: rotation angle to rotate target image')
    parser.add_argument('--subsampling_factor',
                        type=int,
                        dest='subsampling_factor',
                        default=1,
                        help='Factor to subsample source and target image.')
    args = parser.parse_args()

    mean = np.load(args.mean)
    variance = np.load(args.variance)
    stddev = [np.math.sqrt(x) for x in variance]

    def denormalize(image):
        channels = [
            np.expand_dims(
                image[:, :, channel] * stddev[channel] + mean[channel], -1)
            for channel in range(3)
        ]
        denormalized_image = ctfi.rescale(np.concatenate(channels, 2), 0.0,
                                          1.0)
        return denormalized_image

    def normalize(image, name=None, num_channels=3):
        channels = [
            tf.expand_dims(
                (image[:, :, :, channel] - mean[channel]) / stddev[channel],
                -1) for channel in range(num_channels)
        ]
        return tf.concat(channels, num_channels)

    latest_checkpoint = tf.train.latest_checkpoint(args.export_dir)
    saver = tf.train.import_meta_graph(latest_checkpoint + '.meta',
                                       import_scope='imported')

    config = tf.ConfigProto()
    run_options = tf.RunOptions(trace_level=tf.RunOptions.FULL_TRACE)
    run_options.report_tensor_allocations_upon_oom = True
    #config.gpu_options.allow_growth = True

    # Load image and extract patch from it and create distribution.
    source_image = ctfi.subsample(
        ctfi.load(args.source_filename,
                  height=args.source_image_size[0],
                  width=args.source_image_size[1]), args.subsampling_factor)
    args.source_image_size = list(
        map(lambda x: int(x / args.subsampling_factor),
            args.source_image_size))

    #Load image for which to create the heatmap
    target_image = ctfi.subsample(
        ctfi.load(args.target_filename,
                  height=args.target_image_size[0],
                  width=args.target_image_size[1]), args.subsampling_factor)
    args.target_image_size = list(
        map(lambda x: int(x / args.subsampling_factor),
            args.target_image_size))

    heatmap_size = list(
        map(lambda v: max(v[0], v[1]),
            zip(args.source_image_size, args.target_image_size)))

    source_image = tf.expand_dims(
        tf.image.resize_image_with_crop_or_pad(source_image, heatmap_size[0],
                                               heatmap_size[1]), 0)
    target_image = tf.expand_dims(
        tf.image.resize_image_with_crop_or_pad(target_image, heatmap_size[0],
                                               heatmap_size[1]), 0)

    num_patches = np.prod(heatmap_size, axis=0)

    possible_splits = cutil.get_divisors(num_patches)
    num_splits = possible_splits.pop(0)

    while num_patches / num_splits > 500 and len(possible_splits) > 0:
        num_splits = possible_splits.pop(0)

    split_size = int(num_patches / num_splits)

    X, Y = np.meshgrid(range(heatmap_size[1]), range(heatmap_size[0]))

    coords = np.concatenate([
        np.expand_dims(Y.flatten(), axis=1),
        np.expand_dims(X.flatten(), axis=1)
    ],
                            axis=1)

    #source_patches_placeholder = tf.placeholder(tf.float32, shape=[num_patches / num_splits, args.patch_size, args.patch_size, 3])
    #target_patches_placeholder = tf.placeholder(tf.float32, shape=[num_patches / num_splits, args.patch_size, args.patch_size, 3])

    #all_source_patches = ctfi.extract_patches(source_image, args.patch_size, strides=[1,1,1,1], padding='SAME')
    #all_target_patches = ctfi.extract_patches(target_image, args.patch_size, strides=[1,1,1,1], padding='SAME')

    #source_patches = tf.split(all_source_patches, num_splits)
    #target_patches = tf.split(all_target_patches, num_splits)

    #patches = zip(source_patches, target_patches)

    coords_placeholder = tf.placeholder(tf.float32, shape=[split_size, 2])

    source_patches_placeholder = tf.squeeze(
        tf.map_fn(lambda x: get_patch_at(x, source_image, args.patch_size),
                  coords_placeholder,
                  parallel_iterations=8,
                  back_prop=False))
    target_patches_placeholder = tf.squeeze(
        tf.map_fn(lambda x: get_patch_at(x, target_image, args.patch_size),
                  coords_placeholder,
                  parallel_iterations=8,
                  back_prop=False))

    heatmap = np.ndarray(heatmap_size)

    with tf.Session(graph=tf.get_default_graph(),
                    config=config).as_default() as sess:
        source_patches_cov, source_patches_mean = tf.contrib.graph_editor.graph_replace(
            [
                sess.graph.get_tensor_by_name(
                    'imported/z_log_sigma_sq/BiasAdd:0'),
                sess.graph.get_tensor_by_name('imported/z_mean/BiasAdd:0')
            ], {
                sess.graph.get_tensor_by_name('imported/patch:0'):
                normalize(source_patches_placeholder)
            })
        source_patches_distribution = tf.contrib.distributions.MultivariateNormalDiag(
            source_patches_mean[:, args.stain_code_size:],
            tf.exp(source_patches_cov[:, args.stain_code_size:]))

        target_patches_cov, target_patches_mean = tf.contrib.graph_editor.graph_replace(
            [
                sess.graph.get_tensor_by_name(
                    'imported/z_log_sigma_sq/BiasAdd:0'),
                sess.graph.get_tensor_by_name('imported/z_mean/BiasAdd:0')
            ], {
                sess.graph.get_tensor_by_name('imported/patch:0'):
                normalize(target_patches_placeholder)
            })
        target_patches_distribution = tf.contrib.distributions.MultivariateNormalDiag(
            target_patches_mean[:, args.stain_code_size:],
            tf.exp(target_patches_cov[:, args.stain_code_size:]))

        similarity = source_patches_distribution.kl_divergence(
            target_patches_distribution
        ) + target_patches_distribution.kl_divergence(
            source_patches_distribution)
        #similarity = ctf.bhattacharyya_distance(source_patches_distribution, target_patches_distribution)

        sess.run(tf.global_variables_initializer())
        sess.run(tf.local_variables_initializer())
        saver.restore(sess, latest_checkpoint)

        for i in range(num_splits):
            start = i * split_size
            end = start + split_size
            batch_coords = coords[start:end, :]
            feed_dict = {coords_placeholder: batch_coords}
            similarity_values = sess.run(similarity,
                                         feed_dict=feed_dict,
                                         options=run_options)
            #heatmap.extend(similarity_values)
            for idx, val in zip(batch_coords, similarity_values):
                heatmap[idx[0], idx[1]] = val

        heatmap_sad = sess.run(
            tf.reduce_mean(tf.squared_difference(source_image, target_image),
                           axis=3))[0]

        #sim_heatmap = np.reshape(heatmap, heatmap_size, order='C')
        sim_heatmap = heatmap

        fig_images, ax_images = plt.subplots(1, 2)
        ax_images[0].imshow(sess.run(source_image)[0])
        ax_images[1].imshow(sess.run(target_image)[0])

        fig_similarities, ax_similarities = plt.subplots(1, 2)
        heatmap_skld_plot = ax_similarities[0].imshow(sim_heatmap,
                                                      cmap='plasma')
        heatmap_sad_plot = ax_similarities[1].imshow(heatmap_sad,
                                                     cmap='plasma')

        fig_similarities.colorbar(heatmap_skld_plot, ax=ax_similarities[0])
        fig_similarities.colorbar(heatmap_sad_plot, ax=ax_similarities[1])

        plt.show()

        sess.close()
    return 0