def test_labels_blankid_to_last(self): ''' unit test case for the labels_blankid_to_last interface ''' with self.cached_session(): with self.assertRaises(AssertionError) as assert_err: labels = ctc_utils.labels_blankid_to_last(labels=self.labels, blank_index=0, num_class=None) the_exception = assert_err.exception self.assertEqual(str(the_exception), 'The num_class should not be None!') labels = ctc_utils.labels_blankid_to_last(labels=tf.constant( self.labels), blank_index=0, num_class=6) labels_values = np.asarray([0, 0, 0, 2, 0, 0, 0]) labels_index = np.asarray([[0, 0], [0, 1], [0, 2], [0, 3], [1, 0], [1, 1], [1, 2]]) labels_shape = np.asarray([2, 4]) self.assertAllEqual(labels.eval().values, labels_values) self.assertAllEqual(labels.eval().indices, labels_index) self.assertAllEqual(labels.eval().dense_shape, labels_shape) labels = ctc_utils.labels_blankid_to_last(labels=tf.constant( self.labels), blank_index=2, num_class=6) labels_values = np.asarray([1, 1, 1, 2, 1, 1, 1]) labels_index = np.asarray([[0, 0], [0, 1], [0, 2], [0, 3], [1, 0], [1, 1], [1, 2]]) labels_shape = np.asarray([2, 4]) self.assertAllEqual(labels.eval().values, labels_values) self.assertAllEqual(labels.eval().indices, labels_index) self.assertAllEqual(labels.eval().dense_shape, labels_shape)
def ctc_data_transform(labels, logits, blank_index): ''' data transform according blank_index ''' logits = ctc_utils.logits_blankid_to_last( logits=logits, blank_index=blank_index) num_class = logits.shape[2] labels = ctc_utils.labels_blankid_to_last( labels=labels, blank_index=blank_index, num_class=num_class) return labels, logits