class W2lFairseqLMDecoder(W2lDecoder): def __init__(self, args, tgt_dict): super().__init__(args, tgt_dict) self.unit_lm = getattr(args, "unit_lm", False) self.lexicon = load_words(args.lexicon) if args.lexicon else None self.idx_to_wrd = {} checkpoint = torch.load(args.kenlm_model, map_location="cpu") if "cfg" in checkpoint and checkpoint["cfg"] is not None: lm_args = checkpoint["cfg"] else: lm_args = convert_namespace_to_omegaconf(checkpoint["args"]) with open_dict(lm_args.task): lm_args.task.data = osp.dirname(args.kenlm_model) task = tasks.setup_task(lm_args.task) model = task.build_model(lm_args.model) model.load_state_dict(checkpoint["model"], strict=False) self.trie = Trie(self.vocab_size, self.silence) self.word_dict = task.dictionary self.unk_word = self.word_dict.unk() self.lm = FairseqLM(self.word_dict, model) if self.lexicon: start_state = self.lm.start(False) for i, (word, spellings) in enumerate(self.lexicon.items()): if self.unit_lm: word_idx = i self.idx_to_wrd[i] = word score = 0 else: word_idx = self.word_dict.index(word) _, score = self.lm.score(start_state, word_idx, no_cache=True) for spelling in spellings: spelling_idxs = [tgt_dict.index(token) for token in spelling] assert ( tgt_dict.unk() not in spelling_idxs ), f"{spelling} {spelling_idxs}" self.trie.insert(spelling_idxs, word_idx, score) self.trie.smear(SmearingMode.MAX) self.decoder_opts = LexiconDecoderOptions( beam_size=args.beam, beam_size_token=int(getattr(args, "beam_size_token", len(tgt_dict))), beam_threshold=args.beam_threshold, lm_weight=args.lm_weight, word_score=args.word_score, unk_score=args.unk_weight, sil_score=args.sil_weight, log_add=False, criterion_type=self.criterion_type, ) self.decoder = LexiconDecoder( self.decoder_opts, self.trie, self.lm, self.silence, self.blank, self.unk_word, [], self.unit_lm, ) else: assert args.unit_lm, "lexicon free decoding can only be done with a unit language model" from flashlight.lib.text.decoder import LexiconFreeDecoder, LexiconFreeDecoderOptions d = {w: [[w]] for w in tgt_dict.symbols} self.word_dict = create_word_dict(d) self.lm = KenLM(args.kenlm_model, self.word_dict) self.decoder_opts = LexiconFreeDecoderOptions( beam_size=args.beam, beam_size_token=int(getattr(args, "beam_size_token", len(tgt_dict))), beam_threshold=args.beam_threshold, lm_weight=args.lm_weight, sil_score=args.sil_weight, log_add=False, criterion_type=self.criterion_type, ) self.decoder = LexiconFreeDecoder( self.decoder_opts, self.lm, self.silence, self.blank, [] ) def decode(self, emissions): B, T, N = emissions.size() hypos = [] def idx_to_word(idx): if self.unit_lm: return self.idx_to_wrd[idx] else: return self.word_dict[idx] def make_hypo(result): hypo = {"tokens": self.get_tokens(result.tokens), "score": result.score} if self.lexicon: hypo["words"] = [idx_to_word(x) for x in result.words if x >= 0] return hypo for b in range(B): emissions_ptr = emissions.data_ptr() + 4 * b * emissions.stride(0) results = self.decoder.decode(emissions_ptr, T, N) nbest_results = results[: self.nbest] hypos.append([make_hypo(result) for result in nbest_results]) self.lm.empty_cache() return hypos
class FairseqLMDecoder(BaseDecoder): def __init__(self, cfg: DecoderConfig, tgt_dict: Dictionary) -> None: super().__init__(cfg, tgt_dict) self.lexicon = load_words(cfg.lexicon) if cfg.lexicon else None self.idx_to_wrd = {} checkpoint = torch.load(cfg.lmpath, map_location="cpu") if "cfg" in checkpoint and checkpoint["cfg"] is not None: lm_args = checkpoint["cfg"] else: lm_args = convert_namespace_to_omegaconf(checkpoint["args"]) with open_dict(lm_args.task): lm_args.task.data = osp.dirname(cfg.lmpath) task = tasks.setup_task(lm_args.task) model = task.build_model(lm_args.model) model.load_state_dict(checkpoint["model"], strict=False) self.trie = Trie(self.vocab_size, self.silence) self.word_dict = task.dictionary self.unk_word = self.word_dict.unk() self.lm = FairseqLM(self.word_dict, model) if self.lexicon: start_state = self.lm.start(False) for i, (word, spellings) in enumerate(self.lexicon.items()): if self.unitlm: word_idx = i self.idx_to_wrd[i] = word score = 0 else: word_idx = self.word_dict.index(word) _, score = self.lm.score(start_state, word_idx, no_cache=True) for spelling in spellings: spelling_idxs = [ tgt_dict.index(token) for token in spelling ] assert tgt_dict.unk() not in spelling_idxs, \ f"{spelling} {spelling_idxs}" self.trie.insert(spelling_idxs, word_idx, score) self.trie.smear(SmearingMode.MAX) self.decoder_opts = LexiconDecoderOptions( beam_size=cfg.beam, beam_size_token=cfg.beamsizetoken or len(tgt_dict), beam_threshold=cfg.beamthreshold, lm_weight=cfg.lmweight, word_score=cfg.wordscore, unk_score=cfg.unkweight, sil_score=cfg.silweight, log_add=False, criterion_type=self.criterion_type, ) if self.asgtransitions is None: self.asgtransitions = [] self.decoder = LexiconDecoder( self.decoder_opts, self.trie, self.lm, self.silence, self.blank, self.unk_word, self.asgtransitions, self.unitlm, ) else: assert self.unitlm, "Lexicon-free decoding requires unit LM" d = {w: [[w]] for w in tgt_dict.symbols} self.word_dict = create_word_dict(d) self.lm = KenLM(cfg.lmpath, self.word_dict) self.decoder_opts = LexiconFreeDecoderOptions( beam_size=cfg.beam, beam_size_token=cfg.beamsizetoken or len(tgt_dict), beam_threshold=cfg.beamthreshold, lm_weight=cfg.lmweight, sil_score=cfg.silweight, log_add=False, criterion_type=self.criterion_type, ) self.decoder = LexiconFreeDecoder(self.decoder_opts, self.lm, self.silence, self.blank, []) def decode( self, emissions: torch.FloatTensor, ) -> List[List[Dict[str, torch.LongTensor]]]: B, T, N = emissions.size() hypos = [] def make_hypo(result: DecodeResult) -> Dict[str, Any]: hypo = { "tokens": self.get_tokens(result.tokens), "score": result.score, } if self.lexicon: hypo["words"] = [ self.idx_to_wrd[x] if self.unitlm else self.word_dict[x] for x in result.words if x >= 0 ] return hypo for b in range(B): emissions_ptr = emissions.data_ptr() + 4 * b * emissions.stride(0) results = self.decoder.decode(emissions_ptr, T, N) nbest_results = results[:self.nbest] hypos.append([make_hypo(result) for result in nbest_results]) self.lm.empty_cache() return hypos