def save_images(self, directory: str, pairs: List[Tuple[List[str], List[lt.LabeledTensor]]]): filenames = list(itertools.chain(*[p[0] for p in pairs])) image_lts = list(itertools.chain(*[p[1] for p in pairs])) images = self.eval(image_lts) path = os.path.join(os.environ['TEST_TMPDIR'], directory) if not gfile.Exists(path): gfile.MkDir(path) for f, i in zip(filenames, images): util.write_image(os.path.join(path, f), i)
def infer( gitapp: controller.GetInputTargetAndPredictedParameters, restore_directory: str, output_directory: str, extract_patch_size: int, stitch_stride: int, infer_size: int, channel_whitelist: Optional[List[str]], simplify_error_panels: bool, ): """Runs inference on an image. Args: gitapp: GetInputTargetAndPredictedParameters. restore_directory: Where to restore the model from. output_directory: Where to write the generated images. extract_patch_size: The size of input to the model. stitch_stride: The stride size when running model inference. Equivalently, the output size of the model. infer_size: The number of simultaneous inferences to perform in the row and column dimensions. For example, if this is 8, inference will be performed in 8 x 8 blocks for a batch size of 64. channel_whitelist: If provided, only images for the given channels will be produced. This can be used to create simpler error panels. simplify_error_panels: Whether to create simplified error panels. Raises: ValueError: If 1) The DataParameters don't contain a ReadPNGsParameters. 2) The images must be larger than the input to the network. 3) The graph must not contain queues. """ rpp = gitapp.dp.io_parameters if not isinstance(rpp, data_provider.ReadPNGsParameters): raise ValueError( 'Data provider must contain a ReadPNGsParameter, but was: %r', gitapp.dp) original_crop_size = rpp.crop_size image_num_rows, image_num_columns = util.image_size(rpp.directory) logging.info('Uncropped image size is %d x %d', image_num_rows, image_num_columns) image_num_rows = min(image_num_rows, original_crop_size) if image_num_rows < extract_patch_size: raise ValueError( 'Image is too small for inference to be performed: %d vs %d', image_num_rows, extract_patch_size) image_num_columns = min(image_num_columns, original_crop_size) if image_num_columns < extract_patch_size: raise ValueError( 'Image is too small for inference to be performed: %d vs %d', image_num_columns, extract_patch_size) logging.info('After cropping, input image size is (%d, %d)', image_num_rows, image_num_columns) num_row_inferences = (image_num_rows - extract_patch_size) // ( stitch_stride * infer_size) num_column_inferences = (image_num_columns - extract_patch_size) // ( stitch_stride * infer_size) logging.info('Running %d x %d inferences', num_row_inferences, num_column_inferences) num_output_rows = (num_row_inferences * infer_size * stitch_stride) num_output_columns = (num_column_inferences * infer_size * stitch_stride) logging.info('Output image size is (%d, %d)', num_output_rows, num_output_columns) g = tf.Graph() with g.as_default(): row_start = tf.placeholder(dtype=np.int32, shape=[]) column_start = tf.placeholder(dtype=np.int32, shape=[]) # Replace the parameters with a new set, which will cause the network to # run inference in just a local region. gitapp = gitapp._replace( dp=gitapp.dp._replace( io_parameters=rpp._replace( row_start=row_start, column_start=column_start, crop_size=(infer_size - 1) * stitch_stride + extract_patch_size, ))) visualization_lts = controller.setup_stitch(gitapp) def get_statistics(tensor): rc = lt.ReshapeCoder(list(tensor.axes.keys())[:-1], ['batch']) return rc.decode(ops.distribution_statistics(rc.encode(tensor))) visualize_input_lt = visualization_lts['input'] visualize_predict_input_lt = get_statistics( visualization_lts['predict_input']) visualize_target_lt = visualization_lts['target'] visualize_predict_target_lt = get_statistics( visualization_lts['predict_target']) input_lt = lt.LabeledTensor( tf.placeholder( dtype=np.float32, shape=[ 1, num_output_rows, num_output_columns, len(gitapp.dp.input_z_values), 1, 2 ]), axes=[ 'batch', 'row', 'column', ('z', gitapp.dp.input_z_values), ('channel', ['TRANSMISSION']), ('mask', [False, True]), ]) predict_input_lt = lt.LabeledTensor( tf.placeholder( dtype=np.float32, shape=[ 1, num_output_rows, num_output_columns, len(gitapp.dp.input_z_values), 1, len(visualize_predict_input_lt.axes['statistic']), ]), axes=[ 'batch', 'row', 'column', ('z', gitapp.dp.input_z_values), ('channel', ['TRANSMISSION']), visualize_predict_input_lt.axes['statistic'], ]) input_error_panel_lt = visualize.error_panel_from_statistics( input_lt, predict_input_lt, simplify_error_panels) target_lt = lt.LabeledTensor( tf.placeholder( dtype=np.float32, shape=[ 1, num_output_rows, num_output_columns, len(gitapp.dp.target_z_values), len(gitapp.dp.target_channel_values) + 1, 2 ]), axes=[ 'batch', 'row', 'column', ('z', gitapp.dp.target_z_values), ('channel', gitapp.dp.target_channel_values + ['NEURITE_CONFOCAL']), ('mask', [False, True]), ]) predict_target_lt = lt.LabeledTensor( tf.placeholder( dtype=np.float32, shape=[ 1, num_output_rows, num_output_columns, len(gitapp.dp.target_z_values), len(gitapp.dp.target_channel_values) + 1, len(visualize_predict_target_lt.axes['statistic']), ]), axes=[ 'batch', 'row', 'column', ('z', gitapp.dp.target_z_values), ('channel', gitapp.dp.target_channel_values + ['NEURITE_CONFOCAL']), visualize_predict_target_lt.axes['statistic'], ]) logging.info('input_lt: %r', input_lt) logging.info('predict_input_lt: %r', predict_input_lt) logging.info('target_lt: %r', target_lt) logging.info('predict_target_lt: %r', predict_target_lt) def select_channels(tensor): if channel_whitelist is not None: return lt.select(tensor, {'channel': channel_whitelist}) else: return tensor target_error_panel_lt = visualize.error_panel_from_statistics( select_channels(target_lt), select_channels(predict_target_lt), simplify_error_panels) # There shouldn't be any queues in this configuration. queue_runners = g.get_collection(tf.GraphKeys.QUEUE_RUNNERS) if queue_runners: raise ValueError('Graph must not have queues, but had: %r', queue_runners) logging.info('Attempting to find restore checkpoint in %s', restore_directory) init_fn = util.restore_model( restore_directory, restore_logits=True, restore_global_step=True) with tf.Session() as sess: logging.info('Generating images') init_fn(sess) input_rows = [] predict_input_rows = [] target_rows = [] predict_target_rows = [] for infer_row in range(num_row_inferences): input_row = [] predict_input_row = [] target_row = [] predict_target_row = [] for infer_column in range(num_column_inferences): rs = infer_row * infer_size * stitch_stride cs = infer_column * infer_size * stitch_stride logging.info('Running inference at offset: (%d, %d)', rs, cs) [inpt, predict_input, target, predict_target] = sess.run( [ visualize_input_lt, visualize_predict_input_lt, visualize_target_lt, visualize_predict_target_lt, ], feed_dict={ row_start: rs, column_start: cs }) input_row.append(inpt) predict_input_row.append(predict_input) target_row.append(target) predict_target_row.append(predict_target) input_rows.append(np.concatenate(input_row, axis=2)) predict_input_rows.append(np.concatenate(predict_input_row, axis=2)) target_rows.append(np.concatenate(target_row, axis=2)) predict_target_rows.append(np.concatenate(predict_target_row, axis=2)) logging.info('Stitching') stitched_input = np.concatenate(input_rows, axis=1) stitched_predict_input = np.concatenate(predict_input_rows, axis=1) stitched_target = np.concatenate(target_rows, axis=1) stitched_predict_target = np.concatenate(predict_target_rows, axis=1) logging.info('Creating error panels') import pickle f = open('/home/ubuntu/resinfer.pk','wb') #pickle.dump() d = {'input_lt':stitched_input, 'predict_input_lt':stitched_predict_input, 'target_lt':stitched_target,'predict_target_lt':stitched_predict_target} pickle.dump(d,f) [input_error_panel, target_error_panel, global_step] = sess.run( [ input_error_panel_lt, target_error_panel_lt, tf.train.get_global_step() ], feed_dict={ input_lt: stitched_input, predict_input_lt: stitched_predict_input, target_lt: stitched_target, predict_target_lt: stitched_predict_target, }) output_directory = os.path.join(output_directory, '%.8d' % global_step) if not gfile.Exists(output_directory): gfile.MakeDirs(output_directory) util.write_image( os.path.join(output_directory, 'input_error_panel.png'), input_error_panel[0, :, :, :]) util.write_image( os.path.join(output_directory, 'target_error_panel.png'), target_error_panel[0, :, :, :]) logging.info('Done generating images')