Exemplo n.º 1
0
 def save_best_averaged_checkpoint(self, args, trainer,
                                   extra_state: Dict[str, Any]):
     """
     save() should always be called before calling this function - to ensure
     that extra_state and self._averaged_params have been updated correctly.
     """
     best_averaged_checkpoint_filename = os.path.join(
         args.save_dir, constants.AVERAGED_CHECKPOINT_BEST_FILENAME)
     self.log_if_verbose(
         f"| Preparing to save new best averaged checkpoint to "
         f"{best_averaged_checkpoint_filename}.")
     checkpoint_utils.save_state(
         filename=best_averaged_checkpoint_filename,
         args=args,
         model_state_dict=self._averaged_params,
         criterion=trainer.criterion,
         optimizer=trainer.optimizer,
         lr_scheduler=trainer.lr_scheduler,
         num_updates=trainer._num_updates,
         optim_history=trainer._optim_history,
         extra_state=extra_state,
     )
     self.log_if_verbose(
         f"| Finished saving new best averaged checkpoint to "
         f"{best_averaged_checkpoint_filename}.")
Exemplo n.º 2
0
 def save_checkpoint(self, filename, extra_state):
     """Save all training state in a checkpoint file."""
     if distributed_utils.is_master(self.args):  # only save one checkpoint
         extra_state['train_meters'] = self.meters
         checkpoint_utils.save_state(
             filename, self.args, self.get_model().state_dict(), self.criterion, self.optimizer,
             self.lr_scheduler, self._num_updates, self._optim_history, extra_state,
         )
Exemplo n.º 3
0
 def save_checkpoint(self, filename, extra_state):
     """Save all training state in a checkpoint file."""
     if self.is_data_parallel_master:  # only save one checkpoint
         extra_state["metrics"] = metrics.state_dict()
         checkpoint_utils.save_state(
             filename,
             self.args,
             self.get_model().state_dict(),
             self.get_criterion(),
             self.optimizer,
             self.lr_scheduler,
             self.get_num_updates(),
             self._optim_history,
             extra_state,
         )
Exemplo n.º 4
0
 def save_checkpoint(self, filename, extra_state):
     """Save all training state in a checkpoint file."""
     if distributed_utils.is_master(self.args):  # only save one checkpoint
         extra_state["train_meters"] = self.meters
         ##state_dict(): nn.Module中实现,返回dict存放module及其子module存放在_parameters和_buffers中的items
         checkpoint_utils.save_state(
             filename,
             self.args,
             self.get_model().state_dict(),
             self.get_criterion(),
             self.optimizer,
             self.lr_scheduler,
             self.get_num_updates(),
             self._optim_history,
             extra_state,
         )
Exemplo n.º 5
0
 def save_checkpoint(self, filename, extra_state):
     """Save all training state in a checkpoint file."""
     if self.is_data_parallel_master:  # only save one checkpoint
         extra_state["metrics"] = metrics.state_dict()
         extra_state["previous_training_time"] = self.cumulative_training_time()
         checkpoint_utils.save_state(
             filename,
             self.cfg,
             self.get_model().state_dict(),
             self.get_criterion(),
             self.optimizer,
             self.lr_scheduler,
             self.get_num_updates(),
             self._optim_history,
             extra_state,
         )
         logger.info(f"Finished saving checkpoint to {filename}")
Exemplo n.º 6
0
 def save_checkpoint(self, filename, extra_state):
     """Save all training state in a checkpoint file."""
     # tpu-comment: only save one checkpoint unless xla
     # torch_xla handles only saving from master in `xm.save`
     if self.xla or distributed_utils.is_master(self.args):
         extra_state['train_meters'] = self.meters
         checkpoint_utils.save_state(
             filename,
             self.args,
             self.get_model().state_dict(),
             self.get_criterion(),
             self.optimizer,
             self.lr_scheduler,
             self.get_num_updates(),
             self._optim_history,
             extra_state,
         )
         gcsfs.generic_write(
             self.checkpoint_tagger.save_to_json().encode(),
             os.path.join(os.path.dirname(filename),
                          self.checkpoint_tagger_filename))