def predict(data_loader, data_table, model, model_bert, bert_config, tokenizer, max_seq_length, num_target_layers, detail=False, st_pos=0, cnt_tot=1, EG=False, beam_size=4, path_db=None, dset_name='test'): model.eval() model_bert.eval() engine = DBEngine(os.path.join(path_db, f"{dset_name}.db")) results = [] for iB, t in enumerate(data_loader): nlu, nlu_t, sql_i, sql_q, sql_t, tb, hs_t, hds = get_fields(t, data_table, no_hs_t=True, no_sql_t=True) g_sc, g_sa, g_wn, g_wc, g_wo, g_wv = get_g(sql_i) g_wvi_corenlp = get_g_wvi_corenlp(t) wemb_n, wemb_h, l_n, l_hpu, l_hs, \ nlu_tt, t_to_tt_idx, tt_to_t_idx \ = get_wemb_bert(bert_config, model_bert, tokenizer, nlu_t, hds, max_seq_length, num_out_layers_n=num_target_layers, num_out_layers_h=num_target_layers) if not EG: # No Execution guided decoding s_sc, s_sa, s_wn, s_wc, s_wo, s_wv = model(wemb_n, l_n, wemb_h, l_hpu, l_hs) pr_sc, pr_sa, pr_wn, pr_wc, pr_wo, pr_wvi = pred_sw_se(s_sc, s_sa, s_wn, s_wc, s_wo, s_wv, ) pr_wv_str, pr_wv_str_wp = convert_pr_wvi_to_string(pr_wvi, nlu_t, nlu_tt, tt_to_t_idx, nlu) pr_sql_i = generate_sql_i(pr_sc, pr_sa, pr_wn, pr_wc, pr_wo, pr_wv_str, nlu) else: # Execution guided decoding prob_sca, prob_w, prob_wn_w, pr_sc, pr_sa, pr_wn, pr_sql_i = model.beam_forward(wemb_n, l_n, wemb_h, l_hpu, l_hs, engine, tb, nlu_t, nlu_tt, tt_to_t_idx, nlu, beam_size=beam_size) # sort and generate pr_wc, pr_wo, pr_wv, pr_sql_i = sort_and_generate_pr_w(pr_sql_i) # Following variables are just for consistency with no-EG case. pr_wvi = None # not used pr_wv_str=None pr_wv_str_wp=None pr_sql_q = generate_sql_q(pr_sql_i, tb) pr_sql_q_base = generate_sql_q_base(pr_sql_i, tb) for b, (pr_sql_i1, pr_sql_q1, pr_sql_q1_base) in enumerate(zip(pr_sql_i, pr_sql_q, pr_sql_q_base)): results1 = {} results1["query"] = pr_sql_i1 results1["table_id"] = tb[b]["id"] results1["nlu"] = nlu[b] results1["sql"] = pr_sql_q1 results1["sql_with_params"] = pr_sql_q1_base rr = engine.execute_query(tb[b]["id"], Query.from_dict(pr_sql_i1, ordered=True), lower=False) results1["answer"] = rr results.append(results1) return results
parser.add_argument('--source_file', help='source file for the prediction', default=path_source) parser.add_argument('--db_file', help='source database for the prediction', default=path_db) parser.add_argument('--pred_file', help='predictions by the model', default=path_pred) parser.add_argument('--ordered', action='store_true', help='whether the exact match should consider the order of conditions') args = parser.parse_args() args.ordered = ordered engine = DBEngine(args.db_file) exact_match = [] with open(args.source_file) as fs, open(args.pred_file) as fp: grades = [] for ls, lp in tqdm(zip(fs, fp), total=count_lines(args.source_file)): eg = json.loads(ls) ep = json.loads(lp) qg = Query.from_dict(eg['sql'], ordered=args.ordered) gold = engine.execute_query(eg['table_id'], qg, lower=True) pred = ep.get('error', None) qp = None if not ep.get('error', None): try: qp = Query.from_dict(ep['query'], ordered=args.ordered) pred = engine.execute_query(eg['table_id'], qp, lower=True) except Exception as e: pred = repr(e) correct = pred == gold match = qp == qg grades.append(correct) exact_match.append(match) print(json.dumps({
def append_cond(sel, cols, agg, spans, conds, sqls, engine, gold, table_id, is_final=False): span_len = len(spans) for con_col in range(cols): con2_col = con_col #print(f"con2_col:{con2_col}") for op_i, _ in enumerate(cond_ops): con2_op = op_i #print(f"con2_op:{con2_op}") for i in range(span_len): for span in spans[i]: _conds = deepcopy(conds) con2 = ' '.join(span) try: con2 = float(con2) except ValueError: if con2_op > 0: continue _conds.append([con2_col, con2_op, con2]) sql = {'sel': sel, 'conds': _conds, 'agg': agg} qg = Query.from_dict(sql, ordered=True) res = engine.execute_query(table_id, qg, lower=True) if res == gold: #print("res is gold") sqls.append(sql) """ if not is_final: if res and not res[0] is None: if agg==0 and set(gold).issubset(set(res)): # agg=0 print(f"agg:{agg},con3") sqls=append_cond(sel,cols,agg,spans,conds,sqls,engine,gold,table_id,True) elif agg==1: # agg=MAX try: gold_ans = float(gold[0]) except ValueError: continue try: res_ans = float(res[0]) except ValueError: continue if(gold_ans <= res_ans): print(f"agg={agg},cond3,gold_ans:{gold_ans},res_ans:{res_ans}") sqls=append_cond(sel,cols,agg,spans,conds,sqls,engine,gold,table_id,True) elif agg==2: # agg=MIN try: gold_ans = float(gold[0]) except ValueError: continue try: res_ans = float(res[0]) except ValueError: continue if(gold_ans >= res_ans): print(f"agg:{agg},con3") sqls=append_cond(sel,cols,agg,spans,conds,sqls,engine,gold,table_id,True) elif agg==3: # agg=COUNT try: gold_ans = float(gold[0]) except ValueError: continue try: res_ans = float(res[0]) except ValueError: continue if(gold_ans <= res_ans): print(f"agg:{agg},con3") sqls=append_cond(sel,cols,agg,spans,conds,sqls,engine,gold,table_id,True) elif agg==4: # agg=SUM try: gold_ans = float(gold[0]) except ValueError: continue try: res_ans = float(res[0]) except ValueError: continue if(gold_ans <= res_ans): print(f"agg:{agg},con3") sqls=append_cond(sel,cols,agg,spans,conds,sqls,engine,gold,table_id,True) elif agg==5: # agg=AVG try: gold_ans = float(gold[0]) except ValueError: continue try: res_ans = float(res[0]) except ValueError: continue print(f"agg:{agg},con3") sqls=append_cond(sel,cols,agg,spans,conds,sqls,engine,gold,table_id,True) """ return sqls
def get_answer(sql, table_id, engine): qg = Query.from_dict(sql, ordered=False) gold = engine.execute_query(table_id, qg, lower=True) return gold
def generate_sqls(idx, query, engine): #generate all possible sqls for the query and answer table_id = query['table_id'] table = engine.show_table(table_id) gold_sql = query['sql'] gold = get_answer(gold_sql, table_id, engine) cols = len(table[0]) sqls = [] question = query['question'].strip(" ?").split() #print(question) #print(gold_sql) #print(gold) span_len = len(question) if len(question) < 10 else 10 spans = generate_span(question, span_len) for col in tqdm(range(cols)): sel = col #print(f"sel:{sel}") for agg_i, _ in enumerate(agg_ops): agg = agg_i #at most 3 conditions #print(f"agg:{agg}") for con_col in range(cols): # condition 1 con1_col = con_col #print(f"con1_col:{con1_col}") for op_i, _ in enumerate(cond_ops): con1_op = op_i #print(f"con1_op:{con1_op}") for i in range(span_len): for span in spans[i]: con1 = ' '.join(span) try: con1 = float(con1) except ValueError: if con1_op > 0: continue conds = [[con1_col, con1_op, con1]] sql = {'sel': sel, 'conds': conds, 'agg': agg} qg = Query.from_dict(sql, ordered=True) res = engine.execute_query(table_id, qg, lower=True) if res == gold: sqls.append(sql) #print("res is gold") if res and not res[0] is None: if agg == 0 and set(gold).issubset( set(res)): # agg=0 #print(f"agg={agg},cond2") sqls = append_cond(sel, cols, agg, spans, conds, sqls, engine, gold, table_id) elif agg == 1: # agg=MAX try: gold_ans = float(gold[0]) except ValueError: continue try: res_ans = float(res[0]) except ValueError: continue if (gold_ans <= res_ans): #print(res) #print(f"agg={agg},cond2,gold_ans:{gold_ans},res_ans:{res_ans}") sqls = append_cond( sel, cols, agg, spans, conds, sqls, engine, gold, table_id) elif agg == 2: # agg=MIN try: gold_ans = float(gold[0]) except ValueError: continue try: res_ans = float(res[0]) except ValueError: continue if (gold_ans >= res_ans): #print(f"agg={agg},cond2") sqls = append_cond( sel, cols, agg, spans, conds, sqls, engine, gold, table_id) elif agg == 3: # agg=COUNT try: gold_ans = float(gold[0]) except ValueError: continue try: res_ans = float(res[0]) except ValueError: continue if (gold_ans <= res_ans): #print(f"agg={agg},cond2") sqls = append_cond( sel, cols, agg, spans, conds, sqls, engine, gold, table_id) elif agg == 4: # agg=COUNT try: gold_ans = float(gold[0]) except ValueError: continue try: res_ans = float(res[0]) except ValueError: continue if (gold_ans <= res_ans): #print(f"agg={agg},cond2") sqls = append_cond( sel, cols, agg, spans, conds, sqls, engine, gold, table_id) elif agg == 5: # agg=AVG try: gold_ans = float(gold[0]) except ValueError: continue try: res_ans = float(res[0]) except ValueError: continue #print(f"agg={agg},cond2") sqls = append_cond(sel, cols, agg, spans, conds, sqls, engine, gold, table_id) gen_sqls = {} gen_sqls['id'] = idx gen_sqls['sqls'] = json.dumps(sqls) if len(sqls) == 0: print(idx) return gen_sqls
p_sqls_path = os.path.join(root, 'data/distant_data', 'train_distant.jsonl') queries = extract.read_queries(query_path) p_sqlss = extract.read_potential_sqls(p_sqls_path) answer_path = './syn.txt' g_answers = extract.read_gold_answers(answer_path) print(len(g_answers)) engine = DBEngine_s('./data_and_model/train.db') rr_p_sqlss = [] for i, p_sqls in enumerate(tqdm(p_sqlss)): rr_p_sqls = [] if (len(p_sqls)) < 3: rr_p_sqls = [query['query'] for query in p_sqls] else: for p_sql in p_sqls: qg = Query.from_dict(p_sql['query'], ordered=True) res = engine.execute_query(queries[i]['table_id'], qg, lower=True) if res == g_answers[i]: rr_p_sqls.append(p_sql['query']) if len(rr_p_sqls) == 0: print(f"{i}\n") rr_p_sqlss.append(rr_p_sqls) with open('rr_p.jsonl', 'w') as f: for sqls in rr_p_sqlss: f.write(json.dumps(sqls)) f.write('\n') """ engine=DBEngine_s('./data_and_model/train.db') sql={"sel": 0, "conds": [[3, 0, "468-473 (6)"]], "agg": 3} qg=Query.from_dict(sql, ordered=True)
import extract from tqdm import tqdm import json import sys, os from wikisql.lib.dbengine import DBEngine from wikisql.lib.query import Query from sqlnet.dbengine import DBEngine as DBEngine_s from copy import deepcopy import eventlet from IPython import embed root = '/mnt/sda/qhz/sqlova' query_path = os.path.join(root, 'data_and_model', 'train_tok_origin.jsonl') table_path = os.path.join(root, 'data_and_model', 'train.tables.jsonl') queries = extract.read_queries(query_path) engine = DBEngine_s('./data_and_model/train.db') engine1 = DBEngine('./data_and_model/train.db') sql = {"sel": 9, "conds": [[3, 1, 8793], [5, 2, 1030]], "agg": 2} qg = Query.from_dict(sql, ordered=True) res = engine.execute_query(queries[15841]['table_id'], qg, lower=True) print(res)