Exemple #1
0
    def test_recompute_total(self):
        token = LatticeDecoder.Token(history=[1, 2],
                                     ac_logprob=math.log(0.1),
                                     lat_lm_logprob=math.log(0.2),
                                     nn_lm_logprob=math.log(0.3))
        token.recompute_total(0.25, 1.0, 0.0, True)
        assert_almost_equal(token.lm_logprob,
                            math.log(0.25 * 0.3 + 0.75 * 0.2))
        assert_almost_equal(token.total_logprob,
                            math.log(0.1 * (0.25 * 0.3 + 0.75 * 0.2)))
        token.recompute_total(0.25, 1.0, 0.0, False)
        assert_almost_equal(token.lm_logprob,
                            0.25 * math.log(0.3) + 0.75 * math.log(0.2))
        assert_almost_equal(
            token.total_logprob,
            math.log(0.1) + 0.25 * math.log(0.3) + 0.75 * math.log(0.2))
        token.recompute_total(0.25, 10.0, 0.0, True)
        assert_almost_equal(token.lm_logprob,
                            math.log(0.25 * 0.3 + 0.75 * 0.2))
        assert_almost_equal(
            token.total_logprob,
            math.log(0.1) + math.log(0.25 * 0.3 + 0.75 * 0.2) * 10.0)
        token.recompute_total(0.25, 10.0, 0.0, False)
        assert_almost_equal(token.lm_logprob,
                            0.25 * math.log(0.3) + 0.75 * math.log(0.2))
        assert_almost_equal(
            token.total_logprob,
            math.log(0.1) +
            (0.25 * math.log(0.3) + 0.75 * math.log(0.2)) * 10.0)
        token.recompute_total(0.25, 10.0, -20.0, True)
        assert_almost_equal(token.lm_logprob,
                            math.log(0.25 * 0.3 + 0.75 * 0.2))
        assert_almost_equal(
            token.total_logprob,
            math.log(0.1) + math.log(0.25 * 0.3 + 0.75 * 0.2) * 10.0 - 40.0)
        token.recompute_total(0.25, 10.0, -20.0, False)
        assert_almost_equal(token.lm_logprob,
                            0.25 * math.log(0.3) + 0.75 * math.log(0.2))
        assert_almost_equal(
            token.total_logprob,
            math.log(0.1) +
            (0.25 * math.log(0.3) + 0.75 * math.log(0.2)) * 10.0 - 40.0)

        token = LatticeDecoder.Token(history=[1, 2],
                                     ac_logprob=-1000,
                                     lat_lm_logprob=-1001,
                                     nn_lm_logprob=-1002)
        token.recompute_total(0.75, 1.0, 0.0, True)
        # ln(exp(-1000) * (0.75 * exp(-1002) + 0.25 * exp(-1001)))
        assert_almost_equal(token.total_logprob, -2001.64263, decimal=4)
Exemple #2
0
 def test_recompute_hash(self):
     token1 = LatticeDecoder.Token(history=[1, 12, 203, 3004, 23455])
     token2 = LatticeDecoder.Token(history=[2, 12, 203, 3004, 23455])
     token1.recompute_hash(None)
     token2.recompute_hash(None)
     self.assertNotEqual(token1.recombination_hash,
                         token2.recombination_hash)
     token1.recompute_hash(5)
     token2.recompute_hash(5)
     self.assertNotEqual(token1.recombination_hash,
                         token2.recombination_hash)
     token1.recompute_hash(4)
     token2.recompute_hash(4)
     self.assertEqual(token1.recombination_hash, token2.recombination_hash)
Exemple #3
0
 def test_copy_token(self):
     history = [1, 2, 3]
     token1 = LatticeDecoder.Token(history)
     token2 = LatticeDecoder.Token.copy(token1)
     token2.history.append(4)
     self.assertSequenceEqual(token1.history, [1, 2, 3])
     self.assertSequenceEqual(token2.history, [1, 2, 3, 4])
 def test_copy_token(self):
     history = (1, 2, 3)
     token1 = LatticeDecoder.Token(history)
     token2 = LatticeDecoder.Token.copy(token1)
     token2.history = token2.history + (4, )
     self.assertSequenceEqual(token1.history, (1, 2, 3))
     self.assertSequenceEqual(token2.history, (1, 2, 3, 4))
Exemple #5
0
 def __init__(self):
     self._sorted_nodes = [Lattice.Node(id) for id in range(5)]
     self._sorted_nodes[0].time = 0.0
     self._sorted_nodes[1].time = 1.0
     self._sorted_nodes[2].time = 1.0
     self._sorted_nodes[3].time = None
     self._sorted_nodes[4].time = 3.0
     self._tokens = [[LatticeDecoder.Token()], [LatticeDecoder.Token()],
                     [
                         LatticeDecoder.Token(),
                         LatticeDecoder.Token(),
                         LatticeDecoder.Token()
                     ], [LatticeDecoder.Token()], []]
     self._tokens[0][0].total_logprob = -10.0
     self._tokens[0][0].recombination_hash = 1
     self._sorted_nodes[0].best_logprob = -10.0
     self._tokens[1][0].total_logprob = -20.0
     self._tokens[1][0].recombination_hash = 1
     self._sorted_nodes[1].best_logprob = -20.0
     self._tokens[2][0].total_logprob = -30.0
     self._tokens[2][0].recombination_hash = 1
     self._tokens[2][1].total_logprob = -50.0
     self._tokens[2][1].recombination_hash = 2
     self._tokens[2][2].total_logprob = -70.0
     self._tokens[2][2].recombination_hash = 3
     self._sorted_nodes[2].best_logprob = -30.0
     self._tokens[3][0].total_logprob = -100.0
     self._tokens[3][0].recombination_hash = 1
     self._sorted_nodes[3].best_logprob = -100.0
Exemple #6
0
    def test_append_word(self):
        decoding_options = {
            'nnlm_weight': 1.0,
            'lm_scale': 1.0,
            'wi_penalty': 0.0,
            'ignore_unk': False,
            'unk_penalty': 0.0,
            'linear_interpolation': False,
            'max_tokens_per_node': 10,
            'beam': None,
            'recombination_order': None
        }

        initial_state = RecurrentState(self.network.recurrent_state_size)
        token1 = LatticeDecoder.Token(history=[self.sos_id],
                                      state=initial_state)
        token2 = LatticeDecoder.Token(history=[self.sos_id, self.yksi_id],
                                      state=initial_state)
        decoder = LatticeDecoder(self.network, decoding_options)

        self.assertSequenceEqual(token1.history, [self.sos_id])
        self.assertSequenceEqual(token2.history, [self.sos_id, self.yksi_id])
        assert_equal(token1.state.get(0),
                     numpy.zeros(shape=(1, 1, 3)).astype(theano.config.floatX))
        assert_equal(token2.state.get(0),
                     numpy.zeros(shape=(1, 1, 3)).astype(theano.config.floatX))
        self.assertEqual(token1.nn_lm_logprob, 0.0)
        self.assertEqual(token2.nn_lm_logprob, 0.0)

        decoder._append_word([token1, token2], self.kaksi_id)
        self.assertSequenceEqual(token1.history, [self.sos_id, self.kaksi_id])
        self.assertSequenceEqual(token2.history,
                                 [self.sos_id, self.yksi_id, self.kaksi_id])
        assert_equal(token1.state.get(0),
                     numpy.ones(shape=(1, 1, 3)).astype(theano.config.floatX))
        assert_equal(token2.state.get(0),
                     numpy.ones(shape=(1, 1, 3)).astype(theano.config.floatX))
        token1_nn_lm_logprob = math.log(self.sos_prob + self.kaksi_prob)
        token2_nn_lm_logprob = math.log(self.yksi_prob + self.kaksi_prob)
        self.assertAlmostEqual(token1.nn_lm_logprob, token1_nn_lm_logprob)
        self.assertAlmostEqual(token2.nn_lm_logprob, token2_nn_lm_logprob)

        decoder._append_word([token1, token2], self.eos_id)
        self.assertSequenceEqual(token1.history,
                                 [self.sos_id, self.kaksi_id, self.eos_id])
        self.assertSequenceEqual(
            token2.history,
            [self.sos_id, self.yksi_id, self.kaksi_id, self.eos_id])
        assert_equal(
            token1.state.get(0),
            numpy.ones(shape=(1, 1, 3)).astype(theano.config.floatX) * 2)
        assert_equal(
            token2.state.get(0),
            numpy.ones(shape=(1, 1, 3)).astype(theano.config.floatX) * 2)
        token1_nn_lm_logprob += math.log(self.kaksi_prob + self.eos_prob)
        token2_nn_lm_logprob += math.log(self.kaksi_prob + self.eos_prob)
        self.assertAlmostEqual(token1.nn_lm_logprob, token1_nn_lm_logprob)
        self.assertAlmostEqual(token2.nn_lm_logprob, token2_nn_lm_logprob)

        lm_scale = 2.0
        token1.recompute_total(1.0, lm_scale, -0.01)
        token2.recompute_total(1.0, lm_scale, -0.01)
        self.assertAlmostEqual(token1.total_logprob,
                               token1_nn_lm_logprob * lm_scale - 0.03)
        self.assertAlmostEqual(token2.total_logprob,
                               token2_nn_lm_logprob * lm_scale - 0.04)