def __init__(self, lm_weight=2.0, lexicon_path="WER_data/lexicon.txt", token_path="WER_data/letters.lst", lm_path="WER_data/4-gram.bin"): lexicon = load_words(lexicon_path) word_dict = create_word_dict(lexicon) self.token_dict = Dictionary(token_path) self.lm = KenLM(lm_path, word_dict) self.sil_idx = self.token_dict.get_index("|") self.unk_idx = word_dict.get_index("<unk>") self.token_dict.add_entry("#") self.blank_idx = self.token_dict.get_index('#') self.trie = Trie(self.token_dict.index_size(), self.sil_idx) start_state = self.lm.start(start_with_nothing=False) for word, spellings in lexicon.items(): usr_idx = word_dict.get_index(word) _, score = self.lm.score(start_state, usr_idx) for spelling in spellings: # max_reps should be 1; using 0 here to match DecoderTest bug spelling_idxs = tkn_to_idx(spelling, self.token_dict, max_reps=0) self.trie.insert(spelling_idxs, usr_idx, score) self.trie.smear(SmearingMode.MAX) self.opts = DecoderOptions(beam_size=2500, beam_threshold=100.0, lm_weight=lm_weight, word_score=2.0, unk_score=-math.inf, log_add=False, sil_weight=-1, criterion_type=CriterionType.CTC)
start_state = lm.start(False) # use heuristic for the trie, called smearing: # predict lm score for each word in the lexicon, set this score to a leaf # (we predict lm score for each word as each word starts a sentence) # word score of a leaf is propagated up to the root to have some proxy score # for any intermediate path in the trie # SmearingMode defines the function how to process scores # in a node came from the children nodes: # could be max operation or logadd or none for word, spellings in lexicon.items(): usr_idx = word_dict.get_index(word) _, score = lm.score(start_state, usr_idx) for spelling in spellings: # max_reps should be 1; using 0 here to match DecoderTest bug spelling_idxs = tkn_to_idx(spelling, token_dict, 0) trie.insert(spelling_idxs, usr_idx, score) trie.smear(SmearingMode.MAX) # check that trie is built in consistency with c++ trie_score_target = [ -1.05971, -2.87742, -2.64553, -3.05081, -1.05971, -3.08968 ] for i in range(len(sentence)): word = sentence[i] # max_reps should be 1; using 0 here to match DecoderTest bug word_tensor = tkn_to_idx([c for c in word], token_dict, 0) node = trie.search(word_tensor) assert_near(node.max_score, trie_score_target[i], 1e-5)