Exemple #1
0
def generate_and_write_preds_for_classification(
        runner: jiant_runner.JiantRunner,
        supertask: str,
        output_dir: str,
        skip_if_done: bool = False):
    """Write test predictions for classification tasks in XTREME submission format"""
    preds_pickle_path = os.path.join(output_dir, f"{supertask}_test_preds.p")
    if skip_if_done and os.path.exists(preds_pickle_path):
        print(f"Skipping cause {preds_pickle_path} exists")
        return

    test_results_dict = runner.run_test(
        task_name_list=runner.jiant_task_container.task_run_config.
        test_task_list, )
    jiant_evaluate.write_preds(
        eval_results_dict=test_results_dict,
        path=preds_pickle_path,
    )
    preds_output_dir = os.path.join(output_dir, "preds", supertask)
    os.makedirs(preds_output_dir, exist_ok=True)
    for task_name, task_results in test_results_dict.items():
        task = runner.jiant_task_container.task_dict[task_name]
        assert isinstance(task, (tasks.XnliTask, tasks.PawsXTask))
        lang = task.language
        with open(os.path.join(preds_output_dir, f"test-{lang}.tsv"),
                  "w") as f:
            for idx in task_results["preds"]:
                if supertask == "xnli":
                    pred_label = task.ID_TO_LABEL[idx]
                elif supertask == "pawsx":
                    pred_label = idx
                else:
                    raise RuntimeError()
                f.write(f"{pred_label}\n")
    print(f"Wrote {supertask} preds for {len(test_results_dict)} languages")
Exemple #2
0
def generate_and_write_preds_for_tatoeba(runner,
                                         output_dir: str,
                                         skip_if_done: bool = False):
    """Generate predictions (val) for Tateoba and write them in XTREME submission format"""
    preds_pickle_path = os.path.join(output_dir, "tatoeba_val_preds.p")
    if skip_if_done and os.path.exists(preds_pickle_path):
        print(f"Skipping cause {preds_pickle_path} exists")
        return
    val_results_dict = runner.run_val(
        task_name_list=runner.jiant_task_container.task_run_config.
        val_task_list,
        return_preds=True,
    )
    jiant_evaluate.write_preds(
        eval_results_dict=val_results_dict,
        path=preds_pickle_path,
    )
    preds_output_dir = os.path.join(output_dir, "preds", "tatoeba")
    os.makedirs(preds_output_dir, exist_ok=True)
    for task_name, task_results in val_results_dict.items():
        lang = runner.jiant_task_container.task_dict[task_name].language
        with open(os.path.join(preds_output_dir, f"test-{lang}.tsv"),
                  "w") as f:
            for idx in task_results["preds"]:
                f.write(f"{idx:d}\n")
    print(f"Wrote Tatoeba preds for {len(val_results_dict)} languages")
Exemple #3
0
def generate_and_write_preds_for_bucc2018(runner,
                                          output_dir: str,
                                          bucc_val_metrics_path: str,
                                          skip_if_done: bool = False):
    """Generate predictions (test) for Bucc2018 and write them in XTREME submission format"""
    preds_pickle_path = os.path.join(output_dir, "bucc2018_test_preds.p")
    if skip_if_done and os.path.exists(preds_pickle_path):
        print(f"Skipping cause {preds_pickle_path} exists")
        return
    else:
        print(f"{preds_pickle_path} does not exist")
    if bucc_val_metrics_path is None:
        # Recompute thresholds:
        val_results_dict = runner.run_val(
            task_name_list=runner.jiant_task_container.task_run_config.
            val_task_list,
            return_preds=True,
        )
        jiant_evaluate.write_preds(
            eval_results_dict=val_results_dict,
            path=os.path.join(output_dir, "bucc2018_val_preds.p"),
        )
        thresholds_dict = {
            task_name: task_results["metrics"].minor["best-threshold"]
            for task_name, task_results in val_results_dict.items()
        }
    else:
        val_metrics = py_io.read_json(bucc_val_metrics_path)
        thresholds_dict = {
            task_name:
            val_metrics[task_name]["metrics"]["minor"]["best-threshold"]
            for task_name in
            runner.jiant_task_container.task_run_config.val_task_list
        }

    preds_output_dir = os.path.join(output_dir, "preds", "bucc2018")
    os.makedirs(preds_output_dir, exist_ok=True)
    test_results_dict = runner.run_test(
        task_name_list=runner.jiant_task_container.task_run_config.
        test_task_list, )
    jiant_evaluate.write_preds(
        eval_results_dict=test_results_dict,
        path=preds_pickle_path,
    )
    for task_name, task_results in test_results_dict.items():
        bitext = bucc2018_lib.bucc_extract(
            cand2score=task_results["preds"],
            th=thresholds_dict[task_name],
        )
        lang = runner.jiant_task_container.task_dict[task_name].language
        with open(os.path.join(preds_output_dir, f"test-{lang}.tsv"),
                  "w") as f:
            for src, trg in bitext:
                f.write(f"{src}\t{trg}\n")
    print(f"Wrote Bucc2018 preds for {len(test_results_dict)} languages")
Exemple #4
0
def generate_and_write_preds_for_qa(runner,
                                    supertask: str,
                                    output_dir: str,
                                    phase: str,
                                    skip_if_done: bool = False):
    """Generate predictions (test) for QA tasks and write them in XTREME submission format"""
    preds_pickle_path = os.path.join(output_dir, f"{supertask}_test_preds.p")
    if skip_if_done and os.path.exists(preds_pickle_path):
        print(f"Skipping cause {preds_pickle_path} exists")
        return

    if phase == "val":
        task_name_list = runner.jiant_task_container.task_run_config.val_task_list
    elif phase == "test":
        task_name_list = runner.jiant_task_container.task_run_config.test_task_list
    else:
        raise KeyError(phase)
    task_name_list = [
        task_name for task_name in task_name_list
        if task_name.startswith(supertask)
    ]
    if phase == "val":
        test_results_dict = runner.run_val(task_name_list=task_name_list)
    elif phase == "test":
        test_results_dict = {}
        test_dataloader_dict = runner.get_test_dataloader_dict()
        for task_name in task_name_list:
            test_results_dict[task_name] = jiant_runner.run_test(
                test_dataloader=test_dataloader_dict[task_name],
                jiant_model=runner.jiant_model,
                task=runner.jiant_task_container.task_dict[task_name],
                device=runner.device,
                local_rank=runner.rparams.local_rank,
                return_preds=False,
                verbose=True,
            )
    else:
        raise KeyError(phase)

    # Generate QA preds
    tokenizer = runner.model.tokenizer
    for task_name in task_name_list:
        task_results = test_results_dict[task_name]
        task = runner.jiant_task_container.task_dict[task_name]
        logits = task_results["accumulator"].get_accumulated()
        lang = get_qa_language(supertask=supertask, task=task)
        if phase == "val":
            cached = runner.get_val_dataloader_dict([
                task_name
            ])[task_name].dataset.chunked_file_data_cache.get_all()
        elif phase == "test":
            cached = runner.get_test_dataloader_dict(
            )[task_name].dataset.chunked_file_data_cache.get_all()
        else:
            raise KeyError(phase)
        data_rows = [row["data_row"] for row in cached]
        results, predictions = squad_lib.compute_predictions_logits_v3(
            data_rows=data_rows,
            logits=logits,
            n_best_size=task.n_best_size,
            max_answer_length=task.max_answer_length,
            do_lower_case=model_resolution.resolve_is_lower_case(tokenizer),
            version_2_with_negative=task.version_2_with_negative,
            null_score_diff_threshold=task.null_score_diff_threshold,
            skip_get_final_text=(lang == "zh"),
            tokenizer=tokenizer,
        )
        test_results_dict[task_name]["preds"] = predictions

    jiant_evaluate.write_preds(
        eval_results_dict=test_results_dict,
        path=preds_pickle_path,
    )
    preds_output_dir = os.path.join(output_dir, "preds", supertask)
    os.makedirs(preds_output_dir, exist_ok=True)
    for task_name, task_results in test_results_dict.items():
        task = runner.jiant_task_container.task_dict[task_name]
        lang = get_qa_language(supertask=supertask, task=task)
        py_io.write_json(task_results["preds"],
                         os.path.join(preds_output_dir, f"test-{lang}.json"))
    print(f"Wrote {supertask} preds for {len(test_results_dict)} languages")
Exemple #5
0
def run_loop(args: RunConfiguration, checkpoint=None, tasks=None):
    is_resumed = checkpoint is not None
    quick_init_out = initialization.quick_init(args=args, verbose=True)
    print(quick_init_out.n_gpu)
    with quick_init_out.log_writer.log_context():
        jiant_task_container = container_setup.create_jiant_task_container_from_json(
            jiant_task_container_config_path=args.
            jiant_task_container_config_path,
            verbose=True,
        )
        runner = setup_runner(
            args=args,
            jiant_task_container=jiant_task_container,
            quick_init_out=quick_init_out,
            verbose=True,
        )
        if is_resumed:
            runner.load_state(checkpoint["runner_state"])
            del checkpoint["runner_state"]
        checkpoint_saver = jiant_runner.CheckpointSaver(
            metadata={"args": args.to_dict()},
            save_path=os.path.join(args.output_dir, "checkpoint.p"),
        )

        if args.do_val:
            print("EVAL_BEFORE________________________________")
            val_results_dict = runner.run_val(
                task_name_list=runner.jiant_task_container.task_run_config.
                val_task_list,
                return_preds=args.write_val_preds,
            )
            jiant_evaluate.write_val_results(
                val_results_dict=val_results_dict,
                metrics_aggregator=runner.jiant_task_container.
                metrics_aggregator,
                output_dir=args.output_dir,
                verbose=True,
            )
            if args.write_val_preds:
                jiant_evaluate.write_preds(
                    eval_results_dict=val_results_dict,
                    path=os.path.join(args.output_dir, "val_preds.p"),
                )

        if args.do_train:
            metarunner = jiant_metarunner.JiantMetarunner(
                runner=runner,
                save_every_steps=args.save_every_steps,
                eval_every_steps=args.eval_every_steps,
                save_checkpoint_every_steps=args.save_checkpoint_every_steps,
                no_improvements_for_n_evals=args.no_improvements_for_n_evals,
                checkpoint_saver=checkpoint_saver,
                output_dir=args.output_dir,
                verbose=True,
                save_best_model=True,
                save_last_model=args.do_save or args.do_save_last,
                load_best_model=True,
                log_writer=quick_init_out.log_writer,
            )
            if is_resumed:
                metarunner.load_state(checkpoint["metarunner_state"])
                del checkpoint["metarunner_state"]

            metarunner.run_train_loop()
            runner.run_perturb(tasks, metarunner)

        if args.do_val:
            val_results_dict = runner.run_val(
                task_name_list=runner.jiant_task_container.task_run_config.
                val_task_list,
                return_preds=args.write_val_preds,
            )
            jiant_evaluate.write_val_results(
                val_results_dict=val_results_dict,
                metrics_aggregator=runner.jiant_task_container.
                metrics_aggregator,
                output_dir=args.output_dir,
                verbose=True,
            )
            if args.write_val_preds:
                jiant_evaluate.write_preds(
                    eval_results_dict=val_results_dict,
                    path=os.path.join(args.output_dir, "val_preds.p"),
                )
        else:
            assert not args.write_val_preds

        if args.write_test_preds:
            test_results_dict = runner.run_test(
                task_name_list=runner.jiant_task_container.task_run_config.
                test_task_list, )
            jiant_evaluate.write_preds(
                eval_results_dict=test_results_dict,
                path=os.path.join(args.output_dir, "test_preds.p"),
            )
        print("EVAL DONE------------------------")

    if (not args.keep_checkpoint_when_done and args.save_checkpoint_every_steps
            and os.path.exists(os.path.join(args.output_dir, "checkpoint.p"))):
        os.remove(os.path.join(args.output_dir, "checkpoint.p"))

    py_io.write_file("DONE", os.path.join(args.output_dir, "done_file"))
def run_loop(args: RunConfiguration, checkpoint=None):
    is_resumed = checkpoint is not None
    quick_init_out = initialization.quick_init(args=args, verbose=True)
    print(quick_init_out.n_gpu)
    with quick_init_out.log_writer.log_context():
        jiant_task_container = container_setup.create_jiant_task_container_from_json(
            jiant_task_container_config_path=args.
            jiant_task_container_config_path,
            verbose=True,
        )
        runner = setup_runner(
            args=args,
            jiant_task_container=jiant_task_container,
            quick_init_out=quick_init_out,
            verbose=True,
        )
        if is_resumed:
            runner.load_state(checkpoint["runner_state"])
            del checkpoint["runner_state"]

        # allow custom checkpoint name
        if args.custom_checkpoint_name:
            checkpoint_name = os.path.join(args.output_dir,
                                           f"{args.custom_checkpoint_name}.p")
        else:
            checkpoint_name = os.path.join(args.output_dir, "checkpoint.p")

        checkpoint_saver = jiant_runner.CheckpointSaver(
            metadata={"args": args.to_dict()},
            save_path=os.path.join(args.output_dir, checkpoint_name),
        )
        if args.do_train:
            metarunner = jiant_metarunner.JiantMetarunner(
                runner=runner,
                save_every_steps=args.save_every_steps,
                eval_every_steps=args.eval_every_steps,
                save_checkpoint_every_steps=args.save_checkpoint_every_steps,
                no_improvements_for_n_evals=args.no_improvements_for_n_evals,
                checkpoint_saver=checkpoint_saver,
                output_dir=args.output_dir,
                verbose=True,
                save_best_model=args.do_save,
                load_best_model=True,
                log_writer=quick_init_out.log_writer,
            )
            if is_resumed:
                metarunner.load_state(checkpoint["metarunner_state"])
                del checkpoint["metarunner_state"]
            metarunner.run_train_loop()

        if args.do_save:
            # allow custom best model name
            if args.custom_best_name:
                best_model_name = os.path.join(args.output_dir,
                                               f"{args.custom_best_name}.p")
            else:
                best_model_name = os.path.join(args.output_dir, "model.p")

            torch.save(
                torch_utils.get_model_for_saving(
                    runner.jiant_model).state_dict(),
                best_model_name,
            )

        if args.do_val:
            val_results_dict = runner.run_val(
                task_name_list=runner.jiant_task_container.task_run_config.
                val_task_list,
                return_preds=args.write_val_preds,
            )
            jiant_evaluate.write_val_results(
                val_results_dict=val_results_dict,
                metrics_aggregator=runner.jiant_task_container.
                metrics_aggregator,
                output_dir=args.output_dir,
                verbose=True,
                val_jsonl=args.val_jsonl,
            )

            if args.args_jsonl:
                # match arguments with verbose results
                initialization.save_args(args, verbose=True, matched=True)

            if args.write_val_preds:
                if args.extract_exp_name_valpreds:
                    exp_name = os.path.basename(
                        args.jiant_task_container_config_path).split(".")[0]
                    val_fname = f"val_preds_{exp_name}.p"
                else:
                    val_fname = "val_preds.p"
                jiant_evaluate.write_preds(
                    eval_results_dict=val_results_dict,
                    path=os.path.join(args.output_dir, val_fname),
                )
        else:
            assert not args.write_val_preds

        if args.write_test_preds:
            test_results_dict = runner.run_test(
                task_name_list=runner.jiant_task_container.task_run_config.
                test_task_list, )
            jiant_evaluate.write_preds(
                eval_results_dict=test_results_dict,
                path=os.path.join(args.output_dir, "test_preds.p"),
            )

    if args.delete_checkpoint_if_done and args.save_checkpoint_every_steps:
        os.remove(os.path.join(args.output_dir, checkpoint_name))
Exemple #7
0
def run_loop(args: RunConfiguration, checkpoint=None):
    is_resumed = checkpoint is not None
    quick_init_out = initialization.quick_init(args=args, verbose=True)
    print(quick_init_out.n_gpu)
    with quick_init_out.log_writer.log_context():
        jiant_task_container = container_setup.create_jiant_task_container_from_json(
            jiant_task_container_config_path=args.jiant_task_container_config_path, verbose=True,
        )
        runner = setup_runner(
            args=args,
            jiant_task_container=jiant_task_container,
            quick_init_out=quick_init_out,
            verbose=True,
        )
        if is_resumed:
            runner.load_state(checkpoint["runner_state"])
            del checkpoint["runner_state"]
        checkpoint_saver = jiant_runner.CheckpointSaver(
            metadata={"args": args.to_dict()},
            save_path=os.path.join(args.output_dir, "checkpoint.p"),
        )
        if args.do_train:
            metarunner = adapters_metarunner.AdaptersMetarunner(
                runner=runner,
                save_every_steps=args.save_every_steps,
                eval_every_steps=args.eval_every_steps,
                save_checkpoint_every_steps=args.save_checkpoint_every_steps,
                no_improvements_for_n_evals=args.no_improvements_for_n_evals,
                checkpoint_saver=checkpoint_saver,
                output_dir=args.output_dir,
                verbose=True,
                save_best_model=args.do_save,
                load_best_model=True,
                log_writer=quick_init_out.log_writer,
            )
            if is_resumed:
                metarunner.load_state(checkpoint["metarunner_state"])
                del checkpoint["metarunner_state"]
            metarunner.run_train_loop()

        if args.do_save:
            torch.save(
                adapters_modeling.get_optimized_state_dict_for_jiant_model_with_adapters(
                    torch_utils.get_model_for_saving(runner.jiant_model),
                ),
                os.path.join(args.output_dir, "model.p"),
            )

        if args.do_val:
            val_results_dict = runner.run_val(
                task_name_list=runner.jiant_task_container.task_run_config.val_task_list,
                return_preds=args.write_val_preds,
            )
            jiant_evaluate.write_val_results(
                val_results_dict=val_results_dict,
                metrics_aggregator=runner.jiant_task_container.metrics_aggregator,
                output_dir=args.output_dir,
                verbose=True,
            )
            if args.write_val_preds:
                jiant_evaluate.write_preds(
                    eval_results_dict=val_results_dict,
                    path=os.path.join(args.output_dir, "val_preds.p"),
                )
        else:
            assert not args.write_val_preds

        if args.write_test_preds:
            test_results_dict = runner.run_test(
                task_name_list=runner.jiant_task_container.task_run_config.test_task_list,
            )
            jiant_evaluate.write_preds(
                eval_results_dict=test_results_dict,
                path=os.path.join(args.output_dir, "test_preds.p"),
            )

    if args.delete_checkpoint_if_done and args.save_checkpoint_every_steps:
        os.remove(os.path.join(args.output_dir, "checkpoint.p"))