def __init__(self, args, tgt_dict): super().__init__(args, tgt_dict) self.silence = ( tgt_dict.index("<ctc_blank>") if "<ctc_blank>" in tgt_dict.indices else tgt_dict.bos() ) self.lexicon = load_words(args.lexicon) self.word_dict = create_word_dict(self.lexicon) self.unk_word = self.word_dict.get_index("<unk>") self.lm = KenLM(args.kenlm_model, self.word_dict) self.trie = Trie(self.vocab_size, self.silence) start_state = self.lm.start(False) for i, (word, spellings) in enumerate(self.lexicon.items()): word_idx = self.word_dict.get_index(word) _, score = self.lm.score(start_state, word_idx) for spelling in spellings: spelling_idxs = [tgt_dict.index(token) for token in spelling] assert ( tgt_dict.unk() not in spelling_idxs ), f"{spelling} {spelling_idxs}" self.trie.insert(spelling_idxs, word_idx, score) self.trie.smear(SmearingMode.MAX) self.decoder_opts = DecoderOptions( args.beam, int(getattr(args, "beam_size_token", len(tgt_dict))), args.beam_threshold, args.lm_weight, args.word_score, args.unk_weight, args.sil_weight, 0, False, self.criterion_type, ) if self.asg_transitions is None: N = 768 # self.asg_transitions = torch.FloatTensor(N, N).zero_() self.asg_transitions = [] self.decoder = LexiconDecoder( self.decoder_opts, self.trie, self.lm, self.silence, self.blank, self.unk_word, self.asg_transitions, False, )
def __init__(self, args, tgt_dict): super().__init__(args, tgt_dict) self.silence = tgt_dict.bos() self.unit_lm = getattr(args, "unit_lm", False) self.lexicon = load_words(args.lexicon) if args.lexicon else None self.idx_to_wrd = {} checkpoint = torch.load(args.kenlm_model, map_location="cpu") if "cfg" in checkpoint and checkpoint["cfg"] is not None: lm_args = checkpoint["cfg"] else: lm_args = convert_namespace_to_omegaconf(checkpoint["args"]) with open_dict(lm_args.task): lm_args.task.data = osp.dirname(args.kenlm_model) task = tasks.setup_task(lm_args.task) model = task.build_model(lm_args.model) model.load_state_dict(checkpoint["model"], strict=False) self.trie = Trie(self.vocab_size, self.silence) self.word_dict = task.dictionary self.unk_word = self.word_dict.unk() self.lm = FairseqLM(self.word_dict, model) self.decoder_opts = DecoderOptions( args.beam, int(getattr(args, "beam_size_token", len(tgt_dict))), args.beam_threshold, args.lm_weight, args.word_score, args.unk_weight, args.sil_weight, 0, False, self.criterion_type, ) if self.lexicon: start_state = self.lm.start(False) for i, (word, spellings) in enumerate(self.lexicon.items()): if self.unit_lm: word_idx = i self.idx_to_wrd[i] = word score = 0 else: word_idx = self.word_dict.index(word) _, score = self.lm.score(start_state, word_idx, no_cache=True) for spelling in spellings: spelling_idxs = [ tgt_dict.index(token) for token in spelling ] assert ( tgt_dict.unk() not in spelling_idxs), f"{spelling} {spelling_idxs}" self.trie.insert(spelling_idxs, word_idx, score) self.trie.smear(SmearingMode.MAX) self.decoder = LexiconDecoder( self.decoder_opts, self.trie, self.lm, self.silence, self.blank, self.unk_word, [], self.unit_lm, ) else: from wav2letter.decoder import LexiconFreeDecoder self.decoder = LexiconFreeDecoder(self.decoder_opts, self.lm, self.silence, self.blank, [])
class W2lKenLMDecoder(W2lDecoder): def __init__(self, args, tgt_dict): super().__init__(args, tgt_dict) self.silence = (tgt_dict.index("<ctc_blank>") if "<ctc_blank>" in tgt_dict.indices else tgt_dict.bos()) self.lexicon = load_words(args.lexicon) self.word_dict = create_word_dict(self.lexicon) self.unk_word = self.word_dict.get_index("<unk>") self.lm = KenLM(args.kenlm_model, self.word_dict) self.trie = Trie(self.vocab_size, self.silence) start_state = self.lm.start(False) for i, (word, spellings) in enumerate(self.lexicon.items()): word_idx = self.word_dict.get_index(word) _, score = self.lm.score(start_state, word_idx) for spelling in spellings: spelling_idxs = [tgt_dict.index(token) for token in spelling] assert (tgt_dict.unk() not in spelling_idxs), f"{spelling} {spelling_idxs}" self.trie.insert(spelling_idxs, word_idx, score) self.trie.smear(SmearingMode.MAX) self.decoder_opts = DecoderOptions( args.beam, int(getattr(args, "beam_size_token", len(tgt_dict))), args.beam_threshold, args.lm_weight, args.word_score, args.unk_weight, args.sil_weight, 0, False, self.criterion_type, ) if self.asg_transitions is None: N = 768 # self.asg_transitions = torch.FloatTensor(N, N).zero_() self.asg_transitions = [] self.decoder = LexiconDecoder( self.decoder_opts, self.trie, self.lm, self.silence, self.blank, self.unk_word, self.asg_transitions, False, ) def decode(self, emissions): B, T, N = emissions.size() hypos = [] for b in range(B): emissions_ptr = emissions.data_ptr() + 4 * b * emissions.stride(0) results = self.decoder.decode(emissions_ptr, T, N) nbest_results = results[:self.nbest] hypos.append([{ "tokens": self.get_tokens(result.tokens), "score": result.score, "words": [self.word_dict.get_entry(x) for x in result.words if x >= 0], } for result in nbest_results]) return hypos
word_tensor = tkn_to_idx([c for c in word], token_dict, 0) node = trie.search(word_tensor) assert_near(node.max_score, trie_score_target[i], 1e-5) # Define decoder options: # DecoderOptions (beam_size, token_beam_size, beam_threshold, lm_weight, # word_score, unk_score, sil_score, # eos_score, log_add, criterion_type (ASG or CTC)) opts = DecoderOptions(2500, 25000, 100.0, 2.0, 2.0, -math.inf, -1, 0, False, CriterionType.ASG) # define lexicon beam-search decoder with word-level lm # LexiconDecoder(decoder options, trie, lm, silence index, # blank index (for CTC), unk index, # transitiona matrix, is token-level lm) decoder = LexiconDecoder(opts, trie, lm, sil_idx, -1, unk_idx, transitions, False) # run decoding # decoder.decode(emissions, Time, Ntokens) # result is a list of sorted hypothesis, 0-index is the best hypothesis # each hypothesis is a struct with "score" and "words" representation # in the hypothesis and the "tokens" representation results = decoder.decode(emissions.ctypes.data, T, N) print(f"Decoding complete, obtained {len(results)} results") print("Showing top 5 results:") for i in range(min(5, len(results))): prediction = [] for idx in results[i].tokens: if idx < 0: break prediction.append(token_dict.get_entry(idx))
def __init__(self, args, tgt_dict): super().__init__(args, tgt_dict) # self.silence = tgt_dict.bos() self.silence = 1 self.unit_lm = getattr(args, "unit_lm", False) self.lexicon = load_words(args.lexicon) if args.lexicon else None self.idx_to_wrd = {} checkpoint = torch.load(args.kenlm_model, map_location="cpu") lm_args = checkpoint["args"] lm_args.data = osp.dirname(args.kenlm_model) print(lm_args) # import pdb; pdb.set_trace() # lm_args.data = '' task = tasks.setup_task(lm_args) model = task.build_model(lm_args) model.load_state_dict(checkpoint["model"], strict=False) self.trie = Trie(self.vocab_size, self.silence) self.word_dict = task.dictionary self.unk_word = self.word_dict.unk() self.lm = FairseqLM(self.word_dict, model) self.decoder_opts = DecoderOptions( args.beam, int(getattr(args, "beam_size_token", len(tgt_dict))), args.beam_threshold, args.lm_weight, args.word_score, args.unk_weight, args.sil_weight, 0, False, self.criterion_type, ) if self.lexicon: start_state = self.lm.start(False) for i, (word, spellings) in enumerate(self.lexicon.items()): if self.unit_lm: word_idx = i self.idx_to_wrd[i] = word score = 0 else: word_idx = self.word_dict.index(word) _, score = self.lm.score(start_state, word_idx, no_cache=True) # spellings (rightside in lexicon) should nor be unk for spelling in spellings: spelling_idxs = [ tgt_dict.index(token) for token in spelling ] # if tgt_dict.unk() in spelling_idxs: # print(spelling) assert ( tgt_dict.unk() not in spelling_idxs ), f"{spelling} is unk: {spelling_idxs} for acoustic output" self.trie.insert(spelling_idxs, word_idx, score) self.trie.smear(SmearingMode.MAX) self.decoder = LexiconDecoder( self.decoder_opts, self.trie, self.lm, self.silence, self.blank, self.unk_word, [], self.unit_lm, )
class W2lFairseqLMDecoder(W2lDecoder): def __init__(self, args, tgt_dict): super().__init__(args, tgt_dict) # self.silence = tgt_dict.bos() self.silence = 1 self.unit_lm = getattr(args, "unit_lm", False) self.lexicon = load_words(args.lexicon) if args.lexicon else None self.idx_to_wrd = {} checkpoint = torch.load(args.kenlm_model, map_location="cpu") lm_args = checkpoint["args"] lm_args.data = osp.dirname(args.kenlm_model) print(lm_args) # import pdb; pdb.set_trace() # lm_args.data = '' task = tasks.setup_task(lm_args) model = task.build_model(lm_args) model.load_state_dict(checkpoint["model"], strict=False) self.trie = Trie(self.vocab_size, self.silence) self.word_dict = task.dictionary self.unk_word = self.word_dict.unk() self.lm = FairseqLM(self.word_dict, model) self.decoder_opts = DecoderOptions( args.beam, int(getattr(args, "beam_size_token", len(tgt_dict))), args.beam_threshold, args.lm_weight, args.word_score, args.unk_weight, args.sil_weight, 0, False, self.criterion_type, ) if self.lexicon: start_state = self.lm.start(False) for i, (word, spellings) in enumerate(self.lexicon.items()): if self.unit_lm: word_idx = i self.idx_to_wrd[i] = word score = 0 else: word_idx = self.word_dict.index(word) _, score = self.lm.score(start_state, word_idx, no_cache=True) # spellings (rightside in lexicon) should nor be unk for spelling in spellings: spelling_idxs = [ tgt_dict.index(token) for token in spelling ] # if tgt_dict.unk() in spelling_idxs: # print(spelling) assert ( tgt_dict.unk() not in spelling_idxs ), f"{spelling} is unk: {spelling_idxs} for acoustic output" self.trie.insert(spelling_idxs, word_idx, score) self.trie.smear(SmearingMode.MAX) self.decoder = LexiconDecoder( self.decoder_opts, self.trie, self.lm, self.silence, self.blank, self.unk_word, [], self.unit_lm, ) # else: # self.decoder = LexiconFreeDecoder( # self.decoder_opts, self.lm, self.silence, self.blank, [] # ) def decode(self, emissions): B, T, N = emissions.size() hypos = [] def idx_to_word(idx): if self.unit_lm: return self.idx_to_wrd[idx] else: return self.word_dict[idx] def make_hypo(result): hypo = { "tokens": self.get_tokens(result.tokens), "score": result.score } if self.lexicon: hypo["words"] = [ idx_to_word(x) for x in result.words if x >= 0 ] return hypo for b in range(B): emissions_ptr = emissions.data_ptr() + 4 * b * emissions.stride(0) results = self.decoder.decode(emissions_ptr, T, N) nbest_results = results[:self.nbest] hypos.append([make_hypo(result) for result in nbest_results]) self.lm.empty_cache() return hypos