def test_recompute_total(self): token = LatticeDecoder.Token(history=[1, 2], ac_logprob=math.log(0.1), lat_lm_logprob=math.log(0.2), nn_lm_logprob=math.log(0.3)) token.recompute_total(0.25, 1.0, 0.0, True) assert_almost_equal(token.lm_logprob, math.log(0.25 * 0.3 + 0.75 * 0.2)) assert_almost_equal(token.total_logprob, math.log(0.1 * (0.25 * 0.3 + 0.75 * 0.2))) token.recompute_total(0.25, 1.0, 0.0, False) assert_almost_equal(token.lm_logprob, 0.25 * math.log(0.3) + 0.75 * math.log(0.2)) assert_almost_equal( token.total_logprob, math.log(0.1) + 0.25 * math.log(0.3) + 0.75 * math.log(0.2)) token.recompute_total(0.25, 10.0, 0.0, True) assert_almost_equal(token.lm_logprob, math.log(0.25 * 0.3 + 0.75 * 0.2)) assert_almost_equal( token.total_logprob, math.log(0.1) + math.log(0.25 * 0.3 + 0.75 * 0.2) * 10.0) token.recompute_total(0.25, 10.0, 0.0, False) assert_almost_equal(token.lm_logprob, 0.25 * math.log(0.3) + 0.75 * math.log(0.2)) assert_almost_equal( token.total_logprob, math.log(0.1) + (0.25 * math.log(0.3) + 0.75 * math.log(0.2)) * 10.0) token.recompute_total(0.25, 10.0, -20.0, True) assert_almost_equal(token.lm_logprob, math.log(0.25 * 0.3 + 0.75 * 0.2)) assert_almost_equal( token.total_logprob, math.log(0.1) + math.log(0.25 * 0.3 + 0.75 * 0.2) * 10.0 - 40.0) token.recompute_total(0.25, 10.0, -20.0, False) assert_almost_equal(token.lm_logprob, 0.25 * math.log(0.3) + 0.75 * math.log(0.2)) assert_almost_equal( token.total_logprob, math.log(0.1) + (0.25 * math.log(0.3) + 0.75 * math.log(0.2)) * 10.0 - 40.0) token = LatticeDecoder.Token(history=[1, 2], ac_logprob=-1000, lat_lm_logprob=-1001, nn_lm_logprob=-1002) token.recompute_total(0.75, 1.0, 0.0, True) # ln(exp(-1000) * (0.75 * exp(-1002) + 0.25 * exp(-1001))) assert_almost_equal(token.total_logprob, -2001.64263, decimal=4)
def test_recompute_hash(self): token1 = LatticeDecoder.Token(history=[1, 12, 203, 3004, 23455]) token2 = LatticeDecoder.Token(history=[2, 12, 203, 3004, 23455]) token1.recompute_hash(None) token2.recompute_hash(None) self.assertNotEqual(token1.recombination_hash, token2.recombination_hash) token1.recompute_hash(5) token2.recompute_hash(5) self.assertNotEqual(token1.recombination_hash, token2.recombination_hash) token1.recompute_hash(4) token2.recompute_hash(4) self.assertEqual(token1.recombination_hash, token2.recombination_hash)
def test_copy_token(self): history = [1, 2, 3] token1 = LatticeDecoder.Token(history) token2 = LatticeDecoder.Token.copy(token1) token2.history.append(4) self.assertSequenceEqual(token1.history, [1, 2, 3]) self.assertSequenceEqual(token2.history, [1, 2, 3, 4])
def test_copy_token(self): history = (1, 2, 3) token1 = LatticeDecoder.Token(history) token2 = LatticeDecoder.Token.copy(token1) token2.history = token2.history + (4, ) self.assertSequenceEqual(token1.history, (1, 2, 3)) self.assertSequenceEqual(token2.history, (1, 2, 3, 4))
def __init__(self): self._sorted_nodes = [Lattice.Node(id) for id in range(5)] self._sorted_nodes[0].time = 0.0 self._sorted_nodes[1].time = 1.0 self._sorted_nodes[2].time = 1.0 self._sorted_nodes[3].time = None self._sorted_nodes[4].time = 3.0 self._tokens = [[LatticeDecoder.Token()], [LatticeDecoder.Token()], [ LatticeDecoder.Token(), LatticeDecoder.Token(), LatticeDecoder.Token() ], [LatticeDecoder.Token()], []] self._tokens[0][0].total_logprob = -10.0 self._tokens[0][0].recombination_hash = 1 self._sorted_nodes[0].best_logprob = -10.0 self._tokens[1][0].total_logprob = -20.0 self._tokens[1][0].recombination_hash = 1 self._sorted_nodes[1].best_logprob = -20.0 self._tokens[2][0].total_logprob = -30.0 self._tokens[2][0].recombination_hash = 1 self._tokens[2][1].total_logprob = -50.0 self._tokens[2][1].recombination_hash = 2 self._tokens[2][2].total_logprob = -70.0 self._tokens[2][2].recombination_hash = 3 self._sorted_nodes[2].best_logprob = -30.0 self._tokens[3][0].total_logprob = -100.0 self._tokens[3][0].recombination_hash = 1 self._sorted_nodes[3].best_logprob = -100.0
def test_append_word(self): decoding_options = { 'nnlm_weight': 1.0, 'lm_scale': 1.0, 'wi_penalty': 0.0, 'ignore_unk': False, 'unk_penalty': 0.0, 'linear_interpolation': False, 'max_tokens_per_node': 10, 'beam': None, 'recombination_order': None } initial_state = RecurrentState(self.network.recurrent_state_size) token1 = LatticeDecoder.Token(history=[self.sos_id], state=initial_state) token2 = LatticeDecoder.Token(history=[self.sos_id, self.yksi_id], state=initial_state) decoder = LatticeDecoder(self.network, decoding_options) self.assertSequenceEqual(token1.history, [self.sos_id]) self.assertSequenceEqual(token2.history, [self.sos_id, self.yksi_id]) assert_equal(token1.state.get(0), numpy.zeros(shape=(1, 1, 3)).astype(theano.config.floatX)) assert_equal(token2.state.get(0), numpy.zeros(shape=(1, 1, 3)).astype(theano.config.floatX)) self.assertEqual(token1.nn_lm_logprob, 0.0) self.assertEqual(token2.nn_lm_logprob, 0.0) decoder._append_word([token1, token2], self.kaksi_id) self.assertSequenceEqual(token1.history, [self.sos_id, self.kaksi_id]) self.assertSequenceEqual(token2.history, [self.sos_id, self.yksi_id, self.kaksi_id]) assert_equal(token1.state.get(0), numpy.ones(shape=(1, 1, 3)).astype(theano.config.floatX)) assert_equal(token2.state.get(0), numpy.ones(shape=(1, 1, 3)).astype(theano.config.floatX)) token1_nn_lm_logprob = math.log(self.sos_prob + self.kaksi_prob) token2_nn_lm_logprob = math.log(self.yksi_prob + self.kaksi_prob) self.assertAlmostEqual(token1.nn_lm_logprob, token1_nn_lm_logprob) self.assertAlmostEqual(token2.nn_lm_logprob, token2_nn_lm_logprob) decoder._append_word([token1, token2], self.eos_id) self.assertSequenceEqual(token1.history, [self.sos_id, self.kaksi_id, self.eos_id]) self.assertSequenceEqual( token2.history, [self.sos_id, self.yksi_id, self.kaksi_id, self.eos_id]) assert_equal( token1.state.get(0), numpy.ones(shape=(1, 1, 3)).astype(theano.config.floatX) * 2) assert_equal( token2.state.get(0), numpy.ones(shape=(1, 1, 3)).astype(theano.config.floatX) * 2) token1_nn_lm_logprob += math.log(self.kaksi_prob + self.eos_prob) token2_nn_lm_logprob += math.log(self.kaksi_prob + self.eos_prob) self.assertAlmostEqual(token1.nn_lm_logprob, token1_nn_lm_logprob) self.assertAlmostEqual(token2.nn_lm_logprob, token2_nn_lm_logprob) lm_scale = 2.0 token1.recompute_total(1.0, lm_scale, -0.01) token2.recompute_total(1.0, lm_scale, -0.01) self.assertAlmostEqual(token1.total_logprob, token1_nn_lm_logprob * lm_scale - 0.03) self.assertAlmostEqual(token2.total_logprob, token2_nn_lm_logprob * lm_scale - 0.04)