Beispiel #1
0
def predict(data_loader, data_table, model, model_bert, bert_config, tokenizer,
            max_seq_length,
            num_target_layers, detail=False, st_pos=0, cnt_tot=1, EG=False, beam_size=4,
            path_db=None, dset_name='test'):

    model.eval()
    model_bert.eval()

    engine = DBEngine(os.path.join(path_db, f"{dset_name}.db"))
    results = []
    for iB, t in enumerate(data_loader):
        nlu, nlu_t, sql_i, sql_q, sql_t, tb, hs_t, hds = get_fields(t, data_table, no_hs_t=True, no_sql_t=True)
        g_sc, g_sa, g_wn, g_wc, g_wo, g_wv = get_g(sql_i)
        g_wvi_corenlp = get_g_wvi_corenlp(t)
        wemb_n, wemb_h, l_n, l_hpu, l_hs, \
        nlu_tt, t_to_tt_idx, tt_to_t_idx \
            = get_wemb_bert(bert_config, model_bert, tokenizer, nlu_t, hds, max_seq_length,
                            num_out_layers_n=num_target_layers, num_out_layers_h=num_target_layers)
        if not EG:
            # No Execution guided decoding
            s_sc, s_sa, s_wn, s_wc, s_wo, s_wv = model(wemb_n, l_n, wemb_h, l_hpu, l_hs)
            pr_sc, pr_sa, pr_wn, pr_wc, pr_wo, pr_wvi = pred_sw_se(s_sc, s_sa, s_wn, s_wc, s_wo, s_wv, )
            pr_wv_str, pr_wv_str_wp = convert_pr_wvi_to_string(pr_wvi, nlu_t, nlu_tt, tt_to_t_idx, nlu)
            pr_sql_i = generate_sql_i(pr_sc, pr_sa, pr_wn, pr_wc, pr_wo, pr_wv_str, nlu)
        else:
            # Execution guided decoding
            prob_sca, prob_w, prob_wn_w, pr_sc, pr_sa, pr_wn, pr_sql_i = model.beam_forward(wemb_n, l_n, wemb_h, l_hpu,
                                                                                            l_hs, engine, tb,
                                                                                            nlu_t, nlu_tt,
                                                                                            tt_to_t_idx, nlu,
                                                                                            beam_size=beam_size)
            # sort and generate
            pr_wc, pr_wo, pr_wv, pr_sql_i = sort_and_generate_pr_w(pr_sql_i)
            # Following variables are just for consistency with no-EG case.
            pr_wvi = None # not used
            pr_wv_str=None
            pr_wv_str_wp=None

        pr_sql_q = generate_sql_q(pr_sql_i, tb)
        pr_sql_q_base = generate_sql_q_base(pr_sql_i, tb)

        for b, (pr_sql_i1, pr_sql_q1, pr_sql_q1_base) in enumerate(zip(pr_sql_i, pr_sql_q, pr_sql_q_base)):
            results1 = {}
            results1["query"] = pr_sql_i1
            results1["table_id"] = tb[b]["id"]
            results1["nlu"] = nlu[b]
            results1["sql"] = pr_sql_q1
            results1["sql_with_params"] = pr_sql_q1_base
            rr = engine.execute_query(tb[b]["id"], Query.from_dict(pr_sql_i1, ordered=True), lower=False)
            results1["answer"] = rr
            results.append(results1)

    return results
Beispiel #2
0
    parser.add_argument('--source_file', help='source file for the prediction', default=path_source)
    parser.add_argument('--db_file', help='source database for the prediction', default=path_db)
    parser.add_argument('--pred_file', help='predictions by the model', default=path_pred)
    parser.add_argument('--ordered', action='store_true',
                        help='whether the exact match should consider the order of conditions')
    args = parser.parse_args()
    args.ordered = ordered

    engine = DBEngine(args.db_file)
    exact_match = []
    with open(args.source_file) as fs, open(args.pred_file) as fp:
        grades = []
        for ls, lp in tqdm(zip(fs, fp), total=count_lines(args.source_file)):
            eg = json.loads(ls)
            ep = json.loads(lp)
            qg = Query.from_dict(eg['sql'], ordered=args.ordered)
            gold = engine.execute_query(eg['table_id'], qg, lower=True)
            pred = ep.get('error', None)
            qp = None
            if not ep.get('error', None):
                try:
                    qp = Query.from_dict(ep['query'], ordered=args.ordered)
                    pred = engine.execute_query(eg['table_id'], qp, lower=True)
                except Exception as e:
                    pred = repr(e)
            correct = pred == gold
            match = qp == qg
            grades.append(correct)
            exact_match.append(match)

        print(json.dumps({
Beispiel #3
0
def append_cond(sel,
                cols,
                agg,
                spans,
                conds,
                sqls,
                engine,
                gold,
                table_id,
                is_final=False):
    span_len = len(spans)

    for con_col in range(cols):
        con2_col = con_col
        #print(f"con2_col:{con2_col}")
        for op_i, _ in enumerate(cond_ops):
            con2_op = op_i
            #print(f"con2_op:{con2_op}")
            for i in range(span_len):
                for span in spans[i]:
                    _conds = deepcopy(conds)
                    con2 = ' '.join(span)
                    try:
                        con2 = float(con2)
                    except ValueError:
                        if con2_op > 0:
                            continue
                    _conds.append([con2_col, con2_op, con2])
                    sql = {'sel': sel, 'conds': _conds, 'agg': agg}
                    qg = Query.from_dict(sql, ordered=True)
                    res = engine.execute_query(table_id, qg, lower=True)
                    if res == gold:
                        #print("res is gold")
                        sqls.append(sql)
                    """ if not is_final:
                        if res and not res[0] is None:
                            if agg==0 and set(gold).issubset(set(res)): # agg=0
                                print(f"agg:{agg},con3")
                                sqls=append_cond(sel,cols,agg,spans,conds,sqls,engine,gold,table_id,True)        
                            elif agg==1: # agg=MAX
                                try:
                                    gold_ans = float(gold[0])
                                except ValueError:
                                    continue
                                try:
                                    res_ans = float(res[0])
                                except ValueError:
                                    continue
                                if(gold_ans <= res_ans):
                                    print(f"agg={agg},cond3,gold_ans:{gold_ans},res_ans:{res_ans}")
                                    sqls=append_cond(sel,cols,agg,spans,conds,sqls,engine,gold,table_id,True)
                            elif agg==2: # agg=MIN
                                try:
                                    gold_ans = float(gold[0])
                                except ValueError:
                                    continue
                                try:
                                    res_ans = float(res[0])
                                except ValueError:
                                    continue
                                if(gold_ans >= res_ans):
                                    print(f"agg:{agg},con3")
                                    sqls=append_cond(sel,cols,agg,spans,conds,sqls,engine,gold,table_id,True)
                            elif agg==3: # agg=COUNT
                                try:
                                    gold_ans = float(gold[0])
                                except ValueError:
                                    continue
                                try:
                                    res_ans = float(res[0])
                                except ValueError:
                                    continue
                                if(gold_ans <= res_ans):
                                    print(f"agg:{agg},con3")
                                    sqls=append_cond(sel,cols,agg,spans,conds,sqls,engine,gold,table_id,True)
                            elif agg==4: # agg=SUM
                                try:
                                    gold_ans = float(gold[0])
                                except ValueError:
                                    continue
                                try:
                                    res_ans = float(res[0])
                                except ValueError:
                                    continue
                                if(gold_ans <= res_ans):
                                    print(f"agg:{agg},con3")
                                    sqls=append_cond(sel,cols,agg,spans,conds,sqls,engine,gold,table_id,True)
                            elif agg==5: # agg=AVG
                                try:
                                    gold_ans = float(gold[0])
                                except ValueError:
                                    continue
                                try:
                                    res_ans = float(res[0])
                                except ValueError:
                                    continue
                                print(f"agg:{agg},con3")
                                sqls=append_cond(sel,cols,agg,spans,conds,sqls,engine,gold,table_id,True) """
    return sqls
Beispiel #4
0
def get_answer(sql, table_id, engine):

    qg = Query.from_dict(sql, ordered=False)
    gold = engine.execute_query(table_id, qg, lower=True)
    return gold
Beispiel #5
0
def generate_sqls(idx, query, engine):
    #generate all possible sqls for the query and answer

    table_id = query['table_id']
    table = engine.show_table(table_id)
    gold_sql = query['sql']
    gold = get_answer(gold_sql, table_id, engine)
    cols = len(table[0])
    sqls = []
    question = query['question'].strip(" ?").split()
    #print(question)
    #print(gold_sql)
    #print(gold)
    span_len = len(question) if len(question) < 10 else 10
    spans = generate_span(question, span_len)
    for col in tqdm(range(cols)):
        sel = col
        #print(f"sel:{sel}")
        for agg_i, _ in enumerate(agg_ops):
            agg = agg_i
            #at most 3 conditions
            #print(f"agg:{agg}")
            for con_col in range(cols):  # condition 1
                con1_col = con_col
                #print(f"con1_col:{con1_col}")
                for op_i, _ in enumerate(cond_ops):
                    con1_op = op_i
                    #print(f"con1_op:{con1_op}")
                    for i in range(span_len):
                        for span in spans[i]:
                            con1 = ' '.join(span)
                            try:
                                con1 = float(con1)
                            except ValueError:
                                if con1_op > 0:
                                    continue
                            conds = [[con1_col, con1_op, con1]]
                            sql = {'sel': sel, 'conds': conds, 'agg': agg}
                            qg = Query.from_dict(sql, ordered=True)
                            res = engine.execute_query(table_id,
                                                       qg,
                                                       lower=True)
                            if res == gold:
                                sqls.append(sql)
                                #print("res is gold")
                            if res and not res[0] is None:
                                if agg == 0 and set(gold).issubset(
                                        set(res)):  # agg=0
                                    #print(f"agg={agg},cond2")
                                    sqls = append_cond(sel, cols, agg, spans,
                                                       conds, sqls, engine,
                                                       gold, table_id)

                                elif agg == 1:  # agg=MAX

                                    try:
                                        gold_ans = float(gold[0])
                                    except ValueError:
                                        continue
                                    try:
                                        res_ans = float(res[0])
                                    except ValueError:
                                        continue
                                    if (gold_ans <= res_ans):
                                        #print(res)
                                        #print(f"agg={agg},cond2,gold_ans:{gold_ans},res_ans:{res_ans}")
                                        sqls = append_cond(
                                            sel, cols, agg, spans, conds, sqls,
                                            engine, gold, table_id)
                                elif agg == 2:  # agg=MIN

                                    try:
                                        gold_ans = float(gold[0])
                                    except ValueError:
                                        continue
                                    try:
                                        res_ans = float(res[0])
                                    except ValueError:
                                        continue
                                    if (gold_ans >= res_ans):
                                        #print(f"agg={agg},cond2")
                                        sqls = append_cond(
                                            sel, cols, agg, spans, conds, sqls,
                                            engine, gold, table_id)
                                elif agg == 3:  # agg=COUNT
                                    try:
                                        gold_ans = float(gold[0])
                                    except ValueError:
                                        continue
                                    try:
                                        res_ans = float(res[0])
                                    except ValueError:
                                        continue
                                    if (gold_ans <= res_ans):
                                        #print(f"agg={agg},cond2")
                                        sqls = append_cond(
                                            sel, cols, agg, spans, conds, sqls,
                                            engine, gold, table_id)
                                elif agg == 4:  # agg=COUNT
                                    try:
                                        gold_ans = float(gold[0])
                                    except ValueError:
                                        continue
                                    try:
                                        res_ans = float(res[0])
                                    except ValueError:
                                        continue
                                    if (gold_ans <= res_ans):
                                        #print(f"agg={agg},cond2")
                                        sqls = append_cond(
                                            sel, cols, agg, spans, conds, sqls,
                                            engine, gold, table_id)
                                elif agg == 5:  # agg=AVG
                                    try:
                                        gold_ans = float(gold[0])
                                    except ValueError:
                                        continue
                                    try:
                                        res_ans = float(res[0])
                                    except ValueError:
                                        continue
                                    #print(f"agg={agg},cond2")
                                    sqls = append_cond(sel, cols, agg, spans,
                                                       conds, sqls, engine,
                                                       gold, table_id)
    gen_sqls = {}
    gen_sqls['id'] = idx
    gen_sqls['sqls'] = json.dumps(sqls)
    if len(sqls) == 0:
        print(idx)
    return gen_sqls
Beispiel #6
0
 p_sqls_path = os.path.join(root, 'data/distant_data',
                            'train_distant.jsonl')
 queries = extract.read_queries(query_path)
 p_sqlss = extract.read_potential_sqls(p_sqls_path)
 answer_path = './syn.txt'
 g_answers = extract.read_gold_answers(answer_path)
 print(len(g_answers))
 engine = DBEngine_s('./data_and_model/train.db')
 rr_p_sqlss = []
 for i, p_sqls in enumerate(tqdm(p_sqlss)):
     rr_p_sqls = []
     if (len(p_sqls)) < 3:
         rr_p_sqls = [query['query'] for query in p_sqls]
     else:
         for p_sql in p_sqls:
             qg = Query.from_dict(p_sql['query'], ordered=True)
             res = engine.execute_query(queries[i]['table_id'],
                                        qg,
                                        lower=True)
             if res == g_answers[i]:
                 rr_p_sqls.append(p_sql['query'])
     if len(rr_p_sqls) == 0:
         print(f"{i}\n")
     rr_p_sqlss.append(rr_p_sqls)
 with open('rr_p.jsonl', 'w') as f:
     for sqls in rr_p_sqlss:
         f.write(json.dumps(sqls))
         f.write('\n')
 """ engine=DBEngine_s('./data_and_model/train.db')
 sql={"sel": 0, "conds": [[3, 0, "468-473 (6)"]], "agg": 3}
 qg=Query.from_dict(sql, ordered=True)
Beispiel #7
0
    for split in ['train', 'dev', 'test']:
        fsplit = os.path.join(args.din, split) + '.jsonl'
        ftable = os.path.join(args.din, split) + '.tables.jsonl'
        fout = os.path.join(args.dout, split) + '.jsonl'

        print('annotating {}'.format(fsplit))
        with open(fsplit) as fs, open(ftable) as ft, open(fout, 'wt') as fo:
            print('loading tables')
            tables = {}
            for line in tqdm(ft, total=count_lines(ftable)):
                d = json.loads(line)
                tables[d['id']] = d
            print('loading examples')
            n_written = 0
            for line in tqdm(fs, total=count_lines(fsplit)):
                d = json.loads(line)
                a = annotate_example(d, tables[d['table_id']])
                if not is_valid_example(a):
                    raise Exception(str(a))

                gold = Query.from_tokenized_dict(a['query'])
                reconstruct = Query.from_sequence(a['seq_output'],
                                                  a['table'],
                                                  lowercase=True)
                if gold.lower() != reconstruct.lower():
                    raise Exception('Expected:\n{}\nGot:\n{}'.format(
                        gold, reconstruct))
                fo.write(json.dumps(a) + '\n')
                n_written += 1
            print('wrote {} examples'.format(n_written))
Beispiel #8
0
import extract
from tqdm import tqdm
import json
import sys, os
from wikisql.lib.dbengine import DBEngine
from wikisql.lib.query import Query
from sqlnet.dbengine import DBEngine as DBEngine_s
from copy import deepcopy
import eventlet
from IPython import embed

root = '/mnt/sda/qhz/sqlova'
query_path = os.path.join(root, 'data_and_model', 'train_tok_origin.jsonl')
table_path = os.path.join(root, 'data_and_model', 'train.tables.jsonl')
queries = extract.read_queries(query_path)
engine = DBEngine_s('./data_and_model/train.db')
engine1 = DBEngine('./data_and_model/train.db')
sql = {"sel": 9, "conds": [[3, 1, 8793], [5, 2, 1030]], "agg": 2}
qg = Query.from_dict(sql, ordered=True)
res = engine.execute_query(queries[15841]['table_id'], qg, lower=True)
print(res)