def __init__(self, args, tgt_dict): super().__init__(args, tgt_dict) self.silence = ( tgt_dict.index("<ctc_blank>") if "<ctc_blank>" in tgt_dict.indices else tgt_dict.bos() ) self.lexicon = load_words(args.lexicon) self.word_dict = create_word_dict(self.lexicon) self.unk_word = self.word_dict.get_index("<unk>") self.lm = KenLM(args.kenlm_model, self.word_dict) self.trie = Trie(self.vocab_size, self.silence) start_state = self.lm.start(False) for i, (word, spellings) in enumerate(self.lexicon.items()): word_idx = self.word_dict.get_index(word) _, score = self.lm.score(start_state, word_idx) for spelling in spellings: spelling_idxs = [tgt_dict.index(token) for token in spelling] assert ( tgt_dict.unk() not in spelling_idxs ), f"{spelling} {spelling_idxs}" self.trie.insert(spelling_idxs, word_idx, score) self.trie.smear(SmearingMode.MAX) self.decoder_opts = DecoderOptions( args.beam, int(getattr(args, "beam_size_token", len(tgt_dict))), args.beam_threshold, args.lm_weight, args.word_score, args.unk_weight, args.sil_weight, 0, False, self.criterion_type, ) if self.asg_transitions is None: N = 768 # self.asg_transitions = torch.FloatTensor(N, N).zero_() self.asg_transitions = [] self.decoder = LexiconDecoder( self.decoder_opts, self.trie, self.lm, self.silence, self.blank, self.unk_word, self.asg_transitions, False, )
class W2lKenLMDecoder(W2lDecoder): def __init__(self, args, tgt_dict): super().__init__(args, tgt_dict) self.silence = tgt_dict.index(args.silence_token) self.lexicon = load_words(args.lexicon) self.word_dict = create_word_dict(self.lexicon) self.unk_word = self.word_dict.get_index("<unk>") self.lm = KenLM(args.kenlm_model, self.word_dict) self.trie = Trie(self.vocab_size, self.silence) start_state = self.lm.start(False) for word, spellings in self.lexicon.items(): word_idx = self.word_dict.get_index(word) _, score = self.lm.score(start_state, word_idx) for spelling in spellings: spelling_idxs = [tgt_dict.index(token) for token in spelling] self.trie.insert(spelling_idxs, word_idx, score) self.trie.smear(SmearingMode.MAX) self.decoder_opts = DecoderOptions( args.beam, args.beam_threshold, args.lm_weight, args.word_score, args.unk_weight, False, args.sil_weight, self.criterion_type, ) self.decoder = WordLMDecoder( self.decoder_opts, self.trie, self.lm, self.silence, self.blank, self.unk_word, self.asg_transitions, ) def decode(self, emissions): B, T, N = emissions.size() hypos = [] for b in range(B): emissions_ptr = emissions.data_ptr() + 4 * b * emissions.stride(0) nbest_results = self.decoder.decode(emissions_ptr, T, N)[:self.nbest] hypos.append([{ "tokens": self.get_tokens(result.tokens), "score": result.score } for result in nbest_results]) return hypos
def __init__(self, args, tgt_dict): super().__init__(args, tgt_dict) self.silence = tgt_dict.index(args.silence_token) self.lexicon = load_words(args.lexicon) self.word_dict = create_word_dict(self.lexicon) self.unk_word = self.word_dict.get_index("<unk>") self.lm = KenLM(args.kenlm_model, self.word_dict) self.trie = Trie(self.vocab_size, self.silence) start_state = self.lm.start(False) for word, spellings in self.lexicon.items(): word_idx = self.word_dict.get_index(word) _, score = self.lm.score(start_state, word_idx) for spelling in spellings: spelling_idxs = [tgt_dict.index(token) for token in spelling] self.trie.insert(spelling_idxs, word_idx, score) self.trie.smear(SmearingMode.MAX) self.decoder_opts = DecoderOptions( args.beam, args.beam_threshold, args.lm_weight, args.word_score, args.unk_weight, False, args.sil_weight, self.criterion_type, ) self.decoder = WordLMDecoder( self.decoder_opts, self.trie, self.lm, self.silence, self.blank, self.unk_word, self.asg_transitions, )
def __init__(self, lm_weight=2.0, lexicon_path="WER_data/lexicon.txt", token_path="WER_data/letters.lst", lm_path="WER_data/4-gram.bin"): lexicon = load_words(lexicon_path) word_dict = create_word_dict(lexicon) self.token_dict = Dictionary(token_path) self.lm = KenLM(lm_path, word_dict) self.sil_idx = self.token_dict.get_index("|") self.unk_idx = word_dict.get_index("<unk>") self.token_dict.add_entry("#") self.blank_idx = self.token_dict.get_index('#') self.trie = Trie(self.token_dict.index_size(), self.sil_idx) start_state = self.lm.start(start_with_nothing=False) for word, spellings in lexicon.items(): usr_idx = word_dict.get_index(word) _, score = self.lm.score(start_state, usr_idx) for spelling in spellings: # max_reps should be 1; using 0 here to match DecoderTest bug spelling_idxs = tkn_to_idx(spelling, self.token_dict, max_reps=0) self.trie.insert(spelling_idxs, usr_idx, score) self.trie.smear(SmearingMode.MAX) self.opts = DecoderOptions(beam_size=2500, beam_threshold=100.0, lm_weight=lm_weight, word_score=2.0, unk_score=-math.inf, log_add=False, sil_weight=-1, criterion_type=CriterionType.CTC)
class W2lKenLMDecoder(W2lDecoder): def __init__(self, args, tgt_dict): super().__init__(args, tgt_dict) self.silence = (tgt_dict.index("<ctc_blank>") if "<ctc_blank>" in tgt_dict.indices else tgt_dict.bos()) self.lexicon = load_words(args.lexicon) self.word_dict = create_word_dict(self.lexicon) self.unk_word = self.word_dict.get_index("<unk>") self.lm = KenLM(args.kenlm_model, self.word_dict) self.trie = Trie(self.vocab_size, self.silence) start_state = self.lm.start(False) for i, (word, spellings) in enumerate(self.lexicon.items()): word_idx = self.word_dict.get_index(word) _, score = self.lm.score(start_state, word_idx) for spelling in spellings: spelling_idxs = [tgt_dict.index(token) for token in spelling] assert (tgt_dict.unk() not in spelling_idxs), f"{spelling} {spelling_idxs}" self.trie.insert(spelling_idxs, word_idx, score) self.trie.smear(SmearingMode.MAX) self.decoder_opts = DecoderOptions( args.beam, int(getattr(args, "beam_size_token", len(tgt_dict))), args.beam_threshold, args.lm_weight, args.word_score, args.unk_weight, args.sil_weight, 0, False, self.criterion_type, ) if self.asg_transitions is None: N = 768 # self.asg_transitions = torch.FloatTensor(N, N).zero_() self.asg_transitions = [] self.decoder = LexiconDecoder( self.decoder_opts, self.trie, self.lm, self.silence, self.blank, self.unk_word, self.asg_transitions, False, ) def decode(self, emissions): B, T, N = emissions.size() hypos = [] for b in range(B): emissions_ptr = emissions.data_ptr() + 4 * b * emissions.stride(0) results = self.decoder.decode(emissions_ptr, T, N) nbest_results = results[:self.nbest] hypos.append([{ "tokens": self.get_tokens(result.tokens), "score": result.score, "words": [self.word_dict.get_entry(x) for x in result.words if x >= 0], } for result in nbest_results]) return hypos
class WlDecoder: """ Wav2Letter-based decoder. Follows the official examples for the python bindings, see https://github.com/facebookresearch/wav2letter/blob/master/bindings/python/examples/decoder_example.py """ def __init__(self, lm_weight=2.0, lexicon_path="WER_data/lexicon.txt", token_path="WER_data/letters.lst", lm_path="WER_data/4-gram.bin"): lexicon = load_words(lexicon_path) word_dict = create_word_dict(lexicon) self.token_dict = Dictionary(token_path) self.lm = KenLM(lm_path, word_dict) self.sil_idx = self.token_dict.get_index("|") self.unk_idx = word_dict.get_index("<unk>") self.token_dict.add_entry("#") self.blank_idx = self.token_dict.get_index('#') self.trie = Trie(self.token_dict.index_size(), self.sil_idx) start_state = self.lm.start(start_with_nothing=False) for word, spellings in lexicon.items(): usr_idx = word_dict.get_index(word) _, score = self.lm.score(start_state, usr_idx) for spelling in spellings: # max_reps should be 1; using 0 here to match DecoderTest bug spelling_idxs = tkn_to_idx(spelling, self.token_dict, max_reps=0) self.trie.insert(spelling_idxs, usr_idx, score) self.trie.smear(SmearingMode.MAX) self.opts = DecoderOptions(beam_size=2500, beam_threshold=100.0, lm_weight=lm_weight, word_score=2.0, unk_score=-math.inf, log_add=False, sil_weight=-1, criterion_type=CriterionType.CTC) def collapse(self, prediction): result = [] for p in prediction: if result and p == result[-1]: continue result.append(p) blank = '#' space = '|' result = [x for x in result if x != blank] result = [(x if x != space else ' ') for x in result if x != blank] return result def predictions(self, emissions): t, n = emissions.size() emissions = emissions.cpu().numpy() decoder = WordLMDecoder(self.opts, self.trie, self.lm, self.sil_idx, self.blank_idx, self.unk_idx, []) results = decoder.decode(emissions.ctypes.data, t, n) prediction = [ self.token_dict.get_entry(x) for x in results[0].tokens if x >= 0 ] prediction = self.collapse(prediction) return prediction
def assert_near(x, y, tol): assert abs(x - y) <= tol # load test files T, N = load_TN(os.path.join(testing_data_path, "TN.bin")) emissions = load_emissions(os.path.join(testing_data_path, "emission.bin")) transitions = load_transitions( os.path.join(testing_data_path, "transition.bin")) lexicon = loadWords(os.path.join(testing_data_path, "words.lst")) wordDict = createWordDict(lexicon) tokenDict = Dictionary(os.path.join(testing_data_path, "letters.lst")) tokenDict.addEntry("1") lm = KenLM(os.path.join(testing_data_path, "lm.arpa"), wordDict) # test LM #sentence = ["the", "cat", "sat", "on", "the", "mat"] #lm_state = lm.start(False) #total_score = 0 #lm_score_target = [-1.05971, -4.19448, -3.33383, -2.76726, -1.16237, -4.64589] #for i in range(len(sentence)): # lm_state, lm_score = lm.score(lm_state, wordDict.getIndex(sentence[i])) # assert_near(lm_score, lm_score_target[i], 1e-5) # total_score += lm_score #lm_state, lm_score = lm.finish(lm_state) #total_score += lm_score #assert_near(total_score, -19.5123, 1e-5)
emissions = load_emissions(os.path.join(data_path, "emission.bin")) # load transitions (from ASG loss optimization) [Ntokens, Ntokens] transitions = load_transitions(os.path.join(data_path, "transition.bin")) # load lexicon file, which defines spelling of words # the format word and its tokens spelling separated by the spaces, # for example for letters tokens with ASG loss: # ann a n 1 | lexicon = load_words(os.path.join(data_path, "words.lst")) # read lexicon and store it in the w2l dictionary word_dict = create_word_dict(lexicon) # create w2l dict with tokens set (letters in this example) token_dict = Dictionary(os.path.join(data_path, "letters.lst")) # add repetition symbol as soon as we have ASG acoustic model token_dict.add_entry("1") # create Kenlm language model lm = KenLM(os.path.join(data_path, "lm.arpa"), word_dict) # test LM sentence = ["the", "cat", "sat", "on", "the", "mat"] # start LM with nothing, get its current state lm_state = lm.start(False) total_score = 0 lm_score_target = [ -1.05971, -4.19448, -3.33383, -2.76726, -1.16237, -4.64589 ] # iterate over words in the sentence for i in range(len(sentence)): # score lm, taking current state and index of the word # returns new state and score for the word lm_state, lm_score = lm.score(lm_state, word_dict.get_index(sentence[i]))