Exemplo n.º 1
0
def test_basic_breakdown_statements() -> None:
    """
    Test that multiple statements are parsed correctly.
    """
    query = ParsedQuery(
        """
SELECT * FROM birth_names;
SELECT * FROM birth_names LIMIT 1;
"""
    )
    assert query.get_statements() == [
        "SELECT * FROM birth_names",
        "SELECT * FROM birth_names LIMIT 1",
    ]
Exemplo n.º 2
0
def test_kql_is_select_query(app_context: AppContext, kql: str,
                             expected: bool) -> None:
    """
    Make sure that KQL dialect consider only statements that do not start with "." (dot)
    as a SELECT statements
    """

    from superset.db_engine_specs.kusto import KustoKqlEngineSpec
    from superset.sql_parse import ParsedQuery

    parsed_query = ParsedQuery(kql)
    is_select = KustoKqlEngineSpec.is_select_query(parsed_query)

    assert expected == is_select
Exemplo n.º 3
0
def test_kql_is_readonly_query(
    app_context: AppContext, kql: str, expected: bool
) -> None:
    """
    Make sure that KQL dialect consider only SELECT statements as read-only
    """

    from superset.db_engine_specs.kusto import KustoKqlEngineSpec
    from superset.sql_parse import ParsedQuery

    parsed_query = ParsedQuery(kql)
    is_readonly = KustoKqlEngineSpec.is_readonly_query(parsed_query)

    assert expected == is_readonly
Exemplo n.º 4
0
    def test_is_valid_cvas(self):
        """A valid CVAS has a single SELECT statement"""
        query = "SELECT * FROM table"
        parsed = ParsedQuery(query, strip_comments=True)
        assert parsed.is_valid_cvas()

        query = """
            -- comment
            SELECT * FROM table
            -- comment 2
        """
        parsed = ParsedQuery(query, strip_comments=True)
        assert parsed.is_valid_cvas()

        query = """
            -- comment
            SET @value = 42;
            SELECT @value as foo;
            -- comment 2
        """
        parsed = ParsedQuery(query, strip_comments=True)
        assert not parsed.is_valid_cvas()

        query = """
            -- comment
            EXPLAIN SELECT * FROM table
            -- comment 2
        """
        parsed = ParsedQuery(query, strip_comments=True)
        assert not parsed.is_valid_ctas()

        query = """
            SELECT * FROM table;
            INSERT INTO TABLE (foo) VALUES (42);
        """
        parsed = ParsedQuery(query, strip_comments=True)
        assert not parsed.is_valid_ctas()
Exemplo n.º 5
0
def test_messy_breakdown_statements() -> None:
    """
    Test the messy multiple statements are parsed correctly.
    """
    query = ParsedQuery("""
SELECT 1;\t\n\n\n  \t
\t\nSELECT 2;
SELECT * FROM birth_names;;;
SELECT * FROM birth_names LIMIT 1
""")
    assert query.get_statements() == [
        "SELECT 1",
        "SELECT 2",
        "SELECT * FROM birth_names",
        "SELECT * FROM birth_names LIMIT 1",
    ]
Exemplo n.º 6
0
    def process_statement(cls, statement: str, database: "Database",
                          user_name: str) -> str:
        """
        Process a SQL statement by stripping and mutating it.

        :param statement: A single SQL statement
        :param database: Database instance
        :param username: Effective username
        :return: Dictionary with different costs
        """
        parsed_query = ParsedQuery(statement)
        sql = parsed_query.stripped()
        sql_query_mutator = config["SQL_QUERY_MUTATOR"]
        if sql_query_mutator:
            sql = sql_query_mutator(sql, user_name, security_manager, database)

        return sql
Exemplo n.º 7
0
 def test_messy_breakdown_statements(self):
     multi_sql = """
     SELECT 1;\t\n\n\n  \t
     \t\nSELECT 2;
     SELECT * FROM birth_names;;;
     SELECT * FROM birth_names LIMIT 1
     """
     parsed = ParsedQuery(multi_sql)
     statements = parsed.get_statements()
     self.assertEqual(len(statements), 4)
     expected = [
         "SELECT 1",
         "SELECT 2",
         "SELECT * FROM birth_names",
         "SELECT * FROM birth_names LIMIT 1",
     ]
     self.assertEqual(statements, expected)
Exemplo n.º 8
0
def test_cte_is_select() -> None:
    """
    Some CTEs are not correctly identified as SELECTS.
    """
    # `AS(` gets parsed as a function
    sql = ParsedQuery("""WITH foo AS(
SELECT
  FLOOR(__time TO WEEK) AS "week",
  name,
  COUNT(DISTINCT user_id) AS "unique_users"
FROM "druid"."my_table"
GROUP BY 1,2
)
SELECT
  f.week,
  f.name,
  f.unique_users
FROM foo f""")
    assert sql.is_select()
Exemplo n.º 9
0
def get_virtual_table_metadata(dataset: "SqlaTable") -> List[Dict[str, str]]:
    """Use SQLparser to get virtual dataset metadata"""
    if not dataset.sql:
        raise SupersetGenericDBErrorException(
            message=_("Virtual dataset query cannot be empty"),
        )

    db_engine_spec = dataset.database.db_engine_spec
    engine = dataset.database.get_sqla_engine(schema=dataset.schema)
    sql = dataset.get_template_processor().process_template(
        dataset.sql, **dataset.template_params_dict
    )
    parsed_query = ParsedQuery(sql)
    if not db_engine_spec.is_readonly_query(parsed_query):
        raise SupersetSecurityException(
            SupersetError(
                error_type=SupersetErrorType.DATASOURCE_SECURITY_ACCESS_ERROR,
                message=_("Only `SELECT` statements are allowed"),
                level=ErrorLevel.ERROR,
            )
        )
    statements = parsed_query.get_statements()
    if len(statements) > 1:
        raise SupersetSecurityException(
            SupersetError(
                error_type=SupersetErrorType.DATASOURCE_SECURITY_ACCESS_ERROR,
                message=_("Only single queries supported"),
                level=ErrorLevel.ERROR,
            )
        )
    # TODO(villebro): refactor to use same code that's used by
    #  sql_lab.py:execute_sql_statements
    try:
        with closing(engine.raw_connection()) as conn:
            cursor = conn.cursor()
            query = dataset.database.apply_limit_to_sql(statements[0])
            db_engine_spec.execute(cursor, query)
            result = db_engine_spec.fetch_data(cursor, limit=1)
            result_set = SupersetResultSet(result, cursor.description, db_engine_spec)
            cols = result_set.columns
    except Exception as exc:
        raise SupersetGenericDBErrorException(message=str(exc))
    return cols
Exemplo n.º 10
0
    def validate(
            cls,
            sql: str,
            schema: str,
            database: Any,
    ) -> List[SQLValidationAnnotation]:
        """
        Presto supports query-validation queries by running them with a
        prepended explain.

        For example, "SELECT 1 FROM default.mytable" becomes "EXPLAIN (TYPE
        VALIDATE) SELECT 1 FROM default.mytable.
        """
        user_name = g.user.username if g.user else None
        parsed_query = ParsedQuery(sql)
        statements = parsed_query.get_statements()

        logging.info(f'Validating {len(statements)} statement(s)')
        engine = database.get_sqla_engine(
            schema=schema,
            nullpool=True,
            user_name=user_name,
            source=sources.get('sql_lab', None),
        )
        # Sharing a single connection and cursor across the
        # execution of all statements (if many)
        annotations: List[SQLValidationAnnotation] = []
        with closing(engine.raw_connection()) as conn:
            with closing(conn.cursor()) as cursor:
                for statement in parsed_query.get_statements():
                    annotation = cls.validate_statement(
                        statement,
                        database,
                        cursor,
                        user_name,
                    )
                    if annotation:
                        annotations.append(annotation)
        logging.debug(f'Validation found {len(annotations)} error(s)')

        return annotations
Exemplo n.º 11
0
def test_is_explain() -> None:
    """
    Test that ``EXPLAIN`` is detected correctly.
    """
    assert ParsedQuery("EXPLAIN SELECT 1").is_explain() is True
    assert ParsedQuery("EXPLAIN SELECT 1").is_select() is False

    assert (ParsedQuery("""
-- comment
EXPLAIN select * from table
-- comment 2
""").is_explain() is True)

    assert (ParsedQuery("""
-- comment
EXPLAIN select * from table
where col1 = 'something'
-- comment 2

-- comment 3
EXPLAIN select * from table
where col1 = 'something'
-- comment 4
""").is_explain() is True)

    assert (ParsedQuery("""
-- This is a comment
    -- this is another comment but with a space in the front
EXPLAIN SELECT * FROM TABLE
""").is_explain() is True)

    assert (ParsedQuery("""
/* This is a comment
     with stars instead */
EXPLAIN SELECT * FROM TABLE
""").is_explain() is True)

    assert (ParsedQuery("""
-- comment
select * from table
where col1 = 'something'
-- comment 2
""").is_explain() is False)
Exemplo n.º 12
0
    def estimate_statement_cost(  # pylint: disable=too-many-locals
            cls, statement: str, database: "Database", cursor: Any,
            user_name: str) -> Dict[str, Any]:
        """
        Run a SQL query that estimates the cost of a given statement.

        :param statement: A single SQL statement
        :param database: Database instance
        :param cursor: Cursor instance
        :param username: Effective username
        :return: JSON response from Presto
        """
        parsed_query = ParsedQuery(statement)
        sql = parsed_query.stripped()

        sql_query_mutator = config["SQL_QUERY_MUTATOR"]
        if sql_query_mutator:
            sql = sql_query_mutator(sql, user_name, security_manager, database)

        sql = f"EXPLAIN (TYPE IO, FORMAT JSON) {sql}"
        cursor.execute(sql)

        # the output from Presto is a single column and a single row containing
        # JSON:
        #
        #   {
        #     ...
        #     "estimate" : {
        #       "outputRowCount" : 8.73265878E8,
        #       "outputSizeInBytes" : 3.41425774958E11,
        #       "cpuCost" : 3.41425774958E11,
        #       "maxMemory" : 0.0,
        #       "networkCost" : 3.41425774958E11
        #     }
        #   }
        result = json.loads(cursor.fetchone()[0])
        return result
Exemplo n.º 13
0
 def extract_tables(self, query):
     return ParsedQuery(query).tables
Exemplo n.º 14
0
def execute_sql_statements(
    query_id,
    rendered_query,
    return_results=True,
    store_results=False,
    user_name=None,
    session=None,
    start_time=None,
    expand_data=False,
):  # pylint: disable=too-many-arguments, too-many-locals, too-many-statements
    """Executes the sql query returns the results."""
    if store_results and start_time:
        # only asynchronous queries
        stats_logger.timing("sqllab.query.time_pending",
                            now_as_float() - start_time)

    query = get_query(query_id, session)
    payload = dict(query_id=query_id)
    database = query.database
    db_engine_spec = database.db_engine_spec
    db_engine_spec.patch()

    if database.allow_run_async and not results_backend:
        raise SqlLabException("Results backend isn't configured.")

    # Breaking down into multiple statements
    parsed_query = ParsedQuery(rendered_query)
    statements = parsed_query.get_statements()
    logger.info(f"Query {query_id}: Executing {len(statements)} statement(s)")

    logger.info(f"Query {query_id}: Set query to 'running'")
    query.status = QueryStatus.RUNNING
    query.start_running_time = now_as_float()
    session.commit()

    engine = database.get_sqla_engine(
        schema=query.schema,
        nullpool=True,
        user_name=user_name,
        source=sources.get("sql_lab", None),
    )
    # Sharing a single connection and cursor across the
    # execution of all statements (if many)
    with closing(engine.raw_connection()) as conn:
        with closing(conn.cursor()) as cursor:
            statement_count = len(statements)
            for i, statement in enumerate(statements):
                # Check if stopped
                query = get_query(query_id, session)
                if query.status == QueryStatus.STOPPED:
                    return None

                # Run statement
                msg = f"Running statement {i+1} out of {statement_count}"
                logger.info(f"Query {query_id}: {msg}")
                query.set_extra_json_key("progress", msg)
                session.commit()
                try:
                    cdf = execute_sql_statement(statement, query, user_name,
                                                session, cursor)
                except Exception as e:  # pylint: disable=broad-except
                    msg = str(e)
                    if statement_count > 1:
                        msg = f"[Statement {i+1} out of {statement_count}] " + msg
                    payload = handle_query_error(msg, query, session, payload)
                    return payload

    # Success, updating the query entry in database
    query.rows = cdf.size
    query.progress = 100
    query.set_extra_json_key("progress", None)
    if query.select_as_cta:
        query.select_sql = database.select_star(
            query.tmp_table_name,
            limit=query.limit,
            schema=database.force_ctas_schema,
            show_cols=False,
            latest_partition=False,
        )
    query.end_time = now_as_float()

    data, selected_columns, all_columns, expanded_columns = _serialize_and_expand_data(
        cdf, db_engine_spec, store_results and results_backend_use_msgpack,
        expand_data)

    payload.update({
        "status": QueryStatus.SUCCESS,
        "data": data,
        "columns": all_columns,
        "selected_columns": selected_columns,
        "expanded_columns": expanded_columns,
        "query": query.to_dict(),
    })
    payload["query"]["state"] = QueryStatus.SUCCESS

    if store_results and results_backend:
        key = str(uuid.uuid4())
        logger.info(
            f"Query {query_id}: Storing results in results backend, key: {key}"
        )
        with stats_timing("sqllab.query.results_backend_write", stats_logger):
            with stats_timing(
                    "sqllab.query.results_backend_write_serialization",
                    stats_logger):
                serialized_payload = _serialize_payload(
                    payload, results_backend_use_msgpack)
            cache_timeout = database.cache_timeout
            if cache_timeout is None:
                cache_timeout = config["CACHE_DEFAULT_TIMEOUT"]

            compressed = zlib_compress(serialized_payload)
            logger.debug(
                f"*** serialized payload size: {getsizeof(serialized_payload)}"
            )
            logger.debug(
                f"*** compressed payload size: {getsizeof(compressed)}")
            results_backend.set(key, compressed, cache_timeout)
        query.results_key = key

    query.status = QueryStatus.SUCCESS
    session.commit()

    if return_results:
        return payload

    return None
Exemplo n.º 15
0
def execute_sql_statement(sql_statement, query, user_name, session, cursor):
    """Executes a single SQL statement"""
    database = query.database
    db_engine_spec = database.db_engine_spec
    parsed_query = ParsedQuery(sql_statement)
    sql = parsed_query.stripped()

    if not parsed_query.is_readonly() and not database.allow_dml:
        raise SqlLabSecurityException(
            _("Only `SELECT` statements are allowed against this database"))
    if query.select_as_cta:
        if not parsed_query.is_select():
            raise SqlLabException(
                _("Only `SELECT` statements can be used with the CREATE TABLE "
                  "feature."))
        if not query.tmp_table_name:
            start_dttm = datetime.fromtimestamp(query.start_time)
            query.tmp_table_name = "tmp_{}_table_{}".format(
                query.user_id, start_dttm.strftime("%Y_%m_%d_%H_%M_%S"))
        sql = parsed_query.as_create_table(query.tmp_table_name)
        query.select_as_cta_used = True
    if parsed_query.is_select():
        if SQL_MAX_ROW and (not query.limit or query.limit > SQL_MAX_ROW):
            query.limit = SQL_MAX_ROW
        if query.limit:
            sql = database.apply_limit_to_sql(sql, query.limit)

    # Hook to allow environment-specific mutation (usually comments) to the SQL
    if SQL_QUERY_MUTATOR:
        sql = SQL_QUERY_MUTATOR(sql, user_name, security_manager, database)

    try:
        if log_query:
            log_query(
                query.database.sqlalchemy_uri,
                query.executed_sql,
                query.schema,
                user_name,
                __name__,
                security_manager,
            )
        query.executed_sql = sql
        session.commit()
        with stats_timing("sqllab.query.time_executing_query", stats_logger):
            logger.info(f"Query {query.id}: Running query: \n{sql}")
            db_engine_spec.execute(cursor, sql, async_=True)
            logger.info(f"Query {query.id}: Handling cursor")
            db_engine_spec.handle_cursor(cursor, query, session)

        with stats_timing("sqllab.query.time_fetching_results", stats_logger):
            logger.debug(
                "Query %d: Fetching data for query object: %s",
                query.id,
                str(query.to_dict()),
            )
            data = db_engine_spec.fetch_data(cursor, query.limit)

    except SoftTimeLimitExceeded as e:
        logger.exception(f"Query {query.id}: {e}")
        raise SqlLabTimeoutException(
            "SQL Lab timeout. This environment's policy is to kill queries "
            "after {} seconds.".format(SQLLAB_TIMEOUT))
    except Exception as e:
        logger.exception(f"Query {query.id}: {e}")
        raise SqlLabException(db_engine_spec.extract_error_message(e))

    logger.debug(f"Query {query.id}: Fetching cursor description")
    cursor_description = cursor.description
    return SupersetDataFrame(data, cursor_description, db_engine_spec)
Exemplo n.º 16
0
 def sql_tables(self) -> List[Table]:
     return list(ParsedQuery(self.sql).tables)
Exemplo n.º 17
0
    def estimate_statement_cost(  # pylint: disable=too-many-locals
            cls, statement: str, database, cursor,
            user_name: str) -> Dict[str, str]:
        """
        Generate a SQL query that estimates the cost of a given statement.

        :param statement: A single SQL statement
        :param database: Database instance
        :param cursor: Cursor instance
        :param username: Effective username
        """
        parsed_query = ParsedQuery(statement)
        sql = parsed_query.stripped()

        sql_query_mutator = config["SQL_QUERY_MUTATOR"]
        if sql_query_mutator:
            sql = sql_query_mutator(sql, user_name, security_manager, database)

        sql = f"EXPLAIN (TYPE IO, FORMAT JSON) {sql}"
        cursor.execute(sql)

        # the output from Presto is a single column and a single row containing
        # JSON:
        #
        #   {
        #     ...
        #     "estimate" : {
        #       "outputRowCount" : 8.73265878E8,
        #       "outputSizeInBytes" : 3.41425774958E11,
        #       "cpuCost" : 3.41425774958E11,
        #       "maxMemory" : 0.0,
        #       "networkCost" : 3.41425774958E11
        #     }
        #   }
        result = json.loads(cursor.fetchone()[0])
        estimate = result["estimate"]

        def humanize(value: Any, suffix: str) -> str:
            try:
                value = int(value)
            except ValueError:
                return str(value)

            prefixes = ["K", "M", "G", "T", "P", "E", "Z", "Y"]
            prefix = ""
            to_next_prefix = 1000
            while value > to_next_prefix and prefixes:
                prefix = prefixes.pop(0)
                value //= to_next_prefix

            return f"{value} {prefix}{suffix}"

        cost = {}
        columns = [
            ("outputRowCount", "Output count", " rows"),
            ("outputSizeInBytes", "Output size", "B"),
            ("cpuCost", "CPU cost", ""),
            ("maxMemory", "Max memory", "B"),
            ("networkCost", "Network cost", ""),
        ]
        for key, label, suffix in columns:
            if key in estimate:
                cost[label] = humanize(estimate[key], suffix)

        return cost
Exemplo n.º 18
0
 def test_get_query_with_new_limit_comment_with_limit(self):
     sql = "SELECT * FROM birth_names -- SOME COMMENT WITH LIMIT 555"
     parsed = ParsedQuery(sql)
     newsql = parsed.set_or_update_query_limit(1000)
     self.assertEqual(newsql, sql + "\nLIMIT 1000")
Exemplo n.º 19
0
 def is_readonly(sql: str) -> bool:
     return PrestoEngineSpec.is_readonly_query(ParsedQuery(sql))
Exemplo n.º 20
0
 def test_update_not_select(self):
     sql = ParsedQuery("UPDATE t1 SET col1 = NULL")
     self.assertEqual(False, sql.is_select())
Exemplo n.º 21
0
def execute_sql_statement(
    sql_statement: str,
    query: Query,
    user_name: Optional[str],
    session: Session,
    cursor: Any,
    log_params: Optional[Dict[str, Any]],
    apply_ctas: bool = False,
) -> SupersetResultSet:
    """Executes a single SQL statement"""
    database = query.database
    db_engine_spec = database.db_engine_spec
    parsed_query = ParsedQuery(sql_statement)
    sql = parsed_query.stripped()

    if not db_engine_spec.is_readonly_query(
            parsed_query) and not database.allow_dml:
        raise SqlLabSecurityException(
            _("Only `SELECT` statements are allowed against this database"))
    if apply_ctas:
        if not query.tmp_table_name:
            start_dttm = datetime.fromtimestamp(query.start_time)
            query.tmp_table_name = "tmp_{}_table_{}".format(
                query.user_id, start_dttm.strftime("%Y_%m_%d_%H_%M_%S"))
        sql = parsed_query.as_create_table(
            query.tmp_table_name,
            schema_name=query.tmp_schema_name,
            method=query.ctas_method,
        )
        query.select_as_cta_used = True

    # Do not apply limit to the CTA queries when SQLLAB_CTAS_NO_LIMIT is set to true
    if parsed_query.is_select() and not (query.select_as_cta_used
                                         and SQLLAB_CTAS_NO_LIMIT):
        if SQL_MAX_ROW and (not query.limit or query.limit > SQL_MAX_ROW):
            query.limit = SQL_MAX_ROW
        if query.limit:
            sql = database.apply_limit_to_sql(sql, query.limit)

    # Hook to allow environment-specific mutation (usually comments) to the SQL
    if SQL_QUERY_MUTATOR:
        sql = SQL_QUERY_MUTATOR(sql, user_name, security_manager, database)

    try:
        if log_query:
            log_query(
                query.database.sqlalchemy_uri,
                query.executed_sql,
                query.schema,
                user_name,
                __name__,
                security_manager,
                log_params,
            )
        query.executed_sql = sql
        session.commit()
        with stats_timing("sqllab.query.time_executing_query", stats_logger):
            logger.debug("Query %d: Running query: %s", query.id, sql)
            db_engine_spec.execute(cursor, sql, async_=True)
            logger.debug("Query %d: Handling cursor", query.id)
            db_engine_spec.handle_cursor(cursor, query, session)

        with stats_timing("sqllab.query.time_fetching_results", stats_logger):
            logger.debug(
                "Query %d: Fetching data for query object: %s",
                query.id,
                str(query.to_dict()),
            )
            data = db_engine_spec.fetch_data(cursor, query.limit)

    except SoftTimeLimitExceeded as ex:
        logger.error("Query %d: Time limit exceeded", query.id)
        logger.debug("Query %d: %s", query.id, ex)
        raise SqlLabTimeoutException(
            "SQL Lab timeout. This environment's policy is to kill queries "
            "after {} seconds.".format(SQLLAB_TIMEOUT))
    except Exception as ex:
        logger.error("Query %d: %s", query.id, type(ex))
        logger.debug("Query %d: %s", query.id, ex)
        raise SqlLabException(db_engine_spec.extract_error_message(ex))

    logger.debug("Query %d: Fetching cursor description", query.id)
    cursor_description = cursor.description
    return SupersetResultSet(data, cursor_description, db_engine_spec)
Exemplo n.º 22
0
def execute_sql_statement(
    sql_statement: str,
    query: Query,
    user_name: Optional[str],
    session: Session,
    cursor: Any,
    log_params: Optional[Dict[str, Any]],
    apply_ctas: bool = False,
) -> SupersetResultSet:
    """Executes a single SQL statement"""
    database = query.database
    db_engine_spec = database.db_engine_spec
    parsed_query = ParsedQuery(sql_statement)
    sql = parsed_query.stripped()
    # This is a test to see if the query is being
    # limited by either the dropdown or the sql.
    # We are testing to see if more rows exist than the limit.
    increased_limit = None if query.limit is None else query.limit + 1

    if not db_engine_spec.is_readonly_query(parsed_query) and not database.allow_dml:
        raise SupersetErrorException(
            SupersetError(
                message=__("Only SELECT statements are allowed against this database."),
                error_type=SupersetErrorType.DML_NOT_ALLOWED_ERROR,
                level=ErrorLevel.ERROR,
            )
        )
    if apply_ctas:
        if not query.tmp_table_name:
            start_dttm = datetime.fromtimestamp(query.start_time)
            query.tmp_table_name = "tmp_{}_table_{}".format(
                query.user_id, start_dttm.strftime("%Y_%m_%d_%H_%M_%S")
            )
        sql = parsed_query.as_create_table(
            query.tmp_table_name,
            schema_name=query.tmp_schema_name,
            method=query.ctas_method,
        )
        query.select_as_cta_used = True

    # Do not apply limit to the CTA queries when SQLLAB_CTAS_NO_LIMIT is set to true
    if db_engine_spec.is_select_query(parsed_query) and not (
        query.select_as_cta_used and SQLLAB_CTAS_NO_LIMIT
    ):
        if SQL_MAX_ROW and (not query.limit or query.limit > SQL_MAX_ROW):
            query.limit = SQL_MAX_ROW
        if query.limit:
            # We are fetching one more than the requested limit in order
            # to test whether there are more rows than the limit.
            # Later, the extra row will be dropped before sending
            # the results back to the user.
            sql = database.apply_limit_to_sql(sql, increased_limit, force=True)

    # Hook to allow environment-specific mutation (usually comments) to the SQL
    sql = SQL_QUERY_MUTATOR(sql, user_name, security_manager, database)
    try:
        query.executed_sql = sql
        if log_query:
            log_query(
                query.database.sqlalchemy_uri,
                query.executed_sql,
                query.schema,
                user_name,
                __name__,
                security_manager,
                log_params,
            )
        session.commit()
        with stats_timing("sqllab.query.time_executing_query", stats_logger):
            logger.debug("Query %d: Running query: %s", query.id, sql)
            db_engine_spec.execute(cursor, sql, async_=True)
            logger.debug("Query %d: Handling cursor", query.id)
            db_engine_spec.handle_cursor(cursor, query, session)

        with stats_timing("sqllab.query.time_fetching_results", stats_logger):
            logger.debug(
                "Query %d: Fetching data for query object: %s",
                query.id,
                str(query.to_dict()),
            )
            data = db_engine_spec.fetch_data(cursor, increased_limit)
            if query.limit is None or len(data) <= query.limit:
                query.limiting_factor = LimitingFactor.NOT_LIMITED
            else:
                # return 1 row less than increased_query
                data = data[:-1]
    except SoftTimeLimitExceeded as ex:
        logger.warning("Query %d: Time limit exceeded", query.id)
        logger.debug("Query %d: %s", query.id, ex)
        raise SupersetErrorException(
            SupersetError(
                message=__(
                    f"The query was killed after {SQLLAB_TIMEOUT} seconds. It might "
                    "be too complex, or the database might be under heavy load."
                ),
                error_type=SupersetErrorType.SQLLAB_TIMEOUT_ERROR,
                level=ErrorLevel.ERROR,
            )
        )
    except Exception as ex:
        logger.error("Query %d: %s", query.id, type(ex), exc_info=True)
        logger.debug("Query %d: %s", query.id, ex)
        raise SqlLabException(db_engine_spec.extract_error_message(ex))

    logger.debug("Query %d: Fetching cursor description", query.id)
    cursor_description = cursor.description
    return SupersetResultSet(data, cursor_description, db_engine_spec)
Exemplo n.º 23
0
def execute_sql_statements(
    ctask,
    query_id,
    rendered_query,
    return_results=True,
    store_results=False,
    user_name=None,
    session=None,
    start_time=None,
):
    """Executes the sql query returns the results."""
    if store_results and start_time:
        # only asynchronous queries
        stats_logger.timing("sqllab.query.time_pending", now_as_float() - start_time)

    query = get_query(query_id, session)
    payload = dict(query_id=query_id)
    database = query.database
    db_engine_spec = database.db_engine_spec
    db_engine_spec.patch()

    if store_results and not results_backend:
        raise SqlLabException("Results backend isn't configured.")

    # Breaking down into multiple statements
    parsed_query = ParsedQuery(rendered_query)
    statements = parsed_query.get_statements()
    logging.info(f"Executing {len(statements)} statement(s)")

    logging.info("Set query to 'running'")
    query.status = QueryStatus.RUNNING
    query.start_running_time = now_as_float()

    engine = database.get_sqla_engine(
        schema=query.schema,
        nullpool=True,
        user_name=user_name,
        source=sources.get("sql_lab", None),
    )
    # Sharing a single connection and cursor across the
    # execution of all statements (if many)
    with closing(engine.raw_connection()) as conn:
        with closing(conn.cursor()) as cursor:
            query.connection_id = db_engine_spec.get_connection_id(cursor)
            session.commit()
            statement_count = len(statements)
            for i, statement in enumerate(statements):
                # check if the query was stopped
                session.refresh(query)
                if query.status == QueryStatus.STOPPED:
                    payload.update({"status": query.status})
                    return payload

                msg = f"Running statement {i+1} out of {statement_count}"
                logging.info(msg)
                query.set_extra_json_key("progress", msg)
                session.commit()
                try:
                    cdf = execute_sql_statement(
                        statement, query, user_name, session, cursor
                    )
                    msg = f"Running statement {i+1} out of {statement_count}"
                except Exception as e:
                    # query can be stopped in another thread/worker
                    # but in synchronized mode it may lead to an error
                    # skip error the error in such case
                    session.refresh(query)
                    if query.status == QueryStatus.STOPPED:
                        payload.update({"status": query.status})
                        return payload

                    msg = str(e)
                    if statement_count > 1:
                        msg = f"[Statement {i+1} out of {statement_count}] " + msg
                    payload = handle_query_error(msg, query, session, payload)
                    return payload

    # Success, updating the query entry in database
    query.rows = cdf.size
    query.progress = 100
    query.set_extra_json_key("progress", None)
    query.connection_id = None
    if query.select_as_cta:
        query.select_sql = database.select_star(
            query.tmp_table_name,
            limit=query.limit,
            schema=database.force_ctas_schema,
            show_cols=False,
            latest_partition=False,
        )
    query.end_time = now_as_float()

    selected_columns = cdf.columns or []
    data = cdf.data or []
    all_columns, data, expanded_columns = db_engine_spec.expand_data(
        selected_columns, data
    )

    payload.update(
        {
            "status": QueryStatus.SUCCESS,
            "data": data,
            "columns": all_columns,
            "selected_columns": selected_columns,
            "expanded_columns": expanded_columns,
            "query": query.to_dict(),
        }
    )
    payload["query"]["state"] = QueryStatus.SUCCESS

    # go over each row, find bytes columns that start with the magic UAST
    # sequence b'\x00bgr', and replace it with a string containing the
    # UAST in JSON
    for row in payload["data"]:
        for k, v in row.items():
            if isinstance(v, bytes) and len(v) > 4 and v[0:4] == b"\x00bgr":
                try:
                    ctx = decode(v, format=0)
                    row[k] = json.dumps(ctx.load())
                except Exception:
                    pass

    if store_results:
        key = str(uuid.uuid4())
        logging.info(f"Storing results in results backend, key: {key}")
        with stats_timing("sqllab.query.results_backend_write", stats_logger):
            json_payload = json.dumps(
                payload, default=json_iso_dttm_ser, ignore_nan=True
            )
            cache_timeout = database.cache_timeout
            if cache_timeout is None:
                cache_timeout = config.get("CACHE_DEFAULT_TIMEOUT", 0)
            results_backend.set(key, zlib_compress(json_payload), cache_timeout)
        query.results_key = key

    query.status = QueryStatus.SUCCESS
    session.commit()

    if return_results:
        return payload
Exemplo n.º 24
0
    def validate_statement(
            cls,
            statement,
            database,
            cursor,
            user_name,
    ) -> Optional[SQLValidationAnnotation]:
        # pylint: disable=too-many-locals
        db_engine_spec = database.db_engine_spec
        parsed_query = ParsedQuery(statement)
        sql = parsed_query.stripped()

        # Hook to allow environment-specific mutation (usually comments) to the SQL
        # pylint: disable=invalid-name
        SQL_QUERY_MUTATOR = config.get('SQL_QUERY_MUTATOR')
        if SQL_QUERY_MUTATOR:
            sql = SQL_QUERY_MUTATOR(sql, user_name, security_manager, database)

        # Transform the final statement to an explain call before sending it on
        # to presto to validate
        sql = f'EXPLAIN (TYPE VALIDATE) {sql}'

        # Invoke the query against presto. NB this deliberately doesn't use the
        # engine spec's handle_cursor implementation since we don't record
        # these EXPLAIN queries done in validation as proper Query objects
        # in the superset ORM.
        from pyhive.exc import DatabaseError
        try:
            db_engine_spec.execute(cursor, sql)
            polled = cursor.poll()
            while polled:
                logging.info('polling presto for validation progress')
                stats = polled.get('stats', {})
                if stats:
                    state = stats.get('state')
                    if state == 'FINISHED':
                        break
                time.sleep(0.2)
                polled = cursor.poll()
            db_engine_spec.fetch_data(cursor, MAX_ERROR_ROWS)
            return None
        except DatabaseError as db_error:
            # The pyhive presto client yields EXPLAIN (TYPE VALIDATE) responses
            # as though they were normal queries. In other words, it doesn't
            # know that errors here are not exceptional. To map this back to
            # ordinary control flow, we have to trap the category of exception
            # raised by the underlying client, match the exception arguments
            # pyhive provides against the shape of dictionary for a presto query
            # invalid error, and restructure that error as an annotation we can
            # return up.

            # Confirm the first element in the DatabaseError constructor is a
            # dictionary with error information. This is currently provided by
            # the pyhive client, but may break if their interface changes when
            # we update at some point in the future.
            if not db_error.args or not isinstance(db_error.args[0], dict):
                raise PrestoSQLValidationError(
                    'The pyhive presto client returned an unhandled '
                    'database error.',
                ) from db_error
            error_args: Dict[str, Any] = db_error.args[0]

            # Confirm the two fields we need to be able to present an annotation
            # are present in the error response -- a message, and a location.
            if 'message' not in error_args:
                raise PrestoSQLValidationError(
                    'The pyhive presto client did not report an error message',
                ) from db_error
            if 'errorLocation' not in error_args:
                raise PrestoSQLValidationError(
                    'The pyhive presto client did not report an error location',
                ) from db_error

            # Pylint is confused about the type of error_args, despite the hints
            # and checks above.
            # pylint: disable=invalid-sequence-index
            message = error_args['message']
            err_loc = error_args['errorLocation']
            line_number = err_loc.get('lineNumber', None)
            start_column = err_loc.get('columnNumber', None)
            end_column = err_loc.get('columnNumber', None)

            return SQLValidationAnnotation(
                message=message,
                line_number=line_number,
                start_column=start_column,
                end_column=end_column,
            )
        except Exception as e:
            logging.exception(f'Unexpected error running validation query: {e}')
            raise e
Exemplo n.º 25
0
    def test_explain(self):
        sql = ParsedQuery("EXPLAIN SELECT 1")

        self.assertEqual(True, sql.is_explain())
        self.assertEqual(False, sql.is_select())
def test_update() -> None:
    """
    Test that ``UPDATE`` is not detected as ``SELECT``.
    """
    assert ParsedQuery("UPDATE t1 SET col1 = NULL").is_select() is False
Exemplo n.º 27
0
def execute_sql_statements(  # pylint: disable=too-many-arguments, too-many-locals, too-many-statements, too-many-branches
    query_id: int,
    rendered_query: str,
    return_results: bool,
    store_results: bool,
    user_name: Optional[str],
    session: Session,
    start_time: Optional[float],
    expand_data: bool,
    log_params: Optional[Dict[str, Any]],
) -> Optional[Dict[str, Any]]:
    """Executes the sql query returns the results."""
    if store_results and start_time:
        # only asynchronous queries
        stats_logger.timing("sqllab.query.time_pending", now_as_float() - start_time)

    query = get_query(query_id, session)
    payload: Dict[str, Any] = dict(query_id=query_id)
    database = query.database
    db_engine_spec = database.db_engine_spec
    db_engine_spec.patch()

    if database.allow_run_async and not results_backend:
        raise SupersetErrorException(
            SupersetError(
                message=__("Results backend is not configured."),
                error_type=SupersetErrorType.RESULTS_BACKEND_NOT_CONFIGURED_ERROR,
                level=ErrorLevel.ERROR,
            )
        )

    # Breaking down into multiple statements
    parsed_query = ParsedQuery(rendered_query, strip_comments=True)
    if not db_engine_spec.run_multiple_statements_as_one:
        statements = parsed_query.get_statements()
        logger.info(
            "Query %s: Executing %i statement(s)", str(query_id), len(statements)
        )
    else:
        statements = [rendered_query]
        logger.info("Query %s: Executing query as a single statement", str(query_id))

    logger.info("Query %s: Set query to 'running'", str(query_id))
    query.status = QueryStatus.RUNNING
    query.start_running_time = now_as_float()
    session.commit()

    # Should we create a table or view from the select?
    if (
        query.select_as_cta
        and query.ctas_method == CtasMethod.TABLE
        and not parsed_query.is_valid_ctas()
    ):
        raise SupersetErrorException(
            SupersetError(
                message=__(
                    "CTAS (create table as select) can only be run with a query where "
                    "the last statement is a SELECT. Please make sure your query has "
                    "a SELECT as its last statement. Then, try running your query "
                    "again."
                ),
                error_type=SupersetErrorType.INVALID_CTAS_QUERY_ERROR,
                level=ErrorLevel.ERROR,
            )
        )
    if (
        query.select_as_cta
        and query.ctas_method == CtasMethod.VIEW
        and not parsed_query.is_valid_cvas()
    ):
        raise SupersetErrorException(
            SupersetError(
                message=__(
                    "CVAS (create view as select) can only be run with a query with "
                    "a single SELECT statement. Please make sure your query has only "
                    "a SELECT statement. Then, try running your query again."
                ),
                error_type=SupersetErrorType.INVALID_CVAS_QUERY_ERROR,
                level=ErrorLevel.ERROR,
            )
        )

    engine = database.get_sqla_engine(
        schema=query.schema,
        nullpool=True,
        user_name=user_name,
        source=QuerySource.SQL_LAB,
    )
    # Sharing a single connection and cursor across the
    # execution of all statements (if many)
    with closing(engine.raw_connection()) as conn:
        # closing the connection closes the cursor as well
        cursor = conn.cursor()
        statement_count = len(statements)
        for i, statement in enumerate(statements):
            # Check if stopped
            query = get_query(query_id, session)
            if query.status == QueryStatus.STOPPED:
                return None

            # For CTAS we create the table only on the last statement
            apply_ctas = query.select_as_cta and (
                query.ctas_method == CtasMethod.VIEW
                or (query.ctas_method == CtasMethod.TABLE and i == len(statements) - 1)
            )

            # Run statement
            msg = f"Running statement {i+1} out of {statement_count}"
            logger.info("Query %s: %s", str(query_id), msg)
            query.set_extra_json_key("progress", msg)
            session.commit()
            try:
                result_set = execute_sql_statement(
                    statement,
                    query,
                    user_name,
                    session,
                    cursor,
                    log_params,
                    apply_ctas,
                )
            except Exception as ex:  # pylint: disable=broad-except
                msg = str(ex)
                prefix_message = (
                    f"[Statement {i+1} out of {statement_count}]"
                    if statement_count > 1
                    else ""
                )
                payload = handle_query_error(
                    ex, query, session, payload, prefix_message
                )
                return payload

        # Commit the connection so CTA queries will create the table.
        conn.commit()

    # Success, updating the query entry in database
    query.rows = result_set.size
    query.progress = 100
    query.set_extra_json_key("progress", None)
    if query.select_as_cta:
        query.select_sql = database.select_star(
            query.tmp_table_name,
            schema=query.tmp_schema_name,
            limit=query.limit,
            show_cols=False,
            latest_partition=False,
        )
    query.end_time = now_as_float()

    use_arrow_data = store_results and cast(bool, results_backend_use_msgpack)
    data, selected_columns, all_columns, expanded_columns = _serialize_and_expand_data(
        result_set, db_engine_spec, use_arrow_data, expand_data
    )

    # TODO: data should be saved separately from metadata (likely in Parquet)
    payload.update(
        {
            "status": QueryStatus.SUCCESS,
            "data": data,
            "columns": all_columns,
            "selected_columns": selected_columns,
            "expanded_columns": expanded_columns,
            "query": query.to_dict(),
        }
    )
    payload["query"]["state"] = QueryStatus.SUCCESS

    if store_results and results_backend:
        key = str(uuid.uuid4())
        logger.info(
            "Query %s: Storing results in results backend, key: %s", str(query_id), key
        )
        with stats_timing("sqllab.query.results_backend_write", stats_logger):
            with stats_timing(
                "sqllab.query.results_backend_write_serialization", stats_logger
            ):
                serialized_payload = _serialize_payload(
                    payload, cast(bool, results_backend_use_msgpack)
                )
            cache_timeout = database.cache_timeout
            if cache_timeout is None:
                cache_timeout = config["CACHE_DEFAULT_TIMEOUT"]

            compressed = zlib_compress(serialized_payload)
            logger.debug(
                "*** serialized payload size: %i", getsizeof(serialized_payload)
            )
            logger.debug("*** compressed payload size: %i", getsizeof(compressed))
            results_backend.set(key, compressed, cache_timeout)
        query.results_key = key

    query.status = QueryStatus.SUCCESS
    session.commit()

    if return_results:
        # since we're returning results we need to create non-arrow data
        if use_arrow_data:
            (
                data,
                selected_columns,
                all_columns,
                expanded_columns,
            ) = _serialize_and_expand_data(
                result_set, db_engine_spec, False, expand_data
            )
            payload.update(
                {
                    "data": data,
                    "columns": all_columns,
                    "selected_columns": selected_columns,
                    "expanded_columns": expanded_columns,
                }
            )
        return payload

    return None
Exemplo n.º 28
0
def execute_sql_statements(  # pylint: disable=too-many-arguments, too-many-locals, too-many-statements
    query_id: int,
    rendered_query: str,
    return_results: bool,
    store_results: bool,
    user_name: Optional[str],
    session: Session,
    start_time: Optional[float],
    expand_data: bool,
    log_params: Optional[Dict[str, Any]],
) -> Optional[Dict[str, Any]]:
    """Executes the sql query returns the results."""
    if store_results and start_time:
        # only asynchronous queries
        stats_logger.timing("sqllab.query.time_pending",
                            now_as_float() - start_time)

    query = get_query(query_id, session)
    payload: Dict[str, Any] = dict(query_id=query_id)
    database = query.database
    db_engine_spec = database.db_engine_spec
    db_engine_spec.patch()

    if database.allow_run_async and not results_backend:
        raise SqlLabException("Results backend isn't configured.")

    # Breaking down into multiple statements
    parsed_query = ParsedQuery(rendered_query)
    statements = parsed_query.get_statements()
    logger.info("Query %s: Executing %i statement(s)", str(query_id),
                len(statements))

    logger.info("Query %s: Set query to 'running'", str(query_id))
    query.status = QueryStatus.RUNNING
    query.start_running_time = now_as_float()
    session.commit()

    engine = database.get_sqla_engine(
        schema=query.schema,
        nullpool=True,
        user_name=user_name,
        source=QuerySource.SQL_LAB,
    )
    # Sharing a single connection and cursor across the
    # execution of all statements (if many)
    with closing(engine.raw_connection()) as conn:
        with closing(conn.cursor()) as cursor:
            statement_count = len(statements)
            for i, statement in enumerate(statements):
                # Check if stopped
                query = get_query(query_id, session)
                if query.status == QueryStatus.STOPPED:
                    return None

                # Run statement
                msg = f"Running statement {i+1} out of {statement_count}"
                logger.info("Query %s: %s", str(query_id), msg)
                query.set_extra_json_key("progress", msg)
                session.commit()
                try:
                    result_set = execute_sql_statement(statement, query,
                                                       user_name, session,
                                                       cursor, log_params)
                except Exception as ex:  # pylint: disable=broad-except
                    msg = str(ex)
                    if statement_count > 1:
                        msg = f"[Statement {i+1} out of {statement_count}] " + msg
                    payload = handle_query_error(msg, query, session, payload)
                    return payload

        # Commit the connection so CTA queries will create the table.
        conn.commit()

    # Success, updating the query entry in database
    query.rows = result_set.size
    query.progress = 100
    query.set_extra_json_key("progress", None)
    if query.select_as_cta:
        query.select_sql = database.select_star(
            query.tmp_table_name,
            schema=query.tmp_schema_name,
            limit=query.limit,
            show_cols=False,
            latest_partition=False,
        )
    query.end_time = now_as_float()

    use_arrow_data = store_results and cast(bool, results_backend_use_msgpack)
    data, selected_columns, all_columns, expanded_columns = _serialize_and_expand_data(
        result_set, db_engine_spec, use_arrow_data, expand_data)

    # TODO: data should be saved separately from metadata (likely in Parquet)
    payload.update({
        "status": QueryStatus.SUCCESS,
        "data": data,
        "columns": all_columns,
        "selected_columns": selected_columns,
        "expanded_columns": expanded_columns,
        "query": query.to_dict(),
    })
    payload["query"]["state"] = QueryStatus.SUCCESS

    if store_results and results_backend:
        key = str(uuid.uuid4())
        logger.info("Query %s: Storing results in results backend, key: %s",
                    str(query_id), key)
        with stats_timing("sqllab.query.results_backend_write", stats_logger):
            with stats_timing(
                    "sqllab.query.results_backend_write_serialization",
                    stats_logger):
                serialized_payload = _serialize_payload(
                    payload, cast(bool, results_backend_use_msgpack))
            cache_timeout = database.cache_timeout
            if cache_timeout is None:
                cache_timeout = config["CACHE_DEFAULT_TIMEOUT"]

            compressed = zlib_compress(serialized_payload)
            logger.debug("*** serialized payload size: %i",
                         getsizeof(serialized_payload))
            logger.debug("*** compressed payload size: %i",
                         getsizeof(compressed))
            results_backend.set(key, compressed, cache_timeout)
        query.results_key = key

    query.status = QueryStatus.SUCCESS
    session.commit()

    if return_results:
        # since we're returning results we need to create non-arrow data
        if use_arrow_data:
            (
                data,
                selected_columns,
                all_columns,
                expanded_columns,
            ) = _serialize_and_expand_data(result_set, db_engine_spec, False,
                                           expand_data)
            payload.update({
                "data": data,
                "columns": all_columns,
                "selected_columns": selected_columns,
                "expanded_columns": expanded_columns,
            })
        return payload

    return None
Exemplo n.º 29
0
def get_table_metadata(database: Database, table_name: str,
                       schema_name: Optional[str]) -> Dict:
    """
        Get table metadata information, including type, pk, fks.
        This function raises SQLAlchemyError when a schema is not found.


    :param database: The database model
    :param table_name: Table name
    :param schema_name: schema name
    :return: Dict table metadata ready for API response
    """
    keys: List = []
    columns = database.get_columns(table_name, schema_name)
    # define comment dict by tsl
    comment_dict = {}
    primary_key = database.get_pk_constraint(table_name, schema_name)
    if primary_key and primary_key.get("constrained_columns"):
        primary_key["column_names"] = primary_key.pop("constrained_columns")
        primary_key["type"] = "pk"
        keys += [primary_key]
    # get dialect name
    dialect_name = database.get_dialect().name
    if isinstance(dialect_name, bytes):
        dialect_name = dialect_name.decode()
    # get column comment, presto & hive
    if dialect_name == "presto" or dialect_name == "hive":
        db_engine_spec = database.db_engine_spec
        sql = ParsedQuery("desc {a}.{b}".format(a=schema_name,
                                                b=table_name)).stripped()
        engine = database.get_sqla_engine(schema_name)
        conn = engine.raw_connection()
        cursor = conn.cursor()
        query = Query()
        session = Session(bind=engine)
        query.executed_sql = sql
        query.__tablename__ = table_name
        session.commit()
        db_engine_spec.execute(cursor, sql, async_=False)
        data = db_engine_spec.fetch_data(cursor, query.limit)
        # parse list data into dict by tsl; hive and presto is different
        if dialect_name == "presto":
            for d in data:
                d[3]
                comment_dict[d[0]] = d[3]
        else:
            for d in data:
                d[2]
                comment_dict[d[0]] = d[2]
        conn.commit()

    foreign_keys = get_foreign_keys_metadata(database, table_name, schema_name)
    indexes = get_indexes_metadata(database, table_name, schema_name)
    keys += foreign_keys + indexes
    payload_columns: List[Dict] = []
    for col in columns:
        dtype = get_col_type(col)
        if len(comment_dict) > 0:
            payload_columns.append({
                "name":
                col["name"],
                "type":
                dtype.split("(")[0] if "(" in dtype else dtype,
                "longType":
                dtype,
                "keys":
                [k for k in keys if col["name"] in k.get("column_names")],
                "comment":
                comment_dict[col["name"]],
            })
        elif dialect_name == "mysql":
            payload_columns.append({
                "name":
                col["name"],
                "type":
                dtype.split("(")[0] if "(" in dtype else dtype,
                "longType":
                dtype,
                "keys":
                [k for k in keys if col["name"] in k.get("column_names")],
                "comment":
                col["comment"],
            })
        else:
            payload_columns.append({
                "name":
                col["name"],
                "type":
                dtype.split("(")[0] if "(" in dtype else dtype,
                "longType":
                dtype,
                "keys":
                [k for k in keys if col["name"] in k.get("column_names")],
                # "comment": col["comment"],
            })
    return {
        "name":
        table_name,
        "columns":
        payload_columns,
        "selectStar":
        database.select_star(
            table_name,
            schema=schema_name,
            show_cols=True,
            indent=True,
            cols=columns,
            latest_partition=True,
        ),
        "primaryKey":
        primary_key,
        "foreignKeys":
        foreign_keys,
        "indexes":
        keys,
    }
def extract_tables(query: str) -> Set[Table]:
    """
    Helper function to extract tables referenced in a query.
    """
    return ParsedQuery(query).tables