class WlDecoder: """ Wav2Letter-based decoder. Follows the official examples for the python bindings, see https://github.com/facebookresearch/wav2letter/blob/master/bindings/python/examples/decoder_example.py """ def __init__(self, lm_weight=2.0, lexicon_path="WER_data/lexicon.txt", token_path="WER_data/letters.lst", lm_path="WER_data/4-gram.bin"): lexicon = load_words(lexicon_path) word_dict = create_word_dict(lexicon) self.token_dict = Dictionary(token_path) self.lm = KenLM(lm_path, word_dict) self.sil_idx = self.token_dict.get_index("|") self.unk_idx = word_dict.get_index("<unk>") self.token_dict.add_entry("#") self.blank_idx = self.token_dict.get_index('#') self.trie = Trie(self.token_dict.index_size(), self.sil_idx) start_state = self.lm.start(start_with_nothing=False) for word, spellings in lexicon.items(): usr_idx = word_dict.get_index(word) _, score = self.lm.score(start_state, usr_idx) for spelling in spellings: # max_reps should be 1; using 0 here to match DecoderTest bug spelling_idxs = tkn_to_idx(spelling, self.token_dict, max_reps=0) self.trie.insert(spelling_idxs, usr_idx, score) self.trie.smear(SmearingMode.MAX) self.opts = DecoderOptions(beam_size=2500, beam_threshold=100.0, lm_weight=lm_weight, word_score=2.0, unk_score=-math.inf, log_add=False, sil_weight=-1, criterion_type=CriterionType.CTC) def collapse(self, prediction): result = [] for p in prediction: if result and p == result[-1]: continue result.append(p) blank = '#' space = '|' result = [x for x in result if x != blank] result = [(x if x != space else ' ') for x in result if x != blank] return result def predictions(self, emissions): t, n = emissions.size() emissions = emissions.cpu().numpy() decoder = WordLMDecoder(self.opts, self.trie, self.lm, self.sil_idx, self.blank_idx, self.unk_idx, []) results = decoder.decode(emissions.ctypes.data, t, n) prediction = [ self.token_dict.get_entry(x) for x in results[0].tokens if x >= 0 ] prediction = self.collapse(prediction) return prediction
lm_state, lm_score = lm.finish(lm_state) total_score += lm_score assert_near(total_score, -19.5123, 1e-5) # build trie # Trie is necessary to do beam-search decoding with word-level lm # We restrict our search only to the words from the lexicon # Trie is constructed from the lexicon, each node is a token # path from the root to a leaf corresponds to a word spelling in the lexicon # get silence index sil_idx = token_dict.get_index("|") # get unknown word index unk_idx = word_dict.get_index("<unk>") # create the trie, specifying how many tokens we have and silence index trie = Trie(token_dict.index_size(), sil_idx) start_state = lm.start(False) # use heuristic for the trie, called smearing: # predict lm score for each word in the lexicon, set this score to a leaf # (we predict lm score for each word as each word starts a sentence) # word score of a leaf is propagated up to the root to have some proxy score # for any intermediate path in the trie # SmearingMode defines the function how to process scores # in a node came from the children nodes: # could be max operation or logadd or none for word, spellings in lexicon.items(): usr_idx = word_dict.get_index(word) _, score = lm.score(start_state, usr_idx) for spelling in spellings: # max_reps should be 1; using 0 here to match DecoderTest bug