def entity_constraint_extraction(self, el_result, vb=0): ec_dict = {} # <el_item, [anchor predicate]> total_size = 0 for el_item in el_result: mid = el_item.entity.id raw_anchor_pred_list = self.driver.query_pred_given_object(mid) anchor_pred_list = [ pred for pred in raw_anchor_pred_list if not (pred.startswith('common') or pred.startswith('type')) ] # filter common / type predicates ec_dict[el_item] = anchor_pred_list total_size += len(anchor_pred_list) if vb == 1: for el_item, p_list in ec_dict.items(): LogInfo.begin_track('%d anchors for obj = %s (%s): ', len(p_list), el_item.entity.id.encode('utf-8'), el_item.name.encode('utf-8')) for p in p_list: LogInfo.logs(p) LogInfo.end_track() LogInfo.logs( 'In total %d candidate objects and ' + '%d <obj_mid, anchor_pred> extracted.', len(ec_dict), total_size) return ec_dict, total_size
def forward(self, path_wd_hidden, path_kb_hidden, path_len, focus_wd_hidden, focus_kb_hidden, reuse=None): LogInfo.begin_track('SkBiRNNModule forward: ') with tf.variable_scope('SkBiRNNModule', reuse=reuse): if self.data_source == 'kb': use_path_hidden = path_kb_hidden use_focus_hidden = focus_kb_hidden elif self.data_source == 'word': use_path_hidden = path_wd_hidden use_focus_hidden = focus_wd_hidden else: use_path_hidden = tf.concat([path_kb_hidden, path_wd_hidden], axis=-1, name='use_path_hidden') # (batch, path_max_len, dim_item_hidden + dim_kb_hidden) use_focus_hidden = tf.concat( [focus_kb_hidden, focus_wd_hidden], axis=-1, name='use_focus_hidden') # (batch, dim_item_hidden + dim_kb_hidden) use_path_emb_input = tf.concat( [tf.expand_dims(use_focus_hidden, axis=1), use_path_hidden], axis=1, name='use_path_emb_input' ) # (batch, path_max_len + 1, dim_use) show_tensor(use_path_emb_input) use_path_len = path_len + 1 stamps = self.path_max_len + 1 birnn_inputs = tf.unstack(use_path_emb_input, num=stamps, axis=1, name='birnn_inputs') encoder_output = self.rnn_encoder.encode( inputs=birnn_inputs, sequence_length=use_path_len, reuse=reuse) rnn_outputs = tf.stack( encoder_output.outputs, axis=1, name='rnn_outputs') # (batch, path_max_len + 1, dim_sk_hidden) # Since we are in the BiRNN mode, we are simply taking average. sum_sk_hidden = tf.reduce_sum( rnn_outputs, axis=1, name='sum_sk_hidden') # (batch, dim_sk_hidden) use_path_len_mat = tf.cast( tf.expand_dims(use_path_len, axis=1), dtype=tf.float32, name='use_path_len_mat') # (batch, 1) as float32 sk_hidden = tf.div(sum_sk_hidden, use_path_len_mat, name='sk_hidden') # (batch, dim_sk_hidden) LogInfo.end_track() return sk_hidden
def print_att_matrix(word_surf_list, path_surf_list, att_mat): """ Given surfaces of words in Q, and the indicator of each skeleton, print the attention matrix in a nice format. """ word_sz = len(word_surf_list) path_sz = len(path_surf_list) header = [' ' * 10] for col_idx in range(att_mat.shape[1]): if col_idx < path_sz: header.append('%9s' % path_surf_list[col_idx]) else: header.append(' <EMPTY>') if col_idx == path_sz - 1: header.append('|') LogInfo.logs(' '.join(header)) for row_idx, att_row in enumerate(att_mat): # scan each row of the raw attention matrix show_str_list = [] if row_idx >= word_sz: wd = '<EMPTY>' else: wd = word_surf_list[row_idx] show_str_list.append('%10s' % wd) for col_idx, att_val in enumerate(att_row): show_str_list.append('%9.6f' % att_val) if col_idx == path_sz - 1: show_str_list.append('|') show_str = ' '.join(show_str_list) LogInfo.logs(show_str) if row_idx == word_sz - 1: LogInfo.logs( '-' * len(show_str)) # split useful and non-useful parts LogInfo.end_track()
def query_sparql(self, sparql_query): key = self.shrink_query(sparql_query) hit = key in self.sparql_dict show_request = ( not hit or self.vb >= 2 ) # going to perform a real query, or just want to show more if show_request: LogInfo.begin_track('[%s] SPARQL Request:', show_time()) LogInfo.logs(key) if hit and self.vb >= 1: if not show_request: LogInfo.logs('[%s] SPARQL hit!', show_time()) else: LogInfo.logs('SPARQL hit!') if hit: query_ret = self.sparql_dict[key] else: query_ret = self.kernel_query(key=key) # if self.lock is not None: # self.lock.acquire() if query_ret is not None: # ignore schemas returning None self.sparql_dict[key] = query_ret self.sparql_buffer.append((key, query_ret)) # if self.lock is not None: # self.lock.release() if show_request and self.vb >= 3: LogInfo.logs(query_ret) if show_request: LogInfo.end_track() final_query_ret = query_ret or [] return final_query_ret # won't send None back, but use empty list as instead
def show_el_detail(sc): el_feats_concat, el_raw_score = [ sc.run_info[k].tolist() for k in ('el_feats_concat', 'el_raw_score') ] el_len = sc.input_np_dict['el_len'] gl_tup_list = [] for category, gl_data, pred_seq in sc.raw_paths: if category not in ('Entity', 'Main'): continue if category == 'Main': tp = get_domain(pred_seq[0]) else: tp = get_range(pred_seq[-1]) gl_tup_list.append((gl_data, tp)) assert len(gl_tup_list) == el_len for el_idx in range(el_len): gl_data, tp = gl_tup_list[el_idx] LogInfo.begin_track('Entity %d / %d:', el_idx + 1, el_len) LogInfo.logs(gl_data.display()) LogInfo.logs( 'Prominent type: [%s] <--- (Ignore this line if type is not used.)', tp) LogInfo.logs('raw_score = %.6f', el_raw_score[el_idx]) show_str = ' '.join(['%6.3f' % x for x in el_feats_concat[el_idx]]) LogInfo.logs('el_feats_concat = %s', show_str) LogInfo.end_track()
def load_annotations_bio(word_dict, q_max_len): """ Read annotation, convert to B,I,O format, and store into numpy array """ LogInfo.begin_track('Load SimpQ-mention annotation from [%s]:', anno_fp) raw_tup_list = [] # [(v, v_len, tag)] with codecs.open(anno_fp, 'r', 'utf-8') as br: for line_idx, line in enumerate(br.readlines()): spt = line.strip().split('\t') q_idx, st, ed = [int(x) for x in spt[:3]] jac = float(spt[3]) if jac != 1.0: continue # only pick the most accurate sentences tok_list = spt[-1].lower().split(' ') v_len = len(tok_list) v = [word_dict[tok] for tok in tok_list] # TODO: make sure all word exists tag = [2] * st + [ 0 ] + [1] * (ed - st - 1) + [2] * (v_len - ed) # 0: B, 1: I, 2: O # if line_idx < 10: # LogInfo.begin_track('Check case-%d: ', line_idx) # LogInfo.logs('tok_list: %s', tok_list) # LogInfo.logs('v: %s', v) # LogInfo.logs('tag: %s', tag) # LogInfo.end_track() assert len(tag) == len(v) raw_tup_list.append((v, v_len, tag)) q_size = len(raw_tup_list) v_len_list = [tup[1] for tup in raw_tup_list] LogInfo.logs('%d high-quality annotation loaded.', q_size) LogInfo.logs('maximum length = %d (%.6f on avg)', np.max(v_len_list), np.mean(v_len_list)) for pos in (25, 50, 75, 90, 95, 99, 99.9): LogInfo.logs('Percentile = %.1f%%: %.6f', pos, np.percentile(v_len_list, pos)) filt_tup_list = filter(lambda _tup: _tup[1] <= q_max_len, raw_tup_list) LogInfo.logs('%d / %d sentence filtered by [q_max_len=%d].', len(filt_tup_list), q_size, q_max_len) # idx = 0 for v, _, tag in filt_tup_list: v += [0] * (q_max_len - len(v)) tag += [2] * (q_max_len - len(tag)) # if idx < 10: # LogInfo.begin_track('Check formed case-%d ', idx) # LogInfo.logs('v: %s', v) # LogInfo.logs('tag: %s', tag) # LogInfo.end_track() # idx += 1 v_list, v_len_list, tag_list = [[tup[i] for tup in filt_tup_list] for i in range(3)] np_data_list = [ np.array(v_list, dtype='int32'), # (ds, q_max_len) np.array(v_len_list, dtype='int32'), # (ds, ) np.array(tag_list, dtype='int32') # (ds, num_classes) ] for idx, np_data in enumerate(np_data_list): LogInfo.logs('np-%d: %s', idx, np_data.shape) LogInfo.end_track() return np_data_list
def show_el_detail_without_type(sc, qa, schema_dataset): # el_final_feats, el_raw_score = [ # sc.run_info[k].tolist() for k in ('el_final_feats', 'el_score') # ] assert qa is not None assert schema_dataset is not None el_final_feats = sc.run_info['el_final_feats'].tolist() LogInfo.logs('el_final_feats = [%s]', ' '.join(['%6.3f' % x for x in el_final_feats])) el_mask = sc.input_np_dict['el_mask'] path_size = sc.input_np_dict['path_size'] el_indv_feats = sc.input_np_dict['el_indv_feats'] gl_list = [] for category, gl_data, pred_seq in sc.raw_paths: gl_list.append(gl_data) assert path_size == len(gl_list) for el_idx in range(path_size): msk = el_mask[el_idx] gl_data = gl_list[el_idx] LogInfo.begin_track('Entity %d / %d:', el_idx + 1, path_size) LogInfo.logs(gl_data.display()) if msk == 0.: LogInfo.logs('[Mask = 0, IGNORED.]') else: local_feats = el_indv_feats[el_idx] LogInfo.logs('local_feats = [%s]', ' '.join(['%6.3f' % x for x in local_feats])) LogInfo.end_track()
def post_process(self, eval_dl, detail_fp): total_pred = 0 # total number of predicted chunks total_gold = 0 # total number of gold chunks total_correct = 0 avg_log_lik = 0. data_size = 0 for v_len, gold_tag, pred_tag, log_lik in zip(*self.eval_detail_list): data_size += 1 gold_tag = gold_tag[:v_len] pred_tag = pred_tag[:v_len] gold_chunk_set = self.produce_chunk(gold_tag) pred_chunk_set = self.produce_chunk(pred_tag) total_pred += len(pred_chunk_set) total_gold += len(gold_chunk_set) total_correct += len(pred_chunk_set & gold_chunk_set) avg_log_lik += log_lik if data_size <= 5: LogInfo.begin_track('Check case-%d:', data_size) LogInfo.logs('seq_len = %d', v_len) LogInfo.logs('Gold: %s --> %s', gold_tag.tolist(), gold_chunk_set) LogInfo.logs('Pred: %s --> %s', pred_tag.tolist(), pred_chunk_set) LogInfo.logs('log_lik = %.6f', log_lik) LogInfo.end_track() p = 1. * total_correct / total_pred r = 1. * total_correct / total_gold f1 = 2. * p * r / (p + r) if (p + r) > 0. else 0. avg_log_lik /= data_size return f1, avg_log_lik
def load_cands(self): if len(self.np_data_list) > 0 and \ self.q_cand_dict is not None and \ self.q_words_dict is not None: return if not os.path.isfile(self.dump_fp): self.prepare_all_data() return LogInfo.begin_track('Loading candidates & np_data from [%s] ...', self.dump_fp) with open(self.dump_fp, 'rb') as br: self.q_list = cPickle.load(br) LogInfo.logs('q_list loaded for %d questions.', len(self.q_list)) self.q_words_dict = cPickle.load(br) LogInfo.logs('q_words_dict loaded for %d questions.', len(self.q_words_dict)) self.q_cand_dict = cPickle.load(br) LogInfo.logs('q_cand_dict loaded.') cand_size_dist = np.array([len(v) for v in self.q_cand_dict.values()]) LogInfo.begin_track('Show candidate size distribution:') for pos in (25, 50, 75, 90, 95, 99, 99.9, 100): LogInfo.logs('Percentile = %.1f%%: %.6f', pos, np.percentile(cand_size_dist, pos)) LogInfo.end_track() for data_idx in range(self.array_num): np_data = np.load(br) self.np_data_list.append(np_data) LogInfo.logs('np-data-%d loaded: %s', data_idx, np_data.shape) LogInfo.end_track()
def type_filtering(el_result, tl_result, sparql_driver, is_type_extend=True, vb=0): if vb >= 1: LogInfo.begin_track('Type Filtering:') relevant_preds = set([]) for el in el_result: mid = el.entity.id local_relevant_preds = collect_relevant_predicate(mid, sparql_driver) relevant_preds |= local_relevant_preds if vb >= 1: LogInfo.logs('%d relevant predicates collected.', len(relevant_preds)) topical_consistent_types = prepare_topical_consistent_types( relevant_pred_set=relevant_preds, is_type_extended=is_type_extend, vb=vb) filt_tl_result = filter( lambda tl: tl.entity.id in topical_consistent_types, tl_result) LogInfo.logs('Type Filter: %d / %d types are kept.', len(filt_tl_result), len(tl_result)) if vb >= 1: LogInfo.end_track() return filt_tl_result
def construct_gather_linkings(el_result, tl_result, tml_result, tml_comp_result): # Put all E/T/Tm linkings together. gather_linkings = [] for el in el_result: assert hasattr(el, 'link_feat') disp = 'E: [%d, %d) %s (%s) %.6f' % ( el.tokens[0].index, el.tokens[-1].index + 1, el.entity.id.encode('utf-8'), el.name.encode('utf-8'), el.surface_score) gather_linkings.append(LinkData(el, 'Entity', '==', disp, el.link_feat)) for tl in tl_result: disp = 'T: [%d, %d) %s (%s) %.6f' % ( tl.tokens[0].index, tl.tokens[-1].index + 1, tl.entity.id.encode('utf-8'), tl.name.encode('utf-8'), tl.surface_score) gather_linkings.append(LinkData(tl, 'Type', '==', disp, [])) for tml, comp in zip(tml_result, tml_comp_result): disp = 'Tm: [%d, %d) %s %s %.6f' % ( tml.tokens[0].index, tml.tokens[-1].index + 1, comp, tml.entity.sparql_name().encode('utf-8'), tml.surface_score) gather_linkings.append(LinkData(tml, 'Time', comp, disp, [])) sz = len(gather_linkings) LogInfo.begin_track('%d E + %d T + %d Tm = %d links.', len(el_result), len(tl_result), len(tml_result), sz) for link_data in gather_linkings: LogInfo.logs(link_data.display) LogInfo.end_track() return gather_linkings
def pick_one_search(spec_linkings, conflict_matrix, tag_set, av_combs, spec): """ Work for T/Tm/Ord, since only one of them can be selected, no need for DFS. """ assert spec in ('T', 'Tm', 'Ord') LogInfo.begin_track('Searching at %s level ...', spec) spec_available_combs = [] for gl_data_indices, tag_elements, visit_arr in av_combs: for gl_data in spec_linkings: gl_pos = gl_data.gl_pos if visit_arr[gl_pos] != 0: # cannot be visited due to conflict continue new_visit_arr = list(visit_arr) # new state after applying types for conf_idx in conflict_matrix[gl_pos]: new_visit_arr[conf_idx] += 1 if spec in ('Tm', 'Ord'): tag_elem = spec else: tag_elem = 'T:%s' % gl_data.value new_gl_data_indices = list(gl_data_indices) + [gl_pos] new_tag_elements = list(tag_elements) + [tag_elem] tag = '|'.join(new_tag_elements) if tag in tag_set: if vb >= 1: LogInfo.logs(tag) spec_available_combs.append( (new_gl_data_indices, new_tag_elements, new_visit_arr)) LogInfo.end_track() return spec_available_combs
def main(): qa_list = load_webq() yih_ret_fp = 'codalab/WebQ/acl2015-msr-stagg/test_predict.txt' yih_ret_dict = {} with codecs.open(yih_ret_fp, 'r', 'utf-8') as br: for line in br.readlines(): k, v = line.strip().split('\t') yih_ret_dict[k] = v LogInfo.logs('Yih result collected.') exp_tup_list = [('180514_strict/all__full__180508_K03_Fhalf__depSimulate/' 'NFix-20__wUpd_RH_qwOnly_compact__b32__fbpFalse', '180508_K03_Fhalf', 10), ('180516_strict/all__full__180508_K03_Fhalf__Lemmatize/' 'NFix-20__wUpd_RH_qwOnly_compact__b32__fbpFalse', '180508_K03_Fhalf', 15), ('180516_strict/all__full__180508_K03_Fhalf__Lemmatize/' 'NFix-20__wUpd_BH_qwOnly_compact__b32__fbpFalse', '180508_K03_Fhalf', 12)] for exp_suf, data_suf, best_epoch in exp_tup_list: exp_dir = 'runnings/WebQ/' + exp_suf data_dir = 'runnings/candgen_WebQ/' + data_suf LogInfo.begin_track('Dealing with [%s], epoch = %03d:', exp_suf, best_epoch) work(exp_dir=exp_dir, data_dir=data_dir, best_epoch=best_epoch, qa_list=qa_list, yih_ret_dict=yih_ret_dict) LogInfo.end_track()
def get_gradient_tf_list(self, score_tf): LogInfo.begin_track('LambdaRank genearating gradients ... ') grad_tf_list = [] # the return value scan = 0 for var in tf.global_variables(): scan += 1 LogInfo.begin_track('Variable %d / %d %s: ', scan, len(tf.global_variables()), var.get_shape().as_list()) per_row_grad_tf_list = [] for row_idx in range(self.batch_size): LogInfo.begin_track('row_idx = %d / %d: ', row_idx + 1, self.batch_size) local_grad_tf_list = [] for item_idx in range(self.list_len): if (item_idx + 1) % 50 == 0: LogInfo.logs('item_idx = %d / %d', item_idx + 1, self.list_len) local_grad_tf = tf.gradients(score_tf[row_idx, item_idx], var)[0] # ("var_shape", ) local_grad_tf_list.append(local_grad_tf) per_row_grad_tf = tf.stack(local_grad_tf_list, axis=0) per_row_grad_tf_list.append(per_row_grad_tf) # per_row_grad_tf: (list_len, "var_shape") LogInfo.end_track() grad_tf = tf.stack(per_row_grad_tf_list, axis=0) grad_tf_list.append(grad_tf) LogInfo.logs('grad_tf: %s', grad_tf.get_shape().as_list()) # grad_tf: (batch_size, list_len, "var_shape") LogInfo.end_track() return grad_tf_list
def analyze_output(data_dir, sort_item): """ Check the rank distribution in a global perspective """ rank_matrix = [ [], [], [], [] ] # focus on 4 tiers: F1 = 1.0, F1 >= 0.5, F1 >= 0.1, F1 > 0 fp = '%s/lexicon_validate/srt_output.%s.txt' % (data_dir, sort_item) with codecs.open(fp, 'r', 'utf-8') as br: for line in br.readlines(): spt = line.strip().split('\t') f1 = float(spt[-2]) rank = int(spt[-1]) if rank == -1: rank = 2111222333 if f1 == 1.0: rank_matrix[0].append(rank) if f1 >= 0.5: rank_matrix[1].append(rank) if f1 >= 0.1: rank_matrix[2].append(rank) if f1 >= 1e-6: rank_matrix[3].append(rank) for ths, rank_list in zip((1.0, 0.5, 0.1, 1e-6), rank_matrix): LogInfo.begin_track('Show stat for F1 >= %.6f:', ths) rank_list = np.array(rank_list) case_size = len(rank_list) LogInfo.logs('Total cases = %d.', case_size) LogInfo.logs('MRR = %.6f', np.mean(1. / rank_list)) for pos in (50, 60, 70, 80, 90, 95, 99, 99.9, 100): LogInfo.logs('Percentile = %.1f%%: %.6f', pos, np.percentile(rank_list, pos)) LogInfo.end_track()
def load_all_data(self): if len(self.smart_q_cand_dict) > 0: # already loaded return if not os.path.isfile(self.dump_fp): # no dump, read from txt self.load_smart_schemas_from_txt() else: self.load_smart_schemas_from_pickle() LogInfo.begin_track('Meta statistics:') LogInfo.logs('Total questions = %d', len(self.q_idx_list)) LogInfo.logs('T / v / t questions = %s', [len(lst) for lst in self.spt_q_idx_lists]) LogInfo.logs( 'Active Word / Mid / Path = %d / %d / %d (with PAD, START, UNK)', len(self.active_dicts['word']), len(self.active_dicts['mid']), len(self.active_dicts['path'])) LogInfo.logs('path_max_size = %d, qw_max_len = %d, pw_max_len = %d.', self.path_max_size, self.qw_max_len, self.pw_max_len) cand_size_dist = np.array( [len(v) for v in self.smart_q_cand_dict.values()]) LogInfo.logs('Total schemas = %d, avg = %.6f.', np.sum(cand_size_dist), np.mean(cand_size_dist)) for pos in (25, 50, 75, 90, 95, 99, 99.9, 100): LogInfo.logs('cand_size @ %.1f%%: %.6f', pos, np.percentile(cand_size_dist, pos)) qlen_dist = np.array([len(qa['tokens']) for qa in self.qa_list]) LogInfo.logs('Avg question length = %.6f.', np.mean(qlen_dist)) for pos in (25, 50, 75, 90, 95, 99, 99.9, 100): LogInfo.logs('question_len @ %.1f%%: %.6f', pos, np.percentile(qlen_dist, pos)) LogInfo.end_track()
def load_smart_cands(self): if self.smart_q_cand_dict is not None: # already loaded return if not os.path.isfile(self.dump_fp): # no dump, read from txt self.load_smart_schemas_from_txt() else: LogInfo.begin_track('Loading smart_candidates from [%s] ...', self.dump_fp) with open(self.dump_fp, 'rb') as br: LogInfo.begin_track('Loading smart_q_cand_dict ... ') self.smart_q_cand_dict = cPickle.load(br) LogInfo.logs('Candidates for %d questions loaded.', len(self.smart_q_cand_dict)) cand_size_dist = np.array( [len(v) for v in self.smart_q_cand_dict.values()]) LogInfo.logs('Total schemas = %d, avg = %.6f.', np.sum(cand_size_dist), np.mean(cand_size_dist)) for pos in (25, 50, 75, 90, 95, 99, 99.9, 100): LogInfo.logs('Percentile = %.1f%%: %.6f', pos, np.percentile(cand_size_dist, pos)) LogInfo.end_track() self.path_idx_dict = cPickle.load(br) self.entity_idx_dict = cPickle.load(br) self.type_idx_dict = cPickle.load(br) LogInfo.logs('Active E/T/Path dict loaded.') self.pw_voc_inputs = cPickle.load(br) # (path_voc, pw_max_len) self.pw_voc_length = cPickle.load(br) # (path_voc,) self.pw_voc_domain = cPickle.load(br) # (path_voc,) self.entity_type_matrix = cPickle.load( br) # (entity_voc, type_voc) self.pw_max_len = self.pw_voc_inputs.shape[1] LogInfo.logs('path word & entity_type lookup tables loaded.') self.q_idx_list = sorted(self.smart_q_cand_dict.keys()) LogInfo.end_track() # end of loading self.meta_stat() # show meta statistics
def batch_schema_f1_query(self, q_id, states, level): """ perform F1 query for each schema in the state. :param q_id: WebQ-xxx / CompQ-xxx :param states: [(schema, visit_arr)] :param level: coarse / typed / timed / ordinal :return: filtered states where each schema returns at least one answer. """ Tt.start('%s_F1' % level) LogInfo.begin_track('Calculating F1 for %d %s schemas:', len(states), level) for idx, (sc, _) in enumerate(states): if idx % 100 == 0: LogInfo.logs('Current: %d / %d', idx, len(states)) sparql_str = sc.build_sparql(simple_time_match=self.simple_time_match) tm_comp, tm_value, ord_comp, ord_rank, agg = sc.build_aux_for_sparql() allow_forever = self.allow_forever if tm_comp != 'None' else '' # won't specific forever if no time constraints q_sc_key = '|'.join([q_id, sparql_str, tm_comp, tm_value, allow_forever, ord_comp, ord_rank, agg]) if self.vb >= 2: LogInfo.begin_track('Checking schema %d / %d:', idx, len(states)) LogInfo.logs(sc.disp_raw_path()) LogInfo.logs('var_types: %s', sc.var_types) LogInfo.logs(sparql_str) Tt.start('query_q_sc_stat') sc.ans_size, sc.p, sc.r, sc.f1 = self.query_srv.query_q_sc_stat(q_sc_key) Tt.record('query_q_sc_stat') if self.vb >= 2: LogInfo.logs('Answers = %d, P = %.6f, R = %.6f, F1 = %.6f', sc.ans_size, sc.p, sc.r, sc.f1) LogInfo.end_track() filt_states = filter(lambda _tup: _tup[0].ans_size > 0, states) LogInfo.end_track('%d / %d %s schemas kept with ans_size > 0.', len(filt_states), len(states), level) Tt.record('%s_F1' % level) return filt_states
def forward(self, item_wd_embedding, item_len, reuse=None): LogInfo.begin_track('ItemBiRNNModule forward: ') with tf.variable_scope('ItemBiRNNModule', reuse=reuse): # stamps = item_wd_embedding.get_shape().as_list()[1] stamps = self.item_max_len show_tensor(item_wd_embedding) birnn_inputs = tf.unstack(item_wd_embedding, num=stamps, axis=1, name='birnn_inputs') # rnn_input: a list of stamps elements: (batch, n_emb) encoder_output = self.rnn_encoder.encode(inputs=birnn_inputs, sequence_length=item_len, reuse=reuse) birnn_outputs = tf.stack( encoder_output.outputs, axis=1, name='birnn_outputs') # (data_size, q_len, n_hidden_emb) LogInfo.logs('birnn_output = %s', birnn_outputs.get_shape().as_list()) sum_wd_hidden = tf.reduce_sum(birnn_outputs, axis=1) # (data_size, n_hidden_emb) item_len_mat = tf.cast(tf.expand_dims(item_len, axis=1), dtype=tf.float32) # (data_size, 1) as float item_wd_hidden = tf.div( sum_wd_hidden, tf.maximum(item_len_mat, 1), # avoid dividing by 0 name='item_wd_hidden') # (data_size, n_hidden_emb) LogInfo.logs('item_wd_hidden = %s', item_wd_hidden.get_shape().as_list()) LogInfo.end_track() return item_wd_hidden
def __init__(self, use_sparql_cache=True, data_mode='Ordinal', sc_mode='Skeleton', root_path='/home/kangqi/workspace/PythonProject', cache_dir='runnings/compQA/cache'): LogInfo.begin_track('Initializing InputGenerator ... ') assert data_mode in ('Ordinal', 'ComplexQuestions') assert sc_mode in ('Skeleton', 'Sk+Ordinal') self.data_mode = data_mode self.sc_mode = sc_mode if self.data_mode == 'Ordinal': self.qa_data = load_complex_questions_ordinal_only() self.train_qa_list, self.test_qa_list = self.qa_data elif self.data_mode == 'ComplexQuestions': self.qa_data = load_complex_questions() self.train_qa_list, self.test_qa_list = self.qa_data else: LogInfo.logs('Unknown data mode: %s', self.data_mode) self.cand_gen = CandidateGenerator(use_sparql_cache=use_sparql_cache) self.loss_calc = LossCalculator(driver=self.cand_gen.driver) # qa_schema_score_cache_fp = '%s/%s/qa_schema_score_%s_cache' %(root_path, cache_dir, sc_mode) # self.score_cache = DictCache(qa_schema_score_cache_fp) LogInfo.end_track()
def main(args): qa_list = load_simpq() wd_emb_util = WordEmbeddingUtil(wd_emb=args.word_emb, dim_emb=args.dim_emb) lex_name = args.lex_name # lex_name = 'filter-%s-xs-lex-score' % args.fb_subset cand_gen = SimpleQCandidateGenerator(fb_subset=args.fb_subset, lex_name=lex_name, wd_emb_util=wd_emb_util, vb=args.verbose) for q_idx, qa in enumerate(qa_list): LogInfo.begin_track('Entering Q %d / %d [%s]:', q_idx, len(qa_list), qa['utterance'].encode('utf-8')) sub_idx = q_idx / 1000 * 1000 sub_dir = '%s/data/%d-%d' % (args.output_dir, sub_idx, sub_idx + 999) if not os.path.exists(sub_dir): os.makedirs(sub_dir) opt_sc_fp = '%s/%d_schema' % (sub_dir, q_idx) link_fp = '%s/%d_links' % (sub_dir, q_idx) if os.path.isfile(opt_sc_fp): LogInfo.end_track('Skip this question, already saved.') continue Tt.start('single_q') cand_gen.single_question_candgen(q_idx=q_idx, qa=qa, link_fp=link_fp, opt_sc_fp=opt_sc_fp) Tt.record('single_q') LogInfo.end_track() # End of Q
def start_entity_search(entity_linkings, conflict_matrix, tag_set): LogInfo.begin_track('Searching at M/E level ...') entity_available_combs = [] # the return value el_size = len(entity_linkings) gl_size = len(conflict_matrix) for mf_idx, main_focus in enumerate(entity_linkings): gl_pos = main_focus.gl_pos visit_arr = [0] * gl_size for conf_idx in conflict_matrix[gl_pos]: visit_arr[conf_idx] += 1 gl_data_indices = [gl_pos] tag_elements = [] # create the initial state of search mid = main_focus.value type_list = get_entity_type(mid) for tp_idx, tp in enumerate(type_list): state_marker = [ 'M%d/%d-(t%d/%d)' % (mf_idx + 1, el_size, tp_idx + 1, len(type_list)) ] tag_elements.append('M:%s' % tp) entity_search_dfs(entity_linkings=entity_linkings, conflict_matrix=conflict_matrix, tag_set=tag_set, cur_el_idx=-1, gl_data_indices=gl_data_indices, tag_elements=tag_elements, visit_arr=visit_arr, entity_available_combs=entity_available_combs, state_marker=state_marker) del tag_elements[-1] LogInfo.end_track() return entity_available_combs
def make_combination(gather_linkings, sparql_driver, vb): """ Given the E/T/Tm linkings, return all the possible combination of query structure. The only restrict: can't use multiple linkings with overlapped mention. ** Used in either WebQ or CompQ, not SimpQ ** :param gather_linkings: list of named_tuple (detail, category, comparison, display) :param sparql_driver: the sparql query engine :param vb: verbose :return: the dictionary including all necessary information of a schema. """ sz = len(gather_linkings) el_size = len(filter(lambda x: x.category == 'Entity', gather_linkings)) # Step 1: Prepare conflict matrix conflict_matrix = [] for i in range(sz): local_conf_list = [] for j in range(sz): if is_overlap(gather_linkings[i].detail, gather_linkings[j].detail): local_conf_list.append(j) elif gather_linkings[i].category == 'Type' and gather_linkings[ j].category == 'Type': local_conf_list.append(j) """ 180205: We add this restriction for saving time.""" """ I thought there should be only one type constraint in the schema. """ """ Don't make the task even more complex. """ conflict_matrix.append(local_conf_list) # Step 2: start combination searching LogInfo.begin_track( 'Starting searching combination (total links = %d, entities = %d):', len(gather_linkings), el_size) ground_comb_list = [] # [ (comb, path_len, sparql_query_ret) ] for path_len in (1, 2): for mf_idx, main_focus in enumerate(gather_linkings): if main_focus.category != 'Entity': continue visit_arr = [ 0 ] * sz # indicating how many conflicts at the particular searching state state_marker = [ 'Path-%d||F%d/%d' % (path_len, mf_idx + 1, el_size) ] cur_comb = [(0, mf_idx)] # indicating the focus entity for conf_idx in conflict_matrix[mf_idx]: visit_arr[conf_idx] += 1 search_start(path_len=path_len, gather_linkings=gather_linkings, sparql_driver=sparql_driver, cur_idx=-1, cur_comb=cur_comb, conflict_matrix=conflict_matrix, visit_arr=visit_arr, ground_comb_list=ground_comb_list, state_marker=state_marker, vb=vb) LogInfo.end_track() return ground_comb_list
def main(args): assert args.data_name in ('WebQ', 'CompQ') if args.data_name == 'WebQ': qa_list = load_webq() else: qa_list = load_compq() if args.linking_only: query_srv = None q_start = 0 q_end = len(qa_list) else: group_idx = args.group_idx q_start = group_idx * 100 q_end = group_idx * 100 + 100 if args.data_name == 'CompQ': sparql_cache_fp = 'runnings/acl18_cache/group_cache/sparql.g%02d.cache' % group_idx q_sc_cache_fp = 'runnings/acl18_cache/group_cache/q_sc_stat.g%02d.cache' % group_idx else: sparql_cache_fp = 'runnings/acl18_cache/group_cache_%s/sparql.g%02d.cache' % (args.data_name, group_idx) q_sc_cache_fp = 'runnings/acl18_cache/group_cache_%s/q_sc_stat.g%02d.cache' % (args.data_name, group_idx) query_srv = QueryService( sparql_cache_fp=sparql_cache_fp, qsc_cache_fp=q_sc_cache_fp, vb=1 ) wd_emb_util = WordEmbeddingUtil(wd_emb=args.word_emb, dim_emb=args.dim_emb) smart_cand_gen = SMARTCandidateGenerator(data_name=args.data_name, lex_name=args.lex_name, wd_emb_util=wd_emb_util, query_srv=query_srv, allow_forever=args.allow_forever, vb=args.verbose, simple_type_match=args.simple_type, simple_time_match=args.simple_time) for q_idx, qa in enumerate(qa_list): if q_idx < q_start or q_idx >= q_end: continue # if q_idx != 1302: # continue LogInfo.begin_track('Entering Q %d / %d [%s]:', q_idx, len(qa_list), qa['utterance'].encode('utf-8')) sub_idx = q_idx / 100 * 100 sub_dir = '%s/data/%d-%d' % (args.output_dir, sub_idx, sub_idx + 99) if not os.path.exists(sub_dir): os.makedirs(sub_dir) # save_ans_fp = '%s/%d_ans' % (sub_dir, q_idx) opt_sc_fp = '%s/%d_schema' % (sub_dir, q_idx) link_fp = '%s/%d_links' % (sub_dir, q_idx) if os.path.isfile(opt_sc_fp): LogInfo.end_track('Skip this question, already saved.') continue Tt.start('single_q') smart_cand_gen.single_question_candgen(q_idx=q_idx, qa=qa, link_fp=link_fp, opt_sc_fp=opt_sc_fp, linking_only=args.linking_only) Tt.record('single_q') LogInfo.end_track() # End of Q
def save_dicts(self): LogInfo.begin_track('Saving actual word/mid dict into [%s] ...', self.dict_fp) with open(self.dict_fp, 'w') as bw: cPickle.dump(self.word_dict, bw) LogInfo.logs('%d w_dict saved.', len(self.word_dict)) cPickle.dump(self.mid_dict, bw) LogInfo.logs('%d e_dict saved.', len(self.mid_dict)) LogInfo.end_track()
def save_init_emb(self): LogInfo.begin_track('Saving initial embedding matrix into [%s] ...', self.init_mat_fp) np.savez(self.init_mat_fp, word_init_emb=self.word_init_emb, mid_init_emb=self.mid_init_emb) LogInfo.logs('Word embedding: %s saved.', self.word_init_emb.shape) LogInfo.logs('Mid embedding: %s saved.', self.mid_init_emb.shape) LogInfo.end_track()
def main_default(qa_list, linking_wrapper, linking_cache, sparql_driver, simpq_sp_dict, args): q_size = len(qa_list) for q_idx, qa in enumerate(qa_list): if q_idx < args.q_start or q_idx >= args.q_end: continue LogInfo.begin_track('Entering Q %d / %d [%s]:', q_idx, q_size, qa['utterance'].encode('utf-8')) sub_idx = q_idx / 100 * 100 sub_dir = '%s/data/%d-%d' % (args.output_dir, sub_idx, sub_idx + 99) if not os.path.exists(sub_dir): os.makedirs(sub_dir) save_schema_fp = '%s/%d_schema' % (sub_dir, q_idx) save_ans_fp = '%s/%d_ans' % (sub_dir, q_idx) save_link_fp = '%s/%d_links' % (sub_dir, q_idx) if os.path.isfile(save_schema_fp) \ and os.path.isfile(save_ans_fp) \ and os.path.isfile(save_link_fp): LogInfo.end_track('Skip this question, already saved.') continue target_value = qa['targetValue'] gather_linkings, schema_json_list, schema_ans_list = \ process_single_question(data_name=args.data_name, q=qa['utterance'], q_idx=q_idx, target_value=target_value, linking_wrapper=linking_wrapper, linking_cache=linking_cache, ground_truth=args.ground_truth, sparql_driver=sparql_driver, simpq_sp_dict=simpq_sp_dict, vb=args.verbose) with open(save_link_fp + '.tmp', 'w') as bw_link: cPickle.dump(gather_linkings, bw_link) bw_sc = open(save_schema_fp + '.tmp', 'w') bw_ans = open(save_ans_fp + '.tmp', 'w') for schema_json, predict_list in zip(schema_json_list, schema_ans_list): json.dump(predict_list, bw_ans) bw_ans.write('\n') json.dump(schema_json, bw_sc) bw_sc.write('\n') bw_sc.close() bw_ans.close() LogInfo.logs('Schema/Answer/Link saving complete.') shutil.move(save_schema_fp + '.tmp', save_schema_fp) shutil.move(save_link_fp + '.tmp', save_link_fp) shutil.move(save_ans_fp + '.tmp', save_ans_fp) del gather_linkings del schema_ans_list del schema_json_list LogInfo.end_track() # End of Q
def answer_collection(args): assert args.data_name in ('WebQ', 'CompQ') result_dir = args.result_dir data_dir = args.data_dir epoch = args.epoch st = args.q_st ed = args.q_ed if args.data_name == 'WebQ': qa_list = load_webq() if st == ed == -1: st = 3778 ed = 5810 else: qa_list = load_compq() if st == ed == -1: st = 1300 ed = 2100 select_schema_fp = '%s/test_schema_%03d.txt' % (result_dir, epoch) select_dict = {} # <q_idx, sc.ori_idx> test_predict_fp = '%s/test_predict_%03d.txt' % (result_dir, epoch) LogInfo.begin_track('Collecting answers from %s ...', select_schema_fp) LogInfo.logs('Original schema data: %s', data_dir) with open(select_schema_fp, 'r') as br: for line in br.readlines(): spt = line.strip().split('\t') select_dict[int(spt[0])] = int(spt[1]) save_tups = [] for scan_idx, q_idx in enumerate(range(st, ed)): if scan_idx % 50 == 0: LogInfo.logs('%d questions scanned.', scan_idx) utterance = qa_list[q_idx]['utterance'] sc_idx = select_dict.get(q_idx, -1) if sc_idx == -1: ans_list = [] # no answer at all else: div = q_idx / 100 sub_dir = '%d-%d' % (div*100, div*100+99) ans_fp = '%s/%s/%d_ans' % (data_dir, sub_dir, q_idx) ans_line = linecache.getline(ans_fp, lineno=sc_idx+1).strip() # with codecs.open(ans_fp, 'r', 'utf-8') as br: # for skip_tms in range(sc_idx): # skip rows before the target schema line # br.readline() # ans_line = br.readline().strip() LogInfo.logs('q_idx=%d, lineno=%d, content=[%s]', q_idx, sc_idx+1, ans_line) if args.data_name == 'CompQ': ans_line = ans_line.lower() ans_list = json.loads(ans_line) save_tups.append((utterance, ans_list)) with codecs.open(test_predict_fp, 'w', 'utf-8') as bw: for utterance, ans_list in save_tups: bw.write(utterance + '\t') json.dump(ans_list, bw) bw.write('\n') LogInfo.end_track('Predicting results saved to %s.', test_predict_fp)
def renew_data_list(self): """ Target: construct the np_data_list, storing all <q, sc> data np_data_list contains: 1. q_words (data_size, q_max_len) 2. q_words_len (data_size, ) 3. sc_len (data_size, ) 4. preds (data_size, sc_max_len, path_max_len) 5. preds_len (data_size, sc_max_len) 6. pwords (data_size, sc_max_len, pword_max_len) 7. pwords_len (data_size, sc_max_len) """ if self.verbose >= 1: LogInfo.begin_track( '[CompqSingleDataLoader] prepare data for [%s] ...', self.mode) q_idx_list = get_q_range_by_mode(data_name=self.dataset.data_name, mode=self.mode) filt_q_idx_list = filter(lambda q: q in self.dataset.q_cand_dict, q_idx_list) self.question_size = len(filt_q_idx_list) # Got all the related questions emb_pools = [ ] # [ qwords, qwords_len, sc_len, preds, preds_len, pwords, pwords_len ] for _ in range(self.dataset.array_num): emb_pools.append([]) """ Different from original complex DataLoader. No schemas share the same qwords. """ self.cand_tup_list = [] for scan_idx, q_idx in enumerate(filt_q_idx_list): if self.verbose >= 1 and scan_idx % self.proc_ob_num == 0: LogInfo.logs('%d / %d prepared.', scan_idx, len(filt_q_idx_list)) cand_list = self.dataset.q_cand_dict[q_idx] # now store schema input into the corresponding position in the np_data_list for sc in cand_list: self.cand_tup_list.append((q_idx, sc)) sc_tensor_inputs = self.dataset.get_schema_tensor_inputs(sc) for local_list, sc_tensor in zip(emb_pools, sc_tensor_inputs): local_list.append(sc_tensor) # now the detail input of the schema is copied into the memory of the dataloader # including qwords, preds, pwords # Finally: merge inputs together, and produce the final np_data_list self.np_data_list = [] for target_list in emb_pools: self.np_data_list.append(np.array(target_list, dtype='int32')) self.update_statistics() assert len(self.cand_tup_list) == self.np_data_list[0].shape[0] if self.verbose >= 1: LogInfo.logs('%d <q, sc> data collected.', self.np_data_list[0].shape[0]) LogInfo.end_track()
def post_process(self, eval_dl, detail_fp, result_fp): if len(self.eval_detail_dict) == 0: self.eval_detail_dict = {k: [] for k in self.concern_name_list} ret_q_score_dict = {} for scan_idx, (q_idx, cand) in enumerate(eval_dl.eval_sc_tup_list): cand.run_info = { k: self.eval_detail_dict[k][scan_idx] for k in self.concern_name_list } ret_q_score_dict.setdefault(q_idx, []).append(cand) # put all output results into sc.run_info f1_list = [] for q_idx, score_list in ret_q_score_dict.items(): score_list.sort(key=lambda x: x.run_info['full_score'], reverse=True) # sort by score DESC if len(score_list) == 0: f1_list.append(0.) else: f1_list.append(score_list[0].f1) LogInfo.logs('[%3s] Predict %d out of %d questions.', self.name, len(f1_list), eval_dl.total_questions) ret_metric = np.sum(f1_list).astype( 'float32') / eval_dl.total_questions if detail_fp is not None: schema_dataset = eval_dl.schema_dataset bw = open(detail_fp, 'w') LogInfo.redirect(bw) np.set_printoptions(threshold=np.nan) LogInfo.logs('Avg_f1 = %.6f', ret_metric) srt_q_idx_list = sorted(ret_q_score_dict.keys()) for q_idx in srt_q_idx_list: qa = schema_dataset.qa_list[q_idx] q = qa['utterance'] LogInfo.begin_track('Q-%04d [%s]:', q_idx, q.encode('utf-8')) srt_list = ret_q_score_dict[q_idx] # already sorted best_label_f1 = np.max([sc.f1 for sc in srt_list]) best_label_f1 = max(best_label_f1, 0.000001) for rank, sc in enumerate(srt_list): if rank < 20 or sc.f1 == best_label_f1: LogInfo.begin_track( '#-%04d [F1 = %.6f] [row_in_file = %d]', rank + 1, sc.f1, sc.ori_idx) LogInfo.logs('full_score: %.6f', sc.run_info['full_score']) show_overall_detail(sc) LogInfo.end_track() LogInfo.end_track() LogInfo.logs('Avg_f1 = %.6f', ret_metric) np.set_printoptions() # reset output format LogInfo.stop_redirect() bw.close() self.ret_q_score_dict = ret_q_score_dict return [ret_metric]