def __init__(self, args, tgt_dict): super().__init__(args, tgt_dict) self.silence = ( tgt_dict.index("<ctc_blank>") if "<ctc_blank>" in tgt_dict.indices else tgt_dict.bos() ) self.lexicon = load_words(args.lexicon) self.word_dict = create_word_dict(self.lexicon) self.unk_word = self.word_dict.get_index("<unk>") self.lm = KenLM(args.kenlm_model, self.word_dict) self.trie = Trie(self.vocab_size, self.silence) start_state = self.lm.start(False) for i, (word, spellings) in enumerate(self.lexicon.items()): word_idx = self.word_dict.get_index(word) _, score = self.lm.score(start_state, word_idx) for spelling in spellings: spelling_idxs = [tgt_dict.index(token) for token in spelling] assert ( tgt_dict.unk() not in spelling_idxs ), f"{spelling} {spelling_idxs}" self.trie.insert(spelling_idxs, word_idx, score) self.trie.smear(SmearingMode.MAX) self.decoder_opts = DecoderOptions( args.beam, int(getattr(args, "beam_size_token", len(tgt_dict))), args.beam_threshold, args.lm_weight, args.word_score, args.unk_weight, args.sil_weight, 0, False, self.criterion_type, ) if self.asg_transitions is None: N = 768 # self.asg_transitions = torch.FloatTensor(N, N).zero_() self.asg_transitions = [] self.decoder = LexiconDecoder( self.decoder_opts, self.trie, self.lm, self.silence, self.blank, self.unk_word, self.asg_transitions, False, )
def __init__(self, args, tgt_dict): super().__init__(args, tgt_dict) self.silence = tgt_dict.index(args.silence_token) self.lexicon = load_words(args.lexicon) self.word_dict = create_word_dict(self.lexicon) self.unk_word = self.word_dict.get_index("<unk>") self.lm = KenLM(args.kenlm_model, self.word_dict) self.trie = Trie(self.vocab_size, self.silence) start_state = self.lm.start(False) for word, spellings in self.lexicon.items(): word_idx = self.word_dict.get_index(word) _, score = self.lm.score(start_state, word_idx) for spelling in spellings: spelling_idxs = [tgt_dict.index(token) for token in spelling] self.trie.insert(spelling_idxs, word_idx, score) self.trie.smear(SmearingMode.MAX) self.decoder_opts = DecoderOptions( args.beam, args.beam_threshold, args.lm_weight, args.word_score, args.unk_weight, False, args.sil_weight, self.criterion_type, ) self.decoder = WordLMDecoder( self.decoder_opts, self.trie, self.lm, self.silence, self.blank, self.unk_word, self.asg_transitions, )
def __init__(self, lm_weight=2.0, lexicon_path="WER_data/lexicon.txt", token_path="WER_data/letters.lst", lm_path="WER_data/4-gram.bin"): lexicon = load_words(lexicon_path) word_dict = create_word_dict(lexicon) self.token_dict = Dictionary(token_path) self.lm = KenLM(lm_path, word_dict) self.sil_idx = self.token_dict.get_index("|") self.unk_idx = word_dict.get_index("<unk>") self.token_dict.add_entry("#") self.blank_idx = self.token_dict.get_index('#') self.trie = Trie(self.token_dict.index_size(), self.sil_idx) start_state = self.lm.start(start_with_nothing=False) for word, spellings in lexicon.items(): usr_idx = word_dict.get_index(word) _, score = self.lm.score(start_state, usr_idx) for spelling in spellings: # max_reps should be 1; using 0 here to match DecoderTest bug spelling_idxs = tkn_to_idx(spelling, self.token_dict, max_reps=0) self.trie.insert(spelling_idxs, usr_idx, score) self.trie.smear(SmearingMode.MAX) self.opts = DecoderOptions(beam_size=2500, beam_threshold=100.0, lm_weight=lm_weight, word_score=2.0, unk_score=-math.inf, log_add=False, sil_weight=-1, criterion_type=CriterionType.CTC)
def __init__(self, args, tgt_dict): super().__init__(args, tgt_dict) self.silence = tgt_dict.bos() self.unit_lm = getattr(args, "unit_lm", False) self.lexicon = load_words(args.lexicon) if args.lexicon else None self.idx_to_wrd = {} checkpoint = torch.load(args.kenlm_model, map_location="cpu") if "cfg" in checkpoint and checkpoint["cfg"] is not None: lm_args = checkpoint["cfg"] else: lm_args = convert_namespace_to_omegaconf(checkpoint["args"]) with open_dict(lm_args.task): lm_args.task.data = osp.dirname(args.kenlm_model) task = tasks.setup_task(lm_args.task) model = task.build_model(lm_args.model) model.load_state_dict(checkpoint["model"], strict=False) self.trie = Trie(self.vocab_size, self.silence) self.word_dict = task.dictionary self.unk_word = self.word_dict.unk() self.lm = FairseqLM(self.word_dict, model) self.decoder_opts = DecoderOptions( args.beam, int(getattr(args, "beam_size_token", len(tgt_dict))), args.beam_threshold, args.lm_weight, args.word_score, args.unk_weight, args.sil_weight, 0, False, self.criterion_type, ) if self.lexicon: start_state = self.lm.start(False) for i, (word, spellings) in enumerate(self.lexicon.items()): if self.unit_lm: word_idx = i self.idx_to_wrd[i] = word score = 0 else: word_idx = self.word_dict.index(word) _, score = self.lm.score(start_state, word_idx, no_cache=True) for spelling in spellings: spelling_idxs = [ tgt_dict.index(token) for token in spelling ] assert ( tgt_dict.unk() not in spelling_idxs), f"{spelling} {spelling_idxs}" self.trie.insert(spelling_idxs, word_idx, score) self.trie.smear(SmearingMode.MAX) self.decoder = LexiconDecoder( self.decoder_opts, self.trie, self.lm, self.silence, self.blank, self.unk_word, [], self.unit_lm, ) else: from wav2letter.decoder import LexiconFreeDecoder self.decoder = LexiconFreeDecoder(self.decoder_opts, self.lm, self.silence, self.blank, [])
word = sentence[i] # maxReps should be 1; using 0 here to match DecoderTest bug word_tensor = tkn2Idx([c for c in word], tokenDict, 0) node = trie.search(word_tensor) assert_near(node.maxScore, trie_score_target[i], 1e-5) beamSize = 2500 beamThreshold = 100.0 lmWeight = 2.0 wordScore = 2.0 unkScore = -math.inf logAdd = False silWeight = -1 criterionType = CriterionType.ASG opts = DecoderOptions(beamSize, beamThreshold, lmWeight, wordScore, unkScore, logAdd, silWeight, criterionType) decoder = WordLMDecoder(opts, trie, lm, sil_idx, -1, unk_idx, transitions) results = decoder.decode(emissions.ctypes.data, T, N) print(f"Decoding complete, obtained {len(results)} results") print("Showing top 5 results:") for i in range(5): prediction = [] for idx in results[i].tokens: if idx < 0: break prediction.append(tokenDict.getEntry(idx)) prediction = " ".join(prediction) print(f"score={results[i].score} prediction='{prediction}'")
# check that trie is built in consistency with c++ trie_score_target = [ -1.05971, -2.87742, -2.64553, -3.05081, -1.05971, -3.08968 ] for i in range(len(sentence)): word = sentence[i] # max_reps should be 1; using 0 here to match DecoderTest bug word_tensor = tkn_to_idx([c for c in word], token_dict, 0) node = trie.search(word_tensor) assert_near(node.max_score, trie_score_target[i], 1e-5) # Define decoder options: # DecoderOptions (beam_size, token_beam_size, beam_threshold, lm_weight, # word_score, unk_score, sil_score, # eos_score, log_add, criterion_type (ASG or CTC)) opts = DecoderOptions(2500, 25000, 100.0, 2.0, 2.0, -math.inf, -1, 0, False, CriterionType.ASG) # define lexicon beam-search decoder with word-level lm # LexiconDecoder(decoder options, trie, lm, silence index, # blank index (for CTC), unk index, # transitiona matrix, is token-level lm) decoder = LexiconDecoder(opts, trie, lm, sil_idx, -1, unk_idx, transitions, False) # run decoding # decoder.decode(emissions, Time, Ntokens) # result is a list of sorted hypothesis, 0-index is the best hypothesis # each hypothesis is a struct with "score" and "words" representation # in the hypothesis and the "tokens" representation results = decoder.decode(emissions.ctypes.data, T, N) print(f"Decoding complete, obtained {len(results)} results")
def __init__(self, args, tgt_dict): super().__init__(args, tgt_dict) # self.silence = tgt_dict.bos() self.silence = 1 self.unit_lm = getattr(args, "unit_lm", False) self.lexicon = load_words(args.lexicon) if args.lexicon else None self.idx_to_wrd = {} checkpoint = torch.load(args.kenlm_model, map_location="cpu") lm_args = checkpoint["args"] lm_args.data = osp.dirname(args.kenlm_model) print(lm_args) # import pdb; pdb.set_trace() # lm_args.data = '' task = tasks.setup_task(lm_args) model = task.build_model(lm_args) model.load_state_dict(checkpoint["model"], strict=False) self.trie = Trie(self.vocab_size, self.silence) self.word_dict = task.dictionary self.unk_word = self.word_dict.unk() self.lm = FairseqLM(self.word_dict, model) self.decoder_opts = DecoderOptions( args.beam, int(getattr(args, "beam_size_token", len(tgt_dict))), args.beam_threshold, args.lm_weight, args.word_score, args.unk_weight, args.sil_weight, 0, False, self.criterion_type, ) if self.lexicon: start_state = self.lm.start(False) for i, (word, spellings) in enumerate(self.lexicon.items()): if self.unit_lm: word_idx = i self.idx_to_wrd[i] = word score = 0 else: word_idx = self.word_dict.index(word) _, score = self.lm.score(start_state, word_idx, no_cache=True) # spellings (rightside in lexicon) should nor be unk for spelling in spellings: spelling_idxs = [ tgt_dict.index(token) for token in spelling ] # if tgt_dict.unk() in spelling_idxs: # print(spelling) assert ( tgt_dict.unk() not in spelling_idxs ), f"{spelling} is unk: {spelling_idxs} for acoustic output" self.trie.insert(spelling_idxs, word_idx, score) self.trie.smear(SmearingMode.MAX) self.decoder = LexiconDecoder( self.decoder_opts, self.trie, self.lm, self.silence, self.blank, self.unk_word, [], self.unit_lm, )