def import_obj(cls, slc_to_import, import_time=None): """Inserts or overrides slc in the database. remote_id and import_time fields in params_dict are set to track the slice origin and ensure correct overrides for multiple imports. Slice.perm is used to find the datasources and connect them. """ session = db.session make_transient(slc_to_import) slc_to_import.dashboards = [] slc_to_import.alter_params( remote_id=slc_to_import.id, import_time=import_time) # find if the slice was already imported slc_to_override = None for slc in session.query(Slice).all(): if ('remote_id' in slc.params_dict and slc.params_dict['remote_id'] == slc_to_import.id): slc_to_override = slc slc_to_import = slc_to_import.copy() params = slc_to_import.params_dict slc_to_import.datasource_id = ConnectorRegistry.get_datasource_by_name( session, slc_to_import.datasource_type, params['datasource_name'], params['schema'], params['database_name']).id if slc_to_override: slc_to_override.override(slc_to_import) session.flush() return slc_to_override.id session.add(slc_to_import) logging.info('Final slice: {}'.format(slc_to_import.to_json())) session.flush() return slc_to_import.id
def create_missing_perms(): """Creates missing perms for datasources, schemas and metrics""" logging.info( "Fetching a set of all perms to lookup which ones are missing") all_pvs = set() for pv in sm.get_session.query(sm.permissionview_model).all(): if pv.permission and pv.view_menu: all_pvs.add((pv.permission.name, pv.view_menu.name)) def merge_pv(view_menu, perm): """Create permission view menu only if it doesn't exist""" if view_menu and perm and (view_menu, perm) not in all_pvs: merge_perm(sm, view_menu, perm) logging.info("Creating missing datasource permissions.") datasources = ConnectorRegistry.get_all_datasources(db.session) for datasource in datasources: merge_pv('datasource_access', datasource.get_perm()) merge_pv('schema_access', datasource.schema_perm) logging.info("Creating missing database permissions.") databases = db.session.query(models.Database).all() for database in databases: merge_pv('database_access', database.perm) logging.info("Creating missing metrics permissions") metrics = [] for datasource_class in ConnectorRegistry.sources.values(): metrics += list(db.session.query(datasource_class.metric_class).all()) for metric in metrics: if (metric.is_restricted): merge_pv('metric_access', metric.perm)
def import_obj(cls, slc_to_import, slc_to_override, import_time=None): """Inserts or overrides slc in the database. remote_id and import_time fields in params_dict are set to track the slice origin and ensure correct overrides for multiple imports. Slice.perm is used to find the datasources and connect them. :param Slice slc_to_import: Slice object to import :param Slice slc_to_override: Slice to replace, id matches remote_id :returns: The resulting id for the imported slice :rtype: int """ session = db.session make_transient(slc_to_import) slc_to_import.dashboards = [] slc_to_import.alter_params( remote_id=slc_to_import.id, import_time=import_time) slc_to_import = slc_to_import.copy() params = slc_to_import.params_dict slc_to_import.datasource_id = ConnectorRegistry.get_datasource_by_name( session, slc_to_import.datasource_type, params['datasource_name'], params['schema'], params['database_name']).id if slc_to_override: slc_to_override.override(slc_to_import) session.flush() return slc_to_override.id session.add(slc_to_import) logging.info('Final slice: {}'.format(slc_to_import.to_json())) session.flush() return slc_to_import.id
def create_missing_datasource_perms(view_menu_set): logging.info("Creating missing datasource permissions.") datasources = ConnectorRegistry.get_all_datasources( db.session) for datasource in datasources: if datasource and datasource.perm not in view_menu_set: merge_perm(sm, 'datasource_access', datasource.get_perm()) if datasource.schema_perm: merge_perm(sm, 'schema_access', datasource.schema_perm)
def __init__( self, datasource: Dict, queries: List[Dict], ): self.datasource = ConnectorRegistry.get_datasource(datasource.get('type'), int(datasource.get('id')), db.session) self.queries = list(map(lambda query_obj: QueryObject(**query_obj), queries))
def datasource_access_by_name( self, database, datasource_name, schema=None): if self.database_access(database) or self.all_datasource_access(): return True schema_perm = utils.get_schema_perm(database, schema) if schema and self.can_access('schema_access', schema_perm): return True datasources = ConnectorRegistry.query_datasources_by_name( db.session, database, datasource_name, schema=schema) for datasource in datasources: if self.can_access('datasource_access', datasource.perm): return True return False
def __init__( self, datasource: Dict, queries: List[Dict], force: bool = False, custom_cache_timeout: int = None, ): self.datasource = ConnectorRegistry.get_datasource(datasource.get('type'), int(datasource.get('id')), db.session) self.queries = list(map(lambda query_obj: QueryObject(**query_obj), queries)) self.force = force self.custom_cache_timeout = custom_cache_timeout self.enforce_numerical_metrics = True
def save(self): datasource = json.loads(request.form.get('data')) datasource_id = datasource.get('id') datasource_type = datasource.get('type') orm_datasource = ConnectorRegistry.get_datasource( datasource_type, datasource_id, db.session) if not check_ownership(orm_datasource, raise_if_false=False): return json_error_response( __( 'You are not authorized to modify ' 'this data source configuration'), status='401', ) orm_datasource.update_from_object(datasource) data = orm_datasource.data db.session.commit() return self.json_response(data)
def accessible_by_user(self, database, datasource_names, schema=None): if self.database_access(database) or self.all_datasource_access(): return datasource_names if schema: schema_perm = utils.get_schema_perm(database, schema) if self.can_access('schema_access', schema_perm): return datasource_names user_perms = self.user_datasource_perms() user_datasources = ConnectorRegistry.query_datasources_by_permissions( db.session, database, user_perms) if schema: names = { d.table_name for d in user_datasources if d.schema == schema} return [d for d in datasource_names if d in names] else: full_names = {d.full_name for d in user_datasources} return [d for d in datasource_names if d in full_names]
def import_obj( cls, slc_to_import: "Slice", slc_to_override: Optional["Slice"], import_time: Optional[int] = None, ) -> int: """Inserts or overrides slc in the database. remote_id and import_time fields in params_dict are set to track the slice origin and ensure correct overrides for multiple imports. Slice.perm is used to find the datasources and connect them. :param Slice slc_to_import: Slice object to import :param Slice slc_to_override: Slice to replace, id matches remote_id :returns: The resulting id for the imported slice :rtype: int """ session = db.session make_transient(slc_to_import) slc_to_import.dashboards = [] slc_to_import.alter_params(remote_id=slc_to_import.id, import_time=import_time) slc_to_import = slc_to_import.copy() slc_to_import.reset_ownership() params = slc_to_import.params_dict slc_to_import.datasource_id = ConnectorRegistry.get_datasource_by_name( # type: ignore session, slc_to_import.datasource_type, params["datasource_name"], params["schema"], params["database_name"], ).id if slc_to_override: slc_to_override.override(slc_to_import) session.flush() return slc_to_override.id session.add(slc_to_import) logging.info("Final slice: {}".format(slc_to_import.to_json())) session.flush() return slc_to_import.id
def export_dashboards(cls, dashboard_ids): copied_dashboards = [] datasource_ids = set() for dashboard_id in dashboard_ids: # make sure that dashboard_id is an integer dashboard_id = int(dashboard_id) copied_dashboard = (db.session.query(Dashboard).options( subqueryload( Dashboard.slices)).filter_by(id=dashboard_id).first()) make_transient(copied_dashboard) for slc in copied_dashboard.slices: datasource_ids.add((slc.datasource_id, slc.datasource_type)) # add extra params for the import slc.alter_params( remote_id=slc.id, datasource_name=slc.datasource.name, schema=slc.datasource.name, database_name=slc.datasource.database.name, ) copied_dashboard.alter_params(remote_id=dashboard_id) copied_dashboards.append(copied_dashboard) eager_datasources = [] for dashboard_id, dashboard_type in datasource_ids: eager_datasource = ConnectorRegistry.get_eager_datasource( db.session, dashboard_type, dashboard_id) eager_datasource.alter_params( remote_id=eager_datasource.id, database_name=eager_datasource.database.name, ) make_transient(eager_datasource) eager_datasources.append(eager_datasource) return json.dumps( { "dashboards": copied_dashboards, "datasources": eager_datasources }, cls=utils.DashboardEncoder, indent=4, )
def get_datasources_accessible_by_user( self, database: "Database", datasource_names: List[DatasourceName], schema: Optional[str] = None, ) -> List[DatasourceName]: """ Return the list of SQL tables accessible by the user. :param database: The SQL database :param datasource_names: The list of eligible SQL tables w/ schema :param schema: The fallback SQL schema if not present in the table name :returns: The list of accessible SQL tables w/ schema """ from superset import db if self.database_access(database) or self.all_datasource_access(): return datasource_names if schema: schema_perm = self.get_schema_perm(database, schema) if schema_perm and self.can_access("schema_access", schema_perm): return datasource_names user_perms = self.user_view_menu_names("datasource_access") schema_perms = self.user_view_menu_names("schema_access") user_datasources = ConnectorRegistry.query_datasources_by_permissions( db.session, database, user_perms, schema_perms) if schema: names = { d.table_name for d in user_datasources if d.schema == schema } return [d for d in datasource_names if d in names] else: full_names = {d.full_name for d in user_datasources} return [ d for d in datasource_names if f"[{database}].[{d}]" in full_names ]
def accessible_by_user(self, database, datasource_names, schema=None): if self.database_access(database) or self.all_datasource_access(): return datasource_names schema_perm = utils.get_schema_perm(database, schema) if schema and utils.can_access(sm, 'schema_access', schema_perm, g.user): return datasource_names role_ids = set([role.id for role in g.user.roles]) # TODO: cache user_perms or user_datasources user_pvms = (db.session.query(ab_models.PermissionView).join( ab_models.Permission).filter( ab_models.Permission.name == 'datasource_access').filter( ab_models.PermissionView.role.any( ab_models.Role.id.in_(role_ids))).all()) user_perms = set([pvm.view_menu.name for pvm in user_pvms]) user_datasources = ConnectorRegistry.query_datasources_by_permissions( db.session, database, user_perms) full_names = set([d.full_name for d in user_datasources]) return [d for d in datasource_names if d in full_names]
def accessible_by_user(self, database, datasource_names, schema=None): if self.database_access(database) or self.all_datasource_access(): return datasource_names if schema: schema_perm = utils.get_schema_perm(database, schema) if self.can_access('schema_access', schema_perm): return datasource_names user_perms = self.user_datasource_perms() user_datasources = ConnectorRegistry.query_datasources_by_permissions( db.session, database, user_perms) if schema: names = { d.table_name for d in user_datasources if d.schema == schema } return [d for d in datasource_names if d in names] else: full_names = {d.full_name for d in user_datasources} return [d for d in datasource_names if d in full_names]
def get_viz( slice_id=None, form_data=None, datasource_type=None, datasource_id=None, force=False, ): if slice_id: slc = (db.session.query(models.Slice).filter_by(id=slice_id).one()) return slc.get_viz() else: viz_type = form_data.get('viz_type', 'table') datasource = ConnectorRegistry.get_datasource(datasource_type, datasource_id, db.session) viz_obj = viz.viz_types[viz_type]( datasource, form_data=form_data, force=force, ) return viz_obj
def test_get_user_datasources_gamma(self): Datasource = namedtuple("Datasource", ["database", "schema", "name"]) mock_session = mock.MagicMock() mock_session.query.return_value.filter.return_value.all.return_value = [] with mock.patch("superset.security_manager") as mock_security_manager: mock_security_manager.can_access_database.return_value = False with mock.patch.object( ConnectorRegistry, "get_all_datasources") as mock_get_all_datasources: mock_get_all_datasources.return_value = [ Datasource("database1", "schema1", "table1"), Datasource("database1", "schema1", "table2"), Datasource("database2", None, "table1"), ] datasources = ConnectorRegistry.get_user_datasources( mock_session) assert datasources == []
def create_missing_perms(self): """Creates missing perms for datasources, schemas and metrics""" from superset import db from superset.models import core as models logging.info( 'Fetching a set of all perms to lookup which ones are missing') all_pvs = set() for pv in self.get_session.query(self.permissionview_model).all(): if pv.permission and pv.view_menu: all_pvs.add((pv.permission.name, pv.view_menu.name)) def merge_pv(view_menu, perm): """Create permission view menu only if it doesn't exist""" if view_menu and perm and (view_menu, perm) not in all_pvs: self.merge_perm(view_menu, perm) logging.info('Creating missing datasource permissions.') datasources = ConnectorRegistry.get_all_datasources(db.session) for datasource in datasources: merge_pv('datasource_access', datasource.get_perm()) merge_pv('schema_access', datasource.schema_perm) logging.info("see if we can add a schema pem here") merge_pv('pbSchema_access', datasource.schema_perm) logging.info('Creating missing database permissions.') databases = db.session.query(models.Database).all() for database in databases: merge_pv('database_access', database.perm) logging.info('Creating missing metrics permissions') metrics = [] for datasource_class in ConnectorRegistry.sources.values(): metrics += list( db.session.query(datasource_class.metric_class).all()) for metric in metrics: if metric.is_restricted: merge_pv('metric_access', metric.perm)
def save(self): datasource = json.loads(request.form.get('data')) datasource_id = datasource.get('id') datasource_type = datasource.get('type') orm_datasource = ConnectorRegistry.get_datasource( datasource_type, datasource_id, db.session) if not check_ownership(orm_datasource, raise_if_false=False): return json_error_response( __( 'You are not authorized to modify ' 'this data source configuration'), status='401', ) if 'owners' in datasource: datasource['owners'] = db.session.query(orm_datasource.owner_class).filter( orm_datasource.owner_class.id.in_(datasource['owners'])).all() orm_datasource.update_from_object(datasource) data = orm_datasource.data db.session.commit() return self.json_response(data)
def export_dashboards(cls, dashboard_ids): copied_dashboards = [] datasource_ids = set() for dashboard_id in dashboard_ids: # make sure that dashboard_id is an integer dashboard_id = int(dashboard_id) copied_dashboard = ( db.session.query(Dashboard) .options(subqueryload(Dashboard.slices)) .filter_by(id=dashboard_id).first() ) make_transient(copied_dashboard) for slc in copied_dashboard.slices: datasource_ids.add((slc.datasource_id, slc.datasource_type)) # add extra params for the import slc.alter_params( remote_id=slc.id, datasource_name=slc.datasource.name, schema=slc.datasource.name, database_name=slc.datasource.database.name, ) copied_dashboard.alter_params(remote_id=dashboard_id) copied_dashboards.append(copied_dashboard) eager_datasources = [] for dashboard_id, dashboard_type in datasource_ids: eager_datasource = ConnectorRegistry.get_eager_datasource( db.session, dashboard_type, dashboard_id) eager_datasource.alter_params( remote_id=eager_datasource.id, database_name=eager_datasource.database.name, ) make_transient(eager_datasource) eager_datasources.append(eager_datasource) return json.dumps({ 'dashboards': copied_dashboards, 'datasources': eager_datasources, }, cls=utils.DashboardEncoder, indent=4)
def create_missing_perms(self) -> None: """ Creates missing FAB permissions for datasources, schemas and metrics. """ from superset import db from superset.connectors.base.models import BaseMetric from superset.models import core as models logging.info("Fetching a set of all perms to lookup which ones are missing") all_pvs = set() for pv in self.get_session.query(self.permissionview_model).all(): if pv.permission and pv.view_menu: all_pvs.add((pv.permission.name, pv.view_menu.name)) def merge_pv(view_menu, perm): """Create permission view menu only if it doesn't exist""" if view_menu and perm and (view_menu, perm) not in all_pvs: self.add_permission_view_menu(view_menu, perm) logging.info("Creating missing datasource permissions.") datasources = ConnectorRegistry.get_all_datasources(db.session) for datasource in datasources: merge_pv("datasource_access", datasource.get_perm()) merge_pv("schema_access", datasource.schema_perm) logging.info("Creating missing database permissions.") databases = db.session.query(models.Database).all() for database in databases: merge_pv("database_access", database.perm) logging.info("Creating missing metrics permissions") metrics: List[BaseMetric] = [] for datasource_class in ConnectorRegistry.sources.values(): metrics += list(db.session.query(datasource_class.metric_class).all()) for metric in metrics: if metric.is_restricted: merge_pv("metric_access", metric.perm)
def __init__( # pylint: disable=too-many-arguments self, datasource: DatasourceDict, queries: List[Dict[str, Any]], force: bool = False, custom_cache_timeout: Optional[int] = None, result_type: Optional[ChartDataResultType] = None, result_format: Optional[ChartDataResultFormat] = None, ) -> None: self.datasource = ConnectorRegistry.get_datasource( str(datasource["type"]), int(datasource["id"]), db.session) self.queries = [QueryObject(**query_obj) for query_obj in queries] self.force = force self.custom_cache_timeout = custom_cache_timeout self.result_type = result_type or ChartDataResultType.FULL self.result_format = result_format or ChartDataResultFormat.JSON self.cache_values = { "datasource": datasource, "queries": queries, "result_type": self.result_type, "result_format": self.result_format, }
def external_metadata( self, datasource_type: str, datasource_id: int ) -> FlaskResponse: """Gets column info from the source system""" if datasource_type == "druid": datasource = ConnectorRegistry.get_datasource( datasource_type, datasource_id, db.session ) elif datasource_type == "table": database = ( db.session.query(Database).filter_by(id=request.args.get("db_id")).one() ) table_class = ConnectorRegistry.sources["table"] datasource = table_class( database=database, table_name=request.args.get("table_name"), schema=request.args.get("schema") or None, ) else: raise Exception(f"Unsupported datasource_type: {datasource_type}") external_metadata = datasource.external_metadata() return self.json_response(external_metadata)
def get_datasources_accessible_by_user( self, database, datasource_names: List[DatasourceName], schema: str = None) -> List[DatasourceName]: from superset import db if self.database_access(database) or self.all_datasource_access(): return datasource_names if schema: schema_perm = self.get_schema_perm(database, schema) if self.can_access('schema_access', schema_perm): return datasource_names user_perms = self.user_datasource_perms() user_datasources = ConnectorRegistry.query_datasources_by_permissions( db.session, database, user_perms) if schema: names = { d.table_name for d in user_datasources if d.schema == schema} return [d for d in datasource_names if d in names] else: full_names = {d.full_name for d in user_datasources} return [d for d in datasource_names if d in full_names]
def save(self) -> FlaskResponse: data = request.form.get("data") if not isinstance(data, str): return json_error_response("Request missing data field.", status=500) datasource_dict = json.loads(data) datasource_id = datasource_dict.get("id") datasource_type = datasource_dict.get("type") database_id = datasource_dict["database"].get("id") orm_datasource = ConnectorRegistry.get_datasource( datasource_type, datasource_id, db.session ) orm_datasource.database_id = database_id if "owners" in datasource_dict and orm_datasource.owner_class is not None: datasource_dict["owners"] = ( db.session.query(orm_datasource.owner_class) .filter(orm_datasource.owner_class.id.in_(datasource_dict["owners"])) .all() ) duplicates = [ name for name, count in Counter( [col["column_name"] for col in datasource_dict["columns"]] ).items() if count > 1 ] if duplicates: return json_error_response( f"Duplicate column name(s): {','.join(duplicates)}", status=409 ) orm_datasource.update_from_object(datasource_dict) data = orm_datasource.data db.session.commit() return self.json_response(data)
def accessible_by_user(self, database, datasource_names, schema=None): if self.database_access(database) or self.all_datasource_access(): return datasource_names schema_perm = utils.get_schema_perm(database, schema) if schema and utils.can_access( sm, 'schema_access', schema_perm, g.user): return datasource_names role_ids = set([role.id for role in g.user.roles]) # TODO: cache user_perms or user_datasources user_pvms = ( db.session.query(ab_models.PermissionView) .join(ab_models.Permission) .filter(ab_models.Permission.name == 'datasource_access') .filter(ab_models.PermissionView.role.any( ab_models.Role.id.in_(role_ids))) .all() ) user_perms = set([pvm.view_menu.name for pvm in user_pvms]) user_datasources = ConnectorRegistry.query_datasources_by_permissions( db.session, database, user_perms) full_names = set([d.full_name for d in user_datasources]) return [d for d in datasource_names if d in full_names]
def test_query_cache_key_changes_when_metric_is_updated(self): self.login(username="******") payload = get_query_context("birth_names") # make temporary change and revert it to refresh the changed_on property datasource = ConnectorRegistry.get_datasource( datasource_type=payload["datasource"]["type"], datasource_id=payload["datasource"]["id"], session=db.session, ) datasource.metrics.append( SqlMetric(metric_name="foo", expression="select 1;")) db.session.commit() # construct baseline query_cache_key query_context = ChartDataQueryContextSchema().load(payload) query_object = query_context.queries[0] cache_key_original = query_context.query_cache_key(query_object) # wait a second since mysql records timestamps in second granularity time.sleep(1) datasource.metrics[0].expression = "select 2;" db.session.commit() # create new QueryContext with unchanged attributes, extract new query_cache_key query_context = ChartDataQueryContextSchema().load(payload) query_object = query_context.queries[0] cache_key_new = query_context.query_cache_key(query_object) datasource.metrics = [] db.session.commit() # the new cache_key should be different due to updated datasource self.assertNotEqual(cache_key_original, cache_key_new)
def get_user_datasources(self) -> List["BaseDatasource"]: """ Collect datasources which the user has explicit permissions to. :returns: The list of datasources """ user_perms = self.user_view_menu_names("datasource_access") schema_perms = self.user_view_menu_names("schema_access") user_datasources = set() for datasource_class in ConnectorRegistry.sources.values(): user_datasources.update( self.get_session.query(datasource_class) .filter( or_( datasource_class.perm.in_(user_perms), datasource_class.schema_perm.in_(schema_perms), ) ) .all() ) # group all datasources by database all_datasources = ConnectorRegistry.get_all_datasources(self.get_session) datasources_by_database: Dict["Database", Set["BaseDatasource"]] = defaultdict( set ) for datasource in all_datasources: datasources_by_database[datasource.database].add(datasource) # add datasources with implicit permission (eg, database access) for database, datasources in datasources_by_database.items(): if self.can_access_database(database): user_datasources.update(datasources) return list(user_datasources)
def get_viz( slice_id=None, form_data=None, datasource_type=None, datasource_id=None, force=False, ): if slice_id: slc = ( db.session.query(models.Slice) .filter_by(id=slice_id) .one() ) return slc.get_viz() else: viz_type = form_data.get('viz_type', 'table') datasource = ConnectorRegistry.get_datasource( datasource_type, datasource_id, db.session) viz_obj = viz.viz_types[viz_type]( datasource, form_data=form_data, force=force, ) return viz_obj
def configure_data_sources(self) -> None: # Registering sources module_datasource_map = self.config["DEFAULT_MODULE_DS_MAP"] module_datasource_map.update(self.config["ADDITIONAL_MODULE_DS_MAP"]) ConnectorRegistry.register_sources(module_datasource_map)
return GET_FEATURE_FLAGS_FUNC(deepcopy(_feature_flags)) return _feature_flags def is_feature_enabled(feature): """Utility function for checking whether a feature is turned on""" return get_feature_flags().get(feature) # Flask-Compress if conf.get("ENABLE_FLASK_COMPRESS"): Compress(app) talisman = Talisman() if app.config["TALISMAN_ENABLED"]: talisman.init_app(app, **app.config["TALISMAN_CONFIG"]) # Hook that provides administrators a handle on the Flask APP # after initialization flask_app_mutator = app.config.get("FLASK_APP_MUTATOR") if flask_app_mutator: flask_app_mutator(app) from superset import views # noqa # Registering sources module_datasource_map = app.config.get("DEFAULT_MODULE_DS_MAP") module_datasource_map.update(app.config.get("ADDITIONAL_MODULE_DS_MAP")) ConnectorRegistry.register_sources(module_datasource_map)
pass for middleware in app.config.get('ADDITIONAL_MIDDLEWARE'): app.wsgi_app = middleware(app.wsgi_app) class MyIndexView(IndexView): @expose('/') def index(self): return redirect('/superset/welcome') appbuilder = AppBuilder( app, db.session, base_template='superset/base.html', indexview=MyIndexView, security_manager_class=app.config.get("CUSTOM_SECURITY_MANAGER")) sm = appbuilder.sm get_session = appbuilder.get_session results_backend = app.config.get("RESULTS_BACKEND") # Registering sources module_datasource_map = app.config.get("DEFAULT_MODULE_DS_MAP") module_datasource_map.update(app.config.get("ADDITIONAL_MODULE_DS_MAP")) ConnectorRegistry.register_sources(module_datasource_map) from superset import views # noqa
def get_datasource_by_id(datasource_id: int, datasource_type: str) -> BaseDatasource: try: return ConnectorRegistry.get_datasource(datasource_type, datasource_id) except (NoResultFound, KeyError): raise DatasourceNotFoundValidationError()
def create_query_object_factory() -> QueryObjectFactory: return QueryObjectFactory(config, ConnectorRegistry(), db.session)
def get(self, datasource_type: str, datasource_id: int) -> FlaskResponse: datasource = ConnectorRegistry.get_datasource(datasource_type, datasource_id, db.session) return self.json_response(datasource.data)
def _convert_to_model(self, datasource: DatasourceDict) -> BaseDatasource: return ConnectorRegistry.get_datasource(str(datasource["type"]), int(datasource["id"]), db.session)
def __init__( self, datasource: Optional[DatasourceDict] = None, result_type: Optional[ChartDataResultType] = None, annotation_layers: Optional[List[Dict[str, Any]]] = None, applied_time_extras: Optional[Dict[str, str]] = None, apply_fetch_values_predicate: bool = False, granularity: Optional[str] = None, metrics: Optional[List[Union[Dict[str, Any], str]]] = None, groupby: Optional[List[str]] = None, filters: Optional[List[Dict[str, Any]]] = None, time_range: Optional[str] = None, time_shift: Optional[str] = None, is_timeseries: Optional[bool] = None, timeseries_limit: int = 0, row_limit: Optional[int] = None, row_offset: Optional[int] = None, timeseries_limit_metric: Optional[Metric] = None, order_desc: bool = True, extras: Optional[Dict[str, Any]] = None, columns: Optional[List[str]] = None, orderby: Optional[List[OrderBy]] = None, post_processing: Optional[List[Optional[Dict[str, Any]]]] = None, is_rowcount: bool = False, **kwargs: Any, ): columns = columns or [] groupby = groupby or [] extras = extras or {} annotation_layers = annotation_layers or [] self.is_rowcount = is_rowcount self.datasource = None if datasource: self.datasource = ConnectorRegistry.get_datasource( str(datasource["type"]), int(datasource["id"]), db.session) self.result_type = result_type self.apply_fetch_values_predicate = apply_fetch_values_predicate or False self.annotation_layers = [ layer for layer in annotation_layers # formula annotations don't affect the payload, hence can be dropped if layer["annotationType"] != "FORMULA" ] self.applied_time_extras = applied_time_extras or {} self.granularity = granularity self.from_dttm, self.to_dttm = get_since_until( relative_start=extras.get("relative_start", config["DEFAULT_RELATIVE_START_TIME"]), relative_end=extras.get("relative_end", config["DEFAULT_RELATIVE_END_TIME"]), time_range=time_range, time_shift=time_shift, ) # is_timeseries is True if time column is in either columns or groupby # (both are dimensions) self.is_timeseries = (is_timeseries if is_timeseries is not None else DTTM_ALIAS in columns + groupby) self.time_range = time_range self.time_shift = parse_human_timedelta(time_shift) self.post_processing = [ post_proc for post_proc in post_processing or [] if post_proc ] # Support metric reference/definition in the format of # 1. 'metric_name' - name of predefined metric # 2. { label: 'label_name' } - legacy format for a predefined metric # 3. { expressionType: 'SIMPLE' | 'SQL', ... } - adhoc metric self.metrics = metrics and [ x if isinstance(x, str) or is_adhoc_metric(x) else x["label"] # type: ignore for x in metrics ] self.row_limit = config["ROW_LIMIT"] if row_limit is None else row_limit self.row_offset = row_offset or 0 self.filter = filters or [] self.timeseries_limit = timeseries_limit self.timeseries_limit_metric = timeseries_limit_metric self.order_desc = order_desc self.extras = extras if config["SIP_15_ENABLED"]: self.extras["time_range_endpoints"] = get_time_range_endpoints( form_data=self.extras) self.columns = columns self.groupby = groupby or [] self.orderby = orderby or [] # rename deprecated fields for field in DEPRECATED_FIELDS: if field.old_name in kwargs: logger.warning( "The field `%s` is deprecated, please use `%s` instead.", field.old_name, field.new_name, ) value = kwargs[field.old_name] if value: if hasattr(self, field.new_name): logger.warning( "The field `%s` is already populated, " "replacing value with contents from `%s`.", field.new_name, field.old_name, ) setattr(self, field.new_name, value) # move deprecated extras fields to extras for field in DEPRECATED_EXTRAS_FIELDS: if field.old_name in kwargs: logger.warning( "The field `%s` is deprecated and should " "be passed to `extras` via the `%s` property.", field.old_name, field.new_name, ) value = kwargs[field.old_name] if value: if hasattr(self.extras, field.new_name): logger.warning( "The field `%s` is already populated in " "`extras`, replacing value with contents " "from `%s`.", field.new_name, field.old_name, ) self.extras[field.new_name] = value
def external_metadata(self, datasource_type=None, datasource_id=None): """Gets column info from the source system""" orm_datasource = ConnectorRegistry.get_datasource( datasource_type, datasource_id, db.session) return self.json_response(orm_datasource.external_metadata())
def fetch_all_datasources() -> List["BaseDatasource"]: return ConnectorRegistry.get_all_datasources(db.session)
def __init__( # pylint: disable=too-many-arguments,too-many-locals self, query_context: "QueryContext", annotation_layers: Optional[List[Dict[str, Any]]] = None, applied_time_extras: Optional[Dict[str, str]] = None, apply_fetch_values_predicate: bool = False, columns: Optional[List[str]] = None, datasource: Optional[DatasourceDict] = None, extras: Optional[Dict[str, Any]] = None, filters: Optional[List[QueryObjectFilterClause]] = None, granularity: Optional[str] = None, is_rowcount: bool = False, is_timeseries: Optional[bool] = None, metrics: Optional[List[Metric]] = None, order_desc: bool = True, orderby: Optional[List[OrderBy]] = None, post_processing: Optional[List[Optional[Dict[str, Any]]]] = None, result_type: Optional[ChartDataResultType] = None, row_limit: Optional[int] = None, row_offset: Optional[int] = None, series_columns: Optional[List[str]] = None, series_limit: int = 0, series_limit_metric: Optional[Metric] = None, time_range: Optional[str] = None, time_shift: Optional[str] = None, **kwargs: Any, ): columns = columns or [] extras = extras or {} annotation_layers = annotation_layers or [] self.time_offsets = kwargs.get("time_offsets", []) self.inner_from_dttm = kwargs.get("inner_from_dttm") self.inner_to_dttm = kwargs.get("inner_to_dttm") if series_columns: self.series_columns = series_columns elif is_timeseries and metrics: self.series_columns = columns else: self.series_columns = [] self.is_rowcount = is_rowcount self.datasource = None if datasource: self.datasource = ConnectorRegistry.get_datasource( str(datasource["type"]), int(datasource["id"]), db.session) self.result_type = result_type or query_context.result_type self.apply_fetch_values_predicate = apply_fetch_values_predicate or False self.annotation_layers = [ layer for layer in annotation_layers # formula annotations don't affect the payload, hence can be dropped if layer["annotationType"] != "FORMULA" ] self.applied_time_extras = applied_time_extras or {} self.granularity = granularity self.from_dttm, self.to_dttm = get_since_until( relative_start=extras.get("relative_start", config["DEFAULT_RELATIVE_START_TIME"]), relative_end=extras.get("relative_end", config["DEFAULT_RELATIVE_END_TIME"]), time_range=time_range, time_shift=time_shift, ) # is_timeseries is True if time column is in either columns or groupby # (both are dimensions) self.is_timeseries = (is_timeseries if is_timeseries is not None else DTTM_ALIAS in columns) self.time_range = time_range self.time_shift = parse_human_timedelta(time_shift) self.post_processing = [ post_proc for post_proc in post_processing or [] if post_proc ] # Support metric reference/definition in the format of # 1. 'metric_name' - name of predefined metric # 2. { label: 'label_name' } - legacy format for a predefined metric # 3. { expressionType: 'SIMPLE' | 'SQL', ... } - adhoc metric self.metrics = metrics and [ x if isinstance(x, str) or is_adhoc_metric(x) else x["label"] # type: ignore for x in metrics ] default_row_limit = (config["SAMPLES_ROW_LIMIT"] if self.result_type == ChartDataResultType.SAMPLES else config["ROW_LIMIT"]) self.row_limit = apply_max_row_limit(row_limit or default_row_limit) self.row_offset = row_offset or 0 self.filter = filters or [] self.series_limit = series_limit self.series_limit_metric = series_limit_metric self.order_desc = order_desc self.extras = extras if config["SIP_15_ENABLED"]: self.extras["time_range_endpoints"] = get_time_range_endpoints( form_data=self.extras) self.columns = columns self.orderby = orderby or [] self._rename_deprecated_fields(kwargs) self._move_deprecated_extra_fields(kwargs)
def invalidate(self) -> Response: """ Takes a list of datasources, finds the associated cache records and invalidates them and removes the database records --- post: description: >- Takes a list of datasources, finds the associated cache records and invalidates them and removes the database records requestBody: description: >- A list of datasources uuid or the tuples of database and datasource names required: true content: application/json: schema: $ref: "#/components/schemas/CacheInvalidationRequestSchema" responses: 201: description: cache was successfully invalidated 400: $ref: '#/components/responses/400' 500: $ref: '#/components/responses/500' """ try: datasources = CacheInvalidationRequestSchema().load(request.json) except KeyError: return self.response_400(message="Request is incorrect") except ValidationError as error: return self.response_400(message=str(error)) datasource_uids = set(datasources.get("datasource_uids", [])) for ds in datasources.get("datasources", []): ds_obj = ConnectorRegistry.get_datasource_by_name( session=db.session, datasource_type=ds.get("datasource_type"), datasource_name=ds.get("datasource_name"), schema=ds.get("schema"), database_name=ds.get("database_name"), ) if ds_obj: datasource_uids.add(ds_obj.uid) cache_key_objs = (db.session.query(CacheKey).filter( CacheKey.datasource_uid.in_(datasource_uids)).all()) cache_keys = [c.cache_key for c in cache_key_objs] if cache_key_objs: all_keys_deleted = cache_manager.cache.delete_many(*cache_keys) if not all_keys_deleted: # expected behavior as keys may expire and cache is not a # persistent storage logger.info( "Some of the cache keys were not deleted in the list %s", cache_keys) try: delete_stmt = CacheKey.__table__.delete().where( # pylint: disable=no-member CacheKey.cache_key.in_(cache_keys)) db.session.execute(delete_stmt) db.session.commit() self.stats_logger.gauge("invalidated_cache", len(cache_keys)) logger.info( "Invalidated %s cache records for %s datasources", len(cache_keys), len(datasource_uids), ) except SQLAlchemyError as ex: # pragma: no cover logger.error(ex) db.session.rollback() return self.response_500(str(ex)) db.session.commit() return self.response(201)
from superset.connectors.connector_registry import ConnectorRegistry from flask_appbuilder.security.sqla.models import User from superset.viz import TableViz from superset import db from flask import g import json f = open('d:\\home\\pivot_viz\\form_data.json') form_data_json = f.read() f.close() g.user = db.session.query(User).filter(User.id == 2).one() form_data = json.loads(form_data_json) datasource_type = form_data.get('viz_type', 'table') datasource_id = form_data['datasource'].split('__')[0] datasource = ConnectorRegistry.get_datasource(datasource_type, datasource_id, db.session) viz = TableViz(datasource, form_data) res = viz.get_payload(force=True) # print(res) with open('d:\\home\\pivot_viz\\response.json', 'w', encoding='utf8') as f: f.write(viz.json_dumps(res))
def add(self) -> FlaskResponse: allowed_datasources = [] datasources = [] # only if gamma is_gamma = False logging.debug('-------------------------') for role in g.user.roles: if str(role) == 'Gamma': is_gamma = True logging.debug(role.permissions) for perm in role.permissions: if str(perm).startswith('datasource access on ['): #'datasource access on [DB].[DATASOURCE](id:ID)') data_search = re.search( 'datasource access on \[([^\]]+)\]\.\[([^\]]+)\]\(id:([^\)]+)\)', str(perm)) if data_search: allowed_datasources.append({ "connection": data_search.group(1), "name": data_search.group(2), "id": data_search.group(3) }) for d in ConnectorRegistry.get_all_datasources(db.session): if (is_gamma): for a in allowed_datasources: table_name = d.short_data.get("name").split('.')[-1] if table_name == a.get("name") and d.short_data.get( "connection") == a.get("connection") and str( d.short_data.get("id")) == str(a.get("id")): if hasattr(d, 'custom_label'): datasources.append({ "value": str(d.id) + "__" + d.type, "label": d.custom_label }) else: datasources.append({ "value": str(d.id) + "__" + d.type, "label": repr(d) }) else: if hasattr(d, 'custom_label'): datasources.append({ "value": str(d.id) + "__" + d.type, "label": d.custom_label }) else: datasources.append({ "value": str(d.id) + "__" + d.type, "label": repr(d) }) payload = { "datasources": sorted(datasources, key=lambda d: d["label"]), "common": common_bootstrap_payload(), "user": bootstrap_user_data(g.user), } return self.render_template("superset/add_slice.html", bootstrap_data=json.dumps(payload))