def load_checkpoint( self, filename, reset_optimizer=False, reset_lr_scheduler=False, optimizer_overrides=None, reset_meters=False, ): """Load all training state from a checkpoint file.""" extra_state, self._optim_history, last_optim_state = None, [], None try: from fairseq.fb_pathmgr import fb_pathmgr bexists = fb_pathmgr.isfile(filename) except Exception: bexists = os.path.exists(filename) if bexists: state = checkpoint_utils.load_checkpoint_to_cpu(filename) # load model parameters try: self.get_model().load_state_dict(state['model'], strict=True, args=self.args) if utils.has_parameters(self.get_criterion()): self.get_criterion().load_state_dict(state['criterion'], strict=True) except Exception: raise Exception( 'Cannot load model parameters from checkpoint {}; ' 'please ensure that the architectures match.'.format( filename)) extra_state = state['extra_state'] self._optim_history = state['optimizer_history'] last_optim_state = state.get('last_optimizer_state', None) if last_optim_state is not None and not reset_optimizer: # rebuild optimizer after loading model, since params may have changed self._build_optimizer() # only reload optimizer and lr_scheduler if they match last_optim = self._optim_history[-1] assert last_optim['criterion_name'] == self.get_criterion().__class__.__name__, \ 'Criterion does not match; please reset the optimizer (--reset-optimizer).' assert last_optim['optimizer_name'] == self.optimizer.__class__.__name__, \ 'Optimizer does not match; please reset the optimizer (--reset-optimizer).' if not reset_lr_scheduler: self.lr_scheduler.load_state_dict( last_optim['lr_scheduler_state']) self.optimizer.load_state_dict(last_optim_state, optimizer_overrides) self.set_num_updates(last_optim['num_updates']) if extra_state is not None: epoch = extra_state['train_iterator']['epoch'] print('| loaded checkpoint {} (epoch {} @ {} updates)'.format( filename, epoch, self.get_num_updates())) self.lr_step(epoch) if 'train_meters' in extra_state and not reset_meters: self.meters.update(extra_state['train_meters']) del extra_state['train_meters'] # reset TimeMeters, since their start times don't make sense anymore for meter in self.meters.values(): if isinstance(meter, TimeMeter): meter.reset() else: print('| no existing checkpoint found {}'.format(filename)) return extra_state
def load_checkpoint(self, filename, reset_optimizer=False, reset_lr_scheduler=False, optimizer_overrides=None, reset_meters=False): """Load all training state from a checkpoint file.""" extra_state, self._optim_history, last_optim_state = None, [], None try: from fairseq.fb_pathmgr import fb_pathmgr bexists = fb_pathmgr.isfile(filename) except (ModuleNotFoundError, ImportError): bexists = os.path.exists(filename) if bexists: state = checkpoint_utils.load_checkpoint_to_cpu(filename) # load model parameters try: self.get_model().load_state_dict(state["model"], strict=True, args=self.args) if utils.has_parameters(self.get_criterion()): self.get_criterion().load_state_dict(state["criterion"], strict=True) except Exception: raise Exception( "Cannot load model parameters from checkpoint {}; " "please ensure that the architectures match.".format( filename)) extra_state = state["extra_state"] self._optim_history = state["optimizer_history"] last_optim_state = state.get("last_optimizer_state", None) if last_optim_state is not None and not reset_optimizer: # rebuild optimizer after loading model, since params may have changed self._build_optimizer() # only reload optimizer and lr_scheduler if they match last_optim = self._optim_history[-1] assert ( last_optim["criterion_name"] == self.get_criterion().__class__.__name__ ), "Criterion does not match; please reset the optimizer (--reset-optimizer)." assert ( last_optim["optimizer_name"] == self.optimizer.__class__.__name__ ), "Optimizer does not match; please reset the optimizer (--reset-optimizer)." if not reset_lr_scheduler: self.lr_scheduler.load_state_dict( last_optim["lr_scheduler_state"]) self.optimizer.load_state_dict(last_optim_state, optimizer_overrides) self.set_num_updates(last_optim["num_updates"]) if extra_state is not None: epoch = extra_state["train_iterator"]["epoch"] print("| loaded checkpoint {} (epoch {} @ {} updates)".format( filename, epoch, self.get_num_updates())) self.lr_step(epoch) if "train_meters" in extra_state and not reset_meters: self.meters.update(extra_state["train_meters"]) del extra_state["train_meters"] # reset TimeMeters, since their start times don't make sense anymore for meter in self.meters.values(): if isinstance(meter, TimeMeter): meter.reset() else: print("| no existing checkpoint found {}".format(filename)) return extra_state
def load_checkpoint( self, filename, reset_optimizer=False, reset_lr_scheduler=False, optimizer_overrides=None, reset_meters=False, ): """Load all training state from a checkpoint file.""" def rename_state_dict_keys(source,old_name,new_name,all=True): new_state_dict = OrderedDict() for key, value in source.items(): if all: if "encoder" in key and "decoder" not in key: key = key.replace(old_name,new_name) new_state_dict[key] = value else: if "encoder" in key and "decoder" not in key: key = key.replace(old_name,new_name) new_state_dict[key] = value return new_state_dict extra_state, self._optim_history, last_optim_state = None, [], None try: from fairseq.fb_pathmgr import fb_pathmgr bexists = fb_pathmgr.isfile(filename) except Exception: bexists = os.path.exists(filename) if bexists: state = checkpoint_utils.load_checkpoint_to_cpu(filename) # load model parameters try: # if self.args.task == "audio_translation" and self.args.audio_pt is not None: # # print("**************") # # state['model']=rename_state_dict_keys(state['model'],"encoder","text_encoder") # self.get_model().load_state_dict(state['model'], strict=False, args=self.args) # else: self.get_model().load_state_dict(state['model'], strict=True, args=self.args) if utils.has_parameters(self.get_criterion()): self.get_criterion().load_state_dict(state['criterion'], strict=True) except Exception: raise Exception( 'Cannot load model parameters from checkpoint {}; ' 'please ensure that the architectures match.'.format(filename) ) extra_state = state['extra_state'] self._optim_history = state['optimizer_history'] last_optim_state = state.get('last_optimizer_state', None) if last_optim_state is not None and not reset_optimizer: # rebuild optimizer after loading model, since params may have changed self._build_optimizer() # only reload optimizer and lr_scheduler if they match last_optim = self._optim_history[-1] assert last_optim['criterion_name'] == self.get_criterion().__class__.__name__, \ 'Criterion does not match; please reset the optimizer (--reset-optimizer).' assert last_optim['optimizer_name'] == self.optimizer.__class__.__name__, \ 'Optimizer does not match; please reset the optimizer (--reset-optimizer).' if not reset_lr_scheduler: self.lr_scheduler.load_state_dict(last_optim['lr_scheduler_state']) self.optimizer.load_state_dict(last_optim_state, optimizer_overrides) self.set_num_updates(last_optim['num_updates']) if extra_state is not None: epoch = extra_state['train_iterator']['epoch'] print('| loaded mt checkpoint {} (epoch {} @ {} updates)'.format( filename, epoch, self.get_num_updates())) self.lr_step(epoch) if 'train_meters' in extra_state and not reset_meters: self.meters.update(extra_state['train_meters']) del extra_state['train_meters'] # reset TimeMeters, since their start times don't make sense anymore for meter in self.meters.values(): if isinstance(meter, TimeMeter): meter.reset() else: print('| no existing checkpoint found {}'.format(filename)) if self.args.task == "audio_translation" and self.args.audio_pt is not None: print('| loaded audio checkpoint {} '.format(self.args.audio_pt)) model_dict = torch.load(self.args.audio_pt) model_dict["model"] = rename_state_dict_keys(model_dict["model"], "encoder", "audio_encoder", all=False) self.model.load_state_dict(model_dict["model"], strict=False) if self.args.task == "audio_translation" and self.args.mt_pt is not None: print('| loaded mt checkpoint {} '.format(self.args.mt_pt)) model_dict = torch.load(self.args.mt_pt) # model_dict["model"] = rename_state_dict_keys(model_dict["model"], "encoder", "audio_encoder", all=False) self.model.load_state_dict(model_dict["model"], strict=False) return extra_state