コード例 #1
0
ファイル: utils.py プロジェクト: yuconan/nl2sql_baseline
def epoch_exec_acc(model, batch_size, sql_data, table_data, db_path):
    engine = DBEngine(db_path)
    model.eval()
    perm = list(range(len(sql_data)))
    tot_acc_num = 0.0
    for st in tqdm(range(len(sql_data) // batch_size + 1)):
        ed = (st + 1) * batch_size if (st + 1) * batch_size < len(
            perm) else len(perm)
        st = st * batch_size
        q_seq, gt_sel_num, col_seq, col_num, ans_seq, gt_cond_seq, raw_data = \
            to_batch_seq(sql_data, table_data, perm, st, ed, ret_vis_data=True)
        raw_q_seq = [x[0] for x in raw_data]
        raw_col_seq = [x[1] for x in raw_data]
        query_gt, table_ids = to_batch_query(sql_data, perm, st, ed)
        gt_sel_seq = [x[2] for x in ans_seq]
        score = model.forward(q_seq, col_seq, col_num, gt_sel=gt_sel_seq)
        pred_queries = model.gen_query(score, q_seq, col_seq, raw_q_seq,
                                       raw_col_seq)

        for idx, (sql_gt, sql_pred,
                  tid) in enumerate(zip(query_gt, pred_queries, table_ids)):
            ret_gt = engine.execute(tid, sql_gt['sel'], sql_gt['agg'],
                                    sql_gt['conds'])
            try:
                ret_pred = engine.execute(tid, sql_pred['sel'],
                                          sql_pred['agg'], sql_pred['conds'])
            except:
                ret_pred = None
            tot_acc_num += (ret_gt == ret_pred)
    return tot_acc_num / len(sql_data)
コード例 #2
0
ファイル: evaluate.py プロジェクト: lukovnikov/WikiSQL
def eval_one_qelos(db_file, pred_file, source_file):
    engine = DBEngine(db_file)
    exact_match = []
    with open(source_file) as fs, open(pred_file) as fp:
        grades = []
        for ls, lp in tqdm(zip(fs, fp), total=count_lines(source_file)):
            eg = json.loads(ls)
            ep = json.loads(lp)
            qg = Query.from_dict(eg['sql'])
            gold = engine.execute_query(eg['table_id'], qg, lower=True)
            qp = None
            try:
                qp = Query.from_dict(ep)
                pred = engine.execute_query(eg['table_id'], qp, lower=True)
            except Exception as e:
                pred = repr(e)
            correct = pred == gold
            match = qp == qg
            grades.append(correct)
            exact_match.append(match)
        result = {
            'ex_accuracy': sum(grades) / len(grades),
            'lf_accuracy': sum(exact_match) / len(exact_match),
            }
        return result
コード例 #3
0
def hello():
    pred_json = json.loads(request.args.get('pred'))
    gold_json = json.loads(request.args.get('gold'))
    dataset = request.args.get('dataset')

    engine = DBEngine(os.path.join(DATABASE_PATH, "{}.db".format(dataset)))

    exact_match = []
    grades = []

    for lp, ls in tqdm(zip(pred_json, gold_json), total=len(gold_json)):
        eg = ls
        ep = lp
        qg = Query.from_dict(eg['sql'])
        gold = engine.execute_query(eg['table_id'], qg, lower=True)
        pred = ep['error']
        qp = None
        if not ep['error']:
            try:
                qp = Query.from_dict(ep['query'])
                pred = engine.execute_query(eg['table_id'], qp, lower=True)
            except Exception as e:
                pred = repr(e)
        correct = pred == gold
        match = qp == qg
        grades.append(correct)
        exact_match.append(match)

    ex_accuracy = sum(grades) / len(grades)
    lf_accuracy = sum(exact_match) / len(exact_match)
    return json.dumps({"ex_accuracy": ex_accuracy, "lf_accuracy": lf_accuracy})
コード例 #4
0
def epoch_reinforce_train(model, optimizer, batch_size, sql_data, table_data,
                          db_path):
    engine = DBEngine(db_path)

    model.train()
    perm = np.random.permutation(len(sql_data))
    cum_reward = 0.0
    st = 0
    while st < len(sql_data):
        ed = st + batch_size if st + batch_size < len(perm) else len(perm)

        q_seq, col_seq, col_num, ans_seq, query_seq, gt_cond_seq, raw_data = \
            to_batch_seq(sql_data, table_data, perm, st, ed, ret_vis_data=True)
        gt_where_seq = model.generate_gt_where_seq(q_seq, col_seq, query_seq)
        raw_q_seq = [x[0] for x in raw_data]
        raw_col_seq = [x[1] for x in raw_data]
        query_gt, table_ids = to_batch_query(sql_data, perm, st, ed)
        gt_sel_seq = [x[1] for x in ans_seq]
        score = model.forward(q_seq,
                              col_seq,
                              col_num, (True, True, True),
                              reinforce=True,
                              gt_sel=gt_sel_seq)
        pred_queries = model.gen_query(score,
                                       q_seq,
                                       col_seq,
                                       raw_q_seq,
                                       raw_col_seq, (True, True, True),
                                       reinforce=True)

        query_gt, table_ids = to_batch_query(sql_data, perm, st, ed)
        rewards = []
        for idx, (sql_gt, sql_pred,
                  tid) in enumerate(zip(query_gt, pred_queries, table_ids)):
            ret_gt = engine.execute(tid, sql_gt['sel'], sql_gt['agg'],
                                    sql_gt['conds'])
            try:
                ret_pred = engine.execute(tid, sql_pred['sel'],
                                          sql_pred['agg'], sql_pred['conds'])
            except:
                ret_pred = None

            if ret_pred is None:
                rewards.append(-2)
            elif ret_pred != ret_gt:
                rewards.append(-1)
            else:
                rewards.append(1)

        cum_reward += (sum(rewards))
        optimizer.zero_grad()
        model.reinforce_backward(score, rewards)
        optimizer.step()

        st = ed

    return cum_reward / len(sql_data)
コード例 #5
0
def main(anno_file_name, col_headers, raw_args=None, verbose=True):
    parser = argparse.ArgumentParser(description='evaluate.py')
    opts.translate_opts(parser)
    opt = parser.parse_args(raw_args)
    torch.cuda.set_device(opt.gpu)
    opt.db_file = os.path.join(opt.data_path, '{}.db'.format(opt.split))
    opt.pre_word_vecs = os.path.join(opt.data_path, 'embedding')
    dummy_parser = argparse.ArgumentParser(description='train.py')
    opts.model_opts(dummy_parser)
    opts.train_opts(dummy_parser)
    dummy_opt = dummy_parser.parse_known_args([])[0]
    opt.anno = anno_file_name

    engine = DBEngine(opt.db_file)

    js_list = table.IO.read_anno_json(opt.anno)

    prev_best = (None, None)
    sql_query = []
    for fn_model in glob.glob(opt.model_path):

        opt.model = fn_model

        translator = Translator(opt, dummy_opt.__dict__)
        data = table.IO.TableDataset(js_list, translator.fields, None, False)
        test_data = table.IO.OrderedIterator(dataset=data,
                                             device=opt.gpu,
                                             batch_size=opt.batch_size,
                                             train=False,
                                             sort=True,
                                             sort_within_batch=False)

        # inference
        r_list = []
        for batch in test_data:
            r_list += translator.translate(batch)
        r_list.sort(key=lambda x: x.idx)
        pred = r_list[-1]
        sql_pred = {
            'agg': pred.agg,
            'sel': pred.sel,
            'conds': pred.recover_cond_to_gloss(js_list[-1])
        }
        if verbose:
            print('\n sql_pred: ', sql_pred, '\n')
            print('\n col_headers: ', col_headers, '\n')
        sql_query = Query(sql_pred['sel'], sql_pred['agg'], sql_pred['conds'])
        try:
            ans_pred = engine.execute_query(js_list[-1]['table_id'],
                                            Query.from_dict(sql_pred),
                                            lower=True,
                                            verbose=verbose)
        except Exception as e:
            ans_pred = None
    return sql_query.get_complete_query(col_headers), ans_pred
コード例 #6
0
def epoch_acc(model, batch_size, sql_data, table_data, db_path):
    engine = DBEngine(db_path)
    model.eval()
    perm = list(range(len(sql_data)))
    badcase = 0
    one_acc_num, tot_acc_num, ex_acc_num = 0.0, 0.0, 0.0
    for st in tqdm(range(len(sql_data) // batch_size + 1)):
        ed = (st + 1) * batch_size if (st + 1) * batch_size < len(
            perm) else len(perm)
        st = st * batch_size
        q_seq, gt_sel_num, col_seq, col_num, ans_seq, gt_cond_seq, raw_data = \
            to_batch_seq(sql_data, table_data, perm, st, ed, ret_vis_data=True)
        # q_seq: char-based sequence of question
        # gt_sel_num: number of selected columns and aggregation functions, new added field
        # col_seq: char-based column name
        # col_num: number of headers in one table
        # ans_seq: (sel, number of conds, sel list in conds, op list in conds)
        # gt_cond_seq: ground truth of conditions
        # raw_data: ori question, headers, sql
        query_gt, table_ids = to_batch_query(sql_data, perm, st, ed)
        # query_gt: ground truth of sql, data['sql'], containing sel, agg, conds:{sel, op, value}
        raw_q_seq = [x[0] for x in raw_data]  # original question
        try:
            score = model.forward(q_seq, col_seq, col_num)
            pred_queries = model.gen_query(score, q_seq, col_seq, raw_q_seq)
            # generate predicted format
            one_err, tot_err = model.check_acc(raw_data, pred_queries,
                                               query_gt)
        except:
            badcase += 1
            print('badcase', badcase)
            continue
        one_acc_num += (ed - st - one_err)
        tot_acc_num += (ed - st - tot_err)

        # Execution Accuracy
        for sql_gt, sql_pred, tid in zip(query_gt, pred_queries, table_ids):
            ret_gt = engine.execute(tid, sql_gt['sel'], sql_gt['agg'],
                                    sql_gt['conds'], sql_gt['cond_conn_op'])
            try:
                ret_pred = engine.execute(tid, sql_pred['sel'],
                                          sql_pred['agg'], sql_pred['conds'],
                                          sql_pred['cond_conn_op'])
            except:
                ret_pred = None
            ex_acc_num += (ret_gt == ret_pred)
    return one_acc_num / len(sql_data), tot_acc_num / len(
        sql_data), ex_acc_num / len(sql_data)
コード例 #7
0
ファイル: evaluate.py プロジェクト: epaes90/nl2sql-1
def main():
    dummy_parser = argparse.ArgumentParser(description='train.py')
    opts.model_opts(dummy_parser)
    opts.train_opts(dummy_parser)
    dummy_opt = dummy_parser.parse_known_args([])[0]

    engine = DBEngine(opt.db_file)

    with codecs.open(opt.source_file, "r", "utf-8") as corpus_file:
        sql_list = [json.loads(line)['sql'] for line in corpus_file]

    js_list = table.IO.read_anno_json(opt.anno)

    prev_best = (None, None)
    for fn_model in glob.glob(opt.model_path):

        opt.model = fn_model

        translator = Translator(opt, dummy_opt.__dict__)
        data = table.IO.TableDataset(js_list, translator.fields, None, False)
        test_data = table.IO.OrderedIterator(dataset=data,
                                             device=opt.gpu,
                                             batch_size=opt.batch_size,
                                             train=False,
                                             sort=True,
                                             sort_within_batch=False)

        # inference
        if opt.beam_search:
            print('Using execution guidance for inference.')
        r_list = []

        for batch in test_data:
            r_list += translator.translate(batch, js_list, sql_list)

        r_list.sort(key=lambda x: x.idx)

        assert len(r_list) == len(
            js_list), 'len(r_list) != len(js_list): {} != {}'.format(
                len(r_list), len(js_list))

        # evaluation
        for pred, gold, sql_gold in zip(r_list, js_list, sql_list):
            pred.eval(gold, sql_gold, engine)
        print('Results:')
        for metric_name in ('all', 'exe'):
            c_correct = sum((x.correct[metric_name] for x in r_list))
            print('{}: {} / {} = {:.2%}'.format(metric_name, c_correct,
                                                len(r_list),
                                                c_correct / len(r_list)))
            if metric_name == 'all' and (prev_best[0] is None
                                         or c_correct > prev_best[1]):
                prev_best = (fn_model, c_correct)

    if (opt.split == 'dev') and (prev_best[0] is not None):
        with codecs.open(os.path.join(opt.data_path, 'dev_best.txt'),
                         'w',
                         encoding='utf-8') as f_out:
            f_out.write('{}\n'.format(prev_best[0]))
コード例 #8
0
ファイル: evaluate_ours.py プロジェクト: jzl0166/NLIDB
def main(argv):
    del argv  # Unused.

    db_file = join(FLAGS.data_root, FLAGS.db_file)
    parsed_std_sql_file = join(FLAGS.data_root, FLAGS.parsed_std_sql_file)
    parsed_pred_sql_file = join(FLAGS.data_root, FLAGS.parsed_pred_sql_file)

    engine = DBEngine(db_file)
    exact_match = []

    with open(parsed_std_sql_file) as fs, open(parsed_pred_sql_file) as fp:
        grades = []
        for ls, lp in tqdm(zip(fs, fp),
                           total=count_lines(parsed_std_sql_file)):
            eg = json.loads(ls)
            ep = json.loads(lp)

            try:
                qg = Query.from_dict(eg['sql'])
                gold = engine.execute_query(eg['table_id'], qg, lower=True)
            except Exception as e:
                gold = repr(e)

            #pred = ep['error']
            qp = None
            #if not ep['error']:
            if True:
                try:
                    qp = Query.from_dict(ep['sql'])
                    pred = engine.execute_query(eg['table_id'], qp, lower=True)
                except Exception as e:
                    pred = repr(e)
            correct = pred == gold
            match = qp == qg
            if pred == gold and qp != qg:
                print(qp)
                print(qg)
            grades.append(correct)
            exact_match.append(match)
        print(
            json.dumps(
                {
                    'ex_accuracy': sum(grades) / len(grades),
                    'lf_accuracy': sum(exact_match) / len(exact_match),
                },
                indent=2))
コード例 #9
0
ファイル: utils.py プロジェクト: qianwenyuan/typesql_ch
def epoch_acc(model, batch_size, sql_data, table_data, db_path, db_content):
    engine = DBEngine(db_path)
    model.eval()
    perm = list(range(len(sql_data)))
    badcase = 0
    one_acc_num, tot_acc_num, ex_acc_num = 0.0, 0.0, 0.0
    for st in tqdm(range(len(sql_data) // batch_size + 1)):
        ed = (st + 1) * batch_size if (st + 1) * batch_size < len(
            perm) else len(perm)
        st = st * batch_size
        q_seq, gt_sel_num, col_seq, col_num, ans_seq, gt_cond_seq, q_type, col_type,\
         raw_data = to_batch_seq(sql_data, table_data, perm, st, ed, db_content, ret_vis_data=True)
        print("q_seq:{}".format(len(q_seq)))
        raw_q_seq = [x[0] for x in raw_data]
        #raw_col_seq = [x[1] for x in raw_data]
        query_gt, table_ids = to_batch_query(sql_data, perm, st, ed)
        #gt_sel_seq = [x[1] for x in ans_seq]
        #try:
        score = model.forward(q_seq, col_seq, col_num, q_type, col_type)
        pred_queries = model.gen_query(score, q_seq, col_seq, raw_q_seq)
        one_err, tot_err = model.check_acc(raw_data, pred_queries, query_gt)
        #except:
        #    badcase += 1
        #    print 'badcase', badcase
        #    continue
        one_acc_num += (ed - st - one_err)
        tot_acc_num += (ed - st - tot_err)

        # Execution Accuracy
        for sql_gt, sql_pred, tid in zip(query_gt, pred_queries, table_ids):
            ret_gt = engine.execute(tid, sql_gt['sel'], sql_gt['agg'],
                                    sql_gt['conds'], sql_gt['cond_conn_op'])
            try:
                ret_pred = engine.execute(tid, sql_pred['sel'],
                                          sql_pred['agg'], sql_pred['conds'],
                                          sql_pred['cond_conn_op'])
            except:
                ret_pred = None
            ex_acc_num += (ret_gt == ret_pred)
    return one_acc_num / len(sql_data), tot_acc_num / len(
        sql_data), ex_acc_num / len(sql_data)
コード例 #10
0
ファイル: utils.py プロジェクト: shanelleroman/seq2sql
def epoch_exec_acc(model, batch_size, sql_data, table_data, db_path):
    engine = DBEngine(db_path)
    print 'exec acc'

    model.eval()
    perm = list(range(len(sql_data)))
    tot_acc_num = 0.0
    acc_of_log = 0.0
    st = 0
    while st < len(sql_data):
        ed = st+batch_size if st+batch_size < len(perm) else len(perm)
        q_seq, col_seq, col_num, ans_seq, query_seq, gt_cond_seq, raw_data = \
            to_batch_seq(sql_data, table_data, perm, st, ed, ret_vis_data=True)
        raw_q_seq = [x[0] for x in raw_data]
        raw_col_seq = [x[1] for x in raw_data]
        model.generate_gt_sel_seq(q_seq, col_seq, query_seq)
        gt_where_seq = model.generate_gt_where_seq(q_seq, col_seq, query_seq)
        query_gt, table_ids = to_batch_query(sql_data, perm, st, ed)
        gt_sel_seq = [x[1] for x in ans_seq]
        # print 'gt_sel_seq', gt_sel_seq
        score = model.forward(q_seq, col_seq, col_num,
                (True, True, True), gt_sel=gt_sel_seq)
        pred_queries = model.gen_query(score, q_seq, col_seq,
                raw_q_seq, raw_col_seq, (True, True, True))

        for idx, (sql_gt, sql_pred, tid) in enumerate(
                zip(query_gt, pred_queries, table_ids)):
            ret_gt = engine.execute(tid,
                    sql_gt['sel'], sql_gt['agg'], sql_gt['conds'])
            try:
                ret_pred = engine.execute(tid,
                        sql_pred['sel'], sql_pred['agg'], sql_pred['conds'])
            except:
                ret_pred = None
            tot_acc_num += (ret_gt == ret_pred)
        
        st = ed

    return tot_acc_num / len(sql_data)
コード例 #11
0
ファイル: utils.py プロジェクト: yuconan/nl2sql_baseline
def predict_test(model, batch_size, sql_data, table_data, db_path,
                 output_path):
    engine = DBEngine(db_path)
    model.eval()
    perm = list(range(len(sql_data)))
    fw = open(output_path, 'w')
    for st in tqdm(range(len(sql_data) // batch_size + 1)):
        ed = (st + 1) * batch_size if (st + 1) * batch_size < len(
            perm) else len(perm)
        st = st * batch_size
        q_seq, col_seq, col_num, raw_q_seq, table_ids = to_batch_seq_test(
            sql_data, table_data, perm, st, ed)
        score = model.forward(q_seq, col_seq, col_num)
        sql_preds = model.gen_query(score, q_seq, col_seq, raw_q_seq)
        for sql_pred in sql_preds:
            fw.writelines(
                json.dumps(sql_pred, ensure_ascii=False).encode('utf-8') +
                '\n')
    fw.close()
コード例 #12
0
import json
from tqdm import tqdm
import os
import sys
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from lib.query import Query
from lib.dbengine import DBEngine
import pytorch

if __name__ == '__main__':
    for split in ['train', 'dev', 'test']:
        print('checking {}'.format(split))
        engine = DBEngine('data/{}.db'.format(split))
        n_lines = 0
        with open('data/{}.jsonl'.format(split)) as f:
            for l in f:
                n_lines += 1
        with open('data/{}.jsonl'.format(split)) as f:
            for l in tqdm(f, total=n_lines):
                d = json.loads(l)
                query = Query.from_dict(d['sql'])

                # make sure it's executable
                result = engine.execute_query(d['table_id'], query)
                if result:
                    for a, b, c in d['sql']['conds']:
                        if str(c).lower() not in d['question'].lower():
                            raise Exception(
                                'Could not find condition {} in question {} for query {}'
                                .format(c, d['question'], query))
                else:
コード例 #13
0
ファイル: evaluate.py プロジェクト: TooTouch/SPARTA
import pickle
import json


if __name__ == '__main__':
    parser = ArgumentParser()
    parser.add_argument('source_file', help='source file for the prediction')
    parser.add_argument('db_file', help='source database for the prediction')
    parser.add_argument('pred_file', help='predictions by the model')
    parser.add_argument('--ordered', action='store_true', help='whether the exact match should consider the order of conditions')

    parser.add_argument('--topk', type=int, default=3, help='k of top_k')
    
    args = parser.parse_args()

    engine = DBEngine(args.db_file)
   
    temp = []
    
    with open(args.source_file) as fs, open(args.pred_file) as fp:
        grades = []
        exact_match = []
        
        for ls, lp in tqdm(zip(fs, fp), total=count_lines(args.source_file)):            
            eg = json.loads(ls)
            qg = Query.from_dict(eg['sql'], ordered=args.ordered)
            gold = engine.execute_query(eg['table_id'], qg, lower=True)
            
            pred_topk = []
            qp_topk = []
            
コード例 #14
0
ファイル: Translator.py プロジェクト: epaes90/nl2sql-1
    def translate(self, batch, js_list=[], sql_list=[]):
        q, q_len = batch.src
        tbl, tbl_len = batch.tbl
        ent, tbl_split, tbl_mask = batch.ent, batch.tbl_split, batch.tbl_mask
        # encoding
        q_enc, q_all, tbl_enc, q_ht, batch_size = self.model.enc(
            q, q_len, ent, tbl, tbl_len, tbl_split, tbl_mask)

        # (1) decoding
        agg_pred = cpu_vector(argmax(self.model.agg_classifier(q_ht).data))
        sel_pred = cpu_vector(
            argmax(self.model.sel_match(q_ht, tbl_enc, tbl_mask).data))
        lay_pred = argmax(self.model.lay_classifier(q_ht).data)
        engine = DBEngine(self.opt.db_file)
        indices = cpu_vector(batch.indices.data)
        # get layout op tokens
        op_batch_list = []
        op_idx_batch_list = []
        if self.opt.gold_layout:
            lay_pred = batch.lay.data
            cond_op, cond_op_len = batch.cond_op
            cond_op_len_list = cond_op_len.view(-1).tolist()
            for i, len_it in enumerate(cond_op_len_list):
                if len_it == 0:
                    op_idx_batch_list.append([])
                    op_batch_list.append([])
                else:
                    idx_list = cond_op.data[0:len_it,
                                            i].contiguous().view(-1).tolist()
                    op_idx_batch_list.append([
                        int(self.fields['cond_op'].vocab.itos[it])
                        for it in idx_list
                    ])
                    op_batch_list.append(idx_list)
        else:
            lay_batch_list = lay_pred.view(-1).tolist()
            for lay_it in lay_batch_list:
                tk_list = self.fields['lay'].vocab.itos[lay_it].split(' ')
                if (len(tk_list) == 0) or (tk_list[0] == ''):
                    op_idx_batch_list.append([])
                    op_batch_list.append([])
                else:
                    op_idx_batch_list.append(
                        [int(op_str) for op_str in tk_list])
                    op_batch_list.append([
                        self.fields['cond_op'].vocab.stoi[op_str]
                        for op_str in tk_list
                    ])
            # -> (num_cond, batch)
            cond_op = v_eval(
                add_pad(
                    op_batch_list,
                    self.fields['cond_op'].vocab.stoi[table.IO.PAD_WORD]).t())
            cond_op_len = torch.LongTensor([len(it) for it in op_batch_list])
        # emb_op -> (num_cond, batch, emb_size)
        if self.model.opt.layout_encode == 'rnn':
            emb_op = table.Models.encode_unsorted_batch(
                self.model.lay_encoder, cond_op, cond_op_len.clamp(min=1))
        else:
            emb_op = self.model.cond_embedding(cond_op)

        # (2) decoding
        self.model.cond_decoder.attn.applyMaskBySeqBatch(q)
        cond_state = self.model.cond_decoder.init_decoder_state(q_all, q_enc)
        cond_col_list, cond_span_l_list, cond_span_r_list = [], [], []
        t = 0
        need_back_track = [False] * batch_size
        for emb_op_t in emb_op:
            t += 1
            emb_op_t = emb_op_t.unsqueeze(0)
            cond_context, cond_state, _ = self.model.cond_decoder(
                emb_op_t, q_all, cond_state)
            # cond col -> (1, batch)
            cond_col_all = self.model.cond_col_match(cond_context, tbl_enc,
                                                     tbl_mask).data
            cond_col = argmax(cond_col_all)
            # add to this after beam search: cond_col_list.append(cpu_vector(cond_col))
            # emb_col
            batch_index = torch.LongTensor(
                range(batch_size)).unsqueeze_(0).cuda().expand(
                    cond_col.size(0), cond_col.size(1))
            emb_col = tbl_enc[cond_col, batch_index, :]
            cond_context, cond_state, _ = self.model.cond_decoder(
                emb_col, q_all, cond_state)
            # cond span
            q_mask = v_eval(
                q.data.eq(self.model.pad_word_index).transpose(0, 1))
            cond_span_l_batch_all = self.model.cond_span_l_match(
                cond_context, q_all, q_mask).data
            cond_span_l_batch = argmax(cond_span_l_batch_all)
            # add to this after beam search: cond_span_l_list.append(cpu_vector(cond_span_l))
            # emb_span_l: (1, batch, hidden_size)
            emb_span_l = q_all[cond_span_l_batch, batch_index, :]
            cond_span_r_batch = argmax(
                self.model.cond_span_r_match(cond_context, q_all, q_mask,
                                             emb_span_l).data)
            # add to this after beam search: cond_span_r_list.append(cpu_vector(cond_span_r))
            if self.opt.beam_search:
                # for now just go through the next col in cond
                k = min(self.opt.beam_size, cond_col_all.size()[2])
                top_col_idx = cond_col_all.topk(k)[1]
                for b in range(batch_size):
                    if t > len(op_idx_batch_list[b]) or need_back_track[b]:
                        continue
                    idx = indices[b]
                    agg = agg_pred[b]
                    sel = sel_pred[b]
                    cond = []
                    for i in range(t):
                        op = op_idx_batch_list[b][i]
                        if i < t - 1:
                            col = cond_col_list[i][b]
                            span_l = cond_span_l_list[i][b]
                            span_r = cond_span_r_list[i][b]
                        else:
                            col = cond_col[0, b]
                            span_l = cond_span_l_batch[0, b]
                            span_r = cond_span_r_batch[0, b]
                        cond.append((col, op, (span_l, span_r)))
                    pred = ParseResult(idx, agg, sel, cond)
                    pred.eval(js_list[idx], sql_list[idx], engine)
                    n_test = 0
                    while pred.exception_raised and n_test < top_col_idx.size(
                    )[2] - 1:
                        n_test += 1
                        if n_test > self.opt.beam_size:
                            need_back_track[b] = True
                            break
                        cond_col[0, b] = top_col_idx[0, b, n_test]
                        emb_col = tbl_enc[cond_col, batch_index, :]
                        cond_context, cond_state, _ = self.model.cond_decoder(
                            emb_col, q_all, cond_state)
                        # cond span
                        q_mask = v_eval(
                            q.data.eq(self.model.pad_word_index).transpose(
                                0, 1))
                        cond_span_l_batch_all = self.model.cond_span_l_match(
                            cond_context, q_all, q_mask).data
                        cond_span_l_batch = argmax(cond_span_l_batch_all)
                        # emb_span_l: (1, batch, hidden_size)
                        emb_span_l = q_all[cond_span_l_batch, batch_index, :]
                        cond_span_r_batch = argmax(
                            self.model.cond_span_r_match(
                                cond_context, q_all, q_mask, emb_span_l).data)
                        # run the new query over database
                        col = cond_col[0, b]
                        span_l = cond_span_l_batch[0, b]
                        span_r = cond_span_r_batch[0, b]
                        cond.pop()
                        cond.append((col, op, (span_l, span_r)))
                        pred = ParseResult(idx, agg, sel, cond)
                        pred.eval(js_list[idx], sql_list[idx], engine)
            cond_col_list.append(cpu_vector(cond_col))
            cond_span_l_list.append(cpu_vector(cond_span_l_batch))
            cond_span_r_list.append(cpu_vector(cond_span_r_batch))
            # emb_span_r: (1, batch, hidden_size)
            emb_span_r = q_all[cond_span_r_batch, batch_index, :]
            emb_span = self.model.span_merge(
                torch.cat([emb_span_l, emb_span_r], 2))
            cond_context, cond_state, _ = self.model.cond_decoder(
                emb_span, q_all, cond_state)

        # (3) recover output
        r_list = []
        for b in range(batch_size):
            idx = indices[b]
            agg = agg_pred[b]
            sel = sel_pred[b]
            cond = []
            for i in range(len(op_batch_list[b])):
                col = cond_col_list[i][b]
                op = op_idx_batch_list[b][i]
                span_l = cond_span_l_list[i][b]
                span_r = cond_span_r_list[i][b]
                cond.append((col, op, (span_l, span_r)))
            r_list.append(ParseResult(idx, agg, sel, cond))

        return r_list
コード例 #15
0
ファイル: utils.py プロジェクト: wronnyhuang/SQLNet_inference
def infer_exec(model, batch_size, sql_data, table_data, db_path):
    engine = DBEngine(db_path)

    model.eval()
    perm = list(range(len(sql_data)))
    tot_acc_num = 0.0
    acc_of_log = 0.0
    st = 0
    while st < len(sql_data):
        ed = st + batch_size if st + batch_size < len(perm) else len(perm)
        q_seq, col_seq, col_num, ans_seq, query_seq, gt_cond_seq, raw_data = \
            to_batch_seq(sql_data, table_data, perm, st, ed, ret_vis_data=True)
        raw_q_seq = [x[0] for x in raw_data]
        raw_col_seq = [x[1] for x in raw_data]
        query_gt, table_ids = to_batch_query(sql_data, perm, st, ed)
        gt_sel_seq = [x[1] for x in ans_seq]

        score = model.forward(q_seq,
                              col_seq,
                              col_num, (True, True, True),
                              gt_sel=gt_sel_seq)
        pred_queries = model.gen_query(score, q_seq, col_seq, raw_q_seq,
                                       raw_col_seq, (True, True, True))
        for idx, (sql_gt, sql_pred,
                  tid) in enumerate(zip(query_gt, pred_queries, table_ids)):
            print_table(table_data, tid)
            raw_query = engine.get_query_raw(tid, sql_pred['sel'],
                                             sql_pred['agg'],
                                             sql_pred['conds'], table_data)
            # ret_gt = engine.execute(tid, sql_gt['sel'], sql_gt['agg'], sql_gt['conds'])
            # try:
            #     ret_pred = engine.execute(tid, sql_pred['sel'], sql_pred['agg'], sql_pred['conds'])
            # except:
            #     ret_pred = None
            # isCorrect = (ret_gt == ret_pred)
            print "==> Here's an example"
            print 'English: ', raw_q_seq[0]
            print 'SQL: ', raw_query
            # print 'Execution: ', str(ret_pred[0]).encode('utf-8') if ret_pred else 'null'
            # print 'Correct: ', isCorrect
            print '\n'
            break

        # INFERENCE TIME!
        print '==> Your turn, type a question about this table'
        raw_q_seq, q_seq = input_tokenize_wrapper()
        raw_q_seq = raw_q_seq.decode('utf-8')
        q_seq = [w.decode('utf-8') for w in q_seq]
        raw_q_seq, q_seq = [raw_q_seq, raw_q_seq], [q_seq, q_seq]
        score = model.forward(q_seq,
                              col_seq,
                              col_num, (True, True, True),
                              gt_sel=gt_sel_seq)
        pred_queries = model.gen_query(score, q_seq, col_seq, raw_q_seq,
                                       raw_col_seq, (True, True, True))
        for idx, (sql_pred, tid) in enumerate(zip(pred_queries, table_ids)):
            raw_query = engine.get_query_raw(tid,
                                             sql_pred['sel'],
                                             sql_pred['agg'],
                                             sql_pred['conds'],
                                             table_data,
                                             lower=True)
            # try:
            #     ret_pred = engine.execute(tid, sql_pred['sel'], sql_pred['agg'], sql_pred['conds'])
            # except:
            #     ret_pred = '<Failed>'
            print 'ENGLISH: ', raw_q_seq[0]
            print 'SQL: ', raw_query
            # print 'Execution: ', str(ret_pred[0]).encode('utf-8') if ret_pred else 'null'
            print '\n\n'
            break
        st += 1
        sleep(5)
コード例 #16
0
def main():
    dummy_parser = argparse.ArgumentParser(description='train.py')
    opts.model_opts(dummy_parser)
    opts.train_opts(dummy_parser)
    dummy_opt = dummy_parser.parse_known_args([])[0]

    engine = DBEngine(opt.db_file)

    with codecs.open(opt.source_file, "r", "utf-8") as corpus_file:
        sql_list = [json.loads(line)['sql'] for line in corpus_file]

    js_list = table.IO.read_anno_json(opt.anno)

    prev_best = (None, None)
    print(opt.split, opt.model_path)

    num_models = 0

    f_out = open('Two-stream-' + opt.unseen_table + '-out-case', 'w')

    for fn_model in glob.glob(opt.model_path):
        num_models += 1
        sys.stdout.flush()
        print(fn_model)
        print(opt.anno)
        opt.model = fn_model

        translator = table.Translator(opt, dummy_opt.__dict__)
        data = table.IO.TableDataset(js_list, translator.fields, None, False)
        #torch.save(data, open( 'data.pt', 'wb'))
        test_data = table.IO.OrderedIterator(dataset=data,
                                             device=opt.gpu,
                                             batch_size=opt.batch_size,
                                             train=False,
                                             sort=True,
                                             sort_within_batch=False)

        # inference
        r_list = []
        for batch in test_data:
            r_list += translator.translate(batch)
        r_list.sort(key=lambda x: x.idx)
        assert len(r_list) == len(
            js_list), 'len(r_list) != len(js_list): {} != {}'.format(
                len(r_list), len(js_list))
        # evaluation
        error_cases = []
        for pred, gold, sql_gold in zip(r_list, js_list, sql_list):
            error_cases.append(pred.eval(opt.split, gold, sql_gold, engine))
#            error_cases.append(pred.eval(opt.split, gold, sql_gold))
        print('Results:')
        for metric_name in ('all', 'exe', 'agg', 'sel', 'where', 'col', 'span',
                            'lay', 'BIO', 'BIO_col'):
            c_correct = sum((x.correct[metric_name] for x in r_list))
            print('{}: {} / {} = {:.2%}'.format(metric_name, c_correct,
                                                len(r_list),
                                                c_correct / len(r_list)))
            if metric_name == 'all':
                all_acc = c_correct
            if metric_name == 'exe':
                exe_acc = c_correct
        if prev_best[
                0] is None or all_acc + exe_acc > prev_best[1] + prev_best[2]:
            prev_best = (fn_model, all_acc, exe_acc)

#        random.shuffle(error_cases)
        for error_case in error_cases:
            if len(error_case) == 0:
                continue
            json.dump(error_case, f_out)
            f_out.write('\n')


#            print('table_id:\t', error_case['table_id'])
#            print('question_id:\t',error_case['question_id'])
#            print('question:\t', error_case['question'])
#            print('table_head:\t', error_case['table_head'])
#            print('table_content:\t', error_case['table_content'])
#            print()

#            print(error_case['BIO'])
#            print(error_case['BIO_col'])
#            print()

#            print('gold:','agg:',error_case['gold']['agg'],'sel:',error_case['predict']['sel'])
#            for i in range(len(error_case['gold']['conds'])):
#                print(error_case['gold']['conds'][i])

#           print('predict:','agg:',error_case['predict']['agg'],'sel:',error_case['predict']['sel'])
#           for i in range(len(error_case['predict']['conds'])):
#               print(error_case['predict']['conds'][i])
#           print('\n\n')

    print(prev_best)
    if (opt.split == 'dev') and (prev_best[0] is not None) and num_models != 1:
        if opt.unseen_table == 'full':
            with codecs.open(os.path.join(opt.save_path, 'dev_best.txt'),
                             'w',
                             encoding='utf-8') as f_out:
                f_out.write('{}\n'.format(prev_best[0]))
        else:
            with codecs.open(os.path.join(
                    opt.save_path, 'dev_best_' + opt.unseen_table + '.txt'),
                             'w',
                             encoding='utf-8') as f_out:
                f_out.write('{}\n'.format(prev_best[0]))
コード例 #17
0
from lib.dbengine import DBEngine
from lib.table import Table
#from lib.query import Query
import json
import jsonlines

#create DataBase Engine
DB=DBEngine("data/train.db")

#load table
tables=[json.loads(x) for x in open("data/train.tables.jsonl")]
print(len(tables))
#Create Table 
table=Table(tables[0]['id'], tables[0]['header'], tables[0]['types'], tables[0]['rows'])
print(table)

#sample
print(table.generate_queries(DB.conn,n=5, max_tries=5, lower=True))
queries = table.generate_queries(DB.conn,n=5, max_tries=5, lower=True)
print(str(queries[0][0]))
def augment_data():
    new_train_path = "data/augmented_train.jsonl"
    error_cnt = 0
    with open(new_train_path,'w') as f:
        for i in range(len(tables)):
            table_new = Table(tables[i]['id'], tables[i]['header'], tables[i]['types'], tables[i]['rows'])
            try:
                queries = table_new.generate_queries(DB.conn,n=10, max_tries=100, lower=True)
                for query in queries:
                    #print(query[1])
                    info = {}
コード例 #18
0
def epoch_exec_acc(models,
                   batch_size,
                   sql_data,
                   table_data,
                   db_path,
                   db_content,
                   BERT=False,
                   POS=False,
                   ensemble='single'):
    engine = DBEngine(db_path)

    if len(models) > 1:
        models_eval = list()
        for nn in models:
            nn.eval()
            models_eval.append(nn)
    else:
        model = models[0]
        model.eval()

    perm = list(range(len(sql_data)))
    tot_acc_num = 0.0
    acc_of_log = 0.0
    st = 0

    while st < len(sql_data):
        ed = st + batch_size if st + batch_size < len(perm) else len(perm)

        if ensemble == 'mixed':
            if POS:
                q_pos, q_seq_bert, col_seq, col_num, ans_seq, query_seq, gt_cond_seq, q_type, col_type,\
                raw_data = to_batch_seq(sql_data, table_data, perm, st, ed, db_content, ret_vis_data=True, BERT=BERT, POS=POS)
                q_pos, q_seq, col_seq, col_num, ans_seq, query_seq, gt_cond_seq, q_type, col_type,\
            raw_data = to_batch_seq(sql_data, table_data, perm, st, ed, db_content, ret_vis_data=True, BERT=False, POS=POS)
            else:
                q_seq_bert, col_seq, col_num, ans_seq, query_seq, gt_cond_seq, q_type, col_type,\
                raw_data = to_batch_seq(sql_data, table_data, perm, st, ed, db_content, ret_vis_data=True, BERT=BERT, POS=POS)
                q_seq, col_seq, col_num, ans_seq, query_seq, gt_cond_seq, q_type, col_type,\
            raw_data = to_batch_seq(sql_data, table_data, perm, st, ed, db_content, ret_vis_data=True, BERT=False, POS=POS)
        else:
            if POS:
                q_pos, q_seq, col_seq, col_num, ans_seq, query_seq, gt_cond_seq, q_type, col_type,\
            raw_data = to_batch_seq(sql_data, table_data, perm, st, ed, db_content, ret_vis_data=True, BERT=BERT, POS=POS)
            else:
                q_seq, col_seq, col_num, ans_seq, query_seq, gt_cond_seq, q_type, col_type,\
            raw_data = to_batch_seq(sql_data, table_data, perm, st, ed, db_content, ret_vis_data=True, BERT=BERT, POS=POS)

        raw_q_seq = [x[0] for x in raw_data]
        raw_col_seq = [x[1] for x in raw_data]
        query_gt, table_ids = to_batch_query(sql_data, perm, st, ed)
        gt_sel_seq = [x[1] for x in ans_seq]
        gt_agg_seq = [x[0] for x in ans_seq]

        if len(models) > 1:
            scores = list()
            for i, model in enumerate(models_eval):

                if ensemble == 'mixed':
                    if i == 0:
                        gt_where_seq = model.generate_gt_where_seq(
                            q_seq, col_seq, query_seq)
                        if POS:
                            score = model.forward(q_seq,
                                                  col_seq,
                                                  col_num,
                                                  q_type,
                                                  col_type, (True, True, True),
                                                  q_pos=q_pos)
                        else:
                            score = model.forward(q_seq, col_seq, col_num,
                                                  q_type, col_type,
                                                  (True, True, True))
                    else:
                        gt_where_seq = model.generate_gt_where_seq(
                            q_seq_bert, col_seq, query_seq)
                        if POS:
                            score = model.forward(q_seq_bert,
                                                  col_seq,
                                                  col_num,
                                                  q_type,
                                                  col_type, (True, True, True),
                                                  q_pos=q_pos)
                        else:
                            score = model.forward(q_seq_bert, col_seq, col_num,
                                                  q_type, col_type,
                                                  (True, True, True))
                else:
                    gt_where_seq = model.generate_gt_where_seq(
                        q_seq, col_seq, query_seq)
                    if POS:
                        score = model.forward(q_seq,
                                              col_seq,
                                              col_num,
                                              q_type,
                                              col_type, (True, True, True),
                                              q_pos=q_pos)
                    else:
                        score = model.forward(q_seq, col_seq, col_num, q_type,
                                              col_type, (True, True, True))
            scores.append(score)
            model = models_eval[0]
            pred_queries = model.gen_query(scores, q_seq, col_seq, raw_q_seq,
                                           raw_col_seq, (True, True, True))
            for idx, (sql_gt, sql_pred,
                      tid) in enumerate(zip(query_gt, pred_queries,
                                            table_ids)):
                ret_gt = engine.execute(tid, sql_gt['sel'], sql_gt['agg'],
                                        sql_gt['conds'])
                try:
                    ret_pred = engine.execute(tid, sql_pred['sel'],
                                              sql_pred['agg'],
                                              sql_pred['conds'])
                except:
                    ret_pred = None
                tot_acc_num += (ret_gt == ret_pred)
        else:
            gt_where_seq = model.generate_gt_where_seq(q_seq, col_seq,
                                                       query_seq)
            if POS:
                score = [
                    model.forward(q_seq,
                                  col_seq,
                                  col_num,
                                  q_type,
                                  col_type, (True, True, True),
                                  q_pos=q_pos)
                ]
            else:
                score = [
                    model.forward(q_seq, col_seq, col_num, q_type, col_type,
                                  (True, True, True))
                ]
            pred_queries = model.gen_query(score, q_seq, col_seq, raw_q_seq,
                                           raw_col_seq, (True, True, True))
            for idx, (sql_gt, sql_pred,
                      tid) in enumerate(zip(query_gt, pred_queries,
                                            table_ids)):
                ret_gt = engine.execute(tid, sql_gt['sel'], sql_gt['agg'],
                                        sql_gt['conds'])
                try:
                    ret_pred = engine.execute(tid, sql_pred['sel'],
                                              sql_pred['agg'],
                                              sql_pred['conds'])
                except:
                    ret_pred = None
                tot_acc_num += (ret_gt == ret_pred)

        st = ed

    return tot_acc_num / len(sql_data)
コード例 #19
0
import json
import argparse
from tqdm import tqdm
from lib.query import Query
from lib.dbengine import DBEngine
from lib.common import count_lines

if __name__=="__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--source_file")
    parser.add_argument("--db_file")
    parser.add_argument("--pred_file")
    parser.add_argument("--ordered", action='store_true')
    args = parser.parse_args()

    engine = DBEngine(args.db_file)
    ex_acc_list = []
    lf_acc_list = []
    with open(args.source_file) as sf, open(args.pred_file) as pf:
        for source_line, pred_line in tqdm(zip(sf, pf), total=count_lines(args.source_file)):
            # line별 정답과 예측 샘플 가져오기
            gold_example = json.loads(source_line)
            pred_example = json.loads(pred_line)

            # 정답 샘플 lf, ex 구하기
            lf_gold_query = Query.from_dict(gold_example['sql'], ordered=args.ordered)
            ex_gold = engine.execute_query(gold_example['table_id'], lf_gold_query, lower=True)

            # error가 아닌 경우 예측 샘플 lf, ex 구하기
            lf_pred_query = None
            ex_pred = pred_example.get('error', None)