def __init__(self, args, tgt_dict): super().__init__(args, tgt_dict) self.silence = (tgt_dict.index("<ctc_blank>") if "<ctc_blank>" in tgt_dict.indices else tgt_dict.bos()) self.lexicon = load_words(args.lexicon) self.word_dict = create_word_dict(self.lexicon) self.unk_word = self.word_dict.get_index("<unk>") self.lm = KenLM(args.kenlm_model, self.word_dict) self.trie = Trie(self.vocab_size, self.silence) start_state = self.lm.start(False) for i, (word, spellings) in enumerate(self.lexicon.items()): word_idx = self.word_dict.get_index(word) _, score = self.lm.score(start_state, word_idx) for spelling in spellings: spelling_idxs = [tgt_dict.index(token) for token in spelling] assert (tgt_dict.unk() not in spelling_idxs), f"{spelling} {spelling_idxs}" self.trie.insert(spelling_idxs, word_idx, score) self.trie.smear(SmearingMode.MAX) self.decoder_opts = LexiconDecoderOptions( beam_size=args.beam, beam_size_token=int(getattr(args, "beam_size_token", len(tgt_dict))), beam_threshold=args.beam_threshold, lm_weight=args.lm_weight, word_score=args.word_score, unk_score=args.unk_weight, sil_score=args.sil_weight, log_add=False, criterion_type=self.criterion_type, ) if self.asg_transitions is None: N = 768 # self.asg_transitions = torch.FloatTensor(N, N).zero_() self.asg_transitions = [] self.decoder = LexiconDecoder( self.decoder_opts, self.trie, self.lm, self.silence, self.blank, self.unk_word, self.asg_transitions, False, )
def __init__(self, args, tgt_dict): super().__init__(args, tgt_dict) self.unit_lm = getattr(args, "unit_lm", False) self.lexicon = load_words(args.lexicon) if args.lexicon else None self.idx_to_wrd = {} checkpoint = torch.load(args.kenlm_model, map_location="cpu") if "cfg" in checkpoint and checkpoint["cfg"] is not None: lm_args = checkpoint["cfg"] else: lm_args = convert_namespace_to_omegaconf(checkpoint["args"]) with open_dict(lm_args.task): lm_args.task.data = osp.dirname(args.kenlm_model) task = tasks.setup_task(lm_args.task) model = task.build_model(lm_args.model) model.load_state_dict(checkpoint["model"], strict=False) self.trie = Trie(self.vocab_size, self.silence) self.word_dict = task.dictionary self.unk_word = self.word_dict.unk() self.lm = FairseqLM(self.word_dict, model) if self.lexicon: start_state = self.lm.start(False) for i, (word, spellings) in enumerate(self.lexicon.items()): if self.unit_lm: word_idx = i self.idx_to_wrd[i] = word score = 0 else: word_idx = self.word_dict.index(word) _, score = self.lm.score(start_state, word_idx, no_cache=True) for spelling in spellings: spelling_idxs = [tgt_dict.index(token) for token in spelling] assert ( tgt_dict.unk() not in spelling_idxs ), f"{spelling} {spelling_idxs}" self.trie.insert(spelling_idxs, word_idx, score) self.trie.smear(SmearingMode.MAX) self.decoder_opts = LexiconDecoderOptions( beam_size=args.beam, beam_size_token=int(getattr(args, "beam_size_token", len(tgt_dict))), beam_threshold=args.beam_threshold, lm_weight=args.lm_weight, word_score=args.word_score, unk_score=args.unk_weight, sil_score=args.sil_weight, log_add=False, criterion_type=self.criterion_type, ) self.decoder = LexiconDecoder( self.decoder_opts, self.trie, self.lm, self.silence, self.blank, self.unk_word, [], self.unit_lm, ) else: assert args.unit_lm, "lexicon free decoding can only be done with a unit language model" from flashlight.lib.text.decoder import LexiconFreeDecoder, LexiconFreeDecoderOptions d = {w: [[w]] for w in tgt_dict.symbols} self.word_dict = create_word_dict(d) self.lm = KenLM(args.kenlm_model, self.word_dict) self.decoder_opts = LexiconFreeDecoderOptions( beam_size=args.beam, beam_size_token=int(getattr(args, "beam_size_token", len(tgt_dict))), beam_threshold=args.beam_threshold, lm_weight=args.lm_weight, sil_score=args.sil_weight, log_add=False, criterion_type=self.criterion_type, ) self.decoder = LexiconFreeDecoder( self.decoder_opts, self.lm, self.silence, self.blank, [] )
def __init__(self, args, tgt_dict): super().__init__(args, tgt_dict) self.unit_lm = getattr(args, "unit_lm", False) if args.lexicon: self.lexicon = load_words(args.lexicon) self.word_dict = create_word_dict(self.lexicon) self.unk_word = self.word_dict.get_index("<unk>") self.lm = KenLM(args.kenlm_model, self.word_dict) self.trie = Trie(self.vocab_size, self.silence) start_state = self.lm.start(False) for i, (word, spellings) in enumerate(self.lexicon.items()): word_idx = self.word_dict.get_index(word) _, score = self.lm.score(start_state, word_idx) for spelling in spellings: spelling_idxs = [tgt_dict.index(token) for token in spelling] assert ( tgt_dict.unk() not in spelling_idxs ), f"{spelling} {spelling_idxs}" self.trie.insert(spelling_idxs, word_idx, score) self.trie.smear(SmearingMode.MAX) self.decoder_opts = LexiconDecoderOptions( beam_size=args.beam, beam_size_token=int(getattr(args, "beam_size_token", len(tgt_dict))), beam_threshold=args.beam_threshold, lm_weight=args.lm_weight, word_score=args.word_score, unk_score=args.unk_weight, sil_score=args.sil_weight, log_add=False, criterion_type=self.criterion_type, ) if self.asg_transitions is None: N = 768 # self.asg_transitions = torch.FloatTensor(N, N).zero_() self.asg_transitions = [] self.decoder = LexiconDecoder( self.decoder_opts, self.trie, self.lm, self.silence, self.blank, self.unk_word, self.asg_transitions, self.unit_lm, ) else: assert args.unit_lm, "lexicon free decoding can only be done with a unit language model" from flashlight.lib.text.decoder import LexiconFreeDecoder, LexiconFreeDecoderOptions d = {w: [[w]] for w in tgt_dict.symbols} self.word_dict = create_word_dict(d) self.lm = KenLM(args.kenlm_model, self.word_dict) self.decoder_opts = LexiconFreeDecoderOptions( beam_size=args.beam, beam_size_token=int(getattr(args, "beam_size_token", len(tgt_dict))), beam_threshold=args.beam_threshold, lm_weight=args.lm_weight, sil_score=args.sil_weight, log_add=False, criterion_type=self.criterion_type, ) self.decoder = LexiconFreeDecoder( self.decoder_opts, self.lm, self.silence, self.blank, [] )
def __init__(self, kenlm_args, vocab_dict, blank="<pad>", silence="|", unk="<unk>"): self.vocab_size = len(vocab_dict) self.blank_token = (vocab_dict[blank]) self.silence_token = vocab_dict[silence] self.unk_token = vocab_dict[unk] self.nbest = kenlm_args['nbest'] if kenlm_args['lexicon_path']: vocab_keys = vocab_dict.keys() self.lexicon = load_words(kenlm_args['lexicon_path']) self.word_dict = create_word_dict(self.lexicon) self.unk_word = self.word_dict.get_index(unk) self.lm = KenLM(kenlm_args['kenlm_model_path'], self.word_dict) self.trie = Trie(self.vocab_size, self.silence_token) start_state = self.lm.start(False) for i, (word, spellings) in enumerate(self.lexicon.items()): word_idx = self.word_dict.get_index(word) _, score = self.lm.score(start_state, word_idx) for spelling in spellings: spelling_idxs = [] for token in spelling: if token.upper() in vocab_keys: spelling_idxs.append(vocab_dict[token.upper()]) elif token.lower() in vocab_keys: spelling_idxs.append(vocab_dict[token.lower()]) else: print( "WARNING: The token", token, "not exist in your vocabulary, using <unk> token instead" ) spelling_idxs.append(self.unk_token) self.trie.insert(spelling_idxs, word_idx, score) self.trie.smear(SmearingMode.MAX) self.decoder_opts = LexiconDecoderOptions( beam_size=kenlm_args['beam'], beam_size_token=kenlm_args['beam_size_token'] if "beam_size_token" in kenlm_args else len(vocab_dict), beam_threshold=kenlm_args['beam_threshold'], lm_weight=kenlm_args['lm_weight'], word_score=kenlm_args['word_score'], unk_score=-math.inf, sil_score=kenlm_args['sil_weight'], log_add=False, criterion_type=CriterionType.CTC, ) self.decoder = LexiconDecoder( self.decoder_opts, self.trie, self.lm, self.silence_token, self.blank_token, self.unk_word, [], False, ) else: d = {w: [[w]] for w in vocab_dict.keys()} self.word_dict = create_word_dict(d) self.lm = KenLM(kenlm_args['kenlm_model_path'], self.word_dict) self.decoder_opts = LexiconFreeDecoderOptions( beam_size=kenlm_args['beam'], beam_size_token=kenlm_args['beam_size_token'] if "beam_size_token" in kenlm_args else len(vocab_dict), beam_threshold=kenlm_args['beam_threshold'], lm_weight=kenlm_args['lm_weight'], sil_score=kenlm_args['sil_weight'], log_add=False, criterion_type=CriterionType.CTC, ) self.decoder = LexiconFreeDecoder(self.decoder_opts, self.lm, self.silence_token, self.blank_token, [])
def __init__(self, cfg: DecoderConfig, tgt_dict: Dictionary) -> None: super().__init__(cfg, tgt_dict) self.lexicon = load_words(cfg.lexicon) if cfg.lexicon else None self.idx_to_wrd = {} checkpoint = torch.load(cfg.lmpath, map_location="cpu") if "cfg" in checkpoint and checkpoint["cfg"] is not None: lm_args = checkpoint["cfg"] else: lm_args = convert_namespace_to_omegaconf(checkpoint["args"]) with open_dict(lm_args.task): lm_args.task.data = osp.dirname(cfg.lmpath) task = tasks.setup_task(lm_args.task) model = task.build_model(lm_args.model) model.load_state_dict(checkpoint["model"], strict=False) self.trie = Trie(self.vocab_size, self.silence) self.word_dict = task.dictionary self.unk_word = self.word_dict.unk() self.lm = FairseqLM(self.word_dict, model) if self.lexicon: start_state = self.lm.start(False) for i, (word, spellings) in enumerate(self.lexicon.items()): if self.unitlm: word_idx = i self.idx_to_wrd[i] = word score = 0 else: word_idx = self.word_dict.index(word) _, score = self.lm.score(start_state, word_idx, no_cache=True) for spelling in spellings: spelling_idxs = [ tgt_dict.index(token) for token in spelling ] assert tgt_dict.unk() not in spelling_idxs, \ f"{spelling} {spelling_idxs}" self.trie.insert(spelling_idxs, word_idx, score) self.trie.smear(SmearingMode.MAX) self.decoder_opts = LexiconDecoderOptions( beam_size=cfg.beam, beam_size_token=cfg.beamsizetoken or len(tgt_dict), beam_threshold=cfg.beamthreshold, lm_weight=cfg.lmweight, word_score=cfg.wordscore, unk_score=cfg.unkweight, sil_score=cfg.silweight, log_add=False, criterion_type=self.criterion_type, ) if self.asgtransitions is None: self.asgtransitions = [] self.decoder = LexiconDecoder( self.decoder_opts, self.trie, self.lm, self.silence, self.blank, self.unk_word, self.asgtransitions, self.unitlm, ) else: assert self.unitlm, "Lexicon-free decoding requires unit LM" d = {w: [[w]] for w in tgt_dict.symbols} self.word_dict = create_word_dict(d) self.lm = KenLM(cfg.lmpath, self.word_dict) self.decoder_opts = LexiconFreeDecoderOptions( beam_size=cfg.beam, beam_size_token=cfg.beamsizetoken or len(tgt_dict), beam_threshold=cfg.beamthreshold, lm_weight=cfg.lmweight, sil_score=cfg.silweight, log_add=False, criterion_type=self.criterion_type, ) self.decoder = LexiconFreeDecoder(self.decoder_opts, self.lm, self.silence, self.blank, [])
def __init__(self, cfg: DecoderConfig, tgt_dict: Dictionary) -> None: super().__init__(cfg, tgt_dict) if cfg.lexicon: self.lexicon = load_words(cfg.lexicon) self.word_dict = create_word_dict(self.lexicon) self.unk_word = self.word_dict.get_index("<unk>") self.lm = KenLM(cfg.lmpath, self.word_dict) self.trie = Trie(self.vocab_size, self.silence) start_state = self.lm.start(False) for word, spellings in self.lexicon.items(): word_idx = self.word_dict.get_index(word) _, score = self.lm.score(start_state, word_idx) for spelling in spellings: spelling_idxs = [ tgt_dict.index(token) for token in spelling ] assert tgt_dict.unk() not in spelling_idxs, \ f"{spelling} {spelling_idxs}" self.trie.insert(spelling_idxs, word_idx, score) self.trie.smear(SmearingMode.MAX) self.decoder_opts = LexiconDecoderOptions( beam_size=cfg.beam, beam_size_token=cfg.beamsizetoken or len(tgt_dict), beam_threshold=cfg.beamthreshold, lm_weight=cfg.lmweight, word_score=cfg.wordscore, unk_score=cfg.unkweight, sil_score=cfg.silweight, log_add=False, criterion_type=self.criterion_type, ) if self.asgtransitions is None: self.asgtransitions = [] self.decoder = LexiconDecoder( self.decoder_opts, self.trie, self.lm, self.silence, self.blank, self.unk_word, self.asgtransitions, self.unitlm, ) else: assert self.unitlm, "Lexicon-free decoding requires unit LM" d = {w: [[w]] for w in tgt_dict.symbols} self.word_dict = create_word_dict(d) self.lm = KenLM(cfg.lmpath, self.word_dict) self.decoder_opts = LexiconFreeDecoderOptions( beam_size=cfg.beam, beam_size_token=cfg.beamsizetoken or len(tgt_dict), beam_threshold=cfg.beamthreshold, lm_weight=cfg.lmweight, sil_score=cfg.silweight, log_add=False, criterion_type=self.criterion_type, ) self.decoder = LexiconFreeDecoder(self.decoder_opts, self.lm, self.silence, self.blank, [])
word_tensor = tkn_to_idx([c for c in word], token_dict, 1) node = trie.search(word_tensor) assert_near(node.max_score, trie_score_target[i], 1e-5) # Define decoder options: # LexiconDecoderOptions (beam_size, token_beam_size, beam_threshold, lm_weight, # word_score, unk_score, sil_score, # log_add, criterion_type (ASG or CTC)) opts = LexiconDecoderOptions(2500, 25000, 100.0, 2.0, 2.0, -math.inf, -1, False, CriterionType.ASG) # define lexicon beam-search decoder with word-level lm # LexiconDecoder(decoder options, trie, lm, silence index, # blank index (for CTC), unk index, # transitiona matrix, is token-level lm) decoder = LexiconDecoder(opts, trie, lm, sil_idx, -1, unk_idx, transitions, False) # run decoding # decoder.decode(emissions, Time, Ntokens) # result is a list of sorted hypothesis, 0-index is the best hypothesis # each hypothesis is a struct with "score" and "words" representation # in the hypothesis and the "tokens" representation results = decoder.decode(emissions.ctypes.data, T, N) print(f"Decoding complete, obtained {len(results)} results") print("Showing top 5 results:") for i in range(min(5, len(results))): prediction = [] for idx in results[i].tokens: if idx < 0: break prediction.append(token_dict.get_entry(idx))
class W2lKenLMDecoder(W2lDecoder): def __init__(self, args, tgt_dict): super().__init__(args, tgt_dict) self.silence = (tgt_dict.index("<ctc_blank>") if "<ctc_blank>" in tgt_dict.indices else tgt_dict.bos()) self.lexicon = load_words(args.lexicon) self.word_dict = create_word_dict(self.lexicon) self.unk_word = self.word_dict.get_index("<unk>") self.lm = KenLM(args.kenlm_model, self.word_dict) self.trie = Trie(self.vocab_size, self.silence) start_state = self.lm.start(False) for i, (word, spellings) in enumerate(self.lexicon.items()): word_idx = self.word_dict.get_index(word) _, score = self.lm.score(start_state, word_idx) for spelling in spellings: spelling_idxs = [tgt_dict.index(token) for token in spelling] assert (tgt_dict.unk() not in spelling_idxs), f"{spelling} {spelling_idxs}" self.trie.insert(spelling_idxs, word_idx, score) self.trie.smear(SmearingMode.MAX) self.decoder_opts = LexiconDecoderOptions( beam_size=args.beam, beam_size_token=int(getattr(args, "beam_size_token", len(tgt_dict))), beam_threshold=args.beam_threshold, lm_weight=args.lm_weight, word_score=args.word_score, unk_score=args.unk_weight, sil_score=args.sil_weight, log_add=False, criterion_type=self.criterion_type, ) if self.asg_transitions is None: N = 768 # self.asg_transitions = torch.FloatTensor(N, N).zero_() self.asg_transitions = [] self.decoder = LexiconDecoder( self.decoder_opts, self.trie, self.lm, self.silence, self.blank, self.unk_word, self.asg_transitions, False, ) def decode(self, emissions): B, T, N = emissions.size() hypos = [] for b in range(B): emissions_ptr = emissions.data_ptr() + 4 * b * emissions.stride(0) results = self.decoder.decode(emissions_ptr, T, N) nbest_results = results[:self.nbest] hypos.append([{ "tokens": self.get_tokens(result.tokens), "score": result.score, "words": [self.word_dict.get_entry(x) for x in result.words if x >= 0], } for result in nbest_results]) return hypos