예제 #1
0
    def __init__(self, args, tgt_dict):
        super().__init__(args, tgt_dict)

        self.silence = (
            tgt_dict.index("<ctc_blank>")
            if "<ctc_blank>" in tgt_dict.indices
            else tgt_dict.bos()
        )
        self.lexicon = load_words(args.lexicon)
        self.word_dict = create_word_dict(self.lexicon)
        self.unk_word = self.word_dict.get_index("<unk>")

        self.lm = KenLM(args.kenlm_model, self.word_dict)
        self.trie = Trie(self.vocab_size, self.silence)

        start_state = self.lm.start(False)
        for i, (word, spellings) in enumerate(self.lexicon.items()):
            word_idx = self.word_dict.get_index(word)
            _, score = self.lm.score(start_state, word_idx)
            for spelling in spellings:
                spelling_idxs = [tgt_dict.index(token) for token in spelling]
                assert (
                    tgt_dict.unk() not in spelling_idxs
                ), f"{spelling} {spelling_idxs}"
                self.trie.insert(spelling_idxs, word_idx, score)
        self.trie.smear(SmearingMode.MAX)

        self.decoder_opts = DecoderOptions(
            args.beam,
            int(getattr(args, "beam_size_token", len(tgt_dict))),
            args.beam_threshold,
            args.lm_weight,
            args.word_score,
            args.unk_weight,
            args.sil_weight,
            0,
            False,
            self.criterion_type,
        )

        if self.asg_transitions is None:
            N = 768
            # self.asg_transitions = torch.FloatTensor(N, N).zero_()
            self.asg_transitions = []

        self.decoder = LexiconDecoder(
            self.decoder_opts,
            self.trie,
            self.lm,
            self.silence,
            self.blank,
            self.unk_word,
            self.asg_transitions,
            False,
        )
예제 #2
0
    def __init__(self, args, tgt_dict):
        super().__init__(args, tgt_dict)

        self.silence = tgt_dict.index(args.silence_token)

        self.lexicon = load_words(args.lexicon)
        self.word_dict = create_word_dict(self.lexicon)
        self.unk_word = self.word_dict.get_index("<unk>")

        self.lm = KenLM(args.kenlm_model, self.word_dict)
        self.trie = Trie(self.vocab_size, self.silence)

        start_state = self.lm.start(False)
        for word, spellings in self.lexicon.items():
            word_idx = self.word_dict.get_index(word)
            _, score = self.lm.score(start_state, word_idx)
            for spelling in spellings:
                spelling_idxs = [tgt_dict.index(token) for token in spelling]
                self.trie.insert(spelling_idxs, word_idx, score)
        self.trie.smear(SmearingMode.MAX)

        self.decoder_opts = DecoderOptions(
            args.beam,
            args.beam_threshold,
            args.lm_weight,
            args.word_score,
            args.unk_weight,
            False,
            args.sil_weight,
            self.criterion_type,
        )

        self.decoder = WordLMDecoder(
            self.decoder_opts,
            self.trie,
            self.lm,
            self.silence,
            self.blank,
            self.unk_word,
            self.asg_transitions,
        )
예제 #3
0
    def __init__(self,
                 lm_weight=2.0,
                 lexicon_path="WER_data/lexicon.txt",
                 token_path="WER_data/letters.lst",
                 lm_path="WER_data/4-gram.bin"):
        lexicon = load_words(lexicon_path)
        word_dict = create_word_dict(lexicon)

        self.token_dict = Dictionary(token_path)
        self.lm = KenLM(lm_path, word_dict)

        self.sil_idx = self.token_dict.get_index("|")
        self.unk_idx = word_dict.get_index("<unk>")
        self.token_dict.add_entry("#")
        self.blank_idx = self.token_dict.get_index('#')

        self.trie = Trie(self.token_dict.index_size(), self.sil_idx)
        start_state = self.lm.start(start_with_nothing=False)

        for word, spellings in lexicon.items():
            usr_idx = word_dict.get_index(word)
            _, score = self.lm.score(start_state, usr_idx)
            for spelling in spellings:
                # max_reps should be 1; using 0 here to match DecoderTest bug
                spelling_idxs = tkn_to_idx(spelling,
                                           self.token_dict,
                                           max_reps=0)
                self.trie.insert(spelling_idxs, usr_idx, score)

        self.trie.smear(SmearingMode.MAX)
        self.opts = DecoderOptions(beam_size=2500,
                                   beam_threshold=100.0,
                                   lm_weight=lm_weight,
                                   word_score=2.0,
                                   unk_score=-math.inf,
                                   log_add=False,
                                   sil_weight=-1,
                                   criterion_type=CriterionType.CTC)
예제 #4
0
    def __init__(self, args, tgt_dict):
        super().__init__(args, tgt_dict)

        self.silence = tgt_dict.bos()

        self.unit_lm = getattr(args, "unit_lm", False)

        self.lexicon = load_words(args.lexicon) if args.lexicon else None
        self.idx_to_wrd = {}

        checkpoint = torch.load(args.kenlm_model, map_location="cpu")

        if "cfg" in checkpoint and checkpoint["cfg"] is not None:
            lm_args = checkpoint["cfg"]
        else:
            lm_args = convert_namespace_to_omegaconf(checkpoint["args"])

        with open_dict(lm_args.task):
            lm_args.task.data = osp.dirname(args.kenlm_model)

        task = tasks.setup_task(lm_args.task)
        model = task.build_model(lm_args.model)
        model.load_state_dict(checkpoint["model"], strict=False)

        self.trie = Trie(self.vocab_size, self.silence)

        self.word_dict = task.dictionary
        self.unk_word = self.word_dict.unk()
        self.lm = FairseqLM(self.word_dict, model)

        self.decoder_opts = DecoderOptions(
            args.beam,
            int(getattr(args, "beam_size_token", len(tgt_dict))),
            args.beam_threshold,
            args.lm_weight,
            args.word_score,
            args.unk_weight,
            args.sil_weight,
            0,
            False,
            self.criterion_type,
        )

        if self.lexicon:
            start_state = self.lm.start(False)
            for i, (word, spellings) in enumerate(self.lexicon.items()):
                if self.unit_lm:
                    word_idx = i
                    self.idx_to_wrd[i] = word
                    score = 0
                else:
                    word_idx = self.word_dict.index(word)
                    _, score = self.lm.score(start_state,
                                             word_idx,
                                             no_cache=True)

                for spelling in spellings:
                    spelling_idxs = [
                        tgt_dict.index(token) for token in spelling
                    ]
                    assert (
                        tgt_dict.unk()
                        not in spelling_idxs), f"{spelling} {spelling_idxs}"
                    self.trie.insert(spelling_idxs, word_idx, score)
            self.trie.smear(SmearingMode.MAX)

            self.decoder = LexiconDecoder(
                self.decoder_opts,
                self.trie,
                self.lm,
                self.silence,
                self.blank,
                self.unk_word,
                [],
                self.unit_lm,
            )
        else:
            from wav2letter.decoder import LexiconFreeDecoder
            self.decoder = LexiconFreeDecoder(self.decoder_opts, self.lm,
                                              self.silence, self.blank, [])
예제 #5
0
    word = sentence[i]
    # maxReps should be 1; using 0 here to match DecoderTest bug
    word_tensor = tkn2Idx([c for c in word], tokenDict, 0)
    node = trie.search(word_tensor)
    assert_near(node.maxScore, trie_score_target[i], 1e-5)

beamSize = 2500
beamThreshold = 100.0
lmWeight = 2.0
wordScore = 2.0
unkScore = -math.inf
logAdd = False
silWeight = -1
criterionType = CriterionType.ASG

opts = DecoderOptions(beamSize, beamThreshold, lmWeight, wordScore, unkScore,
                      logAdd, silWeight, criterionType)

decoder = WordLMDecoder(opts, trie, lm, sil_idx, -1, unk_idx, transitions)
results = decoder.decode(emissions.ctypes.data, T, N)

print(f"Decoding complete, obtained {len(results)} results")
print("Showing top 5 results:")
for i in range(5):
    prediction = []
    for idx in results[i].tokens:
        if idx < 0:
            break
        prediction.append(tokenDict.getEntry(idx))
    prediction = " ".join(prediction)
    print(f"score={results[i].score} prediction='{prediction}'")
예제 #6
0
    # check that trie is built in consistency with c++
    trie_score_target = [
        -1.05971, -2.87742, -2.64553, -3.05081, -1.05971, -3.08968
    ]
    for i in range(len(sentence)):
        word = sentence[i]
        # max_reps should be 1; using 0 here to match DecoderTest bug
        word_tensor = tkn_to_idx([c for c in word], token_dict, 0)
        node = trie.search(word_tensor)
        assert_near(node.max_score, trie_score_target[i], 1e-5)

    # Define decoder options:
    # DecoderOptions (beam_size, token_beam_size, beam_threshold, lm_weight,
    #                 word_score, unk_score, sil_score,
    #                 eos_score, log_add, criterion_type (ASG or CTC))
    opts = DecoderOptions(2500, 25000, 100.0, 2.0, 2.0, -math.inf, -1, 0,
                          False, CriterionType.ASG)

    # define lexicon beam-search decoder with word-level lm
    # LexiconDecoder(decoder options, trie, lm, silence index,
    #                blank index (for CTC), unk index,
    #                transitiona matrix, is token-level lm)
    decoder = LexiconDecoder(opts, trie, lm, sil_idx, -1, unk_idx, transitions,
                             False)
    # run decoding
    # decoder.decode(emissions, Time, Ntokens)
    # result is a list of sorted hypothesis, 0-index is the best hypothesis
    # each hypothesis is a struct with "score" and "words" representation
    # in the hypothesis and the "tokens" representation
    results = decoder.decode(emissions.ctypes.data, T, N)

    print(f"Decoding complete, obtained {len(results)} results")
예제 #7
0
    def __init__(self, args, tgt_dict):
        super().__init__(args, tgt_dict)

        # self.silence = tgt_dict.bos()
        self.silence = 1

        self.unit_lm = getattr(args, "unit_lm", False)

        self.lexicon = load_words(args.lexicon) if args.lexicon else None
        self.idx_to_wrd = {}

        checkpoint = torch.load(args.kenlm_model, map_location="cpu")
        lm_args = checkpoint["args"]
        lm_args.data = osp.dirname(args.kenlm_model)
        print(lm_args)
        # import pdb; pdb.set_trace()
        # lm_args.data = ''
        task = tasks.setup_task(lm_args)
        model = task.build_model(lm_args)
        model.load_state_dict(checkpoint["model"], strict=False)

        self.trie = Trie(self.vocab_size, self.silence)

        self.word_dict = task.dictionary
        self.unk_word = self.word_dict.unk()
        self.lm = FairseqLM(self.word_dict, model)

        self.decoder_opts = DecoderOptions(
            args.beam,
            int(getattr(args, "beam_size_token", len(tgt_dict))),
            args.beam_threshold,
            args.lm_weight,
            args.word_score,
            args.unk_weight,
            args.sil_weight,
            0,
            False,
            self.criterion_type,
        )

        if self.lexicon:
            start_state = self.lm.start(False)
            for i, (word, spellings) in enumerate(self.lexicon.items()):
                if self.unit_lm:
                    word_idx = i
                    self.idx_to_wrd[i] = word
                    score = 0
                else:
                    word_idx = self.word_dict.index(word)
                    _, score = self.lm.score(start_state,
                                             word_idx,
                                             no_cache=True)

                # spellings (rightside in lexicon) should nor be unk
                for spelling in spellings:
                    spelling_idxs = [
                        tgt_dict.index(token) for token in spelling
                    ]
                    # if tgt_dict.unk() in spelling_idxs:
                    #     print(spelling)
                    assert (
                        tgt_dict.unk() not in spelling_idxs
                    ), f"{spelling} is unk: {spelling_idxs} for acoustic output"
                    self.trie.insert(spelling_idxs, word_idx, score)
            self.trie.smear(SmearingMode.MAX)

            self.decoder = LexiconDecoder(
                self.decoder_opts,
                self.trie,
                self.lm,
                self.silence,
                self.blank,
                self.unk_word,
                [],
                self.unit_lm,
            )