def test_logits_blankid_to_last(self): ''' unit test case for the logits_blankid_to_last interface ''' with self.cached_session(): with self.assertRaises(ValueError) as valueErr: logits = ctc_utils.logits_blankid_to_last(logits=tf.constant(self.logits), blank_index=10) the_exception = valueErr.exception self.assertEqual( str(the_exception), 'blank_index must be less than or equal to num_class - 1') logits = ctc_utils.logits_blankid_to_last(logits=tf.constant(self.logits), blank_index=0) logits_transform = np.asarray( [[[0.221185, 0.0917319, 0.0129757, 0.0142857, 0.0260553, 0.633766], [0.588392, 0.278779, 0.0055756, 0.00569609, 0.010436, 0.111121], [0.633813, 0.321418, 0.00249248, 0.00272882, 0.0037688, 0.0357786], [0.643849, 0.280111, 0.00283995, 0.0035545, 0.00331533, 0.0663296], [0.196634, 0.123377, 0.50648837, 0.00903441, 0.00623107, 0.158235]], [[0.28562, 0.0831517, 0.0862751, 0.0816851, 0.161508, 0.30176], [0.397533, 0.0557226, 0.0546814, 0.0557528, 0.19549, 0.24082], [0.450868, 0.0389607, 0.038309, 0.0391602, 0.202456, 0.230246], [0.429522, 0.0326593, 0.0339046, 0.0326856, 0.190345, 0.280884], [0.315517, 0.0338439, 0.0393744, 0.0339315, 0.154046, 0.423286]]], dtype=np.float32) self.assertAllClose(logits.eval(), logits)
def ctc_decode_blankid_to_last(logits, sequence_length, blank_id=None): ''' Moves the blank_label cloumn to the end of the logit matrix, and adjust the rank of sequence_length to 1 param: logits, (B, T, C), output of ctc asr model param: sequence_length, (B, 1), sequence lengths param: blank_id, None, default blank_id is 0, same to espnet. return: logits_return, (T, B, C) sequence_length_return, (B) ''' logits = tf.transpose(logits, [1, 0, 2]) #blank_id=0 is used as default in Espnet, #while blank_id is set as C-1 in tf.nn.ctc_decoder if blank_id is None: blank_id = 0 logits = ctc_utils.logits_blankid_to_last(logits=logits, blank_index=blank_id) sequence_length_return = tf.cond( pred=tf.equal(tf.rank(sequence_length), 1), true_fn=lambda: sequence_length, false_fn=lambda: tf.squeeze(sequence_length), ) return logits, sequence_length_return, blank_id
def ctc_data_transform(labels, logits, blank_index): ''' data transform according blank_index ''' logits = ctc_utils.logits_blankid_to_last( logits=logits, blank_index=blank_index) num_class = logits.shape[2] labels = ctc_utils.labels_blankid_to_last( labels=labels, blank_index=blank_index, num_class=num_class) return labels, logits