def load_energy(only_metadata: bool = False, force: bool = False, sample: bool = False) -> None: """Loads an energy related dataset to use with sankey and graphs""" tbl_name = "energy_usage" database = utils.get_example_database() table_exists = database.has_table_by_name(tbl_name) if not only_metadata and (not table_exists or force): data = get_example_data("energy.json.gz") pdf = pd.read_json(data) pdf = pdf.head(100) if sample else pdf pdf.to_sql( tbl_name, database.get_sqla_engine(), if_exists="replace", chunksize=500, dtype={ "source": String(255), "target": String(255), "value": Float() }, index=False, method="multi", ) print("Creating table [wb_health_population] reference") tbl = db.session.query(TBL).filter_by(table_name=tbl_name).first() if not tbl: tbl = TBL(table_name=tbl_name) tbl.description = "Energy consumption" tbl.database = database if not any(col.metric_name == "sum__value" for col in tbl.metrics): col = str(column("value").compile(db.engine)) tbl.metrics.append( SqlMetric(metric_name="sum__value", expression=f"SUM({col})")) db.session.merge(tbl) db.session.commit() tbl.fetch_metadata() slc = Slice( slice_name="Energy Sankey", viz_type="sankey", datasource_type="table", datasource_id=tbl.id, params=textwrap.dedent("""\ { "collapsed_fieldsets": "", "groupby": [ "source", "target" ], "metric": "sum__value", "row_limit": "5000", "slice_name": "Energy Sankey", "viz_type": "sankey" } """), ) misc_dash_slices.add(slc.slice_name) merge_slice(slc) slc = Slice( slice_name="Energy Force Layout", viz_type="graph_chart", datasource_type="table", datasource_id=tbl.id, params=textwrap.dedent("""\ { "source": "source", "target": "target", "edgeLength": 400, "repulsion": 1000, "layout": "force", "metric": "sum__value", "row_limit": "5000", "slice_name": "Force", "viz_type": "graph_chart" } """), ) misc_dash_slices.add(slc.slice_name) merge_slice(slc) slc = Slice( slice_name="Heatmap", viz_type="heatmap", datasource_type="table", datasource_id=tbl.id, params=textwrap.dedent("""\ { "all_columns_x": "source", "all_columns_y": "target", "canvas_image_rendering": "pixelated", "collapsed_fieldsets": "", "linear_color_scheme": "blue_white_yellow", "metric": "sum__value", "normalize_across": "heatmap", "slice_name": "Heatmap", "viz_type": "heatmap", "xscale_interval": "1", "yscale_interval": "1" } """), ) misc_dash_slices.add(slc.slice_name) merge_slice(slc)
def test_export_database_command(self, mock_g): mock_g.user = security_manager.find_user("admin") example_db = get_example_database() db_uuid = example_db.uuid command = ExportDatabasesCommand([example_db.id]) contents = dict(command.run()) # TODO: this list shouldn't depend on the order in which unit tests are run # or on the backend; for now use a stable subset core_files = { "metadata.yaml", "databases/examples.yaml", "datasets/examples/energy_usage.yaml", "datasets/examples/wb_health_population.yaml", "datasets/examples/birth_names.yaml", } expected_extra = { "engine_params": {}, "metadata_cache_timeout": {}, "metadata_params": {}, "schemas_allowed_for_csv_upload": [], } if backend() == "presto": expected_extra = { **expected_extra, "engine_params": {"connect_args": {"poll_interval": 0.1}}, } assert core_files.issubset(set(contents.keys())) if example_db.backend == "postgresql": ds_type = "TIMESTAMP WITHOUT TIME ZONE" elif example_db.backend == "hive": ds_type = "TIMESTAMP" elif example_db.backend == "presto": ds_type = "VARCHAR(255)" else: ds_type = "DATETIME" if example_db.backend == "mysql": big_int_type = "BIGINT(20)" else: big_int_type = "BIGINT" metadata = yaml.safe_load(contents["databases/examples.yaml"]) assert metadata == ( { "allow_csv_upload": True, "allow_ctas": True, "allow_cvas": True, "allow_run_async": False, "cache_timeout": None, "database_name": "examples", "expose_in_sqllab": True, "extra": expected_extra, "sqlalchemy_uri": example_db.sqlalchemy_uri, "uuid": str(example_db.uuid), "version": "1.0.0", } ) metadata = yaml.safe_load(contents["datasets/examples/birth_names.yaml"]) metadata.pop("uuid") metadata["columns"].sort(key=lambda x: x["column_name"]) expected_metadata = { "cache_timeout": None, "columns": [ { "column_name": "ds", "description": None, "expression": None, "filterable": True, "groupby": True, "is_active": True, "is_dttm": True, "python_date_format": None, "type": ds_type, "verbose_name": None, }, { "column_name": "gender", "description": None, "expression": None, "filterable": True, "groupby": True, "is_active": True, "is_dttm": False, "python_date_format": None, "type": "STRING" if example_db.backend == "hive" else "VARCHAR(16)", "verbose_name": None, }, { "column_name": "name", "description": None, "expression": None, "filterable": True, "groupby": True, "is_active": True, "is_dttm": False, "python_date_format": None, "type": "STRING" if example_db.backend == "hive" else "VARCHAR(255)", "verbose_name": None, }, { "column_name": "num", "description": None, "expression": None, "filterable": True, "groupby": True, "is_active": True, "is_dttm": False, "python_date_format": None, "type": big_int_type, "verbose_name": None, }, { "column_name": "num_california", "description": None, "expression": "CASE WHEN state = 'CA' THEN num ELSE 0 END", "filterable": True, "groupby": True, "is_active": True, "is_dttm": False, "python_date_format": None, "type": None, "verbose_name": None, }, { "column_name": "state", "description": None, "expression": None, "filterable": True, "groupby": True, "is_active": True, "is_dttm": False, "python_date_format": None, "type": "STRING" if example_db.backend == "hive" else "VARCHAR(10)", "verbose_name": None, }, { "column_name": "num_boys", "description": None, "expression": None, "filterable": True, "groupby": True, "is_active": True, "is_dttm": False, "python_date_format": None, "type": big_int_type, "verbose_name": None, }, { "column_name": "num_girls", "description": None, "expression": None, "filterable": True, "groupby": True, "is_active": True, "is_dttm": False, "python_date_format": None, "type": big_int_type, "verbose_name": None, }, ], "database_uuid": str(db_uuid), "default_endpoint": None, "description": "", "extra": None, "fetch_values_predicate": "123 = 123", "filter_select_enabled": True, "main_dttm_col": "ds", "metrics": [ { "d3format": None, "description": None, "expression": "COUNT(*)", "extra": None, "metric_name": "count", "metric_type": "count", "verbose_name": "COUNT(*)", "warning_text": None, }, { "d3format": None, "description": None, "expression": "SUM(num)", "extra": None, "metric_name": "sum__num", "metric_type": None, "verbose_name": None, "warning_text": None, }, ], "offset": 0, "params": None, "schema": None, "sql": None, "table_name": "birth_names", "template_params": None, "version": "1.0.0", } expected_metadata["columns"].sort(key=lambda x: x["column_name"]) assert metadata == expected_metadata
def load_country_map_data(only_metadata: bool = False, force: bool = False) -> None: """Loading data for map with country map""" tbl_name = "birth_france_by_region" database = utils.get_example_database() table_exists = database.has_table_by_name(tbl_name) if not only_metadata and (not table_exists or force): csv_bytes = get_example_data("birth_france_data_for_country_map.csv", is_gzip=False, make_bytes=True) data = pd.read_csv(csv_bytes, encoding="utf-8") data["dttm"] = datetime.datetime.now().date() data.to_sql( # pylint: disable=no-member tbl_name, database.get_sqla_engine(), if_exists="replace", chunksize=500, dtype={ "DEPT_ID": String(10), "2003": BigInteger, "2004": BigInteger, "2005": BigInteger, "2006": BigInteger, "2007": BigInteger, "2008": BigInteger, "2009": BigInteger, "2010": BigInteger, "2011": BigInteger, "2012": BigInteger, "2013": BigInteger, "2014": BigInteger, "dttm": Date(), }, index=False, ) print("Done loading table!") print("-" * 80) print("Creating table reference") table = get_table_connector_registry() obj = db.session.query(table).filter_by(table_name=tbl_name).first() if not obj: obj = table(table_name=tbl_name) obj.main_dttm_col = "dttm" obj.database = database obj.filter_select_enabled = True if not any(col.metric_name == "avg__2004" for col in obj.metrics): col = str(column("2004").compile(db.engine)) obj.metrics.append( SqlMetric(metric_name="avg__2004", expression=f"AVG({col})")) db.session.merge(obj) db.session.commit() obj.fetch_metadata() tbl = obj slice_data = { "granularity_sqla": "", "since": "", "until": "", "viz_type": "country_map", "entity": "DEPT_ID", "metric": { "expressionType": "SIMPLE", "column": { "type": "INT", "column_name": "2004" }, "aggregate": "AVG", "label": "Boys", "optionName": "metric_112342", }, "row_limit": 500000, "select_country": "france", } print("Creating a slice") slc = Slice( slice_name="Birth in France by department in 2016", viz_type="country_map", datasource_type="table", datasource_id=tbl.id, params=get_slice_json(slice_data), ) misc_dash_slices.add(slc.slice_name) merge_slice(slc)
def test_set_perm_sqla_table(self): session = db.session table = SqlaTable( schema="tmp_schema", table_name="tmp_perm_table", database=get_example_database(), ) session.add(table) session.commit() stored_table = (session.query(SqlaTable).filter_by( table_name="tmp_perm_table").one()) self.assertEqual(stored_table.perm, f"[examples].[tmp_perm_table](id:{stored_table.id})") self.assertIsNotNone( security_manager.find_permission_view_menu("datasource_access", stored_table.perm)) self.assertEqual(stored_table.schema_perm, "[examples].[tmp_schema]") self.assertIsNotNone( security_manager.find_permission_view_menu( "schema_access", stored_table.schema_perm)) # table name change stored_table.table_name = "tmp_perm_table_v2" session.commit() stored_table = (session.query(SqlaTable).filter_by( table_name="tmp_perm_table_v2").one()) self.assertEqual( stored_table.perm, f"[examples].[tmp_perm_table_v2](id:{stored_table.id})") self.assertIsNotNone( security_manager.find_permission_view_menu("datasource_access", stored_table.perm)) # no changes in schema self.assertEqual(stored_table.schema_perm, "[examples].[tmp_schema]") self.assertIsNotNone( security_manager.find_permission_view_menu( "schema_access", stored_table.schema_perm)) # schema name change stored_table.schema = "tmp_schema_v2" session.commit() stored_table = (session.query(SqlaTable).filter_by( table_name="tmp_perm_table_v2").one()) self.assertEqual( stored_table.perm, f"[examples].[tmp_perm_table_v2](id:{stored_table.id})") self.assertIsNotNone( security_manager.find_permission_view_menu("datasource_access", stored_table.perm)) # no changes in schema self.assertEqual(stored_table.schema_perm, "[examples].[tmp_schema_v2]") self.assertIsNotNone( security_manager.find_permission_view_menu( "schema_access", stored_table.schema_perm)) # database change new_db = Database(sqlalchemy_uri="some_uri", database_name="tmp_db") session.add(new_db) stored_table.database = (session.query(Database).filter_by( database_name="tmp_db").one()) session.commit() stored_table = (session.query(SqlaTable).filter_by( table_name="tmp_perm_table_v2").one()) self.assertEqual( stored_table.perm, f"[tmp_db].[tmp_perm_table_v2](id:{stored_table.id})") self.assertIsNotNone( security_manager.find_permission_view_menu("datasource_access", stored_table.perm)) # no changes in schema self.assertEqual(stored_table.schema_perm, "[tmp_db].[tmp_schema_v2]") self.assertIsNotNone( security_manager.find_permission_view_menu( "schema_access", stored_table.schema_perm)) # no schema stored_table.schema = None session.commit() stored_table = (session.query(SqlaTable).filter_by( table_name="tmp_perm_table_v2").one()) self.assertEqual( stored_table.perm, f"[tmp_db].[tmp_perm_table_v2](id:{stored_table.id})") self.assertIsNotNone( security_manager.find_permission_view_menu("datasource_access", stored_table.perm)) self.assertIsNone(stored_table.schema_perm) session.delete(new_db) session.delete(stored_table) session.commit()
def insert_default_dataset(self): return self.insert_dataset( "ab_permission", "", [self.get_user("admin").id], get_example_database() )
def _import( session: Session, configs: Dict[str, Any], overwrite: bool = False, force_data: bool = False, ) -> None: # import databases database_ids: Dict[str, int] = {} for file_name, config in configs.items(): if file_name.startswith("databases/"): database = import_database(session, config, overwrite=overwrite) database_ids[str(database.uuid)] = database.id # import datasets # If database_uuid is not in the list of UUIDs it means that the examples # database was created before its UUID was frozen, so it has a random UUID. # We need to determine its ID so we can point the dataset to it. examples_db = get_example_database() dataset_info: Dict[str, Dict[str, Any]] = {} for file_name, config in configs.items(): if file_name.startswith("datasets/"): # find the ID of the corresponding database if config["database_uuid"] not in database_ids: if examples_db is None: raise Exception("Cannot find examples database") config["database_id"] = examples_db.id else: config["database_id"] = database_ids[ config["database_uuid"]] dataset = import_dataset(session, config, overwrite=overwrite, force_data=force_data) try: dataset = import_dataset(session, config, overwrite=overwrite, force_data=force_data) except MultipleResultsFound: # Multiple result can be found for datasets. There was a bug in # load-examples that resulted in datasets being loaded with a NULL # schema. Users could then add a new dataset with the same name in # the correct schema, resulting in duplicates, since the uniqueness # constraint was not enforced correctly in the application logic. # See https://github.com/apache/superset/issues/16051. continue dataset_info[str(dataset.uuid)] = { "datasource_id": dataset.id, "datasource_type": "view" if dataset.is_sqllab_view else "table", "datasource_name": dataset.table_name, } # import charts chart_ids: Dict[str, int] = {} for file_name, config in configs.items(): if (file_name.startswith("charts/") and config["dataset_uuid"] in dataset_info): # update datasource id, type, and name config.update(dataset_info[config["dataset_uuid"]]) chart = import_chart(session, config, overwrite=overwrite) chart_ids[str(chart.uuid)] = chart.id # store the existing relationship between dashboards and charts existing_relationships = session.execute( select( [dashboard_slices.c.dashboard_id, dashboard_slices.c.slice_id])).fetchall() # import dashboards dashboard_chart_ids: List[Tuple[int, int]] = [] for file_name, config in configs.items(): if file_name.startswith("dashboards/"): try: config = update_id_refs(config, chart_ids, dataset_info) except KeyError: continue dashboard = import_dashboard(session, config, overwrite=overwrite) dashboard.published = True for uuid in find_chart_uuids(config["position"]): chart_id = chart_ids[uuid] if (dashboard.id, chart_id) not in existing_relationships: dashboard_chart_ids.append((dashboard.id, chart_id)) # set ref in the dashboard_slices table values = [{ "dashboard_id": dashboard_id, "slice_id": chart_id } for (dashboard_id, chart_id) in dashboard_chart_ids] # pylint: disable=no-value-for-parameter # sqlalchemy/issues/4656 session.execute(dashboard_slices.insert(), values)
def quote_f(value: Optional[str]): if not value: return value return get_example_database( ).inspector.engine.dialect.identifier_preparer.quote_identifier(value)
def load_examples_run( load_test_data: bool = False, load_big_data: bool = False, only_metadata: bool = False, force: bool = False, ) -> None: if only_metadata: print("Loading examples metadata") else: examples_db = utils.get_example_database() print(f"Loading examples metadata and related data into {examples_db}") # pylint: disable=import-outside-toplevel from superset import examples examples.load_css_templates() if load_test_data: print("Loading energy related dataset") examples.load_energy(only_metadata, force) print("Loading [World Bank's Health Nutrition and Population Stats]") examples.load_world_bank_health_n_pop(only_metadata, force) print("Loading [Birth names]") examples.load_birth_names(only_metadata, force) if load_test_data: print("Loading [Tabbed dashboard]") examples.load_tabbed_dashboard(only_metadata) if not load_test_data: print("Loading [Random long/lat data]") examples.load_long_lat_data(only_metadata, force) print("Loading [Country Map data]") examples.load_country_map_data(only_metadata, force) print("Loading [San Francisco population polygons]") examples.load_sf_population_polygons(only_metadata, force) print("Loading [Flights data]") examples.load_flights(only_metadata, force) print("Loading [BART lines]") examples.load_bart_lines(only_metadata, force) print("Loading [Multi Line]") examples.load_multi_line(only_metadata) print("Loading [Misc Charts] dashboard") examples.load_misc_dashboard() print("Loading DECK.gl demo") examples.load_deck_dash() if load_big_data: print("Loading big synthetic data for tests") examples.load_big_data() # load examples that are stored as YAML config files examples.load_examples_from_configs(force, load_test_data)
def test_run_sync_query_cta_config(self, ctas_method): with mock.patch( "superset.views.core.get_cta_schema_name", lambda d, u, s, sql: CTAS_SCHEMA_NAME, ): examples_db = get_example_database() db_id = examples_db.id backend = examples_db.backend if backend == "sqlite": # sqlite doesn't support schemas return tmp_table_name = f"tmp_async_22_{ctas_method.lower()}" quote = (examples_db.inspector.engine.dialect.identifier_preparer. quote_identifier) expected_full_table_name = f"{CTAS_SCHEMA_NAME}.{quote(tmp_table_name)}" self.drop_table_if_exists(expected_full_table_name, ctas_method, examples_db) name = "James" sql_where = f"SELECT name FROM birth_names WHERE name='{name}'" result = self.run_sql( db_id, sql_where, f"3_{ctas_method}", tmp_table=tmp_table_name, cta=True, ctas_method=ctas_method, ) self.assertEqual(QueryStatus.SUCCESS, result["query"]["state"], msg=result) expected_result = [] # TODO(bkyryliuk): refactor database specific logic into a separate class if backend == "presto": expected_result = ([{ "rows": 1 }] if ctas_method == CtasMethod.TABLE else [{ "result": True }]) self.assertEqual(expected_result, result["data"]) expected_columns = [] # TODO(bkyryliuk): refactor database specific logic into a separate class if backend == "presto": expected_columns = [{ "name": "rows" if ctas_method == CtasMethod.TABLE else "result", "type": "BIGINT" if ctas_method == CtasMethod.TABLE else "BOOLEAN", "is_date": False, }] self.assertEqual(expected_columns, result["columns"]) query = self.get_query_by_id(result["query"]["serverId"]) self.assertEqual( f"CREATE {ctas_method} {CTAS_SCHEMA_NAME}.{tmp_table_name} AS \n" "SELECT name FROM birth_names " "WHERE name='James'", query.executed_sql, ) # TODO(bkyryliuk): quote table and schema names for all databases if backend in {"presto", "hive"}: assert query.select_sql == ( f"SELECT *\nFROM {quote(CTAS_SCHEMA_NAME)}.{quote(tmp_table_name)}" ) else: assert (query.select_sql == "SELECT *\n" f"FROM {CTAS_SCHEMA_NAME}.{tmp_table_name}") time.sleep(CELERY_SHORT_SLEEP_TIME) results = self.run_sql(db_id, query.select_sql) self.assertEqual(QueryStatus.SUCCESS, results["status"], msg=result) self.drop_table_if_exists(expected_full_table_name, ctas_method, get_example_database())
def load_examples_run( load_test_data: bool, only_metadata: bool = False, force: bool = False ) -> None: if only_metadata: print("Loading examples metadata") else: examples_db = utils.get_example_database() print(f"Loading examples metadata and related data into {examples_db}") from superset import examples examples.load_css_templates() print("Loading energy related dataset") examples.load_energy(only_metadata, force) print("Loading [World Bank's Health Nutrition and Population Stats]") examples.load_world_bank_health_n_pop(only_metadata, force) print("Loading [Birth names]") examples.load_birth_names(only_metadata, force) print("Loading [Unicode test data]") examples.load_unicode_test_data(only_metadata, force) if not load_test_data: print("Loading [Random time series data]") examples.load_random_time_series_data(only_metadata, force) print("Loading [Random long/lat data]") examples.load_long_lat_data(only_metadata, force) print("Loading [Country Map data]") examples.load_country_map_data(only_metadata, force) print("Loading [Multiformat time series]") examples.load_multiformat_time_series(only_metadata, force) print("Loading [Paris GeoJson]") examples.load_paris_iris_geojson(only_metadata, force) print("Loading [San Francisco population polygons]") examples.load_sf_population_polygons(only_metadata, force) print("Loading [Flights data]") examples.load_flights(only_metadata, force) print("Loading [BART lines]") examples.load_bart_lines(only_metadata, force) print("Loading [Multi Line]") examples.load_multi_line(only_metadata) print("Loading [Misc Charts] dashboard") examples.load_misc_dashboard() print("Loading DECK.gl demo") examples.load_deck_dash() print("Loading [Tabbed dashboard]") examples.load_tabbed_dashboard(only_metadata)
def test_run_sync_query_dont_exist(self): main_db = get_example_database() db_id = main_db.id sql_dont_exist = "SELECT name FROM table_dont_exist" result1 = self.run_sql(db_id, sql_dont_exist, "1", cta=True) self.assertTrue("error" in result1)
def get_birth_names_dataset() -> SqlaTable: example_db = get_example_database() return (db.session.query(SqlaTable).filter_by( database=example_db, table_name="birth_names").one())
def test_get_dataset_distinct_schema(self): """ Dataset API: Test get dataset distinct schema """ def pg_test_query_parameter(query_parameter, expected_response): uri = f"api/v1/dataset/distinct/schema?q={prison.dumps(query_parameter)}" rv = self.client.get(uri) response = json.loads(rv.data.decode("utf-8")) assert rv.status_code == 200 assert response == expected_response example_db = get_example_database() datasets = [] if example_db.backend == "postgresql": datasets.append( self.insert_dataset("ab_permission", "public", [], get_main_database())) datasets.append( self.insert_dataset("columns", "information_schema", [], get_main_database())) schema_values = [ "", "admin_database", "information_schema", "public", ] expected_response = { "count": 4, "result": [{ "text": val, "value": val } for val in schema_values], } self.login(username="******") uri = "api/v1/dataset/distinct/schema" rv = self.client.get(uri) response = json.loads(rv.data.decode("utf-8")) assert rv.status_code == 200 assert response == expected_response # Test filter query_parameter = {"filter": "inf"} pg_test_query_parameter( query_parameter, { "count": 1, "result": [{ "text": "information_schema", "value": "information_schema" }], }, ) query_parameter = {"page": 0, "page_size": 1} pg_test_query_parameter( query_parameter, { "count": 4, "result": [{ "text": "", "value": "" }] }, ) query_parameter = {"page": 1, "page_size": 1} pg_test_query_parameter( query_parameter, { "count": 4, "result": [{ "text": "admin_database", "value": "admin_database" }], }, ) for dataset in datasets: db.session.delete(dataset) db.session.commit()
def get_energy_usage_dataset(): example_db = get_example_database() return (db.session.query(SqlaTable).filter_by( database=example_db, table_name="energy_usage").one())
def test_select_star(self): self.login(username="******") examples_db = utils.get_example_database() resp = self.get_resp(f"/superset/select_star/{examples_db.id}/birth_names") self.assertIn("gender", resp)
def test_import_table_no_metadata(self): db_id = get_example_database().id table = self.create_table("pure_table", id=10001) imported_id = SqlaTable.import_obj(table, db_id, import_time=1989) imported = self.get_table_by_id(imported_id) self.assert_table_equals(table, imported)
def test_export_database_command(self, mock_g): mock_g.user = security_manager.find_user("admin") example_db = get_example_database() command = ExportDatabasesCommand(database_ids=[example_db.id]) contents = dict(command.run()) # TODO: this list shouldn't depend on the order in which unit tests are run # or on the backend; for now use a stable subset core_files = { "databases/examples.yaml", "datasets/examples/energy_usage.yaml", "datasets/examples/wb_health_population.yaml", "datasets/examples/birth_names.yaml", } expected_extra = { "engine_params": {}, "metadata_cache_timeout": {}, "metadata_params": {}, "schemas_allowed_for_csv_upload": [], } if backend() == "presto": expected_extra = {"engine_params": {"connect_args": {"poll_interval": 0.1}}} assert core_files.issubset(set(contents.keys())) metadata = yaml.safe_load(contents["databases/examples.yaml"]) assert metadata == ( { "allow_csv_upload": True, "allow_ctas": True, "allow_cvas": True, "allow_run_async": False, "cache_timeout": None, "database_name": "examples", "expose_in_sqllab": True, "extra": expected_extra, "sqlalchemy_uri": example_db.sqlalchemy_uri, "uuid": str(example_db.uuid), "version": "1.0.0", } ) metadata = yaml.safe_load(contents["datasets/examples/birth_names.yaml"]) metadata.pop("uuid") assert metadata == { "table_name": "birth_names", "main_dttm_col": None, "description": "Adding a DESCRip", "default_endpoint": "", "offset": 66, "cache_timeout": 55, "schema": "", "sql": "", "params": None, "template_params": None, "filter_select_enabled": True, "fetch_values_predicate": None, "metrics": [ { "metric_name": "ratio", "verbose_name": "Ratio Boys/Girls", "metric_type": None, "expression": "sum(sum_boys) / sum(sum_girls)", "description": "This represents the ratio of boys/girls", "d3format": ".2%", "extra": None, "warning_text": "no warning", }, { "metric_name": "sum__num", "verbose_name": "Babies", "metric_type": None, "expression": "SUM(num)", "description": "", "d3format": "", "extra": None, "warning_text": "", }, { "metric_name": "count", "verbose_name": "", "metric_type": None, "expression": "count(1)", "description": None, "d3format": None, "extra": None, "warning_text": None, }, ], "columns": [ { "column_name": "num_california", "verbose_name": None, "is_dttm": False, "is_active": None, "type": "NUMBER", "groupby": False, "filterable": False, "expression": "CASE WHEN state = 'CA' THEN num ELSE 0 END", "description": None, "python_date_format": None, }, { "column_name": "ds", "verbose_name": "", "is_dttm": True, "is_active": None, "type": "DATETIME", "groupby": True, "filterable": True, "expression": "", "description": None, "python_date_format": None, }, { "column_name": "sum_girls", "verbose_name": None, "is_dttm": False, "is_active": None, "type": "BIGINT(20)", "groupby": False, "filterable": False, "expression": "", "description": None, "python_date_format": None, }, { "column_name": "gender", "verbose_name": None, "is_dttm": False, "is_active": None, "type": "VARCHAR(16)", "groupby": True, "filterable": True, "expression": "", "description": None, "python_date_format": None, }, { "column_name": "state", "verbose_name": None, "is_dttm": None, "is_active": None, "type": "VARCHAR(10)", "groupby": True, "filterable": True, "expression": None, "description": None, "python_date_format": None, }, { "column_name": "sum_boys", "verbose_name": None, "is_dttm": None, "is_active": None, "type": "BIGINT(20)", "groupby": True, "filterable": True, "expression": None, "description": None, "python_date_format": None, }, { "column_name": "num", "verbose_name": None, "is_dttm": None, "is_active": None, "type": "BIGINT(20)", "groupby": True, "filterable": True, "expression": None, "description": None, "python_date_format": None, }, { "column_name": "name", "verbose_name": None, "is_dttm": None, "is_active": None, "type": "VARCHAR(255)", "groupby": True, "filterable": True, "expression": None, "description": None, "python_date_format": None, }, ], "version": "1.0.0", "database_uuid": str(example_db.uuid), }
def init(): """Inits the Superset application""" utils.get_or_create_main_db() utils.get_example_database() appbuilder.add_permissions(update_perms=True) security_manager.sync_role_definitions()
def drop_table_if_exists(table_name: str, table_type: CtasMethod) -> None: """Drop table if it exists, works on any DB""" sql = f"DROP {table_type} IF EXISTS {table_name}" get_example_database().get_sqla_engine().execute(sql)
def test_extra_table_metadata(self): self.login("admin") dbid = utils.get_example_database().id self.get_json_resp( f"/superset/extra_table_metadata/{dbid}/birth_names/superset/" )
def get_database_by_name(database_name: str = "main") -> Database: if database_name == "examples": return get_example_database() else: raise ValueError("Database doesn't exist")
def test_process_template(self): maindb = utils.get_example_database() sql = "SELECT '{{ datetime(2017, 1, 1).isoformat() }}'" tp = jinja_context.get_template_processor(database=maindb) rendered = tp.process_template(sql) self.assertEqual("SELECT '2017-01-01T00:00:00'", rendered)
def test_extra_cache_keys(self, flask_g): flask_g.user.username = "******" base_query_obj = { "granularity": None, "from_dttm": None, "to_dttm": None, "groupby": ["user"], "metrics": [], "is_timeseries": False, "filter": [], } # Table with Jinja callable. table1 = SqlaTable( table_name="test_has_extra_cache_keys_table", sql="SELECT '{{ current_username() }}' as user", database=get_example_database(), ) query_obj = dict(**base_query_obj, extras={}) extra_cache_keys = table1.get_extra_cache_keys(query_obj) self.assertTrue(table1.has_extra_cache_key_calls(query_obj)) assert extra_cache_keys == ["abc"] # Table with Jinja callable disabled. table2 = SqlaTable( table_name="test_has_extra_cache_keys_disabled_table", sql="SELECT '{{ current_username(False) }}' as user", database=get_example_database(), ) query_obj = dict(**base_query_obj, extras={}) extra_cache_keys = table2.get_extra_cache_keys(query_obj) self.assertTrue(table2.has_extra_cache_key_calls(query_obj)) self.assertListEqual(extra_cache_keys, []) # Table with no Jinja callable. query = "SELECT 'abc' as user" table3 = SqlaTable( table_name="test_has_no_extra_cache_keys_table", sql=query, database=get_example_database(), ) query_obj = dict(**base_query_obj, extras={"where": "(user != 'abc')"}) extra_cache_keys = table3.get_extra_cache_keys(query_obj) self.assertFalse(table3.has_extra_cache_key_calls(query_obj)) self.assertListEqual(extra_cache_keys, []) # With Jinja callable in SQL expression. query_obj = dict(**base_query_obj, extras={ "where": "(user != '{{ current_username() }}')" }) extra_cache_keys = table3.get_extra_cache_keys(query_obj) self.assertTrue(table3.has_extra_cache_key_calls(query_obj)) assert extra_cache_keys == ["abc"] # Cleanup for table in [table1, table2, table3]: db.session.delete(table) db.session.commit()
def test_template_kwarg(self): maindb = utils.get_example_database() s = "{{ foo }}" tp = jinja_context.get_template_processor(database=maindb) rendered = tp.process_template(s, foo="bar") self.assertEqual("bar", rendered)
def test_export_dataset_command(self, mock_g): mock_g.user = security_manager.find_user("admin") example_db = get_example_database() example_dataset = example_db.tables[0] command = ExportDatasetsCommand([example_dataset.id]) contents = dict(command.run()) assert list(contents.keys()) == [ "metadata.yaml", "datasets/examples/energy_usage.yaml", "databases/examples.yaml", ] metadata = yaml.safe_load( contents["datasets/examples/energy_usage.yaml"]) # sort columns for deterministc comparison metadata["columns"] = sorted(metadata["columns"], key=itemgetter("column_name")) metadata["metrics"] = sorted(metadata["metrics"], key=itemgetter("metric_name")) # types are different depending on the backend type_map = { column.column_name: str(column.type) for column in example_dataset.columns } assert metadata == { "cache_timeout": None, "columns": [ { "column_name": "source", "description": None, "expression": None, "filterable": True, "groupby": True, "is_active": True, "is_dttm": False, "python_date_format": None, "type": type_map["source"], "verbose_name": None, }, { "column_name": "target", "description": None, "expression": None, "filterable": True, "groupby": True, "is_active": True, "is_dttm": False, "python_date_format": None, "type": type_map["target"], "verbose_name": None, }, { "column_name": "value", "description": None, "expression": None, "filterable": True, "groupby": True, "is_active": True, "is_dttm": False, "python_date_format": None, "type": type_map["value"], "verbose_name": None, }, ], "database_uuid": str(example_db.uuid), "default_endpoint": None, "description": "Energy consumption", "extra": None, "fetch_values_predicate": None, "filter_select_enabled": False, "main_dttm_col": None, "metrics": [ { "d3format": None, "description": None, "expression": "COUNT(*)", "extra": None, "metric_name": "count", "metric_type": "count", "verbose_name": "COUNT(*)", "warning_text": None, }, { "d3format": None, "description": None, "expression": "SUM(value)", "extra": None, "metric_name": "sum__value", "metric_type": None, "verbose_name": None, "warning_text": None, }, ], "offset": 0, "params": None, "schema": None, "sql": None, "table_name": "energy_usage", "template_params": None, "uuid": str(example_dataset.uuid), "version": "1.0.0", }
def test_table_metadata(self): maindb = utils.get_example_database() data = self.get_json_resp(f"/superset/table/{maindb.id}/birth_names/null/") self.assertEqual(data["name"], "birth_names") assert len(data["columns"]) > 5 assert data.get("selectStar").startswith("SELECT")
def load_world_bank_health_n_pop( # pylint: disable=too-many-locals only_metadata: bool = False, force: bool = False ) -> None: """Loads the world bank health dataset, slices and a dashboard""" tbl_name = "wb_health_population" database = utils.get_example_database() table_exists = database.has_table_by_name(tbl_name) if not only_metadata and (not table_exists or force): data = get_example_data("countries.json") pdf = pd.read_json(data) pdf.columns = [col.replace(".", "_") for col in pdf.columns] pdf.year = pd.to_datetime(pdf.year) pdf.to_sql( tbl_name, database.get_sqla_engine(), if_exists="replace", chunksize=50, dtype={ "year": DateTime(), "country_code": String(3), "country_name": String(255), "region": String(255), }, index=False, ) print("Creating table [wb_health_population] reference") tbl = db.session.query(TBL).filter_by(table_name=tbl_name).first() if not tbl: tbl = TBL(table_name=tbl_name) tbl.description = utils.readfile( os.path.join(EXAMPLES_FOLDER, "countries.md")) tbl.main_dttm_col = "year" tbl.database = database tbl.filter_select_enabled = True metrics = [ "sum__SP_POP_TOTL", "sum__SH_DYN_AIDS", "sum__SH_DYN_AIDS", "sum__SP_RUR_TOTL_ZS", "sum__SP_DYN_LE00_IN", "sum__SP_RUR_TOTL", ] for metric in metrics: if not any(col.metric_name == metric for col in tbl.metrics): aggr_func = metric[:3] col = str(column(metric[5:]).compile(db.engine)) tbl.metrics.append( SqlMetric(metric_name=metric, expression=f"{aggr_func}({col})") ) db.session.merge(tbl) db.session.commit() tbl.fetch_metadata() metric = "sum__SP_POP_TOTL" metrics = ["sum__SP_POP_TOTL"] secondary_metric = { "aggregate": "SUM", "column": { "column_name": "SP_RUR_TOTL", "optionName": "_col_SP_RUR_TOTL", "type": "DOUBLE", }, "expressionType": "SIMPLE", "hasCustomLabel": True, "label": "Rural Population", } defaults = { "compare_lag": "10", "compare_suffix": "o10Y", "limit": "25", "granularity_sqla": "year", "groupby": [], "row_limit": config["ROW_LIMIT"], "since": "2014-01-01", "until": "2014-01-02", "time_range": "2014-01-01 : 2014-01-02", "markup_type": "markdown", "country_fieldtype": "cca3", "entity": "country_code", "show_bubbles": True, } print("Creating slices") slices = [ Slice( slice_name="Region Filter", viz_type="filter_box", datasource_type="table", datasource_id=tbl.id, params=get_slice_json( defaults, viz_type="filter_box", date_filter=False, filter_configs=[ { "asc": False, "clearable": True, "column": "region", "key": "2s98dfu", "metric": "sum__SP_POP_TOTL", "multiple": True, }, { "asc": False, "clearable": True, "key": "li3j2lk", "column": "country_name", "metric": "sum__SP_POP_TOTL", "multiple": True, }, ], ), ), Slice( slice_name="World's Population", viz_type="big_number", datasource_type="table", datasource_id=tbl.id, params=get_slice_json( defaults, since="2000", viz_type="big_number", compare_lag="10", metric="sum__SP_POP_TOTL", compare_suffix="over 10Y", ), ), Slice( slice_name="Most Populated Countries", viz_type="table", datasource_type="table", datasource_id=tbl.id, params=get_slice_json( defaults, viz_type="table", metrics=["sum__SP_POP_TOTL"], groupby=["country_name"], ), ), Slice( slice_name="Growth Rate", viz_type="line", datasource_type="table", datasource_id=tbl.id, params=get_slice_json( defaults, viz_type="line", since="1960-01-01", metrics=["sum__SP_POP_TOTL"], num_period_compare="10", groupby=["country_name"], ), ), Slice( slice_name="% Rural", viz_type="world_map", datasource_type="table", datasource_id=tbl.id, params=get_slice_json( defaults, viz_type="world_map", metric="sum__SP_RUR_TOTL_ZS", num_period_compare="10", secondary_metric=secondary_metric, ), ), Slice( slice_name="Life Expectancy VS Rural %", viz_type="bubble", datasource_type="table", datasource_id=tbl.id, params=get_slice_json( defaults, viz_type="bubble", since="2011-01-01", until="2011-01-02", series="region", limit=0, entity="country_name", x="sum__SP_RUR_TOTL_ZS", y="sum__SP_DYN_LE00_IN", size="sum__SP_POP_TOTL", max_bubble_size="50", adhoc_filters=[ { "clause": "WHERE", "expressionType": "SIMPLE", "filterOptionName": "2745eae5", "comparator": [ "TCA", "MNP", "DMA", "MHL", "MCO", "SXM", "CYM", "TUV", "IMY", "KNA", "ASM", "ADO", "AMA", "PLW", ], "operator": "NOT IN", "subject": "country_code", } ], ), ), Slice( slice_name="Rural Breakdown", viz_type="sunburst", datasource_type="table", datasource_id=tbl.id, params=get_slice_json( defaults, viz_type="sunburst", groupby=["region", "country_name"], since="2011-01-01", until="2011-01-01", metric=metric, secondary_metric=secondary_metric, ), ), Slice( slice_name="World's Pop Growth", viz_type="area", datasource_type="table", datasource_id=tbl.id, params=get_slice_json( defaults, since="1960-01-01", until="now", viz_type="area", groupby=["region"], metrics=metrics, ), ), Slice( slice_name="Box plot", viz_type="box_plot", datasource_type="table", datasource_id=tbl.id, params=get_slice_json( defaults, since="1960-01-01", until="now", whisker_options="Min/max (no outliers)", x_ticks_layout="staggered", viz_type="box_plot", groupby=["region"], metrics=metrics, ), ), Slice( slice_name="Treemap", viz_type="treemap", datasource_type="table", datasource_id=tbl.id, params=get_slice_json( defaults, since="1960-01-01", until="now", viz_type="treemap", metrics=["sum__SP_POP_TOTL"], groupby=["region", "country_code"], ), ), Slice( slice_name="Parallel Coordinates", viz_type="para", datasource_type="table", datasource_id=tbl.id, params=get_slice_json( defaults, since="2011-01-01", until="2011-01-01", viz_type="para", limit=100, metrics=["sum__SP_POP_TOTL", "sum__SP_RUR_TOTL_ZS", "sum__SH_DYN_AIDS"], secondary_metric="sum__SP_POP_TOTL", series="country_name", ), ), ] misc_dash_slices.add(slices[-1].slice_name) for slc in slices: merge_slice(slc) print("Creating a World's Health Bank dashboard") dash_name = "World Bank's Data" slug = "world_health" dash = db.session.query(Dashboard).filter_by(slug=slug).first() if not dash: dash = Dashboard() dash.published = True js = textwrap.dedent( """\ { "CHART-36bfc934": { "children": [], "id": "CHART-36bfc934", "meta": { "chartId": 40, "height": 25, "sliceName": "Region Filter", "width": 2 }, "type": "CHART" }, "CHART-37982887": { "children": [], "id": "CHART-37982887", "meta": { "chartId": 41, "height": 25, "sliceName": "World's Population", "width": 2 }, "type": "CHART" }, "CHART-17e0f8d8": { "children": [], "id": "CHART-17e0f8d8", "meta": { "chartId": 42, "height": 92, "sliceName": "Most Populated Countries", "width": 3 }, "type": "CHART" }, "CHART-2ee52f30": { "children": [], "id": "CHART-2ee52f30", "meta": { "chartId": 43, "height": 38, "sliceName": "Growth Rate", "width": 6 }, "type": "CHART" }, "CHART-2d5b6871": { "children": [], "id": "CHART-2d5b6871", "meta": { "chartId": 44, "height": 52, "sliceName": "% Rural", "width": 7 }, "type": "CHART" }, "CHART-0fd0d252": { "children": [], "id": "CHART-0fd0d252", "meta": { "chartId": 45, "height": 50, "sliceName": "Life Expectancy VS Rural %", "width": 8 }, "type": "CHART" }, "CHART-97f4cb48": { "children": [], "id": "CHART-97f4cb48", "meta": { "chartId": 46, "height": 38, "sliceName": "Rural Breakdown", "width": 3 }, "type": "CHART" }, "CHART-b5e05d6f": { "children": [], "id": "CHART-b5e05d6f", "meta": { "chartId": 47, "height": 50, "sliceName": "World's Pop Growth", "width": 4 }, "type": "CHART" }, "CHART-e76e9f5f": { "children": [], "id": "CHART-e76e9f5f", "meta": { "chartId": 48, "height": 50, "sliceName": "Box plot", "width": 4 }, "type": "CHART" }, "CHART-a4808bba": { "children": [], "id": "CHART-a4808bba", "meta": { "chartId": 49, "height": 50, "sliceName": "Treemap", "width": 8 }, "type": "CHART" }, "COLUMN-071bbbad": { "children": [ "ROW-1e064e3c", "ROW-afdefba9" ], "id": "COLUMN-071bbbad", "meta": { "background": "BACKGROUND_TRANSPARENT", "width": 9 }, "type": "COLUMN" }, "COLUMN-fe3914b8": { "children": [ "CHART-36bfc934", "CHART-37982887" ], "id": "COLUMN-fe3914b8", "meta": { "background": "BACKGROUND_TRANSPARENT", "width": 2 }, "type": "COLUMN" }, "GRID_ID": { "children": [ "ROW-46632bc2", "ROW-3fa26c5d", "ROW-812b3f13" ], "id": "GRID_ID", "type": "GRID" }, "HEADER_ID": { "id": "HEADER_ID", "meta": { "text": "World's Bank Data" }, "type": "HEADER" }, "ROOT_ID": { "children": [ "GRID_ID" ], "id": "ROOT_ID", "type": "ROOT" }, "ROW-1e064e3c": { "children": [ "COLUMN-fe3914b8", "CHART-2d5b6871" ], "id": "ROW-1e064e3c", "meta": { "background": "BACKGROUND_TRANSPARENT" }, "type": "ROW" }, "ROW-3fa26c5d": { "children": [ "CHART-b5e05d6f", "CHART-0fd0d252" ], "id": "ROW-3fa26c5d", "meta": { "background": "BACKGROUND_TRANSPARENT" }, "type": "ROW" }, "ROW-46632bc2": { "children": [ "COLUMN-071bbbad", "CHART-17e0f8d8" ], "id": "ROW-46632bc2", "meta": { "background": "BACKGROUND_TRANSPARENT" }, "type": "ROW" }, "ROW-812b3f13": { "children": [ "CHART-a4808bba", "CHART-e76e9f5f" ], "id": "ROW-812b3f13", "meta": { "background": "BACKGROUND_TRANSPARENT" }, "type": "ROW" }, "ROW-afdefba9": { "children": [ "CHART-2ee52f30", "CHART-97f4cb48" ], "id": "ROW-afdefba9", "meta": { "background": "BACKGROUND_TRANSPARENT" }, "type": "ROW" }, "DASHBOARD_VERSION_KEY": "v2" } """ ) pos = json.loads(js) update_slice_ids(pos, slices) dash.dashboard_title = dash_name dash.position_json = json.dumps(pos, indent=4) dash.slug = slug dash.slices = slices[:-1] db.session.merge(dash) db.session.commit()
def test_import_csv(self): self.login(username="******") table_name = "".join(random.choice(string.ascii_uppercase) for _ in range(5)) filename_1 = "testCSV.csv" test_file_1 = open(filename_1, "w+") test_file_1.write("a,b\n") test_file_1.write("john,1\n") test_file_1.write("paul,2\n") test_file_1.close() filename_2 = "testCSV2.csv" test_file_2 = open(filename_2, "w+") test_file_2.write("b,c,d\n") test_file_2.write("john,1,x\n") test_file_2.write("paul,2,y\n") test_file_2.close() example_db = utils.get_example_database() example_db.allow_csv_upload = True db_id = example_db.id db.session.commit() form_data = { "csv_file": open(filename_1, "rb"), "sep": ",", "name": table_name, "con": db_id, "if_exists": "fail", "index_label": "test_label", "mangle_dupe_cols": False, } url = "/databaseview/list/" add_datasource_page = self.get_resp(url) self.assertIn("Upload a CSV", add_datasource_page) url = "/csvtodatabaseview/form" form_get = self.get_resp(url) self.assertIn("CSV to Database configuration", form_get) try: # initial upload with fail mode resp = self.get_resp(url, data=form_data) self.assertIn( f'CSV file "{filename_1}" uploaded to table "{table_name}"', resp ) # upload again with fail mode; should fail form_data["csv_file"] = open(filename_1, "rb") resp = self.get_resp(url, data=form_data) self.assertIn( f'Unable to upload CSV file "{filename_1}" to table "{table_name}"', resp, ) # upload again with append mode form_data["csv_file"] = open(filename_1, "rb") form_data["if_exists"] = "append" resp = self.get_resp(url, data=form_data) self.assertIn( f'CSV file "{filename_1}" uploaded to table "{table_name}"', resp ) # upload again with replace mode form_data["csv_file"] = open(filename_1, "rb") form_data["if_exists"] = "replace" resp = self.get_resp(url, data=form_data) self.assertIn( f'CSV file "{filename_1}" uploaded to table "{table_name}"', resp ) # try to append to table from file with different schema form_data["csv_file"] = open(filename_2, "rb") form_data["if_exists"] = "append" resp = self.get_resp(url, data=form_data) self.assertIn( f'Unable to upload CSV file "{filename_2}" to table "{table_name}"', resp, ) # replace table from file with different schema form_data["csv_file"] = open(filename_2, "rb") form_data["if_exists"] = "replace" resp = self.get_resp(url, data=form_data) self.assertIn( f'CSV file "{filename_2}" uploaded to table "{table_name}"', resp ) table = ( db.session.query(SqlaTable) .filter_by(table_name=table_name, database_id=db_id) .first() ) # make sure the new column name is reflected in the table metadata self.assertIn("d", table.column_names) finally: os.remove(filename_1) os.remove(filename_2)
def _get_database_by_name(self, database_name="main"): if database_name == "examples": return get_example_database() else: raise ValueError("Database doesn't exist")
def test__normalize_prequery_result_type( app_context: Flask, mocker: MockFixture, row: pd.Series, dimension: str, result: Any, ) -> None: def _convert_dttm( target_type: str, dttm: datetime, db_extra: Optional[Dict[str, Any]] = None) -> Optional[str]: if target_type.upper() == TemporalType.TIMESTAMP: return f"""TIME_PARSE('{dttm.isoformat(timespec="seconds")}')""" return None table = SqlaTable(table_name="foobar", database=get_example_database()) mocker.patch.object(table.db_engine_spec, "convert_dttm", new=_convert_dttm) columns_by_name = { "foo": TableColumn( column_name="foo", is_dttm=False, table=table, type="STRING", ), "bar": TableColumn( column_name="bar", is_dttm=False, table=table, type="BOOLEAN", ), "baz": TableColumn( column_name="baz", is_dttm=False, table=table, type="INTEGER", ), "qux": TableColumn( column_name="qux", is_dttm=False, table=table, type="FLOAT", ), "quux": TableColumn( column_name="quuz", is_dttm=True, table=table, type="STRING", ), "quuz": TableColumn( column_name="quux", is_dttm=True, table=table, type="TIMESTAMP", ), } normalized = table._normalize_prequery_result_type( row, dimension, columns_by_name, ) assert type(normalized) == type(result) if isinstance(normalized, TextClause): assert str(normalized) == str(result) else: assert normalized == result