def handle_query_error( msg: str, query: Query, session: Session, payload: Optional[Dict[str, Any]] = None) -> Dict[str, Any]: """Local method handling error while processing the SQL""" payload = payload or {} troubleshooting_link = config["TROUBLESHOOTING_LINK"] query.error_message = msg query.status = QueryStatus.FAILED query.tmp_table_name = None # extract DB-specific errors (invalid column, eg) errors = [ dataclasses.asdict(error) for error in query.database.db_engine_spec.extract_errors(msg) ] if errors: query.set_extra_json_key("errors", errors) session.commit() payload.update({"status": query.status, "error": msg, "errors": errors}) if troubleshooting_link: payload["link"] = troubleshooting_link return payload
def handle_query_error( ex: Exception, query: Query, session: Session, payload: Optional[Dict[str, Any]] = None, prefix_message: str = "", ) -> Dict[str, Any]: """Local method handling error while processing the SQL""" payload = payload or {} msg = f"{prefix_message} {str(ex)}".strip() troubleshooting_link = config["TROUBLESHOOTING_LINK"] query.error_message = msg query.status = QueryStatus.FAILED query.tmp_table_name = None # extract DB-specific errors (invalid column, eg) if isinstance(ex, SupersetErrorException): errors = [ex.error] elif isinstance(ex, SupersetErrorsException): errors = ex.errors else: errors = query.database.db_engine_spec.extract_errors(str(ex)) errors_payload = [dataclasses.asdict(error) for error in errors] if errors: query.set_extra_json_key("errors", errors_payload) session.commit() payload.update({"status": query.status, "error": msg, "errors": errors_payload}) if troubleshooting_link: payload["link"] = troubleshooting_link return payload
def create_query(self) -> Query: # pylint: disable=line-too-long start_time = now_as_float() if self.select_as_cta: return Query( database_id=self.database_id, sql=self.sql, schema=self.schema, select_as_cta=True, ctas_method=self.create_table_as_select.ctas_method, # type: ignore start_time=start_time, tab_name=self.tab_name, status=self.status, limit=self.limit, sql_editor_id=self.sql_editor_id, tmp_table_name=self.create_table_as_select.target_table_name, # type: ignore tmp_schema_name=self.create_table_as_select.target_schema_name, # type: ignore user_id=self.user_id, client_id=self.client_id_or_short_id, ) return Query( database_id=self.database_id, sql=self.sql, schema=self.schema, select_as_cta=False, start_time=start_time, tab_name=self.tab_name, limit=self.limit, status=self.status, sql_editor_id=self.sql_editor_id, user_id=self.user_id, client_id=self.client_id_or_short_id, )
def handle_cursor(cls, cursor: Cursor, query: Query, session: Session) -> None: tracking_url = cls.get_tracking_url(cursor) if tracking_url: query.tracking_url = tracking_url # Adds the executed query id to the extra payload so the query can be cancelled query.set_extra_json_key("cancel_query", cursor.stats["queryId"]) session.commit() BaseEngineSpec.handle_cursor(cursor=cursor, query=query, session=session)
def handle_cursor( # pylint: disable=too-many-locals cls, cursor: Any, query: Query, session: Session ) -> None: """Updates progress information""" from pyhive import hive # pylint: disable=no-name-in-module unfinished_states = ( hive.ttypes.TOperationState.INITIALIZED_STATE, hive.ttypes.TOperationState.RUNNING_STATE, ) polled = cursor.poll() last_log_line = 0 tracking_url = None job_id = None query_id = query.id while polled.operationState in unfinished_states: query = session.query(type(query)).filter_by(id=query_id).one() if query.status == QueryStatus.STOPPED: cursor.cancel() break log = cursor.fetch_logs() or "" if log: log_lines = log.splitlines() progress = cls.progress(log_lines) logger.info(f"Query {query_id}: Progress total: {progress}") needs_commit = False if progress > query.progress: query.progress = progress needs_commit = True if not tracking_url: tracking_url = cls.get_tracking_url(log_lines) if tracking_url: job_id = tracking_url.split("/")[-2] logger.info( f"Query {query_id}: Found the tracking url: {tracking_url}" ) tracking_url = tracking_url_trans(tracking_url) logger.info( f"Query {query_id}: Transformation applied: {tracking_url}" ) query.tracking_url = tracking_url logger.info(f"Query {query_id}: Job id: {job_id}") needs_commit = True if job_id and len(log_lines) > last_log_line: # Wait for job id before logging things out # this allows for prefixing all log lines and becoming # searchable in something like Kibana for l in log_lines[last_log_line:]: logger.info(f"Query {query_id}: [{job_id}] {l}") last_log_line = len(log_lines) if needs_commit: session.commit() time.sleep(hive_poll_interval) polled = cursor.poll()
def handle_query_error( msg: str, query: Query, session: Session, payload: Optional[Dict[str, Any]] = None ) -> Dict[str, Any]: """Local method handling error while processing the SQL""" payload = payload or {} troubleshooting_link = config["TROUBLESHOOTING_LINK"] query.error_message = msg query.status = QueryStatus.FAILED query.tmp_table_name = None session.commit() payload.update({"status": query.status, "error": msg}) if troubleshooting_link: payload["link"] = troubleshooting_link return payload
def insert_query( self, database_id: int, user_id: int, client_id: str, sql: str = "", select_sql: str = "", executed_sql: str = "", limit: int = 100, progress: int = 100, rows: int = 100, tab_name: str = "", status: str = "success", ) -> Query: database = db.session.query(Database).get(database_id) user = db.session.query(security_manager.user_model).get(user_id) query = Query( database=database, user=user, client_id=client_id, sql=sql, select_sql=select_sql, executed_sql=executed_sql, limit=limit, progress=progress, rows=rows, tab_name=tab_name, status=status, ) db.session.add(query) db.session.commit() return query
def test_cancel_query_success(engine_mock: mock.Mock) -> None: from superset.db_engine_specs.trino import TrinoEngineSpec from superset.models.sql_lab import Query query = Query() cursor_mock = engine_mock.return_value.__enter__.return_value assert TrinoEngineSpec.cancel_query(cursor_mock, query, "123") is True
def test_cancel_query_failed(engine_mock: mock.Mock) -> None: from superset.db_engine_specs.trino import TrinoEngineSpec from superset.models.sql_lab import Query query = Query() cursor_mock = engine_mock.raiseError.side_effect = Exception() assert TrinoEngineSpec.cancel_query(cursor_mock, query, "123") is False
def test_query_dao_save_metadata(session: Session) -> None: from superset.models.core import Database from superset.models.sql_lab import Query engine = session.get_bind() Query.metadata.create_all(engine) # pylint: disable=no-member db = Database(database_name="my_database", sqlalchemy_uri="sqlite://") query_obj = Query( client_id="foo", database=db, tab_name="test_tab", sql_editor_id="test_editor_id", sql="select * from bar", select_sql="select * from bar", executed_sql="select * from bar", limit=100, select_as_cta=False, rows=100, error_message="none", results_key="abc", ) session.add(db) session.add(query_obj) from superset.queries.dao import QueryDAO query = session.query(Query).one() QueryDAO.save_metadata(query=query, payload={"columns": []}) assert query.extra.get("columns", None) == []
def test_get_cancel_query_id(engine_mock: mock.Mock) -> None: from superset.db_engine_specs.snowflake import SnowflakeEngineSpec from superset.models.sql_lab import Query query = Query() cursor_mock = engine_mock.return_value.__enter__.return_value cursor_mock.fetchone.return_value = [123] assert SnowflakeEngineSpec.get_cancel_query_id(cursor_mock, query) == 123
def example_query( get_or_create_user: Callable[..., ContextManager[ab_models.User]]): with get_or_create_user("sqllab-test-user") as user: query = Query(client_id=shortid()[:10], database=get_example_database(), user=user) db.session.add(query) db.session.commit() yield query db.session.delete(query) db.session.commit()
def test_query_has_access(mocker: MockFixture) -> None: from superset.explore.utils import check_datasource_access from superset.models.sql_lab import Query mocker.patch(query_find_by_id, return_value=Query()) mocker.patch(raise_for_access, return_value=True) mocker.patch(is_admin, return_value=False) mocker.patch(is_owner, return_value=False) mocker.patch(can_access, return_value=True) assert (check_datasource_access( datasource_id=1, datasource_type=DatasourceType.QUERY, ) == True)
def create_query(self): with self.create_app().app_context(): session = db.session query = Query( sql="select 1 as foo;", client_id="sldkfjlk", database=get_example_database(), ) session.add(query) session.commit() yield query # rollback session.delete(query) session.commit()
def test_query_no_access(mocker: MockFixture, app_context: AppContext) -> None: from superset.connectors.sqla.models import SqlaTable from superset.explore.utils import check_datasource_access from superset.models.core import Database from superset.models.sql_lab import Query with raises(SupersetSecurityException): mocker.patch( query_find_by_id, return_value=Query(database=Database(), sql="select * from foo"), ) mocker.patch(query_datasources_by_name, return_value=[SqlaTable()]) mocker.patch(is_user_admin, return_value=False) mocker.patch(is_owner, return_value=False) mocker.patch(can_access, return_value=False) check_datasource_access( datasource_id=1, datasource_type=DatasourceType.QUERY, )
def handle_cursor(cls, cursor: Any, query: Query, session: Session) -> None: """Updates progress information""" query_id = query.id poll_interval = query.database.connect_args.get( "poll_interval", current_app.config["PRESTO_POLL_INTERVAL"] ) logger.info("Query %i: Polling the cursor for progress", query_id) polled = cursor.poll() # poll returns dict -- JSON status information or ``None`` # if the query is done # https://github.com/dropbox/PyHive/blob/ # b34bdbf51378b3979eaf5eca9e956f06ddc36ca0/pyhive/presto.py#L178 while polled: # Update the object and wait for the kill signal. stats = polled.get("stats", {}) query = session.query(type(query)).filter_by(id=query_id).one() if query.status in [QueryStatus.STOPPED, QueryStatus.TIMED_OUT]: cursor.cancel() break if stats: state = stats.get("state") # if already finished, then stop polling if state == "FINISHED": break completed_splits = float(stats.get("completedSplits")) total_splits = float(stats.get("totalSplits")) if total_splits and completed_splits: progress = 100 * (completed_splits / total_splits) logger.info( "Query {} progress: {} / {} " # pylint: disable=logging-format-interpolation "splits".format(query_id, completed_splits, total_splits) ) if progress > query.progress: query.progress = progress session.commit() time.sleep(poll_interval) logger.info("Query %i: Polling the cursor for progress", query_id) polled = cursor.poll()
def save_metadata(query: Query, payload: Dict[str, Any]) -> None: # pull relevant data from payload and store in extra_json columns = payload.get("columns", {}) db.session.add(query) query.set_extra_json_key("columns", columns)
def test_sql_lab_insert_rls( mocker: MockerFixture, session: Session, app_context: None, ) -> None: """ Integration test for `insert_rls`. """ from flask_appbuilder.security.sqla.models import Role, User from superset.connectors.sqla.models import RowLevelSecurityFilter, SqlaTable from superset.models.core import Database from superset.models.sql_lab import Query from superset.security.manager import SupersetSecurityManager from superset.sql_lab import execute_sql_statement from superset.utils.core import RowLevelSecurityFilterType engine = session.connection().engine Query.metadata.create_all(engine) # pylint: disable=no-member connection = engine.raw_connection() connection.execute("CREATE TABLE t (c INTEGER)") for i in range(10): connection.execute("INSERT INTO t VALUES (?)", (i, )) cursor = connection.cursor() query = Query( sql="SELECT c FROM t", client_id="abcde", database=Database(database_name="test_db", sqlalchemy_uri="sqlite://"), schema=None, limit=5, select_as_cta_used=False, ) session.add(query) session.commit() admin = User( first_name="Alice", last_name="Doe", email="*****@*****.**", username="******", roles=[Role(name="Admin")], ) # first without RLS with override_user(admin): superset_result_set = execute_sql_statement( sql_statement=query.sql, query=query, session=session, cursor=cursor, log_params=None, apply_ctas=False, ) assert (superset_result_set.to_pandas_df().to_markdown() == """ | | c | |---:|----:| | 0 | 0 | | 1 | 1 | | 2 | 2 | | 3 | 3 | | 4 | 4 |""".strip()) assert query.executed_sql == "SELECT c FROM t\nLIMIT 6" # now with RLS rls = RowLevelSecurityFilter( filter_type=RowLevelSecurityFilterType.REGULAR, tables=[SqlaTable(database_id=1, schema=None, table_name="t")], roles=[admin.roles[0]], group_key=None, clause="c > 5", ) session.add(rls) session.flush() mocker.patch.object(SupersetSecurityManager, "find_user", return_value=admin) mocker.patch("superset.sql_lab.is_feature_enabled", return_value=True) with override_user(admin): superset_result_set = execute_sql_statement( sql_statement=query.sql, query=query, session=session, cursor=cursor, log_params=None, apply_ctas=False, ) assert (superset_result_set.to_pandas_df().to_markdown() == """ | | c | |---:|----:| | 0 | 6 | | 1 | 7 | | 2 | 8 | | 3 | 9 |""".strip()) assert query.executed_sql == "SELECT c FROM t WHERE (t.c > 5)\nLIMIT 6"
def handle_cursor( # pylint: disable=too-many-locals cls, cursor: Any, query: Query, session: Session) -> None: """Updates progress information""" # pylint: disable=import-outside-toplevel from pyhive import hive unfinished_states = ( hive.ttypes.TOperationState.INITIALIZED_STATE, hive.ttypes.TOperationState.RUNNING_STATE, ) polled = cursor.poll() last_log_line = 0 tracking_url = None job_id = None query_id = query.id while polled.operationState in unfinished_states: # Queries don't terminate when user clicks the STOP button on SQL LAB. # Refresh session so that the `query.status` modified in stop_query in # views/core.py is reflected here. session.refresh(query) query = session.query(type(query)).filter_by(id=query_id).one() if query.status == QueryStatus.STOPPED: cursor.cancel() break try: log = cursor.fetch_logs() or "" except Exception: # pylint: disable=broad-except logger.warning("Call to GetLog() failed") log = "" if log: log_lines = log.splitlines() progress = cls.progress(log_lines) logger.info("Query %s: Progress total: %s", str(query_id), str(progress)) needs_commit = False if progress > query.progress: query.progress = progress needs_commit = True if not tracking_url: tracking_url = cls.get_tracking_url(log_lines) if tracking_url: job_id = tracking_url.split("/")[-2] logger.info( "Query %s: Found the tracking url: %s", str(query_id), tracking_url, ) transformer = current_app.config[ "TRACKING_URL_TRANSFORMER"] tracking_url = transformer(tracking_url) logger.info( "Query %s: Transformation applied: %s", str(query_id), tracking_url, ) query.tracking_url = tracking_url logger.info("Query %s: Job id: %s", str(query_id), str(job_id)) needs_commit = True if job_id and len(log_lines) > last_log_line: # Wait for job id before logging things out # this allows for prefixing all log lines and becoming # searchable in something like Kibana for l in log_lines[last_log_line:]: logger.info("Query %s: [%s] %s", str(query_id), str(job_id), l) last_log_line = len(log_lines) if needs_commit: session.commit() time.sleep(current_app.config["HIVE_POLL_INTERVAL"]) polled = cursor.poll()
def get_one_report(id): o = db.session.query(SavedQuery).filter_by(id=id).first() desc = {} try: desc = json.loads(o.description) except ValueError: pass if request.method == 'GET': return jsonify({ 'id': o.id, 'created_on': o.created_on.strftime('%Y-%m-%d'), 'changed_on': o.changed_on.strftime('%Y-%m-%d'), 'user_id': o.user_id or '', 'db_id': o.db_id or '', 'label': o.label or '', 'schema': o.schema or '', 'sql': o.sql or '', 'description': desc, }) elif request.method == 'POST': qjson = request.json sql = o.sql database_id = o.db_id schema = o.schema label = o.label session = db.session() mydb = session.query(models.Database).filter_by(id=database_id).first() # paginate page = qjson.get('page', 1) per_page = qjson.get('per_page', app.config['REPORT_PER_PAGE']) hkey = get_hash_key() # # parse config; filters and fields and sorts # qsort = ["ds","desc"] qsort = qjson.get('sort', []) sort = " order by _.%s %s" % (qsort[0], qsort[1]) if qsort else "" # date transfer problem solve after #filters = [ {"field":"ds", "type":"range", "value1":"2016-01-01", "value2":"2017-01-03", "help":u"date字段"} ] filters = qjson.get('filterfield_set', []) fs = [] for f in filters: if f['type'] == 'range': fs.append( "(_.%(field)s >= '%(value1)s' and _.%(field)s < '%(value2)s')" % f) elif f['type'] == 'like': fs.append("_.%(field)s like '%%%(value1)s%%'" % f) else: fs.append("_.%(field)s %(type)s '%(value1)s'") where = " where " + (" and ".join(fs)) if fs else "" # count_sql = "SELECT count(1) as num FROM (%s) _"%sql # sql = "SELECT * FROM (%s) _ LIMIT %s,%s"%(sql, (page-1)*per_page, per_page) # # sql can't end with `;` , complicated sql use select .. as .. count_sql = "SELECT count(1) as num FROM (%s) _ %s" % (sql, where) sql = "SELECT * FROM (%s) _ %s %s LIMIT %s,%s" % (sql, where, sort, (page - 1) * per_page, per_page) if True: query = Query( database_id=int(database_id), limit=1000000, #int(app.config.get('SQL_MAX_ROW', None)), sql=sql, schema=schema, select_as_cta=False, start_time=utils.now_as_float(), tab_name=label, status=QueryStatus.RUNNING, sql_editor_id=hkey[0] + hkey[1], tmp_table_name='', user_id=int(g.user.get_id()), client_id=hkey[2] + hkey[3], ) session.add(query) cquery = Query( database_id=int(database_id), limit=1000000, #int(app.config.get('SQL_MAX_ROW', None)), sql=count_sql, schema=schema, select_as_cta=False, start_time=utils.now_as_float(), tab_name=label, status=QueryStatus.RUNNING, sql_editor_id=hkey[0] + hkey[1], tmp_table_name='', user_id=int(g.user.get_id()), client_id=hkey[0] + hkey[1], ) session.add(cquery) session.flush() db.session.commit() query_id = query.id cquery_id = cquery.id data = sql_lab.get_sql_results(query_id=query_id, return_results=True, template_params={}) cdata = sql_lab.get_sql_results(query_id=cquery_id, return_results=True, template_params={}) return jsonify({ 'data': data['data'], 'id': id, 'label': label, 'query_id': data['query_id'], 'limit': data['query']['limit'], 'limit_reached': False, 'page': page, 'per_page': per_page, 'pages': get_pages(cdata['data'][0]['num'], per_page), 'total': cdata['data'][0]['num'], 'rows': data['query']['rows'], 'sort': qsort, 'changed_on': data['query']['changed_on'], 'displayfield_set': desc['displayfield_set'], 'report_file': url_for('download_one_report', id=id, query_id=data['query_id']), 'status': 'success', }) return 'ok'
def execute_sql_statement( # pylint: disable=too-many-arguments,too-many-statements sql_statement: str, query: Query, session: Session, cursor: Any, log_params: Optional[Dict[str, Any]], apply_ctas: bool = False, ) -> SupersetResultSet: """Executes a single SQL statement""" database: Database = query.database db_engine_spec = database.db_engine_spec parsed_query = ParsedQuery(sql_statement) if is_feature_enabled("RLS_IN_SQLLAB"): # Insert any applicable RLS predicates parsed_query = ParsedQuery( str( insert_rls( parsed_query._parsed[0], # pylint: disable=protected-access database.id, query.schema, ))) sql = parsed_query.stripped() # This is a test to see if the query is being # limited by either the dropdown or the sql. # We are testing to see if more rows exist than the limit. increased_limit = None if query.limit is None else query.limit + 1 if not db_engine_spec.is_readonly_query( parsed_query) and not database.allow_dml: raise SupersetErrorException( SupersetError( message=__( "Only SELECT statements are allowed against this database." ), error_type=SupersetErrorType.DML_NOT_ALLOWED_ERROR, level=ErrorLevel.ERROR, )) if apply_ctas: if not query.tmp_table_name: start_dttm = datetime.fromtimestamp(query.start_time) query.tmp_table_name = "tmp_{}_table_{}".format( query.user_id, start_dttm.strftime("%Y_%m_%d_%H_%M_%S")) sql = parsed_query.as_create_table( query.tmp_table_name, schema_name=query.tmp_schema_name, method=query.ctas_method, ) query.select_as_cta_used = True # Do not apply limit to the CTA queries when SQLLAB_CTAS_NO_LIMIT is set to true if db_engine_spec.is_select_query(parsed_query) and not ( query.select_as_cta_used and SQLLAB_CTAS_NO_LIMIT): if SQL_MAX_ROW and (not query.limit or query.limit > SQL_MAX_ROW): query.limit = SQL_MAX_ROW sql = apply_limit_if_exists(database, increased_limit, query, sql) # Hook to allow environment-specific mutation (usually comments) to the SQL sql = SQL_QUERY_MUTATOR( sql, user_name=get_username(), # TODO(john-bodley): Deprecate in 3.0. security_manager=security_manager, database=database, ) try: query.executed_sql = sql if log_query: log_query( query.database.sqlalchemy_uri, query.executed_sql, query.schema, get_username(), __name__, security_manager, log_params, ) session.commit() with stats_timing("sqllab.query.time_executing_query", stats_logger): logger.debug("Query %d: Running query: %s", query.id, sql) db_engine_spec.execute(cursor, sql, async_=True) logger.debug("Query %d: Handling cursor", query.id) db_engine_spec.handle_cursor(cursor, query, session) with stats_timing("sqllab.query.time_fetching_results", stats_logger): logger.debug( "Query %d: Fetching data for query object: %s", query.id, str(query.to_dict()), ) data = db_engine_spec.fetch_data(cursor, increased_limit) if query.limit is None or len(data) <= query.limit: query.limiting_factor = LimitingFactor.NOT_LIMITED else: # return 1 row less than increased_query data = data[:-1] except SoftTimeLimitExceeded as ex: query.status = QueryStatus.TIMED_OUT logger.warning("Query %d: Time limit exceeded", query.id) logger.debug("Query %d: %s", query.id, ex) raise SupersetErrorException( SupersetError( message=__( "The query was killed after %(sqllab_timeout)s seconds. It might " "be too complex, or the database might be under heavy load.", sqllab_timeout=SQLLAB_TIMEOUT, ), error_type=SupersetErrorType.SQLLAB_TIMEOUT_ERROR, level=ErrorLevel.ERROR, )) from ex except Exception as ex: # query is stopped in another thread/worker # stopping raises expected exceptions which we should skip session.refresh(query) if query.status == QueryStatus.STOPPED: raise SqlLabQueryStoppedException() from ex logger.error("Query %d: %s", query.id, type(ex), exc_info=True) logger.debug("Query %d: %s", query.id, ex) raise SqlLabException(db_engine_spec.extract_error_message(ex)) from ex logger.debug("Query %d: Fetching cursor description", query.id) cursor_description = cursor.description return SupersetResultSet(data, cursor_description, db_engine_spec)
def test_cancel_query(self, engine_mock): query = Query() cursor_mock = engine_mock.return_value.__enter__.return_value assert SnowflakeEngineSpec.cancel_query(cursor_mock, query, 123) is True
def _to_payload_query_based( # pylint: disable=no-self-use self, query: Query) -> str: return json.dumps({"query": query.to_dict()}, default=utils.json_int_dttm_ser, ignore_nan=True)
def execute_sql_statement( sql_statement: str, query: Query, user_name: Optional[str], session: Session, cursor: Any, log_params: Optional[Dict[str, Any]], ) -> SupersetResultSet: """Executes a single SQL statement""" database = query.database db_engine_spec = database.db_engine_spec parsed_query = ParsedQuery(sql_statement) sql = parsed_query.stripped() if not parsed_query.is_readonly() and not database.allow_dml: raise SqlLabSecurityException( _("Only `SELECT` statements are allowed against this database")) if query.select_as_cta: if not parsed_query.is_select(): raise SqlLabException( _("Only `SELECT` statements can be used with the CREATE TABLE " "feature.")) if not query.tmp_table_name: start_dttm = datetime.fromtimestamp(query.start_time) query.tmp_table_name = "tmp_{}_table_{}".format( query.user_id, start_dttm.strftime("%Y_%m_%d_%H_%M_%S")) sql = parsed_query.as_create_table( query.tmp_table_name, schema_name=query.tmp_schema_name, method=query.ctas_method, ) query.select_as_cta_used = True # Do not apply limit to the CTA queries when SQLLAB_CTAS_NO_LIMIT is set to true if parsed_query.is_select() and not (query.select_as_cta_used and SQLLAB_CTAS_NO_LIMIT): if SQL_MAX_ROW and (not query.limit or query.limit > SQL_MAX_ROW): query.limit = SQL_MAX_ROW if query.limit: sql = database.apply_limit_to_sql(sql, query.limit) # Hook to allow environment-specific mutation (usually comments) to the SQL if SQL_QUERY_MUTATOR: sql = SQL_QUERY_MUTATOR(sql, user_name, security_manager, database) try: if log_query: log_query( query.database.sqlalchemy_uri, query.executed_sql, query.schema, user_name, __name__, security_manager, log_params, ) query.executed_sql = sql session.commit() with stats_timing("sqllab.query.time_executing_query", stats_logger): logger.debug("Query %d: Running query: %s", query.id, sql) db_engine_spec.execute(cursor, sql, async_=True) logger.debug("Query %d: Handling cursor", query.id) db_engine_spec.handle_cursor(cursor, query, session) with stats_timing("sqllab.query.time_fetching_results", stats_logger): logger.debug( "Query %d: Fetching data for query object: %s", query.id, str(query.to_dict()), ) data = db_engine_spec.fetch_data(cursor, query.limit) except SoftTimeLimitExceeded as ex: logger.error("Query %d: Time limit exceeded", query.id) logger.debug("Query %d: %s", query.id, ex) raise SqlLabTimeoutException( "SQL Lab timeout. This environment's policy is to kill queries " "after {} seconds.".format(SQLLAB_TIMEOUT)) except Exception as ex: logger.error("Query %d: %s", query.id, type(ex)) logger.debug("Query %d: %s", query.id, ex) raise SqlLabException(db_engine_spec.extract_error_message(ex)) logger.debug("Query %d: Fetching cursor description", query.id) cursor_description = cursor.description return SupersetResultSet(data, cursor_description, db_engine_spec)
def execute_sql_statement( sql_statement: str, query: Query, user_name: Optional[str], session: Session, cursor: Any, log_params: Optional[Dict[str, Any]], apply_ctas: bool = False, ) -> SupersetResultSet: """Executes a single SQL statement""" database = query.database db_engine_spec = database.db_engine_spec parsed_query = ParsedQuery(sql_statement) sql = parsed_query.stripped() # This is a test to see if the query is being # limited by either the dropdown or the sql. # We are testing to see if more rows exist than the limit. increased_limit = None if query.limit is None else query.limit + 1 if not db_engine_spec.is_readonly_query(parsed_query) and not database.allow_dml: raise SupersetErrorException( SupersetError( message=__("Only SELECT statements are allowed against this database."), error_type=SupersetErrorType.DML_NOT_ALLOWED_ERROR, level=ErrorLevel.ERROR, ) ) if apply_ctas: if not query.tmp_table_name: start_dttm = datetime.fromtimestamp(query.start_time) query.tmp_table_name = "tmp_{}_table_{}".format( query.user_id, start_dttm.strftime("%Y_%m_%d_%H_%M_%S") ) sql = parsed_query.as_create_table( query.tmp_table_name, schema_name=query.tmp_schema_name, method=query.ctas_method, ) query.select_as_cta_used = True # Do not apply limit to the CTA queries when SQLLAB_CTAS_NO_LIMIT is set to true if db_engine_spec.is_select_query(parsed_query) and not ( query.select_as_cta_used and SQLLAB_CTAS_NO_LIMIT ): if SQL_MAX_ROW and (not query.limit or query.limit > SQL_MAX_ROW): query.limit = SQL_MAX_ROW if query.limit: # We are fetching one more than the requested limit in order # to test whether there are more rows than the limit. # Later, the extra row will be dropped before sending # the results back to the user. sql = database.apply_limit_to_sql(sql, increased_limit, force=True) # Hook to allow environment-specific mutation (usually comments) to the SQL sql = SQL_QUERY_MUTATOR(sql, user_name, security_manager, database) try: query.executed_sql = sql if log_query: log_query( query.database.sqlalchemy_uri, query.executed_sql, query.schema, user_name, __name__, security_manager, log_params, ) session.commit() with stats_timing("sqllab.query.time_executing_query", stats_logger): logger.debug("Query %d: Running query: %s", query.id, sql) db_engine_spec.execute(cursor, sql, async_=True) logger.debug("Query %d: Handling cursor", query.id) db_engine_spec.handle_cursor(cursor, query, session) with stats_timing("sqllab.query.time_fetching_results", stats_logger): logger.debug( "Query %d: Fetching data for query object: %s", query.id, str(query.to_dict()), ) data = db_engine_spec.fetch_data(cursor, increased_limit) if query.limit is None or len(data) <= query.limit: query.limiting_factor = LimitingFactor.NOT_LIMITED else: # return 1 row less than increased_query data = data[:-1] except SoftTimeLimitExceeded as ex: logger.warning("Query %d: Time limit exceeded", query.id) logger.debug("Query %d: %s", query.id, ex) raise SupersetErrorException( SupersetError( message=__( f"The query was killed after {SQLLAB_TIMEOUT} seconds. It might " "be too complex, or the database might be under heavy load." ), error_type=SupersetErrorType.SQLLAB_TIMEOUT_ERROR, level=ErrorLevel.ERROR, ) ) except Exception as ex: logger.error("Query %d: %s", query.id, type(ex), exc_info=True) logger.debug("Query %d: %s", query.id, ex) raise SqlLabException(db_engine_spec.extract_error_message(ex)) logger.debug("Query %d: Fetching cursor description", query.id) cursor_description = cursor.description return SupersetResultSet(data, cursor_description, db_engine_spec)
def test_cancel_query_failed(self, engine_mock): query = Query() cursor_mock = engine_mock.raiseError.side_effect = Exception() assert SnowflakeEngineSpec.cancel_query(cursor_mock, query, 123) is False
def session_with_data(session: Session) -> Iterator[Session]: from superset.columns.models import Column from superset.connectors.sqla.models import SqlaTable, TableColumn from superset.datasets.models import Dataset from superset.models.core import Database from superset.models.sql_lab import Query, SavedQuery from superset.tables.models import Table engine = session.get_bind() SqlaTable.metadata.create_all(engine) # pylint: disable=no-member db = Database(database_name="my_database", sqlalchemy_uri="sqlite://") columns = [ TableColumn(column_name="a", type="INTEGER"), ] sqla_table = SqlaTable( table_name="my_sqla_table", columns=columns, metrics=[], database=db, ) query_obj = Query( client_id="foo", database=db, tab_name="test_tab", sql_editor_id="test_editor_id", sql="select * from bar", select_sql="select * from bar", executed_sql="select * from bar", limit=100, select_as_cta=False, rows=100, error_message="none", results_key="abc", ) saved_query = SavedQuery(database=db, sql="select * from foo") table = Table( name="my_table", schema="my_schema", catalog="my_catalog", database=db, columns=[], ) dataset = Dataset( database=table.database, name="positions", expression=""" SELECT array_agg(array[longitude,latitude]) AS position FROM my_catalog.my_schema.my_table """, tables=[table], columns=[ Column( name="position", expression="array_agg(array[longitude,latitude])", ), ], ) session.add(dataset) session.add(table) session.add(saved_query) session.add(query_obj) session.add(db) session.add(sqla_table) session.flush() yield session
def test_get_cancel_query_id(self, engine_mock): query = Query() cursor_mock = engine_mock.return_value.__enter__.return_value cursor_mock.fetchone.return_value = [123] assert SnowflakeEngineSpec.get_cancel_query_id(cursor_mock, query) == 123
def get_table_metadata(database: Database, table_name: str, schema_name: Optional[str]) -> Dict: """ Get table metadata information, including type, pk, fks. This function raises SQLAlchemyError when a schema is not found. :param database: The database model :param table_name: Table name :param schema_name: schema name :return: Dict table metadata ready for API response """ keys: List = [] columns = database.get_columns(table_name, schema_name) # define comment dict by tsl comment_dict = {} primary_key = database.get_pk_constraint(table_name, schema_name) if primary_key and primary_key.get("constrained_columns"): primary_key["column_names"] = primary_key.pop("constrained_columns") primary_key["type"] = "pk" keys += [primary_key] # get dialect name dialect_name = database.get_dialect().name if isinstance(dialect_name, bytes): dialect_name = dialect_name.decode() # get column comment, presto & hive if dialect_name == "presto" or dialect_name == "hive": db_engine_spec = database.db_engine_spec sql = ParsedQuery("desc {a}.{b}".format(a=schema_name, b=table_name)).stripped() engine = database.get_sqla_engine(schema_name) conn = engine.raw_connection() cursor = conn.cursor() query = Query() session = Session(bind=engine) query.executed_sql = sql query.__tablename__ = table_name session.commit() db_engine_spec.execute(cursor, sql, async_=False) data = db_engine_spec.fetch_data(cursor, query.limit) # parse list data into dict by tsl; hive and presto is different if dialect_name == "presto": for d in data: d[3] comment_dict[d[0]] = d[3] else: for d in data: d[2] comment_dict[d[0]] = d[2] conn.commit() foreign_keys = get_foreign_keys_metadata(database, table_name, schema_name) indexes = get_indexes_metadata(database, table_name, schema_name) keys += foreign_keys + indexes payload_columns: List[Dict] = [] for col in columns: dtype = get_col_type(col) if len(comment_dict) > 0: payload_columns.append({ "name": col["name"], "type": dtype.split("(")[0] if "(" in dtype else dtype, "longType": dtype, "keys": [k for k in keys if col["name"] in k.get("column_names")], "comment": comment_dict[col["name"]], }) elif dialect_name == "mysql": payload_columns.append({ "name": col["name"], "type": dtype.split("(")[0] if "(" in dtype else dtype, "longType": dtype, "keys": [k for k in keys if col["name"] in k.get("column_names")], "comment": col["comment"], }) else: payload_columns.append({ "name": col["name"], "type": dtype.split("(")[0] if "(" in dtype else dtype, "longType": dtype, "keys": [k for k in keys if col["name"] in k.get("column_names")], # "comment": col["comment"], }) return { "name": table_name, "columns": payload_columns, "selectStar": database.select_star( table_name, schema=schema_name, show_cols=True, indent=True, cols=columns, latest_partition=True, ), "primaryKey": primary_key, "foreignKeys": foreign_keys, "indexes": keys, }
def handle_cursor( # pylint: disable=too-many-locals cls, cursor: Any, query: Query, session: Session) -> None: """Updates progress information""" from pyhive import hive unfinished_states = ( hive.ttypes.TOperationState.INITIALIZED_STATE, hive.ttypes.TOperationState.RUNNING_STATE, ) polled = cursor.poll() last_log_line = 0 tracking_url = None job_id = None query_id = query.id while polled.operationState in unfinished_states: query = session.query(type(query)).filter_by(id=query_id).one() if query.status == QueryStatus.STOPPED: cursor.cancel() break log = cursor.fetch_logs() or "" if log: log_lines = log.splitlines() progress = cls.progress(log_lines) logger.info("Query %s: Progress total: %s", str(query_id), str(progress)) needs_commit = False if progress > query.progress: query.progress = progress needs_commit = True if not tracking_url: tracking_url = cls.get_tracking_url(log_lines) if tracking_url: job_id = tracking_url.split("/")[-2] logger.info( "Query %s: Found the tracking url: %s", str(query_id), tracking_url, ) tracking_url = current_app.config[ "TRACKING_URL_TRANSFORMER"] logger.info( "Query %s: Transformation applied: %s", str(query_id), tracking_url, ) query.tracking_url = tracking_url logger.info("Query %s: Job id: %s", str(query_id), str(job_id)) needs_commit = True if job_id and len(log_lines) > last_log_line: # Wait for job id before logging things out # this allows for prefixing all log lines and becoming # searchable in something like Kibana for l in log_lines[last_log_line:]: logger.info("Query %s: [%s] %s", str(query_id), str(job_id), l) last_log_line = len(log_lines) if needs_commit: session.commit() time.sleep(current_app.config["HIVE_POLL_INTERVAL"]) polled = cursor.poll()