Esempio n. 1
0
    #                 eos_score, log_add, criterion_type (ASG or CTC))
    opts = DecoderOptions(2500, 25000, 100.0, 2.0, 2.0, -math.inf, -1, 0,
                          False, CriterionType.ASG)

    # define lexicon beam-search decoder with word-level lm
    # LexiconDecoder(decoder options, trie, lm, silence index,
    #                blank index (for CTC), unk index,
    #                transitiona matrix, is token-level lm)
    decoder = LexiconDecoder(opts, trie, lm, sil_idx, -1, unk_idx, transitions,
                             False)
    # run decoding
    # decoder.decode(emissions, Time, Ntokens)
    # result is a list of sorted hypothesis, 0-index is the best hypothesis
    # each hypothesis is a struct with "score" and "words" representation
    # in the hypothesis and the "tokens" representation
    results = decoder.decode(emissions.ctypes.data, T, N)

    print(f"Decoding complete, obtained {len(results)} results")
    print("Showing top 5 results:")
    for i in range(min(5, len(results))):
        prediction = []
        for idx in results[i].tokens:
            if idx < 0:
                break
            prediction.append(token_dict.get_entry(idx))
        prediction = " ".join(prediction)
        print(f"score={results[i].score} prediction='{prediction}'")

    assert len(results) == 16
    hyp_score_target = [-284.0998, -284.108, -284.119, -284.127, -284.296]
    for i in range(min(5, len(results))):
Esempio n. 2
0
class W2lKenLMDecoder(W2lDecoder):
    def __init__(self, args, tgt_dict):
        super().__init__(args, tgt_dict)

        self.silence = (tgt_dict.index("<ctc_blank>") if "<ctc_blank>"
                        in tgt_dict.indices else tgt_dict.bos())
        self.lexicon = load_words(args.lexicon)
        self.word_dict = create_word_dict(self.lexicon)
        self.unk_word = self.word_dict.get_index("<unk>")

        self.lm = KenLM(args.kenlm_model, self.word_dict)
        self.trie = Trie(self.vocab_size, self.silence)

        start_state = self.lm.start(False)
        for i, (word, spellings) in enumerate(self.lexicon.items()):
            word_idx = self.word_dict.get_index(word)
            _, score = self.lm.score(start_state, word_idx)
            for spelling in spellings:
                spelling_idxs = [tgt_dict.index(token) for token in spelling]
                assert (tgt_dict.unk()
                        not in spelling_idxs), f"{spelling} {spelling_idxs}"
                self.trie.insert(spelling_idxs, word_idx, score)
        self.trie.smear(SmearingMode.MAX)

        self.decoder_opts = DecoderOptions(
            args.beam,
            int(getattr(args, "beam_size_token", len(tgt_dict))),
            args.beam_threshold,
            args.lm_weight,
            args.word_score,
            args.unk_weight,
            args.sil_weight,
            0,
            False,
            self.criterion_type,
        )

        if self.asg_transitions is None:
            N = 768
            # self.asg_transitions = torch.FloatTensor(N, N).zero_()
            self.asg_transitions = []

        self.decoder = LexiconDecoder(
            self.decoder_opts,
            self.trie,
            self.lm,
            self.silence,
            self.blank,
            self.unk_word,
            self.asg_transitions,
            False,
        )

    def decode(self, emissions):
        B, T, N = emissions.size()
        hypos = []
        for b in range(B):
            emissions_ptr = emissions.data_ptr() + 4 * b * emissions.stride(0)
            results = self.decoder.decode(emissions_ptr, T, N)

            nbest_results = results[:self.nbest]
            hypos.append([{
                "tokens":
                self.get_tokens(result.tokens),
                "score":
                result.score,
                "words":
                [self.word_dict.get_entry(x) for x in result.words if x >= 0],
            } for result in nbest_results])
        return hypos
Esempio n. 3
0
class W2lFairseqLMDecoder(W2lDecoder):
    def __init__(self, args, tgt_dict):
        super().__init__(args, tgt_dict)

        # self.silence = tgt_dict.bos()
        self.silence = 1

        self.unit_lm = getattr(args, "unit_lm", False)

        self.lexicon = load_words(args.lexicon) if args.lexicon else None
        self.idx_to_wrd = {}

        checkpoint = torch.load(args.kenlm_model, map_location="cpu")
        lm_args = checkpoint["args"]
        lm_args.data = osp.dirname(args.kenlm_model)
        print(lm_args)
        # import pdb; pdb.set_trace()
        # lm_args.data = ''
        task = tasks.setup_task(lm_args)
        model = task.build_model(lm_args)
        model.load_state_dict(checkpoint["model"], strict=False)

        self.trie = Trie(self.vocab_size, self.silence)

        self.word_dict = task.dictionary
        self.unk_word = self.word_dict.unk()
        self.lm = FairseqLM(self.word_dict, model)

        self.decoder_opts = DecoderOptions(
            args.beam,
            int(getattr(args, "beam_size_token", len(tgt_dict))),
            args.beam_threshold,
            args.lm_weight,
            args.word_score,
            args.unk_weight,
            args.sil_weight,
            0,
            False,
            self.criterion_type,
        )

        if self.lexicon:
            start_state = self.lm.start(False)
            for i, (word, spellings) in enumerate(self.lexicon.items()):
                if self.unit_lm:
                    word_idx = i
                    self.idx_to_wrd[i] = word
                    score = 0
                else:
                    word_idx = self.word_dict.index(word)
                    _, score = self.lm.score(start_state,
                                             word_idx,
                                             no_cache=True)

                # spellings (rightside in lexicon) should nor be unk
                for spelling in spellings:
                    spelling_idxs = [
                        tgt_dict.index(token) for token in spelling
                    ]
                    # if tgt_dict.unk() in spelling_idxs:
                    #     print(spelling)
                    assert (
                        tgt_dict.unk() not in spelling_idxs
                    ), f"{spelling} is unk: {spelling_idxs} for acoustic output"
                    self.trie.insert(spelling_idxs, word_idx, score)
            self.trie.smear(SmearingMode.MAX)

            self.decoder = LexiconDecoder(
                self.decoder_opts,
                self.trie,
                self.lm,
                self.silence,
                self.blank,
                self.unk_word,
                [],
                self.unit_lm,
            )
        # else:
        #     self.decoder = LexiconFreeDecoder(
        #         self.decoder_opts, self.lm, self.silence, self.blank, []
        #     )

    def decode(self, emissions):
        B, T, N = emissions.size()
        hypos = []

        def idx_to_word(idx):
            if self.unit_lm:
                return self.idx_to_wrd[idx]
            else:
                return self.word_dict[idx]

        def make_hypo(result):
            hypo = {
                "tokens": self.get_tokens(result.tokens),
                "score": result.score
            }
            if self.lexicon:
                hypo["words"] = [
                    idx_to_word(x) for x in result.words if x >= 0
                ]
            return hypo

        for b in range(B):
            emissions_ptr = emissions.data_ptr() + 4 * b * emissions.stride(0)
            results = self.decoder.decode(emissions_ptr, T, N)

            nbest_results = results[:self.nbest]
            hypos.append([make_hypo(result) for result in nbest_results])
            self.lm.empty_cache()

        return hypos