示例#1
0
    def load_dictionary(cls, filename):
        """Load the dictionary from the filename

        Args:
            filename (str): the filename
        """
        return Dictionary.load(filename)
示例#2
0
 def setup_task(cls, args, **kwargs):
     """Setup the task."""
     dictionary = Dictionary.load(os.path.join(args.data, "dict.txt"))
     logger.info("dictionary: {} types".format(len(dictionary)))
     if not hasattr(args, "shuffle_instance"):
         args.shuffle_instance = False
     return cls(args, dictionary)
示例#3
0
    def load_dictionary(cls, args, filename, source=True):
        """Load the dictionary from the filename

        Args:
            filename (str): the filename
        """
        dictionary = Dictionary.load(filename)
        dictionary.add_symbol("<mask>")
        return dictionary
示例#4
0
 def setup_dictionary(cls, args, **kwargs):
     dictionary = None
     output_dictionary = None
     if args.data:
         paths = utils.split_paths(args.data)
         assert len(paths) > 0
         dictionary = Dictionary.load(os.path.join(paths[0], "dict.txt"))
         logger.info("dictionary: {} types".format(len(dictionary)))
         output_dictionary = dictionary
         if args.output_dictionary_size >= 0:
             output_dictionary = TruncatedDictionary(
                 dictionary, args.output_dictionary_size
             )
     return (dictionary, output_dictionary)
示例#5
0
    def setup_task(cls, args, **kwargs):
        data_cfg = S2TDataConfig(op.join(args.data, args.config_yaml))
        dict_path = op.join(args.data, data_cfg.vocab_filename)
        if not op.isfile(dict_path):
            raise FileNotFoundError(f"Dict not found: {dict_path}")
        tgt_dict = Dictionary.load(dict_path)
        logger.info(f"dictionary size ({data_cfg.vocab_filename}): "
                    f"{len(tgt_dict):,}")

        if getattr(args, "train_subset", None) is not None:
            if not all(
                    s.startswith("train")
                    for s in args.train_subset.split(",")):
                raise ValueError('Train splits should be named like "train*".')
        return cls(args, tgt_dict)
示例#6
0
 def setup_task(cls, args, **kwargs):
     paths = utils.split_paths(args.data)
     assert len(paths) > 0
     dictionary = Dictionary.load(os.path.join(paths[0], "dict.txt"))
     logger.info("dictionary: {} types".format(len(dictionary)))
     return cls(args, dictionary)
示例#7
0
 def load_target_dictionary(self):
     if self.cfg.labels:
         dict_path = os.path.join(self.cfg.data, f"dict.{self.cfg.labels}.txt")
         return Dictionary.load(dict_path)
     return None