コード例 #1
0
def _extract_groundtruth_tensors(create_input_dict_fn):
    input_dict = create_input_dict_fn()
    prefetch_queue = prefetcher.prefetch(input_dict, capacity=500)
    input_dict = prefetch_queue.dequeue()
    original_image = tf.expand_dims(input_dict[fields.InputDataFields.image],
                                    0)

    tensor_dict = {'image_id': input_dict[fields.InputDataFields.source_id]}

    normalized_gt_boxlist = box_list.BoxList(
        input_dict[fields.InputDataFields.groundtruth_boxes])
    gt_boxlist = box_list_ops.scale(normalized_gt_boxlist,
                                    tf.shape(original_image)[1],
                                    tf.shape(original_image)[2])
    groundtruth_boxes = gt_boxlist.get()
    groundtruth_classes = input_dict[
        fields.InputDataFields.groundtruth_classes]
    tensor_dict['groundtruth_boxes'] = groundtruth_boxes
    tensor_dict['groundtruth_classes'] = groundtruth_classes
    tensor_dict['area'] = input_dict[fields.InputDataFields.groundtruth_area]
    tensor_dict['difficult'] = input_dict[
        fields.InputDataFields.groundtruth_difficult]

    # subset annotations
    if fields.InputDataFields.groundtruth_subset in input_dict:
        tensor_dict['groundtruth_subset'] \
          = input_dict[fields.InputDataFields.groundtruth_subset]

    return tensor_dict
コード例 #2
0
    def _normalize_boxlist(args):

        boxes, height, width = args
        boxes = box_list_ops.scale(boxes, stride, stride)
        boxes = box_list_ops.to_normalized_coordinates(boxes, height, width)
        boxes = box_list_ops.clip_to_window(boxes, [0., 0., 1., 1.],
                                            filter_nonoverlapping=False)
        return boxes
コード例 #3
0
    def graph_fn():
      corners = tf.constant([[0, 0, 100, 200], [50, 120, 100, 140]],
                            dtype=tf.float32)
      boxes = box_list.BoxList(corners)
      boxes.add_field('extra_data', tf.constant([[1], [2]]))

      y_scale = tf.constant(1.0/100)
      x_scale = tf.constant(1.0/200)
      scaled_boxes = box_list_ops.scale(boxes, y_scale, x_scale)
      return scaled_boxes.get(), scaled_boxes.get_field('extra_data')
コード例 #4
0
 def transform_boxes(elems):
     boxes_per_image, true_image_shape = elems
     blist = box_list.BoxList(boxes_per_image)
     # First transform boxes from image space to resized image space since
     # there may have paddings in the resized images.
     blist = box_list_ops.scale(
         blist, true_image_shape[0] / resized_image_height,
         true_image_shape[1] / resized_image_width)
     # Then transform boxes from resized image space (normalized) to the
     # feature map space (absolute).
     blist = box_list_ops.to_absolute_coordinates(blist,
                                                  height,
                                                  width,
                                                  check_range=False)
     return blist.get()
コード例 #5
0
    def test_scale(self):
        corners = tf.constant([[0, 0, 100, 200], [50, 120, 100, 140]],
                              dtype=tf.float32)
        boxes = box_list.BoxList(corners)
        boxes.add_field('extra_data', tf.constant([[1], [2]]))

        y_scale = tf.constant(1.0 / 100)
        x_scale = tf.constant(1.0 / 200)
        scaled_boxes = box_list_ops.scale(boxes, y_scale, x_scale)
        exp_output = [[0, 0, 1, 1], [0.5, 0.6, 1.0, 0.7]]
        with self.test_session() as sess:
            scaled_corners_out = sess.run(scaled_boxes.get())
            self.assertAllClose(scaled_corners_out, exp_output)
            extra_data_out = sess.run(scaled_boxes.get_field('extra_data'))
            self.assertAllEqual(extra_data_out, [[1], [2]])
コード例 #6
0
  def test_scale(self):
    corners = tf.constant([[0, 0, 100, 200], [50, 120, 100, 140]],
                          dtype=tf.float32)
    boxes = box_list.BoxList(corners)
    boxes.add_field('extra_data', tf.constant([[1], [2]]))

    y_scale = tf.constant(1.0/100)
    x_scale = tf.constant(1.0/200)
    scaled_boxes = box_list_ops.scale(boxes, y_scale, x_scale)
    exp_output = [[0, 0, 1, 1], [0.5, 0.6, 1.0, 0.7]]
    with self.test_session() as sess:
      scaled_corners_out = sess.run(scaled_boxes.get())
      self.assertAllClose(scaled_corners_out, exp_output)
      extra_data_out = sess.run(scaled_boxes.get_field('extra_data'))
      self.assertAllEqual(extra_data_out, [[1], [2]])
コード例 #7
0
def _extract_prediction_tensors(model,
                                create_input_dict_fn,
                                ignore_groundtruth=False):
  """Restores the model in a tensorflow session.

  Args:
    model: model to perform predictions with.
    create_input_dict_fn: function to create input tensor dictionaries.
    ignore_groundtruth: whether groundtruth should be ignored.

  Returns:
    tensor_dict: A tensor dictionary with evaluations.
  """
  input_dict = create_input_dict_fn()
  
  prefetch_queue = prefetcher.prefetch(input_dict, capacity=500)
  input_dict = prefetch_queue.dequeue()
  original_image = tf.expand_dims(input_dict[fields.InputDataFields.image], 0)
  preprocessed_image = model.preprocess(tf.to_float(original_image))
  prediction_dict = model.predict(preprocessed_image)
  detections = model.postprocess(prediction_dict)

  original_image_shape = tf.shape(original_image)
  absolute_detection_boxlist = box_list_ops.to_absolute_coordinates(
      box_list.BoxList(tf.squeeze(detections['detection_boxes'], axis=0)),
      original_image_shape[1], original_image_shape[2])
  label_id_offset = 1
  tensor_dict = {
      'original_image': original_image,
      'image_id': input_dict[fields.InputDataFields.source_id],
      'detection_boxes': absolute_detection_boxlist.get(),
      'detection_scores': tf.squeeze(detections['detection_scores'], axis=0),
      'detection_classes': (
          tf.squeeze(detections['detection_classes'], axis=0) +
          label_id_offset),
  }
  if 'detection_masks' in detections:
    detection_masks = tf.squeeze(detections['detection_masks'],
                                 axis=0)
    detection_boxes = tf.squeeze(detections['detection_boxes'],
                                 axis=0)
    # TODO: This should be done in model's postprocess function ideally.
    detection_masks_reframed = ops.reframe_box_masks_to_image_masks(
        detection_masks,
        detection_boxes,
        original_image_shape[1], original_image_shape[2])
    detection_masks_reframed = tf.to_float(tf.greater(detection_masks_reframed,
                                                      0.5))

    tensor_dict['detection_masks'] = detection_masks_reframed
  # load groundtruth fields into tensor_dict
  if not ignore_groundtruth:
    normalized_gt_boxlist = box_list.BoxList(
        input_dict[fields.InputDataFields.groundtruth_boxes])
    gt_boxlist = box_list_ops.scale(normalized_gt_boxlist,
                                    tf.shape(original_image)[1],
                                    tf.shape(original_image)[2])
    groundtruth_boxes = gt_boxlist.get()
    groundtruth_classes = input_dict[fields.InputDataFields.groundtruth_classes]
    tensor_dict['groundtruth_boxes'] = groundtruth_boxes
    tensor_dict['groundtruth_classes'] = groundtruth_classes
    tensor_dict['area'] = input_dict[fields.InputDataFields.groundtruth_area]
    tensor_dict['is_crowd'] = input_dict[
        fields.InputDataFields.groundtruth_is_crowd]
    tensor_dict['difficult'] = input_dict[
        fields.InputDataFields.groundtruth_difficult]
    if 'detection_masks' in tensor_dict:
      tensor_dict['groundtruth_instance_masks'] = input_dict[
          fields.InputDataFields.groundtruth_instance_masks]
  return tensor_dict
コード例 #8
0
def _extract_prediction_tensors(model,
                                create_input_dict_fn,
                                ignore_groundtruth=False):
  """Restores the model in a tensorflow session.

  Args:
    model: model to perform predictions with.
    create_input_dict_fn: function to create input tensor dictionaries.
    ignore_groundtruth: whether groundtruth should be ignored.

  Returns:
    tensor_dict: A tensor dictionary with evaluations.
  """
  input_dict = create_input_dict_fn()
  prefetch_queue = prefetcher.prefetch(input_dict, capacity=500)
  input_dict = prefetch_queue.dequeue()
  original_image = tf.expand_dims(input_dict[fields.InputDataFields.image], 0)
  preprocessed_image = model.preprocess(tf.to_float(original_image))
  prediction_dict = model.predict(preprocessed_image)
  detections = model.postprocess(prediction_dict)

  original_image_shape = tf.shape(original_image)
  absolute_detection_boxlist = box_list_ops.to_absolute_coordinates(
      box_list.BoxList(tf.squeeze(detections['detection_boxes'], axis=0)),
      original_image_shape[1], original_image_shape[2])
  label_id_offset = 1
  tensor_dict = {
      'original_image': original_image,
      'image_id': input_dict[fields.InputDataFields.source_id],
      'detection_boxes': absolute_detection_boxlist.get(),
      'detection_scores': tf.squeeze(detections['detection_scores'], axis=0),
      'detection_classes': (
          tf.squeeze(detections['detection_classes'], axis=0) +
          label_id_offset),
  }
  if 'detection_masks' in detections:
    detection_masks = tf.squeeze(detections['detection_masks'],
                                 axis=0)
    detection_boxes = tf.squeeze(detections['detection_boxes'],
                                 axis=0)
    # TODO: This should be done in model's postprocess function ideally.
    detection_masks_reframed = ops.reframe_box_masks_to_image_masks(
        detection_masks,
        detection_boxes,
        original_image_shape[1], original_image_shape[2])
    detection_masks_reframed = tf.to_float(tf.greater(detection_masks_reframed,
                                                      0.5))

    tensor_dict['detection_masks'] = detection_masks_reframed
  # load groundtruth fields into tensor_dict
  if not ignore_groundtruth:
    normalized_gt_boxlist = box_list.BoxList(
        input_dict[fields.InputDataFields.groundtruth_boxes])
    gt_boxlist = box_list_ops.scale(normalized_gt_boxlist,
                                    tf.shape(original_image)[1],
                                    tf.shape(original_image)[2])
    groundtruth_boxes = gt_boxlist.get()
    groundtruth_classes = input_dict[fields.InputDataFields.groundtruth_classes]
    tensor_dict['groundtruth_boxes'] = groundtruth_boxes
    tensor_dict['groundtruth_classes'] = groundtruth_classes
    tensor_dict['area'] = input_dict[fields.InputDataFields.groundtruth_area]
    tensor_dict['is_crowd'] = input_dict[
        fields.InputDataFields.groundtruth_is_crowd]
    tensor_dict['difficult'] = input_dict[
        fields.InputDataFields.groundtruth_difficult]
    if 'detection_masks' in tensor_dict:
      tensor_dict['groundtruth_instance_masks'] = input_dict[
          fields.InputDataFields.groundtruth_instance_masks]
  return tensor_dict
コード例 #9
0
ファイル: evaluator.py プロジェクト: zouwen198317/motion-rcnn
def _extract_prediction_tensors(model,
                                create_input_dict_fn,
                                ignore_groundtruth=False):
    """Restores the model in a tensorflow session.

  Args:
    model: model to perform predictions with.
    create_input_dict_fn: function to create input tensor dictionaries.
    ignore_groundtruth: whether groundtruth should be ignored.

  Returns:
    tensor_dict: A tensor dictionary with evaluations.
  """
    input_dict = create_input_dict_fn()
    prefetch_queue = prefetcher.prefetch(input_dict, capacity=500)  # TODO
    input_dict = prefetch_queue.dequeue()
    original_image = tf.expand_dims(input_dict[fields.InputDataFields.image],
                                    0)

    next_image = input_dict.get(fields.InputDataFields.next_image)
    image_input = tf.to_float(original_image)
    if next_image is not None:
        next_image = tf.to_float(next_image)
        image_input = tf.concat(
            [image_input, tf.expand_dims(next_image, 0)], 3)
        depth = input_dict.get(fields.InputDataFields.groundtruth_depth)
        next_depth = input_dict.get(
            fields.InputDataFields.groundtruth_next_depth)
        image_input.set_shape([1, None, None, 6])
        if depth is not None and next_depth is not None:
            camera_intrinsics = input_dict[
                fields.InputDataFields.camera_intrinsics]
            coords = motion_util.get_3D_coords(tf.expand_dims(depth, 0),
                                               camera_intrinsics)
            next_coords = motion_util.get_3D_coords(
                tf.expand_dims(next_depth, 0), camera_intrinsics)
            image_input = tf.concat([image_input, coords, next_coords], 3)
            image_input.set_shape([1, None, None, 12])

    preprocessed_image = model.preprocess(image_input)
    prediction_dict = model.predict(preprocessed_image)
    detections = model.postprocess(prediction_dict)

    original_image_shape = tf.shape(original_image)
    absolute_detection_boxlist = box_list_ops.to_absolute_coordinates(
        box_list.BoxList(tf.squeeze(detections['detection_boxes'], axis=0)),
        original_image_shape[1], original_image_shape[2])
    label_id_offset = 1
    tensor_dict = {
        'original_image':
        original_image,
        'image_id':
        input_dict[fields.InputDataFields.source_id],
        'detection_boxes':
        absolute_detection_boxlist.get(),
        'detection_scores':
        tf.squeeze(detections['detection_scores'], axis=0),
        'detection_classes':
        (tf.squeeze(detections['detection_classes'], axis=0) +
         label_id_offset),
    }
    if 'detection_masks' in detections:
        detection_masks = tf.squeeze(detections['detection_masks'], axis=0)
        detection_boxes = tf.squeeze(detections['detection_boxes'], axis=0)
        # TODO: This should be done in model's postprocess function ideally.
        detection_masks_reframed = ops.reframe_box_masks_to_image_masks(
            detection_masks, detection_boxes, original_image_shape[1],
            original_image_shape[2])
        detection_masks_reframed = tf.to_float(
            tf.greater(detection_masks_reframed, 0.5))

        tensor_dict['detection_masks'] = detection_masks_reframed

    if 'detection_motions' in detections:
        detection_motions = tf.squeeze(detections['detection_motions'], axis=0)
        detection_motions_with_matrices = (
            motion_util.postprocess_detection_motions(detection_motions,
                                                      keep_logits=False))
        tensor_dict['detection_motions'] = detection_motions_with_matrices

    if 'camera_motion' in detections:
        camera_motion_with_matrices = tf.squeeze(
            motion_util.postprocess_camera_motion(detections['camera_motion']),
            axis=0)
        tensor_dict['camera_motion'] = camera_motion_with_matrices
        tensor_dict['groundtruth_camera_motion'] = input_dict[
            fields.InputDataFields.groundtruth_camera_motion]

    # load groundtruth fields into tensor_dict
    if not ignore_groundtruth:
        normalized_gt_boxlist = box_list.BoxList(
            input_dict[fields.InputDataFields.groundtruth_boxes])
        gt_boxlist = box_list_ops.scale(normalized_gt_boxlist,
                                        tf.shape(original_image)[1],
                                        tf.shape(original_image)[2])
        groundtruth_boxes = gt_boxlist.get()
        groundtruth_classes = input_dict[
            fields.InputDataFields.groundtruth_classes]
        tensor_dict['groundtruth_boxes'] = groundtruth_boxes
        tensor_dict['groundtruth_classes'] = groundtruth_classes
        tensor_dict['area'] = input_dict[
            fields.InputDataFields.groundtruth_area]
        tensor_dict['is_crowd'] = input_dict[
            fields.InputDataFields.groundtruth_is_crowd]
        tensor_dict['difficult'] = input_dict[
            fields.InputDataFields.groundtruth_difficult]
        if 'detection_masks' in tensor_dict:
            tensor_dict['groundtruth_instance_masks'] = input_dict[
                fields.InputDataFields.groundtruth_instance_masks]

        if 'detection_motions' in tensor_dict:
            tensor_dict['groundtruth_camera_motion'] = input_dict[
                fields.InputDataFields.groundtruth_camera_motion]
            tensor_dict['groundtruth_instance_motions'] = input_dict[
                fields.InputDataFields.groundtruth_instance_motions]
            tensor_dict['camera_intrinsics'] = input_dict[
                fields.InputDataFields.camera_intrinsics]
            if fields.InputDataFields.groundtruth_flow in input_dict:
                tensor_dict['groundtruth_flow'] = input_dict[
                    fields.InputDataFields.groundtruth_flow]
            if not 'depth' in tensor_dict:
                tensor_dict['depth'] = input_dict[
                    fields.InputDataFields.groundtruth_depth]
            else:
                tensor_dict['groundtruth_depth'] = input_dict[
                    fields.InputDataFields.groundtruth_depth]
    return tensor_dict
コード例 #10
0
def _create_losses(input_queue, create_model_fn):
  """Creates loss function for a DetectionModel.

  Args:
    input_queue: BatchQueue object holding enqueued tensor_dicts.
    create_model_fn: A function to create the DetectionModel.
  """
  detection_model = create_model_fn()
  (original_images, filenames, groundtruth_boxes_list, groundtruth_classes_list, groundtruth_transcriptions_list,
   groundtruth_masks_list
  ) = _get_inputs(input_queue, detection_model.num_classes)

  images = [detection_model.preprocess(image) for image in original_images]
  images = tf.concat(images, 0)
  if any(mask is None for mask in groundtruth_masks_list):
    groundtruth_masks_list = None

  tf.summary.image('InputImage', images, max_outputs=99999)

  print ''
  print '_create_losses'
  print original_images
  print images
  print groundtruth_boxes_list
  print groundtruth_classes_list
  print groundtruth_transcriptions_list
  sys.stdout.flush()

  detection_model.provide_groundtruth(groundtruth_boxes_list,
                                      groundtruth_classes_list,
                                      groundtruth_masks_list,
                                      groundtruth_transcriptions_list = groundtruth_transcriptions_list)
  prediction_dict = detection_model.predict(images)
  losses_dict = detection_model.loss(prediction_dict)
  for name, loss_tensor in losses_dict.iteritems():
    tf.summary.scalar(name, loss_tensor)
    tf.losses.add_loss(loss_tensor)
  print losses_dict
  sys.stdout.flush()

  # Metrics for sequence accuracy
  if prediction_dict['transcriptions'] is not None:
    tf.summary.scalar('CharAccuracy', metrics.char_accuracy(prediction_dict['transcriptions'], prediction_dict['transcriptions_groundtruth']))
    tf.summary.scalar('SequenceAccuracy', metrics.sequence_accuracy(prediction_dict['transcriptions'], prediction_dict['transcriptions_groundtruth']))

  return 

  # All the rest is for debugging and testing during training purpose. 

  # Metrics for detection
  detections = detection_model.postprocess(prediction_dict)

  original_images = original_images[0]
  filenames = filenames[0]

  original_image_shape = tf.shape(original_images)
  absolute_detection_boxlist = box_list_ops.to_absolute_coordinates(
      box_list.BoxList(tf.squeeze(detections['detection_boxes'], axis=0)),
      original_image_shape[1], original_image_shape[2])
  label_id_offset = 1
  det_boxes = absolute_detection_boxlist.get()

  det_scores = tf.squeeze(detections['detection_scores'], axis=0)
  det_classes = tf.ones_like(det_scores)
  det_transcriptions = tf.squeeze(detections['detection_transcriptions'], axis=0)

  print ''
  print 'Metrics printing'
  print groundtruth_boxes_list
  print groundtruth_classes_list
  print groundtruth_transcriptions_list

  normalized_gt_boxlist = box_list.BoxList(groundtruth_boxes_list[0])
  gt_boxlist = box_list_ops.scale(normalized_gt_boxlist, original_image_shape[1], original_image_shape[2])
  gt_boxes = gt_boxlist.get()
  gt_classes = groundtruth_classes_list[0]
  gt_transcriptions = groundtruth_transcriptions_list[0]

  print original_images
  print filenames
  print det_boxes
  print det_scores 
  print det_classes 
  print det_transcriptions
  print gt_boxes
  print gt_classes
  print gt_transcriptions
  #images = tf.Print(images, [groundtruth_boxes_list[0], xx, tf.shape(original_images[0])], message='groundtruthboxes', summarize=10000)
  sys.stdout.flush()

  mAP = tf.py_func(eval_wrapper, [original_images, filenames, det_boxes, det_scores, det_classes, det_transcriptions, gt_boxes, gt_classes, gt_transcriptions, tf.train.get_global_step()], tf.float64, stateful=False)
  tf.summary.scalar('mAP', mAP)
コード例 #11
0
def _extract_prediction_tensors(model,
                                create_input_dict_fn,
                                ignore_groundtruth=False,
                                provide_groundtruth_to_model=False,
                                calc_loss=False):
    """Restores the model in a tensorflow session.

  Args:
    model: model to perform predictions with.
    create_input_dict_fn: function to create input tensor dictionaries.
    ignore_groundtruth: whether groundtruth should be ignored.
    provide_groundtruth_to_model: whether to use model.provide_groundtruth()

  Returns:
    tensor_dict: A tensor dictionary with evaluations.
  """
    mtl = model._mtl
    input_dict = create_input_dict_fn()
    prefetch_queue = prefetcher.prefetch(input_dict, capacity=500)
    input_dict = prefetch_queue.dequeue()

    if calc_loss or mtl.window or mtl.edgemask:
        provide_groundtruth_to_model = True

    # Get groundtruth information
    if provide_groundtruth_to_model:
        (_, groundtruth_boxes_list, groundtruth_ignore_list,
         groundtruth_classes_list, groundtruth_masks_list, _,
         window_boxes_list, window_classes_list, groundtruth_closeness_list,
         groundtruth_edgemask_list) = _get_inputs([input_dict],
                                                  model.num_classes,
                                                  with_filename=False)

        if any(mask is None for mask in groundtruth_masks_list):
            groundtruth_masks_list = None
        model.provide_groundtruth(groundtruth_boxes_list,
                                  groundtruth_classes_list,
                                  groundtruth_closeness_list,
                                  groundtruth_ignore_list,
                                  groundtruth_masks_list)
        model.provide_window(window_boxes_list, window_classes_list)
        model.provide_edgemask(groundtruth_edgemask_list)

    original_image = tf.expand_dims(input_dict[fields.InputDataFields.image],
                                    0)
    preprocessed_image = model.preprocess(tf.to_float(original_image))
    prediction_dict = model.predict(preprocessed_image)

    if mtl.window:
        prediction_dict = model.predict_with_window(prediction_dict)
    if mtl.edgemask:
        prediction_dict = model.predict_edgemask(prediction_dict)
    if mtl.refine:
        prediction_dict = model.predict_with_mtl_results(prediction_dict)

    detections = model.postprocess(prediction_dict)

    original_image_shape = tf.shape(original_image)
    absolute_detection_boxlist = box_list_ops.to_absolute_coordinates(
        box_list.BoxList(tf.squeeze(detections['detection_boxes'], axis=0)),
        original_image_shape[1], original_image_shape[2])
    label_id_offset = 1
    tensor_dict = {
        'original_image':
        original_image,
        'image_id':
        input_dict[fields.InputDataFields.source_id],
        'detection_boxes':
        absolute_detection_boxlist.get(),
        'detection_scores':
        tf.squeeze(detections['detection_scores'], axis=0),
        'detection_classes':
        (tf.squeeze(detections['detection_classes'], axis=0) +
         label_id_offset),
    }

    if 'detection_thresholds' in detections:
        tensor_dict['detection_thresholds'] = \
            tf.squeeze(detections['detection_thresholds'], axis=0)
    if 'detection_masks' in detections:
        detection_masks = tf.squeeze(detections['detection_masks'], axis=0)
        detection_boxes = tf.squeeze(detections['detection_boxes'], axis=0)
        # TODO: This should be done in model's postprocess function ideally.
        detection_masks_reframed = ops.reframe_box_masks_to_image_masks(
            detection_masks, detection_boxes, original_image_shape[1],
            original_image_shape[2])
        detection_masks_reframed = tf.to_float(
            tf.greater(detection_masks_reframed, 0.5))

        tensor_dict['detection_masks'] = detection_masks_reframed
    # load groundtruth fields into tensor_dict
    if not ignore_groundtruth:
        normalized_gt_boxlist = box_list.BoxList(
            input_dict[fields.InputDataFields.groundtruth_boxes])
        gt_boxlist = box_list_ops.scale(normalized_gt_boxlist,
                                        tf.shape(original_image)[1],
                                        tf.shape(original_image)[2])
        groundtruth_boxes = gt_boxlist.get()
        groundtruth_classes = input_dict[
            fields.InputDataFields.groundtruth_classes]
        tensor_dict['groundtruth_boxes'] = groundtruth_boxes
        tensor_dict['groundtruth_classes'] = groundtruth_classes
        tensor_dict['area'] = input_dict[
            fields.InputDataFields.groundtruth_area]
        tensor_dict['difficult'] = input_dict[
            fields.InputDataFields.groundtruth_difficult]
        if 'detection_masks' in tensor_dict:
            tensor_dict['groundtruth_instance_masks'] = input_dict[
                fields.InputDataFields.groundtruth_instance_masks]

        # Subset annotations
        if fields.InputDataFields.groundtruth_subset in input_dict:
            tensor_dict['groundtruth_subset'] \
              = input_dict[fields.InputDataFields.groundtruth_subset]

    if calc_loss:
        losses_dict = model.loss(prediction_dict)

        for loss_name, loss_tensor in losses_dict.iteritems():
            loss_tensor = tf.check_numerics(loss_tensor,
                                            '%s is inf or nan.' % loss_name,
                                            name='Loss/' + loss_name)
            tensor_dict['Loss/' + loss_name] = loss_tensor

    # mtl groundtruth
    if mtl.window:
        tensor_dict['window_classes_gt'] = input_dict[
            fields.InputDataFields.window_classes]
        tensor_dict['window_classes_dt'] = prediction_dict[
            'window_class_predictions']
    if mtl.closeness:
        tensor_dict['closeness_gt'] = input_dict[
            fields.InputDataFields.groundtruth_closeness]
        tensor_dict['closeness_dt'] = prediction_dict['closeness_predictions']
    if mtl.edgemask:
        tensor_dict['edgemask_gt'] = input_dict[
            fields.InputDataFields.groundtruth_edgemask_masks]
        tensor_dict['edgemask_dt'] = prediction_dict['edgemask_predictions']

    return tensor_dict
def _extract_prediction_tensors(model,
                                create_input_dict_fn,
                                ignore_groundtruth=False):
    """Restores the model in a tensorflow session.

  Args:
    model: model to perform predictions with.
    create_input_dict_fn: function to create input tensor dictionaries.
    ignore_groundtruth: whether groundtruth should be ignored.

  Returns:
    tensor_dict: A tensor dictionary with evaluations.
  """
    k_shot = model._k_shot
    input_dict, thread = create_input_dict_fn()
    prefetch_queue = prefetcher.prefetch(input_dict, capacity=100)
    input_dict = prefetch_queue.dequeue()
    images = input_dict[fields.CoLocInputDataFields.supportset]
    images_list = [image for image in tf.split(images, k_shot)]
    float_images = tf.to_float(images)
    input_dict[fields.CoLocInputDataFields.supportset] = float_images
    preprocessed_images = [
        model.preprocess(float_image)
        for float_image in tf.split(float_images, k_shot)
    ]
    preprocessed_image = tf.concat(preprocessed_images, 0)
    prediction_dict = model.predict(preprocessed_image)
    detections = model.postprocess(prediction_dict)
    original_image_shape = tf.shape(images)

    def _absolute_boxes(normalized_boxes):
        absolute_detection_boxlist_list = [
            box_list_ops.to_absolute_coordinates(
                box_list.BoxList(tf.squeeze(k, axis=0)),
                original_image_shape[1], original_image_shape[2])
            for k in tf.split(normalized_boxes, k_shot)
        ]
        return tf.stack([db.get() for db in absolute_detection_boxlist_list])

    tensor_dict = {'original_image': images}

    if detections.has_key('rpn_detection_boxes'):
        tensor_dict['rpn'] = {
            'boxes': detections['rpn_detection_boxes'],
            'scores': detections['rpn_detection_scores'],
            'classes': detections['rpn_detection_classes'],
            'class_agnostic': tf.constant(True)
        }

    if detections.has_key('detection_boxes'):
        tensor_dict['detection'] = {
            'boxes': detections['detection_boxes'],
            'scores': detections['detection_scores'],
            'classes': detections['detection_classes'],
            'class_agnostic': tf.constant(False)
        }
    label_id_offset = 1
    if hasattr(model, '_tree_debug_tensors'):
        tensor_dict.update(model._tree_debug_tensors())

    # Convert to the absolute coordinates
    for key, val in tensor_dict.items():
        if isinstance(val, dict):
            for mkey in val:
                if 'boxes' in mkey:
                    val[mkey] = _absolute_boxes(val[mkey])
                if 'classes' in mkey:
                    val[mkey] = val[mkey] + label_id_offset

    if not ignore_groundtruth:
        groundtruth_boxes_list = []
        groundtruth_classes_list = []
        groundtruth_target_class = input_dict[
            fields.CoLocInputDataFields.groundtruth_target_class]
        for k in xrange(k_shot):
            normalized_gt_boxlist = box_list.BoxList(
                input_dict[fields.CoLocInputDataFields.groundtruth_boxes +
                           '_{}'.format(k)])
            gt_boxlist = box_list_ops.scale(normalized_gt_boxlist,
                                            original_image_shape[1],
                                            original_image_shape[2])
            groundtruth_boxes = gt_boxlist.get()
            groundtruth_classes = input_dict[
                fields.CoLocInputDataFields.groundtruth_classes +
                '_{}'.format(k)]
            groundtruth_boxes_list.append(groundtruth_boxes)
            groundtruth_classes_list.append(groundtruth_classes)
        ndict = dict()
        ndict['boxes'] = groundtruth_boxes_list
        ndict['classes'] = groundtruth_classes_list
        ndict['target_class'] = groundtruth_target_class
        tensor_dict['groundtruth'] = ndict
    return tensor_dict, thread