Exemple #1
0
  def __init__(self):
    self.detector = model.SegLinkDetector()

    # global TF variables
    with tf.device('/cpu:0'):
      self.global_step = tf.Variable(0, trainable=False, name='global_step', dtype=tf.int64)
      tf.summary.scalar('global_step', self.global_step, collections=['brief'])

    # setup training graphs and summaries
    self._setup_train_net_multigpu()

    # if true the training process will be terminated in the next iteration
    self.should_stop = False
Exemple #2
0
def evaluate():
  with tf.device('/cpu:0'):
    # input data
    streams = data.input_stream(FLAGS.test_dataset)
    pstreams = data.test_preprocess(streams)
    if FLAGS.test_resize_method == 'dynamic':
      # each test image is resized to a different size
      # test batch size must be 1
      assert(FLAGS.test_batch_size == 1)
      batches = tf.train.batch(pstreams,
                               FLAGS.test_batch_size,
                               capacity=1000,
                               num_threads=1,
                               dynamic_pad=True)
    else:
      # resize every image to the same size
      batches = tf.train.batch(pstreams,
                               FLAGS.test_batch_size,
                               capacity=1000,
                               num_threads=1)
    image_size = tf.shape(batches['image'])[1:3]

  fetches = {}
  fetches['images'] = batches['image']
  fetches['image_name'] = batches['image_name']
  fetches['resize_size'] = batches['resize_size']
  fetches['orig_size'] = batches['orig_size']

  # detector
  detector = model.SegLinkDetector()
  all_maps = detector.build_model(batches['image'])

  # decode local predictions
  all_nodes, all_links, all_reg = [], [], []
  for i, maps in enumerate(all_maps):
    cls_maps, lnk_maps, reg_maps = maps
    reg_maps = tf.multiply(reg_maps, data.OFFSET_VARIANCE)

    # segments classification
    cls_prob = tf.nn.softmax(tf.reshape(cls_maps, [-1, 2]))
    cls_pos_prob = cls_prob[:, model.POS_LABEL]
    cls_pos_prob_maps = tf.reshape(cls_pos_prob, tf.shape(cls_maps)[:3])
    # node status is 1 where probability is higher than threshold
    node_labels = tf.cast(tf.greater_equal(cls_pos_prob_maps, FLAGS.node_threshold),
                          tf.int32)

    # link classification
    lnk_prob = tf.nn.softmax(tf.reshape(lnk_maps, [-1, 2]))
    lnk_pos_prob = lnk_prob[:, model.POS_LABEL]
    lnk_shape = tf.shape(lnk_maps)
    lnk_pos_prob_maps = tf.reshape(lnk_pos_prob,
                                   [lnk_shape[0], lnk_shape[1], lnk_shape[2], -1])
    # link status is 1 where probability is higher than threshold
    link_labels = tf.cast(tf.greater_equal(lnk_pos_prob_maps, FLAGS.link_threshold),
                          tf.int32)

    all_nodes.append(node_labels)
    all_links.append(link_labels)
    all_reg.append(reg_maps)

    fetches['link_labels_%d' % i] = link_labels

  # decode segments and links
  segments, group_indices, segment_counts = ops.decode_segments_links(
    image_size, all_nodes, all_links, all_reg,
    anchor_sizes=list(detector.anchor_sizes))
  fetches['segments'] = segments
  fetches['group_indices'] = group_indices
  fetches['segment_counts'] = segment_counts

  # combine segments
  combined_rboxes, combined_counts = ops.combine_segments(
    segments, group_indices, segment_counts)
  fetches['combined_rboxes'] = combined_rboxes
  fetches['combined_counts'] = combined_counts

  sess_config = tf.ConfigProto()
  with tf.Session(config=sess_config) as sess:
    # load model
    model_loader = tf.train.Saver()
    model_loader.restore(sess, FLAGS.test_model)

    batch_size = FLAGS.test_batch_size
    n_batches = int(math.ceil(FLAGS.num_test / batch_size))

    # result directory
    result_dir = os.path.join(FLAGS.log_dir, 'results' + FLAGS.result_suffix)
    utils.mkdir_if_not_exist(result_dir)

    intermediate_result_path = os.path.join(FLAGS.log_dir, 'intermediate.pkl')
    if FLAGS.load_intermediate:
      all_batches = joblib.load(intermediate_result_path)
      logging.info('Intermediate result loaded from {}'.format(intermediate_result_path))
    else:
      # run all batches and store results in a list
      all_batches = []
      with slim.queues.QueueRunners(sess):
        for i in range(n_batches):
          if i % 10 == 0:
            logging.info('Evaluating batch %d/%d' % (i+1, n_batches))
          sess_outputs = sess.run(fetches)
          all_batches.append(sess_outputs)
      if FLAGS.save_intermediate:
        joblib.dump(all_batches, intermediate_result_path, compress=5)
        logging.info('Intermediate result saved to {}'.format(intermediate_result_path))

    # # visualize local rboxes (TODO)
    # if FLAGS.save_vis:
    #   vis_save_prefix = os.path.join(save_dir, 'localpred_batch_%d_' % i)
    #   pred_rboxes_counts = []
    #   for j in range(len(all_maps)):
    #     pred_rboxes_counts.append((sess_outputs['segments_det_%d' % j],
    #                               sess_outputs['segment_counts_det_%d' % j]))
    #   _visualize_layer_det(sess_outputs['images'],
    #                       pred_rboxes_counts,
    #                       vis_save_prefix)

    # # visualize joined rboxes (TODO)
    # if FLAGS.save_vis:
    #   vis_save_prefix = os.path.join(save_dir, 'batch_%d_' % i)
    #   # _visualize_linked_det(sess_outputs, save_prefix)
    #   _visualize_combined_rboxes(sess_outputs, vis_save_prefix)

    if FLAGS.result_format == 'icdar_2015_inc':
      postprocess_and_write_results_ic15(all_batches, result_dir)
    elif FLAGS.result_format == 'icdar_2013':
      postprocess_and_write_results_ic13(all_batches, result_dir)
    else:
      logging.critical('Unknown result format: {}'.format(FLAGS.result_format))
      sys.exit(1)
  
  logging.info('Evaluation done.')