Beispiel #1
0
 def test_recompute_hash(self):
     token1 = LatticeDecoder.Token(history=[1, 12, 203, 3004, 23455])
     token2 = LatticeDecoder.Token(history=[2, 12, 203, 3004, 23455])
     token1.recompute_hash(None)
     token2.recompute_hash(None)
     self.assertNotEqual(token1.recombination_hash,
                         token2.recombination_hash)
     token1.recompute_hash(5)
     token2.recompute_hash(5)
     self.assertNotEqual(token1.recombination_hash,
                         token2.recombination_hash)
     token1.recompute_hash(4)
     token2.recompute_hash(4)
     self.assertEqual(token1.recombination_hash, token2.recombination_hash)
Beispiel #2
0
    def test_recompute_total(self):
        token = LatticeDecoder.Token(history=[1, 2],
                                     ac_logprob=math.log(0.1),
                                     lat_lm_logprob=math.log(0.2),
                                     nn_lm_logprob=math.log(0.3))
        token.recompute_total(0.25, 1.0, 0.0, True)
        assert_almost_equal(token.lm_logprob,
                            math.log(0.25 * 0.3 + 0.75 * 0.2))
        assert_almost_equal(token.total_logprob,
                            math.log(0.1 * (0.25 * 0.3 + 0.75 * 0.2)))
        token.recompute_total(0.25, 1.0, 0.0, False)
        assert_almost_equal(token.lm_logprob,
                            0.25 * math.log(0.3) + 0.75 * math.log(0.2))
        assert_almost_equal(
            token.total_logprob,
            math.log(0.1) + 0.25 * math.log(0.3) + 0.75 * math.log(0.2))
        token.recompute_total(0.25, 10.0, 0.0, True)
        assert_almost_equal(token.lm_logprob,
                            math.log(0.25 * 0.3 + 0.75 * 0.2))
        assert_almost_equal(
            token.total_logprob,
            math.log(0.1) + math.log(0.25 * 0.3 + 0.75 * 0.2) * 10.0)
        token.recompute_total(0.25, 10.0, 0.0, False)
        assert_almost_equal(token.lm_logprob,
                            0.25 * math.log(0.3) + 0.75 * math.log(0.2))
        assert_almost_equal(
            token.total_logprob,
            math.log(0.1) +
            (0.25 * math.log(0.3) + 0.75 * math.log(0.2)) * 10.0)
        token.recompute_total(0.25, 10.0, -20.0, True)
        assert_almost_equal(token.lm_logprob,
                            math.log(0.25 * 0.3 + 0.75 * 0.2))
        assert_almost_equal(
            token.total_logprob,
            math.log(0.1) + math.log(0.25 * 0.3 + 0.75 * 0.2) * 10.0 - 40.0)
        token.recompute_total(0.25, 10.0, -20.0, False)
        assert_almost_equal(token.lm_logprob,
                            0.25 * math.log(0.3) + 0.75 * math.log(0.2))
        assert_almost_equal(
            token.total_logprob,
            math.log(0.1) +
            (0.25 * math.log(0.3) + 0.75 * math.log(0.2)) * 10.0 - 40.0)

        token = LatticeDecoder.Token(history=[1, 2],
                                     ac_logprob=-1000,
                                     lat_lm_logprob=-1001,
                                     nn_lm_logprob=-1002)
        token.recompute_total(0.75, 1.0, 0.0, True)
        # ln(exp(-1000) * (0.75 * exp(-1002) + 0.25 * exp(-1001)))
        assert_almost_equal(token.total_logprob, -2001.64263, decimal=4)
    def test_append_word(self):
        decoding_options = {
            'nnlm_weight': 1.0,
            'lm_scale': 1.0,
            'wi_penalty': 0.0,
            'ignore_unk': False,
            'unk_penalty': 0.0,
            'linear_interpolation': False,
            'max_tokens_per_node': 10,
            'beam': None,
            'recombination_order': None
        }

        initial_state = RecurrentState(self.network.recurrent_state_size)
        token1 = LatticeDecoder.Token(history=[self.sos_id], state=initial_state)
        token2 = LatticeDecoder.Token(history=[self.sos_id, self.yksi_id], state=initial_state)
        decoder = LatticeDecoder(self.network, decoding_options)

        self.assertSequenceEqual(token1.history, [self.sos_id])
        self.assertSequenceEqual(token2.history, [self.sos_id, self.yksi_id])
        assert_equal(token1.state.get(0), numpy.zeros(shape=(1,1,3)).astype(theano.config.floatX))
        assert_equal(token2.state.get(0), numpy.zeros(shape=(1,1,3)).astype(theano.config.floatX))
        self.assertEqual(token1.nn_lm_logprob, 0.0)
        self.assertEqual(token2.nn_lm_logprob, 0.0)

        decoder._append_word([token1, token2], self.kaksi_id)
        self.assertSequenceEqual(token1.history, [self.sos_id, self.kaksi_id])
        self.assertSequenceEqual(token2.history, [self.sos_id, self.yksi_id, self.kaksi_id])
        assert_equal(token1.state.get(0), numpy.ones(shape=(1,1,3)).astype(theano.config.floatX))
        assert_equal(token2.state.get(0), numpy.ones(shape=(1,1,3)).astype(theano.config.floatX))
        token1_nn_lm_logprob = math.log(self.sos_prob + self.kaksi_prob)
        token2_nn_lm_logprob = math.log(self.yksi_prob + self.kaksi_prob)
        self.assertAlmostEqual(token1.nn_lm_logprob, token1_nn_lm_logprob)
        self.assertAlmostEqual(token2.nn_lm_logprob, token2_nn_lm_logprob)

        decoder._append_word([token1, token2], self.eos_id)
        self.assertSequenceEqual(token1.history, [self.sos_id, self.kaksi_id, self.eos_id])
        self.assertSequenceEqual(token2.history, [self.sos_id, self.yksi_id, self.kaksi_id, self.eos_id])
        assert_equal(token1.state.get(0), numpy.ones(shape=(1,1,3)).astype(theano.config.floatX) * 2)
        assert_equal(token2.state.get(0), numpy.ones(shape=(1,1,3)).astype(theano.config.floatX) * 2)
        token1_nn_lm_logprob += math.log(self.kaksi_prob + self.eos_prob)
        token2_nn_lm_logprob += math.log(self.kaksi_prob + self.eos_prob)
        self.assertAlmostEqual(token1.nn_lm_logprob, token1_nn_lm_logprob)
        self.assertAlmostEqual(token2.nn_lm_logprob, token2_nn_lm_logprob)

        lm_scale = 2.0
        token1.recompute_total(1.0, lm_scale, -0.01)
        token2.recompute_total(1.0, lm_scale, -0.01)
        self.assertAlmostEqual(token1.total_logprob, token1_nn_lm_logprob * lm_scale - 0.03)
        self.assertAlmostEqual(token2.total_logprob, token2_nn_lm_logprob * lm_scale - 0.04)
Beispiel #4
0
 def test_copy_token(self):
     history = [1, 2, 3]
     token1 = LatticeDecoder.Token(history)
     token2 = LatticeDecoder.Token.copy(token1)
     token2.history.append(4)
     self.assertSequenceEqual(token1.history, [1, 2, 3])
     self.assertSequenceEqual(token2.history, [1, 2, 3, 4])
 def test_copy_token(self):
     history = (1, 2, 3)
     token1 = LatticeDecoder.Token(history)
     token2 = LatticeDecoder.Token.copy(token1)
     token2.history = token2.history + (4, )
     self.assertSequenceEqual(token1.history, (1, 2, 3))
     self.assertSequenceEqual(token2.history, (1, 2, 3, 4))
Beispiel #6
0
 def __init__(self):
     self._sorted_nodes = [Lattice.Node(id) for id in range(5)]
     self._sorted_nodes[0].time = 0.0
     self._sorted_nodes[1].time = 1.0
     self._sorted_nodes[2].time = 1.0
     self._sorted_nodes[3].time = None
     self._sorted_nodes[4].time = 3.0
     self._tokens = [[LatticeDecoder.Token()], [LatticeDecoder.Token()],
                     [
                         LatticeDecoder.Token(),
                         LatticeDecoder.Token(),
                         LatticeDecoder.Token()
                     ], [LatticeDecoder.Token()], []]
     self._tokens[0][0].total_logprob = -10.0
     self._tokens[0][0].recombination_hash = 1
     self._sorted_nodes[0].best_logprob = -10.0
     self._tokens[1][0].total_logprob = -20.0
     self._tokens[1][0].recombination_hash = 1
     self._sorted_nodes[1].best_logprob = -20.0
     self._tokens[2][0].total_logprob = -30.0
     self._tokens[2][0].recombination_hash = 1
     self._tokens[2][1].total_logprob = -50.0
     self._tokens[2][1].recombination_hash = 2
     self._tokens[2][2].total_logprob = -70.0
     self._tokens[2][2].recombination_hash = 3
     self._sorted_nodes[2].best_logprob = -30.0
     self._tokens[3][0].total_logprob = -100.0
     self._tokens[3][0].recombination_hash = 1
     self._sorted_nodes[3].best_logprob = -100.0
Beispiel #7
0
def decode(args):
    log_file = args.log_file
    log_level = getattr(logging, args.log_level.upper(), None)
    if not isinstance(log_level, int):
        print("Invalid logging level requested:", args.log_level)
        sys.exit(1)
    log_format = '%(asctime)s %(funcName)s: %(message)s'
    if args.log_file == '-':
        logging.basicConfig(stream=sys.stdout, format=log_format, level=log_level)
    else:
        logging.basicConfig(filename=log_file, format=log_format, level=log_level)

    if args.debug:
        theano.config.compute_test_value = 'warn'
    else:
        theano.config.compute_test_value = 'off'
    theano.config.profile = args.profile
    theano.config.profile_memory = args.profile

    with h5py.File(args.model_path, 'r') as state:
        print("Reading vocabulary from network state.")
        sys.stdout.flush()
        vocabulary = Vocabulary.from_state(state)
        print("Number of words in vocabulary:", vocabulary.num_words())
        print("Number of word classes:", vocabulary.num_classes())
        print("Building neural network.")
        sys.stdout.flush()
        architecture = Architecture.from_state(state)
        network = Network(architecture, vocabulary,
                          mode=Network.Mode(minibatch=False))
        print("Restoring neural network state.")
        sys.stdout.flush()
        network.set_state(state)

    log_scale = 1.0 if args.log_base is None else numpy.log(args.log_base)

    if args.wi_penalty is None:
        wi_penalty = None
    else:
        wi_penalty = args.wi_penalty * log_scale
    if args.unk_penalty is None:
        ignore_unk = False
        unk_penalty = None
    elif args.unk_penalty == 0:
        ignore_unk = True
        unk_penalty = None
    else:
        ignore_unk = False
        unk_penalty = args.unk_penalty
    decoding_options = {
        'nnlm_weight': args.nnlm_weight,
        'lm_scale': args.lm_scale,
        'wi_penalty': wi_penalty,
        'ignore_unk': ignore_unk,
        'unk_penalty': unk_penalty,
        'linear_interpolation': args.linear_interpolation,
        'max_tokens_per_node': args.max_tokens_per_node,
        'beam': args.beam,
        'recombination_order': args.recombination_order
    }
    logging.debug("DECODING OPTIONS")
    for option_name, option_value in decoding_options.items():
        logging.debug("%s: %s", option_name, str(option_value))

    print("Building word lattice decoder.")
    sys.stdout.flush()
    decoder = LatticeDecoder(network, decoding_options)

    # Combine paths from command line and lattice list.
    lattices = args.lattices
    lattices.extend(args.lattice_list.readlines())
    lattices = [path.strip() for path in lattices]
    # Ignore empty lines in the lattice list.
    lattices = list(filter(None, lattices))
    # Pick every Ith lattice, if --num-jobs is specified and > 1.
    if args.num_jobs < 1:
        print("Invalid number of jobs specified:", args.num_jobs)
        sys.exit(1)
    if (args.job < 0) or (args.job > args.num_jobs - 1):
        print("Invalid job specified:", args.job)
        sys.exit(1)
    lattices = lattices[args.job::args.num_jobs]

    file_type = TextFileType('r')
    for index, path in enumerate(lattices):
        logging.info("Reading word lattice: %s", path)
        lattice_file = file_type(path)
        lattice = SLFLattice(lattice_file)

        if not lattice.utterance_id is None:
            utterance_id = lattice.utterance_id
        else:
            utterance_id = os.path.basename(lattice_file.name)
        logging.info("Utterance `%s' -- %d/%d of job %d",
                     utterance_id,
                     index + 1,
                     len(lattices),
                     args.job)
        tokens = decoder.decode(lattice)

        for index in range(min(args.n_best, len(tokens))):
            line = format_token(tokens[index],
                                utterance_id,
                                vocabulary,
                                log_scale,
                                args.output)
            args.output_file.write(line + "\n")
Beispiel #8
0
    def test_decode(self):
        vocabulary = Vocabulary.from_word_counts({
            'TO': 1,
            'AND': 1,
            'IT': 1,
            'BUT': 1,
            'A.': 1,
            'IN': 1,
            'A': 1,
            'AT': 1,
            'THE': 1,
            'E.': 1,
            "DIDN'T": 1,
            'ELABORATE': 1
        })
        projection_vector = tensor.ones(shape=(vocabulary.num_words(), ),
                                        dtype=theano.config.floatX)
        projection_vector *= 0.05
        network = DummyNetwork(vocabulary, projection_vector)

        decoding_options = {
            'nnlm_weight': 0.0,
            'lm_scale': None,
            'wi_penalty': None,
            'ignore_unk': False,
            'unk_penalty': None,
            'linear_interpolation': True,
            'max_tokens_per_node': None,
            'beam': None,
            'recombination_order': None
        }
        decoder = LatticeDecoder(network, decoding_options)
        tokens = decoder.decode(self.lattice)

        # Compare tokens to n-best list given by SRILM lattice-tool.
        log_scale = math.log(10)

        print()
        for token in tokens:
            print(token.ac_logprob / log_scale,
                  token.lat_lm_logprob / log_scale,
                  token.total_logprob / log_scale,
                  ' '.join(vocabulary.id_to_word[token.history]))

        all_paths = [
            "<s> IT DIDN'T ELABORATE </s>", "<s> BUT IT DIDN'T ELABORATE </s>",
            "<s> THE DIDN'T ELABORATE </s>",
            "<s> AND IT DIDN'T ELABORATE </s>", "<s> E. DIDN'T ELABORATE </s>",
            "<s> IN IT DIDN'T ELABORATE </s>", "<s> A DIDN'T ELABORATE </s>",
            "<s> AT IT DIDN'T ELABORATE </s>",
            "<s> IT IT DIDN'T ELABORATE </s>",
            "<s> TO IT DIDN'T ELABORATE </s>",
            "<s> A. IT DIDN'T ELABORATE </s>", "<s> A IT DIDN'T ELABORATE </s>"
        ]
        paths = [
            ' '.join(vocabulary.id_to_word[token.history]) for token in tokens
        ]
        self.assertListEqual(paths, all_paths)

        token = tokens[0]
        history = ' '.join(vocabulary.id_to_word[token.history])
        self.assertAlmostEqual(token.ac_logprob / log_scale,
                               -8686.28,
                               places=2)
        self.assertAlmostEqual(token.lat_lm_logprob / log_scale,
                               -94.3896,
                               places=2)
        self.assertAlmostEqual(token.nn_lm_logprob, math.log(0.1) * 4)

        token = tokens[1]
        self.assertAlmostEqual(token.ac_logprob / log_scale,
                               -8743.96,
                               places=2)
        self.assertAlmostEqual(token.lat_lm_logprob / log_scale,
                               -111.488,
                               places=2)
        self.assertAlmostEqual(token.nn_lm_logprob, math.log(0.1) * 5)

        token = tokens[-1]
        self.assertAlmostEqual(token.ac_logprob / log_scale,
                               -8696.26,
                               places=2)
        self.assertAlmostEqual(token.lat_lm_logprob / log_scale,
                               -178.00,
                               places=2)
        self.assertAlmostEqual(token.nn_lm_logprob, math.log(0.1) * 5)
Beispiel #9
0
    def test_append_word(self):
        decoding_options = {
            'nnlm_weight': 1.0,
            'lm_scale': 1.0,
            'wi_penalty': 0.0,
            'ignore_unk': False,
            'unk_penalty': 0.0,
            'linear_interpolation': False,
            'max_tokens_per_node': 10,
            'beam': None,
            'recombination_order': None
        }

        initial_state = RecurrentState(self.network.recurrent_state_size)
        token1 = LatticeDecoder.Token(history=[self.sos_id],
                                      state=initial_state)
        token2 = LatticeDecoder.Token(history=[self.sos_id, self.yksi_id],
                                      state=initial_state)
        decoder = LatticeDecoder(self.network, decoding_options)

        self.assertSequenceEqual(token1.history, [self.sos_id])
        self.assertSequenceEqual(token2.history, [self.sos_id, self.yksi_id])
        assert_equal(token1.state.get(0),
                     numpy.zeros(shape=(1, 1, 3)).astype(theano.config.floatX))
        assert_equal(token2.state.get(0),
                     numpy.zeros(shape=(1, 1, 3)).astype(theano.config.floatX))
        self.assertEqual(token1.nn_lm_logprob, 0.0)
        self.assertEqual(token2.nn_lm_logprob, 0.0)

        decoder._append_word([token1, token2], self.kaksi_id)
        self.assertSequenceEqual(token1.history, [self.sos_id, self.kaksi_id])
        self.assertSequenceEqual(token2.history,
                                 [self.sos_id, self.yksi_id, self.kaksi_id])
        assert_equal(token1.state.get(0),
                     numpy.ones(shape=(1, 1, 3)).astype(theano.config.floatX))
        assert_equal(token2.state.get(0),
                     numpy.ones(shape=(1, 1, 3)).astype(theano.config.floatX))
        token1_nn_lm_logprob = math.log(self.sos_prob + self.kaksi_prob)
        token2_nn_lm_logprob = math.log(self.yksi_prob + self.kaksi_prob)
        self.assertAlmostEqual(token1.nn_lm_logprob, token1_nn_lm_logprob)
        self.assertAlmostEqual(token2.nn_lm_logprob, token2_nn_lm_logprob)

        decoder._append_word([token1, token2], self.eos_id)
        self.assertSequenceEqual(token1.history,
                                 [self.sos_id, self.kaksi_id, self.eos_id])
        self.assertSequenceEqual(
            token2.history,
            [self.sos_id, self.yksi_id, self.kaksi_id, self.eos_id])
        assert_equal(
            token1.state.get(0),
            numpy.ones(shape=(1, 1, 3)).astype(theano.config.floatX) * 2)
        assert_equal(
            token2.state.get(0),
            numpy.ones(shape=(1, 1, 3)).astype(theano.config.floatX) * 2)
        token1_nn_lm_logprob += math.log(self.kaksi_prob + self.eos_prob)
        token2_nn_lm_logprob += math.log(self.kaksi_prob + self.eos_prob)
        self.assertAlmostEqual(token1.nn_lm_logprob, token1_nn_lm_logprob)
        self.assertAlmostEqual(token2.nn_lm_logprob, token2_nn_lm_logprob)

        lm_scale = 2.0
        token1.recompute_total(1.0, lm_scale, -0.01)
        token2.recompute_total(1.0, lm_scale, -0.01)
        self.assertAlmostEqual(token1.total_logprob,
                               token1_nn_lm_logprob * lm_scale - 0.03)
        self.assertAlmostEqual(token2.total_logprob,
                               token2_nn_lm_logprob * lm_scale - 0.04)
Beispiel #10
0
def decode(args):
    """A function that performs the "theanolm decode" command.

    :type args: argparse.Namespace
    :param args: a collection of command line arguments
    """

    log_file = args.log_file
    log_level = getattr(logging, args.log_level.upper(), None)
    if not isinstance(log_level, int):
        print("Invalid logging level requested:", args.log_level)
        sys.exit(1)
    log_format = '%(asctime)s %(funcName)s: %(message)s'
    if args.log_file == '-':
        logging.basicConfig(stream=sys.stdout,
                            format=log_format,
                            level=log_level)
    else:
        logging.basicConfig(filename=log_file,
                            format=log_format,
                            level=log_level)

    if args.debug:
        theano.config.compute_test_value = 'warn'
    else:
        theano.config.compute_test_value = 'off'
    theano.config.profile = args.profile
    theano.config.profile_memory = args.profile

    network = Network.from_file(args.model_path,
                                mode=Network.Mode(minibatch=False))

    log_scale = 1.0 if args.log_base is None else numpy.log(args.log_base)

    if args.wi_penalty is None:
        wi_penalty = None
    else:
        wi_penalty = args.wi_penalty * log_scale
    if args.unk_penalty is None:
        ignore_unk = False
        unk_penalty = None
    elif args.unk_penalty == 0:
        ignore_unk = True
        unk_penalty = None
    else:
        ignore_unk = False
        unk_penalty = args.unk_penalty
    decoding_options = {
        'nnlm_weight': args.nnlm_weight,
        'lm_scale': args.lm_scale,
        'wi_penalty': wi_penalty,
        'ignore_unk': ignore_unk,
        'unk_penalty': unk_penalty,
        'linear_interpolation': args.linear_interpolation,
        'max_tokens_per_node': args.max_tokens_per_node,
        'beam': args.beam,
        'recombination_order': args.recombination_order
    }
    logging.debug("DECODING OPTIONS")
    for option_name, option_value in decoding_options.items():
        logging.debug("%s: %s", option_name, str(option_value))

    print("Building word lattice decoder.")
    sys.stdout.flush()
    decoder = LatticeDecoder(network, decoding_options)

    # Combine paths from command line and lattice list.
    lattices = args.lattices
    if args.lattice_list is not None:
        lattices.extend(args.lattice_list.readlines())
    lattices = [path.strip() for path in lattices]
    # Ignore empty lines in the lattice list.
    lattices = [x for x in lattices if x]
    # Pick every Ith lattice, if --num-jobs is specified and > 1.
    if args.num_jobs < 1:
        print("Invalid number of jobs specified:", args.num_jobs)
        sys.exit(1)
    if (args.job < 0) or (args.job > args.num_jobs - 1):
        print("Invalid job specified:", args.job)
        sys.exit(1)
    lattices = lattices[args.job::args.num_jobs]

    file_type = TextFileType('r')
    for index, path in enumerate(lattices):
        logging.info("Reading word lattice: %s", path)
        lattice_file = file_type(path)
        lattice = SLFLattice(lattice_file)

        if lattice.utterance_id is not None:
            utterance_id = lattice.utterance_id
        else:
            utterance_id = os.path.basename(lattice_file.name)
        logging.info("Utterance `%s' -- %d/%d of job %d", utterance_id,
                     index + 1, len(lattices), args.job)
        tokens = decoder.decode(lattice)

        for index in range(min(args.n_best, len(tokens))):
            line = format_token(tokens[index], utterance_id,
                                network.vocabulary, log_scale, args.output)
            args.output_file.write(line + "\n")
Beispiel #11
0
    def test_decode(self):
        vocabulary = Vocabulary.from_word_counts({
            'TO': 1,
            'AND': 1,
            'IT': 1,
            'BUT': 1,
            'A.': 1,
            'IN': 1,
            'A': 1,
            'AT': 1,
            'THE': 1,
            'E.': 1,
            "DIDN'T": 1,
            'ELABORATE': 1})
        projection_vector = tensor.ones(shape=(vocabulary.num_words(),),
                                        dtype=theano.config.floatX)
        projection_vector *= 0.05
        network = DummyNetwork(vocabulary, projection_vector)

        decoding_options = {
            'nnlm_weight': 0.0,
            'lm_scale': None,
            'wi_penalty': None,
            'ignore_unk': False,
            'unk_penalty': None,
            'linear_interpolation': True,
            'max_tokens_per_node': None,
            'beam': None,
            'recombination_order': None
        }
        decoder = LatticeDecoder(network, decoding_options)
        tokens = decoder.decode(self.lattice)

        # Compare tokens to n-best list given by SRILM lattice-tool.
        log_scale = math.log(10)

        print()
        for token in tokens:
            print(token.ac_logprob / log_scale,
                  token.lat_lm_logprob / log_scale,
                  token.total_logprob / log_scale,
                  ' '.join(vocabulary.id_to_word[token.history]))

        all_paths = ["<s> IT DIDN'T ELABORATE </s>",
                     "<s> BUT IT DIDN'T ELABORATE </s>",
                     "<s> THE DIDN'T ELABORATE </s>",
                     "<s> AND IT DIDN'T ELABORATE </s>",
                     "<s> E. DIDN'T ELABORATE </s>",
                     "<s> IN IT DIDN'T ELABORATE </s>",
                     "<s> A DIDN'T ELABORATE </s>",
                     "<s> AT IT DIDN'T ELABORATE </s>",
                     "<s> IT IT DIDN'T ELABORATE </s>",
                     "<s> TO IT DIDN'T ELABORATE </s>",
                     "<s> A. IT DIDN'T ELABORATE </s>",
                     "<s> A IT DIDN'T ELABORATE </s>"]
        paths = [' '.join(vocabulary.id_to_word[token.history])
                 for token in tokens]
        self.assertListEqual(paths, all_paths)

        token = tokens[0]
        history = ' '.join(vocabulary.id_to_word[token.history])
        self.assertAlmostEqual(token.ac_logprob / log_scale, -8686.28, places=2)
        self.assertAlmostEqual(token.lat_lm_logprob / log_scale, -94.3896, places=2)
        self.assertAlmostEqual(token.nn_lm_logprob, math.log(0.1) * 4)

        token = tokens[1]
        self.assertAlmostEqual(token.ac_logprob / log_scale, -8743.96, places=2)
        self.assertAlmostEqual(token.lat_lm_logprob / log_scale, -111.488, places=2)
        self.assertAlmostEqual(token.nn_lm_logprob, math.log(0.1) * 5)

        token = tokens[-1]
        self.assertAlmostEqual(token.ac_logprob / log_scale, -8696.26, places=2)
        self.assertAlmostEqual(token.lat_lm_logprob / log_scale, -178.00, places=2)
        self.assertAlmostEqual(token.nn_lm_logprob, math.log(0.1) * 5)
    def test_decode(self):
        vocabulary = Vocabulary.from_word_counts({
            'to': 1,
            'and': 1,
            'it': 1,
            'but': 1,
            'a.': 1,
            'in': 1,
            'a': 1,
            'at': 1,
            'the': 1,
            "didn't": 1,
            'elaborate': 1
        })
        projection_vector = tensor.ones(
            shape=(vocabulary.num_shortlist_words(), ),
            dtype=theano.config.floatX)
        projection_vector *= 0.05
        network = DummyNetwork(vocabulary, projection_vector)

        decoding_options = {
            'nnlm_weight': 0.0,
            'lm_scale': None,
            'wi_penalty': None,
            'unk_penalty': None,
            'use_shortlist': False,
            'unk_from_lattice': False,
            'linear_interpolation': True,
            'max_tokens_per_node': None,
            'beam': None,
            'recombination_order': 20
        }
        decoder = LatticeDecoder(network, decoding_options)
        tokens = decoder.decode(self.lattice)[0]

        # Compare tokens to n-best list given by SRILM lattice-tool.
        log_scale = math.log(10)

        print()
        for token in tokens:
            print(token.ac_logprob / log_scale,
                  token.lat_lm_logprob / log_scale,
                  token.total_logprob / log_scale,
                  ' '.join(token.history_words(vocabulary)))

        all_paths = [
            "<s> it didn't elaborate </s>", "<s> but it didn't elaborate </s>",
            "<s> the didn't elaborate </s>",
            "<s> and it didn't elaborate </s>", "<s> e. didn't elaborate </s>",
            "<s> in it didn't elaborate </s>", "<s> a didn't elaborate </s>",
            "<s> at it didn't elaborate </s>",
            "<s> it it didn't elaborate </s>",
            "<s> to it didn't elaborate </s>",
            "<s> a. it didn't elaborate </s>", "<s> a it didn't elaborate </s>"
        ]
        paths = [' '.join(token.history_words(vocabulary)) for token in tokens]
        self.assertListEqual(paths, all_paths)

        token = tokens[0]
        history = ' '.join(token.history_words(vocabulary))
        self.assertAlmostEqual(token.ac_logprob / log_scale,
                               -8686.28,
                               places=2)
        self.assertAlmostEqual(token.lat_lm_logprob / log_scale,
                               -94.3896,
                               places=2)
        self.assertAlmostEqual(token.nn_lm_logprob, math.log(0.1) * 4)

        token = tokens[1]
        self.assertAlmostEqual(token.ac_logprob / log_scale,
                               -8743.96,
                               places=2)
        self.assertAlmostEqual(token.lat_lm_logprob / log_scale,
                               -111.488,
                               places=2)
        self.assertAlmostEqual(token.nn_lm_logprob, math.log(0.1) * 5)

        token = tokens[-1]
        self.assertAlmostEqual(token.ac_logprob / log_scale,
                               -8696.26,
                               places=2)
        self.assertAlmostEqual(token.lat_lm_logprob / log_scale,
                               -178.00,
                               places=2)
        self.assertAlmostEqual(token.nn_lm_logprob, math.log(0.1) * 5)
Beispiel #13
0
def decode(args):
    """A function that performs the "theanolm decode" command.

    :type args: argparse.Namespace
    :param args: a collection of command line arguments
    """

    log_file = args.log_file
    log_level = getattr(logging, args.log_level.upper(), None)
    if not isinstance(log_level, int):
        print("Invalid logging level requested:", args.log_level,
              file=sys.stderr)
        sys.exit(1)
    log_format = '%(asctime)s %(funcName)s: %(message)s'
    if args.log_file == '-':
        logging.basicConfig(stream=sys.stdout, format=log_format, level=log_level)
    else:
        logging.basicConfig(filename=log_file, format=log_format, level=log_level)

    if args.debug:
        theano.config.compute_test_value = 'warn'
    else:
        theano.config.compute_test_value = 'off'
    theano.config.profile = args.profile
    theano.config.profile_memory = args.profile

    if (args.lattice_format == 'kaldi') or (args.output == 'kaldi'):
        if args.kaldi_vocabulary is None:
            print("Kaldi lattice vocabulary is not given.", file=sys.stderr)
            sys.exit(1)

    default_device = get_default_device(args.default_device)
    network = Network.from_file(args.model_path,
                                mode=Network.Mode(minibatch=False),
                                default_device=default_device)

    log_scale = 1.0 if args.log_base is None else numpy.log(args.log_base)
    if (args.log_base is not None) and (args.lattice_format == 'kaldi'):
        logging.info("Warning: Kaldi lattice reader doesn't support logarithm "
                     "base conversion.")

    if args.wi_penalty is None:
        wi_penalty = None
    else:
        wi_penalty = args.wi_penalty * log_scale
    decoding_options = {
        'nnlm_weight': args.nnlm_weight,
        'lm_scale': args.lm_scale,
        'wi_penalty': wi_penalty,
        'unk_penalty': args.unk_penalty,
        'use_shortlist': args.shortlist,
        'unk_from_lattice': args.unk_from_lattice,
        'linear_interpolation': args.linear_interpolation,
        'max_tokens_per_node': args.max_tokens_per_node,
        'beam': args.beam,
        'recombination_order': args.recombination_order,
        'prune_relative': args.prune_relative,
        'abs_min_max_tokens': args.abs_min_max_tokens,
        'abs_min_beam': args.abs_min_beam
    }
    logging.debug("DECODING OPTIONS")
    for option_name, option_value in decoding_options.items():
        logging.debug("%s: %s", option_name, str(option_value))

    logging.info("Building word lattice decoder.")
    decoder = LatticeDecoder(network, decoding_options)

    batch = LatticeBatch(args.lattices, args.lattice_list, args.lattice_format,
                         args.kaldi_vocabulary, args.num_jobs, args.job)
    for lattice_number, lattice in enumerate(batch):
        if lattice.utterance_id is None:
            lattice.utterance_id = str(lattice_number)
        logging.info("Utterance `%s´ -- %d of job %d",
                     lattice.utterance_id,
                     lattice_number + 1,
                     args.job)
        log_free_mem()

        final_tokens, recomb_tokens = decoder.decode(lattice)
        if (args.output == "slf") or (args.output == "kaldi"):
            rescored_lattice = RescoredLattice(lattice,
                                               final_tokens,
                                               recomb_tokens,
                                               network.vocabulary)
            rescored_lattice.lm_scale = args.lm_scale
            rescored_lattice.wi_penalty = args.wi_penalty
            if args.output == "slf":
                rescored_lattice.write_slf(args.output_file)
            else:
                assert args.output == "kaldi"
                rescored_lattice.write_kaldi(args.output_file,
                                             batch.kaldi_word_to_id)
        else:
            for token in final_tokens[:min(args.n_best, len(final_tokens))]:
                line = format_token(token,
                                    lattice.utterance_id,
                                    network.vocabulary,
                                    log_scale,
                                    args.output)
                args.output_file.write(line + "\n")
        gc.collect()