def __init__(self,
                 word_lm: FairseqLanguageModel,
                 subword_dict: AsrDictionary,
                 oov_penalty: float = 1e-4,
                 open_vocab: bool = True):
        super().__init__(word_lm.decoder.dictionary)

        self.lm_decoder: FairseqIncrementalDecoder = word_lm.decoder
        assert hasattr(self.lm_decoder, 'masked_copy_incremental_state') and \
            callable(self.lm_decoder.masked_copy_incremental_state), \
            'The wrapped decoder should implement masked_copy_incremental_state()'

        self.oov_penalty = oov_penalty
        self.open_vocab = open_vocab
        self.zero = 1e-10  # a sufficiently small value to avoid the log(0) issue

        word_dict: AsrDictionary = self.lm_decoder.dictionary
        self.word_pad_idx = word_dict.pad()
        self.word_eos_idx = word_dict.eos()
        self.word_unk_idx = word_dict.unk()

        self.subword_space_idx = subword_dict.space()
        self.subword_pad_idx = subword_dict.pad()
        self.subword_eos_idx = subword_dict.eos()
        self.subword_vocab_size = len(subword_dict)

        def tokenizer(x: str) -> List[str]:
            return tokenize(
                x, non_lang_syms=subword_dict.non_lang_syms).split(' ')

        self.tree = TensorizedPrefixTree.build(word_dict, subword_dict,
                                               tokenizer)

        assert self.tree.max_out_degree() <= self.subword_vocab_size
def lexical_prefix_tree(
    word_dict: AsrDictionary,
    subword_dict: AsrDictionary,
    subword_tokenizer: Callable[[str], List[str]] = None,
):
    """Build a lexical prefix tree for words.

    Args:
        word_dict: an instance of :class:`fairseq.data.AsrDictionary`.
        subword_dict: an instance of :class:`fairseq.data.AsrDictionary`.
        subword_tokenizer (callable): a function that takes a word string as its
            only one argument, and returns a list of subwords as a result of
            tokenization.

    Return:
        root (Node): the root of the prefix tree, where each node has the fields:
            ("children": Dict[int,Node], "word_idx": int, "word_set": Tuple[int]).
            "children" is subword_idx -> node, and "word_set" is (first-1, last),
            where [first, last] is the range of the word indexes (inclusive) in
            the word dictionary who share the same prefix at that node.
            We assume words in the word dictionary are in lexical order.
    """
    class Node(object):
        def __init__(self, children={}, word_idx=-1, word_set=None):
            self.children = children
            self.word_idx = word_idx
            self.word_set = word_set

    special_symbols = [word_dict.pad(), word_dict.eos(), word_dict.unk()]
    assert 0 in special_symbols  # to ensure widx - 1 >= 0
    root = Node({}, -1, None)
    for widx in range(len(word_dict)):
        if widx not in special_symbols:  # skip <pad>, <eos>, <unk>
            # tokenize a word into a list of subwords
            subwords = (subword_tokenizer(word_dict[widx]) if subword_tokenizer
                        is not None else list(word_dict[widx]))
            if any(
                    subword_dict.index(s) == subword_dict.unk()
                    for s in subwords):
                # skip words containing any unknown subwords
                continue
            children = root.children
            for i, s in enumerate(subwords):
                sidx = subword_dict.index(s)
                if sidx not in children:  # make a new node
                    children[sidx] = Node({}, -1, (widx - 1, widx))
                else:
                    children[sidx].word_set = (
                        min(children[sidx].word_set[0], widx - 1),
                        max(children[sidx].word_set[1], widx),
                    )
                if i == len(subwords) - 1:  # if word end, set word_idx
                    children[sidx].word_idx = widx
                children = children[sidx].children  # move to children
    return root