Exemplo n.º 1
0
def handle_query_error(
        msg: str,
        query: Query,
        session: Session,
        payload: Optional[Dict[str, Any]] = None) -> Dict[str, Any]:
    """Local method handling error while processing the SQL"""
    payload = payload or {}
    troubleshooting_link = config["TROUBLESHOOTING_LINK"]
    query.error_message = msg
    query.status = QueryStatus.FAILED
    query.tmp_table_name = None

    # extract DB-specific errors (invalid column, eg)
    errors = [
        dataclasses.asdict(error)
        for error in query.database.db_engine_spec.extract_errors(msg)
    ]
    if errors:
        query.set_extra_json_key("errors", errors)

    session.commit()
    payload.update({"status": query.status, "error": msg, "errors": errors})
    if troubleshooting_link:
        payload["link"] = troubleshooting_link
    return payload
Exemplo n.º 2
0
def handle_query_error(
    ex: Exception,
    query: Query,
    session: Session,
    payload: Optional[Dict[str, Any]] = None,
    prefix_message: str = "",
) -> Dict[str, Any]:
    """Local method handling error while processing the SQL"""
    payload = payload or {}
    msg = f"{prefix_message} {str(ex)}".strip()
    troubleshooting_link = config["TROUBLESHOOTING_LINK"]
    query.error_message = msg
    query.status = QueryStatus.FAILED
    query.tmp_table_name = None

    # extract DB-specific errors (invalid column, eg)
    if isinstance(ex, SupersetErrorException):
        errors = [ex.error]
    elif isinstance(ex, SupersetErrorsException):
        errors = ex.errors
    else:
        errors = query.database.db_engine_spec.extract_errors(str(ex))

    errors_payload = [dataclasses.asdict(error) for error in errors]
    if errors:
        query.set_extra_json_key("errors", errors_payload)

    session.commit()
    payload.update({"status": query.status, "error": msg, "errors": errors_payload})
    if troubleshooting_link:
        payload["link"] = troubleshooting_link
    return payload
Exemplo n.º 3
0
 def create_query(self) -> Query:
     # pylint: disable=line-too-long
     start_time = now_as_float()
     if self.select_as_cta:
         return Query(
             database_id=self.database_id,
             sql=self.sql,
             schema=self.schema,
             select_as_cta=True,
             ctas_method=self.create_table_as_select.ctas_method,  # type: ignore
             start_time=start_time,
             tab_name=self.tab_name,
             status=self.status,
             limit=self.limit,
             sql_editor_id=self.sql_editor_id,
             tmp_table_name=self.create_table_as_select.target_table_name,  # type: ignore
             tmp_schema_name=self.create_table_as_select.target_schema_name,  # type: ignore
             user_id=self.user_id,
             client_id=self.client_id_or_short_id,
         )
     return Query(
         database_id=self.database_id,
         sql=self.sql,
         schema=self.schema,
         select_as_cta=False,
         start_time=start_time,
         tab_name=self.tab_name,
         limit=self.limit,
         status=self.status,
         sql_editor_id=self.sql_editor_id,
         user_id=self.user_id,
         client_id=self.client_id_or_short_id,
     )
Exemplo n.º 4
0
    def handle_cursor(cls, cursor: Cursor, query: Query, session: Session) -> None:
        tracking_url = cls.get_tracking_url(cursor)
        if tracking_url:
            query.tracking_url = tracking_url

        # Adds the executed query id to the extra payload so the query can be cancelled
        query.set_extra_json_key("cancel_query", cursor.stats["queryId"])

        session.commit()
        BaseEngineSpec.handle_cursor(cursor=cursor, query=query, session=session)
Exemplo n.º 5
0
    def handle_cursor(  # pylint: disable=too-many-locals
        cls, cursor: Any, query: Query, session: Session
    ) -> None:
        """Updates progress information"""
        from pyhive import hive  # pylint: disable=no-name-in-module

        unfinished_states = (
            hive.ttypes.TOperationState.INITIALIZED_STATE,
            hive.ttypes.TOperationState.RUNNING_STATE,
        )
        polled = cursor.poll()
        last_log_line = 0
        tracking_url = None
        job_id = None
        query_id = query.id
        while polled.operationState in unfinished_states:
            query = session.query(type(query)).filter_by(id=query_id).one()
            if query.status == QueryStatus.STOPPED:
                cursor.cancel()
                break

            log = cursor.fetch_logs() or ""
            if log:
                log_lines = log.splitlines()
                progress = cls.progress(log_lines)
                logger.info(f"Query {query_id}: Progress total: {progress}")
                needs_commit = False
                if progress > query.progress:
                    query.progress = progress
                    needs_commit = True
                if not tracking_url:
                    tracking_url = cls.get_tracking_url(log_lines)
                    if tracking_url:
                        job_id = tracking_url.split("/")[-2]
                        logger.info(
                            f"Query {query_id}: Found the tracking url: {tracking_url}"
                        )
                        tracking_url = tracking_url_trans(tracking_url)
                        logger.info(
                            f"Query {query_id}: Transformation applied: {tracking_url}"
                        )
                        query.tracking_url = tracking_url
                        logger.info(f"Query {query_id}: Job id: {job_id}")
                        needs_commit = True
                if job_id and len(log_lines) > last_log_line:
                    # Wait for job id before logging things out
                    # this allows for prefixing all log lines and becoming
                    # searchable in something like Kibana
                    for l in log_lines[last_log_line:]:
                        logger.info(f"Query {query_id}: [{job_id}] {l}")
                    last_log_line = len(log_lines)
                if needs_commit:
                    session.commit()
            time.sleep(hive_poll_interval)
            polled = cursor.poll()
Exemplo n.º 6
0
def handle_query_error(
    msg: str, query: Query, session: Session, payload: Optional[Dict[str, Any]] = None
) -> Dict[str, Any]:
    """Local method handling error while processing the SQL"""
    payload = payload or {}
    troubleshooting_link = config["TROUBLESHOOTING_LINK"]
    query.error_message = msg
    query.status = QueryStatus.FAILED
    query.tmp_table_name = None
    session.commit()
    payload.update({"status": query.status, "error": msg})
    if troubleshooting_link:
        payload["link"] = troubleshooting_link
    return payload
 def insert_query(
     self,
     database_id: int,
     user_id: int,
     client_id: str,
     sql: str = "",
     select_sql: str = "",
     executed_sql: str = "",
     limit: int = 100,
     progress: int = 100,
     rows: int = 100,
     tab_name: str = "",
     status: str = "success",
 ) -> Query:
     database = db.session.query(Database).get(database_id)
     user = db.session.query(security_manager.user_model).get(user_id)
     query = Query(
         database=database,
         user=user,
         client_id=client_id,
         sql=sql,
         select_sql=select_sql,
         executed_sql=executed_sql,
         limit=limit,
         progress=progress,
         rows=rows,
         tab_name=tab_name,
         status=status,
     )
     db.session.add(query)
     db.session.commit()
     return query
Exemplo n.º 8
0
def test_cancel_query_success(engine_mock: mock.Mock) -> None:
    from superset.db_engine_specs.trino import TrinoEngineSpec
    from superset.models.sql_lab import Query

    query = Query()
    cursor_mock = engine_mock.return_value.__enter__.return_value
    assert TrinoEngineSpec.cancel_query(cursor_mock, query, "123") is True
Exemplo n.º 9
0
def test_cancel_query_failed(engine_mock: mock.Mock) -> None:
    from superset.db_engine_specs.trino import TrinoEngineSpec
    from superset.models.sql_lab import Query

    query = Query()
    cursor_mock = engine_mock.raiseError.side_effect = Exception()
    assert TrinoEngineSpec.cancel_query(cursor_mock, query, "123") is False
Exemplo n.º 10
0
def test_query_dao_save_metadata(session: Session) -> None:
    from superset.models.core import Database
    from superset.models.sql_lab import Query

    engine = session.get_bind()
    Query.metadata.create_all(engine)  # pylint: disable=no-member

    db = Database(database_name="my_database", sqlalchemy_uri="sqlite://")

    query_obj = Query(
        client_id="foo",
        database=db,
        tab_name="test_tab",
        sql_editor_id="test_editor_id",
        sql="select * from bar",
        select_sql="select * from bar",
        executed_sql="select * from bar",
        limit=100,
        select_as_cta=False,
        rows=100,
        error_message="none",
        results_key="abc",
    )

    session.add(db)
    session.add(query_obj)

    from superset.queries.dao import QueryDAO

    query = session.query(Query).one()
    QueryDAO.save_metadata(query=query, payload={"columns": []})
    assert query.extra.get("columns", None) == []
Exemplo n.º 11
0
def test_get_cancel_query_id(engine_mock: mock.Mock) -> None:
    from superset.db_engine_specs.snowflake import SnowflakeEngineSpec
    from superset.models.sql_lab import Query

    query = Query()
    cursor_mock = engine_mock.return_value.__enter__.return_value
    cursor_mock.fetchone.return_value = [123]
    assert SnowflakeEngineSpec.get_cancel_query_id(cursor_mock, query) == 123
Exemplo n.º 12
0
def example_query(
        get_or_create_user: Callable[..., ContextManager[ab_models.User]]):
    with get_or_create_user("sqllab-test-user") as user:
        query = Query(client_id=shortid()[:10],
                      database=get_example_database(),
                      user=user)
        db.session.add(query)
        db.session.commit()
        yield query
        db.session.delete(query)
        db.session.commit()
Exemplo n.º 13
0
def test_query_has_access(mocker: MockFixture) -> None:
    from superset.explore.utils import check_datasource_access
    from superset.models.sql_lab import Query

    mocker.patch(query_find_by_id, return_value=Query())
    mocker.patch(raise_for_access, return_value=True)
    mocker.patch(is_admin, return_value=False)
    mocker.patch(is_owner, return_value=False)
    mocker.patch(can_access, return_value=True)
    assert (check_datasource_access(
        datasource_id=1,
        datasource_type=DatasourceType.QUERY,
    ) == True)
Exemplo n.º 14
0
    def create_query(self):
        with self.create_app().app_context():
            session = db.session

            query = Query(
                sql="select 1 as foo;",
                client_id="sldkfjlk",
                database=get_example_database(),
            )

            session.add(query)
            session.commit()

            yield query

            # rollback
            session.delete(query)
            session.commit()
Exemplo n.º 15
0
def test_query_no_access(mocker: MockFixture, app_context: AppContext) -> None:
    from superset.connectors.sqla.models import SqlaTable
    from superset.explore.utils import check_datasource_access
    from superset.models.core import Database
    from superset.models.sql_lab import Query

    with raises(SupersetSecurityException):
        mocker.patch(
            query_find_by_id,
            return_value=Query(database=Database(), sql="select * from foo"),
        )
        mocker.patch(query_datasources_by_name, return_value=[SqlaTable()])
        mocker.patch(is_user_admin, return_value=False)
        mocker.patch(is_owner, return_value=False)
        mocker.patch(can_access, return_value=False)
        check_datasource_access(
            datasource_id=1,
            datasource_type=DatasourceType.QUERY,
        )
Exemplo n.º 16
0
    def handle_cursor(cls, cursor: Any, query: Query, session: Session) -> None:
        """Updates progress information"""
        query_id = query.id
        poll_interval = query.database.connect_args.get(
            "poll_interval", current_app.config["PRESTO_POLL_INTERVAL"]
        )
        logger.info("Query %i: Polling the cursor for progress", query_id)
        polled = cursor.poll()
        # poll returns dict -- JSON status information or ``None``
        # if the query is done
        # https://github.com/dropbox/PyHive/blob/
        # b34bdbf51378b3979eaf5eca9e956f06ddc36ca0/pyhive/presto.py#L178
        while polled:
            # Update the object and wait for the kill signal.
            stats = polled.get("stats", {})

            query = session.query(type(query)).filter_by(id=query_id).one()
            if query.status in [QueryStatus.STOPPED, QueryStatus.TIMED_OUT]:
                cursor.cancel()
                break

            if stats:
                state = stats.get("state")

                # if already finished, then stop polling
                if state == "FINISHED":
                    break

                completed_splits = float(stats.get("completedSplits"))
                total_splits = float(stats.get("totalSplits"))
                if total_splits and completed_splits:
                    progress = 100 * (completed_splits / total_splits)
                    logger.info(
                        "Query {} progress: {} / {} "  # pylint: disable=logging-format-interpolation
                        "splits".format(query_id, completed_splits, total_splits)
                    )
                    if progress > query.progress:
                        query.progress = progress
                    session.commit()
            time.sleep(poll_interval)
            logger.info("Query %i: Polling the cursor for progress", query_id)
            polled = cursor.poll()
Exemplo n.º 17
0
 def save_metadata(query: Query, payload: Dict[str, Any]) -> None:
     # pull relevant data from payload and store in extra_json
     columns = payload.get("columns", {})
     db.session.add(query)
     query.set_extra_json_key("columns", columns)
Exemplo n.º 18
0
def test_sql_lab_insert_rls(
    mocker: MockerFixture,
    session: Session,
    app_context: None,
) -> None:
    """
    Integration test for `insert_rls`.
    """
    from flask_appbuilder.security.sqla.models import Role, User

    from superset.connectors.sqla.models import RowLevelSecurityFilter, SqlaTable
    from superset.models.core import Database
    from superset.models.sql_lab import Query
    from superset.security.manager import SupersetSecurityManager
    from superset.sql_lab import execute_sql_statement
    from superset.utils.core import RowLevelSecurityFilterType

    engine = session.connection().engine
    Query.metadata.create_all(engine)  # pylint: disable=no-member

    connection = engine.raw_connection()
    connection.execute("CREATE TABLE t (c INTEGER)")
    for i in range(10):
        connection.execute("INSERT INTO t VALUES (?)", (i, ))

    cursor = connection.cursor()

    query = Query(
        sql="SELECT c FROM t",
        client_id="abcde",
        database=Database(database_name="test_db", sqlalchemy_uri="sqlite://"),
        schema=None,
        limit=5,
        select_as_cta_used=False,
    )
    session.add(query)
    session.commit()

    admin = User(
        first_name="Alice",
        last_name="Doe",
        email="*****@*****.**",
        username="******",
        roles=[Role(name="Admin")],
    )

    # first without RLS
    with override_user(admin):
        superset_result_set = execute_sql_statement(
            sql_statement=query.sql,
            query=query,
            session=session,
            cursor=cursor,
            log_params=None,
            apply_ctas=False,
        )
    assert (superset_result_set.to_pandas_df().to_markdown() == """
|    |   c |
|---:|----:|
|  0 |   0 |
|  1 |   1 |
|  2 |   2 |
|  3 |   3 |
|  4 |   4 |""".strip())
    assert query.executed_sql == "SELECT c FROM t\nLIMIT 6"

    # now with RLS
    rls = RowLevelSecurityFilter(
        filter_type=RowLevelSecurityFilterType.REGULAR,
        tables=[SqlaTable(database_id=1, schema=None, table_name="t")],
        roles=[admin.roles[0]],
        group_key=None,
        clause="c > 5",
    )
    session.add(rls)
    session.flush()
    mocker.patch.object(SupersetSecurityManager,
                        "find_user",
                        return_value=admin)
    mocker.patch("superset.sql_lab.is_feature_enabled", return_value=True)

    with override_user(admin):
        superset_result_set = execute_sql_statement(
            sql_statement=query.sql,
            query=query,
            session=session,
            cursor=cursor,
            log_params=None,
            apply_ctas=False,
        )
    assert (superset_result_set.to_pandas_df().to_markdown() == """
|    |   c |
|---:|----:|
|  0 |   6 |
|  1 |   7 |
|  2 |   8 |
|  3 |   9 |""".strip())
    assert query.executed_sql == "SELECT c FROM t WHERE (t.c > 5)\nLIMIT 6"
Exemplo n.º 19
0
    def handle_cursor(  # pylint: disable=too-many-locals
            cls, cursor: Any, query: Query, session: Session) -> None:
        """Updates progress information"""
        # pylint: disable=import-outside-toplevel
        from pyhive import hive

        unfinished_states = (
            hive.ttypes.TOperationState.INITIALIZED_STATE,
            hive.ttypes.TOperationState.RUNNING_STATE,
        )
        polled = cursor.poll()
        last_log_line = 0
        tracking_url = None
        job_id = None
        query_id = query.id
        while polled.operationState in unfinished_states:
            # Queries don't terminate when user clicks the STOP button on SQL LAB.
            # Refresh session so that the `query.status` modified in stop_query in
            # views/core.py is reflected here.
            session.refresh(query)
            query = session.query(type(query)).filter_by(id=query_id).one()
            if query.status == QueryStatus.STOPPED:
                cursor.cancel()
                break

            try:
                log = cursor.fetch_logs() or ""
            except Exception:  # pylint: disable=broad-except
                logger.warning("Call to GetLog() failed")
                log = ""

            if log:
                log_lines = log.splitlines()
                progress = cls.progress(log_lines)
                logger.info("Query %s: Progress total: %s", str(query_id),
                            str(progress))
                needs_commit = False
                if progress > query.progress:
                    query.progress = progress
                    needs_commit = True
                if not tracking_url:
                    tracking_url = cls.get_tracking_url(log_lines)
                    if tracking_url:
                        job_id = tracking_url.split("/")[-2]
                        logger.info(
                            "Query %s: Found the tracking url: %s",
                            str(query_id),
                            tracking_url,
                        )
                        transformer = current_app.config[
                            "TRACKING_URL_TRANSFORMER"]
                        tracking_url = transformer(tracking_url)
                        logger.info(
                            "Query %s: Transformation applied: %s",
                            str(query_id),
                            tracking_url,
                        )
                        query.tracking_url = tracking_url
                        logger.info("Query %s: Job id: %s", str(query_id),
                                    str(job_id))
                        needs_commit = True
                if job_id and len(log_lines) > last_log_line:
                    # Wait for job id before logging things out
                    # this allows for prefixing all log lines and becoming
                    # searchable in something like Kibana
                    for l in log_lines[last_log_line:]:
                        logger.info("Query %s: [%s] %s", str(query_id),
                                    str(job_id), l)
                    last_log_line = len(log_lines)
                if needs_commit:
                    session.commit()
            time.sleep(current_app.config["HIVE_POLL_INTERVAL"])
            polled = cursor.poll()
Exemplo n.º 20
0
def get_one_report(id):
    o = db.session.query(SavedQuery).filter_by(id=id).first()
    desc = {}
    try:
        desc = json.loads(o.description)
    except ValueError:
        pass

    if request.method == 'GET':
        return jsonify({
            'id': o.id,
            'created_on': o.created_on.strftime('%Y-%m-%d'),
            'changed_on': o.changed_on.strftime('%Y-%m-%d'),
            'user_id': o.user_id or '',
            'db_id': o.db_id or '',
            'label': o.label or '',
            'schema': o.schema or '',
            'sql': o.sql or '',
            'description': desc,
        })

    elif request.method == 'POST':
        qjson = request.json

        sql = o.sql
        database_id = o.db_id
        schema = o.schema
        label = o.label

        session = db.session()
        mydb = session.query(models.Database).filter_by(id=database_id).first()

        # paginate
        page = qjson.get('page', 1)
        per_page = qjson.get('per_page', app.config['REPORT_PER_PAGE'])

        hkey = get_hash_key()

        # # parse config; filters and fields and sorts
        # qsort = ["ds","desc"]
        qsort = qjson.get('sort', [])
        sort = " order by _.%s %s" % (qsort[0], qsort[1]) if qsort else ""
        # date transfer problem solve after
        #filters = [ {"field":"ds", "type":"range", "value1":"2016-01-01", "value2":"2017-01-03", "help":u"date字段"} ]
        filters = qjson.get('filterfield_set', [])

        fs = []
        for f in filters:
            if f['type'] == 'range':
                fs.append(
                    "(_.%(field)s >= '%(value1)s' and _.%(field)s < '%(value2)s')"
                    % f)
            elif f['type'] == 'like':
                fs.append("_.%(field)s like '%%%(value1)s%%'" % f)
            else:
                fs.append("_.%(field)s %(type)s '%(value1)s'")

        where = " where " + (" and ".join(fs)) if fs else ""

        # count_sql = "SELECT count(1) as num FROM (%s) _"%sql
        # sql = "SELECT * FROM (%s) _ LIMIT %s,%s"%(sql, (page-1)*per_page, per_page)
        # # sql can't end with `;` , complicated sql use select .. as ..

        count_sql = "SELECT count(1) as num FROM (%s) _ %s" % (sql, where)
        sql = "SELECT * FROM (%s) _ %s %s LIMIT %s,%s" % (sql, where, sort,
                                                          (page - 1) *
                                                          per_page, per_page)

        if True:
            query = Query(
                database_id=int(database_id),
                limit=1000000,  #int(app.config.get('SQL_MAX_ROW', None)),
                sql=sql,
                schema=schema,
                select_as_cta=False,
                start_time=utils.now_as_float(),
                tab_name=label,
                status=QueryStatus.RUNNING,
                sql_editor_id=hkey[0] + hkey[1],
                tmp_table_name='',
                user_id=int(g.user.get_id()),
                client_id=hkey[2] + hkey[3],
            )
            session.add(query)

            cquery = Query(
                database_id=int(database_id),
                limit=1000000,  #int(app.config.get('SQL_MAX_ROW', None)),
                sql=count_sql,
                schema=schema,
                select_as_cta=False,
                start_time=utils.now_as_float(),
                tab_name=label,
                status=QueryStatus.RUNNING,
                sql_editor_id=hkey[0] + hkey[1],
                tmp_table_name='',
                user_id=int(g.user.get_id()),
                client_id=hkey[0] + hkey[1],
            )
            session.add(cquery)

            session.flush()
            db.session.commit()
            query_id = query.id
            cquery_id = cquery.id

            data = sql_lab.get_sql_results(query_id=query_id,
                                           return_results=True,
                                           template_params={})

            cdata = sql_lab.get_sql_results(query_id=cquery_id,
                                            return_results=True,
                                            template_params={})

            return jsonify({
                'data':
                data['data'],
                'id':
                id,
                'label':
                label,
                'query_id':
                data['query_id'],
                'limit':
                data['query']['limit'],
                'limit_reached':
                False,
                'page':
                page,
                'per_page':
                per_page,
                'pages':
                get_pages(cdata['data'][0]['num'], per_page),
                'total':
                cdata['data'][0]['num'],
                'rows':
                data['query']['rows'],
                'sort':
                qsort,
                'changed_on':
                data['query']['changed_on'],
                'displayfield_set':
                desc['displayfield_set'],
                'report_file':
                url_for('download_one_report',
                        id=id,
                        query_id=data['query_id']),
                'status':
                'success',
            })

        return 'ok'
Exemplo n.º 21
0
def execute_sql_statement(  # pylint: disable=too-many-arguments,too-many-statements
    sql_statement: str,
    query: Query,
    session: Session,
    cursor: Any,
    log_params: Optional[Dict[str, Any]],
    apply_ctas: bool = False,
) -> SupersetResultSet:
    """Executes a single SQL statement"""
    database: Database = query.database
    db_engine_spec = database.db_engine_spec

    parsed_query = ParsedQuery(sql_statement)
    if is_feature_enabled("RLS_IN_SQLLAB"):
        # Insert any applicable RLS predicates
        parsed_query = ParsedQuery(
            str(
                insert_rls(
                    parsed_query._parsed[0],  # pylint: disable=protected-access
                    database.id,
                    query.schema,
                )))

    sql = parsed_query.stripped()
    # This is a test to see if the query is being
    # limited by either the dropdown or the sql.
    # We are testing to see if more rows exist than the limit.
    increased_limit = None if query.limit is None else query.limit + 1

    if not db_engine_spec.is_readonly_query(
            parsed_query) and not database.allow_dml:
        raise SupersetErrorException(
            SupersetError(
                message=__(
                    "Only SELECT statements are allowed against this database."
                ),
                error_type=SupersetErrorType.DML_NOT_ALLOWED_ERROR,
                level=ErrorLevel.ERROR,
            ))
    if apply_ctas:
        if not query.tmp_table_name:
            start_dttm = datetime.fromtimestamp(query.start_time)
            query.tmp_table_name = "tmp_{}_table_{}".format(
                query.user_id, start_dttm.strftime("%Y_%m_%d_%H_%M_%S"))
        sql = parsed_query.as_create_table(
            query.tmp_table_name,
            schema_name=query.tmp_schema_name,
            method=query.ctas_method,
        )
        query.select_as_cta_used = True

    # Do not apply limit to the CTA queries when SQLLAB_CTAS_NO_LIMIT is set to true
    if db_engine_spec.is_select_query(parsed_query) and not (
            query.select_as_cta_used and SQLLAB_CTAS_NO_LIMIT):
        if SQL_MAX_ROW and (not query.limit or query.limit > SQL_MAX_ROW):
            query.limit = SQL_MAX_ROW
        sql = apply_limit_if_exists(database, increased_limit, query, sql)

    # Hook to allow environment-specific mutation (usually comments) to the SQL
    sql = SQL_QUERY_MUTATOR(
        sql,
        user_name=get_username(),  # TODO(john-bodley): Deprecate in 3.0.
        security_manager=security_manager,
        database=database,
    )
    try:
        query.executed_sql = sql
        if log_query:
            log_query(
                query.database.sqlalchemy_uri,
                query.executed_sql,
                query.schema,
                get_username(),
                __name__,
                security_manager,
                log_params,
            )
        session.commit()
        with stats_timing("sqllab.query.time_executing_query", stats_logger):
            logger.debug("Query %d: Running query: %s", query.id, sql)
            db_engine_spec.execute(cursor, sql, async_=True)
            logger.debug("Query %d: Handling cursor", query.id)
            db_engine_spec.handle_cursor(cursor, query, session)

        with stats_timing("sqllab.query.time_fetching_results", stats_logger):
            logger.debug(
                "Query %d: Fetching data for query object: %s",
                query.id,
                str(query.to_dict()),
            )
            data = db_engine_spec.fetch_data(cursor, increased_limit)
            if query.limit is None or len(data) <= query.limit:
                query.limiting_factor = LimitingFactor.NOT_LIMITED
            else:
                # return 1 row less than increased_query
                data = data[:-1]
    except SoftTimeLimitExceeded as ex:
        query.status = QueryStatus.TIMED_OUT

        logger.warning("Query %d: Time limit exceeded", query.id)
        logger.debug("Query %d: %s", query.id, ex)
        raise SupersetErrorException(
            SupersetError(
                message=__(
                    "The query was killed after %(sqllab_timeout)s seconds. It might "
                    "be too complex, or the database might be under heavy load.",
                    sqllab_timeout=SQLLAB_TIMEOUT,
                ),
                error_type=SupersetErrorType.SQLLAB_TIMEOUT_ERROR,
                level=ErrorLevel.ERROR,
            )) from ex
    except Exception as ex:
        # query is stopped in another thread/worker
        # stopping raises expected exceptions which we should skip
        session.refresh(query)
        if query.status == QueryStatus.STOPPED:
            raise SqlLabQueryStoppedException() from ex

        logger.error("Query %d: %s", query.id, type(ex), exc_info=True)
        logger.debug("Query %d: %s", query.id, ex)
        raise SqlLabException(db_engine_spec.extract_error_message(ex)) from ex

    logger.debug("Query %d: Fetching cursor description", query.id)
    cursor_description = cursor.description
    return SupersetResultSet(data, cursor_description, db_engine_spec)
 def test_cancel_query(self, engine_mock):
     query = Query()
     cursor_mock = engine_mock.return_value.__enter__.return_value
     assert SnowflakeEngineSpec.cancel_query(cursor_mock, query, 123) is True
Exemplo n.º 23
0
 def _to_payload_query_based(  # pylint: disable=no-self-use
         self, query: Query) -> str:
     return json.dumps({"query": query.to_dict()},
                       default=utils.json_int_dttm_ser,
                       ignore_nan=True)
Exemplo n.º 24
0
def execute_sql_statement(
    sql_statement: str,
    query: Query,
    user_name: Optional[str],
    session: Session,
    cursor: Any,
    log_params: Optional[Dict[str, Any]],
) -> SupersetResultSet:
    """Executes a single SQL statement"""
    database = query.database
    db_engine_spec = database.db_engine_spec
    parsed_query = ParsedQuery(sql_statement)
    sql = parsed_query.stripped()

    if not parsed_query.is_readonly() and not database.allow_dml:
        raise SqlLabSecurityException(
            _("Only `SELECT` statements are allowed against this database"))
    if query.select_as_cta:
        if not parsed_query.is_select():
            raise SqlLabException(
                _("Only `SELECT` statements can be used with the CREATE TABLE "
                  "feature."))
        if not query.tmp_table_name:
            start_dttm = datetime.fromtimestamp(query.start_time)
            query.tmp_table_name = "tmp_{}_table_{}".format(
                query.user_id, start_dttm.strftime("%Y_%m_%d_%H_%M_%S"))
        sql = parsed_query.as_create_table(
            query.tmp_table_name,
            schema_name=query.tmp_schema_name,
            method=query.ctas_method,
        )
        query.select_as_cta_used = True

    # Do not apply limit to the CTA queries when SQLLAB_CTAS_NO_LIMIT is set to true
    if parsed_query.is_select() and not (query.select_as_cta_used
                                         and SQLLAB_CTAS_NO_LIMIT):
        if SQL_MAX_ROW and (not query.limit or query.limit > SQL_MAX_ROW):
            query.limit = SQL_MAX_ROW
        if query.limit:
            sql = database.apply_limit_to_sql(sql, query.limit)

    # Hook to allow environment-specific mutation (usually comments) to the SQL
    if SQL_QUERY_MUTATOR:
        sql = SQL_QUERY_MUTATOR(sql, user_name, security_manager, database)

    try:
        if log_query:
            log_query(
                query.database.sqlalchemy_uri,
                query.executed_sql,
                query.schema,
                user_name,
                __name__,
                security_manager,
                log_params,
            )
        query.executed_sql = sql
        session.commit()
        with stats_timing("sqllab.query.time_executing_query", stats_logger):
            logger.debug("Query %d: Running query: %s", query.id, sql)
            db_engine_spec.execute(cursor, sql, async_=True)
            logger.debug("Query %d: Handling cursor", query.id)
            db_engine_spec.handle_cursor(cursor, query, session)

        with stats_timing("sqllab.query.time_fetching_results", stats_logger):
            logger.debug(
                "Query %d: Fetching data for query object: %s",
                query.id,
                str(query.to_dict()),
            )
            data = db_engine_spec.fetch_data(cursor, query.limit)

    except SoftTimeLimitExceeded as ex:
        logger.error("Query %d: Time limit exceeded", query.id)
        logger.debug("Query %d: %s", query.id, ex)
        raise SqlLabTimeoutException(
            "SQL Lab timeout. This environment's policy is to kill queries "
            "after {} seconds.".format(SQLLAB_TIMEOUT))
    except Exception as ex:
        logger.error("Query %d: %s", query.id, type(ex))
        logger.debug("Query %d: %s", query.id, ex)
        raise SqlLabException(db_engine_spec.extract_error_message(ex))

    logger.debug("Query %d: Fetching cursor description", query.id)
    cursor_description = cursor.description
    return SupersetResultSet(data, cursor_description, db_engine_spec)
Exemplo n.º 25
0
def execute_sql_statement(
    sql_statement: str,
    query: Query,
    user_name: Optional[str],
    session: Session,
    cursor: Any,
    log_params: Optional[Dict[str, Any]],
    apply_ctas: bool = False,
) -> SupersetResultSet:
    """Executes a single SQL statement"""
    database = query.database
    db_engine_spec = database.db_engine_spec
    parsed_query = ParsedQuery(sql_statement)
    sql = parsed_query.stripped()
    # This is a test to see if the query is being
    # limited by either the dropdown or the sql.
    # We are testing to see if more rows exist than the limit.
    increased_limit = None if query.limit is None else query.limit + 1

    if not db_engine_spec.is_readonly_query(parsed_query) and not database.allow_dml:
        raise SupersetErrorException(
            SupersetError(
                message=__("Only SELECT statements are allowed against this database."),
                error_type=SupersetErrorType.DML_NOT_ALLOWED_ERROR,
                level=ErrorLevel.ERROR,
            )
        )
    if apply_ctas:
        if not query.tmp_table_name:
            start_dttm = datetime.fromtimestamp(query.start_time)
            query.tmp_table_name = "tmp_{}_table_{}".format(
                query.user_id, start_dttm.strftime("%Y_%m_%d_%H_%M_%S")
            )
        sql = parsed_query.as_create_table(
            query.tmp_table_name,
            schema_name=query.tmp_schema_name,
            method=query.ctas_method,
        )
        query.select_as_cta_used = True

    # Do not apply limit to the CTA queries when SQLLAB_CTAS_NO_LIMIT is set to true
    if db_engine_spec.is_select_query(parsed_query) and not (
        query.select_as_cta_used and SQLLAB_CTAS_NO_LIMIT
    ):
        if SQL_MAX_ROW and (not query.limit or query.limit > SQL_MAX_ROW):
            query.limit = SQL_MAX_ROW
        if query.limit:
            # We are fetching one more than the requested limit in order
            # to test whether there are more rows than the limit.
            # Later, the extra row will be dropped before sending
            # the results back to the user.
            sql = database.apply_limit_to_sql(sql, increased_limit, force=True)

    # Hook to allow environment-specific mutation (usually comments) to the SQL
    sql = SQL_QUERY_MUTATOR(sql, user_name, security_manager, database)
    try:
        query.executed_sql = sql
        if log_query:
            log_query(
                query.database.sqlalchemy_uri,
                query.executed_sql,
                query.schema,
                user_name,
                __name__,
                security_manager,
                log_params,
            )
        session.commit()
        with stats_timing("sqllab.query.time_executing_query", stats_logger):
            logger.debug("Query %d: Running query: %s", query.id, sql)
            db_engine_spec.execute(cursor, sql, async_=True)
            logger.debug("Query %d: Handling cursor", query.id)
            db_engine_spec.handle_cursor(cursor, query, session)

        with stats_timing("sqllab.query.time_fetching_results", stats_logger):
            logger.debug(
                "Query %d: Fetching data for query object: %s",
                query.id,
                str(query.to_dict()),
            )
            data = db_engine_spec.fetch_data(cursor, increased_limit)
            if query.limit is None or len(data) <= query.limit:
                query.limiting_factor = LimitingFactor.NOT_LIMITED
            else:
                # return 1 row less than increased_query
                data = data[:-1]
    except SoftTimeLimitExceeded as ex:
        logger.warning("Query %d: Time limit exceeded", query.id)
        logger.debug("Query %d: %s", query.id, ex)
        raise SupersetErrorException(
            SupersetError(
                message=__(
                    f"The query was killed after {SQLLAB_TIMEOUT} seconds. It might "
                    "be too complex, or the database might be under heavy load."
                ),
                error_type=SupersetErrorType.SQLLAB_TIMEOUT_ERROR,
                level=ErrorLevel.ERROR,
            )
        )
    except Exception as ex:
        logger.error("Query %d: %s", query.id, type(ex), exc_info=True)
        logger.debug("Query %d: %s", query.id, ex)
        raise SqlLabException(db_engine_spec.extract_error_message(ex))

    logger.debug("Query %d: Fetching cursor description", query.id)
    cursor_description = cursor.description
    return SupersetResultSet(data, cursor_description, db_engine_spec)
 def test_cancel_query_failed(self, engine_mock):
     query = Query()
     cursor_mock = engine_mock.raiseError.side_effect = Exception()
     assert SnowflakeEngineSpec.cancel_query(cursor_mock, query, 123) is False
Exemplo n.º 27
0
def session_with_data(session: Session) -> Iterator[Session]:
    from superset.columns.models import Column
    from superset.connectors.sqla.models import SqlaTable, TableColumn
    from superset.datasets.models import Dataset
    from superset.models.core import Database
    from superset.models.sql_lab import Query, SavedQuery
    from superset.tables.models import Table

    engine = session.get_bind()
    SqlaTable.metadata.create_all(engine)  # pylint: disable=no-member

    db = Database(database_name="my_database", sqlalchemy_uri="sqlite://")

    columns = [
        TableColumn(column_name="a", type="INTEGER"),
    ]

    sqla_table = SqlaTable(
        table_name="my_sqla_table",
        columns=columns,
        metrics=[],
        database=db,
    )

    query_obj = Query(
        client_id="foo",
        database=db,
        tab_name="test_tab",
        sql_editor_id="test_editor_id",
        sql="select * from bar",
        select_sql="select * from bar",
        executed_sql="select * from bar",
        limit=100,
        select_as_cta=False,
        rows=100,
        error_message="none",
        results_key="abc",
    )

    saved_query = SavedQuery(database=db, sql="select * from foo")

    table = Table(
        name="my_table",
        schema="my_schema",
        catalog="my_catalog",
        database=db,
        columns=[],
    )

    dataset = Dataset(
        database=table.database,
        name="positions",
        expression="""
SELECT array_agg(array[longitude,latitude]) AS position
FROM my_catalog.my_schema.my_table
""",
        tables=[table],
        columns=[
            Column(
                name="position",
                expression="array_agg(array[longitude,latitude])",
            ),
        ],
    )

    session.add(dataset)
    session.add(table)
    session.add(saved_query)
    session.add(query_obj)
    session.add(db)
    session.add(sqla_table)
    session.flush()
    yield session
 def test_get_cancel_query_id(self, engine_mock):
     query = Query()
     cursor_mock = engine_mock.return_value.__enter__.return_value
     cursor_mock.fetchone.return_value = [123]
     assert SnowflakeEngineSpec.get_cancel_query_id(cursor_mock, query) == 123
Exemplo n.º 29
0
def get_table_metadata(database: Database, table_name: str,
                       schema_name: Optional[str]) -> Dict:
    """
        Get table metadata information, including type, pk, fks.
        This function raises SQLAlchemyError when a schema is not found.


    :param database: The database model
    :param table_name: Table name
    :param schema_name: schema name
    :return: Dict table metadata ready for API response
    """
    keys: List = []
    columns = database.get_columns(table_name, schema_name)
    # define comment dict by tsl
    comment_dict = {}
    primary_key = database.get_pk_constraint(table_name, schema_name)
    if primary_key and primary_key.get("constrained_columns"):
        primary_key["column_names"] = primary_key.pop("constrained_columns")
        primary_key["type"] = "pk"
        keys += [primary_key]
    # get dialect name
    dialect_name = database.get_dialect().name
    if isinstance(dialect_name, bytes):
        dialect_name = dialect_name.decode()
    # get column comment, presto & hive
    if dialect_name == "presto" or dialect_name == "hive":
        db_engine_spec = database.db_engine_spec
        sql = ParsedQuery("desc {a}.{b}".format(a=schema_name,
                                                b=table_name)).stripped()
        engine = database.get_sqla_engine(schema_name)
        conn = engine.raw_connection()
        cursor = conn.cursor()
        query = Query()
        session = Session(bind=engine)
        query.executed_sql = sql
        query.__tablename__ = table_name
        session.commit()
        db_engine_spec.execute(cursor, sql, async_=False)
        data = db_engine_spec.fetch_data(cursor, query.limit)
        # parse list data into dict by tsl; hive and presto is different
        if dialect_name == "presto":
            for d in data:
                d[3]
                comment_dict[d[0]] = d[3]
        else:
            for d in data:
                d[2]
                comment_dict[d[0]] = d[2]
        conn.commit()

    foreign_keys = get_foreign_keys_metadata(database, table_name, schema_name)
    indexes = get_indexes_metadata(database, table_name, schema_name)
    keys += foreign_keys + indexes
    payload_columns: List[Dict] = []
    for col in columns:
        dtype = get_col_type(col)
        if len(comment_dict) > 0:
            payload_columns.append({
                "name":
                col["name"],
                "type":
                dtype.split("(")[0] if "(" in dtype else dtype,
                "longType":
                dtype,
                "keys":
                [k for k in keys if col["name"] in k.get("column_names")],
                "comment":
                comment_dict[col["name"]],
            })
        elif dialect_name == "mysql":
            payload_columns.append({
                "name":
                col["name"],
                "type":
                dtype.split("(")[0] if "(" in dtype else dtype,
                "longType":
                dtype,
                "keys":
                [k for k in keys if col["name"] in k.get("column_names")],
                "comment":
                col["comment"],
            })
        else:
            payload_columns.append({
                "name":
                col["name"],
                "type":
                dtype.split("(")[0] if "(" in dtype else dtype,
                "longType":
                dtype,
                "keys":
                [k for k in keys if col["name"] in k.get("column_names")],
                # "comment": col["comment"],
            })
    return {
        "name":
        table_name,
        "columns":
        payload_columns,
        "selectStar":
        database.select_star(
            table_name,
            schema=schema_name,
            show_cols=True,
            indent=True,
            cols=columns,
            latest_partition=True,
        ),
        "primaryKey":
        primary_key,
        "foreignKeys":
        foreign_keys,
        "indexes":
        keys,
    }
Exemplo n.º 30
0
    def handle_cursor(  # pylint: disable=too-many-locals
            cls, cursor: Any, query: Query, session: Session) -> None:
        """Updates progress information"""
        from pyhive import hive

        unfinished_states = (
            hive.ttypes.TOperationState.INITIALIZED_STATE,
            hive.ttypes.TOperationState.RUNNING_STATE,
        )
        polled = cursor.poll()
        last_log_line = 0
        tracking_url = None
        job_id = None
        query_id = query.id
        while polled.operationState in unfinished_states:
            query = session.query(type(query)).filter_by(id=query_id).one()
            if query.status == QueryStatus.STOPPED:
                cursor.cancel()
                break

            log = cursor.fetch_logs() or ""
            if log:
                log_lines = log.splitlines()
                progress = cls.progress(log_lines)
                logger.info("Query %s: Progress total: %s", str(query_id),
                            str(progress))
                needs_commit = False
                if progress > query.progress:
                    query.progress = progress
                    needs_commit = True
                if not tracking_url:
                    tracking_url = cls.get_tracking_url(log_lines)
                    if tracking_url:
                        job_id = tracking_url.split("/")[-2]
                        logger.info(
                            "Query %s: Found the tracking url: %s",
                            str(query_id),
                            tracking_url,
                        )
                        tracking_url = current_app.config[
                            "TRACKING_URL_TRANSFORMER"]
                        logger.info(
                            "Query %s: Transformation applied: %s",
                            str(query_id),
                            tracking_url,
                        )
                        query.tracking_url = tracking_url
                        logger.info("Query %s: Job id: %s", str(query_id),
                                    str(job_id))
                        needs_commit = True
                if job_id and len(log_lines) > last_log_line:
                    # Wait for job id before logging things out
                    # this allows for prefixing all log lines and becoming
                    # searchable in something like Kibana
                    for l in log_lines[last_log_line:]:
                        logger.info("Query %s: [%s] %s", str(query_id),
                                    str(job_id), l)
                    last_log_line = len(log_lines)
                if needs_commit:
                    session.commit()
            time.sleep(current_app.config["HIVE_POLL_INTERVAL"])
            polled = cursor.poll()