Exemple #1
0
    def compute_metrics_from_accumulator(
        self, task, accumulator: BaseAccumulator, tokenizer, labels
    ) -> Metrics:

        # Todo: Fix val labels cache
        # This is a quick hack
        logits = accumulator.get_accumulated()
        partial_examples = squad_style.data_rows_to_partial_examples(data_rows=labels)
        all_pred_results = squad_style.logits_to_pred_results_list(logits)
        assert task.context_language == task.question_language
        lang = task.context_language
        predictions = squad_style_utils.compute_predictions_logits_v2(
            partial_examples=partial_examples,
            all_results=all_pred_results,
            n_best_size=task.n_best_size,
            max_answer_length=task.max_answer_length,
            do_lower_case=model_resolution.resolve_is_lower_case(tokenizer),
            version_2_with_negative=task.version_2_with_negative,
            null_score_diff_threshold=task.null_score_diff_threshold,
            tokenizer=tokenizer,
            skip_get_final_text=(lang == "zh"),
            verbose=True,
        )
        dataset = read_json(task.val_path)["data"]
        results = mlqa_lib.evaluate(dataset=dataset, predictions=predictions, lang=lang,)
        return Metrics(major=(results["f1"] + results["exact_match"]) / 2, minor=results,)
Exemple #2
0
def compute_predictions_logits_v3(
    data_rows: List[Union[PartialDataRow, DataRow]],
    logits: np.ndarray,
    n_best_size,
    max_answer_length,
    do_lower_case,
    version_2_with_negative,
    null_score_diff_threshold,
    tokenizer,
    skip_get_final_text=False,
    verbose=True,
):
    """Write final predictions to the json file and log-odds of null if needed."""
    partial_examples = data_rows_to_partial_examples(data_rows)
    all_pred_results = logits_to_pred_results_list(logits)
    predictions = squad_utils.compute_predictions_logits_v2(
        partial_examples=partial_examples,
        all_results=all_pred_results,
        n_best_size=n_best_size,
        max_answer_length=max_answer_length,
        do_lower_case=do_lower_case,
        version_2_with_negative=version_2_with_negative,
        null_score_diff_threshold=null_score_diff_threshold,
        tokenizer=tokenizer,
        verbose=verbose,
        skip_get_final_text=skip_get_final_text,
    )
    results = squad_utils.squad_evaluate(partial_examples, predictions)
    return results, predictions