Beispiel #1
0
 def done_training(self):
     if self.save_last_model:
         self.save_last_model_with_metadata()
     self.eval_save()
     if self.load_best_model and self.best_state_dict is not None:
         copied_state_dict = copy_state_dict(
             state_dict=self.best_state_dict,
             target_device=None,  # Why was this required?
             # target_device=self.device,
         )
         if isinstance(self.model, nn.DataParallel):
             self.model.module.load_state_dict(copied_state_dict)
         else:
             self.model.load_state_dict(copied_state_dict)
Beispiel #2
0
 def done_training(self):
     self.eval_save()
     if self.load_best_model and self.best_state_dict is not None:
         if self.verbose:
             print("Loading Best")
         copied_state_dict = copy_state_dict(
             state_dict=self.best_state_dict,
             target_device=None,  # Why was this required?
             # target_device=self.device,
         )
         if isinstance(self.model, nn.DataParallel):
             self.model.module.load_state_dict(copied_state_dict)
         else:
             self.model.load_state_dict(copied_state_dict)
Beispiel #3
0
 def eval_save(self):
     self.num_evals_since_improvement += 1
     val_results_dict = self.runner.run_val(
         task_name_list=self.runner.jiant_task_container.task_run_config.
         train_val_task_list,
         use_subset=True,
     )
     aggregated_major = jiant_task_sampler.compute_aggregate_major_metrics_from_results_dict(
         metrics_aggregator=self.runner.jiant_task_container.
         metrics_aggregator,
         results_dict=val_results_dict,
     )
     val_metrics_dict = jiant_task_sampler.get_metrics_dict_from_results_dict(
         results_dict=val_results_dict, )
     val_state = ValState(
         score=float(aggregated_major),
         metrics=val_metrics_dict,
         train_state=self.train_state.new(),
     )
     self.log_writer.write_entry("train_val", val_state.to_dict())
     if self.best_val_state is None or val_state.score > self.best_val_state.score:
         self.best_val_state = val_state.new()
         self.log_writer.write_entry("train_val_best",
                                     self.best_val_state.to_dict())
         if self.save_best_model:
             save_model_with_metadata(
                 model=self.model,
                 metadata={
                     "val_state": self.best_val_state.to_dict(),
                     "val_metrics": val_metrics_dict,
                 },
                 output_dir=self.output_dir,
                 file_name="best_model",
             )
         del self.best_state_dict
         self.best_state_dict = copy_state_dict(
             state_dict=get_model_for_saving(self.model).state_dict(),
             target_device=CPU_DEVICE,
         )
         self.num_evals_since_improvement = 0
     self.log_writer.write_entry(
         "early_stopping",
         {
             "num_evals_since_improvement":
             self.num_evals_since_improvement,
             "train_state": self.train_state.to_dict(),
         },
     )
     self.log_writer.flush()
     self.val_state_history.append(val_state)