def main(args): qa_list = load_simpq() wd_emb_util = WordEmbeddingUtil(wd_emb=args.word_emb, dim_emb=args.dim_emb) lex_name = args.lex_name # lex_name = 'filter-%s-xs-lex-score' % args.fb_subset cand_gen = SimpleQCandidateGenerator(fb_subset=args.fb_subset, lex_name=lex_name, wd_emb_util=wd_emb_util, vb=args.verbose) for q_idx, qa in enumerate(qa_list): LogInfo.begin_track('Entering Q %d / %d [%s]:', q_idx, len(qa_list), qa['utterance'].encode('utf-8')) sub_idx = q_idx / 1000 * 1000 sub_dir = '%s/data/%d-%d' % (args.output_dir, sub_idx, sub_idx + 999) if not os.path.exists(sub_dir): os.makedirs(sub_dir) opt_sc_fp = '%s/%d_schema' % (sub_dir, q_idx) link_fp = '%s/%d_links' % (sub_dir, q_idx) if os.path.isfile(opt_sc_fp): LogInfo.end_track('Skip this question, already saved.') continue Tt.start('single_q') cand_gen.single_question_candgen(q_idx=q_idx, qa=qa, link_fp=link_fp, opt_sc_fp=opt_sc_fp) Tt.record('single_q') LogInfo.end_track() # End of Q
def main(args): assert args.data_name in ('WebQ', 'CompQ') if args.data_name == 'WebQ': qa_list = load_webq() else: qa_list = load_compq() if args.linking_only: query_srv = None q_start = 0 q_end = len(qa_list) else: group_idx = args.group_idx q_start = group_idx * 100 q_end = group_idx * 100 + 100 if args.data_name == 'CompQ': sparql_cache_fp = 'runnings/acl18_cache/group_cache/sparql.g%02d.cache' % group_idx q_sc_cache_fp = 'runnings/acl18_cache/group_cache/q_sc_stat.g%02d.cache' % group_idx else: sparql_cache_fp = 'runnings/acl18_cache/group_cache_%s/sparql.g%02d.cache' % (args.data_name, group_idx) q_sc_cache_fp = 'runnings/acl18_cache/group_cache_%s/q_sc_stat.g%02d.cache' % (args.data_name, group_idx) query_srv = QueryService( sparql_cache_fp=sparql_cache_fp, qsc_cache_fp=q_sc_cache_fp, vb=1 ) wd_emb_util = WordEmbeddingUtil(wd_emb=args.word_emb, dim_emb=args.dim_emb) smart_cand_gen = SMARTCandidateGenerator(data_name=args.data_name, lex_name=args.lex_name, wd_emb_util=wd_emb_util, query_srv=query_srv, allow_forever=args.allow_forever, vb=args.verbose, simple_type_match=args.simple_type, simple_time_match=args.simple_time) for q_idx, qa in enumerate(qa_list): if q_idx < q_start or q_idx >= q_end: continue # if q_idx != 1302: # continue LogInfo.begin_track('Entering Q %d / %d [%s]:', q_idx, len(qa_list), qa['utterance'].encode('utf-8')) sub_idx = q_idx / 100 * 100 sub_dir = '%s/data/%d-%d' % (args.output_dir, sub_idx, sub_idx + 99) if not os.path.exists(sub_dir): os.makedirs(sub_dir) # save_ans_fp = '%s/%d_ans' % (sub_dir, q_idx) opt_sc_fp = '%s/%d_schema' % (sub_dir, q_idx) link_fp = '%s/%d_links' % (sub_dir, q_idx) if os.path.isfile(opt_sc_fp): LogInfo.end_track('Skip this question, already saved.') continue Tt.start('single_q') smart_cand_gen.single_question_candgen(q_idx=q_idx, qa=qa, link_fp=link_fp, opt_sc_fp=opt_sc_fp, linking_only=args.linking_only) Tt.record('single_q') LogInfo.end_track() # End of Q
def batch_schema_f1_query(self, q_id, states, level): """ perform F1 query for each schema in the state. :param q_id: WebQ-xxx / CompQ-xxx :param states: [(schema, visit_arr)] :param level: coarse / typed / timed / ordinal :return: filtered states where each schema returns at least one answer. """ Tt.start('%s_F1' % level) LogInfo.begin_track('Calculating F1 for %d %s schemas:', len(states), level) for idx, (sc, _) in enumerate(states): if idx % 100 == 0: LogInfo.logs('Current: %d / %d', idx, len(states)) sparql_str = sc.build_sparql(simple_time_match=self.simple_time_match) tm_comp, tm_value, ord_comp, ord_rank, agg = sc.build_aux_for_sparql() allow_forever = self.allow_forever if tm_comp != 'None' else '' # won't specific forever if no time constraints q_sc_key = '|'.join([q_id, sparql_str, tm_comp, tm_value, allow_forever, ord_comp, ord_rank, agg]) if self.vb >= 2: LogInfo.begin_track('Checking schema %d / %d:', idx, len(states)) LogInfo.logs(sc.disp_raw_path()) LogInfo.logs('var_types: %s', sc.var_types) LogInfo.logs(sparql_str) Tt.start('query_q_sc_stat') sc.ans_size, sc.p, sc.r, sc.f1 = self.query_srv.query_q_sc_stat(q_sc_key) Tt.record('query_q_sc_stat') if self.vb >= 2: LogInfo.logs('Answers = %d, P = %.6f, R = %.6f, F1 = %.6f', sc.ans_size, sc.p, sc.r, sc.f1) LogInfo.end_track() filt_states = filter(lambda _tup: _tup[0].ans_size > 0, states) LogInfo.end_track('%d / %d %s schemas kept with ans_size > 0.', len(filt_states), len(states), level) Tt.record('%s_F1' % level) return filt_states
def read_schemas_from_single_file(q_idx, schema_fp, gather_linkings, sc_max_len): """ Read schemas from file, seperate them into several groups """ general_list = [] strict_list = [] # 2-hop must be mediator elegant_list = [] # 2-hop allows pred1.range == pred2.domain coherent_list = [] # 2-hop allows pred1.range \in pred2.domain () global_lists = [strict_list, elegant_list, coherent_list, general_list] used_pred_name_dict = {} with codecs.open(schema_fp, 'r', 'utf-8') as br: lines = br.readlines() for ori_idx, line in enumerate(lines): Tt.start('read_single_line') sc = CompqSchema.read_schema_from_json( q_idx, json_line=line, gather_linkings=gather_linkings, use_ans_type_dist=False, placeholder_policy='ActiveOnly', full_constr=True, fix_dir=True) Tt.record('read_single_line') sc.ori_idx = ori_idx # Tt.start('construct_path') # sc.construct_path_list() # create the path_list on-the-fly # Tt.record('construct_path') sc_len_list[len(sc.raw_paths)] += 1 for _, _, pred_seq in sc.raw_paths: for pred in pred_seq: if pred not in used_pred_name_dict: used_pred_name_dict[pred] = get_pred_name(pred) if len(sc.raw_paths) > sc_max_len: continue Tt.start('classification') sc_class = schema_classification(sc) global_lists[sc_class].append(sc) Tt.record('classification') # if q_idx == 1353: # LogInfo.logs('Q-idx = %4d, Line = %4d, category = %d', q_idx, ori_idx, sc_class) for i in range(3): global_lists[i + 1] += global_lists[i] return global_lists, used_pred_name_dict
def optimize(self, data_loader, epoch_idx, ob_batch_num=10, summary_writer=None): if data_loader is None: return -1. LogInfo.logs('Debug mode: %s', self.debug_mode) Tt.start('optimize') Tt.start('prepare') self.prepare_data(data_loader=data_loader) # run_options = tf.RunOptions(trace_level=tf.RunOptions.FULL_TRACE) # run_metadata = tf.RunMetadata() # tf.reset_default_graph() Tt.record('prepare') scan_size = 0 ret_loss = 0. for batch_idx in range(data_loader.n_batch): # point = (epoch_idx - 1) * data_loader.n_batch + batch_idx Tt.start('allocate') local_data_list, _ = data_loader.get_next_batch() local_size = len(local_data_list[0]) # the first dimension is always batch size fd = {input_tf: local_data for input_tf, local_data in zip(self.optm_input_tf_list, local_data_list)} Tt.record('allocate') Tt.start('running') if self.debug_mode == 'Loss': local_loss, summary = self.sess.run( [self.avg_loss, self.optm_summary], feed_dict=fd, # options=run_options, # run_metadata=run_metadata ) elif self.debug_mode == 'Grad': _, local_loss, summary = self.sess.run( [self.clipped_gradients, self.avg_loss, self.optm_summary], feed_dict=fd, # options=run_options, # run_metadata=run_metadata ) else: # Update or Raw _, local_loss, summary = self.sess.run( [self.optm_step, self.avg_loss, self.optm_summary], feed_dict=fd, # options=run_options, # run_metadata=run_metadata ) Tt.record('running') Tt.start('display') ret_loss = (ret_loss * scan_size + local_loss * local_size) / (scan_size + local_size) scan_size += local_size if (batch_idx+1) % ob_batch_num == 0: LogInfo.logs('[optm-%s-B%d/%d] avg_loss = %.6f, scanned = %d/%d', data_loader.mode, batch_idx+1, data_loader.n_batch, ret_loss, scan_size, len(data_loader)) # if summary_writer is not None: # summary_writer.add_summary(summary, point) # if batch_idx == 0: # summary_writer.add_run_metadata(run_metadata, 'epoch-%d' % epoch_idx) Tt.record('display') Tt.record('optimize') return ret_loss
def perform_linking(self, q_idx, tok_list, entity_only=False): tok_list = map(lambda x: x.lower(), tok_list) # all steps ignore cases. Tt.start('entity') el_list = self.q_links_dict.get(q_idx, []) Tt.record('entity') if not entity_only: Tt.start('type') tl_list = self.type_linking(tok_list) Tt.record('type') Tt.start('time') tml_list = self.time_linking(tok_list) Tt.record('time') Tt.start('ordinal') ord_list = self.ordinal_linking(tok_list) Tt.record('ordinal') gather_linkings = el_list + tl_list + tml_list + ord_list else: # For SimpleQuestions gather_linkings = el_list for idx in range(len(gather_linkings)): gather_linkings[idx].gl_pos = idx return gather_linkings
def single_question_candgen(self, q_idx, qa, link_fp, opt_sc_fp): # =================== Linking first ==================== # Tt.start('linking') if os.path.isfile(link_fp): gather_linkings = [] with codecs.open(link_fp, 'r', 'utf-8') as br: for line in br.readlines(): tup_list = json.loads(line.strip()) ld_dict = {k: v for k, v in tup_list} gather_linkings.append(LinkData(**ld_dict)) LogInfo.logs('Read %d links from file.', len(gather_linkings)) else: tok_list = [tok.token for tok in qa['tokens']] gather_linkings = self.global_linker.perform_linking( q_idx=q_idx, tok_list=tok_list, entity_only=True ) LogInfo.begin_track('Show %d E links :', len(gather_linkings)) if self.vb >= 1: for gl in gather_linkings: LogInfo.logs(gl.display().encode('utf-8')) LogInfo.end_track() Tt.record('linking') # ==================== Save linking results ================ # if not os.path.isfile(link_fp): with codecs.open(link_fp + '.tmp', 'w', 'utf-8') as bw: for gl in gather_linkings: bw.write(json.dumps(gl.serialize()) + '\n') shutil.move(link_fp + '.tmp', link_fp) LogInfo.logs('%d link data save to file.', len(gather_linkings)) # ===================== simple predicate finding ===================== # gold_entity, gold_pred, _ = qa['targetValue'] sc_list = [] for gl_data in gather_linkings: entity = gl_data.value pred_set = self.subj_pred_dict.get(entity, set([])) for pred in pred_set: sc = CompqSchema() sc.hops = 1 sc.aggregate = False sc.main_pred_seq = [pred] sc.inv_main_pred_seq = [inverse_predicate(pred) for pred in sc.main_pred_seq] # [!p1, !p2] sc.inv_main_pred_seq.reverse() # [!p2, !p1] sc.raw_paths = [('Main', gl_data, [pred])] sc.ans_size = 1 if entity == gold_entity and pred == gold_pred: sc.f1 = sc.p = sc.r = 1. else: sc.f1 = sc.p = sc.r = 0. sc_list.append(sc) # ==================== Save schema results ================ # # p, r, f1, ans_size, hops, raw_paths, (agg) # raw_paths: (category, gl_pos, gl_mid, pred_seq) with codecs.open(opt_sc_fp + '.tmp', 'w', 'utf-8') as bw: for sc in sc_list: sc_info_dict = {k: getattr(sc, k) for k in ('p', 'r', 'f1', 'ans_size', 'hops')} if sc.aggregate is not None: sc_info_dict['agg'] = sc.aggregate opt_raw_paths = [] for cate, gl, pred_seq in sc.raw_paths: opt_raw_paths.append((cate, gl.gl_pos, gl.value, pred_seq)) sc_info_dict['raw_paths'] = opt_raw_paths bw.write(json.dumps(sc_info_dict) + '\n') shutil.move(opt_sc_fp + '.tmp', opt_sc_fp) LogInfo.logs('%d schemas successfully saved into [%s].', len(sc_list), opt_sc_fp)
cand_gen = SimpleQCandidateGenerator(fb_subset=args.fb_subset, lex_name=lex_name, wd_emb_util=wd_emb_util, vb=args.verbose) for q_idx, qa in enumerate(qa_list): LogInfo.begin_track('Entering Q %d / %d [%s]:', q_idx, len(qa_list), qa['utterance'].encode('utf-8')) sub_idx = q_idx / 1000 * 1000 sub_dir = '%s/data/%d-%d' % (args.output_dir, sub_idx, sub_idx + 999) if not os.path.exists(sub_dir): os.makedirs(sub_dir) opt_sc_fp = '%s/%d_schema' % (sub_dir, q_idx) link_fp = '%s/%d_links' % (sub_dir, q_idx) if os.path.isfile(opt_sc_fp): LogInfo.end_track('Skip this question, already saved.') continue Tt.start('single_q') cand_gen.single_question_candgen(q_idx=q_idx, qa=qa, link_fp=link_fp, opt_sc_fp=opt_sc_fp) Tt.record('single_q') LogInfo.end_track() # End of Q if __name__ == '__main__': LogInfo.begin_track('[kangqi.task.compQA.candgen_acl18.simpq_candgen] ... ') _args = parser.parse_args() main(_args) LogInfo.end_track('All Done.') Tt.display()
def working_in_data(data_dir, file_list, qa_list, sc_max_len=3): save_fp = data_dir + '/log.schema_check' bw = open(save_fp, 'w') LogInfo.redirect(bw) LogInfo.begin_track('Working in %s, with schemas in %s:', data_dir, file_list) with open(data_dir + '/' + file_list, 'r') as br: lines = br.readlines() used_pred_tup_list = [] stat_tup_list = [] for line_idx, line in enumerate(lines): if line_idx % 100 == 0: LogInfo.logs('Scanning %d / %d ...', line_idx, len(lines)) line = line.strip() q_idx = int(line.split('/')[2].split('_')[0]) # if q_idx != 1353: # continue schema_fp = data_dir + '/' + line link_fp = schema_fp.replace('schema', 'links') Tt.start('read_linkings') gather_linkings = [] with codecs.open(link_fp, 'r', 'utf-8') as br: for gl_line in br.readlines(): tup_list = json.loads(gl_line.strip()) ld_dict = {k: v for k, v in tup_list} gather_linkings.append(LinkData(**ld_dict)) Tt.record('read_linkings') Tt.start('read_schema') global_lists, used_pred_name_dict = \ read_schemas_from_single_file(q_idx, schema_fp, gather_linkings, sc_max_len) Tt.record('read_schema') stat_tup_list.append((q_idx, global_lists)) used_pred_tup_list.append((q_idx, used_pred_name_dict)) stat_tup_list.sort( key=lambda _tup: max_f1(_tup[1][ELEGANT]) - max_f1(_tup[1][STRICT]), reverse=True) LogInfo.logs('sc_len distribution: %s', sc_len_list) LogInfo.logs('Rank\tQ_idx\tstri.\teleg.\tcohe.\tgene.' '\tstri._F1\teleg._F1\tcohe._F1\tgene._F1\tUtterance') for rank, stat_tup in enumerate(stat_tup_list): q_idx, global_lists = stat_tup size_list = ['%4d' % len(cand_list) for cand_list in global_lists] max_f1_list = [ '%8.6f' % max_f1(cand_list) for cand_list in global_lists ] size_str = '\t'.join(size_list) max_f1_str = '\t'.join(max_f1_list) LogInfo.logs('%4d\t%4d\t%s\t%s\t%s', rank + 1, q_idx, size_str, max_f1_str, qa_list[q_idx]['utterance'].encode('utf-8')) q_size = len(stat_tup_list) f1_upperbound_list = [] avg_cand_size_list = [] found_entity_list = [] for index in range(4): local_max_f1_list = [] local_cand_size_list = [] for _, global_lists in stat_tup_list: sc_list = global_lists[index] local_max_f1_list.append(max_f1(sc_list)) local_cand_size_list.append(len(sc_list)) local_found_entity = len(filter(lambda y: y > 0., local_max_f1_list)) f1_upperbound_list.append(1.0 * sum(local_max_f1_list) / q_size) avg_cand_size_list.append(1.0 * sum(local_cand_size_list) / q_size) found_entity_list.append(local_found_entity) LogInfo.logs(' strict: avg size = %.6f, upp F1 = %.6f', avg_cand_size_list[STRICT], f1_upperbound_list[STRICT]) for name, index in zip(('strict', 'elegant', 'coherent', 'general'), (STRICT, ELEGANT, COHERENT, GENERAL)): avg_size = avg_cand_size_list[index] ratio = 1. * avg_size / avg_cand_size_list[ STRICT] if avg_cand_size_list[STRICT] > 0 else 0. found_entity = found_entity_list[index] upp_f1 = f1_upperbound_list[index] gain = upp_f1 - f1_upperbound_list[STRICT] LogInfo.logs( '%8s: avg size = %.6f (%.2fx), found entity = %d, ' 'upp F1 = %.6f, gain = %.6f', name, avg_size, ratio, found_entity, upp_f1, gain) LogInfo.end_track() bw.close() LogInfo.stop_redirect() used_pred_fp = data_dir + '/log.used_pred_stat' bw = codecs.open(used_pred_fp, 'w', 'utf-8') LogInfo.redirect(bw) ratio_list = [] for q_idx, used_pred_name_dict in used_pred_tup_list: unique_id_size = len(set(used_pred_name_dict.keys())) unique_name_size = len(set(used_pred_name_dict.values())) ratio = 1. * unique_name_size / unique_id_size if unique_id_size > 0 else 1. ratio_list.append(ratio) LogInfo.logs('Q = %4d, unique id = %d, unique name = %d, ratio = %.4f', q_idx, unique_id_size, unique_name_size, ratio) avg_ratio = sum(ratio_list) / len(ratio_list) LogInfo.logs('avg_ratio = %.4f', avg_ratio) LogInfo.stop_redirect()
def coarse_search(self, path_len, entity_linkings, conflict_matrix, cur_el_idx, cur_comb, visit_arr, coarse_state_list, state_marker, aggregate): if self.vb >= 1: LogInfo.begin_track('[%s] (%s)', '||'.join(state_marker), datetime.now().strftime('%Y-%m-%d %H:%M:%S')) fuzzy_query = self.build_fuzzy_query(path_len, entity_linkings, cur_comb) if self.vb >= 1: LogInfo.logs('Fuzzy query: %s', fuzzy_query) st = time.time() Tt.start('query_sparql') query_ret = self.query_service.query_sparql(fuzzy_query) Tt.record('query_sparql') if self.vb >= 4: LogInfo.logs('Recv >>>: %s', query_ret) filt_query_ret = predicate_filtering(query_ret=query_ret, path_len=path_len) if self.vb >= 1: LogInfo.logs('Filt_Query_Ret = %d / %d (%.3fs)', len(filt_query_ret), len(query_ret), time.time() - st) if len(filt_query_ret) == 0: if self.vb >= 1: LogInfo.end_track() return # no need to search deeper coarse_schemas = self.build_coarse_schemas( path_len=path_len, el_linkings=entity_linkings, cur_comb=cur_comb, filt_query_ret=filt_query_ret, aggregate=aggregate) for sc in coarse_schemas: coarse_state_list.append( (sc, list(visit_arr))) # save coarse schemas into the output state # Ready to search deeper el_size = len(entity_linkings) for nxt_el_idx in range(cur_el_idx + 1, el_size): gl_pos = entity_linkings[nxt_el_idx].gl_pos if visit_arr[gl_pos] != 0: # cannot be visited due to conflict continue for conf_idx in conflict_matrix[ gl_pos]: # ready to enter the next state visit_arr[conf_idx] += 1 for attach_idx in range( 1, path_len + 1): # enumerate each possible attach position nxt_comb = list(cur_comb) nxt_comb.append((attach_idx, nxt_el_idx)) state_marker.append('%d/%d-%d' % (nxt_el_idx + 1, el_size, attach_idx)) self.coarse_search(path_len=path_len, entity_linkings=entity_linkings, conflict_matrix=conflict_matrix, cur_el_idx=nxt_el_idx, cur_comb=nxt_comb, visit_arr=visit_arr, coarse_state_list=coarse_state_list, state_marker=state_marker, aggregate=aggregate) del state_marker[-1] for conf_idx in conflict_matrix[gl_pos]: # return back visit_arr[conf_idx] -= 1 # Ends of DFS if self.vb >= 1: LogInfo.end_track()
def single_question_candgen(self, q_idx, qa, link_fp, opt_sc_fp, linking_only): # =================== Linking first ==================== # Tt.start('linking') if os.path.isfile(link_fp): gather_linkings = [] with codecs.open(link_fp, 'r', 'utf-8') as br: for line in br.readlines(): tup_list = json.loads(line.strip()) ld_dict = {k: v for k, v in tup_list} gather_linkings.append(LinkData(**ld_dict)) LogInfo.logs('Read %d links from file.', len(gather_linkings)) else: tok_list = [tok.token for tok in qa['tokens']] gather_linkings = self.global_linker.perform_linking(q_idx=q_idx, tok_list=tok_list) el_size = len(filter(lambda x: x.category == 'Entity', gather_linkings)) tl_size = len(filter(lambda x: x.category == 'Type', gather_linkings)) tml_size = len(filter(lambda x: x.category == 'Time', gather_linkings)) ord_size = len(filter(lambda x: x.category == 'Ordinal', gather_linkings)) LogInfo.begin_track('Show %d E + %d T + %d Tm + %d Ord = %d linkings:', el_size, tl_size, tml_size, ord_size, len(gather_linkings)) if self.vb >= 1: for gl in gather_linkings: LogInfo.logs(gl.display()) LogInfo.end_track() Tt.record('linking') # ==================== Save linking results ================ # if not os.path.isfile(link_fp): with codecs.open(link_fp + '.tmp', 'w', 'utf-8') as bw: for gl in gather_linkings: bw.write(json.dumps(gl.serialize()) + '\n') shutil.move(link_fp + '.tmp', link_fp) LogInfo.logs('%d link data save to file.', len(gather_linkings)) if linking_only: return # ===================== Prepare necessary states for linking =============== # q_id = '%s_%d' % (self.data_name, q_idx) conflict_matrix = construct_conflict_matrix(gather_linkings) entity_linkings = filter(lambda x: x.category == 'Entity', gather_linkings) type_linkings = filter(lambda x: x.category == 'Type', gather_linkings) time_linkings = filter(lambda x: x.category == 'Time', gather_linkings) ordinal_linkings = filter(lambda x: x.category == 'Ordinal', gather_linkings) # ============================== Searching start ========================= # aggregate = False """ 180308: always ignore aggregation, since we've found the f****d up data distribution """ # lower_tok_list = [tok.token.lower() for tok in qa['tokens']] # lower_tok_str = ' '.join(lower_tok_list) # aggregate = lower_tok_str.startswith('how many ') # # apply COUNT(*) to all candidate schemas if we found "how many" at the beginning of a question """ ================ Step 1: coarse linking, using entities only ================ """ LogInfo.begin_track('Coarse level searching (total entities = %d):', el_size) entity_linkings.sort(key=lambda _el: _el.value) """ We sort the linking data, for identifying potential duplicate SPARQL queries and saving time. """ Tt.start('coarse_comb') coarse_states = self.cand_searcher.find_coarse_schemas( entity_linkings=entity_linkings, conflict_matrix=conflict_matrix, aggregate=aggregate) # [(schema, visit_arr)] LogInfo.logs('%d coarse schemas retrieved from scratch.', len(coarse_states)) Tt.record('coarse_comb') coarse_states = self.batch_schema_f1_query(q_id=q_id, states=coarse_states, level='coarse') LogInfo.end_track('Coarse level ended, resulting in %d schemas.', len(coarse_states)) """ ================ Step 2: adding type information ================ """ LogInfo.begin_track('Type level searching (total types = %d):', tl_size) Tt.start('typed_comb') typed_states = [] for idx, (coarse_sc, visit_arr) in enumerate(coarse_states): if idx % 100 == 0: LogInfo.logs('Current: %d / %d', idx, len(coarse_states)) typed_states += self.cand_searcher.find_typed_schemas( type_linkings=type_linkings, conflict_matrix=conflict_matrix, start_schema=coarse_sc, visit_arr=visit_arr, simple_type_match=self.simple_type_match ) LogInfo.logs('%d typed schemas retrieved from %d coarse schemas.', len(typed_states), len(coarse_states)) Tt.record('typed_comb') typed_states = self.batch_schema_f1_query(q_id=q_id, states=typed_states, level='typed') typed_states = coarse_states + typed_states # don't forget those schemas without type constraints LogInfo.end_track('Typed level ended, resulting in %d schemas.', len(typed_states)) """ ================ Step 3: adding time information ================ """ LogInfo.begin_track('Time level searching (total times = %d):', tml_size) Tt.start('timed_comb') timed_states = [] for idx, (typed_sc, visit_arr) in enumerate(typed_states): if idx % 100 == 0: LogInfo.logs('Current: %d / %d', idx, len(typed_states)) timed_states += self.cand_searcher.find_timed_schemas( time_linkings=time_linkings, conflict_matrix=conflict_matrix, start_schema=typed_sc, visit_arr=visit_arr, simple_time_match=self.simple_time_match # True: degenerates into Bao ) LogInfo.logs('%d timed schemas retrieved from %d typed schemas.', len(timed_states), len(typed_states)) Tt.record('timed_comb') timed_states = self.batch_schema_f1_query(q_id=q_id, states=timed_states, level='timed') timed_states = typed_states + timed_states # don't forget the previous schemas LogInfo.end_track('Time level ended, resulting in %d schemas.', len(timed_states)) """ ================ Step 4: ordinal information as the final step ================ """ LogInfo.begin_track('Ordinal level searching (total ordinals = %d):', ord_size) Tt.start('ord_comb') final_states = [] for idx, (timed_sc, visit_arr) in enumerate(timed_states): if idx % 100 == 0: LogInfo.logs('Current: %d / %d', idx, len(timed_states)) final_states += self.cand_searcher.find_ordinal_schemas(ordinal_linkings=ordinal_linkings, start_schema=timed_sc, visit_arr=visit_arr) LogInfo.logs('%d ordinal schemas retrieved from %d timed schemas.', len(final_states), len(timed_states)) Tt.record('ord_comb') final_states = self.batch_schema_f1_query(q_id=q_id, states=final_states, level='ordinal') final_states = timed_states + final_states LogInfo.end_track('Ordinal level ended, we finally collected %d schemas.', len(final_states)) final_schemas = [tup[0] for tup in final_states] self.query_srv.save_buffer() # ==================== Save schema results ================ # # p, r, f1, ans_size, hops, raw_paths, (agg) # raw_paths: (category, gl_pos, gl_mid, pred_seq) with codecs.open(opt_sc_fp + '.tmp', 'w', 'utf-8') as bw: for sc in final_schemas: sc_info_dict = {k: getattr(sc, k) for k in ('p', 'r', 'f1', 'ans_size', 'hops')} if sc.aggregate is not None: sc_info_dict['agg'] = sc.aggregate opt_raw_paths = [] for cate, gl, pred_seq in sc.raw_paths: opt_raw_paths.append((cate, gl.gl_pos, gl.value, pred_seq)) sc_info_dict['raw_paths'] = opt_raw_paths bw.write(json.dumps(sc_info_dict) + '\n') shutil.move(opt_sc_fp + '.tmp', opt_sc_fp) LogInfo.logs('%d schemas successfully saved into [%s].', len(final_schemas), opt_sc_fp) del coarse_states del typed_states del timed_states del final_states
def main(args): LogInfo.begin_track('Learning starts ... ') # ==== Loading Necessary Util ==== # LogInfo.begin_track('Loading Utils ... ') wd_emb_util = WordEmbeddingUtil(wd_emb=args.word_emb, dim_emb=args.dim_emb) LogInfo.end_track() # ==== Loading Dataset ==== # data_config = literal_eval(args.data_config) # including data_name, dir, max_length and others data_config['wd_emb_util'] = wd_emb_util # data_config['kb_emb_util'] = kb_emb_util data_config['verbose'] = args.verbose dataset = QScDataset(**data_config) dataset.load_size() # load size info # ==== Build Model First ==== # LogInfo.begin_track('Building Model and Session ... ') gpu_options = tf.GPUOptions(allow_growth=True, per_process_gpu_memory_fraction=args.gpu_fraction) sess = tf.Session(config=tf.ConfigProto(gpu_options=gpu_options, intra_op_parallelism_threads=8)) rm_config = literal_eval(args.rm_config) # Relation Matching rm_name = rm_config['name'] del rm_config['name'] assert rm_name in ('Compact', 'Separated') rm_config['n_words'] = dataset.word_size rm_config['n_mids'] = dataset.mid_size rm_config['dim_emb'] = args.dim_emb rm_config['q_max_len'] = dataset.q_max_len rm_config['sc_max_len'] = dataset.sc_max_len rm_config['path_max_len'] = dataset.path_max_len rm_config['pword_max_len'] = dataset.path_max_len * dataset.item_max_len rm_config['verbose'] = args.verbose if rm_name == 'Compact': LogInfo.logs('RelationMatchingKernel: Compact') rm_kernel = CompactRelationMatchingKernel(**rm_config) else: LogInfo.logs('RelationMatchingKernel: Separated') rm_kernel = SeparatedRelationMatchingKernel(**rm_config) el_kernel = EntityLinkingKernel( e_max_size=dataset.e_max_size, e_feat_len=dataset.e_feat_len, verbose=args.verbose) model_config = literal_eval(args.model_config) model_config['sess'] = sess model_config['objective'] = args.eval_mode # relation_only / normal model_config['relation_kernel'] = rm_kernel model_config['entity_kernel'] = el_kernel model_config['extra_len'] = dataset.extra_len model_config['verbose'] = args.verbose compq_model = CompqModel(**model_config) LogInfo.begin_track('Showing final parameters: ') for var in tf.global_variables(): LogInfo.logs('%s: %s', var.name, var.get_shape().as_list()) LogInfo.end_track() saver = tf.train.Saver() LogInfo.begin_track('Parameter initializing ... ') start_epoch = 0 best_valid_f1 = 0. resume_flag = False model_dir = None if args.resume_model_name not in ('', 'None'): model_dir = '%s/%s' % (args.output_dir, args.resume_model_name) if os.path.exists(model_dir): resume_flag = True if resume_flag: start_epoch, best_valid_f1 = load_model(saver=saver, sess=sess, model_dir=model_dir) else: dataset.load_init_emb() # loading parameters for embedding initialize LogInfo.logs('Running global_variables_initializer ...') sess.run(tf.global_variables_initializer(), feed_dict={rm_kernel.w_embedding_init: dataset.word_init_emb, rm_kernel.m_embedding_init: dataset.mid_init_emb}) LogInfo.end_track('Start Epoch = %d', start_epoch) LogInfo.end_track('Model Built.') tf.get_default_graph().finalize() # ==== Constructing Data_Loader ==== # LogInfo.begin_track('Creating DataLoader ... ') dataset.load_cands() # first loading all the candidates if args.eval_mode == 'relation_only': ro_change = 0 for cand_list in dataset.q_cand_dict.values(): ro_change += add_relation_only_metric(cand_list) # for "RelationOnly" evaluation LogInfo.logs('RelationOnly F1 change: %d schemas affected.', ro_change) optm_dl_config = {'dataset': dataset, 'mode': 'train', 'batch_size': args.optm_batch_size, 'proc_ob_num': 5000, 'verbose': args.verbose} eval_dl_config = dict(optm_dl_config) spt = args.dl_neg_mode.split('-') # Neg-${POOR_CONTRIB}-${POOR_MAX_SAMPLE} optm_dl_config['poor_contribution'] = int(spt[1]) optm_dl_config['poor_max_sample'] = int(spt[2]) optm_dl_config['shuffle'] = False optm_train_data = CompqPairDataLoader(**optm_dl_config) eval_dl_config['batch_size'] = args.eval_batch_size eval_data_group = [] for mode in ('train', 'valid', 'test'): eval_dl_config['mode'] = mode eval_data = CompqSingleDataLoader(**eval_dl_config) eval_data.renew_data_list() eval_data_group.append(eval_data) (eval_train_data, eval_valid_data, eval_test_data) = eval_data_group LogInfo.end_track() # End of loading data & dataset # ==== Free memories ==== # for item in (wd_emb_util, dataset.wd_emb_util, data_config): del item # ==== Ready for learning ==== # LogInfo.begin_track('Learning start ... ') output_dir = args.output_dir if not os.path.exists(output_dir + '/detail'): os.makedirs(output_dir + '/detail') if not os.path.exists(output_dir + '/result'): os.makedirs(output_dir + '/result') if os.path.isdir(output_dir + '/TB'): shutil.rmtree(output_dir + '/TB') tf.summary.FileWriter(output_dir + '/TB/optm', sess.graph) # saving model graph information # optm_summary_writer = tf.summary.FileWriter(output_dir + '/TB/optm', sess.graph) # eval_train_summary_writer = tf.summary.FileWriter(output_dir + '/TB/eval_train', sess.graph) # eval_valid_summary_writer = tf.summary.FileWriter(output_dir + '/TB/eval_valid', sess.graph) # eval_test_summary_writer = tf.summary.FileWriter(output_dir + '/TB/eval_test', sess.graph) # LogInfo.logs('TensorBoard writer defined.') # TensorBoard imformation status_fp = output_dir + '/status.csv' with open(status_fp, 'a') as bw: bw.write('%s\t%s\t%s\t%s\t%s\t%s\t%s\n' % ( 'Epoch', 'T_loss', 'T_F1', 'v_F1', 'Status', 't_F1', datetime.now().strftime('%Y-%m-%d_%H:%M:%S') )) patience = args.max_patience for epoch in range(start_epoch + 1, args.max_epoch + 1): if patience == 0: LogInfo.logs('Early stopping at epoch = %d.', epoch) break LogInfo.begin_track('Epoch %d / %d', epoch, args.max_epoch) update_flag = False LogInfo.begin_track('Optimizing ...') train_loss = compq_model.optimize(optm_train_data, epoch, ob_batch_num=1) LogInfo.end_track('T_loss = %.6f', train_loss) LogInfo.begin_track('Eval-Training ...') train_f1 = compq_model.evaluate(eval_train_data, epoch, ob_batch_num=50, detail_fp=output_dir + '/detail/train_%03d.txt' % epoch) LogInfo.end_track('T_F1 = %.6f', train_f1) LogInfo.begin_track('Eval-Validating ...') valid_f1 = compq_model.evaluate(eval_valid_data, epoch, ob_batch_num=50, detail_fp=output_dir + '/detail/valid_%03d.txt' % epoch) LogInfo.logs('v_F1 = %.6f', valid_f1) if valid_f1 > best_valid_f1: best_valid_f1 = valid_f1 update_flag = True patience = args.max_patience else: patience -= 1 LogInfo.logs('Model %s, best v_F1 = %.6f [patience = %d]', 'updated' if update_flag else 'stayed', valid_f1, patience) LogInfo.end_track() LogInfo.begin_track('Eval-Testing ... ') test_f1 = compq_model.evaluate(eval_test_data, epoch, ob_batch_num=20, detail_fp=output_dir + '/detail/test_%03d.txt' % epoch, result_fp=output_dir + '/result/test_schema_%03d.txt' % epoch) LogInfo.end_track('t_F1 = %.6f', test_f1) with open(status_fp, 'a') as bw: bw.write('%d\t%8.6f\t%8.6f\t%8.6f\t%s\t%8.6f\t%s\n' % ( epoch, train_loss, train_f1, valid_f1, 'UPDATE' if update_flag else str(patience), test_f1, datetime.now().strftime('%Y-%m-%d_%H:%M:%S') )) save_epoch_dir = '%s/model_epoch_%d' % (output_dir, epoch) save_best_dir = '%s/model_best' % output_dir if args.save_epoch: delete_dir(save_epoch_dir) save_model(saver=saver, sess=sess, model_dir=save_epoch_dir, epoch=epoch, valid_metric=valid_f1) if update_flag and args.save_best: # just create a symbolic link delete_dir(save_best_dir) os.symlink(save_epoch_dir, save_best_dir) # symlink at directory level elif update_flag and args.save_best: delete_dir(save_best_dir) save_model(saver=saver, sess=sess, model_dir=save_best_dir, epoch=epoch, valid_metric=valid_f1) LogInfo.end_track() # End of epoch LogInfo.end_track() # End of learning Tt.display() LogInfo.end_track('All Done.')