def build(
            word_dict: AsrDictionary,
            subword_dict: AsrDictionary,
            subword_tokenizer: Callable[[str], List[str]] = None
    ):
        """
        Builds a tensorized lexical prefix tree for words.
        """

        root = lexical_prefix_tree(
            word_dict=word_dict,
            subword_dict=subword_dict,
            subword_tokenizer=subword_tokenizer
        )  # build traditional tree data structure by reusing existing routines

        # Performs pre-order traversal of this tree to assign an index for each node
        max_num_children = 0
        nodes = [None]  # nodes[0] is a dummy node for OOV
        node_to_id_dict = {}
        stack = [root]

        while len(stack) > 0:
            curr = stack.pop()
            node_id = len(nodes)
            nodes.append(curr)
            node_to_id_dict[curr] = node_id
            if len(curr.children) > max_num_children:
                max_num_children = len(curr.children)

            # Guarantee that the children are traversed ascendingly according to the subword index
            for _, next_node in sorted(curr.children.items(), key=lambda t: t[0], reverse=True):
                stack.append(next_node)

        # Construct the tree
        num_nodes = len(nodes)
        children = np.full([num_nodes, max_num_children], 0, dtype=np.int64)
        prev_subword_idx = np.full([num_nodes], subword_dict.pad(), dtype=np.int64)
        word_idx = np.full([num_nodes], -1, dtype=np.int64)
        word_set_idx = np.full([num_nodes, 2], word_dict.pad(), dtype=np.int64)

        for node_id in range(1, len(nodes)):  # skip 0, which is `None`
            node = nodes[node_id]
            # Guarantee that the children are traversed ascendingly according to the subword index
            for i, (subword_id, child) in enumerate(sorted(node.children.items(), key=lambda t: t[0])):
                child_node_id = node_to_id_dict[child]
                children[node_id, i] = child_node_id
                prev_subword_idx[child_node_id] = subword_id

            word_idx[node_id] = node.word_idx
            if node.word_set is not None:
                word_set_idx[node_id] = node.word_set
            else:
                word_set_idx[node_id] = [0, len(word_dict) - 1]

        return TensorizedPrefixTree(
            children=torch.from_numpy(children),
            prev_subword_idx=torch.from_numpy(prev_subword_idx),
            word_idx=torch.from_numpy(word_idx),
            word_set_idx=torch.from_numpy(word_set_idx)
        )
    def __init__(
        self, wordlm, subwordlm, subwordlm_weight=0.8, oov_penalty=1.0, open_vocab=True,
    ):
        super().__init__(wordlm.decoder.dictionary)

        assert isinstance(wordlm, FairseqLanguageModel)
        self.wordlm_decoder = wordlm.decoder
        assert (
            hasattr(self.wordlm_decoder, "masked_copy_incremental_state") and
            callable(self.wordlm_decoder.masked_copy_incremental_state)
        ), "The wrapped decoder should implement masked_copy_incremental_state()"
        assert isinstance(subwordlm, FairseqLanguageModel)
        self.subwordlm_decoder = subwordlm.decoder
        self.subwordlm_weight = subwordlm_weight
        self.log_oov_penalty = math.log(oov_penalty)
        self.open_vocab = open_vocab
        self.logzero = -10.0

        word_dict = self.wordlm_decoder.dictionary
        assert isinstance(word_dict, AsrDictionary)
        self.word_eos_idx = word_dict.eos()
        self.word_unk_idx = word_dict.unk()

        subword_dict = self.subwordlm_decoder.dictionary
        assert isinstance(subword_dict, AsrDictionary)
        self.subword_space_idx = subword_dict.space()
        self.subword_eos_idx = subword_dict.eos()
        self.subword_vocab_size = len(subword_dict)

        def tokenizer(x):
            return tokenize(x, non_lang_syms=subword_dict.non_lang_syms).split(" ")
        self.lexroot = lexical_prefix_tree(word_dict, subword_dict, tokenizer)
示例#3
0
    def __init__(self,
                 wordlm,
                 subword_dict,
                 oov_penalty=1e-4,
                 open_vocab=True):
        super().__init__(wordlm.decoder.dictionary)

        assert isinstance(wordlm, FairseqLanguageModel)
        self.lm_decoder = wordlm.decoder
        assert hasattr(self.lm_decoder, 'masked_copy_incremental_state') and \
            callable(self.lm_decoder.masked_copy_incremental_state), \
            'The wrapped decoder should implement masked_copy_incremental_state()'
        self.oov_penalty = oov_penalty
        self.open_vocab = open_vocab
        self.zero = 1e-10  # a sufficiently small value to avoid the log(0) issue

        word_dict = self.lm_decoder.dictionary
        assert isinstance(word_dict, AsrDictionary)
        self.word_pad_idx = word_dict.pad()
        self.word_eos_idx = word_dict.eos()
        self.word_unk_idx = word_dict.unk()

        assert isinstance(subword_dict, AsrDictionary)
        self.subword_space_idx = subword_dict.space()
        self.subword_pad_idx = subword_dict.pad()
        self.subword_eos_idx = subword_dict.eos()
        self.subword_vocab_size = len(subword_dict)

        def tokenizer(x):
            return tokenize(
                x, non_lang_syms=subword_dict.non_lang_syms).split(' ')

        self.lexroot = lexical_prefix_tree(word_dict, subword_dict, tokenizer)

        def max_out_degree(node):
            if len(node.children) == 0:
                return 0
            cur_max = len(node.children)
            for _, node in node.children.items():
                cur_max = max(cur_max, max_out_degree(node))
            return cur_max

        self.max_num_children = max_out_degree(self.lexroot)
        assert self.max_num_children <= self.subword_vocab_size