def save_best_averaged_checkpoint(self, args, trainer, extra_state: Dict[str, Any]): """ save() should always be called before calling this function - to ensure that extra_state and self._averaged_params have been updated correctly. """ best_averaged_checkpoint_filename = os.path.join( args.save_dir, constants.AVERAGED_CHECKPOINT_BEST_FILENAME) self.log_if_verbose( f"| Preparing to save new best averaged checkpoint to " f"{best_averaged_checkpoint_filename}.") checkpoint_utils.save_state( filename=best_averaged_checkpoint_filename, args=args, model_state_dict=self._averaged_params, criterion=trainer.criterion, optimizer=trainer.optimizer, lr_scheduler=trainer.lr_scheduler, num_updates=trainer._num_updates, optim_history=trainer._optim_history, extra_state=extra_state, ) self.log_if_verbose( f"| Finished saving new best averaged checkpoint to " f"{best_averaged_checkpoint_filename}.")
def save_checkpoint(self, filename, extra_state): """Save all training state in a checkpoint file.""" if distributed_utils.is_master(self.args): # only save one checkpoint extra_state['train_meters'] = self.meters checkpoint_utils.save_state( filename, self.args, self.get_model().state_dict(), self.criterion, self.optimizer, self.lr_scheduler, self._num_updates, self._optim_history, extra_state, )
def save_checkpoint(self, filename, extra_state): """Save all training state in a checkpoint file.""" if self.is_data_parallel_master: # only save one checkpoint extra_state["metrics"] = metrics.state_dict() checkpoint_utils.save_state( filename, self.args, self.get_model().state_dict(), self.get_criterion(), self.optimizer, self.lr_scheduler, self.get_num_updates(), self._optim_history, extra_state, )
def save_checkpoint(self, filename, extra_state): """Save all training state in a checkpoint file.""" if distributed_utils.is_master(self.args): # only save one checkpoint extra_state["train_meters"] = self.meters ##state_dict(): nn.Module中实现,返回dict存放module及其子module存放在_parameters和_buffers中的items checkpoint_utils.save_state( filename, self.args, self.get_model().state_dict(), self.get_criterion(), self.optimizer, self.lr_scheduler, self.get_num_updates(), self._optim_history, extra_state, )
def save_checkpoint(self, filename, extra_state): """Save all training state in a checkpoint file.""" if self.is_data_parallel_master: # only save one checkpoint extra_state["metrics"] = metrics.state_dict() extra_state["previous_training_time"] = self.cumulative_training_time() checkpoint_utils.save_state( filename, self.cfg, self.get_model().state_dict(), self.get_criterion(), self.optimizer, self.lr_scheduler, self.get_num_updates(), self._optim_history, extra_state, ) logger.info(f"Finished saving checkpoint to {filename}")
def save_checkpoint(self, filename, extra_state): """Save all training state in a checkpoint file.""" # tpu-comment: only save one checkpoint unless xla # torch_xla handles only saving from master in `xm.save` if self.xla or distributed_utils.is_master(self.args): extra_state['train_meters'] = self.meters checkpoint_utils.save_state( filename, self.args, self.get_model().state_dict(), self.get_criterion(), self.optimizer, self.lr_scheduler, self.get_num_updates(), self._optim_history, extra_state, ) gcsfs.generic_write( self.checkpoint_tagger.save_to_json().encode(), os.path.join(os.path.dirname(filename), self.checkpoint_tagger_filename))