def test_execution_order():
    in_sql = sys.argv[1]
    in_sql = "SELECT song_name FROM singer WHERE age  >  (SELECT avg(age) FROM singer)"
    data_dir = sys.argv[2]
    db_name = sys.argv[3]
    schema_graphs = load_schema_graphs_spider(data_dir, 'spider')
    in_sqls = complex_queries[:4]
    db_names = ['flight_4', 'academic', 'baseball_1', 'voter_2']
    for db_name, in_sql in zip(db_names, in_sqls):
        schema = schema_graphs[db_name]
        ast = parse(in_sql)
        # print(json.dumps(ast, indent=4))
        ast_c = copy.deepcopy(ast)
        eo_sql = format(ast, schema, in_execution_order=True)
        eo_tokens = tokenize(in_sql,
                             bu.tokenizer.tokenize,
                             schema=schema,
                             in_execution_order=True)
        print('in_sql: {}'.format(in_sql))
        print('eo_sql: {}'.format(eo_sql))
        # print('eo_tokens: {}'.format(eo_tokens))
        eo_ast = eo_parse(eo_sql)
        assert (json.dumps(ast_c,
                           sort_keys=True) == json.dumps(eo_ast,
                                                         sort_keys=True))
        # print(json.dumps(eo_ast, indent=4))
        restored_sql = format(eo_ast, schema)
        # print('restored_sql: {}'.format(restored_sql))
        print()
def test_no_join_tokenizer():
    # for sql in complex_queries:
    if True:
        sql = 'SELECT avg(age) FROM Student WHERE StuID IN ( SELECT T1.StuID FROM Has_allergy AS T1 JOIN Allergy_Type AS T2 ON T1.Allergy  =  T2.Allergy WHERE T2.allergytype  =  "food" INTERSECT SELECT T1.StuID FROM Has_allergy AS T1 JOIN Allergy_Type AS T2 ON T1.Allergy  =  T2.Allergy WHERE T2.allergytype  =  "animal")'
        sql = 'SELECT T1.Name FROM Tourist_Attractions AS T1 JOIN VISITORS AS T2 JOIN VISITS AS T3 ON T1.Tourist_Attraction_ID  =  T3.Tourist_Attraction_ID AND T2.Tourist_ID  =  T3.Tourist_ID WHERE T2.Tourist_Details  =  "Vincent" INTERSECT SELECT T1.Name FROM Tourist_Attractions AS T1 JOIN VISITORS AS T2 JOIN VISITS AS T3 ON T1.Tourist_Attraction_ID  =  T3.Tourist_Attraction_ID AND T2.Tourist_ID  =  T3.Tourist_ID WHERE T2.Tourist_Details  =  "Marcelle"'
        print(sql)
        print(tokenize(sql, bu.tokenizer.tokenize, in_execution_order=True)[0])
        tokens = tokenize(sql,
                          bu.tokenizer.tokenize,
                          no_join_condition=True,
                          in_execution_order=True)[0]
        sql_njc = bu.tokenizer.convert_tokens_to_string(tokens)
        print(tokens)
        print(sql_njc)
        ast_njc = eo_parse(sql_njc)
        print(json.dumps(ast_njc, indent=4))
        print()
        import pdb
        pdb.set_trace()
def test_atomic_tokenizer():
    for sql in complex_queries:
        tokens, token_types, constants = tokenize(
            sql,
            bu.tokenizer.tokenize,
            atomic_value=True,
            num_token=functional_token_index['num_token'],
            str_token=functional_token_index['str_token'])
        print(sql)
        print(tokens)
        print(token_types)
        for constant in constants:
            print(constant)
        print()
        import pdb
        pdb.set_trace()
示例#4
0
def sql_tokenize(sql, value_tokenize, return_token_types=False, **kwargs):
    if isinstance(sql, string_types):
        sql = standardise_blank_spaces(sql)
        try:
            ast = moz_sp.parse(sql)
        except Exception:
            return value_tokenize(sql)
    else:
        ast = sql

    output = moz_sp.tokenize(ast, value_tokenize, parsed=True, **kwargs)

    if return_token_types:
        return output
    else:
        return output[0]
def test_tokenizer():
    # for sql in complex_queries:
    if True:
        sql = 'SELECT avg(age) FROM Student WHERE StuID IN ( SELECT T1.StuID FROM Has_allergy AS T1 JOIN Allergy_Type AS T2 ON T1.Allergy  =  T2.Allergy WHERE T2.allergytype  =  "food" INTERSECT SELECT T1.StuID FROM Has_allergy AS T1 JOIN Allergy_Type AS T2 ON T1.Allergy  =  T2.Allergy WHERE T2.allergytype  =  "animal")'
        sql = 'SELECT T1.Name FROM Tourist_Attractions AS T1 JOIN VISITORS AS T2 JOIN VISITS AS T3 ON T1.Tourist_Attraction_ID  =  T3.Tourist_Attraction_ID AND T2.Tourist_ID  =  T3.Tourist_ID WHERE T2.Tourist_Details  =  "Vincent" INTERSECT SELECT T1.Name FROM Tourist_Attractions AS T1 JOIN VISITORS AS T2 JOIN VISITS AS T3 ON T1.Tourist_Attraction_ID  =  T3.Tourist_Attraction_ID AND T2.Tourist_ID  =  T3.Tourist_ID WHERE T2.Tourist_Details  =  "Marcelle"'
        sql = "SELECT Perpetrator_ID FROM perpetrator WHERE Year IN ('1995.0', '1994.0', '1982.0')"
        print(sql)
        data_dir = sys.argv[1]
        db_name = sys.argv[2]
        schema_graphs = load_schema_graphs_spider(data_dir, 'spider')
        schema = schema_graphs[db_name]
        tokens = tokenize(sql,
                          bu.tokenizer.tokenize,
                          in_execution_order=True,
                          schema=schema)[0]
        print(tokens)
        print()
def denormalizer_unit_test(sql_query, schema, ast=None, idx=None):
    if ast is None:
        if isinstance(sql_query, dict):
            ast = sql_query
        else:
            ast = parse(sql_query)
    dn_sql_query, contains_self_join = denormalize(ast,
                                                   schema,
                                                   return_parse_tree=True)
    dn_sql_tokens = tokenize(dn_sql_query,
                             schema=schema,
                             keep_singleton_fields=True,
                             parsed=True,
                             value_tokenize=bu.tokenizer.tokenize)
    if DEBUG:
        print(sql_query)
        print(json.dumps(ast, indent=4))
        print(list(zip(*dn_sql_tokens)))
        print()
        import pdb
        pdb.set_trace()