コード例 #1
0
ファイル: core.py プロジェクト: nyu-mll/CNLI-generalization
 def compute_metrics_from_accumulator(
     self, task, accumulator: BaseAccumulator, tokenizer, labels
 ) -> Metrics:
     logits = accumulator.get_accumulated()
     results, predictions = squad_style.compute_predictions_logits_v3(
         data_rows=labels,
         logits=logits,
         n_best_size=task.n_best_size,
         max_answer_length=task.max_answer_length,
         do_lower_case=model_resolution.resolve_is_lower_case(tokenizer),
         version_2_with_negative=task.version_2_with_negative,
         null_score_diff_threshold=task.null_score_diff_threshold,
         tokenizer=tokenizer,
     )
     return Metrics(major=(results["f1"] + results["exact"]) / 2, minor=results,)
コード例 #2
0
 def compute_metrics_from_accumulator(
     self, task, accumulator: BaseAccumulator, tokenizer, labels
 ) -> Metrics:
     logits = accumulator.get_accumulated()
     assert isinstance(task, (tasks.TyDiQATask, tasks.XquadTask))
     lang = task.language
     results, predictions = squad_style.compute_predictions_logits_v3(
         data_rows=labels,
         logits=logits,
         n_best_size=task.n_best_size,
         max_answer_length=task.max_answer_length,
         do_lower_case=model_resolution.resolve_is_lower_case(tokenizer),
         version_2_with_negative=task.version_2_with_negative,
         null_score_diff_threshold=task.null_score_diff_threshold,
         skip_get_final_text=(lang == "zh"),
         tokenizer=tokenizer,
     )
     return Metrics(major=(results["f1"] + results["exact"]) / 2, minor=results,)
コード例 #3
0
ファイル: xtreme_submission.py プロジェクト: yzpang/jiant
def generate_and_write_preds_for_qa(runner,
                                    supertask: str,
                                    output_dir: str,
                                    phase: str,
                                    skip_if_done: bool = False):
    """Generate predictions (test) for QA tasks and write them in XTREME submission format"""
    preds_pickle_path = os.path.join(output_dir, f"{supertask}_test_preds.p")
    if skip_if_done and os.path.exists(preds_pickle_path):
        print(f"Skipping cause {preds_pickle_path} exists")
        return

    if phase == "val":
        task_name_list = runner.jiant_task_container.task_run_config.val_task_list
    elif phase == "test":
        task_name_list = runner.jiant_task_container.task_run_config.test_task_list
    else:
        raise KeyError(phase)
    task_name_list = [
        task_name for task_name in task_name_list
        if task_name.startswith(supertask)
    ]
    if phase == "val":
        test_results_dict = runner.run_val(task_name_list=task_name_list)
    elif phase == "test":
        test_results_dict = {}
        test_dataloader_dict = runner.get_test_dataloader_dict()
        for task_name in task_name_list:
            test_results_dict[task_name] = jiant_runner.run_test(
                test_dataloader=test_dataloader_dict[task_name],
                jiant_model=runner.jiant_model,
                task=runner.jiant_task_container.task_dict[task_name],
                device=runner.device,
                local_rank=runner.rparams.local_rank,
                return_preds=False,
                verbose=True,
            )
    else:
        raise KeyError(phase)

    # Generate QA preds
    tokenizer = runner.model.tokenizer
    for task_name in task_name_list:
        task_results = test_results_dict[task_name]
        task = runner.jiant_task_container.task_dict[task_name]
        logits = task_results["accumulator"].get_accumulated()
        lang = get_qa_language(supertask=supertask, task=task)
        if phase == "val":
            cached = runner.get_val_dataloader_dict([
                task_name
            ])[task_name].dataset.chunked_file_data_cache.get_all()
        elif phase == "test":
            cached = runner.get_test_dataloader_dict(
            )[task_name].dataset.chunked_file_data_cache.get_all()
        else:
            raise KeyError(phase)
        data_rows = [row["data_row"] for row in cached]
        results, predictions = squad_lib.compute_predictions_logits_v3(
            data_rows=data_rows,
            logits=logits,
            n_best_size=task.n_best_size,
            max_answer_length=task.max_answer_length,
            do_lower_case=model_resolution.resolve_is_lower_case(tokenizer),
            version_2_with_negative=task.version_2_with_negative,
            null_score_diff_threshold=task.null_score_diff_threshold,
            skip_get_final_text=(lang == "zh"),
            tokenizer=tokenizer,
        )
        test_results_dict[task_name]["preds"] = predictions

    jiant_evaluate.write_preds(
        eval_results_dict=test_results_dict,
        path=preds_pickle_path,
    )
    preds_output_dir = os.path.join(output_dir, "preds", supertask)
    os.makedirs(preds_output_dir, exist_ok=True)
    for task_name, task_results in test_results_dict.items():
        task = runner.jiant_task_container.task_dict[task_name]
        lang = get_qa_language(supertask=supertask, task=task)
        py_io.write_json(task_results["preds"],
                         os.path.join(preds_output_dir, f"test-{lang}.json"))
    print(f"Wrote {supertask} preds for {len(test_results_dict)} languages")