def compute_metrics_from_accumulator( self, task, accumulator: BaseAccumulator, tokenizer, labels ) -> Metrics: logits = accumulator.get_accumulated() results, predictions = squad_style.compute_predictions_logits_v3( data_rows=labels, logits=logits, n_best_size=task.n_best_size, max_answer_length=task.max_answer_length, do_lower_case=model_resolution.resolve_is_lower_case(tokenizer), version_2_with_negative=task.version_2_with_negative, null_score_diff_threshold=task.null_score_diff_threshold, tokenizer=tokenizer, ) return Metrics(major=(results["f1"] + results["exact"]) / 2, minor=results,)
def compute_metrics_from_accumulator( self, task, accumulator: BaseAccumulator, tokenizer, labels ) -> Metrics: logits = accumulator.get_accumulated() assert isinstance(task, (tasks.TyDiQATask, tasks.XquadTask)) lang = task.language results, predictions = squad_style.compute_predictions_logits_v3( data_rows=labels, logits=logits, n_best_size=task.n_best_size, max_answer_length=task.max_answer_length, do_lower_case=model_resolution.resolve_is_lower_case(tokenizer), version_2_with_negative=task.version_2_with_negative, null_score_diff_threshold=task.null_score_diff_threshold, skip_get_final_text=(lang == "zh"), tokenizer=tokenizer, ) return Metrics(major=(results["f1"] + results["exact"]) / 2, minor=results,)
def generate_and_write_preds_for_qa(runner, supertask: str, output_dir: str, phase: str, skip_if_done: bool = False): """Generate predictions (test) for QA tasks and write them in XTREME submission format""" preds_pickle_path = os.path.join(output_dir, f"{supertask}_test_preds.p") if skip_if_done and os.path.exists(preds_pickle_path): print(f"Skipping cause {preds_pickle_path} exists") return if phase == "val": task_name_list = runner.jiant_task_container.task_run_config.val_task_list elif phase == "test": task_name_list = runner.jiant_task_container.task_run_config.test_task_list else: raise KeyError(phase) task_name_list = [ task_name for task_name in task_name_list if task_name.startswith(supertask) ] if phase == "val": test_results_dict = runner.run_val(task_name_list=task_name_list) elif phase == "test": test_results_dict = {} test_dataloader_dict = runner.get_test_dataloader_dict() for task_name in task_name_list: test_results_dict[task_name] = jiant_runner.run_test( test_dataloader=test_dataloader_dict[task_name], jiant_model=runner.jiant_model, task=runner.jiant_task_container.task_dict[task_name], device=runner.device, local_rank=runner.rparams.local_rank, return_preds=False, verbose=True, ) else: raise KeyError(phase) # Generate QA preds tokenizer = runner.model.tokenizer for task_name in task_name_list: task_results = test_results_dict[task_name] task = runner.jiant_task_container.task_dict[task_name] logits = task_results["accumulator"].get_accumulated() lang = get_qa_language(supertask=supertask, task=task) if phase == "val": cached = runner.get_val_dataloader_dict([ task_name ])[task_name].dataset.chunked_file_data_cache.get_all() elif phase == "test": cached = runner.get_test_dataloader_dict( )[task_name].dataset.chunked_file_data_cache.get_all() else: raise KeyError(phase) data_rows = [row["data_row"] for row in cached] results, predictions = squad_lib.compute_predictions_logits_v3( data_rows=data_rows, logits=logits, n_best_size=task.n_best_size, max_answer_length=task.max_answer_length, do_lower_case=model_resolution.resolve_is_lower_case(tokenizer), version_2_with_negative=task.version_2_with_negative, null_score_diff_threshold=task.null_score_diff_threshold, skip_get_final_text=(lang == "zh"), tokenizer=tokenizer, ) test_results_dict[task_name]["preds"] = predictions jiant_evaluate.write_preds( eval_results_dict=test_results_dict, path=preds_pickle_path, ) preds_output_dir = os.path.join(output_dir, "preds", supertask) os.makedirs(preds_output_dir, exist_ok=True) for task_name, task_results in test_results_dict.items(): task = runner.jiant_task_container.task_dict[task_name] lang = get_qa_language(supertask=supertask, task=task) py_io.write_json(task_results["preds"], os.path.join(preds_output_dir, f"test-{lang}.json")) print(f"Wrote {supertask} preds for {len(test_results_dict)} languages")