Example #1
0
def test_psql_quotation_marks():
    # issue83

    # regression: make sure plain $$ work
    t = sqlparse.split(
        """
    CREATE OR REPLACE FUNCTION testfunc1(integer) RETURNS integer AS $$
          ....
    $$ LANGUAGE plpgsql;
    CREATE OR REPLACE FUNCTION testfunc2(integer) RETURNS integer AS $$
          ....
    $$ LANGUAGE plpgsql;"""
    )
    assert len(t) == 2

    # make sure $SOMETHING$ works too
    t = sqlparse.split(
        """
    CREATE OR REPLACE FUNCTION testfunc1(integer) RETURNS integer AS $PROC_1$
          ....
    $PROC_1$ LANGUAGE plpgsql;
    CREATE OR REPLACE FUNCTION testfunc2(integer) RETURNS integer AS $PROC_2$
          ....
    $PROC_2$ LANGUAGE plpgsql;"""
    )
    assert len(t) == 2
Example #2
0
def test_split_quotes_with_new_line():
    stmts = sqlparse.split('select "foo\nbar"')
    assert len(stmts) == 1
    assert stmts[0] == 'select "foo\nbar"'

    stmts = sqlparse.split("select 'foo\n\bar'")
    assert len(stmts) == 1
    assert stmts[0] == "select 'foo\n\bar'"
Example #3
0
def sqlRstToBlockSequence(sqlText):
    def __tryMatchCommentBlock(text):
        """
        :param text: the text to parse
        :return: (CommentBlock|None,str). The block matched and the rest
            of the string or None and the text.
        """
        m = re.match(BLOCK_COMMENT_BLOCK, text, re.MULTILINE | re.DOTALL)
        if m:
            return (
                structure.BlockCommentBlock(m.group("comment"), m.group("before"), m.group("after")),
                m.group("rest"),
            )
        m = re.match(LINES_COMMENT_BLOCK, text, re.MULTILINE | re.DOTALL)
        if m:
            return (
                structure.LinesCommentBlock(m.group("comment"), m.group("before"), m.group("after")),
                m.group("rest"),
            )
        return (None, text)

    def __doMatchSQLStatement(text):
        m = re.match(CREATE_TABLE_RE + ".*", text, re.IGNORECASE)
        if m:
            return structure.CreateTableStatementBlock(text, name=m.group("name"))
        m = re.match(CREATE_VIEW_RE + ".*", text, re.IGNORECASE)
        if m:
            return structure.CreateViewStatementBlock(text, name=m.group("name"))
        m = re.match(SELECT_RE + ".*", text, re.IGNORECASE)
        if m:
            return structure.SelectStatementBlock(text)
        return structure.UnknownStatementBlock(text)

    blocks = []
    # Split the text in (comments* +sqlstatement restofline) segments.
    # Spaces+lines before comments are removed, spaces after blocks
    # are removed. some blank lines may disappear in the process.
    segments = sqlparse.split(sqlText)

    # Process each segment as they still may contains some comments
    # additionaly to one (or more?) sql statements
    for segment in segments:

        rest = segment

        # try to get all comments before the sql statement(s?)
        (comment_block, rest) = __tryMatchCommentBlock(rest)
        while comment_block is not None:  # and not re.match('\s*',rest, re.MULTILINE):
            blocks.append(comment_block)
            (comment_block, rest) = __tryMatchCommentBlock(rest)

        # a priori there is just one remaining sql block,
        # but just to be sure...
        stmt_texts = sqlparse.split(rest)
        for stmt_text in stmt_texts:
            blocks.append(__doMatchSQLStatement(stmt_text))

    return blocks
Example #4
0
    def find_statements(self):
        """Finds statements in current buffer.

        Returns:
            List of 2-tuples (start iter, end iter).
        """
        buffer_ = self.get_buffer()
        buffer_start, buffer_end = buffer_.get_bounds()
        content = buffer_.get_text(buffer_start, buffer_end)
        iter_ = buffer_.get_start_iter()
        for stmt in sqlparse.split(content):
            if not stmt.strip():
                continue
            # FIXME: Does not work if linebreaks in buffers are not '\n'.
            #    sqlparse unifies them!
            bounds = iter_.forward_search(stmt.strip(),
                                          gtk.TEXT_SEARCH_TEXT_ONLY)
            if bounds is None:
                continue
            start, end = bounds
            yield start, end
            iter_ = end
            if not iter_:
                raise StopIteration
        raise StopIteration
def exec_impala_query_from_file(file_name):
  """Execute each query in an Impala query file individually"""
  if not os.path.exists(file_name):
    LOG.info("Error: File {0} not found".format(file_name))
    return False

  LOG.info("Beginning execution of impala SQL: {0}".format(file_name))
  is_success = True
  impala_client = ImpalaBeeswaxClient(options.impalad, use_kerberos=options.use_kerberos)
  output_file = file_name + ".log"
  with open(output_file, 'w') as out_file:
    try:
      impala_client.connect()
      with open(file_name, 'r+') as query_file:
        queries = sqlparse.split(query_file.read())
        for query in queries:
          query = sqlparse.format(query.rstrip(';'), strip_comments=True)
          if query.strip() != "":
            result = impala_client.execute(query)
            out_file.write("{0}\n{1}\n".format(query, result))
    except Exception as e:
      out_file.write("ERROR: {0}\n".format(query))
      traceback.print_exc(file=out_file)
      is_success = False

  if is_success:
    LOG.info("Finished execution of impala SQL: {0}".format(file_name))
  else:
    LOG.info("Error executing impala SQL: {0} See: {1}".format(file_name, \
             output_file))

  return is_success
def executeSQL(connection,command,parent_conn = None):
    ''' command is a sequence of SQL commands
        separated by ";" and possibly "\n"
        connection is a MySQLdb connection
        returns the output from the last command
        in the sequence
    '''
    #split commands by \n
    commands = command.split("\n")
    #remove comments and whitespace"
    commands = [x for x in commands if x.lstrip()[0:2] != '--']
    commands = [re.sub('\r','',x) for x in commands if x.lstrip() != '\r']
    command = ' '.join(commands)

    statements = sqlparse.split(command)
    count = 0
    for statement in statements:
        cur = connection.cursor()
        #make sure actually does something
        if sqlparse.parse(statement):
            cur.execute(statement)
        cur.close()
    connection.commit()
    if parent_conn:
        parent_conn.send(True)
    return True
Example #7
0
    def run(self, statement, pgspecial=None, exception_formatter=None,
            on_error_resume=False):
        """Execute the sql in the database and return the results.

        :param statement: A string containing one or more sql statements
        :param pgspecial: PGSpecial object
        :param exception_formatter: A callable that accepts an Exception and
               returns a formatted (title, rows, headers, status) tuple that can
               act as a query result. If an exception_formatter is not supplied,
               psycopg2 exceptions are always raised.
        :param on_error_resume: Bool. If true, queries following an exception
               (assuming exception_formatter has been supplied) continue to
               execute.

        :return: Generator yielding tuples containing
                 (title, rows, headers, status, query, success)
        """

        # Remove spaces and EOL
        statement = statement.strip()
        if not statement:  # Empty string
            yield (None, None, None, None, statement, False)

        # Split the sql into separate queries and run each one.
        for sql in sqlparse.split(statement):
            # Remove spaces, eol and semi-colons.
            sql = sql.rstrip(';')

            try:
                if pgspecial:
                    # First try to run each query as special
                    _logger.debug('Trying a pgspecial command. sql: %r', sql)
                    cur = self.conn.cursor()
                    try:
                        for result in pgspecial.execute(cur, sql):
                            # e.g. execute_from_file already appends these
                            if len(result) < 6:
                                yield result + (sql, True)
                            else:
                                yield result
                        continue
                    except special.CommandNotFound:
                        pass

                # Not a special command, so execute as normal sql
                yield self.execute_normal_sql(sql) + (sql, True)
            except psycopg2.DatabaseError as e:
                _logger.error("sql: %r, error: %r", sql, e)
                _logger.error("traceback: %r", traceback.format_exc())

                if (isinstance(e, psycopg2.OperationalError)
                        or not exception_formatter):
                    # Always raise operational errors, regardless of on_error
                    # specification
                    raise

                yield None, None, None, exception_formatter(e), sql, False

                if not on_error_resume:
                    break
Example #8
0
    def prepare_sql_script(self, sql, _allow_fallback=False):
        """
        Takes a SQL script that may contain multiple lines and returns a list
        of statements to feed to successive cursor.execute() calls.

        Since few databases are able to process raw SQL scripts in a single
        cursor.execute() call and PEP 249 doesn't talk about this use case,
        the default implementation is conservative.
        """
        # Remove _allow_fallback and keep only 'return ...' in Django 1.9.
        try:
            # This import must stay inside the method because it's optional.
            import sqlparse
        except ImportError:
            if _allow_fallback:
                # Without sqlparse, fall back to the legacy (and buggy) logic.
                warnings.warn(
                    "Providing initial SQL data on a %s database will require "
                    "sqlparse in Django 1.9." % self.connection.vendor,
                    RemovedInDjango19Warning)
                from django.core.management.sql import _split_statements
                return _split_statements(sql)
            else:
                raise
        else:
            return [sqlparse.format(statement, strip_comments=True)
                    for statement in sqlparse.split(sql) if statement]
Example #9
0
    def run(self, statement, pgspecial=None):
        """Execute the sql in the database and return the results.

        :param statement: A string containing one or more sql statements
        :param pgspecial: PGSpecial object
        :return: List of tuples containing (title, rows, headers, status)
        """

        # Remove spaces and EOL
        statement = statement.strip()
        if not statement:  # Empty string
            yield (None, None, None, None)

        # Split the sql into separate queries and run each one.
        for sql in sqlparse.split(statement):
            # Remove spaces, eol and semi-colons.
            sql = sql.rstrip(";")

            if pgspecial:
                # First try to run each query as special
                try:
                    _logger.debug("Trying a pgspecial command. sql: %r", sql)
                    cur = self.conn.cursor()
                    for result in pgspecial.execute(cur, sql):
                        yield result
                    return
                except special.CommandNotFound:
                    pass

            yield self.execute_normal_sql(sql)
Example #10
0
 def test_if_function(self):  # see issue 33
     # don't let IF as a function confuse the splitter
     sql = ('CREATE TEMPORARY TABLE tmp '
            'SELECT IF(a=1, a, b) AS o FROM one; '
            'SELECT t FROM two')
     stmts = sqlparse.split(sql)
     self.assertEqual(len(stmts), 2)
Example #11
0
def run(conn, sql, config, user_namespace):
    if sql.strip():
        for statement in sqlparse.split(sql):
            first_word = sql.strip().split()[0].lower()
            if first_word == 'begin':
                raise Exception("ipython_sql does not support transactions")
            if first_word.startswith('\\') and 'postgres' in str(conn.dialect):
                pgspecial = PGSpecial()
                _, cur, headers, _ = pgspecial.execute(
                                              conn.session.connection.cursor(),
                                              statement)[0]
                result = FakeResultProxy(cur, headers)
            else:
                txt = sqlalchemy.sql.text(statement)
                result = conn.session.execute(txt, user_namespace)
            try:
                # mssql has autocommit
                if config.autocommit and ('mssql' not in str(conn.dialect)):
                    conn.session.execute('commit')
            except sqlalchemy.exc.OperationalError:
                pass # not all engines can commit
            if result and config.feedback:
                print(interpret_rowcount(result.rowcount))
        resultset = ResultSet(result, statement, config)
        if config.autopandas:
            return resultset.DataFrame()
        else:
            return resultset
        #returning only last result, intentionally
    else:
        return 'Connected: %s' % conn.name
Example #12
0
    def loadDatabaseSchema(dumpFileName, ignoreSlonyTriggers = False):
        database = PgDatabase()
        logging.debug('Loading %s file' % dumpFileName)

        statements = sqlparse.split(open(dumpFileName,'r'))
        logging.debug('Parsed %d statements' % len(statements))
        for statement in statements:
            statement = PgDumpLoader.strip_comment(statement).strip()
            if PgDumpLoader.PATTERN_CREATE_SCHEMA.match(statement):
                CreateSchemaParser.parse(database, statement)
                continue

            match = PgDumpLoader.PATTERN_DEFAULT_SCHEMA.match(statement)
            if match:
                database.setDefaultSchema(match.group(1))
                continue

            if PgDumpLoader.PATTERN_CREATE_TABLE.match(statement):
                CreateTableParser.parse(database, statement)
                continue

            if PgDumpLoader.PATTERN_ALTER_TABLE.match(statement):
                AlterTableParser.parse(database, statement)
                continue

            if PgDumpLoader.PATTERN_CREATE_SEQUENCE.match(statement):
                CreateSequenceParser.parse(database, statement)
                continue

            if PgDumpLoader.PATTERN_ALTER_SEQUENCE.match(statement):
                AlterSequenceParser.parse(database, statement)
                continue

            if PgDumpLoader.PATTERN_CREATE_INDEX.match(statement):
                CreateIndexParser.parse(database, statement)
                continue

            if PgDumpLoader.PATTERN_CREATE_VIEW.match(statement):
                CreateViewParser.parse(database, statement)
                continue

            if PgDumpLoader.PATTERN_ALTER_VIEW.match(statement):
                AlterViewParser.parse(database, statement)
                continue

            if PgDumpLoader.PATTERN_CREATE_TRIGGER.match(statement):
                CreateTriggerParser.parse(database, statement, ignoreSlonyTriggers)
                continue

            if PgDumpLoader.PATTERN_CREATE_FUNCTION.match(statement):
                CreateFunctionParser.parse(database, statement)
                continue

            if PgDumpLoader.PATTERN_COMMENT.match(statement):
                CommentParser.parse(database, statement)
                continue

            logging.info('Ignored statement: %s' % statement)

        return database
Example #13
0
    def run(self, statement):
        """Execute the sql in the database and return the results. The results
        are a list of tuples. Each tuple has 4 values
        (title, rows, headers, status).
        """

        # Remove spaces and EOL
        statement = statement.strip()
        if not statement:  # Empty string
            yield (None, None, None, None)

        # Split the sql into separate queries and run each one.
        for sql in sqlparse.split(statement):
            # Remove spaces, eol and semi-colons.
            sql = sql.rstrip(';')

            # \G is treated specially since we have to set the expanded output
            # and then proceed to execute the sql as normal.
            if sql.endswith('\\G'):
                special.set_expanded_output(True)
                yield self.execute_normal_sql(sql.rsplit('\\G', 1)[0])
            else:
                try:   # Special command
                    _logger.debug('Trying a dbspecial command. sql: %r', sql)
                    cur = self.conn.cursor()
                    for result in special.execute(cur, sql):
                        yield result
                except special.CommandNotFound:  # Regular SQL
                    yield self.execute_normal_sql(sql)
Example #14
0
def execute_statements(cursor, statements):
    """Executes statements."""

    statements = statements.strip(u'%s%s' % (string.whitespace, ';'))
    statement_list = None
    if statements:
        statement_list = sqlparse.split(statements)

    if not statements:
        return

    try:
        for statement in statement_list:
            statement = statement.rstrip(u'%s%s' % (string.whitespace, ';'))

            if not statement:
                continue

            cursor.execute(statement)

            while cursor.nextset() is not None:
                pass

    finally:
        while cursor.nextset() is not None:
            pass
Example #15
0
 def explain(self):
     self.results.assure_visible()
     buf = self.textview.get_buffer()
     bounds = buf.get_selection_bounds()
     if not bounds:
         bounds = buf.get_bounds()
     statement = buf.get_text(*bounds)
     if len(sqlparse.split(statement)) > 1:
         dialogs.error(_(u"Select a single statement to explain."))
         return
     if not self.connection:
         return
     queries = [Query(stmt, self.connection)
                for stmt in self.connection.explain_statements(statement)]
     def _execute_next(last, queries):
         if last is not None and last.failed:
             self.results.set_explain_results(last)
             return
         q = queries.pop(0)
         if len(queries) == 0:
             q.connect('finished',
                       lambda x: self.results.set_explain_results(x))
         else:
             q.connect('finished',
                       lambda x: _execute_next(x, queries))
         q.execute()
     _execute_next(None, queries)
Example #16
0
 def clean_reply(self, reply, dataset):
     solve_sql = reply.solve_sql.strip()
     if not solve_sql:
         raise FormatError("Empty query")
     if len(sqlparse.split(solve_sql)) > 1:
         raise FormatError("Only one query is allowed")
     return solve_sql
Example #17
0
def execute_favorite_query(cur, arg, **_):
    """Returns (title, rows, headers, status)"""
    if arg == '':
        for result in list_favorite_queries():
            yield result

    """Parse out favorite name and optional substitution parameters"""
    name, _, arg_str = arg.partition(' ')
    args = shlex.split(arg_str)

    query = favoritequeries.get(name)
    if query is None:
        message = "No favorite query: %s" % (name)
        yield (None, None, None, message)
    else:
        query, arg_error = subst_favorite_query_args(query, args)
        if arg_error:
            yield (None, None, None, arg_error)
        else:
            for sql in sqlparse.split(query):
                sql = sql.rstrip(';')
                title = '> %s' % (sql)
                cur.execute(sql)
                if cur.description:
                    headers = [x[0] for x in cur.description]
                    yield (title, cur, headers, None)
                else:
                    yield (title, None, None, None)
Example #18
0
 def get_schema(self, name):
     fp = self.load_file(name, 'schemas')
     query = fp.read()
     # queries = query.split(';')
     queries = sqlparse.split(query)
     ret_queries = [q.strip() for q in queries if q.strip()]
     return ret_queries
Example #19
0
    def run(self, statement):
        """Execute the sql in the database and return the results. The results
        are a list of tuples. Each tuple has 4 values
        (title, rows, headers, status).
        """

        # Remove spaces and EOL
        statement = statement.strip()
        if not statement:  # Empty string
            yield (None, None, None, None)

        # Split the sql into separate queries and run each one.
        # Unless it's saving a favorite query, in which case we
        # want to save them all together.
        if statement.startswith('\\fs'):
            components = [statement]

        else:
            components = sqlparse.split(statement)

        for sql in components:
            # Remove spaces, eol and semi-colons.
            sql = sql.rstrip(';')

            # \G is treated specially since we have to set the expanded output.
            if sql.endswith('\\G'):
                special.set_expanded_output(True)
                sql = sql[:-2].strip()
            try:   # Special command
                _logger.debug('Trying a dbspecial command. sql: %r', sql)
                cur = self.conn.cursor()
                for result in special.execute(cur, sql):
                    yield result
            except special.CommandNotFound:  # Regular SQL
                yield self.execute_normal_sql(sql)
Example #20
0
def test_issue485_split_multi():
    p_sql = '''CREATE OR REPLACE RULE ruled_tab_2rules AS ON INSERT
TO public.ruled_tab
DO instead (
select 1;
select 2;
);'''
    assert len(sqlparse.split(p_sql)) == 1
Example #21
0
    def load_database_schema(dump_file_name, ignore_slony_triggers=False):
        database = PgDatabase()

        print("Loading file dump: %s\n" % dump_file_name)

        statements = sqlparse.split(open(dump_file_name, 'r'))
        for statement in statements:
            statement = PgDumpLoader.strip_comment(statement).strip()
            if PgDumpLoader.PATTERN_CREATE_SCHEMA.match(statement):
                CreateSchemaParser.parse(database, statement)
                continue

            match = PgDumpLoader.PATTERN_DEFAULT_SCHEMA.match(statement)
            if match:
                database.setDefaultSchema(match.group(1))
                continue

            if PgDumpLoader.PATTERN_CREATE_TABLE.match(statement):
                CreateTableParser.parse(database, statement)
                continue

            if PgDumpLoader.PATTERN_ALTER_TABLE.match(statement):
                AlterTableParser.parse(database, statement)
                continue

            if PgDumpLoader.PATTERN_CREATE_SEQUENCE.match(statement):
                CreateSequenceParser.parse(database, statement)
                continue

            if PgDumpLoader.PATTERN_ALTER_SEQUENCE.match(statement):
                AlterSequenceParser.parse(database, statement)
                continue

            if PgDumpLoader.PATTERN_CREATE_INDEX.match(statement):
                CreateIndexParser.parse(database, statement)
                continue

            if PgDumpLoader.PATTERN_CREATE_VIEW.match(statement):
                CreateViewParser.parse(database, statement)
                continue

            if PgDumpLoader.PATTERN_ALTER_VIEW.match(statement):
                AlterViewParser.parse(database, statement)
                continue

            if PgDumpLoader.PATTERN_CREATE_TRIGGER.match(statement):
                CreateTriggerParser.parse(database, statement, ignore_slony_triggers)
                continue

            if PgDumpLoader.PATTERN_CREATE_FUNCTION.match(statement):
                CreateFunctionParser.parse(database, statement)
                continue

            if PgDumpLoader.PATTERN_COMMENT.match(statement):
                CommentParser.parse(database, statement)
                continue

        return database
Example #22
0
def need_completion_refresh(queries):
    """Determines if the completion needs a refresh by checking if the sql
    statement is an alter, create, drop or change db."""
    for query in sqlparse.split(queries):
        try:
            first_token = query.split()[0]
            return first_token.lower() in ("alter", "create", "use", "\\c", "\\connect", "drop")
        except Exception:
            return False
Example #23
0
def run(conn, sql, config):
    if sql.strip():
        for statement in sqlparse.split(sql):
            txt = sqlalchemy.sql.text(statement)
            result = conn.session.execute(txt)
        return ResultSet(result, statement, config)
        # returning only last result, intentionally
    else:
        return "Connected: %s" % conn.name
Example #24
0
def test_issue193_splitting_function():
    sql = """   CREATE FUNCTION a(x VARCHAR(20)) RETURNS VARCHAR(20)
                BEGIN
                 DECLARE y VARCHAR(20);
                 RETURN x;
                END;
                SELECT * FROM a.b;"""
    statements = sqlparse.split(sql)
    assert len(statements) == 2
Example #25
0
def run(conn, sql, config, user_namespace={}):
  rowcount = 0
  for statement in sqlparse.split(sql):
    txt = sqlalchemy.sql.text(statement)
    result = conn.session.execute(txt, user_namespace)

  resultset = ResultSet(result, statement, config)
    
  return resultset
Example #26
0
 def execute_query(self, cursor, query):
     """Executes query."""
     statements = sqlparse.split(query)
     for statement in statements:
         statement = statement.rstrip(unicode(string.whitespace + ';'))
         if statement:
             cursor.execute(statement)
             while cursor.nextset() is not None:
                 pass
Example #27
0
 def parse_statements(sql_location):
     if os.path.isfile(sql_location) and os.access(sql_location, os.R_OK):
         file = open(sql_location, "r")
         content = file.read()
         sql = filter(None, sqlparse.split(content))
         assert isinstance(sql, list)
         return sql
     else:
         print("SQL file not found or not accessible")
Example #28
0
def execute_count_statements(cursor, statements):
    """Executes count statement(s)."""

    counts = []
    if not statements:
        return counts

    statements = statements.strip(u'%s%s' % (string.whitespace, ';'))
    statement_list = None
    if statements:
        statement_list = sqlparse.split(statements)

    if not statements:
        return counts

    try:
        for statement in statement_list:
            count = None
            statement = statement.rstrip(u'%s%s' % (string.whitespace, ';'))

            if not statement:
                counts.append(count)
                continue

            row_count = cursor.execute(statement)

            if len(cursor.description) > 1:
                raise exceptions.Error(
                    'Statement should return a single value only.')

            if row_count > 1:
                raise exceptions.Error(
                    u'Statement should return a single row only. '
                    u'Statement was: %s' % (statement,)
                )

            if not row_count:
                raise exceptions.Error(
                    u'Statement returned an empty set. '
                    u'Statement was: %s' % (statement,)
                )

            row = cursor.fetchone()
            if row is None:
                raise exceptions.Error(
                    u'Statement returned an empty set. '
                    u'Statement was: %s' % (statement,)
                )

            counts.append(row[0])

    finally:
        while cursor.nextset() is not None:
            pass

    return counts
Example #29
0
    def run(self):

        sqlFile = os.path.join(self._dirPath, self._plugin_settings["dataFile"])
        self._log("Loading SQL file:" + sqlFile)
        sql = open(sqlFile).read()
        sql_parts = sqlparse.split(sql)
        for sql_part in sql_parts:
            if sql_part.strip() == "":
                continue
            self._execSql(sql_part)
Example #30
0
    def get_query(self, name):
        if name in self.query_cache:
            return self.query_cache.get(name)

        fp = self.load_file(name, 'queries')
        query = fp.read()
        queries = sqlparse.split(query)
        ret_queries = [q.strip() for q in queries if q.strip()]
        self.query_cache[name] = ret_queries
        return ret_queries
Example #31
0
def upgrade():
    migrate_engine = op.get_bind()
    meta = MetaData(bind=migrate_engine)

    engine_name = migrate_engine.engine.name
    if engine_name == 'sqlite':
        sql_file = os.path.splitext(__file__)[0]
        sql_file += '.sql'
        with open(sql_file, 'r') as sqlite_script:
            sql = sqlparse.format(sqlite_script.read(), strip_comments=True)
            for statement in sqlparse.split(sql):
                op.execute(statement)
        return

    enum = Enum('private', 'public', 'shared', 'community', metadata=meta,
                name='image_visibility')
    enum.create()
    v_col = Column('visibility', enum, nullable=False, server_default='shared')
    op.add_column('images', v_col)

    op.create_index('visibility_image_idx', 'images', ['visibility'])

    images = Table('images', meta, autoload=True)
    images.update(values={'visibility': 'public'}).where(
        images.c.is_public).execute()

    image_members = Table('image_members', meta, autoload=True)

    # NOTE(dharinic): Mark all the non-public images as 'private' first
    images.update().values(visibility='private').where(
        not_(images.c.is_public)).execute()
    # NOTE(dharinic): Identify 'shared' images from the above
    images.update().values(visibility='shared').where(and_(
        images.c.visibility == 'private', images.c.id.in_(select(
            [image_members.c.image_id]).distinct().where(
                not_(image_members.c.deleted))))).execute()

    op.drop_index('ix_images_is_public', 'images')
    op.drop_column('images', 'is_public')
Example #32
0
def load_sql_script_file(sql_file_name,
                         msg='load_sql_script_file',
                         exit_if_file_not_found=True):
    logger.info(msg)
    ignore = {'CREATE FUNCTION'}  # extend this

    def _filter(statement, allow=0):
        ddl = [
            t for t in statement.tokens
            if t.ttype in (tokens.DDL, tokens.Keyword)
        ]
        start = ' '.join(d.value for d in ddl[:2])
        if ddl and start in ignore:
            allow = 1
        for tok in statement.tokens:
            if allow or not isinstance(tok, sqlparse.sql.Comment):
                yield tok

    sql_file_name = os.path.join(get_ruian_services_sql_scripts_path(),
                                 sql_file_name)
    logger.info("   Loading SQL commands from {0}".format(sql_file_name))
    if not os.path.exists(sql_file_name):
        if exit_if_file_not_found:
            logger.error("ERROR: File %s not found." % sql_file_name)
            exit_app()
        else:
            logger.warning("ERROR: File %s not found." % sql_file_name)
            return

    in_file = codecs.open(sql_file_name, "r", "utf-8")
    raw = in_file.read()
    in_file.close()
    statements = []
    for stmt in sqlparse.split(raw):
        sql = sqlparse.parse(stmt)[0]
        tl = sqlparse.sql.TokenList([t for t in _filter(sql)])
        statements.append(tl.value)
    logger.info("   Loading SQL commands - done.")
    return statements
Example #33
0
def test_split_begin_end_block():
    sql = """
        select * from dual;
        DECLARE 
        test_num number;
        begin
           dbms_output.put_line("123");
           BEGIN
              select (case when rownum > 1 then 
                    case when rownum == 1 
                      then 0 else 2 
                     end
                    else 0 
                end) a into test_num from dual;
                dbms_output.put_line(test_num);
           end;
        end;
        
        select * from dual;
    """
    stmts = sqlparse.split(sql.strip())
    assert len(stmts) == 3
Example #34
0
    def run(self, statement):
        """Execute the sql in the database and return the results. The results
        are a list of tuples. Each tuple has 3 values (rows, headers, status).
        """

        # Remove spaces and EOL
        statement = statement.strip()
        if not statement:  # Empty string
            yield (None, None, None, None)

        # Split the sql into separate queries and run each one.
        for sql in sqlparse.split(statement):
            # Remove spaces, eol and semi-colons.
            sql = sql.rstrip(';')

            # Check if the command is a \c or 'use'. This is a special
            # exception that cannot be offloaded to `pgspecial` lib. Because we
            # have to change the database connection that we're connected to.
            if sql.startswith('\c') or sql.lower().startswith('use'):
                _logger.debug('Database change command detected.')
                try:
                    dbname = sql.split()[1]
                except:
                    _logger.debug('Database name missing.')
                    raise RuntimeError('Database name missing.')
                self.connect(database=dbname)
                self.dbname = dbname
                _logger.debug('Successfully switched to DB: %r', dbname)
                yield (None, None, None,
                       'You are now connected to database "%s" as '
                       'user "%s"' % (self.dbname, self.user))
            else:
                try:  # Special command
                    _logger.debug('Trying a pgspecial command. sql: %r', sql)
                    cur = self.conn.cursor()
                    for result in pgspecial.execute(cur, sql):
                        yield result
                except KeyError:  # Regular SQL
                    yield self.execute_normal_sql(sql)
Example #35
0
    def __init__(self, owner_uri: str, query_text: str, query_execution_settings: QueryExecutionSettings, query_events: QueryEvents) -> None:
        self._execution_state: ExecutionState = ExecutionState.NOT_STARTED
        self._owner_uri: str = owner_uri
        self._query_text = query_text
        self._disable_auto_commit = False
        self._current_batch_index = 0
        self._batches: List[Batch] = []
        self._execution_plan_options = query_execution_settings.execution_plan_options

        self.is_canceled = False

        # Initialize the batches
        statements = sqlparse.split(query_text)
        selection_data = compute_selection_data_for_batches(statements, query_text)

        for index, batch_text in enumerate(statements):
            # Skip any empty text
            formatted_text = sqlparse.format(batch_text, strip_comments=True).strip()
            if not formatted_text or formatted_text == ';':
                continue

            sql_statement_text = batch_text

            # Create and save the batch
            if bool(self._execution_plan_options):
                if self._execution_plan_options.include_estimated_execution_plan_xml:
                    sql_statement_text = Query.EXPLAIN_QUERY_TEMPLATE.format(sql_statement_text)
                elif self._execution_plan_options.include_actual_execution_plan_xml:
                    self._disable_auto_commit = True
                    sql_statement_text = Query.ANALYZE_EXPLAIN_QUERY_TEMPLATE.format(sql_statement_text)

            batch = create_batch(
                sql_statement_text,
                len(self.batches),
                selection_data[index],
                query_events.batch_events,
                query_execution_settings.result_set_storage_type)

            self._batches.append(batch)
Example #36
0
def run_sql(session, sql, params=None, stop_on_error=False):
    queries = split(sql)
    for q in queries:
        sql = format(q, strip_comments=True).strip()
        if sql == "":
            continue
        try:
            session.execute(text(sql), params=params)
            if hasattr(session, "commit"):
                session.commit()
            pretty_print(sql, dim=True)
        except (ProgrammingError, IntegrityError) as err:
            err = str(err.orig).strip()
            dim = "already exists" in err
            if hasattr(session, "rollback"):
                session.rollback()
            pretty_print(sql, fg=None if dim else "red", dim=True)
            if dim:
                err = "  " + err
            secho(err, fg="red", dim=dim)
            if stop_on_error:
                return
Example #37
0
def query_check(sql=''):
    result = {
        'msg': '',
        'bad_query': False,
        'filtered_sql': sql,
        'has_star': False
    }
    try:
        sql = sqlparse.format(sql, strip_comments=True)
        sql = sqlparse.split(sql)[0]
        result['filtered_sql'] = sql.strip()
    except Exception as err:
        result['bad_query'] = True
        result['msg'] = 'SQL语句无效'

    if re.match("select|show|explain|desc", sql) is None:
        result['bad_query'] = True
        result['msg'] = '不支持的语句类型'
    if re.search('\*', sql) is not None:
        result['has_star'] = True
        result['msg'] = 'SQL语句中含有 * '
    return result
Example #38
0
 def execute_query(self, query):
     # Try to run first as special command
     try:
         for rows, columns, status, statement, is_error in special.execute(self, query):
             yield rows, columns, status, statement, is_error
     except special.CommandNotFound:
         # Execute as normal sql
         # Remove spaces, EOL and semi-colons from end
         query = query.strip()
         if not query:
             yield None, None, None, query, False
         else:
             for single_query in sqlparse.split(query):
                 # Remove spaces, EOL and semi-colons from end
                 single_query = single_query.strip().rstrip(';')
                 if single_query:
                     for rows, columns, status, statement, is_error \
                         in self._execute_query(single_query):
                         yield rows, columns, status, statement, is_error
                 else:
                     yield None, None, None, None, False
                     continue
Example #39
0
    def _batch_DDLs(self, sql):
        """
        Check that the given operation contains only DDL
        statements and batch them into an internal list.

        :type sql: str
        :param sql: A SQL query statement.

        :raises: :class:`ValueError` in case not a DDL statement
                 present in the operation.
        """
        statements = []
        for ddl in sqlparse.split(sql):
            if ddl:
                ddl = ddl.rstrip(";")
                if parse_utils.classify_stmt(ddl) != parse_utils.STMT_DDL:
                    raise ValueError("Only DDL statements may be batched.")

                statements.append(ddl)

        # Only queue DDL statements if they are all correctly classified.
        self.connection._ddl_statements.extend(statements)
Example #40
0
    def parse(self, fd):
        text = fd.read().rstrip('\n')
        statements = []
        for statement in sqlparse.split(text):
            match = name_token.search(statement)
            sql = ' '.join(
                sqlparse.format(statement, strip_comments=True).split())
            sql = params_token.sub(r'%(\1)s', sql)

            name = match.group('name')
            command = match.group('command')
            result = match.group('result')

            if name:
                statements.append({
                    'name': name.replace('-', '_'),
                    'command': command,
                    'result': result,
                    'sql': sql
                })

        return statements
Example #41
0
 def exec_threaded(statement, start_line):
     if self.app.config.get("sqlparse.enabled", True):
         stmts = sqlparse.split(statement)
     else:
         stmts = [statement]
     for stmt in stmts:
         add_offset = len(stmt.splitlines())
         if not stmt.strip():
             start_line += add_offset
             continue
         query = Query(stmt, self.connection)
         #                query.coding_hint = self.connection.coding_hint
         gtk.gdk.threads_enter()
         query.set_data('editor_start_line', start_line)
         query.connect("started", self.on_query_started)
         query.connect("finished", self.on_query_finished, tag_notice)
         gtk.gdk.threads_leave()
         query.execute(True)
         start_line += add_offset
         if query.failed:
             # hmpf, doesn't work that way... so just return here...
             return
    def explainPlan(self, queries, callback):
        queryName = 'explain plan'
        explainQuery = self.getNamedQuery(queryName)
        if not explainQuery:
            return

        strippedQueries = [
            explainQuery.format(query.strip().strip(";"))
            for rawQuery in queries
            for query in filter(None, sqlparse.split(rawQuery))
        ]
        queryToRun = self.buildNamedQuery(queryName, strippedQueries)
        args = self.buildArgs(queryName)
        env = self.buildEnv()
        self.Command.createAndRun(args=args,
                                  env=env,
                                  callback=callback,
                                  query=queryToRun,
                                  encoding=self.encoding,
                                  timeout=self.timeout,
                                  silenceErrors=False,
                                  stream=self.useStreams)
Example #43
0
 def save(self, request):
     id = self.validated_data['id']
     obj = models.DbOrders.objects.get(pk=id)
     # 如果记录已经存在,直接跳转
     if models.DbOrdersExecuteTasks.objects.filter(order__id=id).exists():
         return models.DbOrdersExecuteTasks.objects.filter(
             order__id=id).first().task_id
     # 生成记录
     splitsqls = [
         sql.strip(';')
         for sql in sqlparse.split(obj.contents, encoding='utf8')
     ]
     task_id = ''.join(str(uuid.uuid4()).split('-'))  # 基于UUID生成任务ID
     for sql in splitsqls:
         models.DbOrdersExecuteTasks.objects.create(
             applicant=obj.applicant,
             task_id=task_id,
             sql=sql.strip(';'),
             sql_type=obj.sql_type,
             file_format=obj.file_format,
             order_id=id)
     return task_id
Example #44
0
 def mysql_parse(self, sql, paramHost, paramPort, paramUser, paramPasswd,
                 paramDb):
     result_key = []
     result_value = []
     conn = MySQLdb.connect(host=paramHost,
                            user=paramUser,
                            passwd=paramPasswd,
                            db=paramDb,
                            port=paramPort,
                            charset='utf8')
     cursor = conn.cursor()
     list_sql = sqlparse.split(sql)
     for row in list_sql:
         try:
             cursor.execute('EXPLAIN ' + row)
         except Exception as e:
             result_key.append(row)
             result_value.append(str(e))
     if result_key:
         return False, dict(zip(result_key, result_value))
     else:
         return True, 'ok'
Example #45
0
def exec_sql(engine: Engine, path: Path):
    """
    Call this function to execute the SQL statements within a file against
    your database.

    :param engine: the engine connected to the database
    :param path: the path to the containing your SQL statements
    """
    with engine.connect() as connection:
        logger: logging.Logger = logging.getLogger(__name__)
        for sql_stmt in sqlparse.split(path.read_text().strip()):
            # Parse the statement so that we may detect comments.
            sqlp = sqlparse.parse(sql_stmt)
            # If the parsed statement has only one token and it's statement
            # type is 'unknown'...
            if len(sqlp) == 1 and sqlp[0].get_type() == 'UNKNOWN':
                # ...move along.  This is likely a comment and will cause an
                # exception if we try to execute it by itself.
                continue
            # We're all set.  Execute the statement.
            logger.debug(sql_stmt)
            connection.execute(sql_stmt)
Example #46
0
    def execute(self, queries, callback, stream=False):
        queryToRun = ''

        for query in self.getOptionsForSgdbCli()['before']:
            queryToRun += query + "\n"

        if isinstance(queries, str):
            queries = [queries]

        for rawQuery in queries:
            for query in sqlparse.split(rawQuery):
                if self.safe_limit:
                    parsedTokens = sqlparse.parse(query.strip().replace(
                        "'", "\""))
                    if ((parsedTokens[0][0].ttype in sqlparse.tokens.Keyword
                         and parsedTokens[0][0].value == 'select')):
                        applySafeLimit = True
                        for parse in parsedTokens:
                            for token in parse.tokens:
                                if token.ttype in sqlparse.tokens.Keyword and token.value == 'limit':
                                    applySafeLimit = False
                        if applySafeLimit:
                            if (query.strip()[-1:] == ';'):
                                query = query.strip()[:-1]
                            query += " LIMIT {0};".format(self.safe_limit)
                queryToRun += query + "\n"

        Log("Query: " + queryToRun)

        if self.history:
            self.history.add(queryToRun)

        self.Command.createAndRun(self.builArgs(),
                                  queryToRun,
                                  callback,
                                  options={'show_query': self.show_query},
                                  timeout=self.timeout,
                                  stream=stream)
Example #47
0
 def query_check(self, db_name=None, sql=''):
     # 查询语句的检查、注释去除、切分
     result = {'msg': '', 'bad_query': False, 'filtered_sql': sql, 'has_star': False}
     banned_keywords = ["ascii", "char", "charindex", "concat", "concat_ws", "difference", "format",
                        "len", "nchar", "patindex", "quotename", "replace", "replicate",
                        "reverse", "right", "soundex", "space", "str", "string_agg",
                        "string_escape", "string_split", "stuff", "substring", "trim", "unicode"]
     keyword_warning = ''
     star_patter = r"(^|,| )\*( |\(|$)"
     # 删除注释语句,进行语法判断,执行第一条有效sql
     try:
         sql = sql.format(sql, strip_comments=True)
         sql = sqlparse.split(sql)[0]
         result['filtered_sql'] = sql.strip()
         sql_lower = sql.lower()
     except IndexError:
         result['has_star'] = True
         result['msg'] = '没有有效的SQL语句'
         return result
     if re.match(r"^select", sql_lower) is None:
         result['bad_query'] = True
         result['msg'] = '仅支持^select语法!'
         return result
     if re.search(star_patter, sql_lower) is not None:
         keyword_warning += '禁止使用 * 关键词\n'
         result['bad_query'] = True
         result['has_star'] = True
     if '+' in sql_lower:
         keyword_warning += '禁止使用 + 关键词\n'
         result['bad_query'] = True
     for keyword in banned_keywords:
         pattern = r"(^|,| |=){}( |\(|$)".format(keyword)
         if re.search(pattern, sql_lower) is not None:
             keyword_warning += '禁止使用 {} 关键词\n'.format(keyword)
             result['bad_query'] = True
     if result.get('bad_query'):
         result['msg'] = keyword_warning
     return result
Example #48
0
def explain(request):
    sql_content = request.POST.get('sql_content')
    instance_name = request.POST.get('instance_name')
    db_name = request.POST.get('db_name')
    result = {'status': 0, 'msg': 'ok', 'data': []}

    # 服务器端参数验证
    if sql_content is None or instance_name is None:
        result['status'] = 1
        result['msg'] = '页面提交参数可能为空'
        return HttpResponse(json.dumps(result),
                            content_type='application/json')

    sql_content = sql_content.strip()

    # 过滤非查询的语句
    if re.match(r"^explain", sql_content.lower()):
        pass
    else:
        result['status'] = 1
        result['msg'] = '仅支持explain开头的语句,请检查'
        return HttpResponse(json.dumps(result),
                            content_type='application/json')

    # 执行第一条有效sql
    sql_content = sqlparse.split(sql_content)[0].rstrip(';')

    # 执行获取执行计划语句
    sql_result = Dao(instance_name=instance_name).mysql_query(
        str(db_name), sql_content)

    result['data'] = sql_result

    # 返回查询结果
    return HttpResponse(json.dumps(result,
                                   cls=ExtendJSONEncoder,
                                   bigint_as_string=True),
                        content_type='application/json')
Example #49
0
    def add_query(self,
                  sql,
                  model_name=None,
                  auto_begin=True,
                  bindings=None,
                  abridge_sql_log=False):

        connection = None
        cursor = None

        # TODO: is this sufficient? Largely copy+pasted from snowflake, so
        # there's some common behavior here we can maybe factor out into the
        # SQLAdapter?
        queries = [q.rstrip(';') for q in sqlparse.split(sql)]

        for individual_query in queries:
            # hack -- after the last ';', remove comments and don't run
            # empty queries. this avoids using exceptions as flow control,
            # and also allows us to return the status of the last cursor
            without_comments = re.sub(re.compile('^.*(--.*)$', re.MULTILINE),
                                      '', individual_query).strip()

            if without_comments == "":
                continue

            parent = super(HiveConnectionManager, self)
            connection, cursor = parent.add_query(individual_query, model_name,
                                                  auto_begin, bindings,
                                                  abridge_sql_log)

        if cursor is None:
            raise RuntimeException(
                "Tried to run an empty query on model '{}'. If you are "
                "conditionally running\nsql, eg. in a model hook, make "
                "sure your `else` clause contains valid sql!\n\n"
                "Provided SQL:\n{}".format(model_name, sql))

            return connection, cursor
Example #50
0
 def query_check(self, db_name=None, sql=''):
     # 查询语句的检查、注释去除、切分
     result = {
         'msg': '',
         'bad_query': False,
         'filtered_sql': sql,
         'has_star': False
     }
     # 删除注释语句,进行语法判断,执行第一条有效sql
     try:
         sql = sqlparse.format(sql, strip_comments=True)
         sql = sqlparse.split(sql)[0]
         result['filtered_sql'] = sql.strip()
     except IndexError:
         result['bad_query'] = True
         result['msg'] = '没有有效的SQL语句'
     if re.match(r"^select", sql, re.I) is None:
         result['bad_query'] = True
         result['msg'] = '不支持的查询语法类型!'
     if '*' in sql:
         result['has_star'] = True
         result['msg'] = 'SQL语句中含有 * '
     return result
Example #51
0
def execute_sql(connection, code_to_execute, schema=None):

    sql_statements = sqlparse.split(code_to_execute)

    if schema is not None:
        pre_statement = "set search_path=%s;" % schema
    else:
        pre_statement = ""

    trans = connection.begin()
    try:
        i = 0
        for sql_statement in sql_statements:
            print(sql_statement)
            sql_to_execute = pre_statement + sql_statement
            connection.execute(sql_to_execute)
            i += 1
    except:
        trans.rollback()
        raise

    trans.commit()
    print("Executed %s statements" % i)
Example #52
0
def init_database():
    try:

        db = pymysql.connect(host=config['database']['host'],
                             user=config['database']['user'],
                             password=config['database']['password'],
                             port=config['database']['port'],
                             cursorclass=pymysql.cursors.DictCursor)

        cursor = db.cursor()
        stmts = sqlparse.split(sql.read())

        for stmt in stmts:
            cursor.execute(stmt)
        db.commit()
        print('All queries were executed successfully')
    except (pymysql.err.OperationalError, pymysql.ProgrammingError,
            pymysql.InternalError, pymysql.IntegrityError) as error:
        print('Exception number: {}, value {!r}'.format(error.args[0], error))
        raise
    else:
        cursor.close()
        db.close()
Example #53
0
def run(conn, sql, config, user_namespace):
    if sql.strip():
        for statement in sqlparse.split(sql):
            if sql.strip().split()[0].lower() == 'begin':
                raise Exception("ipython_sql does not support transactions")
            txt = sqlalchemy.sql.text(statement)
            result = conn.session.execute(txt, user_namespace)
            try:
                # mssql has autocommit
                if 'mssql' not in str(conn.dialect):
                    conn.session.execute('commit')
            except sqlalchemy.exc.OperationalError:
                pass  # not all engines can commit
            if result and config.feedback:
                print(interpret_rowcount(result.rowcount))
        resultset = ResultSet(result, statement, config)
        if config.autopandas:
            return resultset.DataFrame()
        else:
            return resultset
        #returning only last result, intentionally
    else:
        return 'Connected: %s' % conn.name
Example #54
0
 def from_file(filepath):
     """Read in a RawUdf from a SQL file on disk."""
     dirpath, basename = os.path.split(filepath)
     name = os.path.basename(dirpath) + "_" + basename.replace(".sql", "")
     with open(filepath) as f:
         text = f.read()
     sql = sqlparse.format(text, strip_comments=True)
     statements = [s for s in sqlparse.split(sql) if s.strip()]
     definitions = [
         s for s in statements
         if s.lower().startswith("create temp function")
     ]
     tests = [
         s for s in statements
         if not s.lower().startswith("create temp function")
     ]
     dependencies = re.findall(UDF_RE, "\n".join(definitions))
     if name not in dependencies:
         raise ValueError(
             "Expected a temporary UDF named {} to be defined in {}".format(
                 name, filepath))
     dependencies.remove(name)
     return RawUdf(name, filepath, definitions, tests, set(dependencies))
Example #55
0
    def post(self, request):
        data = format_request(request)
        sqlContent = data.get('sql_content').strip()

        sqlSplit = []
        for stmt in sqlparse.split(sqlContent):
            sql = sqlparse.parse(stmt)[0]
            sql_comment = sql.token_first()
            if isinstance(sql_comment, sqlparse.sql.Comment):
                sqlSplit.append({
                    'comment': sql_comment.value,
                    'sql': sql.value.replace(sql_comment.value, '')
                })
            else:
                sqlSplit.append({'comment': '', 'sql': sql.value})

        beautifySQL_list = []
        try:
            for row in sqlSplit:
                comment = row['comment']
                sql = row['sql']
                res = sqlparse.parse(sql)
                if res[0].tokens[0].ttype[1] == 'DML':
                    sqlFormat = sqlparse.format(sql,
                                                keyword_case='upper',
                                                reindent=True)
                    beautifySQL_list.append(comment + sqlFormat)
                elif res[0].tokens[0].ttype[1] == 'DDL':
                    sqlFormat = sqlparse.format(sql, keyword_case='upper')
                    beautifySQL_list.append(comment + sqlFormat)
            beautifySQL = '\n\n'.join(beautifySQL_list)
            context = {'data': beautifySQL}
        except Exception as err:
            raise OSError(err)
            context = {'errCode': 400, 'errMsg': "注释不合法, 请检查"}

        return HttpResponse(json.dumps(context))
Example #56
0
    def on_action_execute(self, widget):
        try: 
            self.logger.get_model().clear()
            tabs = self.notebookLog.get_n_pages()
            while(tabs != 0):
                self.notebookLog.remove_page(tabs)
                tabs -=1            
            
            widget = self.notebookEditor.get_nth_page(self.notebookEditor.get_current_page())
            editorText = widget.get_tooltip_text()
            
            currentSource = widget.get_children()
            _buffer = currentSource[0].get_buffer()
            startIter = _buffer.get_start_iter()
            endIter = _buffer.get_end_iter() 
            text = _buffer.get_text(startIter, endIter, True)  
            text = text.rstrip('\n')                
            
            sqlStatements = sqlparse.split(text)
            
            statements = []        
            for i in sqlStatements:
                formatedText = sqlparse.format(i, reindent=True,keyword_case='upper', strip_comments=True)
                statements.append(formatedText)
            
            statements = list(filter(None, statements)) # fastest

            if len(statements) >= 1:  
                n = len(statements)
                for i in range(0,n):
                    self.insert_result(statements[i])                           
        except sqlite3.OperationalError as error:
            print(error)
        except AttributeError as error:
            print(error)
        finally:
            self.boxNotebookLog.show_all()
Example #57
0
    def execute(self, sql, **kwargs):
        """
        Execute a DDL or DML SQL statement.

            Args:
                sql (string): SQL statement(s) separated by semicolons (;)
                kwargs (dict): optional statement named parameters
            Returns:
                results (list): list of dataframes. Non-SELECT statements
                    returns empty dataframes.

        """
        results = []
        statements = sqlparse.split(sql)
        connection = self.engine.connect()
        # begin transaction
        trans = connection.begin()
        try:
            for statement in statements:
                result_set = pd.DataFrame()
                result = connection.execute(text(statement.strip(';')),
                                            **kwargs)
                if result.returns_rows:
                    result_set = pd.DataFrame(result.fetchall())
                    result_set.columns = result.keys()
                    LOGGER.info('Number of returned rows: %s',
                                str(len(result_set.index)))
                results.append(result_set)
            # end transaction
            trans.commit()
        except DatabaseError as db_error:
            trans.rollback()
            LOGGER.error(db_error)
            raise
        finally:
            connection.close()
        return results
Example #58
0
 def query_check(self, db_name=None, sql=''):
     # 查询语句的检查、注释去除、切分
     result = {'msg': '', 'bad_query': False, 'filtered_sql': sql, 'has_star': False}
     # 删除注释语句,进行语法判断,执行第一条有效sql
     try:
         sql = sqlparse.format(sql, strip_comments=True)
         sql = sqlparse.split(sql)[0]
         result['filtered_sql'] = sql.strip()
     except IndexError:
         result['bad_query'] = True
         result['msg'] = '没有有效的SQL语句'
     if re.match(r"^select|^show|^explain", sql, re.I) is None:
         result['bad_query'] = True
         result['msg'] = '不支持的查询语法类型!'
     if '*' in sql:
         result['has_star'] = True
         result['msg'] = 'SQL语句中含有 * '
     # select语句先使用Explain判断语法是否正确
     if re.match(r"^select", sql, re.I):
         explain_result = self.query(db_name=db_name, sql=f"explain {sql}")
         if explain_result.error:
             result['bad_query'] = True
             result['msg'] = explain_result.error
     return result
Example #59
0
 def query_check(self, db_name=None, sql=''):
     # 查询语句的检查、注释去除、切分
     result = {'msg': '', 'bad_query': False, 'filtered_sql': sql, 'has_star': False}
     keyword_warning = ''
     sql_whitelist = ['select', 'explain']
     # 根据白名单list拼接pattern语句
     whitelist_pattern = "^" + "|^".join(sql_whitelist)
     # 删除注释语句,进行语法判断,执行第一条有效sql
     try:
         sql = sql.format(sql, strip_comments=True)
         sql = sqlparse.split(sql)[0]
         result['filtered_sql'] = sql.strip()
         # sql_lower = sql.lower()
     except IndexError:
         result['has_star'] = True
         result['msg'] = '没有有效的SQL语句'
         return result
     if re.match(whitelist_pattern, sql) is None:
         result['bad_query'] = True
         result['msg'] = '仅支持{}语法!'.format(','.join(sql_whitelist))
         return result
     if result.get('bad_query'):
         result['msg'] = keyword_warning
     return result
Example #60
0
    def run(self, statement):
        """Execute the sql in the database and return the results. The results
        are a list of tuples. Each tuple has 4 values
        (title, rows, headers, status).
        """

        # Remove spaces and EOL
        statement = statement.strip()
        if not statement:  # Empty string
            yield (None, None, None, None)

        # Split the sql into separate queries and run each one.
        # Unless it's saving a favorite query, in which case we
        # want to save them all together.
        if statement.startswith('\\fs'):
            components = [statement]
        else:
            components = sqlparse.split(statement)

        for sql in components:
            # Remove spaces, eol and semi-colons.
            sql = sql.rstrip(';')

            # \G is treated specially since we have to set the expanded output
            # and then proceed to execute the sql as normal.
            if sql.endswith('\\G'):
                special.set_expanded_output(True)
                yield self.execute_normal_sql(sql.rsplit('\\G', 1)[0])
            else:
                try:  # Special command
                    _logger.debug('Trying a dbspecial command. sql: %r', sql)
                    cur = self.conn.cursor()
                    for result in special.execute(cur, sql):
                        yield result
                except special.CommandNotFound:  # Regular SQL
                    yield self.execute_normal_sql(sql)