def test_labels_blankid_to_last(self):
        ''' unit test case for the labels_blankid_to_last interface '''
        with self.cached_session():

            with self.assertRaises(AssertionError) as assert_err:
                labels = ctc_utils.labels_blankid_to_last(labels=self.labels,
                                                          blank_index=0,
                                                          num_class=None)
            the_exception = assert_err.exception
            self.assertEqual(str(the_exception),
                             'The num_class should not be None!')

            labels = ctc_utils.labels_blankid_to_last(labels=tf.constant(
                self.labels),
                                                      blank_index=0,
                                                      num_class=6)
            labels_values = np.asarray([0, 0, 0, 2, 0, 0, 0])
            labels_index = np.asarray([[0, 0], [0, 1], [0, 2], [0, 3], [1, 0],
                                       [1, 1], [1, 2]])
            labels_shape = np.asarray([2, 4])
            self.assertAllEqual(labels.eval().values, labels_values)
            self.assertAllEqual(labels.eval().indices, labels_index)
            self.assertAllEqual(labels.eval().dense_shape, labels_shape)

            labels = ctc_utils.labels_blankid_to_last(labels=tf.constant(
                self.labels),
                                                      blank_index=2,
                                                      num_class=6)
            labels_values = np.asarray([1, 1, 1, 2, 1, 1, 1])
            labels_index = np.asarray([[0, 0], [0, 1], [0, 2], [0, 3], [1, 0],
                                       [1, 1], [1, 2]])
            labels_shape = np.asarray([2, 4])
            self.assertAllEqual(labels.eval().values, labels_values)
            self.assertAllEqual(labels.eval().indices, labels_index)
            self.assertAllEqual(labels.eval().dense_shape, labels_shape)
Exemple #2
0
def ctc_data_transform(labels, logits, blank_index):
  '''
  data transform according blank_index
  '''
  logits = ctc_utils.logits_blankid_to_last(
      logits=logits, blank_index=blank_index)

  num_class = logits.shape[2]
  labels = ctc_utils.labels_blankid_to_last(
      labels=labels, blank_index=blank_index, num_class=num_class)
  return labels, logits