Exemplo n.º 1
0
 def save_model(self):
     save_model_with_metadata(
         model=self.model,
         metadata={},
         output_dir=self.output_dir,
         file_name=f"model__{self.train_state.global_steps:09d}",
     )
Exemplo n.º 2
0
 def save_last_model_with_metadata(self):
     save_model_with_metadata(
         model_or_state_dict=self.model,
         output_dir=self.output_dir,
         file_name="last_model",
         metadata={"train_state": self.train_state.to_dict()},
     )
Exemplo n.º 3
0
 def save_best_model_with_metadata(self, val_metrics_dict):
     save_model_with_metadata(
         model_or_state_dict=self.best_state_dict,
         output_dir=self.output_dir,
         file_name="best_model",
         metadata={
             "val_state": self.best_val_state.to_dict(),
             "val_metrics": self.best_val_state.metrics,
         },
     )
Exemplo n.º 4
0
 def eval_save(self):
     self.num_evals_since_improvement += 1
     val_results_dict = self.runner.run_val(
         task_name_list=self.runner.jiant_task_container.task_run_config.
         train_val_task_list,
         use_subset=True,
     )
     aggregated_major = jiant_task_sampler.compute_aggregate_major_metrics_from_results_dict(
         metrics_aggregator=self.runner.jiant_task_container.
         metrics_aggregator,
         results_dict=val_results_dict,
     )
     val_metrics_dict = jiant_task_sampler.get_metrics_dict_from_results_dict(
         results_dict=val_results_dict, )
     val_state = ValState(
         score=float(aggregated_major),
         metrics=val_metrics_dict,
         train_state=self.train_state.new(),
     )
     self.log_writer.write_entry("train_val", val_state.to_dict())
     if self.best_val_state is None or val_state.score > self.best_val_state.score:
         self.best_val_state = val_state.new()
         self.log_writer.write_entry("train_val_best",
                                     self.best_val_state.to_dict())
         if self.save_best_model:
             save_model_with_metadata(
                 model=self.model,
                 metadata={
                     "val_state": self.best_val_state.to_dict(),
                     "val_metrics": val_metrics_dict,
                 },
                 output_dir=self.output_dir,
                 file_name="best_model",
             )
         del self.best_state_dict
         self.best_state_dict = copy_state_dict(
             state_dict=get_model_for_saving(self.model).state_dict(),
             target_device=CPU_DEVICE,
         )
         self.num_evals_since_improvement = 0
     self.log_writer.write_entry(
         "early_stopping",
         {
             "num_evals_since_improvement":
             self.num_evals_since_improvement,
             "train_state": self.train_state.to_dict(),
         },
     )
     self.log_writer.flush()
     self.val_state_history.append(val_state)