def get_runner_state(self): # TODO: Add fp16 (issue #1186) state = { "model": torch_utils.get_model_for_saving(self.jiant_model).state_dict(), "optimizer": self.optimizer_scheduler.optimizer.state_dict(), } return state
def save_model_with_metadata(model: nn.Module, metadata: dict, output_dir: str, file_name="model"): torch.save( torch_utils.get_model_for_saving(model).state_dict(), os.path.join(output_dir, f"{file_name}.p"), ) py_io.write_json(metadata, os.path.join(output_dir, f"{file_name}.metadata.json"))
def save_model(self): """Override to save only optimized parameters""" file_name = f"model__{self.train_state.global_steps:09d}" torch.save( adapters_modeling. get_optimized_state_dict_for_jiant_model_with_adapters( torch_utils.get_model_for_saving(self.model)), os.path.join(self.output_dir, f"{file_name}.p"), ) py_io.write_json( "{}", os.path.join(self.output_dir, f"{file_name}.metadata.json"))
def save_model_with_metadata(model: nn.Module, metadata: dict, output_dir: str, file_name="model"): torch.save( adapters_modeling. get_optimized_state_dict_for_jiant_model_with_adapters( torch_utils.get_model_for_saving(model)), os.path.join(output_dir, f"{file_name}.p"), ) py_io.write_json(metadata, os.path.join(output_dir, f"{file_name}.metadata.json"))
def save_model_with_metadata( model_or_state_dict: Union[nn.Module, dict], output_dir: str, file_name="model", metadata: Optional[dict] = None, ): if isinstance(model_or_state_dict, dict): state_dict = model_or_state_dict else: state_dict = torch_utils.get_model_for_saving(model_or_state_dict).state_dict() torch.save(state_dict, os.path.join(output_dir, f"{file_name}.p")) if metadata is not None: py_io.write_json(metadata, os.path.join(output_dir, f"{file_name}.metadata.json"))
def eval_save(self): self.num_evals_since_improvement += 1 val_results_dict = self.runner.run_val( task_name_list=self.runner.jiant_task_container.task_run_config. train_val_task_list, use_subset=True, ) aggregated_major = jiant_task_sampler.compute_aggregate_major_metrics_from_results_dict( metrics_aggregator=self.runner.jiant_task_container. metrics_aggregator, results_dict=val_results_dict, ) val_metrics_dict = jiant_task_sampler.get_metrics_dict_from_results_dict( results_dict=val_results_dict, ) val_state = ValState( score=float(aggregated_major), metrics=val_metrics_dict, train_state=self.train_state.new(), ) self.log_writer.write_entry("train_val", val_state.to_dict()) if self.best_val_state is None or val_state.score > self.best_val_state.score: self.best_val_state = val_state.new() self.log_writer.write_entry("train_val_best", self.best_val_state.to_dict()) if self.save_best_model: save_model_with_metadata( model=self.model, metadata={ "val_state": self.best_val_state.to_dict(), "val_metrics": val_metrics_dict, }, output_dir=self.output_dir, file_name="best_model", ) del self.best_state_dict self.best_state_dict = copy_state_dict( state_dict=get_model_for_saving(self.model).state_dict(), target_device=CPU_DEVICE, ) self.num_evals_since_improvement = 0 self.log_writer.write_entry( "early_stopping", { "num_evals_since_improvement": self.num_evals_since_improvement, "train_state": self.train_state.to_dict(), }, ) self.log_writer.flush() self.val_state_history.append(val_state)
def load_state(self, runner_state): torch_utils.get_model_for_saving(self.jiant_model).load_state_dict(runner_state["model"]) self.optimizer_scheduler.optimizer.load_state_dict(runner_state["optimizer"])
def run_loop(args: RunConfiguration, checkpoint=None): is_resumed = checkpoint is not None quick_init_out = initialization.quick_init(args=args, verbose=True) print(quick_init_out.n_gpu) with quick_init_out.log_writer.log_context(): jiant_task_container = container_setup.create_jiant_task_container_from_json( jiant_task_container_config_path=args. jiant_task_container_config_path, verbose=True, ) runner = setup_runner( args=args, jiant_task_container=jiant_task_container, quick_init_out=quick_init_out, verbose=True, ) if is_resumed: runner.load_state(checkpoint["runner_state"]) del checkpoint["runner_state"] checkpoint_saver = jiant_runner.CheckpointSaver( metadata={"args": args.to_dict()}, save_path=os.path.join(args.output_dir, "checkpoint.p"), ) if args.do_train: metarunner = jiant_metarunner.JiantMetarunner( runner=runner, save_every_steps=args.save_every_steps, eval_every_steps=args.eval_every_steps, save_checkpoint_every_steps=args.save_checkpoint_every_steps, no_improvements_for_n_evals=args.no_improvements_for_n_evals, checkpoint_saver=checkpoint_saver, output_dir=args.output_dir, verbose=True, save_best_model=args.do_save, load_best_model=True, log_writer=quick_init_out.log_writer, ) if is_resumed: metarunner.load_state(checkpoint["metarunner_state"]) del checkpoint["metarunner_state"] metarunner.run_train_loop() if args.do_save: torch.save( torch_utils.get_model_for_saving( runner.jiant_model).state_dict(), os.path.join(args.output_dir, "model.p"), ) if args.do_val: val_results_dict = runner.run_val( task_name_list=runner.jiant_task_container.task_run_config. val_task_list, return_preds=args.write_val_preds, ) jiant_evaluate.write_val_results( val_results_dict=val_results_dict, metrics_aggregator=runner.jiant_task_container. metrics_aggregator, output_dir=args.output_dir, verbose=True, ) if args.write_val_preds: jiant_evaluate.write_preds( eval_results_dict=val_results_dict, path=os.path.join(args.output_dir, "val_preds.p"), ) else: assert not args.write_val_preds if args.write_test_preds: test_results_dict = runner.run_test( task_name_list=runner.jiant_task_container.task_run_config. test_task_list, ) jiant_evaluate.write_preds( eval_results_dict=test_results_dict, path=os.path.join(args.output_dir, "test_preds.p"), ) if (args.delete_checkpoint_if_done and args.save_checkpoint_every_steps and os.path.exists(os.path.join(args.output_dir, "checkpoint.p"))): os.remove(os.path.join(args.output_dir, "checkpoint.p")) py_io.write_file("DONE", os.path.join(args.output_dir, "done_file"))
def run_loop(args: RunConfiguration, checkpoint=None): is_resumed = checkpoint is not None quick_init_out = initialization.quick_init(args=args, verbose=True) print(quick_init_out.n_gpu) with quick_init_out.log_writer.log_context(): jiant_task_container = container_setup.create_jiant_task_container_from_json( jiant_task_container_config_path=args. jiant_task_container_config_path, verbose=True, ) runner = setup_runner( args=args, jiant_task_container=jiant_task_container, quick_init_out=quick_init_out, verbose=True, ) if is_resumed: runner.load_state(checkpoint["runner_state"]) del checkpoint["runner_state"] # allow custom checkpoint name if args.custom_checkpoint_name: checkpoint_name = os.path.join(args.output_dir, f"{args.custom_checkpoint_name}.p") else: checkpoint_name = os.path.join(args.output_dir, "checkpoint.p") checkpoint_saver = jiant_runner.CheckpointSaver( metadata={"args": args.to_dict()}, save_path=os.path.join(args.output_dir, checkpoint_name), ) if args.do_train: metarunner = jiant_metarunner.JiantMetarunner( runner=runner, save_every_steps=args.save_every_steps, eval_every_steps=args.eval_every_steps, save_checkpoint_every_steps=args.save_checkpoint_every_steps, no_improvements_for_n_evals=args.no_improvements_for_n_evals, checkpoint_saver=checkpoint_saver, output_dir=args.output_dir, verbose=True, save_best_model=args.do_save, load_best_model=True, log_writer=quick_init_out.log_writer, ) if is_resumed: metarunner.load_state(checkpoint["metarunner_state"]) del checkpoint["metarunner_state"] metarunner.run_train_loop() if args.do_save: # allow custom best model name if args.custom_best_name: best_model_name = os.path.join(args.output_dir, f"{args.custom_best_name}.p") else: best_model_name = os.path.join(args.output_dir, "model.p") torch.save( torch_utils.get_model_for_saving( runner.jiant_model).state_dict(), best_model_name, ) if args.do_val: val_results_dict = runner.run_val( task_name_list=runner.jiant_task_container.task_run_config. val_task_list, return_preds=args.write_val_preds, ) jiant_evaluate.write_val_results( val_results_dict=val_results_dict, metrics_aggregator=runner.jiant_task_container. metrics_aggregator, output_dir=args.output_dir, verbose=True, val_jsonl=args.val_jsonl, ) if args.args_jsonl: # match arguments with verbose results initialization.save_args(args, verbose=True, matched=True) if args.write_val_preds: if args.extract_exp_name_valpreds: exp_name = os.path.basename( args.jiant_task_container_config_path).split(".")[0] val_fname = f"val_preds_{exp_name}.p" else: val_fname = "val_preds.p" jiant_evaluate.write_preds( eval_results_dict=val_results_dict, path=os.path.join(args.output_dir, val_fname), ) else: assert not args.write_val_preds if args.write_test_preds: test_results_dict = runner.run_test( task_name_list=runner.jiant_task_container.task_run_config. test_task_list, ) jiant_evaluate.write_preds( eval_results_dict=test_results_dict, path=os.path.join(args.output_dir, "test_preds.p"), ) if args.delete_checkpoint_if_done and args.save_checkpoint_every_steps: os.remove(os.path.join(args.output_dir, checkpoint_name))