Ejemplo n.º 1
0
def export(gitapp: controller.GetInputTargetAndPredictedParameters):
    g = tf.Graph()
    with g.as_default():
        assert FLAGS.metric == METRIC_STITCH

        controller.setup_stitch(gitapp)

        log_entry_points(g)

        signature_map = dict([(o.name, o) for o in g.get_operations()
                              if 'entry_point' in o.name])

        logging.info('Exporting checkpoint at %s to %s',
                     FLAGS.restore_directory, FLAGS.export_directory)
        slim.export_for_serving(g,
                                checkpoint_dir=FLAGS.restore_directory,
                                export_dir=FLAGS.export_directory,
                                generic_signature_tensor_map=signature_map)
Ejemplo n.º 2
0
def eval_stitch(gitapp: controller.GetInputTargetAndPredictedParameters):
    g = tf.Graph()
    with g.as_default():
        controller.setup_stitch(gitapp)

        summary_ops = tf.get_collection(tf.GraphKeys.SUMMARIES)
        input_summary_op = next(
            x for x in summary_ops if 'input_error_panel' in x.name)
        target_summary_op = next(
            x for x in summary_ops if 'target_error_panel' in x.name)

        log_entry_points(g)

        slim.evaluation.evaluation_loop(
            master=FLAGS.master,
            num_evals=0,
            checkpoint_dir=train_directory(),
            logdir=output_directory(),
            # Merge the summaries to keep the graph state in sync.
            summary_op=tf.summary.merge([input_summary_op, target_summary_op]),
            eval_interval_secs=FLAGS.eval_interval_secs)
Ejemplo n.º 3
0
    def setUp(self):
        super(SetupStitchTest, self).setUp()

        dp = self.dp._replace(io_parameters=self.dp.io_parameters._replace(
            crop_size=110 + 128))
        # self.dp.io_parameters.crop_size = 110 + 128
        is_train = True

        gitapp = controller.GetInputTargetAndPredictedParameters(
            dp, self.ap, 110, self.stride, self.stitch_patch_size, self.bp,
            self.core_model, self.add_head, self.shuffle, self.num_classes,
            util.softmax_cross_entropy, is_train)

        self.image_lt_dict = controller.setup_stitch(gitapp)
Ejemplo n.º 4
0
def infer(
    gitapp: controller.GetInputTargetAndPredictedParameters,
    restore_directory: str,
    output_directory: str,
    extract_patch_size: int,
    stitch_stride: int,
    infer_size: int,
    channel_whitelist: Optional[List[str]],
    simplify_error_panels: bool,
):
  """Runs inference on an image.

  Args:
    gitapp: GetInputTargetAndPredictedParameters.
    restore_directory: Where to restore the model from.
    output_directory: Where to write the generated images.
    extract_patch_size: The size of input to the model.
    stitch_stride: The stride size when running model inference.
      Equivalently, the output size of the model.
    infer_size: The number of simultaneous inferences to perform in the
      row and column dimensions.
      For example, if this is 8, inference will be performed in 8 x 8 blocks
      for a batch size of 64.
    channel_whitelist: If provided, only images for the given channels will
      be produced.
      This can be used to create simpler error panels.
    simplify_error_panels: Whether to create simplified error panels.

  Raises:
    ValueError: If
      1) The DataParameters don't contain a ReadPNGsParameters.
      2) The images must be larger than the input to the network.
      3) The graph must not contain queues.
  """
  rpp = gitapp.dp.io_parameters
  if not isinstance(rpp, data_provider.ReadPNGsParameters):
    raise ValueError(
        'Data provider must contain a ReadPNGsParameter, but was: %r',
        gitapp.dp)

  original_crop_size = rpp.crop_size
  image_num_rows, image_num_columns = util.image_size(rpp.directory)
  logging.info('Uncropped image size is %d x %d', image_num_rows,
               image_num_columns)
  image_num_rows = min(image_num_rows, original_crop_size)
  if image_num_rows < extract_patch_size:
    raise ValueError(
        'Image is too small for inference to be performed: %d vs %d',
        image_num_rows, extract_patch_size)
  image_num_columns = min(image_num_columns, original_crop_size)
  if image_num_columns < extract_patch_size:
    raise ValueError(
        'Image is too small for inference to be performed: %d vs %d',
        image_num_columns, extract_patch_size)
  logging.info('After cropping, input image size is (%d, %d)', image_num_rows,
               image_num_columns)

  num_row_inferences = (image_num_rows - extract_patch_size) // (
      stitch_stride * infer_size)
  num_column_inferences = (image_num_columns - extract_patch_size) // (
      stitch_stride * infer_size)
  logging.info('Running %d x %d inferences', num_row_inferences,
               num_column_inferences)
  num_output_rows = (num_row_inferences * infer_size * stitch_stride)
  num_output_columns = (num_column_inferences * infer_size * stitch_stride)
  logging.info('Output image size is (%d, %d)', num_output_rows,
               num_output_columns)

  g = tf.Graph()
  with g.as_default():
    row_start = tf.placeholder(dtype=np.int32, shape=[])
    column_start = tf.placeholder(dtype=np.int32, shape=[])
    # Replace the parameters with a new set, which will cause the network to
    # run inference in just a local region.
    gitapp = gitapp._replace(
        dp=gitapp.dp._replace(
            io_parameters=rpp._replace(
                row_start=row_start,
                column_start=column_start,
                crop_size=(infer_size - 1) * stitch_stride + extract_patch_size,
            )))

    visualization_lts = controller.setup_stitch(gitapp)

    def get_statistics(tensor):
      rc = lt.ReshapeCoder(list(tensor.axes.keys())[:-1], ['batch'])
      return rc.decode(ops.distribution_statistics(rc.encode(tensor)))

    visualize_input_lt = visualization_lts['input']
    visualize_predict_input_lt = get_statistics(
        visualization_lts['predict_input'])
    visualize_target_lt = visualization_lts['target']
    visualize_predict_target_lt = get_statistics(
        visualization_lts['predict_target'])

    input_lt = lt.LabeledTensor(
        tf.placeholder(
            dtype=np.float32,
            shape=[
                1, num_output_rows, num_output_columns,
                len(gitapp.dp.input_z_values), 1, 2
            ]),
        axes=[
            'batch',
            'row',
            'column',
            ('z', gitapp.dp.input_z_values),
            ('channel', ['TRANSMISSION']),
            ('mask', [False, True]),
        ])
    predict_input_lt = lt.LabeledTensor(
        tf.placeholder(
            dtype=np.float32,
            shape=[
                1,
                num_output_rows,
                num_output_columns,
                len(gitapp.dp.input_z_values),
                1,
                len(visualize_predict_input_lt.axes['statistic']),
            ]),
        axes=[
            'batch',
            'row',
            'column',
            ('z', gitapp.dp.input_z_values),
            ('channel', ['TRANSMISSION']),
            visualize_predict_input_lt.axes['statistic'],
        ])
    input_error_panel_lt = visualize.error_panel_from_statistics(
        input_lt, predict_input_lt, simplify_error_panels)

    target_lt = lt.LabeledTensor(
        tf.placeholder(
            dtype=np.float32,
            shape=[
                1, num_output_rows, num_output_columns,
                len(gitapp.dp.target_z_values),
                len(gitapp.dp.target_channel_values) + 1, 2
            ]),
        axes=[
            'batch',
            'row',
            'column',
            ('z', gitapp.dp.target_z_values),
            ('channel', gitapp.dp.target_channel_values + ['NEURITE_CONFOCAL']),
            ('mask', [False, True]),
        ])
    predict_target_lt = lt.LabeledTensor(
        tf.placeholder(
            dtype=np.float32,
            shape=[
                1,
                num_output_rows,
                num_output_columns,
                len(gitapp.dp.target_z_values),
                len(gitapp.dp.target_channel_values) + 1,
                len(visualize_predict_target_lt.axes['statistic']),
            ]),
        axes=[
            'batch',
            'row',
            'column',
            ('z', gitapp.dp.target_z_values),
            ('channel', gitapp.dp.target_channel_values + ['NEURITE_CONFOCAL']),
            visualize_predict_target_lt.axes['statistic'],
        ])

    logging.info('input_lt: %r', input_lt)
    logging.info('predict_input_lt: %r', predict_input_lt)
    logging.info('target_lt: %r', target_lt)
    logging.info('predict_target_lt: %r', predict_target_lt)

    def select_channels(tensor):
      if channel_whitelist is not None:
        return lt.select(tensor, {'channel': channel_whitelist})
      else:
        return tensor

    target_error_panel_lt = visualize.error_panel_from_statistics(
        select_channels(target_lt), select_channels(predict_target_lt),
        simplify_error_panels)

    # There shouldn't be any queues in this configuration.
    queue_runners = g.get_collection(tf.GraphKeys.QUEUE_RUNNERS)
    if queue_runners:
      raise ValueError('Graph must not have queues, but had: %r', queue_runners)

    logging.info('Attempting to find restore checkpoint in %s',
                 restore_directory)
    init_fn = util.restore_model(
        restore_directory, restore_logits=True, restore_global_step=True)

    with tf.Session() as sess:
      logging.info('Generating images')
      init_fn(sess)

      input_rows = []
      predict_input_rows = []
      target_rows = []
      predict_target_rows = []
      for infer_row in range(num_row_inferences):
        input_row = []
        predict_input_row = []
        target_row = []
        predict_target_row = []
        for infer_column in range(num_column_inferences):
          rs = infer_row * infer_size * stitch_stride
          cs = infer_column * infer_size * stitch_stride
          logging.info('Running inference at offset: (%d, %d)', rs, cs)
          [inpt, predict_input, target, predict_target] = sess.run(
              [
                  visualize_input_lt,
                  visualize_predict_input_lt,
                  visualize_target_lt,
                  visualize_predict_target_lt,
              ],
              feed_dict={
                  row_start: rs,
                  column_start: cs
              })

          input_row.append(inpt)
          predict_input_row.append(predict_input)
          target_row.append(target)
          predict_target_row.append(predict_target)
        input_rows.append(np.concatenate(input_row, axis=2))
        predict_input_rows.append(np.concatenate(predict_input_row, axis=2))
        target_rows.append(np.concatenate(target_row, axis=2))
        predict_target_rows.append(np.concatenate(predict_target_row, axis=2))

      logging.info('Stitching')
      stitched_input = np.concatenate(input_rows, axis=1)
      stitched_predict_input = np.concatenate(predict_input_rows, axis=1)
      stitched_target = np.concatenate(target_rows, axis=1)
      stitched_predict_target = np.concatenate(predict_target_rows, axis=1)

      logging.info('Creating error panels')
      import pickle
      f = open('/home/ubuntu/resinfer.pk','wb')
      #pickle.dump()
      d = {'input_lt':stitched_input, 'predict_input_lt':stitched_predict_input,
           'target_lt':stitched_target,'predict_target_lt':stitched_predict_target}
      pickle.dump(d,f)
      [input_error_panel, target_error_panel, global_step] = sess.run(
          [
              input_error_panel_lt, target_error_panel_lt,
              tf.train.get_global_step()
          ],
          feed_dict={
              input_lt: stitched_input,
              predict_input_lt: stitched_predict_input,
              target_lt: stitched_target,
              predict_target_lt: stitched_predict_target,
          })

      output_directory = os.path.join(output_directory, '%.8d' % global_step)
      if not gfile.Exists(output_directory):
        gfile.MakeDirs(output_directory)

      util.write_image(
          os.path.join(output_directory, 'input_error_panel.png'),
          input_error_panel[0, :, :, :])
      util.write_image(
          os.path.join(output_directory, 'target_error_panel.png'),
          target_error_panel[0, :, :, :])

      logging.info('Done generating images')