Beispiel #1
0
    def test_create_table_as(self):
        q = ParsedQuery('SELECT * FROM outer_space;')

        self.assertEqual(
            'CREATE TABLE tmp AS \nSELECT * FROM outer_space',
            q.as_create_table('tmp'))

        self.assertEqual(
            'DROP TABLE IF EXISTS tmp;\n'
            'CREATE TABLE tmp AS \nSELECT * FROM outer_space',
            q.as_create_table('tmp', overwrite=True))

        # now without a semicolon
        q = ParsedQuery('SELECT * FROM outer_space')
        self.assertEqual(
            'CREATE TABLE tmp AS \nSELECT * FROM outer_space',
            q.as_create_table('tmp'))

        # now a multi-line query
        multi_line_query = (
            'SELECT * FROM planets WHERE\n'
            "Luke_Father = 'Darth Vader'")
        q = ParsedQuery(multi_line_query)
        self.assertEqual(
            'CREATE TABLE tmp AS \nSELECT * FROM planets WHERE\n'
            "Luke_Father = 'Darth Vader'",
            q.as_create_table('tmp'),
        )
Beispiel #2
0
def execute_sql_statements(
    ctask, query_id, rendered_query, return_results=True, store_results=False,
    user_name=None, session=None, start_time=None,
):
    """Executes the sql query returns the results."""
    if store_results and start_time:
        # only asynchronous queries
        stats_logger.timing(
            'sqllab.query.time_pending', now_as_float() - start_time)

    query = get_query(query_id, session)
    payload = dict(query_id=query_id)
    database = query.database
    db_engine_spec = database.db_engine_spec
    db_engine_spec.patch()

    if store_results and not results_backend:
        raise SqlLabException("Results backend isn't configured.")

    # Breaking down into multiple statements
    parsed_query = ParsedQuery(rendered_query)
    statements = parsed_query.get_statements()
    logging.info(f'Executing {len(statements)} statement(s)')

    logging.info("Set query to 'running'")
    query.status = QueryStatus.RUNNING
    query.start_running_time = now_as_float()

    engine = database.get_sqla_engine(
        schema=query.schema,
        nullpool=True,
        user_name=user_name,
        source=sources.get('sql_lab', None),
    )
    # Sharing a single connection and cursor across the
    # execution of all statements (if many)
    with closing(engine.raw_connection()) as conn:
        with closing(conn.cursor()) as cursor:
            statement_count = len(statements)
            for i, statement in enumerate(statements):
                # TODO CHECK IF STOPPED
                msg = f'Running statement {i+1} out of {statement_count}'
                logging.info(msg)
                query.set_extra_json_key('progress', msg)
                session.commit()
                is_last_statement = i == len(statements) - 1
                try:
                    cdf = execute_sql_statement(
                        statement, query, user_name, session, cursor,
                        return_results=is_last_statement and return_results)
                    msg = f'Running statement {i+1} out of {statement_count}'
                except Exception as e:
                    msg = str(e)
                    if statement_count > 1:
                        msg = f'[Statement {i+1} out of {statement_count}] ' + msg
                    payload = handle_query_error(msg, query, session, payload)
                    return payload

    # Success, updating the query entry in database
    query.rows = cdf.size
    query.progress = 100
    query.set_extra_json_key('progress', None)
    query.status = QueryStatus.SUCCESS
    if query.select_as_cta:
        query.select_sql = database.select_star(
            query.tmp_table_name,
            limit=query.limit,
            schema=database.force_ctas_schema,
            show_cols=False,
            latest_partition=False)
    query.end_time = now_as_float()
    session.commit()

    payload.update({
        'status': query.status,
        'data': cdf.data if cdf.data else [],
        'columns': cdf.columns if cdf.columns else [],
        'query': query.to_dict(),
    })

    if store_results:
        key = str(uuid.uuid4())
        logging.info(f'Storing results in results backend, key: {key}')
        with stats_timing('sqllab.query.results_backend_write', stats_logger):
            json_payload = json.dumps(
                payload, default=json_iso_dttm_ser, ignore_nan=True)
            cache_timeout = database.cache_timeout
            if cache_timeout is None:
                cache_timeout = config.get('CACHE_DEFAULT_TIMEOUT', 0)
            results_backend.set(key, zlib_compress(json_payload), cache_timeout)
        query.results_key = key
    session.commit()

    if return_results:
        return payload
 def apply_limit_to_sql(cls, sql: str, limit: int,
                        database: "Database") -> str:
     new_sql = ParsedQuery(sql).set_alias()
     return super().apply_limit_to_sql(new_sql, limit, database)
Beispiel #4
0
def execute_sql_statement(
        sql_statement, query, user_name, session,
        cursor, return_results=False):
    """Executes a single SQL statement"""
    database = query.database
    db_engine_spec = database.db_engine_spec
    parsed_query = ParsedQuery(sql_statement)
    sql = parsed_query.stripped()
    SQL_MAX_ROWS = app.config.get('SQL_MAX_ROW')

    if not parsed_query.is_readonly() and not database.allow_dml:
        raise SqlLabSecurityException(
            _('Only `SELECT` statements are allowed against this database'))
    if query.select_as_cta:
        if not parsed_query.is_select():
            raise SqlLabException(_(
                'Only `SELECT` statements can be used with the CREATE TABLE '
                'feature.'))
        if not query.tmp_table_name:
            start_dttm = datetime.fromtimestamp(query.start_time)
            query.tmp_table_name = 'tmp_{}_table_{}'.format(
                query.user_id, start_dttm.strftime('%Y_%m_%d_%H_%M_%S'))
        sql = parsed_query.as_create_table(query.tmp_table_name)
        query.select_as_cta_used = True
    if parsed_query.is_select():
        if SQL_MAX_ROWS and (not query.limit or query.limit > SQL_MAX_ROWS):
            query.limit = SQL_MAX_ROWS
        if query.limit:
            sql = database.apply_limit_to_sql(sql, query.limit)

    # Hook to allow environment-specific mutation (usually comments) to the SQL
    SQL_QUERY_MUTATOR = config.get('SQL_QUERY_MUTATOR')
    if SQL_QUERY_MUTATOR:
        sql = SQL_QUERY_MUTATOR(sql, user_name, security_manager, database)

    try:
        if log_query:
            log_query(
                query.database.sqlalchemy_uri,
                query.executed_sql,
                query.schema,
                user_name,
                __name__,
                security_manager,
            )
        query.executed_sql = sql
        with stats_timing('sqllab.query.time_executing_query', stats_logger):
            logging.info('Running query: \n{}'.format(sql))
            db_engine_spec.execute(cursor, sql, async_=True)
            logging.info('Handling cursor')
            db_engine_spec.handle_cursor(cursor, query, session)

        with stats_timing('sqllab.query.time_fetching_results', stats_logger):
            logging.debug('Fetching data for query object: {}'.format(query.to_dict()))
            data = db_engine_spec.fetch_data(cursor, query.limit)

    except SoftTimeLimitExceeded as e:
        logging.exception(e)
        raise SqlLabTimeoutException(
            "SQL Lab timeout. This environment's policy is to kill queries "
            'after {} seconds.'.format(SQLLAB_TIMEOUT))
    except Exception as e:
        logging.exception(e)
        raise SqlLabException(db_engine_spec.extract_error_message(e))

    logging.debug('Fetching cursor description')
    cursor_description = cursor.description
    return dataframe.SupersetDataFrame(data, cursor_description, db_engine_spec)
Beispiel #5
0
    def validate_statement(
        cls, statement: str, database: Database, cursor: Any, user_name: str
    ) -> Optional[SQLValidationAnnotation]:
        # pylint: disable=too-many-locals
        db_engine_spec = database.db_engine_spec
        parsed_query = ParsedQuery(statement)
        sql = parsed_query.stripped()

        # Hook to allow environment-specific mutation (usually comments) to the SQL
        # pylint: disable=invalid-name
        SQL_QUERY_MUTATOR = config["SQL_QUERY_MUTATOR"]
        if SQL_QUERY_MUTATOR:
            sql = SQL_QUERY_MUTATOR(sql, user_name, security_manager, database)

        # Transform the final statement to an explain call before sending it on
        # to presto to validate
        sql = f"EXPLAIN (TYPE VALIDATE) {sql}"

        # Invoke the query against presto. NB this deliberately doesn't use the
        # engine spec's handle_cursor implementation since we don't record
        # these EXPLAIN queries done in validation as proper Query objects
        # in the superset ORM.
        from pyhive.exc import DatabaseError

        try:
            db_engine_spec.execute(cursor, sql)
            polled = cursor.poll()
            while polled:
                logger.info("polling presto for validation progress")
                stats = polled.get("stats", {})
                if stats:
                    state = stats.get("state")
                    if state == "FINISHED":
                        break
                time.sleep(0.2)
                polled = cursor.poll()
            db_engine_spec.fetch_data(cursor, MAX_ERROR_ROWS)
            return None
        except DatabaseError as db_error:
            # The pyhive presto client yields EXPLAIN (TYPE VALIDATE) responses
            # as though they were normal queries. In other words, it doesn't
            # know that errors here are not exceptional. To map this back to
            # ordinary control flow, we have to trap the category of exception
            # raised by the underlying client, match the exception arguments
            # pyhive provides against the shape of dictionary for a presto query
            # invalid error, and restructure that error as an annotation we can
            # return up.

            # If the first element in the DatabaseError is not a dictionary, but
            # is a string, return that message.
            if db_error.args and isinstance(db_error.args[0], str):
                raise PrestoSQLValidationError(db_error.args[0]) from db_error

            # Confirm the first element in the DatabaseError constructor is a
            # dictionary with error information. This is currently provided by
            # the pyhive client, but may break if their interface changes when
            # we update at some point in the future.
            if not db_error.args or not isinstance(db_error.args[0], dict):
                raise PrestoSQLValidationError(
                    "The pyhive presto client returned an unhandled " "database error."
                ) from db_error
            error_args: Dict[str, Any] = db_error.args[0]

            # Confirm the two fields we need to be able to present an annotation
            # are present in the error response -- a message, and a location.
            if "message" not in error_args:
                raise PrestoSQLValidationError(
                    "The pyhive presto client did not report an error message"
                ) from db_error
            if "errorLocation" not in error_args:
                # Pylint is confused about the type of error_args, despite the hints
                # and checks above.
                # pylint: disable=invalid-sequence-index
                message = error_args["message"] + "\n(Error location unknown)"
                # If we have a message but no error location, return the message and
                # set the location as the beginning.
                return SQLValidationAnnotation(
                    message=message, line_number=1, start_column=1, end_column=1
                )

            # pylint: disable=invalid-sequence-index
            message = error_args["message"]
            err_loc = error_args["errorLocation"]
            line_number = err_loc.get("lineNumber", None)
            start_column = err_loc.get("columnNumber", None)
            end_column = err_loc.get("columnNumber", None)

            return SQLValidationAnnotation(
                message=message,
                line_number=line_number,
                start_column=start_column,
                end_column=end_column,
            )
        except Exception as ex:
            logger.exception(f"Unexpected error running validation query: {ex}")
            raise ex
Beispiel #6
0
 def is_readonly(sql: str) -> bool:
     return PrestoEngineSpec.is_readonly_query(ParsedQuery(sql))
Beispiel #7
0
def execute_sql_statements(
    query_id,
    rendered_query,
    return_results=True,
    store_results=False,
    user_name=None,
    session=None,
    start_time=None,
    expand_data=False,
    log_params=None,
):  # pylint: disable=too-many-arguments, too-many-locals, too-many-statements
    """Executes the sql query returns the results."""
    if store_results and start_time:
        # only asynchronous queries
        stats_logger.timing("sqllab.query.time_pending", now_as_float() - start_time)

    query = get_query(query_id, session)
    payload = dict(query_id=query_id)
    database = query.database
    db_engine_spec = database.db_engine_spec
    db_engine_spec.patch()

    if database.allow_run_async and not results_backend:
        raise SqlLabException("Results backend isn't configured.")

    # Breaking down into multiple statements
    parsed_query = ParsedQuery(rendered_query)
    statements = parsed_query.get_statements()
    logger.info(f"Query {query_id}: Executing {len(statements)} statement(s)")

    logger.info(f"Query {query_id}: Set query to 'running'")
    query.status = QueryStatus.RUNNING
    query.start_running_time = now_as_float()
    session.commit()

    engine = database.get_sqla_engine(
        schema=query.schema,
        nullpool=True,
        user_name=user_name,
        source=QuerySource.SQL_LAB,
    )
    # Sharing a single connection and cursor across the
    # execution of all statements (if many)
    with closing(engine.raw_connection()) as conn:
        with closing(conn.cursor()) as cursor:
            statement_count = len(statements)
            for i, statement in enumerate(statements):
                # Check if stopped
                query = get_query(query_id, session)
                if query.status == QueryStatus.STOPPED:
                    return None

                # Run statement
                msg = f"Running statement {i+1} out of {statement_count}"
                logger.info(f"Query {query_id}: {msg}")
                query.set_extra_json_key("progress", msg)
                session.commit()
                try:
                    result_set = execute_sql_statement(
                        statement, query, user_name, session, cursor, log_params
                    )
                except Exception as e:  # pylint: disable=broad-except
                    msg = str(e)
                    if statement_count > 1:
                        msg = f"[Statement {i+1} out of {statement_count}] " + msg
                    payload = handle_query_error(msg, query, session, payload)
                    return payload

        # Commit the connection so CTA queries will create the table.
        conn.commit()

    # Success, updating the query entry in database
    query.rows = result_set.size
    query.progress = 100
    query.set_extra_json_key("progress", None)
    if query.select_as_cta:
        query.select_sql = database.select_star(
            query.tmp_table_name,
            schema=query.tmp_schema_name,
            limit=query.limit,
            show_cols=False,
            latest_partition=False,
        )
    query.end_time = now_as_float()

    use_arrow_data = store_results and results_backend_use_msgpack
    data, selected_columns, all_columns, expanded_columns = _serialize_and_expand_data(
        result_set, db_engine_spec, use_arrow_data, expand_data
    )

    # TODO: data should be saved separately from metadata (likely in Parquet)
    payload.update(
        {
            "status": QueryStatus.SUCCESS,
            "data": data,
            "columns": all_columns,
            "selected_columns": selected_columns,
            "expanded_columns": expanded_columns,
            "query": query.to_dict(),
        }
    )
    payload["query"]["state"] = QueryStatus.SUCCESS

    if store_results and results_backend:
        key = str(uuid.uuid4())
        logger.info(f"Query {query_id}: Storing results in results backend, key: {key}")
        with stats_timing("sqllab.query.results_backend_write", stats_logger):
            with stats_timing(
                "sqllab.query.results_backend_write_serialization", stats_logger
            ):
                serialized_payload = _serialize_payload(
                    payload, results_backend_use_msgpack
                )
            cache_timeout = database.cache_timeout
            if cache_timeout is None:
                cache_timeout = config["CACHE_DEFAULT_TIMEOUT"]

            compressed = zlib_compress(serialized_payload)
            logger.debug(
                f"*** serialized payload size: {getsizeof(serialized_payload)}"
            )
            logger.debug(f"*** compressed payload size: {getsizeof(compressed)}")
            results_backend.set(key, compressed, cache_timeout)
        query.results_key = key

    query.status = QueryStatus.SUCCESS
    session.commit()

    if return_results:
        # since we're returning results we need to create non-arrow data
        if use_arrow_data:
            (
                data,
                selected_columns,
                all_columns,
                expanded_columns,
            ) = _serialize_and_expand_data(
                result_set, db_engine_spec, False, expand_data
            )
            payload.update(
                {
                    "data": data,
                    "columns": all_columns,
                    "selected_columns": selected_columns,
                    "expanded_columns": expanded_columns,
                }
            )
        return payload

    return None
Beispiel #8
0
def execute_sql_statement(  # pylint: disable=too-many-arguments,too-many-locals
    sql_statement: str,
    query: Query,
    user_name: Optional[str],
    session: Session,
    cursor: Any,
    log_params: Optional[Dict[str, Any]],
    apply_ctas: bool = False,
) -> SupersetResultSet:
    """Executes a single SQL statement"""
    database = query.database
    db_engine_spec = database.db_engine_spec
    parsed_query = ParsedQuery(sql_statement)
    sql = parsed_query.stripped()
    # This is a test to see if the query is being
    # limited by either the dropdown or the sql.
    # We are testing to see if more rows exist than the limit.
    increased_limit = None if query.limit is None else query.limit + 1

    if not db_engine_spec.is_readonly_query(
            parsed_query) and not database.allow_dml:
        raise SupersetErrorException(
            SupersetError(
                message=__(
                    "Only SELECT statements are allowed against this database."
                ),
                error_type=SupersetErrorType.DML_NOT_ALLOWED_ERROR,
                level=ErrorLevel.ERROR,
            ))
    if apply_ctas:
        if not query.tmp_table_name:
            start_dttm = datetime.fromtimestamp(query.start_time)
            query.tmp_table_name = "tmp_{}_table_{}".format(
                query.user_id, start_dttm.strftime("%Y_%m_%d_%H_%M_%S"))
        sql = parsed_query.as_create_table(
            query.tmp_table_name,
            schema_name=query.tmp_schema_name,
            method=query.ctas_method,
        )
        query.select_as_cta_used = True

    # Do not apply limit to the CTA queries when SQLLAB_CTAS_NO_LIMIT is set to true
    if db_engine_spec.is_select_query(parsed_query) and not (
            query.select_as_cta_used and SQLLAB_CTAS_NO_LIMIT):
        if SQL_MAX_ROW and (not query.limit or query.limit > SQL_MAX_ROW):
            query.limit = SQL_MAX_ROW
        sql = apply_limit_if_exists(database, increased_limit, query, sql)

    # Hook to allow environment-specific mutation (usually comments) to the SQL
    sql = SQL_QUERY_MUTATOR(sql, user_name, security_manager, database)
    try:
        query.executed_sql = sql
        if log_query:
            log_query(
                query.database.sqlalchemy_uri,
                query.executed_sql,
                query.schema,
                user_name,
                __name__,
                security_manager,
                log_params,
            )
        session.commit()
        with stats_timing("sqllab.query.time_executing_query", stats_logger):
            logger.debug("Query %d: Running query: %s", query.id, sql)
            db_engine_spec.execute(cursor, sql, async_=True)
            logger.debug("Query %d: Handling cursor", query.id)
            db_engine_spec.handle_cursor(cursor, query, session)

        with stats_timing("sqllab.query.time_fetching_results", stats_logger):
            logger.debug(
                "Query %d: Fetching data for query object: %s",
                query.id,
                str(query.to_dict()),
            )
            data = db_engine_spec.fetch_data(cursor, increased_limit)
            if query.limit is None or len(data) <= query.limit:
                query.limiting_factor = LimitingFactor.NOT_LIMITED
            else:
                # return 1 row less than increased_query
                data = data[:-1]
    except SoftTimeLimitExceeded as ex:
        logger.warning("Query %d: Time limit exceeded", query.id)
        logger.debug("Query %d: %s", query.id, ex)
        raise SupersetErrorException(
            SupersetError(
                message=__(
                    "The query was killed after %(sqllab_timeout)s seconds. It might "
                    "be too complex, or the database might be under heavy load.",
                    sqllab_timeout=SQLLAB_TIMEOUT,
                ),
                error_type=SupersetErrorType.SQLLAB_TIMEOUT_ERROR,
                level=ErrorLevel.ERROR,
            )) from ex
    except Exception as ex:
        # query is stopped in another thread/worker
        # stopping raises expected exceptions which we should skip
        session.refresh(query)
        if query.status == QueryStatus.STOPPED:
            raise SqlLabQueryStoppedException() from ex

        logger.error("Query %d: %s", query.id, type(ex), exc_info=True)
        logger.debug("Query %d: %s", query.id, ex)
        raise SqlLabException(db_engine_spec.extract_error_message(ex)) from ex

    logger.debug("Query %d: Fetching cursor description", query.id)
    cursor_description = cursor.description
    return SupersetResultSet(data, cursor_description, db_engine_spec)
def execute_sql_statements(
    ctask,
    query_id,
    rendered_query,
    return_results=True,
    store_results=False,
    user_name=None,
    session=None,
    start_time=None,
):
    """Executes the sql query returns the results."""
    if store_results and start_time:
        # only asynchronous queries
        stats_logger.timing("sqllab.query.time_pending",
                            now_as_float() - start_time)

    query = get_query(query_id, session)
    payload = dict(query_id=query_id)
    database = query.database
    db_engine_spec = database.db_engine_spec
    db_engine_spec.patch()

    if store_results and not results_backend:
        raise SqlLabException("Results backend isn't configured.")

    # Breaking down into multiple statements
    parsed_query = ParsedQuery(rendered_query)
    statements = parsed_query.get_statements()
    logging.info(f"Query {query_id}: Executing {len(statements)} statement(s)")

    logging.info(f"Query {query_id}: Set query to 'running'")
    query.status = QueryStatus.RUNNING
    query.start_running_time = now_as_float()
    session.commit()

    engine = database.get_sqla_engine(
        schema=query.schema,
        nullpool=True,
        user_name=user_name,
        source=sources.get("sql_lab", None),
    )
    # Sharing a single connection and cursor across the
    # execution of all statements (if many)
    with closing(engine.raw_connection()) as conn:
        with closing(conn.cursor()) as cursor:
            statement_count = len(statements)
            for i, statement in enumerate(statements):
                # Check if stopped
                query = get_query(query_id, session)
                if query.status == QueryStatus.STOPPED:
                    return

                # Run statement
                msg = f"Running statement {i+1} out of {statement_count}"
                logging.info(f"Query {query_id}: {msg}")
                query.set_extra_json_key("progress", msg)
                session.commit()
                try:
                    cdf = execute_sql_statement(statement, query, user_name,
                                                session, cursor)
                except Exception as e:
                    msg = str(e)
                    if statement_count > 1:
                        msg = f"[Statement {i+1} out of {statement_count}] " + msg
                    payload = handle_query_error(msg, query, session, payload)
                    return payload

    # Success, updating the query entry in database
    query.rows = cdf.size
    query.progress = 100
    query.set_extra_json_key("progress", None)
    if query.select_as_cta:
        query.select_sql = database.select_star(
            query.tmp_table_name,
            limit=query.limit,
            schema=database.force_ctas_schema,
            show_cols=False,
            latest_partition=False,
        )
    query.end_time = now_as_float()

    data, selected_columns, all_columns, expanded_columns = _serialize_and_expand_data(
        cdf, db_engine_spec, store_results and results_backend_use_msgpack)

    payload.update({
        "status": QueryStatus.SUCCESS,
        "data": data,
        "columns": all_columns,
        "selected_columns": selected_columns,
        "expanded_columns": expanded_columns,
        "query": query.to_dict(),
    })
    payload["query"]["state"] = QueryStatus.SUCCESS

    if store_results:
        key = str(uuid.uuid4())
        logging.info(
            f"Query {query_id}: Storing results in results backend, key: {key}"
        )
        with stats_timing("sqllab.query.results_backend_write", stats_logger):
            with stats_timing(
                    "sqllab.query.results_backend_write_serialization",
                    stats_logger):
                serialized_payload = _serialize_payload(
                    payload, results_backend_use_msgpack)
            cache_timeout = database.cache_timeout
            if cache_timeout is None:
                cache_timeout = config.get("CACHE_DEFAULT_TIMEOUT", 0)

            compressed = zlib_compress(serialized_payload)
            logging.debug(
                f"*** serialized payload size: {getsizeof(serialized_payload)}"
            )
            logging.debug(
                f"*** compressed payload size: {getsizeof(compressed)}")
            results_backend.set(key, compressed, cache_timeout)
        query.results_key = key

    query.status = QueryStatus.SUCCESS
    session.commit()

    if return_results:
        return payload
def extract_tables(query: str) -> Set[Table]:
    """
    Helper function to extract tables referenced in a query.
    """
    return ParsedQuery(query).tables
def test_update() -> None:
    """
    Test that ``UPDATE`` is not detected as ``SELECT``.
    """
    assert ParsedQuery("UPDATE t1 SET col1 = NULL").is_select() is False
Beispiel #12
0
    def test_is_explain(self):
        query = """
            -- comment
            EXPLAIN select * from table
            -- comment 2
        """
        parsed = ParsedQuery(query)
        self.assertEqual(parsed.is_explain(), True)

        query = """
            -- comment
            EXPLAIN select * from table
            where col1 = 'something'
            -- comment 2

            -- comment 3
            EXPLAIN select * from table
            where col1 = 'something'
            -- comment 4
        """
        parsed = ParsedQuery(query)
        self.assertEqual(parsed.is_explain(), True)

        query = """
            -- This is a comment
                -- this is another comment but with a space in the front
            EXPLAIN SELECT * FROM TABLE
        """
        parsed = ParsedQuery(query)
        self.assertEqual(parsed.is_explain(), True)

        query = """
            /* This is a comment
                 with stars instead */
            EXPLAIN SELECT * FROM TABLE
        """
        parsed = ParsedQuery(query)
        self.assertEqual(parsed.is_explain(), True)

        query = """
            -- comment
            select * from table
            where col1 = 'something'
            -- comment 2
        """
        parsed = ParsedQuery(query)
        self.assertEqual(parsed.is_explain(), False)
Beispiel #13
0
    def test_is_valid_cvas(self):
        """A valid CVAS has a single SELECT statement"""
        query = "SELECT * FROM table"
        parsed = ParsedQuery(query, strip_comments=True)
        assert parsed.is_valid_cvas()

        query = """
            -- comment
            SELECT * FROM table
            -- comment 2
        """
        parsed = ParsedQuery(query, strip_comments=True)
        assert parsed.is_valid_cvas()

        query = """
            -- comment
            SET @value = 42;
            SELECT @value as foo;
            -- comment 2
        """
        parsed = ParsedQuery(query, strip_comments=True)
        assert not parsed.is_valid_cvas()

        query = """
            -- comment
            EXPLAIN SELECT * FROM table
            -- comment 2
        """
        parsed = ParsedQuery(query, strip_comments=True)
        assert not parsed.is_valid_ctas()

        query = """
            SELECT * FROM table;
            INSERT INTO TABLE (foo) VALUES (42);
        """
        parsed = ParsedQuery(query, strip_comments=True)
        assert not parsed.is_valid_ctas()
Beispiel #14
0
def execute_sql_statement(
    sql_statement: str,
    query: Query,
    user_name: Optional[str],
    session: Session,
    cursor: Any,
    log_params: Optional[Dict[str, Any]],
    apply_ctas: bool = False,
) -> SupersetResultSet:
    """Executes a single SQL statement"""
    database = query.database
    db_engine_spec = database.db_engine_spec
    parsed_query = ParsedQuery(sql_statement)
    sql = parsed_query.stripped()
    # This is a test to see if the query is being
    # limited by either the dropdown or the sql.
    # We are testing to see if more rows exist than the limit.
    increased_limit = None if query.limit is None else query.limit + 1

    if not db_engine_spec.is_readonly_query(
            parsed_query) and not database.allow_dml:
        raise SqlLabSecurityException(
            _("Only `SELECT` statements are allowed against this database"))
    if apply_ctas:
        if not query.tmp_table_name:
            start_dttm = datetime.fromtimestamp(query.start_time)
            query.tmp_table_name = "tmp_{}_table_{}".format(
                query.user_id, start_dttm.strftime("%Y_%m_%d_%H_%M_%S"))
        sql = parsed_query.as_create_table(
            query.tmp_table_name,
            schema_name=query.tmp_schema_name,
            method=query.ctas_method,
        )
        query.select_as_cta_used = True

    # Do not apply limit to the CTA queries when SQLLAB_CTAS_NO_LIMIT is set to true
    if parsed_query.is_select() and not (query.select_as_cta_used
                                         and SQLLAB_CTAS_NO_LIMIT):
        if SQL_MAX_ROW and (not query.limit or query.limit > SQL_MAX_ROW):
            query.limit = SQL_MAX_ROW
        if query.limit:
            # We are fetching one more than the requested limit in order
            # to test whether there are more rows than the limit.
            # Later, the extra row will be dropped before sending
            # the results back to the user.
            sql = database.apply_limit_to_sql(sql, increased_limit, force=True)

    # Hook to allow environment-specific mutation (usually comments) to the SQL
    sql = SQL_QUERY_MUTATOR(sql, user_name, security_manager, database)
    try:
        if log_query:
            log_query(
                query.database.sqlalchemy_uri,
                query.executed_sql,
                query.schema,
                user_name,
                __name__,
                security_manager,
                log_params,
            )
        query.executed_sql = sql
        session.commit()
        with stats_timing("sqllab.query.time_executing_query", stats_logger):
            logger.debug("Query %d: Running query: %s", query.id, sql)
            db_engine_spec.execute(cursor, sql, async_=True)
            logger.debug("Query %d: Handling cursor", query.id)
            db_engine_spec.handle_cursor(cursor, query, session)

        with stats_timing("sqllab.query.time_fetching_results", stats_logger):
            logger.debug(
                "Query %d: Fetching data for query object: %s",
                query.id,
                str(query.to_dict()),
            )
            data = db_engine_spec.fetch_data(cursor, increased_limit)
            if query.limit is None or len(data) <= query.limit:
                query.limiting_factor = LimitingFactor.NOT_LIMITED
            else:
                # return 1 row less than increased_query
                data = data[:-1]
    except Exception as ex:
        logger.error("Query %d: %s", query.id, type(ex), exc_info=True)
        logger.debug("Query %d: %s", query.id, ex)
        raise SqlLabException(db_engine_spec.extract_error_message(ex))

    logger.debug("Query %d: Fetching cursor description", query.id)
    cursor_description = cursor.description
    return SupersetResultSet(data, cursor_description, db_engine_spec)
Beispiel #15
0
def after_insert(target: SqlaTable) -> None:  # pylint: disable=too-many-locals
    """
    Copy old datasets to the new models.
    """
    session = inspect(target).session

    # get DB-specific conditional quoter for expressions that point to columns or
    # table names
    database = (
        target.database
        or session.query(Database).filter_by(id=target.database_id).first())
    if not database:
        return
    url = make_url(database.sqlalchemy_uri)
    dialect_class = url.get_dialect()
    conditional_quote = dialect_class().identifier_preparer.quote

    # create columns
    columns = []
    for column in target.columns:
        # ``is_active`` might be ``None`` at this point, but it defaults to ``True``.
        if column.is_active is False:
            continue

        extra_json = json.loads(column.extra or "{}")
        for attr in {
                "groupby", "filterable", "verbose_name", "python_date_format"
        }:
            value = getattr(column, attr)
            if value:
                extra_json[attr] = value

        columns.append(
            NewColumn(
                name=column.column_name,
                type=column.type or "Unknown",
                expression=column.expression
                or conditional_quote(column.column_name),
                description=column.description,
                is_temporal=column.is_dttm,
                is_aggregation=False,
                is_physical=column.expression is None
                or column.expression == "",
                is_spatial=False,
                is_partition=False,
                is_increase_desired=True,
                extra_json=json.dumps(extra_json) if extra_json else None,
                is_managed_externally=target.is_managed_externally,
                external_url=target.external_url,
            ), )

    # create metrics
    for metric in target.metrics:
        extra_json = json.loads(metric.extra or "{}")
        for attr in {"verbose_name", "metric_type", "d3format"}:
            value = getattr(metric, attr)
            if value:
                extra_json[attr] = value

        is_additive = (metric.metric_type
                       and metric.metric_type.lower() in ADDITIVE_METRIC_TYPES)

        columns.append(
            NewColumn(
                name=metric.metric_name,
                type=
                "Unknown",  # figuring this out would require a type inferrer
                expression=metric.expression,
                warning_text=metric.warning_text,
                description=metric.description,
                is_aggregation=True,
                is_additive=is_additive,
                is_physical=False,
                is_spatial=False,
                is_partition=False,
                is_increase_desired=True,
                extra_json=json.dumps(extra_json) if extra_json else None,
                is_managed_externally=target.is_managed_externally,
                external_url=target.external_url,
            ), )

    # physical dataset
    tables = []
    if target.sql is None:
        physical_columns = [column for column in columns if column.is_physical]

        # create table
        table = NewTable(
            name=target.table_name,
            schema=target.schema,
            catalog=None,  # currently not supported
            database_id=target.database_id,
            columns=physical_columns,
            is_managed_externally=target.is_managed_externally,
            external_url=target.external_url,
        )
        tables.append(table)

    # virtual dataset
    else:
        # mark all columns as virtual (not physical)
        for column in columns:
            column.is_physical = False

        # find referenced tables
        parsed = ParsedQuery(target.sql)
        referenced_tables = parsed.tables

        # predicate for finding the referenced tables
        predicate = or_(*[
            and_(
                NewTable.schema == (table.schema or target.schema),
                NewTable.name == table.table,
            ) for table in referenced_tables
        ])
        tables = session.query(NewTable).filter(predicate).all()

    # create the new dataset
    dataset = NewDataset(
        sqlatable_id=target.id,
        name=target.table_name,
        expression=target.sql or conditional_quote(target.table_name),
        tables=tables,
        columns=columns,
        is_physical=target.sql is None,
        is_managed_externally=target.is_managed_externally,
        external_url=target.external_url,
    )
    session.add(dataset)
Beispiel #16
0
 def sql_tables(self) -> List[Table]:
     return list(ParsedQuery(self.sql).tables)
    def estimate_statement_cost(  # pylint: disable=too-many-locals
            cls, statement: str, database, cursor,
            user_name: str) -> Dict[str, str]:
        """
        Generate a SQL query that estimates the cost of a given statement.

        :param statement: A single SQL statement
        :param database: Database instance
        :param cursor: Cursor instance
        :param username: Effective username
        """
        parsed_query = ParsedQuery(statement)
        sql = parsed_query.stripped()

        sql_query_mutator = config.get("SQL_QUERY_MUTATOR")
        if sql_query_mutator:
            sql = sql_query_mutator(sql, user_name, security_manager, database)

        sql = f"EXPLAIN (TYPE IO, FORMAT JSON) {sql}"
        cursor.execute(sql)

        # the output from Presto is a single column and a single row containing
        # JSON:
        #
        #   {
        #     ...
        #     "estimate" : {
        #       "outputRowCount" : 8.73265878E8,
        #       "outputSizeInBytes" : 3.41425774958E11,
        #       "cpuCost" : 3.41425774958E11,
        #       "maxMemory" : 0.0,
        #       "networkCost" : 3.41425774958E11
        #     }
        #   }
        result = json.loads(cursor.fetchone()[0])
        estimate = result["estimate"]

        def humanize(value: Any, suffix: str) -> str:
            try:
                value = int(value)
            except ValueError:
                return str(value)

            prefixes = ["K", "M", "G", "T", "P", "E", "Z", "Y"]
            prefix = ""
            to_next_prefix = 1000
            while value > to_next_prefix and prefixes:
                prefix = prefixes.pop(0)
                value //= to_next_prefix

            return f"{value} {prefix}{suffix}"

        cost = {}
        columns = [
            ("outputRowCount", "Output count", " rows"),
            ("outputSizeInBytes", "Output size", "B"),
            ("cpuCost", "CPU cost", ""),
            ("maxMemory", "Max memory", "B"),
            ("networkCost", "Network cost", ""),
        ]
        for key, label, suffix in columns:
            if key in estimate:
                cost[label] = humanize(estimate[key], suffix)

        return cost
 def extract_tables(self, query):
     return ParsedQuery(query).tables
Beispiel #19
0
def execute_sql_statements(  # pylint: disable=too-many-arguments, too-many-locals, too-many-statements, too-many-branches
    query_id: int,
    rendered_query: str,
    return_results: bool,
    store_results: bool,
    user_name: Optional[str],
    session: Session,
    start_time: Optional[float],
    expand_data: bool,
    log_params: Optional[Dict[str, Any]],
) -> Optional[Dict[str, Any]]:
    """Executes the sql query returns the results."""
    if store_results and start_time:
        # only asynchronous queries
        stats_logger.timing("sqllab.query.time_pending",
                            now_as_float() - start_time)

    query = get_query(query_id, session)
    payload: Dict[str, Any] = dict(query_id=query_id)
    database = query.database
    db_engine_spec = database.db_engine_spec
    db_engine_spec.patch()

    if database.allow_run_async and not results_backend:
        raise SupersetErrorException(
            SupersetError(
                message=__("Results backend is not configured."),
                error_type=SupersetErrorType.
                RESULTS_BACKEND_NOT_CONFIGURED_ERROR,
                level=ErrorLevel.ERROR,
            ))

    # Breaking down into multiple statements
    parsed_query = ParsedQuery(rendered_query, strip_comments=True)
    if not db_engine_spec.run_multiple_statements_as_one:
        statements = parsed_query.get_statements()
        logger.info("Query %s: Executing %i statement(s)", str(query_id),
                    len(statements))
    else:
        statements = [rendered_query]
        logger.info("Query %s: Executing query as a single statement",
                    str(query_id))

    logger.info("Query %s: Set query to 'running'", str(query_id))
    query.status = QueryStatus.RUNNING
    query.start_running_time = now_as_float()
    session.commit()

    # Should we create a table or view from the select?
    if (query.select_as_cta and query.ctas_method == CtasMethod.TABLE
            and not parsed_query.is_valid_ctas()):
        raise SupersetErrorException(
            SupersetError(
                message=
                __("CTAS (create table as select) can only be run with a query where "
                   "the last statement is a SELECT. Please make sure your query has "
                   "a SELECT as its last statement. Then, try running your query "
                   "again."),
                error_type=SupersetErrorType.INVALID_CTAS_QUERY_ERROR,
                level=ErrorLevel.ERROR,
            ))
    if (query.select_as_cta and query.ctas_method == CtasMethod.VIEW
            and not parsed_query.is_valid_cvas()):
        raise SupersetErrorException(
            SupersetError(
                message=__(
                    "CVAS (create view as select) can only be run with a query with "
                    "a single SELECT statement. Please make sure your query has only "
                    "a SELECT statement. Then, try running your query again."),
                error_type=SupersetErrorType.INVALID_CVAS_QUERY_ERROR,
                level=ErrorLevel.ERROR,
            ))

    engine = database.get_sqla_engine(
        schema=query.schema,
        nullpool=True,
        user_name=user_name,
        source=QuerySource.SQL_LAB,
    )
    # Sharing a single connection and cursor across the
    # execution of all statements (if many)
    with closing(engine.raw_connection()) as conn:
        # closing the connection closes the cursor as well
        cursor = conn.cursor()
        cancel_query_id = db_engine_spec.get_cancel_query_id(cursor, query)
        if cancel_query_id is not None:
            query.set_extra_json_key(cancel_query_key, cancel_query_id)
            session.commit()
        statement_count = len(statements)
        for i, statement in enumerate(statements):
            # Check if stopped
            session.refresh(query)
            if query.status == QueryStatus.STOPPED:
                payload.update({"status": query.status})
                return payload

            # For CTAS we create the table only on the last statement
            apply_ctas = query.select_as_cta and (
                query.ctas_method == CtasMethod.VIEW or
                (query.ctas_method == CtasMethod.TABLE
                 and i == len(statements) - 1))

            # Run statement
            msg = f"Running statement {i+1} out of {statement_count}"
            logger.info("Query %s: %s", str(query_id), msg)
            query.set_extra_json_key("progress", msg)
            session.commit()
            try:
                result_set = execute_sql_statement(
                    statement,
                    query,
                    user_name,
                    session,
                    cursor,
                    log_params,
                    apply_ctas,
                )
            except SqlLabQueryStoppedException:
                payload.update({"status": QueryStatus.STOPPED})
                return payload
            except Exception as ex:  # pylint: disable=broad-except
                msg = str(ex)
                prefix_message = (f"[Statement {i+1} out of {statement_count}]"
                                  if statement_count > 1 else "")
                payload = handle_query_error(ex, query, session, payload,
                                             prefix_message)
                return payload

        # Commit the connection so CTA queries will create the table.
        conn.commit()

    # Success, updating the query entry in database
    query.rows = result_set.size
    query.progress = 100
    query.set_extra_json_key("progress", None)
    if query.select_as_cta:
        query.select_sql = database.select_star(
            query.tmp_table_name,
            schema=query.tmp_schema_name,
            limit=query.limit,
            show_cols=False,
            latest_partition=False,
        )
    query.end_time = now_as_float()

    use_arrow_data = store_results and cast(bool, results_backend_use_msgpack)
    data, selected_columns, all_columns, expanded_columns = _serialize_and_expand_data(
        result_set, db_engine_spec, use_arrow_data, expand_data)

    # TODO: data should be saved separately from metadata (likely in Parquet)
    payload.update({
        "status": QueryStatus.SUCCESS,
        "data": data,
        "columns": all_columns,
        "selected_columns": selected_columns,
        "expanded_columns": expanded_columns,
        "query": query.to_dict(),
    })
    payload["query"]["state"] = QueryStatus.SUCCESS

    if store_results and results_backend:
        key = str(uuid.uuid4())
        logger.info("Query %s: Storing results in results backend, key: %s",
                    str(query_id), key)
        with stats_timing("sqllab.query.results_backend_write", stats_logger):
            with stats_timing(
                    "sqllab.query.results_backend_write_serialization",
                    stats_logger):
                serialized_payload = _serialize_payload(
                    payload, cast(bool, results_backend_use_msgpack))
            cache_timeout = database.cache_timeout
            if cache_timeout is None:
                cache_timeout = config["CACHE_DEFAULT_TIMEOUT"]

            compressed = zlib_compress(serialized_payload)
            logger.debug("*** serialized payload size: %i",
                         getsizeof(serialized_payload))
            logger.debug("*** compressed payload size: %i",
                         getsizeof(compressed))
            results_backend.set(key, compressed, cache_timeout)
        query.results_key = key

    query.status = QueryStatus.SUCCESS
    session.commit()

    if return_results:
        # since we're returning results we need to create non-arrow data
        if use_arrow_data:
            (
                data,
                selected_columns,
                all_columns,
                expanded_columns,
            ) = _serialize_and_expand_data(result_set, db_engine_spec, False,
                                           expand_data)
            payload.update({
                "data": data,
                "columns": all_columns,
                "selected_columns": selected_columns,
                "expanded_columns": expanded_columns,
            })
        return payload

    return None
 def test_update_not_select(self):
     sql = ParsedQuery("UPDATE t1 SET col1 = NULL")
     self.assertEqual(False, sql.is_select())
     self.assertEqual(False, sql.is_readonly())
Beispiel #21
0
 def is_readonly_query(cls, parsed_query: ParsedQuery) -> bool:
     """Pessimistic readonly, 100% sure statement won't mutate anything"""
     return (parsed_query.is_select() or parsed_query.is_explain()
             or parsed_query.is_show())
    def test_explain(self):
        sql = ParsedQuery("EXPLAIN SELECT 1")

        self.assertEqual(True, sql.is_explain())
        self.assertEqual(False, sql.is_select())
        self.assertEqual(True, sql.is_readonly())
 def is_readonly_query(cls, parsed_query: ParsedQuery) -> bool:
     """Pessimistic readonly, 100% sure statement won't mutate anything"""
     return super().is_readonly_query(
         parsed_query) or parsed_query.is_show()
 def test_get_query_with_new_limit_comment_with_limit(self):
     sql = "SELECT * FROM birth_names -- SOME COMMENT WITH LIMIT 555"
     parsed = ParsedQuery(sql)
     newsql = parsed.set_or_update_query_limit(1000)
     self.assertEqual(newsql, sql + "\nLIMIT 1000")
def execute_sql_statement(
    sql_statement: str,
    query: Query,
    user_name: Optional[str],
    session: Session,
    cursor: Any,
    log_params: Optional[Dict[str, Any]],
) -> SupersetResultSet:
    """Executes a single SQL statement"""
    database = query.database
    db_engine_spec = database.db_engine_spec
    parsed_query = ParsedQuery(sql_statement)
    sql = parsed_query.stripped()

    if not parsed_query.is_readonly() and not database.allow_dml:
        raise SqlLabSecurityException(
            _("Only `SELECT` statements are allowed against this database")
        )
    if query.select_as_cta:
        if not parsed_query.is_select():
            raise SqlLabException(
                _(
                    "Only `SELECT` statements can be used with the CREATE TABLE "
                    "feature."
                )
            )
        if not query.tmp_table_name:
            start_dttm = datetime.fromtimestamp(query.start_time)
            query.tmp_table_name = "tmp_{}_table_{}".format(
                query.user_id, start_dttm.strftime("%Y_%m_%d_%H_%M_%S")
            )
        sql = parsed_query.as_create_table(
            query.tmp_table_name, schema_name=query.tmp_schema_name
        )
        query.select_as_cta_used = True

    # Do not apply limit to the CTA queries when SQLLAB_CTAS_NO_LIMIT is set to true
    if parsed_query.is_select() and not (
        query.select_as_cta_used and SQLLAB_CTAS_NO_LIMIT
    ):
        if SQL_MAX_ROW and (not query.limit or query.limit > SQL_MAX_ROW):
            query.limit = SQL_MAX_ROW
        if query.limit:
            sql = database.apply_limit_to_sql(sql, query.limit)

    # Hook to allow environment-specific mutation (usually comments) to the SQL
    if SQL_QUERY_MUTATOR:
        sql = SQL_QUERY_MUTATOR(sql, user_name, security_manager, database)

    try:
        if log_query:
            log_query(
                query.database.sqlalchemy_uri,
                query.executed_sql,
                query.schema,
                user_name,
                __name__,
                security_manager,
                log_params,
            )
        query.executed_sql = sql
        session.commit()
        with stats_timing("sqllab.query.time_executing_query", stats_logger):
            logger.debug("Query %d: Running query: %s", query.id, sql)
            db_engine_spec.execute(cursor, sql, async_=True)
            logger.debug("Query %d: Handling cursor", query.id)
            db_engine_spec.handle_cursor(cursor, query, session)

        with stats_timing("sqllab.query.time_fetching_results", stats_logger):
            logger.debug(
                "Query %d: Fetching data for query object: %s",
                query.id,
                str(query.to_dict()),
            )
            data = db_engine_spec.fetch_data(cursor, query.limit)

    except SoftTimeLimitExceeded as ex:
        logger.error("Query %d: Time limit exceeded", query.id)
        logger.debug("Query %d: %s", query.id, ex)
        raise SqlLabTimeoutException(
            "SQL Lab timeout. This environment's policy is to kill queries "
            "after {} seconds.".format(SQLLAB_TIMEOUT)
        )
    except Exception as ex:
        logger.error("Query %d: %s", query.id, type(ex))
        logger.debug("Query %d: %s", query.id, ex)
        raise SqlLabException(db_engine_spec.extract_error_message(ex))

    logger.debug("Query %d: Fetching cursor description", query.id)
    cursor_description = cursor.description
    return SupersetResultSet(data, cursor_description, db_engine_spec)
Beispiel #26
0
def execute_sql_statement(
        sql_statement, query, user_name, session,
        cursor, return_results=False):
    """Executes a single SQL statement"""
    database = query.database
    db_engine_spec = database.db_engine_spec
    parsed_query = ParsedQuery(sql_statement)
    sql = parsed_query.stripped()
    SQL_MAX_ROWS = app.config.get('SQL_MAX_ROW')

    if not parsed_query.is_readonly() and not database.allow_dml:
        raise SqlLabSecurityException(
            _('Only `SELECT` statements are allowed against this database'))
    if query.select_as_cta:
        if not parsed_query.is_select():
            raise SqlLabException(_(
                'Only `SELECT` statements can be used with the CREATE TABLE '
                'feature.'))
        if not query.tmp_table_name:
            start_dttm = datetime.fromtimestamp(query.start_time)
            query.tmp_table_name = 'tmp_{}_table_{}'.format(
                query.user_id, start_dttm.strftime('%Y_%m_%d_%H_%M_%S'))
        sql = parsed_query.as_create_table(query.tmp_table_name)
        query.select_as_cta_used = True
    if parsed_query.is_select():
        if SQL_MAX_ROWS and (not query.limit or query.limit > SQL_MAX_ROWS):
            query.limit = SQL_MAX_ROWS
        if query.limit:
            sql = database.apply_limit_to_sql(sql, query.limit)

    # Hook to allow environment-specific mutation (usually comments) to the SQL
    SQL_QUERY_MUTATOR = config.get('SQL_QUERY_MUTATOR')
    if SQL_QUERY_MUTATOR:
        sql = SQL_QUERY_MUTATOR(sql, user_name, security_manager, database)

    try:
        if log_query:
            log_query(
                query.database.sqlalchemy_uri,
                query.executed_sql,
                query.schema,
                user_name,
                __name__,
                security_manager,
            )
        query.executed_sql = sql
        with stats_timing('sqllab.query.time_executing_query', stats_logger):
            logging.info('Running query: \n{}'.format(sql))
            db_engine_spec.execute(cursor, sql, async_=True)
            logging.info('Handling cursor')
            db_engine_spec.handle_cursor(cursor, query, session)

        with stats_timing('sqllab.query.time_fetching_results', stats_logger):
            logging.debug('Fetching data for query object: {}'.format(query.to_dict()))
            data = db_engine_spec.fetch_data(cursor, query.limit)

    except SoftTimeLimitExceeded as e:
        logging.exception(e)
        raise SqlLabTimeoutException(
            "SQL Lab timeout. This environment's policy is to kill queries "
            'after {} seconds.'.format(SQLLAB_TIMEOUT))
    except Exception as e:
        logging.exception(e)
        raise SqlLabException(db_engine_spec.extract_error_message(e))

    logging.debug('Fetching cursor description')
    cursor_description = cursor.description
    return dataframe.SupersetDataFrame(data, cursor_description, db_engine_spec)
Beispiel #27
0
def execute_sql_statements(
    ctask, query_id, rendered_query, return_results=True, store_results=False,
    user_name=None, session=None, start_time=None,
):
    """Executes the sql query returns the results."""
    if store_results and start_time:
        # only asynchronous queries
        stats_logger.timing(
            'sqllab.query.time_pending', now_as_float() - start_time)

    query = get_query(query_id, session)
    payload = dict(query_id=query_id)
    database = query.database
    db_engine_spec = database.db_engine_spec
    db_engine_spec.patch()

    if store_results and not results_backend:
        raise SqlLabException("Results backend isn't configured.")

    # Breaking down into multiple statements
    parsed_query = ParsedQuery(rendered_query)
    statements = parsed_query.get_statements()
    logging.info(f'Executing {len(statements)} statement(s)')

    logging.info("Set query to 'running'")
    query.status = QueryStatus.RUNNING
    query.start_running_time = now_as_float()

    engine = database.get_sqla_engine(
        schema=query.schema,
        nullpool=True,
        user_name=user_name,
    )
    # Sharing a single connection and cursor across the
    # execution of all statements (if many)
    with closing(engine.raw_connection()) as conn:
        with closing(conn.cursor()) as cursor:
            statement_count = len(statements)
            for i, statement in enumerate(statements):
                # TODO CHECK IF STOPPED
                msg = f'Running statement {i+1} out of {statement_count}'
                logging.info(msg)
                query.set_extra_json_key('progress', msg)
                session.commit()
                is_last_statement = i == len(statements) - 1
                try:
                    cdf = execute_sql_statement(
                        statement, query, user_name, session, cursor,
                        return_results=is_last_statement and return_results)
                    msg = f'Running statement {i+1} out of {statement_count}'
                except Exception as e:
                    msg = str(e)
                    if statement_count > 1:
                        msg = f'[Statement {i+1} out of {statement_count}] ' + msg
                    payload = handle_query_error(msg, query, session, payload)
                    return payload

    # Success, updating the query entry in database
    query.rows = cdf.size
    query.progress = 100
    query.set_extra_json_key('progress', None)
    query.status = QueryStatus.SUCCESS
    if query.select_as_cta:
        query.select_sql = database.select_star(
            query.tmp_table_name,
            limit=query.limit,
            schema=database.force_ctas_schema,
            show_cols=False,
            latest_partition=False)
    query.end_time = now_as_float()
    session.commit()

    payload.update({
        'status': query.status,
        'data': cdf.data if cdf.data else [],
        'columns': cdf.columns if cdf.columns else [],
        'query': query.to_dict(),
    })

    if store_results:
        key = str(uuid.uuid4())
        logging.info(f'Storing results in results backend, key: {key}')
        with stats_timing('sqllab.query.results_backend_write', stats_logger):
            json_payload = json.dumps(
                payload, default=json_iso_dttm_ser, ignore_nan=True)
            cache_timeout = database.cache_timeout
            if cache_timeout is None:
                cache_timeout = config.get('CACHE_DEFAULT_TIMEOUT', 0)
            results_backend.set(key, zlib_compress(json_payload), cache_timeout)
        query.results_key = key
    session.commit()

    if return_results:
        return payload
Beispiel #28
0
def test_is_explain() -> None:
    """
    Test that ``EXPLAIN`` is detected correctly.
    """
    assert ParsedQuery("EXPLAIN SELECT 1").is_explain() is True
    assert ParsedQuery("EXPLAIN SELECT 1").is_select() is False

    assert (
        ParsedQuery(
            """
-- comment
EXPLAIN select * from table
-- comment 2
"""
        ).is_explain()
        is True
    )

    assert (
        ParsedQuery(
            """
-- comment
EXPLAIN select * from table
where col1 = 'something'
-- comment 2

-- comment 3
EXPLAIN select * from table
where col1 = 'something'
-- comment 4
"""
        ).is_explain()
        is True
    )

    assert (
        ParsedQuery(
            """
-- This is a comment
    -- this is another comment but with a space in the front
EXPLAIN SELECT * FROM TABLE
"""
        ).is_explain()
        is True
    )

    assert (
        ParsedQuery(
            """
/* This is a comment
     with stars instead */
EXPLAIN SELECT * FROM TABLE
"""
        ).is_explain()
        is True
    )

    assert (
        ParsedQuery(
            """
-- comment
select * from table
where col1 = 'something'
-- comment 2
"""
        ).is_explain()
        is False
    )