def test_override_user( app_context: AppContext, mocker: MockFixture, username: str, force: bool, ) -> None: mock_g = mocker.patch("superset.utils.core.g", spec={}) admin = security_manager.find_user(username="******") user = security_manager.find_user(username) with override_user(user, force): assert mock_g.user == user assert not hasattr(mock_g, "user") mock_g.user = None with override_user(user, force): assert mock_g.user == user assert mock_g.user is None mock_g.user = admin with override_user(user, force): assert mock_g.user == user if force else admin assert mock_g.user == admin
def get_sql_results( # pylint: disable=too-many-arguments ctask: Task, query_id: int, rendered_query: str, return_results: bool = True, store_results: bool = False, username: Optional[str] = None, start_time: Optional[float] = None, expand_data: bool = False, log_params: Optional[Dict[str, Any]] = None, ) -> Optional[Dict[str, Any]]: """Executes the sql query returns the results.""" with session_scope(not ctask.request.called_directly) as session: with override_user(security_manager.find_user(username)): try: return execute_sql_statements( query_id, rendered_query, return_results, store_results, session=session, start_time=start_time, expand_data=expand_data, log_params=log_params, ) except Exception as ex: # pylint: disable=broad-except logger.debug("Query %d: %s", query_id, ex) stats_logger.incr("error_sqllab_unhandled") query = get_query(query_id, session) return handle_query_error(ex, query, session)
def test_impersonate_user_trino(self, mocked_create_engine): principal_user = security_manager.find_user(username="******") with override_user(principal_user): model = Database(database_name="test_database", sqlalchemy_uri="trino://localhost") model.impersonate_user = True model.get_sqla_engine() call_args = mocked_create_engine.call_args assert str(call_args[0][0]) == "trino://localhost" assert call_args[1]["connect_args"] == {"user": "******"} model = Database( database_name="test_database", sqlalchemy_uri= "trino://*****:*****@localhost", ) model.impersonate_user = True model.get_sqla_engine() call_args = mocked_create_engine.call_args assert (str(call_args[0][0]) == "trino://*****:*****@localhost") assert call_args[1]["connect_args"] == {"user": "******"}
def load_chart_data_into_cache( job_metadata: Dict[str, Any], form_data: Dict[str, Any], ) -> None: # pylint: disable=import-outside-toplevel from superset.charts.data.commands.get_data_command import ChartDataCommand user = (security_manager.get_user_by_id(job_metadata.get("user_id")) or security_manager.get_anonymous_user()) with override_user(user, force=False): try: set_form_data(form_data) query_context = _create_query_context_from_form(form_data) command = ChartDataCommand(query_context) result = command.run(cache=True) cache_key = result["cache_key"] result_url = f"/api/v1/chart/data/{cache_key}" async_query_manager.update_job( job_metadata, async_query_manager.STATUS_DONE, result_url=result_url, ) except SoftTimeLimitExceeded as ex: logger.warning( "A timeout occurred while loading chart data, error: %s", ex) raise ex except Exception as ex: # TODO: QueryContext should support SIP-40 style errors error = ex.message if hasattr(ex, "message") else str(ex) # type: ignore # pylint: disable=no-member errors = [{"message": error}] async_query_manager.update_job(job_metadata, async_query_manager.STATUS_ERROR, errors=errors) raise ex
def _execute_query(self) -> pd.DataFrame: """ Executes the actual alert SQL query template :return: A pandas dataframe :raises AlertQueryError: SQL query is not valid :raises AlertQueryTimeout: The SQL query received a celery soft timeout """ sql_template = jinja_context.get_template_processor( database=self._report_schedule.database) rendered_sql = sql_template.process_template(self._report_schedule.sql) try: limited_rendered_sql = self._report_schedule.database.apply_limit_to_sql( rendered_sql, ALERT_SQL_LIMIT) with override_user( security_manager.find_user( username=app.config["THUMBNAIL_SELENIUM_USER"])): start = default_timer() df = self._report_schedule.database.get_df( sql=limited_rendered_sql) stop = default_timer() logger.info( "Query for %s took %.2f ms", self._report_schedule.name, (stop - start) * 1000.0, ) return df except SoftTimeLimitExceeded as ex: logger.warning( "A timeout occurred while executing the alert query: %s", ex) raise AlertQueryTimeout() from ex except Exception as ex: raise AlertQueryError(message=str(ex)) from ex
def test_override_user(app_context: AppContext, username: str) -> None: admin = security_manager.find_user(username="******") user = security_manager.find_user(username) assert not hasattr(g, "user") with override_user(user): assert g.user == user assert not hasattr(g, "user") g.user = admin with override_user(user): assert g.user == user assert g.user == admin
def test_unsaved_chart_no_dataset_id() -> None: from superset.explore.utils import check_access as check_chart_access with raises(DatasourceNotFoundValidationError): with override_user(User()): check_chart_access( datasource_id=0, chart_id=0, datasource_type=DatasourceType.TABLE, )
def test_update_missing_entry(app_context: AppContext, admin: User) -> None: from superset.key_value.commands.update import UpdateKeyValueCommand with override_user(admin): key = UpdateKeyValueCommand( resource=RESOURCE, key=456, value=NEW_VALUE, ).run() assert key is None
def test_unsaved_chart_unknown_query_id(mocker: MockFixture) -> None: from superset.explore.utils import check_access as check_chart_access with raises(QueryNotFoundValidationError): mocker.patch(query_find_by_id, return_value=None) with override_user(User()): check_chart_access( datasource_id=1, chart_id=0, datasource_type=DatasourceType.QUERY, )
def test_database_impersonate_user(self): uri = "mysql://root@localhost" example_user = security_manager.find_user(username="******") model = Database(database_name="test_database", sqlalchemy_uri=uri) with override_user(example_user): model.impersonate_user = True username = make_url(model.get_sqla_engine().url).username self.assertEqual(example_user.username, username) model.impersonate_user = False username = make_url(model.get_sqla_engine().url).username self.assertNotEqual(example_user.username, username)
def test_unsaved_chart_authorized_dataset(mocker: MockFixture) -> None: from superset.connectors.sqla.models import SqlaTable from superset.explore.utils import check_access as check_chart_access mocker.patch(dataset_find_by_id, return_value=SqlaTable()) mocker.patch(can_access_datasource, return_value=True) with override_user(User()): check_chart_access( datasource_id=1, chart_id=0, datasource_type=DatasourceType.TABLE, )
def test_create_uuid_entry(app_context: AppContext, admin: User) -> None: from superset.key_value.commands.create import CreateKeyValueCommand from superset.key_value.models import KeyValueEntry with override_user(admin): key = CreateKeyValueCommand(resource=RESOURCE, value=VALUE).run() entry = ( db.session.query(KeyValueEntry).filter_by(uuid=key.uuid).autoflush(False).one() ) assert pickle.loads(entry.value) == VALUE assert entry.created_by_fk == admin.id db.session.delete(entry) db.session.commit()
def test_unsaved_chart_unknown_dataset_id( mocker: MockFixture, app_context: AppContext ) -> None: from superset.explore.utils import check_access as check_chart_access with raises(DatasetNotFoundError): mocker.patch(dataset_find_by_id, return_value=None) with override_user(User()): check_chart_access( datasource_id=1, chart_id=0, datasource_type=DatasourceType.TABLE, )
def test_upsert_missing_entry(app_context: AppContext, admin: User) -> None: from superset.key_value.commands.upsert import UpsertKeyValueCommand from superset.key_value.models import KeyValueEntry with override_user(admin): key = UpsertKeyValueCommand( resource=RESOURCE, key=456, value=NEW_VALUE, ).run() assert key is not None assert key.id == 456 db.session.query(KeyValueEntry).filter_by(id=456).delete() db.session.commit()
def test_saved_chart_unknown_chart_id(mocker: MockFixture) -> None: from superset.connectors.sqla.models import SqlaTable from superset.explore.utils import check_access as check_chart_access with raises(ChartNotFoundError): mocker.patch(dataset_find_by_id, return_value=SqlaTable()) mocker.patch(can_access_datasource, return_value=True) mocker.patch(chart_find_by_id, return_value=None) with override_user(User()): check_chart_access( datasource_id=1, chart_id=1, datasource_type=DatasourceType.TABLE, )
def test_impersonate_user_hive(self, mocked_create_engine): uri = "hive://localhost" principal_user = security_manager.find_user(username="******") extra = """ { "metadata_params": {}, "engine_params": { "connect_args":{ "protocol": "https", "username":"******", "password":"******" } }, "metadata_cache_timeout": {}, "schemas_allowed_for_file_upload": [] } """ with override_user(principal_user): model = Database(database_name="test_database", sqlalchemy_uri=uri, extra=extra) model.impersonate_user = True model.get_sqla_engine() call_args = mocked_create_engine.call_args assert str(call_args[0][0]) == "hive://localhost" assert call_args[1]["connect_args"] == { "protocol": "https", "username": "******", "password": "******", "configuration": { "hive.server2.proxy.user": "******" }, } model.impersonate_user = False model.get_sqla_engine() call_args = mocked_create_engine.call_args assert str(call_args[0][0]) == "hive://localhost" assert call_args[1]["connect_args"] == { "protocol": "https", "username": "******", "password": "******", }
def test_saved_chart_is_admin(mocker: MockFixture, app_context: AppContext) -> None: from superset.connectors.sqla.models import SqlaTable from superset.explore.utils import check_access as check_chart_access from superset.models.slice import Slice mocker.patch(dataset_find_by_id, return_value=SqlaTable()) mocker.patch(can_access_datasource, return_value=True) mocker.patch(is_admin, return_value=True) mocker.patch(chart_find_by_id, return_value=Slice()) with override_user(User()): check_chart_access( datasource_id=1, chart_id=1, datasource_type=DatasourceType.TABLE, )
def test_unsaved_chart_unauthorized_dataset( mocker: MockFixture, app_context: AppContext ) -> None: from superset.connectors.sqla.models import SqlaTable from superset.explore.utils import check_access as check_chart_access with raises(DatasetAccessDeniedError): mocker.patch(dataset_find_by_id, return_value=SqlaTable()) mocker.patch(can_access_datasource, return_value=False) with override_user(User()): check_chart_access( datasource_id=1, chart_id=0, datasource_type=DatasourceType.TABLE, )
def test_update_id_entry( app_context: AppContext, admin: User, key_value_entry: KeyValueEntry, ) -> None: from superset.key_value.commands.update import UpdateKeyValueCommand from superset.key_value.models import KeyValueEntry with override_user(admin): key = UpdateKeyValueCommand( resource=RESOURCE, key=ID_KEY, value=NEW_VALUE, ).run() assert key is not None assert key.id == ID_KEY entry = db.session.query(KeyValueEntry).filter_by(id=ID_KEY).autoflush(False).one() assert pickle.loads(entry.value) == NEW_VALUE assert entry.changed_by_fk == admin.id
def test_saved_chart_no_access(mocker: MockFixture) -> None: from superset.connectors.sqla.models import SqlaTable from superset.explore.utils import check_access as check_chart_access from superset.models.slice import Slice with raises(ChartAccessDeniedError): mocker.patch(dataset_find_by_id, return_value=SqlaTable()) mocker.patch(can_access_datasource, return_value=True) mocker.patch(is_admin, return_value=False) mocker.patch(is_owner, return_value=False) mocker.patch(can_access, return_value=False) mocker.patch(chart_find_by_id, return_value=Slice()) with override_user(User()): check_chart_access( datasource_id=1, chart_id=1, datasource_type=DatasourceType.TABLE, )
def update_datasources_cache(username: Optional[str]) -> None: """Refresh sqllab datasources cache""" # pylint: disable=import-outside-toplevel from superset import security_manager from superset.models.core import Database with override_user(security_manager.find_user(username)): for database in db.session.query(Database).all(): if database.allow_multi_schema_metadata_fetch: print("Fetching {} datasources ...".format(database.name)) try: database.get_all_table_names_in_database(force=True, cache=True, cache_timeout=24 * 60 * 60) database.get_all_view_names_in_database(force=True, cache=True, cache_timeout=24 * 60 * 60) except Exception as ex: # pylint: disable=broad-except print("{}".format(str(ex)))
def run(self) -> None: self.validate() uri = self._properties.get("sqlalchemy_uri", "") if self._model and uri == self._model.safe_sqlalchemy_uri(): uri = self._model.sqlalchemy_uri_decrypted # context for error messages url = make_url_safe(uri) context = { "hostname": url.host, "password": url.password, "port": url.port, "username": url.username, "database": url.database, } try: database = DatabaseDAO.build_db_for_connection_test( server_cert=self._properties.get("server_cert", ""), extra=self._properties.get("extra", "{}"), impersonate_user=self._properties.get("impersonate_user", False), encrypted_extra=self._properties.get("encrypted_extra", "{}"), ) database.set_sqlalchemy_uri(uri) database.db_engine_spec.mutate_db_for_connection_test(database) with override_user(self._actor): engine = database.get_sqla_engine() event_logger.log_with_context( action="test_connection_attempt", engine=database.db_engine_spec.__name__, ) def ping(engine: Engine) -> bool: with closing(engine.raw_connection()) as conn: return engine.dialect.do_ping(conn) try: alive = func_timeout( int(app.config["TEST_DATABASE_CONNECTION_TIMEOUT"]. total_seconds()), ping, args=(engine, ), ) except (sqlite3.ProgrammingError, RuntimeError): # SQLite can't run on a separate thread, so ``func_timeout`` fails # RuntimeError catches the equivalent error from duckdb. alive = engine.dialect.do_ping(engine) except FunctionTimedOut as ex: raise SupersetTimeoutException( error_type=SupersetErrorType. CONNECTION_DATABASE_TIMEOUT, message= ("Please check your connection details and database settings, " "and ensure that your database is accepting connections, " "then try connecting again."), level=ErrorLevel.ERROR, extra={"sqlalchemy_uri": database.sqlalchemy_uri}, ) from ex except Exception: # pylint: disable=broad-except alive = False if not alive: raise DBAPIError(None, None, None) # Log succesful connection test with engine event_logger.log_with_context( action="test_connection_success", engine=database.db_engine_spec.__name__, ) except (NoSuchModuleError, ModuleNotFoundError) as ex: event_logger.log_with_context( action=f"test_connection_error.{ex.__class__.__name__}", engine=database.db_engine_spec.__name__, ) raise DatabaseTestConnectionDriverError( message=_("Could not load database driver: {}").format( database.db_engine_spec.__name__), ) from ex except DBAPIError as ex: event_logger.log_with_context( action=f"test_connection_error.{ex.__class__.__name__}", engine=database.db_engine_spec.__name__, ) # check for custom errors (wrong username, wrong password, etc) errors = database.db_engine_spec.extract_errors(ex, context) raise DatabaseTestConnectionFailedError(errors) from ex except SupersetSecurityException as ex: event_logger.log_with_context( action=f"test_connection_error.{ex.__class__.__name__}", engine=database.db_engine_spec.__name__, ) raise DatabaseSecurityUnsafeError(message=str(ex)) from ex except SupersetTimeoutException as ex: event_logger.log_with_context( action=f"test_connection_error.{ex.__class__.__name__}", engine=database.db_engine_spec.__name__, ) # bubble up the exception to return a 408 raise ex except Exception as ex: event_logger.log_with_context( action=f"test_connection_error.{ex.__class__.__name__}", engine=database.db_engine_spec.__name__, ) errors = database.db_engine_spec.extract_errors(ex, context) raise DatabaseTestConnectionUnexpectedError(errors) from ex
def load_explore_json_into_cache( # pylint: disable=too-many-locals job_metadata: Dict[str, Any], form_data: Dict[str, Any], response_type: Optional[str] = None, force: bool = False, ) -> None: cache_key_prefix = "ejr-" # ejr: explore_json request user = (security_manager.get_user_by_id(job_metadata.get("user_id")) or security_manager.get_anonymous_user()) with override_user(user, force=False): try: set_form_data(form_data) datasource_id, datasource_type = get_datasource_info( None, None, form_data) # Perform a deep copy here so that below we can cache the original # value of the form_data object. This is necessary since the viz # objects modify the form_data object. If the modified version were # to be cached here, it will lead to a cache miss when clients # attempt to retrieve the value of the completed async query. original_form_data = copy.deepcopy(form_data) viz_obj = get_viz( datasource_type=cast(str, datasource_type), datasource_id=datasource_id, form_data=form_data, force=force, ) # run query & cache results payload = viz_obj.get_payload() if viz_obj.has_error(payload): raise SupersetVizException(errors=payload["errors"]) # Cache the original form_data value for async retrieval cache_value = { "form_data": original_form_data, "response_type": response_type, } cache_key = generate_cache_key(cache_value, cache_key_prefix) set_and_log_cache(cache_manager.cache, cache_key, cache_value) result_url = f"/superset/explore_json/data/{cache_key}" async_query_manager.update_job( job_metadata, async_query_manager.STATUS_DONE, result_url=result_url, ) except SoftTimeLimitExceeded as ex: logger.warning( "A timeout occurred while loading explore json, error: %s", ex) raise ex except Exception as ex: if isinstance(ex, SupersetVizException): errors = ex.errors # pylint: disable=no-member else: error = ex.message if hasattr(ex, "message") else str(ex) # type: ignore # pylint: disable=no-member errors = [error] async_query_manager.update_job(job_metadata, async_query_manager.STATUS_ERROR, errors=errors) raise ex
def test_sql_lab_insert_rls( mocker: MockerFixture, session: Session, app_context: None, ) -> None: """ Integration test for `insert_rls`. """ from flask_appbuilder.security.sqla.models import Role, User from superset.connectors.sqla.models import RowLevelSecurityFilter, SqlaTable from superset.models.core import Database from superset.models.sql_lab import Query from superset.security.manager import SupersetSecurityManager from superset.sql_lab import execute_sql_statement from superset.utils.core import RowLevelSecurityFilterType engine = session.connection().engine Query.metadata.create_all(engine) # pylint: disable=no-member connection = engine.raw_connection() connection.execute("CREATE TABLE t (c INTEGER)") for i in range(10): connection.execute("INSERT INTO t VALUES (?)", (i, )) cursor = connection.cursor() query = Query( sql="SELECT c FROM t", client_id="abcde", database=Database(database_name="test_db", sqlalchemy_uri="sqlite://"), schema=None, limit=5, select_as_cta_used=False, ) session.add(query) session.commit() admin = User( first_name="Alice", last_name="Doe", email="*****@*****.**", username="******", roles=[Role(name="Admin")], ) # first without RLS with override_user(admin): superset_result_set = execute_sql_statement( sql_statement=query.sql, query=query, session=session, cursor=cursor, log_params=None, apply_ctas=False, ) assert (superset_result_set.to_pandas_df().to_markdown() == """ | | c | |---:|----:| | 0 | 0 | | 1 | 1 | | 2 | 2 | | 3 | 3 | | 4 | 4 |""".strip()) assert query.executed_sql == "SELECT c FROM t\nLIMIT 6" # now with RLS rls = RowLevelSecurityFilter( filter_type=RowLevelSecurityFilterType.REGULAR, tables=[SqlaTable(database_id=1, schema=None, table_name="t")], roles=[admin.roles[0]], group_key=None, clause="c > 5", ) session.add(rls) session.flush() mocker.patch.object(SupersetSecurityManager, "find_user", return_value=admin) mocker.patch("superset.sql_lab.is_feature_enabled", return_value=True) with override_user(admin): superset_result_set = execute_sql_statement( sql_statement=query.sql, query=query, session=session, cursor=cursor, log_params=None, apply_ctas=False, ) assert (superset_result_set.to_pandas_df().to_markdown() == """ | | c | |---:|----:| | 0 | 6 | | 1 | 7 | | 2 | 8 | | 3 | 9 |""".strip()) assert query.executed_sql == "SELECT c FROM t WHERE (t.c > 5)\nLIMIT 6"
def run(self) -> None: engine = self._properties["engine"] engine_specs = get_engine_specs() if engine in BYPASS_VALIDATION_ENGINES: # Skip engines that are only validated onCreate return if engine not in engine_specs: raise InvalidEngineError( SupersetError( message=__( 'Engine "%(engine)s" is not a valid engine.', engine=engine, ), error_type=SupersetErrorType.GENERIC_DB_ENGINE_ERROR, level=ErrorLevel.ERROR, extra={ "allowed": list(engine_specs), "provided": engine }, ), ) engine_spec = engine_specs[engine] if not hasattr(engine_spec, "parameters_schema"): raise InvalidEngineError( SupersetError( message=__( 'Engine "%(engine)s" cannot be configured through parameters.', engine=engine, ), error_type=SupersetErrorType.GENERIC_DB_ENGINE_ERROR, level=ErrorLevel.ERROR, extra={ "allowed": [ name for name, engine_spec in engine_specs.items() if issubclass(engine_spec, BasicParametersMixin) ], "provided": engine, }, ), ) # perform initial validation errors = engine_spec.validate_parameters( # type: ignore self._properties.get("parameters", {})) if errors: event_logger.log_with_context(action="validation_error", engine=engine) raise InvalidParametersError(errors) serialized_encrypted_extra = self._properties.get( "encrypted_extra", "{}") try: encrypted_extra = json.loads(serialized_encrypted_extra) except json.decoder.JSONDecodeError: encrypted_extra = {} # try to connect sqlalchemy_uri = engine_spec.build_sqlalchemy_uri( # type: ignore self._properties.get("parameters"), encrypted_extra, ) if self._model and sqlalchemy_uri == self._model.safe_sqlalchemy_uri(): sqlalchemy_uri = self._model.sqlalchemy_uri_decrypted database = DatabaseDAO.build_db_for_connection_test( server_cert=self._properties.get("server_cert", ""), extra=self._properties.get("extra", "{}"), impersonate_user=self._properties.get("impersonate_user", False), encrypted_extra=serialized_encrypted_extra, ) database.set_sqlalchemy_uri(sqlalchemy_uri) database.db_engine_spec.mutate_db_for_connection_test(database) with override_user(self._actor): engine = database.get_sqla_engine() try: with closing(engine.raw_connection()) as conn: alive = engine.dialect.do_ping(conn) except Exception as ex: url = make_url_safe(sqlalchemy_uri) context = { "hostname": url.host, "password": url.password, "port": url.port, "username": url.username, "database": url.database, } errors = database.db_engine_spec.extract_errors(ex, context) raise DatabaseTestConnectionFailedError(errors) from ex if not alive: raise DatabaseOfflineError( SupersetError( message=__("Database is offline."), error_type=SupersetErrorType.GENERIC_DB_ENGINE_ERROR, level=ErrorLevel.ERROR, ), )