def rejected_tables(self, sql: str, database: "Database", schema: str) -> Set["Table"]: """ Return the list of rejected SQL tables. :param sql: The SQL statement :param database: The SQL database :param schema: The SQL database schema :returns: The rejected tables """ from superset.sql_parse import Table return { table for table in sql_parse.ParsedQuery(sql).tables if not self.can_access_table( database, Table(table.table, table.schema or schema)) }
def test_create_table_from_csv_if_exists_replace_with_schema( mock_upload_to_s3, mock_table, mock_g ): mock_upload_to_s3.return_value = "mock-location" mock_table.infer.return_value = {} mock_g.user = True mock_database = mock.MagicMock() mock_database.get_df.return_value.empty = False mock_execute = mock.MagicMock(return_value=True) mock_database.get_sqla_engine.return_value.execute = mock_execute table_name = "foobar" schema = "schema" HiveEngineSpec.create_table_from_csv( "foo.csv", Table(table=table_name, schema=schema), mock_database, {"sep": "mock", "header": 1, "na_values": "mock"}, {"if_exists": "replace"}, ) mock_execute.assert_any_call(f"DROP TABLE IF EXISTS {schema}.{table_name}")
def test_extract_tables_complex_cte_with_prefix() -> None: """ Test that the parser handles CTEs with prefixes. """ assert ( extract_tables( """ WITH CTE__test (SalesPersonID, SalesOrderID, SalesYear) AS ( SELECT SalesPersonID, SalesOrderID, YEAR(OrderDate) AS SalesYear FROM SalesOrderHeader WHERE SalesPersonID IS NOT NULL ) SELECT SalesPersonID, COUNT(SalesOrderID) AS TotalSales, SalesYear FROM CTE__test GROUP BY SalesYear, SalesPersonID ORDER BY SalesPersonID, SalesYear; """ ) == {Table("SalesOrderHeader")} )
def test_extract_tables_union() -> None: """ Test that ``UNION`` queries work as expected. """ assert extract_tables("SELECT * FROM t1 UNION SELECT * FROM t2") == { Table("t1"), Table("t2"), } assert extract_tables("SELECT * FROM t1 UNION ALL SELECT * FROM t2") == { Table("t1"), Table("t2"), } assert extract_tables("SELECT * FROM t1 INTERSECT ALL SELECT * FROM t2") == { Table("t1"), Table("t2"), }
def test_df_to_sql_if_exists_replace(mock_upload_to_s3, mock_g): mock_upload_to_s3.return_value = "mock-location" mock_g.user = True mock_database = mock.MagicMock() mock_database.get_df.return_value.empty = False mock_execute = mock.MagicMock(return_value=True) mock_database.get_sqla_engine.return_value.execute = mock_execute table_name = "foobar" HiveEngineSpec.df_to_sql( mock_database, Table(table=table_name), pd.DataFrame(), { "if_exists": "replace", "header": 1, "na_values": "mock", "sep": "mock" }, ) mock_execute.assert_any_call(f"DROP TABLE IF EXISTS {table_name}")
def test_complex_extract_tables3(self): query = """SELECT somecol AS somecol FROM (WITH bla AS (SELECT col_a FROM a WHERE 1=1 AND column_of_choice NOT IN ( SELECT interesting_col FROM b ) ), rb AS ( SELECT yet_another_column FROM ( SELECT a FROM c GROUP BY the_other_col ) not_table LEFT JOIN bla foo ON foo.prop = not_table.bad_col0 WHERE 1=1 GROUP BY not_table.bad_col1 , not_table.bad_col2 , ORDER BY not_table.bad_col_3 DESC , not_table.bad_col4 , not_table.bad_col5) SELECT random_col FROM d WHERE 1=1 UNION ALL SELECT even_more_cols FROM e WHERE 1=1 UNION ALL SELECT lets_go_deeper FROM f WHERE 1=1 WHERE 2=2 GROUP BY last_col LIMIT 50000;""" self.assertEqual( { Table("a"), Table("b"), Table("c"), Table("d"), Table("e"), Table("f") }, self.extract_tables(query), )
def test_subselect(self): query = """ SELECT sub.* FROM ( SELECT * FROM s1.t1 WHERE day_of_week = 'Friday' ) sub, s2.t2 WHERE sub.resolution = 'NONE' """ self.assertEqual( {Table("t1", "s1"), Table("t2", "s2")}, self.extract_tables(query) ) query = """ SELECT sub.* FROM ( SELECT * FROM s1.t1 WHERE day_of_week = 'Friday' ) sub WHERE sub.resolution = 'NONE' """ self.assertEqual({Table("t1", "s1")}, self.extract_tables(query)) query = """ SELECT * FROM t1 WHERE s11 > ANY (SELECT COUNT(*) /* no hint */ FROM t2 WHERE NOT EXISTS (SELECT * FROM t3 WHERE ROW(5*t2.s1,77)= (SELECT 50,11*s1 FROM t4))); """ self.assertEqual( {Table("t1"), Table("t2"), Table("t3"), Table("t4")}, self.extract_tables(query), )
def test_extract_tables_subselect() -> None: """ Test that tables inside subselects are parsed correctly. """ assert (extract_tables(""" SELECT sub.* FROM ( SELECT * FROM s1.t1 WHERE day_of_week = 'Friday' ) sub, s2.t2 WHERE sub.resolution = 'NONE' """) == {Table("t1", "s1"), Table("t2", "s2")}) assert (extract_tables(""" SELECT sub.* FROM ( SELECT * FROM s1.t1 WHERE day_of_week = 'Friday' ) sub WHERE sub.resolution = 'NONE' """) == {Table("t1", "s1")}) assert (extract_tables(""" SELECT * FROM t1 WHERE s11 > ANY ( SELECT COUNT(*) /* no hint */ FROM t2 WHERE NOT EXISTS ( SELECT * FROM t3 WHERE ROW(5*t2.s1,77)=( SELECT 50,11*s1 FROM t4 ) ) ) """) == {Table("t1"), Table("t2"), Table("t3"), Table("t4")})
def test_extract_tables_where_subquery() -> None: """ Test that tables in a ``WHERE`` subquery are parsed correctly. """ assert (extract_tables(""" SELECT name FROM t1 WHERE regionkey = (SELECT max(regionkey) FROM t2) """) == {Table("t1"), Table("t2")}) assert (extract_tables(""" SELECT name FROM t1 WHERE regionkey IN (SELECT regionkey FROM t2) """) == {Table("t1"), Table("t2")}) assert (extract_tables(""" SELECT name FROM t1 WHERE regionkey EXISTS (SELECT regionkey FROM t2) """) == {Table("t1"), Table("t2")})
def test_where_subquery(self): query = """ SELECT name FROM t1 WHERE regionkey = (SELECT max(regionkey) FROM t2) """ self.assertEqual({Table("t1"), Table("t2")}, self.extract_tables(query)) query = """ SELECT name FROM t1 WHERE regionkey IN (SELECT regionkey FROM t2) """ self.assertEqual({Table("t1"), Table("t2")}, self.extract_tables(query)) query = """ SELECT name FROM t1 WHERE regionkey EXISTS (SELECT regionkey FROM t2) """ self.assertEqual({Table("t1"), Table("t2")}, self.extract_tables(query))
def test_show_columns(self): query = "SHOW COLUMNS FROM t1" self.assertEqual({Table("t1")}, self.extract_tables(query))
def test_show_tables(self): query = "SHOW TABLES FROM s1 like '%order%'" # TODO: figure out what should code do here self.assertEqual({Table("s1")}, self.extract_tables(query))
def test_select_if(self): query = """ SELECT IF(CARDINALITY(my_array) >= 3, my_array[3], NULL) FROM t1 LIMIT 10 """ self.assertEqual({Table("t1")}, self.extract_tables(query))
def test_select_array(self): query = """ SELECT ARRAY[1, 2, 3] AS my_array FROM t1 LIMIT 10 """ self.assertEqual({Table("t1")}, self.extract_tables(query))
def test_select_in_expression(self): query = "SELECT f1, (SELECT count(1) FROM t2) FROM t1" self.assertEqual({Table("t1"), Table("t2")}, self.extract_tables(query))
def test_extract_tables_describe() -> None: """ Test ``DESCRIBE``. """ assert extract_tables("DESCRIBE t1") == {Table("t1")}
def test_show_partitions(self): query = """ SHOW PARTITIONS FROM orders WHERE ds >= '2013-01-01' ORDER BY ds DESC; """ self.assertEqual({Table("orders")}, self.extract_tables(query))
def test_select_named_table(self): query = "SELECT a.date, a.field FROM left_table a LIMIT 10" self.assertEqual({Table("left_table")}, self.extract_tables(query))
def test_extract_tables_parenthesis() -> None: """ Test that parenthesis are parsed correctly. """ assert extract_tables("SELECT f1, (x + y) AS f2 FROM t1") == {Table("t1")}
def test_multistatement(self): query = "SELECT * FROM t1; SELECT * FROM t2" self.assertEqual({Table("t1"), Table("t2")}, self.extract_tables(query)) query = "SELECT * FROM t1; SELECT * FROM t2;" self.assertEqual({Table("t1"), Table("t2")}, self.extract_tables(query))
def test_extract_tables() -> None: """ Test that referenced tables are parsed correctly from the SQL. """ assert extract_tables("SELECT * FROM tbname") == {Table("tbname")} assert extract_tables("SELECT * FROM tbname foo") == {Table("tbname")} assert extract_tables("SELECT * FROM tbname AS foo") == {Table("tbname")} # underscore assert extract_tables("SELECT * FROM tb_name") == {Table("tb_name")} # quotes assert extract_tables('SELECT * FROM "tbname"') == {Table("tbname")} # unicode assert extract_tables('SELECT * FROM "tb_name" WHERE city = "Lübeck"') == { Table("tb_name") } # columns assert extract_tables("SELECT field1, field2 FROM tb_name") == {Table("tb_name")} assert extract_tables("SELECT t1.f1, t2.f2 FROM t1, t2") == { Table("t1"), Table("t2"), } # named table assert extract_tables("SELECT a.date, a.field FROM left_table a LIMIT 10") == { Table("left_table") } # reverse select assert extract_tables("FROM t1 SELECT field") == {Table("t1")}
def test_extract_tables_complex() -> None: """ Test a few complex queries. """ assert ( extract_tables( """ SELECT sum(m_examples) AS "sum__m_example" FROM ( SELECT COUNT(DISTINCT id_userid) AS m_examples, some_more_info FROM my_b_table b JOIN my_t_table t ON b.ds=t.ds JOIN my_l_table l ON b.uid=l.uid WHERE b.rid IN ( SELECT other_col FROM inner_table ) AND l.bla IN ('x', 'y') GROUP BY 2 ORDER BY 2 ASC ) AS "meh" ORDER BY "sum__m_example" DESC LIMIT 10; """ ) == { Table("my_l_table"), Table("my_b_table"), Table("my_t_table"), Table("inner_table"), } ) assert ( extract_tables( """ SELECT * FROM table_a AS a, table_b AS b, table_c as c WHERE a.id = b.id and b.id = c.id """ ) == {Table("table_a"), Table("table_b"), Table("table_c")} ) assert ( extract_tables( """ SELECT somecol AS somecol FROM ( WITH bla AS ( SELECT col_a FROM a WHERE 1=1 AND column_of_choice NOT IN ( SELECT interesting_col FROM b ) ), rb AS ( SELECT yet_another_column FROM ( SELECT a FROM c GROUP BY the_other_col ) not_table LEFT JOIN bla foo ON foo.prop = not_table.bad_col0 WHERE 1=1 GROUP BY not_table.bad_col1 , not_table.bad_col2 , ORDER BY not_table.bad_col_3 DESC , not_table.bad_col4 , not_table.bad_col5 ) SELECT random_col FROM d WHERE 1=1 UNION ALL SELECT even_more_cols FROM e WHERE 1=1 UNION ALL SELECT lets_go_deeper FROM f WHERE 1=1 WHERE 2=2 GROUP BY last_col LIMIT 50000 ) """ ) == {Table("a"), Table("b"), Table("c"), Table("d"), Table("e"), Table("f")} )
def test_extract_tables_join() -> None: """ Test joins. """ assert extract_tables("SELECT t1.*, t2.* FROM t1 JOIN t2 ON t1.a = t2.a;") == { Table("t1"), Table("t2"), } assert ( extract_tables( """ SELECT a.date, b.name FROM left_table a JOIN ( SELECT CAST((b.year) as VARCHAR) date, name FROM right_table ) b ON a.date = b.date """ ) == {Table("left_table"), Table("right_table")} ) assert ( extract_tables( """ SELECT a.date, b.name FROM left_table a LEFT INNER JOIN ( SELECT CAST((b.year) as VARCHAR) date, name FROM right_table ) b ON a.date = b.date """ ) == {Table("left_table"), Table("right_table")} ) assert ( extract_tables( """ SELECT a.date, b.name FROM left_table a RIGHT OUTER JOIN ( SELECT CAST((b.year) as VARCHAR) date, name FROM right_table ) b ON a.date = b.date """ ) == {Table("left_table"), Table("right_table")} ) assert ( extract_tables( """ SELECT a.date, b.name FROM left_table a FULL OUTER JOIN ( SELECT CAST((b.year) as VARCHAR) date, name FROM right_table ) b ON a.date = b.date """ ) == {Table("left_table"), Table("right_table")} )
def test_reverse_select(self): query = "FROM t1 SELECT field" self.assertEqual({Table("t1")}, self.extract_tables(query))
def test_describe(self): self.assertEqual({Table("t1")}, self.extract_tables("DESCRIBE t1"))
def form_post( # pylint: disable=too-many-locals self, form: ColumnarToDatabaseForm ) -> Response: database = form.con.data columnar_table = Table(table=form.name.data, schema=form.schema.data) files = form.columnar_file.data file_type = {file.filename.split(".")[-1] for file in files} if file_type == {"zip"}: zipfile_ob = zipfile.ZipFile( # pylint: disable=consider-using-with form.columnar_file.data[0] ) # pylint: disable=consider-using-with file_type = {filename.split(".")[-1] for filename in zipfile_ob.namelist()} files = [ io.BytesIO((zipfile_ob.open(filename).read(), filename)[0]) for filename in zipfile_ob.namelist() ] if len(file_type) > 1: message = _( "Multiple file extensions are not allowed for columnar uploads." " Please make sure all files are of the same extension.", ) flash(message, "danger") return redirect("/columnartodatabaseview/form") read = pd.read_parquet kwargs = { "columns": form.usecols.data if form.usecols.data else None, } if not schema_allows_csv_upload(database, columnar_table.schema): message = _( 'Database "%(database_name)s" schema "%(schema_name)s" ' "is not allowed for columnar uploads. " "Please contact your Superset Admin.", database_name=database.database_name, schema_name=columnar_table.schema, ) flash(message, "danger") return redirect("/columnartodatabaseview/form") if "." in columnar_table.table and columnar_table.schema: message = _( "You cannot specify a namespace both in the name of the table: " '"%(columnar_table.table)s" and in the schema field: ' '"%(columnar_table.schema)s". Please remove one', table=columnar_table.table, schema=columnar_table.schema, ) flash(message, "danger") return redirect("/columnartodatabaseview/form") try: chunks = [read(file, **kwargs) for file in files] df = pd.concat(chunks) database = ( db.session.query(models.Database) .filter_by(id=form.data.get("con").data.get("id")) .one() ) database.db_engine_spec.df_to_sql( database, columnar_table, df, to_sql_kwargs={ "chunksize": 1000, "if_exists": form.if_exists.data, "index": form.index.data, "index_label": form.index_label.data, }, ) # Connect table to the database that should be used for exploration. # E.g. if hive was used to upload a csv, presto will be a better option # to explore the table. expore_database = database explore_database_id = database.explore_database_id if explore_database_id: expore_database = ( db.session.query(models.Database) .filter_by(id=explore_database_id) .one_or_none() or database ) sqla_table = ( db.session.query(SqlaTable) .filter_by( table_name=columnar_table.table, schema=columnar_table.schema, database_id=expore_database.id, ) .one_or_none() ) if sqla_table: sqla_table.fetch_metadata() if not sqla_table: sqla_table = SqlaTable(table_name=columnar_table.table) sqla_table.database = expore_database sqla_table.database_id = database.id sqla_table.user_id = g.user.get_id() sqla_table.schema = columnar_table.schema sqla_table.fetch_metadata() db.session.add(sqla_table) db.session.commit() except Exception as ex: # pylint: disable=broad-except db.session.rollback() message = _( 'Unable to upload Columnar file "%(filename)s" to table ' '"%(table_name)s" in database "%(db_name)s". ' "Error message: %(error_msg)s", filename=[file.filename for file in form.columnar_file.data], table_name=form.name.data, db_name=database.database_name, error_msg=str(ex), ) flash(message, "danger") stats_logger.incr("failed_columnar_upload") return redirect("/columnartodatabaseview/form") # Go back to welcome page / splash screen message = _( 'Columnar file "%(columnar_filename)s" uploaded to table "%(table_name)s" ' 'in database "%(db_name)s"', columnar_filename=[file.filename for file in form.columnar_file.data], table_name=str(columnar_table), db_name=sqla_table.database.database_name, ) flash(message, "info") stats_logger.incr("successful_columnar_upload") return redirect("/tablemodelview/list/")
def test_join(self): query = "SELECT t1.*, t2.* FROM t1 JOIN t2 ON t1.a = t2.a;" self.assertEqual({Table("t1"), Table("t2")}, self.extract_tables(query)) # subquery + join query = """ SELECT a.date, b.name FROM left_table a JOIN ( SELECT CAST((b.year) as VARCHAR) date, name FROM right_table ) b ON a.date = b.date """ self.assertEqual( {Table("left_table"), Table("right_table")}, self.extract_tables(query) ) query = """ SELECT a.date, b.name FROM left_table a LEFT INNER JOIN ( SELECT CAST((b.year) as VARCHAR) date, name FROM right_table ) b ON a.date = b.date """ self.assertEqual( {Table("left_table"), Table("right_table")}, self.extract_tables(query) ) query = """ SELECT a.date, b.name FROM left_table a RIGHT OUTER JOIN ( SELECT CAST((b.year) as VARCHAR) date, name FROM right_table ) b ON a.date = b.date """ self.assertEqual( {Table("left_table"), Table("right_table")}, self.extract_tables(query) ) query = """ SELECT a.date, b.name FROM left_table a FULL OUTER JOIN ( SELECT CAST((b.year) as VARCHAR) date, name FROM right_table ) b ON a.date = b.date """ self.assertEqual( {Table("left_table"), Table("right_table")}, self.extract_tables(query) )
def test_extract_tables_show_columns_from() -> None: """ Test ``SHOW COLUMNS FROM``. """ assert extract_tables("SHOW COLUMNS FROM t1") == {Table("t1")}
def test_simple_select(self): query = "SELECT * FROM tbname" self.assertEqual({Table("tbname")}, self.extract_tables(query)) query = "SELECT * FROM tbname foo" self.assertEqual({Table("tbname")}, self.extract_tables(query)) query = "SELECT * FROM tbname AS foo" self.assertEqual({Table("tbname")}, self.extract_tables(query)) # underscores query = "SELECT * FROM tb_name" self.assertEqual({Table("tb_name")}, self.extract_tables(query)) # quotes query = 'SELECT * FROM "tbname"' self.assertEqual({Table("tbname")}, self.extract_tables(query)) # unicode encoding query = 'SELECT * FROM "tb_name" WHERE city = "Lübeck"' self.assertEqual({Table("tb_name")}, self.extract_tables(query)) # schema self.assertEqual( {Table("tbname", "schemaname")}, self.extract_tables("SELECT * FROM schemaname.tbname"), ) self.assertEqual( {Table("tbname", "schemaname")}, self.extract_tables('SELECT * FROM "schemaname"."tbname"'), ) self.assertEqual( {Table("tbname", "schemaname")}, self.extract_tables("SELECT * FROM schemaname.tbname foo"), ) self.assertEqual( {Table("tbname", "schemaname")}, self.extract_tables("SELECT * FROM schemaname.tbname AS foo"), ) self.assertEqual( {Table("tbname", "schemaname", "catalogname")}, self.extract_tables("SELECT * FROM catalogname.schemaname.tbname"), ) # Ill-defined cluster/schema/table. self.assertEqual(set(), self.extract_tables("SELECT * FROM schemaname.")) self.assertEqual( set(), self.extract_tables("SELECT * FROM catalogname.schemaname.") ) self.assertEqual(set(), self.extract_tables("SELECT * FROM catalogname..")) self.assertEqual( set(), self.extract_tables("SELECT * FROM catalogname..tbname") ) # quotes query = "SELECT field1, field2 FROM tb_name" self.assertEqual({Table("tb_name")}, self.extract_tables(query)) query = "SELECT t1.f1, t2.f2 FROM t1, t2" self.assertEqual({Table("t1"), Table("t2")}, self.extract_tables(query))
def raise_for_access( # pylint: disable=too-many-arguments,too-many-locals self, database: Optional["Database"] = None, datasource: Optional["BaseDatasource"] = None, query: Optional["Query"] = None, query_context: Optional["QueryContext"] = None, table: Optional["Table"] = None, viz: Optional["BaseViz"] = None, ) -> None: """ Raise an exception if the user cannot access the resource. :param database: The Superset database :param datasource: The Superset datasource :param query: The SQL Lab query :param query_context: The query context :param table: The Superset table (requires database) :param viz: The visualization :raises SupersetSecurityException: If the user cannot access the resource """ # pylint: disable=import-outside-toplevel from superset.connectors.sqla.models import SqlaTable from superset.extensions import feature_flag_manager from superset.sql_parse import Table if database and table or query: if query: database = query.database database = cast("Database", database) if self.can_access_database(database): return if query: tables = { Table(table_.table, table_.schema or query.schema) for table_ in sql_parse.ParsedQuery(query.sql).tables } elif table: tables = {table} denied = set() for table_ in tables: schema_perm = self.get_schema_perm(database, schema=table_.schema) if not (schema_perm and self.can_access("schema_access", schema_perm)): datasources = SqlaTable.query_datasources_by_name( self.get_session, database, table_.table, schema=table_.schema) # Access to any datasource is suffice. for datasource_ in datasources: if self.can_access("datasource_access", datasource_.perm): break else: denied.add(table_) if denied: raise SupersetSecurityException( self.get_table_access_error_object(denied)) if datasource or query_context or viz: if query_context: datasource = query_context.datasource elif viz: datasource = viz.datasource assert datasource should_check_dashboard_access = ( feature_flag_manager.is_feature_enabled("DASHBOARD_RBAC") or self.is_guest_user()) if not (self.can_access_schema(datasource) or self.can_access( "datasource_access", datasource.perm or "") or (should_check_dashboard_access and self.can_access_based_on_dashboard(datasource))): raise SupersetSecurityException( self.get_datasource_access_error_object(datasource))