def run_loop(args: RunConfiguration, checkpoint=None): is_resumed = checkpoint is not None quick_init_out = initialization.quick_init(args=args, verbose=True) print(quick_init_out.n_gpu) with quick_init_out.log_writer.log_context(): jiant_task_container = container_setup.create_jiant_task_container_from_json( jiant_task_container_config_path=args. jiant_task_container_config_path, verbose=True, ) runner = setup_runner( args=args, jiant_task_container=jiant_task_container, quick_init_out=quick_init_out, verbose=True, ) if is_resumed: runner.load_state(checkpoint["runner_state"]) del checkpoint["runner_state"] # allow custom checkpoint name if args.custom_checkpoint_name: checkpoint_name = os.path.join(args.output_dir, f"{args.custom_checkpoint_name}.p") else: checkpoint_name = os.path.join(args.output_dir, "checkpoint.p") checkpoint_saver = jiant_runner.CheckpointSaver( metadata={"args": args.to_dict()}, save_path=os.path.join(args.output_dir, checkpoint_name), ) if args.do_train: metarunner = jiant_metarunner.JiantMetarunner( runner=runner, save_every_steps=args.save_every_steps, eval_every_steps=args.eval_every_steps, save_checkpoint_every_steps=args.save_checkpoint_every_steps, no_improvements_for_n_evals=args.no_improvements_for_n_evals, checkpoint_saver=checkpoint_saver, output_dir=args.output_dir, verbose=True, save_best_model=args.do_save, load_best_model=True, log_writer=quick_init_out.log_writer, ) if is_resumed: metarunner.load_state(checkpoint["metarunner_state"]) del checkpoint["metarunner_state"] metarunner.run_train_loop() if args.do_save: # allow custom best model name if args.custom_best_name: best_model_name = os.path.join(args.output_dir, f"{args.custom_best_name}.p") else: best_model_name = os.path.join(args.output_dir, "model.p") torch.save( torch_utils.get_model_for_saving( runner.jiant_model).state_dict(), best_model_name, ) if args.do_val: val_results_dict = runner.run_val( task_name_list=runner.jiant_task_container.task_run_config. val_task_list, return_preds=args.write_val_preds, ) jiant_evaluate.write_val_results( val_results_dict=val_results_dict, metrics_aggregator=runner.jiant_task_container. metrics_aggregator, output_dir=args.output_dir, verbose=True, val_jsonl=args.val_jsonl, ) if args.args_jsonl: # match arguments with verbose results initialization.save_args(args, verbose=True, matched=True) if args.write_val_preds: if args.extract_exp_name_valpreds: exp_name = os.path.basename( args.jiant_task_container_config_path).split(".")[0] val_fname = f"val_preds_{exp_name}.p" else: val_fname = "val_preds.p" jiant_evaluate.write_preds( eval_results_dict=val_results_dict, path=os.path.join(args.output_dir, val_fname), ) else: assert not args.write_val_preds if args.write_test_preds: test_results_dict = runner.run_test( task_name_list=runner.jiant_task_container.task_run_config. test_task_list, ) jiant_evaluate.write_preds( eval_results_dict=test_results_dict, path=os.path.join(args.output_dir, "test_preds.p"), ) if args.delete_checkpoint_if_done and args.save_checkpoint_every_steps: os.remove(os.path.join(args.output_dir, checkpoint_name))
def run_loop(args: RunConfiguration, checkpoint=None, tasks=None): is_resumed = checkpoint is not None quick_init_out = initialization.quick_init(args=args, verbose=True) print(quick_init_out.n_gpu) with quick_init_out.log_writer.log_context(): jiant_task_container = container_setup.create_jiant_task_container_from_json( jiant_task_container_config_path=args. jiant_task_container_config_path, verbose=True, ) runner = setup_runner( args=args, jiant_task_container=jiant_task_container, quick_init_out=quick_init_out, verbose=True, ) if is_resumed: runner.load_state(checkpoint["runner_state"]) del checkpoint["runner_state"] checkpoint_saver = jiant_runner.CheckpointSaver( metadata={"args": args.to_dict()}, save_path=os.path.join(args.output_dir, "checkpoint.p"), ) if args.do_val: print("EVAL_BEFORE________________________________") val_results_dict = runner.run_val( task_name_list=runner.jiant_task_container.task_run_config. val_task_list, return_preds=args.write_val_preds, ) jiant_evaluate.write_val_results( val_results_dict=val_results_dict, metrics_aggregator=runner.jiant_task_container. metrics_aggregator, output_dir=args.output_dir, verbose=True, ) if args.write_val_preds: jiant_evaluate.write_preds( eval_results_dict=val_results_dict, path=os.path.join(args.output_dir, "val_preds.p"), ) if args.do_train: metarunner = jiant_metarunner.JiantMetarunner( runner=runner, save_every_steps=args.save_every_steps, eval_every_steps=args.eval_every_steps, save_checkpoint_every_steps=args.save_checkpoint_every_steps, no_improvements_for_n_evals=args.no_improvements_for_n_evals, checkpoint_saver=checkpoint_saver, output_dir=args.output_dir, verbose=True, save_best_model=True, save_last_model=args.do_save or args.do_save_last, load_best_model=True, log_writer=quick_init_out.log_writer, ) if is_resumed: metarunner.load_state(checkpoint["metarunner_state"]) del checkpoint["metarunner_state"] metarunner.run_train_loop() runner.run_perturb(tasks, metarunner) if args.do_val: val_results_dict = runner.run_val( task_name_list=runner.jiant_task_container.task_run_config. val_task_list, return_preds=args.write_val_preds, ) jiant_evaluate.write_val_results( val_results_dict=val_results_dict, metrics_aggregator=runner.jiant_task_container. metrics_aggregator, output_dir=args.output_dir, verbose=True, ) if args.write_val_preds: jiant_evaluate.write_preds( eval_results_dict=val_results_dict, path=os.path.join(args.output_dir, "val_preds.p"), ) else: assert not args.write_val_preds if args.write_test_preds: test_results_dict = runner.run_test( task_name_list=runner.jiant_task_container.task_run_config. test_task_list, ) jiant_evaluate.write_preds( eval_results_dict=test_results_dict, path=os.path.join(args.output_dir, "test_preds.p"), ) print("EVAL DONE------------------------") if (not args.keep_checkpoint_when_done and args.save_checkpoint_every_steps and os.path.exists(os.path.join(args.output_dir, "checkpoint.p"))): os.remove(os.path.join(args.output_dir, "checkpoint.p")) py_io.write_file("DONE", os.path.join(args.output_dir, "done_file"))
def run_loop(args: RunConfiguration, checkpoint=None): is_resumed = checkpoint is not None quick_init_out = initialization.quick_init(args=args, verbose=True) print(quick_init_out.n_gpu) with quick_init_out.log_writer.log_context(): jiant_task_container = container_setup.create_jiant_task_container_from_json( jiant_task_container_config_path=args.jiant_task_container_config_path, verbose=True, ) runner = setup_runner( args=args, jiant_task_container=jiant_task_container, quick_init_out=quick_init_out, verbose=True, ) if is_resumed: runner.load_state(checkpoint["runner_state"]) del checkpoint["runner_state"] checkpoint_saver = jiant_runner.CheckpointSaver( metadata={"args": args.to_dict()}, save_path=os.path.join(args.output_dir, "checkpoint.p"), ) if args.do_train: metarunner = adapters_metarunner.AdaptersMetarunner( runner=runner, save_every_steps=args.save_every_steps, eval_every_steps=args.eval_every_steps, save_checkpoint_every_steps=args.save_checkpoint_every_steps, no_improvements_for_n_evals=args.no_improvements_for_n_evals, checkpoint_saver=checkpoint_saver, output_dir=args.output_dir, verbose=True, save_best_model=args.do_save, load_best_model=True, log_writer=quick_init_out.log_writer, ) if is_resumed: metarunner.load_state(checkpoint["metarunner_state"]) del checkpoint["metarunner_state"] metarunner.run_train_loop() if args.do_save: torch.save( adapters_modeling.get_optimized_state_dict_for_jiant_model_with_adapters( torch_utils.get_model_for_saving(runner.jiant_model), ), os.path.join(args.output_dir, "model.p"), ) if args.do_val: val_results_dict = runner.run_val( task_name_list=runner.jiant_task_container.task_run_config.val_task_list, return_preds=args.write_val_preds, ) jiant_evaluate.write_val_results( val_results_dict=val_results_dict, metrics_aggregator=runner.jiant_task_container.metrics_aggregator, output_dir=args.output_dir, verbose=True, ) if args.write_val_preds: jiant_evaluate.write_preds( eval_results_dict=val_results_dict, path=os.path.join(args.output_dir, "val_preds.p"), ) else: assert not args.write_val_preds if args.write_test_preds: test_results_dict = runner.run_test( task_name_list=runner.jiant_task_container.task_run_config.test_task_list, ) jiant_evaluate.write_preds( eval_results_dict=test_results_dict, path=os.path.join(args.output_dir, "test_preds.p"), ) if args.delete_checkpoint_if_done and args.save_checkpoint_every_steps: os.remove(os.path.join(args.output_dir, "checkpoint.p"))