def epoch_exec_acc(model, batch_size, sql_data, table_data, db_path): engine = DBEngine(db_path) model.eval() perm = list(range(len(sql_data))) tot_acc_num = 0.0 for st in tqdm(range(len(sql_data) // batch_size + 1)): ed = (st + 1) * batch_size if (st + 1) * batch_size < len( perm) else len(perm) st = st * batch_size q_seq, gt_sel_num, col_seq, col_num, ans_seq, gt_cond_seq, raw_data = \ to_batch_seq(sql_data, table_data, perm, st, ed, ret_vis_data=True) raw_q_seq = [x[0] for x in raw_data] raw_col_seq = [x[1] for x in raw_data] query_gt, table_ids = to_batch_query(sql_data, perm, st, ed) gt_sel_seq = [x[2] for x in ans_seq] score = model.forward(q_seq, col_seq, col_num, gt_sel=gt_sel_seq) pred_queries = model.gen_query(score, q_seq, col_seq, raw_q_seq, raw_col_seq) for idx, (sql_gt, sql_pred, tid) in enumerate(zip(query_gt, pred_queries, table_ids)): ret_gt = engine.execute(tid, sql_gt['sel'], sql_gt['agg'], sql_gt['conds']) try: ret_pred = engine.execute(tid, sql_pred['sel'], sql_pred['agg'], sql_pred['conds']) except: ret_pred = None tot_acc_num += (ret_gt == ret_pred) return tot_acc_num / len(sql_data)
def epoch_reinforce_train(model, optimizer, batch_size, sql_data, table_data, db_path): engine = DBEngine(db_path) model.train() perm = np.random.permutation(len(sql_data)) cum_reward = 0.0 st = 0 while st < len(sql_data): ed = st + batch_size if st + batch_size < len(perm) else len(perm) q_seq, col_seq, col_num, ans_seq, query_seq, gt_cond_seq, raw_data = \ to_batch_seq(sql_data, table_data, perm, st, ed, ret_vis_data=True) gt_where_seq = model.generate_gt_where_seq(q_seq, col_seq, query_seq) raw_q_seq = [x[0] for x in raw_data] raw_col_seq = [x[1] for x in raw_data] query_gt, table_ids = to_batch_query(sql_data, perm, st, ed) gt_sel_seq = [x[1] for x in ans_seq] score = model.forward(q_seq, col_seq, col_num, (True, True, True), reinforce=True, gt_sel=gt_sel_seq) pred_queries = model.gen_query(score, q_seq, col_seq, raw_q_seq, raw_col_seq, (True, True, True), reinforce=True) query_gt, table_ids = to_batch_query(sql_data, perm, st, ed) rewards = [] for idx, (sql_gt, sql_pred, tid) in enumerate(zip(query_gt, pred_queries, table_ids)): ret_gt = engine.execute(tid, sql_gt['sel'], sql_gt['agg'], sql_gt['conds']) try: ret_pred = engine.execute(tid, sql_pred['sel'], sql_pred['agg'], sql_pred['conds']) except: ret_pred = None if ret_pred is None: rewards.append(-2) elif ret_pred != ret_gt: rewards.append(-1) else: rewards.append(1) cum_reward += (sum(rewards)) optimizer.zero_grad() model.reinforce_backward(score, rewards) optimizer.step() st = ed return cum_reward / len(sql_data)
def epoch_acc(model, batch_size, sql_data, table_data, db_path): engine = DBEngine(db_path) model.eval() perm = list(range(len(sql_data))) badcase = 0 one_acc_num, tot_acc_num, ex_acc_num = 0.0, 0.0, 0.0 for st in tqdm(range(len(sql_data) // batch_size + 1)): ed = (st + 1) * batch_size if (st + 1) * batch_size < len( perm) else len(perm) st = st * batch_size q_seq, gt_sel_num, col_seq, col_num, ans_seq, gt_cond_seq, raw_data = \ to_batch_seq(sql_data, table_data, perm, st, ed, ret_vis_data=True) # q_seq: char-based sequence of question # gt_sel_num: number of selected columns and aggregation functions, new added field # col_seq: char-based column name # col_num: number of headers in one table # ans_seq: (sel, number of conds, sel list in conds, op list in conds) # gt_cond_seq: ground truth of conditions # raw_data: ori question, headers, sql query_gt, table_ids = to_batch_query(sql_data, perm, st, ed) # query_gt: ground truth of sql, data['sql'], containing sel, agg, conds:{sel, op, value} raw_q_seq = [x[0] for x in raw_data] # original question try: score = model.forward(q_seq, col_seq, col_num) pred_queries = model.gen_query(score, q_seq, col_seq, raw_q_seq) # generate predicted format one_err, tot_err = model.check_acc(raw_data, pred_queries, query_gt) except: badcase += 1 print('badcase', badcase) continue one_acc_num += (ed - st - one_err) tot_acc_num += (ed - st - tot_err) # Execution Accuracy for sql_gt, sql_pred, tid in zip(query_gt, pred_queries, table_ids): ret_gt = engine.execute(tid, sql_gt['sel'], sql_gt['agg'], sql_gt['conds'], sql_gt['cond_conn_op']) try: ret_pred = engine.execute(tid, sql_pred['sel'], sql_pred['agg'], sql_pred['conds'], sql_pred['cond_conn_op']) except: ret_pred = None ex_acc_num += (ret_gt == ret_pred) return one_acc_num / len(sql_data), tot_acc_num / len( sql_data), ex_acc_num / len(sql_data)
def epoch_acc(model, batch_size, sql_data, table_data, db_path, db_content): engine = DBEngine(db_path) model.eval() perm = list(range(len(sql_data))) badcase = 0 one_acc_num, tot_acc_num, ex_acc_num = 0.0, 0.0, 0.0 for st in tqdm(range(len(sql_data) // batch_size + 1)): ed = (st + 1) * batch_size if (st + 1) * batch_size < len( perm) else len(perm) st = st * batch_size q_seq, gt_sel_num, col_seq, col_num, ans_seq, gt_cond_seq, q_type, col_type,\ raw_data = to_batch_seq(sql_data, table_data, perm, st, ed, db_content, ret_vis_data=True) print("q_seq:{}".format(len(q_seq))) raw_q_seq = [x[0] for x in raw_data] #raw_col_seq = [x[1] for x in raw_data] query_gt, table_ids = to_batch_query(sql_data, perm, st, ed) #gt_sel_seq = [x[1] for x in ans_seq] #try: score = model.forward(q_seq, col_seq, col_num, q_type, col_type) pred_queries = model.gen_query(score, q_seq, col_seq, raw_q_seq) one_err, tot_err = model.check_acc(raw_data, pred_queries, query_gt) #except: # badcase += 1 # print 'badcase', badcase # continue one_acc_num += (ed - st - one_err) tot_acc_num += (ed - st - tot_err) # Execution Accuracy for sql_gt, sql_pred, tid in zip(query_gt, pred_queries, table_ids): ret_gt = engine.execute(tid, sql_gt['sel'], sql_gt['agg'], sql_gt['conds'], sql_gt['cond_conn_op']) try: ret_pred = engine.execute(tid, sql_pred['sel'], sql_pred['agg'], sql_pred['conds'], sql_pred['cond_conn_op']) except: ret_pred = None ex_acc_num += (ret_gt == ret_pred) return one_acc_num / len(sql_data), tot_acc_num / len( sql_data), ex_acc_num / len(sql_data)
def epoch_exec_acc(model, batch_size, sql_data, table_data, db_path): engine = DBEngine(db_path) print 'exec acc' model.eval() perm = list(range(len(sql_data))) tot_acc_num = 0.0 acc_of_log = 0.0 st = 0 while st < len(sql_data): ed = st+batch_size if st+batch_size < len(perm) else len(perm) q_seq, col_seq, col_num, ans_seq, query_seq, gt_cond_seq, raw_data = \ to_batch_seq(sql_data, table_data, perm, st, ed, ret_vis_data=True) raw_q_seq = [x[0] for x in raw_data] raw_col_seq = [x[1] for x in raw_data] model.generate_gt_sel_seq(q_seq, col_seq, query_seq) gt_where_seq = model.generate_gt_where_seq(q_seq, col_seq, query_seq) query_gt, table_ids = to_batch_query(sql_data, perm, st, ed) gt_sel_seq = [x[1] for x in ans_seq] # print 'gt_sel_seq', gt_sel_seq score = model.forward(q_seq, col_seq, col_num, (True, True, True), gt_sel=gt_sel_seq) pred_queries = model.gen_query(score, q_seq, col_seq, raw_q_seq, raw_col_seq, (True, True, True)) for idx, (sql_gt, sql_pred, tid) in enumerate( zip(query_gt, pred_queries, table_ids)): ret_gt = engine.execute(tid, sql_gt['sel'], sql_gt['agg'], sql_gt['conds']) try: ret_pred = engine.execute(tid, sql_pred['sel'], sql_pred['agg'], sql_pred['conds']) except: ret_pred = None tot_acc_num += (ret_gt == ret_pred) st = ed return tot_acc_num / len(sql_data)
def epoch_exec_acc(models, batch_size, sql_data, table_data, db_path, db_content, BERT=False, POS=False, ensemble='single'): engine = DBEngine(db_path) if len(models) > 1: models_eval = list() for nn in models: nn.eval() models_eval.append(nn) else: model = models[0] model.eval() perm = list(range(len(sql_data))) tot_acc_num = 0.0 acc_of_log = 0.0 st = 0 while st < len(sql_data): ed = st + batch_size if st + batch_size < len(perm) else len(perm) if ensemble == 'mixed': if POS: q_pos, q_seq_bert, col_seq, col_num, ans_seq, query_seq, gt_cond_seq, q_type, col_type,\ raw_data = to_batch_seq(sql_data, table_data, perm, st, ed, db_content, ret_vis_data=True, BERT=BERT, POS=POS) q_pos, q_seq, col_seq, col_num, ans_seq, query_seq, gt_cond_seq, q_type, col_type,\ raw_data = to_batch_seq(sql_data, table_data, perm, st, ed, db_content, ret_vis_data=True, BERT=False, POS=POS) else: q_seq_bert, col_seq, col_num, ans_seq, query_seq, gt_cond_seq, q_type, col_type,\ raw_data = to_batch_seq(sql_data, table_data, perm, st, ed, db_content, ret_vis_data=True, BERT=BERT, POS=POS) q_seq, col_seq, col_num, ans_seq, query_seq, gt_cond_seq, q_type, col_type,\ raw_data = to_batch_seq(sql_data, table_data, perm, st, ed, db_content, ret_vis_data=True, BERT=False, POS=POS) else: if POS: q_pos, q_seq, col_seq, col_num, ans_seq, query_seq, gt_cond_seq, q_type, col_type,\ raw_data = to_batch_seq(sql_data, table_data, perm, st, ed, db_content, ret_vis_data=True, BERT=BERT, POS=POS) else: q_seq, col_seq, col_num, ans_seq, query_seq, gt_cond_seq, q_type, col_type,\ raw_data = to_batch_seq(sql_data, table_data, perm, st, ed, db_content, ret_vis_data=True, BERT=BERT, POS=POS) raw_q_seq = [x[0] for x in raw_data] raw_col_seq = [x[1] for x in raw_data] query_gt, table_ids = to_batch_query(sql_data, perm, st, ed) gt_sel_seq = [x[1] for x in ans_seq] gt_agg_seq = [x[0] for x in ans_seq] if len(models) > 1: scores = list() for i, model in enumerate(models_eval): if ensemble == 'mixed': if i == 0: gt_where_seq = model.generate_gt_where_seq( q_seq, col_seq, query_seq) if POS: score = model.forward(q_seq, col_seq, col_num, q_type, col_type, (True, True, True), q_pos=q_pos) else: score = model.forward(q_seq, col_seq, col_num, q_type, col_type, (True, True, True)) else: gt_where_seq = model.generate_gt_where_seq( q_seq_bert, col_seq, query_seq) if POS: score = model.forward(q_seq_bert, col_seq, col_num, q_type, col_type, (True, True, True), q_pos=q_pos) else: score = model.forward(q_seq_bert, col_seq, col_num, q_type, col_type, (True, True, True)) else: gt_where_seq = model.generate_gt_where_seq( q_seq, col_seq, query_seq) if POS: score = model.forward(q_seq, col_seq, col_num, q_type, col_type, (True, True, True), q_pos=q_pos) else: score = model.forward(q_seq, col_seq, col_num, q_type, col_type, (True, True, True)) scores.append(score) model = models_eval[0] pred_queries = model.gen_query(scores, q_seq, col_seq, raw_q_seq, raw_col_seq, (True, True, True)) for idx, (sql_gt, sql_pred, tid) in enumerate(zip(query_gt, pred_queries, table_ids)): ret_gt = engine.execute(tid, sql_gt['sel'], sql_gt['agg'], sql_gt['conds']) try: ret_pred = engine.execute(tid, sql_pred['sel'], sql_pred['agg'], sql_pred['conds']) except: ret_pred = None tot_acc_num += (ret_gt == ret_pred) else: gt_where_seq = model.generate_gt_where_seq(q_seq, col_seq, query_seq) if POS: score = [ model.forward(q_seq, col_seq, col_num, q_type, col_type, (True, True, True), q_pos=q_pos) ] else: score = [ model.forward(q_seq, col_seq, col_num, q_type, col_type, (True, True, True)) ] pred_queries = model.gen_query(score, q_seq, col_seq, raw_q_seq, raw_col_seq, (True, True, True)) for idx, (sql_gt, sql_pred, tid) in enumerate(zip(query_gt, pred_queries, table_ids)): ret_gt = engine.execute(tid, sql_gt['sel'], sql_gt['agg'], sql_gt['conds']) try: ret_pred = engine.execute(tid, sql_pred['sel'], sql_pred['agg'], sql_pred['conds']) except: ret_pred = None tot_acc_num += (ret_gt == ret_pred) st = ed return tot_acc_num / len(sql_data)