Exemplo n.º 1
0
    def import_obj(cls, slc_to_import, import_time=None):
        """Inserts or overrides slc in the database.

        remote_id and import_time fields in params_dict are set to track the
        slice origin and ensure correct overrides for multiple imports.
        Slice.perm is used to find the datasources and connect them.
        """
        session = db.session
        make_transient(slc_to_import)
        slc_to_import.dashboards = []
        slc_to_import.alter_params(
            remote_id=slc_to_import.id, import_time=import_time)

        # find if the slice was already imported
        slc_to_override = None
        for slc in session.query(Slice).all():
            if ('remote_id' in slc.params_dict and
                    slc.params_dict['remote_id'] == slc_to_import.id):
                slc_to_override = slc

        slc_to_import = slc_to_import.copy()
        params = slc_to_import.params_dict
        slc_to_import.datasource_id = ConnectorRegistry.get_datasource_by_name(
            session, slc_to_import.datasource_type, params['datasource_name'],
            params['schema'], params['database_name']).id
        if slc_to_override:
            slc_to_override.override(slc_to_import)
            session.flush()
            return slc_to_override.id
        session.add(slc_to_import)
        logging.info('Final slice: {}'.format(slc_to_import.to_json()))
        session.flush()
        return slc_to_import.id
Exemplo n.º 2
0
def create_missing_perms():
    """Creates missing perms for datasources, schemas and metrics"""

    logging.info(
        "Fetching a set of all perms to lookup which ones are missing")
    all_pvs = set()
    for pv in sm.get_session.query(sm.permissionview_model).all():
        if pv.permission and pv.view_menu:
            all_pvs.add((pv.permission.name, pv.view_menu.name))

    def merge_pv(view_menu, perm):
        """Create permission view menu only if it doesn't exist"""
        if view_menu and perm and (view_menu, perm) not in all_pvs:
            merge_perm(sm, view_menu, perm)

    logging.info("Creating missing datasource permissions.")
    datasources = ConnectorRegistry.get_all_datasources(db.session)
    for datasource in datasources:
        merge_pv('datasource_access', datasource.get_perm())
        merge_pv('schema_access', datasource.schema_perm)

    logging.info("Creating missing database permissions.")
    databases = db.session.query(models.Database).all()
    for database in databases:
        merge_pv('database_access', database.perm)

    logging.info("Creating missing metrics permissions")
    metrics = []
    for datasource_class in ConnectorRegistry.sources.values():
        metrics += list(db.session.query(datasource_class.metric_class).all())

    for metric in metrics:
        if (metric.is_restricted):
            merge_pv('metric_access', metric.perm)
Exemplo n.º 3
0
    def import_obj(cls, slc_to_import, slc_to_override, import_time=None):
        """Inserts or overrides slc in the database.

        remote_id and import_time fields in params_dict are set to track the
        slice origin and ensure correct overrides for multiple imports.
        Slice.perm is used to find the datasources and connect them.

        :param Slice slc_to_import: Slice object to import
        :param Slice slc_to_override: Slice to replace, id matches remote_id
        :returns: The resulting id for the imported slice
        :rtype: int
        """
        session = db.session
        make_transient(slc_to_import)
        slc_to_import.dashboards = []
        slc_to_import.alter_params(
            remote_id=slc_to_import.id, import_time=import_time)

        slc_to_import = slc_to_import.copy()
        params = slc_to_import.params_dict
        slc_to_import.datasource_id = ConnectorRegistry.get_datasource_by_name(
            session, slc_to_import.datasource_type, params['datasource_name'],
            params['schema'], params['database_name']).id
        if slc_to_override:
            slc_to_override.override(slc_to_import)
            session.flush()
            return slc_to_override.id
        session.add(slc_to_import)
        logging.info('Final slice: {}'.format(slc_to_import.to_json()))
        session.flush()
        return slc_to_import.id
Exemplo n.º 4
0
def create_missing_datasource_perms(view_menu_set):
    logging.info("Creating missing datasource permissions.")
    datasources = ConnectorRegistry.get_all_datasources(
        db.session)
    for datasource in datasources:
        if datasource and datasource.perm not in view_menu_set:
            merge_perm(sm, 'datasource_access', datasource.get_perm())
            if datasource.schema_perm:
                merge_perm(sm, 'schema_access', datasource.schema_perm)
Exemplo n.º 5
0
 def __init__(
         self,
         datasource: Dict,
         queries: List[Dict],
 ):
     self.datasource = ConnectorRegistry.get_datasource(datasource.get('type'),
                                                        int(datasource.get('id')),
                                                        db.session)
     self.queries = list(map(lambda query_obj: QueryObject(**query_obj), queries))
Exemplo n.º 6
0
    def datasource_access_by_name(
            self, database, datasource_name, schema=None):
        if self.database_access(database) or self.all_datasource_access():
            return True

        schema_perm = utils.get_schema_perm(database, schema)
        if schema and self.can_access('schema_access', schema_perm):
            return True

        datasources = ConnectorRegistry.query_datasources_by_name(
            db.session, database, datasource_name, schema=schema)
        for datasource in datasources:
            if self.can_access('datasource_access', datasource.perm):
                return True
        return False
Exemplo n.º 7
0
    def __init__(
            self,
            datasource: Dict,
            queries: List[Dict],
            force: bool = False,
            custom_cache_timeout: int = None,
    ):
        self.datasource = ConnectorRegistry.get_datasource(datasource.get('type'),
                                                           int(datasource.get('id')),
                                                           db.session)
        self.queries = list(map(lambda query_obj: QueryObject(**query_obj), queries))

        self.force = force

        self.custom_cache_timeout = custom_cache_timeout

        self.enforce_numerical_metrics = True
Exemplo n.º 8
0
    def save(self):
        datasource = json.loads(request.form.get('data'))
        datasource_id = datasource.get('id')
        datasource_type = datasource.get('type')
        orm_datasource = ConnectorRegistry.get_datasource(
            datasource_type, datasource_id, db.session)

        if not check_ownership(orm_datasource, raise_if_false=False):
            return json_error_response(
                __(
                    'You are not authorized to modify '
                    'this data source configuration'),
                status='401',
            )
        orm_datasource.update_from_object(datasource)
        data = orm_datasource.data
        db.session.commit()
        return self.json_response(data)
Exemplo n.º 9
0
    def accessible_by_user(self, database, datasource_names, schema=None):
        if self.database_access(database) or self.all_datasource_access():
            return datasource_names

        if schema:
            schema_perm = utils.get_schema_perm(database, schema)
            if self.can_access('schema_access', schema_perm):
                return datasource_names

        user_perms = self.user_datasource_perms()
        user_datasources = ConnectorRegistry.query_datasources_by_permissions(
            db.session, database, user_perms)
        if schema:
            names = {
                d.table_name
                for d in user_datasources if d.schema == schema}
            return [d for d in datasource_names if d in names]
        else:
            full_names = {d.full_name for d in user_datasources}
            return [d for d in datasource_names if d in full_names]
Exemplo n.º 10
0
    def import_obj(
        cls,
        slc_to_import: "Slice",
        slc_to_override: Optional["Slice"],
        import_time: Optional[int] = None,
    ) -> int:
        """Inserts or overrides slc in the database.

        remote_id and import_time fields in params_dict are set to track the
        slice origin and ensure correct overrides for multiple imports.
        Slice.perm is used to find the datasources and connect them.

        :param Slice slc_to_import: Slice object to import
        :param Slice slc_to_override: Slice to replace, id matches remote_id
        :returns: The resulting id for the imported slice
        :rtype: int
        """
        session = db.session
        make_transient(slc_to_import)
        slc_to_import.dashboards = []
        slc_to_import.alter_params(remote_id=slc_to_import.id,
                                   import_time=import_time)

        slc_to_import = slc_to_import.copy()
        slc_to_import.reset_ownership()
        params = slc_to_import.params_dict
        slc_to_import.datasource_id = ConnectorRegistry.get_datasource_by_name(  # type: ignore
            session,
            slc_to_import.datasource_type,
            params["datasource_name"],
            params["schema"],
            params["database_name"],
        ).id
        if slc_to_override:
            slc_to_override.override(slc_to_import)
            session.flush()
            return slc_to_override.id
        session.add(slc_to_import)
        logging.info("Final slice: {}".format(slc_to_import.to_json()))
        session.flush()
        return slc_to_import.id
Exemplo n.º 11
0
    def export_dashboards(cls, dashboard_ids):
        copied_dashboards = []
        datasource_ids = set()
        for dashboard_id in dashboard_ids:
            # make sure that dashboard_id is an integer
            dashboard_id = int(dashboard_id)
            copied_dashboard = (db.session.query(Dashboard).options(
                subqueryload(
                    Dashboard.slices)).filter_by(id=dashboard_id).first())
            make_transient(copied_dashboard)
            for slc in copied_dashboard.slices:
                datasource_ids.add((slc.datasource_id, slc.datasource_type))
                # add extra params for the import
                slc.alter_params(
                    remote_id=slc.id,
                    datasource_name=slc.datasource.name,
                    schema=slc.datasource.name,
                    database_name=slc.datasource.database.name,
                )
            copied_dashboard.alter_params(remote_id=dashboard_id)
            copied_dashboards.append(copied_dashboard)

            eager_datasources = []
            for dashboard_id, dashboard_type in datasource_ids:
                eager_datasource = ConnectorRegistry.get_eager_datasource(
                    db.session, dashboard_type, dashboard_id)
                eager_datasource.alter_params(
                    remote_id=eager_datasource.id,
                    database_name=eager_datasource.database.name,
                )
                make_transient(eager_datasource)
                eager_datasources.append(eager_datasource)

        return json.dumps(
            {
                "dashboards": copied_dashboards,
                "datasources": eager_datasources
            },
            cls=utils.DashboardEncoder,
            indent=4,
        )
Exemplo n.º 12
0
    def get_datasources_accessible_by_user(
        self,
        database: "Database",
        datasource_names: List[DatasourceName],
        schema: Optional[str] = None,
    ) -> List[DatasourceName]:
        """
        Return the list of SQL tables accessible by the user.

        :param database: The SQL database
        :param datasource_names: The list of eligible SQL tables w/ schema
        :param schema: The fallback SQL schema if not present in the table name
        :returns: The list of accessible SQL tables w/ schema
        """

        from superset import db

        if self.database_access(database) or self.all_datasource_access():
            return datasource_names

        if schema:
            schema_perm = self.get_schema_perm(database, schema)
            if schema_perm and self.can_access("schema_access", schema_perm):
                return datasource_names

        user_perms = self.user_view_menu_names("datasource_access")
        schema_perms = self.user_view_menu_names("schema_access")
        user_datasources = ConnectorRegistry.query_datasources_by_permissions(
            db.session, database, user_perms, schema_perms)
        if schema:
            names = {
                d.table_name
                for d in user_datasources if d.schema == schema
            }
            return [d for d in datasource_names if d in names]
        else:
            full_names = {d.full_name for d in user_datasources}
            return [
                d for d in datasource_names
                if f"[{database}].[{d}]" in full_names
            ]
Exemplo n.º 13
0
    def accessible_by_user(self, database, datasource_names, schema=None):
        if self.database_access(database) or self.all_datasource_access():
            return datasource_names

        schema_perm = utils.get_schema_perm(database, schema)
        if schema and utils.can_access(sm, 'schema_access', schema_perm,
                                       g.user):
            return datasource_names

        role_ids = set([role.id for role in g.user.roles])
        # TODO: cache user_perms or user_datasources
        user_pvms = (db.session.query(ab_models.PermissionView).join(
            ab_models.Permission).filter(
                ab_models.Permission.name == 'datasource_access').filter(
                    ab_models.PermissionView.role.any(
                        ab_models.Role.id.in_(role_ids))).all())
        user_perms = set([pvm.view_menu.name for pvm in user_pvms])
        user_datasources = ConnectorRegistry.query_datasources_by_permissions(
            db.session, database, user_perms)
        full_names = set([d.full_name for d in user_datasources])
        return [d for d in datasource_names if d in full_names]
Exemplo n.º 14
0
    def accessible_by_user(self, database, datasource_names, schema=None):
        if self.database_access(database) or self.all_datasource_access():
            return datasource_names

        if schema:
            schema_perm = utils.get_schema_perm(database, schema)
            if self.can_access('schema_access', schema_perm):
                return datasource_names

        user_perms = self.user_datasource_perms()
        user_datasources = ConnectorRegistry.query_datasources_by_permissions(
            db.session, database, user_perms)
        if schema:
            names = {
                d.table_name
                for d in user_datasources if d.schema == schema
            }
            return [d for d in datasource_names if d in names]
        else:
            full_names = {d.full_name for d in user_datasources}
            return [d for d in datasource_names if d in full_names]
Exemplo n.º 15
0
def get_viz(
    slice_id=None,
    form_data=None,
    datasource_type=None,
    datasource_id=None,
    force=False,
):
    if slice_id:
        slc = (db.session.query(models.Slice).filter_by(id=slice_id).one())
        return slc.get_viz()
    else:
        viz_type = form_data.get('viz_type', 'table')
        datasource = ConnectorRegistry.get_datasource(datasource_type,
                                                      datasource_id,
                                                      db.session)
        viz_obj = viz.viz_types[viz_type](
            datasource,
            form_data=form_data,
            force=force,
        )
        return viz_obj
Exemplo n.º 16
0
    def test_get_user_datasources_gamma(self):
        Datasource = namedtuple("Datasource", ["database", "schema", "name"])

        mock_session = mock.MagicMock()
        mock_session.query.return_value.filter.return_value.all.return_value = []

        with mock.patch("superset.security_manager") as mock_security_manager:
            mock_security_manager.can_access_database.return_value = False

            with mock.patch.object(
                    ConnectorRegistry,
                    "get_all_datasources") as mock_get_all_datasources:
                mock_get_all_datasources.return_value = [
                    Datasource("database1", "schema1", "table1"),
                    Datasource("database1", "schema1", "table2"),
                    Datasource("database2", None, "table1"),
                ]

                datasources = ConnectorRegistry.get_user_datasources(
                    mock_session)

        assert datasources == []
Exemplo n.º 17
0
    def create_missing_perms(self):
        """Creates missing perms for datasources, schemas and metrics"""
        from superset import db
        from superset.models import core as models

        logging.info(
            'Fetching a set of all perms to lookup which ones are missing')
        all_pvs = set()
        for pv in self.get_session.query(self.permissionview_model).all():
            if pv.permission and pv.view_menu:
                all_pvs.add((pv.permission.name, pv.view_menu.name))

        def merge_pv(view_menu, perm):
            """Create permission view menu only if it doesn't exist"""
            if view_menu and perm and (view_menu, perm) not in all_pvs:
                self.merge_perm(view_menu, perm)

        logging.info('Creating missing datasource permissions.')
        datasources = ConnectorRegistry.get_all_datasources(db.session)
        for datasource in datasources:
            merge_pv('datasource_access', datasource.get_perm())
            merge_pv('schema_access', datasource.schema_perm)
            logging.info("see if we can add a schema pem here")
            merge_pv('pbSchema_access', datasource.schema_perm)

        logging.info('Creating missing database permissions.')
        databases = db.session.query(models.Database).all()
        for database in databases:
            merge_pv('database_access', database.perm)

        logging.info('Creating missing metrics permissions')
        metrics = []
        for datasource_class in ConnectorRegistry.sources.values():
            metrics += list(
                db.session.query(datasource_class.metric_class).all())

        for metric in metrics:
            if metric.is_restricted:
                merge_pv('metric_access', metric.perm)
Exemplo n.º 18
0
    def save(self):
        datasource = json.loads(request.form.get('data'))
        datasource_id = datasource.get('id')
        datasource_type = datasource.get('type')
        orm_datasource = ConnectorRegistry.get_datasource(
            datasource_type, datasource_id, db.session)

        if not check_ownership(orm_datasource, raise_if_false=False):
            return json_error_response(
                __(
                    'You are not authorized to modify '
                    'this data source configuration'),
                status='401',
            )

        if 'owners' in datasource:
            datasource['owners'] = db.session.query(orm_datasource.owner_class).filter(
                orm_datasource.owner_class.id.in_(datasource['owners'])).all()
        orm_datasource.update_from_object(datasource)
        data = orm_datasource.data
        db.session.commit()
        return self.json_response(data)
Exemplo n.º 19
0
    def export_dashboards(cls, dashboard_ids):
        copied_dashboards = []
        datasource_ids = set()
        for dashboard_id in dashboard_ids:
            # make sure that dashboard_id is an integer
            dashboard_id = int(dashboard_id)
            copied_dashboard = (
                db.session.query(Dashboard)
                .options(subqueryload(Dashboard.slices))
                .filter_by(id=dashboard_id).first()
            )
            make_transient(copied_dashboard)
            for slc in copied_dashboard.slices:
                datasource_ids.add((slc.datasource_id, slc.datasource_type))
                # add extra params for the import
                slc.alter_params(
                    remote_id=slc.id,
                    datasource_name=slc.datasource.name,
                    schema=slc.datasource.name,
                    database_name=slc.datasource.database.name,
                )
            copied_dashboard.alter_params(remote_id=dashboard_id)
            copied_dashboards.append(copied_dashboard)

            eager_datasources = []
            for dashboard_id, dashboard_type in datasource_ids:
                eager_datasource = ConnectorRegistry.get_eager_datasource(
                    db.session, dashboard_type, dashboard_id)
                eager_datasource.alter_params(
                    remote_id=eager_datasource.id,
                    database_name=eager_datasource.database.name,
                )
                make_transient(eager_datasource)
                eager_datasources.append(eager_datasource)

        return json.dumps({
            'dashboards': copied_dashboards,
            'datasources': eager_datasources,
        }, cls=utils.DashboardEncoder, indent=4)
Exemplo n.º 20
0
    def create_missing_perms(self) -> None:
        """
        Creates missing FAB permissions for datasources, schemas and metrics.
        """

        from superset import db
        from superset.connectors.base.models import BaseMetric
        from superset.models import core as models

        logging.info("Fetching a set of all perms to lookup which ones are missing")
        all_pvs = set()
        for pv in self.get_session.query(self.permissionview_model).all():
            if pv.permission and pv.view_menu:
                all_pvs.add((pv.permission.name, pv.view_menu.name))

        def merge_pv(view_menu, perm):
            """Create permission view menu only if it doesn't exist"""
            if view_menu and perm and (view_menu, perm) not in all_pvs:
                self.add_permission_view_menu(view_menu, perm)

        logging.info("Creating missing datasource permissions.")
        datasources = ConnectorRegistry.get_all_datasources(db.session)
        for datasource in datasources:
            merge_pv("datasource_access", datasource.get_perm())
            merge_pv("schema_access", datasource.schema_perm)

        logging.info("Creating missing database permissions.")
        databases = db.session.query(models.Database).all()
        for database in databases:
            merge_pv("database_access", database.perm)

        logging.info("Creating missing metrics permissions")
        metrics: List[BaseMetric] = []
        for datasource_class in ConnectorRegistry.sources.values():
            metrics += list(db.session.query(datasource_class.metric_class).all())

        for metric in metrics:
            if metric.is_restricted:
                merge_pv("metric_access", metric.perm)
Exemplo n.º 21
0
 def __init__(  # pylint: disable=too-many-arguments
     self,
     datasource: DatasourceDict,
     queries: List[Dict[str, Any]],
     force: bool = False,
     custom_cache_timeout: Optional[int] = None,
     result_type: Optional[ChartDataResultType] = None,
     result_format: Optional[ChartDataResultFormat] = None,
 ) -> None:
     self.datasource = ConnectorRegistry.get_datasource(
         str(datasource["type"]), int(datasource["id"]), db.session)
     self.queries = [QueryObject(**query_obj) for query_obj in queries]
     self.force = force
     self.custom_cache_timeout = custom_cache_timeout
     self.result_type = result_type or ChartDataResultType.FULL
     self.result_format = result_format or ChartDataResultFormat.JSON
     self.cache_values = {
         "datasource": datasource,
         "queries": queries,
         "result_type": self.result_type,
         "result_format": self.result_format,
     }
Exemplo n.º 22
0
 def external_metadata(
     self, datasource_type: str, datasource_id: int
 ) -> FlaskResponse:
     """Gets column info from the source system"""
     if datasource_type == "druid":
         datasource = ConnectorRegistry.get_datasource(
             datasource_type, datasource_id, db.session
         )
     elif datasource_type == "table":
         database = (
             db.session.query(Database).filter_by(id=request.args.get("db_id")).one()
         )
         table_class = ConnectorRegistry.sources["table"]
         datasource = table_class(
             database=database,
             table_name=request.args.get("table_name"),
             schema=request.args.get("schema") or None,
         )
     else:
         raise Exception(f"Unsupported datasource_type: {datasource_type}")
     external_metadata = datasource.external_metadata()
     return self.json_response(external_metadata)
Exemplo n.º 23
0
    def get_datasources_accessible_by_user(
            self, database, datasource_names: List[DatasourceName],
            schema: str = None) -> List[DatasourceName]:
        from superset import db
        if self.database_access(database) or self.all_datasource_access():
            return datasource_names

        if schema:
            schema_perm = self.get_schema_perm(database, schema)
            if self.can_access('schema_access', schema_perm):
                return datasource_names

        user_perms = self.user_datasource_perms()
        user_datasources = ConnectorRegistry.query_datasources_by_permissions(
            db.session, database, user_perms)
        if schema:
            names = {
                d.table_name
                for d in user_datasources if d.schema == schema}
            return [d for d in datasource_names if d in names]
        else:
            full_names = {d.full_name for d in user_datasources}
            return [d for d in datasource_names if d in full_names]
Exemplo n.º 24
0
    def save(self) -> FlaskResponse:
        data = request.form.get("data")
        if not isinstance(data, str):
            return json_error_response("Request missing data field.", status=500)

        datasource_dict = json.loads(data)
        datasource_id = datasource_dict.get("id")
        datasource_type = datasource_dict.get("type")
        database_id = datasource_dict["database"].get("id")
        orm_datasource = ConnectorRegistry.get_datasource(
            datasource_type, datasource_id, db.session
        )
        orm_datasource.database_id = database_id

        if "owners" in datasource_dict and orm_datasource.owner_class is not None:
            datasource_dict["owners"] = (
                db.session.query(orm_datasource.owner_class)
                .filter(orm_datasource.owner_class.id.in_(datasource_dict["owners"]))
                .all()
            )

        duplicates = [
            name
            for name, count in Counter(
                [col["column_name"] for col in datasource_dict["columns"]]
            ).items()
            if count > 1
        ]
        if duplicates:
            return json_error_response(
                f"Duplicate column name(s): {','.join(duplicates)}", status=409
            )
        orm_datasource.update_from_object(datasource_dict)
        data = orm_datasource.data
        db.session.commit()

        return self.json_response(data)
Exemplo n.º 25
0
    def accessible_by_user(self, database, datasource_names, schema=None):
        if self.database_access(database) or self.all_datasource_access():
            return datasource_names

        schema_perm = utils.get_schema_perm(database, schema)
        if schema and utils.can_access(
                sm, 'schema_access', schema_perm, g.user):
            return datasource_names

        role_ids = set([role.id for role in g.user.roles])
        # TODO: cache user_perms or user_datasources
        user_pvms = (
            db.session.query(ab_models.PermissionView)
            .join(ab_models.Permission)
            .filter(ab_models.Permission.name == 'datasource_access')
            .filter(ab_models.PermissionView.role.any(
                ab_models.Role.id.in_(role_ids)))
            .all()
        )
        user_perms = set([pvm.view_menu.name for pvm in user_pvms])
        user_datasources = ConnectorRegistry.query_datasources_by_permissions(
            db.session, database, user_perms)
        full_names = set([d.full_name for d in user_datasources])
        return [d for d in datasource_names if d in full_names]
Exemplo n.º 26
0
    def test_query_cache_key_changes_when_metric_is_updated(self):
        self.login(username="******")
        payload = get_query_context("birth_names")

        # make temporary change and revert it to refresh the changed_on property
        datasource = ConnectorRegistry.get_datasource(
            datasource_type=payload["datasource"]["type"],
            datasource_id=payload["datasource"]["id"],
            session=db.session,
        )

        datasource.metrics.append(
            SqlMetric(metric_name="foo", expression="select 1;"))
        db.session.commit()

        # construct baseline query_cache_key
        query_context = ChartDataQueryContextSchema().load(payload)
        query_object = query_context.queries[0]
        cache_key_original = query_context.query_cache_key(query_object)

        # wait a second since mysql records timestamps in second granularity
        time.sleep(1)

        datasource.metrics[0].expression = "select 2;"
        db.session.commit()

        # create new QueryContext with unchanged attributes, extract new query_cache_key
        query_context = ChartDataQueryContextSchema().load(payload)
        query_object = query_context.queries[0]
        cache_key_new = query_context.query_cache_key(query_object)

        datasource.metrics = []
        db.session.commit()

        # the new cache_key should be different due to updated datasource
        self.assertNotEqual(cache_key_original, cache_key_new)
Exemplo n.º 27
0
    def get_user_datasources(self) -> List["BaseDatasource"]:
        """
        Collect datasources which the user has explicit permissions to.

        :returns: The list of datasources
        """

        user_perms = self.user_view_menu_names("datasource_access")
        schema_perms = self.user_view_menu_names("schema_access")
        user_datasources = set()
        for datasource_class in ConnectorRegistry.sources.values():
            user_datasources.update(
                self.get_session.query(datasource_class)
                .filter(
                    or_(
                        datasource_class.perm.in_(user_perms),
                        datasource_class.schema_perm.in_(schema_perms),
                    )
                )
                .all()
            )

        # group all datasources by database
        all_datasources = ConnectorRegistry.get_all_datasources(self.get_session)
        datasources_by_database: Dict["Database", Set["BaseDatasource"]] = defaultdict(
            set
        )
        for datasource in all_datasources:
            datasources_by_database[datasource.database].add(datasource)

        # add datasources with implicit permission (eg, database access)
        for database, datasources in datasources_by_database.items():
            if self.can_access_database(database):
                user_datasources.update(datasources)

        return list(user_datasources)
Exemplo n.º 28
0
def get_viz(
        slice_id=None,
        form_data=None,
        datasource_type=None,
        datasource_id=None,
        force=False,
):
    if slice_id:
        slc = (
            db.session.query(models.Slice)
            .filter_by(id=slice_id)
            .one()
        )
        return slc.get_viz()
    else:
        viz_type = form_data.get('viz_type', 'table')
        datasource = ConnectorRegistry.get_datasource(
            datasource_type, datasource_id, db.session)
        viz_obj = viz.viz_types[viz_type](
            datasource,
            form_data=form_data,
            force=force,
        )
        return viz_obj
Exemplo n.º 29
0
 def configure_data_sources(self) -> None:
     # Registering sources
     module_datasource_map = self.config["DEFAULT_MODULE_DS_MAP"]
     module_datasource_map.update(self.config["ADDITIONAL_MODULE_DS_MAP"])
     ConnectorRegistry.register_sources(module_datasource_map)
Exemplo n.º 30
0
        return GET_FEATURE_FLAGS_FUNC(deepcopy(_feature_flags))
    return _feature_flags


def is_feature_enabled(feature):
    """Utility function for checking whether a feature is turned on"""
    return get_feature_flags().get(feature)


# Flask-Compress
if conf.get("ENABLE_FLASK_COMPRESS"):
    Compress(app)

talisman = Talisman()

if app.config["TALISMAN_ENABLED"]:
    talisman.init_app(app, **app.config["TALISMAN_CONFIG"])

# Hook that provides administrators a handle on the Flask APP
# after initialization
flask_app_mutator = app.config.get("FLASK_APP_MUTATOR")
if flask_app_mutator:
    flask_app_mutator(app)

from superset import views  # noqa

# Registering sources
module_datasource_map = app.config.get("DEFAULT_MODULE_DS_MAP")
module_datasource_map.update(app.config.get("ADDITIONAL_MODULE_DS_MAP"))
ConnectorRegistry.register_sources(module_datasource_map)
Exemplo n.º 31
0
        pass

for middleware in app.config.get('ADDITIONAL_MIDDLEWARE'):
    app.wsgi_app = middleware(app.wsgi_app)


class MyIndexView(IndexView):
    @expose('/')
    def index(self):
        return redirect('/superset/welcome')


appbuilder = AppBuilder(
    app,
    db.session,
    base_template='superset/base.html',
    indexview=MyIndexView,
    security_manager_class=app.config.get("CUSTOM_SECURITY_MANAGER"))

sm = appbuilder.sm

get_session = appbuilder.get_session
results_backend = app.config.get("RESULTS_BACKEND")

# Registering sources
module_datasource_map = app.config.get("DEFAULT_MODULE_DS_MAP")
module_datasource_map.update(app.config.get("ADDITIONAL_MODULE_DS_MAP"))
ConnectorRegistry.register_sources(module_datasource_map)

from superset import views  # noqa
Exemplo n.º 32
0
def get_datasource_by_id(datasource_id: int,
                         datasource_type: str) -> BaseDatasource:
    try:
        return ConnectorRegistry.get_datasource(datasource_type, datasource_id)
    except (NoResultFound, KeyError):
        raise DatasourceNotFoundValidationError()
Exemplo n.º 33
0
def create_query_object_factory() -> QueryObjectFactory:
    return QueryObjectFactory(config, ConnectorRegistry(), db.session)
Exemplo n.º 34
0
 def get(self, datasource_type: str, datasource_id: int) -> FlaskResponse:
     datasource = ConnectorRegistry.get_datasource(datasource_type,
                                                   datasource_id,
                                                   db.session)
     return self.json_response(datasource.data)
Exemplo n.º 35
0
 def _convert_to_model(self, datasource: DatasourceDict) -> BaseDatasource:
     return ConnectorRegistry.get_datasource(str(datasource["type"]),
                                             int(datasource["id"]),
                                             db.session)
Exemplo n.º 36
0
    def __init__(
        self,
        datasource: Optional[DatasourceDict] = None,
        result_type: Optional[ChartDataResultType] = None,
        annotation_layers: Optional[List[Dict[str, Any]]] = None,
        applied_time_extras: Optional[Dict[str, str]] = None,
        apply_fetch_values_predicate: bool = False,
        granularity: Optional[str] = None,
        metrics: Optional[List[Union[Dict[str, Any], str]]] = None,
        groupby: Optional[List[str]] = None,
        filters: Optional[List[Dict[str, Any]]] = None,
        time_range: Optional[str] = None,
        time_shift: Optional[str] = None,
        is_timeseries: Optional[bool] = None,
        timeseries_limit: int = 0,
        row_limit: Optional[int] = None,
        row_offset: Optional[int] = None,
        timeseries_limit_metric: Optional[Metric] = None,
        order_desc: bool = True,
        extras: Optional[Dict[str, Any]] = None,
        columns: Optional[List[str]] = None,
        orderby: Optional[List[OrderBy]] = None,
        post_processing: Optional[List[Optional[Dict[str, Any]]]] = None,
        is_rowcount: bool = False,
        **kwargs: Any,
    ):
        columns = columns or []
        groupby = groupby or []
        extras = extras or {}
        annotation_layers = annotation_layers or []

        self.is_rowcount = is_rowcount
        self.datasource = None
        if datasource:
            self.datasource = ConnectorRegistry.get_datasource(
                str(datasource["type"]), int(datasource["id"]), db.session)
        self.result_type = result_type
        self.apply_fetch_values_predicate = apply_fetch_values_predicate or False
        self.annotation_layers = [
            layer for layer in annotation_layers
            # formula annotations don't affect the payload, hence can be dropped
            if layer["annotationType"] != "FORMULA"
        ]
        self.applied_time_extras = applied_time_extras or {}
        self.granularity = granularity
        self.from_dttm, self.to_dttm = get_since_until(
            relative_start=extras.get("relative_start",
                                      config["DEFAULT_RELATIVE_START_TIME"]),
            relative_end=extras.get("relative_end",
                                    config["DEFAULT_RELATIVE_END_TIME"]),
            time_range=time_range,
            time_shift=time_shift,
        )
        # is_timeseries is True if time column is in either columns or groupby
        # (both are dimensions)
        self.is_timeseries = (is_timeseries if is_timeseries is not None else
                              DTTM_ALIAS in columns + groupby)
        self.time_range = time_range
        self.time_shift = parse_human_timedelta(time_shift)
        self.post_processing = [
            post_proc for post_proc in post_processing or [] if post_proc
        ]

        # Support metric reference/definition in the format of
        #   1. 'metric_name'   - name of predefined metric
        #   2. { label: 'label_name' }  - legacy format for a predefined metric
        #   3. { expressionType: 'SIMPLE' | 'SQL', ... } - adhoc metric
        self.metrics = metrics and [
            x if isinstance(x, str) or is_adhoc_metric(x) else
            x["label"]  # type: ignore
            for x in metrics
        ]

        self.row_limit = config["ROW_LIMIT"] if row_limit is None else row_limit
        self.row_offset = row_offset or 0
        self.filter = filters or []
        self.timeseries_limit = timeseries_limit
        self.timeseries_limit_metric = timeseries_limit_metric
        self.order_desc = order_desc
        self.extras = extras

        if config["SIP_15_ENABLED"]:
            self.extras["time_range_endpoints"] = get_time_range_endpoints(
                form_data=self.extras)

        self.columns = columns
        self.groupby = groupby or []
        self.orderby = orderby or []

        # rename deprecated fields
        for field in DEPRECATED_FIELDS:
            if field.old_name in kwargs:
                logger.warning(
                    "The field `%s` is deprecated, please use `%s` instead.",
                    field.old_name,
                    field.new_name,
                )
                value = kwargs[field.old_name]
                if value:
                    if hasattr(self, field.new_name):
                        logger.warning(
                            "The field `%s` is already populated, "
                            "replacing value with contents from `%s`.",
                            field.new_name,
                            field.old_name,
                        )
                    setattr(self, field.new_name, value)

        # move deprecated extras fields to extras
        for field in DEPRECATED_EXTRAS_FIELDS:
            if field.old_name in kwargs:
                logger.warning(
                    "The field `%s` is deprecated and should "
                    "be passed to `extras` via the `%s` property.",
                    field.old_name,
                    field.new_name,
                )
                value = kwargs[field.old_name]
                if value:
                    if hasattr(self.extras, field.new_name):
                        logger.warning(
                            "The field `%s` is already populated in "
                            "`extras`, replacing value with contents "
                            "from `%s`.",
                            field.new_name,
                            field.old_name,
                        )
                    self.extras[field.new_name] = value
Exemplo n.º 37
0
 def external_metadata(self, datasource_type=None, datasource_id=None):
     """Gets column info from the source system"""
     orm_datasource = ConnectorRegistry.get_datasource(
         datasource_type, datasource_id, db.session)
     return self.json_response(orm_datasource.external_metadata())
Exemplo n.º 38
0
 def fetch_all_datasources() -> List["BaseDatasource"]:
     return ConnectorRegistry.get_all_datasources(db.session)
Exemplo n.º 39
0
    def __init__(  # pylint: disable=too-many-arguments,too-many-locals
        self,
        query_context: "QueryContext",
        annotation_layers: Optional[List[Dict[str, Any]]] = None,
        applied_time_extras: Optional[Dict[str, str]] = None,
        apply_fetch_values_predicate: bool = False,
        columns: Optional[List[str]] = None,
        datasource: Optional[DatasourceDict] = None,
        extras: Optional[Dict[str, Any]] = None,
        filters: Optional[List[QueryObjectFilterClause]] = None,
        granularity: Optional[str] = None,
        is_rowcount: bool = False,
        is_timeseries: Optional[bool] = None,
        metrics: Optional[List[Metric]] = None,
        order_desc: bool = True,
        orderby: Optional[List[OrderBy]] = None,
        post_processing: Optional[List[Optional[Dict[str, Any]]]] = None,
        result_type: Optional[ChartDataResultType] = None,
        row_limit: Optional[int] = None,
        row_offset: Optional[int] = None,
        series_columns: Optional[List[str]] = None,
        series_limit: int = 0,
        series_limit_metric: Optional[Metric] = None,
        time_range: Optional[str] = None,
        time_shift: Optional[str] = None,
        **kwargs: Any,
    ):
        columns = columns or []
        extras = extras or {}
        annotation_layers = annotation_layers or []
        self.time_offsets = kwargs.get("time_offsets", [])
        self.inner_from_dttm = kwargs.get("inner_from_dttm")
        self.inner_to_dttm = kwargs.get("inner_to_dttm")
        if series_columns:
            self.series_columns = series_columns
        elif is_timeseries and metrics:
            self.series_columns = columns
        else:
            self.series_columns = []

        self.is_rowcount = is_rowcount
        self.datasource = None
        if datasource:
            self.datasource = ConnectorRegistry.get_datasource(
                str(datasource["type"]), int(datasource["id"]), db.session)
        self.result_type = result_type or query_context.result_type
        self.apply_fetch_values_predicate = apply_fetch_values_predicate or False
        self.annotation_layers = [
            layer for layer in annotation_layers
            # formula annotations don't affect the payload, hence can be dropped
            if layer["annotationType"] != "FORMULA"
        ]
        self.applied_time_extras = applied_time_extras or {}
        self.granularity = granularity
        self.from_dttm, self.to_dttm = get_since_until(
            relative_start=extras.get("relative_start",
                                      config["DEFAULT_RELATIVE_START_TIME"]),
            relative_end=extras.get("relative_end",
                                    config["DEFAULT_RELATIVE_END_TIME"]),
            time_range=time_range,
            time_shift=time_shift,
        )
        # is_timeseries is True if time column is in either columns or groupby
        # (both are dimensions)
        self.is_timeseries = (is_timeseries if is_timeseries is not None else
                              DTTM_ALIAS in columns)
        self.time_range = time_range
        self.time_shift = parse_human_timedelta(time_shift)
        self.post_processing = [
            post_proc for post_proc in post_processing or [] if post_proc
        ]

        # Support metric reference/definition in the format of
        #   1. 'metric_name'   - name of predefined metric
        #   2. { label: 'label_name' }  - legacy format for a predefined metric
        #   3. { expressionType: 'SIMPLE' | 'SQL', ... } - adhoc metric
        self.metrics = metrics and [
            x if isinstance(x, str) or is_adhoc_metric(x) else
            x["label"]  # type: ignore
            for x in metrics
        ]

        default_row_limit = (config["SAMPLES_ROW_LIMIT"]
                             if self.result_type == ChartDataResultType.SAMPLES
                             else config["ROW_LIMIT"])
        self.row_limit = apply_max_row_limit(row_limit or default_row_limit)
        self.row_offset = row_offset or 0
        self.filter = filters or []
        self.series_limit = series_limit
        self.series_limit_metric = series_limit_metric
        self.order_desc = order_desc
        self.extras = extras

        if config["SIP_15_ENABLED"]:
            self.extras["time_range_endpoints"] = get_time_range_endpoints(
                form_data=self.extras)

        self.columns = columns
        self.orderby = orderby or []

        self._rename_deprecated_fields(kwargs)
        self._move_deprecated_extra_fields(kwargs)
Exemplo n.º 40
0
    def invalidate(self) -> Response:
        """
        Takes a list of datasources, finds the associated cache records and
        invalidates them and removes the database records

        ---
        post:
          description: >-
            Takes a list of datasources, finds the associated cache records and
            invalidates them and removes the database records
          requestBody:
            description: >-
              A list of datasources uuid or the tuples of database and datasource names
            required: true
            content:
              application/json:
                schema:
                  $ref: "#/components/schemas/CacheInvalidationRequestSchema"
          responses:
            201:
              description: cache was successfully invalidated
            400:
              $ref: '#/components/responses/400'
            500:
              $ref: '#/components/responses/500'
        """
        try:
            datasources = CacheInvalidationRequestSchema().load(request.json)
        except KeyError:
            return self.response_400(message="Request is incorrect")
        except ValidationError as error:
            return self.response_400(message=str(error))
        datasource_uids = set(datasources.get("datasource_uids", []))
        for ds in datasources.get("datasources", []):
            ds_obj = ConnectorRegistry.get_datasource_by_name(
                session=db.session,
                datasource_type=ds.get("datasource_type"),
                datasource_name=ds.get("datasource_name"),
                schema=ds.get("schema"),
                database_name=ds.get("database_name"),
            )
            if ds_obj:
                datasource_uids.add(ds_obj.uid)

        cache_key_objs = (db.session.query(CacheKey).filter(
            CacheKey.datasource_uid.in_(datasource_uids)).all())
        cache_keys = [c.cache_key for c in cache_key_objs]
        if cache_key_objs:
            all_keys_deleted = cache_manager.cache.delete_many(*cache_keys)

            if not all_keys_deleted:
                # expected behavior as keys may expire and cache is not a
                # persistent storage
                logger.info(
                    "Some of the cache keys were not deleted in the list %s",
                    cache_keys)

            try:
                delete_stmt = CacheKey.__table__.delete().where(  # pylint: disable=no-member
                    CacheKey.cache_key.in_(cache_keys))
                db.session.execute(delete_stmt)
                db.session.commit()
                self.stats_logger.gauge("invalidated_cache", len(cache_keys))
                logger.info(
                    "Invalidated %s cache records for %s datasources",
                    len(cache_keys),
                    len(datasource_uids),
                )
            except SQLAlchemyError as ex:  # pragma: no cover
                logger.error(ex)
                db.session.rollback()
                return self.response_500(str(ex))
            db.session.commit()
        return self.response(201)
Exemplo n.º 41
0
 def external_metadata(self, datasource_type=None, datasource_id=None):
     """Gets column info from the source system"""
     orm_datasource = ConnectorRegistry.get_datasource(
         datasource_type, datasource_id, db.session)
     return self.json_response(orm_datasource.external_metadata())
Exemplo n.º 42
0
from superset.connectors.connector_registry import ConnectorRegistry
from flask_appbuilder.security.sqla.models import User
from superset.viz import TableViz
from superset import db
from flask import g
import json

f = open('d:\\home\\pivot_viz\\form_data.json')
form_data_json = f.read()
f.close()

g.user = db.session.query(User).filter(User.id == 2).one()
form_data = json.loads(form_data_json)
datasource_type = form_data.get('viz_type', 'table')
datasource_id = form_data['datasource'].split('__')[0]
datasource = ConnectorRegistry.get_datasource(datasource_type, datasource_id,
                                              db.session)
viz = TableViz(datasource, form_data)
res = viz.get_payload(force=True)
# print(res)
with open('d:\\home\\pivot_viz\\response.json', 'w', encoding='utf8') as f:
    f.write(viz.json_dumps(res))
Exemplo n.º 43
0
    def add(self) -> FlaskResponse:
        allowed_datasources = []
        datasources = []

        # only if gamma
        is_gamma = False
        logging.debug('-------------------------')
        for role in g.user.roles:
            if str(role) == 'Gamma':
                is_gamma = True
            logging.debug(role.permissions)
            for perm in role.permissions:
                if str(perm).startswith('datasource access on ['):
                    #'datasource access on [DB].[DATASOURCE](id:ID)')
                    data_search = re.search(
                        'datasource access on \[([^\]]+)\]\.\[([^\]]+)\]\(id:([^\)]+)\)',
                        str(perm))
                    if data_search:
                        allowed_datasources.append({
                            "connection":
                            data_search.group(1),
                            "name":
                            data_search.group(2),
                            "id":
                            data_search.group(3)
                        })
        for d in ConnectorRegistry.get_all_datasources(db.session):
            if (is_gamma):
                for a in allowed_datasources:
                    table_name = d.short_data.get("name").split('.')[-1]
                    if table_name == a.get("name") and d.short_data.get(
                            "connection") == a.get("connection") and str(
                                d.short_data.get("id")) == str(a.get("id")):
                        if hasattr(d, 'custom_label'):
                            datasources.append({
                                "value": str(d.id) + "__" + d.type,
                                "label": d.custom_label
                            })
                        else:
                            datasources.append({
                                "value": str(d.id) + "__" + d.type,
                                "label": repr(d)
                            })
            else:
                if hasattr(d, 'custom_label'):
                    datasources.append({
                        "value": str(d.id) + "__" + d.type,
                        "label": d.custom_label
                    })
                else:
                    datasources.append({
                        "value": str(d.id) + "__" + d.type,
                        "label": repr(d)
                    })
        payload = {
            "datasources": sorted(datasources, key=lambda d: d["label"]),
            "common": common_bootstrap_payload(),
            "user": bootstrap_user_data(g.user),
        }
        return self.render_template("superset/add_slice.html",
                                    bootstrap_data=json.dumps(payload))