def __init__(self): self.detector = model.SegLinkDetector() # global TF variables with tf.device('/cpu:0'): self.global_step = tf.Variable(0, trainable=False, name='global_step', dtype=tf.int64) tf.summary.scalar('global_step', self.global_step, collections=['brief']) # setup training graphs and summaries self._setup_train_net_multigpu() # if true the training process will be terminated in the next iteration self.should_stop = False
def evaluate(): with tf.device('/cpu:0'): # input data streams = data.input_stream(FLAGS.test_dataset) pstreams = data.test_preprocess(streams) if FLAGS.test_resize_method == 'dynamic': # each test image is resized to a different size # test batch size must be 1 assert(FLAGS.test_batch_size == 1) batches = tf.train.batch(pstreams, FLAGS.test_batch_size, capacity=1000, num_threads=1, dynamic_pad=True) else: # resize every image to the same size batches = tf.train.batch(pstreams, FLAGS.test_batch_size, capacity=1000, num_threads=1) image_size = tf.shape(batches['image'])[1:3] fetches = {} fetches['images'] = batches['image'] fetches['image_name'] = batches['image_name'] fetches['resize_size'] = batches['resize_size'] fetches['orig_size'] = batches['orig_size'] # detector detector = model.SegLinkDetector() all_maps = detector.build_model(batches['image']) # decode local predictions all_nodes, all_links, all_reg = [], [], [] for i, maps in enumerate(all_maps): cls_maps, lnk_maps, reg_maps = maps reg_maps = tf.multiply(reg_maps, data.OFFSET_VARIANCE) # segments classification cls_prob = tf.nn.softmax(tf.reshape(cls_maps, [-1, 2])) cls_pos_prob = cls_prob[:, model.POS_LABEL] cls_pos_prob_maps = tf.reshape(cls_pos_prob, tf.shape(cls_maps)[:3]) # node status is 1 where probability is higher than threshold node_labels = tf.cast(tf.greater_equal(cls_pos_prob_maps, FLAGS.node_threshold), tf.int32) # link classification lnk_prob = tf.nn.softmax(tf.reshape(lnk_maps, [-1, 2])) lnk_pos_prob = lnk_prob[:, model.POS_LABEL] lnk_shape = tf.shape(lnk_maps) lnk_pos_prob_maps = tf.reshape(lnk_pos_prob, [lnk_shape[0], lnk_shape[1], lnk_shape[2], -1]) # link status is 1 where probability is higher than threshold link_labels = tf.cast(tf.greater_equal(lnk_pos_prob_maps, FLAGS.link_threshold), tf.int32) all_nodes.append(node_labels) all_links.append(link_labels) all_reg.append(reg_maps) fetches['link_labels_%d' % i] = link_labels # decode segments and links segments, group_indices, segment_counts = ops.decode_segments_links( image_size, all_nodes, all_links, all_reg, anchor_sizes=list(detector.anchor_sizes)) fetches['segments'] = segments fetches['group_indices'] = group_indices fetches['segment_counts'] = segment_counts # combine segments combined_rboxes, combined_counts = ops.combine_segments( segments, group_indices, segment_counts) fetches['combined_rboxes'] = combined_rboxes fetches['combined_counts'] = combined_counts sess_config = tf.ConfigProto() with tf.Session(config=sess_config) as sess: # load model model_loader = tf.train.Saver() model_loader.restore(sess, FLAGS.test_model) batch_size = FLAGS.test_batch_size n_batches = int(math.ceil(FLAGS.num_test / batch_size)) # result directory result_dir = os.path.join(FLAGS.log_dir, 'results' + FLAGS.result_suffix) utils.mkdir_if_not_exist(result_dir) intermediate_result_path = os.path.join(FLAGS.log_dir, 'intermediate.pkl') if FLAGS.load_intermediate: all_batches = joblib.load(intermediate_result_path) logging.info('Intermediate result loaded from {}'.format(intermediate_result_path)) else: # run all batches and store results in a list all_batches = [] with slim.queues.QueueRunners(sess): for i in range(n_batches): if i % 10 == 0: logging.info('Evaluating batch %d/%d' % (i+1, n_batches)) sess_outputs = sess.run(fetches) all_batches.append(sess_outputs) if FLAGS.save_intermediate: joblib.dump(all_batches, intermediate_result_path, compress=5) logging.info('Intermediate result saved to {}'.format(intermediate_result_path)) # # visualize local rboxes (TODO) # if FLAGS.save_vis: # vis_save_prefix = os.path.join(save_dir, 'localpred_batch_%d_' % i) # pred_rboxes_counts = [] # for j in range(len(all_maps)): # pred_rboxes_counts.append((sess_outputs['segments_det_%d' % j], # sess_outputs['segment_counts_det_%d' % j])) # _visualize_layer_det(sess_outputs['images'], # pred_rboxes_counts, # vis_save_prefix) # # visualize joined rboxes (TODO) # if FLAGS.save_vis: # vis_save_prefix = os.path.join(save_dir, 'batch_%d_' % i) # # _visualize_linked_det(sess_outputs, save_prefix) # _visualize_combined_rboxes(sess_outputs, vis_save_prefix) if FLAGS.result_format == 'icdar_2015_inc': postprocess_and_write_results_ic15(all_batches, result_dir) elif FLAGS.result_format == 'icdar_2013': postprocess_and_write_results_ic13(all_batches, result_dir) else: logging.critical('Unknown result format: {}'.format(FLAGS.result_format)) sys.exit(1) logging.info('Evaluation done.')