Exemplo n.º 1
0
def predict(data_loader, data_table, model, model_bert, bert_config, tokenizer,
            max_seq_length,
            num_target_layers, detail=False, st_pos=0, cnt_tot=1, EG=False, beam_size=4,
            path_db=None, dset_name='test'):

    model.eval()
    model_bert.eval()

    engine = DBEngine(os.path.join(path_db, f"{dset_name}.db"))
    results = []
    for iB, t in enumerate(data_loader):
        nlu, nlu_t, sql_i, sql_q, sql_t, tb, hs_t, hds = get_fields(t, data_table, no_hs_t=True, no_sql_t=True)
        g_sc, g_sa, g_wn, g_wc, g_wo, g_wv = get_g(sql_i)
        g_wvi_corenlp = get_g_wvi_corenlp(t)
        wemb_n, wemb_h, l_n, l_hpu, l_hs, \
        nlu_tt, t_to_tt_idx, tt_to_t_idx \
            = get_wemb_bert(bert_config, model_bert, tokenizer, nlu_t, hds, max_seq_length,
                            num_out_layers_n=num_target_layers, num_out_layers_h=num_target_layers)
        if not EG:
            # No Execution guided decoding
            s_sc, s_sa, s_wn, s_wc, s_wo, s_wv = model(wemb_n, l_n, wemb_h, l_hpu, l_hs)
            pr_sc, pr_sa, pr_wn, pr_wc, pr_wo, pr_wvi = pred_sw_se(s_sc, s_sa, s_wn, s_wc, s_wo, s_wv, )
            pr_wv_str, pr_wv_str_wp = convert_pr_wvi_to_string(pr_wvi, nlu_t, nlu_tt, tt_to_t_idx, nlu)
            pr_sql_i = generate_sql_i(pr_sc, pr_sa, pr_wn, pr_wc, pr_wo, pr_wv_str, nlu)
        else:
            # Execution guided decoding
            prob_sca, prob_w, prob_wn_w, pr_sc, pr_sa, pr_wn, pr_sql_i = model.beam_forward(wemb_n, l_n, wemb_h, l_hpu,
                                                                                            l_hs, engine, tb,
                                                                                            nlu_t, nlu_tt,
                                                                                            tt_to_t_idx, nlu,
                                                                                            beam_size=beam_size)
            # sort and generate
            pr_wc, pr_wo, pr_wv, pr_sql_i = sort_and_generate_pr_w(pr_sql_i)
            # Following variables are just for consistency with no-EG case.
            pr_wvi = None # not used
            pr_wv_str=None
            pr_wv_str_wp=None

        pr_sql_q = generate_sql_q(pr_sql_i, tb)
        pr_sql_q_base = generate_sql_q_base(pr_sql_i, tb)

        for b, (pr_sql_i1, pr_sql_q1, pr_sql_q1_base) in enumerate(zip(pr_sql_i, pr_sql_q, pr_sql_q_base)):
            results1 = {}
            results1["query"] = pr_sql_i1
            results1["table_id"] = tb[b]["id"]
            results1["nlu"] = nlu[b]
            results1["sql"] = pr_sql_q1
            results1["sql_with_params"] = pr_sql_q1_base
            rr = engine.execute_query(tb[b]["id"], Query.from_dict(pr_sql_i1, ordered=True), lower=False)
            results1["answer"] = rr
            results.append(results1)

    return results
Exemplo n.º 2
0
 queries = extract.read_queries(query_path)
 p_sqlss = extract.read_potential_sqls(p_sqls_path)
 answer_path = './syn.txt'
 g_answers = extract.read_gold_answers(answer_path)
 print(len(g_answers))
 engine = DBEngine_s('./data_and_model/train.db')
 rr_p_sqlss = []
 for i, p_sqls in enumerate(tqdm(p_sqlss)):
     rr_p_sqls = []
     if (len(p_sqls)) < 3:
         rr_p_sqls = [query['query'] for query in p_sqls]
     else:
         for p_sql in p_sqls:
             qg = Query.from_dict(p_sql['query'], ordered=True)
             res = engine.execute_query(queries[i]['table_id'],
                                        qg,
                                        lower=True)
             if res == g_answers[i]:
                 rr_p_sqls.append(p_sql['query'])
     if len(rr_p_sqls) == 0:
         print(f"{i}\n")
     rr_p_sqlss.append(rr_p_sqls)
 with open('rr_p.jsonl', 'w') as f:
     for sqls in rr_p_sqlss:
         f.write(json.dumps(sqls))
         f.write('\n')
 """ engine=DBEngine_s('./data_and_model/train.db')
 sql={"sel": 0, "conds": [[3, 0, "468-473 (6)"]], "agg": 3}
 qg=Query.from_dict(sql, ordered=True)
 res = engine.execute_query("1-10007452-3", qg, lower=True)
 print(res) """