コード例 #1
0
def transform_table_identifiers(sql, callable):
    from grt.modules import MysqlSqlFacade

    ast_list = MysqlSqlFacade.parseAstFromSqlScript(sql)

    # Do some validation first
    for ast in ast_list:
        if type(ast) is str:
            log_info("copyAsDrupalQuery", ast)  # error
            mforms.App.get().set_status_text("Cannot format invalid SQL: %s" %
                                             ast)
            return 1
        #from sql_reformatter import dump_tree
        #import sys
        #myFile = open("C:\\temp\\log.txt", "w")
        #dump_tree(myFile, ast)
        #myFile.close()

    offsets = []
    for ast in ast_list:
        get_table_ident_offsets(offsets, ast)

    sql2 = ""
    bb = 0
    for b, e in offsets:
        sql2 += sql[bb:b] + callable(sql[b:e])
        bb = e

    sql2 += sql[bb:]
    return sql2
コード例 #2
0
def transform_table_identifiers(sql,callable):
  from grt.modules import MysqlSqlFacade
  
  ast_list = MysqlSqlFacade.parseAstFromSqlScript(sql)
  
  # Do some validation first
  for ast in ast_list:
    if type(ast) is str:
      log_info("copyAsDrupalQuery",ast) # error
      mforms.App.get().set_status_text("Cannot format invalid SQL: %s" % ast)
      return 1
    #from sql_reformatter import dump_tree
    #import sys
    #myFile = open("C:\\temp\\log.txt", "w")
    #dump_tree(myFile, ast)
    #myFile.close()
    
  offsets = []
  for ast in ast_list:
    get_table_ident_offsets(offsets, ast)
   
  sql2 = ""    
  bb = 0   
  for b, e in offsets:
    sql2 += sql[bb:b] + callable(sql[b:e])
    bb = e
    
  sql2 += sql[bb:]
  return sql2
コード例 #3
0
    def reverseEngineer(cls, connection, catalog_name, schemata_list, context):
        from grt.modules import MysqlSqlFacade
        grt.send_progress(0, "Reverse engineering catalog information")
        cls.check_interruption()
        catalog = cls.reverseEngineerCatalog(connection, catalog_name)

        # calculate total workload 1st
        grt.send_progress(0.1, 'Preparing...')

        get_tables = context.get("reverseEngineerTables", True)

        # 10% of the progress is for preparation
        total = 1e-10  # total should not be zero to avoid DivisionByZero exceptions
        accumulated_progress = 0.1
        total += len(cls.getTableNames(connection, catalog_name,
                                       '')) if get_tables else 0

        grt.send_progress(0.1, "Gathered stats")

        # Now the second pass for reverse engineering tables:
        if get_tables:
            idx = 0
            for object_type, name, tbl_name, _, sql in cls.execute_query(
                    connection, "SELECT * FROM sqlite_master"):
                if type in ('view', 'trigger'
                            ) or not sql or tbl_name.startswith('sqlite_'):
                    continue

                sql = sql.replace('[', '').replace(']', '')

                grt.log_debug('SQLiteReverseEngineering',
                              'Procesing this sql:\n%s;' % sql)

                MysqlSqlFacade.parseSqlScriptString(catalog, sql)

                cls.check_interruption()
                grt.send_progress(0.1 + idx / total,
                                  'Object %s reverse engineered!' % name)
                idx += 1

        grt.send_progress(1.0, 'Reverse engineering completed!')
        return catalog
コード例 #4
0
def apply_to_keywords(editor, callable):
    from grt.modules import MysqlSqlFacade
    non_keywords = [
        "ident", "ident_or_text", "TEXT_STRING", "text_string",
        "TEXT_STRING_filesystem", "TEXT_STRING_literal", "TEXT_STRING_sys",
        "part_name"
    ]

    text = editor.selectedText
    selectionOnly = True
    if not text:
        selectionOnly = False
        text = editor.script

    new_text = ""
    ast_list = MysqlSqlFacade.parseAstFromSqlScript(text)
    bb = 0
    for ast in ast_list:
        if type(ast) is str:
            # error
            print ast
            mforms.App.get().set_status_text("Cannot format invalid SQL: %s" %
                                             ast)
            return 1
        else:
            if 0:  # debug
                from sql_reformatter import dump_tree
                import sys
                dump_tree(sys.stdout, ast)

            def get_keyword_offsets(offsets, script, node):
                s, v, c, base, b, e = node
                if v:
                    b += base
                    e += base
                    if s not in non_keywords:
                        offsets.append((b, e))
                for i in c:
                    get_keyword_offsets(offsets, script, i)

            offsets = []
            get_keyword_offsets(offsets, text, ast)
            for b, e in offsets:
                new_text += text[bb:b] + callable(text[b:e])
                bb = e
    new_text += text[bb:]

    if selectionOnly:
        editor.replaceSelection(new_text)
    else:
        editor.replaceContents(new_text)

    mforms.App.get().set_status_text("SQL code reformatted.")
    return 0
コード例 #5
0
def apply_to_keywords(editor, callable):
    from grt.modules import MysqlSqlFacade
    non_keywords = ["ident", "ident_or_text", "TEXT_STRING", "text_string", "TEXT_STRING_filesystem", "TEXT_STRING_literal", "TEXT_STRING_sys",
                    "part_name"]

    text = editor.selectedText
    selectionOnly = True
    if not text:
        selectionOnly = False
        text = editor.script
    
    new_text = ""
    ast_list = MysqlSqlFacade.parseAstFromSqlScript(text)
    bb = 0
    for ast in ast_list:
        if type(ast) is str:
            # error
            print ast
            mforms.App.get().set_status_text("Cannot format invalid SQL: %s"%ast)
            return 1
        else:
            if 0: # debug
                from sql_reformatter import dump_tree
                import sys
                dump_tree(sys.stdout, ast)

            def get_keyword_offsets(offsets, script, node):
                s, v, c, base, b, e = node
                if v:
                    b += base
                    e += base
                    if s not in non_keywords:
                        offsets.append((b, e))
                for i in c:
                    get_keyword_offsets(offsets, script, i)

            offsets = []
            get_keyword_offsets(offsets, text, ast)
            for b, e in offsets:
                new_text += text[bb:b] + callable(text[b:e])
                bb = e
    new_text += text[bb:]

    if selectionOnly:
        editor.replaceSelection(new_text)
    else:
        editor.replaceContents(new_text)

    mforms.App.get().set_status_text("SQL code reformatted.")
    return 0
コード例 #6
0
    def reverseEngineer(cls, connection, catalog_name, schemata_list, context):
        from grt.modules import MysqlSqlFacade
        grt.send_progress(0, "Reverse engineering catalog information")
        cls.check_interruption()
        catalog = cls.reverseEngineerCatalog(connection, catalog_name)

        # calculate total workload 1st
        grt.send_progress(0.1, 'Preparing...')

        get_tables = context.get("reverseEngineerTables", True)

        # 10% of the progress is for preparation
        total = 1e-10  # total should not be zero to avoid DivisionByZero exceptions
        accumulated_progress = 0.1
        total += len(cls.getTableNames(connection, catalog_name, '')) if get_tables else 0

        grt.send_progress(0.1, "Gathered stats")

        # Now the second pass for reverse engineering tables:
        if get_tables:
            idx = 0
            for object_type, name, tbl_name, _, sql in cls.execute_query(connection, "SELECT * FROM sqlite_master"):
                if type in ('view', 'trigger') or not sql or tbl_name.startswith('sqlite_'):
                    continue

                sql = sql.replace('[', '').replace(']', '')

                grt.log_debug('SQLiteReverseEngineering', 'Procesing this sql:\n%s;' % sql)

                MysqlSqlFacade.parseSqlScriptString(catalog, sql)

                cls.check_interruption()
                grt.send_progress(0.1 + idx / total, 'Object %s reverse engineered!' % name)
                idx += 1

        grt.send_progress(1.0, 'Reverse engineering completed!')
        return catalog
コード例 #7
0
ファイル: code_utils_grt.py プロジェクト: Arrjaan/Cliff
def _parse_column_name_list_from_query(query):
    from grt.modules import MysqlSqlFacade

    ast_list = MysqlSqlFacade.parseAstFromSqlScript(query)
    for ast in ast_list:
        if type(ast) is str:
            continue
        else:
            s, v, c, _base, _begin, _end = ast
            trimmed_ast = trim_ast(ast)
            select_item_list = find_child_node(trimmed_ast, "select_item_list")
            if select_item_list:
                columns = []
                variables = []
                index = 0
                for node in node_children(select_item_list):
                    if node_symbol(node) == "select_item":
                        index += 1
                        ident = find_child_node(find_child_node(node, "expr"), "ident")
                        alias = find_child_node(find_child_node(node, "select_alias"), "ident")
                        if not alias:
                            if ident:
                                name = node_value(ident)
                            else:
                                name = "field%i" % index
                        else:
                            name = node_value(alias)
                        columns.append(name)

                helper = ASTHelper(query)
                begin, end = helper.get_ast_range(ast)
                # dump_tree(sys.stdout, ast)

                query = query[begin:end]
                offset = begin

                vars = find_child_nodes(ast, "variable")
                for var in reversed(vars):
                    begin, end = helper.get_ast_range(var)
                    begin -= offset
                    end -= offset

                    name = query[begin:end]
                    query = query[:begin] + "?" + query[end:]
                    variables.insert(0, name)

                return query, columns, variables
コード例 #8
0
ファイル: code_utils_grt.py プロジェクト: Arrjaan/Cliff
def _parse_column_name_list_from_query(query):
    from grt.modules import MysqlSqlFacade

    ast_list = MysqlSqlFacade.parseAstFromSqlScript(query)
    for ast in ast_list:
        if type(ast) is str:
            continue
        else:
            s, v, c, _base, _begin, _end = ast
            trimmed_ast = trim_ast(ast)
            select_item_list = find_child_node(trimmed_ast, "select_item_list")
            if select_item_list:
                columns = []
                variables = []
                index = 0
                for node in node_children(select_item_list):
                    if node_symbol(node) == "select_item":
                        index += 1
                        ident = find_child_node(find_child_node(node, "expr"), "ident")
                        alias = find_child_node(find_child_node(node, "select_alias"), "ident")
                        if not alias:
                            if ident:
                                name = node_value(ident)
                            else:
                                name = "field%i"%index
                        else:
                            name = node_value(alias)
                        columns.append(name)

                helper = ASTHelper(query)
                begin, end = helper.get_ast_range(ast)
                #dump_tree(sys.stdout, ast)
                
                query = query[begin:end]
                offset = begin
                
                vars = find_child_nodes(ast, "variable")
                for var in reversed(vars):
                    begin, end = helper.get_ast_range(var)
                    begin -= offset
                    end -= offset
                    
                    name = query[begin:end]
                    query = query[:begin] + "?" + query[end:]
                    variables.insert(0, name)

                return query, columns, variables
コード例 #9
0
def doReformatSQLStatement(text, return_none_if_unsupported):
    from grt.modules import MysqlSqlFacade
    ast_list = MysqlSqlFacade.parseAstFromSqlScript(text)
    if len(ast_list) != 1:
        raise Exception("Error parsing statement")
    if type(ast_list[0]) is str:
        raise Exception("Error parsing statement: %s" % ast_list[0])

    ast = ast_list[0]

    def trim_ast_fix_bq(text, node, add_rollup):
        s = node[0]
        v = node[1]
        c = node[2]
        # put back backquotes to identifiers, if there's any
        if s in ("ident", "ident_or_text"):
            begin = node[3] + node[4]
            end = node[3] + node[5]
            if begin > 0 and text[begin - 1] == '`' and text[end] == '`':
                v = "`%s`" % v.replace("`", "``")
        l = []
        for i, nc in enumerate(c):
            l.append(trim_ast_fix_bq(text, nc, add_rollup))
            if add_rollup and nc[0] == "olap_opt" and nc[1].upper(
            ) == "WITH" and (i == len(c) - 1
                             or c[i + 1][1].upper() != "ROLLUP"):
                l.append(("olap_opt", "ROLLUP", []))
        return (s, v, l)

    formatter = formatter_for_statement_ast(ast)
    if formatter:
        # workaround a bug in parser where WITH ROLLUP is turned into WITH
        add_rollup = "WITH ROLLUP" in text.upper()

        p = formatter(trim_ast_fix_bq(text, ast, add_rollup))
        return p.run()
    else:
        if return_none_if_unsupported:
            return None
        return text
コード例 #10
0
def doReformatSQLStatement(text, return_none_if_unsupported):
    from grt.modules import MysqlSqlFacade
    ast_list = MysqlSqlFacade.parseAstFromSqlScript(text)
    if len(ast_list) != 1:
        raise Exception("Error parsing statement")
    if type(ast_list[0]) is str:
        raise Exception("Error parsing statement: %s" % ast_list[0])

    ast = ast_list[0]
    
    def trim_ast_fix_bq(text, node, add_rollup):
        s = node[0]
        v = node[1]
        c = node[2]
        # put back backquotes to identifiers, if there's any
        if s in ("ident", "ident_or_text"):
            begin = node[3] + node[4]
            end = node[3] + node[5]
            if begin > 0 and text[begin-1] == '`' and text[end] == '`':
                v = "`%s`" % v.replace("`", "``")
        l = []
        for i, nc in enumerate(c):
            l.append(trim_ast_fix_bq(text, nc, add_rollup))
            if add_rollup and nc[0] == "olap_opt" and nc[1].upper() == "WITH" and (i == len(c)-1 or c[i+1][1].upper() != "ROLLUP"):
                l.append(("olap_opt", "ROLLUP", []))
        return (s, v, l)

    formatter = formatter_for_statement_ast(ast)
    if formatter:
        # workaround a bug in parser where WITH ROLLUP is turned into WITH
        add_rollup = "WITH ROLLUP" in text.upper()

        p = formatter(trim_ast_fix_bq(text, ast, add_rollup))
        return p.run()
    else:
        if return_none_if_unsupported:
            return None
        return text
コード例 #11
0
ファイル: sqlide_grt.py プロジェクト: juanga16/p111mil
def doReformatSQLStatement(text, return_none_if_unsupported):
    from grt.modules import MysqlSqlFacade
    ast_list = MysqlSqlFacade.parseAstFromSqlScript(text)
    if len(ast_list) != 1:
        raise Exception("Error parsing statement")
    if type(ast_list[0]) is str:
        raise Exception("Error parsing statement: %s" % ast_list[0])

    helper = ASTHelper(text)

    curpos = 0
    new_text = ""
    ast = ast_list[0]
    
    def trim_ast_fix_bq(text, node):
        s = node[0]
        v = node[1]
        c = node[2]
        # put back backquotes to identifiers, if there's any
        if s in ("ident", "ident_or_text"):
            begin = node[3] + node[4]
            end = node[3] + node[5]
            if begin > 0 and text[begin-1] == '`' and text[end] == '`':
                v = "`%s`" % v.replace("`", "``")
        l = []
        for i in c:
            l.append(trim_ast_fix_bq(text, i))
        return (s, v, l)

    formatter = formatter_for_statement_ast(ast)
    if formatter:
        p = formatter(trim_ast_fix_bq(text, ast))
        return p.run()
    else:
        if return_none_if_unsupported:
            return None
        return text
コード例 #12
0
def doReformatSQLStatement(text, return_none_if_unsupported):
    from grt.modules import MysqlSqlFacade
    ast_list = MysqlSqlFacade.parseAstFromSqlScript(text)
    if len(ast_list) != 1:
        raise Exception("Error parsing statement")
    if type(ast_list[0]) is str:
        raise Exception("Error parsing statement: %s" % ast_list[0])

    helper = ASTHelper(text)

    curpos = 0
    new_text = ""
    ast = ast_list[0]
    
    def trim_ast_fix_bq(text, node):
        s = node[0]
        v = node[1]
        c = node[2]
        # put back backquotes to identifiers, if there's any
        if s in ("ident", "ident_or_text"):
            begin = node[3] + node[4]
            end = node[3] + node[5]
            if begin > 0 and text[begin-1] == '`' and text[end] == '`':
                v = "`%s`" % v.replace("`", "``")
        l = []
        for i in c:
            l.append(trim_ast_fix_bq(text, i))
        return (s, v, l)

    formatter = formatter_for_statement_ast(ast)
    if formatter:
        p = formatter(trim_ast_fix_bq(text, ast))
        return p.run()
    else:
        if return_none_if_unsupported:
            return None
        return text
コード例 #13
0
ファイル: sqlide_grt.py プロジェクト: Vescatur/DatabaseTest
def enbeautificate(editor):
    """Reformat the selected SQL statements or the one under the cursor."""

    from grt.modules import MysqlSqlFacade

    text = editor.selectedText
    selectionOnly = True
    if not text:
        selectionOnly = False
        text = editor.currentStatement

    ok_count = 0
    bad_count = 0
    
    prev_end = 0
    new_text = []
    ranges = MysqlSqlFacade.getSqlStatementRanges(text)
    for begin, end in ranges:
        end = begin + end
        if begin > prev_end:
            new_text.append(text[prev_end:begin])
        statement = text[begin:end]
        
        # 
        stripped = statement.lstrip(" \t\r\n")
        leading = statement[:len(statement) - len(stripped)]
        statement = stripped
        stripped = statement.rstrip(" \t\r\n")
        if stripped != statement:
            trailing = statement[-(len(statement) - len(stripped)):]
        else:
            trailing = ""
        statement = stripped

        # if there's a comment at the start, then skip the comment until its end
        while True:
            if statement.startswith("-- "):
                comment, _, rest = statement.partition("\n")
                leading += comment+"\n"
                statement = rest
            elif statement.startswith("/*"):
                pos = statement.find("*/")
                if pos >= 0:
                    leading += statement[:pos+2]
                    statement = statement[pos+2:]
                else:
                    break
            else:
                break
        stripped = statement.lstrip(" \t\r\n")
        leading += statement[:len(statement) - len(stripped)]
        statement = stripped
        stripped = statement.rstrip(" \t\r\n")
        if stripped != statement:
            trailing += statement[-(len(statement) - len(stripped)):]
        statement = stripped                

        try:
            result = doReformatSQLStatement(statement, True)
        except:
            import traceback
            log_error("Error reformating SQL: %s\n%s\n" % (statement, traceback.format_exc()))
            result = None
        if result:
            ok_count += 1
            if leading:
                new_text.append(leading.strip(" "))
            new_text.append(result)
            if trailing:
                new_text.append(trailing.strip(" "))
        else:
            bad_count += 1
            new_text.append(text[begin:end])
        prev_end = end
    new_text.append(text[prev_end:])

    new_text = "".join(new_text)

    if selectionOnly:
        editor.replaceSelection(new_text)
    else:
        editor.replaceCurrentStatement(new_text)

    if bad_count > 0:
        mforms.App.get().set_status_text("Formatted %i statements, %i unsupported statement types skipped."%(ok_count, bad_count))
    else:
        mforms.App.get().set_status_text("Formatted %i statements."%ok_count)

    return 0
コード例 #14
0
def _parse_column_name_list_from_query(query):
    from grt.modules import MysqlSqlFacade

    ast_list = MysqlSqlFacade.parseAstFromSqlScript(query)
    for ast in ast_list:
        if type(ast) is str:
            continue
        else:
            s, v, c, _base, _begin, _end = ast
            trimmed_ast = trim_ast(ast)
            select_item_list = find_child_node(trimmed_ast, "select_item_list")
            if select_item_list:
                columns = []
                variables = []
                index = 0
                for node in node_children(select_item_list):
                    if node_symbol(node) == "select_item":
                        alias = find_child_node(find_child_node(node, "select_alias"), "ident")
                        if not alias:
                            ident = find_child_node(node, "simple_ident_q")
                            if ident and len(node_children(ident)) == 3:
                                ident = node_children(ident)[-1]
                            else:
                                ident = find_child_node(find_child_node(node, "expr"), "ident")
                            if ident:
                                name = node_value(ident)
                            else:
                                name = "field"
                                field = flatten_node(node)
                                if field:
                                    import re
                                    m = re.match("([a-zA-Z0-9_]*)", field)
                                    if m:
                                        name = m.groups()[0]
                        else:
                            name = node_value(alias)
                        columns.append(name)

                helper = ASTHelper(query)
                begin, end = helper.get_ast_range(ast)
                #dump_tree(sys.stdout, ast)
                
                query = query[begin:end]
                offset = begin
                
                vars = find_child_nodes(ast, "variable")
                for var in reversed(vars):
                    begin, end = helper.get_ast_range(var)
                    begin -= offset
                    end -= offset
                    
                    name = query[begin:end]
                    query = query[:begin] + "?" + query[end:]
                    variables.insert(0, name)

                duplicates = {}
                for i, c in enumerate(columns):
                    if duplicates.has_key(c):
                        columns[i] = "%s%i" % (c, duplicates[c])
                        duplicates[c] += 1
                    duplicates[c] = duplicates.get(c, 0)+1

                return query, columns, variables
コード例 #15
0
def _parse_column_name_list_from_query(query):
    from grt.modules import MysqlSqlFacade

    ast_list = MysqlSqlFacade.parseAstFromSqlScript(query)
    for ast in ast_list:
        if type(ast) is str:
            continue
        else:
            s, v, c, _base, _begin, _end = ast
            trimmed_ast = trim_ast(ast)
            select_item_list = find_child_node(trimmed_ast, "select_item_list")
            if select_item_list:
                columns = []
                variables = []
                index = 0
                for node in node_children(select_item_list):
                    if node_symbol(node) == "select_item":
                        alias = find_child_node(find_child_node(node, "select_alias"), "ident")
                        if not alias:
                            ident = find_child_node(node, "simple_ident_q")
                            if ident and len(node_children(ident)) == 3:
                                ident = node_children(ident)[-1]
                            else:
                                ident = find_child_node(find_child_node(node, "expr"), "ident")
                            if ident:
                                name = node_value(ident)
                            else:
                                name = "field"
                                field = flatten_node(node)
                                if field:
                                    import re
                                    m = re.match("([a-zA-Z0-9_]*)", field)
                                    if m:
                                        name = m.groups()[0]
                        else:
                            name = node_value(alias)
                        columns.append(name)

                helper = ASTHelper(query)
                begin, end = helper.get_ast_range(ast)
                #dump_tree(sys.stdout, ast)
                
                query = query[begin:end]
                offset = begin
                
                vars = find_child_nodes(ast, "variable")
                for var in reversed(vars):
                    begin, end = helper.get_ast_range(var)
                    begin -= offset
                    end -= offset
                    
                    name = query[begin:end]
                    query = query[:begin] + "?" + query[end:]
                    variables.insert(0, name)

                duplicates = {}
                for i, c in enumerate(columns):
                    if duplicates.has_key(c):
                        columns[i] = "%s%i" % (c, duplicates[c])
                        duplicates[c] += 1
                    duplicates[c] = duplicates.get(c, 0)+1

                return query, columns, variables
コード例 #16
0
ファイル: sqlide_grt.py プロジェクト: Arrjaan/Cliff
def enbeautificate(editor):
    from grt.modules import MysqlSqlFacade

    text = editor.selectedText
    selectionOnly = True
    if not text:
        selectionOnly = False
        text = editor.script

    helper = ASTHelper(text)

    ok_count = 0
    bad_count = 0
    
    curpos = 0
    new_text = ""
    ast_list = MysqlSqlFacade.parseAstFromSqlScript(text)
    for ast in ast_list:
        if type(ast) is str:
            # error
            print ast
            mforms.App.get().set_status_text("Cannot format invalid SQL: %s"%ast)
            return 1
        else:
            if 0: # debug
                from sql_reformatter import dump_tree
                import sys
                dump_tree(sys.stdout, ast)
            s, v, c, _base, _begin, _end = ast
            begin, end = helper.get_ast_range(ast)
            new_text += text[curpos:begin].rstrip(" ") # strip spaces that would come before statement
            
            # The token range does not include the quotation char if the token is quoted.
            # So extend the range by one to avoid adding part of the original token to the output.
            if end < len(text):
              possible_quote_char = text[end]
            else:
              possible_quote_char = None
            if possible_quote_char == '\'' or possible_quote_char == '"' or possible_quote_char == '`':
                curpos = end + 1
            else:
                curpos = end

            def trim_ast(node):
                s = node[0]
                v = node[1]
                c = node[2]
                l = []
                for i in c:
                    l.append(trim_ast(i))
                return (s, v, l)

            formatter = formatter_for_statement_ast(ast)
            if formatter:
                ok_count += 1
                p = formatter(trim_ast(ast))
                fmted = p.run()
            else:
                bad_count += 1
                fmted = text[begin:end]
            new_text += fmted
    new_text += text[curpos:]

    if selectionOnly:
        editor.replaceSelection(new_text)
    else:
        editor.replaceContents(new_text)

    if bad_count > 0:
        mforms.App.get().set_status_text("Formatted %i statements, %i unsupported statement types skipped."%(ok_count, bad_count))
    else:
        mforms.App.get().set_status_text("Formatted %i statements."%ok_count)

    return 0
コード例 #17
0
ファイル: sqlide_grt.py プロジェクト: juanga16/p111mil
def enbeautificate_old(editor):
    from grt.modules import MysqlSqlFacade

    text = editor.selectedText
    selectionOnly = True
    if not text:
        selectionOnly = False
        text = editor.script

    helper = ASTHelper(text)

    ok_count = 0
    bad_count = 0
    
    curpos = 0
    new_text = ""
    ast_list = MysqlSqlFacade.parseAstFromSqlScript(text)
    for ast in ast_list:
        if type(ast) is str:
            # error
            print ast
            mforms.App.get().set_status_text("Cannot format invalid SQL: %s"%ast)
            return 1
        else:
            if 0: # debug
                from sql_reformatter import dump_tree
                import sys
                dump_tree(sys.stdout, ast)
            s, v, c, _base, _begin, _end = ast
            begin, end = helper.get_ast_range(ast)
            new_text += text[curpos:begin].rstrip(" ") # strip spaces that would come before statement
            
            # The token range does not include the quotation char if the token is quoted.
            # So extend the range by one to avoid adding part of the original token to the output.
            if end < len(text):
              possible_quote_char = text[end]
            else:
              possible_quote_char = None
            if possible_quote_char == '\'' or possible_quote_char == '"' or possible_quote_char == '`':
                curpos = end + 1
            else:
                curpos = end

            def trim_ast_fix_bq(text, node):
                s = node[0]
                v = node[1]
                c = node[2]
                # put back backquotes to identifiers, if there's any
                if s in ("ident", "ident_or_text"):
                    begin = node[3] + node[4]
                    end = node[3] + node[5]
                    if begin > 0 and end < len(text) and text[begin-1] == '`' and text[end] == '`':
                        v = "`%s`" % v.replace("`", "``")
                l = []
                for i in c:
                    l.append(trim_ast_fix_bq(text, i))
                return (s, v, l)

            formatter = formatter_for_statement_ast(ast)
            if formatter:
                ok_count += 1
                p = formatter(trim_ast_fix_bq(text, ast))
                fmted = p.run()
            else:
                bad_count += 1
                fmted = text[begin:end]
            new_text += fmted
    new_text += text[curpos:]

    if selectionOnly:
        editor.replaceSelection(new_text)
    else:
        editor.replaceContents(new_text)

    if bad_count > 0:
        mforms.App.get().set_status_text("Formatted %i statements, %i unsupported statement types skipped."%(ok_count, bad_count))
    else:
        mforms.App.get().set_status_text("Formatted %i statements."%ok_count)

    return 0