示例#1
0
def test_different_not_equal_operators():
    ground_truth = [(False, 7, (0, (0, '__papers.title__', False), None),
                     '"bar"', None)]
    assert get_sql(test_schema(),
                   'SELECT * FROM papers WHERE papers.title <> "bar"'
                   )['where'] == ground_truth
    assert get_sql(test_schema(),
                   'SELECT * FROM papers WHERE papers.title != "bar"'
                   )['where'] == ground_truth
示例#2
0
def test_joins():
    ground_truth = {
        'conds': [],
        'table_units': [('table_unit', '__papers__'),
                        ('table_unit', '__coauthored__')]
    }
    assert get_sql(
        test_schema(),
        'SELECT * FROM papers JOIN coauthored')['from'] == ground_truth
    assert get_sql(
        test_schema(),
        'SELECT * FROM papers INNER JOIN coauthored')['from'] == ground_truth
    assert get_sql(test_schema(),
                   'SELECT * FROM papers, coauthored')['from'] == ground_truth
示例#3
0
def test_parse_col():
    ground_truth = (False, [(3, (0, (0, '__papers.id__', True), None))])
    assert get_sql(test_schema(),
                   'SELECT COUNT(DISTINCT(papers.id)) FROM papers'
                   )['select'] == ground_truth
    assert get_sql(test_schema(),
                   'SELECT COUNT(DISTINCT papers.id) FROM papers'
                   )['select'] == ground_truth

    ground_truth = (True, [(0, (0, (0, '__papers.id__', False), None))])
    assert get_sql(
        test_schema(),
        'SELECT DISTINCT(papers.id) FROM papers')['select'] == ground_truth
    assert get_sql(
        test_schema(),
        'SELECT DISTINCT papers.id FROM papers')['select'] == ground_truth
示例#4
0
 def single_acc(self, pred_sql, gold_sql, db, etype):
     """
         @return:
             score(float): 0 or 1, etype score
             hardness(str): one of 'easy', 'medium', 'hard', 'extra'
     """
     db_name = db
     db = os.path.join(self.database_dir, db, db + ".sqlite")
     schema = Schema(get_schema(db))
     g_sql = get_sql(schema, gold_sql)
     hardness = self.engine.eval_hardness(g_sql)
     try:
         p_sql = get_sql(schema, pred_sql)
     except:
         # If p_sql is not valid, then we will use an empty sql to evaluate with the correct sql
         p_sql = {
             "except": None,
             "from": {
                 "conds": [],
                 "table_units": []
             },
             "groupBy": [],
             "having": [],
             "intersect": None,
             "limit": None,
             "orderBy": [],
             "select": [False, []],
             "union": None,
             "where": []
         }
     kmap = self.kmaps[db_name]
     g_valid_col_units = build_valid_col_units(g_sql['from']['table_units'],
                                               schema)
     g_sql = rebuild_sql_val(g_sql)
     g_sql = rebuild_sql_col(
         g_valid_col_units, g_sql,
         kmap)  # kmap: map __tab.col__ to pivot __tab.col__
     p_valid_col_units = build_valid_col_units(p_sql['from']['table_units'],
                                               schema)
     p_sql = rebuild_sql_val(p_sql)
     p_sql = rebuild_sql_col(p_valid_col_units, p_sql, kmap)
     if etype == 'exec':
         score = float(eval_exec_match(db, pred_sql, gold_sql, p_sql,
                                       g_sql))
     if etype == 'match':
         score = float(self.engine.eval_exact_match(p_sql, g_sql))
     return score, hardness
示例#5
0
def evaluate_match_per_example(g_str, p_str, db_id, db_dir, kmaps):
    evaluator = Evaluator()

    db = os.path.join(db_dir, db_id, db_id + ".sqlite")
    schema = Schema(get_schema(db))
    g_sql = get_sql(schema, g_str)
    hardness = evaluator.eval_hardness(g_sql)
    bool_err = False

    try:
        p_sql = get_sql(schema, p_str)
    except:
        # If p_sql is not valid, then we will use an empty sql to evaluate with the correct sql
        p_sql = {
            "except": None,
            "from": {
                "conds": [],
                "table_units": []
            },
            "groupBy": [],
            "having": [],
            "intersect": None,
            "limit": None,
            "orderBy": [],
            "select": [False, []],
            "union": None,
            "where": []
        }
        bool_err = True

    # rebuild sql for value evaluation
    kmap = kmaps[db_id]
    g_valid_col_units = build_valid_col_units(g_sql['from']['table_units'],
                                              schema)
    g_sql = rebuild_sql_val(g_sql)
    g_sql = rebuild_sql_col(g_valid_col_units, g_sql, kmap)
    p_valid_col_units = build_valid_col_units(p_sql['from']['table_units'],
                                              schema)
    p_sql = rebuild_sql_val(p_sql)
    p_sql = rebuild_sql_col(p_valid_col_units, p_sql, kmap)

    exact_score = evaluator.eval_exact_match(copy.deepcopy(p_sql),
                                             copy.deepcopy(g_sql))
    partial_scores = evaluator.partial_scores

    return hardness, bool_err, exact_score, partial_scores, p_sql, g_sql
示例#6
0
def do_score(evaluator, db_dir, kmaps, p, g):
    g_str, db = g
    db_name = db
    db = os.path.join(db_dir, db, db + ".sqlite")
    schema = Schema(get_schema(db))
    g_sql = get_sql(schema, g_str)

    try:
        p_sql = get_sql(schema, p)
    except:
        # If p_sql is not valid, then we will use an empty sql to evaluate with the correct sql
        p_sql = {
            "except": None,
            "from": {
                "conds": [],
                "table_units": []
            },
            "groupBy": [],
            "having": [],
            "intersect": None,
            "limit": None,
            "orderBy": [],
            "select": [
                False,
                []
            ],
            "union": None,
            "where": []
        }

    # rebuild sql for value evaluation
    kmap = kmaps[db_name]
    g_valid_col_units = build_valid_col_units(g_sql['from']['table_units'], schema)
    g_sql = rebuild_sql_val(g_sql)
    g_sql = rebuild_sql_col(g_valid_col_units, g_sql, kmap)
    p_valid_col_units = build_valid_col_units(p_sql['from']['table_units'], schema)
    p_sql = rebuild_sql_val(p_sql)
    p_sql = rebuild_sql_col(p_valid_col_units, p_sql, kmap)

    exact_score = evaluator.eval_exact_match(p_sql, g_sql)
    return exact_score
示例#7
0
    def __init__(self, query, query_no_value, question, db_id, schema):
        self.db_id = db_id
        self.query = query

        self.question = question
        try:
            self.question_toks = nltk.word_tokenize(self.question)
            self.query_toks = nltk.word_tokenize(self.query)
            self.query_toks_no_value = nltk.word_tokenize(query_no_value)
        except:
            self.question = None
            self.query = None

        self.sql = get_sql(schema, query)
示例#8
0
 def validity_check(self, sql: str, db: str):
     """ Check whether the given sql query is valid, including:
     1. only use columns in tables mentioned in FROM clause
     2. comparison operator or MAX/MIN/SUM/AVG only applied to columns of type number/time
     @params:
         sql(str): SQL query
         db(str): db_id field, database name
     @return:
         flag(boolean)
     """
     schema, table = self.schemas[db], self.tables[db]
     schema = SchemaID(schema, table)
     try:
         sql = get_sql(schema, sql)
         return self.sql_check(sql, self.database[db])
     except Exception as e:
         print('Runtime error occurs:', e)
         return False
示例#9
0
def evaluate(gold, predict, db_dir, etype, kmaps):
    with open(gold) as f:
        glist = [
            l.strip().split('\t') for l in f.readlines() if len(l.strip()) > 0
        ]

    with open(predict) as f:
        plist = [
            l.strip().split('\t') for l in f.readlines() if len(l.strip()) > 0
        ]
    mixed_list = list(zip(glist, plist))
    random.shuffle(mixed_list)
    glist, plist = zip(*mixed_list)
    # plist = [("select max(Share),min(Share) from performance where Type != 'terminal'", "orchestra")]
    # glist = [("SELECT max(SHARE) ,  min(SHARE) FROM performance WHERE TYPE != 'Live final'", "orchestra")]
    evaluator = Evaluator()

    levels = ['easy', 'medium', 'hard', 'extra', 'all']
    partial_types = [
        'select', 'select(no AGG)', 'from', 'where', 'where(no OP)',
        'group(no Having)', 'group', 'order', 'and/or', 'IUEN', 'keywords'
    ]
    entries = []
    scores = {}

    for level in levels:
        scores[level] = {'count': 0, 'partial': {}, 'exact': 0.}
        scores[level]['exec'] = 0
        for type_ in partial_types:
            scores[level]['partial'][type_] = {
                'acc': 0.,
                'rec': 0.,
                'f1': 0.,
                'acc_count': 0,
                'rec_count': 0
            }

    eval_err_num = 0
    compound_correct = 0
    compound_detect = 0
    for p, g in zip(plist, glist):
        p_str = p[0]
        g_str, db = g
        db_name = db
        db = os.path.join(db_dir, db, db + ".sqlite")
        schema = Schema(get_schema(db))
        g_sql = get_sql(schema, g_str)
        assert g_sql['from']['table_units']
        hardness = evaluator.eval_hardness(g_sql)
        scores[hardness]['count'] += 1
        scores['all']['count'] += 1

        try:
            p_sql = get_sql(schema, p_str)
        except:
            # If p_sql is not valid, then we will use an empty sql to evaluate with the correct sql
            p_sql = {
                "except": None,
                "from": {
                    "conds": [],
                    "table_units": []
                },
                "groupBy": [],
                "having": [],
                "intersect": None,
                "limit": None,
                "orderBy": [],
                "select": [False, []],
                "union": None,
                "where": []
            }
            eval_err_num += 1
            print(("eval_err_num:{}".format(eval_err_num)))

        # rebuild sql for value evaluation
        kmap = kmaps[db_name]
        g_valid_col_units = build_valid_col_units(g_sql['from']['table_units'],
                                                  schema)
        g_sql = rebuild_sql_val(g_sql)
        g_sql = rebuild_sql_col(g_valid_col_units, g_sql, kmap)
        p_valid_col_units = build_valid_col_units(p_sql['from']['table_units'],
                                                  schema)
        p_sql = rebuild_sql_val(p_sql)
        p_sql = rebuild_sql_col(p_valid_col_units, p_sql, kmap)

        if etype in ["all", "exec"]:
            exec_score = eval_exec_match(db, p_str, g_str, p_sql, g_sql)
            if exec_score:
                scores[hardness]['exec'] += 1

        if etype in ["all", "match"]:
            exact_score = evaluator.eval_exact_match(p_sql, g_sql)
            if g_sql['intersect'] or g_sql['union'] or g_sql['except']:
                compound_detect += 1
                compound_correct += exact_score
            partial_scores = evaluator.partial_scores
            if exact_score == 0:
                print(("{} pred: {}".format(hardness, p_str)))
                print(("{} gold: {}".format(hardness, g_str)))
                print("")
            scores[hardness]['exact'] += exact_score
            scores['all']['exact'] += exact_score
            for type_ in partial_types:
                if partial_scores[type_]['pred_total'] > 0:
                    scores[hardness]['partial'][type_][
                        'acc'] += partial_scores[type_]['acc']
                    scores[hardness]['partial'][type_]['acc_count'] += 1
                if partial_scores[type_]['label_total'] > 0:
                    scores[hardness]['partial'][type_][
                        'rec'] += partial_scores[type_]['rec']
                    scores[hardness]['partial'][type_]['rec_count'] += 1
                scores[hardness]['partial'][type_]['f1'] += partial_scores[
                    type_]['f1']
                if partial_scores[type_]['pred_total'] > 0:
                    scores['all']['partial'][type_]['acc'] += partial_scores[
                        type_]['acc']
                    scores['all']['partial'][type_]['acc_count'] += 1
                if partial_scores[type_]['label_total'] > 0:
                    scores['all']['partial'][type_]['rec'] += partial_scores[
                        type_]['rec']
                    scores['all']['partial'][type_]['rec_count'] += 1
                scores['all']['partial'][type_]['f1'] += partial_scores[type_][
                    'f1']

            entries.append({
                'predictSQL': p_str,
                'goldSQL': g_str,
                'hardness': hardness,
                'exact': exact_score,
                'partial': partial_scores
            })

    for level in levels:
        if scores[level]['count'] == 0:
            continue
        if etype in ["all", "exec"]:
            scores[level]['exec'] /= scores[level]['count']

        if etype in ["all", "match"]:
            scores[level]['exact'] /= scores[level]['count']
            for type_ in partial_types:
                if scores[level]['partial'][type_]['acc_count'] == 0:
                    scores[level]['partial'][type_]['acc'] = 0
                else:
                    scores[level]['partial'][type_]['acc'] = scores[level]['partial'][type_]['acc'] / \
                                                             scores[level]['partial'][type_]['acc_count'] * 1.0
                if scores[level]['partial'][type_]['rec_count'] == 0:
                    scores[level]['partial'][type_]['rec'] = 0
                else:
                    scores[level]['partial'][type_]['rec'] = scores[level]['partial'][type_]['rec'] / \
                                                             scores[level]['partial'][type_]['rec_count'] * 1.0
                if scores[level]['partial'][type_]['acc'] == 0 and scores[
                        level]['partial'][type_]['rec'] == 0:
                    scores[level]['partial'][type_]['f1'] = 1
                else:
                    scores[level]['partial'][type_]['f1'] = \
                        2.0 * scores[level]['partial'][type_]['acc'] * scores[level]['partial'][type_]['rec'] / (
                        scores[level]['partial'][type_]['rec'] + scores[level]['partial'][type_]['acc'])

    print("compound: {} / {}".format(compound_correct, compound_detect))
    print_scores(scores, etype)
示例#10
0
def evaluate(gold, predict, db_dir, etype, kmaps):
    with open(gold) as f:
        glist = []
        gseq_one = []
        for l in f.readlines():
            if len(l.strip()) == 0:
                glist.append(gseq_one)
                gseq_one = []
            else:
                lstrip = l.strip().split('\t')
                gseq_one.append(lstrip)
        #glist = [l.strip().split('\t') for l in f.readlines() if len(l.strip()) > 0]

    with open(predict) as f:
        plist = []
        pseq_one = []

        p_socre_list = []
        pseq_score_one = []

        question_list = []
        question_one = []

        while True:
            l = f.readline()
            if l == "":
                break

            if len(l.strip()) == 0:
                plist.append(pseq_one)
                pseq_one = []

                p_socre_list.append(pseq_score_one)
                pseq_score_one = []

                question_list.append(question_one)
                question_one = []
            else:
                x = l.strip().split('\t')
                pseq_one.append(x)

                l2 = f.readline()
                y = l2.strip().split('\t')
                y = [math.exp(-float(s)) for s in y]
                assert len(x) == len(y)
                pseq_score_one.append(y)

                question_one.append(f.readline().strip())

                #print('len(x)', len(x))
        #plist = [l.strip().split('\t') for l in f.readlines() if len(l.strip()) > 0]
    # plist = [[("select product_type_code from products group by product_type_code order by count ( * ) desc limit value", "orchestra")]]
    # glist = [[("SELECT product_type_code FROM Products GROUP BY product_type_code ORDER BY count(*) DESC LIMIT 1", "customers_and_orders")]]
    evaluator = Evaluator()
    evaluator2 = Evaluator()

    turns = ['turn 1', 'turn 2', 'turn 3', 'turn 4', 'turn >4']
    levels = ['easy', 'medium', 'hard', 'extra', 'all', 'joint_all']
    partial_types = [
        'select', 'select(no AGG)', 'where', 'where(no OP)',
        'group(no Having)', 'group', 'order', 'and/or', 'IUEN', 'keywords'
    ]
    entries = []
    scores = {}

    for turn in turns:
        scores[turn] = {'count': 0, 'exact': 0.}
        scores[turn]['exec'] = 0

    for level in levels:
        scores[level] = {'count': 0, 'partial': {}, 'exact': 0.}
        scores[level]['exec'] = 0
        for type_ in partial_types:
            scores[level]['partial'][type_] = {
                'acc': 0.,
                'rec': 0.,
                'f1': 0.,
                'acc_count': 0,
                'rec_count': 0
            }

    eval_err_num = 0
    n1 = 0
    n2 = 0
    n3 = 0
    predict_file = open("./predict.txt", "w")
    for p, g, s, questions in zip(plist, glist, p_socre_list, question_list):
        scores['joint_all']['count'] += 1
        turn_scores = {"exec": [], "exact": []}
        predict_str = ''
        for idx, pgs in enumerate(zip(p, g, s, questions)):
            p, g, s, question = pgs
            #p_str = p[0]
            #p_str = p_str.replace("value", "1")
            g_str, db = g
            db_name = db
            db = os.path.join(db_dir, db, db + ".sqlite")
            schema = Schema(get_schema(db))
            try:
                g_sql = get_sql(schema, g_str)
            except:
                continue
            hardness = evaluator.eval_hardness(g_sql)
            ori_idx = idx
            if idx > 3:
                idx = ">4"
            else:
                idx += 1
            turn_id = "turn " + str(idx)
            scores[turn_id]['count'] += 1
            scores[hardness]['count'] += 1
            scores['all']['count'] += 1

            p_sql = None
            flag = False
            p_sql_socre = []

            for p_str, s in zip(p, s):
                cur_s = s
                flag2 = False

                try:
                    p_str = p_str.replace("value", "1")
                    p_sql = get_sql(schema, p_str)
                    flag2 = True
                except:
                    pass
                if flag2:
                    vis = set()
                    for ss in p_str.split(' '):
                        ss = ss.lower()
                        if ss == 'from':
                            break
                        if ss in stop_word:
                            continue
                        if ss in vis:
                            flag2 = False
                            for fk in [
                                    'none', 'max', 'min', 'count', 'sum', 'avg'
                            ]:
                                if fk in p_str.lower():
                                    flag2 = True
                                    break
                            if flag2:
                                break
                            if cmp(p_sql, g_sql, kmaps[db_name], evaluator2,
                                   schema):
                                pass
                            break
                        vis.add(ss)

                    if flag2 is False:
                        continue
                    slist = p_str.lower().split(' ')
                    for i in range(len(slist) - 2):
                        ss = slist[i]
                        if slist[i + 1] == '=' and slist[i + 2] == '1':
                            if ss in vis:
                                if cmp(p_sql, g_sql, kmaps[db_name],
                                       evaluator2, schema):
                                    pass
                                flag2 = False
                                break
                    if flag2 == False:
                        continue
                    flag = False
                    for i in range(len(p_sql_socre)):
                        sql1 = p_sql_socre[i][0]
                        if cmp(sql1, p_sql, kmaps[db_name], evaluator2,
                               schema):
                            #print('+++')
                            p_sql_socre[i] = (sql1,
                                              (p_sql_socre[i][1][0] + cur_s,
                                               p_sql_socre[i][1][1]))
                            flag = True

                            if cmp(sql1, g_sql, kmaps[db_name], evaluator2,
                                   schema):
                                assert cmp(p_sql, g_sql, kmaps[db_name],
                                           evaluator2, schema)
                            if cmp(p_sql, g_sql, kmaps[db_name], evaluator2,
                                   schema):
                                assert cmp(sql1, g_sql, kmaps[db_name],
                                           evaluator2, schema)
                            break
                    if flag == False:
                        p_sql_socre.append((p_sql, (cur_s, p_str)))
            p_sql = None
            max_socre = -100
            p_str = "error"
            for i in range(len(p_sql_socre)):
                sql1 = p_sql_socre[i][0]
                cur_s = p_sql_socre[i][1][0]
                cur_p_str = p_sql_socre[i][1][1]

                if p_sql == None or max_socre < cur_s:
                    p_sql = sql1
                    max_socre = cur_s
                    p_str = cur_p_str

            if False and p_sql is None:
                print('p', p)
                print('s', s)
                for pi in p:
                    if p_sql == None or len(p_str.split(' ')) < len(
                            pi.split(' ')):
                        try:
                            pi = pi.replace("value", "1")
                            p_sql = get_sql(schema, pi)
                            p_str = pi
                        except:
                            pass

            if p_sql is None:
                # If p_sql is not valid, then we will use an empty sql to evaluate with the correct sql
                p_sql = {
                    "except": None,
                    "from": {
                        "conds": [],
                        "table_units": []
                    },
                    "groupBy": [],
                    "having": [],
                    "intersect": None,
                    "limit": None,
                    "orderBy": [],
                    "select": [False, []],
                    "union": None,
                    "where": []
                }
                eval_err_num += 1
                print("eval_err_num:{}".format(eval_err_num))

            # rebuild sql for value evaluation
            kmap = kmaps[db_name]
            g_valid_col_units = build_valid_col_units(
                g_sql['from']['table_units'], schema)
            g_sql = rebuild_sql_val(g_sql)
            g_sql = rebuild_sql_col(g_valid_col_units, g_sql, kmap)
            p_valid_col_units = build_valid_col_units(
                p_sql['from']['table_units'], schema)
            p_sql = rebuild_sql_val(p_sql)
            p_sql = rebuild_sql_col(p_valid_col_units, p_sql, kmap)

            if etype in ["all", "exec"]:
                exec_score = eval_exec_match(db, p_str, g_str, p_sql, g_sql)
                if exec_score:
                    scores[hardness]['exec'] += 1
                    scores[turn_id]['exec'] += 1
                    turn_scores['exec'].append(1)
                else:
                    turn_scores['exec'].append(0)

            if etype in ["all", "match"]:
                exact_score = evaluator.eval_exact_match(p_sql, g_sql)
                partial_scores = evaluator.partial_scores
                if exact_score == 0:
                    turn_scores['exact'].append(0)
                    """
                    print('question: {}'.format(question))
                    print("{} pred: {}".format(hardness, p_str))
                    print("{} gold: {}".format(hardness, g_str))
                    print('')
                    """
                else:
                    """
                    print("Right")
                    print('question', question)
                    print("{} pred: {}".format(hardness, p_str))
                    print("{} gold: {}".format(hardness, g_str))
                    print('')
                    """
                    turn_scores['exact'].append(1)

                print(p_str)

                predict_str += p_str + '\n'

                scores[turn_id]['exact'] += exact_score
                scores[hardness]['exact'] += exact_score
                scores['all']['exact'] += exact_score
                for type_ in partial_types:
                    if partial_scores[type_]['pred_total'] > 0:
                        scores[hardness]['partial'][type_][
                            'acc'] += partial_scores[type_]['acc']
                        scores[hardness]['partial'][type_]['acc_count'] += 1
                    if partial_scores[type_]['label_total'] > 0:
                        scores[hardness]['partial'][type_][
                            'rec'] += partial_scores[type_]['rec']
                        scores[hardness]['partial'][type_]['rec_count'] += 1
                    scores[hardness]['partial'][type_]['f1'] += partial_scores[
                        type_]['f1']
                    if partial_scores[type_]['pred_total'] > 0:
                        scores['all']['partial'][type_][
                            'acc'] += partial_scores[type_]['acc']
                        scores['all']['partial'][type_]['acc_count'] += 1
                    if partial_scores[type_]['label_total'] > 0:
                        scores['all']['partial'][type_][
                            'rec'] += partial_scores[type_]['rec']
                        scores['all']['partial'][type_]['rec_count'] += 1
                    scores['all']['partial'][type_]['f1'] += partial_scores[
                        type_]['f1']

                entries.append({
                    'predictSQL': p_str,
                    'goldSQL': g_str,
                    'hardness': hardness,
                    'exact': exact_score,
                    'partial': partial_scores
                })

        if all(v == 1 for v in turn_scores["exec"]):
            scores['joint_all']['exec'] += 1

        if all(v == 1 for v in turn_scores["exact"]):
            scores['joint_all']['exact'] += 1

        predict_str += '\n'
        predict_file.write(predict_str)

    for turn in turns:
        if scores[turn]['count'] == 0:
            continue
        if etype in ["all", "exec"]:
            scores[turn]['exec'] /= scores[turn]['count']

        if etype in ["all", "match"]:
            scores[turn]['exact'] /= scores[turn]['count']

    for level in levels:
        if scores[level]['count'] == 0:
            continue
        if etype in ["all", "exec"]:
            scores[level]['exec'] /= scores[level]['count']

        if etype in ["all", "match"]:
            scores[level]['exact'] /= scores[level]['count']
            for type_ in partial_types:
                if scores[level]['partial'][type_]['acc_count'] == 0:
                    scores[level]['partial'][type_]['acc'] = 0
                else:
                    scores[level]['partial'][type_]['acc'] = scores[level]['partial'][type_]['acc'] / \
                                                             scores[level]['partial'][type_]['acc_count'] * 1.0
                if scores[level]['partial'][type_]['rec_count'] == 0:
                    scores[level]['partial'][type_]['rec'] = 0
                else:
                    scores[level]['partial'][type_]['rec'] = scores[level]['partial'][type_]['rec'] / \
                                                             scores[level]['partial'][type_]['rec_count'] * 1.0
                if scores[level]['partial'][type_]['acc'] == 0 and scores[
                        level]['partial'][type_]['rec'] == 0:
                    scores[level]['partial'][type_]['f1'] = 1
                else:
                    scores[level]['partial'][type_]['f1'] = \
                        2.0 * scores[level]['partial'][type_]['acc'] * scores[level]['partial'][type_]['rec'] / (
                        scores[level]['partial'][type_]['rec'] + scores[level]['partial'][type_]['acc'])

    print_scores(scores, etype)
    predict_file.close()
示例#11
0
def evaluate(gold, predict, db_dir, etype, kmaps):
    with open(gold) as f:
        glist = []
        gseq_one = []
        for l in f.readlines():
            if len(l.strip()) == 0:
                glist.append(gseq_one)
                gseq_one = []
            else:
                lstrip = l.strip().split('\t')
                gseq_one.append(lstrip)
        #glist = [l.strip().split('\t') for l in f.readlines() if len(l.strip()) > 0]

    with open(predict) as f:
        plist = []
        pseq_one = []
        for l in f.readlines():
            if len(l.strip()) == 0:
                plist.append(pseq_one)
                pseq_one = []
            else:
                pseq_one.append(l.strip().split('\t'))
        #plist = [l.strip().split('\t') for l in f.readlines() if len(l.strip()) > 0]
    # plist = [[("select product_type_code from products group by product_type_code order by count ( * ) desc limit value", "orchestra")]]
    # glist = [[("SELECT product_type_code FROM Products GROUP BY product_type_code ORDER BY count(*) DESC LIMIT 1", "customers_and_orders")]]
    evaluator = Evaluator()

    turns = ['turn 1', 'turn 2', 'turn 3', 'turn 4', 'turn >4']
    levels = ['easy', 'medium', 'hard', 'extra', 'all', 'joint_all']
    partial_types = [
        'select', 'select(no AGG)', 'where', 'where(no OP)',
        'group(no Having)', 'group', 'order', 'and/or', 'IUEN', 'keywords'
    ]
    entries = []
    scores = {}

    for turn in turns:
        scores[turn] = {'count': 0, 'exact': 0.}
        scores[turn]['exec'] = 0

    for level in levels:
        scores[level] = {'count': 0, 'partial': {}, 'exact': 0.}
        scores[level]['exec'] = 0
        for type_ in partial_types:
            scores[level]['partial'][type_] = {
                'acc': 0.,
                'rec': 0.,
                'f1': 0.,
                'acc_count': 0,
                'rec_count': 0
            }

    eval_err_num = 0
    for p, g in zip(plist, glist):
        print("----------------------interaction begin--------------")
        scores['joint_all']['count'] += 1
        turn_scores = {"exec": [], "exact": []}
        for idx, pg in enumerate(zip(p, g)):
            p, g = pg
            p_str = p[0]
            p_str = p_str.replace("value", "1")
            g_str, db = g
            db_name = db
            db = os.path.join(db_dir, db, db + ".sqlite")
            schema = Schema(get_schema(db))
            g_sql = get_sql(schema, g_str)
            hardness = evaluator.eval_hardness(g_sql)
            if idx > 3:
                idx = ">4"
            else:
                idx += 1
            turn_id = "turn " + str(idx)
            scores[turn_id]['count'] += 1
            scores[hardness]['count'] += 1
            scores['all']['count'] += 1

            try:
                p_sql = get_sql(schema, p_str)
            except:
                # If p_sql is not valid, then we will use an empty sql to evaluate with the correct sql
                p_sql = {
                    "except": None,
                    "from": {
                        "conds": [],
                        "table_units": []
                    },
                    "groupBy": [],
                    "having": [],
                    "intersect": None,
                    "limit": None,
                    "orderBy": [],
                    "select": [False, []],
                    "union": None,
                    "where": []
                }
                eval_err_num += 1
                print("eval_err_num:{}".format(eval_err_num))

            # rebuild sql for value evaluation
            kmap = kmaps[db_name]
            g_valid_col_units = build_valid_col_units(
                g_sql['from']['table_units'], schema)
            g_sql = rebuild_sql_val(g_sql)
            g_sql = rebuild_sql_col(g_valid_col_units, g_sql, kmap)
            p_valid_col_units = build_valid_col_units(
                p_sql['from']['table_units'], schema)
            p_sql = rebuild_sql_val(p_sql)
            p_sql = rebuild_sql_col(p_valid_col_units, p_sql, kmap)

            if etype in ["all", "exec"]:
                exec_score = eval_exec_match(db, p_str, g_str, p_sql, g_sql)
                if exec_score:
                    scores[hardness]['exec'] += 1
                    scores[turn_id]['exec'] += 1
                    turn_scores['exec'].append(1)
                else:
                    turn_scores['exec'].append(0)

            if etype in ["all", "match"]:
                exact_score = evaluator.eval_exact_match(p_sql, g_sql)
                partial_scores = evaluator.partial_scores
                if exact_score == 0:
                    turn_scores['exact'].append(0)
                    print("{} pred: {}".format(hardness, p_str))
                    print("{} gold: {}".format(hardness, g_str))
                    print("")
                else:
                    print("correct")
                    print("{} pred: {}".format(hardness, p_str))
                    print("{} gold: {}".format(hardness, g_str))
                    print("")
                    turn_scores['exact'].append(1)
                scores[turn_id]['exact'] += exact_score
                scores[hardness]['exact'] += exact_score
                scores['all']['exact'] += exact_score
                for type_ in partial_types:
                    if partial_scores[type_]['pred_total'] > 0:
                        scores[hardness]['partial'][type_][
                            'acc'] += partial_scores[type_]['acc']
                        scores[hardness]['partial'][type_]['acc_count'] += 1
                    if partial_scores[type_]['label_total'] > 0:
                        scores[hardness]['partial'][type_][
                            'rec'] += partial_scores[type_]['rec']
                        scores[hardness]['partial'][type_]['rec_count'] += 1
                    scores[hardness]['partial'][type_]['f1'] += partial_scores[
                        type_]['f1']
                    if partial_scores[type_]['pred_total'] > 0:
                        scores['all']['partial'][type_][
                            'acc'] += partial_scores[type_]['acc']
                        scores['all']['partial'][type_]['acc_count'] += 1
                    if partial_scores[type_]['label_total'] > 0:
                        scores['all']['partial'][type_][
                            'rec'] += partial_scores[type_]['rec']
                        scores['all']['partial'][type_]['rec_count'] += 1
                    scores['all']['partial'][type_]['f1'] += partial_scores[
                        type_]['f1']

                entries.append({
                    'predictSQL': p_str,
                    'goldSQL': g_str,
                    'hardness': hardness,
                    'exact': exact_score,
                    'partial': partial_scores
                })

        if all(v == 1 for v in turn_scores["exec"]):
            scores['joint_all']['exec'] += 1

        if all(v == 1 for v in turn_scores["exact"]):
            scores['joint_all']['exact'] += 1
            print("all correct")

    for turn in turns:
        if scores[turn]['count'] == 0:
            continue
        if etype in ["all", "exec"]:
            scores[turn]['exec'] /= scores[turn]['count']

        if etype in ["all", "match"]:
            scores[turn]['exact'] /= scores[turn]['count']

    for level in levels:
        if scores[level]['count'] == 0:
            continue
        if etype in ["all", "exec"]:
            scores[level]['exec'] /= scores[level]['count']

        if etype in ["all", "match"]:
            scores[level]['exact'] /= scores[level]['count']
            for type_ in partial_types:
                if scores[level]['partial'][type_]['acc_count'] == 0:
                    scores[level]['partial'][type_]['acc'] = 0
                else:
                    scores[level]['partial'][type_]['acc'] = scores[level]['partial'][type_]['acc'] / \
                                                             scores[level]['partial'][type_]['acc_count'] * 1.0
                if scores[level]['partial'][type_]['rec_count'] == 0:
                    scores[level]['partial'][type_]['rec'] = 0
                else:
                    scores[level]['partial'][type_]['rec'] = scores[level]['partial'][type_]['rec'] / \
                                                             scores[level]['partial'][type_]['rec_count'] * 1.0
                if scores[level]['partial'][type_]['acc'] == 0 and scores[
                        level]['partial'][type_]['rec'] == 0:
                    scores[level]['partial'][type_]['f1'] = 1
                else:
                    scores[level]['partial'][type_]['f1'] = \
                        2.0 * scores[level]['partial'][type_]['acc'] * scores[level]['partial'][type_]['rec'] / (
                        scores[level]['partial'][type_]['rec'] + scores[level]['partial'][type_]['acc'])

    print_scores(scores, etype)
示例#12
0
def parse_file_and_sql(filepath, schema, db_id):
    f = open(filepath, "r")
    ret = []
    lines = list(f.readlines())
    f.close()
    i = 0
    questions = []
    has_prefix = False
    while i < len(lines):
        line = lines[i].lstrip().rstrip()
        line = line.replace("\r", "")
        line = line.replace("\n", "")
        if len(line) == 0:
            i += 1
            continue
        if ord('0') <= ord(line[0]) <= ord('9'):
            #remove question number
            if len(questions) != 0:
                print(
                    '\n-----------------------------wrong indexing!-----------------------------------\n'
                )
                print('questions: ' + questions)
                sys.exit()
            index = line.find(".")
            if index != -1:
                line = line[index + 1:]
            if line != '' and len(line) != 0:
                questions.append(line.lstrip().rstrip())
            i += 1
            continue
    if line.startswith("P:"):
        index = line.find("P:")
        line = line[index + 2:]
        if line != '' and len(line) != 0:
            questions.append(line.lstrip().rstrip())
        has_prefix = True
    if (line.startswith("select") or line.startswith("SELECT") or line.startswith("Select") or \
        line.startswith("with") or line.startswith("With") or line.startswith("WITH")) and has_prefix:
        sql = [line]
        i += 1
        while i < len(lines):
            line = lines[i]
            line = lines[i].lstrip().rstrip()
            line = line.replace("\r", "")
            line = line.replace("\n", "")
            if len(line) == 0 or len(line.strip()) == 0 or ord('0') <= ord(line[0]) <= ord('9') or \
               not (line[0].isalpha() or line[0] in ['(',')','=','<','>', '+', '-','!','\'','\"','%']):
                break
            sql.append(line)
            i += 1
        sql = " ".join(sql)
        sql = sqlparse.format(sql, reindent=False, keyword_case='upper')
        sql = re.sub(r"(<=|>=|=|<|>|,)", r" \1 ", sql)
        #       sql = sql.replace("\"","'")
        sql = re.sub(r"(T\d+\.)\s", r"\1", sql)
        #if len(questions) != 2:
        #    print '\n-----------------------------wrong indexing!-----------------------------------\n'
        #    print 'questions: ', questions
        #    sys.exit()
        for ix, q in enumerate(questions):
            try:
                q = q.encode("utf8")
                sql = sql.encode("utf8")
                q_toks = word_tokenize(q)
                query_toks = word_tokenize(sql)
                query_toks_no_value = strip_query(sql)
                sql_label = None

                sql_label = get_sql(schema, sql)
                #print("query: {}".format(sql))
                #print("\ndb_id: {}".format(db_id))
                #print("query: {}".format(sql))
                ret.append({
                    'question': q,
                    'question_toks': q_toks,
                    'query': sql,
                    'query_toks': query_toks,
                    'query_toks_no_value': query_toks_no_value,
                    'sql': sql_label,
                    'db_id': db_id
                })
            except Exception as e:
                #print("query: {}".format(sql))
                #print(e)
                pass
            questions = []
            has_prefix = False
            continue

        i += 1

    return ret
示例#13
0
        schemas[db_id] = schema

    return schemas, db_names, tables


schemas, db_names, tables = get_schemas_from_json(table_file)

with open(sql_path) as inf:
    sql_data = json.load(inf)

sql_data_new = []
for data in sql_data:
    try:
        db_id = data["db_id"]
        schema = schemas[db_id]
        table = tables[db_id]
        schema = Schema(schema, table)
        sql = data["query"]
        sql_label = get_sql(schema, sql)
        data["sql"] = sql_label
        sql_data_new.append(data)
    except:
        print("db_id: ", db_id)
        print("sql: ", sql)

with open(output_file, 'wt') as out:
    json.dump(sql_data_new,
              out,
              sort_keys=True,
              indent=4,
              separators=(',', ': '))
            sql_schema = construct_schema(db_id)
            sql_schemas[db_id] = sql_schema
        schema = schemas[db_id]
        item = {key: value for key, value in i.items()}
        rm_where_query = build_from(rm_where(query), schema, True)
        add_rm_where = True
        if rm_where_query == 0:
            add_rm_where = False
            rm_where_query = query

        rm_where_query_toks = word_tokenize(rm_where_query)
        item['query'] = rm_where_query
        item['query_toks'] = rm_where_query_toks
        item['query_toks_no_value'] = rm_value_toks(rm_where_query)
        try:
            item['sql'] = get_sql(sql_schema, rm_where_query)
        except:
            add_rm_where = False
            print(query)
            print(rm_where_query)
            print()

        item = {key: value for key, value in i.items()}
        rm_select_query = build_from(rm_select(query), schema, False)
        add_rm_select = True
        if rm_select_query == 0:
            add_rm_select = False
            rm_select_query = query

        rm_select_query_toks = word_tokenize(rm_select_query)
        item['query'] = rm_select_query
示例#15
0
def evaluate(gold, predict, db_dir, etype, kmaps, plug_value, keep_distinct, progress_bar_for_each_datapoint):

    with open(gold) as f:
        glist = []
        gseq_one = []
        for l in f.readlines():
            if len(l.strip()) == 0:
                glist.append(gseq_one)
                gseq_one = []
            else:
                lstrip = l.strip().split('\t')
                gseq_one.append(lstrip)

        # include the last session
        # this was previously ignored in the SParC evaluation script
        # which might lead to slight differences in scores
        if len(gseq_one) != 0:
            glist.append(gseq_one)

    # spider formatting indicates that there is only one "single turn"
    # do not report "turn accuracy" for SPIDER
    include_turn_acc = len(glist) > 1

    with open(predict) as f:
        plist = []
        pseq_one = []
        for l in f.readlines():
            if len(l.strip()) == 0:
                plist.append(pseq_one)
                pseq_one = []
            else:
                pseq_one.append(l.strip().split('\t'))

        if len(pseq_one) != 0:
            plist.append(pseq_one)

    assert len(plist) == len(glist), "number of sessions must equal"

    evaluator = Evaluator()
    turns = ['turn 1', 'turn 2', 'turn 3', 'turn 4', 'turn > 4']
    levels = ['easy', 'medium', 'hard', 'extra', 'all', 'joint_all']

    partial_types = ['select', 'select(no AGG)', 'where', 'where(no OP)', 'group(no Having)',
                     'group', 'order', 'and/or', 'IUEN', 'keywords']
    entries = []
    scores = {}

    for turn in turns:
        scores[turn] = {'count': 0, 'exact': 0.}
        scores[turn]['exec'] = 0

    for level in levels:
        scores[level] = {'count': 0, 'partial': {}, 'exact': 0.}
        scores[level]['exec'] = 0
        for type_ in partial_types:
            scores[level]['partial'][type_] = {'acc': 0., 'rec': 0., 'f1': 0.,'acc_count':0,'rec_count':0}

    for i, (p, g) in enumerate(zip(plist, glist)):
        if (i + 1) % 10 == 0:
            print('Evaluating %dth prediction' % (i + 1))
        scores['joint_all']['count'] += 1
        turn_scores = {"exec": [], "exact": []}
        for idx, pg in enumerate(zip(p, g)):
            p, g = pg
            p_str = p[0]
            p_str = p_str.replace("value", "1")
            g_str, db = g
            db_name = db
            db = os.path.join(db_dir, db, db + ".sqlite")
            schema = Schema(get_schema(db))
            g_sql = get_sql(schema, g_str)
            hardness = evaluator.eval_hardness(g_sql)
            if idx > 3:
                idx = "> 4"
            else:
                idx += 1
            turn_id = "turn " + str(idx)
            scores[turn_id]['count'] += 1
            scores[hardness]['count'] += 1
            scores['all']['count'] += 1

            try:
                p_sql = get_sql(schema, p_str)
            except:
                # If p_sql is not valid, then we will use an empty sql to evaluate with the correct sql
                p_sql = {
                "except": None,
                "from": {
                    "conds": [],
                    "table_units": []
                },
                "groupBy": [],
                "having": [],
                "intersect": None,
                "limit": None,
                "orderBy": [],
                "select": [
                    False,
                    []
                ],
                "union": None,
                "where": []
                }

            if etype in ["all", "exec"]:
                exec_score = eval_exec_match(db=db, p_str=p_str, g_str=g_str, plug_value=plug_value,
                                             keep_distinct=keep_distinct, progress_bar_for_each_datapoint=progress_bar_for_each_datapoint)
                if exec_score:
                    scores[hardness]['exec'] += 1
                    scores[turn_id]['exec'] += 1
                    scores['all']['exec'] += 1
                    turn_scores['exec'].append(1)
                else:
                    turn_scores['exec'].append(0)

            if etype in ["all", "match"]:
                # rebuild sql for value evaluation
                kmap = kmaps[db_name]
                g_valid_col_units = build_valid_col_units(g_sql['from']['table_units'], schema)
                g_sql = rebuild_sql_val(g_sql)
                g_sql = rebuild_sql_col(g_valid_col_units, g_sql, kmap)
                p_valid_col_units = build_valid_col_units(p_sql['from']['table_units'], schema)
                p_sql = rebuild_sql_val(p_sql)
                p_sql = rebuild_sql_col(p_valid_col_units, p_sql, kmap)
                exact_score = evaluator.eval_exact_match(p_sql, g_sql)
                partial_scores = evaluator.partial_scores
                if exact_score == 0:
                    turn_scores['exact'].append(0)
                    print("{} pred: {}".format(hardness, p_str))
                    print("{} gold: {}".format(hardness, g_str))
                    print("")
                else:
                    turn_scores['exact'].append(1)
                scores[turn_id]['exact'] += exact_score
                scores[hardness]['exact'] += exact_score
                scores['all']['exact'] += exact_score
                for type_ in partial_types:
                    if partial_scores[type_]['pred_total'] > 0:
                        scores[hardness]['partial'][type_]['acc'] += partial_scores[type_]['acc']
                        scores[hardness]['partial'][type_]['acc_count'] += 1
                    if partial_scores[type_]['label_total'] > 0:
                        scores[hardness]['partial'][type_]['rec'] += partial_scores[type_]['rec']
                        scores[hardness]['partial'][type_]['rec_count'] += 1
                    scores[hardness]['partial'][type_]['f1'] += partial_scores[type_]['f1']
                    if partial_scores[type_]['pred_total'] > 0:
                        scores['all']['partial'][type_]['acc'] += partial_scores[type_]['acc']
                        scores['all']['partial'][type_]['acc_count'] += 1
                    if partial_scores[type_]['label_total'] > 0:
                        scores['all']['partial'][type_]['rec'] += partial_scores[type_]['rec']
                        scores['all']['partial'][type_]['rec_count'] += 1
                    scores['all']['partial'][type_]['f1'] += partial_scores[type_]['f1']

                entries.append({
                    'predictSQL': p_str,
                    'goldSQL': g_str,
                    'hardness': hardness,
                    'exact': exact_score,
                    'partial': partial_scores
                })

        if all(v == 1 for v in turn_scores["exec"]):
            scores['joint_all']['exec'] += 1

        if all(v == 1 for v in turn_scores["exact"]):
            scores['joint_all']['exact'] += 1

    for turn in turns:
        if scores[turn]['count'] == 0:
            continue
        if etype in ["all", "exec"]:
            scores[turn]['exec'] /= scores[turn]['count']

        if etype in ["all", "match"]:
            scores[turn]['exact'] /= scores[turn]['count']

    for level in levels:
        if scores[level]['count'] == 0:
            continue
        if etype in ["all", "exec"]:
            scores[level]['exec'] /= scores[level]['count']

        if etype in ["all", "match"]:
            scores[level]['exact'] /= scores[level]['count']
            for type_ in partial_types:
                if scores[level]['partial'][type_]['acc_count'] == 0:
                    scores[level]['partial'][type_]['acc'] = 0
                else:
                    scores[level]['partial'][type_]['acc'] = scores[level]['partial'][type_]['acc'] / \
                                                             scores[level]['partial'][type_]['acc_count'] * 1.0
                if scores[level]['partial'][type_]['rec_count'] == 0:
                    scores[level]['partial'][type_]['rec'] = 0
                else:
                    scores[level]['partial'][type_]['rec'] = scores[level]['partial'][type_]['rec'] / \
                                                             scores[level]['partial'][type_]['rec_count'] * 1.0
                if scores[level]['partial'][type_]['acc'] == 0 and scores[level]['partial'][type_]['rec'] == 0:
                    scores[level]['partial'][type_]['f1'] = 1
                else:
                    scores[level]['partial'][type_]['f1'] = \
                        2.0 * scores[level]['partial'][type_]['acc'] * scores[level]['partial'][type_]['rec'] / (
                        scores[level]['partial'][type_]['rec'] + scores[level]['partial'][type_]['acc'])

    print_scores(scores, etype, include_turn_acc=include_turn_acc)
示例#16
0
def evaluate(gold, predict, db_dir, etype, kmaps, wtp):
    with open(gold) as f:
        glist = []
        gseq_one = []
        for l in f.readlines():
            if len(l.strip()) == 0:
                glist.append(gseq_one)
                gseq_one = []
            else:
                lstrip = l.strip().split('\t')
                gseq_one.append(lstrip)
    with open(predict) as f:
        plist = []
        pseq_one = []
        for l in f.readlines():
            if len(l.strip()) == 0:
                plist.append(pseq_one)
                pseq_one = []
            else:
                pseq_one.append(l.strip().split('\t'))
    with open(wtp) as f:
        wtlist = []
        wt_one = []
        for l in f.readlines():
            if len(l.strip()) == 0:
                wtlist.append(wt_one)
                wt_one = []
            else:
                wt_one.append(l)

    evaluator = Evaluator()

    turns = ['turn 1', 'turn 2', 'turn 3', 'turn 4', 'turn >4']
    levels = ['easy', 'medium', 'hard', 'extra', 'all', 'joint_all']
    partial_types = [
        'select', 'select(no AGG)', 'where', 'where(no OP)',
        'group(no Having)', 'group', 'order', 'and/or', 'IUEN', 'keywords'
    ]
    entries = []
    scores = {}

    for turn in turns:
        scores[turn] = {'count': 0, 'exact': 0.}
        scores[turn]['exec'] = 0

    for level in levels:
        scores[level] = {'count': 0, 'partial': {}, 'exact': 0.}
        scores[level]['exec'] = 0
        for type_ in partial_types:
            scores[level]['partial'][type_] = {
                'acc': 0.,
                'rec': 0.,
                'f1': 0.,
                'acc_count': 0,
                'rec_count': 0
            }
    eval_err_num = 0
    wheerr = 0
    for p, g, wtl in zip(plist, glist, wtlist):
        scores['joint_all']['count'] += 1
        turn_scores = {"exec": [], "exact": []}
        bff = 0
        pb = ''
        for idx, pg in enumerate(zip(p, g)):
            p, g = pg

            p_str = p[0]
            p_str = p_str.replace("value", "1")
            g_str, db = g
            db_name = db
            db = os.path.join(db_dir, db, db + ".sqlite")
            schema = Schema(get_schema(db))
            g_sql = get_sql(schema, g_str)
            hardness = evaluator.eval_hardness(g_sql)
            if idx > 3:
                idx = ">4"
            else:
                idx += 1
            turn_id = "turn " + str(idx)
            scores[turn_id]['count'] += 1
            scores[hardness]['count'] += 1
            scores['all']['count'] += 1
            err = 0
            try:
                p_sql = get_sql(schema, p_str)
            except Exception as e:
                p_sql = {
                    "except": None,
                    "from": {
                        "conds": [],
                        "table_units": []
                    },
                    "groupBy": [],
                    "having": [],
                    "intersect": None,
                    "limit": None,
                    "orderBy": [],
                    "select": [False, []],
                    "union": None,
                    "where": []
                }
                eval_err_num += 1
                err = 1

            kmap = kmaps[db_name]
            g_valid_col_units = build_valid_col_units(
                g_sql['from']['table_units'], schema)
            g_sql = rebuild_sql_val(g_sql)
            g_sql = rebuild_sql_col(g_valid_col_units, g_sql, kmap)
            p_valid_col_units = build_valid_col_units(
                p_sql['from']['table_units'], schema)
            p_sql = rebuild_sql_val(p_sql)
            p_sql = rebuild_sql_col(p_valid_col_units, p_sql, kmap)

            if etype in ["all", "exec"]:
                exec_score = eval_exec_match(db, p_str, g_str, p_sql, g_sql)
                if exec_score:

                    scores[hardness]['exec'] += 1
                    scores[turn_id]['exec'] += 1
                    turn_scores['exec'].append(1)
                else:
                    turn_scores['exec'].append(0)

            if etype in ["all", "match"]:

                exact_score = evaluator.eval_exact_match(p_sql, g_sql)
                partial_scores = evaluator.partial_scores
                if exact_score == 0:

                    turn_scores['exact'].append(0)

                    bff = 0
                else:

                    bff = 1
                    pb = p_str

                    turn_scores['exact'].append(1)
                scores[turn_id]['exact'] += exact_score
                scores[hardness]['exact'] += exact_score
                scores['all']['exact'] += exact_score
                for type_ in partial_types:

                    if partial_scores[type_]['pred_total'] > 0:
                        scores[hardness]['partial'][type_][
                            'acc'] += partial_scores[type_]['acc']
                        scores[hardness]['partial'][type_]['acc_count'] += 1
                    if partial_scores[type_]['label_total'] > 0:
                        scores[hardness]['partial'][type_][
                            'rec'] += partial_scores[type_]['rec']
                        scores[hardness]['partial'][type_]['rec_count'] += 1
                    scores[hardness]['partial'][type_]['f1'] += partial_scores[
                        type_]['f1']
                    if partial_scores[type_]['pred_total'] > 0:
                        scores['all']['partial'][type_][
                            'acc'] += partial_scores[type_]['acc']
                        scores['all']['partial'][type_]['acc_count'] += 1
                    if partial_scores[type_]['label_total'] > 0:
                        scores['all']['partial'][type_][
                            'rec'] += partial_scores[type_]['rec']
                        scores['all']['partial'][type_]['rec_count'] += 1
                    scores['all']['partial'][type_]['f1'] += partial_scores[
                        type_]['f1']

                entries.append({
                    'predictSQL': p_str,
                    'goldSQL': g_str,
                    'hardness': hardness,
                    'exact': exact_score,
                    'partial': partial_scores
                })

        if all(v == 1 for v in turn_scores["exec"]):
            scores['joint_all']['exec'] += 1

        if all(v == 1 for v in turn_scores["exact"]):
            scores['joint_all']['exact'] += 1

    for turn in turns:
        if scores[turn]['count'] == 0:
            continue
        if etype in ["all", "exec"]:
            scores[turn]['exec'] /= scores[turn]['count']

        if etype in ["all", "match"]:
            scores[turn]['exact'] /= scores[turn]['count']

    for level in levels:
        if scores[level]['count'] == 0:
            continue
        if etype in ["all", "exec"]:
            scores[level]['exec'] /= scores[level]['count']

        if etype in ["all", "match"]:
            scores[level]['exact'] /= scores[level]['count']
            for type_ in partial_types:

                if scores[level]['partial'][type_]['acc_count'] == 0:
                    scores[level]['partial'][type_]['acc'] = 0
                else:
                    scores[level]['partial'][type_]['acc'] = scores[level]['partial'][type_]['acc'] / \
                                                             scores[level]['partial'][type_]['acc_count'] * 1.0
                if scores[level]['partial'][type_]['rec_count'] == 0:
                    scores[level]['partial'][type_]['rec'] = 0
                else:
                    scores[level]['partial'][type_]['rec'] = scores[level]['partial'][type_]['rec'] / \
                                                             scores[level]['partial'][type_]['rec_count'] * 1.0
                if scores[level]['partial'][type_]['acc'] == 0 and scores[
                        level]['partial'][type_]['rec'] == 0:
                    scores[level]['partial'][type_]['f1'] = 1
                else:
                    scores[level]['partial'][type_]['f1'] = \
                        2.0 * scores[level]['partial'][type_]['acc'] * scores[level]['partial'][type_]['rec'] / (
                        scores[level]['partial'][type_]['rec'] + scores[level]['partial'][type_]['acc'])
    print_scores(scores, etype)