def build( word_dict: TokenDictionary, subword_dict: TokenDictionary, subword_tokenizer: Callable[[str], List[str]] = None ): """ Builds a tensorized lexical prefix tree for words. """ root = lexical_prefix_tree( word_dict=word_dict, subword_dict=subword_dict, subword_tokenizer=subword_tokenizer ) # build traditional tree data structure by reusing existing routines # Performs pre-order traversal of this tree to assign an index for each node max_num_children = 0 nodes = [None] # nodes[0] is a dummy node for OOV node_to_id_dict = {} stack = [root] while len(stack) > 0: curr = stack.pop() node_id = len(nodes) nodes.append(curr) node_to_id_dict[curr] = node_id if len(curr.children) > max_num_children: max_num_children = len(curr.children) # Guarantee that the children are traversed ascendingly according to the subword index for _, next_node in sorted(curr.children.items(), key=lambda t: t[0], reverse=True): stack.append(next_node) # Construct the tree num_nodes = len(nodes) children = np.full([num_nodes, max_num_children], 0, dtype=np.int64) prev_subword_idx = np.full([num_nodes], subword_dict.pad(), dtype=np.int64) word_idx = np.full([num_nodes], -1, dtype=np.int64) word_set_idx = np.full([num_nodes, 2], word_dict.pad(), dtype=np.int64) for node_id in range(1, len(nodes)): # skip 0, which is `None` node = nodes[node_id] # Guarantee that the children are traversed ascendingly according to the subword index for i, (subword_id, child) in enumerate(sorted(node.children.items(), key=lambda t: t[0])): child_node_id = node_to_id_dict[child] children[node_id, i] = child_node_id prev_subword_idx[child_node_id] = subword_id word_idx[node_id] = node.word_idx if node.word_set is not None: word_set_idx[node_id] = node.word_set else: word_set_idx[node_id] = [0, len(word_dict) - 1] return TensorizedPrefixTree( children=torch.from_numpy(children), prev_subword_idx=torch.from_numpy(prev_subword_idx), word_idx=torch.from_numpy(word_idx), word_set_idx=torch.from_numpy(word_set_idx) )
def __init__(self, word_lm: FairseqLanguageModel, subword_dict: TokenDictionary, oov_penalty: float = 1e-4, open_vocab: bool = True): super().__init__(word_lm.decoder.dictionary) self.lm_decoder: FairseqIncrementalDecoder = word_lm.decoder assert hasattr(self.lm_decoder, 'masked_copy_incremental_state') and \ callable(self.lm_decoder.masked_copy_incremental_state), \ 'The wrapped decoder should implement masked_copy_incremental_state()' self.oov_penalty = oov_penalty self.open_vocab = open_vocab self.zero = 1e-10 # a sufficiently small value to avoid the log(0) issue word_dict: TokenDictionary = self.lm_decoder.dictionary self.word_pad_idx = word_dict.pad() self.word_eos_idx = word_dict.eos() self.word_unk_idx = word_dict.unk() self.subword_space_idx = subword_dict.space() self.subword_pad_idx = subword_dict.pad() self.subword_eos_idx = subword_dict.eos() self.subword_vocab_size = len(subword_dict) tokenizer: Callable[[str], List[str]] = \ lambda x: tokenize(x, non_lang_syms=subword_dict.non_lang_syms).split(' ') self.tree = TensorizedPrefixTree.build(word_dict, subword_dict, tokenizer) assert self.tree.max_out_degree() <= self.subword_vocab_size
def lexical_prefix_tree(word_dict: TokenDictionary, subword_dict: TokenDictionary, subword_tokenizer: Callable[[str], List[str]] = None): """Build a lexical prefix tree for words. Args: word_dict: an instance of :class:`fairseq.data.TokenDictionary`. subword_dict: an instance of :class:`fairseq.data.TokenDictionary`. subword_tokenizer (callable): a function that takes a word string as its only one argument, and returns a list of subwords as a result of tokenization. Return: root (Node): the root of the prefix tree, where each node has the fields: ('children': Dict[int,Node], 'word_idx': int, 'word_set': Tuple[int]). 'children' is subword_idx -> node, and 'word_set' is (first-1, last), where [first, last] is the range of the word indexes (inclusive) in the word dictionary who share the same prefix at that node. We assume words in the word dictionary are in lexical order. """ class Node(object): def __init__(self, children={}, word_idx=-1, word_set=None): self.children = children self.word_idx = word_idx self.word_set = word_set special_symbols = [word_dict.pad(), word_dict.eos(), word_dict.unk()] assert 0 in special_symbols # to ensure widx - 1 >= 0 root = Node({}, -1, None) for widx in range(len(word_dict)): if widx not in special_symbols: # skip <pad>, <eos>, <unk> # tokenize a word into a list of subwords subwords = subword_tokenizer(word_dict[widx]) \ if subword_tokenizer is not None else list(word_dict[widx]) if any( subword_dict.index(s) == subword_dict.unk() for s in subwords): # skip words containing any unknown subwords continue children = root.children for i, s in enumerate(subwords): sidx = subword_dict.index(s) if sidx not in children: # make a new node children[sidx] = Node({}, -1, (widx - 1, widx)) else: children[sidx].word_set = (min(children[sidx].word_set[0], widx - 1), max(children[sidx].word_set[1], widx)) if i == len(subwords) - 1: # if word end, set word_idx children[sidx].word_idx = widx children = children[sidx].children # move to children return root