コード例 #1
0
def predict(args):
    # Don't use wandb

    strategies = args.strategies

    for idx, strategy in enumerate(strategies):
        args = update_args(args, strategy)
        args.strategy = strategy

        checkpoint_dir = glob(p.join(args.path.checkpoint, f"{strategy}*"))
        if not checkpoint_dir:
            raise FileNotFoundError(
                f"{strategy} 전략에 대한 checkpoint가 존재하지 않습니다.")

        args.model.model_path = get_last_checkpoint(checkpoint_dir[0])
        if args.model.model_path is None:
            raise FileNotFoundError(
                f"{checkpoint_dir[0]} 경로에 체크포인트가 존재하지 않습니다.")

        args.train.output_dir = p.join(args.path.checkpoint, strategy)
        args.train.do_predict = True

        datasets = get_dataset(args, is_train=False)
        reader = get_reader(args, eval_answers=datasets["validation"])
        retriever = get_retriever(args)

        datasets["validation"] = retriever.retrieve(
            datasets["validation"], topk=args.retriever.topk)["validation"]
        reader.set_dataset(eval_dataset=datasets["validation"])

        trainer = reader.get_trainer()

        # use pororo_predict WHERE args.train.pororo_predictions=True
        trainer.predict(test_dataset=reader.eval_dataset,
                        test_examples=datasets["validation"])
コード例 #2
0
 def test_valid_model(self, args=args):
     for seed, strategy in [(SEED, strategy) for strategy in strategies]:
         args = update_args(args, strategy)
         args.strategy, args.seed = strategy, seed
         set_seed(seed)
         try:
             get_reader_model(args)
         except Exception:
             assert False, "hugging face에 존재하지 않는 model 혹은 잘못된 경로입니다. "
コード例 #3
0
 def test_valid_dataset(self, args=args):
     for seed, strategy in [(SEED, strategy) for strategy in strategies]:
         args = update_args(args, strategy)
         args.strategy, args.seed = strategy, seed
         set_seed(seed)
         try:
             prepare_dataset(args, is_train=True)
         except KeyError:
             assert False, "존재하지 않는 dataset입니다. "
コード例 #4
0
def run(args, models, eval_answers, datasets):
    """Ensemble을 수행합니다.
    1. Soft Voting Use Offset
    2. Soft Voting Use Span
    3. Hard Voting Use Offset
    """

    soft_offset_predictions = defaultdict(dict)
    soft_span_predictions = defaultdict(dict)
    hard_offset_predictions = defaultdict(dict)

    for model_path, strategy in models:
        args.model_name_or_path = model_path
        args.model.reader_name = "DPR"

        if strategy is not None:
            args = update_args(args, strategy)

        args.retriever.topk = TOPK

        reader = get_reader(args, eval_answers=eval_answers)
        reader.set_dataset(eval_dataset=datasets["validation"])

        trainer = reader.get_trainer()

        logit_list, (contexts, document_ids,
                     question_ids) = trainer.get_logits_with_keys(
                         reader.eval_dataset,
                         datasets["validation"],
                         keys=["context", "context_id", "id"])

        # Logit Standardization, -1 ~ 1
        logit_list = logit_list_standardization(logit_list)

        soft_voting_use_offset(soft_offset_predictions, logit_list, contexts,
                               document_ids, question_ids)
        hard_voting_use_offset(hard_offset_predictions, logit_list, contexts,
                               document_ids, question_ids)
        soft_voting_use_span(soft_span_predictions, logit_list, contexts,
                             document_ids, question_ids)

    offset_postprocess(soft_offset_predictions)
    offset_postprocess(hard_offset_predictions)
    span_postprocess(soft_span_predictions)

    filename = "soft_offset_predictions.json"
    save_offset_ensemble(args, soft_offset_predictions, filename)

    filename = "hard_offset_predictions.json"
    save_offset_ensemble(args, hard_offset_predictions, filename)

    filename = "soft_span_predictions.json"
    save_span_ensemble(args, hard_offset_predictions, filename)
コード例 #5
0
def train_retriever(args):
    strategies = args.strategies
    seeds = args.seeds[: args.run_cnt]

    topk_result = defaultdict(list)
    wandb.init(project="p-stage-3", reinit=True)
    wandb.run.name = "COMPARE RETRIEVER"

    for idx, (seed, strategy) in enumerate(product(seeds, strategies)):
        args = update_args(args, strategy)
        set_seed(seed)

        datasets = get_dataset(args, is_train=True)

        retriever = get_retriever(args)
        valid_datasets = retriever.retrieve(datasets["validation"], topk=args.retriever.topk)

        print(f"전략: {strategy} RETRIEVER: {args.model.retriever_name}")
        legend_name = "_".join([strategy, args.model.retriever_name])
        topk = args.retriever.topk

        cur_cnt, tot_cnt = 0, len(datasets["validation"])

        indexes = np.array(range(tot_cnt * topk))
        print("total_cnt:", tot_cnt)
        print("valid_datasets:", valid_datasets)

        qc_dict = defaultdict(bool)
        for idx, fancy_index in enumerate(zip([indexes[i::topk] for i in range(topk)])):
            topk_dataset = valid_datasets["validation"][fancy_index[0]]

            for question, real, pred in zip(
                topk_dataset["question"], topk_dataset["original_context"], topk_dataset["context"]
            ):
                # if two texts overlaps more than 65%,
                if fuzz.ratio(real, pred) > 85 and not qc_dict[question]:
                    qc_dict[question] = True
                    cur_cnt += 1

            topk_acc = cur_cnt / tot_cnt
            topk_result[legend_name].append(topk_acc)
            print(f"TOPK: {idx + 1} ACC: {topk_acc * 100:.2f}")

    fig = get_topk_fig(args, topk_result)
    wandb.log({"retriever topk result": wandb.Image(fig)})

    if args.report is True:
        report_retriever_to_slack(args, fig)
コード例 #6
0
    def test_strategies_with_dataset(self, args=args):
        """
        (Constraint)
            - num_train_epoch 1
            - random seed 1
            - dataset fragment (rows : 100)
        (Caution)
            ERROR가 표시된다면, 상위 단위 테스트 결과를 확인하세요.
        """
        for seed, strategy in [(SEED, strategy) for strategy in strategies]:
            wandb.init(project="p-stage-3-test", reinit=True)
            args = update_args(args, strategy)
            args.strategy, args.seed = strategy, seed
            set_seed(seed)

            datasets = prepare_dataset(args, is_train=True)
            model, tokenizer = get_reader_model(args)
            train_dataset, post_processing_function = preprocess_dataset(
                args, datasets, tokenizer, is_train=True)

            train_dataset = train_dataset.select(range(100))  # select 100

            data_collator = DataCollatorWithPadding(
                tokenizer, pad_to_multiple_of=8 if args.train.fp16 else None)

            args.train.do_train = True
            args.train.run_name = "_".join(
                [strategy, args.alias, str(seed), "test"])
            wandb.run.name = args.train.run_name

            # TRAIN MRC
            args.train.num_train_epochs = 1.0  # fix epoch 1
            trainer = QuestionAnsweringTrainer(
                model=model,
                args=args.train,  # training_args
                custom_args=args,
                train_dataset=train_dataset,
                tokenizer=tokenizer,
                data_collator=data_collator,
                post_process_function=post_processing_function,
                compute_metrics=compute_metrics,
            )

            trainer.train()
コード例 #7
0
def train_reader(args):
    strategies = args.strategies
    seeds = args.seeds[:args.run_cnt]

    for idx, (seed, strategy) in enumerate(product(seeds, strategies)):
        wandb.init(project="p-stage-3", reinit=True)
        args = update_args(args,
                           strategy)  # auto add args.save_path, args.base_path
        args.strategy, args.seed = strategy, seed
        args.info = Namespace()
        set_seed(seed)

        checkpoint_dir = glob(p.join(args.path.checkpoint, f"{strategy}*"))
        if not checkpoint_dir:
            raise FileNotFoundError(
                f"{strategy} 전략에 대한 checkpoint가 존재하지 않습니다.")

        args.model.model_path = get_last_checkpoint(checkpoint_dir[0])
        if args.model.model_path is None:
            raise FileNotFoundError(
                f"{checkpoint_dir[0]} 경로에 체크포인트가 존재하지 않습니다.")

        # run_name: strategy + alias + seed
        args.train.run_name = "_".join([strategy, args.alias, str(seed)])
        args.train.output_dir = p.join(args.path.checkpoint,
                                       args.train.run_name)
        wandb.run.name = args.train.run_name

        print("checkpoint_dir: ", args.train.output_dir)

        datasets = get_dataset(args, is_train=True)
        reader = get_reader(args, eval_answers=datasets["validation"])
        retriever = get_retriever(args)

        datasets["validation"] = retriever.retrieve(
            datasets["validation"], topk=args.retriever.topk)["validation"]
        reader.set_dataset(eval_dataset=datasets["validation"])

        trainer = reader.get_trainer()

        if args.train.do_eval:
            metric_results = trainer.evaluate()
            results = evaluation(args)

            metric_results["predictions"]["exact_match"] = results["EM"][
                "value"]
            metric_results["predictions"]["f1"] = results["F1"]["value"]

            print("EVAL RESULT")
            print(metric_results["predictions"])

        if args.train.do_eval and args.train.pororo_prediction:
            assert metric_results is not None, "trainer.evaluate()가 None을 반환합니다."

            results = evaluation(args, prefix="pororo_")

            metric_results["pororo_predictions"]["exact_match"] = results[
                "EM"]["value"]
            metric_results["pororo_predictions"]["f1"] = results["F1"]["value"]

            print("PORORO EVAL RESULT")
            print(metric_results["pororo_predictions"])

        if args.train.do_eval and args.report:
            report_reader_to_slack(args,
                                   p.basename(__file__),
                                   metric_results["predictions"],
                                   use_pororo=False)

            if args.train.pororo_prediction:
                report_reader_to_slack(args,
                                       p.basename(__file__),
                                       metric_results["pororo_predictions"],
                                       use_pororo=True)
コード例 #8
0
def train_reader(args):
    strategies = args.strategies
    seeds = args.seeds[:args.run_cnt]

    for idx, (seed, strategy) in enumerate(product(seeds, strategies)):
        wandb.init(project="p-stage-3", reinit=True)
        args = update_args(args,
                           strategy)  # auto add args.save_path, args.base_path
        args.strategy, args.seed = strategy, seed
        args.info = Namespace()
        set_seed(seed)

        # below codes must run before 'reader.get_trainer()'
        # run_name: strategy + alias + seed
        args.train.run_name = "_".join([strategy, args.alias, str(seed)])
        args.train.output_dir = p.join(args.path.checkpoint,
                                       args.train.run_name)
        wandb.run.name = args.train.run_name

        print("checkpoint_dir: ", args.train.output_dir)

        # retrieve 과정이 없어 top-k를 반환할 수 없음. 무조건 top-1만 반환
        # run_mrc.py DOES NOT execute retrieve, so args.retriever.topk cannot be n(>1).
        # If topk > 1, post processing function returns mis-bundled predictions.
        args.retriever.topk = 1

        datasets = get_dataset(args, is_train=True)
        reader = get_reader(args, eval_answers=datasets["validation"])

        reader.set_dataset(train_dataset=datasets["train"],
                           eval_dataset=datasets["validation"])

        trainer = reader.get_trainer()

        if args.train.do_train:
            train_results = trainer.train()
            print(train_results)

        metric_results = None

        if args.train.do_eval:
            metric_results = trainer.evaluate()
            results = evaluation(args)

            metric_results["predictions"]["exact_match"] = results["EM"][
                "value"]
            metric_results["predictions"]["f1"] = results["F1"]["value"]

            print("EVAL RESULT")
            print(metric_results["predictions"])

        if args.train.do_eval and args.train.pororo_prediction:
            assert metric_results is not None, "trainer.evaluate()가 None을 반환합니다."

            results = evaluation(args, prefix="pororo_")

            metric_results["pororo_predictions"]["exact_match"] = results[
                "EM"]["value"]
            metric_results["pororo_predictions"]["f1"] = results["F1"]["value"]

            print("PORORO EVAL RESULT")
            print(metric_results["pororo_predictions"])

        if args.train.do_eval and args.report:
            report_reader_to_slack(args,
                                   p.basename(__file__),
                                   metric_results["predictions"],
                                   use_pororo=False)

            if args.train.pororo_prediction:
                report_reader_to_slack(args,
                                       p.basename(__file__),
                                       metric_results["pororo_predictions"],
                                       use_pororo=True)
コード例 #9
0
 def test_valid_strategy(self, args=args):
     for strategy in strategies:
         try:
             update_args(args, strategy)
         except FileNotFoundError:
             assert False, "전략명이 맞는지 확인해주세요. "