def test_grains_dict(self): uri = 'mysql://root@localhost' database = Database(sqlalchemy_uri=uri) d = database.grains_dict() self.assertEquals(d.get('day').function, 'DATE({col})') self.assertEquals(d.get('P1D').function, 'DATE({col})') self.assertEquals(d.get('Time Column').function, '{col}')
def test_database_schema_hive(self): sqlalchemy_uri = 'hive://[email protected]:10000/default?auth=NOSASL' model = Database(sqlalchemy_uri=sqlalchemy_uri) db = make_url(model.get_sqla_engine().url).database self.assertEquals('default', db) db = make_url(model.get_sqla_engine(schema='core_db').url).database self.assertEquals('core_db', db)
def test_database_schema_mysql(self): sqlalchemy_uri = 'mysql://root@localhost/superset' model = Database(sqlalchemy_uri=sqlalchemy_uri) db = make_url(model.get_sqla_engine().url).database self.assertEquals('superset', db) db = make_url(model.get_sqla_engine(schema='staging').url).database self.assertEquals('staging', db)
def test_database_schema_postgres(self): sqlalchemy_uri = 'postgresql+psycopg2://postgres.airbnb.io:5439/prod' model = Database(sqlalchemy_uri=sqlalchemy_uri) db = make_url(model.get_sqla_engine().url).database self.assertEquals('prod', db) db = make_url(model.get_sqla_engine(schema='foo').url).database self.assertEquals('prod', db)
def test_postgres_mixedcase_col_time_grain(self): uri = 'postgresql+psycopg2://uid:pwd@localhost:5432/superset' database = Database(sqlalchemy_uri=uri) pdf, time_grain = '', 'P1D' expression, column_name = '', 'MixedCaseCol' grain = database.grains_dict().get(time_grain) col = database.db_engine_spec.get_timestamp_column(expression, column_name) grain_expr = database.db_engine_spec.get_time_expr(col, pdf, time_grain, grain) grain_expr_expected = grain.function.replace('{col}', f'"{column_name}"') self.assertEqual(grain_expr, grain_expr_expected)
def test_database_impersonate_user(self): uri = 'mysql://root@localhost' example_user = '******' model = Database(sqlalchemy_uri=uri) model.impersonate_user = True user_name = make_url(model.get_sqla_engine(user_name=example_user).url).username self.assertEquals(example_user, user_name) model.impersonate_user = False user_name = make_url(model.get_sqla_engine(user_name=example_user).url).username self.assertNotEquals(example_user, user_name)
def create_table_for_dashboard( df: DataFrame, table_name: str, database: Database, dtype: Dict[str, Any], table_description: str = "", fetch_values_predicate: Optional[str] = None, ) -> SqlaTable: df.to_sql( table_name, database.get_sqla_engine(), if_exists="replace", chunksize=500, dtype=dtype, index=False, method="multi", ) table_source = ConnectorRegistry.sources["table"] table = ( db.session.query(table_source).filter_by(table_name=table_name).one_or_none() ) if not table: table = table_source(table_name=table_name) if fetch_values_predicate: table.fetch_values_predicate = fetch_values_predicate table.database = database table.description = table_description db.session.merge(table) db.session.commit() return table
def test_labels_expected_on_mutated_query(self): query_obj = { "granularity": None, "from_dttm": None, "to_dttm": None, "groupby": ["user"], "metrics": [{ "expressionType": "SIMPLE", "column": { "column_name": "user" }, "aggregate": "COUNT_DISTINCT", "label": "COUNT_DISTINCT(user)", }], "is_timeseries": False, "filter": [], "extras": {}, } database = Database(database_name="testdb", sqlalchemy_uri="sqlite://") table = SqlaTable(table_name="bq_table", database=database) db.session.add(database) db.session.add(table) db.session.commit() sqlaq = table.get_sqla_query(**query_obj) assert sqlaq.labels_expected == ["user", "COUNT_DISTINCT(user)"] sql = table.database.compile_sqla_query(sqlaq.sqla_query) assert "COUNT_DISTINCT_user__00db1" in sql db.session.delete(table) db.session.delete(database) db.session.commit()
def test_set_perm_database(self): session = db.session database = Database(database_name="tmp_database", sqlalchemy_uri="sqlite://test") session.add(database) stored_db = (session.query(Database).filter_by( database_name="tmp_database").one()) self.assertEqual(stored_db.perm, f"[tmp_database].(id:{stored_db.id})") self.assertIsNotNone( security_manager.find_permission_view_menu("database_access", stored_db.perm)) stored_db.database_name = "tmp_database2" session.commit() stored_db = (session.query(Database).filter_by( database_name="tmp_database2").one()) self.assertEqual(stored_db.perm, f"[tmp_database2].(id:{stored_db.id})") self.assertIsNotNone( security_manager.find_permission_view_menu("database_access", stored_db.perm)) session.delete(stored_db) session.commit()
def test_table_model(session: Session) -> None: """ Test basic attributes of a ``Table``. """ from superset.columns.models import Column from superset.models.core import Database from superset.tables.models import Table engine = session.get_bind() Table.metadata.create_all(engine) # pylint: disable=no-member table = Table( name="my_table", schema="my_schema", catalog="my_catalog", database=Database(database_name="my_database", sqlalchemy_uri="test://"), columns=[ Column( name="ds", type="TIMESTAMP", expression="ds", ) ], ) session.add(table) session.flush() assert table.id == 1 assert table.uuid is not None assert table.database_id == 1 assert table.catalog == "my_catalog" assert table.schema == "my_schema" assert table.name == "my_table" assert [column.name for column in table.columns] == ["ds"]
def get_indexes_metadata( database: Database, table_name: str, schema_name: Optional[str] ) -> List[Dict[str, Any]]: indexes = database.get_indexes(table_name, schema_name) for idx in indexes: idx["type"] = "index" return indexes
def test_cascade_delete_table(app_context: None, session: Session) -> None: """ Test that deleting ``Table`` also deletes its columns. """ from superset.columns.models import Column from superset.models.core import Database from superset.tables.models import Table engine = session.get_bind() Table.metadata.create_all(engine) # pylint: disable=no-member table = Table( name="my_table", schema="my_schema", catalog="my_catalog", database=Database(database_name="my_database", sqlalchemy_uri="sqlite://"), columns=[ Column(name="longitude", expression="longitude"), Column(name="latitude", expression="latitude"), ], ) session.add(table) session.flush() columns = session.query(Column).all() assert len(columns) == 2 session.delete(table) session.flush() # test that columns were deleted columns = session.query(Column).all() assert len(columns) == 0
def load_data(tbl_name: str, database: Database, sample: bool = False) -> None: pdf = pd.read_json(get_example_data("birth_names.json.gz")) # TODO(bkyryliuk): move load examples data into the pytest fixture if database.backend == "presto": pdf.ds = pd.to_datetime(pdf.ds, unit="ms") pdf.ds = pdf.ds.dt.strftime("%Y-%m-%d %H:%M%:%S") else: pdf.ds = pd.to_datetime(pdf.ds, unit="ms") pdf = pdf.head(100) if sample else pdf pdf.to_sql( tbl_name, database.get_sqla_engine(), if_exists="replace", chunksize=500, dtype={ # TODO(bkyryliuk): use TIMESTAMP type for presto "ds": DateTime if database.backend != "presto" else String(255), "gender": String(16), "state": String(10), "name": String(255), }, method="multi", index=False, ) print("Done loading table!") print("-" * 80)
def test_delete_sqlatable(app_context: None, session: Session) -> None: """ Test that deleting a ``SqlaTable`` also deletes the corresponding ``Dataset``. """ from superset.columns.models import Column from superset.connectors.sqla.models import SqlaTable, TableColumn from superset.datasets.models import Dataset from superset.models.core import Database from superset.tables.models import Table engine = session.get_bind() Dataset.metadata.create_all(engine) # pylint: disable=no-member columns = [ TableColumn(column_name="ds", is_dttm=1, type="TIMESTAMP"), ] sqla_table = SqlaTable( table_name="old_dataset", columns=columns, metrics=[], database=Database(database_name="my_database", sqlalchemy_uri="sqlite://"), ) session.add(sqla_table) session.flush() datasets = session.query(Dataset).all() assert len(datasets) == 1 session.delete(sqla_table) session.flush() # test that dataset was also deleted datasets = session.query(Dataset).all() assert len(datasets) == 0
def schema_allows_csv_upload(database: Database, schema: Optional[str]) -> bool: if not database.allow_csv_upload: return False schemas = database.get_schema_access_for_csv_upload() if schemas: return schema in schemas return security_manager.can_access_database(database)
def at_least_one_schema_is_allowed(database: Database) -> bool: """ If the user has access to the database or all datasource 1. if schemas_allowed_for_csv_upload is empty a) if database does not support schema user is able to upload csv without specifying schema name b) if database supports schema user is able to upload csv to any schema 2. if schemas_allowed_for_csv_upload is not empty a) if database does not support schema This situation is impossible and upload will fail b) if database supports schema user is able to upload to schema in schemas_allowed_for_csv_upload elif the user does not access to the database or all datasource 1. if schemas_allowed_for_csv_upload is empty a) if database does not support schema user is unable to upload csv b) if database supports schema user is unable to upload csv 2. if schemas_allowed_for_csv_upload is not empty a) if database does not support schema This situation is impossible and user is unable to upload csv b) if database supports schema user is able to upload to schema in schemas_allowed_for_csv_upload """ if security_manager.can_access_database(database): return True schemas = database.get_schema_access_for_csv_upload() if schemas and security_manager.get_schemas_accessible_by_user( database, schemas, False): return True return False
def load_data(data_uri: str, dataset: SqlaTable, example_database: Database, session: Session) -> None: data = request.urlopen(data_uri) # pylint: disable=consider-using-with if data_uri.endswith(".gz"): data = gzip.open(data) df = pd.read_csv(data, encoding="utf-8") dtype = get_dtype(df, dataset) # convert temporal columns for column_name, sqla_type in dtype.items(): if isinstance(sqla_type, (Date, DateTime)): df[column_name] = pd.to_datetime(df[column_name]) # reuse session when loading data if possible, to make import atomic if example_database.sqlalchemy_uri == current_app.config.get( "SQLALCHEMY_DATABASE_URI" ) or not current_app.config.get("SQLALCHEMY_EXAMPLES_URI"): logger.info("Loading data inside the import transaction") connection = session.connection() else: logger.warning("Loading data outside the import transaction") connection = example_database.get_sqla_engine() df.to_sql( dataset.table_name, con=connection, schema=dataset.schema, if_exists="replace", chunksize=CHUNKSIZE, dtype=dtype, index=False, method="multi", )
def test_query_dao_save_metadata(app_context: None, session: Session) -> None: from superset.models.core import Database from superset.models.sql_lab import Query engine = session.get_bind() Query.metadata.create_all(engine) # pylint: disable=no-member db = Database(database_name="my_database", sqlalchemy_uri="sqlite://") query_obj = Query( client_id="foo", database=db, tab_name="test_tab", sql_editor_id="test_editor_id", sql="select * from bar", select_sql="select * from bar", executed_sql="select * from bar", limit=100, select_as_cta=False, rows=100, error_message="none", results_key="abc", ) session.add(db) session.add(query_obj) from superset.queries.dao import QueryDAO query = session.query(Query).one() QueryDAO.save_metadata(query=query, payload={"columns": []}) assert query.extra.get("columns", None) == []
def validate(cls, sql: str, schema: Optional[str], database: Database) -> List[SQLValidationAnnotation]: """ Presto supports query-validation queries by running them with a prepended explain. For example, "SELECT 1 FROM default.mytable" becomes "EXPLAIN (TYPE VALIDATE) SELECT 1 FROM default.mytable. """ user_name = g.user.username if g.user and hasattr(g.user, "username") else None parsed_query = ParsedQuery(sql) statements = parsed_query.get_statements() logger.info("Validating %i statement(s)", len(statements)) engine = database.get_sqla_engine( schema=schema, nullpool=True, user_name=user_name, source=QuerySource.SQL_LAB, ) # Sharing a single connection and cursor across the # execution of all statements (if many) annotations: List[SQLValidationAnnotation] = [] with closing(engine.raw_connection()) as conn: cursor = conn.cursor() for statement in parsed_query.get_statements(): annotation = cls.validate_statement(statement, database, cursor, user_name) if annotation: annotations.append(annotation) logger.debug("Validation found %i error(s)", len(annotations)) return annotations
def test_table(session: Session) -> "SqlaTable": """ Fixture that generates an in-memory table. """ from superset.connectors.sqla.models import SqlaTable, TableColumn from superset.models.core import Database engine = session.get_bind() SqlaTable.metadata.create_all(engine) # pylint: disable=no-member columns = [ TableColumn(column_name="ds", is_dttm=1, type="TIMESTAMP"), TableColumn(column_name="event_time", is_dttm=1, type="TIMESTAMP"), TableColumn(column_name="id", type="INTEGER"), TableColumn(column_name="dttm", type="INTEGER"), TableColumn(column_name="duration_ms", type="INTEGER"), ] return SqlaTable( table_name="test_table", columns=columns, metrics=[], main_dttm_col=None, database=Database(database_name="my_database", sqlalchemy_uri="sqlite://"), )
def test_quote_expressions(app_context: None, session: Session) -> None: """ Test that expressions are quoted appropriately in columns and datasets. """ from superset.connectors.sqla.models import SqlaTable, TableColumn from superset.datasets.models import Dataset from superset.models.core import Database engine = session.get_bind() Dataset.metadata.create_all(engine) # pylint: disable=no-member columns = [ TableColumn(column_name="has space", type="INTEGER"), TableColumn(column_name="no_need", type="INTEGER"), ] sqla_table = SqlaTable( table_name="old dataset", columns=columns, metrics=[], database=Database(database_name="my_database", sqlalchemy_uri="sqlite://"), ) session.add(sqla_table) session.flush() dataset = session.query(Dataset).one() assert dataset.expression == '"old dataset"' assert dataset.columns[0].expression == '"has space"' assert dataset.columns[1].expression == "no_need"
def test_import_dataset_managed_externally(session: Session) -> None: """ Test importing a dataset that is managed externally. """ from superset.connectors.sqla.models import SqlaTable, SqlMetric, TableColumn from superset.datasets.commands.importers.v1.utils import import_dataset from superset.datasets.schemas import ImportV1DatasetSchema from superset.models.core import Database from tests.integration_tests.fixtures.importexport import dataset_config engine = session.get_bind() SqlaTable.metadata.create_all(engine) # pylint: disable=no-member database = Database(database_name="my_database", sqlalchemy_uri="sqlite://") session.add(database) session.flush() dataset_uuid = uuid.uuid4() config = copy.deepcopy(dataset_config) config["is_managed_externally"] = True config["external_url"] = "https://example.org/my_table" config["database_id"] = database.id sqla_table = import_dataset(session, config) assert sqla_table.is_managed_externally is True assert sqla_table.external_url == "https://example.org/my_table"
def test_get_metrics(mocker: MockFixture) -> None: """ Tests for ``get_metrics``. """ from superset.db_engine_specs.base import MetricType from superset.db_engine_specs.sqlite import SqliteEngineSpec from superset.models.core import Database database = Database(database_name="my_database", sqlalchemy_uri="sqlite://") assert database.get_metrics("table") == [ { "expression": "COUNT(*)", "metric_name": "count", "metric_type": "count", "verbose_name": "COUNT(*)", } ] class CustomSqliteEngineSpec(SqliteEngineSpec): @classmethod def get_metrics( cls, database: Database, inspector: Inspector, table_name: str, schema: Optional[str], ) -> List[MetricType]: return [ { "expression": "COUNT(DISTINCT user_id)", "metric_name": "count_distinct_user_id", "metric_type": "count_distinct", "verbose_name": "COUNT(DISTINCT user_id)", }, ] database.get_db_engine_spec_for_backend = mocker.MagicMock( # type: ignore return_value=CustomSqliteEngineSpec ) assert database.get_metrics("table") == [ { "expression": "COUNT(DISTINCT user_id)", "metric_name": "count_distinct_user_id", "metric_type": "count_distinct", "verbose_name": "COUNT(DISTINCT user_id)", }, ]
def get_foreign_keys_metadata( database: Database, table_name: str, schema_name: Optional[str]) -> List[Dict[str, Any]]: foreign_keys = database.get_foreign_keys(table_name, schema_name) for fk in foreign_keys: fk["column_names"] = fk.pop("constrained_columns") fk["type"] = "fk" return foreign_keys
def setUp(self): super(SqlaConnectorTestCase, self).setUp() sqlalchemy_uri = 'sqlite:////tmp/test.db' database = Database( database_name='test_database', sqlalchemy_uri=sqlalchemy_uri) self.connection = database.get_sqla_engine().connect() self.datasource = SqlaTable(table_name='test_datasource', database=database, columns=self.columns, metrics=self.metrics) with database.get_sqla_engine().begin() as connection: self.df.to_sql(self.datasource.table_name, connection, if_exists='replace', index=False, dtype={'received': Date})
def import_from_dict(session, data, sync=[]): """Imports databases and druid clusters from dictionary""" if isinstance(data, dict): logging.info('Importing %d %s', len(data.get(DATABASES_KEY, [])), DATABASES_KEY) for database in data.get(DATABASES_KEY, []): Database.import_from_dict(session, database, sync=sync) logging.info('Importing %d %s', len(data.get(DRUID_CLUSTERS_KEY, [])), DRUID_CLUSTERS_KEY) for datasource in data.get(DRUID_CLUSTERS_KEY, []): DruidCluster.import_from_dict(session, datasource, sync=sync) session.commit() else: logging.info('Supplied object is not a dictionary.')
def select_star(self, database: Database, table_name: str, schema_name: Optional[str] = None) -> FlaskResponse: """ Table schema info --- get: description: Get database select star for table parameters: - in: path schema: type: integer name: pk description: The database id - in: path schema: type: string name: table_name description: Table name - in: path schema: type: string name: schema_name description: Table schema responses: 200: description: select star for table content: text/plain: schema: type: object properties: result: type: string description: SQL select star 400: $ref: '#/components/responses/400' 401: $ref: '#/components/responses/401' 404: $ref: '#/components/responses/404' 422: $ref: '#/components/responses/422' 500: $ref: '#/components/responses/500' """ self.incr_stats("init", self.select_star.__name__) try: result = database.select_star(table_name, schema_name, latest_partition=True, show_cols=True) except NoSuchTableError: self.incr_stats("error", self.select_star.__name__) return self.response(404, message="Table not found on the database") self.incr_stats("success", self.select_star.__name__) return self.response(200, result=result)
def test_apply_limit(self): def compile_sqla_query(qry: Select, schema: Optional[str] = None) -> str: return str( qry.compile( dialect=mssql.dialect(), compile_kwargs={"literal_binds": True} ) ) database = Database( database_name="mssql_test", sqlalchemy_uri="mssql+pymssql://sa:Password_123@localhost:1433/msdb", ) db.session.add(database) db.session.commit() with mock.patch.object(database, "compile_sqla_query", new=compile_sqla_query): test_sql = "SELECT COUNT(*) FROM FOO_TABLE" limited_sql = MssqlEngineSpec.apply_limit_to_sql(test_sql, 1000, database) expected_sql = ( "SELECT TOP 1000 * \n" "FROM (SELECT COUNT(*) AS COUNT_1 FROM FOO_TABLE) AS inner_qry" ) self.assertEqual(expected_sql, limited_sql) test_sql = "SELECT COUNT(*), SUM(id) FROM FOO_TABLE" limited_sql = MssqlEngineSpec.apply_limit_to_sql(test_sql, 1000, database) expected_sql = ( "SELECT TOP 1000 * \n" "FROM (SELECT COUNT(*) AS COUNT_1, SUM(id) AS SUM_2 FROM FOO_TABLE) " "AS inner_qry" ) self.assertEqual(expected_sql, limited_sql) test_sql = "SELECT COUNT(*), FOO_COL1 FROM FOO_TABLE GROUP BY FOO_COL1" limited_sql = MssqlEngineSpec.apply_limit_to_sql(test_sql, 1000, database) expected_sql = ( "SELECT TOP 1000 * \n" "FROM (SELECT COUNT(*) AS COUNT_1, " "FOO_COL1 FROM FOO_TABLE GROUP BY FOO_COL1)" " AS inner_qry" ) self.assertEqual(expected_sql, limited_sql) test_sql = "SELECT COUNT(*), COUNT(*) FROM FOO_TABLE" limited_sql = MssqlEngineSpec.apply_limit_to_sql(test_sql, 1000, database) expected_sql = ( "SELECT TOP 1000 * \n" "FROM (SELECT COUNT(*) AS COUNT_1, COUNT(*) AS COUNT_2 FROM FOO_TABLE)" " AS inner_qry" ) self.assertEqual(expected_sql, limited_sql) db.session.delete(database) db.session.commit()
def test_update_sqlatable(mocker: MockFixture, app_context: None, session: Session) -> None: """ Test that updating a ``SqlaTable`` also updates the corresponding ``Dataset``. """ # patch session mocker.patch("superset.security.SupersetSecurityManager.get_session", return_value=session) from superset.columns.models import Column from superset.connectors.sqla.models import SqlaTable, TableColumn from superset.datasets.models import Dataset from superset.models.core import Database from superset.tables.models import Table engine = session.get_bind() Dataset.metadata.create_all(engine) # pylint: disable=no-member columns = [ TableColumn(column_name="ds", is_dttm=1, type="TIMESTAMP"), ] sqla_table = SqlaTable( table_name="old_dataset", columns=columns, metrics=[], database=Database(database_name="my_database", sqlalchemy_uri="sqlite://"), ) session.add(sqla_table) session.flush() dataset = session.query(Dataset).one() assert len(dataset.columns) == 1 # add a column to the original ``SqlaTable`` instance sqla_table.columns.append( TableColumn(column_name="user_id", type="INTEGER")) session.flush() # check that the column was added to the dataset dataset = session.query(Dataset).one() assert len(dataset.columns) == 2 # delete the column in the original instance sqla_table.columns = sqla_table.columns[1:] session.flush() # check that the column was also removed from the dataset dataset = session.query(Dataset).one() assert len(dataset.columns) == 1 # modify the attribute in a column sqla_table.columns[0].is_dttm = True session.flush() # check that the dataset column was modified dataset = session.query(Dataset).one() assert dataset.columns[0].is_temporal is True
def create_test_table_context(database: Database): database.get_sqla_engine().execute( "CREATE TABLE test_table AS SELECT 1 as first, 2 as second") database.get_sqla_engine().execute( "INSERT INTO test_table (first, second) VALUES (1, 2)") database.get_sqla_engine().execute( "INSERT INTO test_table (first, second) VALUES (3, 4)") yield db.session database.get_sqla_engine().execute("DROP TABLE test_table")
def test_alter_new_orm_column(self): """ DB Eng Specs (crate): Test alter orm column """ database = Database(database_name="crate", sqlalchemy_uri="crate://db") tbl = SqlaTable(table_name="druid_tbl", database=database) col = TableColumn(column_name="ts", type="TIMESTAMP", table=tbl) CrateEngineSpec.alter_new_orm_column(col) assert col.python_date_format == "epoch_ms"
def enable_csv_upload(self, database: models.Database) -> None: """Enables csv upload in the given database.""" database.allow_csv_upload = True db.session.commit() add_datasource_page = self.get_resp("/databaseview/list/") self.assertIn("Upload a CSV", add_datasource_page) form_get = self.get_resp("/csvtodatabaseview/form") self.assertIn("CSV to Database configuration", form_get)
def build_db_for_connection_test(server_cert: str, extra: str, impersonate_user: bool, encrypted_extra: str) -> Database: return Database( server_cert=server_cert, extra=extra, impersonate_user=impersonate_user, encrypted_extra=encrypted_extra, )
def test_database_connection_test_mutator(self): database = Database(sqlalchemy_uri="snowflake://abc") SnowflakeEngineSpec.mutate_db_for_connection_test(database) engine_params = json.loads(database.extra or "{}") self.assertDictEqual( {"engine_params": {"connect_args": {"validate_default_parameters": True}}}, engine_params, )
def import_from_dict(data: Dict[str, Any], sync: Optional[List[str]] = None) -> None: """Imports databases and druid clusters from dictionary""" if not sync: sync = [] if isinstance(data, dict): logger.info("Importing %d %s", len(data.get(DATABASES_KEY, [])), DATABASES_KEY) for database in data.get(DATABASES_KEY, []): Database.import_from_dict(database, sync=sync) logger.info("Importing %d %s", len(data.get(DRUID_CLUSTERS_KEY, [])), DRUID_CLUSTERS_KEY) for datasource in data.get(DRUID_CLUSTERS_KEY, []): DruidCluster.import_from_dict(datasource, sync=sync) db.session.commit() else: logger.info("Supplied object is not a dictionary.")
def test_dataset_model(app_context: None, session: Session) -> None: """ Test basic attributes of a ``Dataset``. """ from superset.columns.models import Column from superset.datasets.models import Dataset from superset.models.core import Database from superset.tables.models import Table engine = session.get_bind() Dataset.metadata.create_all(engine) # pylint: disable=no-member table = Table( name="my_table", schema="my_schema", catalog="my_catalog", database=Database(database_name="my_database", sqlalchemy_uri="sqlite://"), columns=[ Column(name="longitude", expression="longitude"), Column(name="latitude", expression="latitude"), ], ) session.add(table) session.flush() dataset = Dataset( database=table.database, name="positions", expression=""" SELECT array_agg(array[longitude,latitude]) AS position FROM my_catalog.my_schema.my_table """, tables=[table], columns=[ Column( name="position", expression="array_agg(array[longitude,latitude])", ), ], ) session.add(dataset) session.flush() assert dataset.id == 1 assert dataset.uuid is not None assert dataset.name == "positions" assert ( dataset.expression == """ SELECT array_agg(array[longitude,latitude]) AS position FROM my_catalog.my_schema.my_table """ ) assert [table.name for table in dataset.tables] == ["my_table"] assert [column.name for column in dataset.columns] == ["position"]
def apply_limit_if_exists(database: Database, increased_limit: Optional[int], query: Query, sql: str) -> str: if query.limit and increased_limit: # We are fetching one more than the requested limit in order # to test whether there are more rows than the limit. # Later, the extra row will be dropped before sending # the results back to the user. sql = database.apply_limit_to_sql(sql, increased_limit, force=True) return sql
def export_schema_to_dict(back_references): """Exports the supported import/export schema to a dictionary""" databases = [Database.export_schema(recursive=True, include_parent_ref=back_references)] clusters = [DruidCluster.export_schema(recursive=True, include_parent_ref=back_references)] data = dict() if databases: data[DATABASES_KEY] = databases if clusters: data[DRUID_CLUSTERS_KEY] = clusters return data
def test_database_schema_presto(self): sqlalchemy_uri = 'presto://presto.airbnb.io:8080/hive/default' model = Database(sqlalchemy_uri=sqlalchemy_uri) db = make_url(model.get_sqla_engine().url).database self.assertEquals('hive/default', db) db = make_url(model.get_sqla_engine(schema='core_db').url).database self.assertEquals('hive/core_db', db) sqlalchemy_uri = 'presto://presto.airbnb.io:8080/hive' model = Database(sqlalchemy_uri=sqlalchemy_uri) db = make_url(model.get_sqla_engine().url).database self.assertEquals('hive', db) db = make_url(model.get_sqla_engine(schema='core_db').url).database self.assertEquals('hive/core_db', db)
def test_database_for_various_backend(self): sqlalchemy_uri = 'presto://presto.airbnb.io:8080/hive/default' model = Database(sqlalchemy_uri=sqlalchemy_uri) url = make_url(model.sqlalchemy_uri) db = model.get_database_for_various_backend(url, None) assert db == 'hive/default' db = model.get_database_for_various_backend(url, 'raw_data') assert db == 'hive/raw_data' sqlalchemy_uri = 'redshift+psycopg2://superset:[email protected]:5439/prod' model = Database(sqlalchemy_uri=sqlalchemy_uri) url = make_url(model.sqlalchemy_uri) db = model.get_database_for_various_backend(url, None) assert db == 'prod' db = model.get_database_for_various_backend(url, 'test') assert db == 'prod' sqlalchemy_uri = 'postgresql+psycopg2://superset:[email protected]:5439/prod' model = Database(sqlalchemy_uri=sqlalchemy_uri) url = make_url(model.sqlalchemy_uri) db = model.get_database_for_various_backend(url, None) assert db == 'prod' db = model.get_database_for_various_backend(url, 'adhoc') assert db == 'prod' sqlalchemy_uri = 'hive://[email protected]:10000/raw_data' model = Database(sqlalchemy_uri=sqlalchemy_uri) url = make_url(model.sqlalchemy_uri) db = model.get_database_for_various_backend(url, None) assert db == 'raw_data' db = model.get_database_for_various_backend(url, 'adhoc') assert db == 'adhoc' sqlalchemy_uri = 'mysql://*****:*****@mysql.airbnb.io/superset' model = Database(sqlalchemy_uri=sqlalchemy_uri) url = make_url(model.sqlalchemy_uri) db = model.get_database_for_various_backend(url, None) assert db == 'superset' db = model.get_database_for_various_backend(url, 'adhoc') assert db == 'adhoc'