コード例 #1
0
ファイル: utils.py プロジェクト: yuconan/nl2sql_baseline
def epoch_exec_acc(model, batch_size, sql_data, table_data, db_path):
    engine = DBEngine(db_path)
    model.eval()
    perm = list(range(len(sql_data)))
    tot_acc_num = 0.0
    for st in tqdm(range(len(sql_data) // batch_size + 1)):
        ed = (st + 1) * batch_size if (st + 1) * batch_size < len(
            perm) else len(perm)
        st = st * batch_size
        q_seq, gt_sel_num, col_seq, col_num, ans_seq, gt_cond_seq, raw_data = \
            to_batch_seq(sql_data, table_data, perm, st, ed, ret_vis_data=True)
        raw_q_seq = [x[0] for x in raw_data]
        raw_col_seq = [x[1] for x in raw_data]
        query_gt, table_ids = to_batch_query(sql_data, perm, st, ed)
        gt_sel_seq = [x[2] for x in ans_seq]
        score = model.forward(q_seq, col_seq, col_num, gt_sel=gt_sel_seq)
        pred_queries = model.gen_query(score, q_seq, col_seq, raw_q_seq,
                                       raw_col_seq)

        for idx, (sql_gt, sql_pred,
                  tid) in enumerate(zip(query_gt, pred_queries, table_ids)):
            ret_gt = engine.execute(tid, sql_gt['sel'], sql_gt['agg'],
                                    sql_gt['conds'])
            try:
                ret_pred = engine.execute(tid, sql_pred['sel'],
                                          sql_pred['agg'], sql_pred['conds'])
            except:
                ret_pred = None
            tot_acc_num += (ret_gt == ret_pred)
    return tot_acc_num / len(sql_data)
コード例 #2
0
def epoch_reinforce_train(model, optimizer, batch_size, sql_data, table_data,
                          db_path):
    engine = DBEngine(db_path)

    model.train()
    perm = np.random.permutation(len(sql_data))
    cum_reward = 0.0
    st = 0
    while st < len(sql_data):
        ed = st + batch_size if st + batch_size < len(perm) else len(perm)

        q_seq, col_seq, col_num, ans_seq, query_seq, gt_cond_seq, raw_data = \
            to_batch_seq(sql_data, table_data, perm, st, ed, ret_vis_data=True)
        gt_where_seq = model.generate_gt_where_seq(q_seq, col_seq, query_seq)
        raw_q_seq = [x[0] for x in raw_data]
        raw_col_seq = [x[1] for x in raw_data]
        query_gt, table_ids = to_batch_query(sql_data, perm, st, ed)
        gt_sel_seq = [x[1] for x in ans_seq]
        score = model.forward(q_seq,
                              col_seq,
                              col_num, (True, True, True),
                              reinforce=True,
                              gt_sel=gt_sel_seq)
        pred_queries = model.gen_query(score,
                                       q_seq,
                                       col_seq,
                                       raw_q_seq,
                                       raw_col_seq, (True, True, True),
                                       reinforce=True)

        query_gt, table_ids = to_batch_query(sql_data, perm, st, ed)
        rewards = []
        for idx, (sql_gt, sql_pred,
                  tid) in enumerate(zip(query_gt, pred_queries, table_ids)):
            ret_gt = engine.execute(tid, sql_gt['sel'], sql_gt['agg'],
                                    sql_gt['conds'])
            try:
                ret_pred = engine.execute(tid, sql_pred['sel'],
                                          sql_pred['agg'], sql_pred['conds'])
            except:
                ret_pred = None

            if ret_pred is None:
                rewards.append(-2)
            elif ret_pred != ret_gt:
                rewards.append(-1)
            else:
                rewards.append(1)

        cum_reward += (sum(rewards))
        optimizer.zero_grad()
        model.reinforce_backward(score, rewards)
        optimizer.step()

        st = ed

    return cum_reward / len(sql_data)
コード例 #3
0
def epoch_acc(model, batch_size, sql_data, table_data, db_path):
    engine = DBEngine(db_path)
    model.eval()
    perm = list(range(len(sql_data)))
    badcase = 0
    one_acc_num, tot_acc_num, ex_acc_num = 0.0, 0.0, 0.0
    for st in tqdm(range(len(sql_data) // batch_size + 1)):
        ed = (st + 1) * batch_size if (st + 1) * batch_size < len(
            perm) else len(perm)
        st = st * batch_size
        q_seq, gt_sel_num, col_seq, col_num, ans_seq, gt_cond_seq, raw_data = \
            to_batch_seq(sql_data, table_data, perm, st, ed, ret_vis_data=True)
        # q_seq: char-based sequence of question
        # gt_sel_num: number of selected columns and aggregation functions, new added field
        # col_seq: char-based column name
        # col_num: number of headers in one table
        # ans_seq: (sel, number of conds, sel list in conds, op list in conds)
        # gt_cond_seq: ground truth of conditions
        # raw_data: ori question, headers, sql
        query_gt, table_ids = to_batch_query(sql_data, perm, st, ed)
        # query_gt: ground truth of sql, data['sql'], containing sel, agg, conds:{sel, op, value}
        raw_q_seq = [x[0] for x in raw_data]  # original question
        try:
            score = model.forward(q_seq, col_seq, col_num)
            pred_queries = model.gen_query(score, q_seq, col_seq, raw_q_seq)
            # generate predicted format
            one_err, tot_err = model.check_acc(raw_data, pred_queries,
                                               query_gt)
        except:
            badcase += 1
            print('badcase', badcase)
            continue
        one_acc_num += (ed - st - one_err)
        tot_acc_num += (ed - st - tot_err)

        # Execution Accuracy
        for sql_gt, sql_pred, tid in zip(query_gt, pred_queries, table_ids):
            ret_gt = engine.execute(tid, sql_gt['sel'], sql_gt['agg'],
                                    sql_gt['conds'], sql_gt['cond_conn_op'])
            try:
                ret_pred = engine.execute(tid, sql_pred['sel'],
                                          sql_pred['agg'], sql_pred['conds'],
                                          sql_pred['cond_conn_op'])
            except:
                ret_pred = None
            ex_acc_num += (ret_gt == ret_pred)
    return one_acc_num / len(sql_data), tot_acc_num / len(
        sql_data), ex_acc_num / len(sql_data)
コード例 #4
0
ファイル: utils.py プロジェクト: qianwenyuan/typesql_ch
def epoch_acc(model, batch_size, sql_data, table_data, db_path, db_content):
    engine = DBEngine(db_path)
    model.eval()
    perm = list(range(len(sql_data)))
    badcase = 0
    one_acc_num, tot_acc_num, ex_acc_num = 0.0, 0.0, 0.0
    for st in tqdm(range(len(sql_data) // batch_size + 1)):
        ed = (st + 1) * batch_size if (st + 1) * batch_size < len(
            perm) else len(perm)
        st = st * batch_size
        q_seq, gt_sel_num, col_seq, col_num, ans_seq, gt_cond_seq, q_type, col_type,\
         raw_data = to_batch_seq(sql_data, table_data, perm, st, ed, db_content, ret_vis_data=True)
        print("q_seq:{}".format(len(q_seq)))
        raw_q_seq = [x[0] for x in raw_data]
        #raw_col_seq = [x[1] for x in raw_data]
        query_gt, table_ids = to_batch_query(sql_data, perm, st, ed)
        #gt_sel_seq = [x[1] for x in ans_seq]
        #try:
        score = model.forward(q_seq, col_seq, col_num, q_type, col_type)
        pred_queries = model.gen_query(score, q_seq, col_seq, raw_q_seq)
        one_err, tot_err = model.check_acc(raw_data, pred_queries, query_gt)
        #except:
        #    badcase += 1
        #    print 'badcase', badcase
        #    continue
        one_acc_num += (ed - st - one_err)
        tot_acc_num += (ed - st - tot_err)

        # Execution Accuracy
        for sql_gt, sql_pred, tid in zip(query_gt, pred_queries, table_ids):
            ret_gt = engine.execute(tid, sql_gt['sel'], sql_gt['agg'],
                                    sql_gt['conds'], sql_gt['cond_conn_op'])
            try:
                ret_pred = engine.execute(tid, sql_pred['sel'],
                                          sql_pred['agg'], sql_pred['conds'],
                                          sql_pred['cond_conn_op'])
            except:
                ret_pred = None
            ex_acc_num += (ret_gt == ret_pred)
    return one_acc_num / len(sql_data), tot_acc_num / len(
        sql_data), ex_acc_num / len(sql_data)
コード例 #5
0
ファイル: utils.py プロジェクト: shanelleroman/seq2sql
def epoch_exec_acc(model, batch_size, sql_data, table_data, db_path):
    engine = DBEngine(db_path)
    print 'exec acc'

    model.eval()
    perm = list(range(len(sql_data)))
    tot_acc_num = 0.0
    acc_of_log = 0.0
    st = 0
    while st < len(sql_data):
        ed = st+batch_size if st+batch_size < len(perm) else len(perm)
        q_seq, col_seq, col_num, ans_seq, query_seq, gt_cond_seq, raw_data = \
            to_batch_seq(sql_data, table_data, perm, st, ed, ret_vis_data=True)
        raw_q_seq = [x[0] for x in raw_data]
        raw_col_seq = [x[1] for x in raw_data]
        model.generate_gt_sel_seq(q_seq, col_seq, query_seq)
        gt_where_seq = model.generate_gt_where_seq(q_seq, col_seq, query_seq)
        query_gt, table_ids = to_batch_query(sql_data, perm, st, ed)
        gt_sel_seq = [x[1] for x in ans_seq]
        # print 'gt_sel_seq', gt_sel_seq
        score = model.forward(q_seq, col_seq, col_num,
                (True, True, True), gt_sel=gt_sel_seq)
        pred_queries = model.gen_query(score, q_seq, col_seq,
                raw_q_seq, raw_col_seq, (True, True, True))

        for idx, (sql_gt, sql_pred, tid) in enumerate(
                zip(query_gt, pred_queries, table_ids)):
            ret_gt = engine.execute(tid,
                    sql_gt['sel'], sql_gt['agg'], sql_gt['conds'])
            try:
                ret_pred = engine.execute(tid,
                        sql_pred['sel'], sql_pred['agg'], sql_pred['conds'])
            except:
                ret_pred = None
            tot_acc_num += (ret_gt == ret_pred)
        
        st = ed

    return tot_acc_num / len(sql_data)
コード例 #6
0
def epoch_exec_acc(models,
                   batch_size,
                   sql_data,
                   table_data,
                   db_path,
                   db_content,
                   BERT=False,
                   POS=False,
                   ensemble='single'):
    engine = DBEngine(db_path)

    if len(models) > 1:
        models_eval = list()
        for nn in models:
            nn.eval()
            models_eval.append(nn)
    else:
        model = models[0]
        model.eval()

    perm = list(range(len(sql_data)))
    tot_acc_num = 0.0
    acc_of_log = 0.0
    st = 0

    while st < len(sql_data):
        ed = st + batch_size if st + batch_size < len(perm) else len(perm)

        if ensemble == 'mixed':
            if POS:
                q_pos, q_seq_bert, col_seq, col_num, ans_seq, query_seq, gt_cond_seq, q_type, col_type,\
                raw_data = to_batch_seq(sql_data, table_data, perm, st, ed, db_content, ret_vis_data=True, BERT=BERT, POS=POS)
                q_pos, q_seq, col_seq, col_num, ans_seq, query_seq, gt_cond_seq, q_type, col_type,\
            raw_data = to_batch_seq(sql_data, table_data, perm, st, ed, db_content, ret_vis_data=True, BERT=False, POS=POS)
            else:
                q_seq_bert, col_seq, col_num, ans_seq, query_seq, gt_cond_seq, q_type, col_type,\
                raw_data = to_batch_seq(sql_data, table_data, perm, st, ed, db_content, ret_vis_data=True, BERT=BERT, POS=POS)
                q_seq, col_seq, col_num, ans_seq, query_seq, gt_cond_seq, q_type, col_type,\
            raw_data = to_batch_seq(sql_data, table_data, perm, st, ed, db_content, ret_vis_data=True, BERT=False, POS=POS)
        else:
            if POS:
                q_pos, q_seq, col_seq, col_num, ans_seq, query_seq, gt_cond_seq, q_type, col_type,\
            raw_data = to_batch_seq(sql_data, table_data, perm, st, ed, db_content, ret_vis_data=True, BERT=BERT, POS=POS)
            else:
                q_seq, col_seq, col_num, ans_seq, query_seq, gt_cond_seq, q_type, col_type,\
            raw_data = to_batch_seq(sql_data, table_data, perm, st, ed, db_content, ret_vis_data=True, BERT=BERT, POS=POS)

        raw_q_seq = [x[0] for x in raw_data]
        raw_col_seq = [x[1] for x in raw_data]
        query_gt, table_ids = to_batch_query(sql_data, perm, st, ed)
        gt_sel_seq = [x[1] for x in ans_seq]
        gt_agg_seq = [x[0] for x in ans_seq]

        if len(models) > 1:
            scores = list()
            for i, model in enumerate(models_eval):

                if ensemble == 'mixed':
                    if i == 0:
                        gt_where_seq = model.generate_gt_where_seq(
                            q_seq, col_seq, query_seq)
                        if POS:
                            score = model.forward(q_seq,
                                                  col_seq,
                                                  col_num,
                                                  q_type,
                                                  col_type, (True, True, True),
                                                  q_pos=q_pos)
                        else:
                            score = model.forward(q_seq, col_seq, col_num,
                                                  q_type, col_type,
                                                  (True, True, True))
                    else:
                        gt_where_seq = model.generate_gt_where_seq(
                            q_seq_bert, col_seq, query_seq)
                        if POS:
                            score = model.forward(q_seq_bert,
                                                  col_seq,
                                                  col_num,
                                                  q_type,
                                                  col_type, (True, True, True),
                                                  q_pos=q_pos)
                        else:
                            score = model.forward(q_seq_bert, col_seq, col_num,
                                                  q_type, col_type,
                                                  (True, True, True))
                else:
                    gt_where_seq = model.generate_gt_where_seq(
                        q_seq, col_seq, query_seq)
                    if POS:
                        score = model.forward(q_seq,
                                              col_seq,
                                              col_num,
                                              q_type,
                                              col_type, (True, True, True),
                                              q_pos=q_pos)
                    else:
                        score = model.forward(q_seq, col_seq, col_num, q_type,
                                              col_type, (True, True, True))
            scores.append(score)
            model = models_eval[0]
            pred_queries = model.gen_query(scores, q_seq, col_seq, raw_q_seq,
                                           raw_col_seq, (True, True, True))
            for idx, (sql_gt, sql_pred,
                      tid) in enumerate(zip(query_gt, pred_queries,
                                            table_ids)):
                ret_gt = engine.execute(tid, sql_gt['sel'], sql_gt['agg'],
                                        sql_gt['conds'])
                try:
                    ret_pred = engine.execute(tid, sql_pred['sel'],
                                              sql_pred['agg'],
                                              sql_pred['conds'])
                except:
                    ret_pred = None
                tot_acc_num += (ret_gt == ret_pred)
        else:
            gt_where_seq = model.generate_gt_where_seq(q_seq, col_seq,
                                                       query_seq)
            if POS:
                score = [
                    model.forward(q_seq,
                                  col_seq,
                                  col_num,
                                  q_type,
                                  col_type, (True, True, True),
                                  q_pos=q_pos)
                ]
            else:
                score = [
                    model.forward(q_seq, col_seq, col_num, q_type, col_type,
                                  (True, True, True))
                ]
            pred_queries = model.gen_query(score, q_seq, col_seq, raw_q_seq,
                                           raw_col_seq, (True, True, True))
            for idx, (sql_gt, sql_pred,
                      tid) in enumerate(zip(query_gt, pred_queries,
                                            table_ids)):
                ret_gt = engine.execute(tid, sql_gt['sel'], sql_gt['agg'],
                                        sql_gt['conds'])
                try:
                    ret_pred = engine.execute(tid, sql_pred['sel'],
                                              sql_pred['agg'],
                                              sql_pred['conds'])
                except:
                    ret_pred = None
                tot_acc_num += (ret_gt == ret_pred)

        st = ed

    return tot_acc_num / len(sql_data)