Exemplo n.º 1
0
def session_scope(nullpool: bool) -> Iterator[Session]:
    """Provide a transactional scope around a series of operations."""
    database_uri = app.config["SQLALCHEMY_DATABASE_URI"]
    if "sqlite" in database_uri:
        logger.warning(
            "SQLite Database support for metadata databases will be removed \
            in a future version of Superset."
        )
    if nullpool:
        engine = sa.create_engine(database_uri, poolclass=NullPool)
        session_class = sessionmaker()
        session_class.configure(bind=engine)
        session = session_class()
    else:
        session = db.session()
        session.commit()  # HACK

    try:
        yield session
        session.commit()
    except Exception as ex:
        session.rollback()
        logger.exception(ex)
        raise
    finally:
        session.close()
Exemplo n.º 2
0
        def wrapper(*args, **kwargs):
            start_dttm = datetime.now()
            user_id = None
            if g.user:
                user_id = g.user.get_id()
            d = request.args.to_dict()
            post_data = request.form or {}
            d.update(post_data)
            d.update(kwargs)
            slice_id = d.get('slice_id', 0)
            try:
                slice_id = int(slice_id) if slice_id else 0
            except ValueError:
                slice_id = 0
            params = ""
            try:
                params = json.dumps(d)
            except:
                pass
            stats_logger.incr(f.__name__)
            value = f(*args, **kwargs)

            sesh = db.session()
            log = cls(
                action=f.__name__,
                json=params,
                dashboard_id=d.get('dashboard_id') or None,
                slice_id=slice_id,
                duration_ms=(
                    datetime.now() - start_dttm).total_seconds() * 1000,
                referrer=request.referrer[:1000] if request.referrer else None,
                user_id=user_id)
            sesh.add(log)
            sesh.commit()
            return value
Exemplo n.º 3
0
class SqlMetricInlineView(CompactCRUDMixin, SupersetModelView):
    datamodel = SQLAInterface(models.SqlMetric)

    list_title = _("Metrics")
    show_title = _("Show Metric")
    add_title = _("Add Metric")
    edit_title = _("Edit Metric")

    list_columns = ["metric_name", "verbose_name", "metric_type"]
    edit_columns = [
        "metric_name",
        "description",
        "verbose_name",
        "metric_type",
        "expression",
        "table",
        "d3format",
        "warning_text",
    ]
    description_columns = {
        "expression":
        utils.markdown(
            "a valid, *aggregating* SQL expression as supported by the "
            "underlying backend. Example: `count(DISTINCT userid)`",
            True,
        ),
        "d3format":
        utils.markdown(
            "d3 formatting string as defined [here]"
            "(https://github.com/d3/d3-format/blob/master/README.md#format). "
            "For instance, this default formatting applies in the Table "
            "visualization and allow for different metric to use different "
            "formats",
            True,
        ),
    }
    add_columns = edit_columns
    page_size = 500
    label_columns = {
        "metric_name": _("Metric"),
        "description": _("Description"),
        "verbose_name": _("Verbose Name"),
        "metric_type": _("Type"),
        "expression": _("SQL Expression"),
        "table": _("Table"),
        "d3format": _("D3 Format"),
        "warning_text": _("Warning Message"),
    }

    add_form_extra_fields = {
        "table":
        QuerySelectField(
            "Table",
            query_factory=lambda: db.session().query(models.SqlaTable),
            allow_blank=True,
            widget=Select2Widget(extra_classes="readonly"),
        )
    }

    edit_form_extra_fields = add_form_extra_fields
Exemplo n.º 4
0
 def refresh_datasources(  # pylint: disable=no-self-use
         self, refresh_all: bool = True) -> FlaskResponse:
     """endpoint that refreshes druid datasources metadata"""
     session = db.session()
     DruidCluster = ConnectorRegistry.sources[  # pylint: disable=invalid-name
         "druid"].cluster_class
     for cluster in session.query(DruidCluster).all():
         cluster_name = cluster.cluster_name
         valid_cluster = True
         try:
             cluster.refresh_datasources(refresh_all=refresh_all)
         except Exception as ex:  # pylint: disable=broad-except
             valid_cluster = False
             flash(
                 "Error while processing cluster '{}'\n{}".format(
                     cluster_name, utils.error_msg_from_exception(ex)),
                 "danger",
             )
             logger.exception(ex)
         if valid_cluster:
             cluster.metadata_last_refreshed = datetime.now()
             flash(
                 _("Refreshed metadata from cluster [{}]").format(
                     cluster.cluster_name),
                 "info",
             )
     session.commit()
     return redirect("/druiddatasourcemodelview/list/")
Exemplo n.º 5
0
def get_dashboard_extra_filters(slice_id: int,
                                dashboard_id: int) -> List[Dict[str, Any]]:
    session = db.session()
    dashboard = session.query(Dashboard).filter_by(
        id=dashboard_id).one_or_none()

    # is chart in this dashboard?
    if (dashboard is None or not dashboard.json_metadata
            or not dashboard.slices
            or not any([slc
                        for slc in dashboard.slices if slc.id == slice_id])):
        return []

    try:
        # does this dashboard have default filters?
        json_metadata = json.loads(dashboard.json_metadata)
        default_filters = json.loads(
            json_metadata.get("default_filters", "null"))
        if not default_filters:
            return []

        # are default filters applicable to the given slice?
        filter_scopes = json_metadata.get("filter_scopes", {})
        layout = json.loads(dashboard.position_json or "{}")

        if (isinstance(layout, dict) and isinstance(filter_scopes, dict)
                and isinstance(default_filters, dict)):
            return build_extra_filters(layout, filter_scopes, default_filters,
                                       slice_id)
    except json.JSONDecodeError:
        pass

    return []
Exemplo n.º 6
0
 def refresh_datasources(self, refreshAll=True):
     """endpoint that refreshes druid datasources metadata"""
     session = db.session()
     DruidCluster = ConnectorRegistry.sources['druid'].cluster_class
     for cluster in session.query(DruidCluster).all():
         cluster_name = cluster.cluster_name
         valid_cluster = True
         try:
             cluster.refresh_datasources(refreshAll=refreshAll)
         except Exception as e:
             valid_cluster = False
             flash(
                 "Error while processing cluster '{}'\n{}".format(
                     cluster_name, utils.error_msg_from_exception(e)),
                 'danger')
             logging.exception(e)
             pass
         if valid_cluster:
             cluster.metadata_last_refreshed = datetime.now()
             flash(
                 _('Refreshed metadata from cluster [{}]').format(
                     cluster.cluster_name),
                 'info')
     session.commit()
     return redirect('/druiddatasourcemodelview/list/')
Exemplo n.º 7
0
    def test_get_dashboard_changed_on(self):
        session = db.session()
        dashboard = session.query(Dashboard).filter_by(
            slug="world_health").first()

        assert dashboard.changed_on == DashboardDAO.get_dashboard_changed_on(
            dashboard)
        assert dashboard.changed_on == DashboardDAO.get_dashboard_changed_on(
            "world_health")

        old_changed_on = dashboard.changed_on

        # freezegun doesn't work for some reason, so we need to sleep here :(
        time.sleep(1)
        data = dashboard.data
        positions = data["position_json"]
        data.update({"positions": positions})
        original_data = copy.deepcopy(data)

        data.update({"foo": "bar"})
        DashboardDAO.set_dash_metadata(dashboard, data)
        session.merge(dashboard)
        session.commit()
        assert old_changed_on < DashboardDAO.get_dashboard_changed_on(
            dashboard)

        DashboardDAO.set_dash_metadata(dashboard, original_data)
        session.merge(dashboard)
        session.commit()
Exemplo n.º 8
0
 def refresh_datasources(self, refreshAll=True):
     """endpoint that refreshes druid datasources metadata"""
     session = db.session()
     DruidCluster = ConnectorRegistry.sources['druid'].cluster_class
     for cluster in session.query(DruidCluster).all():
         cluster_name = cluster.cluster_name
         valid_cluster = True
         try:
             cluster.refresh_datasources(refreshAll=refreshAll)
         except Exception as e:
             valid_cluster = False
             flash(
                 "Error while processing cluster '{}'\n{}".format(
                     cluster_name, utils.error_msg_from_exception(e)),
                 'danger')
             logging.exception(e)
             pass
         if valid_cluster:
             cluster.metadata_last_refreshed = datetime.now()
             flash(
                 _('Refreshed metadata from cluster [{}]').format(
                     cluster.cluster_name),
                 'info')
     session.commit()
     return redirect('/druiddatasourcemodelview/list/')
Exemplo n.º 9
0
    def test_get_dashboard_changed_on(self):
        self.login(username="******")
        session = db.session()
        dashboard = session.query(Dashboard).filter_by(slug="world_health").first()

        changed_on = dashboard.changed_on.replace(microsecond=0)
        assert changed_on == DashboardDAO.get_dashboard_changed_on(dashboard)
        assert changed_on == DashboardDAO.get_dashboard_changed_on("world_health")

        old_changed_on = dashboard.changed_on

        # freezegun doesn't work for some reason, so we need to sleep here :(
        time.sleep(1)
        data = dashboard.data
        positions = data["position_json"]
        data.update({"positions": positions})
        original_data = copy.deepcopy(data)

        data.update({"foo": "bar"})
        DashboardDAO.set_dash_metadata(dashboard, data)
        session.merge(dashboard)
        session.commit()
        new_changed_on = DashboardDAO.get_dashboard_changed_on(dashboard)
        assert old_changed_on.replace(microsecond=0) < new_changed_on
        assert new_changed_on == DashboardDAO.get_dashboard_and_datasets_changed_on(
            dashboard
        )
        assert new_changed_on == DashboardDAO.get_dashboard_and_slices_changed_on(
            dashboard
        )

        DashboardDAO.set_dash_metadata(dashboard, original_data)
        session.merge(dashboard)
        session.commit()
Exemplo n.º 10
0
        def wrapper(*args, **kwargs):
            start_dttm = datetime.now()
            user_id = None
            if g.user:
                user_id = g.user.get_id()
            d = request.args.to_dict()
            post_data = request.form or {}
            d.update(post_data)
            d.update(kwargs)
            slice_id = d.get('slice_id', 0)
            try:
                slice_id = int(slice_id) if slice_id else 0
            except ValueError:
                slice_id = 0
            params = ""
            try:
                params = json.dumps(d)
            except:
                pass
            stats_logger.incr(f.__name__)
            value = f(*args, **kwargs)

            sesh = db.session()
            log = cls(
                action=f.__name__,
                json=params,
                dashboard_id=d.get('dashboard_id') or None,
                slice_id=slice_id,
                duration_ms=(
                    datetime.now() - start_dttm).total_seconds() * 1000,
                referrer=request.referrer[:1000] if request.referrer else None,
                user_id=user_id)
            sesh.add(log)
            sesh.commit()
            return value
Exemplo n.º 11
0
    def get(cls, id_or_slug: str) -> Dashboard:
        session = db.session()
        qry = session.query(Dashboard)
        if id_or_slug.isdigit():
            qry = qry.filter_by(id=int(id_or_slug))
        else:
            qry = qry.filter_by(slug=id_or_slug)

        return qry.one_or_none()
Exemplo n.º 12
0
def get_session(nullpool):
    if nullpool:
        engine = sqlalchemy.create_engine(
            app.config.get('SQLALCHEMY_DATABASE_URI'), poolclass=NullPool)
        session_class = sessionmaker()
        session_class.configure(bind=engine)
        return session_class()
    session = db.session()
    session.commit()  # HACK
    return session
Exemplo n.º 13
0
def get_session(nullpool):
    if nullpool:
        engine = sqlalchemy.create_engine(
            app.config.get('SQLALCHEMY_DATABASE_URI'), poolclass=NullPool)
        session_class = sessionmaker()
        session_class.configure(bind=engine)
        return session_class()
    session = db.session()
    session.commit()  # HACK
    return session
Exemplo n.º 14
0
        def wrapper(*args, **kwargs):
            user_id = None
            if g.user:
                user_id = g.user.get_id()
            d = request.form.to_dict() or {}

            # request parameters can overwrite post body
            request_params = request.args.to_dict()
            d.update(request_params)
            d.update(kwargs)

            slice_id = d.get("slice_id")
            dashboard_id = d.get("dashboard_id")

            try:
                slice_id = int(
                    slice_id or json.loads(d.get("form_data")).get("slice_id"))
            except (ValueError, TypeError):
                slice_id = 0

            stats_logger.incr(f.__name__)
            start_dttm = datetime.now()
            value = f(*args, **kwargs)
            duration_ms = (datetime.now() - start_dttm).total_seconds() * 1000

            # bulk insert
            try:
                explode_by = d.get("explode")
                records = json.loads(d.get(explode_by))
            except Exception:
                records = [d]

            referrer = request.referrer[:1000] if request.referrer else None
            logs = []
            for record in records:
                try:
                    json_string = json.dumps(record)
                except Exception:
                    json_string = None
                log = cls(
                    action=f.__name__,
                    json=json_string,
                    dashboard_id=dashboard_id,
                    slice_id=slice_id,
                    duration_ms=duration_ms,
                    referrer=referrer,
                    user_id=user_id,
                )
                logs.append(log)

            sesh = db.session()
            sesh.bulk_save_objects(logs)
            sesh.commit()
            return value
Exemplo n.º 15
0
        def wrapper(*args, **kwargs):
            user_id = None
            if g.user:
                user_id = g.user.get_id()
            d = request.form.to_dict() or {}

            # request parameters can overwrite post body
            request_params = request.args.to_dict()
            d.update(request_params)
            d.update(kwargs)

            slice_id = d.get('slice_id')
            dashboard_id = d.get('dashboard_id')

            try:
                slice_id = int(
                    slice_id or json.loads(d.get('form_data')).get('slice_id'))
            except (ValueError, TypeError):
                slice_id = 0

            stats_logger.incr(f.__name__)
            start_dttm = datetime.now()
            value = f(*args, **kwargs)
            duration_ms = (datetime.now() - start_dttm).total_seconds() * 1000

            # bulk insert
            try:
                explode_by = d.get('explode')
                records = json.loads(d.get(explode_by))
            except Exception:
                records = [d]

            referrer = request.referrer[:1000] if request.referrer else None
            logs = []
            for record in records:
                try:
                    json_string = json.dumps(record)
                except Exception:
                    json_string = None
                log = cls(
                    action=f.__name__,
                    json=json_string,
                    dashboard_id=dashboard_id,
                    slice_id=slice_id,
                    duration_ms=duration_ms,
                    referrer=referrer,
                    user_id=user_id)
                logs.append(log)

            sesh = db.session()
            sesh.bulk_save_objects(logs)
            sesh.commit()
            return value
Exemplo n.º 16
0
class DruidMetricInlineView(CompactCRUDMixin, SupersetModelView):
    datamodel = SQLAInterface(models.DruidMetric)
    include_route_methods = RouteMethod.RELATED_VIEW_SET

    list_title = _("Metrics")
    show_title = _("Show Druid Metric")
    add_title = _("Add Druid Metric")
    edit_title = _("Edit Druid Metric")

    list_columns = ["metric_name", "verbose_name", "metric_type"]
    edit_columns = [
        "metric_name",
        "description",
        "verbose_name",
        "metric_type",
        "json",
        "datasource",
        "d3format",
        "warning_text",
    ]
    add_columns = edit_columns
    page_size = 500
    validators_columns = {"json": [validate_json]}
    description_columns = {
        "metric_type":
        utils.markdown(
            "use `postagg` as the metric type if you are defining a "
            "[Druid Post Aggregation]"
            "(http://druid.io/docs/latest/querying/post-aggregations.html)",
            True,
        )
    }
    label_columns = {
        "metric_name": _("Metric"),
        "description": _("Description"),
        "verbose_name": _("Verbose Name"),
        "metric_type": _("Type"),
        "json": _("JSON"),
        "datasource": _("Druid Datasource"),
        "warning_text": _("Warning Message"),
    }

    add_form_extra_fields = {
        "datasource":
        QuerySelectField(
            "Datasource",
            query_factory=lambda: db.session().query(models.DruidDatasource),
            allow_blank=True,
            widget=Select2Widget(extra_classes="readonly"),
        )
    }

    edit_form_extra_fields = add_form_extra_fields
Exemplo n.º 17
0
def get_dashboard(dashboard_id_or_slug: str,) -> Dashboard:
    session = db.session()
    qry = session.query(Dashboard)
    if dashboard_id_or_slug.isdigit():
        qry = qry.filter_by(id=int(dashboard_id_or_slug))
    else:
        qry = qry.filter_by(slug=dashboard_id_or_slug)
    dashboard = qry.one_or_none()

    if not dashboard:
        abort(404)

    return dashboard
Exemplo n.º 18
0
def refresh_druid(datasource, merge):
    """Refresh druid datasources"""
    session = db.session()
    from superset.connectors.druid.models import DruidCluster

    for cluster in session.query(DruidCluster).all():
        try:
            cluster.refresh_datasources(datasource_name=datasource, merge_flag=merge)
        except Exception as e:
            print("Error while processing cluster '{}'\n{}".format(cluster, str(e)))
            logging.exception(e)
        cluster.metadata_last_refreshed = datetime.now()
        print("Refreshed metadata from cluster " "[" + cluster.cluster_name + "]")
    session.commit()
Exemplo n.º 19
0
def scan_druid():
    """Scan druid datasources"""
    session = db.session()
    from superset.connectors.druid.models import DruidCluster
    for cluster in session.query(DruidCluster).all():
        try:
            cluster.refresh_datasources(refreshAll=True)
        except Exception as e:
            print("Error while scanning cluster '{}'\n{}".format(
                cluster, str(e)))
            logging.exception(e)
        cluster.metadata_last_refreshed = datetime.now()
        print('Refreshed metadata from cluster '
              '[' + cluster.cluster_name + ']')
    session.commit()
Exemplo n.º 20
0
def async_etl(etl_id=0):
    """Run Etl Sync Data From DataSource to DWH."""

    msq = 'Run Async ETL Id ({}) Task'.format(etl_id)
    print(msq)
    logger.info(msq)

    # db_session = app_manager.get_db().session()
    db_session = db.session()
    try:
        db_session.query(EtlTable).filter_by(id=etl_id).one().sync()
    except Exception as e:
        print(e)
        logger.exception(e)
        raise
Exemplo n.º 21
0
def refresh_druid(datasource):
    """Refresh druid datasources"""
    session = db.session()
    from superset import models
    for cluster in session.query(models.DruidCluster).all():
        try:
            cluster.refresh_datasources(datasource_name=datasource)
        except Exception as e:
            print(
                "Error while processing cluster '{}'\n{}".format(
                    cluster, str(e)))
            logging.exception(e)
        cluster.metadata_last_refreshed = datetime.now()
        print(
            "Refreshed metadata from cluster "
            "[" + cluster.cluster_name + "]")
    session.commit()
Exemplo n.º 22
0
def refresh_druid(datasource, merge):
    """Refresh druid datasources"""
    session = db.session()
    from superset.connectors.druid.models import DruidCluster
    for cluster in session.query(DruidCluster).all():
        try:
            cluster.refresh_datasources(datasource_name=datasource,
                                        merge_flag=merge)
        except Exception as e:
            print(
                "Error while processing cluster '{}'\n{}".format(
                    cluster, str(e)))
            logging.exception(e)
        cluster.metadata_last_refreshed = datetime.now()
        print(
            'Refreshed metadata from cluster '
            '[' + cluster.cluster_name + ']')
    session.commit()
Exemplo n.º 23
0
def add_slice_to_dashboard(request,
                           args,
                           datasource_type=None,
                           datasource_id=None):
    form_data = json.loads(args.get('form_data'))
    datasource_id = args.get('datasource_id')
    datasource_type = args.get('datasource_type')
    datasource_name = args.get('datasource_name')
    viz_type = form_data.get('viz_type')

    form_data['datasource'] = str(datasource_id) + '__' + datasource_type

    # On explore, merge legacy and extra filters into the form data
    utils.convert_legacy_filters_into_adhoc(form_data)
    utils.merge_extra_filters(form_data)
    """Save or overwrite a slice"""
    slice_name = args.get('slice_name')
    action = args.get('action')
    #saving slice
    slc = models.Slice(owners=[g.user] if g.user else [])
    slc.params = json.dumps(form_data, indent=2, sort_keys=True)
    slc.datasource_name = datasource_name
    slc.viz_type = form_data['viz_type']
    slc.datasource_type = datasource_type
    slc.datasource_id = datasource_id
    slc.slice_name = slice_name
    session = db.session()
    session.add(slc)
    session.commit()

    #adding slice to dashboard
    dash = (db.session.query(models.Dashboard).filter_by(
        id=int(args.get('save_to_dashboard_id'))).one())

    dash.slices.append(slc)
    db.session.commit()
    logging.info('Slice [' + slc.slice_name +
                 '] was added to dashboard id [ ' +
                 str(args.get('save_to_dashboard_id')) + ' ]')

    return {
        'form_data': slc.form_data,
        'slice': slc.data,
    }
Exemplo n.º 24
0
def session_scope(nullpool):
    """Provide a transactional scope around a series of operations."""
    if nullpool:
        engine = sqlalchemy.create_engine(
            app.config["SQLALCHEMY_DATABASE_URI"], poolclass=NullPool)
        session_class = sessionmaker()
        session_class.configure(bind=engine)
        session = session_class()
    else:
        session = db.session()
        session.commit()  # HACK

    try:
        yield session
        session.commit()
    except Exception as e:
        session.rollback()
        logger.exception(e)
        raise
    finally:
        session.close()
Exemplo n.º 25
0
def session_scope(nullpool):
    """Provide a transactional scope around a series of operations."""
    if nullpool:
        engine = sqlalchemy.create_engine(
            app.config.get('SQLALCHEMY_DATABASE_URI'), poolclass=NullPool)
        session_class = sessionmaker()
        session_class.configure(bind=engine)
        session = session_class()
    else:
        session = db.session()
        session.commit()  # HACK

    try:
        yield session
        session.commit()
    except Exception as e:
        session.rollback()
        logging.exception(e)
        raise
    finally:
        session.close()
Exemplo n.º 26
0
 def refresh_datasources(self):
     """endpoint that refreshes druid datasources metadata"""
     session = db.session()
     DruidCluster = ConnectorRegistry.sources['druid'].cluster_class
     for cluster in session.query(DruidCluster).all():
         cluster_name = cluster.cluster_name
         try:
             cluster.refresh_datasources()
         except Exception as e:
             flash(
                 "Error while processing cluster '{}'\n{}".format(
                     cluster_name, utils.error_msg_from_exception(e)),
                 "danger")
             logging.exception(e)
             return redirect('/druidclustermodelview/list/')
         cluster.metadata_last_refreshed = datetime.now()
         flash(
             "Refreshed metadata from cluster "
             "[" + cluster.cluster_name + "]", 'info')
     session.commit()
     return redirect("/druiddatasourcemodelview/list/")
Exemplo n.º 27
0
 def refresh_datasources(self):
     """endpoint that refreshes druid datasources metadata"""
     session = db.session()
     DruidCluster = ConnectorRegistry.sources['druid'].cluster_class
     for cluster in session.query(DruidCluster).all():
         cluster_name = cluster.cluster_name
         try:
             cluster.refresh_datasources()
         except Exception as e:
             flash(
                 "Error while processing cluster '{}'\n{}".format(
                     cluster_name, utils.error_msg_from_exception(e)),
                 "danger")
             logging.exception(e)
             return redirect('/druidclustermodelview/list/')
         cluster.metadata_last_refreshed = datetime.now()
         flash(
             "Refreshed metadata from cluster "
             "[" + cluster.cluster_name + "]",
             'info')
     session.commit()
     return redirect("/druiddatasourcemodelview/list/")
Exemplo n.º 28
0
 def refresh_datasources(self):
     """endpoint that refreshes elastic datasources metadata"""
     session = db.session()
     elastic_cluster = ConnectorRegistry.sources['elastic'].cluster_class
     for cluster in session.query(elastic_cluster).all():
         cluster_name = cluster.cluster_name
         try:
             cluster.refresh_datasources()
         except Exception as e:
             flash(
                 'Error while processing cluster \'{}\'\n{}'.format(
                     cluster_name, error_msg_from_exception(e)),
                 'danger')
             logging.exception(e)
             return redirect('/elasticclustermodelview/list/')
         cluster.metadata_last_refreshed = datetime.now()
         flash(
             'Refreshed metadata from cluster '
             '[' + cluster.cluster_name + ']',
             'info')
     session.commit()
     return redirect('/elasticdatasourcemodelview/list/')
Exemplo n.º 29
0
        def wrapper(*args, **kwargs):
            start_dttm = datetime.now()
            user_id = None
            if g.user:
                user_id = g.user.get_id()
            d = request.form.to_dict() or {}
            # request parameters can overwrite post body
            request_params = request.args.to_dict()
            d.update(request_params)
            d.update(kwargs)
            slice_id = d.get('slice_id')

            try:
                slice_id = int(
                    slice_id or json.loads(d.get('form_data')).get('slice_id'))
            except (ValueError, TypeError):
                slice_id = 0

            params = ''
            try:
                params = json.dumps(d)
            except Exception:
                pass
            stats_logger.incr(f.__name__)
            value = f(*args, **kwargs)
            sesh = db.session()
            log = cls(
                action=f.__name__,
                json=params,
                dashboard_id=d.get('dashboard_id'),
                slice_id=slice_id,
                duration_ms=(datetime.now() - start_dttm).total_seconds() *
                1000,
                referrer=request.referrer[:1000] if request.referrer else None,
                user_id=user_id)
            sesh.add(log)
            sesh.commit()
            return value
Exemplo n.º 30
0
    def run(self):
        session = db.session()

        # Editor
        alpha = session.query(Role).filter(Role.name == 'Alpha').first()
        editor = session.query(Role).filter(Role.name == 'Editor').first()
        if not editor:
            editor = Role()
        editor.name = 'Editor'
        editor.permissions = alpha.permissions
        print('\nCopying Alpha role to Editor...')
        SQLAInterface(Role, session).add(editor)
        print('Generating custom Editor permissions from SQL...')
        db.engine.execute(EDITOR_SQL)
        print('Editor role created successfully.\n')

        # Viewer
        gamma = session.query(Role).filter(Role.name == 'Gamma').first()
        viewer = session.query(Role).filter(Role.name == 'Viewer').first()
        if not viewer:
            viewer = Role()
        viewer.name = 'Viewer'
        viewer.permissions = gamma.permissions
        print('Copying Gamma role to Viewer...')
        SQLAInterface(Role, session).add(viewer)
        print('Generating custom Viewer permissions from SQL...')
        db.engine.execute(VIEWER_SQL)
        print('Viewer role created successfully.')

        engine = sqlalchemy.create_engine(SQLALCHEMY_ROOT_DATABASE_URI)
        root_conn = engine.raw_connection()
        with root_conn.cursor() as cursor:
            print('\nGranting all privileges to the superset db user...')
            grant = '''
            GRANT ALL PRIVILEGES ON *.* TO 'superset'@'%';
            FLUSH PRIVILEGES;
            '''
            cursor.execute(grant)
Exemplo n.º 31
0
        def wrapper(*args, **kwargs):
            start_dttm = datetime.now()
            user_id = None
            if g.user:
                user_id = g.user.get_id()
            d = request.form.to_dict() or {}
            # request parameters can overwrite post body
            request_params = request.args.to_dict()
            d.update(request_params)
            d.update(kwargs)
            slice_id = d.get('slice_id')

            try:
                slice_id = int(
                    slice_id or json.loads(d.get('form_data')).get('slice_id'))
            except (ValueError, TypeError):
                slice_id = 0

            params = ''
            try:
                params = json.dumps(d)
            except Exception:
                pass
            stats_logger.incr(f.__name__)
            value = f(*args, **kwargs)
            sesh = db.session()
            log = cls(
                action=f.__name__,
                json=params,
                dashboard_id=d.get('dashboard_id'),
                slice_id=slice_id,
                duration_ms=(
                    datetime.now() - start_dttm).total_seconds() * 1000,
                referrer=request.referrer[:1000] if request.referrer else None,
                user_id=user_id)
            sesh.add(log)
            sesh.commit()
            return value
Exemplo n.º 32
0
    def _allow_csv_upload_databases(self) -> list:
        """ Get all databases which allow csv upload as database dto
        :returns list of database dto
        """
        databases = db.session().query(Database).filter_by(
            allow_csv_upload=True).all()
        permitted_databases: list = []
        for database in databases:
            if security_manager.database_access(database):
                permitted_databases.append(database)

        databases_json = [
            DatabaseDto(NEW_DATABASE_ID, "In a new database", [])
        ]
        for database in permitted_databases:
            databases_json.append(
                DatabaseDto(
                    database.id,
                    database.name,
                    json.loads(
                        database.extra)["schemas_allowed_for_csv_upload"],
                ))
        return databases_json
Exemplo n.º 33
0
def run_etl():

    # db_session = app_manager.get_db().session()
    db_session = db.session()

    msg = 'Scheduler Try Run ETL Async Tasks'
    print(msg)
    logger.info(msg)

    etl_tasks = db_session.query(EtlTable).with_for_update(
        skip_locked=True).filter(EtlTable.is_active.is_(True),
                                 EtlTable.is_valid.is_(True),
                                 EtlTable.is_scheduled.isnot(True),
                                 EtlTable.sync_periodic != 0,
                                 EtlTable.sync_next_time < datetime.utcnow())

    for etl_task in etl_tasks:
        etl_task.is_scheduled = True
        logger.info(etl_task)
        db_session.merge(etl_task)
        async_etl.delay(etl_id=etl_task.id)
    db_session.commit()
    return True
Exemplo n.º 34
0
 def get_dash_by_slug(self, dash_slug):
     sesh = db.session()
     return sesh.query(models.Dashboard).filter_by(
         slug=dash_slug).first()
Exemplo n.º 35
0
class TableColumnInlineView(CompactCRUDMixin, SupersetModelView):  # noqa
    datamodel = SQLAInterface(models.TableColumn)

    list_title = _('Columns')
    show_title = _('Show Column')
    add_title = _('Add Column')
    edit_title = _('Edit Column')

    can_delete = False
    list_widget = ListWidgetWithCheckboxes
    edit_columns = [
        'column_name', 'verbose_name', 'description', 'type', 'groupby',
        'filterable', 'table', 'expression', 'is_dttm', 'python_date_format',
        'database_expression'
    ]
    add_columns = edit_columns
    list_columns = [
        'column_name', 'verbose_name', 'type', 'groupby', 'filterable',
        'is_dttm'
    ]
    page_size = 500
    description_columns = {
        'is_dttm':
        _('Whether to make this column available as a '
          '[Time Granularity] option, column has to be DATETIME or '
          'DATETIME-like'),
        'filterable':
        _('Whether this column is exposed in the `Filters` section '
          'of the explore view.'),
        'type':
        _('The data type that was inferred by the database. '
          'It may be necessary to input a type manually for '
          'expression-defined columns in some cases. In most case '
          'users should not need to alter this.'),
        'expression':
        utils.markdown(
            'a valid, *non-aggregating* SQL expression as supported by the '
            'underlying backend. Example: `substr(name, 1, 1)`', True),
        'python_date_format':
        utils.markdown(
            Markup(
                'The pattern of timestamp format, use '
                '<a href="https://docs.python.org/2/library/'
                'datetime.html#strftime-strptime-behavior">'
                'python datetime string pattern</a> '
                'expression. If time is stored in epoch '
                'format, put `epoch_s` or `epoch_ms`. Leave `Database Expression` '
                'below empty if timestamp is stored in '
                'String or Integer(epoch) type'), True),
        'database_expression':
        utils.markdown(
            'The database expression to cast internal datetime '
            'constants to database date/timestamp type according to the DBAPI. '
            'The expression should follow the pattern of '
            '%Y-%m-%d %H:%M:%S, based on different DBAPI. '
            'The string should be a python string formatter \n'
            "`Ex: TO_DATE('{}', 'YYYY-MM-DD HH24:MI:SS')` for Oracle "
            'Superset uses default expression based on DB URI if this '
            'field is blank.', True),
    }
    label_columns = {
        'column_name': _('Column'),
        'verbose_name': _('Verbose Name'),
        'description': _('Description'),
        'groupby': _('Groupable'),
        'filterable': _('Filterable'),
        'table': _('Table'),
        'expression': _('Expression'),
        'is_dttm': _('Is temporal'),
        'python_date_format': _('Datetime Format'),
        'database_expression': _('Database Expression'),
        'type': _('Type'),
    }

    add_form_extra_fields = {
        'table':
        QuerySelectField(
            'Table',
            query_factory=lambda: db.session().query(models.SqlaTable),
            allow_blank=True,
            widget=Select2Widget(extra_classes='readonly'),
        ),
    }

    edit_form_extra_fields = add_form_extra_fields
Exemplo n.º 36
0
class TableModelView(DatasourceModelView, DeleteMixin,
                     YamlExportMixin):  # noqa
    datamodel = SQLAInterface(models.SqlaTable)

    list_title = _('Tables')
    show_title = _('Show Table')
    add_title = _('Import a table definition')
    edit_title = _('Edit Table')

    list_columns = ['link', 'database_name', 'changed_by_', 'modified']
    order_columns = ['modified']
    add_columns = ['database', 'schema', 'table_name']
    edit_columns = [
        'table_name',
        'sql',
        'filter_select_enabled',
        'fetch_values_predicate',
        'database',
        'schema',
        'description',
        'owners',
        'main_dttm_col',
        'default_endpoint',
        'offset',
        'cache_timeout',
        'is_sqllab_view',
        'template_params',
    ]
    base_filters = [['id', DatasourceFilter, lambda: []]]
    show_columns = edit_columns + ['perm', 'slices']
    related_views = [TableColumnInlineView, SqlMetricInlineView]
    base_order = ('changed_on', 'desc')
    search_columns = (
        'database',
        'schema',
        'table_name',
        'owners',
        'is_sqllab_view',
    )
    description_columns = {
        'slices':
        _('The list of charts associated with this table. By '
          'altering this datasource, you may change how these associated '
          'charts behave. '
          'Also note that charts need to point to a datasource, so '
          'this form will fail at saving if removing charts from a '
          'datasource. If you want to change the datasource for a chart, '
          "overwrite the chart from the 'explore view'"),
        'offset':
        _('Timezone offset (in hours) for this datasource'),
        'table_name':
        _('Name of the table that exists in the source database'),
        'schema':
        _('Schema, as used only in some databases like Postgres, Redshift '
          'and DB2'),
        'description':
        Markup(
            'Supports <a href="https://daringfireball.net/projects/markdown/">'
            'markdown</a>'),
        'sql':
        _(
            'This fields acts a Superset view, meaning that Superset will '
            'run a query against this string as a subquery.', ),
        'fetch_values_predicate':
        _(
            'Predicate applied when fetching distinct value to '
            'populate the filter control component. Supports '
            'jinja template syntax. Applies only when '
            '`Enable Filter Select` is on.', ),
        'default_endpoint':
        _('Redirects to this endpoint when clicking on the table '
          'from the table list'),
        'filter_select_enabled':
        _("Whether to populate the filter's dropdown in the explore "
          "view's filter section with a list of distinct values fetched "
          'from the backend on the fly'),
        'is_sqllab_view':
        _("Whether the table was generated by the 'Visualize' flow "
          'in SQL Lab'),
        'template_params':
        _('A set of parameters that become available in the query using '
          'Jinja templating syntax'),
        'cache_timeout':
        _('Duration (in seconds) of the caching timeout for this table. '
          'A timeout of 0 indicates that the cache never expires. '
          'Note this defaults to the database timeout if undefined.'),
    }
    label_columns = {
        'slices': _('Associated Charts'),
        'link': _('Table'),
        'changed_by_': _('Changed By'),
        'database': _('Database'),
        'database_name': _('Database'),
        'changed_on_': _('Last Changed'),
        'filter_select_enabled': _('Enable Filter Select'),
        'schema': _('Schema'),
        'default_endpoint': _('Default Endpoint'),
        'offset': _('Offset'),
        'cache_timeout': _('Cache Timeout'),
        'table_name': _('Table Name'),
        'fetch_values_predicate': _('Fetch Values Predicate'),
        'owners': _('Owners'),
        'main_dttm_col': _('Main Datetime Column'),
        'description': _('Description'),
        'is_sqllab_view': _('SQL Lab View'),
        'template_params': _('Template parameters'),
        'modified': _('Modified'),
    }

    edit_form_extra_fields = {
        'database':
        QuerySelectField(
            'Database',
            query_factory=lambda: db.session().query(models.Database),
            widget=Select2Widget(extra_classes='readonly'),
        ),
    }

    def pre_add(self, table):
        with db.session.no_autoflush:
            table_query = db.session.query(models.SqlaTable).filter(
                models.SqlaTable.table_name == table.table_name,
                models.SqlaTable.schema == table.schema,
                models.SqlaTable.database_id == table.database.id)
            if db.session.query(table_query.exists()).scalar():
                raise Exception(get_datasource_exist_error_msg(
                    table.full_name))

        # Fail before adding if the table can't be found
        try:
            table.get_sqla_table_object()
        except Exception as e:
            logger.exception(f'Got an error in pre_add for {table.name}')
            raise Exception(
                _('Table [{}] could not be found, '
                  'please double check your '
                  'database connection, schema, and '
                  'table name, error: {}').format(table.name, str(e)))

    def post_add(self, table, flash_message=True):
        table.fetch_metadata()
        security_manager.add_permission_view_menu('datasource_access',
                                                  table.get_perm())
        if table.schema:
            security_manager.add_permission_view_menu('schema_access',
                                                      table.schema_perm)

        if flash_message:
            flash(
                _('The table was created. '
                  'As part of this two-phase configuration '
                  'process, you should now click the edit button by '
                  'the new table to configure it.'), 'info')

    def post_update(self, table):
        self.post_add(table, flash_message=False)

    def _delete(self, pk):
        DeleteMixin._delete(self, pk)

    @expose('/edit/<pk>', methods=['GET', 'POST'])
    @has_access
    def edit(self, pk):
        """Simple hack to redirect to explore view after saving"""
        resp = super(TableModelView, self).edit(pk)
        if isinstance(resp, str):
            return resp
        return redirect('/superset/explore/table/{}/'.format(pk))

    @action('refresh', __('Refresh Metadata'), __('Refresh column metadata'),
            'fa-refresh')
    def refresh(self, tables):
        if not isinstance(tables, list):
            tables = [tables]
        successes = []
        failures = []
        for t in tables:
            try:
                t.fetch_metadata()
                successes.append(t)
            except Exception:
                failures.append(t)

        if len(successes) > 0:
            success_msg = _(
                'Metadata refreshed for the following table(s): %(tables)s',
                tables=', '.join([t.table_name for t in successes]))
            flash(success_msg, 'info')
        if len(failures) > 0:
            failure_msg = _(
                'Unable to retrieve metadata for the following table(s): %(tables)s',
                tables=', '.join([t.table_name for t in failures]))
            flash(failure_msg, 'danger')

        return redirect('/tablemodelview/list/')
Exemplo n.º 37
0
class SqlMetricInlineView(CompactCRUDMixin, SupersetModelView):  # noqa
    datamodel = SQLAInterface(models.SqlMetric)

    list_title = _('Metrics')
    show_title = _('Show Metric')
    add_title = _('Add Metric')
    edit_title = _('Edit Metric')

    list_columns = ['metric_name', 'verbose_name', 'metric_type']
    edit_columns = [
        'metric_name', 'description', 'verbose_name', 'metric_type',
        'expression', 'table', 'd3format', 'is_restricted', 'warning_text'
    ]
    description_columns = {
        'expression':
        utils.markdown(
            'a valid, *aggregating* SQL expression as supported by the '
            'underlying backend. Example: `count(DISTINCT userid)`', True),
        'is_restricted':
        _('Whether access to this metric is restricted '
          'to certain roles. Only roles with the permission '
          "'metric access on XXX (the name of this metric)' "
          'are allowed to access this metric'),
        'd3format':
        utils.markdown(
            'd3 formatting string as defined [here]'
            '(https://github.com/d3/d3-format/blob/master/README.md#format). '
            'For instance, this default formatting applies in the Table '
            'visualization and allow for different metric to use different '
            'formats',
            True,
        ),
    }
    add_columns = edit_columns
    page_size = 500
    label_columns = {
        'metric_name': _('Metric'),
        'description': _('Description'),
        'verbose_name': _('Verbose Name'),
        'metric_type': _('Type'),
        'expression': _('SQL Expression'),
        'table': _('Table'),
        'd3format': _('D3 Format'),
        'is_restricted': _('Is Restricted'),
        'warning_text': _('Warning Message'),
    }

    add_form_extra_fields = {
        'table':
        QuerySelectField(
            'Table',
            query_factory=lambda: db.session().query(models.SqlaTable),
            allow_blank=True,
            widget=Select2Widget(extra_classes='readonly'),
        ),
    }

    edit_form_extra_fields = add_form_extra_fields

    def post_add(self, metric):
        if metric.is_restricted:
            security_manager.add_permission_view_menu('metric_access',
                                                      metric.get_perm())

    def post_update(self, metric):
        if metric.is_restricted:
            security_manager.add_permission_view_menu('metric_access',
                                                      metric.get_perm())
Exemplo n.º 38
0
def get_sql_results(self, query_id, return_results=True, store_results=False):
    """Executes the sql query returns the results."""
    if not self.request.called_directly:
        engine = sqlalchemy.create_engine(
            app.config.get('SQLALCHEMY_DATABASE_URI'), poolclass=NullPool)
        session_class = sessionmaker()
        session_class.configure(bind=engine)
        session = session_class()
    else:
        session = db.session()
        session.commit()  # HACK
    query = session.query(models.Query).filter_by(id=query_id).one()
    database = query.database
    db_engine_spec = database.db_engine_spec

    def handle_error(msg):
        """Local method handling error while processing the SQL"""
        query.error_message = msg
        query.status = QueryStatus.FAILED
        query.tmp_table_name = None
        session.commit()
        raise Exception(query.error_message)

    if store_results and not results_backend:
        handle_error("Results backend isn't configured.")

    # Limit enforced only for retrieving the data, not for the CTA queries.
    superset_query = SupersetQuery(query.sql)
    executed_sql = superset_query.stripped()
    if not superset_query.is_select() and not database.allow_dml:
        handle_error(
            "Only `SELECT` statements are allowed against this database")
    if query.select_as_cta:
        if not superset_query.is_select():
            handle_error(
                "Only `SELECT` statements can be used with the CREATE TABLE "
                "feature.")
        if not query.tmp_table_name:
            start_dttm = datetime.fromtimestamp(query.start_time)
            query.tmp_table_name = 'tmp_{}_table_{}'.format(
                query.user_id,
                start_dttm.strftime('%Y_%m_%d_%H_%M_%S'))
        executed_sql = superset_query.as_create_table(query.tmp_table_name)
        query.select_as_cta_used = True
    elif (
            query.limit and superset_query.is_select() and
            db_engine_spec.limit_method == LimitMethod.WRAP_SQL):
        executed_sql = database.wrap_sql_limit(executed_sql, query.limit)
        query.limit_used = True
    engine = database.get_sqla_engine(schema=query.schema)
    try:
        template_processor = get_template_processor(
            database=database, query=query)
        executed_sql = template_processor.process_template(executed_sql)
        executed_sql = db_engine_spec.sql_preprocessor(executed_sql)
    except Exception as e:
        logging.exception(e)
        msg = "Template rendering failed: " + utils.error_msg_from_exception(e)
        handle_error(msg)

    query.executed_sql = executed_sql
    logging.info("Running query: \n{}".format(executed_sql))
    try:
        result_proxy = engine.execute(query.executed_sql, schema=query.schema)
    except Exception as e:
        logging.exception(e)
        handle_error(db_engine_spec.extract_error_message(e))

    cursor = result_proxy.cursor
    query.status = QueryStatus.RUNNING
    session.flush()
    db_engine_spec.handle_cursor(cursor, query, session)

    cdf = None
    if result_proxy.cursor:
        column_names = [col[0] for col in result_proxy.cursor.description]
        column_names = dedup(column_names)
        if db_engine_spec.limit_method == LimitMethod.FETCH_MANY:
            data = result_proxy.fetchmany(query.limit)
        else:
            data = result_proxy.fetchall()
        cdf = dataframe.SupersetDataFrame(
            pd.DataFrame(data, columns=column_names))

    query.rows = result_proxy.rowcount
    query.progress = 100
    query.status = QueryStatus.SUCCESS
    if query.rows == -1 and cdf:
        # Presto doesn't provide result_proxy.row_count
        query.rows = cdf.size
    if query.select_as_cta:
        query.select_sql = '{}'.format(database.select_star(
            query.tmp_table_name,
            limit=query.limit,
            schema=database.force_ctas_schema
        ))
    query.end_time = utils.now_as_float()
    session.flush()

    payload = {
        'query_id': query.id,
        'status': query.status,
        'data': [],
    }
    payload['data'] = cdf.data if cdf else []
    payload['columns'] = cdf.columns_dict if cdf else []
    payload['query'] = query.to_dict()
    payload = json.dumps(payload, default=utils.json_iso_dttm_ser)

    if store_results:
        key = '{}'.format(uuid.uuid4())
        logging.info("Storing results in results backend, key: {}".format(key))
        results_backend.set(key, zlib.compress(payload))
        query.results_key = key

    session.flush()
    session.commit()

    if return_results:
        return payload
Exemplo n.º 39
0
def get_sql_results(self, query_id, return_results=True, store_results=False):
    """Executes the sql query returns the results."""
    if not self.request.called_directly:
        engine = sqlalchemy.create_engine(
            app.config.get('SQLALCHEMY_DATABASE_URI'), poolclass=NullPool)
        session_class = sessionmaker()
        session_class.configure(bind=engine)
        session = session_class()
    else:
        session = db.session()
        session.commit()  # HACK
    query = session.query(models.Query).filter_by(id=query_id).one()
    database = query.database
    executed_sql = query.sql.strip().strip(';')
    db_engine_spec = database.db_engine_spec

    def handle_error(msg):
        """Local method handling error while processing the SQL"""
        query.error_message = msg
        query.status = QueryStatus.FAILED
        query.tmp_table_name = None
        session.commit()
        raise Exception(query.error_message)

    # Limit enforced only for retrieving the data, not for the CTA queries.
    is_select = is_query_select(executed_sql);
    if not is_select and not database.allow_dml:
        handle_error(
            "Only `SELECT` statements are allowed against this database")
    if query.select_as_cta:
        if not is_select:
            handle_error(
                "Only `SELECT` statements can be used with the CREATE TABLE "
                "feature.")
        if not query.tmp_table_name:
            start_dttm = datetime.fromtimestamp(query.start_time)
            query.tmp_table_name = 'tmp_{}_table_{}'.format(
                query.user_id,
                start_dttm.strftime('%Y_%m_%d_%H_%M_%S'))
        executed_sql = create_table_as(
            executed_sql, query.tmp_table_name, database.force_ctas_schema)
        query.select_as_cta_used = True
    elif (
            query.limit and is_select and
            db_engine_spec.limit_method == LimitMethod.WRAP_SQL):
        executed_sql = database.wrap_sql_limit(executed_sql, query.limit)
        query.limit_used = True
    engine = database.get_sqla_engine(schema=query.schema)
    try:
        template_processor = get_template_processor(
            database=database, query=query)
        executed_sql = template_processor.process_template(executed_sql)
    except Exception as e:
        logging.exception(e)
        msg = "Template rendering failed: " + utils.error_msg_from_exception(e)
        handle_error(msg)
    try:
        query.executed_sql = executed_sql
        logging.info("Running query: \n{}".format(executed_sql))
        result_proxy = engine.execute(query.executed_sql, schema=query.schema)
    except Exception as e:
        logging.exception(e)
        handle_error(utils.error_msg_from_exception(e))

    cursor = result_proxy.cursor
    query.status = QueryStatus.RUNNING
    session.flush()
    db_engine_spec.handle_cursor(cursor, query, session)

    cdf = None
    if result_proxy.cursor:
        column_names = [col[0] for col in result_proxy.cursor.description]
        if db_engine_spec.limit_method == LimitMethod.FETCH_MANY:
            data = result_proxy.fetchmany(query.limit)
        else:
            data = result_proxy.fetchall()
        cdf = dataframe.SupersetDataFrame(
            pd.DataFrame(data, columns=column_names))

    query.rows = result_proxy.rowcount
    query.progress = 100
    query.status = QueryStatus.SUCCESS
    if query.rows == -1 and cdf:
        # Presto doesn't provide result_proxy.row_count
        query.rows = cdf.size
    if query.select_as_cta:
        query.select_sql = '{}'.format(database.select_star(
            query.tmp_table_name, limit=query.limit))
    query.end_time = utils.now_as_float()
    session.flush()

    payload = {
        'query_id': query.id,
        'status': query.status,
        'data': [],
    }
    payload['data'] = cdf.data if cdf else []
    payload['columns'] = cdf.columns_dict if cdf else []
    payload['query'] = query.to_dict()
    payload = json.dumps(payload, default=utils.json_iso_dttm_ser)

    if store_results and results_backend:
        key = '{}'.format(uuid.uuid4())
        logging.info("Storing results in results backend, key: {}".format(key))
        results_backend.set(key, zlib.compress(payload))
        query.results_key = key

    session.flush()
    session.commit()

    if return_results:
        return payload
Exemplo n.º 40
0
 def get_dash_by_slug(self, dash_slug):
     sesh = db.session()
     return sesh.query(Dashboard).filter_by(slug=dash_slug).first()
Exemplo n.º 41
0
def sync_role_definitions():
    """Inits the Superset application with security roles and such"""
    logging.info("Syncing role definition")

    # Creating default roles
    alpha = sm.add_role("Alpha")
    admin = sm.add_role("Admin")
    gamma = sm.add_role("Gamma")
    public = sm.add_role("Public")
    sql_lab = sm.add_role("sql_lab")
    granter = sm.add_role("granter")

    get_or_create_main_db()

    # Global perms
    sm.add_permission_view_menu(
        'all_datasource_access', 'all_datasource_access')
    sm.add_permission_view_menu('all_database_access', 'all_database_access')

    perms = db.session.query(ab_models.PermissionView).all()
    perms = [p for p in perms if p.permission and p.view_menu]

    logging.info("Syncing admin perms")
    for p in perms:
        # admin has all_database_access and all_datasource_access
        if is_user_defined_permission(p):
            sm.del_permission_role(admin, p)
        else:
            sm.add_permission_role(admin, p)

    logging.info("Syncing alpha perms")
    for p in perms:
        # alpha has all_database_access and all_datasource_access
        if is_user_defined_permission(p):
            sm.del_permission_role(alpha, p)
        elif (
                (
                    p.view_menu.name not in ADMIN_ONLY_VIEW_MENUES and
                    p.permission.name not in ADMIN_ONLY_PERMISSIONS
                ) or
                (p.permission.name, p.view_menu.name) in READ_ONLY_PRODUCT
        ):
            sm.add_permission_role(alpha, p)
        else:
            sm.del_permission_role(alpha, p)

    logging.info("Syncing gamma perms and public if specified")
    PUBLIC_ROLE_LIKE_GAMMA = conf.get('PUBLIC_ROLE_LIKE_GAMMA', False)
    for p in perms:
        if (
                (
                    p.view_menu.name not in ADMIN_ONLY_VIEW_MENUES and
                    p.permission.name not in ADMIN_ONLY_PERMISSIONS and
                    p.permission.name not in ALPHA_ONLY_PERMISSIONS
                ) or
                (p.permission.name, p.view_menu.name) in READ_ONLY_PRODUCT
        ):
            sm.add_permission_role(gamma, p)
            if PUBLIC_ROLE_LIKE_GAMMA:
                sm.add_permission_role(public, p)
        else:
            sm.del_permission_role(gamma, p)
            sm.del_permission_role(public, p)

    logging.info("Syncing sql_lab perms")
    for p in perms:
        if (
                p.view_menu.name in {'SQL Lab'} or
                p.permission.name in {
                    'can_sql_json', 'can_csv', 'can_search_queries'}
        ):
            sm.add_permission_role(sql_lab, p)
        else:
            sm.del_permission_role(sql_lab, p)

    logging.info("Syncing granter perms")
    for p in perms:
        if (
                p.permission.name in {
                    'can_override_role_permissions', 'can_aprove'}
        ):
            sm.add_permission_role(granter, p)
        else:
            sm.del_permission_role(granter, p)

    logging.info("Making sure all data source perms have been created")
    session = db.session()
    datasources = [
        o for o in session.query(models.SqlaTable).all()]
    datasources += [
        o for o in session.query(models.DruidDatasource).all()]
    for datasource in datasources:
        perm = datasource.get_perm()
        sm.add_permission_view_menu('datasource_access', perm)
        if perm != datasource.perm:
            datasource.perm = perm

    logging.info("Making sure all database perms have been created")
    databases = [o for o in session.query(models.Database).all()]
    for database in databases:
        perm = database.get_perm()
        if perm != database.perm:
            database.perm = perm
        sm.add_permission_view_menu('database_access', perm)
    session.commit()

    logging.info("Making sure all metrics perms exist")
    models.init_metrics_perm()
Exemplo n.º 42
0
def get_sql_results(self, query_id, return_results=True, store_results=False):
    """Executes the sql query returns the results."""
    if not self.request.called_directly:
        engine = sqlalchemy.create_engine(
            app.config.get('SQLALCHEMY_DATABASE_URI'), poolclass=NullPool)
        session_class = sessionmaker()
        session_class.configure(bind=engine)
        session = session_class()
    else:
        session = db.session()
        session.commit()  # HACK
    try:
        query = session.query(models.Query).filter_by(id=query_id).one()
    except Exception as e:
        logging.error("Query with id `{}` could not be retrieved".format(query_id))
        logging.error("Sleeping for a sec and retrying...")
        # Nasty hack to get around a race condition where the worker
        # cannot find the query it's supposed to run
        sleep(1)
        query = session.query(models.Query).filter_by(id=query_id).one()

    database = query.database
    db_engine_spec = database.db_engine_spec
    db_engine_spec.patch()

    def handle_error(msg):
        """Local method handling error while processing the SQL"""
        query.error_message = msg
        query.status = QueryStatus.FAILED
        query.tmp_table_name = None
        session.commit()
        raise Exception(query.error_message)

    if store_results and not results_backend:
        handle_error("Results backend isn't configured.")

    # Limit enforced only for retrieving the data, not for the CTA queries.
    superset_query = SupersetQuery(query.sql)
    executed_sql = superset_query.stripped()
    if not superset_query.is_select() and not database.allow_dml:
        handle_error(
            "Only `SELECT` statements are allowed against this database")
    if query.select_as_cta:
        if not superset_query.is_select():
            handle_error(
                "Only `SELECT` statements can be used with the CREATE TABLE "
                "feature.")
        if not query.tmp_table_name:
            start_dttm = datetime.fromtimestamp(query.start_time)
            query.tmp_table_name = 'tmp_{}_table_{}'.format(
                query.user_id,
                start_dttm.strftime('%Y_%m_%d_%H_%M_%S'))
        executed_sql = superset_query.as_create_table(query.tmp_table_name)
        query.select_as_cta_used = True
    elif (
            query.limit and superset_query.is_select() and
            db_engine_spec.limit_method == LimitMethod.WRAP_SQL):
        executed_sql = database.wrap_sql_limit(executed_sql, query.limit)
        query.limit_used = True
    try:
        template_processor = get_template_processor(
            database=database, query=query)
        executed_sql = template_processor.process_template(executed_sql)
        executed_sql = db_engine_spec.sql_preprocessor(executed_sql)
    except Exception as e:
        logging.exception(e)
        msg = "Template rendering failed: " + utils.error_msg_from_exception(e)
        handle_error(msg)

    query.executed_sql = executed_sql
    query.status = QueryStatus.RUNNING
    query.start_running_time = utils.now_as_float()
    session.merge(query)
    session.commit()
    logging.info("Set query to 'running'")

    engine = database.get_sqla_engine(schema=query.schema)
    conn = engine.raw_connection()
    cursor = conn.cursor()
    logging.info("Running query: \n{}".format(executed_sql))
    try:
        logging.info(query.executed_sql)
        cursor.execute(
            query.executed_sql, **db_engine_spec.cursor_execute_kwargs)
    except Exception as e:
        logging.exception(e)
        conn.close()
        handle_error(db_engine_spec.extract_error_message(e))

    try:
        logging.info("Handling cursor")
        db_engine_spec.handle_cursor(cursor, query, session)
        logging.info("Fetching data: {}".format(query.to_dict()))
        data = db_engine_spec.fetch_data(cursor, query.limit)
    except Exception as e:
        logging.exception(e)
        conn.close()
        handle_error(db_engine_spec.extract_error_message(e))

    conn.commit()
    conn.close()

    if query.status == utils.QueryStatus.STOPPED:
        return json.dumps({
            'query_id': query.id,
            'status': query.status,
            'query': query.to_dict(),
        }, default=utils.json_iso_dttm_ser)

    column_names = (
        [col[0] for col in cursor.description] if cursor.description else [])
    column_names = dedup(column_names)
    cdf = dataframe.SupersetDataFrame(pd.DataFrame(
        list(data), columns=column_names))

    query.rows = cdf.size
    query.progress = 100
    query.status = QueryStatus.SUCCESS
    if query.select_as_cta:
        query.select_sql = '{}'.format(database.select_star(
            query.tmp_table_name,
            limit=query.limit,
            schema=database.force_ctas_schema
        ))
    query.end_time = utils.now_as_float()
    session.merge(query)
    session.flush()

    payload = {
        'query_id': query.id,
        'status': query.status,
        'data': cdf.data if cdf.data else [],
        'columns': cdf.columns if cdf.columns else [],
        'query': query.to_dict(),
    }
    payload = json.dumps(payload, default=utils.json_iso_dttm_ser)

    if store_results:
        key = '{}'.format(uuid.uuid4())
        logging.info("Storing results in results backend, key: {}".format(key))
        results_backend.set(key, zlib.compress(payload))
        query.results_key = key

    session.merge(query)
    session.commit()

    if return_results:
        return payload
Exemplo n.º 43
0
    def sync_to_db_from_config(cls, druid_config, user, cluster):
        """Merges the ds config from druid_config into one stored in the db."""
        session = db.session()
        datasource = (
            session.query(cls)
            .filter_by(
                datasource_name=druid_config['name'])
            .first()
        )
        # Create a new datasource.
        if not datasource:
            datasource = cls(
                datasource_name=druid_config['name'],
                cluster=cluster,
                owner=user,
                changed_by_fk=user.id,
                created_by_fk=user.id,
            )
            session.add(datasource)

        dimensions = druid_config['dimensions']
        for dim in dimensions:
            col_obj = (
                session.query(DruidColumn)
                .filter_by(
                    datasource_name=druid_config['name'],
                    column_name=dim)
                .first()
            )
            if not col_obj:
                col_obj = DruidColumn(
                    datasource_name=druid_config['name'],
                    column_name=dim,
                    groupby=True,
                    filterable=True,
                    # TODO: fetch type from Hive.
                    type="STRING",
                    datasource=datasource,
                )
                session.add(col_obj)
        # Import Druid metrics
        for metric_spec in druid_config["metrics_spec"]:
            metric_name = metric_spec["name"]
            metric_type = metric_spec["type"]
            metric_json = json.dumps(metric_spec)

            if metric_type == "count":
                metric_type = "longSum"
                metric_json = json.dumps({
                    "type": "longSum",
                    "name": metric_name,
                    "fieldName": metric_name,
                })

            metric_obj = (
                session.query(DruidMetric)
                .filter_by(
                    datasource_name=druid_config['name'],
                    metric_name=metric_name)
            ).first()
            if not metric_obj:
                metric_obj = DruidMetric(
                    metric_name=metric_name,
                    metric_type=metric_type,
                    verbose_name="%s(%s)" % (metric_type, metric_name),
                    datasource=datasource,
                    json=metric_json,
                    description=(
                        "Imported from the airolap config dir for %s" %
                        druid_config['name']),
                )
                session.add(metric_obj)
        session.commit()