示例#1
0
    def __init__(self, cfg: DecoderConfig, tgt_dict: Dictionary) -> None:
        self.tgt_dict = tgt_dict
        self.vocab_size = len(tgt_dict)
        self.nbest = cfg.nbest
        self.unitlm = cfg.unitlm

        if cfg.criterion == "ctc":
            self.criterion_type = CriterionType.CTC
            self.blank = (tgt_dict.index("<ctc_blank>") if "<ctc_blank>"
                          in tgt_dict.indices else tgt_dict.bos())
            if "<sep>" in tgt_dict.indices:
                self.silence = tgt_dict.index("<sep>")
            elif "|" in tgt_dict.indices:
                self.silence = tgt_dict.index("|")
            else:
                self.silence = tgt_dict.eos()
            self.asgtransitions = None
        elif cfg.criterion == "asg_loss":
            self.criterion_type = CriterionType.ASG
            self.blank = -1
            self.silence = -1
            self.asgtransitions = cfg.asgtransitions
            self.maxreplabel = cfg.maxreplabel
            assert len(self.asgtransitions) == self.vocab_size**2
        else:
            raise RuntimeError(f"unknown criterion: {cfg.criterion}")
示例#2
0
    def __init__(self,
                 args,
                 token_dictionary: Dictionary,
                 expansion_dictionary: Dictionary,
                 expansion_strategy: ExpansionStrategy,
                 device: Optional[Union[torch.device, str]] = None,
                 regenerate_tokens: bool = False,
                 temperature: float = 1.0):
        self.temperature = temperature

        self.inference = IterativeInference(token_dictionary,
                                            expansion_dictionary,
                                            expansion_strategy,
                                            device,
                                            mask_unk=True)
        self.dependency_placeholder_ids = {
            token_dictionary.index(t)
            for t in expansion_strategy.get_dependency_placeholders()
        }

        def expand(e):
            left_deps, right_deps = expansion_strategy.expand_deps(e)
            left_dep_idxs = [token_dictionary.index(t) for t in left_deps]
            right_dep_idxs = [token_dictionary.index(t) for t in right_deps]
            return left_dep_idxs, right_dep_idxs

        self.expansions = {
            expansion_dictionary.index(e): expand(e)
            for e in expansion_dictionary.symbols
        }
        self.regenerate_tokens = regenerate_tokens
示例#3
0
    def __init__(self, tgt_dict: Dictionary) -> None:
        self.tgt_dict = tgt_dict
        self.vocab_size = len(tgt_dict)

        self.blank = (tgt_dict.index("<ctc_blank>")
                      if "<ctc_blank>" in tgt_dict.indices else tgt_dict.bos())
        if "<sep>" in tgt_dict.indices:
            self.silence = tgt_dict.index("<sep>")
        elif "|" in tgt_dict.indices:
            self.silence = tgt_dict.index("|")
        else:
            self.silence = tgt_dict.eos()
示例#4
0
    def __init__(self,
                 token_dictionary: Dictionary,
                 expansion_dictionary: Dictionary,
                 expansion_strategy: ExpansionStrategy,
                 device=None,
                 mask_unk: bool = True):

        assert token_dictionary.pad_index == expansion_dictionary.pad_index
        self.device = device or torch.device('cpu')
        self.pad_idx = token_dictionary.pad_index
        self.token_dictionary = token_dictionary
        self.expansion_dictionary = expansion_dictionary
        self.expansion_strategy = expansion_strategy
        self.root_token_id = token_dictionary.index(expansion_strategy.root_node_token())

        minus_inf = float('-inf')
        # create mask to use in the token softmax later
        token_prob_mask = np.zeros(shape=(len(token_dictionary)), dtype=np.float32)
        for dep_placeholder in expansion_strategy.get_dependency_placeholders():
            index_to_mask = token_dictionary.index(dep_placeholder)
            if index_to_mask == token_dictionary.unk_index:
                continue   # symbol not found (e.g. [subword] if no subwords) ==> skip
            token_prob_mask[index_to_mask] = minus_inf
        special_token_idxs = [token_dictionary.pad_index, token_dictionary.eos_index]
        if mask_unk:
            special_token_idxs += [token_dictionary.unk_index]
        for special_token_idx in special_token_idxs:
            token_prob_mask[special_token_idx] = minus_inf
        self.token_prob_mask = torch.from_numpy(token_prob_mask).to(device)

        # create mask to use in the expansion softmax later
        expansion_prob_mask = np.zeros(shape=(len(expansion_dictionary)), dtype=np.float32)
        special_idxs = [expansion_dictionary.pad_index, expansion_dictionary.eos_index]
        if mask_unk:
            special_idxs = [expansion_dictionary.unk_index]
        for special_idx in special_idxs:
            expansion_prob_mask[special_idx] = minus_inf
        self.expansion_prob_mask = torch.from_numpy(expansion_prob_mask).to(device)
示例#5
0
def tok_ids(tokens: List[str], dictionary: Dictionary) -> List[int]:
    return [dictionary.index(t) for t in tokens]
示例#6
0
    def __init__(self, cfg: DecoderConfig, tgt_dict: Dictionary) -> None:
        super().__init__(cfg, tgt_dict)

        self.lexicon = load_words(cfg.lexicon) if cfg.lexicon else None
        self.idx_to_wrd = {}

        checkpoint = torch.load(cfg.lmpath, map_location="cpu")

        if "cfg" in checkpoint and checkpoint["cfg"] is not None:
            lm_args = checkpoint["cfg"]
        else:
            lm_args = convert_namespace_to_omegaconf(checkpoint["args"])

        with open_dict(lm_args.task):
            lm_args.task.data = osp.dirname(cfg.lmpath)

        task = tasks.setup_task(lm_args.task)
        model = task.build_model(lm_args.model)
        model.load_state_dict(checkpoint["model"], strict=False)

        self.trie = Trie(self.vocab_size, self.silence)

        self.word_dict = task.dictionary
        self.unk_word = self.word_dict.unk()
        self.lm = FairseqLM(self.word_dict, model)

        if self.lexicon:
            start_state = self.lm.start(False)
            for i, (word, spellings) in enumerate(self.lexicon.items()):
                if self.unitlm:
                    word_idx = i
                    self.idx_to_wrd[i] = word
                    score = 0
                else:
                    word_idx = self.word_dict.index(word)
                    _, score = self.lm.score(start_state,
                                             word_idx,
                                             no_cache=True)

                for spelling in spellings:
                    spelling_idxs = [
                        tgt_dict.index(token) for token in spelling
                    ]
                    assert tgt_dict.unk() not in spelling_idxs, \
                        f"{spelling} {spelling_idxs}"
                    self.trie.insert(spelling_idxs, word_idx, score)
            self.trie.smear(SmearingMode.MAX)

            self.decoder_opts = LexiconDecoderOptions(
                beam_size=cfg.beam,
                beam_size_token=cfg.beamsizetoken or len(tgt_dict),
                beam_threshold=cfg.beamthreshold,
                lm_weight=cfg.lmweight,
                word_score=cfg.wordscore,
                unk_score=cfg.unkweight,
                sil_score=cfg.silweight,
                log_add=False,
                criterion_type=self.criterion_type,
            )

            if self.asgtransitions is None:
                self.asgtransitions = []

            self.decoder = LexiconDecoder(
                self.decoder_opts,
                self.trie,
                self.lm,
                self.silence,
                self.blank,
                self.unk_word,
                self.asgtransitions,
                self.unitlm,
            )
        else:
            assert self.unitlm, "Lexicon-free decoding requires unit LM"

            d = {w: [[w]] for w in tgt_dict.symbols}
            self.word_dict = create_word_dict(d)
            self.lm = KenLM(cfg.lmpath, self.word_dict)
            self.decoder_opts = LexiconFreeDecoderOptions(
                beam_size=cfg.beam,
                beam_size_token=cfg.beamsizetoken or len(tgt_dict),
                beam_threshold=cfg.beamthreshold,
                lm_weight=cfg.lmweight,
                sil_score=cfg.silweight,
                log_add=False,
                criterion_type=self.criterion_type,
            )
            self.decoder = LexiconFreeDecoder(self.decoder_opts, self.lm,
                                              self.silence, self.blank, [])
示例#7
0
    def __init__(self, cfg: DecoderConfig, tgt_dict: Dictionary) -> None:
        super().__init__(cfg, tgt_dict)

        if cfg.lexicon:
            self.lexicon = load_words(cfg.lexicon)
            self.word_dict = create_word_dict(self.lexicon)
            self.unk_word = self.word_dict.get_index("<unk>")

            self.lm = KenLM(cfg.lmpath, self.word_dict)
            self.trie = Trie(self.vocab_size, self.silence)

            start_state = self.lm.start(False)
            for word, spellings in self.lexicon.items():
                word_idx = self.word_dict.get_index(word)
                _, score = self.lm.score(start_state, word_idx)
                for spelling in spellings:
                    spelling_idxs = [
                        tgt_dict.index(token) for token in spelling
                    ]
                    assert tgt_dict.unk() not in spelling_idxs, \
                        f"{spelling} {spelling_idxs}"
                    self.trie.insert(spelling_idxs, word_idx, score)
            self.trie.smear(SmearingMode.MAX)

            self.decoder_opts = LexiconDecoderOptions(
                beam_size=cfg.beam,
                beam_size_token=cfg.beamsizetoken or len(tgt_dict),
                beam_threshold=cfg.beamthreshold,
                lm_weight=cfg.lmweight,
                word_score=cfg.wordscore,
                unk_score=cfg.unkweight,
                sil_score=cfg.silweight,
                log_add=False,
                criterion_type=self.criterion_type,
            )

            if self.asgtransitions is None:
                self.asgtransitions = []

            self.decoder = LexiconDecoder(
                self.decoder_opts,
                self.trie,
                self.lm,
                self.silence,
                self.blank,
                self.unk_word,
                self.asgtransitions,
                self.unitlm,
            )
        else:
            assert self.unitlm, "Lexicon-free decoding requires unit LM"

            d = {w: [[w]] for w in tgt_dict.symbols}
            self.word_dict = create_word_dict(d)
            self.lm = KenLM(cfg.lmpath, self.word_dict)
            self.decoder_opts = LexiconFreeDecoderOptions(
                beam_size=cfg.beam,
                beam_size_token=cfg.beamsizetoken or len(tgt_dict),
                beam_threshold=cfg.beamthreshold,
                lm_weight=cfg.lmweight,
                sil_score=cfg.silweight,
                log_add=False,
                criterion_type=self.criterion_type,
            )
            self.decoder = LexiconFreeDecoder(self.decoder_opts, self.lm,
                                              self.silence, self.blank, [])