def test_database_schema_presto(self): sqlalchemy_uri = 'presto://presto.airbnb.io:8080/hive/default' model = Database(sqlalchemy_uri=sqlalchemy_uri) db = make_url(model.get_sqla_engine().url).database self.assertEquals('hive/default', db) db = make_url(model.get_sqla_engine(schema='core_db').url).database self.assertEquals('hive/core_db', db) sqlalchemy_uri = 'presto://presto.airbnb.io:8080/hive' model = Database(sqlalchemy_uri=sqlalchemy_uri) db = make_url(model.get_sqla_engine().url).database self.assertEquals('hive', db) db = make_url(model.get_sqla_engine(schema='core_db').url).database self.assertEquals('hive/core_db', db)
def test_database_schema_postgres(self): sqlalchemy_uri = 'postgresql+psycopg2://postgres.airbnb.io:5439/prod' model = Database(sqlalchemy_uri=sqlalchemy_uri) db = make_url(model.get_sqla_engine().url).database self.assertEquals('prod', db) db = make_url(model.get_sqla_engine(schema='foo').url).database self.assertEquals('prod', db)
def test_database_schema_presto(self): sqlalchemy_uri = "presto://presto.airbnb.io:8080/hive/default" model = Database(sqlalchemy_uri=sqlalchemy_uri) db = make_url(model.get_sqla_engine().url).database self.assertEquals("hive/default", db) db = make_url(model.get_sqla_engine(schema="core_db").url).database self.assertEquals("hive/core_db", db) sqlalchemy_uri = "presto://presto.airbnb.io:8080/hive" model = Database(sqlalchemy_uri=sqlalchemy_uri) db = make_url(model.get_sqla_engine().url).database self.assertEquals("hive", db) db = make_url(model.get_sqla_engine(schema="core_db").url).database self.assertEquals("hive/core_db", db)
def test_database_schema_mysql(self): sqlalchemy_uri = "mysql://root@localhost/superset" model = Database(sqlalchemy_uri=sqlalchemy_uri) db = make_url(model.get_sqla_engine().url).database self.assertEquals("superset", db) db = make_url(model.get_sqla_engine(schema="staging").url).database self.assertEquals("staging", db)
def test_database_schema_postgres(self): sqlalchemy_uri = "postgresql+psycopg2://postgres.airbnb.io:5439/prod" model = Database(database_name="test_database", sqlalchemy_uri=sqlalchemy_uri) db = make_url(model.get_sqla_engine().url).database self.assertEqual("prod", db) db = make_url(model.get_sqla_engine(schema="foo").url).database self.assertEqual("prod", db)
def test_database_schema_hive(self): sqlalchemy_uri = "hive://[email protected]:10000/default?auth=NOSASL" model = Database(database_name="test_database", sqlalchemy_uri=sqlalchemy_uri) db = make_url(model.get_sqla_engine().url).database self.assertEqual("default", db) db = make_url(model.get_sqla_engine(schema="core_db").url).database self.assertEqual("core_db", db)
def sql_limit_regex(self, sql, expected_sql, engine_spec_class=MySQLEngineSpec, limit=1000): main = Database(database_name="test_database", sqlalchemy_uri="sqlite://") limited = engine_spec_class.apply_limit_to_sql(sql, limit, main) self.assertEqual(expected_sql, limited)
def build_db_for_connection_test(server_cert: str, extra: str, impersonate_user: bool, encrypted_extra: str) -> Database: return Database( server_cert=server_cert, extra=extra, impersonate_user=impersonate_user, encrypted_extra=encrypted_extra, )
def test_dataset_model(app_context: None, session: Session) -> None: """ Test basic attributes of a ``Dataset``. """ from superset.columns.models import Column from superset.datasets.models import Dataset from superset.models.core import Database from superset.tables.models import Table engine = session.get_bind() Dataset.metadata.create_all(engine) # pylint: disable=no-member table = Table( name="my_table", schema="my_schema", catalog="my_catalog", database=Database(database_name="my_database", sqlalchemy_uri="sqlite://"), columns=[ Column(name="longitude", expression="longitude"), Column(name="latitude", expression="latitude"), ], ) session.add(table) session.flush() dataset = Dataset( database=table.database, name="positions", expression=""" SELECT array_agg(array[longitude,latitude]) AS position FROM my_catalog.my_schema.my_table """, tables=[table], columns=[ Column( name="position", expression="array_agg(array[longitude,latitude])", ), ], ) session.add(dataset) session.flush() assert dataset.id == 1 assert dataset.uuid is not None assert dataset.name == "positions" assert ( dataset.expression == """ SELECT array_agg(array[longitude,latitude]) AS position FROM my_catalog.my_schema.my_table """ ) assert [table.name for table in dataset.tables] == ["my_table"] assert [column.name for column in dataset.columns] == ["position"]
def test_alter_new_orm_column(self): """ DB Eng Specs (crate): Test alter orm column """ database = Database(database_name="crate", sqlalchemy_uri="crate://db") tbl = SqlaTable(table_name="druid_tbl", database=database) col = TableColumn(column_name="ts", type="TIMESTAMP", table=tbl) CrateEngineSpec.alter_new_orm_column(col) assert col.python_date_format == "epoch_ms"
def test_database_connection_test_mutator(self): database = Database(sqlalchemy_uri="snowflake://abc") SnowflakeEngineSpec.mutate_db_for_connection_test(database) engine_params = json.loads(database.extra or "{}") self.assertDictEqual( {"engine_params": {"connect_args": {"validate_default_parameters": True}}}, engine_params, )
def test_postgres_mixedcase_col_time_grain(self): uri = 'postgresql+psycopg2://uid:pwd@localhost:5432/superset' database = Database(sqlalchemy_uri=uri) pdf, time_grain = '', 'P1D' expression, column_name = '', 'MixedCaseCol' grain = database.grains_dict().get(time_grain) col = database.db_engine_spec.get_timestamp_column(expression, column_name) grain_expr = database.db_engine_spec.get_time_expr(col, pdf, time_grain, grain) grain_expr_expected = grain.function.replace('{col}', f'"{column_name}"') self.assertEqual(grain_expr, grain_expr_expected)
def test_get_sqla_engine(self, mocked_create_engine): model = Database( database_name="test_database", sqlalchemy_uri="mysql://root@localhost", ) model.db_engine_spec.get_dbapi_exception_mapping = mock.Mock( return_value={Exception: SupersetException}) mocked_create_engine.side_effect = Exception() with self.assertRaises(SupersetException): model.get_sqla_engine()
def test_set_perm_slice(self): session = db.session database = Database(database_name="tmp_database", sqlalchemy_uri="sqlite://test") table = SqlaTable(table_name="tmp_perm_table", database=database) session.add(database) session.add(table) session.commit() # no schema permission slice = Slice( datasource_id=table.id, datasource_type="table", datasource_name="tmp_perm_table", slice_name="slice_name", ) session.add(slice) session.commit() slice = session.query(Slice).filter_by(slice_name="slice_name").one() self.assertEquals(slice.perm, table.perm) self.assertEquals(slice.perm, f"[tmp_database].[tmp_perm_table](id:{table.id})") self.assertEquals(slice.schema_perm, table.schema_perm) self.assertIsNone(slice.schema_perm) table.schema = "tmp_perm_schema" table.table_name = "tmp_perm_table_v2" session.commit() # TODO(bogdan): modify slice permissions on the table update. self.assertNotEquals(slice.perm, table.perm) self.assertEquals(slice.perm, f"[tmp_database].[tmp_perm_table](id:{table.id})") self.assertEquals( table.perm, f"[tmp_database].[tmp_perm_table_v2](id:{table.id})") # TODO(bogdan): modify slice schema permissions on the table update. self.assertNotEquals(slice.schema_perm, table.schema_perm) self.assertIsNone(slice.schema_perm) # updating slice refreshes the permissions slice.slice_name = "slice_name_v2" session.commit() self.assertEquals(slice.perm, table.perm) self.assertEquals( slice.perm, f"[tmp_database].[tmp_perm_table_v2](id:{table.id})") self.assertEquals(slice.schema_perm, table.schema_perm) self.assertEquals(slice.schema_perm, "[tmp_database].[tmp_perm_schema]") session.delete(slice) session.delete(table) session.delete(database) session.commit()
def test_database_impersonate_user(self): uri = 'mysql://root@localhost' example_user = '******' model = Database(sqlalchemy_uri=uri) model.impersonate_user = True user_name = make_url(model.get_sqla_engine(user_name=example_user).url).username self.assertEquals(example_user, user_name) model.impersonate_user = False user_name = make_url(model.get_sqla_engine(user_name=example_user).url).username self.assertNotEquals(example_user, user_name)
def test_is_time_druid_time_col(self): """Druid has a special __time column""" database = Database(database_name="druid_db", sqlalchemy_uri="druid://db") tbl = SqlaTable(table_name="druid_tbl", database=database) col = TableColumn(column_name="__time", type="INTEGER", table=tbl) self.assertEqual(col.is_dttm, None) DruidEngineSpec.alter_new_orm_column(col) self.assertEqual(col.is_dttm, True) col = TableColumn(column_name="__not_time", type="INTEGER", table=tbl) self.assertEqual(col.is_temporal, False)
def test_cascade_delete_dataset(app_context: None, session: Session) -> None: """ Test that deleting ``Dataset`` also deletes its columns. """ from superset.columns.models import Column from superset.datasets.models import Dataset from superset.models.core import Database from superset.tables.models import Table engine = session.get_bind() Dataset.metadata.create_all(engine) # pylint: disable=no-member table = Table( name="my_table", schema="my_schema", catalog="my_catalog", database=Database(database_name="my_database", sqlalchemy_uri="sqlite://"), columns=[ Column(name="longitude", expression="longitude"), Column(name="latitude", expression="latitude"), ], ) session.add(table) session.flush() dataset = Dataset( name="positions", expression=""" SELECT array_agg(array[longitude,latitude]) AS position FROM my_catalog.my_schema.my_table """, database=table.database, tables=[table], columns=[ Column( name="position", expression="array_agg(array[longitude,latitude])", ), ], ) session.add(dataset) session.flush() columns = session.query(Column).all() assert len(columns) == 3 session.delete(dataset) session.flush() # test that dataset columns were deleted (but not table columns) columns = session.query(Column).all() assert len(columns) == 2
def test_database_for_various_backend(self): sqlalchemy_uri = 'presto://presto.airbnb.io:8080/hive/default' model = Database(sqlalchemy_uri=sqlalchemy_uri) url = make_url(model.sqlalchemy_uri) db = model.get_database_for_various_backend(url, None) assert db == 'hive/default' db = model.get_database_for_various_backend(url, 'raw_data') assert db == 'hive/raw_data' sqlalchemy_uri = 'redshift+psycopg2://superset:[email protected]:5439/prod' model = Database(sqlalchemy_uri=sqlalchemy_uri) url = make_url(model.sqlalchemy_uri) db = model.get_database_for_various_backend(url, None) assert db == 'prod' db = model.get_database_for_various_backend(url, 'test') assert db == 'prod' sqlalchemy_uri = 'postgresql+psycopg2://superset:[email protected]:5439/prod' model = Database(sqlalchemy_uri=sqlalchemy_uri) url = make_url(model.sqlalchemy_uri) db = model.get_database_for_various_backend(url, None) assert db == 'prod' db = model.get_database_for_various_backend(url, 'adhoc') assert db == 'prod' sqlalchemy_uri = 'hive://[email protected]:10000/raw_data' model = Database(sqlalchemy_uri=sqlalchemy_uri) url = make_url(model.sqlalchemy_uri) db = model.get_database_for_various_backend(url, None) assert db == 'raw_data' db = model.get_database_for_various_backend(url, 'adhoc') assert db == 'adhoc' sqlalchemy_uri = 'mysql://*****:*****@mysql.airbnb.io/superset' model = Database(sqlalchemy_uri=sqlalchemy_uri) url = make_url(model.sqlalchemy_uri) db = model.get_database_for_various_backend(url, None) assert db == 'superset' db = model.get_database_for_various_backend(url, 'adhoc') assert db == 'adhoc'
def test_database_impersonate_user(self): uri = "mysql://root@localhost" example_user = security_manager.find_user(username="******") model = Database(database_name="test_database", sqlalchemy_uri=uri) with override_user(example_user): model.impersonate_user = True username = make_url(model.get_sqla_engine().url).username self.assertEqual(example_user.username, username) model.impersonate_user = False username = make_url(model.get_sqla_engine().url).username self.assertNotEqual(example_user.username, username)
def test_database_connection_test_mutator() -> None: from superset.db_engine_specs.snowflake import SnowflakeEngineSpec from superset.models.core import Database database = Database(sqlalchemy_uri="snowflake://abc") SnowflakeEngineSpec.mutate_db_for_connection_test(database) engine_params = json.loads(database.extra or "{}") assert { "engine_params": { "connect_args": { "validate_default_parameters": True } } } == engine_params
def test_impersonate_user_hive(self, mocked_create_engine): uri = "hive://localhost" principal_user = "******" extra = """ { "metadata_params": {}, "engine_params": { "connect_args":{ "protocol": "https", "username":"******", "password":"******" } }, "metadata_cache_timeout": {}, "schemas_allowed_for_csv_upload": [] } """ model = Database(database_name="test_database", sqlalchemy_uri=uri, extra=extra) model.impersonate_user = True model.get_sqla_engine(user_name=principal_user) call_args = mocked_create_engine.call_args assert str(call_args[0][0]) == "hive://localhost" assert call_args[1]["connect_args"] == { "protocol": "https", "username": "******", "password": "******", "configuration": { "hive.server2.proxy.user": "******" }, } model.impersonate_user = False model.get_sqla_engine(user_name=principal_user) call_args = mocked_create_engine.call_args assert str(call_args[0][0]) == "hive://localhost" assert call_args[1]["connect_args"] == { "protocol": "https", "username": "******", "password": "******", }
def test_get_metrics(mocker: MockFixture) -> None: """ Tests for ``get_metrics``. """ from superset.db_engine_specs.base import MetricType from superset.db_engine_specs.sqlite import SqliteEngineSpec from superset.models.core import Database database = Database(database_name="my_database", sqlalchemy_uri="sqlite://") assert database.get_metrics("table") == [ { "expression": "COUNT(*)", "metric_name": "count", "metric_type": "count", "verbose_name": "COUNT(*)", } ] class CustomSqliteEngineSpec(SqliteEngineSpec): @classmethod def get_metrics( cls, database: Database, inspector: Inspector, table_name: str, schema: Optional[str], ) -> List[MetricType]: return [ { "expression": "COUNT(DISTINCT user_id)", "metric_name": "count_distinct_user_id", "metric_type": "count_distinct", "verbose_name": "COUNT(DISTINCT user_id)", }, ] database.get_db_engine_spec_for_backend = mocker.MagicMock( # type: ignore return_value=CustomSqliteEngineSpec ) assert database.get_metrics("table") == [ { "expression": "COUNT(DISTINCT user_id)", "metric_name": "count_distinct_user_id", "metric_type": "count_distinct", "verbose_name": "COUNT(DISTINCT user_id)", }, ]
def test_impersonate_user_presto(self, mocked_create_engine): uri = "presto://localhost" principal_user = security_manager.find_user(username="******") extra = """ { "metadata_params": {}, "engine_params": { "connect_args":{ "protocol": "https", "username":"******", "password":"******" } }, "metadata_cache_timeout": {}, "schemas_allowed_for_file_upload": [] } """ with override_user(principal_user): model = Database(database_name="test_database", sqlalchemy_uri=uri, extra=extra) model.impersonate_user = True model.get_sqla_engine() call_args = mocked_create_engine.call_args assert str(call_args[0][0]) == "presto://gamma@localhost" assert call_args[1]["connect_args"] == { "protocol": "https", "username": "******", "password": "******", "principal_username": "******", } model.impersonate_user = False model.get_sqla_engine() call_args = mocked_create_engine.call_args assert str(call_args[0][0]) == "presto://localhost" assert call_args[1]["connect_args"] == { "protocol": "https", "username": "******", "password": "******", }
def test_hybrid_perm_database(self): database = Database(database_name="tmp_database3", sqlalchemy_uri="sqlite://test") db.session.add(database) id_ = (db.session.query( Database.id).filter_by(database_name="tmp_database3").scalar()) record = (db.session.query(Database).filter_by( perm=f"[tmp_database3].(id:{id_})").one()) self.assertEqual(record.get_perm(), record.perm) self.assertEqual(record.id, id_) self.assertEqual(record.database_name, "tmp_database3") db.session.delete(database) db.session.commit()
def test_update_sqlatable_metric(mocker: MockFixture, app_context: None, session: Session) -> None: """ Test that updating a ``SqlaTable`` also updates the corresponding ``Dataset``. For this test we check that updating the SQL expression in a metric belonging to a ``SqlaTable`` is reflected in the ``Dataset`` metric. """ # patch session mocker.patch("superset.security.SupersetSecurityManager.get_session", return_value=session) from superset.columns.models import Column from superset.connectors.sqla.models import SqlaTable, SqlMetric, TableColumn from superset.datasets.models import Dataset from superset.models.core import Database from superset.tables.models import Table engine = session.get_bind() Dataset.metadata.create_all(engine) # pylint: disable=no-member columns = [ TableColumn(column_name="ds", is_dttm=1, type="TIMESTAMP"), ] metrics = [ SqlMetric(metric_name="cnt", expression="COUNT(*)"), ] sqla_table = SqlaTable( table_name="old_dataset", columns=columns, metrics=metrics, database=Database(database_name="my_database", sqlalchemy_uri="sqlite://"), ) session.add(sqla_table) session.flush() # check that the metric was created column = session.query(Column).filter_by(is_physical=False).one() assert column.expression == "COUNT(*)" # change the metric definition sqla_table.metrics[0].expression = "MAX(ds)" session.flush() assert column.expression == "MAX(ds)"
def setUp(self): super(SqlaConnectorTestCase, self).setUp() sqlalchemy_uri = 'sqlite:////tmp/test.db' database = Database( database_name='test_database', sqlalchemy_uri=sqlalchemy_uri) self.connection = database.get_sqla_engine().connect() self.datasource = SqlaTable(table_name='test_datasource', database=database, columns=self.columns, metrics=self.metrics) with database.get_sqla_engine().begin() as connection: self.df.to_sql(self.datasource.table_name, connection, if_exists='replace', index=False, dtype={'received': Date})
def test_update_physical_sqlatable_schema( mocker: MockFixture, app_context: None, session: Session ) -> None: """ Test that updating a ``SqlaTable`` schema also updates the corresponding ``Dataset``. """ # patch session mocker.patch( "superset.security.SupersetSecurityManager.get_session", return_value=session ) mocker.patch("superset.datasets.dao.db.session", session) from superset.columns.models import Column from superset.connectors.sqla.models import SqlaTable, TableColumn from superset.datasets.models import Dataset from superset.models.core import Database from superset.tables.models import Table engine = session.get_bind() Dataset.metadata.create_all(engine) # pylint: disable=no-member columns = [ TableColumn(column_name="ds", is_dttm=1, type="TIMESTAMP"), ] sqla_table = SqlaTable( table_name="old_dataset", schema="old_schema", columns=columns, metrics=[], database=Database(database_name="my_database", sqlalchemy_uri="sqlite://"), ) session.add(sqla_table) session.flush() dataset = session.query(Dataset).one() assert dataset.tables[0].schema == "old_schema" assert dataset.tables[0].id == 1 sqla_table.schema = "new_schema" session.flush() new_dataset = session.query(Dataset).one() assert new_dataset.tables[0].schema == "new_schema" assert new_dataset.tables[0].id == 2
def test_query_no_access(mocker: MockFixture, app_context: AppContext) -> None: from superset.connectors.sqla.models import SqlaTable from superset.explore.utils import check_datasource_access from superset.models.core import Database from superset.models.sql_lab import Query with raises(SupersetSecurityException): mocker.patch( query_find_by_id, return_value=Query(database=Database(), sql="select * from foo"), ) mocker.patch(query_datasources_by_name, return_value=[SqlaTable()]) mocker.patch(is_user_admin, return_value=False) mocker.patch(is_owner, return_value=False) mocker.patch(can_access, return_value=False) check_datasource_access( datasource_id=1, datasource_type=DatasourceType.QUERY, )
def insert_database( self, database_name: str, sqlalchemy_uri: str, extra: str = "", encrypted_extra: str = "", server_cert: str = "", expose_in_sqllab: bool = False, ) -> Database: database = Database( database_name=database_name, sqlalchemy_uri=sqlalchemy_uri, extra=extra, encrypted_extra=encrypted_extra, server_cert=server_cert, expose_in_sqllab=expose_in_sqllab, ) db.session.add(database) db.session.commit() return database
def setUpClass(cls): role = security_manager.role_model(name=TEST_ROLE) perm = security_manager.find_permission_view_menu( 'can_add', 'SliceModelView') role.permissions.append(perm) appbuilder.sm.add_user(TEST_USER, 'datasource', 'user', '*****@*****.**', role, 'general') database = Database(database_name=TEST_DB) table1 = SqlaTable(table_name='table_for_test_role', database=database) table2 = SqlaTable(table_name='table_not_for_test_role', database=database) db.session.add_all([table1, table2]) db.session.commit()