def test_regular2(self):
     best_logp, best_alig = ctc_alignment(self._logp, [1, 0], ctc_sym=2)
     torch.testing.assert_allclose(best_logp,
                                   math.log(0.5 * 0.5 * 0.5 * 0.2))
     self.assertEqual([1, 1, 0, 2], best_alig)
 def test_regular3(self):
     best_logp, best_alig = ctc_alignment(self._logp, [0, 2, 0, 2],
                                          ctc_sym=1)
     torch.testing.assert_allclose(best_logp,
                                   math.log(0.3 * 0.1 * 0.5 * 0.2))
     self.assertEqual([0, 2, 0, 2], best_alig)
 def test_repeated_label(self):
     best_logp, best_alig = ctc_alignment(self._logp, [2, 2], ctc_sym=0)
     torch.testing.assert_allclose(best_logp,
                                   math.log(0.2 * 0.4 * 0.5 * 0.2))
     self.assertEqual([2, 0, 0, 2], best_alig)
 def test_single_label2(self):
     best_logp, best_alig = ctc_alignment(self._logp, [2], ctc_sym=1)
     torch.testing.assert_allclose(best_logp,
                                   math.log(0.5 * 0.5 * 0.4 * 0.7))
     self.assertEqual([1, 1, 2, 1], best_alig)
 def test_single_label(self):
     best_logp, best_alig = ctc_alignment(self._logp, [1], ctc_sym=0)
     torch.testing.assert_allclose(best_logp,
                                   math.log(0.3 * 0.4 * 0.5 * 0.7))
     self.assertEqual([0, 0, 0, 1], best_alig)
 def test_empty_reference2(self):
     best_logp, best_alig = ctc_alignment(self._logp, [], ctc_sym=2)
     torch.testing.assert_allclose(best_logp,
                                   math.log(0.2 * 0.1 * 0.4 * 0.2))
     self.assertEqual([2, 2, 2, 2], best_alig)
 def test_empty_reference(self):
     best_logp, best_alig = ctc_alignment(self._logp, [], ctc_sym=0)
     torch.testing.assert_allclose(best_logp,
                                   math.log(0.3 * 0.4 * 0.5 * 0.1))
     self.assertEqual([0, 0, 0, 0], best_alig)