def create_slice( datasource_id: Optional[int] = None, datasource: Optional[SqlaTable] = None, name: Optional[str] = None, owners: Optional[List[User]] = None, ) -> Slice: name = name if name is not None else random_str() owners = owners if owners is not None else [] datasource_type = "table" if datasource: return Slice( slice_name=name, table=datasource, owners=owners, datasource_type=datasource_type, ) datasource_id = (datasource_id if datasource_id is not None else create_datasource_table_to_db(name=name + "_table").id) return Slice( slice_name=name, datasource_id=datasource_id, owners=owners, datasource_type=datasource_type, )
def test_import_slices_for_non_existent_table(self): with self.assertRaises(AttributeError): Slice.import_obj( self.create_slice("Import Me 3", id=10004, table_name="non_existent"), None, )
def test_import_slices_override(self): slc = self.create_slice("Import Me New", id=10005) slc_1_id = Slice.import_obj(slc, None, import_time=1990) slc.slice_name = "Import Me New" imported_slc_1 = self.get_slice(slc_1_id) slc_2 = self.create_slice("Import Me New", id=10005) slc_2_id = Slice.import_obj(slc_2, imported_slc_1, import_time=1990) self.assertEqual(slc_1_id, slc_2_id) imported_slc_2 = self.get_slice(slc_2_id) self.assert_slice_equals(slc, imported_slc_2)
def test_set_perm_slice(self): session = db.session database = Database(database_name="tmp_database", sqlalchemy_uri="sqlite://test") table = SqlaTable(table_name="tmp_perm_table", database=database) session.add(database) session.add(table) session.commit() # no schema permission slice = Slice( datasource_id=table.id, datasource_type="table", datasource_name="tmp_perm_table", slice_name="slice_name", ) session.add(slice) session.commit() slice = session.query(Slice).filter_by(slice_name="slice_name").one() self.assertEquals(slice.perm, table.perm) self.assertEquals(slice.perm, f"[tmp_database].[tmp_perm_table](id:{table.id})") self.assertEquals(slice.schema_perm, table.schema_perm) self.assertIsNone(slice.schema_perm) table.schema = "tmp_perm_schema" table.table_name = "tmp_perm_table_v2" session.commit() # TODO(bogdan): modify slice permissions on the table update. self.assertNotEquals(slice.perm, table.perm) self.assertEquals(slice.perm, f"[tmp_database].[tmp_perm_table](id:{table.id})") self.assertEquals( table.perm, f"[tmp_database].[tmp_perm_table_v2](id:{table.id})") # TODO(bogdan): modify slice schema permissions on the table update. self.assertNotEquals(slice.schema_perm, table.schema_perm) self.assertIsNone(slice.schema_perm) # updating slice refreshes the permissions slice.slice_name = "slice_name_v2" session.commit() self.assertEquals(slice.perm, table.perm) self.assertEquals( slice.perm, f"[tmp_database].[tmp_perm_table_v2](id:{table.id})") self.assertEquals(slice.schema_perm, table.schema_perm) self.assertEquals(slice.schema_perm, "[tmp_database].[tmp_perm_schema]") session.delete(slice) session.delete(table) session.delete(database) session.commit()
def export_chart(chart: Slice) -> Iterator[Tuple[str, str]]: chart_slug = sanitize(chart.slice_name) file_name = f"charts/{chart_slug}.yaml" payload = chart.export_to_dict( recursive=False, include_parent_ref=False, include_defaults=True, export_uuids=True, ) # TODO (betodealmeida): move this logic to export_to_dict once this # becomes the default export endpoint for key in REMOVE_KEYS: del payload[key] if "params" in payload: try: payload["params"] = json.loads(payload["params"]) except json.decoder.JSONDecodeError: pass payload["version"] = IMPORT_EXPORT_VERSION if chart.table: payload["dataset_uuid"] = str(chart.table.uuid) file_content = yaml.safe_dump(payload, sort_keys=False) yield file_name, file_content if chart.table: yield from ExportDatasetsCommand([chart.table.id]).run()
def decode_dashboards( # pylint: disable=too-many-return-statements o: Dict[str, Any]) -> Any: """ Function to be passed into json.loads obj_hook parameter Recreates the dashboard object from a json representation. """ from superset.connectors.druid.models import ( DruidCluster, DruidColumn, DruidDatasource, DruidMetric, ) if "__Dashboard__" in o: return Dashboard(**o["__Dashboard__"]) if "__Slice__" in o: return Slice(**o["__Slice__"]) if "__TableColumn__" in o: return TableColumn(**o["__TableColumn__"]) if "__SqlaTable__" in o: return SqlaTable(**o["__SqlaTable__"]) if "__SqlMetric__" in o: return SqlMetric(**o["__SqlMetric__"]) if "__DruidCluster__" in o: return DruidCluster(**o["__DruidCluster__"]) if "__DruidColumn__" in o: return DruidColumn(**o["__DruidColumn__"]) if "__DruidDatasource__" in o: return DruidDatasource(**o["__DruidDatasource__"]) if "__DruidMetric__" in o: return DruidMetric(**o["__DruidMetric__"]) if "__datetime__" in o: return datetime.strptime(o["__datetime__"], "%Y-%m-%dT%H:%M:%S") return o
def new_slice(name=None, table=None, tags=None): """Create a new test slice (and test table if none specified)""" if name is None: name = "slice-%s" % uuid.uuid4() if table is None: table = create_table(tags=tags) if tags is None: tags = ['test'] slyce = Slice( slice_name=name, datasource_type='table', datasource_name=table.datasource_name, viz_type='bubble', params=json.dumps( dict( tags=tags, database_name=table.database_name, datasource_name=table.datasource_name, schema=table.schema, metrics=[], )), ) # NOTE that we don't actually import the slice here - it needs to # be attached to a dashboard for that to make sense return slyce
def test_treemap_migrate(app_context: SupersetApp) -> None: from superset.models.slice import Slice slc = Slice( viz_type=MigrateTreeMap.source_viz_type, datasource_type="table", params=treemap_form_data, query_context=f'{{"form_data": {treemap_form_data}}}', ) slc = MigrateTreeMap.upgrade_slice(slc) assert slc.viz_type == MigrateTreeMap.target_viz_type # verify form_data new_form_data = json.loads(slc.params) assert new_form_data["metric"] == "sum__num" assert new_form_data["viz_type"] == "treemap_v2" assert "metrics" not in new_form_data assert json.dumps(new_form_data["form_data_bak"], sort_keys=True) == json.dumps( json.loads(treemap_form_data), sort_keys=True ) # verify query_context new_query_context = json.loads(slc.query_context) assert new_query_context["form_data"]["viz_type"] == "treemap_v2" # downgrade slc = MigrateTreeMap.downgrade_slice(slc) assert slc.viz_type == MigrateTreeMap.source_viz_type assert json.dumps(json.loads(slc.params), sort_keys=True) == json.dumps( json.loads(treemap_form_data), sort_keys=True )
def test_data_for_slices_with_adhoc_column(self): # should perform sqla.model.BaseDatasource.data_for_slices() with adhoc # column and legacy chart tbl = self.get_table(name="birth_names") dashboard = self.get_dash_by_slug("births") slc = Slice( slice_name="slice with adhoc column", datasource_type="table", viz_type="table", params=json.dumps({ "adhoc_filters": [], "granularity_sqla": "ds", "groupby": [ "name", { "label": "adhoc_column", "sqlExpression": "name" }, ], "metrics": ["sum__num"], "time_range": "No filter", "viz_type": "table", }), datasource_id=tbl.id, ) dashboard.slices.append(slc) datasource_info = slc.datasource.data_for_slices([slc]) assert "database" in datasource_info # clean up and auto commit metadata_db.session.delete(slc)
def load_random_time_series_data(only_metadata: bool = False, force: bool = False) -> None: """Loading random time series data from a zip file in the repo""" tbl_name = "random_time_series" database = utils.get_example_database() table_exists = database.has_table_by_name(tbl_name) if not only_metadata and (not table_exists or force): data = get_example_data("random_time_series.json.gz") pdf = pd.read_json(data) if database.backend == "presto": pdf.ds = pd.to_datetime(pdf.ds, unit="s") pdf.ds = pdf.ds.dt.strftime("%Y-%m-%d %H:%M%:%S") else: pdf.ds = pd.to_datetime(pdf.ds, unit="s") pdf.to_sql( tbl_name, database.get_sqla_engine(), if_exists="replace", chunksize=500, dtype={ "ds": DateTime if database.backend != "presto" else String(255) }, index=False, ) print("Done loading table!") print("-" * 80) print(f"Creating table [{tbl_name}] reference") obj = db.session.query(TBL).filter_by(table_name=tbl_name).first() if not obj: obj = TBL(table_name=tbl_name) obj.main_dttm_col = "ds" obj.database = database db.session.merge(obj) db.session.commit() obj.fetch_metadata() tbl = obj slice_data = { "granularity_sqla": "day", "row_limit": config["ROW_LIMIT"], "since": "2019-01-01", "until": "2019-02-01", "metric": "count", "viz_type": "cal_heatmap", "domain_granularity": "month", "subdomain_granularity": "day", } print("Creating a slice") slc = Slice( slice_name="Calendar Heatmap", viz_type="cal_heatmap", datasource_type="table", datasource_id=tbl.id, params=get_slice_json(slice_data), ) merge_slice(slc)
def insert_chart( self, slice_name: str, owners: List[int], datasource_id: int, created_by=None, datasource_type: str = "table", description: Optional[str] = None, viz_type: Optional[str] = None, params: Optional[str] = None, cache_timeout: Optional[int] = None, ) -> Slice: obj_owners = list() for owner in owners: user = db.session.query(security_manager.user_model).get(owner) obj_owners.append(user) datasource = ConnectorRegistry.get_datasource(datasource_type, datasource_id, db.session) slice = Slice( cache_timeout=cache_timeout, created_by=created_by, datasource_id=datasource.id, datasource_name=datasource.name, datasource_type=datasource.type, description=description, owners=obj_owners, params=params, slice_name=slice_name, viz_type=viz_type, ) db.session.add(slice) db.session.commit() return slice
def create_slice( self, name, ds_id=None, id=None, db_name="examples", table_name="wb_health_population", ): params = { "num_period_compare": "10", "remote_id": id, "datasource_name": table_name, "database_name": db_name, "schema": "", # Test for trailing commas "metrics": ["sum__signup_attempt_email", "sum__signup_attempt_facebook"], } if table_name and not ds_id: table = self.get_table_by_name(table_name) if table: ds_id = table.id return Slice( slice_name=name, datasource_type="table", viz_type="bubble", params=json.dumps(params), datasource_id=ds_id, id=id, )
def _export(model: Slice, export_related: bool = True) -> Iterator[Tuple[str, str]]: chart_slug = secure_filename(model.slice_name) file_name = f"charts/{chart_slug}_{model.id}.yaml" payload = model.export_to_dict( recursive=False, include_parent_ref=False, include_defaults=True, export_uuids=True, ) # TODO (betodealmeida): move this logic to export_to_dict once this # becomes the default export endpoint payload = { key: value for key, value in payload.items() if key not in REMOVE_KEYS } if payload.get("params"): try: payload["params"] = json.loads(payload["params"]) except json.decoder.JSONDecodeError: logger.info("Unable to decode `params` field: %s", payload["params"]) payload["version"] = EXPORT_VERSION if model.table: payload["dataset_uuid"] = str(model.table.uuid) file_content = yaml.safe_dump(payload, sort_keys=False) yield file_name, file_content if model.table and export_related: yield from ExportDatasetsCommand([model.table.id]).run()
def load_multi_line(only_metadata: bool = False) -> None: load_world_bank_health_n_pop(only_metadata) load_birth_names(only_metadata) ids = [ row.id for row in db.session.query(Slice).filter( Slice.slice_name.in_(["Growth Rate", "Trends"])) ] slc = Slice( datasource_type=DatasourceType.TABLE, # not true, but needed datasource_id=1, # cannot be empty slice_name="Multi Line", viz_type="line_multi", params=json.dumps({ "slice_name": "Multi Line", "viz_type": "line_multi", "line_charts": [ids[0]], "line_charts_2": [ids[1]], "since": "1970", "until": "1995", "prefix_metric_with_slice_name": True, "show_legend": False, "x_axis_format": "%Y", }), ) misc_dash_slices.add(slc.slice_name) merge_slice(slc)
def test_import_2_slices_for_same_table(self): table_id = self.get_table_by_name("wb_health_population").id # table_id != 666, import func will have to find the table slc_1 = self.create_slice("Import Me 1", ds_id=666, id=10002) slc_id_1 = Slice.import_obj(slc_1, None) slc_2 = self.create_slice("Import Me 2", ds_id=666, id=10003) slc_id_2 = Slice.import_obj(slc_2, None) imported_slc_1 = self.get_slice(slc_id_1) imported_slc_2 = self.get_slice(slc_id_2) self.assertEqual(table_id, imported_slc_1.datasource_id) self.assert_slice_equals(slc_1, imported_slc_1) self.assertEqual(imported_slc_1.datasource.perm, imported_slc_1.perm) self.assertEqual(table_id, imported_slc_2.datasource_id) self.assert_slice_equals(slc_2, imported_slc_2) self.assertEqual(imported_slc_2.datasource.perm, imported_slc_2.perm)
def create_slice(title: str, viz_type: str, table: SqlaTable, slices_dict: Dict[str, str]) -> Slice: return Slice( slice_name=title, viz_type=viz_type, datasource_type=DatasourceType.TABLE, datasource_id=table.id, params=json.dumps(slices_dict, indent=4, sort_keys=True), )
def test_import_1_slice(self): expected_slice = self.create_slice("Import Me", id=10001) slc_id = Slice.import_obj(expected_slice, None, import_time=1989) slc = self.get_slice(slc_id) self.assertEqual(slc.datasource.perm, slc.perm) self.assert_slice_equals(expected_slice, slc) table_id = self.get_table_by_name("wb_health_population").id self.assertEqual(table_id, self.get_slice(slc_id).datasource_id)
def populate_dashboards(instance: Slice, dashboards: List[int]): """ Mutates a Slice with the dashboards SQLA Models """ dashboards_tmp = [] for dashboard_id in dashboards: dashboards_tmp.append( current_app.appbuilder.get_session.query(Dashboard).filter_by( id=dashboard_id).one()) instance.dashboards = dashboards_tmp
def test_saved_chart_is_admin(mocker: MockFixture, app_context: AppContext) -> None: from superset.connectors.sqla.models import SqlaTable from superset.explore.form_data.utils import check_access from superset.models.slice import Slice mocker.patch(dataset_find_by_id, return_value=SqlaTable()) mocker.patch(can_access_datasource, return_value=True) mocker.patch(is_user_admin, return_value=True) mocker.patch(chart_find_by_id, return_value=Slice()) assert check_access(dataset_id=1, chart_id=1, actor=User()) == True
def create_slice(datasource_id: Optional[int], name: Optional[str], owners: Optional[List[User]]) -> Slice: name = name or random_str() owners = owners or [] datasource_id = (datasource_id or create_datasource_table_to_db(name=name + "_table").id) return Slice( slice_name=name, datasource_id=datasource_id, owners=owners, datasource_type="table", )
def test_saved_chart_no_access(mocker: MockFixture, app_context: AppContext) -> None: from superset.connectors.sqla.models import SqlaTable from superset.explore.form_data.utils import check_access from superset.models.slice import Slice with raises(ChartAccessDeniedError): mocker.patch(dataset_find_by_id, return_value=SqlaTable()) mocker.patch(can_access_datasource, return_value=True) mocker.patch(is_user_admin, return_value=False) mocker.patch(is_owner, return_value=False) mocker.patch(can_access, return_value=False) mocker.patch(chart_find_by_id, return_value=Slice()) check_access(dataset_id=1, chart_id=1, actor=User())
def test_saved_chart_is_admin(mocker: MockFixture) -> None: from superset.connectors.sqla.models import SqlaTable from superset.explore.utils import check_access as check_chart_access from superset.models.slice import Slice mocker.patch(dataset_find_by_id, return_value=SqlaTable()) mocker.patch(can_access_datasource, return_value=True) mocker.patch(is_admin, return_value=True) mocker.patch(chart_find_by_id, return_value=Slice()) with override_user(User()): check_chart_access( datasource_id=1, chart_id=1, datasource_type=DatasourceType.TABLE, )
def test_saved_chart_is_owner(mocker: MockFixture, app_context: AppContext) -> None: from superset.connectors.sqla.models import SqlaTable from superset.explore.utils import check_access as check_chart_access from superset.models.slice import Slice mocker.patch(dataset_find_by_id, return_value=SqlaTable()) mocker.patch(can_access_datasource, return_value=True) mocker.patch(is_user_admin, return_value=False) mocker.patch(is_owner, return_value=True) mocker.patch(chart_find_by_id, return_value=Slice()) check_chart_access( datasource_id=1, chart_id=1, actor=User(), datasource_type=DatasourceType.TABLE, )
def import_chart( session: Session, config: Dict[str, Any], overwrite: bool = False ) -> Slice: existing = session.query(Slice).filter_by(uuid=config["uuid"]).first() if existing: if not overwrite: return existing config["id"] = existing.id # TODO (betodealmeida): move this logic to import_from_dict config["params"] = json.dumps(config["params"]) chart = Slice.import_from_dict(session, config, recursive=False) if chart.id is None: session.flush() return chart
def session_with_data(session: Session) -> Iterator[Session]: from superset.models.slice import Slice engine = session.get_bind() Slice.metadata.create_all(engine) # pylint: disable=no-member slice_obj = Slice( id=1, datasource_id=1, datasource_type=DatasourceType.TABLE, datasource_name="tmp_perm_table", slice_name="slice_name", ) session.add(slice_obj) session.commit() yield session session.rollback()
def import_chart( slc_to_import: Slice, slc_to_override: Optional[Slice], import_time: Optional[int] = None, ) -> int: """Inserts or overrides slc in the database. remote_id and import_time fields in params_dict are set to track the slice origin and ensure correct overrides for multiple imports. Slice.perm is used to find the datasources and connect them. :param Slice slc_to_import: Slice object to import :param Slice slc_to_override: Slice to replace, id matches remote_id :returns: The resulting id for the imported slice :rtype: int """ session = db.session make_transient(slc_to_import) slc_to_import.dashboards = [] slc_to_import.alter_params(remote_id=slc_to_import.id, import_time=import_time) slc_to_import = slc_to_import.copy() slc_to_import.reset_ownership() params = slc_to_import.params_dict datasource = ConnectorRegistry.get_datasource_by_name( session, slc_to_import.datasource_type, params["datasource_name"], params["schema"], params["database_name"], ) slc_to_import.datasource_id = datasource.id # type: ignore if slc_to_override: slc_to_override.override(slc_to_import) session.flush() return slc_to_override.id session.add(slc_to_import) logger.info("Final slice: %s", str(slc_to_import.to_json())) session.flush() return slc_to_import.id
def decode_dashboards(o: Dict[str, Any]) -> Any: """ Function to be passed into json.loads obj_hook parameter Recreates the dashboard object from a json representation. """ if "__Dashboard__" in o: return Dashboard(**o["__Dashboard__"]) if "__Slice__" in o: return Slice(**o["__Slice__"]) if "__TableColumn__" in o: return TableColumn(**o["__TableColumn__"]) if "__SqlaTable__" in o: return SqlaTable(**o["__SqlaTable__"]) if "__SqlMetric__" in o: return SqlMetric(**o["__SqlMetric__"]) if "__datetime__" in o: return datetime.strptime(o["__datetime__"], "%Y-%m-%dT%H:%M:%S") return o
def create_slice(self): with self.create_app().app_context(): session = db.session dataset = (session.query(SqlaTable).filter_by( table_name="dummy_sql_table").first()) slice = Slice( datasource_id=dataset.id, datasource_type=DatasourceType.TABLE, datasource_name="tmp_perm_table", slice_name="slice_name", ) session.add(slice) session.commit() yield slice # rollback session.delete(slice) session.commit()
def decode_dashboards(o): """ Function to be passed into json.loads obj_hook parameter Recreates the dashboard object from a json representation. """ import superset.models.core as models if "__Dashboard__" in o: return Dashboard(**o["__Dashboard__"]) elif "__Slice__" in o: return Slice(**o["__Slice__"]) elif "__TableColumn__" in o: return TableColumn(**o["__TableColumn__"]) elif "__SqlaTable__" in o: return SqlaTable(**o["__SqlaTable__"]) elif "__SqlMetric__" in o: return SqlMetric(**o["__SqlMetric__"]) elif "__datetime__" in o: return datetime.strptime(o["__datetime__"], "%Y-%m-%dT%H:%M:%S") else: return o
def insert_chart( self, slice_name: str, owners: List[int], datasource_id: int, created_by=None, datasource_type: str = "table", description: Optional[str] = None, viz_type: Optional[str] = None, params: Optional[str] = None, cache_timeout: Optional[int] = None, certified_by: Optional[str] = None, certification_details: Optional[str] = None, ) -> Slice: obj_owners = list() for owner in owners: user = db.session.query(security_manager.user_model).get(owner) obj_owners.append(user) datasource = (db.session.query(SqlaTable).filter_by( id=datasource_id).one_or_none()) slice = Slice( cache_timeout=cache_timeout, certified_by=certified_by, certification_details=certification_details, created_by=created_by, datasource_id=datasource.id, datasource_name=datasource.name, datasource_type=datasource.type, description=description, owners=obj_owners, params=params, slice_name=slice_name, viz_type=viz_type, ) db.session.add(slice) db.session.commit() return slice