def predict(data_loader, data_table, model, model_bert, bert_config, tokenizer, max_seq_length, num_target_layers, detail=False, st_pos=0, cnt_tot=1, EG=False, beam_size=4, path_db=None, dset_name='test'): model.eval() model_bert.eval() engine = DBEngine(os.path.join(path_db, f"{dset_name}.db")) results = [] for iB, t in enumerate(data_loader): nlu, nlu_t, sql_i, sql_q, sql_t, tb, hs_t, hds = get_fields(t, data_table, no_hs_t=True, no_sql_t=True) g_sc, g_sa, g_wn, g_wc, g_wo, g_wv = get_g(sql_i) g_wvi_corenlp = get_g_wvi_corenlp(t) wemb_n, wemb_h, l_n, l_hpu, l_hs, \ nlu_tt, t_to_tt_idx, tt_to_t_idx \ = get_wemb_bert(bert_config, model_bert, tokenizer, nlu_t, hds, max_seq_length, num_out_layers_n=num_target_layers, num_out_layers_h=num_target_layers) if not EG: # No Execution guided decoding s_sc, s_sa, s_wn, s_wc, s_wo, s_wv = model(wemb_n, l_n, wemb_h, l_hpu, l_hs) pr_sc, pr_sa, pr_wn, pr_wc, pr_wo, pr_wvi = pred_sw_se(s_sc, s_sa, s_wn, s_wc, s_wo, s_wv, ) pr_wv_str, pr_wv_str_wp = convert_pr_wvi_to_string(pr_wvi, nlu_t, nlu_tt, tt_to_t_idx, nlu) pr_sql_i = generate_sql_i(pr_sc, pr_sa, pr_wn, pr_wc, pr_wo, pr_wv_str, nlu) else: # Execution guided decoding prob_sca, prob_w, prob_wn_w, pr_sc, pr_sa, pr_wn, pr_sql_i = model.beam_forward(wemb_n, l_n, wemb_h, l_hpu, l_hs, engine, tb, nlu_t, nlu_tt, tt_to_t_idx, nlu, beam_size=beam_size) # sort and generate pr_wc, pr_wo, pr_wv, pr_sql_i = sort_and_generate_pr_w(pr_sql_i) # Following variables are just for consistency with no-EG case. pr_wvi = None # not used pr_wv_str=None pr_wv_str_wp=None pr_sql_q = generate_sql_q(pr_sql_i, tb) pr_sql_q_base = generate_sql_q_base(pr_sql_i, tb) for b, (pr_sql_i1, pr_sql_q1, pr_sql_q1_base) in enumerate(zip(pr_sql_i, pr_sql_q, pr_sql_q_base)): results1 = {} results1["query"] = pr_sql_i1 results1["table_id"] = tb[b]["id"] results1["nlu"] = nlu[b] results1["sql"] = pr_sql_q1 results1["sql_with_params"] = pr_sql_q1_base rr = engine.execute_query(tb[b]["id"], Query.from_dict(pr_sql_i1, ordered=True), lower=False) results1["answer"] = rr results.append(results1) return results
queries = extract.read_queries(query_path) p_sqlss = extract.read_potential_sqls(p_sqls_path) answer_path = './syn.txt' g_answers = extract.read_gold_answers(answer_path) print(len(g_answers)) engine = DBEngine_s('./data_and_model/train.db') rr_p_sqlss = [] for i, p_sqls in enumerate(tqdm(p_sqlss)): rr_p_sqls = [] if (len(p_sqls)) < 3: rr_p_sqls = [query['query'] for query in p_sqls] else: for p_sql in p_sqls: qg = Query.from_dict(p_sql['query'], ordered=True) res = engine.execute_query(queries[i]['table_id'], qg, lower=True) if res == g_answers[i]: rr_p_sqls.append(p_sql['query']) if len(rr_p_sqls) == 0: print(f"{i}\n") rr_p_sqlss.append(rr_p_sqls) with open('rr_p.jsonl', 'w') as f: for sqls in rr_p_sqlss: f.write(json.dumps(sqls)) f.write('\n') """ engine=DBEngine_s('./data_and_model/train.db') sql={"sel": 0, "conds": [[3, 0, "468-473 (6)"]], "agg": 3} qg=Query.from_dict(sql, ordered=True) res = engine.execute_query("1-10007452-3", qg, lower=True) print(res) """