def save_state( filename, args, model_state_dict, criterion, optimizer, lr_scheduler, num_updates, optim_history=None, extra_state=None, ): from fairseq import utils if optim_history is None: optim_history = [] if extra_state is None: extra_state = {} state_dict = { "args": args, "model": model_state_dict if model_state_dict else {}, "optimizer_history": optim_history + [{ "criterion_name": criterion.__class__.__name__, "optimizer_name": optimizer.__class__.__name__, "lr_scheduler_state": lr_scheduler.state_dict(), "num_updates": num_updates, }], "extra_state": extra_state, } if utils.has_parameters(criterion): state_dict["criterion"] = criterion.state_dict() if not args.no_save_optimizer_state: state_dict["last_optimizer_state"] = convert_state_dict_type( optimizer.state_dict()) with PathManager.open(filename, "wb") as f: torch_persistent_save(state_dict, f)
def __init__(self, args): super().__init__(args) self.eos = DEFAULT_EOS self.gpu = getattr(args, "gpu", False) self.args = args self.load_model_vocab(args) if getattr(self.model.decoder.layers[0].encoder_attn, 'pre_decision_ratio', None) is not None: self.speech_segment_size *= ( self.model.decoder.layers[0].encoder_attn.pre_decision_ratio) args.global_cmvn = None if args.config: with open(os.path.join(args.data_bin, args.config), "r") as f: config = yaml.load(f, Loader=yaml.BaseLoader) if "global_cmvn" in config: args.global_cmvn = np.load( config["global_cmvn"]["stats_npz_path"]) if args.global_stats: with PathManager.open(args.global_stats, "r") as f: global_cmvn = json.loads(f.read()) self.global_cmvn = { "mean": global_cmvn["mean"], "std": global_cmvn["stddev"] } self.feature_extractor = OnlineFeatureExtractor(args) self.max_len = args.max_len self.force_finish = args.force_finish torch.set_grad_enabled(False)
def load(cls, f, f_non_lang_syms=None): """Loads the dictionary from a text file with the format: ``` <symbol0> <count0> <symbol1> <count1> ... ``` Identifies the space symbol if it exists, by obtaining its index (space_index=-1 if no space symbol) Loads non_lang_syms from another text file, if it exists, with one symbol per line """ d = super().load(f) d.space_index = d.indices.get(d.space_word, -1) if f_non_lang_syms is not None: assert isinstance(f_non_lang_syms, str) try: with PathManager.open(f_non_lang_syms, "r", encoding="utf-8") as fd: non_lang_syms = [x.rstrip() for x in fd.readlines()] except FileNotFoundError as fnfe: raise fnfe except UnicodeError: raise Exception( "Incorrect encoding detected in {}, please " "rebuild the dataset".format(f) ) for sym in non_lang_syms: assert d.index(sym) != d.unk(), \ "{} in {} is not in the dictionary".format(sym, f_non_lang_syms) d.non_lang_syms = non_lang_syms return d
def save_state( filename, cfg: FairseqConfig, model_state_dict, criterion, optimizer, lr_scheduler, num_updates, optim_history=None, extra_state=None, **kwargs, ): from fairseq import utils if optim_history is None: optim_history = [] if extra_state is None: extra_state = {} state_dict = { "cfg": cfg, "args": kwargs.get("args", None), "model": model_state_dict or {}, "optimizer_history": optim_history + [{ "criterion_name": criterion.__class__.__name__, "optimizer_name": optimizer.__class__.__name__, "lr_scheduler_state": lr_scheduler.state_dict(), "num_updates": num_updates, }], "extra_state": extra_state, } if utils.has_parameters(criterion): state_dict["criterion"] = criterion.state_dict() if cfg is None: cfg = state_dict["args"] assert cfg is not None, "must provide cfg or args" if isinstance(cfg, DictConfig): no_save_optimizer_state = cfg.checkpoint.no_save_optimizer_state else: no_save_optimizer_state = cfg.no_save_optimizer_state if not no_save_optimizer_state: state_dict["last_optimizer_state"] = optimizer.state_dict() # keep everything on CPU state_dict = utils.move_to_cpu(state_dict) if PathManager.supports_rename(filename): # do atomic save with PathManager.open(filename + ".tmp", "wb") as f: torch_persistent_save(state_dict, f) PathManager.rename(filename + ".tmp", filename) else: # fallback to non-atomic save with PathManager.open(filename, "wb") as f: torch_persistent_save(state_dict, f)
def main(): parser = argparse.ArgumentParser( description="Tool to average the params of input checkpoints to " "produce a new checkpoint", ) # fmt: off parser.add_argument('--inputs', required=True, nargs='+', help='Input checkpoint file paths.') parser.add_argument( '--output', required=True, metavar='FILE', help= 'Write the new checkpoint containing the averaged weights to this path.' ) num_group = parser.add_mutually_exclusive_group() num_group.add_argument( '--num-epoch-checkpoints', type=int, help= 'if set, will try to find checkpoints with names checkpoint_xx.pt in the path specified by input, ' 'and average last this many of them.') num_group.add_argument( '--num-update-checkpoints', type=int, help= 'if set, will try to find checkpoints with names checkpoint.jj_ee_xx.pt in the path specified by input, ' 'and average last this many of them.') parser.add_argument( '--checkpoint-upper-bound', type=int, help= 'when using --num-epoch-checkpoints, this will set an upper bound on which epoch to use, ' 'when using --num-update-checkpoints, this will set an upper bound on which update to use' 'e.g., with --num-epoch-checkpoints=10 --checkpoint-upper-bound=50, checkpoints 41-50 would be averaged.' 'e.g., with --num-update-checkpoints=10 --checkpoint-upper-bound=50000, checkpoints 40500-50000 would be averaged assuming --save-interval-updates 500' ) # fmt: on args = parser.parse_args() print(args) num = None is_update_based = False if args.num_update_checkpoints is not None: num = args.num_update_checkpoints is_update_based = True elif args.num_epoch_checkpoints is not None: num = args.num_epoch_checkpoints assert args.checkpoint_upper_bound is None or ( args.num_epoch_checkpoints is not None or args.num_update_checkpoints is not None ), "--checkpoint-upper-bound requires --num-epoch-checkpoints or --num-update-checkpoints" assert ( args.num_epoch_checkpoints is None or args.num_update_checkpoints is None ), "Cannot combine --num-epoch-checkpoints and --num-update-checkpoints" if num is not None: args.inputs = last_n_checkpoints( args.inputs, num, is_update_based, upper_bound=args.checkpoint_upper_bound, ) print("averaging checkpoints: ", args.inputs) new_state = average_checkpoints(args.inputs) with PathManager.open(args.output, "wb") as f: torch.save(new_state, f) print("Finished writing averaged checkpoint to {}".format(args.output))
def test_file_io(self): from fairseq.file_io import PathManager with PathManager.open(os.path.join(self._tmpdir, "test.txt"), "r") as f: s = f.read() self.assertEqual(s, self._tmpfile_contents)
def add_from_file(self, f, tgt_first=False, bos_id_tgt=None, pad_id_tgt=None, eos_id_tgt=None, unk_id_tgt=None): """ Loads a pre-existing dictionary from a text file and adds its symbols to this instance. """ if isinstance(f, str): try: with PathManager.open(f, "r", encoding="utf-8") as fd: self.add_from_file(fd, tgt_first=tgt_first, bos_id_tgt=bos_id_tgt, pad_id_tgt=pad_id_tgt, eos_id_tgt=eos_id_tgt, unk_id_tgt=unk_id_tgt) except FileNotFoundError as fnfe: raise fnfe except UnicodeError: raise Exception("Incorrect encoding detected in {}, please " "rebuild the dataset".format(f)) return lines = f.readlines() indices_start_line = self._load_meta(lines) if tgt_first: i = 0 for line in lines[indices_start_line:]: while i in [bos_id_tgt, pad_id_tgt, eos_id_tgt, unk_id_tgt]: if i == bos_id_tgt: self.add_symbol(self.bos_word) i += 1 elif i == pad_id_tgt: self.add_symbol(self.pad_word) i += 1 elif i == eos_id_tgt: self.add_symbol(self.eos_word) i += 1 elif i == unk_id_tgt: self.add_symbol(self.unk_word) i += 1 try: line, field = line.rstrip().rsplit(" ", 1) if field == "#fairseq:overwrite": overwrite = True line, field = line.rsplit(" ", 1) else: overwrite = False count = int(field) word = line # if word in self and not overwrite: # raise RuntimeError( # "Duplicate word found when loading Dictionary: '{}'. " # "Duplicate words can overwrite earlier ones by adding the " # "#fairseq:overwrite flag at the end of the corresponding row " # "in the dictionary file. If using the Camembert model, please " # "download an updated copy of the model file." # .format(word) # ) self.add_symbol(word, n=count, overwrite=overwrite) i += 1 except ValueError: raise ValueError( "Incorrect dictionary format, expected '<token> <cnt> [flags]'" ) else: for line in lines[indices_start_line:]: try: line, field = line.rstrip().rsplit(" ", 1) if field == "#fairseq:overwrite": overwrite = True line, field = line.rsplit(" ", 1) else: overwrite = False count = int(field) word = line if word in self and not overwrite: raise RuntimeError( "Duplicate word found when loading Dictionary: '{}'. " "Duplicate words can overwrite earlier ones by adding the " "#fairseq:overwrite flag at the end of the corresponding row " "in the dictionary file. If using the Camembert model, please " "download an updated copy of the model file.". format(word)) self.add_symbol(word, n=count, overwrite=overwrite) except ValueError: raise ValueError( "Incorrect dictionary format, expected '<token> <cnt> [flags]'" )
def main(): parser = argparse.ArgumentParser( description='Tool to average the params of input checkpoints to ' 'produce a new checkpoint') # fmt: off parser.add_argument('--input-directory', type=str, required=True, help='Input directory containing model checkpoints.') parser.add_argument( '--output', default=None, type=str, help= 'Write the new checkpoint containing the averaged weights to this path.' ) num_group = parser.add_mutually_exclusive_group() num_group.add_argument( '--num-epoch-checkpoints', type=int, help= 'if set, will try to find checkpoints with names checkpoint_xx.pt in the path specified by input, ' 'and average last this many of them.') num_group.add_argument( '--num-update-checkpoints', type=int, help= 'if set, will try to find checkpoints with names checkpoint_ee_xx.pt in the path specified by input, ' 'and average last this many of them.') parser.add_argument( '--checkpoint-upper-bound', type=int, help= 'when using --num-epoch-checkpoints, this will set an upper bound on which epoch to use, ' 'when using --num-update-checkpoints, this will set an upper bound on which update to use' 'e.g., with --num-epoch-checkpoints=10 --checkpoint-upper-bound=50, checkpoints 41-50 would be averaged.' 'e.g., with --num-update-checkpoints=10 --checkpoint-upper-bound=50000, checkpoints 40500-50000 would be averaged assuming --save-interval-updates 500' ) # fmt: on args = parser.parse_args() print(args) num = None is_update_based = False if args.num_update_checkpoints is not None: num = args.num_update_checkpoints is_update_based = True elif args.num_epoch_checkpoints is not None: num = args.num_epoch_checkpoints assert args.checkpoint_upper_bound is None or (args.num_epoch_checkpoints is not None or args.num_update_checkpoints is not None), \ '--checkpoint-upper-bound requires --num-epoch-checkpoints or --num-update-checkpoints' assert args.num_epoch_checkpoints is None or args.num_update_checkpoints is None, \ 'Cannot combine --num-epoch-checkpoints and --num-update-checkpoints' if num is not None: args.inputs = last_n_checkpoints( args.input_directory, num, is_update_based, upper_bound=args.checkpoint_upper_bound, ) print('averaging checkpoints: ', args.inputs) if args.output is None and args.num_epoch_checkpoints is not None: range_checkpoints = sorted([ int(re.search(r".*checkpoint(.*)\.pt$", input_file).groups()[0]) for input_file in args.inputs ]) start = range_checkpoints[0] end = range_checkpoints[-1] args.output = os.path.join( args.input_directory, "checkpoint_average_%s_%s.pt" % (start, end)) assert args.output is not None, "No output file specified, nor could it be deduced" new_state = average_checkpoints(args.inputs) with PathManager.open(args.output, 'wb') as f: torch.save(new_state, f) print('Finished writing averaged checkpoint to {}'.format(args.output))
def add_file_to_dictionary(filename, dict, tokenize): with PathManager.open(filename, "r", encoding="utf-8") as f: for line in f: for word in tokenize(line): dict.add_symbol(word) dict.add_symbol(dict.eos_word)
def main(args): state = checkpoint_utils.load_checkpoint_to_cpu(args.checkpoint) ns = state["args"] model = state["model"] ns.arch = "transformer_modular" if (args.encoder_attention_heads_active is None and args.decoder_attention_heads_active is None): raise ValueError( 'Either --encoder-attention-heads-active or ' '--decoder-attention-heads-active option must be set.') if args.encoder_attention_heads_active is None: args.encoder_attention_heads_active = args.decoder_attention_heads_active if args.encoder_modular_layer_indices is not None: ns.encoder_modular_layer_indices = "({})".format( args.encoder_modular_layer_indices) model = convert_model(model, ns, coder="encoder", att_type="self_attn") if args.decoder_modular_layer_indices is not None: ns.decoder_modular_layer_indices = "({})".format( args.decoder_modular_layer_indices) model = convert_model(model, ns, coder="decoder", att_type="self_attn") model = convert_model(model, ns, coder="decoder", att_type="encoder_attn") ctrl_enc = ModularCtrl(ns.encoder_embed_dim, ns.encoder_attention_heads, args.encoder_attention_heads_active, hidden_depth=args.ctrl_hidden_depth, hidden_dim=args.ctrl_hidden_dim, ctrl_type=args.ctrl_type) ns.module_ctrl_hidden_depth = args.ctrl_hidden_depth ns.module_ctrl_hidden_dim = args.ctrl_hidden_dim ns.module_ctrl_type = args.ctrl_type for k, v in ctrl_enc.state_dict().items(): model["encoder.module_ctrl.{}".format(k)] = v if not args.share_encoder_ctrl: if args.decoder_attention_heads_active is None: raise ValueError("Missing ``decoder-attention-heads-active'' " "when ``share-encoder-ctrl'' is disabled.") ns.share_encoder_ctrl = False ctrl_dec = ModularCtrl(ns.decoder_embed_dim, ns.decoder_attention_heads, args.decoder_attention_heads_active, hidden_depth=args.ctrl_hidden_depth, hidden_dim=args.ctrl_hidden_dim, ctrl_type=args.ctrl_type) for k, v in ctrl_dec.state_dict().items(): model["decoder.module_ctrl.{}".format(k)] = v else: ns.share_encoder_ctrl = True ns.arch = "transformer_modular" ns.criterion = "label_smoothed_cross_entropy_modular" ns.task = "translation_modular" ns.encoder_attention_heads_active = args.encoder_attention_heads_active state["args"] = ns state["model"] = model for i, _ in enumerate(state["optimizer_history"]): state["optimizer_history"][i][ "criterion_name"] = 'LabelSmoothedCrossEntropyModularCriterion' state = utils.move_to_cpu(state) with PathManager.open(args.save_as, "wb") as f: checkpoint_utils.torch_persistent_save(state, f)