Пример #1
0
def test_data_loading_and_preprocess():
  fig = plt.figure()
  ax = fig.add_subplot(111)

  def _visualize_example(save_path, image, gt_rboxes, mean_subtracted=True):
    ax.clear()
    # convert image
    image_display = vis.convert_image_for_visualization(
        image, mean_subtracted=mean_subtracted)
    # draw image
    ax.imshow(image_display)
    # draw groundtruths
    image_h = image_display.shape[0]
    image_w = image_display.shape[1]
    vis.visualize_rboxes(ax, gt_rboxes,
        edgecolor='yellow', facecolor='none', verbose=False)
    # save plot
    plt.savefig(save_path)

  n_batches = 10
  batch_size = 32

  save_dir = '../vis/example'
  utils.mkdir_if_not_exist(save_dir)

  streams = data.input_stream('../data/synthtext_train.tf')
  pstreams = data.train_preprocess(streams)
  batches = tf.train.shuffle_batch(pstreams, batch_size, capacity=2000, min_after_dequeue=20,
                                   num_threads=1)
  with tf.Session() as sess:
    sess.run(tf.initialize_all_variables())
    tf.train.start_queue_runners(sess=sess)
    for i in xrange(n_batches):
      fetches = {'images': batches['image'],
                 'gt_rboxes': batches['rboxes'],
                 'gt_counts': batches['count']}
      sess_outputs = sess.run(fetches)
      for j in xrange(batch_size):
        save_path = os.path.join(save_dir, '%04d_%d.jpg' % (i, j))
        gt_count = sess_outputs['gt_counts'][j]
        _visualize_example(save_path,
                           sess_outputs['images'][j],
                           sess_outputs['gt_rboxes'][j, :gt_count],
                           mean_subtracted=True)
        print('Visualization saved to %s' % save_path)
Пример #2
0
def evaluate():
  with tf.device('/cpu:0'):
    # input data
    streams = data.input_stream(FLAGS.test_dataset)
    pstreams = data.test_preprocess(streams)
    if FLAGS.test_resize_method == 'dynamic':
      # each test image is resized to a different size
      # test batch size must be 1
      assert(FLAGS.test_batch_size == 1)
      batches = tf.train.batch(pstreams,
                               FLAGS.test_batch_size,
                               capacity=1000,
                               num_threads=1,
                               dynamic_pad=True)
    else:
      # resize every image to the same size
      batches = tf.train.batch(pstreams,
                               FLAGS.test_batch_size,
                               capacity=1000,
                               num_threads=1)
    image_size = tf.shape(batches['image'])[1:3]

  fetches = {}
  fetches['images'] = batches['image']
  fetches['image_name'] = batches['image_name']
  fetches['resize_size'] = batches['resize_size']
  fetches['orig_size'] = batches['orig_size']

  # detector
  detector = model.SegLinkDetector()
  all_maps = detector.build_model(batches['image'])

  # decode local predictions
  all_nodes, all_links, all_reg = [], [], []
  for i, maps in enumerate(all_maps):
    cls_maps, lnk_maps, reg_maps = maps
    reg_maps = tf.multiply(reg_maps, data.OFFSET_VARIANCE)

    # segments classification
    cls_prob = tf.nn.softmax(tf.reshape(cls_maps, [-1, 2]))
    cls_pos_prob = cls_prob[:, model.POS_LABEL]
    cls_pos_prob_maps = tf.reshape(cls_pos_prob, tf.shape(cls_maps)[:3])
    # node status is 1 where probability is higher than threshold
    node_labels = tf.cast(tf.greater_equal(cls_pos_prob_maps, FLAGS.node_threshold),
                          tf.int32)

    # link classification
    lnk_prob = tf.nn.softmax(tf.reshape(lnk_maps, [-1, 2]))
    lnk_pos_prob = lnk_prob[:, model.POS_LABEL]
    lnk_shape = tf.shape(lnk_maps)
    lnk_pos_prob_maps = tf.reshape(lnk_pos_prob,
                                   [lnk_shape[0], lnk_shape[1], lnk_shape[2], -1])
    # link status is 1 where probability is higher than threshold
    link_labels = tf.cast(tf.greater_equal(lnk_pos_prob_maps, FLAGS.link_threshold),
                          tf.int32)

    all_nodes.append(node_labels)
    all_links.append(link_labels)
    all_reg.append(reg_maps)

    fetches['link_labels_%d' % i] = link_labels

  # decode segments and links
  segments, group_indices, segment_counts = ops.decode_segments_links(
    image_size, all_nodes, all_links, all_reg,
    anchor_sizes=list(detector.anchor_sizes))
  fetches['segments'] = segments
  fetches['group_indices'] = group_indices
  fetches['segment_counts'] = segment_counts

  # combine segments
  combined_rboxes, combined_counts = ops.combine_segments(
    segments, group_indices, segment_counts)
  fetches['combined_rboxes'] = combined_rboxes
  fetches['combined_counts'] = combined_counts

  sess_config = tf.ConfigProto()
  with tf.Session(config=sess_config) as sess:
    # load model
    model_loader = tf.train.Saver()
    model_loader.restore(sess, FLAGS.test_model)

    batch_size = FLAGS.test_batch_size
    n_batches = int(math.ceil(FLAGS.num_test / batch_size))

    # result directory
    result_dir = os.path.join(FLAGS.log_dir, 'results' + FLAGS.result_suffix)
    utils.mkdir_if_not_exist(result_dir)

    intermediate_result_path = os.path.join(FLAGS.log_dir, 'intermediate.pkl')
    if FLAGS.load_intermediate:
      all_batches = joblib.load(intermediate_result_path)
      logging.info('Intermediate result loaded from {}'.format(intermediate_result_path))
    else:
      # run all batches and store results in a list
      all_batches = []
      with slim.queues.QueueRunners(sess):
        for i in range(n_batches):
          if i % 10 == 0:
            logging.info('Evaluating batch %d/%d' % (i+1, n_batches))
          sess_outputs = sess.run(fetches)
          all_batches.append(sess_outputs)
      if FLAGS.save_intermediate:
        joblib.dump(all_batches, intermediate_result_path, compress=5)
        logging.info('Intermediate result saved to {}'.format(intermediate_result_path))

    # # visualize local rboxes (TODO)
    # if FLAGS.save_vis:
    #   vis_save_prefix = os.path.join(save_dir, 'localpred_batch_%d_' % i)
    #   pred_rboxes_counts = []
    #   for j in range(len(all_maps)):
    #     pred_rboxes_counts.append((sess_outputs['segments_det_%d' % j],
    #                               sess_outputs['segment_counts_det_%d' % j]))
    #   _visualize_layer_det(sess_outputs['images'],
    #                       pred_rboxes_counts,
    #                       vis_save_prefix)

    # # visualize joined rboxes (TODO)
    # if FLAGS.save_vis:
    #   vis_save_prefix = os.path.join(save_dir, 'batch_%d_' % i)
    #   # _visualize_linked_det(sess_outputs, save_prefix)
    #   _visualize_combined_rboxes(sess_outputs, vis_save_prefix)

    if FLAGS.result_format == 'icdar_2015_inc':
      postprocess_and_write_results_ic15(all_batches, result_dir)
    elif FLAGS.result_format == 'icdar_2013':
      postprocess_and_write_results_ic13(all_batches, result_dir)
    else:
      logging.critical('Unknown result format: {}'.format(FLAGS.result_format))
      sys.exit(1)
  
  logging.info('Evaluation done.')
Пример #3
0
  def _setup_train_net_multigpu(self):
    with tf.device('/cpu:0'):
      # learning rate decay
      with tf.name_scope('lr_decay'):
        if FLAGS.lr_policy == 'staircase':
          # decayed learning rate
          lr_breakpoints = [int(o) for o in FLAGS.lr_breakpoints.split(',')]
          lr_decays = [float(o) for o in FLAGS.lr_decays.split(',')]
          assert(len(lr_breakpoints) == len(lr_decays))
          pred_fn_pairs = []
          for lr_decay, lr_breakpoint in zip(lr_decays, lr_breakpoints):
            fn = (lambda o: lambda: tf.constant(o, tf.float32))(lr_decay)
            pred_fn_pairs.append((tf.less(self.global_step, lr_breakpoint), fn))
          lr_decay = tf.case(pred_fn_pairs, default=(lambda: tf.constant(1.0)))
        else:
          logging.error('Unkonw lr_policy: {}'.format(FLAGS.lr_policy))
          sys.exit(1)

        self.current_lr = lr_decay * FLAGS.base_lr
        tf.summary.scalar('lr', self.current_lr, collections=['brief'])

      # input data
      # batch_size = int(FLAGS.train_batch_size / FLAGS.n_gpu)
      with tf.name_scope('input_data'):
        batch_size = FLAGS.train_batch_size
        train_datasets = FLAGS.train_datasets.split(';')
        train_pstreams_list = []
        for i, dataset in enumerate(train_datasets):
          if not os.path.exists(dataset):
            logging.critical('Could not find dataset {}'.format(dataset))
            sys.exit(1)
          logging.info('Added training dataset #{}: {}'.format(i, dataset))
          train_streams = data.input_stream(dataset)
          train_pstreams = data.train_preprocess(train_streams)
          train_pstreams_list.append(train_pstreams)
        capacity = batch_size * 50
        min_after_dequeue = batch_size * 3
        train_batch = tf.train.shuffle_batch_join(train_pstreams_list,
                                                  batch_size,
                                                  capacity=capacity,
                                                  min_after_dequeue=min_after_dequeue)
        logging.info('Batch size {}; capacity: {}; min_after_dequeue: {}'.format(batch_size, capacity, min_after_dequeue))

        # split batch into sub-batches for each GPU
        sub_batch_size = int(FLAGS.train_batch_size / FLAGS.n_gpu)
        logging.info('Batch size is {} on each of the {} GPUs'.format(sub_batch_size, FLAGS.n_gpu))
        sub_batches = []
        for i in range(FLAGS.n_gpu):
          sub_batch = {}
          for k, v in train_batch.items():
            sub_batch[k] = v[i*sub_batch_size : (i+1)*sub_batch_size]
          sub_batches.append(sub_batch)

      if FLAGS.optimizer == 'sgd':
        optimizer = tf.train.MomentumOptimizer(self.current_lr, FLAGS.momentum)
        logging.info('Using SGD optimizer. Momentum={}'.format(FLAGS.momentum))
      elif FLAGS.optimizer == 'adam':
        optimizer = tf.train.AdamOptimizer(self.current_lr)
        logging.info('Using ADAM optimizer.')
      elif FLAGS.optimizer == 'rmsprop':
        optimizer = tf.train.RMSPropOptimizer(self.current_lr)
        logging.info('Using RMSProp optimizer.')
      else:
        logging.critical('Unsupported optimizer {}'.format(FLAGS.optimizer))
        sys.exit(1)

      # construct towers
      tower_gradients = []
      tower_losses = []
      for i in range(FLAGS.n_gpu):
        logging.info('Setting up tower %d' % i)
        with tf.device('/gpu:%d' % i):
          # variables are shared
          with tf.variable_scope(tf.get_variable_scope(), reuse=(i > 0)):
            with tf.name_scope('tower_%d' % i):
              loss = self._tower_loss(sub_batches[i])
              # tf.get_variable_scope().reuse_variables()
              gradients = optimizer.compute_gradients(loss)
              tower_gradients.append(gradients)
              tower_losses.append(loss)

      # average loss and gradients
      self.loss = tf.truediv(tf.add_n(tower_losses), float(len(tower_losses)),
                             name='average_tower_loss')
      tf.summary.scalar('total_loss', self.loss, collections=['brief'])
      with tf.name_scope('average_gradients'):
        grads = self._average_gradients(tower_gradients)

      # update variables
      with tf.variable_scope('optimizer'):
        self.train_op = optimizer.apply_gradients(grads, global_step=self.global_step)

      # setup summaries
      for var in tf.all_variables():
        # remove the illegal ":x" part from the variable name
        summary_name = 'parameters/' + var.name.split(':')[0]
        tf.summary.histogram(summary_name, var, collections=['detailed'])
      
      self.brief_summary_op = tf.summary.merge_all(key='brief')
      self.detailed_summary_op = tf.summary.merge_all(key='detailed')
Пример #4
0
def test_encode_decode_real_data():
  save_dir = '../vis/gt_link_node/'
  utils.mkdir_if_not_exist(save_dir)
  batch_size = 233

  streams = data.input_stream(FLAGS.train_record_path)
  pstreams = data.train_preprocess(streams)
  batch = tf.train.batch(pstreams, batch_size, num_threads=1, capacity=100)

  image_h = tf.shape(batch['image'])[1]
  image_w = tf.shape(batch['image'])[2]
  image_size = tf.pack([image_h, image_w])

  detector = model_fctd.FctdDetector()
  all_maps = detector.build_model(batch['image'])

  det_layers = ['det_conv4_3', 'det_fc7', 'det_conv6',
                'det_conv7', 'det_conv8', 'det_pool6']

  fetches = {}
  fetches['images'] = batch['image']
  fetches['image_size'] = image_size

  for i, det_layer in enumerate(det_layers):
    cls_maps, lnk_maps, reg_maps = all_maps[i]
    map_h, map_w = tf.shape(cls_maps)[1], tf.shape(cls_maps)[2]
    map_size = tf.pack([map_h, map_w])

    node_status_below = tf.constant([[[0]]], dtype=tf.int32)
    match_indices_below = tf.constant([[[0]]], dtype=tf.int32)
    cross_links = False # FIXME

    node_status, link_status, local_gt, match_indices = ops.encode_groundtruth(
        batch['rboxes'],
        batch['count'],
        map_size,
        image_size,
        node_status_below,
        match_indices_below,
        region_size=detector.region_sizes[i],
        pos_scale_diff_thresh=FLAGS.pos_scale_diff_threshold,
        neg_scale_diff_thresh=FLAGS.neg_scale_diff_threshold,
        cross_links=cross_links)

    fetches['node_status_%d' % i] = node_status
    fetches['link_status_%d' % i] = link_status
    fetches['local_gt_%d' % i] = local_gt

  def _visualize_nodes_links(ax, image, node_status, link_status, image_size):
    """
    Visualize nodes and links of one example.
    ARGS
      `node_status`: int [map_h, map_w]
      `link_status`: int [map_h, map_w, n_links]
      `image_size`: int [2]
    """
    ax.clear()
    image_display = vis.convert_image_for_visualization(
        image, mean_subtracted=True)
    ax.imshow(image_display)

    vis.visualize_nodes(ax, node_status, image_size)
    vis.visualize_links(ax, link_status, image_size)

  with tf.Session() as sess:
    sess.run(tf.initialize_all_variables())
    tf.train.start_queue_runners(sess=sess)

    sess_outputs = sess.run(fetches)

    fig = plt.figure()
    for i in xrange(batch_size):
      fig.clear()
      for j, det_layer in enumerate(det_layers):
        ax = fig.add_subplot(2, 3, j+1)
        _visualize_nodes_links(ax,
                               sess_outputs['images'][i],
                               sess_outputs['node_status_%d' % j][i],
                               sess_outputs['link_status_%d' % j][i],
                               sess_outputs['image_size'])

      save_path = os.path.join(save_dir, 'gt_node_link_%04d.jpg' % i)
      plt.savefig(save_path, dpi=200)
      print('Visualization saved to %s' % save_path)