def testCountErrors(self):
     """Tests that the error counter works as expected.
 """
     truth_str = 'farm barn'
     counts = ec.CountErrors(ocr_text=truth_str, truth_text=truth_str)
     self.assertEqual(
         counts, ec.ErrorCounts(fn=0, fp=0, truth_count=9, test_count=9))
     # With a period on the end, we get a char error.
     dot_str = 'farm barn.'
     counts = ec.CountErrors(ocr_text=dot_str, truth_text=truth_str)
     self.assertEqual(
         counts, ec.ErrorCounts(fn=0, fp=1, truth_count=9, test_count=10))
     counts = ec.CountErrors(ocr_text=truth_str, truth_text=dot_str)
     self.assertEqual(
         counts, ec.ErrorCounts(fn=1, fp=0, truth_count=10, test_count=9))
     # Space is just another char.
     no_space = 'farmbarn'
     counts = ec.CountErrors(ocr_text=no_space, truth_text=truth_str)
     self.assertEqual(
         counts, ec.ErrorCounts(fn=1, fp=0, truth_count=9, test_count=8))
     counts = ec.CountErrors(ocr_text=truth_str, truth_text=no_space)
     self.assertEqual(
         counts, ec.ErrorCounts(fn=0, fp=1, truth_count=8, test_count=9))
     # Lose them all.
     counts = ec.CountErrors(ocr_text='', truth_text=truth_str)
     self.assertEqual(
         counts, ec.ErrorCounts(fn=9, fp=0, truth_count=9, test_count=0))
     counts = ec.CountErrors(ocr_text=truth_str, truth_text='')
     self.assertEqual(
         counts, ec.ErrorCounts(fn=0, fp=9, truth_count=0, test_count=9))
 def testCountWordErrors(self):
   """Tests that the error counter works as expected.
   """
   truth_str = 'farm barn'
   counts = ec.CountWordErrors(ocr_text=truth_str, truth_text=truth_str)
   self.assertEqual(
       counts, ec.ErrorCounts(
           fn=0, fp=0, truth_count=2, test_count=2))
   # With a period on the end, we get a word error.
   dot_str = 'farm barn.'
   counts = ec.CountWordErrors(ocr_text=dot_str, truth_text=truth_str)
   self.assertEqual(
       counts, ec.ErrorCounts(
           fn=1, fp=1, truth_count=2, test_count=2))
   counts = ec.CountWordErrors(ocr_text=truth_str, truth_text=dot_str)
   self.assertEqual(
       counts, ec.ErrorCounts(
           fn=1, fp=1, truth_count=2, test_count=2))
   # Space is special.
   no_space = 'farmbarn'
   counts = ec.CountWordErrors(ocr_text=no_space, truth_text=truth_str)
   self.assertEqual(
       counts, ec.ErrorCounts(
           fn=2, fp=1, truth_count=2, test_count=1))
   counts = ec.CountWordErrors(ocr_text=truth_str, truth_text=no_space)
   self.assertEqual(
       counts, ec.ErrorCounts(
           fn=1, fp=2, truth_count=1, test_count=2))
   # Lose them all.
   counts = ec.CountWordErrors(ocr_text='', truth_text=truth_str)
   self.assertEqual(
       counts, ec.ErrorCounts(
           fn=2, fp=0, truth_count=2, test_count=0))
   counts = ec.CountWordErrors(ocr_text=truth_str, truth_text='')
   self.assertEqual(
       counts, ec.ErrorCounts(
           fn=0, fp=2, truth_count=0, test_count=2))
   # With a space in ba rn, there is an extra add.
   sp_str = 'farm ba rn'
   counts = ec.CountWordErrors(ocr_text=sp_str, truth_text=truth_str)
   self.assertEqual(
       counts, ec.ErrorCounts(
           fn=1, fp=2, truth_count=2, test_count=3))
   counts = ec.CountWordErrors(ocr_text=truth_str, truth_text=sp_str)
   self.assertEqual(
       counts, ec.ErrorCounts(
           fn=2, fp=1, truth_count=3, test_count=2))
Beispiel #3
0
    def SoftmaxEval(self, sess, model, num_steps):
        """Evaluate a model in softmax mode.

    Adds char, word recall and sequence error rate events to the sw summary
    writer, and returns them as well
    TODO(rays) Add LogisticEval.
    Args:
      sess:  A tensor flow Session.
      model: The model to run in the session. Requires a VGSLImageModel or any
        other class that has a using_ctc attribute and a RunAStep(sess) method
        that reurns a softmax result with corresponding labels.
      num_steps: Number of steps to evaluate for.
    Returns:
      ErrorRates named tuple.
    Raises:
      ValueError: If an unsupported number of dimensions is used.
    """
        coord = tf.train.Coordinator()
        threads = tf.train.start_queue_runners(sess=sess, coord=coord)
        # Run the requested number of evaluation steps, gathering the outputs of the
        # softmax and the true labels of the evaluation examples.
        total_label_counts = ec.ErrorCounts(0, 0, 0, 0)
        total_word_counts = ec.ErrorCounts(0, 0, 0, 0)
        sequence_errors = 0
        for _ in xrange(num_steps):
            softmax_result, labels = model.RunAStep(sess)
            # Collapse softmax to same shape as labels.
            predictions = softmax_result.argmax(axis=-1)
            # Exclude batch from num_dims.
            num_dims = len(predictions.shape) - 1
            batch_size = predictions.shape[0]
            null_label = softmax_result.shape[-1] - 1
            for b in xrange(batch_size):
                if num_dims == 2:
                    # TODO(rays) Support 2-d data.
                    raise ValueError('2-d label data not supported yet!')
                else:
                    if num_dims == 1:
                        pred_batch = predictions[b, :]
                        labels_batch = labels[b, :]
                    else:
                        pred_batch = [predictions[b]]
                        labels_batch = [labels[b]]
                    text = self.StringFromCTC(pred_batch, model.using_ctc,
                                              null_label)
                    truth = self.StringFromCTC(labels_batch, False, null_label)
                    # Note that recall_errs is false negatives (fn) aka drops/deletions.
                    # Actual recall would be 1-fn/truth_words.
                    # Likewise precision_errs is false positives (fp) aka adds/insertions.
                    # Actual precision would be 1-fp/ocr_words.
                    total_word_counts = ec.AddErrors(
                        total_word_counts, ec.CountWordErrors(text, truth))
                    total_label_counts = ec.AddErrors(
                        total_label_counts, ec.CountErrors(text, truth))
                    if text != truth:
                        sequence_errors += 1

        coord.request_stop()
        coord.join(threads)
        return ec.ComputeErrorRates(total_label_counts, total_word_counts,
                                    sequence_errors, num_steps * batch_size)