def load_checkpoint( self, filename, reset_optimizer=False, reset_lr_scheduler=False, optimizer_overrides=None, reset_meters=False, ): """ Load all training state from a checkpoint file. rank = 0 will load the checkpoint, and then broadcast it to all other ranks. """ extra_state, self._optim_history, last_optim_state = None, [], None bexists = PathManager.isfile(filename) if bexists: if (self.data_parallel_rank == 0 # TPUs don't support broadcast yet, so load checkpoints # on every worker for now or self.tpu): state = checkpoint_utils.load_checkpoint_to_cpu(filename) last_optim_state = state.get("last_optimizer_state", None) # If doing zero_sharding, do not broadcast global optimizer # state. Later we will broadcast sharded states to each rank # to avoid memory from exploding. if (self.cfg.distributed_training.zero_sharding == "os" and "last_optimizer_state" in state and self.data_parallel_world_size > 1): state["last_optimizer_state"] = "SHARDED" else: last_optim_state = None state = None if (self.data_parallel_world_size > 1 # disable on TPUs until they support broadcast and not self.tpu): state = distributed_utils.broadcast_object( state, src_rank=0, group=self.data_parallel_process_group, dist_device=self.device, ) if self.data_parallel_rank > 0: last_optim_state = state.get("last_optimizer_state", None) # load model parameters try: self.get_model().load_state_dict(state["model"], strict=True, model_cfg=self.cfg.model) if utils.has_parameters(self.get_criterion()): self.get_criterion().load_state_dict(state["criterion"], strict=True) except Exception: raise Exception( "Cannot load model parameters from checkpoint {}; " "please ensure that the architectures match.".format( filename)) extra_state = state["extra_state"] self._optim_history = state["optimizer_history"] if last_optim_state is not None and not reset_optimizer: # rebuild optimizer after loading model, since params may have changed self._build_optimizer() # only reload optimizer and lr_scheduler if they match last_optim = self._optim_history[-1] assert ( last_optim["criterion_name"] == self.get_criterion().__class__.__name__ ), "Criterion does not match; please reset the optimizer (--reset-optimizer)." assert ( last_optim["optimizer_name"] == self.optimizer.__class__.__name__ ), "Optimizer does not match; please reset the optimizer (--reset-optimizer)." if not reset_lr_scheduler: self.lr_scheduler.load_state_dict( last_optim["lr_scheduler_state"]) if self.data_parallel_world_size > 1: last_optim_state = self.optimizer.broadcast_global_state_dict( last_optim_state) self.optimizer.load_state_dict(last_optim_state, optimizer_overrides) self.set_num_updates(last_optim["num_updates"]) if extra_state is not None: epoch = extra_state["train_iterator"]["epoch"] logger.info("loaded checkpoint {} (epoch {} @ {} updates)".format( filename, epoch, self.get_num_updates())) if "previous_training_time" in extra_state: self._previous_training_time = extra_state[ "previous_training_time"] self._start_time = time.time() self.lr_step(epoch) if "metrics" in extra_state and not reset_meters: metrics.load_state_dict(extra_state["metrics"]) # reset TimeMeters, since their start times don't make sense anymore for meter in metrics.get_meters("default"): if isinstance(meter, meters.TimeMeter): meter.reset() else: logger.info("no existing checkpoint found {}".format(filename)) return extra_state
def binarize( filename, dict, consumer, tokenize=tokenize_line, append_eos=True, reverse_order=False, offset=0, end=-1, already_numberized=False, avoid_tokenize=False, ) -> Dict[str, int]: nseq, ntok = 0, 0 replaced = Counter() def replaced_consumer(word, idx): if idx == dict.unk_index and word != dict.unk_word: replaced.update([word]) def replaced_consumer_from_pretrained(word, idx): if idx == dict.convert_tokens_to_ids( dict.unk_token) and word != dict.unk_token: replaced.update([word]) with open(PathManager.get_local_path(filename), "r", encoding="utf-8") as f: f.seek(offset) # next(f) breaks f.tell(), hence readline() must be used line = safe_readline(f) while line: # f.tell() does not always give the byte position in the file # sometimes it skips to a very large number # it is unlikely that through a normal read we go from # end bytes to end + 2**32 bytes (4 GB) and this makes it unlikely # that the procedure breaks by the undeterministic behavior of # f.tell() if end > 0 and f.tell() > end and f.tell() < end + 2**32: break if already_numberized: id_strings = line.strip().split() id_list = [int(id_string) for id_string in id_strings] if reverse_order: id_list.reverse() if append_eos: id_list.append(dict.eos()) ids = torch.IntTensor(id_list) elif isinstance(dict, BertTokenizer) and not isinstance( dict, ElectraTokenizer): line = line.strip() line = '{} {} {}'.format('[CLS]', line, '[SEP]') if avoid_tokenize is False: tokenizedline = dict.tokenize(line) else: tokenizedline = line.strip().split() # max-len:1000000000000 # print('----------bert_max-len:' + str(dict.max_len) + '----------') # if len(tokenizedline) > dict.max_len: # tokenizedline = tokenizedline[:dict.max_len - 1] # tokenizedline.append('[SEP]') words = dict.convert_tokens_to_ids(tokenizedline) nwords = len(words) ids = torch.IntTensor(nwords) for i, word in enumerate(words): ids[i] = word replaced_consumer_from_pretrained( tokenizedline[i], word) elif isinstance(dict, BartTokenizer): line = line.strip() if avoid_tokenize is False: # extra space at the end will cause weird outputs. line = '{} {}{}'.format('<s>', line, '</s>') tokenizedline = dict.tokenize(line) else: line = '{} {} {}'.format('<s>', line, '</s>') tokenizedline = line.strip().split() # tokenizedline = dict.tokenize(line) words = dict.convert_tokens_to_ids(tokenizedline) assert len(tokenizedline) == len(words) nwords = len(words) ids = torch.IntTensor(nwords) for i, word in enumerate(words): ids[i] = word replaced_consumer_from_pretrained( tokenizedline[i], word) elif isinstance(dict, ElectraTokenizer): line = line.strip() line = '{} {} {}'.format('[CLS]', line, '[SEP]') if avoid_tokenize is False: tokenizedline = dict.tokenize(line) else: tokenizedline = line.strip().split() # max-len:1000000000000 # print('----------bert_max-len:' + str(dict.max_len) + '----------') # if len(tokenizedline) > dict.max_len: # tokenizedline = tokenizedline[:dict.max_len - 1] # tokenizedline.append('[SEP]') words = dict.convert_tokens_to_ids(tokenizedline) # # import pdb; pdb.set_trace() nwords = len(words) ids = torch.IntTensor(nwords) for i, word in enumerate(words): ids[i] = word replaced_consumer_from_pretrained( tokenizedline[i], word) elif dict is None: line = line.strip() words = line.split() words = [int(item) for item in words] nwords = len(words) ids = torch.IntTensor(nwords) for i, word in enumerate(words): ids[i] = word else: ids = dict.encode_line( line=line, line_tokenizer=tokenize, add_if_not_exist=False, consumer=replaced_consumer, append_eos=append_eos, reverse_order=reverse_order, ) nseq += 1 ntok += len(ids) consumer(ids) line = f.readline() return { "nseq": nseq, "nunk": sum(replaced.values()), "ntok": ntok, "replaced": replaced, }
def save_state( filename, cfg: FairseqConfig, model_state_dict, criterion, optimizer, lr_scheduler, num_updates, optim_history=None, extra_state=None, task=None, **kwargs, ): from fairseq import utils if optim_history is None: optim_history = [] if extra_state is None: extra_state = {} state_dict = { "cfg": cfg, "args": kwargs.get("args", None), "model": model_state_dict or {}, "optimizer_history": optim_history + [ { "criterion_name": criterion.__class__.__name__, "optimizer_name": optimizer.__class__.__name__, "lr_scheduler_state": lr_scheduler.state_dict(), "num_updates": num_updates, } ], "extra_state": extra_state, "task_state": task.state_dict() if task is not None else {} } if utils.has_parameters(criterion): state_dict["criterion"] = criterion.state_dict() if cfg is None: cfg = state_dict["args"] assert cfg is not None, "must provide cfg or args" if isinstance(cfg, DictConfig): no_save_optimizer_state = cfg.checkpoint.no_save_optimizer_state else: no_save_optimizer_state = cfg.no_save_optimizer_state if not no_save_optimizer_state: state_dict["last_optimizer_state"] = optimizer.state_dict() # keep everything on CPU state_dict = utils.move_to_cpu(state_dict) if PathManager.supports_rename(filename): # do atomic save with PathManager.open(filename + ".tmp", "wb") as f: torch_persistent_save(state_dict, f) PathManager.rename(filename + ".tmp", filename) else: # fallback to non-atomic save with PathManager.open(filename, "wb") as f: torch_persistent_save(state_dict, f)
def load_checkpoint( self, filename, reset_optimizer=False, reset_lr_scheduler=False, optimizer_overrides=None, reset_meters=False, ): """Load all training state from a checkpoint file.""" extra_state, self._optim_history, last_optim_state = None, [], None bexists = PathManager.isfile(filename) if bexists: state = checkpoint_utils.load_checkpoint_to_cpu(filename) # load model parameters try: self.get_model().load_state_dict( state["model"], strict=True, args=self.args ) if utils.has_parameters(self.get_criterion()): self.get_criterion().load_state_dict( state["criterion"], strict=True ) except Exception: raise Exception( "Cannot load model parameters from checkpoint {}; " "please ensure that the architectures match.".format(filename) ) extra_state = state["extra_state"] self._optim_history = state["optimizer_history"] last_optim_state = state.get("last_optimizer_state", None) if last_optim_state is not None and not reset_optimizer: # rebuild optimizer after loading model, since params may have changed self._build_optimizer() # only reload optimizer and lr_scheduler if they match last_optim = self._optim_history[-1] assert ( last_optim["criterion_name"] == self.get_criterion().__class__.__name__ ), "Criterion does not match; please reset the optimizer (--reset-optimizer)." assert ( last_optim["optimizer_name"] == self.optimizer.__class__.__name__ ), "Optimizer does not match; please reset the optimizer (--reset-optimizer)." if not reset_lr_scheduler: self.lr_scheduler.load_state_dict(last_optim["lr_scheduler_state"]) self.optimizer.load_state_dict(last_optim_state, optimizer_overrides) self.set_num_updates(last_optim["num_updates"]) if extra_state is not None: epoch = extra_state["train_iterator"]["epoch"] logger.info( "loaded checkpoint {} (epoch {} @ {} updates)".format( filename, epoch, self.get_num_updates() ) ) if "previous_training_time" in extra_state: self._previous_training_time = extra_state["previous_training_time"] self._start_time = time.time() self.lr_step(epoch) if "metrics" in extra_state and not reset_meters: metrics.load_state_dict(extra_state["metrics"]) # reset TimeMeters, since their start times don't make sense anymore for meter in metrics.get_meters("default"): if isinstance(meter, meters.TimeMeter): meter.reset() else: logger.info("no existing checkpoint found {}".format(filename)) return extra_state
def test_file_io(self): from fairseq.file_io import PathManager with PathManager.open(os.path.join(self._tmpdir, "test.txt"), "r") as f: s = f.read() self.assertEqual(s, self._tmpfile_contents)
def save_checkpoint(cfg: CheckpointConfig, trainer, epoch_itr, val_loss): from fairseq import meters # only one worker should attempt to create the required dir if cfg.distributed_rank == 0: os.makedirs(cfg.save_dir, exist_ok=True) prev_best = getattr(save_checkpoint, "best", val_loss) if val_loss is not None: best_function = max if cfg.maximize_best_checkpoint_metric else min save_checkpoint.best = best_function(val_loss, prev_best) if cfg.no_save: return trainer.consolidate_optimizer() if not trainer.is_data_parallel_master: return write_timer = meters.StopwatchMeter() write_timer.start() epoch = epoch_itr.epoch end_of_epoch = epoch_itr.end_of_epoch() updates = trainer.get_num_updates() logger.info(f"Preparing to save checkpoint for epoch {epoch} @ {updates} updates") def is_better(a, b): return a >= b if cfg.maximize_best_checkpoint_metric else a <= b suffix = cfg.checkpoint_suffix or "" checkpoint_conds = collections.OrderedDict() checkpoint_conds["checkpoint{}{}.pt".format(epoch, suffix)] = ( end_of_epoch and not cfg.no_epoch_checkpoints and epoch % cfg.save_interval == 0 ) checkpoint_conds["checkpoint_{}_{}{}.pt".format(epoch, updates, suffix)] = ( not end_of_epoch and cfg.save_interval_updates > 0 and updates % cfg.save_interval_updates == 0 ) checkpoint_conds["checkpoint_best{}.pt".format(suffix)] = val_loss is not None and ( not hasattr(save_checkpoint, "best") or is_better(val_loss, save_checkpoint.best) ) if val_loss is not None and cfg.keep_best_checkpoints > 0: checkpoint_conds[ "checkpoint.best_{}_{:.2f}.pt".format(cfg.best_checkpoint_metric, val_loss) ] = not hasattr(save_checkpoint, "best") or is_better( val_loss, save_checkpoint.best ) checkpoint_conds[ "checkpoint_last{}.pt".format(suffix) ] = not cfg.no_last_checkpoints extra_state = {"train_iterator": epoch_itr.state_dict(), "val_loss": val_loss} if hasattr(save_checkpoint, "best"): extra_state.update({"best": save_checkpoint.best}) checkpoints = [ os.path.join(cfg.save_dir, fn) for fn, cond in checkpoint_conds.items() if cond ] if len(checkpoints) > 0: trainer.save_checkpoint(checkpoints[0], extra_state) for cp in checkpoints[1:]: assert PathManager.copy( checkpoints[0], cp, overwrite=True ), f"Failed to copy {checkpoints[0]} to {cp}" write_timer.stop() logger.info( "Saved checkpoint {} (epoch {} @ {} updates, score {}) (writing took {} seconds)".format( checkpoints[0], epoch, updates, val_loss, write_timer.sum ) ) if not end_of_epoch and cfg.keep_interval_updates > 0: # remove old checkpoints; checkpoints are sorted in descending order checkpoints = checkpoint_paths( cfg.save_dir, pattern=r"checkpoint_\d+_(\d+)\.pt" ) for old_chk in checkpoints[cfg.keep_interval_updates :]: if os.path.lexists(old_chk): os.remove(old_chk) if cfg.keep_last_epochs > 0: # remove old epoch checkpoints; checkpoints are sorted in descending order checkpoints = checkpoint_paths(cfg.save_dir, pattern=r"checkpoint(\d+)\.pt") for old_chk in checkpoints[cfg.keep_last_epochs :]: if os.path.lexists(old_chk): os.remove(old_chk) if cfg.keep_best_checkpoints > 0: # only keep the best N checkpoints according to validation metric checkpoints = checkpoint_paths( cfg.save_dir, pattern=r"checkpoint\.best_{}_(\d+\.?\d*)\.pt".format( cfg.best_checkpoint_metric ), ) if not cfg.maximize_best_checkpoint_metric: checkpoints = checkpoints[::-1] for old_chk in checkpoints[cfg.keep_best_checkpoints :]: if os.path.lexists(old_chk): os.remove(old_chk)
def load_checkpoint(args, trainer, **passthrough_args): """ Load a checkpoint and restore the training iterator. *passthrough_args* will be passed through to ``trainer.get_train_iterator``. """ reset_optimizer = args.reset_optimizer reset_lr_scheduler = args.reset_lr_scheduler optimizer_overrides = eval(args.optimizer_overrides) reset_meters = args.reset_meters reset_dataloader = args.reset_dataloader if getattr(args, 'finetune_from_model', None) is not None \ and (reset_optimizer or reset_lr_scheduler or reset_meters or reset_dataloader): raise ValueError( "--finetune-from-model can not be set together with either --reset-optimizer" " or reset_lr_scheduler or reset_meters or reset_dataloader") suffix = getattr(args, "checkpoint_suffix", "") if args.restore_file == "checkpoint_last.pt": # default value of restore_file is 'checkpoint_last.pt' checkpoint_path = os.path.join(args.save_dir, "checkpoint_last{}.pt".format(suffix)) first_launch = not PathManager.exists(checkpoint_path) if getattr(args, 'finetune_from_model', None) is not None and first_launch: # if there is no last checkpoint to restore, start the finetune from pretrained model # else just use usual logic to load checkpoint, e.g. restart from last checkpoint and etc. if PathManager.exists(args.finetune_from_model): checkpoint_path = args.finetune_from_model reset_optimizer = True reset_lr_scheduler = True reset_meters = True reset_dataloader = True logger.info( f'loading pretrained model from {checkpoint_path}: ' 'optimizer, lr scheduler, meters, dataloader will be reset' ) else: raise ValueError( f'--funetune-from-model {args.finetune_from_model} does not exist' ) elif getattr(args, "model_parallel_size", 1) > 1: checkpoint_path = args.restore_file.replace(".pt", suffix + ".pt") else: checkpoint_path = args.restore_file if args.restore_file != "checkpoint_last.pt" and getattr( args, 'finetune_from_model', None): raise ValueError( '--finetune-from-model and --restore-file (non-default value) ' 'can not be specified together: ' + str(args)) extra_state = trainer.load_checkpoint( checkpoint_path, reset_optimizer, reset_lr_scheduler, optimizer_overrides, reset_meters=reset_meters, ) if (extra_state is not None and "best" in extra_state and not reset_optimizer and not reset_meters): save_checkpoint.best = extra_state["best"] if extra_state is not None and not reset_dataloader: # restore iterator from checkpoint itr_state = extra_state["train_iterator"] epoch_itrs = trainer.get_train_iterator(epoch=itr_state["epoch"], load_dataset=True, **passthrough_args) epoch_itrs.load_state_dict(itr_state) else: epoch_itrs = trainer.get_train_iterator(epoch=1, load_dataset=True, **passthrough_args) if isinstance(epoch_itrs, list): trainer.lr_step(epoch_itrs[0].epoch) else: trainer.lr_step(epoch_itrs.epoch) return extra_state, epoch_itrs
def save_checkpoint(cfg: CheckpointConfig, trainer, epoch_itr, val_loss): from fairseq import meters # only one worker should attempt to create the required dir if trainer.data_parallel_rank == 0: os.makedirs(cfg.save_dir, exist_ok=True) prev_best = getattr(save_checkpoint, "best", val_loss) if val_loss is not None: best_function = max if cfg.maximize_best_checkpoint_metric else min save_checkpoint.best = best_function(val_loss, prev_best) if cfg.no_save: return trainer.consolidate_optimizer( ) # TODO(SS): do we need this if no_save_optimizer_state if not trainer.should_save_checkpoint_on_current_rank: if trainer.always_call_state_dict_during_save_checkpoint: trainer.state_dict() return write_timer = meters.StopwatchMeter() write_timer.start() epoch = epoch_itr.epoch end_of_epoch = epoch_itr.end_of_epoch() updates = trainer.get_num_updates() logger.info( f"Preparing to save checkpoint for epoch {epoch} @ {updates} updates") def is_better(a, b): return a >= b if cfg.maximize_best_checkpoint_metric else a <= b suffix = trainer.checkpoint_suffix checkpoint_conds = collections.OrderedDict() checkpoint_conds["checkpoint{}{}.pt".format( epoch, suffix)] = (end_of_epoch and not cfg.no_epoch_checkpoints and epoch % cfg.save_interval == 0) checkpoint_conds["checkpoint_{}_{}{}.pt".format( epoch, updates, suffix)] = (not end_of_epoch and cfg.save_interval_updates > 0 and updates % cfg.save_interval_updates == 0) checkpoint_conds["checkpoint_best{}.pt".format( suffix)] = val_loss is not None and ( not hasattr(save_checkpoint, "best") or is_better(val_loss, save_checkpoint.best)) if val_loss is not None and cfg.keep_best_checkpoints > 0: worst_best = getattr(save_checkpoint, "best", None) chkpts = checkpoint_paths( cfg.save_dir, pattern=r"checkpoint\.best_{}_(\d+\.?\d*)\.pt".format( cfg.best_checkpoint_metric), ) if len(chkpts) > 0: p = chkpts[-1] if cfg.maximize_best_checkpoint_metric else chkpts[0] worst_best = float(p.rsplit("_")[-1].replace(".pt", "")) # add random digits to resolve ties rand_sfx = randint(0, cfg.keep_best_checkpoints) checkpoint_conds["checkpoint.best_{}_{:.3f}{}.pt".format( cfg.best_checkpoint_metric, val_loss, rand_sfx)] = worst_best is None or is_better(val_loss, worst_best) checkpoint_conds["checkpoint_last{}.pt".format( suffix)] = not cfg.no_last_checkpoints extra_state = { "train_iterator": epoch_itr.state_dict(), "val_loss": val_loss } if hasattr(save_checkpoint, "best"): extra_state.update({"best": save_checkpoint.best}) checkpoints = [ os.path.join(cfg.save_dir, fn) for fn, cond in checkpoint_conds.items() if cond ] if len(checkpoints) > 0: trainer.save_checkpoint(checkpoints[0], extra_state) for cp in checkpoints[1:]: if cfg.write_checkpoints_asynchronously: # TODO[ioPath]: Need to implement a delayed asynchronous # file copying/moving feature. logger.warning( f"ioPath is not copying {checkpoints[0]} to {cp} " "since async write mode is on.") else: assert PathManager.copy( checkpoints[0], cp, overwrite=True), f"Failed to copy {checkpoints[0]} to {cp}" write_timer.stop() logger.info( "Saved checkpoint {} (epoch {} @ {} updates, score {}) (writing took {} seconds)" .format(checkpoints[0], epoch, updates, val_loss, write_timer.sum)) if not end_of_epoch and cfg.keep_interval_updates > 0: # remove old checkpoints; checkpoints are sorted in descending order if cfg.keep_interval_updates_pattern == -1: checkpoints = checkpoint_paths( cfg.save_dir, pattern=r"checkpoint_\d+_(\d+){}\.pt".format(suffix)) else: checkpoints = checkpoint_paths( cfg.save_dir, pattern=r"checkpoint_\d+_(\d+){}\.pt".format(suffix), keep_match=True, ) checkpoints = [ x[0] for x in checkpoints if x[1] % cfg.keep_interval_updates_pattern != 0 ] for old_chk in checkpoints[cfg.keep_interval_updates:]: if os.path.lexists(old_chk): os.remove(old_chk) elif PathManager.exists(old_chk): PathManager.rm(old_chk) if cfg.keep_last_epochs > 0: # remove old epoch checkpoints; checkpoints are sorted in descending order checkpoints = checkpoint_paths( cfg.save_dir, pattern=r"checkpoint(\d+){}\.pt".format(suffix)) for old_chk in checkpoints[cfg.keep_last_epochs:]: if os.path.lexists(old_chk): os.remove(old_chk) if cfg.keep_best_checkpoints > 0: # only keep the best N checkpoints according to validation metric checkpoints = checkpoint_paths( cfg.save_dir, pattern=r"checkpoint\.best_{}_(\d+\.?\d*){}\.pt".format( cfg.best_checkpoint_metric, suffix), ) if not cfg.maximize_best_checkpoint_metric: checkpoints = checkpoints[::-1] for old_chk in checkpoints[cfg.keep_best_checkpoints:]: if os.path.lexists(old_chk): os.remove(old_chk)
def save_checkpoint(args, trainer, epoch_itr, val_loss): from fairseq import distributed_utils, meters prev_best = getattr(save_checkpoint, "best", val_loss) if val_loss is not None: best_function = max if args.maximize_best_checkpoint_metric else min save_checkpoint.best = best_function(val_loss, prev_best) if args.no_save or not distributed_utils.is_master(args): return def is_better(a, b): return a >= b if args.maximize_best_checkpoint_metric else a <= b write_timer = meters.StopwatchMeter() write_timer.start() epoch = epoch_itr.epoch end_of_epoch = epoch_itr.end_of_epoch() updates = trainer.get_num_updates() checkpoint_conds = collections.OrderedDict() checkpoint_conds["checkpoint{}.pt".format(epoch)] = ( end_of_epoch and not args.no_epoch_checkpoints and epoch % args.save_interval == 0) checkpoint_conds["checkpoint_{}_{}.pt".format( epoch, updates)] = (not end_of_epoch and args.save_interval_updates > 0 and updates % args.save_interval_updates == 0) checkpoint_conds["checkpoint_best.pt"] = val_loss is not None and ( not hasattr(save_checkpoint, "best") or is_better(val_loss, save_checkpoint.best)) if val_loss is not None and args.keep_best_checkpoints > 0: checkpoint_conds["checkpoint.best_{}_{:.2f}.pt".format( args.best_checkpoint_metric, val_loss)] = (not hasattr(save_checkpoint, "best") or is_better(val_loss, save_checkpoint.best)) checkpoint_conds["checkpoint_last.pt"] = not args.no_last_checkpoints extra_state = { "train_iterator": epoch_itr.state_dict(), "val_loss": val_loss } if hasattr(save_checkpoint, "best"): extra_state.update({"best": save_checkpoint.best}) checkpoints = [ os.path.join(args.save_dir, fn) for fn, cond in checkpoint_conds.items() if cond ] if len(checkpoints) > 0: trainer.save_checkpoint(checkpoints[0], extra_state) for cp in checkpoints[1:]: PathManager.copy(checkpoints[0], cp, overwrite=True) write_timer.stop() logger.info( "saved checkpoint {} (epoch {} @ {} updates, score {}) (writing took {} seconds)" .format(checkpoints[0], epoch, updates, val_loss, write_timer.sum)) if not end_of_epoch and args.keep_interval_updates > 0: # remove old checkpoints; checkpoints are sorted in descending order checkpoints = checkpoint_paths(args.save_dir, pattern=r"checkpoint_\d+_(\d+)\.pt") for old_chk in checkpoints[args.keep_interval_updates:]: if os.path.lexists(old_chk): os.remove(old_chk) if args.keep_last_epochs > 0: # remove old epoch checkpoints; checkpoints are sorted in descending order checkpoints = checkpoint_paths(args.save_dir, pattern=r"checkpoint(\d+)\.pt") for old_chk in checkpoints[args.keep_last_epochs:]: if os.path.lexists(old_chk): os.remove(old_chk) if args.keep_best_checkpoints > 0: # only keep the best N checkpoints according to validation metric checkpoints = checkpoint_paths( args.save_dir, pattern=r"checkpoint\.best_{}_(\d+\.?\d*)\.pt".format( args.best_checkpoint_metric)) if not args.maximize_best_checkpoint_metric: checkpoints = checkpoints[::-1] for old_chk in checkpoints[args.keep_best_checkpoints:]: if os.path.lexists(old_chk): os.remove(old_chk)
def add_file_to_dictionary(filename, dict, tokenize): with PathManager.open(filename, "r", encoding="utf-8") as f: for line in f: for word in tokenize(line): dict.add_symbol(word) dict.add_symbol(dict.eos_word)
def main(cfg: FairseqConfig) -> None: if isinstance(cfg, argparse.Namespace): cfg = convert_namespace_to_omegaconf(cfg) utils.import_user_module(cfg.common) if distributed_utils.is_master( cfg.distributed_training) and "job_logging_cfg" in cfg: # make hydra logging work with ddp (see # see https://github.com/facebookresearch/hydra/issues/1126) logging.config.dictConfig(OmegaConf.to_container(cfg.job_logging_cfg)) assert ( cfg.dataset.max_tokens is not None or cfg.dataset.batch_size is not None ), "Must specify batch size either with --max-tokens or --batch-size" metrics.reset() if cfg.common.log_file is not None: handler = logging.FileHandler(filename=cfg.common.log_file) logger.addHandler(handler) np.random.seed(cfg.common.seed) utils.set_torch_seed(cfg.common.seed) if distributed_utils.is_master(cfg.distributed_training): checkpoint_utils.verify_checkpoint_directory(cfg.checkpoint.save_dir) # Print args logger.info(cfg) if cfg.checkpoint.write_checkpoints_asynchronously: try: import iopath # noqa: F401 except ImportError: logging.exception( "Asynchronous checkpoint writing is specified but iopath is " "not installed: `pip install iopath`") return # Setup task, e.g., translation, language modeling, etc. task = tasks.setup_task(cfg.task) assert cfg.criterion, "Please specify criterion to train a model" # Build model and criterion if cfg.distributed_training.ddp_backend == "fully_sharded": with fsdp_enable_wrap(cfg.distributed_training): model = fsdp_wrap(task.build_model(cfg.model)) else: model = task.build_model(cfg.model) criterion = task.build_criterion(cfg.criterion) logger.info(model) logger.info("task: {}".format(task.__class__.__name__)) logger.info("model: {}".format(model.__class__.__name__)) logger.info("criterion: {}".format(criterion.__class__.__name__)) logger.info("num. shared model params: {:,} (num. trained: {:,})".format( sum(p.numel() for p in model.parameters() if not getattr(p, "expert", False)), sum(p.numel() for p in model.parameters() if not getattr(p, "expert", False) and p.requires_grad))) logger.info("num. expert model params: {} (num. trained: {})".format( sum(p.numel() for p in model.parameters() if getattr(p, "expert", False)), sum(p.numel() for p in model.parameters() if getattr(p, "expert", False) and p.requires_grad), )) # Load valid dataset (we load training data below, based on the latest checkpoint) # We load the valid dataset AFTER building the model data_utils.raise_if_valid_subsets_unintentionally_ignored(cfg) if cfg.dataset.combine_valid_subsets: task.load_dataset("valid", combine=True, epoch=1) else: for valid_sub_split in cfg.dataset.valid_subset.split(","): task.load_dataset(valid_sub_split, combine=False, epoch=1) # (optionally) Configure quantization if cfg.common.quantization_config_path is not None: quantizer = quantization_utils.Quantizer( config_path=cfg.common.quantization_config_path, max_epoch=cfg.optimization.max_epoch, max_update=cfg.optimization.max_update, ) else: quantizer = None # Build trainer if cfg.common.model_parallel_size == 1: trainer = Trainer(cfg, task, model, criterion, quantizer) else: trainer = MegatronTrainer(cfg, task, model, criterion) logger.info("training on {} devices (GPUs/TPUs)".format( cfg.distributed_training.distributed_world_size)) logger.info( "max tokens per device = {} and max sentences per device = {}".format( cfg.dataset.max_tokens, cfg.dataset.batch_size, )) # Load the latest checkpoint if one is available and restore the # corresponding train iterator extra_state, epoch_itr = checkpoint_utils.load_checkpoint( cfg.checkpoint, trainer, # don't cache epoch iterators for sharded datasets disable_iterator_cache=task.has_sharded_data("train"), ) if cfg.common.tpu: import torch_xla.core.xla_model as xm xm.rendezvous("load_checkpoint") # wait for all workers max_epoch = cfg.optimization.max_epoch or math.inf lr = trainer.get_lr() train_meter = meters.StopwatchMeter() train_meter.start() while epoch_itr.next_epoch_idx <= max_epoch: if lr <= cfg.optimization.stop_min_lr: logger.info( f"stopping training because current learning rate ({lr}) is smaller " "than or equal to minimum learning rate " f"(--stop-min-lr={cfg.optimization.stop_min_lr})") break # train for one epoch valid_losses, should_stop = train(cfg, trainer, task, epoch_itr) if should_stop: break # only use first validation loss to update the learning rate lr = trainer.lr_step(epoch_itr.epoch, valid_losses[0]) epoch_itr = trainer.get_train_iterator( epoch_itr.next_epoch_idx, # sharded data: get train iterator for next epoch load_dataset=task.has_sharded_data("train"), # don't cache epoch iterators for sharded datasets disable_iterator_cache=task.has_sharded_data("train"), ) train_meter.stop() logger.info("done training in {:.1f} seconds".format(train_meter.sum)) # ioPath implementation to wait for all asynchronous file writes to complete. if cfg.checkpoint.write_checkpoints_asynchronously: logger.info( "ioPath PathManager waiting for all asynchronous checkpoint " "writes to finish.") PathManager.async_close() logger.info("ioPath PathManager finished waiting.")
def load_checkpoint( self, filename, reset_optimizer=False, reset_lr_scheduler=False, optimizer_overrides=None, reset_meters=False, ): """ Load all training state from a checkpoint file. rank = 0 will load the checkpoint, and then broadcast it to all other ranks. """ extra_state, self._optim_history, last_optim_state = None, [], None logger.info(f"Preparing to load checkpoint {filename}") is_distributed = self.data_parallel_world_size > 1 bexists = PathManager.isfile(filename) if bexists: load_on_all_ranks = ( self.cfg.checkpoint.load_checkpoint_on_all_dp_ranks # TPUs don't support broadcast yet, so load checkpoints # on every worker for now or self.tpu) if load_on_all_ranks or self.data_parallel_rank == 0: state = checkpoint_utils.load_checkpoint_to_cpu( filename, load_on_all_ranks=load_on_all_ranks) last_optim_state = state.get("last_optimizer_state", None) # If doing zero_sharding, do not broadcast global optimizer # state. Later we will broadcast sharded states to each rank # to avoid memory from exploding. if (not load_on_all_ranks and self.cfg.distributed_training.zero_sharding == "os" and "last_optimizer_state" in state and is_distributed): state["last_optimizer_state"] = "SHARDED" else: last_optim_state = None state = None if is_distributed and not load_on_all_ranks: state = distributed_utils.broadcast_object( state, src_rank=0, group=self.data_parallel_process_group, dist_device=self.device, ) if self.data_parallel_rank > 0: last_optim_state = state.get("last_optimizer_state", None) # load model parameters try: # model_dict = self.model.state_dict() # 1. filter out unnecessary keys # pretrained_dict = {k: v for k, v in state["model"].items() if k in model_dict # and v.size() == model_dict[k].size()} # 2. overwrite entries in the existing state dict # model_dict.update(pretrained_dict) # 3. load the new state dict # self.model.load_state_dict( # model_dict, strict=False, model_cfg=self.cfg.model # ) self.model.load_state_dict(state["model"], strict=True, model_cfg=self.cfg.model) if utils.has_parameters(self.get_criterion()): self.get_criterion().load_state_dict(state["criterion"], strict=True) except Exception: raise Exception( "Cannot load model parameters from checkpoint {}; " "please ensure that the architectures match.".format( filename)) extra_state = state["extra_state"] self._optim_history = state["optimizer_history"] if last_optim_state is not None and not reset_optimizer: # rebuild optimizer after loading model, since params may have changed self._build_optimizer() # only reload optimizer and lr_scheduler if they match last_optim = self._optim_history[-1] assert ( last_optim["criterion_name"] == self.get_criterion().__class__.__name__ ), f"Criterion does not match; please reset the optimizer (--reset-optimizer). {last_optim['criterion_name']} vs {self.get_criterion().__class__.__name__}" assert ( last_optim["optimizer_name"] == self.optimizer.__class__.__name__ ), f"Optimizer does not match; please reset the optimizer (--reset-optimizer). {last_optim['optimizer_name']} vs {self.optimizer.__class__.__name__}" if not reset_lr_scheduler: self.lr_scheduler.load_state_dict( last_optim["lr_scheduler_state"]) if not load_on_all_ranks and is_distributed: last_optim_state = self.optimizer.broadcast_global_state_dict( last_optim_state) self.optimizer.load_state_dict(last_optim_state, optimizer_overrides) self.set_num_updates(last_optim["num_updates"]) if extra_state is not None: itr_state = extra_state["train_iterator"] epoch = itr_state["epoch"] if "previous_training_time" in extra_state: self._previous_training_time = extra_state[ "previous_training_time"] self._start_time = time.time() self.lr_step(epoch) if itr_state.get("version", 1) >= 2 and itr_state["iterations_in_epoch"] == 0: # reset meters at start of epoch reset_meters = True if "metrics" in extra_state and not reset_meters: metrics.load_state_dict(extra_state["metrics"]) # reset TimeMeters, since their start times don't make sense anymore for meter in metrics.get_meters("default"): if isinstance(meter, meters.TimeMeter): meter.reset() logger.info("Loaded checkpoint {} (epoch {} @ {} updates)".format( filename, epoch, self.get_num_updates())) else: logger.info("No existing checkpoint found {}".format(filename)) return extra_state
def main(args): state = checkpoint_utils.load_checkpoint_to_cpu(args.checkpoint) ns = state["args"] model = state["model"] ns.arch = "transformer_modular" if (args.encoder_attention_heads_active is None and args.decoder_attention_heads_active is None): raise ValueError( 'Either --encoder-attention-heads-active or ' '--decoder-attention-heads-active option must be set.') if args.encoder_attention_heads_active is None: args.encoder_attention_heads_active = args.decoder_attention_heads_active if args.encoder_modular_layer_indices is not None: ns.encoder_modular_layer_indices = "({})".format( args.encoder_modular_layer_indices) model = convert_model(model, ns, coder="encoder", att_type="self_attn") if args.decoder_modular_layer_indices is not None: ns.decoder_modular_layer_indices = "({})".format( args.decoder_modular_layer_indices) model = convert_model(model, ns, coder="decoder", att_type="self_attn") model = convert_model(model, ns, coder="decoder", att_type="encoder_attn") ctrl_enc = ModularCtrl(ns.encoder_embed_dim, ns.encoder_attention_heads, args.encoder_attention_heads_active, hidden_depth=args.ctrl_hidden_depth, hidden_dim=args.ctrl_hidden_dim, ctrl_type=args.ctrl_type) ns.module_ctrl_hidden_depth = args.ctrl_hidden_depth ns.module_ctrl_hidden_dim = args.ctrl_hidden_dim ns.module_ctrl_type = args.ctrl_type for k, v in ctrl_enc.state_dict().items(): model["encoder.module_ctrl.{}".format(k)] = v if not args.share_encoder_ctrl: if args.decoder_attention_heads_active is None: raise ValueError("Missing ``decoder-attention-heads-active'' " "when ``share-encoder-ctrl'' is disabled.") ns.share_encoder_ctrl = False ctrl_dec = ModularCtrl(ns.decoder_embed_dim, ns.decoder_attention_heads, args.decoder_attention_heads_active, hidden_depth=args.ctrl_hidden_depth, hidden_dim=args.ctrl_hidden_dim, ctrl_type=args.ctrl_type) for k, v in ctrl_dec.state_dict().items(): model["decoder.module_ctrl.{}".format(k)] = v else: ns.share_encoder_ctrl = True ns.arch = "transformer_modular" ns.criterion = "label_smoothed_cross_entropy_modular" ns.task = "translation_modular" ns.encoder_attention_heads_active = args.encoder_attention_heads_active state["args"] = ns state["model"] = model for i, _ in enumerate(state["optimizer_history"]): state["optimizer_history"][i][ "criterion_name"] = 'LabelSmoothedCrossEntropyModularCriterion' state = utils.move_to_cpu(state) with PathManager.open(args.save_as, "wb") as f: checkpoint_utils.torch_persistent_save(state, f)
def load_checkpoint(cfg: CheckpointConfig, trainer, **passthrough_args): """ Load a checkpoint and restore the training iterator. *passthrough_args* will be passed through to ``trainer.get_train_iterator``. """ reset_optimizer = cfg.reset_optimizer reset_lr_scheduler = cfg.reset_lr_scheduler optimizer_overrides = ast.literal_eval(cfg.optimizer_overrides) reset_meters = cfg.reset_meters reset_dataloader = cfg.reset_dataloader if cfg.finetune_from_model is not None and (reset_optimizer or reset_lr_scheduler or reset_meters or reset_dataloader): raise ValueError( "--finetune-from-model can not be set together with either --reset-optimizer" " or reset_lr_scheduler or reset_meters or reset_dataloader") suffix = trainer.checkpoint_suffix if (cfg.restore_file == "checkpoint_last.pt" ): # default value of restore_file is 'checkpoint_last.pt' checkpoint_path = os.path.join(cfg.save_dir, "checkpoint_last{}.pt".format(suffix)) first_launch = not PathManager.exists(checkpoint_path) if cfg.finetune_from_model is not None and first_launch: # if there is no last checkpoint to restore, start the finetune from pretrained model # else just use usual logic to load checkpoint, e.g. restart from last checkpoint and etc. if PathManager.exists(cfg.finetune_from_model): checkpoint_path = cfg.finetune_from_model reset_optimizer = True reset_lr_scheduler = True reset_meters = True reset_dataloader = True logger.info( f"loading pretrained model from {checkpoint_path}: " "optimizer, lr scheduler, meters, dataloader will be reset" ) else: raise ValueError( f"--funetune-from-model {cfg.finetune_from_model} does not exist" ) elif suffix is not None: checkpoint_path = cfg.restore_file.replace(".pt", suffix + ".pt") else: checkpoint_path = cfg.restore_file if cfg.restore_file != "checkpoint_last.pt" and cfg.finetune_from_model: raise ValueError( "--finetune-from-model and --restore-file (non-default value) " "can not be specified together: " + str(cfg)) extra_state = trainer.load_checkpoint( checkpoint_path, reset_optimizer, reset_lr_scheduler, optimizer_overrides, reset_meters=reset_meters, ) if (extra_state is not None and "best" in extra_state and not reset_optimizer and not reset_meters): save_checkpoint.best = extra_state["best"] if extra_state is not None and not reset_dataloader: # restore iterator from checkpoint itr_state = extra_state["train_iterator"] epoch_itr = trainer.get_train_iterator(epoch=itr_state["epoch"], load_dataset=True, **passthrough_args) epoch_itr.load_state_dict(itr_state) else: epoch_itr = trainer.get_train_iterator(epoch=1, load_dataset=True, **passthrough_args) trainer.lr_step(epoch_itr.epoch) return extra_state, epoch_itr
def exists(path): return PathManager.exists(path)
def load_checkpoint_to_cpu(path, arg_overrides=None, load_on_all_ranks=False): """Loads a checkpoint to CPU (with upgrading for backward compatibility). If doing single-GPU training or if the checkpoint is only being loaded by at most one process on each node (current default behavior is for only rank 0 to read the checkpoint from disk), load_on_all_ranks should be False to avoid errors from torch.distributed not having been initialized or torch.distributed.barrier() hanging. If all processes on each node may be loading the checkpoint simultaneously, load_on_all_ranks should be set to True to avoid I/O conflicts. There's currently no support for > 1 but < all processes loading the checkpoint on each node. """ local_path = PathManager.get_local_path(path) # The locally cached file returned by get_local_path() may be stale for # remote files that are periodically updated/overwritten (ex: # checkpoint_last.pt) - so we remove the local copy, sync across processes # (if needed), and then download a fresh copy. if local_path != path and PathManager.path_requires_pathmanager(path): try: os.remove(local_path) except FileNotFoundError: # With potentially multiple processes removing the same file, the # file being missing is benign (missing_ok isn't available until # Python 3.8). pass if load_on_all_ranks: torch.distributed.barrier() local_path = PathManager.get_local_path(path) with open(local_path, "rb") as f: state = torch.load(f, map_location=torch.device("cpu")) if "args" in state and state[ "args"] is not None and arg_overrides is not None: args = state["args"] for arg_name, arg_val in arg_overrides.items(): setattr(args, arg_name, arg_val) if "cfg" in state and state["cfg"] is not None: # hack to be able to set Namespace in dict config. this should be removed when we update to newer # omegaconf version that supports object flags, or when we migrate all existing models from omegaconf import _utils old_primitive = _utils.is_primitive_type _utils.is_primitive_type = lambda _: True state["cfg"] = OmegaConf.create(state["cfg"]) _utils.is_primitive_type = old_primitive OmegaConf.set_struct(state["cfg"], True) if arg_overrides is not None: overwrite_args_by_name(state["cfg"], arg_overrides) state = _upgrade_state_dict(state) return state
def exists(path): return PathManager.exists(index_file_path(path)) and PathManager.exists( data_file_path(path) )
def load_model_ensemble_and_task( filenames, arg_overrides: Optional[Dict[str, Any]] = None, task=None, strict=True, suffix="", num_shards=1, state=None, ): assert state is None or len(filenames) == 1 from fairseq import tasks assert not ( strict and num_shards > 1 ), "Cannot load state dict with strict=True and checkpoint shards > 1" ensemble = [] cfg = None for filename in filenames: orig_filename = filename model_shard_state = {"shard_weights": [], "shard_metadata": []} assert num_shards > 0 st = time.time() for shard_idx in range(num_shards): filename = get_maybe_sharded_checkpoint_filename( orig_filename, suffix, shard_idx, num_shards) if not PathManager.exists(filename): raise IOError("Model file not found: {}".format(filename)) if state is None: state = load_checkpoint_to_cpu(filename, arg_overrides) if "args" in state and state["args"] is not None: cfg = convert_namespace_to_omegaconf(state["args"]) elif "cfg" in state and state["cfg"] is not None: cfg = state["cfg"] else: raise RuntimeError( f"Neither args nor cfg exist in state keys = {state.keys()}" ) if task is None: task = tasks.setup_task(cfg.task) if "task_state" in state: task.load_state_dict(state["task_state"]) if "fsdp_metadata" in state and num_shards > 1: model_shard_state["shard_weights"].append(state["model"]) model_shard_state["shard_metadata"].append( state["fsdp_metadata"]) # check FSDP import before the code goes too far if not has_FSDP: raise ImportError( "Cannot find FullyShardedDataParallel. " "Please install fairscale with: pip install fairscale") if shard_idx == num_shards - 1: consolidated_model_state = FSDP.consolidate_shard_weights( shard_weights=model_shard_state["shard_weights"], shard_metadata=model_shard_state["shard_metadata"], ) model = task.build_model(cfg.model) model.load_state_dict(consolidated_model_state, strict=strict, model_cfg=cfg.model) else: # model parallel checkpoint or unsharded checkpoint model = task.build_model(cfg.model) model.load_state_dict(state["model"], strict=strict, model_cfg=cfg.model) # reset state so it gets loaded for the next model in ensemble state = None if shard_idx % 10 == 0 and shard_idx > 0: elapsed = time.time() - st logger.info( f"Loaded {shard_idx} shards in {elapsed:.2f}s, {elapsed / (shard_idx+1):.2f}s/shard" ) # build model for ensemble ensemble.append(model) return ensemble, cfg, task
def upgrade_state_dict_with_infoxlm_weights( state_dict: Dict[str, Any], pretrained_infoxlm_checkpoint: str, num_layers: int, shared_cross_attn: bool = False) -> Dict[str, Any]: """ Load XLM weights into a Transformer encoder or decoder model. Args: state_dict: state dict for either TransformerEncoder or TransformerDecoder pretrained_infoxlm_checkpoint: checkpoint to load XLM weights from Raises: AssertionError: If architecture (num layers, attention heads, etc.) does not match between the current Transformer encoder or decoder and the pretrained_xlm_checkpoint """ if not os.path.exists(pretrained_infoxlm_checkpoint): raise IOError( "Model file not found: {}".format(pretrained_infoxlm_checkpoint)) # state = checkpoint_utils.load_checkpoint_to_cpu(pretrained_infoxlm_checkpoint) with open(PathManager.get_local_path(pretrained_infoxlm_checkpoint), "rb") as f: state = torch.load(f, map_location=torch.device("cpu")) infoxlm_state_dict = state["model"] # print(state_dict.keys()) for key in infoxlm_state_dict.keys(): if 'layers' in key and int(key.split('.')[3]) > num_layers - 1: continue if not key.startswith('decoder.'): continue if 'lm_head' not in key: if 'in_proj_weight' in key: q, k, v = infoxlm_state_dict[key].chunk(3, dim=0) state_dict[key.replace('decoder.sentence_encoder.', '').replace('in_proj_weight', 'q_proj.weight')] = q state_dict[key.replace('decoder.sentence_encoder.', '').replace('in_proj_weight', 'k_proj.weight')] = k state_dict[key.replace('decoder.sentence_encoder.', '').replace('in_proj_weight', 'v_proj.weight')] = v if shared_cross_attn: state_dict[key.replace( 'decoder.sentence_encoder.', '').replace('in_proj_weight', 'q_proj.weight').replace( 'self_attn', 'encoder_attn')] = q state_dict[key.replace( 'decoder.sentence_encoder.', '').replace('in_proj_weight', 'k_proj.weight').replace( 'self_attn', 'encoder_attn')] = k state_dict[key.replace( 'decoder.sentence_encoder.', '').replace('in_proj_weight', 'v_proj.weight').replace( 'self_attn', 'encoder_attn')] = v elif 'in_proj_bias' in key: q, k, v = infoxlm_state_dict[key].chunk(3, dim=0) state_dict[key.replace('decoder.sentence_encoder.', '').replace('in_proj_bias', 'q_proj.bias')] = q state_dict[key.replace('decoder.sentence_encoder.', '').replace('in_proj_bias', 'k_proj.bias')] = k state_dict[key.replace('decoder.sentence_encoder.', '').replace('in_proj_bias', 'v_proj.bias')] = v if shared_cross_attn: state_dict[key.replace('decoder.sentence_encoder.', '').replace('in_proj_bias', 'q_proj.bias').replace( 'self_attn', 'encoder_attn')] = q state_dict[key.replace('decoder.sentence_encoder.', '').replace('in_proj_bias', 'k_proj.bias').replace( 'self_attn', 'encoder_attn')] = k state_dict[key.replace('decoder.sentence_encoder.', '').replace('in_proj_bias', 'v_proj.bias').replace( 'self_attn', 'encoder_attn')] = v elif 'emb_layer_norm' in key: state_dict[key.replace( 'decoder.sentence_encoder.emb_layer_norm', 'layernorm_embedding')] = infoxlm_state_dict[key] elif 'embed_positions' in key: state_dict[key.replace( 'decoder.sentence_encoder.', '')] = infoxlm_state_dict[key][:state_dict[key.replace( 'decoder.sentence_encoder.', '')].size(0)] elif 'embed_tokens' in key: state_dict[key.replace('decoder.sentence_encoder.', '')][:infoxlm_state_dict[key]. size(0)] = infoxlm_state_dict[key] else: state_dict[key.replace('decoder.sentence_encoder.', '')] = infoxlm_state_dict[key] return state_dict