def test_database_schema_presto(self):
        sqlalchemy_uri = 'presto://presto.airbnb.io:8080/hive/default'
        model = Database(sqlalchemy_uri=sqlalchemy_uri)

        db = make_url(model.get_sqla_engine().url).database
        self.assertEquals('hive/default', db)

        db = make_url(model.get_sqla_engine(schema='core_db').url).database
        self.assertEquals('hive/core_db', db)

        sqlalchemy_uri = 'presto://presto.airbnb.io:8080/hive'
        model = Database(sqlalchemy_uri=sqlalchemy_uri)

        db = make_url(model.get_sqla_engine().url).database
        self.assertEquals('hive', db)

        db = make_url(model.get_sqla_engine(schema='core_db').url).database
        self.assertEquals('hive/core_db', db)
    def test_database_schema_postgres(self):
        sqlalchemy_uri = 'postgresql+psycopg2://postgres.airbnb.io:5439/prod'
        model = Database(sqlalchemy_uri=sqlalchemy_uri)

        db = make_url(model.get_sqla_engine().url).database
        self.assertEquals('prod', db)

        db = make_url(model.get_sqla_engine(schema='foo').url).database
        self.assertEquals('prod', db)
    def test_database_schema_presto(self):
        sqlalchemy_uri = "presto://presto.airbnb.io:8080/hive/default"
        model = Database(sqlalchemy_uri=sqlalchemy_uri)

        db = make_url(model.get_sqla_engine().url).database
        self.assertEquals("hive/default", db)

        db = make_url(model.get_sqla_engine(schema="core_db").url).database
        self.assertEquals("hive/core_db", db)

        sqlalchemy_uri = "presto://presto.airbnb.io:8080/hive"
        model = Database(sqlalchemy_uri=sqlalchemy_uri)

        db = make_url(model.get_sqla_engine().url).database
        self.assertEquals("hive", db)

        db = make_url(model.get_sqla_engine(schema="core_db").url).database
        self.assertEquals("hive/core_db", db)
    def test_database_schema_mysql(self):
        sqlalchemy_uri = "mysql://root@localhost/superset"
        model = Database(sqlalchemy_uri=sqlalchemy_uri)

        db = make_url(model.get_sqla_engine().url).database
        self.assertEquals("superset", db)

        db = make_url(model.get_sqla_engine(schema="staging").url).database
        self.assertEquals("staging", db)
示例#5
0
    def test_database_schema_postgres(self):
        sqlalchemy_uri = "postgresql+psycopg2://postgres.airbnb.io:5439/prod"
        model = Database(database_name="test_database", sqlalchemy_uri=sqlalchemy_uri)

        db = make_url(model.get_sqla_engine().url).database
        self.assertEqual("prod", db)

        db = make_url(model.get_sqla_engine(schema="foo").url).database
        self.assertEqual("prod", db)
    def test_database_schema_hive(self):
        sqlalchemy_uri = "hive://[email protected]:10000/default?auth=NOSASL"
        model = Database(database_name="test_database",
                         sqlalchemy_uri=sqlalchemy_uri)
        db = make_url(model.get_sqla_engine().url).database
        self.assertEqual("default", db)

        db = make_url(model.get_sqla_engine(schema="core_db").url).database
        self.assertEqual("core_db", db)
示例#7
0
 def sql_limit_regex(self,
                     sql,
                     expected_sql,
                     engine_spec_class=MySQLEngineSpec,
                     limit=1000):
     main = Database(database_name="test_database",
                     sqlalchemy_uri="sqlite://")
     limited = engine_spec_class.apply_limit_to_sql(sql, limit, main)
     self.assertEqual(expected_sql, limited)
示例#8
0
 def build_db_for_connection_test(server_cert: str, extra: str,
                                  impersonate_user: bool,
                                  encrypted_extra: str) -> Database:
     return Database(
         server_cert=server_cert,
         extra=extra,
         impersonate_user=impersonate_user,
         encrypted_extra=encrypted_extra,
     )
示例#9
0
def test_dataset_model(app_context: None, session: Session) -> None:
    """
    Test basic attributes of a ``Dataset``.
    """
    from superset.columns.models import Column
    from superset.datasets.models import Dataset
    from superset.models.core import Database
    from superset.tables.models import Table

    engine = session.get_bind()
    Dataset.metadata.create_all(engine)  # pylint: disable=no-member

    table = Table(
        name="my_table",
        schema="my_schema",
        catalog="my_catalog",
        database=Database(database_name="my_database", sqlalchemy_uri="sqlite://"),
        columns=[
            Column(name="longitude", expression="longitude"),
            Column(name="latitude", expression="latitude"),
        ],
    )
    session.add(table)
    session.flush()

    dataset = Dataset(
        database=table.database,
        name="positions",
        expression="""
SELECT array_agg(array[longitude,latitude]) AS position
FROM my_catalog.my_schema.my_table
""",
        tables=[table],
        columns=[
            Column(
                name="position",
                expression="array_agg(array[longitude,latitude])",
            ),
        ],
    )
    session.add(dataset)
    session.flush()

    assert dataset.id == 1
    assert dataset.uuid is not None

    assert dataset.name == "positions"
    assert (
        dataset.expression
        == """
SELECT array_agg(array[longitude,latitude]) AS position
FROM my_catalog.my_schema.my_table
"""
    )

    assert [table.name for table in dataset.tables] == ["my_table"]
    assert [column.name for column in dataset.columns] == ["position"]
示例#10
0
 def test_alter_new_orm_column(self):
     """
     DB Eng Specs (crate): Test alter orm column
     """
     database = Database(database_name="crate", sqlalchemy_uri="crate://db")
     tbl = SqlaTable(table_name="druid_tbl", database=database)
     col = TableColumn(column_name="ts", type="TIMESTAMP", table=tbl)
     CrateEngineSpec.alter_new_orm_column(col)
     assert col.python_date_format == "epoch_ms"
示例#11
0
    def test_database_connection_test_mutator(self):
        database = Database(sqlalchemy_uri="snowflake://abc")
        SnowflakeEngineSpec.mutate_db_for_connection_test(database)
        engine_params = json.loads(database.extra or "{}")

        self.assertDictEqual(
            {"engine_params": {"connect_args": {"validate_default_parameters": True}}},
            engine_params,
        )
示例#12
0
 def test_postgres_mixedcase_col_time_grain(self):
     uri = 'postgresql+psycopg2://uid:pwd@localhost:5432/superset'
     database = Database(sqlalchemy_uri=uri)
     pdf, time_grain = '', 'P1D'
     expression, column_name = '', 'MixedCaseCol'
     grain = database.grains_dict().get(time_grain)
     col = database.db_engine_spec.get_timestamp_column(expression, column_name)
     grain_expr = database.db_engine_spec.get_time_expr(col, pdf, time_grain, grain)
     grain_expr_expected = grain.function.replace('{col}', f'"{column_name}"')
     self.assertEqual(grain_expr, grain_expr_expected)
示例#13
0
 def test_get_sqla_engine(self, mocked_create_engine):
     model = Database(
         database_name="test_database",
         sqlalchemy_uri="mysql://root@localhost",
     )
     model.db_engine_spec.get_dbapi_exception_mapping = mock.Mock(
         return_value={Exception: SupersetException})
     mocked_create_engine.side_effect = Exception()
     with self.assertRaises(SupersetException):
         model.get_sqla_engine()
示例#14
0
    def test_set_perm_slice(self):
        session = db.session
        database = Database(database_name="tmp_database",
                            sqlalchemy_uri="sqlite://test")
        table = SqlaTable(table_name="tmp_perm_table", database=database)
        session.add(database)
        session.add(table)
        session.commit()

        # no schema permission
        slice = Slice(
            datasource_id=table.id,
            datasource_type="table",
            datasource_name="tmp_perm_table",
            slice_name="slice_name",
        )
        session.add(slice)
        session.commit()

        slice = session.query(Slice).filter_by(slice_name="slice_name").one()
        self.assertEquals(slice.perm, table.perm)
        self.assertEquals(slice.perm,
                          f"[tmp_database].[tmp_perm_table](id:{table.id})")
        self.assertEquals(slice.schema_perm, table.schema_perm)
        self.assertIsNone(slice.schema_perm)

        table.schema = "tmp_perm_schema"
        table.table_name = "tmp_perm_table_v2"
        session.commit()
        # TODO(bogdan): modify slice permissions on the table update.
        self.assertNotEquals(slice.perm, table.perm)
        self.assertEquals(slice.perm,
                          f"[tmp_database].[tmp_perm_table](id:{table.id})")
        self.assertEquals(
            table.perm, f"[tmp_database].[tmp_perm_table_v2](id:{table.id})")
        # TODO(bogdan): modify slice schema permissions on the table update.
        self.assertNotEquals(slice.schema_perm, table.schema_perm)
        self.assertIsNone(slice.schema_perm)

        # updating slice refreshes the permissions
        slice.slice_name = "slice_name_v2"
        session.commit()
        self.assertEquals(slice.perm, table.perm)
        self.assertEquals(
            slice.perm, f"[tmp_database].[tmp_perm_table_v2](id:{table.id})")
        self.assertEquals(slice.schema_perm, table.schema_perm)
        self.assertEquals(slice.schema_perm,
                          "[tmp_database].[tmp_perm_schema]")

        session.delete(slice)
        session.delete(table)
        session.delete(database)

        session.commit()
    def test_database_impersonate_user(self):
        uri = 'mysql://root@localhost'
        example_user = '******'
        model = Database(sqlalchemy_uri=uri)

        model.impersonate_user = True
        user_name = make_url(model.get_sqla_engine(user_name=example_user).url).username
        self.assertEquals(example_user, user_name)

        model.impersonate_user = False
        user_name = make_url(model.get_sqla_engine(user_name=example_user).url).username
        self.assertNotEquals(example_user, user_name)
示例#16
0
    def test_is_time_druid_time_col(self):
        """Druid has a special __time column"""

        database = Database(database_name="druid_db", sqlalchemy_uri="druid://db")
        tbl = SqlaTable(table_name="druid_tbl", database=database)
        col = TableColumn(column_name="__time", type="INTEGER", table=tbl)
        self.assertEqual(col.is_dttm, None)
        DruidEngineSpec.alter_new_orm_column(col)
        self.assertEqual(col.is_dttm, True)

        col = TableColumn(column_name="__not_time", type="INTEGER", table=tbl)
        self.assertEqual(col.is_temporal, False)
示例#17
0
def test_cascade_delete_dataset(app_context: None, session: Session) -> None:
    """
    Test that deleting ``Dataset`` also deletes its columns.
    """
    from superset.columns.models import Column
    from superset.datasets.models import Dataset
    from superset.models.core import Database
    from superset.tables.models import Table

    engine = session.get_bind()
    Dataset.metadata.create_all(engine)  # pylint: disable=no-member

    table = Table(
        name="my_table",
        schema="my_schema",
        catalog="my_catalog",
        database=Database(database_name="my_database",
                          sqlalchemy_uri="sqlite://"),
        columns=[
            Column(name="longitude", expression="longitude"),
            Column(name="latitude", expression="latitude"),
        ],
    )
    session.add(table)
    session.flush()

    dataset = Dataset(
        name="positions",
        expression="""
SELECT array_agg(array[longitude,latitude]) AS position
FROM my_catalog.my_schema.my_table
""",
        database=table.database,
        tables=[table],
        columns=[
            Column(
                name="position",
                expression="array_agg(array[longitude,latitude])",
            ),
        ],
    )
    session.add(dataset)
    session.flush()

    columns = session.query(Column).all()
    assert len(columns) == 3

    session.delete(dataset)
    session.flush()

    # test that dataset columns were deleted (but not table columns)
    columns = session.query(Column).all()
    assert len(columns) == 2
示例#18
0
    def test_database_for_various_backend(self):
        sqlalchemy_uri = 'presto://presto.airbnb.io:8080/hive/default'
        model = Database(sqlalchemy_uri=sqlalchemy_uri)
        url = make_url(model.sqlalchemy_uri)
        db = model.get_database_for_various_backend(url, None)
        assert db == 'hive/default'
        db = model.get_database_for_various_backend(url, 'raw_data')
        assert db == 'hive/raw_data'

        sqlalchemy_uri = 'redshift+psycopg2://superset:[email protected]:5439/prod'
        model = Database(sqlalchemy_uri=sqlalchemy_uri)
        url = make_url(model.sqlalchemy_uri)
        db = model.get_database_for_various_backend(url, None)
        assert db == 'prod'
        db = model.get_database_for_various_backend(url, 'test')
        assert db == 'prod'

        sqlalchemy_uri = 'postgresql+psycopg2://superset:[email protected]:5439/prod'
        model = Database(sqlalchemy_uri=sqlalchemy_uri)
        url = make_url(model.sqlalchemy_uri)
        db = model.get_database_for_various_backend(url, None)
        assert db == 'prod'
        db = model.get_database_for_various_backend(url, 'adhoc')
        assert db == 'prod'

        sqlalchemy_uri = 'hive://[email protected]:10000/raw_data'
        model = Database(sqlalchemy_uri=sqlalchemy_uri)
        url = make_url(model.sqlalchemy_uri)
        db = model.get_database_for_various_backend(url, None)
        assert db == 'raw_data'
        db = model.get_database_for_various_backend(url, 'adhoc')
        assert db == 'adhoc'

        sqlalchemy_uri = 'mysql://*****:*****@mysql.airbnb.io/superset'
        model = Database(sqlalchemy_uri=sqlalchemy_uri)
        url = make_url(model.sqlalchemy_uri)
        db = model.get_database_for_various_backend(url, None)
        assert db == 'superset'
        db = model.get_database_for_various_backend(url, 'adhoc')
        assert db == 'adhoc'
示例#19
0
    def test_database_impersonate_user(self):
        uri = "mysql://root@localhost"
        example_user = security_manager.find_user(username="******")
        model = Database(database_name="test_database", sqlalchemy_uri=uri)

        with override_user(example_user):
            model.impersonate_user = True
            username = make_url(model.get_sqla_engine().url).username
            self.assertEqual(example_user.username, username)

            model.impersonate_user = False
            username = make_url(model.get_sqla_engine().url).username
            self.assertNotEqual(example_user.username, username)
示例#20
0
def test_database_connection_test_mutator() -> None:
    from superset.db_engine_specs.snowflake import SnowflakeEngineSpec
    from superset.models.core import Database

    database = Database(sqlalchemy_uri="snowflake://abc")
    SnowflakeEngineSpec.mutate_db_for_connection_test(database)
    engine_params = json.loads(database.extra or "{}")

    assert {
        "engine_params": {
            "connect_args": {
                "validate_default_parameters": True
            }
        }
    } == engine_params
示例#21
0
    def test_impersonate_user_hive(self, mocked_create_engine):
        uri = "hive://localhost"
        principal_user = "******"
        extra = """
                {
                    "metadata_params": {},
                    "engine_params": {
                               "connect_args":{
                                  "protocol": "https",
                                  "username":"******",
                                  "password":"******"
                               }
                    },
                    "metadata_cache_timeout": {},
                    "schemas_allowed_for_csv_upload": []
                }
                """

        model = Database(database_name="test_database",
                         sqlalchemy_uri=uri,
                         extra=extra)

        model.impersonate_user = True
        model.get_sqla_engine(user_name=principal_user)
        call_args = mocked_create_engine.call_args

        assert str(call_args[0][0]) == "hive://localhost"

        assert call_args[1]["connect_args"] == {
            "protocol": "https",
            "username": "******",
            "password": "******",
            "configuration": {
                "hive.server2.proxy.user": "******"
            },
        }

        model.impersonate_user = False
        model.get_sqla_engine(user_name=principal_user)
        call_args = mocked_create_engine.call_args

        assert str(call_args[0][0]) == "hive://localhost"

        assert call_args[1]["connect_args"] == {
            "protocol": "https",
            "username": "******",
            "password": "******",
        }
示例#22
0
def test_get_metrics(mocker: MockFixture) -> None:
    """
    Tests for ``get_metrics``.
    """
    from superset.db_engine_specs.base import MetricType
    from superset.db_engine_specs.sqlite import SqliteEngineSpec
    from superset.models.core import Database

    database = Database(database_name="my_database", sqlalchemy_uri="sqlite://")
    assert database.get_metrics("table") == [
        {
            "expression": "COUNT(*)",
            "metric_name": "count",
            "metric_type": "count",
            "verbose_name": "COUNT(*)",
        }
    ]

    class CustomSqliteEngineSpec(SqliteEngineSpec):
        @classmethod
        def get_metrics(
            cls,
            database: Database,
            inspector: Inspector,
            table_name: str,
            schema: Optional[str],
        ) -> List[MetricType]:
            return [
                {
                    "expression": "COUNT(DISTINCT user_id)",
                    "metric_name": "count_distinct_user_id",
                    "metric_type": "count_distinct",
                    "verbose_name": "COUNT(DISTINCT user_id)",
                },
            ]

    database.get_db_engine_spec_for_backend = mocker.MagicMock(  # type: ignore
        return_value=CustomSqliteEngineSpec
    )
    assert database.get_metrics("table") == [
        {
            "expression": "COUNT(DISTINCT user_id)",
            "metric_name": "count_distinct_user_id",
            "metric_type": "count_distinct",
            "verbose_name": "COUNT(DISTINCT user_id)",
        },
    ]
示例#23
0
    def test_impersonate_user_presto(self, mocked_create_engine):
        uri = "presto://localhost"
        principal_user = security_manager.find_user(username="******")
        extra = """
                {
                    "metadata_params": {},
                    "engine_params": {
                               "connect_args":{
                                  "protocol": "https",
                                  "username":"******",
                                  "password":"******"
                               }
                    },
                    "metadata_cache_timeout": {},
                    "schemas_allowed_for_file_upload": []
                }
                """

        with override_user(principal_user):
            model = Database(database_name="test_database",
                             sqlalchemy_uri=uri,
                             extra=extra)
            model.impersonate_user = True
            model.get_sqla_engine()
            call_args = mocked_create_engine.call_args

            assert str(call_args[0][0]) == "presto://gamma@localhost"

            assert call_args[1]["connect_args"] == {
                "protocol": "https",
                "username": "******",
                "password": "******",
                "principal_username": "******",
            }

            model.impersonate_user = False
            model.get_sqla_engine()
            call_args = mocked_create_engine.call_args

            assert str(call_args[0][0]) == "presto://localhost"

            assert call_args[1]["connect_args"] == {
                "protocol": "https",
                "username": "******",
                "password": "******",
            }
    def test_hybrid_perm_database(self):
        database = Database(database_name="tmp_database3",
                            sqlalchemy_uri="sqlite://test")

        db.session.add(database)

        id_ = (db.session.query(
            Database.id).filter_by(database_name="tmp_database3").scalar())

        record = (db.session.query(Database).filter_by(
            perm=f"[tmp_database3].(id:{id_})").one())

        self.assertEqual(record.get_perm(), record.perm)
        self.assertEqual(record.id, id_)
        self.assertEqual(record.database_name, "tmp_database3")
        db.session.delete(database)
        db.session.commit()
示例#25
0
def test_update_sqlatable_metric(mocker: MockFixture, app_context: None,
                                 session: Session) -> None:
    """
    Test that updating a ``SqlaTable`` also updates the corresponding ``Dataset``.

    For this test we check that updating the SQL expression in a metric belonging to a
    ``SqlaTable`` is reflected in the ``Dataset`` metric.
    """
    # patch session
    mocker.patch("superset.security.SupersetSecurityManager.get_session",
                 return_value=session)

    from superset.columns.models import Column
    from superset.connectors.sqla.models import SqlaTable, SqlMetric, TableColumn
    from superset.datasets.models import Dataset
    from superset.models.core import Database
    from superset.tables.models import Table

    engine = session.get_bind()
    Dataset.metadata.create_all(engine)  # pylint: disable=no-member

    columns = [
        TableColumn(column_name="ds", is_dttm=1, type="TIMESTAMP"),
    ]
    metrics = [
        SqlMetric(metric_name="cnt", expression="COUNT(*)"),
    ]
    sqla_table = SqlaTable(
        table_name="old_dataset",
        columns=columns,
        metrics=metrics,
        database=Database(database_name="my_database",
                          sqlalchemy_uri="sqlite://"),
    )
    session.add(sqla_table)
    session.flush()

    # check that the metric was created
    column = session.query(Column).filter_by(is_physical=False).one()
    assert column.expression == "COUNT(*)"

    # change the metric definition
    sqla_table.metrics[0].expression = "MAX(ds)"
    session.flush()

    assert column.expression == "MAX(ds)"
 def setUp(self):
     super(SqlaConnectorTestCase, self).setUp()
     sqlalchemy_uri = 'sqlite:////tmp/test.db'
     database = Database(
         database_name='test_database',
         sqlalchemy_uri=sqlalchemy_uri)
     self.connection = database.get_sqla_engine().connect()
     self.datasource = SqlaTable(table_name='test_datasource',
                                 database=database,
                                 columns=self.columns,
                                 metrics=self.metrics)
     with database.get_sqla_engine().begin() as connection:
         self.df.to_sql(self.datasource.table_name,
                        connection,
                        if_exists='replace',
                        index=False,
                        dtype={'received': Date})
示例#27
0
def test_update_physical_sqlatable_schema(
    mocker: MockFixture, app_context: None, session: Session
) -> None:
    """
    Test that updating a ``SqlaTable`` schema also updates the corresponding ``Dataset``.
    """
    # patch session
    mocker.patch(
        "superset.security.SupersetSecurityManager.get_session", return_value=session
    )
    mocker.patch("superset.datasets.dao.db.session", session)

    from superset.columns.models import Column
    from superset.connectors.sqla.models import SqlaTable, TableColumn
    from superset.datasets.models import Dataset
    from superset.models.core import Database
    from superset.tables.models import Table

    engine = session.get_bind()
    Dataset.metadata.create_all(engine)  # pylint: disable=no-member

    columns = [
        TableColumn(column_name="ds", is_dttm=1, type="TIMESTAMP"),
    ]
    sqla_table = SqlaTable(
        table_name="old_dataset",
        schema="old_schema",
        columns=columns,
        metrics=[],
        database=Database(database_name="my_database", sqlalchemy_uri="sqlite://"),
    )
    session.add(sqla_table)
    session.flush()

    dataset = session.query(Dataset).one()
    assert dataset.tables[0].schema == "old_schema"
    assert dataset.tables[0].id == 1

    sqla_table.schema = "new_schema"
    session.flush()

    new_dataset = session.query(Dataset).one()
    assert new_dataset.tables[0].schema == "new_schema"
    assert new_dataset.tables[0].id == 2
def test_query_no_access(mocker: MockFixture, app_context: AppContext) -> None:
    from superset.connectors.sqla.models import SqlaTable
    from superset.explore.utils import check_datasource_access
    from superset.models.core import Database
    from superset.models.sql_lab import Query

    with raises(SupersetSecurityException):
        mocker.patch(
            query_find_by_id,
            return_value=Query(database=Database(), sql="select * from foo"),
        )
        mocker.patch(query_datasources_by_name, return_value=[SqlaTable()])
        mocker.patch(is_user_admin, return_value=False)
        mocker.patch(is_owner, return_value=False)
        mocker.patch(can_access, return_value=False)
        check_datasource_access(
            datasource_id=1,
            datasource_type=DatasourceType.QUERY,
        )
示例#29
0
 def insert_database(
     self,
     database_name: str,
     sqlalchemy_uri: str,
     extra: str = "",
     encrypted_extra: str = "",
     server_cert: str = "",
     expose_in_sqllab: bool = False,
 ) -> Database:
     database = Database(
         database_name=database_name,
         sqlalchemy_uri=sqlalchemy_uri,
         extra=extra,
         encrypted_extra=encrypted_extra,
         server_cert=server_cert,
         expose_in_sqllab=expose_in_sqllab,
     )
     db.session.add(database)
     db.session.commit()
     return database
示例#30
0
    def setUpClass(cls):

        role = security_manager.role_model(name=TEST_ROLE)

        perm = security_manager.find_permission_view_menu(
            'can_add', 'SliceModelView')
        role.permissions.append(perm)

        appbuilder.sm.add_user(TEST_USER, 'datasource', 'user',
                               '*****@*****.**', role, 'general')

        database = Database(database_name=TEST_DB)

        table1 = SqlaTable(table_name='table_for_test_role', database=database)
        table2 = SqlaTable(table_name='table_not_for_test_role',
                           database=database)

        db.session.add_all([table1, table2])

        db.session.commit()