class W2lKenLMDecoder(W2lDecoder): def __init__(self, args, tgt_dict): super().__init__(args, tgt_dict) self.unit_lm = getattr(args, "unit_lm", False) if args.lexicon: self.lexicon = load_words(args.lexicon) self.word_dict = create_word_dict(self.lexicon) self.unk_word = self.word_dict.get_index("<unk>") self.lm = KenLM(args.kenlm_model, self.word_dict) self.trie = Trie(self.vocab_size, self.silence) start_state = self.lm.start(False) for i, (word, spellings) in enumerate(self.lexicon.items()): word_idx = self.word_dict.get_index(word) _, score = self.lm.score(start_state, word_idx) for spelling in spellings: spelling_idxs = [tgt_dict.index(token) for token in spelling] assert ( tgt_dict.unk() not in spelling_idxs ), f"{spelling} {spelling_idxs}" self.trie.insert(spelling_idxs, word_idx, score) self.trie.smear(SmearingMode.MAX) self.decoder_opts = LexiconDecoderOptions( beam_size=args.beam, beam_size_token=int(getattr(args, "beam_size_token", len(tgt_dict))), beam_threshold=args.beam_threshold, lm_weight=args.lm_weight, word_score=args.word_score, unk_score=args.unk_weight, sil_score=args.sil_weight, log_add=False, criterion_type=self.criterion_type, ) if self.asg_transitions is None: N = 768 # self.asg_transitions = torch.FloatTensor(N, N).zero_() self.asg_transitions = [] self.decoder = LexiconDecoder( self.decoder_opts, self.trie, self.lm, self.silence, self.blank, self.unk_word, self.asg_transitions, self.unit_lm, ) else: assert args.unit_lm, "lexicon free decoding can only be done with a unit language model" from flashlight.lib.text.decoder import LexiconFreeDecoder, LexiconFreeDecoderOptions d = {w: [[w]] for w in tgt_dict.symbols} self.word_dict = create_word_dict(d) self.lm = KenLM(args.kenlm_model, self.word_dict) self.decoder_opts = LexiconFreeDecoderOptions( beam_size=args.beam, beam_size_token=int(getattr(args, "beam_size_token", len(tgt_dict))), beam_threshold=args.beam_threshold, lm_weight=args.lm_weight, sil_score=args.sil_weight, log_add=False, criterion_type=self.criterion_type, ) self.decoder = LexiconFreeDecoder( self.decoder_opts, self.lm, self.silence, self.blank, [] ) def get_timesteps(self, token_idxs: List[int]) -> List[int]: """Returns frame numbers corresponding to every non-blank token. Parameters ---------- token_idxs : List[int] IDs of decoded tokens. Returns ------- List[int] Frame numbers corresponding to every non-blank token. """ timesteps = [] for i, token_idx in enumerate(token_idxs): if token_idx == self.blank: continue if i == 0 or token_idx != token_idxs[i-1]: timesteps.append(i) return timesteps def decode(self, emissions): B, T, N = emissions.size() hypos = [] for b in range(B): emissions_ptr = emissions.data_ptr() + 4 * b * emissions.stride(0) results = self.decoder.decode(emissions_ptr, T, N) nbest_results = results[: self.nbest] hypos.append( [ { "tokens": self.get_tokens(result.tokens), "score": result.score, "timesteps": self.get_timesteps(result.tokens), "words": [ self.word_dict.get_entry(x) for x in result.words if x >= 0 ], } for result in nbest_results ] ) return hypos
class W2lFairseqLMDecoder(W2lDecoder): def __init__(self, args, tgt_dict): super().__init__(args, tgt_dict) self.unit_lm = getattr(args, "unit_lm", False) self.lexicon = load_words(args.lexicon) if args.lexicon else None self.idx_to_wrd = {} checkpoint = torch.load(args.kenlm_model, map_location="cpu") if "cfg" in checkpoint and checkpoint["cfg"] is not None: lm_args = checkpoint["cfg"] else: lm_args = convert_namespace_to_omegaconf(checkpoint["args"]) with open_dict(lm_args.task): lm_args.task.data = osp.dirname(args.kenlm_model) task = tasks.setup_task(lm_args.task) model = task.build_model(lm_args.model) model.load_state_dict(checkpoint["model"], strict=False) self.trie = Trie(self.vocab_size, self.silence) self.word_dict = task.dictionary self.unk_word = self.word_dict.unk() self.lm = FairseqLM(self.word_dict, model) if self.lexicon: start_state = self.lm.start(False) for i, (word, spellings) in enumerate(self.lexicon.items()): if self.unit_lm: word_idx = i self.idx_to_wrd[i] = word score = 0 else: word_idx = self.word_dict.index(word) _, score = self.lm.score(start_state, word_idx, no_cache=True) for spelling in spellings: spelling_idxs = [tgt_dict.index(token) for token in spelling] assert ( tgt_dict.unk() not in spelling_idxs ), f"{spelling} {spelling_idxs}" self.trie.insert(spelling_idxs, word_idx, score) self.trie.smear(SmearingMode.MAX) self.decoder_opts = LexiconDecoderOptions( beam_size=args.beam, beam_size_token=int(getattr(args, "beam_size_token", len(tgt_dict))), beam_threshold=args.beam_threshold, lm_weight=args.lm_weight, word_score=args.word_score, unk_score=args.unk_weight, sil_score=args.sil_weight, log_add=False, criterion_type=self.criterion_type, ) self.decoder = LexiconDecoder( self.decoder_opts, self.trie, self.lm, self.silence, self.blank, self.unk_word, [], self.unit_lm, ) else: assert args.unit_lm, "lexicon free decoding can only be done with a unit language model" from flashlight.lib.text.decoder import LexiconFreeDecoder, LexiconFreeDecoderOptions d = {w: [[w]] for w in tgt_dict.symbols} self.word_dict = create_word_dict(d) self.lm = KenLM(args.kenlm_model, self.word_dict) self.decoder_opts = LexiconFreeDecoderOptions( beam_size=args.beam, beam_size_token=int(getattr(args, "beam_size_token", len(tgt_dict))), beam_threshold=args.beam_threshold, lm_weight=args.lm_weight, sil_score=args.sil_weight, log_add=False, criterion_type=self.criterion_type, ) self.decoder = LexiconFreeDecoder( self.decoder_opts, self.lm, self.silence, self.blank, [] ) def decode(self, emissions): B, T, N = emissions.size() hypos = [] def idx_to_word(idx): if self.unit_lm: return self.idx_to_wrd[idx] else: return self.word_dict[idx] def make_hypo(result): hypo = {"tokens": self.get_tokens(result.tokens), "score": result.score} if self.lexicon: hypo["words"] = [idx_to_word(x) for x in result.words if x >= 0] return hypo for b in range(B): emissions_ptr = emissions.data_ptr() + 4 * b * emissions.stride(0) results = self.decoder.decode(emissions_ptr, T, N) nbest_results = results[: self.nbest] hypos.append([make_hypo(result) for result in nbest_results]) self.lm.empty_cache() return hypos
class FairseqLMDecoder(BaseDecoder): def __init__(self, cfg: DecoderConfig, tgt_dict: Dictionary) -> None: super().__init__(cfg, tgt_dict) self.lexicon = load_words(cfg.lexicon) if cfg.lexicon else None self.idx_to_wrd = {} checkpoint = torch.load(cfg.lmpath, map_location="cpu") if "cfg" in checkpoint and checkpoint["cfg"] is not None: lm_args = checkpoint["cfg"] else: lm_args = convert_namespace_to_omegaconf(checkpoint["args"]) with open_dict(lm_args.task): lm_args.task.data = osp.dirname(cfg.lmpath) task = tasks.setup_task(lm_args.task) model = task.build_model(lm_args.model) model.load_state_dict(checkpoint["model"], strict=False) self.trie = Trie(self.vocab_size, self.silence) self.word_dict = task.dictionary self.unk_word = self.word_dict.unk() self.lm = FairseqLM(self.word_dict, model) if self.lexicon: start_state = self.lm.start(False) for i, (word, spellings) in enumerate(self.lexicon.items()): if self.unitlm: word_idx = i self.idx_to_wrd[i] = word score = 0 else: word_idx = self.word_dict.index(word) _, score = self.lm.score(start_state, word_idx, no_cache=True) for spelling in spellings: spelling_idxs = [ tgt_dict.index(token) for token in spelling ] assert tgt_dict.unk() not in spelling_idxs, \ f"{spelling} {spelling_idxs}" self.trie.insert(spelling_idxs, word_idx, score) self.trie.smear(SmearingMode.MAX) self.decoder_opts = LexiconDecoderOptions( beam_size=cfg.beam, beam_size_token=cfg.beamsizetoken or len(tgt_dict), beam_threshold=cfg.beamthreshold, lm_weight=cfg.lmweight, word_score=cfg.wordscore, unk_score=cfg.unkweight, sil_score=cfg.silweight, log_add=False, criterion_type=self.criterion_type, ) if self.asgtransitions is None: self.asgtransitions = [] self.decoder = LexiconDecoder( self.decoder_opts, self.trie, self.lm, self.silence, self.blank, self.unk_word, self.asgtransitions, self.unitlm, ) else: assert self.unitlm, "Lexicon-free decoding requires unit LM" d = {w: [[w]] for w in tgt_dict.symbols} self.word_dict = create_word_dict(d) self.lm = KenLM(cfg.lmpath, self.word_dict) self.decoder_opts = LexiconFreeDecoderOptions( beam_size=cfg.beam, beam_size_token=cfg.beamsizetoken or len(tgt_dict), beam_threshold=cfg.beamthreshold, lm_weight=cfg.lmweight, sil_score=cfg.silweight, log_add=False, criterion_type=self.criterion_type, ) self.decoder = LexiconFreeDecoder(self.decoder_opts, self.lm, self.silence, self.blank, []) def decode( self, emissions: torch.FloatTensor, ) -> List[List[Dict[str, torch.LongTensor]]]: B, T, N = emissions.size() hypos = [] def make_hypo(result: DecodeResult) -> Dict[str, Any]: hypo = { "tokens": self.get_tokens(result.tokens), "score": result.score, } if self.lexicon: hypo["words"] = [ self.idx_to_wrd[x] if self.unitlm else self.word_dict[x] for x in result.words if x >= 0 ] return hypo for b in range(B): emissions_ptr = emissions.data_ptr() + 4 * b * emissions.stride(0) results = self.decoder.decode(emissions_ptr, T, N) nbest_results = results[:self.nbest] hypos.append([make_hypo(result) for result in nbest_results]) self.lm.empty_cache() return hypos
class KenLMDecoder(object): def __init__(self, kenlm_args, vocab_dict, blank="<pad>", silence="|", unk="<unk>"): self.vocab_size = len(vocab_dict) self.blank_token = (vocab_dict[blank]) self.silence_token = vocab_dict[silence] self.unk_token = vocab_dict[unk] self.nbest = kenlm_args['nbest'] if kenlm_args['lexicon_path']: vocab_keys = vocab_dict.keys() self.lexicon = load_words(kenlm_args['lexicon_path']) self.word_dict = create_word_dict(self.lexicon) self.unk_word = self.word_dict.get_index(unk) self.lm = KenLM(kenlm_args['kenlm_model_path'], self.word_dict) self.trie = Trie(self.vocab_size, self.silence_token) start_state = self.lm.start(False) for i, (word, spellings) in enumerate(self.lexicon.items()): word_idx = self.word_dict.get_index(word) _, score = self.lm.score(start_state, word_idx) for spelling in spellings: spelling_idxs = [] for token in spelling: if token.upper() in vocab_keys: spelling_idxs.append(vocab_dict[token.upper()]) elif token.lower() in vocab_keys: spelling_idxs.append(vocab_dict[token.lower()]) else: print( "WARNING: The token", token, "not exist in your vocabulary, using <unk> token instead" ) spelling_idxs.append(self.unk_token) self.trie.insert(spelling_idxs, word_idx, score) self.trie.smear(SmearingMode.MAX) self.decoder_opts = LexiconDecoderOptions( beam_size=kenlm_args['beam'], beam_size_token=kenlm_args['beam_size_token'] if "beam_size_token" in kenlm_args else len(vocab_dict), beam_threshold=kenlm_args['beam_threshold'], lm_weight=kenlm_args['lm_weight'], word_score=kenlm_args['word_score'], unk_score=-math.inf, sil_score=kenlm_args['sil_weight'], log_add=False, criterion_type=CriterionType.CTC, ) self.decoder = LexiconDecoder( self.decoder_opts, self.trie, self.lm, self.silence_token, self.blank_token, self.unk_word, [], False, ) else: d = {w: [[w]] for w in vocab_dict.keys()} self.word_dict = create_word_dict(d) self.lm = KenLM(kenlm_args['kenlm_model_path'], self.word_dict) self.decoder_opts = LexiconFreeDecoderOptions( beam_size=kenlm_args['beam'], beam_size_token=kenlm_args['beam_size_token'] if "beam_size_token" in kenlm_args else len(vocab_dict), beam_threshold=kenlm_args['beam_threshold'], lm_weight=kenlm_args['lm_weight'], sil_score=kenlm_args['sil_weight'], log_add=False, criterion_type=CriterionType.CTC, ) self.decoder = LexiconFreeDecoder(self.decoder_opts, self.lm, self.silence_token, self.blank_token, []) def get_tokens(self, idxs): """Normalize tokens by handling CTC blank""" idxs = (g[0] for g in it.groupby(idxs)) idxs = filter(lambda x: x != self.blank_token, idxs) return torch.LongTensor(list(idxs)) def decode(self, emissions): B, T, N = emissions.size() # print(emissions.shape) tokens = [] scores = [] for b in range(B): emissions_ptr = emissions.data_ptr() + 4 * b * emissions.stride(0) results = self.decoder.decode(emissions_ptr, T, N) nbest_results = results[:self.nbest] tokens_nbest = [] scores_nbest = [] for result in nbest_results: tokens_nbest.append(result.tokens) scores_nbest.append(result.score) tokens.append(tokens_nbest) scores.append(scores_nbest) token_array = np.array(tokens, dtype=object).transpose((1, 0, 2)) scores_arrray = np.array(scores, dtype=object).transpose() return token_array, scores_arrray
class KenLMDecoder(BaseDecoder): def __init__(self, cfg: DecoderConfig, tgt_dict: Dictionary) -> None: super().__init__(cfg, tgt_dict) if cfg.lexicon: self.lexicon = load_words(cfg.lexicon) self.word_dict = create_word_dict(self.lexicon) self.unk_word = self.word_dict.get_index("<unk>") self.lm = KenLM(cfg.lmpath, self.word_dict) self.trie = Trie(self.vocab_size, self.silence) start_state = self.lm.start(False) for word, spellings in self.lexicon.items(): word_idx = self.word_dict.get_index(word) _, score = self.lm.score(start_state, word_idx) for spelling in spellings: spelling_idxs = [ tgt_dict.index(token) for token in spelling ] assert tgt_dict.unk() not in spelling_idxs, \ f"{spelling} {spelling_idxs}" self.trie.insert(spelling_idxs, word_idx, score) self.trie.smear(SmearingMode.MAX) self.decoder_opts = LexiconDecoderOptions( beam_size=cfg.beam, beam_size_token=cfg.beamsizetoken or len(tgt_dict), beam_threshold=cfg.beamthreshold, lm_weight=cfg.lmweight, word_score=cfg.wordscore, unk_score=cfg.unkweight, sil_score=cfg.silweight, log_add=False, criterion_type=self.criterion_type, ) if self.asgtransitions is None: self.asgtransitions = [] self.decoder = LexiconDecoder( self.decoder_opts, self.trie, self.lm, self.silence, self.blank, self.unk_word, self.asgtransitions, self.unitlm, ) else: assert self.unitlm, "Lexicon-free decoding requires unit LM" d = {w: [[w]] for w in tgt_dict.symbols} self.word_dict = create_word_dict(d) self.lm = KenLM(cfg.lmpath, self.word_dict) self.decoder_opts = LexiconFreeDecoderOptions( beam_size=cfg.beam, beam_size_token=cfg.beamsizetoken or len(tgt_dict), beam_threshold=cfg.beamthreshold, lm_weight=cfg.lmweight, sil_score=cfg.silweight, log_add=False, criterion_type=self.criterion_type, ) self.decoder = LexiconFreeDecoder(self.decoder_opts, self.lm, self.silence, self.blank, []) def decode( self, emissions: torch.FloatTensor, ) -> List[List[Dict[str, torch.LongTensor]]]: B, T, N = emissions.size() hypos = [] for b in range(B): emissions_ptr = emissions.data_ptr() + 4 * b * emissions.stride(0) results = self.decoder.decode(emissions_ptr, T, N) nbest_results = results[:self.nbest] hypos.append([{ "tokens": self.get_tokens(result.tokens), "score": result.score, "words": [self.word_dict.get_entry(x) for x in result.words if x >= 0], } for result in nbest_results]) return hypos
class W2lKenLMDecoder(W2lDecoder): def __init__(self, args, tgt_dict): super().__init__(args, tgt_dict) self.unit_lm = getattr(args, "unit_lm", False) if args.lexicon: self.lexicon = load_words(args.lexicon) self.word_dict = create_word_dict(self.lexicon) self.unk_word = self.word_dict.get_index("<unk>") self.lm = KenLM(args.kenlm_model, self.word_dict) self.trie = Trie(self.vocab_size, self.silence) start_state = self.lm.start(False) for i, (word, spellings) in enumerate(self.lexicon.items()): word_idx = self.word_dict.get_index(word) _, score = self.lm.score(start_state, word_idx) for spelling in spellings: if isinstance(tgt_dict, HF2FairSeqDictWrapper) and hasattr(tgt_dict, "hf_tokenizer"): xx = [tgt_dict.hf_tokenizer.encode(x, add_special_tokens=False) for x in spelling] spelling_idxs = [y for x in xx for y in x] else: spelling_idxs = [tgt_dict.index(token) for token in spelling] assert ( tgt_dict.unk() not in spelling_idxs ), f"{spelling} {spelling_idxs}" self.trie.insert(spelling_idxs, word_idx, score) self.trie.smear(SmearingMode.MAX) self.decoder_opts = LexiconDecoderOptions( beam_size=args.beam, beam_size_token=int(getattr(args, "beam_size_token", len(tgt_dict))), beam_threshold=args.beam_threshold, lm_weight=args.lm_weight, word_score=args.word_score, unk_score=args.unk_weight, sil_score=args.sil_weight, log_add=False, criterion_type=self.criterion_type, ) if self.asg_transitions is None: N = 768 # self.asg_transitions = torch.FloatTensor(N, N).zero_() self.asg_transitions = [] self.decoder = LexiconDecoder( self.decoder_opts, self.trie, self.lm, self.silence, self.blank, self.unk_word, self.asg_transitions, self.unit_lm, ) else: assert args.unit_lm, "lexicon free decoding can only be done with a unit language model" from flashlight.lib.text.decoder import LexiconFreeDecoder, LexiconFreeDecoderOptions d = {w: [[w]] for w in tgt_dict.symbols} self.word_dict = create_word_dict(d) self.lm = KenLM(args.kenlm_model, self.word_dict) self.decoder_opts = LexiconFreeDecoderOptions( beam_size=args.beam, beam_size_token=int(getattr(args, "beam_size_token", len(tgt_dict))), beam_threshold=args.beam_threshold, lm_weight=args.lm_weight, sil_score=args.sil_weight, log_add=False, criterion_type=self.criterion_type, ) self.decoder = LexiconFreeDecoder( self.decoder_opts, self.lm, self.silence, self.blank, [] ) def decode(self, emissions): B, T, N = emissions.size() hypos = [] for b in range(B): emissions_ptr = emissions.data_ptr() + 4 * b * emissions.stride(0) results = self.decoder.decode(emissions_ptr, T, N) nbest_results = results[: self.nbest] hypos.append( [ { "tokens": self.get_tokens(result.tokens), "score": result.score, "words": [ self.word_dict.get_entry(x) for x in result.words if x >= 0 ], } for result in nbest_results ] ) return hypos
# predict lm score for each word in the lexicon, set this score to a leaf # (we predict lm score for each word as each word starts a sentence) # word score of a leaf is propagated up to the root to have some proxy score # for any intermediate path in the trie # SmearingMode defines the function how to process scores # in a node came from the children nodes: # could be max operation or logadd or none for word, spellings in lexicon.items(): usr_idx = word_dict.get_index(word) _, score = lm.score(start_state, usr_idx) for spelling in spellings: # max_reps should be 1; using 0 here to match DecoderTest bug spelling_idxs = tkn_to_idx(spelling, token_dict, 1) trie.insert(spelling_idxs, usr_idx, score) trie.smear(SmearingMode.MAX) # check that trie is built in consistency with c++ trie_score_target = [ -1.05971, -2.87742, -2.64553, -3.05081, -1.05971, -3.08968 ] for i in range(len(sentence)): word = sentence[i] # max_reps should be 1; using 0 here to match DecoderTest bug word_tensor = tkn_to_idx([c for c in word], token_dict, 1) node = trie.search(word_tensor) assert_near(node.max_score, trie_score_target[i], 1e-5) # Define decoder options: # LexiconDecoderOptions (beam_size, token_beam_size, beam_threshold, lm_weight, # word_score, unk_score, sil_score,
class KenLMDecoder(BaseDecoder): def __init__(self, cfg: FlashlightDecoderConfig, tgt_dict: Dictionary) -> None: super().__init__(tgt_dict) self.nbest = cfg.nbest self.unitlm = cfg.unitlm if cfg.lexicon: self.lexicon = load_words(cfg.lexicon) self.word_dict = create_word_dict(self.lexicon) self.unk_word = self.word_dict.get_index("<unk>") self.lm = KenLM(cfg.lmpath, self.word_dict) self.trie = Trie(self.vocab_size, self.silence) start_state = self.lm.start(False) for word, spellings in self.lexicon.items(): word_idx = self.word_dict.get_index(word) _, score = self.lm.score(start_state, word_idx) for spelling in spellings: spelling_idxs = [ tgt_dict.index(token) for token in spelling ] assert (tgt_dict.unk() not in spelling_idxs ), f"{word} {spelling} {spelling_idxs}" self.trie.insert(spelling_idxs, word_idx, score) self.trie.smear(SmearingMode.MAX) self.decoder_opts = LexiconDecoderOptions( beam_size=cfg.beam, beam_size_token=cfg.beamsizetoken or len(tgt_dict), beam_threshold=cfg.beamthreshold, lm_weight=cfg.lmweight, word_score=cfg.wordscore, unk_score=cfg.unkweight, sil_score=cfg.silweight, log_add=False, criterion_type=CriterionType.CTC, ) self.decoder = LexiconDecoder( self.decoder_opts, self.trie, self.lm, self.silence, self.blank, self.unk_word, [], self.unitlm, ) else: assert self.unitlm, "Lexicon-free decoding requires unit LM" d = {w: [[w]] for w in tgt_dict.symbols} self.word_dict = create_word_dict(d) self.lm = KenLM(cfg.lmpath, self.word_dict) self.decoder_opts = LexiconFreeDecoderOptions( beam_size=cfg.beam, beam_size_token=cfg.beamsizetoken or len(tgt_dict), beam_threshold=cfg.beamthreshold, lm_weight=cfg.lmweight, sil_score=cfg.silweight, log_add=False, criterion_type=CriterionType.CTC, ) self.decoder = LexiconFreeDecoder(self.decoder_opts, self.lm, self.silence, self.blank, []) def get_timesteps(self, token_idxs: List[int]) -> List[int]: """Returns frame numbers corresponding to every non-blank token. Parameters ---------- token_idxs : List[int] IDs of decoded tokens. Returns ------- List[int] Frame numbers corresponding to every non-blank token. """ timesteps = [] for i, token_idx in enumerate(token_idxs): if token_idx == self.blank: continue if i == 0 or token_idx != token_idxs[i - 1]: timesteps.append(i) return timesteps def decode( self, emissions: torch.FloatTensor, ) -> List[List[Dict[str, torch.LongTensor]]]: B, T, N = emissions.size() hypos = [] for b in range(B): emissions_ptr = emissions.data_ptr() + 4 * b * emissions.stride(0) results = self.decoder.decode(emissions_ptr, T, N) nbest_results = results[:self.nbest] hypos.append([{ "tokens": self.get_tokens(result.tokens), "score": result.score, "timesteps": self.get_timesteps(result.tokens), "words": [self.word_dict.get_entry(x) for x in result.words if x >= 0], } for result in nbest_results]) return hypos
class W2lKenLMDecoder(W2lDecoder): def __init__(self, args, tgt_dict): super().__init__(args, tgt_dict) self.silence = (tgt_dict.index("<ctc_blank>") if "<ctc_blank>" in tgt_dict.indices else tgt_dict.bos()) self.lexicon = load_words(args.lexicon) self.word_dict = create_word_dict(self.lexicon) self.unk_word = self.word_dict.get_index("<unk>") self.lm = KenLM(args.kenlm_model, self.word_dict) self.trie = Trie(self.vocab_size, self.silence) start_state = self.lm.start(False) for i, (word, spellings) in enumerate(self.lexicon.items()): word_idx = self.word_dict.get_index(word) _, score = self.lm.score(start_state, word_idx) for spelling in spellings: spelling_idxs = [tgt_dict.index(token) for token in spelling] assert (tgt_dict.unk() not in spelling_idxs), f"{spelling} {spelling_idxs}" self.trie.insert(spelling_idxs, word_idx, score) self.trie.smear(SmearingMode.MAX) self.decoder_opts = LexiconDecoderOptions( beam_size=args.beam, beam_size_token=int(getattr(args, "beam_size_token", len(tgt_dict))), beam_threshold=args.beam_threshold, lm_weight=args.lm_weight, word_score=args.word_score, unk_score=args.unk_weight, sil_score=args.sil_weight, log_add=False, criterion_type=self.criterion_type, ) if self.asg_transitions is None: N = 768 # self.asg_transitions = torch.FloatTensor(N, N).zero_() self.asg_transitions = [] self.decoder = LexiconDecoder( self.decoder_opts, self.trie, self.lm, self.silence, self.blank, self.unk_word, self.asg_transitions, False, ) def decode(self, emissions): B, T, N = emissions.size() hypos = [] for b in range(B): emissions_ptr = emissions.data_ptr() + 4 * b * emissions.stride(0) results = self.decoder.decode(emissions_ptr, T, N) nbest_results = results[:self.nbest] hypos.append([{ "tokens": self.get_tokens(result.tokens), "score": result.score, "words": [self.word_dict.get_entry(x) for x in result.words if x >= 0], } for result in nbest_results]) return hypos