def test_recompute_hash(self): token1 = LatticeDecoder.Token(history=[1, 12, 203, 3004, 23455]) token2 = LatticeDecoder.Token(history=[2, 12, 203, 3004, 23455]) token1.recompute_hash(None) token2.recompute_hash(None) self.assertNotEqual(token1.recombination_hash, token2.recombination_hash) token1.recompute_hash(5) token2.recompute_hash(5) self.assertNotEqual(token1.recombination_hash, token2.recombination_hash) token1.recompute_hash(4) token2.recompute_hash(4) self.assertEqual(token1.recombination_hash, token2.recombination_hash)
def test_recompute_total(self): token = LatticeDecoder.Token(history=[1, 2], ac_logprob=math.log(0.1), lat_lm_logprob=math.log(0.2), nn_lm_logprob=math.log(0.3)) token.recompute_total(0.25, 1.0, 0.0, True) assert_almost_equal(token.lm_logprob, math.log(0.25 * 0.3 + 0.75 * 0.2)) assert_almost_equal(token.total_logprob, math.log(0.1 * (0.25 * 0.3 + 0.75 * 0.2))) token.recompute_total(0.25, 1.0, 0.0, False) assert_almost_equal(token.lm_logprob, 0.25 * math.log(0.3) + 0.75 * math.log(0.2)) assert_almost_equal( token.total_logprob, math.log(0.1) + 0.25 * math.log(0.3) + 0.75 * math.log(0.2)) token.recompute_total(0.25, 10.0, 0.0, True) assert_almost_equal(token.lm_logprob, math.log(0.25 * 0.3 + 0.75 * 0.2)) assert_almost_equal( token.total_logprob, math.log(0.1) + math.log(0.25 * 0.3 + 0.75 * 0.2) * 10.0) token.recompute_total(0.25, 10.0, 0.0, False) assert_almost_equal(token.lm_logprob, 0.25 * math.log(0.3) + 0.75 * math.log(0.2)) assert_almost_equal( token.total_logprob, math.log(0.1) + (0.25 * math.log(0.3) + 0.75 * math.log(0.2)) * 10.0) token.recompute_total(0.25, 10.0, -20.0, True) assert_almost_equal(token.lm_logprob, math.log(0.25 * 0.3 + 0.75 * 0.2)) assert_almost_equal( token.total_logprob, math.log(0.1) + math.log(0.25 * 0.3 + 0.75 * 0.2) * 10.0 - 40.0) token.recompute_total(0.25, 10.0, -20.0, False) assert_almost_equal(token.lm_logprob, 0.25 * math.log(0.3) + 0.75 * math.log(0.2)) assert_almost_equal( token.total_logprob, math.log(0.1) + (0.25 * math.log(0.3) + 0.75 * math.log(0.2)) * 10.0 - 40.0) token = LatticeDecoder.Token(history=[1, 2], ac_logprob=-1000, lat_lm_logprob=-1001, nn_lm_logprob=-1002) token.recompute_total(0.75, 1.0, 0.0, True) # ln(exp(-1000) * (0.75 * exp(-1002) + 0.25 * exp(-1001))) assert_almost_equal(token.total_logprob, -2001.64263, decimal=4)
def test_append_word(self): decoding_options = { 'nnlm_weight': 1.0, 'lm_scale': 1.0, 'wi_penalty': 0.0, 'ignore_unk': False, 'unk_penalty': 0.0, 'linear_interpolation': False, 'max_tokens_per_node': 10, 'beam': None, 'recombination_order': None } initial_state = RecurrentState(self.network.recurrent_state_size) token1 = LatticeDecoder.Token(history=[self.sos_id], state=initial_state) token2 = LatticeDecoder.Token(history=[self.sos_id, self.yksi_id], state=initial_state) decoder = LatticeDecoder(self.network, decoding_options) self.assertSequenceEqual(token1.history, [self.sos_id]) self.assertSequenceEqual(token2.history, [self.sos_id, self.yksi_id]) assert_equal(token1.state.get(0), numpy.zeros(shape=(1,1,3)).astype(theano.config.floatX)) assert_equal(token2.state.get(0), numpy.zeros(shape=(1,1,3)).astype(theano.config.floatX)) self.assertEqual(token1.nn_lm_logprob, 0.0) self.assertEqual(token2.nn_lm_logprob, 0.0) decoder._append_word([token1, token2], self.kaksi_id) self.assertSequenceEqual(token1.history, [self.sos_id, self.kaksi_id]) self.assertSequenceEqual(token2.history, [self.sos_id, self.yksi_id, self.kaksi_id]) assert_equal(token1.state.get(0), numpy.ones(shape=(1,1,3)).astype(theano.config.floatX)) assert_equal(token2.state.get(0), numpy.ones(shape=(1,1,3)).astype(theano.config.floatX)) token1_nn_lm_logprob = math.log(self.sos_prob + self.kaksi_prob) token2_nn_lm_logprob = math.log(self.yksi_prob + self.kaksi_prob) self.assertAlmostEqual(token1.nn_lm_logprob, token1_nn_lm_logprob) self.assertAlmostEqual(token2.nn_lm_logprob, token2_nn_lm_logprob) decoder._append_word([token1, token2], self.eos_id) self.assertSequenceEqual(token1.history, [self.sos_id, self.kaksi_id, self.eos_id]) self.assertSequenceEqual(token2.history, [self.sos_id, self.yksi_id, self.kaksi_id, self.eos_id]) assert_equal(token1.state.get(0), numpy.ones(shape=(1,1,3)).astype(theano.config.floatX) * 2) assert_equal(token2.state.get(0), numpy.ones(shape=(1,1,3)).astype(theano.config.floatX) * 2) token1_nn_lm_logprob += math.log(self.kaksi_prob + self.eos_prob) token2_nn_lm_logprob += math.log(self.kaksi_prob + self.eos_prob) self.assertAlmostEqual(token1.nn_lm_logprob, token1_nn_lm_logprob) self.assertAlmostEqual(token2.nn_lm_logprob, token2_nn_lm_logprob) lm_scale = 2.0 token1.recompute_total(1.0, lm_scale, -0.01) token2.recompute_total(1.0, lm_scale, -0.01) self.assertAlmostEqual(token1.total_logprob, token1_nn_lm_logprob * lm_scale - 0.03) self.assertAlmostEqual(token2.total_logprob, token2_nn_lm_logprob * lm_scale - 0.04)
def test_copy_token(self): history = [1, 2, 3] token1 = LatticeDecoder.Token(history) token2 = LatticeDecoder.Token.copy(token1) token2.history.append(4) self.assertSequenceEqual(token1.history, [1, 2, 3]) self.assertSequenceEqual(token2.history, [1, 2, 3, 4])
def test_copy_token(self): history = (1, 2, 3) token1 = LatticeDecoder.Token(history) token2 = LatticeDecoder.Token.copy(token1) token2.history = token2.history + (4, ) self.assertSequenceEqual(token1.history, (1, 2, 3)) self.assertSequenceEqual(token2.history, (1, 2, 3, 4))
def __init__(self): self._sorted_nodes = [Lattice.Node(id) for id in range(5)] self._sorted_nodes[0].time = 0.0 self._sorted_nodes[1].time = 1.0 self._sorted_nodes[2].time = 1.0 self._sorted_nodes[3].time = None self._sorted_nodes[4].time = 3.0 self._tokens = [[LatticeDecoder.Token()], [LatticeDecoder.Token()], [ LatticeDecoder.Token(), LatticeDecoder.Token(), LatticeDecoder.Token() ], [LatticeDecoder.Token()], []] self._tokens[0][0].total_logprob = -10.0 self._tokens[0][0].recombination_hash = 1 self._sorted_nodes[0].best_logprob = -10.0 self._tokens[1][0].total_logprob = -20.0 self._tokens[1][0].recombination_hash = 1 self._sorted_nodes[1].best_logprob = -20.0 self._tokens[2][0].total_logprob = -30.0 self._tokens[2][0].recombination_hash = 1 self._tokens[2][1].total_logprob = -50.0 self._tokens[2][1].recombination_hash = 2 self._tokens[2][2].total_logprob = -70.0 self._tokens[2][2].recombination_hash = 3 self._sorted_nodes[2].best_logprob = -30.0 self._tokens[3][0].total_logprob = -100.0 self._tokens[3][0].recombination_hash = 1 self._sorted_nodes[3].best_logprob = -100.0
def decode(args): log_file = args.log_file log_level = getattr(logging, args.log_level.upper(), None) if not isinstance(log_level, int): print("Invalid logging level requested:", args.log_level) sys.exit(1) log_format = '%(asctime)s %(funcName)s: %(message)s' if args.log_file == '-': logging.basicConfig(stream=sys.stdout, format=log_format, level=log_level) else: logging.basicConfig(filename=log_file, format=log_format, level=log_level) if args.debug: theano.config.compute_test_value = 'warn' else: theano.config.compute_test_value = 'off' theano.config.profile = args.profile theano.config.profile_memory = args.profile with h5py.File(args.model_path, 'r') as state: print("Reading vocabulary from network state.") sys.stdout.flush() vocabulary = Vocabulary.from_state(state) print("Number of words in vocabulary:", vocabulary.num_words()) print("Number of word classes:", vocabulary.num_classes()) print("Building neural network.") sys.stdout.flush() architecture = Architecture.from_state(state) network = Network(architecture, vocabulary, mode=Network.Mode(minibatch=False)) print("Restoring neural network state.") sys.stdout.flush() network.set_state(state) log_scale = 1.0 if args.log_base is None else numpy.log(args.log_base) if args.wi_penalty is None: wi_penalty = None else: wi_penalty = args.wi_penalty * log_scale if args.unk_penalty is None: ignore_unk = False unk_penalty = None elif args.unk_penalty == 0: ignore_unk = True unk_penalty = None else: ignore_unk = False unk_penalty = args.unk_penalty decoding_options = { 'nnlm_weight': args.nnlm_weight, 'lm_scale': args.lm_scale, 'wi_penalty': wi_penalty, 'ignore_unk': ignore_unk, 'unk_penalty': unk_penalty, 'linear_interpolation': args.linear_interpolation, 'max_tokens_per_node': args.max_tokens_per_node, 'beam': args.beam, 'recombination_order': args.recombination_order } logging.debug("DECODING OPTIONS") for option_name, option_value in decoding_options.items(): logging.debug("%s: %s", option_name, str(option_value)) print("Building word lattice decoder.") sys.stdout.flush() decoder = LatticeDecoder(network, decoding_options) # Combine paths from command line and lattice list. lattices = args.lattices lattices.extend(args.lattice_list.readlines()) lattices = [path.strip() for path in lattices] # Ignore empty lines in the lattice list. lattices = list(filter(None, lattices)) # Pick every Ith lattice, if --num-jobs is specified and > 1. if args.num_jobs < 1: print("Invalid number of jobs specified:", args.num_jobs) sys.exit(1) if (args.job < 0) or (args.job > args.num_jobs - 1): print("Invalid job specified:", args.job) sys.exit(1) lattices = lattices[args.job::args.num_jobs] file_type = TextFileType('r') for index, path in enumerate(lattices): logging.info("Reading word lattice: %s", path) lattice_file = file_type(path) lattice = SLFLattice(lattice_file) if not lattice.utterance_id is None: utterance_id = lattice.utterance_id else: utterance_id = os.path.basename(lattice_file.name) logging.info("Utterance `%s' -- %d/%d of job %d", utterance_id, index + 1, len(lattices), args.job) tokens = decoder.decode(lattice) for index in range(min(args.n_best, len(tokens))): line = format_token(tokens[index], utterance_id, vocabulary, log_scale, args.output) args.output_file.write(line + "\n")
def test_decode(self): vocabulary = Vocabulary.from_word_counts({ 'TO': 1, 'AND': 1, 'IT': 1, 'BUT': 1, 'A.': 1, 'IN': 1, 'A': 1, 'AT': 1, 'THE': 1, 'E.': 1, "DIDN'T": 1, 'ELABORATE': 1 }) projection_vector = tensor.ones(shape=(vocabulary.num_words(), ), dtype=theano.config.floatX) projection_vector *= 0.05 network = DummyNetwork(vocabulary, projection_vector) decoding_options = { 'nnlm_weight': 0.0, 'lm_scale': None, 'wi_penalty': None, 'ignore_unk': False, 'unk_penalty': None, 'linear_interpolation': True, 'max_tokens_per_node': None, 'beam': None, 'recombination_order': None } decoder = LatticeDecoder(network, decoding_options) tokens = decoder.decode(self.lattice) # Compare tokens to n-best list given by SRILM lattice-tool. log_scale = math.log(10) print() for token in tokens: print(token.ac_logprob / log_scale, token.lat_lm_logprob / log_scale, token.total_logprob / log_scale, ' '.join(vocabulary.id_to_word[token.history])) all_paths = [ "<s> IT DIDN'T ELABORATE </s>", "<s> BUT IT DIDN'T ELABORATE </s>", "<s> THE DIDN'T ELABORATE </s>", "<s> AND IT DIDN'T ELABORATE </s>", "<s> E. DIDN'T ELABORATE </s>", "<s> IN IT DIDN'T ELABORATE </s>", "<s> A DIDN'T ELABORATE </s>", "<s> AT IT DIDN'T ELABORATE </s>", "<s> IT IT DIDN'T ELABORATE </s>", "<s> TO IT DIDN'T ELABORATE </s>", "<s> A. IT DIDN'T ELABORATE </s>", "<s> A IT DIDN'T ELABORATE </s>" ] paths = [ ' '.join(vocabulary.id_to_word[token.history]) for token in tokens ] self.assertListEqual(paths, all_paths) token = tokens[0] history = ' '.join(vocabulary.id_to_word[token.history]) self.assertAlmostEqual(token.ac_logprob / log_scale, -8686.28, places=2) self.assertAlmostEqual(token.lat_lm_logprob / log_scale, -94.3896, places=2) self.assertAlmostEqual(token.nn_lm_logprob, math.log(0.1) * 4) token = tokens[1] self.assertAlmostEqual(token.ac_logprob / log_scale, -8743.96, places=2) self.assertAlmostEqual(token.lat_lm_logprob / log_scale, -111.488, places=2) self.assertAlmostEqual(token.nn_lm_logprob, math.log(0.1) * 5) token = tokens[-1] self.assertAlmostEqual(token.ac_logprob / log_scale, -8696.26, places=2) self.assertAlmostEqual(token.lat_lm_logprob / log_scale, -178.00, places=2) self.assertAlmostEqual(token.nn_lm_logprob, math.log(0.1) * 5)
def test_append_word(self): decoding_options = { 'nnlm_weight': 1.0, 'lm_scale': 1.0, 'wi_penalty': 0.0, 'ignore_unk': False, 'unk_penalty': 0.0, 'linear_interpolation': False, 'max_tokens_per_node': 10, 'beam': None, 'recombination_order': None } initial_state = RecurrentState(self.network.recurrent_state_size) token1 = LatticeDecoder.Token(history=[self.sos_id], state=initial_state) token2 = LatticeDecoder.Token(history=[self.sos_id, self.yksi_id], state=initial_state) decoder = LatticeDecoder(self.network, decoding_options) self.assertSequenceEqual(token1.history, [self.sos_id]) self.assertSequenceEqual(token2.history, [self.sos_id, self.yksi_id]) assert_equal(token1.state.get(0), numpy.zeros(shape=(1, 1, 3)).astype(theano.config.floatX)) assert_equal(token2.state.get(0), numpy.zeros(shape=(1, 1, 3)).astype(theano.config.floatX)) self.assertEqual(token1.nn_lm_logprob, 0.0) self.assertEqual(token2.nn_lm_logprob, 0.0) decoder._append_word([token1, token2], self.kaksi_id) self.assertSequenceEqual(token1.history, [self.sos_id, self.kaksi_id]) self.assertSequenceEqual(token2.history, [self.sos_id, self.yksi_id, self.kaksi_id]) assert_equal(token1.state.get(0), numpy.ones(shape=(1, 1, 3)).astype(theano.config.floatX)) assert_equal(token2.state.get(0), numpy.ones(shape=(1, 1, 3)).astype(theano.config.floatX)) token1_nn_lm_logprob = math.log(self.sos_prob + self.kaksi_prob) token2_nn_lm_logprob = math.log(self.yksi_prob + self.kaksi_prob) self.assertAlmostEqual(token1.nn_lm_logprob, token1_nn_lm_logprob) self.assertAlmostEqual(token2.nn_lm_logprob, token2_nn_lm_logprob) decoder._append_word([token1, token2], self.eos_id) self.assertSequenceEqual(token1.history, [self.sos_id, self.kaksi_id, self.eos_id]) self.assertSequenceEqual( token2.history, [self.sos_id, self.yksi_id, self.kaksi_id, self.eos_id]) assert_equal( token1.state.get(0), numpy.ones(shape=(1, 1, 3)).astype(theano.config.floatX) * 2) assert_equal( token2.state.get(0), numpy.ones(shape=(1, 1, 3)).astype(theano.config.floatX) * 2) token1_nn_lm_logprob += math.log(self.kaksi_prob + self.eos_prob) token2_nn_lm_logprob += math.log(self.kaksi_prob + self.eos_prob) self.assertAlmostEqual(token1.nn_lm_logprob, token1_nn_lm_logprob) self.assertAlmostEqual(token2.nn_lm_logprob, token2_nn_lm_logprob) lm_scale = 2.0 token1.recompute_total(1.0, lm_scale, -0.01) token2.recompute_total(1.0, lm_scale, -0.01) self.assertAlmostEqual(token1.total_logprob, token1_nn_lm_logprob * lm_scale - 0.03) self.assertAlmostEqual(token2.total_logprob, token2_nn_lm_logprob * lm_scale - 0.04)
def decode(args): """A function that performs the "theanolm decode" command. :type args: argparse.Namespace :param args: a collection of command line arguments """ log_file = args.log_file log_level = getattr(logging, args.log_level.upper(), None) if not isinstance(log_level, int): print("Invalid logging level requested:", args.log_level) sys.exit(1) log_format = '%(asctime)s %(funcName)s: %(message)s' if args.log_file == '-': logging.basicConfig(stream=sys.stdout, format=log_format, level=log_level) else: logging.basicConfig(filename=log_file, format=log_format, level=log_level) if args.debug: theano.config.compute_test_value = 'warn' else: theano.config.compute_test_value = 'off' theano.config.profile = args.profile theano.config.profile_memory = args.profile network = Network.from_file(args.model_path, mode=Network.Mode(minibatch=False)) log_scale = 1.0 if args.log_base is None else numpy.log(args.log_base) if args.wi_penalty is None: wi_penalty = None else: wi_penalty = args.wi_penalty * log_scale if args.unk_penalty is None: ignore_unk = False unk_penalty = None elif args.unk_penalty == 0: ignore_unk = True unk_penalty = None else: ignore_unk = False unk_penalty = args.unk_penalty decoding_options = { 'nnlm_weight': args.nnlm_weight, 'lm_scale': args.lm_scale, 'wi_penalty': wi_penalty, 'ignore_unk': ignore_unk, 'unk_penalty': unk_penalty, 'linear_interpolation': args.linear_interpolation, 'max_tokens_per_node': args.max_tokens_per_node, 'beam': args.beam, 'recombination_order': args.recombination_order } logging.debug("DECODING OPTIONS") for option_name, option_value in decoding_options.items(): logging.debug("%s: %s", option_name, str(option_value)) print("Building word lattice decoder.") sys.stdout.flush() decoder = LatticeDecoder(network, decoding_options) # Combine paths from command line and lattice list. lattices = args.lattices if args.lattice_list is not None: lattices.extend(args.lattice_list.readlines()) lattices = [path.strip() for path in lattices] # Ignore empty lines in the lattice list. lattices = [x for x in lattices if x] # Pick every Ith lattice, if --num-jobs is specified and > 1. if args.num_jobs < 1: print("Invalid number of jobs specified:", args.num_jobs) sys.exit(1) if (args.job < 0) or (args.job > args.num_jobs - 1): print("Invalid job specified:", args.job) sys.exit(1) lattices = lattices[args.job::args.num_jobs] file_type = TextFileType('r') for index, path in enumerate(lattices): logging.info("Reading word lattice: %s", path) lattice_file = file_type(path) lattice = SLFLattice(lattice_file) if lattice.utterance_id is not None: utterance_id = lattice.utterance_id else: utterance_id = os.path.basename(lattice_file.name) logging.info("Utterance `%s' -- %d/%d of job %d", utterance_id, index + 1, len(lattices), args.job) tokens = decoder.decode(lattice) for index in range(min(args.n_best, len(tokens))): line = format_token(tokens[index], utterance_id, network.vocabulary, log_scale, args.output) args.output_file.write(line + "\n")
def test_decode(self): vocabulary = Vocabulary.from_word_counts({ 'TO': 1, 'AND': 1, 'IT': 1, 'BUT': 1, 'A.': 1, 'IN': 1, 'A': 1, 'AT': 1, 'THE': 1, 'E.': 1, "DIDN'T": 1, 'ELABORATE': 1}) projection_vector = tensor.ones(shape=(vocabulary.num_words(),), dtype=theano.config.floatX) projection_vector *= 0.05 network = DummyNetwork(vocabulary, projection_vector) decoding_options = { 'nnlm_weight': 0.0, 'lm_scale': None, 'wi_penalty': None, 'ignore_unk': False, 'unk_penalty': None, 'linear_interpolation': True, 'max_tokens_per_node': None, 'beam': None, 'recombination_order': None } decoder = LatticeDecoder(network, decoding_options) tokens = decoder.decode(self.lattice) # Compare tokens to n-best list given by SRILM lattice-tool. log_scale = math.log(10) print() for token in tokens: print(token.ac_logprob / log_scale, token.lat_lm_logprob / log_scale, token.total_logprob / log_scale, ' '.join(vocabulary.id_to_word[token.history])) all_paths = ["<s> IT DIDN'T ELABORATE </s>", "<s> BUT IT DIDN'T ELABORATE </s>", "<s> THE DIDN'T ELABORATE </s>", "<s> AND IT DIDN'T ELABORATE </s>", "<s> E. DIDN'T ELABORATE </s>", "<s> IN IT DIDN'T ELABORATE </s>", "<s> A DIDN'T ELABORATE </s>", "<s> AT IT DIDN'T ELABORATE </s>", "<s> IT IT DIDN'T ELABORATE </s>", "<s> TO IT DIDN'T ELABORATE </s>", "<s> A. IT DIDN'T ELABORATE </s>", "<s> A IT DIDN'T ELABORATE </s>"] paths = [' '.join(vocabulary.id_to_word[token.history]) for token in tokens] self.assertListEqual(paths, all_paths) token = tokens[0] history = ' '.join(vocabulary.id_to_word[token.history]) self.assertAlmostEqual(token.ac_logprob / log_scale, -8686.28, places=2) self.assertAlmostEqual(token.lat_lm_logprob / log_scale, -94.3896, places=2) self.assertAlmostEqual(token.nn_lm_logprob, math.log(0.1) * 4) token = tokens[1] self.assertAlmostEqual(token.ac_logprob / log_scale, -8743.96, places=2) self.assertAlmostEqual(token.lat_lm_logprob / log_scale, -111.488, places=2) self.assertAlmostEqual(token.nn_lm_logprob, math.log(0.1) * 5) token = tokens[-1] self.assertAlmostEqual(token.ac_logprob / log_scale, -8696.26, places=2) self.assertAlmostEqual(token.lat_lm_logprob / log_scale, -178.00, places=2) self.assertAlmostEqual(token.nn_lm_logprob, math.log(0.1) * 5)
def test_decode(self): vocabulary = Vocabulary.from_word_counts({ 'to': 1, 'and': 1, 'it': 1, 'but': 1, 'a.': 1, 'in': 1, 'a': 1, 'at': 1, 'the': 1, "didn't": 1, 'elaborate': 1 }) projection_vector = tensor.ones( shape=(vocabulary.num_shortlist_words(), ), dtype=theano.config.floatX) projection_vector *= 0.05 network = DummyNetwork(vocabulary, projection_vector) decoding_options = { 'nnlm_weight': 0.0, 'lm_scale': None, 'wi_penalty': None, 'unk_penalty': None, 'use_shortlist': False, 'unk_from_lattice': False, 'linear_interpolation': True, 'max_tokens_per_node': None, 'beam': None, 'recombination_order': 20 } decoder = LatticeDecoder(network, decoding_options) tokens = decoder.decode(self.lattice)[0] # Compare tokens to n-best list given by SRILM lattice-tool. log_scale = math.log(10) print() for token in tokens: print(token.ac_logprob / log_scale, token.lat_lm_logprob / log_scale, token.total_logprob / log_scale, ' '.join(token.history_words(vocabulary))) all_paths = [ "<s> it didn't elaborate </s>", "<s> but it didn't elaborate </s>", "<s> the didn't elaborate </s>", "<s> and it didn't elaborate </s>", "<s> e. didn't elaborate </s>", "<s> in it didn't elaborate </s>", "<s> a didn't elaborate </s>", "<s> at it didn't elaborate </s>", "<s> it it didn't elaborate </s>", "<s> to it didn't elaborate </s>", "<s> a. it didn't elaborate </s>", "<s> a it didn't elaborate </s>" ] paths = [' '.join(token.history_words(vocabulary)) for token in tokens] self.assertListEqual(paths, all_paths) token = tokens[0] history = ' '.join(token.history_words(vocabulary)) self.assertAlmostEqual(token.ac_logprob / log_scale, -8686.28, places=2) self.assertAlmostEqual(token.lat_lm_logprob / log_scale, -94.3896, places=2) self.assertAlmostEqual(token.nn_lm_logprob, math.log(0.1) * 4) token = tokens[1] self.assertAlmostEqual(token.ac_logprob / log_scale, -8743.96, places=2) self.assertAlmostEqual(token.lat_lm_logprob / log_scale, -111.488, places=2) self.assertAlmostEqual(token.nn_lm_logprob, math.log(0.1) * 5) token = tokens[-1] self.assertAlmostEqual(token.ac_logprob / log_scale, -8696.26, places=2) self.assertAlmostEqual(token.lat_lm_logprob / log_scale, -178.00, places=2) self.assertAlmostEqual(token.nn_lm_logprob, math.log(0.1) * 5)
def decode(args): """A function that performs the "theanolm decode" command. :type args: argparse.Namespace :param args: a collection of command line arguments """ log_file = args.log_file log_level = getattr(logging, args.log_level.upper(), None) if not isinstance(log_level, int): print("Invalid logging level requested:", args.log_level, file=sys.stderr) sys.exit(1) log_format = '%(asctime)s %(funcName)s: %(message)s' if args.log_file == '-': logging.basicConfig(stream=sys.stdout, format=log_format, level=log_level) else: logging.basicConfig(filename=log_file, format=log_format, level=log_level) if args.debug: theano.config.compute_test_value = 'warn' else: theano.config.compute_test_value = 'off' theano.config.profile = args.profile theano.config.profile_memory = args.profile if (args.lattice_format == 'kaldi') or (args.output == 'kaldi'): if args.kaldi_vocabulary is None: print("Kaldi lattice vocabulary is not given.", file=sys.stderr) sys.exit(1) default_device = get_default_device(args.default_device) network = Network.from_file(args.model_path, mode=Network.Mode(minibatch=False), default_device=default_device) log_scale = 1.0 if args.log_base is None else numpy.log(args.log_base) if (args.log_base is not None) and (args.lattice_format == 'kaldi'): logging.info("Warning: Kaldi lattice reader doesn't support logarithm " "base conversion.") if args.wi_penalty is None: wi_penalty = None else: wi_penalty = args.wi_penalty * log_scale decoding_options = { 'nnlm_weight': args.nnlm_weight, 'lm_scale': args.lm_scale, 'wi_penalty': wi_penalty, 'unk_penalty': args.unk_penalty, 'use_shortlist': args.shortlist, 'unk_from_lattice': args.unk_from_lattice, 'linear_interpolation': args.linear_interpolation, 'max_tokens_per_node': args.max_tokens_per_node, 'beam': args.beam, 'recombination_order': args.recombination_order, 'prune_relative': args.prune_relative, 'abs_min_max_tokens': args.abs_min_max_tokens, 'abs_min_beam': args.abs_min_beam } logging.debug("DECODING OPTIONS") for option_name, option_value in decoding_options.items(): logging.debug("%s: %s", option_name, str(option_value)) logging.info("Building word lattice decoder.") decoder = LatticeDecoder(network, decoding_options) batch = LatticeBatch(args.lattices, args.lattice_list, args.lattice_format, args.kaldi_vocabulary, args.num_jobs, args.job) for lattice_number, lattice in enumerate(batch): if lattice.utterance_id is None: lattice.utterance_id = str(lattice_number) logging.info("Utterance `%s´ -- %d of job %d", lattice.utterance_id, lattice_number + 1, args.job) log_free_mem() final_tokens, recomb_tokens = decoder.decode(lattice) if (args.output == "slf") or (args.output == "kaldi"): rescored_lattice = RescoredLattice(lattice, final_tokens, recomb_tokens, network.vocabulary) rescored_lattice.lm_scale = args.lm_scale rescored_lattice.wi_penalty = args.wi_penalty if args.output == "slf": rescored_lattice.write_slf(args.output_file) else: assert args.output == "kaldi" rescored_lattice.write_kaldi(args.output_file, batch.kaldi_word_to_id) else: for token in final_tokens[:min(args.n_best, len(final_tokens))]: line = format_token(token, lattice.utterance_id, network.vocabulary, log_scale, args.output) args.output_file.write(line + "\n") gc.collect()