Exemple #1
0
def epoch_acc(batch_size, table_data, sql_data, db_path, cs, out=False):
    with open('../submit/results_val.jsonl') as inf:
        sqlOutdata = []
        for idx, line in enumerate(inf):
            _sql = json.loads(line.strip())
            _sql['sql'] = _sql['query']
            _sql['question'] = _sql['nlu']
            # print(_sql)
            if ('error') in _sql:
                sqlOutdata.append(sqlOutdata[idx - 1])
            else:
                sqlOutdata.append(_sql)
    print(len(sql_data))
    print(len(sqlOutdata))
    engine = DBEngine(db_path)
    perm = list(range(len(sqlOutdata)))
    badcase = 0
    one_acc_num, tot_acc_num, ex_acc_num = 0.0, 0.0, 0.0
    for st in tqdm(range(len(sql_data) // batch_size + 1)):
        ed = (st + 1) * batch_size if (st + 1) * batch_size < len(
            perm) else len(perm)
        st = st * batch_size
        q_seq, gt_sel_num, col_seq, col_num, ans_seq, gt_cond_seq, gt_type, raw_data = \
            to_batch_seq(sql_data, table_data, perm, st, ed, ret_vis_data=True)

        query_gt, table_ids = to_batch_query(sql_data, perm, st, ed)
        # query_gt: ground truth of sql, data['sql'], containing sel, agg, conds:{sel, op, value}
        raw_q_seq = [x[0] for x in raw_data]  # original question
        # try:

        pred_queriesc, allfsc = genByout(sqlOutdata[st:ed], table_ids,
                                         raw_q_seq, gt_type, 'val', cs)

        one_err, tot_err = check_acc(raw_data, pred_queriesc, query_gt, allfsc)

        # except:
        #     badcase += 1
        #     print 'badcase', badcase
        #     continue
        one_acc_num += (ed - st - one_err)
        tot_acc_num += (ed - st - tot_err)

        # Execution Accuracy
        for sql_gt, sql_pred, tid in zip(query_gt, pred_queriesc, table_ids):
            ret_gt = engine.execute(tid, sql_gt['sel'], sql_gt['agg'],
                                    sql_gt['conds'], sql_gt['cond_conn_op'])
            try:
                ret_pred = engine.execute(tid, sql_pred['sel'],
                                          sql_pred['agg'], sql_pred['conds'],
                                          sql_pred['cond_conn_op'])
            except:
                ret_pred = None
            ex_acc_num += (ret_gt == ret_pred)
    return one_acc_num / len(sql_data), tot_acc_num / len(
        sql_data), ex_acc_num / len(sql_data)
Exemple #2
0
def train(train_loader,
          train_table,
          model,
          model_bert,
          opt,
          bert_config,
          tokenizer,
          max_seq_length,
          num_target_layers,
          accumulate_gradients=1,
          check_grad=True,
          st_pos=0,
          opt_bert=None,
          path_db=None,
          dset_name='train'):
    model.train()
    model_bert.train()

    ave_loss = 0
    count = 0  # count the # of examples
    count_sc = 0  # count the # of correct predictions of select column
    count_sa = 0  # of selectd aggregation
    count_wn = 0  # of where number
    count_wc = 0  # of where column
    count_wo = 0  # of where operator
    count_wv = 0  # of where-value
    count_wvi = 0  # of where-value index (on question tokens)
    count_logic_form_acc = 0  # of logical form acc
    count_execute_acc = 0  # of execution acc

    # Engine for SQL querying.
    engine = DBEngine(os.path.join(path_db, f"{dset_name}.db"))

    for batch_index, batch_data in enumerate(train_loader):
        count += len(batch_data)

        if count < st_pos:
            continue
        # Get fields
        question, question_token, sql, sql_text, sql_t, table, header_token, header \
            = get_fields(batch_data, train_table, no_hs_t=True, no_sql_t=True)

        gt_select_column, gt_select_agg, gt_wherenumber, gt_wherecolumn, g_wo, g_wv = get_gt(
            sql)
        # get ground truth where-value index under CoreNLP tokenization scheme. It's done already on trainset.
        gt_wherevalueindex_corenlp = get_gt_wherevalueindex_corenlp(batch_data)

        emb_question, emb_header, len_question, len_header_token, number_header, \
        question_token_bert, token_to_berttoken_index, berttoken_to_token_index \
            = get_wemb_bert(bert_config, model_bert, tokenizer, question_token, header, max_seq_length,
                            num_out_layers_n=num_target_layers, num_out_layers_h=num_target_layers)

        try:
            #
            gt_wherevalueindex = get_gt_wherevalueindex_bert_from_gt_wherevalueindex_corenlp(
                token_to_berttoken_index, gt_wherevalueindex_corenlp)
        except:
            # Exception happens when where-condition is not found in nlu_tt.
            # In this case, that train example is not used.
            # During test, that example considered as wrongly answered.
            # e.g. train: 32.
            continue

        # score
        # gt_wherevalueindex = start_index, end_index
        # score_where_value: [batch,4,question_len,2]
        score_select_column, score_select_agg, score_where_number, score_where_column, score_where_op, score_where_value, \
            score_select_column_softmax, score_select_agg_softmax, score_where_number_softmax,\
            score_where_column_softmax, score_whereop_softmax, score_where_value_softmax \
                    = model(emb_question, len_question, emb_header, len_header_token, number_header,
                           g_sc=gt_select_column, g_sa=gt_select_agg, g_wn=gt_wherenumber,
                           g_wc=gt_wherecolumn, g_wvi=gt_wherevalueindex)

        # Calculate loss & step
        loss = Loss_selectwhere_startend_v2(
            score_select_column, score_select_agg, score_where_number,
            score_where_column, score_where_op, score_where_value,
            gt_select_column, gt_select_agg, gt_wherenumber, gt_wherecolumn,
            g_wo, gt_wherevalueindex)

        # RL

        # Random explore
        pred_selectcolumn_random, pred_selectagg_random, pred_wherenumber_random, pred_wherecolumn_random, pred_whereop_random, pred_wherevalueindex_random = \
                  pred_selectwhere_startend_random(score_select_column_softmax, score_select_agg_softmax, score_where_number_softmax,
                                                               score_where_column_softmax, score_whereop_softmax, score_where_value_softmax,)
        """
        pred_wherevalue_str_random, pred_wherevalue_str_bert_random = convert_pred_wvi_to_string(pred_wherevalueindex_random, question_token,
                                                                                  question_token_bert,
                                                                                  berttoken_to_token_index, question)
        
        # Sort pr_wc:
        #   Sort pr_wc when training the model as pr_wo and pr_wvi are predicted using ground-truth where-column (g_wc)
        #   In case of 'dev' or 'test', it is not necessary as the ground-truth is not used during inference.
        pred_wherecolumn_sorted_random = sort_pred_wherecolumn(pred_wherecolumn_random, gt_wherecolumn)
        
        random_sql_int = generate_sql_i_v2(pred_selectcolumn_random, pred_selectagg_random, pred_wherenumber_random,
                                        pred_wherecolumn_sorted_random, pred_whereop_random, pred_wherevalue_str_random, question)
        """

        # Prediction
        pr_sc, pr_sa, pr_wn, pr_wc, pr_wo, pr_wvi = pred_selectwhere_startend(
            score_select_column,
            score_select_agg,
            score_where_number,
            score_where_column,
            score_where_op,
            score_where_value,
        )
        pred_wherevalue_str, pred_wherevalue_str_bert_random = convert_pred_wvi_to_string(
            pr_wvi, question_token, question_token_bert,
            berttoken_to_token_index, question)

        # Sort pr_wc:
        #   Sort pr_wc when training the model as pr_wo and pr_wvi are predicted using ground-truth where-column (g_wc)
        #   In case of 'dev' or 'test', it is not necessary as the ground-truth is not used during inference.
        pr_wc_sorted = sort_pred_wherecolumn(pr_wc, gt_wherecolumn)
        pred_sql_int = generate_sql_i(pr_sc, pr_sa, pr_wn, pr_wc_sorted, pr_wo,
                                      pred_wherevalue_str, question)

        # select column
        def equal_in(cell_, pred_answer_column):
            for cell in pred_answer_column:
                if cell == cell_:
                    return True
            return False

        # select_column_random_int = torch.squeeze(torch.softmax(score_select_column,dim=-1).multinomial(1),dim=-1)
        #select_column_int = torch.squeeze(torch.argmax(score_select_column,dim=-1),dim=-1)

        # if batch_index%100==0:
        #     print("RL_loss_select_column",RL_loss_select_column.data.cpu().numpy())

        # RL
        # where number

        def list_in_list(small, big):
            for cell in big:
                try:
                    cell = str(int(cell))
                except:
                    cell = str(cell)
                if cell.lower() in small:
                    return True

        batch_size = score_where_value_softmax.shape[0]
        batch_have_good_reward = False
        while batch_have_good_reward == False:

            selectcolumn_random_v2 = []
            for i in range(batch_size):
                col = random.randint(1, len(header[i])) - 1
                selectcolumn_random_v2.append(col)

            wherenumber_random_v2 = []
            for i in range(batch_size):
                num = random.randint(1, 4)
                wherenumber_random_v2.append(num)

            selectagg_random_v2 = []
            for i in range(batch_size):
                agg = random.randint(1, 6) - 1
                selectagg_random_v2.append(agg)

            whereop_random_v2 = []
            for i in range(batch_size):
                cond = []
                for j in range(4):
                    col = random.randint(1, 3) - 1
                    cond.append(col)
                whereop_random_v2.append(cond)

            wherecolumn_random_v2 = []
            for i in range(batch_size):
                cond = []
                for j in range(4):
                    col = random.randint(1, len(header[i])) - 1
                    cond.append(col)
                wherecolumn_random_v2.append(cond)

            pred_wherevalueindex_random_v2 = []  # [8,4,2]
            for i in range(batch_size):
                cond = []
                for j in range(4):
                    start = random.randint(1, len_question[i]) - 1
                    end = random.randint(1, len_question[i]) - 1
                    while start >= end:
                        start = random.randint(1, len_question[i]) - 1
                        end = random.randint(1, len_question[i]) - 1
                    cond.append([start, end])
                pred_wherevalueindex_random_v2.append(cond)

            pred_wherevalue_str_random_v2, pred_wherevalue_str_bert_random_v2 = convert_pred_wvi_to_string(
                pred_wherevalueindex_random_v2, question_token,
                question_token_bert, berttoken_to_token_index, question)

            # pred_wherecolumn_sorted_random_v2 = sort_pred_wherecolumn(wherecolumn_random_v2, gt_wherecolumn)

            random_sql_int_v2 = generate_sql_i_v2(
                pred_selectcolumn_random, pred_selectagg_random,
                wherenumber_random_v2, wherecolumn_random_v2,
                whereop_random_v2, pred_wherevalue_str_random_v2, question)

            batch_reward_all = []
            for i in range(len(batch_data)):
                gt_answer_list = batch_data[i]["answer"]
                reward = 0
                tmp_conds = []
                for j in range(wherenumber_random_v2[i]):
                    cond = []
                    cond.append(wherecolumn_random_v2[i][j])
                    cond.append(whereop_random_v2[i][j])
                    cond.append(random_sql_int_v2[i]["conds"][j][2])
                    tmp_conds.append(cond)

                pred_answer_column = engine.execute(table[i]['id'],
                                                    selectcolumn_random_v2[i],
                                                    selectagg_random_v2[i],
                                                    tmp_conds)
                # answer in
                for cell in gt_answer_list:
                    if cell in pred_answer_column or equal_in(
                            cell, pred_answer_column):
                        reward = 1
                if reward > 0:
                    batch_have_good_reward = True
                batch_reward_all.append(reward)
            """
            batch_reward_where = []

            for i in range(len(batch_data)):
                gt_answer_list = batch_data[i]["answer"]
                reward = 0
                if len(random_sql_int_v2[i]["conds"])>0 and len(pred_sql_int[i]["conds"])>0:
                    tmp_conds = []
                    tmp_cond = pred_sql_int[i]["conds"][0]
                    tmp_cond[0] = wherecolumn_random_v2[i][0]
                    tmp_cond[1] = whereop_random_v2[i][0]
                    tmp_cond[2] = random_sql_int_v2[i]["conds"][0][2]
                    tmp_conds.append(tmp_cond)
                    if tmp_conds[0][2]!="":
                        pred_answer_column = engine.execute(table[i]['id'], "*", 0, tmp_conds)
                        # answer in
                        for cell in gt_answer_list:
                            if cell in pred_answer_column or equal_in(cell, pred_answer_column):
                                reward = 1

                        pred_answer_column4 = engine.execute(table[i]['id'], pred_sql_int[i]["sel"], 0, tmp_conds)
                        # answer absolute in
                        for cell in gt_answer_list:
                            if cell in pred_answer_column4 or equal_in(cell, pred_answer_column4):
                                reward = 1

                    if len(tmp_conds)>=1:
                        # same column: word in question and where column
                        pred_answer_column2 = engine.execute(table[i]['id'], tmp_conds[0][0], 0, [])
                        for cell in pred_answer_column2:
                            try:
                                cell = str(int(cell))
                            except:
                                cell = str(cell)
                            if cell in question[i].lower():
                                reward = 1
                        # same column: where value and where column
                            if cell == tmp_conds[0][2]:
                                reward = 1


                    tmp_conds2 = []
                    tmp_cond2 = pred_sql_int[i]["conds"][0]
                    tmp_cond2[0] = wherecolumn_random_v2[i][0]
                    tmp_cond2[2] = random_sql_int_v2[i]["conds"][0][2]
                    tmp_cond2[1] = 0 # EQUAL
                    tmp_conds2.append(tmp_cond2)
                    if len(tmp_conds2) >= 1:
                        pred_answer_column3 = engine.execute(table[i]['id'], tmp_conds2[0][0], 0, tmp_conds2)
                        # same row: the answer and this cell
                        for row in table[i]["rows"]:
                            if list_in_list(pred_answer_column3,row) and list_in_list(gt_answer_list,row):
                                reward = 1
                if reward>0:
                    batch_have_good_reward=True
                batch_reward_where.append(reward)
            """

        onehot_action_batch_wherenumber = []
        for action_int in wherenumber_random_v2:
            tmp = [0] * score_where_number.shape[1]
            tmp[action_int - 1] = 1
            onehot_action_batch_wherenumber.append(tmp)

        RL_loss_where_number = torch.mean(-torch.log(
            torch.sum(torch.softmax(score_where_number, dim=-1) * torch.tensor(
                onehot_action_batch_wherenumber, dtype=torch.float).to(device),
                      dim=-1)) * torch.tensor(batch_reward_all,
                                              dtype=torch.float).to(device))
        """
        batch_reward_select_column = []
        for i in range(len(batch_data)):
            gt_answer_list = batch_data[i]["answer"]
            pred_answer_column = engine.execute(table[i]['id'], random_sql_int[i]["sel"], 0, [])
            reward = -1
        for cell in gt_answer_list:
            if cell in pred_answer_column or equal_in(cell,pred_answer_column):
                reward = 1
        batch_reward_select_column.append(reward)
        """

        onehot_action_batch_selectcolumn = []
        for sel in selectcolumn_random_v2:
            tmp = [0] * score_select_column.shape[1]
            tmp[sel] = 1
            onehot_action_batch_selectcolumn.append(tmp)

        RL_loss_select_column = torch.mean(
            -torch.log(
                torch.sum(torch.softmax(score_select_column, dim=-1) *
                          torch.tensor(onehot_action_batch_selectcolumn,
                                       dtype=torch.float).to(device),
                          dim=-1)) *
            torch.tensor(batch_reward_all, dtype=torch.float).to(device))

        onehot_action_batch_selectagg = []
        for agg in selectagg_random_v2:
            tmp = [0] * score_select_agg.shape[1]
            tmp[agg] = 1
            onehot_action_batch_selectagg.append(tmp)

        RL_loss_select_agg = torch.mean(-torch.log(
            torch.sum(torch.softmax(score_select_agg, dim=-1) * torch.tensor(
                onehot_action_batch_selectagg, dtype=torch.float).to(device),
                      dim=-1)) * torch.tensor(batch_reward_all,
                                              dtype=torch.float).to(device))

        # RL
        # where column
        # where_column_int = torch.squeeze(torch.argmax(score_where_column, dim=-1), dim=-1)
        # where_column_int = torch.squeeze(torch.softmax(score_where_column, dim=-1).multinomial(1), dim=-1)

        where_column_int = []
        for tmp_sql in random_sql_int_v2:
            if len(tmp_sql["conds"]) == 0:
                where_column_int.append(-1)
            else:
                where_column_int.append(tmp_sql["conds"][0][0])
        onehot_action_batch_wherecolumn = []
        for action_int in wherecolumn_random_v2:
            tmp = [0] * score_where_column.shape[1]
            if action_int[0] != -1:  # all 0 will runtime error
                tmp[action_int[0]] = 1
            else:  # this is must
                tmp = [0.01] * score_where_column.shape[1]
            onehot_action_batch_wherecolumn.append(tmp)

        RL_loss_where_column = torch.mean(-torch.log(
            torch.sum(torch.softmax(score_where_column, dim=-1) * torch.tensor(
                onehot_action_batch_wherecolumn, dtype=torch.float).to(device),
                      dim=-1)) * torch.tensor(batch_reward_all,
                                              dtype=torch.float).to(device))

        # RL
        # where value
        # pred_wherevalueindex_random [8,4,2]
        action_wherevalue = []
        for one in pred_wherevalueindex_random_v2:
            if len(one) > 0:
                action = []
                for i in range(4):
                    if len(one) > i:
                        start = one[i][0]
                        tmp_start = [0] * score_where_value.shape[2]
                        tmp_start[start] = 1
                        end = one[i][1]
                        tmp_end = [0] * score_where_value.shape[2]
                        tmp_end[end] = 1
                        action.append([tmp_start, tmp_end])
                    else:
                        action.append([[0.01] * score_where_value.shape[2],
                                       [0.01] * score_where_value.shape[2]])
                action_wherevalue.append(action)
            else:
                action = []
                for i in range(4):
                    action.append([[0.01] * score_where_value.shape[2],
                                   [0.01] * score_where_value.shape[2]])
                action_wherevalue.append(action)

        tmp_action = torch.tensor(action_wherevalue,
                                  dtype=torch.float).to(device)
        tmp_score_where_value = torch.transpose(score_where_value, 2, 3)
        RL_loss_where_value = torch.mean(-torch.log(
            torch.sum(
                torch.softmax(tmp_score_where_value, dim=-1) * tmp_action,
                dim=-1)) * torch.unsqueeze(torch.unsqueeze(
                    torch.tensor(batch_reward_all, dtype=torch.float), dim=-1),
                                           dim=-1).to(device))

        # RL
        # where op
        action_whereop = []
        for one in whereop_random_v2:
            if len(one) > 0:
                action = []
                for i in range(4):
                    if len(one) > i:
                        op = one[0]
                        tmp_start = [0] * score_where_op.shape[2]
                        tmp_start[op] = 1
                        action.append(tmp_start)
                    else:
                        action.append([0.01] * score_where_op.shape[2])
                action_whereop.append(action)
            else:
                action = []
                for i in range(4):
                    action.append([0.01] * score_where_op.shape[2])
                action_whereop.append(action)

        tmp_action_op = torch.tensor(action_whereop,
                                     dtype=torch.float).to(device)
        RL_loss_where_op = torch.mean(-torch.log(
            torch.sum(torch.softmax(score_where_op, dim=-1) * tmp_action_op,
                      dim=-1)) * torch.unsqueeze(torch.tensor(
                          batch_reward_all, dtype=torch.float),
                                                 dim=-1).to(device))

        loss += RL_loss_where_number
        loss += RL_loss_select_agg
        loss += RL_loss_select_column
        loss += RL_loss_where_op
        loss += RL_loss_where_column
        loss += RL_loss_where_value
        # Calculate gradient
        if batch_index % accumulate_gradients == 0:  # mode
            # at start, perform zero_grad
            opt.zero_grad()
            if opt_bert:
                opt_bert.zero_grad()
            loss.backward()
            if accumulate_gradients == 1:
                opt.step()
                if opt_bert:
                    opt_bert.step()
        elif batch_index % accumulate_gradients == (accumulate_gradients - 1):
            # at the final, take step with accumulated graident
            loss.backward()
            opt.step()
            if opt_bert:
                opt_bert.step()
        else:
            # at intermediate stage, just accumulates the gradients
            loss.backward()

        # Cacluate accuracy
        count_sc1_list, count_sa1_list, count_wn1_list, \
        count_wc1_list, count_wo1_list, \
        count_wvi1_list, count_wv1_list = get_count_sw_list(gt_select_column, gt_select_agg, gt_wherenumber, gt_wherecolumn, g_wo, gt_wherevalueindex,
                                                                   pr_sc, pr_sa, pr_wn, pr_wc, pr_wo, pr_wvi,
                                                                   sql, pred_sql_int,
                                                                   mode='train')

        count_lx1_list = get_count_lx_list(count_sc1_list, count_sa1_list,
                                           count_wn1_list, count_wc1_list,
                                           count_wo1_list, count_wv1_list)
        # lx stands for logical form accuracy

        # Execution accuracy test.
        count_x1_list, g_ans, pr_ans = get_count_x_list(
            engine, table, gt_select_column, gt_select_agg, sql, pr_sc, pr_sa,
            pred_sql_int)

        # statistics
        ave_loss += loss.item()

        # count
        count_sc += sum(count_sc1_list)
        count_sa += sum(count_sa1_list)
        count_wn += sum(count_wn1_list)
        count_wc += sum(count_wc1_list)
        count_wo += sum(count_wo1_list)
        count_wvi += sum(count_wvi1_list)
        count_wv += sum(count_wv1_list)
        count_logic_form_acc += sum(count_lx1_list)
        count_execute_acc += sum(count_x1_list)

    ave_loss /= count
    acc_sc = count_sc / count
    acc_sa = count_sa / count
    acc_wn = count_wn / count
    acc_wc = count_wc / count
    acc_wo = count_wo / count
    acc_wvi = count_wv / count
    acc_wv = count_wv / count
    acc_lx = count_logic_form_acc / count
    acc_x = count_execute_acc / count

    acc = [
        ave_loss, acc_sc, acc_sa, acc_wn, acc_wc, acc_wo, acc_wvi, acc_wv,
        acc_lx, acc_x
    ]

    aux_out = 1

    return acc, aux_out
Exemple #3
0
import os

from sqlnet.dbengine import DBEngine
import json
import time

path_db = './wikisql/data/tianchi/'
dset_name = 'val'

# The engine for seaching results
engine = DBEngine(os.path.join(path_db, f"{dset_name}.db"))

# Return the results queried
query = lambda sql_tmp: engine.execute(sql_tmp['table_id'], sql_tmp['sql']['sel'], sql_tmp['sql']['agg'], sql_tmp['sql']['conds'], sql_tmp['sql']['cond_conn_op'])

fname = os.path.join(path_db, f"{dset_name}.json")
with open(fname, encoding='utf-8') as fs:
    total_count = 0
    start = time.time()
    for line in fs:
        record = json.loads(line)
        break
    while total_count < 10:
        res = query(record)
        total_count += 1
    print('%d times of invoking cost %.2fs.' % (total_count, time.time() - start))
    # print('The number of empty results is %d, the number of result is 0 %d' % (count, tmp))
Exemple #4
0
def train(train_loader,
          train_table,
          model,
          model_bert,
          opt,
          bert_config,
          tokenizer,
          max_seq_length,
          num_target_layers,
          accumulate_gradients=1,
          check_grad=True,
          st_pos=0,
          opt_bert=None,
          path_db=None,
          dset_name='train'):

    ave_loss = 0
    count = 0  # count the # of examples
    count_sc = 0  # count the # of correct predictions of select column
    count_sa = 0  # of selectd aggregation
    count_wn = 0  # of where number
    count_wc = 0  # of where column
    count_wo = 0  # of where operator
    count_wv = 0  # of where-value
    count_wvi = 0  # of where-value index (on question tokens)
    count_logic_form_acc = 0  # of logical form acc
    count_execute_acc = 0  # of execution acc

    # Engine for SQL querying.
    engine = DBEngine(os.path.join(path_db, f"{dset_name}.db"))

    explored_data_list = []

    for batch_index, batch_data in enumerate(train_loader):

        count += len(batch_data)

        if count < st_pos:
            continue
        # Get fields
        question, question_token, sql, sql_text, sql_t, table, header_token, header \
            = get_fields(batch_data, train_table, no_hs_t=True, no_sql_t=True)

        len_question_bert, len_header_token, number_header, \
        question_token_bert, token_to_berttoken_index, berttoken_to_token_index \
            = get_wemb_bert_v2(bert_config, model_bert, tokenizer, question_token, header, max_seq_length,
                               num_out_layers_n=num_target_layers, num_out_layers_h=num_target_layers)

        # select column
        def equal_in(cell_, pred_answer_column):
            for cell in pred_answer_column:
                if cell == cell_:
                    return True
            return False

        # RL
        # where number

        def list_in_list(small, big):
            for cell in big:
                try:
                    cell_ = int(cell)
                    if cell_ in small:
                        return True
                    cell_ = float(cell)
                    if cell_ in small:
                        return True
                except:
                    cell_ = str(cell)
                    if cell_.lower() in small:
                        return True

            for cell in small:
                try:
                    cell_ = int(cell)
                    if cell_ in big:
                        return True
                    cell_ = float(cell)
                    if cell_ in big:
                        return True
                except:
                    cell_ = str(cell)
                    if cell_.lower() in big:
                        return True

            return False

        def list_exact_match(input1, input2):
            tmp1 = [str(item) for item in input1]
            tmp2 = [str(item) for item in input2]
            if sorted(tmp1) == sorted(tmp2):
                return True
            return False

        def contains(big_list, small_list):
            return set(small_list).issubset(set(big_list))

        for i in range(len(batch_data)):
            print(sql[i])
            explored_data = {}
            explore_count = 0
            breakall = False

            reward_where = False
            reward_where_cond1 = 0
            reward_where_cond2 = 0
            reward_where_cond3 = 0
            reward_where_cond4 = 0

            len_question = len(question_token[i])

            gt_answer_list = batch_data[i]["answer"]

            where_number_random = 0
            select_column_random = -1
            select_agg_random = 0
            col = 0
            op = 0
            start = 0
            end = 0

            col2 = 0
            op2 = 0
            start2 = 0
            end2 = 0

            col3 = 0
            op3 = 0
            start3 = 0
            end3 = 0

            col4 = 0
            op4 = 0
            start4 = 0
            end4 = 0

            saved_where_number = -1
            saved_col1 = -1
            saved_op1 = -1
            saved_start1 = -1
            saved_end1 = -1
            saved = False

            # select_column_random = 3
            # where_number_random = 1
            while True:

                final_select_agg = None
                final_select_column = None
                final_conds = []

                tmp_conds = []
                if where_number_random == 0:
                    select_column_random += 1
                    if select_column_random == len(header[i]):
                        select_agg_random += 1
                        select_column_random = 0
                    if select_agg_random == 6:
                        where_number_random += 1
                        select_agg_random = 0

                if where_number_random == 1:

                    end += 1
                    if end >= len_question + 1:
                        start += 1
                        end = start + 1
                    if start >= len_question:
                        op += 1
                        start = 0
                    if op == 3:
                        col += 1
                        op = 0
                    if col == len(header[i]):
                        select_column_random += 1
                        col = 0
                    if select_column_random == len(header[i]):
                        select_agg_random += 1
                        select_column_random = 0
                    if select_agg_random == 6:
                        where_number_random += 1
                        select_agg_random = 0

                if saved == True and where_number_random == 2:
                    col = saved_col1
                    start = saved_start1
                    end = saved_end1
                    op = saved_op1
                    end2 += 1
                    if end2 >= len_question + 1:
                        start2 += 1
                        end2 = start2 + 1
                    if start2 >= len_question:
                        op2 += 1
                        start2 = 0
                    if op2 == 3:
                        col2 += 1
                        op2 = 0
                    if col2 == len(header[i]):
                        select_column_random += 1
                        col2 = 0
                    if select_column_random == len(header[i]):
                        select_agg_random += 1
                        select_column_random = 0
                    if select_agg_random == 6:
                        where_number_random += 1
                        select_agg_random = 0

                if saved == False and where_number_random == 2:
                    end += 1
                    if end >= len_question + 1:
                        start += 1
                        end = start + 1
                    if start >= len_question:
                        op += 1
                        start = 0
                    if op == 3:
                        col += 1
                        op = 0
                    if col == len(header[i]):
                        end2 += 1
                        col = 0
                    if end2 >= len_question + 1:
                        start2 += 1
                        end2 = start2 + 1
                    if start2 >= len_question:
                        op2 += 1
                        start2 = 0
                    if op2 == 3:
                        col2 += 1
                        op2 = 0
                    if col2 == len(header[i]):
                        select_column_random += 1
                        col2 = 0
                    if select_column_random == len(header[i]):
                        select_agg_random += 1
                        select_column_random = 0
                    if select_agg_random == 6:
                        where_number_random += 1
                        select_agg_random = 0

                if where_number_random == 3:
                    # break #TODO
                    end += 1
                    if end >= len_question + 1:
                        start += 1
                        end = start + 1
                    if start >= len_question:
                        op += 1
                        start = 0
                    if op == 3:
                        col += 1
                        op = 0
                    if col == len(header[i]):
                        end2 += 1
                        col = 0
                    if end2 >= len_question + 1:
                        start2 += 1
                        end2 = start2 + 1
                    if start2 >= len_question:
                        op2 += 1
                        start2 = 0
                    if op2 == 3:
                        col2 += 1
                        op2 = 0
                    if col2 == len(header[i]):
                        end3 += 1
                        col2 = 0
                    if end3 >= len_question + 1:
                        start3 += 1
                        end3 = start3 + 1
                    if start3 >= len_question:
                        op3 += 1
                        start3 = 0
                    if op3 == 3:
                        col3 += 1
                        op3 = 0
                    if col3 == len(header[i]):
                        select_column_random += 1
                        col3 = 0
                    if select_column_random == len(header[i]):
                        select_agg_random += 1
                        select_column_random = 0
                    if select_agg_random == 6:
                        where_number_random += 1
                        select_agg_random = 0

                if where_number_random == 4:
                    end += 1
                    if end >= len_question + 1:
                        start += 1
                        end = start + 1
                    if start >= len_question:
                        op += 1
                        start = 0
                    if op == 3:
                        col += 1
                        op = 0
                    if col == len(header[i]):
                        end2 += 1
                        col = 0
                    if end2 >= len_question + 1:
                        start2 += 1
                        end2 = start2 + 1
                    if start2 >= len_question:
                        op2 += 1
                        start2 = 0
                    if op2 == 3:
                        col2 += 1
                        op2 = 0
                    if col2 == len(header[i]):
                        end3 += 1
                        col2 = 0
                    if end3 >= len_question + 1:
                        start3 += 1
                        end3 = start3 + 1
                    if start3 >= len_question:
                        op3 += 1
                        start3 = 0
                    if op3 == 3:
                        col3 += 1
                        op3 = 0
                    if col3 == len(header[i]):
                        end4 += 1
                        col3 = 0
                    if end4 >= len_question + 1:
                        start4 += 1
                        end4 = start4 + 1
                    if start4 >= len_question:
                        op4 += 1
                        start4 = 0
                    if op4 == 3:
                        col4 += 1
                        op4 = 0
                    if col4 == len(header[i]):
                        select_column_random += 1
                        col4 = 0
                    if select_column_random == len(header[i]):
                        select_agg_random += 1
                        select_column_random = 0
                    if select_agg_random == 6:
                        where_number_random += 1
                        select_agg_random = 0

                if where_number_random == 1:
                    cond = []
                    cond.append(col)
                    cond.append(op)

                    pr_wv_str = question_token[i][start:end]
                    cond_value = merge_wv_t1_eng(pr_wv_str, question[i])
                    try:
                        cond_value_ = float(cond_value)
                    except:
                        cond_value_ = cond_value
                    # if type(cond_value_) == str:  # and random.randint(1,2)==1:
                    #     op = 0
                    cond.append(cond_value_)
                    tmp_conds.append(cond)

                if where_number_random == 2:
                    cond = []
                    cond.append(col)
                    cond.append(op)
                    pr_wv_str = question_token[i][start:end]
                    cond_value = merge_wv_t1_eng(pr_wv_str, question[i])
                    try:
                        cond_value_ = float(cond_value)
                    except:
                        cond_value_ = cond_value
                    cond.append(cond_value_)
                    tmp_conds.append(cond)

                    cond = []
                    cond.append(col2)
                    cond.append(op2)
                    pr_wv_str = question_token[i][start2:end2]
                    cond_value = merge_wv_t1_eng(pr_wv_str, question[i])
                    try:
                        cond_value_ = float(cond_value)
                    except:
                        cond_value_ = cond_value
                    cond.append(cond_value_)
                    tmp_conds.append(cond)

                if where_number_random == 3:
                    cond = []
                    cond.append(col)
                    cond.append(op)
                    pr_wv_str = question_token[i][start:end]
                    cond_value = merge_wv_t1_eng(pr_wv_str, question[i])
                    try:
                        cond_value_ = float(cond_value)
                    except:
                        cond_value_ = cond_value
                    cond.append(cond_value_)
                    tmp_conds.append(cond)

                    cond = []
                    cond.append(col2)
                    cond.append(op2)
                    pr_wv_str = question_token[i][start2:end2]
                    cond_value = merge_wv_t1_eng(pr_wv_str, question[i])
                    try:
                        cond_value_ = float(cond_value)
                    except:
                        cond_value_ = cond_value
                    cond.append(cond_value_)
                    tmp_conds.append(cond)

                    cond = []
                    cond.append(col3)
                    cond.append(op3)
                    pr_wv_str = question_token[i][start3:end3]
                    cond_value = merge_wv_t1_eng(pr_wv_str, question[i])
                    try:
                        cond_value_ = float(cond_value)
                    except:
                        cond_value_ = cond_value
                    cond.append(cond_value_)
                    tmp_conds.append(cond)

                if where_number_random == 4:
                    cond = []
                    cond.append(col)
                    cond.append(op)
                    pr_wv_str = question_token[i][start:end]
                    cond_value = merge_wv_t1_eng(pr_wv_str, question[i])
                    try:
                        cond_value_ = float(cond_value)
                    except:
                        cond_value_ = cond_value
                    cond.append(cond_value_)
                    tmp_conds.append(cond)

                    cond = []
                    cond.append(col2)
                    cond.append(op2)
                    pr_wv_str = question_token[i][start2:end2]
                    cond_value = merge_wv_t1_eng(pr_wv_str, question[i])
                    try:
                        cond_value_ = float(cond_value)
                    except:
                        cond_value_ = cond_value
                    cond.append(cond_value_)
                    tmp_conds.append(cond)

                    cond = []
                    cond.append(col3)
                    cond.append(op3)
                    pr_wv_str = question_token[i][start3:end3]
                    cond_value = merge_wv_t1_eng(pr_wv_str, question[i])
                    try:
                        cond_value_ = float(cond_value)
                    except:
                        cond_value_ = cond_value
                    cond.append(cond_value_)
                    tmp_conds.append(cond)

                    cond = []
                    cond.append(col4)
                    cond.append(op4)
                    pr_wv_str = question_token[i][start4:end4]
                    cond_value = merge_wv_t1_eng(pr_wv_str, question[i])
                    try:
                        cond_value_ = float(cond_value)
                    except:
                        cond_value_ = cond_value
                    cond.append(cond_value_)
                    tmp_conds.append(cond)

                # print(select_column_random, select_agg_random, tmp_conds)
                pred_answer_column = engine.execute(table[i]['id'],
                                                    select_column_random,
                                                    select_agg_random,
                                                    tmp_conds)

                explore_count += 1
                if explore_count % 100000 == 0:
                    print(explore_count)
                if explore_count > 500000:
                    break

                exact_match = list_exact_match(gt_answer_list,
                                               pred_answer_column)
                if where_number_random==1 and not exact_match and contains(pred_answer_column,gt_answer_list)\
                        and saved_where_number==-1 and op==0: # 可能会导致1condition的错过
                    # where_number_random = 2 # 可能会导致1condition的错过
                    saved_start1 = start
                    saved_end1 = end
                    saved_col1 = col
                    saved_op1 = op
                    saved = True

                # answer in
                if exact_match:
                    if pred_answer_column == [None]:
                        break

                    print("explore sql", select_column_random,
                          select_agg_random, tmp_conds)

                    if type(gt_answer_list[0]
                            ) == str and select_agg_random != 0:
                        print("fake sql")
                    elif where_number_random == 1 and type(tmp_conds[0][2])==str and tmp_conds[0][1]!=0 or\
                        where_number_random == 2 and type(tmp_conds[1][2]) == str and tmp_conds[1][1] != 0 or\
                        where_number_random == 3 and type(tmp_conds[2][2]) == str and tmp_conds[2][1] != 0 or \
                        where_number_random == 4 and type(tmp_conds[3][2]) == str and tmp_conds[3][1] != 0:
                        print("fake sql")
                    else:
                        # print("explore answer", pred_answer_column)
                        if type(pred_answer_column[0]) == int or type(
                                pred_answer_column[0]) == float:
                            final_select_agg = select_agg_random
                        else:
                            final_select_agg = 0

                        if final_select_agg == 0:
                            pred_answer_column2 = engine.execute(
                                table[i]['id'], select_column_random, 0, [])
                            for cell in gt_answer_list:
                                if cell in pred_answer_column2 or equal_in(
                                        cell, pred_answer_column2):
                                    final_select_column = select_column_random
                                    break
                        else:
                            final_select_column = select_column_random

                        if final_select_agg == 0:
                            pred_answer_column3 = engine.execute(
                                table[i]['id'], "*", 0, tmp_conds)
                            # answer in
                            for cell in gt_answer_list:
                                if cell in pred_answer_column3 or equal_in(
                                        cell, pred_answer_column3):
                                    reward_where = True
                                    break
                        else:
                            reward_where = True

                        # same column: word in question and where column
                        if where_number_random >= 1:
                            pred_answer_column4 = engine.execute(
                                table[i]['id'], tmp_conds[0][0], 0, [])
                            for cell in pred_answer_column4:
                                try:
                                    cell_ = str(float(cell))
                                    if cell_ in question[i].lower():
                                        reward_where_cond1 += 0.1
                                        break
                                    cell_ = str(int(cell))
                                    if cell_ in question[i].lower():
                                        reward_where_cond1 += 0.1
                                        break
                                except:
                                    cell = str(cell)
                                    if cell in question[i].lower():
                                        reward_where_cond1 += 0.1
                                        break
                                # same column: where value and where column
                            value = tmp_conds[0][2]
                            if value in pred_answer_column4:
                                reward_where_cond1 += 0.1
                            try:
                                value = float(tmp_conds[0][2])
                                if value in pred_answer_column4:
                                    reward_where_cond1 += 0.1
                            except:
                                pass
                            try:
                                value = int(tmp_conds[0][2])
                                if value in pred_answer_column4:
                                    reward_where_cond1 += 0.1
                            except:
                                pass
                            try:
                                value = str(int(tmp_conds[0][2]))
                                if value in pred_answer_column4:
                                    reward_where_cond1 += 0.1
                            except:
                                pass
                            try:
                                value = str(float(tmp_conds[0][2]))
                                if value in pred_answer_column4:
                                    reward_where_cond1 += 0.1
                            except:
                                pass

                        # same column: word in question and where column
                        if where_number_random >= 2:
                            pred_answer_column4 = engine.execute(
                                table[i]['id'], tmp_conds[1][0], 0, [])
                            for cell in pred_answer_column4:
                                try:
                                    cell_ = str(float(cell))
                                    if cell_ in question[i].lower():
                                        reward_where_cond2 += 0.1
                                        break
                                    cell_ = str(int(cell))
                                    if cell_ in question[i].lower():
                                        reward_where_cond2 += 0.1
                                        break
                                except:
                                    cell = str(cell)
                                    if cell in question[i].lower():
                                        reward_where_cond2 += 0.1
                                        break
                                # same column: where value and where column
                            value = tmp_conds[1][2]
                            if value in pred_answer_column4:
                                reward_where_cond2 += 0.1
                            try:
                                value = float(tmp_conds[1][2])
                                if value in pred_answer_column4:
                                    reward_where_cond2 += 0.1
                            except:
                                pass
                            try:
                                value = int(tmp_conds[1][2])
                                if value in pred_answer_column4:
                                    reward_where_cond2 += 0.1
                            except:
                                pass
                            try:
                                value = str(int(tmp_conds[1][2]))
                                if value in pred_answer_column4:
                                    reward_where_cond2 += 0.1
                            except:
                                pass
                            try:
                                value = str(float(tmp_conds[1][2]))
                                if value in pred_answer_column4:
                                    reward_where_cond2 += 0.1
                            except:
                                pass

                        # same column: word in question and where column
                        if where_number_random >= 3:
                            pred_answer_column4 = engine.execute(
                                table[i]['id'], tmp_conds[2][0], 0, [])
                            for cell in pred_answer_column4:
                                try:
                                    cell_ = str(float(cell))
                                    if cell_ in question[i].lower():
                                        reward_where_cond3 += 0.1
                                        break
                                    cell_ = str(int(cell))
                                    if cell_ in question[i].lower():
                                        reward_where_cond3 += 0.1
                                        break
                                except:
                                    cell = str(cell)
                                    if cell in question[i].lower():
                                        reward_where_cond3 += 0.1
                                        break
                                # same column: where value and where column
                            value = tmp_conds[2][2]
                            if value in pred_answer_column4:
                                reward_where_cond3 += 0.1
                            try:
                                value = float(tmp_conds[2][2])
                                if value in pred_answer_column4:
                                    reward_where_cond3 += 0.1
                            except:
                                pass
                            try:
                                value = int(tmp_conds[2][2])
                                if value in pred_answer_column4:
                                    reward_where_cond3 += 0.1
                            except:
                                pass
                            try:
                                value = str(int(tmp_conds[2][2]))
                                if value in pred_answer_column4:
                                    reward_where_cond3 += 0.1
                            except:
                                pass
                            try:
                                value = str(float(tmp_conds[2][2]))
                                if value in pred_answer_column4:
                                    reward_where_cond3 += 0.1
                            except:
                                pass

                        # same column: word in question and where column
                        if where_number_random >= 4:
                            pred_answer_column4 = engine.execute(
                                table[i]['id'], tmp_conds[3][0], 0, [])
                            for cell in pred_answer_column4:
                                try:
                                    cell_ = str(float(cell))
                                    if cell_ in question[i].lower():
                                        reward_where_cond4 += 0.1
                                        break
                                    cell_ = str(int(cell))
                                    if cell_ in question[i].lower():
                                        reward_where_cond4 += 0.1
                                        break
                                except:
                                    cell = str(cell)
                                    if cell in question[i].lower():
                                        reward_where_cond4 += 0.1
                                        break
                                # same column: where value and where column
                            value = tmp_conds[3][2]
                            if value in pred_answer_column4:
                                reward_where_cond4 += 0.1
                            try:
                                value = float(tmp_conds[3][2])
                                if value in pred_answer_column4:
                                    reward_where_cond4 += 0.1
                            except:
                                pass
                            try:
                                value = int(tmp_conds[3][2])
                                if value in pred_answer_column4:
                                    reward_where_cond4 += 0.1
                            except:
                                pass
                            try:
                                value = str(int(tmp_conds[3][2]))
                                if value in pred_answer_column4:
                                    reward_where_cond4 += 0.1
                            except:
                                pass
                            try:
                                value = str(float(tmp_conds[3][2]))
                                if value in pred_answer_column4:
                                    reward_where_cond4 += 0.1
                            except:
                                pass
                        """ 有问题,cond op 只能强制为 = 因为 > 或 < 不在一行
                        if where_number_random >= 1 and final_select_agg==0:
                            tmp_conds2 = tmp_conds
                            tmp_conds2[0][1] = 0  # EQUAL
                            pred_answer_column5 = engine.execute(table[i]['id'], tmp_conds2[0][0], 0, tmp_conds2)
                            # same row: the answer and this cell
                            for row in table[i]["rows"]:
                                if list_in_list(pred_answer_column5, row) and list_in_list(gt_answer_list, row):
                                    reward_where_cond1 += 0.1
                                    break
    
                        if where_number_random >= 2 and final_select_agg==0:
                            tmp_conds2 = tmp_conds
                            tmp_conds2[0][1] = 0  # EQUAL
                            tmp_conds2[1][1] = 0  # EQUAL
                            pred_answer_column5 = engine.execute(table[i]['id'], tmp_conds2[1][0], 0, tmp_conds2)
                            # same row: the answer and this cell
                            for row in table[i]["rows"]:
                                if list_in_list(pred_answer_column5, row) and list_in_list(gt_answer_list, row):
                                    reward_where_cond2 += 0.1
                                    break
    
                        if where_number_random >= 3 and final_select_agg==0:
                            tmp_conds2 = tmp_conds
                            tmp_conds2[0][1] = 0  # EQUAL
                            tmp_conds2[1][1] = 0  # EQUAL
                            tmp_conds2[2][1] = 0  # EQUAL
                            pred_answer_column5 = engine.execute(table[i]['id'], tmp_conds2[2][0], 0, tmp_conds2)
                            # same row: the answer and this cell
                            for row in table[i]["rows"]:
                                if list_in_list(pred_answer_column5, row) and list_in_list(gt_answer_list, row):
                                    reward_where_cond3 += 0.1
                                    break
    
                        if where_number_random >= 4 and final_select_agg==0:
                            tmp_conds2 = tmp_conds
                            tmp_conds2[0][1] = 0  # EQUAL
                            tmp_conds2[1][1] = 0  # EQUAL
                            tmp_conds2[2][1] = 0  # EQUAL
                            tmp_conds2[3][1] = 0  # EQUAL
                            pred_answer_column5 = engine.execute(table[i]['id'], tmp_conds2[3][0], 0, tmp_conds2)
                            # same row: the answer and this cell
                            for row in table[i]["rows"]:
                                if list_in_list(pred_answer_column5, row) and list_in_list(gt_answer_list, row):
                                    reward_where_cond4 += 0.1
                                    break
                        """

                        if reward_where_cond1 >= 0.2 and reward_where == True and where_number_random >= 1:
                            final_conds.append(tmp_conds[0])
                        if reward_where_cond2 >= 0.2 and reward_where == True and where_number_random >= 2:
                            final_conds.append(tmp_conds[1])
                        if reward_where_cond3 >= 0.2 and reward_where == True and where_number_random >= 3:
                            final_conds.append(tmp_conds[2])
                        if reward_where_cond4 >= 0.2 and reward_where == True and where_number_random >= 4:
                            final_conds.append(tmp_conds[3])
                        if final_select_agg != None and final_select_column != None and (
                                where_number_random == 1 and len(final_conds)
                                == 1 or where_number_random == 2
                                and len(final_conds) == 2
                                or where_number_random == 3
                                and len(final_conds) == 3
                                or where_number_random == 4
                                and len(final_conds) == 4):
                            break
                        if final_select_agg != None and final_select_column != None and where_number_random == 0:
                            break

            if final_select_column != None:
                explored_data["sel"] = final_select_column
                explored_data["agg"] = final_select_agg
                explored_data["conds"] = final_conds
                explored_data_list.append(explored_data)
                print(len(explored_data_list))
                one_data = batch_data[i]
                one_data["sql"] = explored_data
                one_data["query"] = explored_data
                f = open("gen_data.jsonl", mode="a", encoding="utf-8")
                json.dump(one_data, f)
                f.write('\n')
                f.close()

    print("Done")
    ave_loss /= count
    acc_sc = count_sc / count
    acc_sa = count_sa / count
    acc_wn = count_wn / count
    acc_wc = count_wc / count
    acc_wo = count_wo / count
    acc_wvi = count_wv / count
    acc_wv = count_wv / count
    acc_lx = count_logic_form_acc / count
    acc_x = count_execute_acc / count

    acc = [
        ave_loss, acc_sc, acc_sa, acc_wn, acc_wc, acc_wo, acc_wvi, acc_wv,
        acc_lx, acc_x
    ]

    aux_out = 1

    return acc, aux_out
Exemple #5
0
def test(data_loader,
         data_table,
         model,
         model_bert,
         bert_config,
         tokenizer,
         max_seq_length,
         num_target_layers,
         detail=False,
         st_pos=0,
         cnt_tot=1,
         EG=False,
         beam_size=4,
         path_db=None,
         dset_name='test'):
    model.eval()
    model_bert.eval()

    print('g_scn/wc/wn/wo dev/test不监督')
    cnt = 0

    engine = DBEngine(os.path.join(path_db, f"{dset_name}.db"))
    results = []
    total = 0
    one_acc_num, tot_acc_num, ex_acc_num = 0.0, 0.0, 0.0
    for iB, t in enumerate(data_loader):

        cnt += len(t)
        if cnt < st_pos:
            continue
        # Get fields
        nlu, nlu_t, sql_i, sql_t, tb, hs_t, hds = get_fields(t,
                                                             data_table,
                                                             no_hs_t=True,
                                                             no_sql_t=True)

        g_sc, g_sa, g_sop, g_wn, g_wc, g_wo, g_wv, g_sel_num_seq, g_sel_ag_seq, conds = get_g(
            sql_i)

        g_wvi_corenlp = get_g_wvi_corenlp(t)

        wemb_n, wemb_h, l_n, l_hpu, l_hs, \
        nlu_tt, t_to_tt_idx, tt_to_t_idx \
            = get_wemb_bert(bert_config, model_bert, tokenizer, nlu_t, hds, max_seq_length,
                            num_out_layers_n=num_target_layers, num_out_layers_h=num_target_layers)

        try:
            g_wvi = get_g_wvi_bert_from_g_wvi_corenlp(t_to_tt_idx,
                                                      g_wvi_corenlp)
        except:
            # Exception happens when where-condition is not found in nlu_tt.
            # In this case, that train example is not used.
            # During test, that example considered as wrongly answered.
            for b in range(len(nlu)):
                results1 = {}
                results1["error"] = "Skip happened"
                results1["nlu"] = nlu[b]
                results1["table_id"] = tb[b]["id"]
                results.append(results1)
            continue

        # model specific part
        # score
        if not EG:
            # No Execution guided decoding
            s_scn, s_sc, s_sa, s_sop, s_wn, s_wc, s_wo, s_wv = model(
                wemb_n, l_n, wemb_h, l_hpu, l_hs)

            # prediction
            score = []
            score.append(s_scn)
            score.append(s_sc)
            score.append(s_sa)
            score.append(s_sop)
            tuple(score)
            pr_sql_i1 = model.gen_query(score, nlu_tt, nlu)

            pr_wn, pr_wc, pr_sop, pr_wo, pr_wvi = pred_sw_se(
                s_sop, s_wn, s_wc, s_wo, s_wv)
            pr_wv_str, pr_wv_str_wp = convert_pr_wvi_to_string(
                pr_wvi, nlu_t, nlu_tt, tt_to_t_idx, nlu)  # 映射到字符串

            pr_wc_sorted = sort_pr_wc(pr_wc, g_wc)
            pr_sql_i = generate_sql_i(pr_sql_i1, pr_wn, pr_wc_sorted, pr_wo,
                                      pr_wv_str, nlu)

        else:
            # Execution guided decoding
            prob_sca, prob_w, prob_wn_w, pr_sc, pr_sa, pr_wn, pr_sql_i = model.beam_forward(
                wemb_n,
                l_n,
                wemb_h,
                l_hpu,
                l_hs,
                engine,
                tb,
                nlu_t,
                nlu_tt,
                tt_to_t_idx,
                nlu,
                beam_size=beam_size)
            # sort and generate
            pr_wc, pr_wo, pr_wv, pr_sql_i = sort_and_generate_pr_w(pr_sql_i)

            # Follosing variables are just for the consistency with no-EG case.

        # # Saving for the official evaluation later.
        for b, pr_sql_i1 in enumerate(pr_sql_i):
            results1 = {}
            results1["query"] = pr_sql_i1
            results1["table_id"] = tb[b]["id"]
            results1["nlu"] = nlu[b]
            results.append(results1)

        one_err, tot_err = model.check_acc(nlu, pr_sql_i, sql_i)
        one_acc_num += (len(pr_sql_i) - one_err)
        tot_acc_num += (len(pr_sql_i) - tot_err)
        total += len(pr_sql_i)

        # Execution Accuracy
        table_ids = []
        for x in range(len(tb)):
            table_ids.append(tb[x]['id'])

        for sql_gt, sql_pred, tid in zip(sql_i, pr_sql_i, table_ids):
            ret_gt = engine.execute(tid, sql_gt['sel'], sql_gt['agg'],
                                    sql_gt['conds'], sql_gt['cond_conn_op'])
            try:
                ret_pred = engine.execute(tid, sql_pred['sel'],
                                          sql_pred['agg'], sql_pred['conds'],
                                          sql_pred['cond_conn_op'])
            except:
                ret_pred = None
            ex_acc_num += (ret_gt == ret_pred)

    return ((one_acc_num / total), (tot_acc_num / total),
            ex_acc_num / total), results