def test_ctc_lambda_loss(self): ''' test ctc loss ''' with self.cached_session(): label_lens = np.expand_dims(np.asarray([5, 4]), 1) input_lens = np.expand_dims(np.asarray([5, 5]), 1) # number of timesteps loss_log_probs = [7.2771974, 8.057934] # dimensions are batch x time x categories labels = np.asarray([[0, 1, 2, 1, 0], [0, 1, 1, 0, -1]]) inputs = np.asarray( [[[0.633766, 0.221185, 0.0917319, 0.0129757, 0.0142857, 0.0260553], [0.111121, 0.588392, 0.278779, 0.0055756, 0.00569609, 0.010436], [0.0357786, 0.633813, 0.321418, 0.00249248, 0.00272882, 0.0037688], [0.0663296, 0.643849, 0.280111, 0.00283995, 0.0035545, 0.00331533], [0.458235, 0.396634, 0.123377, 0.00648837, 0.00903441, 0.00623107]], [[0.30176, 0.28562, 0.0831517, 0.0862751, 0.0816851, 0.161508], [0.24082, 0.397533, 0.0557226, 0.0546814, 0.0557528, 0.19549], [0.230246, 0.450868, 0.0389607, 0.038309, 0.0391602, 0.202456], [0.280884, 0.429522, 0.0326593, 0.0339046, 0.0326856, 0.190345], [0.423286, 0.315517, 0.0338439, 0.0393744, 0.0339315, 0.154046]]], dtype=np.float32) loss = loss_utils.ctc_lambda_loss( labels=tf.constant(labels), logits=tf.constant(inputs), input_length=tf.constant(input_lens), label_length=tf.constant(label_lens), blank_index=0) self.assertEqual(loss.eval().shape[0], inputs.shape[0]) self.assertAllClose(loss.eval(), loss_log_probs, atol=1e-05) self.assertAllClose( np.mean(loss.eval()), np.mean(loss_log_probs), atol=1e-05) # test when batch_size = 1, that is, one sample only ref = [7.277198] input_lens = np.asarray([5]) label_lens = np.asarray([5]) labels = np.asarray([[0, 1, 2, 1, 0]]) inputs = np.asarray( [[[0.633766, 0.221185, 0.0917319, 0.0129757, 0.0142857, 0.0260553], [0.111121, 0.588392, 0.278779, 0.0055756, 0.00569609, 0.010436], [0.0357786, 0.633813, 0.321418, 0.00249248, 0.00272882, 0.0037688], [0.0663296, 0.643849, 0.280111, 0.00283995, 0.0035545, 0.00331533], [0.458235, 0.396634, 0.123377, 0.00648837, 0.00903441, 0.00623107]] ], dtype=np.float32) loss = loss_utils.ctc_lambda_loss( labels=tf.constant(labels), logits=tf.constant(inputs), input_length=tf.constant(input_lens), label_length=tf.constant(label_lens), blank_index=0) self.assertAllClose(loss.eval(), ref, atol=1e-05) self.assertAllClose(np.mean(loss.eval()), np.mean(ref), atol=1e-05)
def call(self, logits=None, input_length=None, labels=None, label_length=None, soft_labels=None): del soft_labels return ctc_lambda_loss(logits=logits, input_length=input_length, labels=labels, label_length=label_length)
def call(self, logits=None, input_length=None, labels=None, label_length=None, **kwargs): blank_index = kwargs.get('blank_index', 0) return ctc_lambda_loss(logits=logits, input_length=input_length, labels=labels, label_length=label_length, blank_index=blank_index)