Exemple #1
0
    def build(
            word_dict: TokenDictionary,
            subword_dict: TokenDictionary,
            subword_tokenizer: Callable[[str], List[str]] = None
    ):
        """
        Builds a tensorized lexical prefix tree for words.
        """

        root = lexical_prefix_tree(
            word_dict=word_dict,
            subword_dict=subword_dict,
            subword_tokenizer=subword_tokenizer
        )  # build traditional tree data structure by reusing existing routines

        # Performs pre-order traversal of this tree to assign an index for each node
        max_num_children = 0
        nodes = [None]  # nodes[0] is a dummy node for OOV
        node_to_id_dict = {}
        stack = [root]

        while len(stack) > 0:
            curr = stack.pop()
            node_id = len(nodes)
            nodes.append(curr)
            node_to_id_dict[curr] = node_id
            if len(curr.children) > max_num_children:
                max_num_children = len(curr.children)

            # Guarantee that the children are traversed ascendingly according to the subword index
            for _, next_node in sorted(curr.children.items(), key=lambda t: t[0], reverse=True):
                stack.append(next_node)

        # Construct the tree
        num_nodes = len(nodes)
        children = np.full([num_nodes, max_num_children], 0, dtype=np.int64)
        prev_subword_idx = np.full([num_nodes], subword_dict.pad(), dtype=np.int64)
        word_idx = np.full([num_nodes], -1, dtype=np.int64)
        word_set_idx = np.full([num_nodes, 2], word_dict.pad(), dtype=np.int64)

        for node_id in range(1, len(nodes)):  # skip 0, which is `None`
            node = nodes[node_id]
            # Guarantee that the children are traversed ascendingly according to the subword index
            for i, (subword_id, child) in enumerate(sorted(node.children.items(), key=lambda t: t[0])):
                child_node_id = node_to_id_dict[child]
                children[node_id, i] = child_node_id
                prev_subword_idx[child_node_id] = subword_id

            word_idx[node_id] = node.word_idx
            if node.word_set is not None:
                word_set_idx[node_id] = node.word_set
            else:
                word_set_idx[node_id] = [0, len(word_dict) - 1]

        return TensorizedPrefixTree(
            children=torch.from_numpy(children),
            prev_subword_idx=torch.from_numpy(prev_subword_idx),
            word_idx=torch.from_numpy(word_idx),
            word_set_idx=torch.from_numpy(word_set_idx)
        )
    def __init__(self,
                 word_lm: FairseqLanguageModel,
                 subword_dict: TokenDictionary,
                 oov_penalty: float = 1e-4,
                 open_vocab: bool = True):
        super().__init__(word_lm.decoder.dictionary)

        self.lm_decoder: FairseqIncrementalDecoder = word_lm.decoder
        assert hasattr(self.lm_decoder, 'masked_copy_incremental_state') and \
            callable(self.lm_decoder.masked_copy_incremental_state), \
            'The wrapped decoder should implement masked_copy_incremental_state()'

        self.oov_penalty = oov_penalty
        self.open_vocab = open_vocab
        self.zero = 1e-10  # a sufficiently small value to avoid the log(0) issue

        word_dict: TokenDictionary = self.lm_decoder.dictionary
        self.word_pad_idx = word_dict.pad()
        self.word_eos_idx = word_dict.eos()
        self.word_unk_idx = word_dict.unk()

        self.subword_space_idx = subword_dict.space()
        self.subword_pad_idx = subword_dict.pad()
        self.subword_eos_idx = subword_dict.eos()
        self.subword_vocab_size = len(subword_dict)

        tokenizer: Callable[[str], List[str]] = \
            lambda x: tokenize(x, non_lang_syms=subword_dict.non_lang_syms).split(' ')
        self.tree = TensorizedPrefixTree.build(word_dict, subword_dict,
                                               tokenizer)

        assert self.tree.max_out_degree() <= self.subword_vocab_size
Exemple #3
0
 def make_dictionary():
     """construct dictionary."""
     d = TokenDictionary()
     alphabet = string.ascii_lowercase
     for token in alphabet:
         d.add_symbol(token)
     d.add_symbol('<space>')
     d.finalize(padding_factor=1)  # don't add extra padding symbols
     d.space_index = d.indices.get('<space>', -1)
     return d
 def load_dictionary(cls, filename, non_lang_syms=None):
     """Load the dictionary from the filename
     Args:
         filename (str): the filename
         non_lang_syms (str): non_lang_syms filename
     """
     return TokenDictionary.load(filename, f_non_lang_syms=non_lang_syms)
Exemple #5
0
    def load_dictionary(cls, filename):
        """Load the dictionary from the filename

        Args:
            filename (str): the filename
        """
        return TokenDictionary.load(filename)
Exemple #6
0
def lexical_prefix_tree(word_dict: TokenDictionary,
                        subword_dict: TokenDictionary,
                        subword_tokenizer: Callable[[str], List[str]] = None):
    """Build a lexical prefix tree for words.

    Args:
        word_dict: an instance of :class:`fairseq.data.TokenDictionary`.
        subword_dict: an instance of :class:`fairseq.data.TokenDictionary`.
        subword_tokenizer (callable): a function that takes a word string as its
            only one argument, and returns a list of subwords as a result of
            tokenization.

    Return:
        root (Node): the root of the prefix tree, where each node has the fields:
            ('children': Dict[int,Node], 'word_idx': int, 'word_set': Tuple[int]).
            'children' is subword_idx -> node, and 'word_set' is (first-1, last),
            where [first, last] is the range of the word indexes (inclusive) in
            the word dictionary who share the same prefix at that node.
            We assume words in the word dictionary are in lexical order.
    """
    class Node(object):
        def __init__(self, children={}, word_idx=-1, word_set=None):
            self.children = children
            self.word_idx = word_idx
            self.word_set = word_set

    special_symbols = [word_dict.pad(), word_dict.eos(), word_dict.unk()]
    assert 0 in special_symbols  # to ensure widx - 1 >= 0
    root = Node({}, -1, None)
    for widx in range(len(word_dict)):
        if widx not in special_symbols:  # skip <pad>, <eos>, <unk>
            # tokenize a word into a list of subwords
            subwords = subword_tokenizer(word_dict[widx]) \
                if subword_tokenizer is not None else list(word_dict[widx])
            if any(
                    subword_dict.index(s) == subword_dict.unk()
                    for s in subwords):
                # skip words containing any unknown subwords
                continue
            children = root.children
            for i, s in enumerate(subwords):
                sidx = subword_dict.index(s)
                if sidx not in children:  # make a new node
                    children[sidx] = Node({}, -1, (widx - 1, widx))
                else:
                    children[sidx].word_set = (min(children[sidx].word_set[0],
                                                   widx - 1),
                                               max(children[sidx].word_set[1],
                                                   widx))
                if i == len(subwords) - 1:  # if word end, set word_idx
                    children[sidx].word_idx = widx
                children = children[sidx].children  # move to children
    return root
Exemple #7
0
    def setup_task(cls, args, **kwargs):
        """Setup the task (e.g., load dictionaries).

        Args:
            args (argparse.Namespace): parsed command-line arguments
        """
        dictionary = None
        output_dictionary = None
        if args.data:
            paths = args.data.split(":")
            assert len(paths) > 0
            dict_path = os.path.join(paths[0], "dict.txt") if args.dict is None \
                else args.dict
            dictionary = TokenDictionary.load(dict_path)
            print("| dictionary: {} types".format(len(dictionary)))
            output_dictionary = dictionary
            if args.output_dictionary_size >= 0:
                output_dictionary = TruncatedDictionary(
                    dictionary, args.output_dictionary_size
                )

        # upgrade old checkpoints
        if hasattr(args, "exclude_self_target"):
            args.self_target = not args.exclude_self_target

        targets = []
        if getattr(args, "self_target", False):
            targets.append("self")
        if getattr(args, "future_target", False):
            targets.append("future")
        if getattr(args, "past_target", False):
            targets.append("past")
        if len(targets) == 0:
            # standard language modeling
            targets = ["future"]

        return cls(args, dictionary, output_dictionary, targets=targets)
Exemple #8
0
    def build_dictionary(cls, filenames, workers=1, threshold=-1, nwords=-1, padding_factor=8):
        """Build the dictionary

        Args:
            filenames (list): list of filenames
            workers (int): number of concurrent workers
            threshold (int): defines the minimum word count
            nwords (int): defines the total number of words in the final dictionary,
                including special symbols
            padding_factor (int): can be used to pad the dictionary size to be a
                multiple of 8, which is important on some hardware (e.g., Nvidia
                Tensor Cores).
        """
        d = TokenDictionary()
        for filename in filenames:
            TokenDictionary.add_file_to_dictionary(filename, d, tokenizer.tokenize_line, workers)
        d.finalize(threshold=threshold, nwords=nwords, padding_factor=padding_factor)
        return d
Exemple #9
0
 def make_dictionary(vocab, non_lang_syms=[]):
     """construct dictionary."""
     assert isinstance(vocab, list) and isinstance(non_lang_syms, list)
     d = TokenDictionary()
     for token in vocab:
         d.add_symbol(token)
     d.add_symbol('<space>')
     for token in non_lang_syms:
         d.add_symbol(token)
     d.finalize(padding_factor=1)  # don't add extra padding symbols
     d.space_index = d.indices.get('<space>', -1)
     return d