def infer(nlu1, table_name, data_table, path_db, db_name, model, model_bert, bert_config, max_seq_length, num_target_layers, beam_size=4, show_table=False, show_answer_only=False): # I know it is of against the DRY principle but to minimize the risk of introducing bug w, the infer function introuced. model.eval() model_bert.eval() engine = DBEngine(os.path.join(path_db, f"{db_name}.db")) # Get inputs nlu = [nlu1] # nlu_t1 = tokenize_corenlp(client, nlu1) nlu_t1 = tokenize_corenlp_direct_version(client, nlu1) nlu_t = [nlu_t1] tb1 = data_table[0] hds1 = tb1['header'] tb = [tb1] hds = [hds1] hs_t = [[]] wemb_n, wemb_h, l_n, l_hpu, l_hs, \ nlu_tt, t_to_tt_idx, tt_to_t_idx \ = get_wemb_bert(bert_config, model_bert, tokenizer, nlu_t, hds, max_seq_length, num_out_layers_n=num_target_layers, num_out_layers_h=num_target_layers) prob_sca, prob_w, prob_wn_w, pr_sc, pr_sa, pr_wn, pr_sql_i = model.beam_forward( wemb_n, l_n, wemb_h, l_hpu, l_hs, engine, tb, nlu_t, nlu_tt, tt_to_t_idx, nlu, beam_size=beam_size) # sort and generate pr_wc, pr_wo, pr_wv, pr_sql_i = sort_and_generate_pr_w(pr_sql_i) if len(pr_sql_i) != 1: raise EnvironmentError pr_sql_q1 = generate_sql_q(pr_sql_i, [tb1]) pr_sql_q = [pr_sql_q1] try: pr_ans, _ = engine.execute_return_query(tb[0]['id'], pr_sc[0], pr_sa[0], pr_sql_i[0]['conds']) except: pr_ans = ['Answer not found.'] pr_sql_q = ['Answer not found.'] if show_answer_only: print(f'Q: {nlu[0]}') print(f'A: {pr_ans[0]}') print(f'SQL: {pr_sql_q}') else: print( f'START ============================================================= ' ) print(f'{hds}') if show_table: print(engine.show_table(table_name)) print(f'nlu: {nlu}') print(f'pr_sql_i : {pr_sql_i}') print(f'pr_sql_q : {pr_sql_q}') print(f'pr_ans: {pr_ans}') print( f'---------------------------------------------------------------------' ) return pr_sql_i, pr_ans
def infernew(dev_loader, data_table, model, model_bert, bert_config, tokenizer, max_seq_length, num_target_layers, detail=False, path_db=None, st_pos=0, dset_name='train', EG=False, beam_size=4): model.eval() model_bert.eval() ave_loss = 0 cnt = 0 cnt_sc = 0 cnt_sa = 0 cnt_wn = 0 cnt_wc = 0 cnt_wo = 0 cnt_wv = 0 cnt_wvi = 0 cnt_lx = 0 cnt_x = 0 cnt_list = [] engine = DBEngine(os.path.join(path_db, f"{dset_name}.db")) results = [] count = 0 for iB, t in enumerate(dev_loader): nlu, nlu_t, sql_i, sql_q, sql_t, tb, hs_t, hds = get_fields( t, data_table, no_hs_t=True, no_sql_t=True) g_sc, g_sa, g_wn, g_wc, g_wo, g_wv = get_g(sql_i) g_wvi_corenlp = get_g_wvi_corenlp(t) wemb_n, wemb_h, l_n, l_hpu, l_hs, \ nlu_tt, t_to_tt_idx, tt_to_t_idx \ = get_wemb_bert(bert_config, model_bert, tokenizer, nlu_t, hds, max_seq_length, num_out_layers_n=num_target_layers, num_out_layers_h=num_target_layers) try: g_wvi = get_g_wvi_bert_from_g_wvi_corenlp(t_to_tt_idx, g_wvi_corenlp) g_wv_str, g_wv_str_wp = convert_pr_wvi_to_string( g_wvi, nlu_t, nlu_tt, tt_to_t_idx, nlu) except: # Exception happens when where-condition is not found in nlu_tt. # In this case, that train example is not used. # During test, that example considered as wrongly answered. for b in range(len(nlu)): results1 = {} results1["error"] = "Skip happened" results1["nlu"] = nlu[b] results1["table_id"] = tb[b]["id"] results.append(results1) continue knowledge = [] for k in t: if "bertindex_knowledge" in k: knowledge.append(k["bertindex_knowledge"]) else: knowledge.append(max(l_n) * [0]) knowledge_header = [] for k in t: if "header_knowledge" in k: knowledge_header.append(k["header_knowledge"]) else: knowledge_header.append(max(l_hs) * [0]) if not EG: # No Execution guided decoding s_sc, s_sa, s_wn, s_wc, s_wo, s_wv = model( wemb_n, l_n, wemb_h, l_hpu, l_hs, knowledge=knowledge, knowledge_header=knowledge_header) # get loss & step loss = Loss_sw_se(s_sc, s_sa, s_wn, s_wc, s_wo, s_wv, g_sc, g_sa, g_wn, g_wc, g_wo, g_wvi) # prediction pr_sc, pr_sa, pr_wn, pr_wc, pr_wo, pr_wvi = pred_sw_se( s_sc, s_sa, s_wn, s_wc, s_wo, s_wv, ) pr_wv_str, pr_wv_str_wp = convert_pr_wvi_to_string( pr_wvi, nlu_t, nlu_tt, tt_to_t_idx, nlu) # g_sql_i = generate_sql_i(g_sc, g_sa, g_wn, g_wc, g_wo, g_wv_str, nlu) pr_sql_i = generate_sql_i(pr_sc, pr_sa, pr_wn, pr_wc, pr_wo, pr_wv_str, nlu) else: # Execution guided decoding prob_sca, prob_w, prob_wn_w, pr_sc, pr_sa, pr_wn, pr_sql_i = model.beam_forward( wemb_n, l_n, wemb_h, l_hpu, l_hs, engine, tb, nlu_t, nlu_tt, tt_to_t_idx, nlu, beam_size=beam_size, knowledge=knowledge, knowledge_header=knowledge_header) # sort and generate pr_wc, pr_wo, pr_wv, pr_sql_i = sort_and_generate_pr_w(pr_sql_i) # Follosing variables are just for the consistency with no-EG case. pr_wvi = None # not used pr_wv_str = None pr_wv_str_wp = None loss = torch.tensor([0]) pr_sql_q1 = generate_sql_q(pr_sql_i, tb) pr_sql_q = [pr_sql_q1] try: pr_ans, _ = engine.execute_return_query(tb[0]['id'], pr_sc[0], pr_sa[0], pr_sql_i[0]['conds']) except: pr_ans = ['Answer not found.'] pr_sql_q = ['Answer not found.'] yes = True if yes: print(f'Q: {nlu[0]}') print(f'A: {pr_ans[0]}') print(f'SQL: {pr_sql_q}') else: print( f'START ============================================================= ' ) print(f'{hds}') if yes: print(engine.show_table(table_name)) print(f'nlu: {nlu}') print(f'pr_sql_i : {pr_sql_i}') print(f'pr_sql_q : {pr_sql_q}') print(f'pr_ans: {pr_ans}') print( f'---------------------------------------------------------------------' ) return pr_sql_i, pr_ans
def infer(nlu1, table_name, data_table, path_db, db_name, model, model_bert, bert_config, max_seq_length, num_target_layers, beam_size=4, show_table=False, show_answer_only=False): # I know it is of against the DRY principle but to minimize the risk of introducing bug w, the infer function introuced. model.eval() model_bert.eval() engine = DBEngine(os.path.join(path_db, f"{db_name}.db")) # 问题输入 nlu = [nlu1] #问题数组 ''' ==tokenize_corenlp_direct_version函数作用:就是英文分词(可能按照stanza规则分?)== ==client:stanford的corenlp代理类== ==nlu1:刚刚定义的问题列表== 2020/12/02修改:修改infer中文分词问题 2020/12/11修改:取消使用stanza中文分词,直接完全分词 ''' # nlu_t1 = tokenize_corenlp_direct_version(client, nlu1) nlu_t1 = list(nlu1) nlu_t = [nlu_t1] # 把分词之后的数据也放到数组里 #tb1 = data_table[0] ''' 2020/12/01修改:tb1根据问题来选择表 循环查找即可 ''' for temple_table in data_table: if temple_table['name'] == table_name: tb1 = temple_table break hds1 = tb1['header'] tb = [tb1] hds = [hds1] hs_t = [[]] # 获取bert-output wemb_n, wemb_h, l_n, l_hpu, l_hs, \ nlu_tt, t_to_tt_idx, tt_to_t_idx \ = get_wemb_bert(bert_config, model_bert, tokenizer, nlu_t, hds, max_seq_length, num_out_layers_n=num_target_layers, num_out_layers_h=num_target_layers) # 获取sqlova-output '''2020/12/11修改:换用model预测''' # prob_sca, prob_w, prob_wn_w, pr_sc, pr_sa, pr_wn, pr_sql_i = model.beam_forward(wemb_n, l_n, wemb_h, l_hpu, # l_hs, engine, tb, # nlu_t, nlu_tt, # tt_to_t_idx, nlu, # beam_size=beam_size) # 获取sqlova得出的6大重要部分参数 pr_sc pr_sa s_sc, s_sa, s_wn, s_wc, s_wo, s_wv = model(wemb_n, l_n, wemb_h, l_hpu, l_hs) # 将权重参数变成值 pr_sc, pr_sa, pr_wn, pr_wc, pr_wo, pr_wvi = pred_sw_se( s_sc, s_sa, s_wn, s_wc, s_wo, s_wv, ) # 根据值得出where-value(这里是分词版,一个一个切分) pr_wv_str, pr_wv_str_wp = convert_pr_wvi_to_string(pr_wvi, nlu_t, nlu_tt, tt_to_t_idx, nlu) # 最后在将值整合成conds(pre版本) '''2020/12/03修改:将agg和sel变为列表形式(generate_sql_i函数内修改)''' pr_sql_i = generate_sql_i(pr_sc, pr_sa, pr_wn, pr_wc, pr_wo, pr_wv_str, nlu) # 切分出where-col/where-op/where-val pr_wc, pr_wo, pr_wv, pr_sql_i = sort_and_generate_pr_w(pr_sql_i) if len(pr_sql_i) != 1: # 判断是不是生成了conds raise EnvironmentError pr_sql_q1 = generate_sql_q(pr_sql_i, [tb1]) # 根据上面的conds生成sql语句 pr_sql_q = [pr_sql_q1] # 将生成的sql语句放到list里,因为infer可能有多句 '''下面执行SQL语句''' try: pr_ans, _ = engine.execute_return_query(tb[0]['id'], pr_sc[0], pr_sa[0], pr_sql_i[0]['conds']) except: pr_ans = ['Answer not found.'] pr_sql_q = ['Answer not found.'] if show_answer_only: print(f'Q: {nlu[0]}') print(f'A: {pr_ans[0]}') print(f'SQL: {pr_sql_q}') else: print( f'START ============================================================= ' ) print(f'{hds}') if show_table: print(engine.show_table(table_name)) print(f'nlu: {nlu}') print(f'pr_sql_i : {pr_sql_i}') print(f'pr_sql_q : {pr_sql_q}') print(f'pr_ans: {pr_ans}') print( f'---------------------------------------------------------------------' ) return pr_sql_i, pr_ans