Esempio n. 1
0
def model_ensemble(args):
    """ 직접 모델과 전략 입력해주시면 됩니다! """

    MODELS = [
        ("../input/model_ensemble_checkpoint/gunmo/RD_G04_C01_KOELECTRA_BASE_V3_FINETUNED_95/checkpoint-6000/",
         None),
        (
            "../input/model_ensemble_checkpoint/suyeon/KOELECTRA_FINETUNED_TRAIN_KOELECTRA_FINETUNED_95/checkpoint-5400/",
            None,
        ),
        ("../input/model_ensemble_checkpoint/suyeon/ST05_AtireBM25_95/checkpoint-5000/",
         None),
        ("../input/model_ensemble_checkpoint/jonghun/ST101_CNN_95/checkpoint-15100/",
         "ST101"),
        ("../input/model_ensemble_checkpoint/jonghun/ST103_CNN_LSTM_95/checkpoint-5500/",
         "ST103"),
        ("../input/model_ensemble_checkpoint/jonghun/ST104_CCNN_v2_95/checkpoint-15100/",
         "ST104"),
        ("../input/model_ensemble_checkpoint/jonghun/ST106_LSTM_95/checkpoint-1500/",
         "ST106"),
    ]

    args.retriever.topk = TOPK
    args.data.max_answer_length = MAX_ANSWER_LENGTH
    args.retriever.model_name = "ATIREBM25_DPRBERT"
    args.train.do_predict = True

    datasets = get_dataset(args, is_train=False)
    retriever = get_retriever(args)

    eval_answers = datasets["validation"]
    datasets["validation"] = retriever.retrieve(
        datasets["validation"], topk=args.retriever.topk)["validation"]

    run(args, MODELS, eval_answers, datasets)
Esempio n. 2
0
def predict(args):
    # Don't use wandb

    strategies = args.strategies

    for idx, strategy in enumerate(strategies):
        args = update_args(args, strategy)
        args.strategy = strategy

        checkpoint_dir = glob(p.join(args.path.checkpoint, f"{strategy}*"))
        if not checkpoint_dir:
            raise FileNotFoundError(
                f"{strategy} 전략에 대한 checkpoint가 존재하지 않습니다.")

        args.model.model_path = get_last_checkpoint(checkpoint_dir[0])
        if args.model.model_path is None:
            raise FileNotFoundError(
                f"{checkpoint_dir[0]} 경로에 체크포인트가 존재하지 않습니다.")

        args.train.output_dir = p.join(args.path.checkpoint, strategy)
        args.train.do_predict = True

        datasets = get_dataset(args, is_train=False)
        reader = get_reader(args, eval_answers=datasets["validation"])
        retriever = get_retriever(args)

        datasets["validation"] = retriever.retrieve(
            datasets["validation"], topk=args.retriever.topk)["validation"]
        reader.set_dataset(eval_dataset=datasets["validation"])

        trainer = reader.get_trainer()

        # use pororo_predict WHERE args.train.pororo_predictions=True
        trainer.predict(test_dataset=reader.eval_dataset,
                        test_examples=datasets["validation"])
Esempio n. 3
0
def train_retriever(args):
    strategies = args.strategies
    seeds = args.seeds[: args.run_cnt]

    topk_result = defaultdict(list)
    wandb.init(project="p-stage-3", reinit=True)
    wandb.run.name = "COMPARE RETRIEVER"

    for idx, (seed, strategy) in enumerate(product(seeds, strategies)):
        args = update_args(args, strategy)
        set_seed(seed)

        datasets = get_dataset(args, is_train=True)

        retriever = get_retriever(args)
        valid_datasets = retriever.retrieve(datasets["validation"], topk=args.retriever.topk)

        print(f"전략: {strategy} RETRIEVER: {args.model.retriever_name}")
        legend_name = "_".join([strategy, args.model.retriever_name])
        topk = args.retriever.topk

        cur_cnt, tot_cnt = 0, len(datasets["validation"])

        indexes = np.array(range(tot_cnt * topk))
        print("total_cnt:", tot_cnt)
        print("valid_datasets:", valid_datasets)

        qc_dict = defaultdict(bool)
        for idx, fancy_index in enumerate(zip([indexes[i::topk] for i in range(topk)])):
            topk_dataset = valid_datasets["validation"][fancy_index[0]]

            for question, real, pred in zip(
                topk_dataset["question"], topk_dataset["original_context"], topk_dataset["context"]
            ):
                # if two texts overlaps more than 65%,
                if fuzz.ratio(real, pred) > 85 and not qc_dict[question]:
                    qc_dict[question] = True
                    cur_cnt += 1

            topk_acc = cur_cnt / tot_cnt
            topk_result[legend_name].append(topk_acc)
            print(f"TOPK: {idx + 1} ACC: {topk_acc * 100:.2f}")

    fig = get_topk_fig(args, topk_result)
    wandb.log({"retriever topk result": wandb.Image(fig)})

    if args.report is True:
        report_retriever_to_slack(args, fig)
def get_train_dataset(args):
    args.data.dataset_name = "train_dataset"
    datasets = get_dataset(args, is_train=True)
    datasets = concatenate_datasets(
        [datasets["train"], datasets["validation"]])
    return datasets
Esempio n. 5
0
def train_reader(args):
    strategies = args.strategies
    seeds = args.seeds[:args.run_cnt]

    for idx, (seed, strategy) in enumerate(product(seeds, strategies)):
        wandb.init(project="p-stage-3", reinit=True)
        args = update_args(args,
                           strategy)  # auto add args.save_path, args.base_path
        args.strategy, args.seed = strategy, seed
        args.info = Namespace()
        set_seed(seed)

        checkpoint_dir = glob(p.join(args.path.checkpoint, f"{strategy}*"))
        if not checkpoint_dir:
            raise FileNotFoundError(
                f"{strategy} 전략에 대한 checkpoint가 존재하지 않습니다.")

        args.model.model_path = get_last_checkpoint(checkpoint_dir[0])
        if args.model.model_path is None:
            raise FileNotFoundError(
                f"{checkpoint_dir[0]} 경로에 체크포인트가 존재하지 않습니다.")

        # run_name: strategy + alias + seed
        args.train.run_name = "_".join([strategy, args.alias, str(seed)])
        args.train.output_dir = p.join(args.path.checkpoint,
                                       args.train.run_name)
        wandb.run.name = args.train.run_name

        print("checkpoint_dir: ", args.train.output_dir)

        datasets = get_dataset(args, is_train=True)
        reader = get_reader(args, eval_answers=datasets["validation"])
        retriever = get_retriever(args)

        datasets["validation"] = retriever.retrieve(
            datasets["validation"], topk=args.retriever.topk)["validation"]
        reader.set_dataset(eval_dataset=datasets["validation"])

        trainer = reader.get_trainer()

        if args.train.do_eval:
            metric_results = trainer.evaluate()
            results = evaluation(args)

            metric_results["predictions"]["exact_match"] = results["EM"][
                "value"]
            metric_results["predictions"]["f1"] = results["F1"]["value"]

            print("EVAL RESULT")
            print(metric_results["predictions"])

        if args.train.do_eval and args.train.pororo_prediction:
            assert metric_results is not None, "trainer.evaluate()가 None을 반환합니다."

            results = evaluation(args, prefix="pororo_")

            metric_results["pororo_predictions"]["exact_match"] = results[
                "EM"]["value"]
            metric_results["pororo_predictions"]["f1"] = results["F1"]["value"]

            print("PORORO EVAL RESULT")
            print(metric_results["pororo_predictions"])

        if args.train.do_eval and args.report:
            report_reader_to_slack(args,
                                   p.basename(__file__),
                                   metric_results["predictions"],
                                   use_pororo=False)

            if args.train.pororo_prediction:
                report_reader_to_slack(args,
                                       p.basename(__file__),
                                       metric_results["pororo_predictions"],
                                       use_pororo=True)
Esempio n. 6
0
def train_reader(args):
    strategies = args.strategies
    seeds = args.seeds[:args.run_cnt]

    for idx, (seed, strategy) in enumerate(product(seeds, strategies)):
        wandb.init(project="p-stage-3", reinit=True)
        args = update_args(args,
                           strategy)  # auto add args.save_path, args.base_path
        args.strategy, args.seed = strategy, seed
        args.info = Namespace()
        set_seed(seed)

        # below codes must run before 'reader.get_trainer()'
        # run_name: strategy + alias + seed
        args.train.run_name = "_".join([strategy, args.alias, str(seed)])
        args.train.output_dir = p.join(args.path.checkpoint,
                                       args.train.run_name)
        wandb.run.name = args.train.run_name

        print("checkpoint_dir: ", args.train.output_dir)

        # retrieve 과정이 없어 top-k를 반환할 수 없음. 무조건 top-1만 반환
        # run_mrc.py DOES NOT execute retrieve, so args.retriever.topk cannot be n(>1).
        # If topk > 1, post processing function returns mis-bundled predictions.
        args.retriever.topk = 1

        datasets = get_dataset(args, is_train=True)
        reader = get_reader(args, eval_answers=datasets["validation"])

        reader.set_dataset(train_dataset=datasets["train"],
                           eval_dataset=datasets["validation"])

        trainer = reader.get_trainer()

        if args.train.do_train:
            train_results = trainer.train()
            print(train_results)

        metric_results = None

        if args.train.do_eval:
            metric_results = trainer.evaluate()
            results = evaluation(args)

            metric_results["predictions"]["exact_match"] = results["EM"][
                "value"]
            metric_results["predictions"]["f1"] = results["F1"]["value"]

            print("EVAL RESULT")
            print(metric_results["predictions"])

        if args.train.do_eval and args.train.pororo_prediction:
            assert metric_results is not None, "trainer.evaluate()가 None을 반환합니다."

            results = evaluation(args, prefix="pororo_")

            metric_results["pororo_predictions"]["exact_match"] = results[
                "EM"]["value"]
            metric_results["pororo_predictions"]["f1"] = results["F1"]["value"]

            print("PORORO EVAL RESULT")
            print(metric_results["pororo_predictions"])

        if args.train.do_eval and args.report:
            report_reader_to_slack(args,
                                   p.basename(__file__),
                                   metric_results["predictions"],
                                   use_pororo=False)

            if args.train.pororo_prediction:
                report_reader_to_slack(args,
                                       p.basename(__file__),
                                       metric_results["pororo_predictions"],
                                       use_pororo=True)