def generate_and_write_preds_for_tagging(runner: jiant_runner.JiantRunner, supertask: str, output_dir: str, skip_if_done: bool = False): """Generate and write test predictions for tagging tasks in XTREME submission format""" preds_pickle_path = os.path.join(output_dir, f"{supertask}_test_preds.p") if skip_if_done and os.path.exists(preds_pickle_path): print(f"Skipping cause {preds_pickle_path} exists") return test_dataloader_dict = runner.get_test_dataloader_dict() preds_output_dir = os.path.join(output_dir, "preds", supertask) os.makedirs(preds_output_dir, exist_ok=True) preds_dict = {} for task_name in runner.jiant_task_container.task_run_config.test_task_list: task = runner.jiant_task_container.task_dict[task_name] assert isinstance(task, (tasks.UdposTask, tasks.PanxTask)) preds_list = get_preds_for_single_tagging_task( task=task, test_dataloader=test_dataloader_dict[task_name], runner=runner, ) preds_dict[task_name] = preds_list lang = task.language with open(os.path.join(preds_output_dir, f"test-{lang}.tsv"), "w") as f: for example_preds in preds_list: for word, label in example_preds: f.write(f"{label}\n") f.write("\n") torch.save(preds_dict, preds_pickle_path) print( f"Wrote {supertask} preds for" f" {len(runner.jiant_task_container.task_run_config.test_task_list)} languages" )
def generate_and_write_preds_for_classification( runner: jiant_runner.JiantRunner, supertask: str, output_dir: str, skip_if_done: bool = False): """Write test predictions for classification tasks in XTREME submission format""" preds_pickle_path = os.path.join(output_dir, f"{supertask}_test_preds.p") if skip_if_done and os.path.exists(preds_pickle_path): print(f"Skipping cause {preds_pickle_path} exists") return test_results_dict = runner.run_test( task_name_list=runner.jiant_task_container.task_run_config. test_task_list, ) jiant_evaluate.write_preds( eval_results_dict=test_results_dict, path=preds_pickle_path, ) preds_output_dir = os.path.join(output_dir, "preds", supertask) os.makedirs(preds_output_dir, exist_ok=True) for task_name, task_results in test_results_dict.items(): task = runner.jiant_task_container.task_dict[task_name] assert isinstance(task, (tasks.XnliTask, tasks.PawsXTask)) lang = task.language with open(os.path.join(preds_output_dir, f"test-{lang}.tsv"), "w") as f: for idx in task_results["preds"]: if supertask == "xnli": pred_label = task.ID_TO_LABEL[idx] elif supertask == "pawsx": pred_label = idx else: raise RuntimeError() f.write(f"{pred_label}\n") print(f"Wrote {supertask} preds for {len(test_results_dict)} languages")