Esempio n. 1
0
def run_loop(args: RunConfiguration):
    quick_init_out = initialization.quick_init(args=args, verbose=True)
    with quick_init_out.log_writer.log_context():
        if args.jiant_task_container_path:
            jiant_task_container = container_setup.create_jiant_task_container(
                **py_io.read_json(args.jiant_task_container_path)
            )
        else:
            raise RuntimeError("Need `jiant_task_container_path` or individual config paths")
        runner = setup_runner(
            args=args,
            jiant_task_container=jiant_task_container,
            quick_init_out=quick_init_out,
            verbose=True,
        )
    supertask, output_dir = args.supertask, args.output_dir
    if supertask in ["xnli", "pawsx"]:
        generate_and_write_preds_for_classification(
            runner=runner,
            supertask=supertask,
            output_dir=output_dir,
            skip_if_done=args.skip_if_done,
        )
    elif supertask in ["udpos", "panx"]:
        generate_and_write_preds_for_tagging(
            runner=runner,
            supertask=supertask,
            output_dir=output_dir,
            skip_if_done=args.skip_if_done,
        )
    elif supertask in ["xquad", "mlqa"]:
        generate_and_write_preds_for_qa(
            runner=runner,
            supertask=supertask,
            output_dir=output_dir,
            phase="test",
            skip_if_done=args.skip_if_done,
        )
    elif supertask == "tydiqa":
        generate_and_write_preds_for_qa(
            runner=runner,
            supertask="tydiqa",
            output_dir=output_dir,
            phase="val",
            skip_if_done=args.skip_if_done,
        )
    elif supertask == "bucc2018":
        generate_and_write_preds_for_bucc2018(
            runner=runner,
            output_dir=output_dir,
            bucc_val_metrics_path=args.bucc_val_metrics_path,
            skip_if_done=args.skip_if_done,
        )
    elif supertask == "tatoeba":
        generate_and_write_preds_for_tatoeba(
            runner=runner, output_dir=output_dir, skip_if_done=args.skip_if_done,
        )
    else:
        raise KeyError(supertask)
Esempio n. 2
0
def main(args: RunConfiguration):
    quick_init_out = initialization.quick_init(args=args, verbose=True)
    jiant_task_container = create_sample_jiant_task_container(
        working_dir=args.working_dir, )
    runner = setup_runner(
        args=args,
        jiant_task_container=jiant_task_container,
        quick_init_out=quick_init_out,
        verbose=True,
    )
    runner.run_train()
    val_metrics = runner.run_val(
        jiant_task_container.task_run_config.val_task_list)
    show_json({
        task_name: task_result_dict["metrics"].to_dict()
        for task_name, task_result_dict in val_metrics.items()
    })
Esempio n. 3
0
def run_loop(args: RunConfiguration, checkpoint=None, tasks=None):
    is_resumed = checkpoint is not None
    quick_init_out = initialization.quick_init(args=args, verbose=True)
    print(quick_init_out.n_gpu)
    with quick_init_out.log_writer.log_context():
        jiant_task_container = container_setup.create_jiant_task_container_from_json(
            jiant_task_container_config_path=args.
            jiant_task_container_config_path,
            verbose=True,
        )
        runner = setup_runner(
            args=args,
            jiant_task_container=jiant_task_container,
            quick_init_out=quick_init_out,
            verbose=True,
        )
        if is_resumed:
            runner.load_state(checkpoint["runner_state"])
            del checkpoint["runner_state"]
        checkpoint_saver = jiant_runner.CheckpointSaver(
            metadata={"args": args.to_dict()},
            save_path=os.path.join(args.output_dir, "checkpoint.p"),
        )

        if args.do_val:
            print("EVAL_BEFORE________________________________")
            val_results_dict = runner.run_val(
                task_name_list=runner.jiant_task_container.task_run_config.
                val_task_list,
                return_preds=args.write_val_preds,
            )
            jiant_evaluate.write_val_results(
                val_results_dict=val_results_dict,
                metrics_aggregator=runner.jiant_task_container.
                metrics_aggregator,
                output_dir=args.output_dir,
                verbose=True,
            )
            if args.write_val_preds:
                jiant_evaluate.write_preds(
                    eval_results_dict=val_results_dict,
                    path=os.path.join(args.output_dir, "val_preds.p"),
                )

        if args.do_train:
            metarunner = jiant_metarunner.JiantMetarunner(
                runner=runner,
                save_every_steps=args.save_every_steps,
                eval_every_steps=args.eval_every_steps,
                save_checkpoint_every_steps=args.save_checkpoint_every_steps,
                no_improvements_for_n_evals=args.no_improvements_for_n_evals,
                checkpoint_saver=checkpoint_saver,
                output_dir=args.output_dir,
                verbose=True,
                save_best_model=True,
                save_last_model=args.do_save or args.do_save_last,
                load_best_model=True,
                log_writer=quick_init_out.log_writer,
            )
            if is_resumed:
                metarunner.load_state(checkpoint["metarunner_state"])
                del checkpoint["metarunner_state"]

            metarunner.run_train_loop()
            runner.run_perturb(tasks, metarunner)

        if args.do_val:
            val_results_dict = runner.run_val(
                task_name_list=runner.jiant_task_container.task_run_config.
                val_task_list,
                return_preds=args.write_val_preds,
            )
            jiant_evaluate.write_val_results(
                val_results_dict=val_results_dict,
                metrics_aggregator=runner.jiant_task_container.
                metrics_aggregator,
                output_dir=args.output_dir,
                verbose=True,
            )
            if args.write_val_preds:
                jiant_evaluate.write_preds(
                    eval_results_dict=val_results_dict,
                    path=os.path.join(args.output_dir, "val_preds.p"),
                )
        else:
            assert not args.write_val_preds

        if args.write_test_preds:
            test_results_dict = runner.run_test(
                task_name_list=runner.jiant_task_container.task_run_config.
                test_task_list, )
            jiant_evaluate.write_preds(
                eval_results_dict=test_results_dict,
                path=os.path.join(args.output_dir, "test_preds.p"),
            )
        print("EVAL DONE------------------------")

    if (not args.keep_checkpoint_when_done and args.save_checkpoint_every_steps
            and os.path.exists(os.path.join(args.output_dir, "checkpoint.p"))):
        os.remove(os.path.join(args.output_dir, "checkpoint.p"))

    py_io.write_file("DONE", os.path.join(args.output_dir, "done_file"))
Esempio n. 4
0
def run_loop(args: RunConfiguration, checkpoint=None):
    is_resumed = checkpoint is not None
    quick_init_out = initialization.quick_init(args=args, verbose=True)
    print(quick_init_out.n_gpu)
    with quick_init_out.log_writer.log_context():
        jiant_task_container = container_setup.create_jiant_task_container_from_json(
            jiant_task_container_config_path=args.
            jiant_task_container_config_path,
            verbose=True,
        )
        runner = setup_runner(
            args=args,
            jiant_task_container=jiant_task_container,
            quick_init_out=quick_init_out,
            verbose=True,
        )
        if is_resumed:
            runner.load_state(checkpoint["runner_state"])
            del checkpoint["runner_state"]

        # allow custom checkpoint name
        if args.custom_checkpoint_name:
            checkpoint_name = os.path.join(args.output_dir,
                                           f"{args.custom_checkpoint_name}.p")
        else:
            checkpoint_name = os.path.join(args.output_dir, "checkpoint.p")

        checkpoint_saver = jiant_runner.CheckpointSaver(
            metadata={"args": args.to_dict()},
            save_path=os.path.join(args.output_dir, checkpoint_name),
        )
        if args.do_train:
            metarunner = jiant_metarunner.JiantMetarunner(
                runner=runner,
                save_every_steps=args.save_every_steps,
                eval_every_steps=args.eval_every_steps,
                save_checkpoint_every_steps=args.save_checkpoint_every_steps,
                no_improvements_for_n_evals=args.no_improvements_for_n_evals,
                checkpoint_saver=checkpoint_saver,
                output_dir=args.output_dir,
                verbose=True,
                save_best_model=args.do_save,
                load_best_model=True,
                log_writer=quick_init_out.log_writer,
            )
            if is_resumed:
                metarunner.load_state(checkpoint["metarunner_state"])
                del checkpoint["metarunner_state"]
            metarunner.run_train_loop()

        if args.do_save:
            # allow custom best model name
            if args.custom_best_name:
                best_model_name = os.path.join(args.output_dir,
                                               f"{args.custom_best_name}.p")
            else:
                best_model_name = os.path.join(args.output_dir, "model.p")

            torch.save(
                torch_utils.get_model_for_saving(
                    runner.jiant_model).state_dict(),
                best_model_name,
            )

        if args.do_val:
            val_results_dict = runner.run_val(
                task_name_list=runner.jiant_task_container.task_run_config.
                val_task_list,
                return_preds=args.write_val_preds,
            )
            jiant_evaluate.write_val_results(
                val_results_dict=val_results_dict,
                metrics_aggregator=runner.jiant_task_container.
                metrics_aggregator,
                output_dir=args.output_dir,
                verbose=True,
                val_jsonl=args.val_jsonl,
            )

            if args.args_jsonl:
                # match arguments with verbose results
                initialization.save_args(args, verbose=True, matched=True)

            if args.write_val_preds:
                if args.extract_exp_name_valpreds:
                    exp_name = os.path.basename(
                        args.jiant_task_container_config_path).split(".")[0]
                    val_fname = f"val_preds_{exp_name}.p"
                else:
                    val_fname = "val_preds.p"
                jiant_evaluate.write_preds(
                    eval_results_dict=val_results_dict,
                    path=os.path.join(args.output_dir, val_fname),
                )
        else:
            assert not args.write_val_preds

        if args.write_test_preds:
            test_results_dict = runner.run_test(
                task_name_list=runner.jiant_task_container.task_run_config.
                test_task_list, )
            jiant_evaluate.write_preds(
                eval_results_dict=test_results_dict,
                path=os.path.join(args.output_dir, "test_preds.p"),
            )

    if args.delete_checkpoint_if_done and args.save_checkpoint_every_steps:
        os.remove(os.path.join(args.output_dir, checkpoint_name))
Esempio n. 5
0
def run_loop(args: RunConfiguration, checkpoint=None):
    is_resumed = checkpoint is not None
    quick_init_out = initialization.quick_init(args=args, verbose=True)
    print(quick_init_out.n_gpu)
    with quick_init_out.log_writer.log_context():
        jiant_task_container = container_setup.create_jiant_task_container_from_json(
            jiant_task_container_config_path=args.jiant_task_container_config_path, verbose=True,
        )
        runner = setup_runner(
            args=args,
            jiant_task_container=jiant_task_container,
            quick_init_out=quick_init_out,
            verbose=True,
        )
        if is_resumed:
            runner.load_state(checkpoint["runner_state"])
            del checkpoint["runner_state"]
        checkpoint_saver = jiant_runner.CheckpointSaver(
            metadata={"args": args.to_dict()},
            save_path=os.path.join(args.output_dir, "checkpoint.p"),
        )
        if args.do_train:
            metarunner = adapters_metarunner.AdaptersMetarunner(
                runner=runner,
                save_every_steps=args.save_every_steps,
                eval_every_steps=args.eval_every_steps,
                save_checkpoint_every_steps=args.save_checkpoint_every_steps,
                no_improvements_for_n_evals=args.no_improvements_for_n_evals,
                checkpoint_saver=checkpoint_saver,
                output_dir=args.output_dir,
                verbose=True,
                save_best_model=args.do_save,
                load_best_model=True,
                log_writer=quick_init_out.log_writer,
            )
            if is_resumed:
                metarunner.load_state(checkpoint["metarunner_state"])
                del checkpoint["metarunner_state"]
            metarunner.run_train_loop()

        if args.do_save:
            torch.save(
                adapters_modeling.get_optimized_state_dict_for_jiant_model_with_adapters(
                    torch_utils.get_model_for_saving(runner.jiant_model),
                ),
                os.path.join(args.output_dir, "model.p"),
            )

        if args.do_val:
            val_results_dict = runner.run_val(
                task_name_list=runner.jiant_task_container.task_run_config.val_task_list,
                return_preds=args.write_val_preds,
            )
            jiant_evaluate.write_val_results(
                val_results_dict=val_results_dict,
                metrics_aggregator=runner.jiant_task_container.metrics_aggregator,
                output_dir=args.output_dir,
                verbose=True,
            )
            if args.write_val_preds:
                jiant_evaluate.write_preds(
                    eval_results_dict=val_results_dict,
                    path=os.path.join(args.output_dir, "val_preds.p"),
                )
        else:
            assert not args.write_val_preds

        if args.write_test_preds:
            test_results_dict = runner.run_test(
                task_name_list=runner.jiant_task_container.task_run_config.test_task_list,
            )
            jiant_evaluate.write_preds(
                eval_results_dict=test_results_dict,
                path=os.path.join(args.output_dir, "test_preds.p"),
            )

    if args.delete_checkpoint_if_done and args.save_checkpoint_every_steps:
        os.remove(os.path.join(args.output_dir, "checkpoint.p"))