def epoch_acc(batch_size, table_data, sql_data, db_path, cs, out=False): with open('../submit/results_val.jsonl') as inf: sqlOutdata = [] for idx, line in enumerate(inf): _sql = json.loads(line.strip()) _sql['sql'] = _sql['query'] _sql['question'] = _sql['nlu'] # print(_sql) if ('error') in _sql: sqlOutdata.append(sqlOutdata[idx - 1]) else: sqlOutdata.append(_sql) print(len(sql_data)) print(len(sqlOutdata)) engine = DBEngine(db_path) perm = list(range(len(sqlOutdata))) badcase = 0 one_acc_num, tot_acc_num, ex_acc_num = 0.0, 0.0, 0.0 for st in tqdm(range(len(sql_data) // batch_size + 1)): ed = (st + 1) * batch_size if (st + 1) * batch_size < len( perm) else len(perm) st = st * batch_size q_seq, gt_sel_num, col_seq, col_num, ans_seq, gt_cond_seq, gt_type, raw_data = \ to_batch_seq(sql_data, table_data, perm, st, ed, ret_vis_data=True) query_gt, table_ids = to_batch_query(sql_data, perm, st, ed) # query_gt: ground truth of sql, data['sql'], containing sel, agg, conds:{sel, op, value} raw_q_seq = [x[0] for x in raw_data] # original question # try: pred_queriesc, allfsc = genByout(sqlOutdata[st:ed], table_ids, raw_q_seq, gt_type, 'val', cs) one_err, tot_err = check_acc(raw_data, pred_queriesc, query_gt, allfsc) # except: # badcase += 1 # print 'badcase', badcase # continue one_acc_num += (ed - st - one_err) tot_acc_num += (ed - st - tot_err) # Execution Accuracy for sql_gt, sql_pred, tid in zip(query_gt, pred_queriesc, table_ids): ret_gt = engine.execute(tid, sql_gt['sel'], sql_gt['agg'], sql_gt['conds'], sql_gt['cond_conn_op']) try: ret_pred = engine.execute(tid, sql_pred['sel'], sql_pred['agg'], sql_pred['conds'], sql_pred['cond_conn_op']) except: ret_pred = None ex_acc_num += (ret_gt == ret_pred) return one_acc_num / len(sql_data), tot_acc_num / len( sql_data), ex_acc_num / len(sql_data)
def train(train_loader, train_table, model, model_bert, opt, bert_config, tokenizer, max_seq_length, num_target_layers, accumulate_gradients=1, check_grad=True, st_pos=0, opt_bert=None, path_db=None, dset_name='train'): model.train() model_bert.train() ave_loss = 0 count = 0 # count the # of examples count_sc = 0 # count the # of correct predictions of select column count_sa = 0 # of selectd aggregation count_wn = 0 # of where number count_wc = 0 # of where column count_wo = 0 # of where operator count_wv = 0 # of where-value count_wvi = 0 # of where-value index (on question tokens) count_logic_form_acc = 0 # of logical form acc count_execute_acc = 0 # of execution acc # Engine for SQL querying. engine = DBEngine(os.path.join(path_db, f"{dset_name}.db")) for batch_index, batch_data in enumerate(train_loader): count += len(batch_data) if count < st_pos: continue # Get fields question, question_token, sql, sql_text, sql_t, table, header_token, header \ = get_fields(batch_data, train_table, no_hs_t=True, no_sql_t=True) gt_select_column, gt_select_agg, gt_wherenumber, gt_wherecolumn, g_wo, g_wv = get_gt( sql) # get ground truth where-value index under CoreNLP tokenization scheme. It's done already on trainset. gt_wherevalueindex_corenlp = get_gt_wherevalueindex_corenlp(batch_data) emb_question, emb_header, len_question, len_header_token, number_header, \ question_token_bert, token_to_berttoken_index, berttoken_to_token_index \ = get_wemb_bert(bert_config, model_bert, tokenizer, question_token, header, max_seq_length, num_out_layers_n=num_target_layers, num_out_layers_h=num_target_layers) try: # gt_wherevalueindex = get_gt_wherevalueindex_bert_from_gt_wherevalueindex_corenlp( token_to_berttoken_index, gt_wherevalueindex_corenlp) except: # Exception happens when where-condition is not found in nlu_tt. # In this case, that train example is not used. # During test, that example considered as wrongly answered. # e.g. train: 32. continue # score # gt_wherevalueindex = start_index, end_index # score_where_value: [batch,4,question_len,2] score_select_column, score_select_agg, score_where_number, score_where_column, score_where_op, score_where_value, \ score_select_column_softmax, score_select_agg_softmax, score_where_number_softmax,\ score_where_column_softmax, score_whereop_softmax, score_where_value_softmax \ = model(emb_question, len_question, emb_header, len_header_token, number_header, g_sc=gt_select_column, g_sa=gt_select_agg, g_wn=gt_wherenumber, g_wc=gt_wherecolumn, g_wvi=gt_wherevalueindex) # Calculate loss & step loss = Loss_selectwhere_startend_v2( score_select_column, score_select_agg, score_where_number, score_where_column, score_where_op, score_where_value, gt_select_column, gt_select_agg, gt_wherenumber, gt_wherecolumn, g_wo, gt_wherevalueindex) # RL # Random explore pred_selectcolumn_random, pred_selectagg_random, pred_wherenumber_random, pred_wherecolumn_random, pred_whereop_random, pred_wherevalueindex_random = \ pred_selectwhere_startend_random(score_select_column_softmax, score_select_agg_softmax, score_where_number_softmax, score_where_column_softmax, score_whereop_softmax, score_where_value_softmax,) """ pred_wherevalue_str_random, pred_wherevalue_str_bert_random = convert_pred_wvi_to_string(pred_wherevalueindex_random, question_token, question_token_bert, berttoken_to_token_index, question) # Sort pr_wc: # Sort pr_wc when training the model as pr_wo and pr_wvi are predicted using ground-truth where-column (g_wc) # In case of 'dev' or 'test', it is not necessary as the ground-truth is not used during inference. pred_wherecolumn_sorted_random = sort_pred_wherecolumn(pred_wherecolumn_random, gt_wherecolumn) random_sql_int = generate_sql_i_v2(pred_selectcolumn_random, pred_selectagg_random, pred_wherenumber_random, pred_wherecolumn_sorted_random, pred_whereop_random, pred_wherevalue_str_random, question) """ # Prediction pr_sc, pr_sa, pr_wn, pr_wc, pr_wo, pr_wvi = pred_selectwhere_startend( score_select_column, score_select_agg, score_where_number, score_where_column, score_where_op, score_where_value, ) pred_wherevalue_str, pred_wherevalue_str_bert_random = convert_pred_wvi_to_string( pr_wvi, question_token, question_token_bert, berttoken_to_token_index, question) # Sort pr_wc: # Sort pr_wc when training the model as pr_wo and pr_wvi are predicted using ground-truth where-column (g_wc) # In case of 'dev' or 'test', it is not necessary as the ground-truth is not used during inference. pr_wc_sorted = sort_pred_wherecolumn(pr_wc, gt_wherecolumn) pred_sql_int = generate_sql_i(pr_sc, pr_sa, pr_wn, pr_wc_sorted, pr_wo, pred_wherevalue_str, question) # select column def equal_in(cell_, pred_answer_column): for cell in pred_answer_column: if cell == cell_: return True return False # select_column_random_int = torch.squeeze(torch.softmax(score_select_column,dim=-1).multinomial(1),dim=-1) #select_column_int = torch.squeeze(torch.argmax(score_select_column,dim=-1),dim=-1) # if batch_index%100==0: # print("RL_loss_select_column",RL_loss_select_column.data.cpu().numpy()) # RL # where number def list_in_list(small, big): for cell in big: try: cell = str(int(cell)) except: cell = str(cell) if cell.lower() in small: return True batch_size = score_where_value_softmax.shape[0] batch_have_good_reward = False while batch_have_good_reward == False: selectcolumn_random_v2 = [] for i in range(batch_size): col = random.randint(1, len(header[i])) - 1 selectcolumn_random_v2.append(col) wherenumber_random_v2 = [] for i in range(batch_size): num = random.randint(1, 4) wherenumber_random_v2.append(num) selectagg_random_v2 = [] for i in range(batch_size): agg = random.randint(1, 6) - 1 selectagg_random_v2.append(agg) whereop_random_v2 = [] for i in range(batch_size): cond = [] for j in range(4): col = random.randint(1, 3) - 1 cond.append(col) whereop_random_v2.append(cond) wherecolumn_random_v2 = [] for i in range(batch_size): cond = [] for j in range(4): col = random.randint(1, len(header[i])) - 1 cond.append(col) wherecolumn_random_v2.append(cond) pred_wherevalueindex_random_v2 = [] # [8,4,2] for i in range(batch_size): cond = [] for j in range(4): start = random.randint(1, len_question[i]) - 1 end = random.randint(1, len_question[i]) - 1 while start >= end: start = random.randint(1, len_question[i]) - 1 end = random.randint(1, len_question[i]) - 1 cond.append([start, end]) pred_wherevalueindex_random_v2.append(cond) pred_wherevalue_str_random_v2, pred_wherevalue_str_bert_random_v2 = convert_pred_wvi_to_string( pred_wherevalueindex_random_v2, question_token, question_token_bert, berttoken_to_token_index, question) # pred_wherecolumn_sorted_random_v2 = sort_pred_wherecolumn(wherecolumn_random_v2, gt_wherecolumn) random_sql_int_v2 = generate_sql_i_v2( pred_selectcolumn_random, pred_selectagg_random, wherenumber_random_v2, wherecolumn_random_v2, whereop_random_v2, pred_wherevalue_str_random_v2, question) batch_reward_all = [] for i in range(len(batch_data)): gt_answer_list = batch_data[i]["answer"] reward = 0 tmp_conds = [] for j in range(wherenumber_random_v2[i]): cond = [] cond.append(wherecolumn_random_v2[i][j]) cond.append(whereop_random_v2[i][j]) cond.append(random_sql_int_v2[i]["conds"][j][2]) tmp_conds.append(cond) pred_answer_column = engine.execute(table[i]['id'], selectcolumn_random_v2[i], selectagg_random_v2[i], tmp_conds) # answer in for cell in gt_answer_list: if cell in pred_answer_column or equal_in( cell, pred_answer_column): reward = 1 if reward > 0: batch_have_good_reward = True batch_reward_all.append(reward) """ batch_reward_where = [] for i in range(len(batch_data)): gt_answer_list = batch_data[i]["answer"] reward = 0 if len(random_sql_int_v2[i]["conds"])>0 and len(pred_sql_int[i]["conds"])>0: tmp_conds = [] tmp_cond = pred_sql_int[i]["conds"][0] tmp_cond[0] = wherecolumn_random_v2[i][0] tmp_cond[1] = whereop_random_v2[i][0] tmp_cond[2] = random_sql_int_v2[i]["conds"][0][2] tmp_conds.append(tmp_cond) if tmp_conds[0][2]!="": pred_answer_column = engine.execute(table[i]['id'], "*", 0, tmp_conds) # answer in for cell in gt_answer_list: if cell in pred_answer_column or equal_in(cell, pred_answer_column): reward = 1 pred_answer_column4 = engine.execute(table[i]['id'], pred_sql_int[i]["sel"], 0, tmp_conds) # answer absolute in for cell in gt_answer_list: if cell in pred_answer_column4 or equal_in(cell, pred_answer_column4): reward = 1 if len(tmp_conds)>=1: # same column: word in question and where column pred_answer_column2 = engine.execute(table[i]['id'], tmp_conds[0][0], 0, []) for cell in pred_answer_column2: try: cell = str(int(cell)) except: cell = str(cell) if cell in question[i].lower(): reward = 1 # same column: where value and where column if cell == tmp_conds[0][2]: reward = 1 tmp_conds2 = [] tmp_cond2 = pred_sql_int[i]["conds"][0] tmp_cond2[0] = wherecolumn_random_v2[i][0] tmp_cond2[2] = random_sql_int_v2[i]["conds"][0][2] tmp_cond2[1] = 0 # EQUAL tmp_conds2.append(tmp_cond2) if len(tmp_conds2) >= 1: pred_answer_column3 = engine.execute(table[i]['id'], tmp_conds2[0][0], 0, tmp_conds2) # same row: the answer and this cell for row in table[i]["rows"]: if list_in_list(pred_answer_column3,row) and list_in_list(gt_answer_list,row): reward = 1 if reward>0: batch_have_good_reward=True batch_reward_where.append(reward) """ onehot_action_batch_wherenumber = [] for action_int in wherenumber_random_v2: tmp = [0] * score_where_number.shape[1] tmp[action_int - 1] = 1 onehot_action_batch_wherenumber.append(tmp) RL_loss_where_number = torch.mean(-torch.log( torch.sum(torch.softmax(score_where_number, dim=-1) * torch.tensor( onehot_action_batch_wherenumber, dtype=torch.float).to(device), dim=-1)) * torch.tensor(batch_reward_all, dtype=torch.float).to(device)) """ batch_reward_select_column = [] for i in range(len(batch_data)): gt_answer_list = batch_data[i]["answer"] pred_answer_column = engine.execute(table[i]['id'], random_sql_int[i]["sel"], 0, []) reward = -1 for cell in gt_answer_list: if cell in pred_answer_column or equal_in(cell,pred_answer_column): reward = 1 batch_reward_select_column.append(reward) """ onehot_action_batch_selectcolumn = [] for sel in selectcolumn_random_v2: tmp = [0] * score_select_column.shape[1] tmp[sel] = 1 onehot_action_batch_selectcolumn.append(tmp) RL_loss_select_column = torch.mean( -torch.log( torch.sum(torch.softmax(score_select_column, dim=-1) * torch.tensor(onehot_action_batch_selectcolumn, dtype=torch.float).to(device), dim=-1)) * torch.tensor(batch_reward_all, dtype=torch.float).to(device)) onehot_action_batch_selectagg = [] for agg in selectagg_random_v2: tmp = [0] * score_select_agg.shape[1] tmp[agg] = 1 onehot_action_batch_selectagg.append(tmp) RL_loss_select_agg = torch.mean(-torch.log( torch.sum(torch.softmax(score_select_agg, dim=-1) * torch.tensor( onehot_action_batch_selectagg, dtype=torch.float).to(device), dim=-1)) * torch.tensor(batch_reward_all, dtype=torch.float).to(device)) # RL # where column # where_column_int = torch.squeeze(torch.argmax(score_where_column, dim=-1), dim=-1) # where_column_int = torch.squeeze(torch.softmax(score_where_column, dim=-1).multinomial(1), dim=-1) where_column_int = [] for tmp_sql in random_sql_int_v2: if len(tmp_sql["conds"]) == 0: where_column_int.append(-1) else: where_column_int.append(tmp_sql["conds"][0][0]) onehot_action_batch_wherecolumn = [] for action_int in wherecolumn_random_v2: tmp = [0] * score_where_column.shape[1] if action_int[0] != -1: # all 0 will runtime error tmp[action_int[0]] = 1 else: # this is must tmp = [0.01] * score_where_column.shape[1] onehot_action_batch_wherecolumn.append(tmp) RL_loss_where_column = torch.mean(-torch.log( torch.sum(torch.softmax(score_where_column, dim=-1) * torch.tensor( onehot_action_batch_wherecolumn, dtype=torch.float).to(device), dim=-1)) * torch.tensor(batch_reward_all, dtype=torch.float).to(device)) # RL # where value # pred_wherevalueindex_random [8,4,2] action_wherevalue = [] for one in pred_wherevalueindex_random_v2: if len(one) > 0: action = [] for i in range(4): if len(one) > i: start = one[i][0] tmp_start = [0] * score_where_value.shape[2] tmp_start[start] = 1 end = one[i][1] tmp_end = [0] * score_where_value.shape[2] tmp_end[end] = 1 action.append([tmp_start, tmp_end]) else: action.append([[0.01] * score_where_value.shape[2], [0.01] * score_where_value.shape[2]]) action_wherevalue.append(action) else: action = [] for i in range(4): action.append([[0.01] * score_where_value.shape[2], [0.01] * score_where_value.shape[2]]) action_wherevalue.append(action) tmp_action = torch.tensor(action_wherevalue, dtype=torch.float).to(device) tmp_score_where_value = torch.transpose(score_where_value, 2, 3) RL_loss_where_value = torch.mean(-torch.log( torch.sum( torch.softmax(tmp_score_where_value, dim=-1) * tmp_action, dim=-1)) * torch.unsqueeze(torch.unsqueeze( torch.tensor(batch_reward_all, dtype=torch.float), dim=-1), dim=-1).to(device)) # RL # where op action_whereop = [] for one in whereop_random_v2: if len(one) > 0: action = [] for i in range(4): if len(one) > i: op = one[0] tmp_start = [0] * score_where_op.shape[2] tmp_start[op] = 1 action.append(tmp_start) else: action.append([0.01] * score_where_op.shape[2]) action_whereop.append(action) else: action = [] for i in range(4): action.append([0.01] * score_where_op.shape[2]) action_whereop.append(action) tmp_action_op = torch.tensor(action_whereop, dtype=torch.float).to(device) RL_loss_where_op = torch.mean(-torch.log( torch.sum(torch.softmax(score_where_op, dim=-1) * tmp_action_op, dim=-1)) * torch.unsqueeze(torch.tensor( batch_reward_all, dtype=torch.float), dim=-1).to(device)) loss += RL_loss_where_number loss += RL_loss_select_agg loss += RL_loss_select_column loss += RL_loss_where_op loss += RL_loss_where_column loss += RL_loss_where_value # Calculate gradient if batch_index % accumulate_gradients == 0: # mode # at start, perform zero_grad opt.zero_grad() if opt_bert: opt_bert.zero_grad() loss.backward() if accumulate_gradients == 1: opt.step() if opt_bert: opt_bert.step() elif batch_index % accumulate_gradients == (accumulate_gradients - 1): # at the final, take step with accumulated graident loss.backward() opt.step() if opt_bert: opt_bert.step() else: # at intermediate stage, just accumulates the gradients loss.backward() # Cacluate accuracy count_sc1_list, count_sa1_list, count_wn1_list, \ count_wc1_list, count_wo1_list, \ count_wvi1_list, count_wv1_list = get_count_sw_list(gt_select_column, gt_select_agg, gt_wherenumber, gt_wherecolumn, g_wo, gt_wherevalueindex, pr_sc, pr_sa, pr_wn, pr_wc, pr_wo, pr_wvi, sql, pred_sql_int, mode='train') count_lx1_list = get_count_lx_list(count_sc1_list, count_sa1_list, count_wn1_list, count_wc1_list, count_wo1_list, count_wv1_list) # lx stands for logical form accuracy # Execution accuracy test. count_x1_list, g_ans, pr_ans = get_count_x_list( engine, table, gt_select_column, gt_select_agg, sql, pr_sc, pr_sa, pred_sql_int) # statistics ave_loss += loss.item() # count count_sc += sum(count_sc1_list) count_sa += sum(count_sa1_list) count_wn += sum(count_wn1_list) count_wc += sum(count_wc1_list) count_wo += sum(count_wo1_list) count_wvi += sum(count_wvi1_list) count_wv += sum(count_wv1_list) count_logic_form_acc += sum(count_lx1_list) count_execute_acc += sum(count_x1_list) ave_loss /= count acc_sc = count_sc / count acc_sa = count_sa / count acc_wn = count_wn / count acc_wc = count_wc / count acc_wo = count_wo / count acc_wvi = count_wv / count acc_wv = count_wv / count acc_lx = count_logic_form_acc / count acc_x = count_execute_acc / count acc = [ ave_loss, acc_sc, acc_sa, acc_wn, acc_wc, acc_wo, acc_wvi, acc_wv, acc_lx, acc_x ] aux_out = 1 return acc, aux_out
import os from sqlnet.dbengine import DBEngine import json import time path_db = './wikisql/data/tianchi/' dset_name = 'val' # The engine for seaching results engine = DBEngine(os.path.join(path_db, f"{dset_name}.db")) # Return the results queried query = lambda sql_tmp: engine.execute(sql_tmp['table_id'], sql_tmp['sql']['sel'], sql_tmp['sql']['agg'], sql_tmp['sql']['conds'], sql_tmp['sql']['cond_conn_op']) fname = os.path.join(path_db, f"{dset_name}.json") with open(fname, encoding='utf-8') as fs: total_count = 0 start = time.time() for line in fs: record = json.loads(line) break while total_count < 10: res = query(record) total_count += 1 print('%d times of invoking cost %.2fs.' % (total_count, time.time() - start)) # print('The number of empty results is %d, the number of result is 0 %d' % (count, tmp))
def train(train_loader, train_table, model, model_bert, opt, bert_config, tokenizer, max_seq_length, num_target_layers, accumulate_gradients=1, check_grad=True, st_pos=0, opt_bert=None, path_db=None, dset_name='train'): ave_loss = 0 count = 0 # count the # of examples count_sc = 0 # count the # of correct predictions of select column count_sa = 0 # of selectd aggregation count_wn = 0 # of where number count_wc = 0 # of where column count_wo = 0 # of where operator count_wv = 0 # of where-value count_wvi = 0 # of where-value index (on question tokens) count_logic_form_acc = 0 # of logical form acc count_execute_acc = 0 # of execution acc # Engine for SQL querying. engine = DBEngine(os.path.join(path_db, f"{dset_name}.db")) explored_data_list = [] for batch_index, batch_data in enumerate(train_loader): count += len(batch_data) if count < st_pos: continue # Get fields question, question_token, sql, sql_text, sql_t, table, header_token, header \ = get_fields(batch_data, train_table, no_hs_t=True, no_sql_t=True) len_question_bert, len_header_token, number_header, \ question_token_bert, token_to_berttoken_index, berttoken_to_token_index \ = get_wemb_bert_v2(bert_config, model_bert, tokenizer, question_token, header, max_seq_length, num_out_layers_n=num_target_layers, num_out_layers_h=num_target_layers) # select column def equal_in(cell_, pred_answer_column): for cell in pred_answer_column: if cell == cell_: return True return False # RL # where number def list_in_list(small, big): for cell in big: try: cell_ = int(cell) if cell_ in small: return True cell_ = float(cell) if cell_ in small: return True except: cell_ = str(cell) if cell_.lower() in small: return True for cell in small: try: cell_ = int(cell) if cell_ in big: return True cell_ = float(cell) if cell_ in big: return True except: cell_ = str(cell) if cell_.lower() in big: return True return False def list_exact_match(input1, input2): tmp1 = [str(item) for item in input1] tmp2 = [str(item) for item in input2] if sorted(tmp1) == sorted(tmp2): return True return False def contains(big_list, small_list): return set(small_list).issubset(set(big_list)) for i in range(len(batch_data)): print(sql[i]) explored_data = {} explore_count = 0 breakall = False reward_where = False reward_where_cond1 = 0 reward_where_cond2 = 0 reward_where_cond3 = 0 reward_where_cond4 = 0 len_question = len(question_token[i]) gt_answer_list = batch_data[i]["answer"] where_number_random = 0 select_column_random = -1 select_agg_random = 0 col = 0 op = 0 start = 0 end = 0 col2 = 0 op2 = 0 start2 = 0 end2 = 0 col3 = 0 op3 = 0 start3 = 0 end3 = 0 col4 = 0 op4 = 0 start4 = 0 end4 = 0 saved_where_number = -1 saved_col1 = -1 saved_op1 = -1 saved_start1 = -1 saved_end1 = -1 saved = False # select_column_random = 3 # where_number_random = 1 while True: final_select_agg = None final_select_column = None final_conds = [] tmp_conds = [] if where_number_random == 0: select_column_random += 1 if select_column_random == len(header[i]): select_agg_random += 1 select_column_random = 0 if select_agg_random == 6: where_number_random += 1 select_agg_random = 0 if where_number_random == 1: end += 1 if end >= len_question + 1: start += 1 end = start + 1 if start >= len_question: op += 1 start = 0 if op == 3: col += 1 op = 0 if col == len(header[i]): select_column_random += 1 col = 0 if select_column_random == len(header[i]): select_agg_random += 1 select_column_random = 0 if select_agg_random == 6: where_number_random += 1 select_agg_random = 0 if saved == True and where_number_random == 2: col = saved_col1 start = saved_start1 end = saved_end1 op = saved_op1 end2 += 1 if end2 >= len_question + 1: start2 += 1 end2 = start2 + 1 if start2 >= len_question: op2 += 1 start2 = 0 if op2 == 3: col2 += 1 op2 = 0 if col2 == len(header[i]): select_column_random += 1 col2 = 0 if select_column_random == len(header[i]): select_agg_random += 1 select_column_random = 0 if select_agg_random == 6: where_number_random += 1 select_agg_random = 0 if saved == False and where_number_random == 2: end += 1 if end >= len_question + 1: start += 1 end = start + 1 if start >= len_question: op += 1 start = 0 if op == 3: col += 1 op = 0 if col == len(header[i]): end2 += 1 col = 0 if end2 >= len_question + 1: start2 += 1 end2 = start2 + 1 if start2 >= len_question: op2 += 1 start2 = 0 if op2 == 3: col2 += 1 op2 = 0 if col2 == len(header[i]): select_column_random += 1 col2 = 0 if select_column_random == len(header[i]): select_agg_random += 1 select_column_random = 0 if select_agg_random == 6: where_number_random += 1 select_agg_random = 0 if where_number_random == 3: # break #TODO end += 1 if end >= len_question + 1: start += 1 end = start + 1 if start >= len_question: op += 1 start = 0 if op == 3: col += 1 op = 0 if col == len(header[i]): end2 += 1 col = 0 if end2 >= len_question + 1: start2 += 1 end2 = start2 + 1 if start2 >= len_question: op2 += 1 start2 = 0 if op2 == 3: col2 += 1 op2 = 0 if col2 == len(header[i]): end3 += 1 col2 = 0 if end3 >= len_question + 1: start3 += 1 end3 = start3 + 1 if start3 >= len_question: op3 += 1 start3 = 0 if op3 == 3: col3 += 1 op3 = 0 if col3 == len(header[i]): select_column_random += 1 col3 = 0 if select_column_random == len(header[i]): select_agg_random += 1 select_column_random = 0 if select_agg_random == 6: where_number_random += 1 select_agg_random = 0 if where_number_random == 4: end += 1 if end >= len_question + 1: start += 1 end = start + 1 if start >= len_question: op += 1 start = 0 if op == 3: col += 1 op = 0 if col == len(header[i]): end2 += 1 col = 0 if end2 >= len_question + 1: start2 += 1 end2 = start2 + 1 if start2 >= len_question: op2 += 1 start2 = 0 if op2 == 3: col2 += 1 op2 = 0 if col2 == len(header[i]): end3 += 1 col2 = 0 if end3 >= len_question + 1: start3 += 1 end3 = start3 + 1 if start3 >= len_question: op3 += 1 start3 = 0 if op3 == 3: col3 += 1 op3 = 0 if col3 == len(header[i]): end4 += 1 col3 = 0 if end4 >= len_question + 1: start4 += 1 end4 = start4 + 1 if start4 >= len_question: op4 += 1 start4 = 0 if op4 == 3: col4 += 1 op4 = 0 if col4 == len(header[i]): select_column_random += 1 col4 = 0 if select_column_random == len(header[i]): select_agg_random += 1 select_column_random = 0 if select_agg_random == 6: where_number_random += 1 select_agg_random = 0 if where_number_random == 1: cond = [] cond.append(col) cond.append(op) pr_wv_str = question_token[i][start:end] cond_value = merge_wv_t1_eng(pr_wv_str, question[i]) try: cond_value_ = float(cond_value) except: cond_value_ = cond_value # if type(cond_value_) == str: # and random.randint(1,2)==1: # op = 0 cond.append(cond_value_) tmp_conds.append(cond) if where_number_random == 2: cond = [] cond.append(col) cond.append(op) pr_wv_str = question_token[i][start:end] cond_value = merge_wv_t1_eng(pr_wv_str, question[i]) try: cond_value_ = float(cond_value) except: cond_value_ = cond_value cond.append(cond_value_) tmp_conds.append(cond) cond = [] cond.append(col2) cond.append(op2) pr_wv_str = question_token[i][start2:end2] cond_value = merge_wv_t1_eng(pr_wv_str, question[i]) try: cond_value_ = float(cond_value) except: cond_value_ = cond_value cond.append(cond_value_) tmp_conds.append(cond) if where_number_random == 3: cond = [] cond.append(col) cond.append(op) pr_wv_str = question_token[i][start:end] cond_value = merge_wv_t1_eng(pr_wv_str, question[i]) try: cond_value_ = float(cond_value) except: cond_value_ = cond_value cond.append(cond_value_) tmp_conds.append(cond) cond = [] cond.append(col2) cond.append(op2) pr_wv_str = question_token[i][start2:end2] cond_value = merge_wv_t1_eng(pr_wv_str, question[i]) try: cond_value_ = float(cond_value) except: cond_value_ = cond_value cond.append(cond_value_) tmp_conds.append(cond) cond = [] cond.append(col3) cond.append(op3) pr_wv_str = question_token[i][start3:end3] cond_value = merge_wv_t1_eng(pr_wv_str, question[i]) try: cond_value_ = float(cond_value) except: cond_value_ = cond_value cond.append(cond_value_) tmp_conds.append(cond) if where_number_random == 4: cond = [] cond.append(col) cond.append(op) pr_wv_str = question_token[i][start:end] cond_value = merge_wv_t1_eng(pr_wv_str, question[i]) try: cond_value_ = float(cond_value) except: cond_value_ = cond_value cond.append(cond_value_) tmp_conds.append(cond) cond = [] cond.append(col2) cond.append(op2) pr_wv_str = question_token[i][start2:end2] cond_value = merge_wv_t1_eng(pr_wv_str, question[i]) try: cond_value_ = float(cond_value) except: cond_value_ = cond_value cond.append(cond_value_) tmp_conds.append(cond) cond = [] cond.append(col3) cond.append(op3) pr_wv_str = question_token[i][start3:end3] cond_value = merge_wv_t1_eng(pr_wv_str, question[i]) try: cond_value_ = float(cond_value) except: cond_value_ = cond_value cond.append(cond_value_) tmp_conds.append(cond) cond = [] cond.append(col4) cond.append(op4) pr_wv_str = question_token[i][start4:end4] cond_value = merge_wv_t1_eng(pr_wv_str, question[i]) try: cond_value_ = float(cond_value) except: cond_value_ = cond_value cond.append(cond_value_) tmp_conds.append(cond) # print(select_column_random, select_agg_random, tmp_conds) pred_answer_column = engine.execute(table[i]['id'], select_column_random, select_agg_random, tmp_conds) explore_count += 1 if explore_count % 100000 == 0: print(explore_count) if explore_count > 500000: break exact_match = list_exact_match(gt_answer_list, pred_answer_column) if where_number_random==1 and not exact_match and contains(pred_answer_column,gt_answer_list)\ and saved_where_number==-1 and op==0: # 可能会导致1condition的错过 # where_number_random = 2 # 可能会导致1condition的错过 saved_start1 = start saved_end1 = end saved_col1 = col saved_op1 = op saved = True # answer in if exact_match: if pred_answer_column == [None]: break print("explore sql", select_column_random, select_agg_random, tmp_conds) if type(gt_answer_list[0] ) == str and select_agg_random != 0: print("fake sql") elif where_number_random == 1 and type(tmp_conds[0][2])==str and tmp_conds[0][1]!=0 or\ where_number_random == 2 and type(tmp_conds[1][2]) == str and tmp_conds[1][1] != 0 or\ where_number_random == 3 and type(tmp_conds[2][2]) == str and tmp_conds[2][1] != 0 or \ where_number_random == 4 and type(tmp_conds[3][2]) == str and tmp_conds[3][1] != 0: print("fake sql") else: # print("explore answer", pred_answer_column) if type(pred_answer_column[0]) == int or type( pred_answer_column[0]) == float: final_select_agg = select_agg_random else: final_select_agg = 0 if final_select_agg == 0: pred_answer_column2 = engine.execute( table[i]['id'], select_column_random, 0, []) for cell in gt_answer_list: if cell in pred_answer_column2 or equal_in( cell, pred_answer_column2): final_select_column = select_column_random break else: final_select_column = select_column_random if final_select_agg == 0: pred_answer_column3 = engine.execute( table[i]['id'], "*", 0, tmp_conds) # answer in for cell in gt_answer_list: if cell in pred_answer_column3 or equal_in( cell, pred_answer_column3): reward_where = True break else: reward_where = True # same column: word in question and where column if where_number_random >= 1: pred_answer_column4 = engine.execute( table[i]['id'], tmp_conds[0][0], 0, []) for cell in pred_answer_column4: try: cell_ = str(float(cell)) if cell_ in question[i].lower(): reward_where_cond1 += 0.1 break cell_ = str(int(cell)) if cell_ in question[i].lower(): reward_where_cond1 += 0.1 break except: cell = str(cell) if cell in question[i].lower(): reward_where_cond1 += 0.1 break # same column: where value and where column value = tmp_conds[0][2] if value in pred_answer_column4: reward_where_cond1 += 0.1 try: value = float(tmp_conds[0][2]) if value in pred_answer_column4: reward_where_cond1 += 0.1 except: pass try: value = int(tmp_conds[0][2]) if value in pred_answer_column4: reward_where_cond1 += 0.1 except: pass try: value = str(int(tmp_conds[0][2])) if value in pred_answer_column4: reward_where_cond1 += 0.1 except: pass try: value = str(float(tmp_conds[0][2])) if value in pred_answer_column4: reward_where_cond1 += 0.1 except: pass # same column: word in question and where column if where_number_random >= 2: pred_answer_column4 = engine.execute( table[i]['id'], tmp_conds[1][0], 0, []) for cell in pred_answer_column4: try: cell_ = str(float(cell)) if cell_ in question[i].lower(): reward_where_cond2 += 0.1 break cell_ = str(int(cell)) if cell_ in question[i].lower(): reward_where_cond2 += 0.1 break except: cell = str(cell) if cell in question[i].lower(): reward_where_cond2 += 0.1 break # same column: where value and where column value = tmp_conds[1][2] if value in pred_answer_column4: reward_where_cond2 += 0.1 try: value = float(tmp_conds[1][2]) if value in pred_answer_column4: reward_where_cond2 += 0.1 except: pass try: value = int(tmp_conds[1][2]) if value in pred_answer_column4: reward_where_cond2 += 0.1 except: pass try: value = str(int(tmp_conds[1][2])) if value in pred_answer_column4: reward_where_cond2 += 0.1 except: pass try: value = str(float(tmp_conds[1][2])) if value in pred_answer_column4: reward_where_cond2 += 0.1 except: pass # same column: word in question and where column if where_number_random >= 3: pred_answer_column4 = engine.execute( table[i]['id'], tmp_conds[2][0], 0, []) for cell in pred_answer_column4: try: cell_ = str(float(cell)) if cell_ in question[i].lower(): reward_where_cond3 += 0.1 break cell_ = str(int(cell)) if cell_ in question[i].lower(): reward_where_cond3 += 0.1 break except: cell = str(cell) if cell in question[i].lower(): reward_where_cond3 += 0.1 break # same column: where value and where column value = tmp_conds[2][2] if value in pred_answer_column4: reward_where_cond3 += 0.1 try: value = float(tmp_conds[2][2]) if value in pred_answer_column4: reward_where_cond3 += 0.1 except: pass try: value = int(tmp_conds[2][2]) if value in pred_answer_column4: reward_where_cond3 += 0.1 except: pass try: value = str(int(tmp_conds[2][2])) if value in pred_answer_column4: reward_where_cond3 += 0.1 except: pass try: value = str(float(tmp_conds[2][2])) if value in pred_answer_column4: reward_where_cond3 += 0.1 except: pass # same column: word in question and where column if where_number_random >= 4: pred_answer_column4 = engine.execute( table[i]['id'], tmp_conds[3][0], 0, []) for cell in pred_answer_column4: try: cell_ = str(float(cell)) if cell_ in question[i].lower(): reward_where_cond4 += 0.1 break cell_ = str(int(cell)) if cell_ in question[i].lower(): reward_where_cond4 += 0.1 break except: cell = str(cell) if cell in question[i].lower(): reward_where_cond4 += 0.1 break # same column: where value and where column value = tmp_conds[3][2] if value in pred_answer_column4: reward_where_cond4 += 0.1 try: value = float(tmp_conds[3][2]) if value in pred_answer_column4: reward_where_cond4 += 0.1 except: pass try: value = int(tmp_conds[3][2]) if value in pred_answer_column4: reward_where_cond4 += 0.1 except: pass try: value = str(int(tmp_conds[3][2])) if value in pred_answer_column4: reward_where_cond4 += 0.1 except: pass try: value = str(float(tmp_conds[3][2])) if value in pred_answer_column4: reward_where_cond4 += 0.1 except: pass """ 有问题,cond op 只能强制为 = 因为 > 或 < 不在一行 if where_number_random >= 1 and final_select_agg==0: tmp_conds2 = tmp_conds tmp_conds2[0][1] = 0 # EQUAL pred_answer_column5 = engine.execute(table[i]['id'], tmp_conds2[0][0], 0, tmp_conds2) # same row: the answer and this cell for row in table[i]["rows"]: if list_in_list(pred_answer_column5, row) and list_in_list(gt_answer_list, row): reward_where_cond1 += 0.1 break if where_number_random >= 2 and final_select_agg==0: tmp_conds2 = tmp_conds tmp_conds2[0][1] = 0 # EQUAL tmp_conds2[1][1] = 0 # EQUAL pred_answer_column5 = engine.execute(table[i]['id'], tmp_conds2[1][0], 0, tmp_conds2) # same row: the answer and this cell for row in table[i]["rows"]: if list_in_list(pred_answer_column5, row) and list_in_list(gt_answer_list, row): reward_where_cond2 += 0.1 break if where_number_random >= 3 and final_select_agg==0: tmp_conds2 = tmp_conds tmp_conds2[0][1] = 0 # EQUAL tmp_conds2[1][1] = 0 # EQUAL tmp_conds2[2][1] = 0 # EQUAL pred_answer_column5 = engine.execute(table[i]['id'], tmp_conds2[2][0], 0, tmp_conds2) # same row: the answer and this cell for row in table[i]["rows"]: if list_in_list(pred_answer_column5, row) and list_in_list(gt_answer_list, row): reward_where_cond3 += 0.1 break if where_number_random >= 4 and final_select_agg==0: tmp_conds2 = tmp_conds tmp_conds2[0][1] = 0 # EQUAL tmp_conds2[1][1] = 0 # EQUAL tmp_conds2[2][1] = 0 # EQUAL tmp_conds2[3][1] = 0 # EQUAL pred_answer_column5 = engine.execute(table[i]['id'], tmp_conds2[3][0], 0, tmp_conds2) # same row: the answer and this cell for row in table[i]["rows"]: if list_in_list(pred_answer_column5, row) and list_in_list(gt_answer_list, row): reward_where_cond4 += 0.1 break """ if reward_where_cond1 >= 0.2 and reward_where == True and where_number_random >= 1: final_conds.append(tmp_conds[0]) if reward_where_cond2 >= 0.2 and reward_where == True and where_number_random >= 2: final_conds.append(tmp_conds[1]) if reward_where_cond3 >= 0.2 and reward_where == True and where_number_random >= 3: final_conds.append(tmp_conds[2]) if reward_where_cond4 >= 0.2 and reward_where == True and where_number_random >= 4: final_conds.append(tmp_conds[3]) if final_select_agg != None and final_select_column != None and ( where_number_random == 1 and len(final_conds) == 1 or where_number_random == 2 and len(final_conds) == 2 or where_number_random == 3 and len(final_conds) == 3 or where_number_random == 4 and len(final_conds) == 4): break if final_select_agg != None and final_select_column != None and where_number_random == 0: break if final_select_column != None: explored_data["sel"] = final_select_column explored_data["agg"] = final_select_agg explored_data["conds"] = final_conds explored_data_list.append(explored_data) print(len(explored_data_list)) one_data = batch_data[i] one_data["sql"] = explored_data one_data["query"] = explored_data f = open("gen_data.jsonl", mode="a", encoding="utf-8") json.dump(one_data, f) f.write('\n') f.close() print("Done") ave_loss /= count acc_sc = count_sc / count acc_sa = count_sa / count acc_wn = count_wn / count acc_wc = count_wc / count acc_wo = count_wo / count acc_wvi = count_wv / count acc_wv = count_wv / count acc_lx = count_logic_form_acc / count acc_x = count_execute_acc / count acc = [ ave_loss, acc_sc, acc_sa, acc_wn, acc_wc, acc_wo, acc_wvi, acc_wv, acc_lx, acc_x ] aux_out = 1 return acc, aux_out
def test(data_loader, data_table, model, model_bert, bert_config, tokenizer, max_seq_length, num_target_layers, detail=False, st_pos=0, cnt_tot=1, EG=False, beam_size=4, path_db=None, dset_name='test'): model.eval() model_bert.eval() print('g_scn/wc/wn/wo dev/test不监督') cnt = 0 engine = DBEngine(os.path.join(path_db, f"{dset_name}.db")) results = [] total = 0 one_acc_num, tot_acc_num, ex_acc_num = 0.0, 0.0, 0.0 for iB, t in enumerate(data_loader): cnt += len(t) if cnt < st_pos: continue # Get fields nlu, nlu_t, sql_i, sql_t, tb, hs_t, hds = get_fields(t, data_table, no_hs_t=True, no_sql_t=True) g_sc, g_sa, g_sop, g_wn, g_wc, g_wo, g_wv, g_sel_num_seq, g_sel_ag_seq, conds = get_g( sql_i) g_wvi_corenlp = get_g_wvi_corenlp(t) wemb_n, wemb_h, l_n, l_hpu, l_hs, \ nlu_tt, t_to_tt_idx, tt_to_t_idx \ = get_wemb_bert(bert_config, model_bert, tokenizer, nlu_t, hds, max_seq_length, num_out_layers_n=num_target_layers, num_out_layers_h=num_target_layers) try: g_wvi = get_g_wvi_bert_from_g_wvi_corenlp(t_to_tt_idx, g_wvi_corenlp) except: # Exception happens when where-condition is not found in nlu_tt. # In this case, that train example is not used. # During test, that example considered as wrongly answered. for b in range(len(nlu)): results1 = {} results1["error"] = "Skip happened" results1["nlu"] = nlu[b] results1["table_id"] = tb[b]["id"] results.append(results1) continue # model specific part # score if not EG: # No Execution guided decoding s_scn, s_sc, s_sa, s_sop, s_wn, s_wc, s_wo, s_wv = model( wemb_n, l_n, wemb_h, l_hpu, l_hs) # prediction score = [] score.append(s_scn) score.append(s_sc) score.append(s_sa) score.append(s_sop) tuple(score) pr_sql_i1 = model.gen_query(score, nlu_tt, nlu) pr_wn, pr_wc, pr_sop, pr_wo, pr_wvi = pred_sw_se( s_sop, s_wn, s_wc, s_wo, s_wv) pr_wv_str, pr_wv_str_wp = convert_pr_wvi_to_string( pr_wvi, nlu_t, nlu_tt, tt_to_t_idx, nlu) # 映射到字符串 pr_wc_sorted = sort_pr_wc(pr_wc, g_wc) pr_sql_i = generate_sql_i(pr_sql_i1, pr_wn, pr_wc_sorted, pr_wo, pr_wv_str, nlu) else: # Execution guided decoding prob_sca, prob_w, prob_wn_w, pr_sc, pr_sa, pr_wn, pr_sql_i = model.beam_forward( wemb_n, l_n, wemb_h, l_hpu, l_hs, engine, tb, nlu_t, nlu_tt, tt_to_t_idx, nlu, beam_size=beam_size) # sort and generate pr_wc, pr_wo, pr_wv, pr_sql_i = sort_and_generate_pr_w(pr_sql_i) # Follosing variables are just for the consistency with no-EG case. # # Saving for the official evaluation later. for b, pr_sql_i1 in enumerate(pr_sql_i): results1 = {} results1["query"] = pr_sql_i1 results1["table_id"] = tb[b]["id"] results1["nlu"] = nlu[b] results.append(results1) one_err, tot_err = model.check_acc(nlu, pr_sql_i, sql_i) one_acc_num += (len(pr_sql_i) - one_err) tot_acc_num += (len(pr_sql_i) - tot_err) total += len(pr_sql_i) # Execution Accuracy table_ids = [] for x in range(len(tb)): table_ids.append(tb[x]['id']) for sql_gt, sql_pred, tid in zip(sql_i, pr_sql_i, table_ids): ret_gt = engine.execute(tid, sql_gt['sel'], sql_gt['agg'], sql_gt['conds'], sql_gt['cond_conn_op']) try: ret_pred = engine.execute(tid, sql_pred['sel'], sql_pred['agg'], sql_pred['conds'], sql_pred['cond_conn_op']) except: ret_pred = None ex_acc_num += (ret_gt == ret_pred) return ((one_acc_num / total), (tot_acc_num / total), ex_acc_num / total), results