Exemple #1
0
 def get_runner_state(self):
     # TODO: Add fp16  (issue #1186)
     state = {
         "model": torch_utils.get_model_for_saving(self.jiant_model).state_dict(),
         "optimizer": self.optimizer_scheduler.optimizer.state_dict(),
     }
     return state
Exemple #2
0
def save_model_with_metadata(model: nn.Module,
                             metadata: dict,
                             output_dir: str,
                             file_name="model"):
    torch.save(
        torch_utils.get_model_for_saving(model).state_dict(),
        os.path.join(output_dir, f"{file_name}.p"),
    )
    py_io.write_json(metadata,
                     os.path.join(output_dir, f"{file_name}.metadata.json"))
Exemple #3
0
 def save_model(self):
     """Override to save only optimized parameters"""
     file_name = f"model__{self.train_state.global_steps:09d}"
     torch.save(
         adapters_modeling.
         get_optimized_state_dict_for_jiant_model_with_adapters(
             torch_utils.get_model_for_saving(self.model)),
         os.path.join(self.output_dir, f"{file_name}.p"),
     )
     py_io.write_json(
         "{}", os.path.join(self.output_dir, f"{file_name}.metadata.json"))
Exemple #4
0
def save_model_with_metadata(model: nn.Module,
                             metadata: dict,
                             output_dir: str,
                             file_name="model"):
    torch.save(
        adapters_modeling.
        get_optimized_state_dict_for_jiant_model_with_adapters(
            torch_utils.get_model_for_saving(model)),
        os.path.join(output_dir, f"{file_name}.p"),
    )
    py_io.write_json(metadata,
                     os.path.join(output_dir, f"{file_name}.metadata.json"))
Exemple #5
0
def save_model_with_metadata(
    model_or_state_dict: Union[nn.Module, dict],
    output_dir: str,
    file_name="model",
    metadata: Optional[dict] = None,
):
    if isinstance(model_or_state_dict, dict):
        state_dict = model_or_state_dict
    else:
        state_dict = torch_utils.get_model_for_saving(model_or_state_dict).state_dict()

    torch.save(state_dict, os.path.join(output_dir, f"{file_name}.p"))
    if metadata is not None:
        py_io.write_json(metadata, os.path.join(output_dir, f"{file_name}.metadata.json"))
Exemple #6
0
 def eval_save(self):
     self.num_evals_since_improvement += 1
     val_results_dict = self.runner.run_val(
         task_name_list=self.runner.jiant_task_container.task_run_config.
         train_val_task_list,
         use_subset=True,
     )
     aggregated_major = jiant_task_sampler.compute_aggregate_major_metrics_from_results_dict(
         metrics_aggregator=self.runner.jiant_task_container.
         metrics_aggregator,
         results_dict=val_results_dict,
     )
     val_metrics_dict = jiant_task_sampler.get_metrics_dict_from_results_dict(
         results_dict=val_results_dict, )
     val_state = ValState(
         score=float(aggregated_major),
         metrics=val_metrics_dict,
         train_state=self.train_state.new(),
     )
     self.log_writer.write_entry("train_val", val_state.to_dict())
     if self.best_val_state is None or val_state.score > self.best_val_state.score:
         self.best_val_state = val_state.new()
         self.log_writer.write_entry("train_val_best",
                                     self.best_val_state.to_dict())
         if self.save_best_model:
             save_model_with_metadata(
                 model=self.model,
                 metadata={
                     "val_state": self.best_val_state.to_dict(),
                     "val_metrics": val_metrics_dict,
                 },
                 output_dir=self.output_dir,
                 file_name="best_model",
             )
         del self.best_state_dict
         self.best_state_dict = copy_state_dict(
             state_dict=get_model_for_saving(self.model).state_dict(),
             target_device=CPU_DEVICE,
         )
         self.num_evals_since_improvement = 0
     self.log_writer.write_entry(
         "early_stopping",
         {
             "num_evals_since_improvement":
             self.num_evals_since_improvement,
             "train_state": self.train_state.to_dict(),
         },
     )
     self.log_writer.flush()
     self.val_state_history.append(val_state)
Exemple #7
0
 def load_state(self, runner_state):
     torch_utils.get_model_for_saving(self.jiant_model).load_state_dict(runner_state["model"])
     self.optimizer_scheduler.optimizer.load_state_dict(runner_state["optimizer"])
Exemple #8
0
def run_loop(args: RunConfiguration, checkpoint=None):
    is_resumed = checkpoint is not None
    quick_init_out = initialization.quick_init(args=args, verbose=True)
    print(quick_init_out.n_gpu)
    with quick_init_out.log_writer.log_context():
        jiant_task_container = container_setup.create_jiant_task_container_from_json(
            jiant_task_container_config_path=args.
            jiant_task_container_config_path,
            verbose=True,
        )
        runner = setup_runner(
            args=args,
            jiant_task_container=jiant_task_container,
            quick_init_out=quick_init_out,
            verbose=True,
        )
        if is_resumed:
            runner.load_state(checkpoint["runner_state"])
            del checkpoint["runner_state"]
        checkpoint_saver = jiant_runner.CheckpointSaver(
            metadata={"args": args.to_dict()},
            save_path=os.path.join(args.output_dir, "checkpoint.p"),
        )
        if args.do_train:
            metarunner = jiant_metarunner.JiantMetarunner(
                runner=runner,
                save_every_steps=args.save_every_steps,
                eval_every_steps=args.eval_every_steps,
                save_checkpoint_every_steps=args.save_checkpoint_every_steps,
                no_improvements_for_n_evals=args.no_improvements_for_n_evals,
                checkpoint_saver=checkpoint_saver,
                output_dir=args.output_dir,
                verbose=True,
                save_best_model=args.do_save,
                load_best_model=True,
                log_writer=quick_init_out.log_writer,
            )
            if is_resumed:
                metarunner.load_state(checkpoint["metarunner_state"])
                del checkpoint["metarunner_state"]
            metarunner.run_train_loop()

        if args.do_save:
            torch.save(
                torch_utils.get_model_for_saving(
                    runner.jiant_model).state_dict(),
                os.path.join(args.output_dir, "model.p"),
            )

        if args.do_val:
            val_results_dict = runner.run_val(
                task_name_list=runner.jiant_task_container.task_run_config.
                val_task_list,
                return_preds=args.write_val_preds,
            )
            jiant_evaluate.write_val_results(
                val_results_dict=val_results_dict,
                metrics_aggregator=runner.jiant_task_container.
                metrics_aggregator,
                output_dir=args.output_dir,
                verbose=True,
            )
            if args.write_val_preds:
                jiant_evaluate.write_preds(
                    eval_results_dict=val_results_dict,
                    path=os.path.join(args.output_dir, "val_preds.p"),
                )
        else:
            assert not args.write_val_preds

        if args.write_test_preds:
            test_results_dict = runner.run_test(
                task_name_list=runner.jiant_task_container.task_run_config.
                test_task_list, )
            jiant_evaluate.write_preds(
                eval_results_dict=test_results_dict,
                path=os.path.join(args.output_dir, "test_preds.p"),
            )

    if (args.delete_checkpoint_if_done and args.save_checkpoint_every_steps
            and os.path.exists(os.path.join(args.output_dir, "checkpoint.p"))):
        os.remove(os.path.join(args.output_dir, "checkpoint.p"))

    py_io.write_file("DONE", os.path.join(args.output_dir, "done_file"))
def run_loop(args: RunConfiguration, checkpoint=None):
    is_resumed = checkpoint is not None
    quick_init_out = initialization.quick_init(args=args, verbose=True)
    print(quick_init_out.n_gpu)
    with quick_init_out.log_writer.log_context():
        jiant_task_container = container_setup.create_jiant_task_container_from_json(
            jiant_task_container_config_path=args.
            jiant_task_container_config_path,
            verbose=True,
        )
        runner = setup_runner(
            args=args,
            jiant_task_container=jiant_task_container,
            quick_init_out=quick_init_out,
            verbose=True,
        )
        if is_resumed:
            runner.load_state(checkpoint["runner_state"])
            del checkpoint["runner_state"]

        # allow custom checkpoint name
        if args.custom_checkpoint_name:
            checkpoint_name = os.path.join(args.output_dir,
                                           f"{args.custom_checkpoint_name}.p")
        else:
            checkpoint_name = os.path.join(args.output_dir, "checkpoint.p")

        checkpoint_saver = jiant_runner.CheckpointSaver(
            metadata={"args": args.to_dict()},
            save_path=os.path.join(args.output_dir, checkpoint_name),
        )
        if args.do_train:
            metarunner = jiant_metarunner.JiantMetarunner(
                runner=runner,
                save_every_steps=args.save_every_steps,
                eval_every_steps=args.eval_every_steps,
                save_checkpoint_every_steps=args.save_checkpoint_every_steps,
                no_improvements_for_n_evals=args.no_improvements_for_n_evals,
                checkpoint_saver=checkpoint_saver,
                output_dir=args.output_dir,
                verbose=True,
                save_best_model=args.do_save,
                load_best_model=True,
                log_writer=quick_init_out.log_writer,
            )
            if is_resumed:
                metarunner.load_state(checkpoint["metarunner_state"])
                del checkpoint["metarunner_state"]
            metarunner.run_train_loop()

        if args.do_save:
            # allow custom best model name
            if args.custom_best_name:
                best_model_name = os.path.join(args.output_dir,
                                               f"{args.custom_best_name}.p")
            else:
                best_model_name = os.path.join(args.output_dir, "model.p")

            torch.save(
                torch_utils.get_model_for_saving(
                    runner.jiant_model).state_dict(),
                best_model_name,
            )

        if args.do_val:
            val_results_dict = runner.run_val(
                task_name_list=runner.jiant_task_container.task_run_config.
                val_task_list,
                return_preds=args.write_val_preds,
            )
            jiant_evaluate.write_val_results(
                val_results_dict=val_results_dict,
                metrics_aggregator=runner.jiant_task_container.
                metrics_aggregator,
                output_dir=args.output_dir,
                verbose=True,
                val_jsonl=args.val_jsonl,
            )

            if args.args_jsonl:
                # match arguments with verbose results
                initialization.save_args(args, verbose=True, matched=True)

            if args.write_val_preds:
                if args.extract_exp_name_valpreds:
                    exp_name = os.path.basename(
                        args.jiant_task_container_config_path).split(".")[0]
                    val_fname = f"val_preds_{exp_name}.p"
                else:
                    val_fname = "val_preds.p"
                jiant_evaluate.write_preds(
                    eval_results_dict=val_results_dict,
                    path=os.path.join(args.output_dir, val_fname),
                )
        else:
            assert not args.write_val_preds

        if args.write_test_preds:
            test_results_dict = runner.run_test(
                task_name_list=runner.jiant_task_container.task_run_config.
                test_task_list, )
            jiant_evaluate.write_preds(
                eval_results_dict=test_results_dict,
                path=os.path.join(args.output_dir, "test_preds.p"),
            )

    if args.delete_checkpoint_if_done and args.save_checkpoint_every_steps:
        os.remove(os.path.join(args.output_dir, checkpoint_name))