def generate_and_write_preds_for_classification( runner: jiant_runner.JiantRunner, supertask: str, output_dir: str, skip_if_done: bool = False): """Write test predictions for classification tasks in XTREME submission format""" preds_pickle_path = os.path.join(output_dir, f"{supertask}_test_preds.p") if skip_if_done and os.path.exists(preds_pickle_path): print(f"Skipping cause {preds_pickle_path} exists") return test_results_dict = runner.run_test( task_name_list=runner.jiant_task_container.task_run_config. test_task_list, ) jiant_evaluate.write_preds( eval_results_dict=test_results_dict, path=preds_pickle_path, ) preds_output_dir = os.path.join(output_dir, "preds", supertask) os.makedirs(preds_output_dir, exist_ok=True) for task_name, task_results in test_results_dict.items(): task = runner.jiant_task_container.task_dict[task_name] assert isinstance(task, (tasks.XnliTask, tasks.PawsXTask)) lang = task.language with open(os.path.join(preds_output_dir, f"test-{lang}.tsv"), "w") as f: for idx in task_results["preds"]: if supertask == "xnli": pred_label = task.ID_TO_LABEL[idx] elif supertask == "pawsx": pred_label = idx else: raise RuntimeError() f.write(f"{pred_label}\n") print(f"Wrote {supertask} preds for {len(test_results_dict)} languages")
def generate_and_write_preds_for_tatoeba(runner, output_dir: str, skip_if_done: bool = False): """Generate predictions (val) for Tateoba and write them in XTREME submission format""" preds_pickle_path = os.path.join(output_dir, "tatoeba_val_preds.p") if skip_if_done and os.path.exists(preds_pickle_path): print(f"Skipping cause {preds_pickle_path} exists") return val_results_dict = runner.run_val( task_name_list=runner.jiant_task_container.task_run_config. val_task_list, return_preds=True, ) jiant_evaluate.write_preds( eval_results_dict=val_results_dict, path=preds_pickle_path, ) preds_output_dir = os.path.join(output_dir, "preds", "tatoeba") os.makedirs(preds_output_dir, exist_ok=True) for task_name, task_results in val_results_dict.items(): lang = runner.jiant_task_container.task_dict[task_name].language with open(os.path.join(preds_output_dir, f"test-{lang}.tsv"), "w") as f: for idx in task_results["preds"]: f.write(f"{idx:d}\n") print(f"Wrote Tatoeba preds for {len(val_results_dict)} languages")
def generate_and_write_preds_for_bucc2018(runner, output_dir: str, bucc_val_metrics_path: str, skip_if_done: bool = False): """Generate predictions (test) for Bucc2018 and write them in XTREME submission format""" preds_pickle_path = os.path.join(output_dir, "bucc2018_test_preds.p") if skip_if_done and os.path.exists(preds_pickle_path): print(f"Skipping cause {preds_pickle_path} exists") return else: print(f"{preds_pickle_path} does not exist") if bucc_val_metrics_path is None: # Recompute thresholds: val_results_dict = runner.run_val( task_name_list=runner.jiant_task_container.task_run_config. val_task_list, return_preds=True, ) jiant_evaluate.write_preds( eval_results_dict=val_results_dict, path=os.path.join(output_dir, "bucc2018_val_preds.p"), ) thresholds_dict = { task_name: task_results["metrics"].minor["best-threshold"] for task_name, task_results in val_results_dict.items() } else: val_metrics = py_io.read_json(bucc_val_metrics_path) thresholds_dict = { task_name: val_metrics[task_name]["metrics"]["minor"]["best-threshold"] for task_name in runner.jiant_task_container.task_run_config.val_task_list } preds_output_dir = os.path.join(output_dir, "preds", "bucc2018") os.makedirs(preds_output_dir, exist_ok=True) test_results_dict = runner.run_test( task_name_list=runner.jiant_task_container.task_run_config. test_task_list, ) jiant_evaluate.write_preds( eval_results_dict=test_results_dict, path=preds_pickle_path, ) for task_name, task_results in test_results_dict.items(): bitext = bucc2018_lib.bucc_extract( cand2score=task_results["preds"], th=thresholds_dict[task_name], ) lang = runner.jiant_task_container.task_dict[task_name].language with open(os.path.join(preds_output_dir, f"test-{lang}.tsv"), "w") as f: for src, trg in bitext: f.write(f"{src}\t{trg}\n") print(f"Wrote Bucc2018 preds for {len(test_results_dict)} languages")
def generate_and_write_preds_for_qa(runner, supertask: str, output_dir: str, phase: str, skip_if_done: bool = False): """Generate predictions (test) for QA tasks and write them in XTREME submission format""" preds_pickle_path = os.path.join(output_dir, f"{supertask}_test_preds.p") if skip_if_done and os.path.exists(preds_pickle_path): print(f"Skipping cause {preds_pickle_path} exists") return if phase == "val": task_name_list = runner.jiant_task_container.task_run_config.val_task_list elif phase == "test": task_name_list = runner.jiant_task_container.task_run_config.test_task_list else: raise KeyError(phase) task_name_list = [ task_name for task_name in task_name_list if task_name.startswith(supertask) ] if phase == "val": test_results_dict = runner.run_val(task_name_list=task_name_list) elif phase == "test": test_results_dict = {} test_dataloader_dict = runner.get_test_dataloader_dict() for task_name in task_name_list: test_results_dict[task_name] = jiant_runner.run_test( test_dataloader=test_dataloader_dict[task_name], jiant_model=runner.jiant_model, task=runner.jiant_task_container.task_dict[task_name], device=runner.device, local_rank=runner.rparams.local_rank, return_preds=False, verbose=True, ) else: raise KeyError(phase) # Generate QA preds tokenizer = runner.model.tokenizer for task_name in task_name_list: task_results = test_results_dict[task_name] task = runner.jiant_task_container.task_dict[task_name] logits = task_results["accumulator"].get_accumulated() lang = get_qa_language(supertask=supertask, task=task) if phase == "val": cached = runner.get_val_dataloader_dict([ task_name ])[task_name].dataset.chunked_file_data_cache.get_all() elif phase == "test": cached = runner.get_test_dataloader_dict( )[task_name].dataset.chunked_file_data_cache.get_all() else: raise KeyError(phase) data_rows = [row["data_row"] for row in cached] results, predictions = squad_lib.compute_predictions_logits_v3( data_rows=data_rows, logits=logits, n_best_size=task.n_best_size, max_answer_length=task.max_answer_length, do_lower_case=model_resolution.resolve_is_lower_case(tokenizer), version_2_with_negative=task.version_2_with_negative, null_score_diff_threshold=task.null_score_diff_threshold, skip_get_final_text=(lang == "zh"), tokenizer=tokenizer, ) test_results_dict[task_name]["preds"] = predictions jiant_evaluate.write_preds( eval_results_dict=test_results_dict, path=preds_pickle_path, ) preds_output_dir = os.path.join(output_dir, "preds", supertask) os.makedirs(preds_output_dir, exist_ok=True) for task_name, task_results in test_results_dict.items(): task = runner.jiant_task_container.task_dict[task_name] lang = get_qa_language(supertask=supertask, task=task) py_io.write_json(task_results["preds"], os.path.join(preds_output_dir, f"test-{lang}.json")) print(f"Wrote {supertask} preds for {len(test_results_dict)} languages")
def run_loop(args: RunConfiguration, checkpoint=None, tasks=None): is_resumed = checkpoint is not None quick_init_out = initialization.quick_init(args=args, verbose=True) print(quick_init_out.n_gpu) with quick_init_out.log_writer.log_context(): jiant_task_container = container_setup.create_jiant_task_container_from_json( jiant_task_container_config_path=args. jiant_task_container_config_path, verbose=True, ) runner = setup_runner( args=args, jiant_task_container=jiant_task_container, quick_init_out=quick_init_out, verbose=True, ) if is_resumed: runner.load_state(checkpoint["runner_state"]) del checkpoint["runner_state"] checkpoint_saver = jiant_runner.CheckpointSaver( metadata={"args": args.to_dict()}, save_path=os.path.join(args.output_dir, "checkpoint.p"), ) if args.do_val: print("EVAL_BEFORE________________________________") val_results_dict = runner.run_val( task_name_list=runner.jiant_task_container.task_run_config. val_task_list, return_preds=args.write_val_preds, ) jiant_evaluate.write_val_results( val_results_dict=val_results_dict, metrics_aggregator=runner.jiant_task_container. metrics_aggregator, output_dir=args.output_dir, verbose=True, ) if args.write_val_preds: jiant_evaluate.write_preds( eval_results_dict=val_results_dict, path=os.path.join(args.output_dir, "val_preds.p"), ) if args.do_train: metarunner = jiant_metarunner.JiantMetarunner( runner=runner, save_every_steps=args.save_every_steps, eval_every_steps=args.eval_every_steps, save_checkpoint_every_steps=args.save_checkpoint_every_steps, no_improvements_for_n_evals=args.no_improvements_for_n_evals, checkpoint_saver=checkpoint_saver, output_dir=args.output_dir, verbose=True, save_best_model=True, save_last_model=args.do_save or args.do_save_last, load_best_model=True, log_writer=quick_init_out.log_writer, ) if is_resumed: metarunner.load_state(checkpoint["metarunner_state"]) del checkpoint["metarunner_state"] metarunner.run_train_loop() runner.run_perturb(tasks, metarunner) if args.do_val: val_results_dict = runner.run_val( task_name_list=runner.jiant_task_container.task_run_config. val_task_list, return_preds=args.write_val_preds, ) jiant_evaluate.write_val_results( val_results_dict=val_results_dict, metrics_aggregator=runner.jiant_task_container. metrics_aggregator, output_dir=args.output_dir, verbose=True, ) if args.write_val_preds: jiant_evaluate.write_preds( eval_results_dict=val_results_dict, path=os.path.join(args.output_dir, "val_preds.p"), ) else: assert not args.write_val_preds if args.write_test_preds: test_results_dict = runner.run_test( task_name_list=runner.jiant_task_container.task_run_config. test_task_list, ) jiant_evaluate.write_preds( eval_results_dict=test_results_dict, path=os.path.join(args.output_dir, "test_preds.p"), ) print("EVAL DONE------------------------") if (not args.keep_checkpoint_when_done and args.save_checkpoint_every_steps and os.path.exists(os.path.join(args.output_dir, "checkpoint.p"))): os.remove(os.path.join(args.output_dir, "checkpoint.p")) py_io.write_file("DONE", os.path.join(args.output_dir, "done_file"))
def run_loop(args: RunConfiguration, checkpoint=None): is_resumed = checkpoint is not None quick_init_out = initialization.quick_init(args=args, verbose=True) print(quick_init_out.n_gpu) with quick_init_out.log_writer.log_context(): jiant_task_container = container_setup.create_jiant_task_container_from_json( jiant_task_container_config_path=args. jiant_task_container_config_path, verbose=True, ) runner = setup_runner( args=args, jiant_task_container=jiant_task_container, quick_init_out=quick_init_out, verbose=True, ) if is_resumed: runner.load_state(checkpoint["runner_state"]) del checkpoint["runner_state"] # allow custom checkpoint name if args.custom_checkpoint_name: checkpoint_name = os.path.join(args.output_dir, f"{args.custom_checkpoint_name}.p") else: checkpoint_name = os.path.join(args.output_dir, "checkpoint.p") checkpoint_saver = jiant_runner.CheckpointSaver( metadata={"args": args.to_dict()}, save_path=os.path.join(args.output_dir, checkpoint_name), ) if args.do_train: metarunner = jiant_metarunner.JiantMetarunner( runner=runner, save_every_steps=args.save_every_steps, eval_every_steps=args.eval_every_steps, save_checkpoint_every_steps=args.save_checkpoint_every_steps, no_improvements_for_n_evals=args.no_improvements_for_n_evals, checkpoint_saver=checkpoint_saver, output_dir=args.output_dir, verbose=True, save_best_model=args.do_save, load_best_model=True, log_writer=quick_init_out.log_writer, ) if is_resumed: metarunner.load_state(checkpoint["metarunner_state"]) del checkpoint["metarunner_state"] metarunner.run_train_loop() if args.do_save: # allow custom best model name if args.custom_best_name: best_model_name = os.path.join(args.output_dir, f"{args.custom_best_name}.p") else: best_model_name = os.path.join(args.output_dir, "model.p") torch.save( torch_utils.get_model_for_saving( runner.jiant_model).state_dict(), best_model_name, ) if args.do_val: val_results_dict = runner.run_val( task_name_list=runner.jiant_task_container.task_run_config. val_task_list, return_preds=args.write_val_preds, ) jiant_evaluate.write_val_results( val_results_dict=val_results_dict, metrics_aggregator=runner.jiant_task_container. metrics_aggregator, output_dir=args.output_dir, verbose=True, val_jsonl=args.val_jsonl, ) if args.args_jsonl: # match arguments with verbose results initialization.save_args(args, verbose=True, matched=True) if args.write_val_preds: if args.extract_exp_name_valpreds: exp_name = os.path.basename( args.jiant_task_container_config_path).split(".")[0] val_fname = f"val_preds_{exp_name}.p" else: val_fname = "val_preds.p" jiant_evaluate.write_preds( eval_results_dict=val_results_dict, path=os.path.join(args.output_dir, val_fname), ) else: assert not args.write_val_preds if args.write_test_preds: test_results_dict = runner.run_test( task_name_list=runner.jiant_task_container.task_run_config. test_task_list, ) jiant_evaluate.write_preds( eval_results_dict=test_results_dict, path=os.path.join(args.output_dir, "test_preds.p"), ) if args.delete_checkpoint_if_done and args.save_checkpoint_every_steps: os.remove(os.path.join(args.output_dir, checkpoint_name))
def run_loop(args: RunConfiguration, checkpoint=None): is_resumed = checkpoint is not None quick_init_out = initialization.quick_init(args=args, verbose=True) print(quick_init_out.n_gpu) with quick_init_out.log_writer.log_context(): jiant_task_container = container_setup.create_jiant_task_container_from_json( jiant_task_container_config_path=args.jiant_task_container_config_path, verbose=True, ) runner = setup_runner( args=args, jiant_task_container=jiant_task_container, quick_init_out=quick_init_out, verbose=True, ) if is_resumed: runner.load_state(checkpoint["runner_state"]) del checkpoint["runner_state"] checkpoint_saver = jiant_runner.CheckpointSaver( metadata={"args": args.to_dict()}, save_path=os.path.join(args.output_dir, "checkpoint.p"), ) if args.do_train: metarunner = adapters_metarunner.AdaptersMetarunner( runner=runner, save_every_steps=args.save_every_steps, eval_every_steps=args.eval_every_steps, save_checkpoint_every_steps=args.save_checkpoint_every_steps, no_improvements_for_n_evals=args.no_improvements_for_n_evals, checkpoint_saver=checkpoint_saver, output_dir=args.output_dir, verbose=True, save_best_model=args.do_save, load_best_model=True, log_writer=quick_init_out.log_writer, ) if is_resumed: metarunner.load_state(checkpoint["metarunner_state"]) del checkpoint["metarunner_state"] metarunner.run_train_loop() if args.do_save: torch.save( adapters_modeling.get_optimized_state_dict_for_jiant_model_with_adapters( torch_utils.get_model_for_saving(runner.jiant_model), ), os.path.join(args.output_dir, "model.p"), ) if args.do_val: val_results_dict = runner.run_val( task_name_list=runner.jiant_task_container.task_run_config.val_task_list, return_preds=args.write_val_preds, ) jiant_evaluate.write_val_results( val_results_dict=val_results_dict, metrics_aggregator=runner.jiant_task_container.metrics_aggregator, output_dir=args.output_dir, verbose=True, ) if args.write_val_preds: jiant_evaluate.write_preds( eval_results_dict=val_results_dict, path=os.path.join(args.output_dir, "val_preds.p"), ) else: assert not args.write_val_preds if args.write_test_preds: test_results_dict = runner.run_test( task_name_list=runner.jiant_task_container.task_run_config.test_task_list, ) jiant_evaluate.write_preds( eval_results_dict=test_results_dict, path=os.path.join(args.output_dir, "test_preds.p"), ) if args.delete_checkpoint_if_done and args.save_checkpoint_every_steps: os.remove(os.path.join(args.output_dir, "checkpoint.p"))