Exemple #1
0
def compute_rouge_for_oracle():
    """
        The rec dp for oracle saves text for comparing against refecence.

    :return:
    """
    ir_rec_dp = join(path_parser.summary_rank,
                     ir_config.IR_RECORDS_DIR_NAME_TF)

    if exists(ir_rec_dp):
        raise ValueError('ir_rec_dp exists: {}'.format(ir_rec_dp))
    os.mkdir(ir_rec_dp)

    cids = tools.get_test_cc_ids()
    for cid in tqdm(cids):
        retrieval_params = {
            'model_name': ir_config.IR_MODEL_NAME_TF,
            'cid': cid,
            'filter_var': ir_config.FILTER_VAR,
            'filter': ir_config.FILTER,
            'deduplicate': ir_config.DEDUPLICATE,
            'prune': True,
        }

        retrieved_items = ir_tools.retrieve(**retrieval_params)
        summary = '\n'.join([item[-1] for item in retrieved_items])
        with open(join(ir_rec_dp, cid), mode='a', encoding='utf-8') as out_f:
            out_f.write(summary)

    performance = rouge.compute_rouge_for_ablation_study(ir_rec_dp)
    logger.info(performance)
Exemple #2
0
def tune():
    """
        Tune IR confidence / compression rate based on Recall Rouge 2.
    :return:
    """
    if ir_config.FILTER == 'conf':
        tune_range = np.arange(0.05, 1.05, 0.05)
    else:
        interval = 10
        tune_range = range(interval, 500 + interval, interval)

    ir_tune_dp = join(path_parser.summary_rank, ir_config.IR_TUNE_DIR_NAME_TF)
    ir_tune_result_fp = join(path_parser.tune, ir_config.IR_TUNE_DIR_NAME_TF)
    with open(ir_tune_result_fp, mode='a', encoding='utf-8') as out_f:
        headline = 'Filter\tRecall\tF1\n'
        out_f.write(headline)

    cids = tools.get_test_cc_ids()
    for filter_var in tune_range:
        if exists(ir_tune_dp):  # remove previous output
            shutil.rmtree(ir_tune_dp)
        os.mkdir(ir_tune_dp)

        for cid in tqdm(cids):
            retrieval_params = {
                'model_name': ir_config.IR_MODEL_NAME_TF,
                'cid': cid,
                'filter_var': filter_var,
                'filter': ir_config.FILTER,
                'deduplicate': ir_config.DEDUPLICATE,
                'prune': True,
            }

            retrieved_items = ir_tools.retrieve(**
                                                retrieval_params)  # pid, score

            passage_ids = [item[0] for item in retrieved_items]
            original_passages, _, _ = load_retrieved_passages(
                cid=cid, get_sents=True, passage_ids=passage_ids)
            passages = ['\n'.join(sents) for sents in original_passages]
            summary = '\n'.join(passages)
            print(summary)
            # print(summary)
            with open(join(ir_tune_dp, cid), mode='a',
                      encoding='utf-8') as out_f:
                out_f.write(summary)

        performance = rouge.compute_rouge_for_dev(ir_tune_dp,
                                                  tune_centrality=False)
        with open(ir_tune_result_fp, mode='a', encoding='utf-8') as out_f:
            if ir_config.FILTER == 'conf':
                rec = '{0:.2f}\t{1}\n'.format(filter_var, performance)
            else:
                rec = '{0}\t{1}\n'.format(filter_var, performance)

            out_f.write(rec)
Exemple #3
0
def score_end2end(model_name, n_iter=None, damp=0.85, use_rel_vec=True, cc_ids=None):
    dp_mode = 'r'
    dp_params = {
        'model_name': model_name,  # one model has only one suit of summary components but different ranking sys
        'n_iter': n_iter,
        'mode': dp_mode,
    }

    summ_comp_root = graph_io.get_summ_comp_root(**dp_params)
    sim_mat_dp = graph_io.get_sim_mat_dp(summ_comp_root, mode=dp_mode)
    rel_vec_dp = graph_io.get_rel_vec_dp(summ_comp_root, mode=dp_mode)
    sid2abs_dp = graph_io.get_sid2abs_dp(summ_comp_root, mode=dp_mode)

    sid2score_dp = graph_io.get_sid2score_dp(summ_comp_root, mode='w')

    dps = {
        'sim_mat_dp': sim_mat_dp,
        'rel_vec_dp': rel_vec_dp,
        'sid2abs_dp': sid2abs_dp,
    }

    if not cc_ids:
        cc_ids = tools.get_test_cc_ids()
    
    for cid in tqdm(cc_ids):
        comp_params = {
            **dps,
            'cid': cid,
        }
        components = graph_io.load_components(**comp_params)
        # logger.info('[GRAPH RANK 1/2] successfully loaded components')

        abs2sid = {}
        for sid, abs in components['sid2abs'].items():
            abs2sid[abs] = sid

        scoring_params = {
            'sim_mat': components['sim_mat'],
            'rel_vec': components['rel_vec'].transpose() if use_rel_vec else None,
            # 'rel_vec': components['rel_vec'] if use_rel_vec else None,
            'cid': cid,
            'damp': damp,
            'abs2sid': abs2sid,
            # 'rm_dialog': rm_dialog,
        }

        sid2score = _score_graph_initially(**scoring_params)
        graph_io.dump_sid2score(sid2score=sid2score, sid2score_dp=sid2score_dp, cid=cid)

        # logger.info('[GRAPH RANK 2/2] successfully completed initial scoring')

    logger.info('[GRAPH RANK] Finished. Scores were dumped to: {}'.format(sid2score_dp))
Exemple #4
0
def select_for_ablation_study(model_name,
                              cos_threshold,
                              cc_ids=None,
                              ref_dp=None):
    """
        For ablation study.

        Typically for evaluation of model w/o centrality module.

        For DUC datasets,  cc_ids and ref_dp will be defaultly set.
        For TDQFS dataset, cc_ids and ref_dp need to be specified.

    :return:
    """
    text_params = {
        'model_name': model_name,
        'cos_threshold': cos_threshold,
        'n_iter': None,
        'diversity_param_tuple': None,
        'extra': None,
    }
    text_dp = tools.init_text_dp(**text_params)
    rank_dp = tools.get_rank_dp(model_name)

    base_selector_params = {
        'text_dp': text_dp,
        'cos_threshold': cos_threshold,
        'max_n_summary_words': 500,
    }

    if not cc_ids:
        cc_ids = tools.get_test_cc_ids()

    for cid in tqdm(cc_ids):
        rank_fp = join(rank_dp, cid)
        selector_params = {
            **base_selector_params,
            'cid': cid,
            'rank_fp': rank_fp,
        }

        selector = SelectorNaive(**selector_params)
        selector.gen_and_dump_summary()

    logger.info(
        '[SELECT SENT] successfully dumped selected sentences to: {}'.format(
            text_dp))
    output = rouge.compute_rouge_for_ablation_study(text_dp, ref_dp)
    return output
Exemple #5
0
def rank_e2e():
    rank_dp = tools.get_rank_dp(model_name=MODEL_NAME)

    if exists(rank_dp):
        raise ValueError('rank_dp exists: {}'.format(rank_dp))
    os.mkdir(rank_dp)

    cc_ids = tools.get_test_cc_ids()
    for cid in tqdm(cc_ids):
        rank_records = _lexrank(cid)
        rank_sent.dump_rank_records(rank_records,
                                    out_fp=join(rank_dp, cid),
                                    with_rank_idx=False)

    logger.info('Successfully dumped rankings to: {}'.format(rank_dp))
Exemple #6
0
    def __init__(self, tokenize_narr, query_type=None):
        # fixme: this class may not work right; check type(narr).
        if query_type == config.TITLE:
            query_dict = dataset_parser.get_cid2title()
        elif query_type == config.NARR:
            query_dict = dataset_parser.get_cid2narr()
        else:
            query_dict = dataset_parser.get_cid2query(tokenize_narr)

        cids = tools.get_test_cc_ids()

        self.loader_init_params = []
        for cid in cids:
            query = query_dict[cid]
            self.loader_init_params.append({
                'cid': cid,
                'query': query,
            })
Exemple #7
0
    def __init__(self, tokenize_narr, query_type, retrieve_dp):
        if query_type == config.TITLE:
            query_dict = dataset_parser.get_cid2title()
        elif query_type == config.NARR:
            query_dict = dataset_parser.get_cid2narr()
        elif query_type == config.QUERY:
            query_dict = dataset_parser.get_cid2query(tokenize_narr)
        else:
            raise ValueError('Invalid query_type: {}'.format(query_type))

        cids = tools.get_test_cc_ids()

        self.loader_init_params = []
        for cid in cids:
            query = query_dict[cid]
            self.loader_init_params.append({
                'cid': cid,
                'query': query,
                'retrieve_dp': retrieve_dp,
            })
Exemple #8
0
def select_for_dev(
    rank_dp,
    text_dp,
    cos_threshold,
    rel_sents_dp=None,
    retrieved_dp=None,
):
    """
        For final tuning.

        Rouge-2 are evaluated with "-l 250"
    :return:
    """
    # make dump dir
    cc_ids = tools.get_test_cc_ids()

    base_selector_params = {
        'text_dp': text_dp,
        'cos_threshold': cos_threshold,
        'max_n_summary_words': 500,
        'rel_sents_dp': rel_sents_dp,
        'retrieved_dp': retrieved_dp,
    }

    for cid in tqdm(cc_ids):
        rank_fp = join(rank_dp, cid)
        selector_params = {
            **base_selector_params,
            'cid': cid,
            'rank_fp': rank_fp,
        }

        selector = Selector(**selector_params)
        selector.gen_and_dump_summary()

    logger.info(
        '[SELECT SENT] successfully dumped selected sentences to: {}'.format(
            text_dp))
    output = rouge.compute_rouge_for_dev(text_dp, tune_centrality=True)
    return output
Exemple #9
0
def ir_rank2records():
    ir_rec_dp = join(path_parser.summary_rank,
                     ir_config.IR_RECORDS_DIR_NAME_TF)

    if exists(ir_rec_dp):
        raise ValueError('qa_rec_dp exists: {}'.format(ir_rec_dp))
    os.mkdir(ir_rec_dp)

    cids = tools.get_test_cc_ids()
    for cid in tqdm(cids):
        retrieval_params = {
            'model_name': ir_config.IR_MODEL_NAME_TF,
            'cid': cid,
            'filter_var': ir_config.FILTER_VAR,
            'filter': ir_config.FILTER,
            'deduplicate': ir_config.DEDUPLICATE,
            'prune': True,
        }

        retrieved_items = ir_tools.retrieve(**retrieval_params)
        ir_tools.dump_retrieval(fp=join(ir_rec_dp, cid),
                                retrieved_items=retrieved_items)
Exemple #10
0
def build_test_cid_query_dicts(tokenize_narr, concat_title_narr, query_type=None):
    """

    :param tokenize_narr: bool
    :param concat_title_narr: bool
    :return:
    """
    query_info = dict()
    for year in config.years:
        query_params = {
            'year': year,
            'tokenize_narr': tokenize_narr,
            'concat_title_narr': concat_title_narr,
        }

        annual_query_info = dataset_parser.build_query_info(**query_params)
        query_info = {
            **annual_query_info,
            **query_info,
        }

    cids = tools.get_test_cc_ids()
    test_cid_query_dicts = []

    for cid in cids:
        query = tools.get_query_w_cid(query_info, cid=cid)

        if query_type:
            query = query[query_type]

        print('query: {}'.format(query))
        test_cid_query_dicts.append({
            'cid': cid,
            'query': query,
        })

    return test_cid_query_dicts
Exemple #11
0
def build_oracle_test_cid_query_dicts():
    def _get_ref(cid):
        REF_DP = join(path_parser.data_summary_targets, config.test_year)
        fp = join(REF_DP, '{}_{}'.format(cid, 1))
        ref = ''
        lines = io.open(fp, encoding='utf-8').readlines()
        for line in lines:
            ref += line.rstrip('\n')

        return ref

    test_cid_query_dicts = []
    cids = tools.get_test_cc_ids()

    for cid in cids:
        ref = _get_ref(cid)
        logger.info('cid {}: {}'.format(cid, ref))

        test_cid_query_dicts.append({
            'cid': cid,
            'query': ref,
        })

    return test_cid_query_dicts
use_tdqfs = 'tdqfs' in centrality_config.QA_RECORD_DIR_NAME

if use_tdqfs:
    sentence_dp = path_parser.data_tdqfs_sentences
    query_fp = path_parser.data_tdqfs_queries
    tdqfs_summary_target_dp = path_parser.data_tdqfs_summary_targets

    test_cid_query_dicts = general_tools.build_tdqfs_cid_query_dicts(
        query_fp=query_fp, proc=True)
    cc_ids = [cq_dict['cid'] for cq_dict in test_cid_query_dicts]
else:
    test_cid_query_dicts = general_tools.build_test_cid_query_dicts(
        tokenize_narr=False,
        concat_title_narr=False,
        query_type=centrality_config.QUERY_TYPE)
    cc_ids = get_test_cc_ids()


def _load_rel_scores(cid, ir_record_dp):
    ir_record_fp = join(ir_record_dp, cid)
    ir_records = io.open(ir_record_fp, encoding='utf-8').readlines()
    ir_scores = [float(line.split('\t')[1]) for line in ir_records]
    ir_rel_scores = np.array(ir_scores)
    return ir_rel_scores


def _build_components(cid, query):
    sim_items = tfidf_tools.build_sim_items_e2e(cid,
                                                query,
                                                mask_intra=None,
                                                max_ns_doc=None,
Exemple #13
0
def rank_end2end(model_name,
                 diversity_param_tuple,
                 component_name=None,
                 n_iter=None,
                 rank_dp=None,
                 retrieved_dp=None,
                 rm_dialog=True,
                 cc_ids=None):
    """

    :param model_name:
    :param diversity_param_tuple:
    :param component_name:
    :param n_iter:
    :param rank_dp:
    :param retrieved_dp:
    :param rm_dialog: only useful when retrieved_dp=None
    :return:
    """
    dp_mode = 'r'
    dp_params = {
        'n_iter': n_iter,
        'mode': dp_mode,
    }

    diversity_weight, diversity_algorithm = diversity_param_tuple

    # todo: double check this condition; added later for avoiding bug for centrality-tfidf.
    # # one model has only one suit of summary components but different ranking sys
    if component_name:
        dp_params['model_name'] = component_name
    else:
        dp_params['model_name'] = model_name

    summ_comp_root = graph_io.get_summ_comp_root(**dp_params)
    sim_mat_dp = graph_io.get_sim_mat_dp(summ_comp_root, mode=dp_mode)
    rel_vec_dp = graph_io.get_rel_vec_dp(summ_comp_root, mode=dp_mode)
    sid2abs_dp = graph_io.get_sid2abs_dp(summ_comp_root, mode=dp_mode)
    sid2score_dp = graph_io.get_sid2score_dp(summ_comp_root, mode=dp_mode)

    if not rank_dp:
        rank_dp_params = {
            'model_name': model_name,
            'n_iter': n_iter,
            'diversity_param_tuple': diversity_param_tuple,
        }

        rank_dp = tools.get_rank_dp(**rank_dp_params)

    if exists(rank_dp):
        raise ValueError('rank_dp exists: {}'.format(rank_dp))
    os.mkdir(rank_dp)

    dps = {
        'sim_mat_dp': sim_mat_dp,
        'rel_vec_dp': rel_vec_dp,
        'sid2abs_dp': sid2abs_dp,
    }

    if not cc_ids:
        cc_ids = tools.get_test_cc_ids()
    
    for cid in tqdm(cc_ids):
        # logger.info('cid: {}'.format(cid))
        comp_params = {
            **dps,
            'cid': cid,
        }
        components = graph_io.load_components(**comp_params)
        # logger.info('[GRAPH RANK 1/2] successfully loaded components')
        sid2score = graph_io.load_sid2score(sid2score_dp, cid)

        if retrieved_dp:
            original_sents, _ = load_retrieved_sentences(retrieved_dp=retrieved_dp, cid=cid)
        else:
            if 'tdqfs' in config.test_year:
                original_sents, _ = dataset_parser.cid2sents_tdqfs(cid)
            else:
                original_sents, _ = dataset_parser.cid2sents(cid, rm_dialog=rm_dialog)  # 2d lists, docs => sents

        diversity_params = {
            'sid2score': sid2score,
            'sid2abs': components['sid2abs'],
            'sim_mat': components['sim_mat'],
            'original_sents': original_sents,
        }

        if diversity_algorithm == 'wan':
            diversity_params['omega'] = diversity_weight
            rank_records = _rank_with_diversity_penalty_wan(**diversity_params)
        else:
            raise ValueError('Invalid diversity_algorithm: {}'.format(diversity_algorithm))

        logger.info('cid: {}, #rank_records: {}'.format(cid, len(rank_records)))
        rank_sent.dump_rank_records(rank_records, out_fp=join(rank_dp, cid), with_rank_idx=False)

    logger.info('[GRAPH RANK] Finished. Rankings were dumped to: {}'.format(rank_dp))
Exemple #14
0
def select_end2end(model_name,
                   n_iter=None,
                   diversity_param_tuple=None,
                   cos_threshold=None,
                   max_n_summary_words=500,
                   extra=None,
                   rank_model_name=None,
                   rel_sents_dp=None,
                   retrieved_dp=None,
                   rm_dialog=True,
                   cc_ids=None):
    """

    :param model_name:
    :param n_iter:
    :param sort: date or origin
    :param duplication: if True, duplicated generated summary $n_refs$.
    :param attn_weigh: bool
    :param para_weigh: bool
    :param doc_weigh: bool
    :param cos_threshold: 0.5, 0.6
    :param max_n_summary_words: 500
    :param rank_model_name: you can specify rank_model_name; default is set to model_name.
    :param rm_dialog: only useful when retrieved_dp=None
    :return:
    """
    # make dump dir
    text_params = {
        'model_name': model_name,
        'cos_threshold': cos_threshold,
        'n_iter': n_iter,
        'diversity_param_tuple': diversity_param_tuple,
        'extra': extra,
    }

    text_dp = tools.init_text_dp(**text_params)
    # date_sorted_doc_indices
    if not cc_ids:
        cc_ids = tools.get_test_cc_ids()

    base_selector_params = {
        'text_dp': text_dp,
        'cos_threshold': cos_threshold,
        'max_n_summary_words': max_n_summary_words,
        'rel_sents_dp': rel_sents_dp,
        'retrieved_dp': retrieved_dp,
    }

    # logger.info('[SELECT SENTS] selecting sents for {} clusters'.format(len(cc_ids)))
    if not rank_model_name:
        rank_model_name = model_name

    rank_dp = tools.get_rank_dp(rank_model_name,
                                n_iter=n_iter,
                                diversity_param_tuple=diversity_param_tuple,
                                extra=extra)

    for cid in tqdm(cc_ids):
        rank_fp = join(rank_dp, cid)
        selector_params = {
            **base_selector_params,
            'cid': cid,
            'rank_fp': rank_fp,
            'rm_dialog': rm_dialog,
        }

        selector = Selector(**selector_params)
        selector.gen_and_dump_summary()

    logger.info(
        '[SELECT SENT] successfully dumped selected sentences to: {}'.format(
            text_dp))

    output = rouge.compute_rouge_end2end(**text_params)
    return output
Exemple #15
0
                     qa_config.RELEVANCE_SCORE_DIR_NAME)
rank_dp = join(path_parser.summary_rank, qa_config.QA_MODEL_NAME_BERT)
ir_rec_dp = join(path_parser.summary_rank, qa_config.IR_RECORDS_DIR_NAME)

use_tdqfs = 'tdqfs' in qa_config.IR_RECORDS_DIR_NAME

if use_tdqfs:
    sentence_dp = path_parser.data_tdqfs_sentences
    query_fp = path_parser.data_tdqfs_queries
    tdqfs_summary_target_dp = path_parser.data_tdqfs_summary_targets

    test_cid_query_dicts = general_tools.build_tdqfs_cid_query_dicts(
        query_fp=query_fp, proc=True)
    cids = [cq_dict['cid'] for cq_dict in test_cid_query_dicts]
else:
    cids = tools.get_test_cc_ids()


def init():
    # parse args
    parser = ArgumentParser()
    parser.add_argument(
        'n_devices',
        nargs='?',
        default=4,
        help='num of devices on which model will be running on')

    args = parser.parse_args()
    all_device_ids = [0, 1, 2, 3]
    device = all_device_ids[:int(args.n_devices)]
    config_meta['device'] = device