Ejemplo n.º 1
0
 def load_dictionary(cls, filename, non_lang_syms=None):
     """Load the dictionary from the filename
     Args:
         filename (str): the filename
         non_lang_syms (str): non_lang_syms filename
     """
     return TokenDictionary.load(filename, f_non_lang_syms=non_lang_syms)
Ejemplo n.º 2
0
    def load_dictionary(cls, filename):
        """Load the dictionary from the filename

        Args:
            filename (str): the filename
        """
        return TokenDictionary.load(filename)
Ejemplo n.º 3
0
    def setup_task(cls, args, **kwargs):
        """Setup the task (e.g., load dictionaries).

        Args:
            args (argparse.Namespace): parsed command-line arguments
        """
        dictionary = None
        output_dictionary = None
        if args.data:
            paths = args.data.split(":")
            assert len(paths) > 0
            dict_path = os.path.join(paths[0], "dict.txt") if args.dict is None \
                else args.dict
            dictionary = TokenDictionary.load(dict_path)
            print("| dictionary: {} types".format(len(dictionary)))
            output_dictionary = dictionary
            if args.output_dictionary_size >= 0:
                output_dictionary = TruncatedDictionary(
                    dictionary, args.output_dictionary_size
                )

        # upgrade old checkpoints
        if hasattr(args, "exclude_self_target"):
            args.self_target = not args.exclude_self_target

        targets = []
        if getattr(args, "self_target", False):
            targets.append("self")
        if getattr(args, "future_target", False):
            targets.append("future")
        if getattr(args, "past_target", False):
            targets.append("past")
        if len(targets) == 0:
            # standard language modeling
            targets = ["future"]

        return cls(args, dictionary, output_dictionary, targets=targets)