Esempio n. 1
0
    def compute_metrics_from_accumulator(
        self, task, accumulator: BaseAccumulator, tokenizer, labels
    ) -> Metrics:

        # Todo: Fix val labels cache
        # This is a quick hack
        logits = accumulator.get_accumulated()
        partial_examples = squad_style.data_rows_to_partial_examples(data_rows=labels)
        all_pred_results = squad_style.logits_to_pred_results_list(logits)
        assert task.context_language == task.question_language
        lang = task.context_language
        predictions = squad_style_utils.compute_predictions_logits_v2(
            partial_examples=partial_examples,
            all_results=all_pred_results,
            n_best_size=task.n_best_size,
            max_answer_length=task.max_answer_length,
            do_lower_case=model_resolution.resolve_is_lower_case(tokenizer),
            version_2_with_negative=task.version_2_with_negative,
            null_score_diff_threshold=task.null_score_diff_threshold,
            tokenizer=tokenizer,
            skip_get_final_text=(lang == "zh"),
            verbose=True,
        )
        dataset = read_json(task.val_path)["data"]
        results = mlqa_lib.evaluate(dataset=dataset, predictions=predictions, lang=lang,)
        return Metrics(major=(results["f1"] + results["exact_match"]) / 2, minor=results,)
Esempio n. 2
0
 def compute_metrics_from_accumulator(
     self, task, accumulator: BaseAccumulator, tokenizer, labels
 ) -> Metrics:
     logits = accumulator.get_accumulated()
     results, predictions = squad_style.compute_predictions_logits_v3(
         data_rows=labels,
         logits=logits,
         n_best_size=task.n_best_size,
         max_answer_length=task.max_answer_length,
         do_lower_case=model_resolution.resolve_is_lower_case(tokenizer),
         version_2_with_negative=task.version_2_with_negative,
         null_score_diff_threshold=task.null_score_diff_threshold,
         tokenizer=tokenizer,
     )
     return Metrics(major=(results["f1"] + results["exact"]) / 2, minor=results,)
Esempio n. 3
0
 def compute_metrics_from_accumulator(
     self, task, accumulator: BaseAccumulator, tokenizer, labels
 ) -> Metrics:
     logits = accumulator.get_accumulated()
     assert isinstance(task, (tasks.TyDiQATask, tasks.XquadTask))
     lang = task.language
     results, predictions = squad_style.compute_predictions_logits_v3(
         data_rows=labels,
         logits=logits,
         n_best_size=task.n_best_size,
         max_answer_length=task.max_answer_length,
         do_lower_case=model_resolution.resolve_is_lower_case(tokenizer),
         version_2_with_negative=task.version_2_with_negative,
         null_score_diff_threshold=task.null_score_diff_threshold,
         skip_get_final_text=(lang == "zh"),
         tokenizer=tokenizer,
     )
     return Metrics(major=(results["f1"] + results["exact"]) / 2, minor=results,)
Esempio n. 4
0
    def tokenize(self, tokenizer):
        passage = (
            self.passage.lower()
            if model_resolution.resolve_is_lower_case(tokenizer=tokenizer)
            else self.passage
        )
        passage_tokens = tokenizer.tokenize(passage)
        token_aligner = TokenAligner(source=passage, target=passage_tokens)
        answer_token_span = token_aligner.project_char_to_token_span(
            self.answer_char_span[0], self.answer_char_span[1], inclusive=True
        )

        return TokenizedExample(
            guid=self.guid,
            passage=passage_tokens,
            question=tokenizer.tokenize(self.question),
            answer_str=self.answer,
            passage_str=passage,
            answer_token_span=answer_token_span,
            token_idx_to_char_idx_map=token_aligner.source_char_idx_to_target_token_idx.T,
        )
Esempio n. 5
0
def generate_and_write_preds_for_qa(runner,
                                    supertask: str,
                                    output_dir: str,
                                    phase: str,
                                    skip_if_done: bool = False):
    """Generate predictions (test) for QA tasks and write them in XTREME submission format"""
    preds_pickle_path = os.path.join(output_dir, f"{supertask}_test_preds.p")
    if skip_if_done and os.path.exists(preds_pickle_path):
        print(f"Skipping cause {preds_pickle_path} exists")
        return

    if phase == "val":
        task_name_list = runner.jiant_task_container.task_run_config.val_task_list
    elif phase == "test":
        task_name_list = runner.jiant_task_container.task_run_config.test_task_list
    else:
        raise KeyError(phase)
    task_name_list = [
        task_name for task_name in task_name_list
        if task_name.startswith(supertask)
    ]
    if phase == "val":
        test_results_dict = runner.run_val(task_name_list=task_name_list)
    elif phase == "test":
        test_results_dict = {}
        test_dataloader_dict = runner.get_test_dataloader_dict()
        for task_name in task_name_list:
            test_results_dict[task_name] = jiant_runner.run_test(
                test_dataloader=test_dataloader_dict[task_name],
                jiant_model=runner.jiant_model,
                task=runner.jiant_task_container.task_dict[task_name],
                device=runner.device,
                local_rank=runner.rparams.local_rank,
                return_preds=False,
                verbose=True,
            )
    else:
        raise KeyError(phase)

    # Generate QA preds
    tokenizer = runner.model.tokenizer
    for task_name in task_name_list:
        task_results = test_results_dict[task_name]
        task = runner.jiant_task_container.task_dict[task_name]
        logits = task_results["accumulator"].get_accumulated()
        lang = get_qa_language(supertask=supertask, task=task)
        if phase == "val":
            cached = runner.get_val_dataloader_dict([
                task_name
            ])[task_name].dataset.chunked_file_data_cache.get_all()
        elif phase == "test":
            cached = runner.get_test_dataloader_dict(
            )[task_name].dataset.chunked_file_data_cache.get_all()
        else:
            raise KeyError(phase)
        data_rows = [row["data_row"] for row in cached]
        results, predictions = squad_lib.compute_predictions_logits_v3(
            data_rows=data_rows,
            logits=logits,
            n_best_size=task.n_best_size,
            max_answer_length=task.max_answer_length,
            do_lower_case=model_resolution.resolve_is_lower_case(tokenizer),
            version_2_with_negative=task.version_2_with_negative,
            null_score_diff_threshold=task.null_score_diff_threshold,
            skip_get_final_text=(lang == "zh"),
            tokenizer=tokenizer,
        )
        test_results_dict[task_name]["preds"] = predictions

    jiant_evaluate.write_preds(
        eval_results_dict=test_results_dict,
        path=preds_pickle_path,
    )
    preds_output_dir = os.path.join(output_dir, "preds", supertask)
    os.makedirs(preds_output_dir, exist_ok=True)
    for task_name, task_results in test_results_dict.items():
        task = runner.jiant_task_container.task_dict[task_name]
        lang = get_qa_language(supertask=supertask, task=task)
        py_io.write_json(task_results["preds"],
                         os.path.join(preds_output_dir, f"test-{lang}.json"))
    print(f"Wrote {supertask} preds for {len(test_results_dict)} languages")