def epoch_acc(batch_size, table_data, sql_data, db_path, cs, out=False): with open('../submit/results_val.jsonl') as inf: sqlOutdata = [] for idx, line in enumerate(inf): _sql = json.loads(line.strip()) _sql['sql'] = _sql['query'] _sql['question'] = _sql['nlu'] # print(_sql) if ('error') in _sql: sqlOutdata.append(sqlOutdata[idx - 1]) else: sqlOutdata.append(_sql) print(len(sql_data)) print(len(sqlOutdata)) engine = DBEngine(db_path) perm = list(range(len(sqlOutdata))) badcase = 0 one_acc_num, tot_acc_num, ex_acc_num = 0.0, 0.0, 0.0 for st in tqdm(range(len(sql_data) // batch_size + 1)): ed = (st + 1) * batch_size if (st + 1) * batch_size < len( perm) else len(perm) st = st * batch_size q_seq, gt_sel_num, col_seq, col_num, ans_seq, gt_cond_seq, gt_type, raw_data = \ to_batch_seq(sql_data, table_data, perm, st, ed, ret_vis_data=True) query_gt, table_ids = to_batch_query(sql_data, perm, st, ed) # query_gt: ground truth of sql, data['sql'], containing sel, agg, conds:{sel, op, value} raw_q_seq = [x[0] for x in raw_data] # original question # try: pred_queriesc, allfsc = genByout(sqlOutdata[st:ed], table_ids, raw_q_seq, gt_type, 'val', cs) one_err, tot_err = check_acc(raw_data, pred_queriesc, query_gt, allfsc) # except: # badcase += 1 # print 'badcase', badcase # continue one_acc_num += (ed - st - one_err) tot_acc_num += (ed - st - tot_err) # Execution Accuracy for sql_gt, sql_pred, tid in zip(query_gt, pred_queriesc, table_ids): ret_gt = engine.execute(tid, sql_gt['sel'], sql_gt['agg'], sql_gt['conds'], sql_gt['cond_conn_op']) try: ret_pred = engine.execute(tid, sql_pred['sel'], sql_pred['agg'], sql_pred['conds'], sql_pred['cond_conn_op']) except: ret_pred = None ex_acc_num += (ret_gt == ret_pred) return one_acc_num / len(sql_data), tot_acc_num / len( sql_data), ex_acc_num / len(sql_data)
def predict(data_loader, data_table, model, model_bert, bert_config, tokenizer, max_seq_length, num_target_layers, detail=False, st_pos=0, cnt_tot=1, EG=False, beam_size=4, path_db=None, dset_name='test'): model.eval() model_bert.eval() engine = DBEngine(os.path.join(path_db, f"{dset_name}.db")) results = [] for iB, t in enumerate(data_loader): nlu, nlu_t, sql_i, sql_q, sql_t, tb, hs_t, hds = get_fields(t, data_table, no_hs_t=True, no_sql_t=True) g_sc, g_sa, g_wn, g_wc, g_wo, g_wv = get_g(sql_i) g_wvi_corenlp = get_g_wvi_corenlp(t) wemb_n, wemb_h, l_n, l_hpu, l_hs, \ nlu_tt, t_to_tt_idx, tt_to_t_idx \ = get_wemb_bert(bert_config, model_bert, tokenizer, nlu_t, hds, max_seq_length, num_out_layers_n=num_target_layers, num_out_layers_h=num_target_layers) if not EG: # No Execution guided decoding s_sc, s_sa, s_wn, s_wc, s_wo, s_wv = model(wemb_n, l_n, wemb_h, l_hpu, l_hs) pr_sc, pr_sa, pr_wn, pr_wc, pr_wo, pr_wvi = pred_sw_se(s_sc, s_sa, s_wn, s_wc, s_wo, s_wv, ) pr_wv_str, pr_wv_str_wp = convert_pr_wvi_to_string(pr_wvi, nlu_t, nlu_tt, tt_to_t_idx, nlu) pr_sql_i = generate_sql_i(pr_sc, pr_sa, pr_wn, pr_wc, pr_wo, pr_wv_str, nlu) else: # Execution guided decoding prob_sca, prob_w, prob_wn_w, pr_sc, pr_sa, pr_wn, pr_sql_i = model.beam_forward(wemb_n, l_n, wemb_h, l_hpu, l_hs, engine, tb, nlu_t, nlu_tt, tt_to_t_idx, nlu, beam_size=beam_size) # sort and generate pr_wc, pr_wo, pr_wv, pr_sql_i = sort_and_generate_pr_w(pr_sql_i) # Following variables are just for consistency with no-EG case. pr_wvi = None # not used pr_wv_str=None pr_wv_str_wp=None pr_sql_q = generate_sql_q(pr_sql_i, tb) pr_sql_q_base = generate_sql_q_base(pr_sql_i, tb) for b, (pr_sql_i1, pr_sql_q1, pr_sql_q1_base) in enumerate(zip(pr_sql_i, pr_sql_q, pr_sql_q_base)): results1 = {} results1["query"] = pr_sql_i1 results1["table_id"] = tb[b]["id"] results1["nlu"] = nlu[b] results1["sql"] = pr_sql_q1 results1["sql_with_params"] = pr_sql_q1_base rr = engine.execute_query(tb[b]["id"], Query.from_dict(pr_sql_i1, ordered=True), lower=False) results1["answer"] = rr results.append(results1) return results
def predict(data_loader, data_table, model, model_bert, bert_config, tokenizer, max_seq_length, num_target_layers, detail=False, st_pos=0, cnt_tot=1, EG=False, beam_size=4, path_db=None, dset_name='test'): model.eval() model_bert.eval() engine = DBEngine(os.path.join(path_db, f"{dset_name}.db")) results = [] for iB, t in enumerate(data_loader): nlu, nlu_t, sql_i, sql_q, sql_t, tb, hs_t, hds = get_fields(t, data_table, no_hs_t=True, no_sql_t=True,result=True) # g_sc, g_sa, g_wn, g_wc, g_wo, g_wv = get_g(sql_i) # g_wvi_corenlp = get_g_wvi_corenlp(t) wemb_n, wemb_h, l_n, l_hpu, l_hs, \ nlu_tt, t_to_tt_idx, tt_to_t_idx \ = get_wemb_bert(bert_config, model_bert, tokenizer, nlu_t, hds, max_seq_length, num_out_layers_n=num_target_layers, num_out_layers_h=num_target_layers) if not EG: # No Execution guided decoding s_sn, s_sc, s_sa, s_wn, s_wc, s_wo, s_wv, s_wco = model(wemb_n, l_n, wemb_h, l_hpu, l_hs) # get loss & step # prediction pr_sn, pr_sc, pr_sa, pr_wn, pr_wc, pr_wo, pr_wvi, dwpr_sc, pr_wco = pred_sw_se(s_sn, s_sc, s_sa, s_wn, s_wc, s_wo, s_wv, s_wco, typ=True) pr_wv_str, pr_wv_str_wp = convert_pr_wvi_to_string(pr_wvi, nlu_t, nlu_tt, tt_to_t_idx, nlu) # g_sql_i = generate_sql_i(g_sc, g_sa, g_wn, g_wc, g_wo, g_wv_str, nlu) pr_sql_i = generate_sql_i(pr_sn, pr_sc, pr_sa, pr_wn, pr_wc, pr_wo, pr_wv_str, pr_wco, nlu) else: # Execution guided decoding # prob_sca, prob_w, prob_wn_w, pr_sc, pr_sa, pr_wn, pr_sql_i = model.beam_forward(wemb_n, l_n, wemb_h, l_hpu, # l_hs, engine, tb, # nlu_t, nlu_tt, # tt_to_t_idx, nlu, # beam_size=beam_size) # # sort and generate # pr_wc, pr_wo, pr_wv, pr_sql_i = sort_and_generate_pr_w(pr_sql_i) # # Following variables are just for consistency with no-EG case. # pr_wvi = None # not used # pr_wv_str=None # pr_wv_str_wp=None return pr_sql_q = generate_sql_q(pr_sql_i, tb) for b, (pr_sql_i1, pr_sql_q1) in enumerate(zip(pr_sql_i, pr_sql_q)): results1 = {} pr_sql_i1['sel']=pr_sc[b] results1["query"] = pr_sql_i1 results1["table_id"] = tb[b]["id"] results1["nlu"] = nlu[b] results1["sql"] = pr_sql_q1 results.append(results1) return results
def test( data_loader, data_table, model, model_bert, tokenizer, sql_vocab, max_seq_length, detail=False, st_pos=0, cnt_tot=1, EG=False, beam_only=True, beam_size=4, path_db=None, dset_name='test', col_pool_type='start_tok', aug=False, ): model.eval() model_bert.eval() ave_loss = 0 cnt = 0 cnt_lx = 0 cnt_x = 0 results = [] cnt_list = [] engine = DBEngine(os.path.join(path_db, f"{dset_name}.db")) for iB, t in enumerate(tqdm(data_loader)): cnt += len(t) if cnt < st_pos: continue # Get fields nlu, nlu_t, sql_i, sql_q, sql_t, tb, hs_t, hds = get_fields( t, data_table, no_hs_t=True, no_sql_t=True) g_sc, g_sa, g_wn, g_wc, g_wo, g_wv = get_g(sql_i) g_wvi_corenlp = get_g_wvi_corenlp(t) all_encoder_layer, pooled_output, tokens, i_nlu, i_hds, i_sql_vocab, \ l_n, l_hpu, l_hs, l_input, \ nlu_tt, t_to_tt_idx, tt_to_t_idx \ = get_bert_output_s2s(model_bert, tokenizer, nlu_t, hds, sql_vocab, max_seq_length,sample=False) try: # g_wvi = get_g_wvi_bert_from_g_wvi_corenlp(t_to_tt_idx, g_wvi_corenlp) except: # Exception happens when where-condition is not found in nlu_tt. # In this case, that train example is not used. # During test, that example considered as wrongly answered. # e.g. train: 32. for b in range(len(nlu)): results1 = {} results1["error"] = "Skip happened" results1["nlu"] = nlu[b] results1["table_id"] = tb[b]["id"] results.append(results1) continue # Generate g_pnt_idx g_pnt_idxs = gen_g_pnt_idx(g_wvi, sql_i, i_hds, i_sql_vocab, col_pool_type=col_pool_type) pnt_start_tok = i_sql_vocab[0][-2][0] pnt_end_tok = i_sql_vocab[0][-1][0] # check # print(array(tokens[0])[g_pnt_idxs[0]]) wenc_s2s = all_encoder_layer[-1] # wemb_h = [B, max_header_number, hS] cls_vec = pooled_output if not EG: score = model( wenc_s2s, l_input, cls_vec, pnt_start_tok, ) loss = Loss_s2s(score, g_pnt_idxs) pr_pnt_idxs = pred_pnt_idxs(score, pnt_start_tok, pnt_end_tok) else: # EG pr_pnt_idxs, p_list, pnt_list_beam = model.EG_forward( wenc_s2s, l_input, cls_vec, pnt_start_tok, pnt_end_tok, i_sql_vocab, i_nlu, i_hds, # for EG tokens, nlu, nlu_t, hds, tt_to_t_idx, # for EG tb, engine, beam_size, beam_only=beam_only) if beam_only: loss = torch.tensor([0]) else: # print('EG on!') loss = torch.tensor([1]) g_i_vg_list, g_i_vg_sub_list = gen_i_vg_from_pnt_idxs( g_pnt_idxs, i_sql_vocab, i_nlu, i_hds) g_sql_q_s2s, g_sql_i = gen_sql_q_from_i_vg(tokens, nlu, nlu_t, hds, tt_to_t_idx, pnt_start_tok, pnt_end_tok, g_pnt_idxs, g_i_vg_list, g_i_vg_sub_list) pr_i_vg_list, pr_i_vg_sub_list = gen_i_vg_from_pnt_idxs( pr_pnt_idxs, i_sql_vocab, i_nlu, i_hds) pr_sql_q_s2s, pr_sql_i = gen_sql_q_from_i_vg( tokens, nlu, nlu_t, hds, tt_to_t_idx, pnt_start_tok, pnt_end_tok, pr_pnt_idxs, pr_i_vg_list, pr_i_vg_sub_list) g_sql_q = generate_sql_q(sql_i, tb) try: pr_sql_q = generate_sql_q(pr_sql_i, tb) # gen pr_sc, pr_sa pr_sc = [] pr_sa = [] for pr_sql_i1 in pr_sql_i: pr_sc.append(pr_sql_i1["sel"]) pr_sa.append(pr_sql_i1["agg"]) except: bS = len(sql_i) pr_sql_q = ['NA'] * bS pr_sc = ['NA'] * bS pr_sa = ['NA'] * bS for b, pr_sql_i1 in enumerate(pr_sql_i): results1 = {} results1["query"] = pr_sql_i1 results1["table_id"] = tb[b]["id"] results1["nlu"] = nlu[b] results.append(results1) # Cacluate accuracy cnt_lx1_list = get_cnt_lx_list_s2s(g_pnt_idxs, pr_pnt_idxs) # if not aug: # cnt_x1_list, g_ans, pr_ans = get_cnt_x_list(engine, tb, g_sc, g_sa, sql_i, pr_sc, pr_sa, pr_sql_i) # else: cnt_x1_list = [0] * len(t) g_ans = ['N/A (data augmented'] * len(t) pr_ans = ['N/A (data augmented'] * len(t) # statistics ave_loss += loss.item() # count cnt_lx += sum(cnt_lx1_list) cnt_x += sum(cnt_x1_list) # report if detail: print(f"Ground T : {g_pnt_idxs}") print(f"Prediction: {pr_pnt_idxs}") print(f"Ground T : {g_sql_q}") print(f"Prediction: {pr_sql_q}") ave_loss /= cnt acc_lx = cnt_lx / cnt acc_x = cnt_x / cnt acc = [ave_loss, acc_lx, acc_x] return acc, results
def train(train_loader, train_table, model, model_bert, opt, bert_config, tokenizer, max_seq_length, num_target_layers, accumulate_gradients=1, check_grad=True, st_pos=0, opt_bert=None, path_db=None, dset_name='train', mvl=2): #max value length model.train() model_bert.train() #train table is a dict, key is table id, value is the whole table ave_loss = 0 cnt = 0 # count the # of examples cnt_sn = 0 # count select number cnt_sc = 0 # count the # of correct predictions of select column cnt_sa = 0 # of selectd aggregation cnt_wn = 0 # of where number cnt_wr = 0 #where relation number = cnt_wn - 1 cnt_wc = 0 # of where column cnt_wo = 0 # of where operator cnt_wv = 0 # of where-value cnt_wvi = 0 # of where-value index (on question tokens) cnt_lx = 0 # of logical form acc cnt_x = 0 # of execution acc # Engine for SQL querying. engine = DBEngine(os.path.join(path_db, dset_name, f"{dset_name}.db")) for iB, t in enumerate(train_loader): #generate each data batch cnt += len(t) if cnt < st_pos: continue # Get fields nlu, nlu_t, sql_i, sql_q, sql_t, tb, hs_t, hds = get_fields( t, train_table, no_hs_t=True, no_sql_t=True, generate_mode=False) # nlu : natural language utterance # nlu_t: tokenized nlu # sql_i: canonical form of SQL query # sql_q: full SQL query text. Not used. # sql_t: tokenized SQL query # tb : table # hs_t : tokenized headers. Not used. ''' print('nlu: ', nlu) print('nlu_t: ', nlu_t) print('sql_i: ', sql_i) print('sql_q: ', sql_q) print('sql_t: ', sql_t) #print('tb: ', tb) print('hs_t: ', hs_t) print('hds: ', hds) ''' try: g_sn, g_sc, g_sa, g_wn, g_wr, g_dwn, g_wc, g_wo, g_wv, g_wrcn, wvi_change_index = get_g( sql_i) #get the where values ''' print('g_sn: ', g_sn) print('g_sc: ', g_sc) print('g_sa: ', g_sa) print('g_wn: ', g_wn) print('g_wr: ', g_wr) print('g_dwn: ', g_dwn) print('g_wc: ', g_wc) print('g_wo: ', g_wo) print('g_wv: ', g_wv) print('g_wrcn: ', g_wrcn) ''' #g_sn: (a list of double) number of select column; #g_sc: (a list of list) select column names; #g_sa: (a list of list) agg for each col; #g_wr: (a list of double) if value=0, then there is only one condition, else there are two conditions; #g_wc: (a list of list) where col; #g_wo: (a list of list) where op; #g_wv: (a list of list) where val; # get ground truth where-value index under CoreNLP tokenization scheme. It's done already on trainset. g_wvi_corenlp = get_g_wvi_corenlp(t, wvi_change_index) # this function is to get the indices of where values from the question token wemb_n, wemb_h, l_n, l_hpu, l_hs, \ nlu_tt, t_to_tt_idx, tt_to_t_idx, wemb_v, l_npu, l_token \ = get_wemb_bert(bert_config, model_bert, tokenizer, nlu_t, hds, max_seq_length, num_out_layers_n=num_target_layers, num_out_layers_h=num_target_layers, num_out_layers_v=num_target_layers) ''' print('wemb_n: ', torch.tensor(wemb_n).size()) print('wemb_h: ', torch.tensor(wemb_h).size()) ''' #print('l_n: ', l_n[0]) #print('l_hpu: ', l_hpu) #print('l_hs: ', l_hs) #print('nlu_tt: ', nlu_tt[0]) #print('t_to_tt_idx: ', t_to_tt_idx) #print('tt_to_t_idx: ', tt_to_t_idx) #print('g_wvi_corenlp', g_wvi_corenlp) # wemb_n: natural language embedding # wemb_h: header embedding # l_n: token lengths of each question # l_hpu: header token lengths # l_hs: the number of columns (headers) of the tables. # g_wvi = get_g_wvi_bert_from_g_wvi_corenlp( t_to_tt_idx, g_wvi_corenlp ) #if not exist, it will not train not include the length, so the end value is the start index of this word, not the end index of this word, so it need to add sth g_wvi = g_wvi_corenlp if g_wvi: for L in g_wvi: for e in L: if e[1] - e[0] + 1 > mvl: cnt -= len(t) print('error: ', e) raise RuntimeError( 'invalid training set' ) #only train length no larger than 8 of where value g_wvi = get_g_wvi_stidx_length_jian_yi( g_wvi) #不能sort,sort会导致两者对应不上 #print('g_wvi', g_wvi[0][0]) except: # Exception happens when where-condition is not found in nlu_tt. # In this case, that train example is not used. # During test, that example considered as wrongly answered. # e.g. train: 32. continue # score s_sn, s_sc, s_sa, s_wn, s_wr, s_hrpc, s_wc, s_wo, s_wv1, s_wv2, s_wv3, s_wv4 = model( mvl, wemb_n, l_n, wemb_h, l_hpu, l_hs, wemb_v, l_npu, l_token, g_sn=g_sn, g_sc=g_sc, g_sa=g_sa, g_wn=g_wn, g_dwn=g_dwn, g_wr=g_wr, g_wc=g_wc, g_wo=g_wo, g_wvi=g_wvi, g_wrcn=g_wrcn) #print('g_wvi: ', g_wvi[0]) ''' print('s_sn: ', s_sn) print('s_sc: ', s_sc) print('s_sa: ', s_sa) print('s_wn: ', s_wn) print('s_wr: ', s_wr) print('s_hrpc: ', s_hrpc) print('s_wrpc', s_wrpc) print('s_nrpc: ', s_nrpc) print('s_wc: ', s_wc) print('s_wo: ', s_wo) print('s_wv1: ', s_wv1) print('s_wv2: ', s_wv2) ''' # Calculate loss & step loss = Loss_sw_se(s_sn, s_sc, s_sa, s_wn, s_wr, s_hrpc, s_wc, s_wo, s_wv1, s_wv2, s_wv3, s_wv4, g_sn, g_sc, g_sa, g_wn, g_dwn, g_wr, g_wc, g_wo, g_wvi, g_wrcn, mvl) ''' print('ave_loss', ave_loss) print('loss: ', loss.item()) print('cnt: ', cnt) ''' # Calculate gradient if iB % accumulate_gradients == 0: # mode # at start, perform zero_grad opt.zero_grad() if opt_bert: opt_bert.zero_grad() loss.backward() if accumulate_gradients == 1: opt.step() if opt_bert: opt_bert.step() elif iB % accumulate_gradients == (accumulate_gradients - 1): # at the final, take step with accumulated graident loss.backward() opt.step() if opt_bert: opt_bert.step() else: # at intermediate stage, just accumulates the gradients loss.backward() #print('grad finish') # Prediction #print('s_wc: ', s_wc.size()) pr_sn, pr_sc, pr_sa, pr_wn, pr_wr, pr_hrpc, pr_wc, pr_wo, pr_wvi = pred_sw_se( s_sn, s_sc, s_sa, s_wn, s_wr, s_hrpc, s_wc, s_wo, s_wv1, s_wv2, s_wv3, s_wv4, mvl) ''' print('pr_sn: ', pr_sn) print('pr_sc: ', pr_sc) print('pr_sa: ', pr_sa) print('pr_wn: ', pr_wn) print('pr_wr: ', pr_wr) print('pr_hrpc: ', pr_hrpc) print('pr_wrpc', pr_wrpc) print('pr_nrpc: ', pr_nrpc) print('pr_wc: ', pr_wc) print('pr_wo: ', pr_wo) print('pr_wvi: ', pr_wvi) ''' pr_wvi_decode = g_wvi_decoder_stidx_length_jian_yi(pr_wvi) #print('pr_wvi_decode: ', pr_wvi_decode) pr_wv_str, pr_wv_str_wp = convert_pr_wvi_to_string( pr_wvi_decode, nlu_t, nlu_tt, tt_to_t_idx) #print('pr_wv_str: ', pr_wv_str) #print('pr_wv_str_wp: ', pr_wv_str_wp) # Sort pr_wc: # Sort pr_wc when training the model as pr_wo and pr_wvi are predicted using ground-truth where-column (g_wc) # In case of 'dev' or 'test', it is not necessary as the ground-truth is not used during inference. pr_sc_sorted = sort_pr_wc(pr_sc, g_sc) pr_wc_sorted = sort_pr_wc(pr_wc, g_wc) #print('pr_wc: ', pr_wc) #print('g_wc: ', g_wc) pr_sql_i = generate_sql_i(pr_sc_sorted, pr_sa, pr_wn, pr_wr, pr_wc_sorted, pr_wo, pr_wv_str, nlu) #print('pr_sql_i: ', pr_sql_i) # Cacluate accuracy cnt_sn1_list, cnt_sc1_list, cnt_sa1_list, cnt_wn1_list, \ cnt_wr1_list, cnt_wc1_list, cnt_wo1_list, \ cnt_wvi1_list, cnt_wv1_list = get_cnt_sw_list(g_sn, g_sc, g_sa, g_wn, g_wr, g_wc, g_wo, g_wvi, pr_sn, pr_sc, pr_sa, pr_wn, pr_wr, pr_wc, pr_wo, pr_wvi, sql_i, pr_sql_i, mode='train') ''' print('cnt_sn1_list: ', cnt_sn1_list) print('cnt_sc1_list: ', cnt_sc1_list) print('cnt_sa1_list: ', cnt_sa1_list) print('cnt_wn1_list: ', cnt_wn1_list) print('cnt_wr1_list: ', cnt_wr1_list) print('cnt_wc1_list: ', cnt_wc1_list) print('cnt_wo1_list', cnt_wo1_list) print('cnt_wvi1_list: ', cnt_wvi1_list) print('cnt_wv1_list: ', cnt_wv1_list) ''' cnt_lx1_list = get_cnt_lx_list(cnt_sn1_list, cnt_sc1_list, cnt_sa1_list, cnt_wn1_list, cnt_wr1_list, cnt_wc1_list, cnt_wo1_list, cnt_wv1_list) # lx stands for logical form accuracy # Execution accuracy test. cnt_x1_list, g_ans, pr_ans = get_cnt_x_list(engine, tb, g_sc, g_sa, sql_i, pr_sc, pr_sa, pr_sql_i) # statistics ave_loss += loss.item() ''' print('cnt_lx1_list: ', cnt_lx1_list) print('cnt_x1_list: ', cnt_x1_list) print('g_ans: ', g_ans) print('pr_ans: ', pr_ans) print('ave_loss: ', ave_loss) ''' # count cnt_sn += sum(cnt_sn1_list) cnt_sc += sum(cnt_sc1_list) cnt_sa += sum(cnt_sa1_list) cnt_wn += sum(cnt_wn1_list) cnt_wr += sum(cnt_wr1_list) cnt_wc += sum(cnt_wc1_list) cnt_wo += sum(cnt_wo1_list) cnt_wvi += sum(cnt_wvi1_list) cnt_wv += sum(cnt_wv1_list) cnt_lx += sum(cnt_lx1_list) cnt_x += sum(cnt_x1_list) if iB % 200 == 0: logger.info( '%d - th data batch -> loss: %.4f; acc_sn: %.4f; acc_sc: %.4f; acc_sa: %.4f; acc_wn: %.4f; acc_wr: %.4f; acc_wc: %.4f; acc_wo: %.4f; acc_wvi: %.4f; acc_wv: %.4f; acc_lx: %.4f; acc_x %.4f;' % (iB, ave_loss / cnt, cnt_sn / cnt, cnt_sc / cnt, cnt_sa / cnt, cnt_wn / cnt, cnt_wr / cnt, cnt_wc / cnt, cnt_wo / cnt, cnt_wvi / cnt, cnt_wv / cnt, cnt_lx / cnt, cnt_x / cnt)) #print('train: [ ', iB, '- th data batch -> loss:', ave_loss / cnt, '; acc_sn: ', cnt_sn / cnt, '; acc_sc: ', cnt_sc / cnt, '; acc_sa: ', cnt_sa / cnt, '; acc_wn: ', cnt_wn / cnt, '; acc_wr: ', cnt_wr / cnt, '; acc_wc: ', cnt_wc / cnt, '; acc_wo: ', cnt_wo / cnt, '; acc_wvi: ', cnt_wvi / cnt, '; acc_wv: ', cnt_wv / cnt, '; acc_lx: ', cnt_lx / cnt, '; acc_x: ', cnt_x / cnt, ' ]') ave_loss = ave_loss / cnt acc_sn = cnt_sn / cnt acc_sc = cnt_sc / cnt acc_sa = cnt_sa / cnt acc_wn = cnt_wn / cnt acc_wr = cnt_wr / cnt acc_wc = cnt_wc / cnt acc_wo = cnt_wo / cnt acc_wvi = cnt_wvi / cnt acc_wv = cnt_wv / cnt acc_lx = cnt_lx / cnt acc_x = cnt_x / cnt acc = [ ave_loss, acc_sn, acc_sc, acc_sa, acc_wn, acc_wr, acc_wc, acc_wo, acc_wvi, acc_wv, acc_lx, acc_x ] aux_out = 1 return acc, aux_out
def test(data_loader, data_table, model, model_bert, bert_config, tokenizer, max_seq_length, num_target_layers, detail=False, st_pos=0, cnt_tot=1, EG=False, beam_size=4, path_db=None, dset_name='test'): model.eval() model_bert.eval() print('g_scn/wc/wn/wo dev/test不监督') cnt = 0 engine = DBEngine(os.path.join(path_db, f"{dset_name}.db")) results = [] total = 0 one_acc_num, tot_acc_num, ex_acc_num = 0.0, 0.0, 0.0 for iB, t in enumerate(data_loader): cnt += len(t) if cnt < st_pos: continue # Get fields nlu, nlu_t, sql_i, sql_t, tb, hs_t, hds = get_fields(t, data_table, no_hs_t=True, no_sql_t=True) g_sc, g_sa, g_sop, g_wn, g_wc, g_wo, g_wv, g_sel_num_seq, g_sel_ag_seq, conds = get_g( sql_i) g_wvi_corenlp = get_g_wvi_corenlp(t) wemb_n, wemb_h, l_n, l_hpu, l_hs, \ nlu_tt, t_to_tt_idx, tt_to_t_idx \ = get_wemb_bert(bert_config, model_bert, tokenizer, nlu_t, hds, max_seq_length, num_out_layers_n=num_target_layers, num_out_layers_h=num_target_layers) try: g_wvi = get_g_wvi_bert_from_g_wvi_corenlp(t_to_tt_idx, g_wvi_corenlp) except: # Exception happens when where-condition is not found in nlu_tt. # In this case, that train example is not used. # During test, that example considered as wrongly answered. for b in range(len(nlu)): results1 = {} results1["error"] = "Skip happened" results1["nlu"] = nlu[b] results1["table_id"] = tb[b]["id"] results.append(results1) continue # model specific part # score if not EG: # No Execution guided decoding s_scn, s_sc, s_sa, s_sop, s_wn, s_wc, s_wo, s_wv = model( wemb_n, l_n, wemb_h, l_hpu, l_hs) # prediction score = [] score.append(s_scn) score.append(s_sc) score.append(s_sa) score.append(s_sop) tuple(score) pr_sql_i1 = model.gen_query(score, nlu_tt, nlu) pr_wn, pr_wc, pr_sop, pr_wo, pr_wvi = pred_sw_se( s_sop, s_wn, s_wc, s_wo, s_wv) pr_wv_str, pr_wv_str_wp = convert_pr_wvi_to_string( pr_wvi, nlu_t, nlu_tt, tt_to_t_idx, nlu) # 映射到字符串 pr_wc_sorted = sort_pr_wc(pr_wc, g_wc) pr_sql_i = generate_sql_i(pr_sql_i1, pr_wn, pr_wc_sorted, pr_wo, pr_wv_str, nlu) else: # Execution guided decoding prob_sca, prob_w, prob_wn_w, pr_sc, pr_sa, pr_wn, pr_sql_i = model.beam_forward( wemb_n, l_n, wemb_h, l_hpu, l_hs, engine, tb, nlu_t, nlu_tt, tt_to_t_idx, nlu, beam_size=beam_size) # sort and generate pr_wc, pr_wo, pr_wv, pr_sql_i = sort_and_generate_pr_w(pr_sql_i) # Follosing variables are just for the consistency with no-EG case. # # Saving for the official evaluation later. for b, pr_sql_i1 in enumerate(pr_sql_i): results1 = {} results1["query"] = pr_sql_i1 results1["table_id"] = tb[b]["id"] results1["nlu"] = nlu[b] results.append(results1) one_err, tot_err = model.check_acc(nlu, pr_sql_i, sql_i) one_acc_num += (len(pr_sql_i) - one_err) tot_acc_num += (len(pr_sql_i) - tot_err) total += len(pr_sql_i) # Execution Accuracy table_ids = [] for x in range(len(tb)): table_ids.append(tb[x]['id']) for sql_gt, sql_pred, tid in zip(sql_i, pr_sql_i, table_ids): ret_gt = engine.execute(tid, sql_gt['sel'], sql_gt['agg'], sql_gt['conds'], sql_gt['cond_conn_op']) try: ret_pred = engine.execute(tid, sql_pred['sel'], sql_pred['agg'], sql_pred['conds'], sql_pred['cond_conn_op']) except: ret_pred = None ex_acc_num += (ret_gt == ret_pred) return ((one_acc_num / total), (tot_acc_num / total), ex_acc_num / total), results
def train(train_loader, train_table, model, model_bert, opt, scheduler, tokenizer, sql_vocab, max_seq_length, accumulate_gradients=1, check_grad=False, st_pos=0, opt_bert=None, bert_scheduler=None, path_db=None, dset_name='train', col_pool_type='start_tok', aug=False): model.train() model_bert.train() model_old = deepcopy(model_bert) ave_loss = 0 cnt = 0 # count the # of examples cnt_x = 0 cnt_lx = 0 # of logical form acc # Engine for SQL querying. engine = DBEngine(os.path.join(path_db, f"{dset_name}.db")) for iB, t in enumerate(tqdm(train_loader)): cnt += len(t) opt_bert.zero_grad() opt.zero_grad() if cnt < st_pos: continue # Get fields nlu, nlu_t, sql_i, sql_q, sql_t, tb, hs_t, hds = get_fields( t, train_table, no_hs_t=True, no_sql_t=True) g_sc, g_sa, g_wn, g_wc, g_wo, g_wv = get_g(sql_i) # get ground truth where-value index under CoreNLP tokenization scheme. It's done already on trainset. g_wvi_corenlp = get_g_wvi_corenlp(t) # g_wvi_corenlp = get_g_wvi_corenlp(t) all_encoder_layer, pooled_output, tokens, i_nlu, i_hds, i_sql_vocab, \ l_n, l_hpu, l_hs, l_input, \ nlu_tt, t_to_tt_idx, tt_to_t_idx \ = get_bert_output_s2s(model_bert, tokenizer, nlu_t, hds, sql_vocab, max_seq_length,sample=True) try: # g_wvi = get_g_wvi_bert_from_g_wvi_corenlp(t_to_tt_idx, g_wvi_corenlp) except: # Exception happens when where-condition is not found in nlu_tt. # In this case, that train example is not used. # During test, that example considered as wrongly answered. # e.g. train: 32. continue # Generate g_pnt_idx g_pnt_idxs = gen_g_pnt_idx(g_wvi, sql_i, i_hds, i_sql_vocab, col_pool_type=col_pool_type) pnt_start_tok = i_sql_vocab[0][-2][0] pnt_end_tok = i_sql_vocab[0][-1][0] # check # print(array(tokens[0])[g_pnt_idxs[0]]) wenc_s2s = all_encoder_layer[-1] # wemb_h = [B, max_header_number, hS] cls_vec = pooled_output score = model(wenc_s2s, l_input, cls_vec, pnt_start_tok, g_pnt_idxs=g_pnt_idxs) # Calculate loss & step loss = Loss_s2s(score, g_pnt_idxs) ## calculate UR loss = custom_regularization(model_old, model_bert, args.bS, loss) loss.backward() opt_bert.step() opt.step() bert_scheduler.step() scheduler.step() if check_grad: named_parameters = model.named_parameters() mu_list, sig_list = get_mean_grad(named_parameters) grad_abs_mean_mean = mean(mu_list) grad_abs_mean_sig = std(mu_list) grad_abs_sig_mean = mean(sig_list) else: grad_abs_mean_mean = 1 grad_abs_mean_sig = 1 grad_abs_sig_mean = 1 # Prediction pr_pnt_idxs = pred_pnt_idxs(score, pnt_start_tok, pnt_end_tok) # generate pr_sql_q # pr_sql_q_rough = generate_sql_q_s2s(pr_pnt_idxs, tokens, tb) # g_sql_q_rough = generate_sql_q_s2s(g_pnt_idxs, tokens, tb) g_i_vg_list, g_i_vg_sub_list = gen_i_vg_from_pnt_idxs( g_pnt_idxs, i_sql_vocab, i_nlu, i_hds) g_sql_q_s2s, g_sql_i = gen_sql_q_from_i_vg(tokens, nlu, nlu_t, hds, tt_to_t_idx, pnt_start_tok, pnt_end_tok, g_pnt_idxs, g_i_vg_list, g_i_vg_sub_list) pr_i_vg_list, pr_i_vg_sub_list = gen_i_vg_from_pnt_idxs( pr_pnt_idxs, i_sql_vocab, i_nlu, i_hds) pr_sql_q_s2s, pr_sql_i = gen_sql_q_from_i_vg( tokens, nlu, nlu_t, hds, tt_to_t_idx, pnt_start_tok, pnt_end_tok, pr_pnt_idxs, pr_i_vg_list, pr_i_vg_sub_list) g_sql_q = generate_sql_q(sql_i, tb) try: pr_sql_q = generate_sql_q(pr_sql_i, tb) # gen pr_sc, pr_sa pr_sc = [] pr_sa = [] for pr_sql_i1 in pr_sql_i: pr_sc.append(pr_sql_i1["sel"]) pr_sa.append(pr_sql_i1["agg"]) except: bS = len(sql_i) pr_sql_q = ['NA'] * bS pr_sc = ['NA'] * bS pr_sa = ['NA'] * bS # Cacluate accuracy cnt_lx1_list = get_cnt_lx_list_s2s(g_pnt_idxs, pr_pnt_idxs) # if not aug: # cnt_x1_list, g_ans, pr_ans = get_cnt_x_list(engine, tb, g_sc, g_sa, sql_i, pr_sc, pr_sa, pr_sql_i) # else: cnt_x1_list = [0] * len(t) g_ans = ['N/A (data augmented'] * len(t) pr_ans = ['N/A (data augmented'] * len(t) # statistics ave_loss += loss.item() # count cnt_lx += sum(cnt_lx1_list) cnt_x += sum(cnt_x1_list) ave_loss /= cnt acc_lx = cnt_lx / cnt acc_x = cnt_x / cnt acc = [ave_loss, acc_lx, acc_x] aux_out = [grad_abs_mean_mean, grad_abs_mean_sig, grad_abs_sig_mean] return acc, aux_out
def run_split(split): engine = DBEngine(os.path.join(path_db, f"{split}.db")) with open(split + '_tok.jsonl') as f: for idx, line in enumerate(f): t1 = json.loads(line.strip())
def train(train_loader, train_table, model, model_bert, opt, bert_config, tokenizer, max_seq_length, num_target_layers, accumulate_gradients=1, check_grad=True, st_pos=0, opt_bert=None, path_db=None, dset_name='train'): model.train() model_bert.train() ave_loss = 0 cnt = 0 # count the # of examples cnt_sc = 0 # count the # of correct predictions of select column cnt_sa = 0 # of selectd aggregation cnt_wn = 0 # of where number cnt_wc = 0 # of where column cnt_wo = 0 # of where operator cnt_wv = 0 # of where-value cnt_wvi = 0 # of where-value index (on question tokens) cnt_lx = 0 # of logical form acc cnt_x = 0 # of execution acc # Engine for SQL querying. engine = DBEngine(os.path.join(path_db, f"{dset_name}.db")) for iB, t in enumerate(train_loader): cnt += len(t) if cnt < st_pos: continue # Get fields nlu, nlu_t, sql_i, sql_q, sql_t, tb, hs_t, hds = get_fields( t, train_table, no_hs_t=True, no_sql_t=True) # nlu : natural language utterance # nlu_t: tokenized nlu # sql_i: canonical form of SQL query # sql_q: full SQL query text. Not used. # sql_t: tokenized SQL query # tb : table # hs_t : tokenized headers. Not used. g_sc, g_sa, g_wn, g_wc, g_wo, g_wv = get_g(sql_i) # get ground truth where-value index under CoreNLP tokenization scheme. It's done already on trainset. g_wvi_corenlp = get_g_wvi_corenlp(t) wemb_n, wemb_h, l_n, l_hpu, l_hs, \ nlu_tt, t_to_tt_idx, tt_to_t_idx \ = get_wemb_bert(bert_config, model_bert, tokenizer, nlu_t, hds, max_seq_length, num_out_layers_n=num_target_layers, num_out_layers_h=num_target_layers) # wemb_n: natural language embedding # wemb_h: header embedding # l_n: token lengths of each question # l_hpu: header token lengths # l_hs: the number of columns (headers) of the tables. try: # g_wvi = get_g_wvi_bert_from_g_wvi_corenlp(t_to_tt_idx, g_wvi_corenlp) except: # Exception happens when where-condition is not found in nlu_tt. # In this case, that train example is not used. # During test, that example considered as wrongly answered. # e.g. train: 32. continue # score s_sc, s_sa, s_wn, s_wc, s_wo, s_wv = model(wemb_n, l_n, wemb_h, l_hpu, l_hs, g_sc=g_sc, g_sa=g_sa, g_wn=g_wn, g_wc=g_wc, g_wvi=g_wvi) # Calculate loss & step loss = Loss_sw_se(s_sc, s_sa, s_wn, s_wc, s_wo, s_wv, g_sc, g_sa, g_wn, g_wc, g_wo, g_wvi) # Calculate gradient if iB % accumulate_gradients == 0: # mode # at start, perform zero_grad opt.zero_grad() if opt_bert: opt_bert.zero_grad() loss.backward() if accumulate_gradients == 1: opt.step() if opt_bert: opt_bert.step() elif iB % accumulate_gradients == (accumulate_gradients - 1): # at the final, take step with accumulated graident loss.backward() opt.step() if opt_bert: opt_bert.step() else: # at intermediate stage, just accumulates the gradients loss.backward() # Prediction pr_sc, pr_sa, pr_wn, pr_wc, pr_wo, pr_wvi = pred_sw_se( s_sc, s_sa, s_wn, s_wc, s_wo, s_wv, ) pr_wv_str, pr_wv_str_wp = convert_pr_wvi_to_string( pr_wvi, nlu_t, nlu_tt, tt_to_t_idx, nlu) # Sort pr_wc: # Sort pr_wc when training the model as pr_wo and pr_wvi are predicted using ground-truth where-column (g_wc) # In case of 'dev' or 'test', it is not necessary as the ground-truth is not used during inference. pr_wc_sorted = sort_pr_wc(pr_wc, g_wc) pr_sql_i = generate_sql_i(pr_sc, pr_sa, pr_wn, pr_wc_sorted, pr_wo, pr_wv_str, nlu) # Cacluate accuracy cnt_sc1_list, cnt_sa1_list, cnt_wn1_list, \ cnt_wc1_list, cnt_wo1_list, \ cnt_wvi1_list, cnt_wv1_list = get_cnt_sw_list(g_sc, g_sa, g_wn, g_wc, g_wo, g_wvi, pr_sc, pr_sa, pr_wn, pr_wc, pr_wo, pr_wvi, sql_i, pr_sql_i, mode='train') cnt_lx1_list = get_cnt_lx_list(cnt_sc1_list, cnt_sa1_list, cnt_wn1_list, cnt_wc1_list, cnt_wo1_list, cnt_wv1_list) # lx stands for logical form accuracy # Execution accuracy test. cnt_x1_list, g_ans, pr_ans = get_cnt_x_list(engine, tb, g_sc, g_sa, sql_i, pr_sc, pr_sa, pr_sql_i) # statistics ave_loss += loss.item() # count cnt_sc += sum(cnt_sc1_list) cnt_sa += sum(cnt_sa1_list) cnt_wn += sum(cnt_wn1_list) cnt_wc += sum(cnt_wc1_list) cnt_wo += sum(cnt_wo1_list) cnt_wvi += sum(cnt_wvi1_list) cnt_wv += sum(cnt_wv1_list) cnt_lx += sum(cnt_lx1_list) cnt_x += sum(cnt_x1_list) ave_loss /= cnt acc_sc = cnt_sc / cnt acc_sa = cnt_sa / cnt acc_wn = cnt_wn / cnt acc_wc = cnt_wc / cnt acc_wo = cnt_wo / cnt acc_wvi = cnt_wv / cnt acc_wv = cnt_wv / cnt acc_lx = cnt_lx / cnt acc_x = cnt_x / cnt acc = [ ave_loss, acc_sc, acc_sa, acc_wn, acc_wc, acc_wo, acc_wvi, acc_wv, acc_lx, acc_x ] aux_out = 1 return acc, aux_out
def train(train_loader, train_table, model, model_bert, opt, bert_config, tokenizer, max_seq_length, num_target_layers, accumulate_gradients=1, check_grad=True, st_pos=0, opt_bert=None, path_db=None, dset_name='train'): model.train() model_bert.train() ave_loss = 0 cnt = 0 # count the # of examples cnt_sc = 0 # count the # of correct predictions of select column cnt_sa = 0 # of selectd aggregation cnt_wn = 0 # of where number cnt_wc = 0 # of where column cnt_wo = 0 # of where operator cnt_wv = 0 # of where-value cnt_wvi = 0 # of where-value index (on question tokens) cnt_lx = 0 # of logical form acc cnt_x = 0 # of execution acc # Engine for SQL querying. engine = DBEngine(os.path.join(path_db, f"{dset_name}.db")) for iB, t in enumerate(train_loader): # t is the whole line from *_tok.jsonl cnt += len(t) if cnt < st_pos: continue # Get fields nlu, nlu_t, sql_i, sql_q, sql_t, tb, hs_t, hds = get_fields( t, train_table, no_hs_t=True, no_sql_t=True) # nlu : natural language utterance. A whole sentence. # nlu_t: tokenized nlu. token word of sentence, not numbers. # sql_i: canonical form of SQL query. I saw it equal to sql_q # sql_q: full SQL query text. Not used. # sql_t: tokenized SQL query. Now is none. # tb : table with content # hs_t : tokenized headers. Not used. # hds : head of table g_sc, g_sa, g_wn, g_wc, g_wo, g_wv = get_g(sql_i) # get ground truth where-value index under CoreNLP tokenization scheme. It's done already on trainset. # g_sc select column; g_sa agg type; g_wn the amount of where condition; # g_wc where column; g_wo where operator; g_wv where value (put after operator) g_wvi_corenlp = get_g_wvi_corenlp(t) # the index of where value in NL. The index is a pair that show start index and stop index. wemb_n, wemb_h, l_n, l_hpu, l_hs, \ nlu_tt, t_to_tt_idx, tt_to_t_idx \ = get_wemb_bert(bert_config, model_bert, tokenizer, nlu_t, hds, max_seq_length, num_out_layers_n=num_target_layers, num_out_layers_h=num_target_layers) # wemb_n: natural language embedding. [batch, seq_len, hidden * num_target_layers] # wemb_h: header embedding. [table_header_amount_in_batch, max_length_header_token, hidden * num_target_layers] # The first and third dimension of wemb_h always have valid data. The second dimension may be not. # Invalid data will fill 0 in it. Use l_hpu for finding the valid data. # l_n: token lengths of each question # l_hpu: header token lengths. This is a one dimension list contain several table header. # l_hs: the number of columns (headers) of the tables. Can be used for split the first dimension of wemb_h. # You can check encode_hpu in utils_wikisql.py for the reason of wemb_h and l_hpu. This design is good! try: # g_wvi = get_g_wvi_bert_from_g_wvi_corenlp(t_to_tt_idx, g_wvi_corenlp) except: # Exception happens when where-condition is not found in nlu_tt. # In this case, that train example is not used. # During test, that example considered as wrongly answered. # e.g. {"question":"What is the highest number of not usable satellites when there are more than 0 launch failures, less than 30 retired, and the block is I?", # "question_tok":["What","is","the","highest","number","of","not","usable","satellites","when","there","are","more","than","0","launch","failures",",","less","than","30","retired",",","and","the","block","is","I","?"], # "sql":{"sel":3,"conds":[[5,1,0],[4,2,30],[0,0,"block i"]],"agg":1}, # "query":{"sel":3,"conds":[[5,1,0],[4,2,30],[0,0,"block i"]],"agg":1}, # "wvi_corenlp":null, # "tok_error":"SQuAD style st, ed are not found under CoreNLP."} # In example, the condition is 'block i' but NL is 'block is i'! So the "wvi_corenlp":null. # If there is no condition, "wvi_corenlp" will be [] and will not cause except. # If there is except, you will loss one batch training data. continue # score # s_sc: [batch, max_table_header_number] # s_sa: [batch, n_agg_ops] # s_wn: [batch, max_where_condition_in_wikisql + 1] +1 for when no conditon. So, it can become a classification problem. # s_wc: [batch, max_table_header_number] # s_wo: [batch, max_where_condition_in_wikisql, n_cond_ops] # s_wv: [batch, max_where_condition_in_wikisql, max_NL_Len_in_batch, 2] s_sc, s_sa, s_wn, s_wc, s_wo, s_wv = model(wemb_n, l_n, wemb_h, l_hpu, l_hs, g_sc=g_sc, g_sa=g_sa, g_wn=g_wn, g_wc=g_wc, g_wvi=g_wvi) # Calculate loss & step loss = Loss_sw_se(s_sc, s_sa, s_wn, s_wc, s_wo, s_wv, g_sc, g_sa, g_wn, g_wc, g_wo, g_wvi) # Calculate gradient if iB % accumulate_gradients == 0: # mode # at start, perform zero_grad opt.zero_grad() if opt_bert: opt_bert.zero_grad() loss.backward() if accumulate_gradients == 1: opt.step() if opt_bert: opt_bert.step() elif iB % accumulate_gradients == (accumulate_gradients - 1): # at the final, take step with accumulated graident loss.backward() opt.step() if opt_bert: opt_bert.step() else: # at intermediate stage, just accumulates the gradients loss.backward() # Prediction pr_sc, pr_sa, pr_wn, pr_wc, pr_wo, pr_wvi = pred_sw_se( s_sc, s_sa, s_wn, s_wc, s_wo, s_wv, ) pr_wv_str, pr_wv_str_wp = convert_pr_wvi_to_string( pr_wvi, nlu_t, nlu_tt, tt_to_t_idx, nlu) pr_sql_i = generate_sql_i(pr_sc, pr_sa, pr_wn, pr_wc, pr_wo, pr_wv_str, nlu) # Cacluate accuracy cnt_sc1_list, cnt_sa1_list, cnt_wn1_list, \ cnt_wc1_list, cnt_wo1_list, \ cnt_wvi1_list, cnt_wv1_list = get_cnt_sw_list(g_sc, g_sa, g_wn, g_wc, g_wo, g_wvi, pr_sc, pr_sa, pr_wn, pr_wc, pr_wo, pr_wvi, sql_i, pr_sql_i, mode='train') cnt_lx1_list = get_cnt_lx_list(cnt_sc1_list, cnt_sa1_list, cnt_wn1_list, cnt_wc1_list, cnt_wo1_list, cnt_wv1_list) # lx stands for logical form accuracy # Execution accuracy test. cnt_x1_list, g_ans, pr_ans = get_cnt_x_list(engine, tb, g_sc, g_sa, sql_i, pr_sc, pr_sa, pr_sql_i) # statistics ave_loss += loss.item() # count cnt_sc += sum(cnt_sc1_list) cnt_sa += sum(cnt_sa1_list) cnt_wn += sum(cnt_wn1_list) cnt_wc += sum(cnt_wc1_list) cnt_wo += sum(cnt_wo1_list) cnt_wvi += sum(cnt_wvi1_list) cnt_wv += sum(cnt_wv1_list) cnt_lx += sum(cnt_lx1_list) cnt_x += sum(cnt_x1_list) ave_loss /= cnt acc_sc = cnt_sc / cnt acc_sa = cnt_sa / cnt acc_wn = cnt_wn / cnt acc_wc = cnt_wc / cnt acc_wo = cnt_wo / cnt acc_wvi = cnt_wv / cnt acc_wv = cnt_wv / cnt acc_lx = cnt_lx / cnt acc_x = cnt_x / cnt acc = [ ave_loss, acc_sc, acc_sa, acc_wn, acc_wc, acc_wo, acc_wvi, acc_wv, acc_lx, acc_x ] aux_out = 1 return acc, aux_out
def infernew(dev_loader, data_table, model, model_bert, bert_config, tokenizer, max_seq_length, num_target_layers, detail=False, path_db=None, st_pos=0, dset_name='train', EG=False, beam_size=4): model.eval() model_bert.eval() ave_loss = 0 cnt = 0 cnt_sc = 0 cnt_sa = 0 cnt_wn = 0 cnt_wc = 0 cnt_wo = 0 cnt_wv = 0 cnt_wvi = 0 cnt_lx = 0 cnt_x = 0 cnt_list = [] engine = DBEngine(os.path.join(path_db, f"{dset_name}.db")) results = [] count = 0 for iB, t in enumerate(dev_loader): nlu, nlu_t, sql_i, sql_q, sql_t, tb, hs_t, hds = get_fields( t, data_table, no_hs_t=True, no_sql_t=True) g_sc, g_sa, g_wn, g_wc, g_wo, g_wv = get_g(sql_i) g_wvi_corenlp = get_g_wvi_corenlp(t) wemb_n, wemb_h, l_n, l_hpu, l_hs, \ nlu_tt, t_to_tt_idx, tt_to_t_idx \ = get_wemb_bert(bert_config, model_bert, tokenizer, nlu_t, hds, max_seq_length, num_out_layers_n=num_target_layers, num_out_layers_h=num_target_layers) try: g_wvi = get_g_wvi_bert_from_g_wvi_corenlp(t_to_tt_idx, g_wvi_corenlp) g_wv_str, g_wv_str_wp = convert_pr_wvi_to_string( g_wvi, nlu_t, nlu_tt, tt_to_t_idx, nlu) except: # Exception happens when where-condition is not found in nlu_tt. # In this case, that train example is not used. # During test, that example considered as wrongly answered. for b in range(len(nlu)): results1 = {} results1["error"] = "Skip happened" results1["nlu"] = nlu[b] results1["table_id"] = tb[b]["id"] results.append(results1) continue knowledge = [] for k in t: if "bertindex_knowledge" in k: knowledge.append(k["bertindex_knowledge"]) else: knowledge.append(max(l_n) * [0]) knowledge_header = [] for k in t: if "header_knowledge" in k: knowledge_header.append(k["header_knowledge"]) else: knowledge_header.append(max(l_hs) * [0]) if not EG: # No Execution guided decoding s_sc, s_sa, s_wn, s_wc, s_wo, s_wv = model( wemb_n, l_n, wemb_h, l_hpu, l_hs, knowledge=knowledge, knowledge_header=knowledge_header) # get loss & step loss = Loss_sw_se(s_sc, s_sa, s_wn, s_wc, s_wo, s_wv, g_sc, g_sa, g_wn, g_wc, g_wo, g_wvi) # prediction pr_sc, pr_sa, pr_wn, pr_wc, pr_wo, pr_wvi = pred_sw_se( s_sc, s_sa, s_wn, s_wc, s_wo, s_wv, ) pr_wv_str, pr_wv_str_wp = convert_pr_wvi_to_string( pr_wvi, nlu_t, nlu_tt, tt_to_t_idx, nlu) # g_sql_i = generate_sql_i(g_sc, g_sa, g_wn, g_wc, g_wo, g_wv_str, nlu) pr_sql_i = generate_sql_i(pr_sc, pr_sa, pr_wn, pr_wc, pr_wo, pr_wv_str, nlu) else: # Execution guided decoding prob_sca, prob_w, prob_wn_w, pr_sc, pr_sa, pr_wn, pr_sql_i = model.beam_forward( wemb_n, l_n, wemb_h, l_hpu, l_hs, engine, tb, nlu_t, nlu_tt, tt_to_t_idx, nlu, beam_size=beam_size, knowledge=knowledge, knowledge_header=knowledge_header) # sort and generate pr_wc, pr_wo, pr_wv, pr_sql_i = sort_and_generate_pr_w(pr_sql_i) # Follosing variables are just for the consistency with no-EG case. pr_wvi = None # not used pr_wv_str = None pr_wv_str_wp = None loss = torch.tensor([0]) pr_sql_q1 = generate_sql_q(pr_sql_i, tb) pr_sql_q = [pr_sql_q1] try: pr_ans, _ = engine.execute_return_query(tb[0]['id'], pr_sc[0], pr_sa[0], pr_sql_i[0]['conds']) except: pr_ans = ['Answer not found.'] pr_sql_q = ['Answer not found.'] yes = True if yes: print(f'Q: {nlu[0]}') print(f'A: {pr_ans[0]}') print(f'SQL: {pr_sql_q}') else: print( f'START ============================================================= ' ) print(f'{hds}') if yes: print(engine.show_table(table_name)) print(f'nlu: {nlu}') print(f'pr_sql_i : {pr_sql_i}') print(f'pr_sql_q : {pr_sql_q}') print(f'pr_ans: {pr_ans}') print( f'---------------------------------------------------------------------' ) return pr_sql_i, pr_ans
def infer(nlu1, table_name, data_table, path_db, db_name, model, model_bert, bert_config, max_seq_length, num_target_layers, beam_size=4, show_table=False, show_answer_only=False): # I know it is of against the DRY principle but to minimize the risk of introducing bug w, the infer function introuced. model.eval() model_bert.eval() engine = DBEngine(os.path.join(path_db, f"{db_name}.db")) # 问题输入 nlu = [nlu1] #问题数组 ''' ==tokenize_corenlp_direct_version函数作用:就是英文分词(可能按照stanza规则分?)== ==client:stanford的corenlp代理类== ==nlu1:刚刚定义的问题列表== 2020/12/02修改:修改infer中文分词问题 2020/12/11修改:取消使用stanza中文分词,直接完全分词 ''' # nlu_t1 = tokenize_corenlp_direct_version(client, nlu1) nlu_t1 = list(nlu1) nlu_t = [nlu_t1] # 把分词之后的数据也放到数组里 #tb1 = data_table[0] ''' 2020/12/01修改:tb1根据问题来选择表 循环查找即可 ''' for temple_table in data_table: if temple_table['name'] == table_name: tb1 = temple_table break hds1 = tb1['header'] tb = [tb1] hds = [hds1] hs_t = [[]] # 获取bert-output wemb_n, wemb_h, l_n, l_hpu, l_hs, \ nlu_tt, t_to_tt_idx, tt_to_t_idx \ = get_wemb_bert(bert_config, model_bert, tokenizer, nlu_t, hds, max_seq_length, num_out_layers_n=num_target_layers, num_out_layers_h=num_target_layers) # 获取sqlova-output '''2020/12/11修改:换用model预测''' # prob_sca, prob_w, prob_wn_w, pr_sc, pr_sa, pr_wn, pr_sql_i = model.beam_forward(wemb_n, l_n, wemb_h, l_hpu, # l_hs, engine, tb, # nlu_t, nlu_tt, # tt_to_t_idx, nlu, # beam_size=beam_size) # 获取sqlova得出的6大重要部分参数 pr_sc pr_sa s_sc, s_sa, s_wn, s_wc, s_wo, s_wv = model(wemb_n, l_n, wemb_h, l_hpu, l_hs) # 将权重参数变成值 pr_sc, pr_sa, pr_wn, pr_wc, pr_wo, pr_wvi = pred_sw_se( s_sc, s_sa, s_wn, s_wc, s_wo, s_wv, ) # 根据值得出where-value(这里是分词版,一个一个切分) pr_wv_str, pr_wv_str_wp = convert_pr_wvi_to_string(pr_wvi, nlu_t, nlu_tt, tt_to_t_idx, nlu) # 最后在将值整合成conds(pre版本) '''2020/12/03修改:将agg和sel变为列表形式(generate_sql_i函数内修改)''' pr_sql_i = generate_sql_i(pr_sc, pr_sa, pr_wn, pr_wc, pr_wo, pr_wv_str, nlu) # 切分出where-col/where-op/where-val pr_wc, pr_wo, pr_wv, pr_sql_i = sort_and_generate_pr_w(pr_sql_i) if len(pr_sql_i) != 1: # 判断是不是生成了conds raise EnvironmentError pr_sql_q1 = generate_sql_q(pr_sql_i, [tb1]) # 根据上面的conds生成sql语句 pr_sql_q = [pr_sql_q1] # 将生成的sql语句放到list里,因为infer可能有多句 '''下面执行SQL语句''' try: pr_ans, _ = engine.execute_return_query(tb[0]['id'], pr_sc[0], pr_sa[0], pr_sql_i[0]['conds']) except: pr_ans = ['Answer not found.'] pr_sql_q = ['Answer not found.'] if show_answer_only: print(f'Q: {nlu[0]}') print(f'A: {pr_ans[0]}') print(f'SQL: {pr_sql_q}') else: print( f'START ============================================================= ' ) print(f'{hds}') if show_table: print(engine.show_table(table_name)) print(f'nlu: {nlu}') print(f'pr_sql_i : {pr_sql_i}') print(f'pr_sql_q : {pr_sql_q}') print(f'pr_ans: {pr_ans}') print( f'---------------------------------------------------------------------' ) return pr_sql_i, pr_ans
def test(data_loader, data_table, model, model_bert, bert_config, tokenizer, max_seq_length, num_target_layers, detail=False, st_pos=0, cnt_tot=1, path_db=None, dset_name='test'): model.eval() model_bert.eval() ave_loss = 0 cnt = 0 cnt_sc = 0 cnt_sa = 0 cnt_wn = 0 cnt_wc = 0 cnt_wo = 0 cnt_wv = 0 cnt_wvi = 0 cnt_ao = 0 cnt_ord = 0 cnt_lx = 0 cnt_x = 0 cnt_list = [] engine = DBEngine(os.path.join(path_db, f"{dset_name}.db")) results = [] many = 0 for iB, t in enumerate(data_loader): cnt += len(t) if cnt < st_pos: continue # Get fields nlu, nlu_t, sql_i, sql_q, sql_t, tb, hs_t, hds, hs_type = get_fields( t, data_table, no_hs_t=True, no_sql_t=True) g_sc, g_sa, g_wn, g_wc, g_wo, g_wv, g_ao, g_ord = get_g(sql_i) g_wvi_corenlp = get_g_wvi_corenlp(t) wemb_n, wemb_h, l_n, l_hpu, l_hs, \ nlu_tt, t_to_tt_idx, tt_to_t_idx \ = get_wemb_bert(bert_config, model_bert, tokenizer, nlu_t, hds, max_seq_length, num_out_layers_n=num_target_layers, num_out_layers_h=num_target_layers) try: g_wvi = get_g_wvi_bert_from_g_wvi_corenlp(t_to_tt_idx, g_wvi_corenlp) g_wv_str, g_wv_str_wp = convert_pr_wvi_to_string( g_wvi, nlu_t, nlu_tt, tt_to_t_idx, nlu, hs_type) except: # Exception happens when where-condition is not found in nlu_tt. # In this case, that train example is not used. # During test, that example considered as wrongly answered. for b in range(len(nlu)): results1 = {} results1["error"] = "Skip happened" results1["nlu"] = nlu[b] results1["table_id"] = tb[b]["id"] results.append(results1) continue s_sc, s_sa, s_wn, s_wc, s_wo, s_wv, s_ao, s_ord = model( wemb_n, l_n, wemb_h, l_hpu, l_hs, hs_type=hs_type) # Calculate loss & step loss = Loss_sw_se(s_sc, s_sa, s_wn, s_wc, s_wo, s_wv, s_ao, s_ord, g_sc, g_sa, g_wn, g_wc, g_wo, g_wvi, g_ao, g_ord, hs_type) # prediction pr_sc, pr_sa, pr_wn, pr_wc, pr_wo, pr_wvi, pr_ao, pr_ord = pred_sw_se( s_sc, s_sa, s_wn, s_wc, s_wo, s_wv, s_ao, s_ord, hs_type) pr_wv_str, pr_wv_str_wp = convert_pr_wvi_to_string( pr_wvi, nlu_t, nlu_tt, tt_to_t_idx, nlu, hs_type) # g_sql_i = generate_sql_i(g_sc, g_sa, g_wn, g_wc, g_wo, g_wv_str, nlu) pr_sql_i = generate_sql_i(pr_sc, pr_sa, pr_wn, pr_wc, pr_wo, pr_wv_str, pr_ao, pr_ord, nlu, hs_type) # Saving for the official evaluation later. for b, pr_sql_i1 in enumerate(pr_sql_i): results1 = {} results1["query"] = pr_sql_i1 results1["table_id"] = tb[b]["id"] results1["nlu"] = nlu[b] results.append(results1) cnt_sc1_list, cnt_sa1_list, cnt_wn1_list, \ cnt_wc1_list, cnt_wo1_list, \ cnt_wvi1_list, cnt_wv1_list,cnt_ao1_list, cnt_ord1_list = get_cnt_sw_list(g_sc, g_sa, g_wn, g_wc, g_wo, g_wvi,g_ao,g_ord, pr_sc, pr_sa, pr_wn, pr_wc, pr_wo, pr_wvi, pr_ao, pr_ord, sql_i, pr_sql_i,hs_type, mode='test') cnt_lx1_list = get_cnt_lx_list(cnt_sc1_list, cnt_sa1_list, cnt_wn1_list, cnt_wc1_list, cnt_wo1_list, cnt_wv1_list, cnt_ao1_list, cnt_ord1_list, hs_type) # Execution accura y test cnt_x1_list = [] # lx stands for logical form accuracy # Execution accuracy test. cnt_x1_list, g_ans, pr_ans = get_cnt_x_list(engine, tb, g_sc, g_sa, sql_i, pr_sc, pr_sa, pr_sql_i, pr_ao, pr_ord, hds) # xiajing:need change # stat ave_loss += loss.item() # count cnt_sc += sum(cnt_sc1_list) cnt_sa += sum(cnt_sa1_list) cnt_wn += sum(cnt_wn1_list) cnt_wc += sum(cnt_wc1_list) cnt_wo += sum(cnt_wo1_list) cnt_wv += sum(cnt_wv1_list) cnt_ao += sum(cnt_ao1_list) cnt_ord += sum(cnt_ord1_list) cnt_wvi += sum(cnt_wvi1_list) cnt_lx += sum(cnt_lx1_list) cnt_x += sum(cnt_x1_list) current_cnt = [ cnt_tot, cnt, cnt_sc, cnt_sa, cnt_wn, cnt_wc, cnt_wo, cnt_wv, cnt_wvi, cnt_ao, cnt_ord, cnt_lx, cnt_x ] cnt_list1 = [ cnt_sc1_list, cnt_sa1_list, cnt_wn1_list, cnt_wc1_list, cnt_wo1_list, cnt_wv1_list, cnt_ao1_list, cnt_ord1_list, cnt_lx1_list, cnt_x1_list ] cnt_list.append(cnt_list1) ave_loss /= cnt acc_sc = cnt_sc / cnt acc_sa = cnt_sa / cnt acc_wn = cnt_wn / cnt acc_wc = cnt_wc / cnt acc_wo = cnt_wo / cnt acc_wvi = cnt_wvi / cnt acc_wv = cnt_wv / cnt acc_ao = cnt_ao / cnt acc_ord = cnt_ord / cnt acc_lx = cnt_lx / cnt acc_x = cnt_x / cnt acc = [ ave_loss, acc_sc, acc_sa, acc_wn, acc_wc, acc_wo, acc_wvi, acc_wv, acc_ao, acc_ord, acc_lx, acc_x ] return acc, results, cnt_list
def test(data_loader, data_table, model, model_bert, bert_config, tokenizer, max_seq_length, num_target_layers, detail=False, st_pos=0, cnt_tot=1, EG=False, beam_size=4, path_db=None, dset_name='test'): model.eval() model_bert.eval() ave_loss = 0 cnt = 0 cnt_sc = 0 cnt_sa = 0 cnt_wn = 0 cnt_wc = 0 cnt_wo = 0 cnt_wv = 0 cnt_wvi = 0 cnt_lx = 0 cnt_x = 0 cnt_list = [] #初始化数据库查询引擎 engine = DBEngine(os.path.join(path_db, f"{dset_name}.db")) results = [] #相比于train,他就多了个收取结果的数组 for iB, t in enumerate(data_loader): '''t为bs个数据的详情''' cnt += len(t) #每个循环中的例子数 if cnt < st_pos: continue # 将问题拆分成一个个部分,并且结合问题所对应的表 nlu, nlu_t, sql_i, sql_q, sql_t, tb, hs_t, hds = get_fields( t, data_table, no_hs_t=True, no_sql_t=True) # nlu:bs个问题 # nlu_t:标记化的问题,这里不分词 # sql_i:SQL查询的规范形式 # sql_q:完整的SQL查询文本。 不曾用过。 # sql_t:没软用 # tb:bs个问题对应的表格(不一定一对一,但是保证bs个问题要找的表在里面) # hs_t:标记化的标头。 不曾用过。 # hds:表头 '''分别获取bs个问题的sc, sa, wn, wc, wo, wv(多个wn都放在list里)''' g_sc, g_sa, g_wn, g_wc, g_wo, g_wv = get_g(sql_i) '''这个是获取loader里WV的起止''' g_wvi_corenlp = get_g_wvi_corenlp(t) wemb_n, wemb_h, l_n, l_hpu, l_hs, \ nlu_tt, t_to_tt_idx, tt_to_t_idx \ = get_wemb_bert(bert_config, model_bert, tokenizer, nlu_t, hds, max_seq_length, num_out_layers_n=num_target_layers, num_out_layers_h=num_target_layers) '''函数作用:获取所有从bert模型中输出的参数''' # wemb_n: 问题的参数 # wemb_h: 表字段的参数 # l_n: 问题的长度 # l_hpu: 我们不是把问题和表头合在一起了嘛,这就是通过表头的起始,获取每个表头字段的长度 # l_hs: 表字段总数 # nlu_tt: 已经分词了的问题 # t_to_tt_idx: 将已分词的每个字(词)标记它的序号 # tt_to_t_idx: 同上? try: g_wvi = get_g_wvi_bert_from_g_wvi_corenlp(t_to_tt_idx, g_wvi_corenlp) g_wv_str, g_wv_str_wp = convert_pr_wvi_to_string( g_wvi, nlu_t, nlu_tt, tt_to_t_idx, nlu) except: # Exception happens when where-condition is not found in nlu_tt. # In this case, that train example is not used. # During test, that example considered as wrongly answered. for b in range(len(nlu)): results1 = {} results1["error"] = "Skip happened" results1["nlu"] = nlu[b] results1["table_id"] = tb[b]["id"] results.append(results1) continue # model specific part # score if not EG: # 上面已经获取了bert模型的输出,这里将这个输出输入到s2s模型中(并结合问题json的各个字段),获取这个模型得出的bat_sizen内六大关键元素的权重 s_sc, s_sa, s_wn, s_wc, s_wo, s_wv = model(wemb_n, l_n, wemb_h, l_hpu, l_hs) # 生成/计算损失值 loss = Loss_sw_se(s_sc, s_sa, s_wn, s_wc, s_wo, s_wv, g_sc, g_sa, g_wn, g_wc, g_wo, g_wvi) # 预测得出:最可能的sc/sa/wn/wc/wo/wvi pr_sc, pr_sa, pr_wn, pr_wc, pr_wo, pr_wvi = pred_sw_se( s_sc, s_sa, s_wn, s_wc, s_wo, s_wv, ) # 根据预测得出的wv起始位置来获取where-value值 pr_wv_str, pr_wv_str_wp = convert_pr_wvi_to_string( pr_wvi, nlu_t, nlu_tt, tt_to_t_idx, nlu) # g_sql_i = generate_sql_i(g_sc, g_sa, g_wn, g_wc, g_wo, g_wv_str, nlu) # 由预测出的pr_等生成对应的sql语句表示 pr_sql_i = generate_sql_i(pr_sc, pr_sa, pr_wn, pr_wc, pr_wo, pr_wv_str, nlu) else: # Execution guided decoding prob_sca, prob_w, prob_wn_w, pr_sc, pr_sa, pr_wn, pr_sql_i = model.beam_forward( wemb_n, l_n, wemb_h, l_hpu, l_hs, engine, tb, nlu_t, nlu_tt, tt_to_t_idx, nlu, beam_size=beam_size) # sort and generate pr_wc, pr_wo, pr_wv, pr_sql_i = sort_and_generate_pr_w(pr_sql_i) # Follosing variables are just for the consistency with no-EG case. pr_wvi = None # not used pr_wv_str = None pr_wv_str_wp = None loss = torch.tensor([0]) g_sql_q = generate_sql_q(sql_i, tb) pr_sql_q = generate_sql_q(pr_sql_i, tb) # Saving for the official evaluation later. for b, pr_sql_i1 in enumerate(pr_sql_i): results1 = {} results1["query"] = pr_sql_i1 results1["table_id"] = tb[b]["id"] results1["nlu"] = nlu[b] results.append(results1) cnt_sc1_list, cnt_sa1_list, cnt_wn1_list, \ cnt_wc1_list, cnt_wo1_list, \ cnt_wvi1_list, cnt_wv1_list = get_cnt_sw_list(g_sc, g_sa, g_wn, g_wc, g_wo, g_wvi, pr_sc, pr_sa, pr_wn, pr_wc, pr_wo, pr_wvi, sql_i, pr_sql_i, mode='test') cnt_lx1_list = get_cnt_lx_list(cnt_sc1_list, cnt_sa1_list, cnt_wn1_list, cnt_wc1_list, cnt_wo1_list, cnt_wv1_list) # Execution accura y test cnt_x1_list = [] # lx stands for logical form accuracy # Execution accuracy test. cnt_x1_list, g_ans, pr_ans = get_cnt_x_list(engine, tb, g_sc, g_sa, sql_i, pr_sc, pr_sa, pr_sql_i) # stat ave_loss += loss.item() # count cnt_sc += sum(cnt_sc1_list) cnt_sa += sum(cnt_sa1_list) cnt_wn += sum(cnt_wn1_list) cnt_wc += sum(cnt_wc1_list) cnt_wo += sum(cnt_wo1_list) cnt_wv += sum(cnt_wv1_list) cnt_wvi += sum(cnt_wvi1_list) cnt_lx += sum(cnt_lx1_list) cnt_x += sum(cnt_x1_list) current_cnt = [ cnt_tot, cnt, cnt_sc, cnt_sa, cnt_wn, cnt_wc, cnt_wo, cnt_wv, cnt_wvi, cnt_lx, cnt_x ] cnt_list1 = [ cnt_sc1_list, cnt_sa1_list, cnt_wn1_list, cnt_wc1_list, cnt_wo1_list, cnt_wv1_list, cnt_lx1_list, cnt_x1_list ] cnt_list.append(cnt_list1) # report if detail: report_detail(hds, nlu, g_sc, g_sa, g_wn, g_wc, g_wo, g_wv, g_wv_str, g_sql_q, g_ans, pr_sc, pr_sa, pr_wn, pr_wc, pr_wo, pr_wv_str, pr_sql_q, pr_ans, cnt_list1, current_cnt) ave_loss /= cnt acc_sc = cnt_sc / cnt acc_sa = cnt_sa / cnt acc_wn = cnt_wn / cnt acc_wc = cnt_wc / cnt acc_wo = cnt_wo / cnt acc_wvi = cnt_wvi / cnt acc_wv = cnt_wv / cnt acc_lx = cnt_lx / cnt acc_x = cnt_x / cnt acc = [ ave_loss, acc_sc, acc_sa, acc_wn, acc_wc, acc_wo, acc_wvi, acc_wv, acc_lx, acc_x ] return acc, results, cnt_list
def train(train_loader, train_table, model, model_bert, opt, bert_config, tokenizer, max_seq_length, num_target_layers, accumulate_gradients=1, check_grad=True, st_pos=0, opt_bert=None, path_db=None, dset_name='train'): model.train() #将模块设置为训练模式/评估模式。 model_bert.train() ave_loss = 0 cnt = 0 # count the # of examples cnt_sc = 0 # count the # of correct predictions of select column cnt_sa = 0 # of selectd aggregation cnt_wn = 0 # of where number cnt_wc = 0 # of where column cnt_wo = 0 # of where operator cnt_wv = 0 # of where-value cnt_wvi = 0 # of where-value index (on question tokens) cnt_lx = 0 # of logical form acc cnt_x = 0 # of execution acc # 初始化数据库查询引擎 engine = DBEngine(os.path.join(path_db, f"{dset_name}.db")) for iB, t in enumerate(train_loader): '''t为batch_size个数据''' cnt += len(t) #每个循环中的例子数 if cnt < st_pos: continue # 将问题拆分成一个个部分,并且结合问题所对应的表 nlu, nlu_t, sql_i, sql_q, sql_t, tb, hs_t, hds = get_fields( t, train_table, no_hs_t=True, no_sql_t=True) # nlu:bs个问题 # nlu_t:标记化的问题,这里不分词 # sql_i:SQL查询的规范形式 # sql_q:完整的SQL查询文本。 不曾用过。 # sql_t:没软用 # tb:bs个问题对应的表格(不一定一对一,但是保证bs个问题要找的表在里面) # hs_t:标记化的标头。 不曾用过。 # hds:表头 '''分别获取bs个问题的sc, sa, wn, wc, wo, wv(多个wn都放在list里)''' g_sc, g_sa, g_wn, g_wc, g_wo, g_wv = get_g(sql_i) # get ground truth where-value index under CoreNLP tokenization scheme. It's done already on trainset. '''这个是获取loader里WV的起止(有问题待改进)''' g_wvi_corenlp = get_g_wvi_corenlp(t) wemb_n, wemb_h, l_n, l_hpu, l_hs, \ nlu_tt, t_to_tt_idx, tt_to_t_idx \ = get_wemb_bert(bert_config, model_bert, tokenizer, nlu_t, hds, max_seq_length, num_out_layers_n=num_target_layers, num_out_layers_h=num_target_layers) '''函数作用:获取所有从bert模型中输出的参数''' # wemb_n: 问题的参数 # wemb_h: 表字段的参数 # l_n: 问题的长度 # l_hpu: 我们不是把问题和表头合在一起了嘛,这就是通过表头的起始,获取每个表头字段的长度 # l_hs: 表字段总数 # nlu_tt: 已经分词了的问题 # t_to_tt_idx: 将已分词的每个字(词)标记它的序号 # tt_to_t_idx: 同上? try: #验证/过滤? g_wvi = get_g_wvi_bert_from_g_wvi_corenlp(t_to_tt_idx, g_wvi_corenlp) except: # Exception happens when where-condition is not found in nlu_tt. # In this case, that train example is not used. # During test, that example considered as wrongly answered. # e.g. train: 32. continue # 上面已经获取了bert模型的输出,这里将这个输出输入到s2s模型中(并结合问题json的各个字段),获取这个模型得出的bat_sizen内六大关键元素的权重 s_sc, s_sa, s_wn, s_wc, s_wo, s_wv = model(wemb_n, l_n, wemb_h, l_hpu, l_hs, g_sc=g_sc, g_sa=g_sa, g_wn=g_wn, g_wc=g_wc, g_wvi=g_wvi) # 生成/计算损失值 loss = Loss_sw_se(s_sc, s_sa, s_wn, s_wc, s_wo, s_wv, g_sc, g_sa, g_wn, g_wc, g_wo, g_wvi) # 计算梯度 if iB % accumulate_gradients == 0: # mode # at start, perform zero_grad opt.zero_grad() if opt_bert: opt_bert.zero_grad() loss.backward() if accumulate_gradients == 1: opt.step() if opt_bert: opt_bert.step() elif iB % accumulate_gradients == (accumulate_gradients - 1): # at the final, take step with accumulated graident loss.backward() opt.step() if opt_bert: opt_bert.step() else: # at intermediate stage, just accumulates the gradients loss.backward() # 预测得出:最可能的sc/sa/wn/wc/wo/wvi pr_sc, pr_sa, pr_wn, pr_wc, pr_wo, pr_wvi = pred_sw_se( s_sc, s_sa, s_wn, s_wc, s_wo, s_wv, ) # 根据预测得出的wv起始位置来获取where-value值 pr_wv_str, pr_wv_str_wp = convert_pr_wvi_to_string( pr_wvi, nlu_t, nlu_tt, tt_to_t_idx, nlu) # Sort pr_wc: # Sort pr_wc when training the model as pr_wo and pr_wvi are predicted using ground-truth where-column (g_wc) # In case of 'dev' or 'test', it is not necessary as the ground-truth is not used during inference. pr_wc_sorted = sort_pr_wc(pr_wc, g_wc) #对预测出的wc进行排序 pr_sql_i = generate_sql_i(pr_sc, pr_sa, pr_wn, pr_wc_sorted, pr_wo, pr_wv_str, nlu) #由预测出的pr_等生成对应的sql语句表示 # 计算准确率(1:正确 0:错误) cnt_sc1_list, cnt_sa1_list, cnt_wn1_list, \ cnt_wc1_list, cnt_wo1_list, \ cnt_wvi1_list, cnt_wv1_list = get_cnt_sw_list(g_sc, g_sa, g_wn, g_wc, g_wo, g_wvi, pr_sc, pr_sa, pr_wn, pr_wc, pr_wo, pr_wvi, sql_i, pr_sql_i, mode='train') #是否全对/全对的数量(完美的sql,全对:1 不全对:0) cnt_lx1_list = get_cnt_lx_list(cnt_sc1_list, cnt_sa1_list, cnt_wn1_list, cnt_wc1_list, cnt_wo1_list, cnt_wv1_list) # lx stands for logical form accuracy # 获得结果,出现小错误频率大王!! cnt_x1_list, g_ans, pr_ans = get_cnt_x_list(engine, tb, g_sc, g_sa, sql_i, pr_sc, pr_sa, pr_sql_i) # statistics ave_loss += loss.item() # count cnt_sc += sum(cnt_sc1_list) cnt_sa += sum(cnt_sa1_list) cnt_wn += sum(cnt_wn1_list) cnt_wc += sum(cnt_wc1_list) cnt_wo += sum(cnt_wo1_list) cnt_wvi += sum(cnt_wvi1_list) cnt_wv += sum(cnt_wv1_list) cnt_lx += sum(cnt_lx1_list) cnt_x += sum(cnt_x1_list) ave_loss /= cnt acc_sc = cnt_sc / cnt acc_sa = cnt_sa / cnt acc_wn = cnt_wn / cnt acc_wc = cnt_wc / cnt acc_wo = cnt_wo / cnt acc_wvi = cnt_wvi / cnt acc_wv = cnt_wv / cnt acc_lx = cnt_lx / cnt acc_x = cnt_x / cnt acc = [ ave_loss, acc_sc, acc_sa, acc_wn, acc_wc, acc_wo, acc_wvi, acc_wv, acc_lx, acc_x ] aux_out = 1 return acc, aux_out
def predict(data_loader, data_table, model, model_bert, bert_config, tokenizer, max_seq_length, num_target_layers, detail=False, st_pos=0, cnt_tot=1, EG=False, beam_size=4, path_db=None, dset_name='test'): model.eval() model_bert.eval() engine = DBEngine(os.path.join(path_db, f"{dset_name}.db")) results = [] # for iB, t in enumerate(data_loader): # nlu, nlu_t, sql_i, sql_q, sql_t, tb, hs_t, hds = get_fields(t, data_table, no_hs_t=True, no_sql_t=True) # g_sc, g_sa, g_wn, g_wc, g_wo, g_wv = get_g(sql_i) # g_wvi_corenlp = get_g_wvi_corenlp(t) # wemb_n, wemb_h, l_n, l_hpu, l_hs, \ # nlu_tt, t_to_tt_idx, tt_to_t_idx \ # = get_wemb_bert(bert_config, model_bert, tokenizer, nlu_t, hds, max_seq_length, # num_out_layers_n=num_target_layers, num_out_layers_h=num_target_layers) # if not EG: # # No Execution guided decoding # s_sc, s_sa, s_wn, s_wc, s_wo, s_wv = model(wemb_n, l_n, wemb_h, l_hpu, l_hs) # pr_sc, pr_sa, pr_wn, pr_wc, pr_wo, pr_wvi = pred_sw_se(s_sc, s_sa, s_wn, s_wc, s_wo, s_wv, ) # pr_wv_str, pr_wv_str_wp = convert_pr_wvi_to_string(pr_wvi, nlu_t, nlu_tt, tt_to_t_idx, nlu) # pr_sql_i = generate_sql_i(pr_sc, pr_sa, pr_wn, pr_wc, pr_wo, pr_wv_str, nlu) # else: # # Execution guided decoding # prob_sca, prob_w, prob_wn_w, pr_sc, pr_sa, pr_wn, pr_sql_i = model.beam_forward(wemb_n, l_n, wemb_h, l_hpu, # l_hs, engine, tb, # nlu_t, nlu_tt, # tt_to_t_idx, nlu, # beam_size=beam_size) # # sort and generate # pr_wc, pr_wo, pr_wv, pr_sql_i = sort_and_generate_pr_w(pr_sql_i) # # Following variables are just for consistency with no-EG case. # pr_wvi = None # not used # pr_wv_str=None # pr_wv_str_wp=None # # pr_sql_q = generate_sql_q(pr_sql_i, tb) # # for b, (pr_sql_i1, pr_sql_q1) in enumerate(zip(pr_sql_i, pr_sql_q)): # results1 = {} # results1["query"] = pr_sql_i1 # results1["table_id"] = tb[b]["id"] # results1["nlu"] = nlu[b] # results1["sql"] = pr_sql_q1 # results.append(results1) for iB, t in enumerate(data_loader): # Get fields nlu, nlu_t, sql_i, sql_q, sql_t, tb, hs_t, hds = get_fields( t, data_table, no_hs_t=True, no_sql_t=True) g_sc, g_sa, g_sop, g_wn, g_wc, g_wo, g_wv, g_sel_num_seq, g_sel_ag_seq = get_g( sql_i) g_wvi_corenlp = get_g_wvi_corenlp(t) wemb_n, wemb_h, l_n, l_hpu, l_hs, \ nlu_tt, t_to_tt_idx, tt_to_t_idx \ = get_wemb_bert(bert_config, model_bert, tokenizer, nlu_t, hds, max_seq_length, num_out_layers_n=num_target_layers, num_out_layers_h=num_target_layers) try: pass g_wvi = get_g_wvi_bert_from_g_wvi_corenlp(t_to_tt_idx, g_wvi_corenlp) except: # Exception happens when where-condition is not found in nlu_tt. # In this case, that train example is not used. # During test, that example considered as wrongly answered. for b in range(len(nlu)): results1 = {} results1["error"] = "Skip happened" results1["nlu"] = nlu[b] results1["table_id"] = tb[b]["id"] results.append(results1) continue # model specific part # score if not EG: # No Execution guided decoding g_sel_seq = [x[1] for x in g_sel_ag_seq] s_scn, s_sc, s_sa, s_sop, s_wn, s_wc, s_wo, s_wv = model( wemb_n, l_n, wemb_h, l_hpu, l_hs, g_sc=g_sc, g_sa=g_sa, g_wn=g_wn, g_wc=g_wc, g_sop=g_sop, g_wo=g_wo, g_wvi=g_wv, g_sel=g_sel_seq, g_scn=g_sel_num_seq) # prediction score = [] score.append(s_scn) score.append(s_sc) score.append(s_sa) score.append(s_sop) tuple(score) pr_sql_i1 = model.gen_query(score, nlu_tt, nlu) pr_wn, pr_wc, pr_sop, pr_wo, pr_wvi = pred_sw_se( s_sop, s_wn, s_wc, s_wo, s_wv) pr_wv_str, pr_wv_str_wp = convert_pr_wvi_to_string( pr_wvi, nlu_t, nlu_tt, tt_to_t_idx, nlu) # 映射到字符串 pr_wc_sorted = sort_pr_wc(pr_wc, g_wc) pr_sql_i = generate_sql_i(pr_sql_i1, pr_wn, pr_wc_sorted, pr_wo, pr_wv_str, nlu) else: # Execution guided decoding prob_sca, prob_w, prob_wn_w, pr_sc, pr_sa, pr_wn, pr_sql_i = model.beam_forward( wemb_n, l_n, wemb_h, l_hpu, l_hs, engine, tb, nlu_t, nlu_tt, tt_to_t_idx, nlu, beam_size=beam_size) # sort and generate pr_wc, pr_wo, pr_wv, pr_sql_i = sort_and_generate_pr_w(pr_sql_i) # # Saving for the official evaluation later. for b, pr_sql_i1 in enumerate(pr_sql_i): results.append(pr_sql_i1) return results
def test(data_loader, data_table, model, model_bert, bert_config, tokenizer, max_seq_length, num_target_layers, detail=False, st_pos=0, cnt_tot=1, EG=False, beam_size=4, path_db=None, dset_name='test', mvl=2): model.eval() model_bert.eval() ave_loss = 0 cnt = 0 cnt_sn = 0 cnt_sc = 0 cnt_sa = 0 cnt_wn = 0 cnt_wr = 0 cnt_wc = 0 cnt_wo = 0 cnt_wv = 0 cnt_wvi = 0 cnt_lx = 0 cnt_x = 0 cnt_err = 0 cnt_still = 0 cnt_skip = 0 cnt_hrpc = 0 cnt_list = [] engine = DBEngine(os.path.join(path_db, dset_name, f"{dset_name}.db")) results = [] for iB, t in enumerate(data_loader): #print('iB : %d' % iB) cnt += len(t) if cnt < st_pos: continue # Get fields nlu, nlu_t, sql_i, sql_q, sql_t, tb, hs_t, hds = get_fields( t, data_table, no_hs_t=True, no_sql_t=True, generate_mode=False) g_sn, g_sc, g_sa, g_wn, g_wr, g_dwn, g_wc, g_wo, g_wv, g_r_c_n, wvi_change_index = get_g( sql_i) g_wrcn = g_r_c_n wemb_n, wemb_h, l_n, l_hpu, l_hs, \ nlu_tt, t_to_tt_idx, tt_to_t_idx, wemb_v, l_npu, l_token \ = get_wemb_bert(bert_config, model_bert, tokenizer, nlu_t, hds, max_seq_length, num_out_layers_n=num_target_layers, num_out_layers_h=num_target_layers, num_out_layers_v=num_target_layers) try: #here problem #print('ok') g_wvi_corenlp = get_g_wvi_corenlp(t, wvi_change_index) g_wvi = get_g_wvi_bert_from_g_wvi_corenlp(t_to_tt_idx, g_wvi_corenlp) #print('no') g_wv_str, g_wv_str_wp = convert_pr_wvi_to_string( g_wvi, nlu_t, nlu_tt, tt_to_t_idx) g_wvi = get_g_wvi_stidx_length_jian_yi(g_wvi_corenlp) #print('gogogo:', g_wvi) #这里需要连同脏数据一起计算准确率 except: # Exception happens when where-condition is not found in nlu_tt. # In this case, that train example is not used. # During test, that example considered as wrongly answered. for b in range(len(nlu)): results1 = {} results1["error"] = "Skip happened" results1["nlu"] = nlu[b] results1["table_id"] = tb[b]["id"] results.append(results1) cnt_skip += len(nlu) continue # model specific part # score if not EG: # No Execution guided decoding s_sn, s_sc, s_sa, s_wn, s_wr, s_hrpc, s_wc, s_wo, s_wv1, s_wv2, s_wv3, s_wv4 = model( mvl, wemb_n, l_n, wemb_h, l_hpu, l_hs, wemb_v, l_npu, l_token) # get loss & step #loss = Loss_sw_se(s_sn, s_sc, s_sa, s_wn, s_wr, s_hrpc, s_wrpc, s_nrpc, s_wc, s_wo, s_wv1, s_wv2, g_sn, g_sc, g_sa, g_wn, g_dwn, g_wr, g_wc, g_wo, g_wvi, g_wrcn) #unable for loss loss = torch.tensor([0]) # prediction pr_sn, pr_sc, pr_sa, pr_wn, pr_wr, pr_hrpc, pr_wc, pr_wo, pr_wvi = pred_sw_se( s_sn, s_sc, s_sa, s_wn, s_wr, s_hrpc, s_wc, s_wo, s_wv1, s_wv2, s_wv3, s_wv4, mvl) pr_wvi_decode = g_wvi_decoder_stidx_length_jian_yi(pr_wvi) pr_wv_str, pr_wv_str_wp = convert_pr_wvi_to_string( pr_wvi_decode, nlu_t, nlu_tt, tt_to_t_idx) # g_sql_i = generate_sql_i(g_sc, g_sa, g_wn, g_wc, g_wo, g_wv_str, nlu) pr_sql_i = generate_sql_i(pr_sc, pr_sa, pr_wn, pr_wr, pr_wc, pr_wo, pr_wv_str, nlu) else: s_sn, s_sc, s_sa, s_wn, s_wr, s_hrpc, s_wc, s_wo, s_wv1, s_wv2, s_wv3, s_wv4 = model( mvl, wemb_n, l_n, wemb_h, l_hpu, l_hs, wemb_v, l_npu, l_token) pr_sn1, pr_sc1, pr_sa1, pr_wn1, pr_wr1, pr_hrpc1, pr_wc1, pr_wo1, pr_wvi1 = pred_sw_se( s_sn, s_sc, s_sa, s_wn, s_wr, s_hrpc, s_wc, s_wo, s_wv1, s_wv2, s_wv3, s_wv4, mvl) pr_wvi_decode = g_wvi_decoder_stidx_length_jian_yi(pr_wvi1) pr_wv_str, pr_wv_str_wp = convert_pr_wvi_to_string( pr_wvi_decode, nlu_t, nlu_tt, tt_to_t_idx) pr_sql_i1 = generate_sql_i(pr_sc1, pr_sa1, pr_wn1, pr_wr1, pr_wc1, pr_wo1, pr_wv_str, nlu) # Execution guided decoding pr_sql_i, exe_error1, still_error1 = model.beam_forward( pr_sql_i1, mvl, wemb_n, l_n, wemb_h, l_hpu, l_hs, wemb_v, l_npu, l_token, engine, tb, nlu_t, beam_size=beam_size) # sort and generate #pr_wc, pr_wo, pr_wv, pr_sql_i = sort_and_generate_pr_w(pr_sql_i) cnt_err += exe_error1 cnt_still += still_error1 pr_sn, pr_sc, pr_sa, pr_wn, pr_wr, pr_wc, pr_wo, pr_wv = generate_pr( pr_sql_i) # Follosing variables are just for the consistency with no-EG case. pr_wvi = None # not used pr_wv_str = None pr_wv_str_wp = None loss = torch.tensor([0]) ''' s_sn, s_sc, s_sa, s_wn, s_wr, s_hrpc, s_wc, s_wo, s_wv1, s_wv2, s_wv3, s_wv4 = model.beam_forward(mvl, wemb_n, l_n, wemb_h, l_hpu, l_hs, wemb_v, l_npu, l_token, engine, tb, nlu_t, beam_size=beam_size) # get loss & step #loss = Loss_sw_se(s_sn, s_sc, s_sa, s_wn, s_wr, s_hrpc, s_wrpc, s_nrpc, s_wc, s_wo, s_wv1, s_wv2, g_sn, g_sc, g_sa, g_wn, g_dwn, g_wr, g_wc, g_wo, g_wvi, g_wrcn) #unable for loss loss = torch.tensor([0]) # prediction pr_sn, pr_sc, pr_sa, pr_wn, pr_wr, pr_hrpc, pr_wc, pr_wo, pr_wvi = pred_sw_se(s_sn, s_sc, s_sa, s_wn, s_wr, s_hrpc, s_wc, s_wo, s_wv1, s_wv2, s_wv3, s_wv4, mvl) pr_wvi_decode = g_wvi_decoder_stidx_length_jian_yi(pr_wvi) pr_wv_str, pr_wv_str_wp = convert_pr_wvi_to_string(pr_wvi_decode, nlu_t, nlu_tt, tt_to_t_idx) # g_sql_i = generate_sql_i(g_sc, g_sa, g_wn, g_wc, g_wo, g_wv_str, nlu) pr_sql_i = generate_sql_i(pr_sc, pr_sa, pr_wn, pr_wr, pr_wc, pr_wo, pr_wv_str, nlu) ''' g_sql_q = generate_sql_q(sql_i, tb) pr_sql_q = generate_sql_q(pr_sql_i, tb) # Saving for the official evaluation later. for b, pr_sql_i1 in enumerate(pr_sql_i): results1 = {} results1["query"] = pr_sql_i1 results1["table_id"] = tb[b]["id"] results1["nlu"] = nlu[b] results.append(results1) cnt_sn1_list, cnt_sc1_list, cnt_sa1_list, cnt_wn1_list, \ cnt_wr1_list, cnt_wc1_list, cnt_wo1_list, \ cnt_wvi1_list, cnt_wv1_list = get_cnt_sw_list(g_sn, g_sc, g_sa, g_wn, g_wr, g_wc, g_wo, g_wvi, pr_sn, pr_sc, pr_sa, pr_wn, pr_wr, pr_wc, pr_wo, pr_wvi, sql_i, pr_sql_i, mode='test') cnt_lx1_list = get_cnt_lx_list(cnt_sn1_list, cnt_sc1_list, cnt_sa1_list, cnt_wn1_list, cnt_wr1_list, cnt_wc1_list, cnt_wo1_list, cnt_wv1_list) # Execution accura y test cnt_x1_list = [] # lx stands for logical form accuracy # Execution accuracy test. cnt_x1_list, g_ans, pr_ans = get_cnt_x_list(engine, tb, g_sc, g_sa, sql_i, pr_sc, pr_sa, pr_sql_i) # stat ave_loss += loss.item() #print('loss: ', ave_loss / cnt) # count cnt_sn += sum(cnt_sn1_list) cnt_sc += sum(cnt_sc1_list) cnt_sa += sum(cnt_sa1_list) cnt_wn += sum(cnt_wn1_list) cnt_wr += sum(cnt_wr1_list) cnt_wc += sum(cnt_wc1_list) cnt_wo += sum(cnt_wo1_list) cnt_wv += sum(cnt_wv1_list) cnt_wvi += sum(cnt_wvi1_list) cnt_lx += sum(cnt_lx1_list) cnt_x += sum(cnt_x1_list) if iB % 10 == 0: logger.info( '%d - th data batch -> loss: %.4f; acc_sn: %.4f; acc_sc: %.4f; acc_sa: %.4f; acc_wn: %.4f; acc_wr: %.4f; acc_wc: %.4f; acc_wo: %.4f; acc_wvi: %.4f; acc_wv: %.4f; acc_lx: %.4f; acc_x: %.4f; execute_error: %.4f; skip_error: %.4f; still_error: %.4f' % (iB, ave_loss / cnt, cnt_sn / cnt, cnt_sc / cnt, cnt_sa / cnt, cnt_wn / cnt, cnt_wr / cnt, cnt_wc / cnt, cnt_wo / cnt, cnt_wvi / cnt, cnt_wv / cnt, cnt_lx / cnt, cnt_x / cnt, cnt_err / cnt, cnt_skip / cnt, cnt_still / cnt)) current_cnt = [ cnt_tot, cnt, cnt_sn, cnt_sc, cnt_sa, cnt_wn, cnt_wr, cnt_wc, cnt_wo, cnt_wv, cnt_wvi, cnt_lx, cnt_x ] cnt_list1 = [ cnt_sn1_list, cnt_sc1_list, cnt_sa1_list, cnt_wn1_list, cnt_wr1_list, cnt_wc1_list, cnt_wo1_list, cnt_wv1_list, cnt_lx1_list, cnt_x1_list ] cnt_list.append(cnt_list1) # report if detail: report_detail(hds, nlu, g_sn, g_sc, g_sa, g_wn, g_wr, g_wc, g_wo, g_wv, g_wv_str, g_sql_q, g_ans, pr_sn, pr_sc, pr_sa, pr_wn, pr_wr, pr_wc, pr_wo, pr_wv_str, pr_sql_q, pr_ans, cnt_list1, current_cnt) ave_loss /= cnt acc_sn = cnt_sn / cnt acc_sc = cnt_sc / cnt acc_sa = cnt_sa / cnt acc_wn = cnt_wn / cnt acc_wr = cnt_wr / cnt acc_wc = cnt_wc / cnt acc_wo = cnt_wo / cnt acc_wvi = cnt_wvi / cnt acc_wv = cnt_wv / cnt acc_lx = cnt_lx / cnt acc_x = cnt_x / cnt acc = [ ave_loss, acc_sn, acc_sc, acc_sa, acc_wn, acc_wr, acc_wc, acc_wo, acc_wvi, acc_wv, acc_lx, acc_x ] return acc, results, cnt_list
def train(train_loader, train_table, model, model_bert, opt, bert_config, tokenizer, max_seq_length, num_target_layers, accumulate_gradients=1, check_grad=False, st_pos=0, opt_bert=None, path_db=None, dset_name='train', col_pool_type='start_tok', aug=False): model.train() model_bert.train() ave_loss = 0 cnt = 0 # count the # of examples cnt_sc = 0 # count the # of correct predictions of select column cnt_sa = 0 # of selectd aggregation cnt_wn = 0 # of where number cnt_wc = 0 # of where column cnt_wo = 0 # of where operator cnt_wv = 0 # of where-value cnt_wvi = 0 # of where-value index (on question tokens) cnt_lx = 0 # of logical form acc cnt_x = 0 # of execution acc # Engine for SQL querying. engine = DBEngine(os.path.join(path_db, f"{dset_name}.db")) for iB, t in enumerate(train_loader): cnt += len(t) if cnt < st_pos: continue # Get fields nlu, nlu_t, sql_i, sql_q, sql_t, tb, hs_t, hds = get_fields( t, train_table, no_hs_t=True, no_sql_t=True) g_sc, g_sa, g_wn, g_wc, g_wo, g_wv = get_g(sql_i) # get ground truth where-value index under CoreNLP tokenization scheme. It's done already on trainset. g_wvi_corenlp = get_g_wvi_corenlp(t) all_encoder_layer, pooled_output, tokens, i_nlu, i_hds, \ l_n, l_hpu, l_hs, \ nlu_tt, t_to_tt_idx, tt_to_t_idx \ = get_bert_output(model_bert, tokenizer, nlu_t, hds, max_seq_length) try: # g_wvi = get_g_wvi_bert_from_g_wvi_corenlp(t_to_tt_idx, g_wvi_corenlp) except: # Exception happens when where-condition is not found in nlu_tt. # In this case, that train example is not used. # During test, that example considered as wrongly answered. # e.g. train: 32. continue wemb_n = get_wemb_n(i_nlu, l_n, bert_config.hidden_size, bert_config.num_hidden_layers, all_encoder_layer, 1) wemb_h = get_wemb_h_FT_Scalar_1(i_hds, l_hs, bert_config.hidden_size, all_encoder_layer, col_pool_type=col_pool_type) # wemb_h = [B, max_header_number, hS] cls_vec = pooled_output # model specific part # get g_wvi (it is idex for word-piece tok) # score s_sc, s_sa, s_wn, s_wc, s_wo, s_wv = model(wemb_n, l_n, wemb_h, l_hs, cls_vec, g_sc=g_sc, g_sa=g_sa, g_wn=g_wn, g_wc=g_wc, g_wo=g_wo, g_wvi=g_wvi) # Calculate loss & step loss = Loss_sw_se(s_sc, s_sa, s_wn, s_wc, s_wo, s_wv, g_sc, g_sa, g_wn, g_wc, g_wo, g_wvi) # Calculate gradient if iB % accumulate_gradients == 0: # mode # at start, perform zero_grad opt.zero_grad() if opt_bert: opt_bert.zero_grad() loss.backward() if accumulate_gradients == 1: opt.step() if opt_bert: opt_bert.step() elif iB % accumulate_gradients == (accumulate_gradients - 1): # at the final, take step with accumulated graident loss.backward() opt.step() if opt_bert: opt_bert.step() else: # at intermediate stage, just accumulates the gradients loss.backward() if check_grad: named_parameters = model.named_parameters() mu_list, sig_list = get_mean_grad(named_parameters) grad_abs_mean_mean = mean(mu_list) grad_abs_mean_sig = std(mu_list) grad_abs_sig_mean = mean(sig_list) else: grad_abs_mean_mean = 1 grad_abs_mean_sig = 1 grad_abs_sig_mean = 1 # Prediction pr_sc, pr_sa, pr_wn, pr_wc, pr_wo, pr_wvi = pred_sw_se( s_sc, s_sa, s_wn, s_wc, s_wo, s_wv, ) pr_wv_str, pr_wv_str_wp = convert_pr_wvi_to_string( pr_wvi, nlu_t, nlu_tt, tt_to_t_idx, nlu) pr_sql_i = generate_sql_i(pr_sc, pr_sa, pr_wn, pr_wc, pr_wo, pr_wv_str, nlu) # Cacluate accuracy cnt_sc1_list, cnt_sa1_list, cnt_wn1_list, \ cnt_wc1_list, cnt_wo1_list, \ cnt_wvi1_list, cnt_wv1_list = get_cnt_sw_list(g_sc, g_sa, g_wn, g_wc, g_wo, g_wvi, pr_sc, pr_sa, pr_wn, pr_wc, pr_wo, pr_wvi, sql_i, pr_sql_i, mode='train') cnt_lx1_list = get_cnt_lx_list(cnt_sc1_list, cnt_sa1_list, cnt_wn1_list, cnt_wc1_list, cnt_wo1_list, cnt_wv1_list) # lx stands for logical form accuracy # Execution accuracy test. if not aug: cnt_x1_list, g_ans, pr_ans = get_cnt_x_list( engine, tb, g_sc, g_sa, sql_i, pr_sc, pr_sa, pr_sql_i) else: cnt_x1_list = [0] * len(t) g_ans = ['N/A (data augmented'] * len(t) pr_ans = ['N/A (data augmented'] * len(t) # statistics ave_loss += loss.item() # count cnt_sc += sum(cnt_sc1_list) cnt_sa += sum(cnt_sa1_list) cnt_wn += sum(cnt_wn1_list) cnt_wc += sum(cnt_wc1_list) cnt_wo += sum(cnt_wo1_list) cnt_wvi += sum(cnt_wvi1_list) cnt_wv += sum(cnt_wv1_list) cnt_lx += sum(cnt_lx1_list) cnt_x += sum(cnt_x1_list) ave_loss /= cnt acc_sc = cnt_sc / cnt acc_sa = cnt_sa / cnt acc_wn = cnt_wn / cnt acc_wc = cnt_wc / cnt acc_wo = cnt_wo / cnt acc_wvi = cnt_wv / cnt acc_wv = cnt_wv / cnt acc_lx = cnt_lx / cnt acc_x = cnt_x / cnt acc = [ ave_loss, acc_sc, acc_sa, acc_wn, acc_wc, acc_wo, acc_wvi, acc_wv, acc_lx, acc_x ] aux_out = [grad_abs_mean_mean, grad_abs_mean_sig, grad_abs_sig_mean] return acc, aux_out
def infer(nlu1, table_name, data_table, path_db, db_name, model, model_bert, bert_config, max_seq_length, num_target_layers, beam_size=4, show_table=False, show_answer_only=False): # I know it is of against the DRY principle but to minimize the risk of introducing bug w, the infer function introuced. model.eval() model_bert.eval() engine = DBEngine(os.path.join(path_db, f"{db_name}.db")) # Get inputs nlu = [nlu1] # nlu_t1 = tokenize_corenlp(client, nlu1) nlu_t1 = tokenize_corenlp_direct_version(client, nlu1) nlu_t = [nlu_t1] tb1 = data_table[0] hds1 = tb1['header'] tb = [tb1] hds = [hds1] hs_t = [[]] wemb_n, wemb_h, l_n, l_hpu, l_hs, \ nlu_tt, t_to_tt_idx, tt_to_t_idx \ = get_wemb_bert(bert_config, model_bert, tokenizer, nlu_t, hds, max_seq_length, num_out_layers_n=num_target_layers, num_out_layers_h=num_target_layers) prob_sca, prob_w, prob_wn_w, pr_sc, pr_sa, pr_wn, pr_sql_i = model.beam_forward( wemb_n, l_n, wemb_h, l_hpu, l_hs, engine, tb, nlu_t, nlu_tt, tt_to_t_idx, nlu, beam_size=beam_size) # sort and generate pr_wc, pr_wo, pr_wv, pr_sql_i = sort_and_generate_pr_w(pr_sql_i) if len(pr_sql_i) != 1: raise EnvironmentError pr_sql_q1 = generate_sql_q(pr_sql_i, [tb1]) pr_sql_q = [pr_sql_q1] try: pr_ans, _ = engine.execute_return_query(tb[0]['id'], pr_sc[0], pr_sa[0], pr_sql_i[0]['conds']) except: pr_ans = ['Answer not found.'] pr_sql_q = ['Answer not found.'] if show_answer_only: print(f'Q: {nlu[0]}') print(f'A: {pr_ans[0]}') print(f'SQL: {pr_sql_q}') else: print( f'START ============================================================= ' ) print(f'{hds}') if show_table: print(engine.show_table(table_name)) print(f'nlu: {nlu}') print(f'pr_sql_i : {pr_sql_i}') print(f'pr_sql_q : {pr_sql_q}') print(f'pr_ans: {pr_ans}') print( f'---------------------------------------------------------------------' ) return pr_sql_i, pr_ans
print(idx) return gen_sqls if __name__ == "__main__": root = '/mnt/sda/qhz/sqlova' query_path = os.path.join(root, 'data_and_model', 'train_tok_origin.jsonl') table_path = os.path.join(root, 'data_and_model', 'train.tables.jsonl') p_sqls_path = os.path.join(root, 'data/distant_data', 'train_distant.jsonl') queries = extract.read_queries(query_path) p_sqlss = extract.read_potential_sqls(p_sqls_path) answer_path = './syn.txt' g_answers = extract.read_gold_answers(answer_path) print(len(g_answers)) engine = DBEngine_s('./data_and_model/train.db') rr_p_sqlss = [] for i, p_sqls in enumerate(tqdm(p_sqlss)): rr_p_sqls = [] if (len(p_sqls)) < 3: rr_p_sqls = [query['query'] for query in p_sqls] else: for p_sql in p_sqls: qg = Query.from_dict(p_sql['query'], ordered=True) res = engine.execute_query(queries[i]['table_id'], qg, lower=True) if res == g_answers[i]: rr_p_sqls.append(p_sql['query']) if len(rr_p_sqls) == 0: print(f"{i}\n")
def train(train_loader, train_table, model, model_bert, opt, bert_config, tokenizer, max_seq_length, num_target_layers, accumulate_gradients=1, check_grad=True, st_pos=0, opt_bert=None, path_db=None, dset_name='train'): model.train() model_bert.train() ave_loss = 0 count = 0 # count the # of examples count_sc = 0 # count the # of correct predictions of select column count_sa = 0 # of selectd aggregation count_wn = 0 # of where number count_wc = 0 # of where column count_wo = 0 # of where operator count_wv = 0 # of where-value count_wvi = 0 # of where-value index (on question tokens) count_logic_form_acc = 0 # of logical form acc count_execute_acc = 0 # of execution acc # Engine for SQL querying. engine = DBEngine(os.path.join(path_db, f"{dset_name}.db")) for batch_index, batch_data in enumerate(train_loader): count += len(batch_data) if count < st_pos: continue # Get fields question, question_token, sql, sql_text, sql_t, table, header_token, header \ = get_fields(batch_data, train_table, no_hs_t=True, no_sql_t=True) gt_select_column, gt_select_agg, gt_wherenumber, gt_wherecolumn, g_wo, g_wv = get_gt( sql) # get ground truth where-value index under CoreNLP tokenization scheme. It's done already on trainset. gt_wherevalueindex_corenlp = get_gt_wherevalueindex_corenlp(batch_data) emb_question, emb_header, len_question, len_header_token, number_header, \ question_token_bert, token_to_berttoken_index, berttoken_to_token_index \ = get_wemb_bert(bert_config, model_bert, tokenizer, question_token, header, max_seq_length, num_out_layers_n=num_target_layers, num_out_layers_h=num_target_layers) try: # gt_wherevalueindex = get_gt_wherevalueindex_bert_from_gt_wherevalueindex_corenlp( token_to_berttoken_index, gt_wherevalueindex_corenlp) except: # Exception happens when where-condition is not found in nlu_tt. # In this case, that train example is not used. # During test, that example considered as wrongly answered. # e.g. train: 32. continue # score # gt_wherevalueindex = start_index, end_index # score_where_value: [batch,4,question_len,2] score_select_column, score_select_agg, score_where_number, score_where_column, score_where_op, score_where_value, \ score_select_column_softmax, score_select_agg_softmax, score_where_number_softmax,\ score_where_column_softmax, score_whereop_softmax, score_where_value_softmax \ = model(emb_question, len_question, emb_header, len_header_token, number_header, g_sc=gt_select_column, g_sa=gt_select_agg, g_wn=gt_wherenumber, g_wc=gt_wherecolumn, g_wvi=gt_wherevalueindex) # Calculate loss & step loss = Loss_selectwhere_startend_v2( score_select_column, score_select_agg, score_where_number, score_where_column, score_where_op, score_where_value, gt_select_column, gt_select_agg, gt_wherenumber, gt_wherecolumn, g_wo, gt_wherevalueindex) # RL # Random explore pred_selectcolumn_random, pred_selectagg_random, pred_wherenumber_random, pred_wherecolumn_random, pred_whereop_random, pred_wherevalueindex_random = \ pred_selectwhere_startend_random(score_select_column_softmax, score_select_agg_softmax, score_where_number_softmax, score_where_column_softmax, score_whereop_softmax, score_where_value_softmax,) """ pred_wherevalue_str_random, pred_wherevalue_str_bert_random = convert_pred_wvi_to_string(pred_wherevalueindex_random, question_token, question_token_bert, berttoken_to_token_index, question) # Sort pr_wc: # Sort pr_wc when training the model as pr_wo and pr_wvi are predicted using ground-truth where-column (g_wc) # In case of 'dev' or 'test', it is not necessary as the ground-truth is not used during inference. pred_wherecolumn_sorted_random = sort_pred_wherecolumn(pred_wherecolumn_random, gt_wherecolumn) random_sql_int = generate_sql_i_v2(pred_selectcolumn_random, pred_selectagg_random, pred_wherenumber_random, pred_wherecolumn_sorted_random, pred_whereop_random, pred_wherevalue_str_random, question) """ # Prediction pr_sc, pr_sa, pr_wn, pr_wc, pr_wo, pr_wvi = pred_selectwhere_startend( score_select_column, score_select_agg, score_where_number, score_where_column, score_where_op, score_where_value, ) pred_wherevalue_str, pred_wherevalue_str_bert_random = convert_pred_wvi_to_string( pr_wvi, question_token, question_token_bert, berttoken_to_token_index, question) # Sort pr_wc: # Sort pr_wc when training the model as pr_wo and pr_wvi are predicted using ground-truth where-column (g_wc) # In case of 'dev' or 'test', it is not necessary as the ground-truth is not used during inference. pr_wc_sorted = sort_pred_wherecolumn(pr_wc, gt_wherecolumn) pred_sql_int = generate_sql_i(pr_sc, pr_sa, pr_wn, pr_wc_sorted, pr_wo, pred_wherevalue_str, question) # select column def equal_in(cell_, pred_answer_column): for cell in pred_answer_column: if cell == cell_: return True return False # select_column_random_int = torch.squeeze(torch.softmax(score_select_column,dim=-1).multinomial(1),dim=-1) #select_column_int = torch.squeeze(torch.argmax(score_select_column,dim=-1),dim=-1) # if batch_index%100==0: # print("RL_loss_select_column",RL_loss_select_column.data.cpu().numpy()) # RL # where number def list_in_list(small, big): for cell in big: try: cell = str(int(cell)) except: cell = str(cell) if cell.lower() in small: return True batch_size = score_where_value_softmax.shape[0] batch_have_good_reward = False while batch_have_good_reward == False: selectcolumn_random_v2 = [] for i in range(batch_size): col = random.randint(1, len(header[i])) - 1 selectcolumn_random_v2.append(col) wherenumber_random_v2 = [] for i in range(batch_size): num = random.randint(1, 4) wherenumber_random_v2.append(num) selectagg_random_v2 = [] for i in range(batch_size): agg = random.randint(1, 6) - 1 selectagg_random_v2.append(agg) whereop_random_v2 = [] for i in range(batch_size): cond = [] for j in range(4): col = random.randint(1, 3) - 1 cond.append(col) whereop_random_v2.append(cond) wherecolumn_random_v2 = [] for i in range(batch_size): cond = [] for j in range(4): col = random.randint(1, len(header[i])) - 1 cond.append(col) wherecolumn_random_v2.append(cond) pred_wherevalueindex_random_v2 = [] # [8,4,2] for i in range(batch_size): cond = [] for j in range(4): start = random.randint(1, len_question[i]) - 1 end = random.randint(1, len_question[i]) - 1 while start >= end: start = random.randint(1, len_question[i]) - 1 end = random.randint(1, len_question[i]) - 1 cond.append([start, end]) pred_wherevalueindex_random_v2.append(cond) pred_wherevalue_str_random_v2, pred_wherevalue_str_bert_random_v2 = convert_pred_wvi_to_string( pred_wherevalueindex_random_v2, question_token, question_token_bert, berttoken_to_token_index, question) # pred_wherecolumn_sorted_random_v2 = sort_pred_wherecolumn(wherecolumn_random_v2, gt_wherecolumn) random_sql_int_v2 = generate_sql_i_v2( pred_selectcolumn_random, pred_selectagg_random, wherenumber_random_v2, wherecolumn_random_v2, whereop_random_v2, pred_wherevalue_str_random_v2, question) batch_reward_all = [] for i in range(len(batch_data)): gt_answer_list = batch_data[i]["answer"] reward = 0 tmp_conds = [] for j in range(wherenumber_random_v2[i]): cond = [] cond.append(wherecolumn_random_v2[i][j]) cond.append(whereop_random_v2[i][j]) cond.append(random_sql_int_v2[i]["conds"][j][2]) tmp_conds.append(cond) pred_answer_column = engine.execute(table[i]['id'], selectcolumn_random_v2[i], selectagg_random_v2[i], tmp_conds) # answer in for cell in gt_answer_list: if cell in pred_answer_column or equal_in( cell, pred_answer_column): reward = 1 if reward > 0: batch_have_good_reward = True batch_reward_all.append(reward) """ batch_reward_where = [] for i in range(len(batch_data)): gt_answer_list = batch_data[i]["answer"] reward = 0 if len(random_sql_int_v2[i]["conds"])>0 and len(pred_sql_int[i]["conds"])>0: tmp_conds = [] tmp_cond = pred_sql_int[i]["conds"][0] tmp_cond[0] = wherecolumn_random_v2[i][0] tmp_cond[1] = whereop_random_v2[i][0] tmp_cond[2] = random_sql_int_v2[i]["conds"][0][2] tmp_conds.append(tmp_cond) if tmp_conds[0][2]!="": pred_answer_column = engine.execute(table[i]['id'], "*", 0, tmp_conds) # answer in for cell in gt_answer_list: if cell in pred_answer_column or equal_in(cell, pred_answer_column): reward = 1 pred_answer_column4 = engine.execute(table[i]['id'], pred_sql_int[i]["sel"], 0, tmp_conds) # answer absolute in for cell in gt_answer_list: if cell in pred_answer_column4 or equal_in(cell, pred_answer_column4): reward = 1 if len(tmp_conds)>=1: # same column: word in question and where column pred_answer_column2 = engine.execute(table[i]['id'], tmp_conds[0][0], 0, []) for cell in pred_answer_column2: try: cell = str(int(cell)) except: cell = str(cell) if cell in question[i].lower(): reward = 1 # same column: where value and where column if cell == tmp_conds[0][2]: reward = 1 tmp_conds2 = [] tmp_cond2 = pred_sql_int[i]["conds"][0] tmp_cond2[0] = wherecolumn_random_v2[i][0] tmp_cond2[2] = random_sql_int_v2[i]["conds"][0][2] tmp_cond2[1] = 0 # EQUAL tmp_conds2.append(tmp_cond2) if len(tmp_conds2) >= 1: pred_answer_column3 = engine.execute(table[i]['id'], tmp_conds2[0][0], 0, tmp_conds2) # same row: the answer and this cell for row in table[i]["rows"]: if list_in_list(pred_answer_column3,row) and list_in_list(gt_answer_list,row): reward = 1 if reward>0: batch_have_good_reward=True batch_reward_where.append(reward) """ onehot_action_batch_wherenumber = [] for action_int in wherenumber_random_v2: tmp = [0] * score_where_number.shape[1] tmp[action_int - 1] = 1 onehot_action_batch_wherenumber.append(tmp) RL_loss_where_number = torch.mean(-torch.log( torch.sum(torch.softmax(score_where_number, dim=-1) * torch.tensor( onehot_action_batch_wherenumber, dtype=torch.float).to(device), dim=-1)) * torch.tensor(batch_reward_all, dtype=torch.float).to(device)) """ batch_reward_select_column = [] for i in range(len(batch_data)): gt_answer_list = batch_data[i]["answer"] pred_answer_column = engine.execute(table[i]['id'], random_sql_int[i]["sel"], 0, []) reward = -1 for cell in gt_answer_list: if cell in pred_answer_column or equal_in(cell,pred_answer_column): reward = 1 batch_reward_select_column.append(reward) """ onehot_action_batch_selectcolumn = [] for sel in selectcolumn_random_v2: tmp = [0] * score_select_column.shape[1] tmp[sel] = 1 onehot_action_batch_selectcolumn.append(tmp) RL_loss_select_column = torch.mean( -torch.log( torch.sum(torch.softmax(score_select_column, dim=-1) * torch.tensor(onehot_action_batch_selectcolumn, dtype=torch.float).to(device), dim=-1)) * torch.tensor(batch_reward_all, dtype=torch.float).to(device)) onehot_action_batch_selectagg = [] for agg in selectagg_random_v2: tmp = [0] * score_select_agg.shape[1] tmp[agg] = 1 onehot_action_batch_selectagg.append(tmp) RL_loss_select_agg = torch.mean(-torch.log( torch.sum(torch.softmax(score_select_agg, dim=-1) * torch.tensor( onehot_action_batch_selectagg, dtype=torch.float).to(device), dim=-1)) * torch.tensor(batch_reward_all, dtype=torch.float).to(device)) # RL # where column # where_column_int = torch.squeeze(torch.argmax(score_where_column, dim=-1), dim=-1) # where_column_int = torch.squeeze(torch.softmax(score_where_column, dim=-1).multinomial(1), dim=-1) where_column_int = [] for tmp_sql in random_sql_int_v2: if len(tmp_sql["conds"]) == 0: where_column_int.append(-1) else: where_column_int.append(tmp_sql["conds"][0][0]) onehot_action_batch_wherecolumn = [] for action_int in wherecolumn_random_v2: tmp = [0] * score_where_column.shape[1] if action_int[0] != -1: # all 0 will runtime error tmp[action_int[0]] = 1 else: # this is must tmp = [0.01] * score_where_column.shape[1] onehot_action_batch_wherecolumn.append(tmp) RL_loss_where_column = torch.mean(-torch.log( torch.sum(torch.softmax(score_where_column, dim=-1) * torch.tensor( onehot_action_batch_wherecolumn, dtype=torch.float).to(device), dim=-1)) * torch.tensor(batch_reward_all, dtype=torch.float).to(device)) # RL # where value # pred_wherevalueindex_random [8,4,2] action_wherevalue = [] for one in pred_wherevalueindex_random_v2: if len(one) > 0: action = [] for i in range(4): if len(one) > i: start = one[i][0] tmp_start = [0] * score_where_value.shape[2] tmp_start[start] = 1 end = one[i][1] tmp_end = [0] * score_where_value.shape[2] tmp_end[end] = 1 action.append([tmp_start, tmp_end]) else: action.append([[0.01] * score_where_value.shape[2], [0.01] * score_where_value.shape[2]]) action_wherevalue.append(action) else: action = [] for i in range(4): action.append([[0.01] * score_where_value.shape[2], [0.01] * score_where_value.shape[2]]) action_wherevalue.append(action) tmp_action = torch.tensor(action_wherevalue, dtype=torch.float).to(device) tmp_score_where_value = torch.transpose(score_where_value, 2, 3) RL_loss_where_value = torch.mean(-torch.log( torch.sum( torch.softmax(tmp_score_where_value, dim=-1) * tmp_action, dim=-1)) * torch.unsqueeze(torch.unsqueeze( torch.tensor(batch_reward_all, dtype=torch.float), dim=-1), dim=-1).to(device)) # RL # where op action_whereop = [] for one in whereop_random_v2: if len(one) > 0: action = [] for i in range(4): if len(one) > i: op = one[0] tmp_start = [0] * score_where_op.shape[2] tmp_start[op] = 1 action.append(tmp_start) else: action.append([0.01] * score_where_op.shape[2]) action_whereop.append(action) else: action = [] for i in range(4): action.append([0.01] * score_where_op.shape[2]) action_whereop.append(action) tmp_action_op = torch.tensor(action_whereop, dtype=torch.float).to(device) RL_loss_where_op = torch.mean(-torch.log( torch.sum(torch.softmax(score_where_op, dim=-1) * tmp_action_op, dim=-1)) * torch.unsqueeze(torch.tensor( batch_reward_all, dtype=torch.float), dim=-1).to(device)) loss += RL_loss_where_number loss += RL_loss_select_agg loss += RL_loss_select_column loss += RL_loss_where_op loss += RL_loss_where_column loss += RL_loss_where_value # Calculate gradient if batch_index % accumulate_gradients == 0: # mode # at start, perform zero_grad opt.zero_grad() if opt_bert: opt_bert.zero_grad() loss.backward() if accumulate_gradients == 1: opt.step() if opt_bert: opt_bert.step() elif batch_index % accumulate_gradients == (accumulate_gradients - 1): # at the final, take step with accumulated graident loss.backward() opt.step() if opt_bert: opt_bert.step() else: # at intermediate stage, just accumulates the gradients loss.backward() # Cacluate accuracy count_sc1_list, count_sa1_list, count_wn1_list, \ count_wc1_list, count_wo1_list, \ count_wvi1_list, count_wv1_list = get_count_sw_list(gt_select_column, gt_select_agg, gt_wherenumber, gt_wherecolumn, g_wo, gt_wherevalueindex, pr_sc, pr_sa, pr_wn, pr_wc, pr_wo, pr_wvi, sql, pred_sql_int, mode='train') count_lx1_list = get_count_lx_list(count_sc1_list, count_sa1_list, count_wn1_list, count_wc1_list, count_wo1_list, count_wv1_list) # lx stands for logical form accuracy # Execution accuracy test. count_x1_list, g_ans, pr_ans = get_count_x_list( engine, table, gt_select_column, gt_select_agg, sql, pr_sc, pr_sa, pred_sql_int) # statistics ave_loss += loss.item() # count count_sc += sum(count_sc1_list) count_sa += sum(count_sa1_list) count_wn += sum(count_wn1_list) count_wc += sum(count_wc1_list) count_wo += sum(count_wo1_list) count_wvi += sum(count_wvi1_list) count_wv += sum(count_wv1_list) count_logic_form_acc += sum(count_lx1_list) count_execute_acc += sum(count_x1_list) ave_loss /= count acc_sc = count_sc / count acc_sa = count_sa / count acc_wn = count_wn / count acc_wc = count_wc / count acc_wo = count_wo / count acc_wvi = count_wv / count acc_wv = count_wv / count acc_lx = count_logic_form_acc / count acc_x = count_execute_acc / count acc = [ ave_loss, acc_sc, acc_sa, acc_wn, acc_wc, acc_wo, acc_wvi, acc_wv, acc_lx, acc_x ] aux_out = 1 return acc, aux_out
def test(data_loader, data_table, model, model_bert, bert_config, tokenizer, max_seq_length, num_target_layers, detail=False, st_pos=0, cnt_tot=1, EG=False, beam_size=4, path_db=None, dset_name='test', mvl=2): model.eval() model_bert.eval() engine = DBEngine(os.path.join(path_db, dset_name, f"{dset_name}.db")) results = [] for iB, t in enumerate(data_loader): #print('iB: ', iB)#to locate the error #print(iB) # Get fields nlu, nlu_t, sql_i, sql_q, sql_t, tb, hs_t, hds = get_fields( t, data_table, no_hs_t=True, no_sql_t=True, generate_mode=True) wemb_n, wemb_h, l_n, l_hpu, l_hs, \ nlu_tt, t_to_tt_idx, tt_to_t_idx, wemb_v, l_npu, l_token \ = get_wemb_bert(bert_config, model_bert, tokenizer, nlu_t, hds, max_seq_length, num_out_layers_n=num_target_layers, num_out_layers_h=num_target_layers, num_out_layers_v=num_target_layers) # model specific part # score if not EG: # No Execution guided decoding s_sn, s_sc, s_sa, s_wn, s_wr, s_hrpc, s_wc, s_wo, s_wv1, s_wv2, s_wv3, s_wv4 = model( mvl, wemb_n, l_n, wemb_h, l_hpu, l_hs, wemb_v, l_npu, l_token) # get loss & step #loss = Loss_sw_se(s_sn, s_sc, s_sa, s_wn, s_wr, s_hrpc, s_wrpc, s_nrpc, s_wc, s_wo, s_wv1, s_wv2, g_sn, g_sc, g_sa, g_wn, g_dwn, g_wr, g_wc, g_wo, g_wvi, g_wrcn) #unable for loss loss = torch.tensor([0]) # prediction pr_sn, pr_sc, pr_sa, pr_wn, pr_wr, pr_hrpc, pr_wc, pr_wo, pr_wvi = pred_sw_se( s_sn, s_sc, s_sa, s_wn, s_wr, s_hrpc, s_wc, s_wo, s_wv1, s_wv2, s_wv3, s_wv4, mvl) pr_wvi_decode = g_wvi_decoder_stidx_length_jian_yi(pr_wvi) pr_wv_str, pr_wv_str_wp = convert_pr_wvi_to_string( pr_wvi_decode, nlu_t, nlu_tt, tt_to_t_idx) # g_sql_i = generate_sql_i(g_sc, g_sa, g_wn, g_wc, g_wo, g_wv_str, nlu) pr_sql_i = generate_sql_i(pr_sc, pr_sa, pr_wn, pr_wr, pr_wc, pr_wo, pr_wv_str, nlu) else: # Execution guided decoding s_sn, s_sc, s_sa, s_wn, s_wr, s_hrpc, s_wc, s_wo, s_wv1, s_wv2, s_wv3, s_wv4 = model( mvl, wemb_n, l_n, wemb_h, l_hpu, l_hs, wemb_v, l_npu, l_token) pr_sn1, pr_sc1, pr_sa1, pr_wn1, pr_wr1, pr_hrpc1, pr_wc1, pr_wo1, pr_wvi1 = pred_sw_se( s_sn, s_sc, s_sa, s_wn, s_wr, s_hrpc, s_wc, s_wo, s_wv1, s_wv2, s_wv3, s_wv4, mvl) pr_wvi_decode = g_wvi_decoder_stidx_length_jian_yi(pr_wvi1) pr_wv_str, pr_wv_str_wp = convert_pr_wvi_to_string( pr_wvi_decode, nlu_t, nlu_tt, tt_to_t_idx) pr_sql_i1 = generate_sql_i(pr_sc1, pr_sa1, pr_wn1, pr_wr1, pr_wc1, pr_wo1, pr_wv_str, nlu) # Execution guided decoding pr_sql_i, exe_error1, still_error1 = model.beam_forward( pr_sql_i1, mvl, wemb_n, l_n, wemb_h, l_hpu, l_hs, wemb_v, l_npu, l_token, engine, tb, nlu_t, beam_size=beam_size) # sort and generate #print(pr_sql_i) #pr_wc, pr_wo, pr_wv, pr_sql_i = sort_and_generate_pr_w(pr_sql_i) #pr_sn, pr_sc, pr_sa, pr_wn, pr_wr, pr_wc, pr_wo, pr_wv = generate_pr(pr_sql_i) # Follosing variables are just for the consistency with no-EG case. pr_wvi = None # not used pr_wv_str = None pr_wv_str_wp = None loss = torch.tensor([0]) #pr_sql_q = generate_sql_q(pr_sql_i, tb) # Saving for the official evaluation later. for b, pr_sql_i1 in enumerate(pr_sql_i): results1 = pr_sql_i1 # print(results1) results.append(results1) return results
def test(data_loader, data_table, model, model_bert, bert_config, tokenizer, max_seq_length, num_target_layers, detail=False, st_pos=0, cnt_tot=1, EG=False, beam_size=4, path_db=None, dset_name='test', col_pool_type='start_tok', aug=False): model.eval() model_bert.eval() ave_loss = 0 cnt = 0 cnt_sc = 0 cnt_sa = 0 cnt_wn = 0 cnt_wc = 0 cnt_wo = 0 cnt_wv = 0 cnt_wvi = 0 cnt_lx = 0 cnt_x = 0 cnt_list = [] p_list = [] # List of prediction probabilities. data_list = [ ] # Miscellanerous data. Save it for later convenience of analysis. engine = DBEngine(os.path.join(path_db, f"{dset_name}.db")) results = [] for iB, t in enumerate(data_loader): cnt += len(t) if cnt < st_pos: continue # Get fields nlu, nlu_t, sql_i, sql_q, sql_t, tb, hs_t, hds = get_fields( t, data_table, no_hs_t=True, no_sql_t=True) g_sc, g_sa, g_wn, g_wc, g_wo, g_wv = get_g(sql_i) g_wvi_corenlp = get_g_wvi_corenlp(t) all_encoder_layer, pooled_output, tokens, i_nlu, i_hds, \ l_n, l_hpu, l_hs, \ nlu_tt, t_to_tt_idx, tt_to_t_idx \ = get_bert_output(model_bert, tokenizer, nlu_t, hds, max_seq_length) try: g_wvi = get_g_wvi_bert_from_g_wvi_corenlp(t_to_tt_idx, g_wvi_corenlp) g_wv_str, g_wv_str_wp = convert_pr_wvi_to_string( g_wvi, nlu_t, nlu_tt, tt_to_t_idx, nlu) except: # Exception happens when where-condition is not found in nlu_tt. # In this case, that train example is not used. # During test, that example considered as wrongly answered. for b in range(len(nlu)): results1 = {} results1["error"] = "Skip happened" results1["nlu"] = nlu[b] results1["table_id"] = tb[b]["id"] results.append(results1) continue # model specific part # score wemb_n = get_wemb_n(i_nlu, l_n, bert_config.hidden_size, bert_config.num_hidden_layers, all_encoder_layer, 1) wemb_h = get_wemb_h_FT_Scalar_1(i_hds, l_hs, bert_config.hidden_size, all_encoder_layer, col_pool_type=col_pool_type) # wemb_h = [B, max_header_number, hS] cls_vec = pooled_output # No Execution guided decoding if not EG: s_sc, s_sa, s_wn, s_wc, s_wo, s_wv = model(wemb_n, l_n, wemb_h, l_hs, cls_vec) # get loss & step loss = Loss_sw_se(s_sc, s_sa, s_wn, s_wc, s_wo, s_wv, g_sc, g_sa, g_wn, g_wc, g_wo, g_wvi) # prediction pr_sc, pr_sa, pr_wn, pr_wc, pr_wo, pr_wvi = pred_sw_se( s_sc, s_sa, s_wn, s_wc, s_wo, s_wv, ) pr_wv_str, pr_wv_str_wp = convert_pr_wvi_to_string( pr_wvi, nlu_t, nlu_tt, tt_to_t_idx, nlu) # g_sql_i = generate_sql_i(g_sc, g_sa, g_wn, g_wc, g_wo, g_wv_str, nlu) pr_sql_i = generate_sql_i(pr_sc, pr_sa, pr_wn, pr_wc, pr_wo, pr_wv_str, nlu) # calculate probability p_tot, p_select, p_where, p_sc, p_sa, p_wn, p_wc, p_wo, p_wvi \ = cal_prob(s_sc, s_sa, s_wn, s_wc, s_wo, s_wv, pr_sc, pr_sa, pr_wn, pr_wc, pr_wo, pr_wvi) else: # Execution guided decoding pr_sc_best, pr_sa_best, pr_wn_based_on_prob, pr_wvi_best, \ pr_sql_i, p_tot, p_select, p_where, p_sc_best, p_sa_best, \ p_wn_best, p_wc_best, p_wo_best, p_wvi_best \ = model.forward_EG(wemb_n, l_n, wemb_h, l_hs, cls_vec, engine, tb, nlu_t, nlu_tt, tt_to_t_idx, nlu, beam_size=beam_size) pr_sc = pr_sc_best pr_sa = pr_sa_best pr_wn = pr_wn_based_on_prob p_sc = p_sc_best p_sa = p_sa_best p_wn = p_wn_best # sort and generate: prob-based-sort (descending) -> wc-idx-based-sort (ascending) pr_wc, pr_wo, pr_wv_str, pr_wvi, pr_sql_i, \ p_wc, p_wo, p_wvi = sort_and_generate_pr_w(pr_sql_i, pr_wvi_best, p_wc_best, p_wo_best, p_wvi_best) # Follosing variables are just for the consistency with no-EG case. pr_wv_str_wp = None loss = torch.tensor([0]) p_list_batch = [ p_tot, p_select, p_where, p_sc, p_sa, p_wn, p_wc, p_wo, p_wvi ] p_list.append(p_list_batch) g_sql_q = generate_sql_q(sql_i, tb) pr_sql_q = generate_sql_q(pr_sql_i, tb) # Saving for the official evaluation later. for b, pr_sql_i1 in enumerate(pr_sql_i): results1 = {} results1["query"] = pr_sql_i1 results1["table_id"] = tb[b]["id"] results1["nlu"] = nlu[b] results.append(results1) cnt_sc1_list, cnt_sa1_list, cnt_wn1_list, \ cnt_wc1_list, cnt_wo1_list, \ cnt_wvi1_list, cnt_wv1_list = get_cnt_sw_list(g_sc, g_sa, g_wn, g_wc, g_wo, g_wvi, pr_sc, pr_sa, pr_wn, pr_wc, pr_wo, pr_wvi, sql_i, pr_sql_i, mode='test') cnt_lx1_list = get_cnt_lx_list(cnt_sc1_list, cnt_sa1_list, cnt_wn1_list, cnt_wc1_list, cnt_wo1_list, cnt_wv1_list) # Execution accura y test cnt_x1_list = [] # lx stands for logical form accuracy # Execution accuracy test. if not aug: cnt_x1_list, g_ans, pr_ans = get_cnt_x_list( engine, tb, g_sc, g_sa, sql_i, pr_sc, pr_sa, pr_sql_i) else: cnt_x1_list = [0] * len(t) g_ans = ['N/A (data augmented'] * len(t) pr_ans = ['N/A (data augmented'] * len(t) # stat ave_loss += loss.item() # count cnt_sc += sum(cnt_sc1_list) cnt_sa += sum(cnt_sa1_list) cnt_wn += sum(cnt_wn1_list) cnt_wc += sum(cnt_wc1_list) cnt_wo += sum(cnt_wo1_list) cnt_wv += sum(cnt_wv1_list) cnt_wvi += sum(cnt_wvi1_list) cnt_lx += sum(cnt_lx1_list) cnt_x += sum(cnt_x1_list) current_cnt = [ cnt_tot, cnt, cnt_sc, cnt_sa, cnt_wn, cnt_wc, cnt_wo, cnt_wv, cnt_wvi, cnt_lx, cnt_x ] cnt_list_batch = [ cnt_sc1_list, cnt_sa1_list, cnt_wn1_list, cnt_wc1_list, cnt_wo1_list, cnt_wv1_list, cnt_lx1_list, cnt_x1_list ] cnt_list.append(cnt_list_batch) # report if detail: report_detail(hds, nlu, g_sc, g_sa, g_wn, g_wc, g_wo, g_wv, g_wv_str, g_sql_q, g_ans, pr_sc, pr_sa, pr_wn, pr_wc, pr_wo, pr_wv_str, pr_sql_q, pr_ans, cnt_list_batch, current_cnt) data_batch = [] for b, nlu1 in enumerate(nlu): data1 = [ nlu[b], nlu_t[b], sql_i[b], g_sql_q[b], g_ans[b], pr_sql_i[b], pr_sql_q[b], pr_ans[b], tb[b] ] data_batch.append(data1) data_list.append(data_batch) ave_loss /= cnt acc_sc = cnt_sc / cnt acc_sa = cnt_sa / cnt acc_wn = cnt_wn / cnt acc_wc = cnt_wc / cnt acc_wo = cnt_wo / cnt acc_wvi = cnt_wvi / cnt acc_wv = cnt_wv / cnt acc_lx = cnt_lx / cnt acc_x = cnt_x / cnt acc = [ ave_loss, acc_sc, acc_sa, acc_wn, acc_wc, acc_wo, acc_wvi, acc_wv, acc_lx, acc_x ] return acc, results, cnt_list, p_list, data_list
def train(train_loader, train_table, model, model_bert, opt, bert_config, tokenizer, max_seq_length, num_target_layers, accumulate_gradients=1, check_grad=True, st_pos=0, opt_bert=None, path_db=None, dset_name='train'): ave_loss = 0 count = 0 # count the # of examples count_sc = 0 # count the # of correct predictions of select column count_sa = 0 # of selectd aggregation count_wn = 0 # of where number count_wc = 0 # of where column count_wo = 0 # of where operator count_wv = 0 # of where-value count_wvi = 0 # of where-value index (on question tokens) count_logic_form_acc = 0 # of logical form acc count_execute_acc = 0 # of execution acc # Engine for SQL querying. engine = DBEngine(os.path.join(path_db, f"{dset_name}.db")) explored_data_list = [] for batch_index, batch_data in enumerate(train_loader): count += len(batch_data) if count < st_pos: continue # Get fields question, question_token, sql, sql_text, sql_t, table, header_token, header \ = get_fields(batch_data, train_table, no_hs_t=True, no_sql_t=True) len_question_bert, len_header_token, number_header, \ question_token_bert, token_to_berttoken_index, berttoken_to_token_index \ = get_wemb_bert_v2(bert_config, model_bert, tokenizer, question_token, header, max_seq_length, num_out_layers_n=num_target_layers, num_out_layers_h=num_target_layers) # select column def equal_in(cell_, pred_answer_column): for cell in pred_answer_column: if cell == cell_: return True return False # RL # where number def list_in_list(small, big): for cell in big: try: cell_ = int(cell) if cell_ in small: return True cell_ = float(cell) if cell_ in small: return True except: cell_ = str(cell) if cell_.lower() in small: return True for cell in small: try: cell_ = int(cell) if cell_ in big: return True cell_ = float(cell) if cell_ in big: return True except: cell_ = str(cell) if cell_.lower() in big: return True return False def list_exact_match(input1, input2): tmp1 = [str(item) for item in input1] tmp2 = [str(item) for item in input2] if sorted(tmp1) == sorted(tmp2): return True return False def contains(big_list, small_list): return set(small_list).issubset(set(big_list)) for i in range(len(batch_data)): print(sql[i]) explored_data = {} explore_count = 0 breakall = False reward_where = False reward_where_cond1 = 0 reward_where_cond2 = 0 reward_where_cond3 = 0 reward_where_cond4 = 0 len_question = len(question_token[i]) gt_answer_list = batch_data[i]["answer"] where_number_random = 0 select_column_random = -1 select_agg_random = 0 col = 0 op = 0 start = 0 end = 0 col2 = 0 op2 = 0 start2 = 0 end2 = 0 col3 = 0 op3 = 0 start3 = 0 end3 = 0 col4 = 0 op4 = 0 start4 = 0 end4 = 0 saved_where_number = -1 saved_col1 = -1 saved_op1 = -1 saved_start1 = -1 saved_end1 = -1 saved = False # select_column_random = 3 # where_number_random = 1 while True: final_select_agg = None final_select_column = None final_conds = [] tmp_conds = [] if where_number_random == 0: select_column_random += 1 if select_column_random == len(header[i]): select_agg_random += 1 select_column_random = 0 if select_agg_random == 6: where_number_random += 1 select_agg_random = 0 if where_number_random == 1: end += 1 if end >= len_question + 1: start += 1 end = start + 1 if start >= len_question: op += 1 start = 0 if op == 3: col += 1 op = 0 if col == len(header[i]): select_column_random += 1 col = 0 if select_column_random == len(header[i]): select_agg_random += 1 select_column_random = 0 if select_agg_random == 6: where_number_random += 1 select_agg_random = 0 if saved == True and where_number_random == 2: col = saved_col1 start = saved_start1 end = saved_end1 op = saved_op1 end2 += 1 if end2 >= len_question + 1: start2 += 1 end2 = start2 + 1 if start2 >= len_question: op2 += 1 start2 = 0 if op2 == 3: col2 += 1 op2 = 0 if col2 == len(header[i]): select_column_random += 1 col2 = 0 if select_column_random == len(header[i]): select_agg_random += 1 select_column_random = 0 if select_agg_random == 6: where_number_random += 1 select_agg_random = 0 if saved == False and where_number_random == 2: end += 1 if end >= len_question + 1: start += 1 end = start + 1 if start >= len_question: op += 1 start = 0 if op == 3: col += 1 op = 0 if col == len(header[i]): end2 += 1 col = 0 if end2 >= len_question + 1: start2 += 1 end2 = start2 + 1 if start2 >= len_question: op2 += 1 start2 = 0 if op2 == 3: col2 += 1 op2 = 0 if col2 == len(header[i]): select_column_random += 1 col2 = 0 if select_column_random == len(header[i]): select_agg_random += 1 select_column_random = 0 if select_agg_random == 6: where_number_random += 1 select_agg_random = 0 if where_number_random == 3: # break #TODO end += 1 if end >= len_question + 1: start += 1 end = start + 1 if start >= len_question: op += 1 start = 0 if op == 3: col += 1 op = 0 if col == len(header[i]): end2 += 1 col = 0 if end2 >= len_question + 1: start2 += 1 end2 = start2 + 1 if start2 >= len_question: op2 += 1 start2 = 0 if op2 == 3: col2 += 1 op2 = 0 if col2 == len(header[i]): end3 += 1 col2 = 0 if end3 >= len_question + 1: start3 += 1 end3 = start3 + 1 if start3 >= len_question: op3 += 1 start3 = 0 if op3 == 3: col3 += 1 op3 = 0 if col3 == len(header[i]): select_column_random += 1 col3 = 0 if select_column_random == len(header[i]): select_agg_random += 1 select_column_random = 0 if select_agg_random == 6: where_number_random += 1 select_agg_random = 0 if where_number_random == 4: end += 1 if end >= len_question + 1: start += 1 end = start + 1 if start >= len_question: op += 1 start = 0 if op == 3: col += 1 op = 0 if col == len(header[i]): end2 += 1 col = 0 if end2 >= len_question + 1: start2 += 1 end2 = start2 + 1 if start2 >= len_question: op2 += 1 start2 = 0 if op2 == 3: col2 += 1 op2 = 0 if col2 == len(header[i]): end3 += 1 col2 = 0 if end3 >= len_question + 1: start3 += 1 end3 = start3 + 1 if start3 >= len_question: op3 += 1 start3 = 0 if op3 == 3: col3 += 1 op3 = 0 if col3 == len(header[i]): end4 += 1 col3 = 0 if end4 >= len_question + 1: start4 += 1 end4 = start4 + 1 if start4 >= len_question: op4 += 1 start4 = 0 if op4 == 3: col4 += 1 op4 = 0 if col4 == len(header[i]): select_column_random += 1 col4 = 0 if select_column_random == len(header[i]): select_agg_random += 1 select_column_random = 0 if select_agg_random == 6: where_number_random += 1 select_agg_random = 0 if where_number_random == 1: cond = [] cond.append(col) cond.append(op) pr_wv_str = question_token[i][start:end] cond_value = merge_wv_t1_eng(pr_wv_str, question[i]) try: cond_value_ = float(cond_value) except: cond_value_ = cond_value # if type(cond_value_) == str: # and random.randint(1,2)==1: # op = 0 cond.append(cond_value_) tmp_conds.append(cond) if where_number_random == 2: cond = [] cond.append(col) cond.append(op) pr_wv_str = question_token[i][start:end] cond_value = merge_wv_t1_eng(pr_wv_str, question[i]) try: cond_value_ = float(cond_value) except: cond_value_ = cond_value cond.append(cond_value_) tmp_conds.append(cond) cond = [] cond.append(col2) cond.append(op2) pr_wv_str = question_token[i][start2:end2] cond_value = merge_wv_t1_eng(pr_wv_str, question[i]) try: cond_value_ = float(cond_value) except: cond_value_ = cond_value cond.append(cond_value_) tmp_conds.append(cond) if where_number_random == 3: cond = [] cond.append(col) cond.append(op) pr_wv_str = question_token[i][start:end] cond_value = merge_wv_t1_eng(pr_wv_str, question[i]) try: cond_value_ = float(cond_value) except: cond_value_ = cond_value cond.append(cond_value_) tmp_conds.append(cond) cond = [] cond.append(col2) cond.append(op2) pr_wv_str = question_token[i][start2:end2] cond_value = merge_wv_t1_eng(pr_wv_str, question[i]) try: cond_value_ = float(cond_value) except: cond_value_ = cond_value cond.append(cond_value_) tmp_conds.append(cond) cond = [] cond.append(col3) cond.append(op3) pr_wv_str = question_token[i][start3:end3] cond_value = merge_wv_t1_eng(pr_wv_str, question[i]) try: cond_value_ = float(cond_value) except: cond_value_ = cond_value cond.append(cond_value_) tmp_conds.append(cond) if where_number_random == 4: cond = [] cond.append(col) cond.append(op) pr_wv_str = question_token[i][start:end] cond_value = merge_wv_t1_eng(pr_wv_str, question[i]) try: cond_value_ = float(cond_value) except: cond_value_ = cond_value cond.append(cond_value_) tmp_conds.append(cond) cond = [] cond.append(col2) cond.append(op2) pr_wv_str = question_token[i][start2:end2] cond_value = merge_wv_t1_eng(pr_wv_str, question[i]) try: cond_value_ = float(cond_value) except: cond_value_ = cond_value cond.append(cond_value_) tmp_conds.append(cond) cond = [] cond.append(col3) cond.append(op3) pr_wv_str = question_token[i][start3:end3] cond_value = merge_wv_t1_eng(pr_wv_str, question[i]) try: cond_value_ = float(cond_value) except: cond_value_ = cond_value cond.append(cond_value_) tmp_conds.append(cond) cond = [] cond.append(col4) cond.append(op4) pr_wv_str = question_token[i][start4:end4] cond_value = merge_wv_t1_eng(pr_wv_str, question[i]) try: cond_value_ = float(cond_value) except: cond_value_ = cond_value cond.append(cond_value_) tmp_conds.append(cond) # print(select_column_random, select_agg_random, tmp_conds) pred_answer_column = engine.execute(table[i]['id'], select_column_random, select_agg_random, tmp_conds) explore_count += 1 if explore_count % 100000 == 0: print(explore_count) if explore_count > 500000: break exact_match = list_exact_match(gt_answer_list, pred_answer_column) if where_number_random==1 and not exact_match and contains(pred_answer_column,gt_answer_list)\ and saved_where_number==-1 and op==0: # 可能会导致1condition的错过 # where_number_random = 2 # 可能会导致1condition的错过 saved_start1 = start saved_end1 = end saved_col1 = col saved_op1 = op saved = True # answer in if exact_match: if pred_answer_column == [None]: break print("explore sql", select_column_random, select_agg_random, tmp_conds) if type(gt_answer_list[0] ) == str and select_agg_random != 0: print("fake sql") elif where_number_random == 1 and type(tmp_conds[0][2])==str and tmp_conds[0][1]!=0 or\ where_number_random == 2 and type(tmp_conds[1][2]) == str and tmp_conds[1][1] != 0 or\ where_number_random == 3 and type(tmp_conds[2][2]) == str and tmp_conds[2][1] != 0 or \ where_number_random == 4 and type(tmp_conds[3][2]) == str and tmp_conds[3][1] != 0: print("fake sql") else: # print("explore answer", pred_answer_column) if type(pred_answer_column[0]) == int or type( pred_answer_column[0]) == float: final_select_agg = select_agg_random else: final_select_agg = 0 if final_select_agg == 0: pred_answer_column2 = engine.execute( table[i]['id'], select_column_random, 0, []) for cell in gt_answer_list: if cell in pred_answer_column2 or equal_in( cell, pred_answer_column2): final_select_column = select_column_random break else: final_select_column = select_column_random if final_select_agg == 0: pred_answer_column3 = engine.execute( table[i]['id'], "*", 0, tmp_conds) # answer in for cell in gt_answer_list: if cell in pred_answer_column3 or equal_in( cell, pred_answer_column3): reward_where = True break else: reward_where = True # same column: word in question and where column if where_number_random >= 1: pred_answer_column4 = engine.execute( table[i]['id'], tmp_conds[0][0], 0, []) for cell in pred_answer_column4: try: cell_ = str(float(cell)) if cell_ in question[i].lower(): reward_where_cond1 += 0.1 break cell_ = str(int(cell)) if cell_ in question[i].lower(): reward_where_cond1 += 0.1 break except: cell = str(cell) if cell in question[i].lower(): reward_where_cond1 += 0.1 break # same column: where value and where column value = tmp_conds[0][2] if value in pred_answer_column4: reward_where_cond1 += 0.1 try: value = float(tmp_conds[0][2]) if value in pred_answer_column4: reward_where_cond1 += 0.1 except: pass try: value = int(tmp_conds[0][2]) if value in pred_answer_column4: reward_where_cond1 += 0.1 except: pass try: value = str(int(tmp_conds[0][2])) if value in pred_answer_column4: reward_where_cond1 += 0.1 except: pass try: value = str(float(tmp_conds[0][2])) if value in pred_answer_column4: reward_where_cond1 += 0.1 except: pass # same column: word in question and where column if where_number_random >= 2: pred_answer_column4 = engine.execute( table[i]['id'], tmp_conds[1][0], 0, []) for cell in pred_answer_column4: try: cell_ = str(float(cell)) if cell_ in question[i].lower(): reward_where_cond2 += 0.1 break cell_ = str(int(cell)) if cell_ in question[i].lower(): reward_where_cond2 += 0.1 break except: cell = str(cell) if cell in question[i].lower(): reward_where_cond2 += 0.1 break # same column: where value and where column value = tmp_conds[1][2] if value in pred_answer_column4: reward_where_cond2 += 0.1 try: value = float(tmp_conds[1][2]) if value in pred_answer_column4: reward_where_cond2 += 0.1 except: pass try: value = int(tmp_conds[1][2]) if value in pred_answer_column4: reward_where_cond2 += 0.1 except: pass try: value = str(int(tmp_conds[1][2])) if value in pred_answer_column4: reward_where_cond2 += 0.1 except: pass try: value = str(float(tmp_conds[1][2])) if value in pred_answer_column4: reward_where_cond2 += 0.1 except: pass # same column: word in question and where column if where_number_random >= 3: pred_answer_column4 = engine.execute( table[i]['id'], tmp_conds[2][0], 0, []) for cell in pred_answer_column4: try: cell_ = str(float(cell)) if cell_ in question[i].lower(): reward_where_cond3 += 0.1 break cell_ = str(int(cell)) if cell_ in question[i].lower(): reward_where_cond3 += 0.1 break except: cell = str(cell) if cell in question[i].lower(): reward_where_cond3 += 0.1 break # same column: where value and where column value = tmp_conds[2][2] if value in pred_answer_column4: reward_where_cond3 += 0.1 try: value = float(tmp_conds[2][2]) if value in pred_answer_column4: reward_where_cond3 += 0.1 except: pass try: value = int(tmp_conds[2][2]) if value in pred_answer_column4: reward_where_cond3 += 0.1 except: pass try: value = str(int(tmp_conds[2][2])) if value in pred_answer_column4: reward_where_cond3 += 0.1 except: pass try: value = str(float(tmp_conds[2][2])) if value in pred_answer_column4: reward_where_cond3 += 0.1 except: pass # same column: word in question and where column if where_number_random >= 4: pred_answer_column4 = engine.execute( table[i]['id'], tmp_conds[3][0], 0, []) for cell in pred_answer_column4: try: cell_ = str(float(cell)) if cell_ in question[i].lower(): reward_where_cond4 += 0.1 break cell_ = str(int(cell)) if cell_ in question[i].lower(): reward_where_cond4 += 0.1 break except: cell = str(cell) if cell in question[i].lower(): reward_where_cond4 += 0.1 break # same column: where value and where column value = tmp_conds[3][2] if value in pred_answer_column4: reward_where_cond4 += 0.1 try: value = float(tmp_conds[3][2]) if value in pred_answer_column4: reward_where_cond4 += 0.1 except: pass try: value = int(tmp_conds[3][2]) if value in pred_answer_column4: reward_where_cond4 += 0.1 except: pass try: value = str(int(tmp_conds[3][2])) if value in pred_answer_column4: reward_where_cond4 += 0.1 except: pass try: value = str(float(tmp_conds[3][2])) if value in pred_answer_column4: reward_where_cond4 += 0.1 except: pass """ 有问题,cond op 只能强制为 = 因为 > 或 < 不在一行 if where_number_random >= 1 and final_select_agg==0: tmp_conds2 = tmp_conds tmp_conds2[0][1] = 0 # EQUAL pred_answer_column5 = engine.execute(table[i]['id'], tmp_conds2[0][0], 0, tmp_conds2) # same row: the answer and this cell for row in table[i]["rows"]: if list_in_list(pred_answer_column5, row) and list_in_list(gt_answer_list, row): reward_where_cond1 += 0.1 break if where_number_random >= 2 and final_select_agg==0: tmp_conds2 = tmp_conds tmp_conds2[0][1] = 0 # EQUAL tmp_conds2[1][1] = 0 # EQUAL pred_answer_column5 = engine.execute(table[i]['id'], tmp_conds2[1][0], 0, tmp_conds2) # same row: the answer and this cell for row in table[i]["rows"]: if list_in_list(pred_answer_column5, row) and list_in_list(gt_answer_list, row): reward_where_cond2 += 0.1 break if where_number_random >= 3 and final_select_agg==0: tmp_conds2 = tmp_conds tmp_conds2[0][1] = 0 # EQUAL tmp_conds2[1][1] = 0 # EQUAL tmp_conds2[2][1] = 0 # EQUAL pred_answer_column5 = engine.execute(table[i]['id'], tmp_conds2[2][0], 0, tmp_conds2) # same row: the answer and this cell for row in table[i]["rows"]: if list_in_list(pred_answer_column5, row) and list_in_list(gt_answer_list, row): reward_where_cond3 += 0.1 break if where_number_random >= 4 and final_select_agg==0: tmp_conds2 = tmp_conds tmp_conds2[0][1] = 0 # EQUAL tmp_conds2[1][1] = 0 # EQUAL tmp_conds2[2][1] = 0 # EQUAL tmp_conds2[3][1] = 0 # EQUAL pred_answer_column5 = engine.execute(table[i]['id'], tmp_conds2[3][0], 0, tmp_conds2) # same row: the answer and this cell for row in table[i]["rows"]: if list_in_list(pred_answer_column5, row) and list_in_list(gt_answer_list, row): reward_where_cond4 += 0.1 break """ if reward_where_cond1 >= 0.2 and reward_where == True and where_number_random >= 1: final_conds.append(tmp_conds[0]) if reward_where_cond2 >= 0.2 and reward_where == True and where_number_random >= 2: final_conds.append(tmp_conds[1]) if reward_where_cond3 >= 0.2 and reward_where == True and where_number_random >= 3: final_conds.append(tmp_conds[2]) if reward_where_cond4 >= 0.2 and reward_where == True and where_number_random >= 4: final_conds.append(tmp_conds[3]) if final_select_agg != None and final_select_column != None and ( where_number_random == 1 and len(final_conds) == 1 or where_number_random == 2 and len(final_conds) == 2 or where_number_random == 3 and len(final_conds) == 3 or where_number_random == 4 and len(final_conds) == 4): break if final_select_agg != None and final_select_column != None and where_number_random == 0: break if final_select_column != None: explored_data["sel"] = final_select_column explored_data["agg"] = final_select_agg explored_data["conds"] = final_conds explored_data_list.append(explored_data) print(len(explored_data_list)) one_data = batch_data[i] one_data["sql"] = explored_data one_data["query"] = explored_data f = open("gen_data.jsonl", mode="a", encoding="utf-8") json.dump(one_data, f) f.write('\n') f.close() print("Done") ave_loss /= count acc_sc = count_sc / count acc_sa = count_sa / count acc_wn = count_wn / count acc_wc = count_wc / count acc_wo = count_wo / count acc_wvi = count_wv / count acc_wv = count_wv / count acc_lx = count_logic_form_acc / count acc_x = count_execute_acc / count acc = [ ave_loss, acc_sc, acc_sa, acc_wn, acc_wc, acc_wo, acc_wvi, acc_wv, acc_lx, acc_x ] aux_out = 1 return acc, aux_out
from sqlnet.dbengine import DBEngine from rl.train_rl import config from train import * import json engine_train = DBEngine("train.db") engine_dev = DBEngine("dev.db") train_data, train_table, dev_data, dev_table, _, _ = load_wikisql( "./", False, -1, no_w2i=True, no_hs_tok=True) train_loader, dev_loader = get_loader_wikisql(train_data, dev_data, 32, shuffle_train=False) def process(train_data_, name, engine_): for i, item in enumerate(train_data_): if i % 100 == 0: print(i) # if i==15988: # print() # sql = {'sel': 5, 'conds': [[3, 0, "26"], [6, 1, "8"]], 'agg': 1} # table_id = '2-10240125-1' # t = train_table[table_id] # a = engine_.execute_query_v2(table_id, sql) answer = engine_.execute_query_v2(item["table_id"], item["sql"]) if answer == [None]: print(None) train_data_[i]["answer"] = answer
import os from sqlnet.dbengine import DBEngine import json import time path_db = './wikisql/data/tianchi/' dset_name = 'val' # The engine for seaching results engine = DBEngine(os.path.join(path_db, f"{dset_name}.db")) # Return the results queried query = lambda sql_tmp: engine.execute(sql_tmp['table_id'], sql_tmp['sql']['sel'], sql_tmp['sql']['agg'], sql_tmp['sql']['conds'], sql_tmp['sql']['cond_conn_op']) fname = os.path.join(path_db, f"{dset_name}.json") with open(fname, encoding='utf-8') as fs: total_count = 0 start = time.time() for line in fs: record = json.loads(line) break while total_count < 10: res = query(record) total_count += 1 print('%d times of invoking cost %.2fs.' % (total_count, time.time() - start)) # print('The number of empty results is %d, the number of result is 0 %d' % (count, tmp))
def test(data_loader, data_table, model, model_bert, bert_config, tokenizer, max_seq_length, num_target_layers, detail=False, st_pos=0, cnt_tot=1, EG=False, beam_size=4, path_db=None, dset_name='test'): model.eval() model_bert.eval() ave_loss = 0 cnt = 0 cnt_sc = 0 cnt_sa = 0 cnt_wn = 0 cnt_wc = 0 cnt_wo = 0 cnt_wv = 0 cnt_wvi = 0 cnt_lx = 0 cnt_x = 0 cnt_list = [] engine = DBEngine(os.path.join(path_db, f"{dset_name}.db")) results = [] for iB, t in enumerate(data_loader): cnt += len(t) if cnt < st_pos: continue # Get fields nlu, nlu_t, sql_i, sql_q, sql_t, tb, hs_t, hds = get_fields( t, data_table, no_hs_t=True, no_sql_t=True) g_sc, g_sa, g_wn, g_wc, g_wo, g_wv = get_g(sql_i) g_wvi_corenlp = get_g_wvi_corenlp(t) wemb_n, wemb_h, l_n, l_hpu, l_hs, \ nlu_tt, t_to_tt_idx, tt_to_t_idx \ = get_wemb_bert(bert_config, model_bert, tokenizer, nlu_t, hds, max_seq_length, num_out_layers_n=num_target_layers, num_out_layers_h=num_target_layers) try: g_wvi = get_g_wvi_bert_from_g_wvi_corenlp(t_to_tt_idx, g_wvi_corenlp) g_wv_str, g_wv_str_wp = convert_pr_wvi_to_string( g_wvi, nlu_t, nlu_tt, tt_to_t_idx, nlu) except: # Exception happens when where-condition is not found in nlu_tt. # In this case, that train example is not used. # During test, that example considered as wrongly answered. for b in range(len(nlu)): results1 = {} results1["error"] = "Skip happened" results1["nlu"] = nlu[b] results1["table_id"] = tb[b]["id"] results.append(results1) continue # model specific part # score if not EG: # No Execution guided decoding s_sc, s_sa, s_wn, s_wc, s_wo, s_wv = model(wemb_n, l_n, wemb_h, l_hpu, l_hs) # get loss & step loss = Loss_sw_se(s_sc, s_sa, s_wn, s_wc, s_wo, s_wv, g_sc, g_sa, g_wn, g_wc, g_wo, g_wvi) # prediction pr_sc, pr_sa, pr_wn, pr_wc, pr_wo, pr_wvi = pred_sw_se( s_sc, s_sa, s_wn, s_wc, s_wo, s_wv, ) pr_wv_str, pr_wv_str_wp = convert_pr_wvi_to_string( pr_wvi, nlu_t, nlu_tt, tt_to_t_idx, nlu) # g_sql_i = generate_sql_i(g_sc, g_sa, g_wn, g_wc, g_wo, g_wv_str, nlu) pr_sql_i = generate_sql_i(pr_sc, pr_sa, pr_wn, pr_wc, pr_wo, pr_wv_str, nlu) else: # Execution guided decoding prob_sca, prob_w, prob_wn_w, pr_sc, pr_sa, pr_wn, pr_sql_i = model.beam_forward( wemb_n, l_n, wemb_h, l_hpu, l_hs, engine, tb, nlu_t, nlu_tt, tt_to_t_idx, nlu, beam_size=beam_size) # sort and generate pr_wc, pr_wo, pr_wv, pr_sql_i = sort_and_generate_pr_w(pr_sql_i) # Follosing variables are just for the consistency with no-EG case. pr_wvi = None # not used pr_wv_str = None pr_wv_str_wp = None loss = torch.tensor([0]) g_sql_q = generate_sql_q(sql_i, tb) pr_sql_q = generate_sql_q(pr_sql_i, tb) # Saving for the official evaluation later. for b, pr_sql_i1 in enumerate(pr_sql_i): results1 = {} results1["query"] = pr_sql_i1 results1["table_id"] = tb[b]["id"] results1["nlu"] = nlu[b] results.append(results1) cnt_sc1_list, cnt_sa1_list, cnt_wn1_list, \ cnt_wc1_list, cnt_wo1_list, \ cnt_wvi1_list, cnt_wv1_list = get_cnt_sw_list(g_sc, g_sa,g_wn, g_wc,g_wo, g_wvi, pr_sc, pr_sa, pr_wn, pr_wc, pr_wo, pr_wvi, sql_i, pr_sql_i, mode='test') cnt_lx1_list = get_cnt_lx_list(cnt_sc1_list, cnt_sa1_list, cnt_wn1_list, cnt_wc1_list, cnt_wo1_list, cnt_wv1_list) # Execution accura y test cnt_x1_list = [] # lx stands for logical form accuracy # Execution accuracy test. cnt_x1_list, g_ans, pr_ans = get_cnt_x_list(engine, tb, g_sc, g_sa, sql_i, pr_sc, pr_sa, pr_sql_i) # stat ave_loss += loss.item() # count cnt_sc += sum(cnt_sc1_list) cnt_sa += sum(cnt_sa1_list) cnt_wn += sum(cnt_wn1_list) cnt_wc += sum(cnt_wc1_list) cnt_wo += sum(cnt_wo1_list) cnt_wv += sum(cnt_wv1_list) cnt_wvi += sum(cnt_wvi1_list) cnt_lx += sum(cnt_lx1_list) cnt_x += sum(cnt_x1_list) current_cnt = [ cnt_tot, cnt, cnt_sc, cnt_sa, cnt_wn, cnt_wc, cnt_wo, cnt_wv, cnt_wvi, cnt_lx, cnt_x ] cnt_list1 = [ cnt_sc1_list, cnt_sa1_list, cnt_wn1_list, cnt_wc1_list, cnt_wo1_list, cnt_wv1_list, cnt_lx1_list, cnt_x1_list ] cnt_list.append(cnt_list1) # report if detail: report_detail(hds, nlu, g_sc, g_sa, g_wn, g_wc, g_wo, g_wv, g_wv_str, g_sql_q, g_ans, pr_sc, pr_sa, pr_wn, pr_wc, pr_wo, pr_wv_str, pr_sql_q, pr_ans, cnt_list1, current_cnt) ave_loss /= cnt acc_sc = cnt_sc / cnt acc_sa = cnt_sa / cnt acc_wn = cnt_wn / cnt acc_wc = cnt_wc / cnt acc_wo = cnt_wo / cnt acc_wvi = cnt_wvi / cnt acc_wv = cnt_wv / cnt acc_lx = cnt_lx / cnt acc_x = cnt_x / cnt acc = [ ave_loss, acc_sc, acc_sa, acc_wn, acc_wc, acc_wo, acc_wvi, acc_wv, acc_lx, acc_x ] return acc, results, cnt_list
def train(train_loader, train_table, model, model_bert, opt, bert_config, tokenizer, max_seq_length, num_target_layers, accumulate_gradients=1, check_grad=True, st_pos=0, opt_bert=None, path_db=None, dset_name='train'): model.train() model_bert.train() ave_loss, one_acc_num, tot_acc_num, ex_acc_num = 0, 0.0, 0.0, 0.0 cnt = 0 # count the # of examples # Engine for SQL querying. # 这里别忘了改,引擎要变成新的 engine = DBEngine(os.path.join(path_db, f"{dset_name}.db")) pbar = tqdm(range(len(train_loader.dataset) // 16)) for iB, t in enumerate(train_loader): # t 是一个完整的tok文件 cnt += len(t) if cnt < st_pos: continue # Get fields nlu, nlu_t, sql_i, sql_t, tb, hs_t, hds = get_fields(t, train_table, no_hs_t=True, no_sql_t=True) # nlu : natural language utterance 源自然语言 # nlu_t: tokenized nlu 分词的问题 # sql_i: canonical form of SQL query 查询sql # sql_q: full SQL query text. Not used.已删除 # sql_t: tokenized SQL query 分词的问题 = nlu_t # tb : table # hs_t : tokenized headers. Not used. # hds : header g_sc, g_sa, g_sop, g_wn, g_wc, g_wo, g_wv, g_sel_num_seq, g_sel_ag_seq, conds = get_g( sql_i) # g_sel_num_seq真实sel的个数 # g_sel_ag_seq 包含一个元组,agg个数,sel实际值,agg实际值(list) # get ground truth where-value index under CoreNLP tokenization scheme. It's done already on trainset. ''' 去除baseline情况的,需要把g_sc /g_sa换成一位数组 ''' g_sc1 = [] for i in range(len(g_sc)): g_sc1.append(g_sc[i][0]) g_sc = g_sc1 g_sa1 = [] for i in range(len(g_sa)): g_sa1.append(g_sa[i][0]) g_sa = g_sa1 # 这里提取了语义索引 g_wvi_corenlp = get_g_wvi_corenlp(t) wemb_n, wemb_h, l_n, l_hpu, l_hs, \ nlu_tt, t_to_tt_idx, tt_to_t_idx \ = get_wemb_bert(bert_config, model_bert, tokenizer, nlu_t, hds, max_seq_length, num_out_layers_n=num_target_layers, num_out_layers_h=num_target_layers) # wemb_n: natural language embedding # wemb_h: header embedding # l_n: token lengths of each question # l_hpu: header token lengths # l_hs: the number of columns (headers) of the tables. try: g_wvi = get_g_wvi_bert_from_g_wvi_corenlp(t_to_tt_idx, g_wvi_corenlp) except: print('索引转值出错') continue # score s_scn, s_sc, s_sa, s_sop, s_wn, s_wc, s_wo, s_wv = model( wemb_n, l_n, wemb_h, l_hpu, l_hs, g_scn=g_sel_num_seq, g_sc=g_sc, g_sa=g_sa, g_wn=g_wn, g_wc=g_wc, g_sop=g_sop, g_wo=g_wo, g_wvi=g_wvi) # start = time.time() # results = [] # lenth = len(t) # g_wvi_corenlp = [] # 多进程部分 ''' manager = mp.Manager() dict = manager.dict() pool = mp.Pool(32) for x in range(lenth): pool.apply_async(gwvi, (dict, x, conds[x], nlu_t[x])) pool.close() pool.join() for idx in range(lenth): g_wvi_corenlp.append(dict[idx]) end = time.time() print('runs %0.2f seconds.' % (end - start)) ''' # 单进程部分 # for x in range(len(conds)): # wv_ann1 = [] # cond1 = conds[x] # nlu_1 = nlu_t[x] # for conds11 in cond1: # _wv_ann1 = annotate_ws.annotate(str(conds11[2])) # wv_ann11 = _wv_ann1['gloss'] # wv_ann1.append(wv_ann11) # # try: # wvi1_corenlp = annotate_ws.check_wv_tok_in_nlu_tok(wv_ann1, nlu_1) # g_wvi_corenlp.append(wvi1_corenlp) # except: # print("gwvi构建失败") # print(nlu_1) # exit() loss = Loss_sw_se(s_scn, s_sc, s_sa, s_sop, s_wn, s_wc, s_wo, s_wv, g_sel_num_seq, g_sc, g_sa, g_sop, g_wn, g_wc, g_wo, g_wvi) # Calculate gradient if iB % accumulate_gradients == 0: # mode # at start, perform zero_grad opt.zero_grad() if opt_bert: opt_bert.zero_grad() loss.backward() if accumulate_gradients == 1: opt.step() if opt_bert: opt_bert.step() elif iB % accumulate_gradients == (accumulate_gradients - 1): # at the final, take step with accumulated graident loss.backward() opt.step() if opt_bert: opt_bert.step() else: # at intermediate stage, just accumulates the gradients loss.backward() # L = loss.item() ave_loss += loss.item() pbar.update(len(t)) return ave_loss / cnt