def build( word_dict: AsrDictionary, subword_dict: AsrDictionary, subword_tokenizer: Callable[[str], List[str]] = None ): """ Builds a tensorized lexical prefix tree for words. """ root = lexical_prefix_tree( word_dict=word_dict, subword_dict=subword_dict, subword_tokenizer=subword_tokenizer ) # build traditional tree data structure by reusing existing routines # Performs pre-order traversal of this tree to assign an index for each node max_num_children = 0 nodes = [None] # nodes[0] is a dummy node for OOV node_to_id_dict = {} stack = [root] while len(stack) > 0: curr = stack.pop() node_id = len(nodes) nodes.append(curr) node_to_id_dict[curr] = node_id if len(curr.children) > max_num_children: max_num_children = len(curr.children) # Guarantee that the children are traversed ascendingly according to the subword index for _, next_node in sorted(curr.children.items(), key=lambda t: t[0], reverse=True): stack.append(next_node) # Construct the tree num_nodes = len(nodes) children = np.full([num_nodes, max_num_children], 0, dtype=np.int64) prev_subword_idx = np.full([num_nodes], subword_dict.pad(), dtype=np.int64) word_idx = np.full([num_nodes], -1, dtype=np.int64) word_set_idx = np.full([num_nodes, 2], word_dict.pad(), dtype=np.int64) for node_id in range(1, len(nodes)): # skip 0, which is `None` node = nodes[node_id] # Guarantee that the children are traversed ascendingly according to the subword index for i, (subword_id, child) in enumerate(sorted(node.children.items(), key=lambda t: t[0])): child_node_id = node_to_id_dict[child] children[node_id, i] = child_node_id prev_subword_idx[child_node_id] = subword_id word_idx[node_id] = node.word_idx if node.word_set is not None: word_set_idx[node_id] = node.word_set else: word_set_idx[node_id] = [0, len(word_dict) - 1] return TensorizedPrefixTree( children=torch.from_numpy(children), prev_subword_idx=torch.from_numpy(prev_subword_idx), word_idx=torch.from_numpy(word_idx), word_set_idx=torch.from_numpy(word_set_idx) )
def __init__( self, wordlm, subwordlm, subwordlm_weight=0.8, oov_penalty=1.0, open_vocab=True, ): super().__init__(wordlm.decoder.dictionary) assert isinstance(wordlm, FairseqLanguageModel) self.wordlm_decoder = wordlm.decoder assert ( hasattr(self.wordlm_decoder, "masked_copy_incremental_state") and callable(self.wordlm_decoder.masked_copy_incremental_state) ), "The wrapped decoder should implement masked_copy_incremental_state()" assert isinstance(subwordlm, FairseqLanguageModel) self.subwordlm_decoder = subwordlm.decoder self.subwordlm_weight = subwordlm_weight self.log_oov_penalty = math.log(oov_penalty) self.open_vocab = open_vocab self.logzero = -10.0 word_dict = self.wordlm_decoder.dictionary assert isinstance(word_dict, AsrDictionary) self.word_eos_idx = word_dict.eos() self.word_unk_idx = word_dict.unk() subword_dict = self.subwordlm_decoder.dictionary assert isinstance(subword_dict, AsrDictionary) self.subword_space_idx = subword_dict.space() self.subword_eos_idx = subword_dict.eos() self.subword_vocab_size = len(subword_dict) def tokenizer(x): return tokenize(x, non_lang_syms=subword_dict.non_lang_syms).split(" ") self.lexroot = lexical_prefix_tree(word_dict, subword_dict, tokenizer)
def __init__(self, wordlm, subword_dict, oov_penalty=1e-4, open_vocab=True): super().__init__(wordlm.decoder.dictionary) assert isinstance(wordlm, FairseqLanguageModel) self.lm_decoder = wordlm.decoder assert hasattr(self.lm_decoder, 'masked_copy_incremental_state') and \ callable(self.lm_decoder.masked_copy_incremental_state), \ 'The wrapped decoder should implement masked_copy_incremental_state()' self.oov_penalty = oov_penalty self.open_vocab = open_vocab self.zero = 1e-10 # a sufficiently small value to avoid the log(0) issue word_dict = self.lm_decoder.dictionary assert isinstance(word_dict, AsrDictionary) self.word_pad_idx = word_dict.pad() self.word_eos_idx = word_dict.eos() self.word_unk_idx = word_dict.unk() assert isinstance(subword_dict, AsrDictionary) self.subword_space_idx = subword_dict.space() self.subword_pad_idx = subword_dict.pad() self.subword_eos_idx = subword_dict.eos() self.subword_vocab_size = len(subword_dict) def tokenizer(x): return tokenize( x, non_lang_syms=subword_dict.non_lang_syms).split(' ') self.lexroot = lexical_prefix_tree(word_dict, subword_dict, tokenizer) def max_out_degree(node): if len(node.children) == 0: return 0 cur_max = len(node.children) for _, node in node.children.items(): cur_max = max(cur_max, max_out_degree(node)) return cur_max self.max_num_children = max_out_degree(self.lexroot) assert self.max_num_children <= self.subword_vocab_size