Ejemplo n.º 1
0
def qa_rank2records_in_batch():
    if qa_config.FILTER == 'conf':
        filter_var_range = np.arange(0.05, 1.05, 0.05)
    else:  # topK
        interval = 10
        if qa_config.ir_config.FILTER == 'topK':
            start = interval
            end = qa_config.ir_config.FILTER_VAR + interval
        else:
            start = 40
            end = 150 + interval
        filter_var_range = range(start, end, interval)

    for filter_var in tqdm(filter_var_range):
        qa_rec_dn = qa_config.QA_RECORD_DIR_NAME_PATTERN.format(
            qa_config.QA_MODEL_NAME_BERT, filter_var, qa_config.FILTER)
        qa_rec_dp = join(path_parser.summary_rank, qa_rec_dn)

        if exists(qa_rec_dp):
            raise ValueError('qa_rec_dp exists: {}'.format(qa_rec_dp))
        os.mkdir(qa_rec_dp)

        for cid in cids:
            retrieval_params = {
                'model_name': qa_config.QA_MODEL_NAME_BERT,
                'cid': cid,
                'filter_var': filter_var,
                'filter': qa_config.FILTER,
                'deduplicate': None,
            }

            retrieved_items = ir_tools.retrieve(**retrieval_params)
            ir_tools.dump_retrieval(fp=join(qa_rec_dp, cid),
                                    retrieved_items=retrieved_items)
Ejemplo n.º 2
0
def rank2records_in_batch():
    interval = 10
    start = 20
    end = 150 + interval
    filter_var_range = range(start, end, interval)

    for filter_var in tqdm(filter_var_range):
        qa_rec_dn = ensemble_config.QA_RECORD_DIR_NAME_PATTERN.format(ensemble_config.MODEL_NAME,
                                                                      filter_var,
                                                                      ensemble_config.FILTER)
        qa_rec_dp = join(path_parser.summary_rank, qa_rec_dn)

        if exists(qa_rec_dp):
            raise ValueError('qa_rec_dp exists: {}'.format(qa_rec_dp))
        os.mkdir(qa_rec_dp)

        for cid in cids:
            retrieval_params = {
                'model_name': ensemble_config.MODEL_NAME,
                'cid': cid,
                'filter_var': filter_var,
                'filter': ensemble_config.FILTER,
                'deduplicate': None,
            }

            retrieved_items = ir_tools.retrieve(**retrieval_params)
            ir_tools.dump_retrieval(fp=join(qa_rec_dp, cid), retrieved_items=retrieved_items)
Ejemplo n.º 3
0
def compute_rouge_for_oracle():
    """
        The rec dp for oracle saves text for comparing against refecence.

    :return:
    """
    ir_rec_dp = join(path_parser.summary_rank,
                     ir_config.IR_RECORDS_DIR_NAME_TF)

    if exists(ir_rec_dp):
        raise ValueError('ir_rec_dp exists: {}'.format(ir_rec_dp))
    os.mkdir(ir_rec_dp)

    cids = tools.get_test_cc_ids()
    for cid in tqdm(cids):
        retrieval_params = {
            'model_name': ir_config.IR_MODEL_NAME_TF,
            'cid': cid,
            'filter_var': ir_config.FILTER_VAR,
            'filter': ir_config.FILTER,
            'deduplicate': ir_config.DEDUPLICATE,
            'prune': True,
        }

        retrieved_items = ir_tools.retrieve(**retrieval_params)
        summary = '\n'.join([item[-1] for item in retrieved_items])
        with open(join(ir_rec_dp, cid), mode='a', encoding='utf-8') as out_f:
            out_f.write(summary)

    performance = rouge.compute_rouge_for_ablation_study(ir_rec_dp)
    logger.info(performance)
Ejemplo n.º 4
0
def tune():
    """
        Tune IR confidence / compression rate based on Recall Rouge 2.
    :return:
    """
    if ir_config.FILTER == 'conf':
        tune_range = np.arange(0.05, 1.05, 0.05)
    else:
        interval = 10
        tune_range = range(interval, 500 + interval, interval)

    ir_tune_dp = join(path_parser.summary_rank, ir_config.IR_TUNE_DIR_NAME_TF)
    ir_tune_result_fp = join(path_parser.tune, ir_config.IR_TUNE_DIR_NAME_TF)
    with open(ir_tune_result_fp, mode='a', encoding='utf-8') as out_f:
        headline = 'Filter\tRecall\tF1\n'
        out_f.write(headline)

    cids = tools.get_test_cc_ids()
    for filter_var in tune_range:
        if exists(ir_tune_dp):  # remove previous output
            shutil.rmtree(ir_tune_dp)
        os.mkdir(ir_tune_dp)

        for cid in tqdm(cids):
            retrieval_params = {
                'model_name': ir_config.IR_MODEL_NAME_TF,
                'cid': cid,
                'filter_var': filter_var,
                'filter': ir_config.FILTER,
                'deduplicate': ir_config.DEDUPLICATE,
                'prune': True,
            }

            retrieved_items = ir_tools.retrieve(**
                                                retrieval_params)  # pid, score

            passage_ids = [item[0] for item in retrieved_items]
            original_passages, _, _ = load_retrieved_passages(
                cid=cid, get_sents=True, passage_ids=passage_ids)
            passages = ['\n'.join(sents) for sents in original_passages]
            summary = '\n'.join(passages)
            print(summary)
            # print(summary)
            with open(join(ir_tune_dp, cid), mode='a',
                      encoding='utf-8') as out_f:
                out_f.write(summary)

        performance = rouge.compute_rouge_for_dev(ir_tune_dp,
                                                  tune_centrality=False)
        with open(ir_tune_result_fp, mode='a', encoding='utf-8') as out_f:
            if ir_config.FILTER == 'conf':
                rec = '{0:.2f}\t{1}\n'.format(filter_var, performance)
            else:
                rec = '{0}\t{1}\n'.format(filter_var, performance)

            out_f.write(rec)
Ejemplo n.º 5
0
def tune():
    """
        Tune QA confidence / compression rate / topK
        based on Recall Rouge 2.
    :return:
    """
    if qa_config.FILTER == 'conf':
        tune_range = np.arange(0.05, 1.05, 0.05)
    else:  # topK
        interval = 10
        if qa_config.ir_config.FILTER == 'topK':
            end = qa_config.ir_config.FILTER_VAR + interval
        else:
            end = 200 + interval
        tune_range = range(interval, end, interval)

    qa_tune_dp = join(path_parser.summary_rank,
                      qa_config.QA_TUNE_DIR_NAME_BERT)
    qa_tune_result_fp = join(path_parser.tune, qa_config.QA_TUNE_DIR_NAME_BERT)
    with open(qa_tune_result_fp, mode='a', encoding='utf-8') as out_f:
        headline = 'Filter\tRecall\tF1\n'
        out_f.write(headline)

    for filter_var in tune_range:
        if exists(qa_tune_dp):  # remove previous output
            shutil.rmtree(qa_tune_dp)
        os.mkdir(qa_tune_dp)

        for cid in tqdm(cids):
            retrieval_params = {
                'model_name': qa_config.QA_MODEL_NAME_BERT,
                'cid': cid,
                'filter_var': filter_var,
                'filter': qa_config.FILTER,
                'deduplicate': None,
            }

            retrieved_items = ir_tools.retrieve(**retrieval_params)
            summary = '\n'.join([item[-1] for item in retrieved_items])
            # print(summary)
            with open(join(qa_tune_dp, cid), mode='a',
                      encoding='utf-8') as out_f:
                out_f.write(summary)

        performance = rouge.compute_rouge_for_dev(qa_tune_dp,
                                                  tune_centrality=False)
        with open(qa_tune_result_fp, mode='a', encoding='utf-8') as out_f:
            if qa_config.FILTER == 'conf':
                rec = '{0:.2f}\t{1}\n'.format(filter_var, performance)
            else:
                rec = '{0}\t{1}\n'.format(filter_var, performance)

            out_f.write(rec)
Ejemplo n.º 6
0
def rank2records():
    rec_dp = join(path_parser.summary_rank, ensemble_config.QA_RECORD_DIR_NAME)

    if exists(rec_dp):
        raise ValueError('rec_dp exists: {}'.format(rec_dp))
    os.mkdir(rec_dp)

    for cid in tqdm(cids):
        retrieval_params = {
            'model_name': ensemble_config.MODEL_NAME,
            'cid': cid,
            'filter_var': ensemble_config.FILTER_VAR,
            'filter': ensemble_config.FILTER,
            'deduplicate': None,
        }

        retrieved_items = ir_tools.retrieve(**retrieval_params)
        ir_tools.dump_retrieval(fp=join(rec_dp, cid), retrieved_items=retrieved_items)
Ejemplo n.º 7
0
def ir_rank2records():
    ir_rec_dp = join(path_parser.summary_rank,
                     ir_config.IR_RECORDS_DIR_NAME_TF)

    if exists(ir_rec_dp):
        raise ValueError('qa_rec_dp exists: {}'.format(ir_rec_dp))
    os.mkdir(ir_rec_dp)

    for cid in tqdm(cids):
        retrieval_params = {
            'model_name': ir_config.IR_MODEL_NAME_TF,
            'cid': cid,
            'filter_var': ir_config.FILTER_VAR,
            'filter': ir_config.FILTER,
            'deduplicate': ir_config.DEDUPLICATE,
            'prune': True,
        }

        retrieved_items = ir_tools.retrieve(**retrieval_params)
        ir_tools.dump_retrieval(fp=join(ir_rec_dp, cid),
                                retrieved_items=retrieved_items)
Ejemplo n.º 8
0
def ir_rank2records():
    ir_rec_dp = join(path_parser.summary_rank,
                     ir_config.IR_RECORDS_DIR_NAME_TF)
    assert not exists(ir_rec_dp), f'ir_rec_dp exists: {ir_rec_dp}'
    os.mkdir(ir_rec_dp)

    # cids = tools.get_test_cc_ids()
    cids = [c_q_dict['cid'] for c_q_dict in test_cid_query_dicts]
    for cid in tqdm(cids):
        retrieval_params = {
            'model_name': ir_config.IR_MODEL_NAME_TF,
            'cid': cid,
            'filter_var': ir_config.FILTER_VAR,
            'filter': ir_config.FILTER,
            'deduplicate': ir_config.DEDUPLICATE,
            # 'prune': True,
            'prune': False,
        }

        retrieved_items = ir_tools.retrieve(**retrieval_params)
        ir_tools.dump_retrieval(fp=join(ir_rec_dp, cid),
                                retrieved_items=retrieved_items)