Beispiel #1
0
class W2lKenLMDecoder(W2lDecoder):
    def __init__(self, args, tgt_dict):
        super().__init__(args, tgt_dict)

        self.silence = tgt_dict.index(args.silence_token)

        self.lexicon = load_words(args.lexicon)
        self.word_dict = create_word_dict(self.lexicon)
        self.unk_word = self.word_dict.get_index("<unk>")

        self.lm = KenLM(args.kenlm_model, self.word_dict)
        self.trie = Trie(self.vocab_size, self.silence)

        start_state = self.lm.start(False)
        for word, spellings in self.lexicon.items():
            word_idx = self.word_dict.get_index(word)
            _, score = self.lm.score(start_state, word_idx)
            for spelling in spellings:
                spelling_idxs = [tgt_dict.index(token) for token in spelling]
                self.trie.insert(spelling_idxs, word_idx, score)
        self.trie.smear(SmearingMode.MAX)

        self.decoder_opts = DecoderOptions(
            args.beam,
            args.beam_threshold,
            args.lm_weight,
            args.word_score,
            args.unk_weight,
            False,
            args.sil_weight,
            self.criterion_type,
        )

        self.decoder = WordLMDecoder(
            self.decoder_opts,
            self.trie,
            self.lm,
            self.silence,
            self.blank,
            self.unk_word,
            self.asg_transitions,
        )

    def decode(self, emissions):
        B, T, N = emissions.size()
        hypos = []
        for b in range(B):
            emissions_ptr = emissions.data_ptr() + 4 * b * emissions.stride(0)
            nbest_results = self.decoder.decode(emissions_ptr, T,
                                                N)[:self.nbest]
            hypos.append([{
                "tokens": self.get_tokens(result.tokens),
                "score": result.score
            } for result in nbest_results])
        return hypos
Beispiel #2
0
class W2lKenLMDecoder(W2lDecoder):
    def __init__(self, args, tgt_dict):
        super().__init__(args, tgt_dict)

        self.silence = (tgt_dict.index("<ctc_blank>") if "<ctc_blank>"
                        in tgt_dict.indices else tgt_dict.bos())
        self.lexicon = load_words(args.lexicon)
        self.word_dict = create_word_dict(self.lexicon)
        self.unk_word = self.word_dict.get_index("<unk>")

        self.lm = KenLM(args.kenlm_model, self.word_dict)
        self.trie = Trie(self.vocab_size, self.silence)

        start_state = self.lm.start(False)
        for i, (word, spellings) in enumerate(self.lexicon.items()):
            word_idx = self.word_dict.get_index(word)
            _, score = self.lm.score(start_state, word_idx)
            for spelling in spellings:
                spelling_idxs = [tgt_dict.index(token) for token in spelling]
                assert (tgt_dict.unk()
                        not in spelling_idxs), f"{spelling} {spelling_idxs}"
                self.trie.insert(spelling_idxs, word_idx, score)
        self.trie.smear(SmearingMode.MAX)

        self.decoder_opts = DecoderOptions(
            args.beam,
            int(getattr(args, "beam_size_token", len(tgt_dict))),
            args.beam_threshold,
            args.lm_weight,
            args.word_score,
            args.unk_weight,
            args.sil_weight,
            0,
            False,
            self.criterion_type,
        )

        if self.asg_transitions is None:
            N = 768
            # self.asg_transitions = torch.FloatTensor(N, N).zero_()
            self.asg_transitions = []

        self.decoder = LexiconDecoder(
            self.decoder_opts,
            self.trie,
            self.lm,
            self.silence,
            self.blank,
            self.unk_word,
            self.asg_transitions,
            False,
        )

    def decode(self, emissions):
        B, T, N = emissions.size()
        hypos = []
        for b in range(B):
            emissions_ptr = emissions.data_ptr() + 4 * b * emissions.stride(0)
            results = self.decoder.decode(emissions_ptr, T, N)

            nbest_results = results[:self.nbest]
            hypos.append([{
                "tokens":
                self.get_tokens(result.tokens),
                "score":
                result.score,
                "words":
                [self.word_dict.get_entry(x) for x in result.words if x >= 0],
            } for result in nbest_results])
        return hypos
#    assert_near(lm_score, lm_score_target[i], 1e-5)
#    total_score += lm_score
#lm_state, lm_score = lm.finish(lm_state)
#total_score += lm_score
#assert_near(total_score, -19.5123, 1e-5)

# build trie
sentence = ["the", "cat", "sat", "on", "the", "mat"]
sil_idx = tokenDict.getIndex("|")
unk_idx = wordDict.getIndex("<unk>")
trie = Trie(tokenDict.indexSize(), sil_idx)
start_state = lm.start(False)

for word, spellings in lexicon.items():
    usr_idx = wordDict.getIndex(word)
    _, score = lm.score(start_state, usr_idx)
    for spelling in spellings:
        # maxReps should be 1; using 0 here to match DecoderTest bug
        spelling_idxs = tkn2Idx(spelling, tokenDict, 0)
        trie.insert(spelling_idxs, usr_idx, score)

trie.smear(SmearingMode.MAX)

trie_score_target = [
    -1.05971, -2.87742, -2.64553, -3.05081, -1.05971, -3.08968
]
for i in range(len(sentence)):
    word = sentence[i]
    # maxReps should be 1; using 0 here to match DecoderTest bug
    word_tensor = tkn2Idx([c for c in word], tokenDict, 0)
    node = trie.search(word_tensor)
class WlDecoder:
    """
    Wav2Letter-based decoder. Follows the official examples for the python bindings, 
    see https://github.com/facebookresearch/wav2letter/blob/master/bindings/python/examples/decoder_example.py
    """
    def __init__(self,
                 lm_weight=2.0,
                 lexicon_path="WER_data/lexicon.txt",
                 token_path="WER_data/letters.lst",
                 lm_path="WER_data/4-gram.bin"):
        lexicon = load_words(lexicon_path)
        word_dict = create_word_dict(lexicon)

        self.token_dict = Dictionary(token_path)
        self.lm = KenLM(lm_path, word_dict)

        self.sil_idx = self.token_dict.get_index("|")
        self.unk_idx = word_dict.get_index("<unk>")
        self.token_dict.add_entry("#")
        self.blank_idx = self.token_dict.get_index('#')

        self.trie = Trie(self.token_dict.index_size(), self.sil_idx)
        start_state = self.lm.start(start_with_nothing=False)

        for word, spellings in lexicon.items():
            usr_idx = word_dict.get_index(word)
            _, score = self.lm.score(start_state, usr_idx)
            for spelling in spellings:
                # max_reps should be 1; using 0 here to match DecoderTest bug
                spelling_idxs = tkn_to_idx(spelling,
                                           self.token_dict,
                                           max_reps=0)
                self.trie.insert(spelling_idxs, usr_idx, score)

        self.trie.smear(SmearingMode.MAX)
        self.opts = DecoderOptions(beam_size=2500,
                                   beam_threshold=100.0,
                                   lm_weight=lm_weight,
                                   word_score=2.0,
                                   unk_score=-math.inf,
                                   log_add=False,
                                   sil_weight=-1,
                                   criterion_type=CriterionType.CTC)

    def collapse(self, prediction):
        result = []

        for p in prediction:
            if result and p == result[-1]:
                continue
            result.append(p)

        blank = '#'
        space = '|'

        result = [x for x in result if x != blank]
        result = [(x if x != space else ' ') for x in result if x != blank]
        return result

    def predictions(self, emissions):
        t, n = emissions.size()

        emissions = emissions.cpu().numpy()
        decoder = WordLMDecoder(self.opts, self.trie, self.lm, self.sil_idx,
                                self.blank_idx, self.unk_idx, [])
        results = decoder.decode(emissions.ctypes.data, t, n)

        prediction = [
            self.token_dict.get_entry(x) for x in results[0].tokens if x >= 0
        ]
        prediction = self.collapse(prediction)

        return prediction
Beispiel #5
0
    # create Kenlm language model
    lm = KenLM(os.path.join(data_path, "lm.arpa"), word_dict)

    # test LM
    sentence = ["the", "cat", "sat", "on", "the", "mat"]
    # start LM with nothing, get its current state
    lm_state = lm.start(False)
    total_score = 0
    lm_score_target = [
        -1.05971, -4.19448, -3.33383, -2.76726, -1.16237, -4.64589
    ]
    # iterate over words in the sentence
    for i in range(len(sentence)):
        # score lm, taking current state and index of the word
        # returns new state and score for the word
        lm_state, lm_score = lm.score(lm_state,
                                      word_dict.get_index(sentence[i]))
        assert_near(lm_score, lm_score_target[i], 1e-5)
        # add score of the current word to the total sentence score
        total_score += lm_score
    # move lm to the final state, the score returned is for eos
    lm_state, lm_score = lm.finish(lm_state)
    total_score += lm_score
    assert_near(total_score, -19.5123, 1e-5)

    # build trie
    # Trie is necessary to do beam-search decoding with word-level lm
    # We restrict our search only to the words from the lexicon
    # Trie is constructed from the lexicon, each node is a token
    # path from the root to a leaf corresponds to a word spelling in the lexicon

    # get silence index