def test_different_not_equal_operators(): ground_truth = [(False, 7, (0, (0, '__papers.title__', False), None), '"bar"', None)] assert get_sql(test_schema(), 'SELECT * FROM papers WHERE papers.title <> "bar"' )['where'] == ground_truth assert get_sql(test_schema(), 'SELECT * FROM papers WHERE papers.title != "bar"' )['where'] == ground_truth
def test_joins(): ground_truth = { 'conds': [], 'table_units': [('table_unit', '__papers__'), ('table_unit', '__coauthored__')] } assert get_sql( test_schema(), 'SELECT * FROM papers JOIN coauthored')['from'] == ground_truth assert get_sql( test_schema(), 'SELECT * FROM papers INNER JOIN coauthored')['from'] == ground_truth assert get_sql(test_schema(), 'SELECT * FROM papers, coauthored')['from'] == ground_truth
def test_parse_col(): ground_truth = (False, [(3, (0, (0, '__papers.id__', True), None))]) assert get_sql(test_schema(), 'SELECT COUNT(DISTINCT(papers.id)) FROM papers' )['select'] == ground_truth assert get_sql(test_schema(), 'SELECT COUNT(DISTINCT papers.id) FROM papers' )['select'] == ground_truth ground_truth = (True, [(0, (0, (0, '__papers.id__', False), None))]) assert get_sql( test_schema(), 'SELECT DISTINCT(papers.id) FROM papers')['select'] == ground_truth assert get_sql( test_schema(), 'SELECT DISTINCT papers.id FROM papers')['select'] == ground_truth
def single_acc(self, pred_sql, gold_sql, db, etype): """ @return: score(float): 0 or 1, etype score hardness(str): one of 'easy', 'medium', 'hard', 'extra' """ db_name = db db = os.path.join(self.database_dir, db, db + ".sqlite") schema = Schema(get_schema(db)) g_sql = get_sql(schema, gold_sql) hardness = self.engine.eval_hardness(g_sql) try: p_sql = get_sql(schema, pred_sql) except: # If p_sql is not valid, then we will use an empty sql to evaluate with the correct sql p_sql = { "except": None, "from": { "conds": [], "table_units": [] }, "groupBy": [], "having": [], "intersect": None, "limit": None, "orderBy": [], "select": [False, []], "union": None, "where": [] } kmap = self.kmaps[db_name] g_valid_col_units = build_valid_col_units(g_sql['from']['table_units'], schema) g_sql = rebuild_sql_val(g_sql) g_sql = rebuild_sql_col( g_valid_col_units, g_sql, kmap) # kmap: map __tab.col__ to pivot __tab.col__ p_valid_col_units = build_valid_col_units(p_sql['from']['table_units'], schema) p_sql = rebuild_sql_val(p_sql) p_sql = rebuild_sql_col(p_valid_col_units, p_sql, kmap) if etype == 'exec': score = float(eval_exec_match(db, pred_sql, gold_sql, p_sql, g_sql)) if etype == 'match': score = float(self.engine.eval_exact_match(p_sql, g_sql)) return score, hardness
def evaluate_match_per_example(g_str, p_str, db_id, db_dir, kmaps): evaluator = Evaluator() db = os.path.join(db_dir, db_id, db_id + ".sqlite") schema = Schema(get_schema(db)) g_sql = get_sql(schema, g_str) hardness = evaluator.eval_hardness(g_sql) bool_err = False try: p_sql = get_sql(schema, p_str) except: # If p_sql is not valid, then we will use an empty sql to evaluate with the correct sql p_sql = { "except": None, "from": { "conds": [], "table_units": [] }, "groupBy": [], "having": [], "intersect": None, "limit": None, "orderBy": [], "select": [False, []], "union": None, "where": [] } bool_err = True # rebuild sql for value evaluation kmap = kmaps[db_id] g_valid_col_units = build_valid_col_units(g_sql['from']['table_units'], schema) g_sql = rebuild_sql_val(g_sql) g_sql = rebuild_sql_col(g_valid_col_units, g_sql, kmap) p_valid_col_units = build_valid_col_units(p_sql['from']['table_units'], schema) p_sql = rebuild_sql_val(p_sql) p_sql = rebuild_sql_col(p_valid_col_units, p_sql, kmap) exact_score = evaluator.eval_exact_match(copy.deepcopy(p_sql), copy.deepcopy(g_sql)) partial_scores = evaluator.partial_scores return hardness, bool_err, exact_score, partial_scores, p_sql, g_sql
def do_score(evaluator, db_dir, kmaps, p, g): g_str, db = g db_name = db db = os.path.join(db_dir, db, db + ".sqlite") schema = Schema(get_schema(db)) g_sql = get_sql(schema, g_str) try: p_sql = get_sql(schema, p) except: # If p_sql is not valid, then we will use an empty sql to evaluate with the correct sql p_sql = { "except": None, "from": { "conds": [], "table_units": [] }, "groupBy": [], "having": [], "intersect": None, "limit": None, "orderBy": [], "select": [ False, [] ], "union": None, "where": [] } # rebuild sql for value evaluation kmap = kmaps[db_name] g_valid_col_units = build_valid_col_units(g_sql['from']['table_units'], schema) g_sql = rebuild_sql_val(g_sql) g_sql = rebuild_sql_col(g_valid_col_units, g_sql, kmap) p_valid_col_units = build_valid_col_units(p_sql['from']['table_units'], schema) p_sql = rebuild_sql_val(p_sql) p_sql = rebuild_sql_col(p_valid_col_units, p_sql, kmap) exact_score = evaluator.eval_exact_match(p_sql, g_sql) return exact_score
def __init__(self, query, query_no_value, question, db_id, schema): self.db_id = db_id self.query = query self.question = question try: self.question_toks = nltk.word_tokenize(self.question) self.query_toks = nltk.word_tokenize(self.query) self.query_toks_no_value = nltk.word_tokenize(query_no_value) except: self.question = None self.query = None self.sql = get_sql(schema, query)
def validity_check(self, sql: str, db: str): """ Check whether the given sql query is valid, including: 1. only use columns in tables mentioned in FROM clause 2. comparison operator or MAX/MIN/SUM/AVG only applied to columns of type number/time @params: sql(str): SQL query db(str): db_id field, database name @return: flag(boolean) """ schema, table = self.schemas[db], self.tables[db] schema = SchemaID(schema, table) try: sql = get_sql(schema, sql) return self.sql_check(sql, self.database[db]) except Exception as e: print('Runtime error occurs:', e) return False
def evaluate(gold, predict, db_dir, etype, kmaps): with open(gold) as f: glist = [ l.strip().split('\t') for l in f.readlines() if len(l.strip()) > 0 ] with open(predict) as f: plist = [ l.strip().split('\t') for l in f.readlines() if len(l.strip()) > 0 ] mixed_list = list(zip(glist, plist)) random.shuffle(mixed_list) glist, plist = zip(*mixed_list) # plist = [("select max(Share),min(Share) from performance where Type != 'terminal'", "orchestra")] # glist = [("SELECT max(SHARE) , min(SHARE) FROM performance WHERE TYPE != 'Live final'", "orchestra")] evaluator = Evaluator() levels = ['easy', 'medium', 'hard', 'extra', 'all'] partial_types = [ 'select', 'select(no AGG)', 'from', 'where', 'where(no OP)', 'group(no Having)', 'group', 'order', 'and/or', 'IUEN', 'keywords' ] entries = [] scores = {} for level in levels: scores[level] = {'count': 0, 'partial': {}, 'exact': 0.} scores[level]['exec'] = 0 for type_ in partial_types: scores[level]['partial'][type_] = { 'acc': 0., 'rec': 0., 'f1': 0., 'acc_count': 0, 'rec_count': 0 } eval_err_num = 0 compound_correct = 0 compound_detect = 0 for p, g in zip(plist, glist): p_str = p[0] g_str, db = g db_name = db db = os.path.join(db_dir, db, db + ".sqlite") schema = Schema(get_schema(db)) g_sql = get_sql(schema, g_str) assert g_sql['from']['table_units'] hardness = evaluator.eval_hardness(g_sql) scores[hardness]['count'] += 1 scores['all']['count'] += 1 try: p_sql = get_sql(schema, p_str) except: # If p_sql is not valid, then we will use an empty sql to evaluate with the correct sql p_sql = { "except": None, "from": { "conds": [], "table_units": [] }, "groupBy": [], "having": [], "intersect": None, "limit": None, "orderBy": [], "select": [False, []], "union": None, "where": [] } eval_err_num += 1 print(("eval_err_num:{}".format(eval_err_num))) # rebuild sql for value evaluation kmap = kmaps[db_name] g_valid_col_units = build_valid_col_units(g_sql['from']['table_units'], schema) g_sql = rebuild_sql_val(g_sql) g_sql = rebuild_sql_col(g_valid_col_units, g_sql, kmap) p_valid_col_units = build_valid_col_units(p_sql['from']['table_units'], schema) p_sql = rebuild_sql_val(p_sql) p_sql = rebuild_sql_col(p_valid_col_units, p_sql, kmap) if etype in ["all", "exec"]: exec_score = eval_exec_match(db, p_str, g_str, p_sql, g_sql) if exec_score: scores[hardness]['exec'] += 1 if etype in ["all", "match"]: exact_score = evaluator.eval_exact_match(p_sql, g_sql) if g_sql['intersect'] or g_sql['union'] or g_sql['except']: compound_detect += 1 compound_correct += exact_score partial_scores = evaluator.partial_scores if exact_score == 0: print(("{} pred: {}".format(hardness, p_str))) print(("{} gold: {}".format(hardness, g_str))) print("") scores[hardness]['exact'] += exact_score scores['all']['exact'] += exact_score for type_ in partial_types: if partial_scores[type_]['pred_total'] > 0: scores[hardness]['partial'][type_][ 'acc'] += partial_scores[type_]['acc'] scores[hardness]['partial'][type_]['acc_count'] += 1 if partial_scores[type_]['label_total'] > 0: scores[hardness]['partial'][type_][ 'rec'] += partial_scores[type_]['rec'] scores[hardness]['partial'][type_]['rec_count'] += 1 scores[hardness]['partial'][type_]['f1'] += partial_scores[ type_]['f1'] if partial_scores[type_]['pred_total'] > 0: scores['all']['partial'][type_]['acc'] += partial_scores[ type_]['acc'] scores['all']['partial'][type_]['acc_count'] += 1 if partial_scores[type_]['label_total'] > 0: scores['all']['partial'][type_]['rec'] += partial_scores[ type_]['rec'] scores['all']['partial'][type_]['rec_count'] += 1 scores['all']['partial'][type_]['f1'] += partial_scores[type_][ 'f1'] entries.append({ 'predictSQL': p_str, 'goldSQL': g_str, 'hardness': hardness, 'exact': exact_score, 'partial': partial_scores }) for level in levels: if scores[level]['count'] == 0: continue if etype in ["all", "exec"]: scores[level]['exec'] /= scores[level]['count'] if etype in ["all", "match"]: scores[level]['exact'] /= scores[level]['count'] for type_ in partial_types: if scores[level]['partial'][type_]['acc_count'] == 0: scores[level]['partial'][type_]['acc'] = 0 else: scores[level]['partial'][type_]['acc'] = scores[level]['partial'][type_]['acc'] / \ scores[level]['partial'][type_]['acc_count'] * 1.0 if scores[level]['partial'][type_]['rec_count'] == 0: scores[level]['partial'][type_]['rec'] = 0 else: scores[level]['partial'][type_]['rec'] = scores[level]['partial'][type_]['rec'] / \ scores[level]['partial'][type_]['rec_count'] * 1.0 if scores[level]['partial'][type_]['acc'] == 0 and scores[ level]['partial'][type_]['rec'] == 0: scores[level]['partial'][type_]['f1'] = 1 else: scores[level]['partial'][type_]['f1'] = \ 2.0 * scores[level]['partial'][type_]['acc'] * scores[level]['partial'][type_]['rec'] / ( scores[level]['partial'][type_]['rec'] + scores[level]['partial'][type_]['acc']) print("compound: {} / {}".format(compound_correct, compound_detect)) print_scores(scores, etype)
def evaluate(gold, predict, db_dir, etype, kmaps): with open(gold) as f: glist = [] gseq_one = [] for l in f.readlines(): if len(l.strip()) == 0: glist.append(gseq_one) gseq_one = [] else: lstrip = l.strip().split('\t') gseq_one.append(lstrip) #glist = [l.strip().split('\t') for l in f.readlines() if len(l.strip()) > 0] with open(predict) as f: plist = [] pseq_one = [] p_socre_list = [] pseq_score_one = [] question_list = [] question_one = [] while True: l = f.readline() if l == "": break if len(l.strip()) == 0: plist.append(pseq_one) pseq_one = [] p_socre_list.append(pseq_score_one) pseq_score_one = [] question_list.append(question_one) question_one = [] else: x = l.strip().split('\t') pseq_one.append(x) l2 = f.readline() y = l2.strip().split('\t') y = [math.exp(-float(s)) for s in y] assert len(x) == len(y) pseq_score_one.append(y) question_one.append(f.readline().strip()) #print('len(x)', len(x)) #plist = [l.strip().split('\t') for l in f.readlines() if len(l.strip()) > 0] # plist = [[("select product_type_code from products group by product_type_code order by count ( * ) desc limit value", "orchestra")]] # glist = [[("SELECT product_type_code FROM Products GROUP BY product_type_code ORDER BY count(*) DESC LIMIT 1", "customers_and_orders")]] evaluator = Evaluator() evaluator2 = Evaluator() turns = ['turn 1', 'turn 2', 'turn 3', 'turn 4', 'turn >4'] levels = ['easy', 'medium', 'hard', 'extra', 'all', 'joint_all'] partial_types = [ 'select', 'select(no AGG)', 'where', 'where(no OP)', 'group(no Having)', 'group', 'order', 'and/or', 'IUEN', 'keywords' ] entries = [] scores = {} for turn in turns: scores[turn] = {'count': 0, 'exact': 0.} scores[turn]['exec'] = 0 for level in levels: scores[level] = {'count': 0, 'partial': {}, 'exact': 0.} scores[level]['exec'] = 0 for type_ in partial_types: scores[level]['partial'][type_] = { 'acc': 0., 'rec': 0., 'f1': 0., 'acc_count': 0, 'rec_count': 0 } eval_err_num = 0 n1 = 0 n2 = 0 n3 = 0 predict_file = open("./predict.txt", "w") for p, g, s, questions in zip(plist, glist, p_socre_list, question_list): scores['joint_all']['count'] += 1 turn_scores = {"exec": [], "exact": []} predict_str = '' for idx, pgs in enumerate(zip(p, g, s, questions)): p, g, s, question = pgs #p_str = p[0] #p_str = p_str.replace("value", "1") g_str, db = g db_name = db db = os.path.join(db_dir, db, db + ".sqlite") schema = Schema(get_schema(db)) try: g_sql = get_sql(schema, g_str) except: continue hardness = evaluator.eval_hardness(g_sql) ori_idx = idx if idx > 3: idx = ">4" else: idx += 1 turn_id = "turn " + str(idx) scores[turn_id]['count'] += 1 scores[hardness]['count'] += 1 scores['all']['count'] += 1 p_sql = None flag = False p_sql_socre = [] for p_str, s in zip(p, s): cur_s = s flag2 = False try: p_str = p_str.replace("value", "1") p_sql = get_sql(schema, p_str) flag2 = True except: pass if flag2: vis = set() for ss in p_str.split(' '): ss = ss.lower() if ss == 'from': break if ss in stop_word: continue if ss in vis: flag2 = False for fk in [ 'none', 'max', 'min', 'count', 'sum', 'avg' ]: if fk in p_str.lower(): flag2 = True break if flag2: break if cmp(p_sql, g_sql, kmaps[db_name], evaluator2, schema): pass break vis.add(ss) if flag2 is False: continue slist = p_str.lower().split(' ') for i in range(len(slist) - 2): ss = slist[i] if slist[i + 1] == '=' and slist[i + 2] == '1': if ss in vis: if cmp(p_sql, g_sql, kmaps[db_name], evaluator2, schema): pass flag2 = False break if flag2 == False: continue flag = False for i in range(len(p_sql_socre)): sql1 = p_sql_socre[i][0] if cmp(sql1, p_sql, kmaps[db_name], evaluator2, schema): #print('+++') p_sql_socre[i] = (sql1, (p_sql_socre[i][1][0] + cur_s, p_sql_socre[i][1][1])) flag = True if cmp(sql1, g_sql, kmaps[db_name], evaluator2, schema): assert cmp(p_sql, g_sql, kmaps[db_name], evaluator2, schema) if cmp(p_sql, g_sql, kmaps[db_name], evaluator2, schema): assert cmp(sql1, g_sql, kmaps[db_name], evaluator2, schema) break if flag == False: p_sql_socre.append((p_sql, (cur_s, p_str))) p_sql = None max_socre = -100 p_str = "error" for i in range(len(p_sql_socre)): sql1 = p_sql_socre[i][0] cur_s = p_sql_socre[i][1][0] cur_p_str = p_sql_socre[i][1][1] if p_sql == None or max_socre < cur_s: p_sql = sql1 max_socre = cur_s p_str = cur_p_str if False and p_sql is None: print('p', p) print('s', s) for pi in p: if p_sql == None or len(p_str.split(' ')) < len( pi.split(' ')): try: pi = pi.replace("value", "1") p_sql = get_sql(schema, pi) p_str = pi except: pass if p_sql is None: # If p_sql is not valid, then we will use an empty sql to evaluate with the correct sql p_sql = { "except": None, "from": { "conds": [], "table_units": [] }, "groupBy": [], "having": [], "intersect": None, "limit": None, "orderBy": [], "select": [False, []], "union": None, "where": [] } eval_err_num += 1 print("eval_err_num:{}".format(eval_err_num)) # rebuild sql for value evaluation kmap = kmaps[db_name] g_valid_col_units = build_valid_col_units( g_sql['from']['table_units'], schema) g_sql = rebuild_sql_val(g_sql) g_sql = rebuild_sql_col(g_valid_col_units, g_sql, kmap) p_valid_col_units = build_valid_col_units( p_sql['from']['table_units'], schema) p_sql = rebuild_sql_val(p_sql) p_sql = rebuild_sql_col(p_valid_col_units, p_sql, kmap) if etype in ["all", "exec"]: exec_score = eval_exec_match(db, p_str, g_str, p_sql, g_sql) if exec_score: scores[hardness]['exec'] += 1 scores[turn_id]['exec'] += 1 turn_scores['exec'].append(1) else: turn_scores['exec'].append(0) if etype in ["all", "match"]: exact_score = evaluator.eval_exact_match(p_sql, g_sql) partial_scores = evaluator.partial_scores if exact_score == 0: turn_scores['exact'].append(0) """ print('question: {}'.format(question)) print("{} pred: {}".format(hardness, p_str)) print("{} gold: {}".format(hardness, g_str)) print('') """ else: """ print("Right") print('question', question) print("{} pred: {}".format(hardness, p_str)) print("{} gold: {}".format(hardness, g_str)) print('') """ turn_scores['exact'].append(1) print(p_str) predict_str += p_str + '\n' scores[turn_id]['exact'] += exact_score scores[hardness]['exact'] += exact_score scores['all']['exact'] += exact_score for type_ in partial_types: if partial_scores[type_]['pred_total'] > 0: scores[hardness]['partial'][type_][ 'acc'] += partial_scores[type_]['acc'] scores[hardness]['partial'][type_]['acc_count'] += 1 if partial_scores[type_]['label_total'] > 0: scores[hardness]['partial'][type_][ 'rec'] += partial_scores[type_]['rec'] scores[hardness]['partial'][type_]['rec_count'] += 1 scores[hardness]['partial'][type_]['f1'] += partial_scores[ type_]['f1'] if partial_scores[type_]['pred_total'] > 0: scores['all']['partial'][type_][ 'acc'] += partial_scores[type_]['acc'] scores['all']['partial'][type_]['acc_count'] += 1 if partial_scores[type_]['label_total'] > 0: scores['all']['partial'][type_][ 'rec'] += partial_scores[type_]['rec'] scores['all']['partial'][type_]['rec_count'] += 1 scores['all']['partial'][type_]['f1'] += partial_scores[ type_]['f1'] entries.append({ 'predictSQL': p_str, 'goldSQL': g_str, 'hardness': hardness, 'exact': exact_score, 'partial': partial_scores }) if all(v == 1 for v in turn_scores["exec"]): scores['joint_all']['exec'] += 1 if all(v == 1 for v in turn_scores["exact"]): scores['joint_all']['exact'] += 1 predict_str += '\n' predict_file.write(predict_str) for turn in turns: if scores[turn]['count'] == 0: continue if etype in ["all", "exec"]: scores[turn]['exec'] /= scores[turn]['count'] if etype in ["all", "match"]: scores[turn]['exact'] /= scores[turn]['count'] for level in levels: if scores[level]['count'] == 0: continue if etype in ["all", "exec"]: scores[level]['exec'] /= scores[level]['count'] if etype in ["all", "match"]: scores[level]['exact'] /= scores[level]['count'] for type_ in partial_types: if scores[level]['partial'][type_]['acc_count'] == 0: scores[level]['partial'][type_]['acc'] = 0 else: scores[level]['partial'][type_]['acc'] = scores[level]['partial'][type_]['acc'] / \ scores[level]['partial'][type_]['acc_count'] * 1.0 if scores[level]['partial'][type_]['rec_count'] == 0: scores[level]['partial'][type_]['rec'] = 0 else: scores[level]['partial'][type_]['rec'] = scores[level]['partial'][type_]['rec'] / \ scores[level]['partial'][type_]['rec_count'] * 1.0 if scores[level]['partial'][type_]['acc'] == 0 and scores[ level]['partial'][type_]['rec'] == 0: scores[level]['partial'][type_]['f1'] = 1 else: scores[level]['partial'][type_]['f1'] = \ 2.0 * scores[level]['partial'][type_]['acc'] * scores[level]['partial'][type_]['rec'] / ( scores[level]['partial'][type_]['rec'] + scores[level]['partial'][type_]['acc']) print_scores(scores, etype) predict_file.close()
def evaluate(gold, predict, db_dir, etype, kmaps): with open(gold) as f: glist = [] gseq_one = [] for l in f.readlines(): if len(l.strip()) == 0: glist.append(gseq_one) gseq_one = [] else: lstrip = l.strip().split('\t') gseq_one.append(lstrip) #glist = [l.strip().split('\t') for l in f.readlines() if len(l.strip()) > 0] with open(predict) as f: plist = [] pseq_one = [] for l in f.readlines(): if len(l.strip()) == 0: plist.append(pseq_one) pseq_one = [] else: pseq_one.append(l.strip().split('\t')) #plist = [l.strip().split('\t') for l in f.readlines() if len(l.strip()) > 0] # plist = [[("select product_type_code from products group by product_type_code order by count ( * ) desc limit value", "orchestra")]] # glist = [[("SELECT product_type_code FROM Products GROUP BY product_type_code ORDER BY count(*) DESC LIMIT 1", "customers_and_orders")]] evaluator = Evaluator() turns = ['turn 1', 'turn 2', 'turn 3', 'turn 4', 'turn >4'] levels = ['easy', 'medium', 'hard', 'extra', 'all', 'joint_all'] partial_types = [ 'select', 'select(no AGG)', 'where', 'where(no OP)', 'group(no Having)', 'group', 'order', 'and/or', 'IUEN', 'keywords' ] entries = [] scores = {} for turn in turns: scores[turn] = {'count': 0, 'exact': 0.} scores[turn]['exec'] = 0 for level in levels: scores[level] = {'count': 0, 'partial': {}, 'exact': 0.} scores[level]['exec'] = 0 for type_ in partial_types: scores[level]['partial'][type_] = { 'acc': 0., 'rec': 0., 'f1': 0., 'acc_count': 0, 'rec_count': 0 } eval_err_num = 0 for p, g in zip(plist, glist): print("----------------------interaction begin--------------") scores['joint_all']['count'] += 1 turn_scores = {"exec": [], "exact": []} for idx, pg in enumerate(zip(p, g)): p, g = pg p_str = p[0] p_str = p_str.replace("value", "1") g_str, db = g db_name = db db = os.path.join(db_dir, db, db + ".sqlite") schema = Schema(get_schema(db)) g_sql = get_sql(schema, g_str) hardness = evaluator.eval_hardness(g_sql) if idx > 3: idx = ">4" else: idx += 1 turn_id = "turn " + str(idx) scores[turn_id]['count'] += 1 scores[hardness]['count'] += 1 scores['all']['count'] += 1 try: p_sql = get_sql(schema, p_str) except: # If p_sql is not valid, then we will use an empty sql to evaluate with the correct sql p_sql = { "except": None, "from": { "conds": [], "table_units": [] }, "groupBy": [], "having": [], "intersect": None, "limit": None, "orderBy": [], "select": [False, []], "union": None, "where": [] } eval_err_num += 1 print("eval_err_num:{}".format(eval_err_num)) # rebuild sql for value evaluation kmap = kmaps[db_name] g_valid_col_units = build_valid_col_units( g_sql['from']['table_units'], schema) g_sql = rebuild_sql_val(g_sql) g_sql = rebuild_sql_col(g_valid_col_units, g_sql, kmap) p_valid_col_units = build_valid_col_units( p_sql['from']['table_units'], schema) p_sql = rebuild_sql_val(p_sql) p_sql = rebuild_sql_col(p_valid_col_units, p_sql, kmap) if etype in ["all", "exec"]: exec_score = eval_exec_match(db, p_str, g_str, p_sql, g_sql) if exec_score: scores[hardness]['exec'] += 1 scores[turn_id]['exec'] += 1 turn_scores['exec'].append(1) else: turn_scores['exec'].append(0) if etype in ["all", "match"]: exact_score = evaluator.eval_exact_match(p_sql, g_sql) partial_scores = evaluator.partial_scores if exact_score == 0: turn_scores['exact'].append(0) print("{} pred: {}".format(hardness, p_str)) print("{} gold: {}".format(hardness, g_str)) print("") else: print("correct") print("{} pred: {}".format(hardness, p_str)) print("{} gold: {}".format(hardness, g_str)) print("") turn_scores['exact'].append(1) scores[turn_id]['exact'] += exact_score scores[hardness]['exact'] += exact_score scores['all']['exact'] += exact_score for type_ in partial_types: if partial_scores[type_]['pred_total'] > 0: scores[hardness]['partial'][type_][ 'acc'] += partial_scores[type_]['acc'] scores[hardness]['partial'][type_]['acc_count'] += 1 if partial_scores[type_]['label_total'] > 0: scores[hardness]['partial'][type_][ 'rec'] += partial_scores[type_]['rec'] scores[hardness]['partial'][type_]['rec_count'] += 1 scores[hardness]['partial'][type_]['f1'] += partial_scores[ type_]['f1'] if partial_scores[type_]['pred_total'] > 0: scores['all']['partial'][type_][ 'acc'] += partial_scores[type_]['acc'] scores['all']['partial'][type_]['acc_count'] += 1 if partial_scores[type_]['label_total'] > 0: scores['all']['partial'][type_][ 'rec'] += partial_scores[type_]['rec'] scores['all']['partial'][type_]['rec_count'] += 1 scores['all']['partial'][type_]['f1'] += partial_scores[ type_]['f1'] entries.append({ 'predictSQL': p_str, 'goldSQL': g_str, 'hardness': hardness, 'exact': exact_score, 'partial': partial_scores }) if all(v == 1 for v in turn_scores["exec"]): scores['joint_all']['exec'] += 1 if all(v == 1 for v in turn_scores["exact"]): scores['joint_all']['exact'] += 1 print("all correct") for turn in turns: if scores[turn]['count'] == 0: continue if etype in ["all", "exec"]: scores[turn]['exec'] /= scores[turn]['count'] if etype in ["all", "match"]: scores[turn]['exact'] /= scores[turn]['count'] for level in levels: if scores[level]['count'] == 0: continue if etype in ["all", "exec"]: scores[level]['exec'] /= scores[level]['count'] if etype in ["all", "match"]: scores[level]['exact'] /= scores[level]['count'] for type_ in partial_types: if scores[level]['partial'][type_]['acc_count'] == 0: scores[level]['partial'][type_]['acc'] = 0 else: scores[level]['partial'][type_]['acc'] = scores[level]['partial'][type_]['acc'] / \ scores[level]['partial'][type_]['acc_count'] * 1.0 if scores[level]['partial'][type_]['rec_count'] == 0: scores[level]['partial'][type_]['rec'] = 0 else: scores[level]['partial'][type_]['rec'] = scores[level]['partial'][type_]['rec'] / \ scores[level]['partial'][type_]['rec_count'] * 1.0 if scores[level]['partial'][type_]['acc'] == 0 and scores[ level]['partial'][type_]['rec'] == 0: scores[level]['partial'][type_]['f1'] = 1 else: scores[level]['partial'][type_]['f1'] = \ 2.0 * scores[level]['partial'][type_]['acc'] * scores[level]['partial'][type_]['rec'] / ( scores[level]['partial'][type_]['rec'] + scores[level]['partial'][type_]['acc']) print_scores(scores, etype)
def parse_file_and_sql(filepath, schema, db_id): f = open(filepath, "r") ret = [] lines = list(f.readlines()) f.close() i = 0 questions = [] has_prefix = False while i < len(lines): line = lines[i].lstrip().rstrip() line = line.replace("\r", "") line = line.replace("\n", "") if len(line) == 0: i += 1 continue if ord('0') <= ord(line[0]) <= ord('9'): #remove question number if len(questions) != 0: print( '\n-----------------------------wrong indexing!-----------------------------------\n' ) print('questions: ' + questions) sys.exit() index = line.find(".") if index != -1: line = line[index + 1:] if line != '' and len(line) != 0: questions.append(line.lstrip().rstrip()) i += 1 continue if line.startswith("P:"): index = line.find("P:") line = line[index + 2:] if line != '' and len(line) != 0: questions.append(line.lstrip().rstrip()) has_prefix = True if (line.startswith("select") or line.startswith("SELECT") or line.startswith("Select") or \ line.startswith("with") or line.startswith("With") or line.startswith("WITH")) and has_prefix: sql = [line] i += 1 while i < len(lines): line = lines[i] line = lines[i].lstrip().rstrip() line = line.replace("\r", "") line = line.replace("\n", "") if len(line) == 0 or len(line.strip()) == 0 or ord('0') <= ord(line[0]) <= ord('9') or \ not (line[0].isalpha() or line[0] in ['(',')','=','<','>', '+', '-','!','\'','\"','%']): break sql.append(line) i += 1 sql = " ".join(sql) sql = sqlparse.format(sql, reindent=False, keyword_case='upper') sql = re.sub(r"(<=|>=|=|<|>|,)", r" \1 ", sql) # sql = sql.replace("\"","'") sql = re.sub(r"(T\d+\.)\s", r"\1", sql) #if len(questions) != 2: # print '\n-----------------------------wrong indexing!-----------------------------------\n' # print 'questions: ', questions # sys.exit() for ix, q in enumerate(questions): try: q = q.encode("utf8") sql = sql.encode("utf8") q_toks = word_tokenize(q) query_toks = word_tokenize(sql) query_toks_no_value = strip_query(sql) sql_label = None sql_label = get_sql(schema, sql) #print("query: {}".format(sql)) #print("\ndb_id: {}".format(db_id)) #print("query: {}".format(sql)) ret.append({ 'question': q, 'question_toks': q_toks, 'query': sql, 'query_toks': query_toks, 'query_toks_no_value': query_toks_no_value, 'sql': sql_label, 'db_id': db_id }) except Exception as e: #print("query: {}".format(sql)) #print(e) pass questions = [] has_prefix = False continue i += 1 return ret
schemas[db_id] = schema return schemas, db_names, tables schemas, db_names, tables = get_schemas_from_json(table_file) with open(sql_path) as inf: sql_data = json.load(inf) sql_data_new = [] for data in sql_data: try: db_id = data["db_id"] schema = schemas[db_id] table = tables[db_id] schema = Schema(schema, table) sql = data["query"] sql_label = get_sql(schema, sql) data["sql"] = sql_label sql_data_new.append(data) except: print("db_id: ", db_id) print("sql: ", sql) with open(output_file, 'wt') as out: json.dump(sql_data_new, out, sort_keys=True, indent=4, separators=(',', ': '))
sql_schema = construct_schema(db_id) sql_schemas[db_id] = sql_schema schema = schemas[db_id] item = {key: value for key, value in i.items()} rm_where_query = build_from(rm_where(query), schema, True) add_rm_where = True if rm_where_query == 0: add_rm_where = False rm_where_query = query rm_where_query_toks = word_tokenize(rm_where_query) item['query'] = rm_where_query item['query_toks'] = rm_where_query_toks item['query_toks_no_value'] = rm_value_toks(rm_where_query) try: item['sql'] = get_sql(sql_schema, rm_where_query) except: add_rm_where = False print(query) print(rm_where_query) print() item = {key: value for key, value in i.items()} rm_select_query = build_from(rm_select(query), schema, False) add_rm_select = True if rm_select_query == 0: add_rm_select = False rm_select_query = query rm_select_query_toks = word_tokenize(rm_select_query) item['query'] = rm_select_query
def evaluate(gold, predict, db_dir, etype, kmaps, plug_value, keep_distinct, progress_bar_for_each_datapoint): with open(gold) as f: glist = [] gseq_one = [] for l in f.readlines(): if len(l.strip()) == 0: glist.append(gseq_one) gseq_one = [] else: lstrip = l.strip().split('\t') gseq_one.append(lstrip) # include the last session # this was previously ignored in the SParC evaluation script # which might lead to slight differences in scores if len(gseq_one) != 0: glist.append(gseq_one) # spider formatting indicates that there is only one "single turn" # do not report "turn accuracy" for SPIDER include_turn_acc = len(glist) > 1 with open(predict) as f: plist = [] pseq_one = [] for l in f.readlines(): if len(l.strip()) == 0: plist.append(pseq_one) pseq_one = [] else: pseq_one.append(l.strip().split('\t')) if len(pseq_one) != 0: plist.append(pseq_one) assert len(plist) == len(glist), "number of sessions must equal" evaluator = Evaluator() turns = ['turn 1', 'turn 2', 'turn 3', 'turn 4', 'turn > 4'] levels = ['easy', 'medium', 'hard', 'extra', 'all', 'joint_all'] partial_types = ['select', 'select(no AGG)', 'where', 'where(no OP)', 'group(no Having)', 'group', 'order', 'and/or', 'IUEN', 'keywords'] entries = [] scores = {} for turn in turns: scores[turn] = {'count': 0, 'exact': 0.} scores[turn]['exec'] = 0 for level in levels: scores[level] = {'count': 0, 'partial': {}, 'exact': 0.} scores[level]['exec'] = 0 for type_ in partial_types: scores[level]['partial'][type_] = {'acc': 0., 'rec': 0., 'f1': 0.,'acc_count':0,'rec_count':0} for i, (p, g) in enumerate(zip(plist, glist)): if (i + 1) % 10 == 0: print('Evaluating %dth prediction' % (i + 1)) scores['joint_all']['count'] += 1 turn_scores = {"exec": [], "exact": []} for idx, pg in enumerate(zip(p, g)): p, g = pg p_str = p[0] p_str = p_str.replace("value", "1") g_str, db = g db_name = db db = os.path.join(db_dir, db, db + ".sqlite") schema = Schema(get_schema(db)) g_sql = get_sql(schema, g_str) hardness = evaluator.eval_hardness(g_sql) if idx > 3: idx = "> 4" else: idx += 1 turn_id = "turn " + str(idx) scores[turn_id]['count'] += 1 scores[hardness]['count'] += 1 scores['all']['count'] += 1 try: p_sql = get_sql(schema, p_str) except: # If p_sql is not valid, then we will use an empty sql to evaluate with the correct sql p_sql = { "except": None, "from": { "conds": [], "table_units": [] }, "groupBy": [], "having": [], "intersect": None, "limit": None, "orderBy": [], "select": [ False, [] ], "union": None, "where": [] } if etype in ["all", "exec"]: exec_score = eval_exec_match(db=db, p_str=p_str, g_str=g_str, plug_value=plug_value, keep_distinct=keep_distinct, progress_bar_for_each_datapoint=progress_bar_for_each_datapoint) if exec_score: scores[hardness]['exec'] += 1 scores[turn_id]['exec'] += 1 scores['all']['exec'] += 1 turn_scores['exec'].append(1) else: turn_scores['exec'].append(0) if etype in ["all", "match"]: # rebuild sql for value evaluation kmap = kmaps[db_name] g_valid_col_units = build_valid_col_units(g_sql['from']['table_units'], schema) g_sql = rebuild_sql_val(g_sql) g_sql = rebuild_sql_col(g_valid_col_units, g_sql, kmap) p_valid_col_units = build_valid_col_units(p_sql['from']['table_units'], schema) p_sql = rebuild_sql_val(p_sql) p_sql = rebuild_sql_col(p_valid_col_units, p_sql, kmap) exact_score = evaluator.eval_exact_match(p_sql, g_sql) partial_scores = evaluator.partial_scores if exact_score == 0: turn_scores['exact'].append(0) print("{} pred: {}".format(hardness, p_str)) print("{} gold: {}".format(hardness, g_str)) print("") else: turn_scores['exact'].append(1) scores[turn_id]['exact'] += exact_score scores[hardness]['exact'] += exact_score scores['all']['exact'] += exact_score for type_ in partial_types: if partial_scores[type_]['pred_total'] > 0: scores[hardness]['partial'][type_]['acc'] += partial_scores[type_]['acc'] scores[hardness]['partial'][type_]['acc_count'] += 1 if partial_scores[type_]['label_total'] > 0: scores[hardness]['partial'][type_]['rec'] += partial_scores[type_]['rec'] scores[hardness]['partial'][type_]['rec_count'] += 1 scores[hardness]['partial'][type_]['f1'] += partial_scores[type_]['f1'] if partial_scores[type_]['pred_total'] > 0: scores['all']['partial'][type_]['acc'] += partial_scores[type_]['acc'] scores['all']['partial'][type_]['acc_count'] += 1 if partial_scores[type_]['label_total'] > 0: scores['all']['partial'][type_]['rec'] += partial_scores[type_]['rec'] scores['all']['partial'][type_]['rec_count'] += 1 scores['all']['partial'][type_]['f1'] += partial_scores[type_]['f1'] entries.append({ 'predictSQL': p_str, 'goldSQL': g_str, 'hardness': hardness, 'exact': exact_score, 'partial': partial_scores }) if all(v == 1 for v in turn_scores["exec"]): scores['joint_all']['exec'] += 1 if all(v == 1 for v in turn_scores["exact"]): scores['joint_all']['exact'] += 1 for turn in turns: if scores[turn]['count'] == 0: continue if etype in ["all", "exec"]: scores[turn]['exec'] /= scores[turn]['count'] if etype in ["all", "match"]: scores[turn]['exact'] /= scores[turn]['count'] for level in levels: if scores[level]['count'] == 0: continue if etype in ["all", "exec"]: scores[level]['exec'] /= scores[level]['count'] if etype in ["all", "match"]: scores[level]['exact'] /= scores[level]['count'] for type_ in partial_types: if scores[level]['partial'][type_]['acc_count'] == 0: scores[level]['partial'][type_]['acc'] = 0 else: scores[level]['partial'][type_]['acc'] = scores[level]['partial'][type_]['acc'] / \ scores[level]['partial'][type_]['acc_count'] * 1.0 if scores[level]['partial'][type_]['rec_count'] == 0: scores[level]['partial'][type_]['rec'] = 0 else: scores[level]['partial'][type_]['rec'] = scores[level]['partial'][type_]['rec'] / \ scores[level]['partial'][type_]['rec_count'] * 1.0 if scores[level]['partial'][type_]['acc'] == 0 and scores[level]['partial'][type_]['rec'] == 0: scores[level]['partial'][type_]['f1'] = 1 else: scores[level]['partial'][type_]['f1'] = \ 2.0 * scores[level]['partial'][type_]['acc'] * scores[level]['partial'][type_]['rec'] / ( scores[level]['partial'][type_]['rec'] + scores[level]['partial'][type_]['acc']) print_scores(scores, etype, include_turn_acc=include_turn_acc)
def evaluate(gold, predict, db_dir, etype, kmaps, wtp): with open(gold) as f: glist = [] gseq_one = [] for l in f.readlines(): if len(l.strip()) == 0: glist.append(gseq_one) gseq_one = [] else: lstrip = l.strip().split('\t') gseq_one.append(lstrip) with open(predict) as f: plist = [] pseq_one = [] for l in f.readlines(): if len(l.strip()) == 0: plist.append(pseq_one) pseq_one = [] else: pseq_one.append(l.strip().split('\t')) with open(wtp) as f: wtlist = [] wt_one = [] for l in f.readlines(): if len(l.strip()) == 0: wtlist.append(wt_one) wt_one = [] else: wt_one.append(l) evaluator = Evaluator() turns = ['turn 1', 'turn 2', 'turn 3', 'turn 4', 'turn >4'] levels = ['easy', 'medium', 'hard', 'extra', 'all', 'joint_all'] partial_types = [ 'select', 'select(no AGG)', 'where', 'where(no OP)', 'group(no Having)', 'group', 'order', 'and/or', 'IUEN', 'keywords' ] entries = [] scores = {} for turn in turns: scores[turn] = {'count': 0, 'exact': 0.} scores[turn]['exec'] = 0 for level in levels: scores[level] = {'count': 0, 'partial': {}, 'exact': 0.} scores[level]['exec'] = 0 for type_ in partial_types: scores[level]['partial'][type_] = { 'acc': 0., 'rec': 0., 'f1': 0., 'acc_count': 0, 'rec_count': 0 } eval_err_num = 0 wheerr = 0 for p, g, wtl in zip(plist, glist, wtlist): scores['joint_all']['count'] += 1 turn_scores = {"exec": [], "exact": []} bff = 0 pb = '' for idx, pg in enumerate(zip(p, g)): p, g = pg p_str = p[0] p_str = p_str.replace("value", "1") g_str, db = g db_name = db db = os.path.join(db_dir, db, db + ".sqlite") schema = Schema(get_schema(db)) g_sql = get_sql(schema, g_str) hardness = evaluator.eval_hardness(g_sql) if idx > 3: idx = ">4" else: idx += 1 turn_id = "turn " + str(idx) scores[turn_id]['count'] += 1 scores[hardness]['count'] += 1 scores['all']['count'] += 1 err = 0 try: p_sql = get_sql(schema, p_str) except Exception as e: p_sql = { "except": None, "from": { "conds": [], "table_units": [] }, "groupBy": [], "having": [], "intersect": None, "limit": None, "orderBy": [], "select": [False, []], "union": None, "where": [] } eval_err_num += 1 err = 1 kmap = kmaps[db_name] g_valid_col_units = build_valid_col_units( g_sql['from']['table_units'], schema) g_sql = rebuild_sql_val(g_sql) g_sql = rebuild_sql_col(g_valid_col_units, g_sql, kmap) p_valid_col_units = build_valid_col_units( p_sql['from']['table_units'], schema) p_sql = rebuild_sql_val(p_sql) p_sql = rebuild_sql_col(p_valid_col_units, p_sql, kmap) if etype in ["all", "exec"]: exec_score = eval_exec_match(db, p_str, g_str, p_sql, g_sql) if exec_score: scores[hardness]['exec'] += 1 scores[turn_id]['exec'] += 1 turn_scores['exec'].append(1) else: turn_scores['exec'].append(0) if etype in ["all", "match"]: exact_score = evaluator.eval_exact_match(p_sql, g_sql) partial_scores = evaluator.partial_scores if exact_score == 0: turn_scores['exact'].append(0) bff = 0 else: bff = 1 pb = p_str turn_scores['exact'].append(1) scores[turn_id]['exact'] += exact_score scores[hardness]['exact'] += exact_score scores['all']['exact'] += exact_score for type_ in partial_types: if partial_scores[type_]['pred_total'] > 0: scores[hardness]['partial'][type_][ 'acc'] += partial_scores[type_]['acc'] scores[hardness]['partial'][type_]['acc_count'] += 1 if partial_scores[type_]['label_total'] > 0: scores[hardness]['partial'][type_][ 'rec'] += partial_scores[type_]['rec'] scores[hardness]['partial'][type_]['rec_count'] += 1 scores[hardness]['partial'][type_]['f1'] += partial_scores[ type_]['f1'] if partial_scores[type_]['pred_total'] > 0: scores['all']['partial'][type_][ 'acc'] += partial_scores[type_]['acc'] scores['all']['partial'][type_]['acc_count'] += 1 if partial_scores[type_]['label_total'] > 0: scores['all']['partial'][type_][ 'rec'] += partial_scores[type_]['rec'] scores['all']['partial'][type_]['rec_count'] += 1 scores['all']['partial'][type_]['f1'] += partial_scores[ type_]['f1'] entries.append({ 'predictSQL': p_str, 'goldSQL': g_str, 'hardness': hardness, 'exact': exact_score, 'partial': partial_scores }) if all(v == 1 for v in turn_scores["exec"]): scores['joint_all']['exec'] += 1 if all(v == 1 for v in turn_scores["exact"]): scores['joint_all']['exact'] += 1 for turn in turns: if scores[turn]['count'] == 0: continue if etype in ["all", "exec"]: scores[turn]['exec'] /= scores[turn]['count'] if etype in ["all", "match"]: scores[turn]['exact'] /= scores[turn]['count'] for level in levels: if scores[level]['count'] == 0: continue if etype in ["all", "exec"]: scores[level]['exec'] /= scores[level]['count'] if etype in ["all", "match"]: scores[level]['exact'] /= scores[level]['count'] for type_ in partial_types: if scores[level]['partial'][type_]['acc_count'] == 0: scores[level]['partial'][type_]['acc'] = 0 else: scores[level]['partial'][type_]['acc'] = scores[level]['partial'][type_]['acc'] / \ scores[level]['partial'][type_]['acc_count'] * 1.0 if scores[level]['partial'][type_]['rec_count'] == 0: scores[level]['partial'][type_]['rec'] = 0 else: scores[level]['partial'][type_]['rec'] = scores[level]['partial'][type_]['rec'] / \ scores[level]['partial'][type_]['rec_count'] * 1.0 if scores[level]['partial'][type_]['acc'] == 0 and scores[ level]['partial'][type_]['rec'] == 0: scores[level]['partial'][type_]['f1'] = 1 else: scores[level]['partial'][type_]['f1'] = \ 2.0 * scores[level]['partial'][type_]['acc'] * scores[level]['partial'][type_]['rec'] / ( scores[level]['partial'][type_]['rec'] + scores[level]['partial'][type_]['acc']) print_scores(scores, etype)