예제 #1
0
def generate_and_write_preds_for_tagging(runner: jiant_runner.JiantRunner,
                                         supertask: str,
                                         output_dir: str,
                                         skip_if_done: bool = False):
    """Generate and write test predictions for tagging tasks in XTREME submission format"""
    preds_pickle_path = os.path.join(output_dir, f"{supertask}_test_preds.p")
    if skip_if_done and os.path.exists(preds_pickle_path):
        print(f"Skipping cause {preds_pickle_path} exists")
        return

    test_dataloader_dict = runner.get_test_dataloader_dict()
    preds_output_dir = os.path.join(output_dir, "preds", supertask)
    os.makedirs(preds_output_dir, exist_ok=True)
    preds_dict = {}
    for task_name in runner.jiant_task_container.task_run_config.test_task_list:
        task = runner.jiant_task_container.task_dict[task_name]
        assert isinstance(task, (tasks.UdposTask, tasks.PanxTask))
        preds_list = get_preds_for_single_tagging_task(
            task=task,
            test_dataloader=test_dataloader_dict[task_name],
            runner=runner,
        )
        preds_dict[task_name] = preds_list
        lang = task.language
        with open(os.path.join(preds_output_dir, f"test-{lang}.tsv"),
                  "w") as f:
            for example_preds in preds_list:
                for word, label in example_preds:
                    f.write(f"{label}\n")
                f.write("\n")
    torch.save(preds_dict, preds_pickle_path)
    print(
        f"Wrote {supertask} preds for"
        f" {len(runner.jiant_task_container.task_run_config.test_task_list)} languages"
    )
예제 #2
0
def generate_and_write_preds_for_classification(
        runner: jiant_runner.JiantRunner,
        supertask: str,
        output_dir: str,
        skip_if_done: bool = False):
    """Write test predictions for classification tasks in XTREME submission format"""
    preds_pickle_path = os.path.join(output_dir, f"{supertask}_test_preds.p")
    if skip_if_done and os.path.exists(preds_pickle_path):
        print(f"Skipping cause {preds_pickle_path} exists")
        return

    test_results_dict = runner.run_test(
        task_name_list=runner.jiant_task_container.task_run_config.
        test_task_list, )
    jiant_evaluate.write_preds(
        eval_results_dict=test_results_dict,
        path=preds_pickle_path,
    )
    preds_output_dir = os.path.join(output_dir, "preds", supertask)
    os.makedirs(preds_output_dir, exist_ok=True)
    for task_name, task_results in test_results_dict.items():
        task = runner.jiant_task_container.task_dict[task_name]
        assert isinstance(task, (tasks.XnliTask, tasks.PawsXTask))
        lang = task.language
        with open(os.path.join(preds_output_dir, f"test-{lang}.tsv"),
                  "w") as f:
            for idx in task_results["preds"]:
                if supertask == "xnli":
                    pred_label = task.ID_TO_LABEL[idx]
                elif supertask == "pawsx":
                    pred_label = idx
                else:
                    raise RuntimeError()
                f.write(f"{pred_label}\n")
    print(f"Wrote {supertask} preds for {len(test_results_dict)} languages")