def test_basic_breakdown_statements() -> None: """ Test that multiple statements are parsed correctly. """ query = ParsedQuery( """ SELECT * FROM birth_names; SELECT * FROM birth_names LIMIT 1; """ ) assert query.get_statements() == [ "SELECT * FROM birth_names", "SELECT * FROM birth_names LIMIT 1", ]
def test_kql_is_select_query(app_context: AppContext, kql: str, expected: bool) -> None: """ Make sure that KQL dialect consider only statements that do not start with "." (dot) as a SELECT statements """ from superset.db_engine_specs.kusto import KustoKqlEngineSpec from superset.sql_parse import ParsedQuery parsed_query = ParsedQuery(kql) is_select = KustoKqlEngineSpec.is_select_query(parsed_query) assert expected == is_select
def test_kql_is_readonly_query( app_context: AppContext, kql: str, expected: bool ) -> None: """ Make sure that KQL dialect consider only SELECT statements as read-only """ from superset.db_engine_specs.kusto import KustoKqlEngineSpec from superset.sql_parse import ParsedQuery parsed_query = ParsedQuery(kql) is_readonly = KustoKqlEngineSpec.is_readonly_query(parsed_query) assert expected == is_readonly
def test_is_valid_cvas(self): """A valid CVAS has a single SELECT statement""" query = "SELECT * FROM table" parsed = ParsedQuery(query, strip_comments=True) assert parsed.is_valid_cvas() query = """ -- comment SELECT * FROM table -- comment 2 """ parsed = ParsedQuery(query, strip_comments=True) assert parsed.is_valid_cvas() query = """ -- comment SET @value = 42; SELECT @value as foo; -- comment 2 """ parsed = ParsedQuery(query, strip_comments=True) assert not parsed.is_valid_cvas() query = """ -- comment EXPLAIN SELECT * FROM table -- comment 2 """ parsed = ParsedQuery(query, strip_comments=True) assert not parsed.is_valid_ctas() query = """ SELECT * FROM table; INSERT INTO TABLE (foo) VALUES (42); """ parsed = ParsedQuery(query, strip_comments=True) assert not parsed.is_valid_ctas()
def test_messy_breakdown_statements() -> None: """ Test the messy multiple statements are parsed correctly. """ query = ParsedQuery(""" SELECT 1;\t\n\n\n \t \t\nSELECT 2; SELECT * FROM birth_names;;; SELECT * FROM birth_names LIMIT 1 """) assert query.get_statements() == [ "SELECT 1", "SELECT 2", "SELECT * FROM birth_names", "SELECT * FROM birth_names LIMIT 1", ]
def process_statement(cls, statement: str, database: "Database", user_name: str) -> str: """ Process a SQL statement by stripping and mutating it. :param statement: A single SQL statement :param database: Database instance :param username: Effective username :return: Dictionary with different costs """ parsed_query = ParsedQuery(statement) sql = parsed_query.stripped() sql_query_mutator = config["SQL_QUERY_MUTATOR"] if sql_query_mutator: sql = sql_query_mutator(sql, user_name, security_manager, database) return sql
def test_messy_breakdown_statements(self): multi_sql = """ SELECT 1;\t\n\n\n \t \t\nSELECT 2; SELECT * FROM birth_names;;; SELECT * FROM birth_names LIMIT 1 """ parsed = ParsedQuery(multi_sql) statements = parsed.get_statements() self.assertEqual(len(statements), 4) expected = [ "SELECT 1", "SELECT 2", "SELECT * FROM birth_names", "SELECT * FROM birth_names LIMIT 1", ] self.assertEqual(statements, expected)
def test_cte_is_select() -> None: """ Some CTEs are not correctly identified as SELECTS. """ # `AS(` gets parsed as a function sql = ParsedQuery("""WITH foo AS( SELECT FLOOR(__time TO WEEK) AS "week", name, COUNT(DISTINCT user_id) AS "unique_users" FROM "druid"."my_table" GROUP BY 1,2 ) SELECT f.week, f.name, f.unique_users FROM foo f""") assert sql.is_select()
def get_virtual_table_metadata(dataset: "SqlaTable") -> List[Dict[str, str]]: """Use SQLparser to get virtual dataset metadata""" if not dataset.sql: raise SupersetGenericDBErrorException( message=_("Virtual dataset query cannot be empty"), ) db_engine_spec = dataset.database.db_engine_spec engine = dataset.database.get_sqla_engine(schema=dataset.schema) sql = dataset.get_template_processor().process_template( dataset.sql, **dataset.template_params_dict ) parsed_query = ParsedQuery(sql) if not db_engine_spec.is_readonly_query(parsed_query): raise SupersetSecurityException( SupersetError( error_type=SupersetErrorType.DATASOURCE_SECURITY_ACCESS_ERROR, message=_("Only `SELECT` statements are allowed"), level=ErrorLevel.ERROR, ) ) statements = parsed_query.get_statements() if len(statements) > 1: raise SupersetSecurityException( SupersetError( error_type=SupersetErrorType.DATASOURCE_SECURITY_ACCESS_ERROR, message=_("Only single queries supported"), level=ErrorLevel.ERROR, ) ) # TODO(villebro): refactor to use same code that's used by # sql_lab.py:execute_sql_statements try: with closing(engine.raw_connection()) as conn: cursor = conn.cursor() query = dataset.database.apply_limit_to_sql(statements[0]) db_engine_spec.execute(cursor, query) result = db_engine_spec.fetch_data(cursor, limit=1) result_set = SupersetResultSet(result, cursor.description, db_engine_spec) cols = result_set.columns except Exception as exc: raise SupersetGenericDBErrorException(message=str(exc)) return cols
def validate( cls, sql: str, schema: str, database: Any, ) -> List[SQLValidationAnnotation]: """ Presto supports query-validation queries by running them with a prepended explain. For example, "SELECT 1 FROM default.mytable" becomes "EXPLAIN (TYPE VALIDATE) SELECT 1 FROM default.mytable. """ user_name = g.user.username if g.user else None parsed_query = ParsedQuery(sql) statements = parsed_query.get_statements() logging.info(f'Validating {len(statements)} statement(s)') engine = database.get_sqla_engine( schema=schema, nullpool=True, user_name=user_name, source=sources.get('sql_lab', None), ) # Sharing a single connection and cursor across the # execution of all statements (if many) annotations: List[SQLValidationAnnotation] = [] with closing(engine.raw_connection()) as conn: with closing(conn.cursor()) as cursor: for statement in parsed_query.get_statements(): annotation = cls.validate_statement( statement, database, cursor, user_name, ) if annotation: annotations.append(annotation) logging.debug(f'Validation found {len(annotations)} error(s)') return annotations
def test_is_explain() -> None: """ Test that ``EXPLAIN`` is detected correctly. """ assert ParsedQuery("EXPLAIN SELECT 1").is_explain() is True assert ParsedQuery("EXPLAIN SELECT 1").is_select() is False assert (ParsedQuery(""" -- comment EXPLAIN select * from table -- comment 2 """).is_explain() is True) assert (ParsedQuery(""" -- comment EXPLAIN select * from table where col1 = 'something' -- comment 2 -- comment 3 EXPLAIN select * from table where col1 = 'something' -- comment 4 """).is_explain() is True) assert (ParsedQuery(""" -- This is a comment -- this is another comment but with a space in the front EXPLAIN SELECT * FROM TABLE """).is_explain() is True) assert (ParsedQuery(""" /* This is a comment with stars instead */ EXPLAIN SELECT * FROM TABLE """).is_explain() is True) assert (ParsedQuery(""" -- comment select * from table where col1 = 'something' -- comment 2 """).is_explain() is False)
def estimate_statement_cost( # pylint: disable=too-many-locals cls, statement: str, database: "Database", cursor: Any, user_name: str) -> Dict[str, Any]: """ Run a SQL query that estimates the cost of a given statement. :param statement: A single SQL statement :param database: Database instance :param cursor: Cursor instance :param username: Effective username :return: JSON response from Presto """ parsed_query = ParsedQuery(statement) sql = parsed_query.stripped() sql_query_mutator = config["SQL_QUERY_MUTATOR"] if sql_query_mutator: sql = sql_query_mutator(sql, user_name, security_manager, database) sql = f"EXPLAIN (TYPE IO, FORMAT JSON) {sql}" cursor.execute(sql) # the output from Presto is a single column and a single row containing # JSON: # # { # ... # "estimate" : { # "outputRowCount" : 8.73265878E8, # "outputSizeInBytes" : 3.41425774958E11, # "cpuCost" : 3.41425774958E11, # "maxMemory" : 0.0, # "networkCost" : 3.41425774958E11 # } # } result = json.loads(cursor.fetchone()[0]) return result
def extract_tables(self, query): return ParsedQuery(query).tables
def execute_sql_statements( query_id, rendered_query, return_results=True, store_results=False, user_name=None, session=None, start_time=None, expand_data=False, ): # pylint: disable=too-many-arguments, too-many-locals, too-many-statements """Executes the sql query returns the results.""" if store_results and start_time: # only asynchronous queries stats_logger.timing("sqllab.query.time_pending", now_as_float() - start_time) query = get_query(query_id, session) payload = dict(query_id=query_id) database = query.database db_engine_spec = database.db_engine_spec db_engine_spec.patch() if database.allow_run_async and not results_backend: raise SqlLabException("Results backend isn't configured.") # Breaking down into multiple statements parsed_query = ParsedQuery(rendered_query) statements = parsed_query.get_statements() logger.info(f"Query {query_id}: Executing {len(statements)} statement(s)") logger.info(f"Query {query_id}: Set query to 'running'") query.status = QueryStatus.RUNNING query.start_running_time = now_as_float() session.commit() engine = database.get_sqla_engine( schema=query.schema, nullpool=True, user_name=user_name, source=sources.get("sql_lab", None), ) # Sharing a single connection and cursor across the # execution of all statements (if many) with closing(engine.raw_connection()) as conn: with closing(conn.cursor()) as cursor: statement_count = len(statements) for i, statement in enumerate(statements): # Check if stopped query = get_query(query_id, session) if query.status == QueryStatus.STOPPED: return None # Run statement msg = f"Running statement {i+1} out of {statement_count}" logger.info(f"Query {query_id}: {msg}") query.set_extra_json_key("progress", msg) session.commit() try: cdf = execute_sql_statement(statement, query, user_name, session, cursor) except Exception as e: # pylint: disable=broad-except msg = str(e) if statement_count > 1: msg = f"[Statement {i+1} out of {statement_count}] " + msg payload = handle_query_error(msg, query, session, payload) return payload # Success, updating the query entry in database query.rows = cdf.size query.progress = 100 query.set_extra_json_key("progress", None) if query.select_as_cta: query.select_sql = database.select_star( query.tmp_table_name, limit=query.limit, schema=database.force_ctas_schema, show_cols=False, latest_partition=False, ) query.end_time = now_as_float() data, selected_columns, all_columns, expanded_columns = _serialize_and_expand_data( cdf, db_engine_spec, store_results and results_backend_use_msgpack, expand_data) payload.update({ "status": QueryStatus.SUCCESS, "data": data, "columns": all_columns, "selected_columns": selected_columns, "expanded_columns": expanded_columns, "query": query.to_dict(), }) payload["query"]["state"] = QueryStatus.SUCCESS if store_results and results_backend: key = str(uuid.uuid4()) logger.info( f"Query {query_id}: Storing results in results backend, key: {key}" ) with stats_timing("sqllab.query.results_backend_write", stats_logger): with stats_timing( "sqllab.query.results_backend_write_serialization", stats_logger): serialized_payload = _serialize_payload( payload, results_backend_use_msgpack) cache_timeout = database.cache_timeout if cache_timeout is None: cache_timeout = config["CACHE_DEFAULT_TIMEOUT"] compressed = zlib_compress(serialized_payload) logger.debug( f"*** serialized payload size: {getsizeof(serialized_payload)}" ) logger.debug( f"*** compressed payload size: {getsizeof(compressed)}") results_backend.set(key, compressed, cache_timeout) query.results_key = key query.status = QueryStatus.SUCCESS session.commit() if return_results: return payload return None
def execute_sql_statement(sql_statement, query, user_name, session, cursor): """Executes a single SQL statement""" database = query.database db_engine_spec = database.db_engine_spec parsed_query = ParsedQuery(sql_statement) sql = parsed_query.stripped() if not parsed_query.is_readonly() and not database.allow_dml: raise SqlLabSecurityException( _("Only `SELECT` statements are allowed against this database")) if query.select_as_cta: if not parsed_query.is_select(): raise SqlLabException( _("Only `SELECT` statements can be used with the CREATE TABLE " "feature.")) if not query.tmp_table_name: start_dttm = datetime.fromtimestamp(query.start_time) query.tmp_table_name = "tmp_{}_table_{}".format( query.user_id, start_dttm.strftime("%Y_%m_%d_%H_%M_%S")) sql = parsed_query.as_create_table(query.tmp_table_name) query.select_as_cta_used = True if parsed_query.is_select(): if SQL_MAX_ROW and (not query.limit or query.limit > SQL_MAX_ROW): query.limit = SQL_MAX_ROW if query.limit: sql = database.apply_limit_to_sql(sql, query.limit) # Hook to allow environment-specific mutation (usually comments) to the SQL if SQL_QUERY_MUTATOR: sql = SQL_QUERY_MUTATOR(sql, user_name, security_manager, database) try: if log_query: log_query( query.database.sqlalchemy_uri, query.executed_sql, query.schema, user_name, __name__, security_manager, ) query.executed_sql = sql session.commit() with stats_timing("sqllab.query.time_executing_query", stats_logger): logger.info(f"Query {query.id}: Running query: \n{sql}") db_engine_spec.execute(cursor, sql, async_=True) logger.info(f"Query {query.id}: Handling cursor") db_engine_spec.handle_cursor(cursor, query, session) with stats_timing("sqllab.query.time_fetching_results", stats_logger): logger.debug( "Query %d: Fetching data for query object: %s", query.id, str(query.to_dict()), ) data = db_engine_spec.fetch_data(cursor, query.limit) except SoftTimeLimitExceeded as e: logger.exception(f"Query {query.id}: {e}") raise SqlLabTimeoutException( "SQL Lab timeout. This environment's policy is to kill queries " "after {} seconds.".format(SQLLAB_TIMEOUT)) except Exception as e: logger.exception(f"Query {query.id}: {e}") raise SqlLabException(db_engine_spec.extract_error_message(e)) logger.debug(f"Query {query.id}: Fetching cursor description") cursor_description = cursor.description return SupersetDataFrame(data, cursor_description, db_engine_spec)
def sql_tables(self) -> List[Table]: return list(ParsedQuery(self.sql).tables)
def estimate_statement_cost( # pylint: disable=too-many-locals cls, statement: str, database, cursor, user_name: str) -> Dict[str, str]: """ Generate a SQL query that estimates the cost of a given statement. :param statement: A single SQL statement :param database: Database instance :param cursor: Cursor instance :param username: Effective username """ parsed_query = ParsedQuery(statement) sql = parsed_query.stripped() sql_query_mutator = config["SQL_QUERY_MUTATOR"] if sql_query_mutator: sql = sql_query_mutator(sql, user_name, security_manager, database) sql = f"EXPLAIN (TYPE IO, FORMAT JSON) {sql}" cursor.execute(sql) # the output from Presto is a single column and a single row containing # JSON: # # { # ... # "estimate" : { # "outputRowCount" : 8.73265878E8, # "outputSizeInBytes" : 3.41425774958E11, # "cpuCost" : 3.41425774958E11, # "maxMemory" : 0.0, # "networkCost" : 3.41425774958E11 # } # } result = json.loads(cursor.fetchone()[0]) estimate = result["estimate"] def humanize(value: Any, suffix: str) -> str: try: value = int(value) except ValueError: return str(value) prefixes = ["K", "M", "G", "T", "P", "E", "Z", "Y"] prefix = "" to_next_prefix = 1000 while value > to_next_prefix and prefixes: prefix = prefixes.pop(0) value //= to_next_prefix return f"{value} {prefix}{suffix}" cost = {} columns = [ ("outputRowCount", "Output count", " rows"), ("outputSizeInBytes", "Output size", "B"), ("cpuCost", "CPU cost", ""), ("maxMemory", "Max memory", "B"), ("networkCost", "Network cost", ""), ] for key, label, suffix in columns: if key in estimate: cost[label] = humanize(estimate[key], suffix) return cost
def test_get_query_with_new_limit_comment_with_limit(self): sql = "SELECT * FROM birth_names -- SOME COMMENT WITH LIMIT 555" parsed = ParsedQuery(sql) newsql = parsed.set_or_update_query_limit(1000) self.assertEqual(newsql, sql + "\nLIMIT 1000")
def is_readonly(sql: str) -> bool: return PrestoEngineSpec.is_readonly_query(ParsedQuery(sql))
def test_update_not_select(self): sql = ParsedQuery("UPDATE t1 SET col1 = NULL") self.assertEqual(False, sql.is_select())
def execute_sql_statement( sql_statement: str, query: Query, user_name: Optional[str], session: Session, cursor: Any, log_params: Optional[Dict[str, Any]], apply_ctas: bool = False, ) -> SupersetResultSet: """Executes a single SQL statement""" database = query.database db_engine_spec = database.db_engine_spec parsed_query = ParsedQuery(sql_statement) sql = parsed_query.stripped() if not db_engine_spec.is_readonly_query( parsed_query) and not database.allow_dml: raise SqlLabSecurityException( _("Only `SELECT` statements are allowed against this database")) if apply_ctas: if not query.tmp_table_name: start_dttm = datetime.fromtimestamp(query.start_time) query.tmp_table_name = "tmp_{}_table_{}".format( query.user_id, start_dttm.strftime("%Y_%m_%d_%H_%M_%S")) sql = parsed_query.as_create_table( query.tmp_table_name, schema_name=query.tmp_schema_name, method=query.ctas_method, ) query.select_as_cta_used = True # Do not apply limit to the CTA queries when SQLLAB_CTAS_NO_LIMIT is set to true if parsed_query.is_select() and not (query.select_as_cta_used and SQLLAB_CTAS_NO_LIMIT): if SQL_MAX_ROW and (not query.limit or query.limit > SQL_MAX_ROW): query.limit = SQL_MAX_ROW if query.limit: sql = database.apply_limit_to_sql(sql, query.limit) # Hook to allow environment-specific mutation (usually comments) to the SQL if SQL_QUERY_MUTATOR: sql = SQL_QUERY_MUTATOR(sql, user_name, security_manager, database) try: if log_query: log_query( query.database.sqlalchemy_uri, query.executed_sql, query.schema, user_name, __name__, security_manager, log_params, ) query.executed_sql = sql session.commit() with stats_timing("sqllab.query.time_executing_query", stats_logger): logger.debug("Query %d: Running query: %s", query.id, sql) db_engine_spec.execute(cursor, sql, async_=True) logger.debug("Query %d: Handling cursor", query.id) db_engine_spec.handle_cursor(cursor, query, session) with stats_timing("sqllab.query.time_fetching_results", stats_logger): logger.debug( "Query %d: Fetching data for query object: %s", query.id, str(query.to_dict()), ) data = db_engine_spec.fetch_data(cursor, query.limit) except SoftTimeLimitExceeded as ex: logger.error("Query %d: Time limit exceeded", query.id) logger.debug("Query %d: %s", query.id, ex) raise SqlLabTimeoutException( "SQL Lab timeout. This environment's policy is to kill queries " "after {} seconds.".format(SQLLAB_TIMEOUT)) except Exception as ex: logger.error("Query %d: %s", query.id, type(ex)) logger.debug("Query %d: %s", query.id, ex) raise SqlLabException(db_engine_spec.extract_error_message(ex)) logger.debug("Query %d: Fetching cursor description", query.id) cursor_description = cursor.description return SupersetResultSet(data, cursor_description, db_engine_spec)
def execute_sql_statement( sql_statement: str, query: Query, user_name: Optional[str], session: Session, cursor: Any, log_params: Optional[Dict[str, Any]], apply_ctas: bool = False, ) -> SupersetResultSet: """Executes a single SQL statement""" database = query.database db_engine_spec = database.db_engine_spec parsed_query = ParsedQuery(sql_statement) sql = parsed_query.stripped() # This is a test to see if the query is being # limited by either the dropdown or the sql. # We are testing to see if more rows exist than the limit. increased_limit = None if query.limit is None else query.limit + 1 if not db_engine_spec.is_readonly_query(parsed_query) and not database.allow_dml: raise SupersetErrorException( SupersetError( message=__("Only SELECT statements are allowed against this database."), error_type=SupersetErrorType.DML_NOT_ALLOWED_ERROR, level=ErrorLevel.ERROR, ) ) if apply_ctas: if not query.tmp_table_name: start_dttm = datetime.fromtimestamp(query.start_time) query.tmp_table_name = "tmp_{}_table_{}".format( query.user_id, start_dttm.strftime("%Y_%m_%d_%H_%M_%S") ) sql = parsed_query.as_create_table( query.tmp_table_name, schema_name=query.tmp_schema_name, method=query.ctas_method, ) query.select_as_cta_used = True # Do not apply limit to the CTA queries when SQLLAB_CTAS_NO_LIMIT is set to true if db_engine_spec.is_select_query(parsed_query) and not ( query.select_as_cta_used and SQLLAB_CTAS_NO_LIMIT ): if SQL_MAX_ROW and (not query.limit or query.limit > SQL_MAX_ROW): query.limit = SQL_MAX_ROW if query.limit: # We are fetching one more than the requested limit in order # to test whether there are more rows than the limit. # Later, the extra row will be dropped before sending # the results back to the user. sql = database.apply_limit_to_sql(sql, increased_limit, force=True) # Hook to allow environment-specific mutation (usually comments) to the SQL sql = SQL_QUERY_MUTATOR(sql, user_name, security_manager, database) try: query.executed_sql = sql if log_query: log_query( query.database.sqlalchemy_uri, query.executed_sql, query.schema, user_name, __name__, security_manager, log_params, ) session.commit() with stats_timing("sqllab.query.time_executing_query", stats_logger): logger.debug("Query %d: Running query: %s", query.id, sql) db_engine_spec.execute(cursor, sql, async_=True) logger.debug("Query %d: Handling cursor", query.id) db_engine_spec.handle_cursor(cursor, query, session) with stats_timing("sqllab.query.time_fetching_results", stats_logger): logger.debug( "Query %d: Fetching data for query object: %s", query.id, str(query.to_dict()), ) data = db_engine_spec.fetch_data(cursor, increased_limit) if query.limit is None or len(data) <= query.limit: query.limiting_factor = LimitingFactor.NOT_LIMITED else: # return 1 row less than increased_query data = data[:-1] except SoftTimeLimitExceeded as ex: logger.warning("Query %d: Time limit exceeded", query.id) logger.debug("Query %d: %s", query.id, ex) raise SupersetErrorException( SupersetError( message=__( f"The query was killed after {SQLLAB_TIMEOUT} seconds. It might " "be too complex, or the database might be under heavy load." ), error_type=SupersetErrorType.SQLLAB_TIMEOUT_ERROR, level=ErrorLevel.ERROR, ) ) except Exception as ex: logger.error("Query %d: %s", query.id, type(ex), exc_info=True) logger.debug("Query %d: %s", query.id, ex) raise SqlLabException(db_engine_spec.extract_error_message(ex)) logger.debug("Query %d: Fetching cursor description", query.id) cursor_description = cursor.description return SupersetResultSet(data, cursor_description, db_engine_spec)
def execute_sql_statements( ctask, query_id, rendered_query, return_results=True, store_results=False, user_name=None, session=None, start_time=None, ): """Executes the sql query returns the results.""" if store_results and start_time: # only asynchronous queries stats_logger.timing("sqllab.query.time_pending", now_as_float() - start_time) query = get_query(query_id, session) payload = dict(query_id=query_id) database = query.database db_engine_spec = database.db_engine_spec db_engine_spec.patch() if store_results and not results_backend: raise SqlLabException("Results backend isn't configured.") # Breaking down into multiple statements parsed_query = ParsedQuery(rendered_query) statements = parsed_query.get_statements() logging.info(f"Executing {len(statements)} statement(s)") logging.info("Set query to 'running'") query.status = QueryStatus.RUNNING query.start_running_time = now_as_float() engine = database.get_sqla_engine( schema=query.schema, nullpool=True, user_name=user_name, source=sources.get("sql_lab", None), ) # Sharing a single connection and cursor across the # execution of all statements (if many) with closing(engine.raw_connection()) as conn: with closing(conn.cursor()) as cursor: query.connection_id = db_engine_spec.get_connection_id(cursor) session.commit() statement_count = len(statements) for i, statement in enumerate(statements): # check if the query was stopped session.refresh(query) if query.status == QueryStatus.STOPPED: payload.update({"status": query.status}) return payload msg = f"Running statement {i+1} out of {statement_count}" logging.info(msg) query.set_extra_json_key("progress", msg) session.commit() try: cdf = execute_sql_statement( statement, query, user_name, session, cursor ) msg = f"Running statement {i+1} out of {statement_count}" except Exception as e: # query can be stopped in another thread/worker # but in synchronized mode it may lead to an error # skip error the error in such case session.refresh(query) if query.status == QueryStatus.STOPPED: payload.update({"status": query.status}) return payload msg = str(e) if statement_count > 1: msg = f"[Statement {i+1} out of {statement_count}] " + msg payload = handle_query_error(msg, query, session, payload) return payload # Success, updating the query entry in database query.rows = cdf.size query.progress = 100 query.set_extra_json_key("progress", None) query.connection_id = None if query.select_as_cta: query.select_sql = database.select_star( query.tmp_table_name, limit=query.limit, schema=database.force_ctas_schema, show_cols=False, latest_partition=False, ) query.end_time = now_as_float() selected_columns = cdf.columns or [] data = cdf.data or [] all_columns, data, expanded_columns = db_engine_spec.expand_data( selected_columns, data ) payload.update( { "status": QueryStatus.SUCCESS, "data": data, "columns": all_columns, "selected_columns": selected_columns, "expanded_columns": expanded_columns, "query": query.to_dict(), } ) payload["query"]["state"] = QueryStatus.SUCCESS # go over each row, find bytes columns that start with the magic UAST # sequence b'\x00bgr', and replace it with a string containing the # UAST in JSON for row in payload["data"]: for k, v in row.items(): if isinstance(v, bytes) and len(v) > 4 and v[0:4] == b"\x00bgr": try: ctx = decode(v, format=0) row[k] = json.dumps(ctx.load()) except Exception: pass if store_results: key = str(uuid.uuid4()) logging.info(f"Storing results in results backend, key: {key}") with stats_timing("sqllab.query.results_backend_write", stats_logger): json_payload = json.dumps( payload, default=json_iso_dttm_ser, ignore_nan=True ) cache_timeout = database.cache_timeout if cache_timeout is None: cache_timeout = config.get("CACHE_DEFAULT_TIMEOUT", 0) results_backend.set(key, zlib_compress(json_payload), cache_timeout) query.results_key = key query.status = QueryStatus.SUCCESS session.commit() if return_results: return payload
def validate_statement( cls, statement, database, cursor, user_name, ) -> Optional[SQLValidationAnnotation]: # pylint: disable=too-many-locals db_engine_spec = database.db_engine_spec parsed_query = ParsedQuery(statement) sql = parsed_query.stripped() # Hook to allow environment-specific mutation (usually comments) to the SQL # pylint: disable=invalid-name SQL_QUERY_MUTATOR = config.get('SQL_QUERY_MUTATOR') if SQL_QUERY_MUTATOR: sql = SQL_QUERY_MUTATOR(sql, user_name, security_manager, database) # Transform the final statement to an explain call before sending it on # to presto to validate sql = f'EXPLAIN (TYPE VALIDATE) {sql}' # Invoke the query against presto. NB this deliberately doesn't use the # engine spec's handle_cursor implementation since we don't record # these EXPLAIN queries done in validation as proper Query objects # in the superset ORM. from pyhive.exc import DatabaseError try: db_engine_spec.execute(cursor, sql) polled = cursor.poll() while polled: logging.info('polling presto for validation progress') stats = polled.get('stats', {}) if stats: state = stats.get('state') if state == 'FINISHED': break time.sleep(0.2) polled = cursor.poll() db_engine_spec.fetch_data(cursor, MAX_ERROR_ROWS) return None except DatabaseError as db_error: # The pyhive presto client yields EXPLAIN (TYPE VALIDATE) responses # as though they were normal queries. In other words, it doesn't # know that errors here are not exceptional. To map this back to # ordinary control flow, we have to trap the category of exception # raised by the underlying client, match the exception arguments # pyhive provides against the shape of dictionary for a presto query # invalid error, and restructure that error as an annotation we can # return up. # Confirm the first element in the DatabaseError constructor is a # dictionary with error information. This is currently provided by # the pyhive client, but may break if their interface changes when # we update at some point in the future. if not db_error.args or not isinstance(db_error.args[0], dict): raise PrestoSQLValidationError( 'The pyhive presto client returned an unhandled ' 'database error.', ) from db_error error_args: Dict[str, Any] = db_error.args[0] # Confirm the two fields we need to be able to present an annotation # are present in the error response -- a message, and a location. if 'message' not in error_args: raise PrestoSQLValidationError( 'The pyhive presto client did not report an error message', ) from db_error if 'errorLocation' not in error_args: raise PrestoSQLValidationError( 'The pyhive presto client did not report an error location', ) from db_error # Pylint is confused about the type of error_args, despite the hints # and checks above. # pylint: disable=invalid-sequence-index message = error_args['message'] err_loc = error_args['errorLocation'] line_number = err_loc.get('lineNumber', None) start_column = err_loc.get('columnNumber', None) end_column = err_loc.get('columnNumber', None) return SQLValidationAnnotation( message=message, line_number=line_number, start_column=start_column, end_column=end_column, ) except Exception as e: logging.exception(f'Unexpected error running validation query: {e}') raise e
def test_explain(self): sql = ParsedQuery("EXPLAIN SELECT 1") self.assertEqual(True, sql.is_explain()) self.assertEqual(False, sql.is_select())
def test_update() -> None: """ Test that ``UPDATE`` is not detected as ``SELECT``. """ assert ParsedQuery("UPDATE t1 SET col1 = NULL").is_select() is False
def execute_sql_statements( # pylint: disable=too-many-arguments, too-many-locals, too-many-statements, too-many-branches query_id: int, rendered_query: str, return_results: bool, store_results: bool, user_name: Optional[str], session: Session, start_time: Optional[float], expand_data: bool, log_params: Optional[Dict[str, Any]], ) -> Optional[Dict[str, Any]]: """Executes the sql query returns the results.""" if store_results and start_time: # only asynchronous queries stats_logger.timing("sqllab.query.time_pending", now_as_float() - start_time) query = get_query(query_id, session) payload: Dict[str, Any] = dict(query_id=query_id) database = query.database db_engine_spec = database.db_engine_spec db_engine_spec.patch() if database.allow_run_async and not results_backend: raise SupersetErrorException( SupersetError( message=__("Results backend is not configured."), error_type=SupersetErrorType.RESULTS_BACKEND_NOT_CONFIGURED_ERROR, level=ErrorLevel.ERROR, ) ) # Breaking down into multiple statements parsed_query = ParsedQuery(rendered_query, strip_comments=True) if not db_engine_spec.run_multiple_statements_as_one: statements = parsed_query.get_statements() logger.info( "Query %s: Executing %i statement(s)", str(query_id), len(statements) ) else: statements = [rendered_query] logger.info("Query %s: Executing query as a single statement", str(query_id)) logger.info("Query %s: Set query to 'running'", str(query_id)) query.status = QueryStatus.RUNNING query.start_running_time = now_as_float() session.commit() # Should we create a table or view from the select? if ( query.select_as_cta and query.ctas_method == CtasMethod.TABLE and not parsed_query.is_valid_ctas() ): raise SupersetErrorException( SupersetError( message=__( "CTAS (create table as select) can only be run with a query where " "the last statement is a SELECT. Please make sure your query has " "a SELECT as its last statement. Then, try running your query " "again." ), error_type=SupersetErrorType.INVALID_CTAS_QUERY_ERROR, level=ErrorLevel.ERROR, ) ) if ( query.select_as_cta and query.ctas_method == CtasMethod.VIEW and not parsed_query.is_valid_cvas() ): raise SupersetErrorException( SupersetError( message=__( "CVAS (create view as select) can only be run with a query with " "a single SELECT statement. Please make sure your query has only " "a SELECT statement. Then, try running your query again." ), error_type=SupersetErrorType.INVALID_CVAS_QUERY_ERROR, level=ErrorLevel.ERROR, ) ) engine = database.get_sqla_engine( schema=query.schema, nullpool=True, user_name=user_name, source=QuerySource.SQL_LAB, ) # Sharing a single connection and cursor across the # execution of all statements (if many) with closing(engine.raw_connection()) as conn: # closing the connection closes the cursor as well cursor = conn.cursor() statement_count = len(statements) for i, statement in enumerate(statements): # Check if stopped query = get_query(query_id, session) if query.status == QueryStatus.STOPPED: return None # For CTAS we create the table only on the last statement apply_ctas = query.select_as_cta and ( query.ctas_method == CtasMethod.VIEW or (query.ctas_method == CtasMethod.TABLE and i == len(statements) - 1) ) # Run statement msg = f"Running statement {i+1} out of {statement_count}" logger.info("Query %s: %s", str(query_id), msg) query.set_extra_json_key("progress", msg) session.commit() try: result_set = execute_sql_statement( statement, query, user_name, session, cursor, log_params, apply_ctas, ) except Exception as ex: # pylint: disable=broad-except msg = str(ex) prefix_message = ( f"[Statement {i+1} out of {statement_count}]" if statement_count > 1 else "" ) payload = handle_query_error( ex, query, session, payload, prefix_message ) return payload # Commit the connection so CTA queries will create the table. conn.commit() # Success, updating the query entry in database query.rows = result_set.size query.progress = 100 query.set_extra_json_key("progress", None) if query.select_as_cta: query.select_sql = database.select_star( query.tmp_table_name, schema=query.tmp_schema_name, limit=query.limit, show_cols=False, latest_partition=False, ) query.end_time = now_as_float() use_arrow_data = store_results and cast(bool, results_backend_use_msgpack) data, selected_columns, all_columns, expanded_columns = _serialize_and_expand_data( result_set, db_engine_spec, use_arrow_data, expand_data ) # TODO: data should be saved separately from metadata (likely in Parquet) payload.update( { "status": QueryStatus.SUCCESS, "data": data, "columns": all_columns, "selected_columns": selected_columns, "expanded_columns": expanded_columns, "query": query.to_dict(), } ) payload["query"]["state"] = QueryStatus.SUCCESS if store_results and results_backend: key = str(uuid.uuid4()) logger.info( "Query %s: Storing results in results backend, key: %s", str(query_id), key ) with stats_timing("sqllab.query.results_backend_write", stats_logger): with stats_timing( "sqllab.query.results_backend_write_serialization", stats_logger ): serialized_payload = _serialize_payload( payload, cast(bool, results_backend_use_msgpack) ) cache_timeout = database.cache_timeout if cache_timeout is None: cache_timeout = config["CACHE_DEFAULT_TIMEOUT"] compressed = zlib_compress(serialized_payload) logger.debug( "*** serialized payload size: %i", getsizeof(serialized_payload) ) logger.debug("*** compressed payload size: %i", getsizeof(compressed)) results_backend.set(key, compressed, cache_timeout) query.results_key = key query.status = QueryStatus.SUCCESS session.commit() if return_results: # since we're returning results we need to create non-arrow data if use_arrow_data: ( data, selected_columns, all_columns, expanded_columns, ) = _serialize_and_expand_data( result_set, db_engine_spec, False, expand_data ) payload.update( { "data": data, "columns": all_columns, "selected_columns": selected_columns, "expanded_columns": expanded_columns, } ) return payload return None
def execute_sql_statements( # pylint: disable=too-many-arguments, too-many-locals, too-many-statements query_id: int, rendered_query: str, return_results: bool, store_results: bool, user_name: Optional[str], session: Session, start_time: Optional[float], expand_data: bool, log_params: Optional[Dict[str, Any]], ) -> Optional[Dict[str, Any]]: """Executes the sql query returns the results.""" if store_results and start_time: # only asynchronous queries stats_logger.timing("sqllab.query.time_pending", now_as_float() - start_time) query = get_query(query_id, session) payload: Dict[str, Any] = dict(query_id=query_id) database = query.database db_engine_spec = database.db_engine_spec db_engine_spec.patch() if database.allow_run_async and not results_backend: raise SqlLabException("Results backend isn't configured.") # Breaking down into multiple statements parsed_query = ParsedQuery(rendered_query) statements = parsed_query.get_statements() logger.info("Query %s: Executing %i statement(s)", str(query_id), len(statements)) logger.info("Query %s: Set query to 'running'", str(query_id)) query.status = QueryStatus.RUNNING query.start_running_time = now_as_float() session.commit() engine = database.get_sqla_engine( schema=query.schema, nullpool=True, user_name=user_name, source=QuerySource.SQL_LAB, ) # Sharing a single connection and cursor across the # execution of all statements (if many) with closing(engine.raw_connection()) as conn: with closing(conn.cursor()) as cursor: statement_count = len(statements) for i, statement in enumerate(statements): # Check if stopped query = get_query(query_id, session) if query.status == QueryStatus.STOPPED: return None # Run statement msg = f"Running statement {i+1} out of {statement_count}" logger.info("Query %s: %s", str(query_id), msg) query.set_extra_json_key("progress", msg) session.commit() try: result_set = execute_sql_statement(statement, query, user_name, session, cursor, log_params) except Exception as ex: # pylint: disable=broad-except msg = str(ex) if statement_count > 1: msg = f"[Statement {i+1} out of {statement_count}] " + msg payload = handle_query_error(msg, query, session, payload) return payload # Commit the connection so CTA queries will create the table. conn.commit() # Success, updating the query entry in database query.rows = result_set.size query.progress = 100 query.set_extra_json_key("progress", None) if query.select_as_cta: query.select_sql = database.select_star( query.tmp_table_name, schema=query.tmp_schema_name, limit=query.limit, show_cols=False, latest_partition=False, ) query.end_time = now_as_float() use_arrow_data = store_results and cast(bool, results_backend_use_msgpack) data, selected_columns, all_columns, expanded_columns = _serialize_and_expand_data( result_set, db_engine_spec, use_arrow_data, expand_data) # TODO: data should be saved separately from metadata (likely in Parquet) payload.update({ "status": QueryStatus.SUCCESS, "data": data, "columns": all_columns, "selected_columns": selected_columns, "expanded_columns": expanded_columns, "query": query.to_dict(), }) payload["query"]["state"] = QueryStatus.SUCCESS if store_results and results_backend: key = str(uuid.uuid4()) logger.info("Query %s: Storing results in results backend, key: %s", str(query_id), key) with stats_timing("sqllab.query.results_backend_write", stats_logger): with stats_timing( "sqllab.query.results_backend_write_serialization", stats_logger): serialized_payload = _serialize_payload( payload, cast(bool, results_backend_use_msgpack)) cache_timeout = database.cache_timeout if cache_timeout is None: cache_timeout = config["CACHE_DEFAULT_TIMEOUT"] compressed = zlib_compress(serialized_payload) logger.debug("*** serialized payload size: %i", getsizeof(serialized_payload)) logger.debug("*** compressed payload size: %i", getsizeof(compressed)) results_backend.set(key, compressed, cache_timeout) query.results_key = key query.status = QueryStatus.SUCCESS session.commit() if return_results: # since we're returning results we need to create non-arrow data if use_arrow_data: ( data, selected_columns, all_columns, expanded_columns, ) = _serialize_and_expand_data(result_set, db_engine_spec, False, expand_data) payload.update({ "data": data, "columns": all_columns, "selected_columns": selected_columns, "expanded_columns": expanded_columns, }) return payload return None
def get_table_metadata(database: Database, table_name: str, schema_name: Optional[str]) -> Dict: """ Get table metadata information, including type, pk, fks. This function raises SQLAlchemyError when a schema is not found. :param database: The database model :param table_name: Table name :param schema_name: schema name :return: Dict table metadata ready for API response """ keys: List = [] columns = database.get_columns(table_name, schema_name) # define comment dict by tsl comment_dict = {} primary_key = database.get_pk_constraint(table_name, schema_name) if primary_key and primary_key.get("constrained_columns"): primary_key["column_names"] = primary_key.pop("constrained_columns") primary_key["type"] = "pk" keys += [primary_key] # get dialect name dialect_name = database.get_dialect().name if isinstance(dialect_name, bytes): dialect_name = dialect_name.decode() # get column comment, presto & hive if dialect_name == "presto" or dialect_name == "hive": db_engine_spec = database.db_engine_spec sql = ParsedQuery("desc {a}.{b}".format(a=schema_name, b=table_name)).stripped() engine = database.get_sqla_engine(schema_name) conn = engine.raw_connection() cursor = conn.cursor() query = Query() session = Session(bind=engine) query.executed_sql = sql query.__tablename__ = table_name session.commit() db_engine_spec.execute(cursor, sql, async_=False) data = db_engine_spec.fetch_data(cursor, query.limit) # parse list data into dict by tsl; hive and presto is different if dialect_name == "presto": for d in data: d[3] comment_dict[d[0]] = d[3] else: for d in data: d[2] comment_dict[d[0]] = d[2] conn.commit() foreign_keys = get_foreign_keys_metadata(database, table_name, schema_name) indexes = get_indexes_metadata(database, table_name, schema_name) keys += foreign_keys + indexes payload_columns: List[Dict] = [] for col in columns: dtype = get_col_type(col) if len(comment_dict) > 0: payload_columns.append({ "name": col["name"], "type": dtype.split("(")[0] if "(" in dtype else dtype, "longType": dtype, "keys": [k for k in keys if col["name"] in k.get("column_names")], "comment": comment_dict[col["name"]], }) elif dialect_name == "mysql": payload_columns.append({ "name": col["name"], "type": dtype.split("(")[0] if "(" in dtype else dtype, "longType": dtype, "keys": [k for k in keys if col["name"] in k.get("column_names")], "comment": col["comment"], }) else: payload_columns.append({ "name": col["name"], "type": dtype.split("(")[0] if "(" in dtype else dtype, "longType": dtype, "keys": [k for k in keys if col["name"] in k.get("column_names")], # "comment": col["comment"], }) return { "name": table_name, "columns": payload_columns, "selectStar": database.select_star( table_name, schema=schema_name, show_cols=True, indent=True, cols=columns, latest_partition=True, ), "primaryKey": primary_key, "foreignKeys": foreign_keys, "indexes": keys, }
def extract_tables(query: str) -> Set[Table]: """ Helper function to extract tables referenced in a query. """ return ParsedQuery(query).tables