예제 #1
0
def _parse_column_name_list_from_query(query):
    from grt.modules import MysqlSqlFacade

    ast_list = MysqlSqlFacade.parseAstFromSqlScript(query)
    for ast in ast_list:
        if type(ast) is str:
            continue
        else:
            s, v, c, _base, _begin, _end = ast
            trimmed_ast = trim_ast(ast)
            select_item_list = find_child_node(trimmed_ast, "select_item_list")
            if select_item_list:
                columns = []
                variables = []
                index = 0
                for node in node_children(select_item_list):
                    if node_symbol(node) == "select_item":
                        index += 1
                        ident = find_child_node(find_child_node(node, "expr"), "ident")
                        alias = find_child_node(find_child_node(node, "select_alias"), "ident")
                        if not alias:
                            if ident:
                                name = node_value(ident)
                            else:
                                name = "field%i"%index
                        else:
                            name = node_value(alias)
                        columns.append(name)

                helper = ASTHelper(query)
                begin, end = helper.get_ast_range(ast)
                #dump_tree(sys.stdout, ast)
                
                query = query[begin:end]
                offset = begin
                
                vars = find_child_nodes(ast, "variable")
                for var in reversed(vars):
                    begin, end = helper.get_ast_range(var)
                    begin -= offset
                    end -= offset
                    
                    name = query[begin:end]
                    query = query[:begin] + "?" + query[end:]
                    variables.insert(0, name)

                return query, columns, variables
예제 #2
0
def _parse_column_name_list_from_query(query):
    from grt.modules import MysqlSqlFacade

    ast_list = MysqlSqlFacade.parseAstFromSqlScript(query)
    for ast in ast_list:
        if type(ast) is str:
            continue
        else:
            s, v, c, _base, _begin, _end = ast
            trimmed_ast = trim_ast(ast)
            select_item_list = find_child_node(trimmed_ast, "select_item_list")
            if select_item_list:
                columns = []
                variables = []
                index = 0
                for node in node_children(select_item_list):
                    if node_symbol(node) == "select_item":
                        index += 1
                        ident = find_child_node(find_child_node(node, "expr"), "ident")
                        alias = find_child_node(find_child_node(node, "select_alias"), "ident")
                        if not alias:
                            if ident:
                                name = node_value(ident)
                            else:
                                name = "field%i" % index
                        else:
                            name = node_value(alias)
                        columns.append(name)

                helper = ASTHelper(query)
                begin, end = helper.get_ast_range(ast)
                # dump_tree(sys.stdout, ast)

                query = query[begin:end]
                offset = begin

                vars = find_child_nodes(ast, "variable")
                for var in reversed(vars):
                    begin, end = helper.get_ast_range(var)
                    begin -= offset
                    end -= offset

                    name = query[begin:end]
                    query = query[:begin] + "?" + query[end:]
                    variables.insert(0, name)

                return query, columns, variables
예제 #3
0
def doReformatSQLStatement(text, return_none_if_unsupported):
    from grt.modules import MysqlSqlFacade
    ast_list = MysqlSqlFacade.parseAstFromSqlScript(text)
    if len(ast_list) != 1:
        raise Exception("Error parsing statement")
    if type(ast_list[0]) is str:
        raise Exception("Error parsing statement: %s" % ast_list[0])

    helper = ASTHelper(text)

    curpos = 0
    new_text = ""
    ast = ast_list[0]
    
    def trim_ast_fix_bq(text, node):
        s = node[0]
        v = node[1]
        c = node[2]
        # put back backquotes to identifiers, if there's any
        if s in ("ident", "ident_or_text"):
            begin = node[3] + node[4]
            end = node[3] + node[5]
            if begin > 0 and text[begin-1] == '`' and text[end] == '`':
                v = "`%s`" % v.replace("`", "``")
        l = []
        for i in c:
            l.append(trim_ast_fix_bq(text, i))
        return (s, v, l)

    formatter = formatter_for_statement_ast(ast)
    if formatter:
        p = formatter(trim_ast_fix_bq(text, ast))
        return p.run()
    else:
        if return_none_if_unsupported:
            return None
        return text
예제 #4
0
def _parse_column_name_list_from_query(query):
    from grt.modules import MysqlSqlFacade

    ast_list = MysqlSqlFacade.parseAstFromSqlScript(query)
    for ast in ast_list:
        if type(ast) is str:
            continue
        else:
            s, v, c, _base, _begin, _end = ast
            trimmed_ast = trim_ast(ast)
            select_item_list = find_child_node(trimmed_ast, "select_item_list")
            if select_item_list:
                columns = []
                variables = []
                index = 0
                for node in node_children(select_item_list):
                    if node_symbol(node) == "select_item":
                        alias = find_child_node(find_child_node(node, "select_alias"), "ident")
                        if not alias:
                            ident = find_child_node(node, "simple_ident_q")
                            if ident and len(node_children(ident)) == 3:
                                ident = node_children(ident)[-1]
                            else:
                                ident = find_child_node(find_child_node(node, "expr"), "ident")
                            if ident:
                                name = node_value(ident)
                            else:
                                name = "field"
                                field = flatten_node(node)
                                if field:
                                    import re
                                    m = re.match("([a-zA-Z0-9_]*)", field)
                                    if m:
                                        name = m.groups()[0]
                        else:
                            name = node_value(alias)
                        columns.append(name)

                helper = ASTHelper(query)
                begin, end = helper.get_ast_range(ast)
                #dump_tree(sys.stdout, ast)
                
                query = query[begin:end]
                offset = begin
                
                vars = find_child_nodes(ast, "variable")
                for var in reversed(vars):
                    begin, end = helper.get_ast_range(var)
                    begin -= offset
                    end -= offset
                    
                    name = query[begin:end]
                    query = query[:begin] + "?" + query[end:]
                    variables.insert(0, name)

                duplicates = {}
                for i, c in enumerate(columns):
                    if duplicates.has_key(c):
                        columns[i] = "%s%i" % (c, duplicates[c])
                        duplicates[c] += 1
                    duplicates[c] = duplicates.get(c, 0)+1

                return query, columns, variables
예제 #5
0
def _parse_column_name_list_from_query(query):
    from grt.modules import MysqlSqlFacade

    ast_list = MysqlSqlFacade.parseAstFromSqlScript(query)
    for ast in ast_list:
        if type(ast) is str:
            continue
        else:
            s, v, c, _base, _begin, _end = ast
            trimmed_ast = trim_ast(ast)
            select_item_list = find_child_node(trimmed_ast, "select_item_list")
            if select_item_list:
                columns = []
                variables = []
                index = 0
                for node in node_children(select_item_list):
                    if node_symbol(node) == "select_item":
                        alias = find_child_node(find_child_node(node, "select_alias"), "ident")
                        if not alias:
                            ident = find_child_node(node, "simple_ident_q")
                            if ident and len(node_children(ident)) == 3:
                                ident = node_children(ident)[-1]
                            else:
                                ident = find_child_node(find_child_node(node, "expr"), "ident")
                            if ident:
                                name = node_value(ident)
                            else:
                                name = "field"
                                field = flatten_node(node)
                                if field:
                                    import re
                                    m = re.match("([a-zA-Z0-9_]*)", field)
                                    if m:
                                        name = m.groups()[0]
                        else:
                            name = node_value(alias)
                        columns.append(name)

                helper = ASTHelper(query)
                begin, end = helper.get_ast_range(ast)
                #dump_tree(sys.stdout, ast)
                
                query = query[begin:end]
                offset = begin
                
                vars = find_child_nodes(ast, "variable")
                for var in reversed(vars):
                    begin, end = helper.get_ast_range(var)
                    begin -= offset
                    end -= offset
                    
                    name = query[begin:end]
                    query = query[:begin] + "?" + query[end:]
                    variables.insert(0, name)

                duplicates = {}
                for i, c in enumerate(columns):
                    if duplicates.has_key(c):
                        columns[i] = "%s%i" % (c, duplicates[c])
                        duplicates[c] += 1
                    duplicates[c] = duplicates.get(c, 0)+1

                return query, columns, variables
예제 #6
0
def enbeautificate_old(editor):
    from grt.modules import MysqlSqlFacade

    text = editor.selectedText
    selectionOnly = True
    if not text:
        selectionOnly = False
        text = editor.script

    helper = ASTHelper(text)

    ok_count = 0
    bad_count = 0
    
    curpos = 0
    new_text = ""
    ast_list = MysqlSqlFacade.parseAstFromSqlScript(text)
    for ast in ast_list:
        if type(ast) is str:
            # error
            print ast
            mforms.App.get().set_status_text("Cannot format invalid SQL: %s"%ast)
            return 1
        else:
            if 0: # debug
                from sql_reformatter import dump_tree
                import sys
                dump_tree(sys.stdout, ast)
            s, v, c, _base, _begin, _end = ast
            begin, end = helper.get_ast_range(ast)
            new_text += text[curpos:begin].rstrip(" ") # strip spaces that would come before statement
            
            # The token range does not include the quotation char if the token is quoted.
            # So extend the range by one to avoid adding part of the original token to the output.
            if end < len(text):
              possible_quote_char = text[end]
            else:
              possible_quote_char = None
            if possible_quote_char == '\'' or possible_quote_char == '"' or possible_quote_char == '`':
                curpos = end + 1
            else:
                curpos = end

            def trim_ast_fix_bq(text, node):
                s = node[0]
                v = node[1]
                c = node[2]
                # put back backquotes to identifiers, if there's any
                if s in ("ident", "ident_or_text"):
                    begin = node[3] + node[4]
                    end = node[3] + node[5]
                    if begin > 0 and end < len(text) and text[begin-1] == '`' and text[end] == '`':
                        v = "`%s`" % v.replace("`", "``")
                l = []
                for i in c:
                    l.append(trim_ast_fix_bq(text, i))
                return (s, v, l)

            formatter = formatter_for_statement_ast(ast)
            if formatter:
                ok_count += 1
                p = formatter(trim_ast_fix_bq(text, ast))
                fmted = p.run()
            else:
                bad_count += 1
                fmted = text[begin:end]
            new_text += fmted
    new_text += text[curpos:]

    if selectionOnly:
        editor.replaceSelection(new_text)
    else:
        editor.replaceContents(new_text)

    if bad_count > 0:
        mforms.App.get().set_status_text("Formatted %i statements, %i unsupported statement types skipped."%(ok_count, bad_count))
    else:
        mforms.App.get().set_status_text("Formatted %i statements."%ok_count)

    return 0
예제 #7
0
파일: sqlide_grt.py 프로젝트: Arrjaan/Cliff
def enbeautificate(editor):
    from grt.modules import MysqlSqlFacade

    text = editor.selectedText
    selectionOnly = True
    if not text:
        selectionOnly = False
        text = editor.script

    helper = ASTHelper(text)

    ok_count = 0
    bad_count = 0
    
    curpos = 0
    new_text = ""
    ast_list = MysqlSqlFacade.parseAstFromSqlScript(text)
    for ast in ast_list:
        if type(ast) is str:
            # error
            print ast
            mforms.App.get().set_status_text("Cannot format invalid SQL: %s"%ast)
            return 1
        else:
            if 0: # debug
                from sql_reformatter import dump_tree
                import sys
                dump_tree(sys.stdout, ast)
            s, v, c, _base, _begin, _end = ast
            begin, end = helper.get_ast_range(ast)
            new_text += text[curpos:begin].rstrip(" ") # strip spaces that would come before statement
            
            # The token range does not include the quotation char if the token is quoted.
            # So extend the range by one to avoid adding part of the original token to the output.
            if end < len(text):
              possible_quote_char = text[end]
            else:
              possible_quote_char = None
            if possible_quote_char == '\'' or possible_quote_char == '"' or possible_quote_char == '`':
                curpos = end + 1
            else:
                curpos = end

            def trim_ast(node):
                s = node[0]
                v = node[1]
                c = node[2]
                l = []
                for i in c:
                    l.append(trim_ast(i))
                return (s, v, l)

            formatter = formatter_for_statement_ast(ast)
            if formatter:
                ok_count += 1
                p = formatter(trim_ast(ast))
                fmted = p.run()
            else:
                bad_count += 1
                fmted = text[begin:end]
            new_text += fmted
    new_text += text[curpos:]

    if selectionOnly:
        editor.replaceSelection(new_text)
    else:
        editor.replaceContents(new_text)

    if bad_count > 0:
        mforms.App.get().set_status_text("Formatted %i statements, %i unsupported statement types skipped."%(ok_count, bad_count))
    else:
        mforms.App.get().set_status_text("Formatted %i statements."%ok_count)

    return 0