def __init__(self, cfg: DecoderConfig, tgt_dict: Dictionary) -> None: self.tgt_dict = tgt_dict self.vocab_size = len(tgt_dict) self.nbest = cfg.nbest self.unitlm = cfg.unitlm if cfg.criterion == "ctc": self.criterion_type = CriterionType.CTC self.blank = (tgt_dict.index("<ctc_blank>") if "<ctc_blank>" in tgt_dict.indices else tgt_dict.bos()) if "<sep>" in tgt_dict.indices: self.silence = tgt_dict.index("<sep>") elif "|" in tgt_dict.indices: self.silence = tgt_dict.index("|") else: self.silence = tgt_dict.eos() self.asgtransitions = None elif cfg.criterion == "asg_loss": self.criterion_type = CriterionType.ASG self.blank = -1 self.silence = -1 self.asgtransitions = cfg.asgtransitions self.maxreplabel = cfg.maxreplabel assert len(self.asgtransitions) == self.vocab_size**2 else: raise RuntimeError(f"unknown criterion: {cfg.criterion}")
def __init__(self, args, token_dictionary: Dictionary, expansion_dictionary: Dictionary, expansion_strategy: ExpansionStrategy, device: Optional[Union[torch.device, str]] = None, regenerate_tokens: bool = False, temperature: float = 1.0): self.temperature = temperature self.inference = IterativeInference(token_dictionary, expansion_dictionary, expansion_strategy, device, mask_unk=True) self.dependency_placeholder_ids = { token_dictionary.index(t) for t in expansion_strategy.get_dependency_placeholders() } def expand(e): left_deps, right_deps = expansion_strategy.expand_deps(e) left_dep_idxs = [token_dictionary.index(t) for t in left_deps] right_dep_idxs = [token_dictionary.index(t) for t in right_deps] return left_dep_idxs, right_dep_idxs self.expansions = { expansion_dictionary.index(e): expand(e) for e in expansion_dictionary.symbols } self.regenerate_tokens = regenerate_tokens
def __init__(self, tgt_dict: Dictionary) -> None: self.tgt_dict = tgt_dict self.vocab_size = len(tgt_dict) self.blank = (tgt_dict.index("<ctc_blank>") if "<ctc_blank>" in tgt_dict.indices else tgt_dict.bos()) if "<sep>" in tgt_dict.indices: self.silence = tgt_dict.index("<sep>") elif "|" in tgt_dict.indices: self.silence = tgt_dict.index("|") else: self.silence = tgt_dict.eos()
def __init__(self, token_dictionary: Dictionary, expansion_dictionary: Dictionary, expansion_strategy: ExpansionStrategy, device=None, mask_unk: bool = True): assert token_dictionary.pad_index == expansion_dictionary.pad_index self.device = device or torch.device('cpu') self.pad_idx = token_dictionary.pad_index self.token_dictionary = token_dictionary self.expansion_dictionary = expansion_dictionary self.expansion_strategy = expansion_strategy self.root_token_id = token_dictionary.index(expansion_strategy.root_node_token()) minus_inf = float('-inf') # create mask to use in the token softmax later token_prob_mask = np.zeros(shape=(len(token_dictionary)), dtype=np.float32) for dep_placeholder in expansion_strategy.get_dependency_placeholders(): index_to_mask = token_dictionary.index(dep_placeholder) if index_to_mask == token_dictionary.unk_index: continue # symbol not found (e.g. [subword] if no subwords) ==> skip token_prob_mask[index_to_mask] = minus_inf special_token_idxs = [token_dictionary.pad_index, token_dictionary.eos_index] if mask_unk: special_token_idxs += [token_dictionary.unk_index] for special_token_idx in special_token_idxs: token_prob_mask[special_token_idx] = minus_inf self.token_prob_mask = torch.from_numpy(token_prob_mask).to(device) # create mask to use in the expansion softmax later expansion_prob_mask = np.zeros(shape=(len(expansion_dictionary)), dtype=np.float32) special_idxs = [expansion_dictionary.pad_index, expansion_dictionary.eos_index] if mask_unk: special_idxs = [expansion_dictionary.unk_index] for special_idx in special_idxs: expansion_prob_mask[special_idx] = minus_inf self.expansion_prob_mask = torch.from_numpy(expansion_prob_mask).to(device)
def tok_ids(tokens: List[str], dictionary: Dictionary) -> List[int]: return [dictionary.index(t) for t in tokens]
def __init__(self, cfg: DecoderConfig, tgt_dict: Dictionary) -> None: super().__init__(cfg, tgt_dict) self.lexicon = load_words(cfg.lexicon) if cfg.lexicon else None self.idx_to_wrd = {} checkpoint = torch.load(cfg.lmpath, map_location="cpu") if "cfg" in checkpoint and checkpoint["cfg"] is not None: lm_args = checkpoint["cfg"] else: lm_args = convert_namespace_to_omegaconf(checkpoint["args"]) with open_dict(lm_args.task): lm_args.task.data = osp.dirname(cfg.lmpath) task = tasks.setup_task(lm_args.task) model = task.build_model(lm_args.model) model.load_state_dict(checkpoint["model"], strict=False) self.trie = Trie(self.vocab_size, self.silence) self.word_dict = task.dictionary self.unk_word = self.word_dict.unk() self.lm = FairseqLM(self.word_dict, model) if self.lexicon: start_state = self.lm.start(False) for i, (word, spellings) in enumerate(self.lexicon.items()): if self.unitlm: word_idx = i self.idx_to_wrd[i] = word score = 0 else: word_idx = self.word_dict.index(word) _, score = self.lm.score(start_state, word_idx, no_cache=True) for spelling in spellings: spelling_idxs = [ tgt_dict.index(token) for token in spelling ] assert tgt_dict.unk() not in spelling_idxs, \ f"{spelling} {spelling_idxs}" self.trie.insert(spelling_idxs, word_idx, score) self.trie.smear(SmearingMode.MAX) self.decoder_opts = LexiconDecoderOptions( beam_size=cfg.beam, beam_size_token=cfg.beamsizetoken or len(tgt_dict), beam_threshold=cfg.beamthreshold, lm_weight=cfg.lmweight, word_score=cfg.wordscore, unk_score=cfg.unkweight, sil_score=cfg.silweight, log_add=False, criterion_type=self.criterion_type, ) if self.asgtransitions is None: self.asgtransitions = [] self.decoder = LexiconDecoder( self.decoder_opts, self.trie, self.lm, self.silence, self.blank, self.unk_word, self.asgtransitions, self.unitlm, ) else: assert self.unitlm, "Lexicon-free decoding requires unit LM" d = {w: [[w]] for w in tgt_dict.symbols} self.word_dict = create_word_dict(d) self.lm = KenLM(cfg.lmpath, self.word_dict) self.decoder_opts = LexiconFreeDecoderOptions( beam_size=cfg.beam, beam_size_token=cfg.beamsizetoken or len(tgt_dict), beam_threshold=cfg.beamthreshold, lm_weight=cfg.lmweight, sil_score=cfg.silweight, log_add=False, criterion_type=self.criterion_type, ) self.decoder = LexiconFreeDecoder(self.decoder_opts, self.lm, self.silence, self.blank, [])
def __init__(self, cfg: DecoderConfig, tgt_dict: Dictionary) -> None: super().__init__(cfg, tgt_dict) if cfg.lexicon: self.lexicon = load_words(cfg.lexicon) self.word_dict = create_word_dict(self.lexicon) self.unk_word = self.word_dict.get_index("<unk>") self.lm = KenLM(cfg.lmpath, self.word_dict) self.trie = Trie(self.vocab_size, self.silence) start_state = self.lm.start(False) for word, spellings in self.lexicon.items(): word_idx = self.word_dict.get_index(word) _, score = self.lm.score(start_state, word_idx) for spelling in spellings: spelling_idxs = [ tgt_dict.index(token) for token in spelling ] assert tgt_dict.unk() not in spelling_idxs, \ f"{spelling} {spelling_idxs}" self.trie.insert(spelling_idxs, word_idx, score) self.trie.smear(SmearingMode.MAX) self.decoder_opts = LexiconDecoderOptions( beam_size=cfg.beam, beam_size_token=cfg.beamsizetoken or len(tgt_dict), beam_threshold=cfg.beamthreshold, lm_weight=cfg.lmweight, word_score=cfg.wordscore, unk_score=cfg.unkweight, sil_score=cfg.silweight, log_add=False, criterion_type=self.criterion_type, ) if self.asgtransitions is None: self.asgtransitions = [] self.decoder = LexiconDecoder( self.decoder_opts, self.trie, self.lm, self.silence, self.blank, self.unk_word, self.asgtransitions, self.unitlm, ) else: assert self.unitlm, "Lexicon-free decoding requires unit LM" d = {w: [[w]] for w in tgt_dict.symbols} self.word_dict = create_word_dict(d) self.lm = KenLM(cfg.lmpath, self.word_dict) self.decoder_opts = LexiconFreeDecoderOptions( beam_size=cfg.beam, beam_size_token=cfg.beamsizetoken or len(tgt_dict), beam_threshold=cfg.beamthreshold, lm_weight=cfg.lmweight, sil_score=cfg.silweight, log_add=False, criterion_type=self.criterion_type, ) self.decoder = LexiconFreeDecoder(self.decoder_opts, self.lm, self.silence, self.blank, [])