def compute_rouge_for_oracle(): """ The rec dp for oracle saves text for comparing against refecence. :return: """ ir_rec_dp = join(path_parser.summary_rank, ir_config.IR_RECORDS_DIR_NAME_TF) if exists(ir_rec_dp): raise ValueError('ir_rec_dp exists: {}'.format(ir_rec_dp)) os.mkdir(ir_rec_dp) cids = tools.get_test_cc_ids() for cid in tqdm(cids): retrieval_params = { 'model_name': ir_config.IR_MODEL_NAME_TF, 'cid': cid, 'filter_var': ir_config.FILTER_VAR, 'filter': ir_config.FILTER, 'deduplicate': ir_config.DEDUPLICATE, 'prune': True, } retrieved_items = ir_tools.retrieve(**retrieval_params) summary = '\n'.join([item[-1] for item in retrieved_items]) with open(join(ir_rec_dp, cid), mode='a', encoding='utf-8') as out_f: out_f.write(summary) performance = rouge.compute_rouge_for_ablation_study(ir_rec_dp) logger.info(performance)
def tune(): """ Tune IR confidence / compression rate based on Recall Rouge 2. :return: """ if ir_config.FILTER == 'conf': tune_range = np.arange(0.05, 1.05, 0.05) else: interval = 10 tune_range = range(interval, 500 + interval, interval) ir_tune_dp = join(path_parser.summary_rank, ir_config.IR_TUNE_DIR_NAME_TF) ir_tune_result_fp = join(path_parser.tune, ir_config.IR_TUNE_DIR_NAME_TF) with open(ir_tune_result_fp, mode='a', encoding='utf-8') as out_f: headline = 'Filter\tRecall\tF1\n' out_f.write(headline) cids = tools.get_test_cc_ids() for filter_var in tune_range: if exists(ir_tune_dp): # remove previous output shutil.rmtree(ir_tune_dp) os.mkdir(ir_tune_dp) for cid in tqdm(cids): retrieval_params = { 'model_name': ir_config.IR_MODEL_NAME_TF, 'cid': cid, 'filter_var': filter_var, 'filter': ir_config.FILTER, 'deduplicate': ir_config.DEDUPLICATE, 'prune': True, } retrieved_items = ir_tools.retrieve(** retrieval_params) # pid, score passage_ids = [item[0] for item in retrieved_items] original_passages, _, _ = load_retrieved_passages( cid=cid, get_sents=True, passage_ids=passage_ids) passages = ['\n'.join(sents) for sents in original_passages] summary = '\n'.join(passages) print(summary) # print(summary) with open(join(ir_tune_dp, cid), mode='a', encoding='utf-8') as out_f: out_f.write(summary) performance = rouge.compute_rouge_for_dev(ir_tune_dp, tune_centrality=False) with open(ir_tune_result_fp, mode='a', encoding='utf-8') as out_f: if ir_config.FILTER == 'conf': rec = '{0:.2f}\t{1}\n'.format(filter_var, performance) else: rec = '{0}\t{1}\n'.format(filter_var, performance) out_f.write(rec)
def score_end2end(model_name, n_iter=None, damp=0.85, use_rel_vec=True, cc_ids=None): dp_mode = 'r' dp_params = { 'model_name': model_name, # one model has only one suit of summary components but different ranking sys 'n_iter': n_iter, 'mode': dp_mode, } summ_comp_root = graph_io.get_summ_comp_root(**dp_params) sim_mat_dp = graph_io.get_sim_mat_dp(summ_comp_root, mode=dp_mode) rel_vec_dp = graph_io.get_rel_vec_dp(summ_comp_root, mode=dp_mode) sid2abs_dp = graph_io.get_sid2abs_dp(summ_comp_root, mode=dp_mode) sid2score_dp = graph_io.get_sid2score_dp(summ_comp_root, mode='w') dps = { 'sim_mat_dp': sim_mat_dp, 'rel_vec_dp': rel_vec_dp, 'sid2abs_dp': sid2abs_dp, } if not cc_ids: cc_ids = tools.get_test_cc_ids() for cid in tqdm(cc_ids): comp_params = { **dps, 'cid': cid, } components = graph_io.load_components(**comp_params) # logger.info('[GRAPH RANK 1/2] successfully loaded components') abs2sid = {} for sid, abs in components['sid2abs'].items(): abs2sid[abs] = sid scoring_params = { 'sim_mat': components['sim_mat'], 'rel_vec': components['rel_vec'].transpose() if use_rel_vec else None, # 'rel_vec': components['rel_vec'] if use_rel_vec else None, 'cid': cid, 'damp': damp, 'abs2sid': abs2sid, # 'rm_dialog': rm_dialog, } sid2score = _score_graph_initially(**scoring_params) graph_io.dump_sid2score(sid2score=sid2score, sid2score_dp=sid2score_dp, cid=cid) # logger.info('[GRAPH RANK 2/2] successfully completed initial scoring') logger.info('[GRAPH RANK] Finished. Scores were dumped to: {}'.format(sid2score_dp))
def select_for_ablation_study(model_name, cos_threshold, cc_ids=None, ref_dp=None): """ For ablation study. Typically for evaluation of model w/o centrality module. For DUC datasets, cc_ids and ref_dp will be defaultly set. For TDQFS dataset, cc_ids and ref_dp need to be specified. :return: """ text_params = { 'model_name': model_name, 'cos_threshold': cos_threshold, 'n_iter': None, 'diversity_param_tuple': None, 'extra': None, } text_dp = tools.init_text_dp(**text_params) rank_dp = tools.get_rank_dp(model_name) base_selector_params = { 'text_dp': text_dp, 'cos_threshold': cos_threshold, 'max_n_summary_words': 500, } if not cc_ids: cc_ids = tools.get_test_cc_ids() for cid in tqdm(cc_ids): rank_fp = join(rank_dp, cid) selector_params = { **base_selector_params, 'cid': cid, 'rank_fp': rank_fp, } selector = SelectorNaive(**selector_params) selector.gen_and_dump_summary() logger.info( '[SELECT SENT] successfully dumped selected sentences to: {}'.format( text_dp)) output = rouge.compute_rouge_for_ablation_study(text_dp, ref_dp) return output
def rank_e2e(): rank_dp = tools.get_rank_dp(model_name=MODEL_NAME) if exists(rank_dp): raise ValueError('rank_dp exists: {}'.format(rank_dp)) os.mkdir(rank_dp) cc_ids = tools.get_test_cc_ids() for cid in tqdm(cc_ids): rank_records = _lexrank(cid) rank_sent.dump_rank_records(rank_records, out_fp=join(rank_dp, cid), with_rank_idx=False) logger.info('Successfully dumped rankings to: {}'.format(rank_dp))
def __init__(self, tokenize_narr, query_type=None): # fixme: this class may not work right; check type(narr). if query_type == config.TITLE: query_dict = dataset_parser.get_cid2title() elif query_type == config.NARR: query_dict = dataset_parser.get_cid2narr() else: query_dict = dataset_parser.get_cid2query(tokenize_narr) cids = tools.get_test_cc_ids() self.loader_init_params = [] for cid in cids: query = query_dict[cid] self.loader_init_params.append({ 'cid': cid, 'query': query, })
def __init__(self, tokenize_narr, query_type, retrieve_dp): if query_type == config.TITLE: query_dict = dataset_parser.get_cid2title() elif query_type == config.NARR: query_dict = dataset_parser.get_cid2narr() elif query_type == config.QUERY: query_dict = dataset_parser.get_cid2query(tokenize_narr) else: raise ValueError('Invalid query_type: {}'.format(query_type)) cids = tools.get_test_cc_ids() self.loader_init_params = [] for cid in cids: query = query_dict[cid] self.loader_init_params.append({ 'cid': cid, 'query': query, 'retrieve_dp': retrieve_dp, })
def select_for_dev( rank_dp, text_dp, cos_threshold, rel_sents_dp=None, retrieved_dp=None, ): """ For final tuning. Rouge-2 are evaluated with "-l 250" :return: """ # make dump dir cc_ids = tools.get_test_cc_ids() base_selector_params = { 'text_dp': text_dp, 'cos_threshold': cos_threshold, 'max_n_summary_words': 500, 'rel_sents_dp': rel_sents_dp, 'retrieved_dp': retrieved_dp, } for cid in tqdm(cc_ids): rank_fp = join(rank_dp, cid) selector_params = { **base_selector_params, 'cid': cid, 'rank_fp': rank_fp, } selector = Selector(**selector_params) selector.gen_and_dump_summary() logger.info( '[SELECT SENT] successfully dumped selected sentences to: {}'.format( text_dp)) output = rouge.compute_rouge_for_dev(text_dp, tune_centrality=True) return output
def ir_rank2records(): ir_rec_dp = join(path_parser.summary_rank, ir_config.IR_RECORDS_DIR_NAME_TF) if exists(ir_rec_dp): raise ValueError('qa_rec_dp exists: {}'.format(ir_rec_dp)) os.mkdir(ir_rec_dp) cids = tools.get_test_cc_ids() for cid in tqdm(cids): retrieval_params = { 'model_name': ir_config.IR_MODEL_NAME_TF, 'cid': cid, 'filter_var': ir_config.FILTER_VAR, 'filter': ir_config.FILTER, 'deduplicate': ir_config.DEDUPLICATE, 'prune': True, } retrieved_items = ir_tools.retrieve(**retrieval_params) ir_tools.dump_retrieval(fp=join(ir_rec_dp, cid), retrieved_items=retrieved_items)
def build_test_cid_query_dicts(tokenize_narr, concat_title_narr, query_type=None): """ :param tokenize_narr: bool :param concat_title_narr: bool :return: """ query_info = dict() for year in config.years: query_params = { 'year': year, 'tokenize_narr': tokenize_narr, 'concat_title_narr': concat_title_narr, } annual_query_info = dataset_parser.build_query_info(**query_params) query_info = { **annual_query_info, **query_info, } cids = tools.get_test_cc_ids() test_cid_query_dicts = [] for cid in cids: query = tools.get_query_w_cid(query_info, cid=cid) if query_type: query = query[query_type] print('query: {}'.format(query)) test_cid_query_dicts.append({ 'cid': cid, 'query': query, }) return test_cid_query_dicts
def build_oracle_test_cid_query_dicts(): def _get_ref(cid): REF_DP = join(path_parser.data_summary_targets, config.test_year) fp = join(REF_DP, '{}_{}'.format(cid, 1)) ref = '' lines = io.open(fp, encoding='utf-8').readlines() for line in lines: ref += line.rstrip('\n') return ref test_cid_query_dicts = [] cids = tools.get_test_cc_ids() for cid in cids: ref = _get_ref(cid) logger.info('cid {}: {}'.format(cid, ref)) test_cid_query_dicts.append({ 'cid': cid, 'query': ref, }) return test_cid_query_dicts
use_tdqfs = 'tdqfs' in centrality_config.QA_RECORD_DIR_NAME if use_tdqfs: sentence_dp = path_parser.data_tdqfs_sentences query_fp = path_parser.data_tdqfs_queries tdqfs_summary_target_dp = path_parser.data_tdqfs_summary_targets test_cid_query_dicts = general_tools.build_tdqfs_cid_query_dicts( query_fp=query_fp, proc=True) cc_ids = [cq_dict['cid'] for cq_dict in test_cid_query_dicts] else: test_cid_query_dicts = general_tools.build_test_cid_query_dicts( tokenize_narr=False, concat_title_narr=False, query_type=centrality_config.QUERY_TYPE) cc_ids = get_test_cc_ids() def _load_rel_scores(cid, ir_record_dp): ir_record_fp = join(ir_record_dp, cid) ir_records = io.open(ir_record_fp, encoding='utf-8').readlines() ir_scores = [float(line.split('\t')[1]) for line in ir_records] ir_rel_scores = np.array(ir_scores) return ir_rel_scores def _build_components(cid, query): sim_items = tfidf_tools.build_sim_items_e2e(cid, query, mask_intra=None, max_ns_doc=None,
def rank_end2end(model_name, diversity_param_tuple, component_name=None, n_iter=None, rank_dp=None, retrieved_dp=None, rm_dialog=True, cc_ids=None): """ :param model_name: :param diversity_param_tuple: :param component_name: :param n_iter: :param rank_dp: :param retrieved_dp: :param rm_dialog: only useful when retrieved_dp=None :return: """ dp_mode = 'r' dp_params = { 'n_iter': n_iter, 'mode': dp_mode, } diversity_weight, diversity_algorithm = diversity_param_tuple # todo: double check this condition; added later for avoiding bug for centrality-tfidf. # # one model has only one suit of summary components but different ranking sys if component_name: dp_params['model_name'] = component_name else: dp_params['model_name'] = model_name summ_comp_root = graph_io.get_summ_comp_root(**dp_params) sim_mat_dp = graph_io.get_sim_mat_dp(summ_comp_root, mode=dp_mode) rel_vec_dp = graph_io.get_rel_vec_dp(summ_comp_root, mode=dp_mode) sid2abs_dp = graph_io.get_sid2abs_dp(summ_comp_root, mode=dp_mode) sid2score_dp = graph_io.get_sid2score_dp(summ_comp_root, mode=dp_mode) if not rank_dp: rank_dp_params = { 'model_name': model_name, 'n_iter': n_iter, 'diversity_param_tuple': diversity_param_tuple, } rank_dp = tools.get_rank_dp(**rank_dp_params) if exists(rank_dp): raise ValueError('rank_dp exists: {}'.format(rank_dp)) os.mkdir(rank_dp) dps = { 'sim_mat_dp': sim_mat_dp, 'rel_vec_dp': rel_vec_dp, 'sid2abs_dp': sid2abs_dp, } if not cc_ids: cc_ids = tools.get_test_cc_ids() for cid in tqdm(cc_ids): # logger.info('cid: {}'.format(cid)) comp_params = { **dps, 'cid': cid, } components = graph_io.load_components(**comp_params) # logger.info('[GRAPH RANK 1/2] successfully loaded components') sid2score = graph_io.load_sid2score(sid2score_dp, cid) if retrieved_dp: original_sents, _ = load_retrieved_sentences(retrieved_dp=retrieved_dp, cid=cid) else: if 'tdqfs' in config.test_year: original_sents, _ = dataset_parser.cid2sents_tdqfs(cid) else: original_sents, _ = dataset_parser.cid2sents(cid, rm_dialog=rm_dialog) # 2d lists, docs => sents diversity_params = { 'sid2score': sid2score, 'sid2abs': components['sid2abs'], 'sim_mat': components['sim_mat'], 'original_sents': original_sents, } if diversity_algorithm == 'wan': diversity_params['omega'] = diversity_weight rank_records = _rank_with_diversity_penalty_wan(**diversity_params) else: raise ValueError('Invalid diversity_algorithm: {}'.format(diversity_algorithm)) logger.info('cid: {}, #rank_records: {}'.format(cid, len(rank_records))) rank_sent.dump_rank_records(rank_records, out_fp=join(rank_dp, cid), with_rank_idx=False) logger.info('[GRAPH RANK] Finished. Rankings were dumped to: {}'.format(rank_dp))
def select_end2end(model_name, n_iter=None, diversity_param_tuple=None, cos_threshold=None, max_n_summary_words=500, extra=None, rank_model_name=None, rel_sents_dp=None, retrieved_dp=None, rm_dialog=True, cc_ids=None): """ :param model_name: :param n_iter: :param sort: date or origin :param duplication: if True, duplicated generated summary $n_refs$. :param attn_weigh: bool :param para_weigh: bool :param doc_weigh: bool :param cos_threshold: 0.5, 0.6 :param max_n_summary_words: 500 :param rank_model_name: you can specify rank_model_name; default is set to model_name. :param rm_dialog: only useful when retrieved_dp=None :return: """ # make dump dir text_params = { 'model_name': model_name, 'cos_threshold': cos_threshold, 'n_iter': n_iter, 'diversity_param_tuple': diversity_param_tuple, 'extra': extra, } text_dp = tools.init_text_dp(**text_params) # date_sorted_doc_indices if not cc_ids: cc_ids = tools.get_test_cc_ids() base_selector_params = { 'text_dp': text_dp, 'cos_threshold': cos_threshold, 'max_n_summary_words': max_n_summary_words, 'rel_sents_dp': rel_sents_dp, 'retrieved_dp': retrieved_dp, } # logger.info('[SELECT SENTS] selecting sents for {} clusters'.format(len(cc_ids))) if not rank_model_name: rank_model_name = model_name rank_dp = tools.get_rank_dp(rank_model_name, n_iter=n_iter, diversity_param_tuple=diversity_param_tuple, extra=extra) for cid in tqdm(cc_ids): rank_fp = join(rank_dp, cid) selector_params = { **base_selector_params, 'cid': cid, 'rank_fp': rank_fp, 'rm_dialog': rm_dialog, } selector = Selector(**selector_params) selector.gen_and_dump_summary() logger.info( '[SELECT SENT] successfully dumped selected sentences to: {}'.format( text_dp)) output = rouge.compute_rouge_end2end(**text_params) return output
qa_config.RELEVANCE_SCORE_DIR_NAME) rank_dp = join(path_parser.summary_rank, qa_config.QA_MODEL_NAME_BERT) ir_rec_dp = join(path_parser.summary_rank, qa_config.IR_RECORDS_DIR_NAME) use_tdqfs = 'tdqfs' in qa_config.IR_RECORDS_DIR_NAME if use_tdqfs: sentence_dp = path_parser.data_tdqfs_sentences query_fp = path_parser.data_tdqfs_queries tdqfs_summary_target_dp = path_parser.data_tdqfs_summary_targets test_cid_query_dicts = general_tools.build_tdqfs_cid_query_dicts( query_fp=query_fp, proc=True) cids = [cq_dict['cid'] for cq_dict in test_cid_query_dicts] else: cids = tools.get_test_cc_ids() def init(): # parse args parser = ArgumentParser() parser.add_argument( 'n_devices', nargs='?', default=4, help='num of devices on which model will be running on') args = parser.parse_args() all_device_ids = [0, 1, 2, 3] device = all_device_ids[:int(args.n_devices)] config_meta['device'] = device