Beispiel #1
0
def test_multiple_statement_safety_belt():
    sql1 = parse_sql('select a from x; select b from y')
    sql2 = parse_sql('select a from x;\n\nselect b from y')
    assert sql1 != sql2
    _remove_stmt_len_and_location(sql1)
    _remove_stmt_len_and_location(sql2)
    assert sql1 == sql2
Beispiel #2
0
def test_errors():
    with pytest.raises(Error) as exc:
        parse_sql('FooBar')
    assert exc.typename == 'ParseError'
    assert exc.value.location == 1
    assert 'syntax error ' in str(exc.value)

    with pytest.raises(Error) as exc:
        parse_sql('SELECT foo FRON bar')
    assert exc.typename == 'ParseError'
    assert exc.value.location == 17
    errmsg = str(exc.value)
    assert 'syntax error at or near "bar"' in errmsg
    assert 'location 17' in errmsg

    with pytest.raises(Error) as exc:
        parse_plpgsql('CREATE FUMCTION add (a integer, b integer)'
                      ' RETURNS integer AS $$ BEGIN RETURN a + b; END; $$'
                      ' LANGUAGE plpgsql')
    assert exc.typename == 'ParseError'
    assert exc.value.location == 8
    errmsg = str(exc.value)
    assert 'syntax error at or near "FUMCTION"' in errmsg
    assert 'location 8' in errmsg

    with pytest.raises(Error) as exc:
        fingerprint('SELECT foo FRON bar')
    assert exc.typename == 'ParseError'
    assert exc.value.location == 17
    errmsg = str(exc.value)
    assert 'syntax error at or near "bar"' in errmsg
    assert 'location 17' in errmsg
Beispiel #3
0
def test_pointless_attributes_remotion():
    sql1 = parse_sql('select a from x; select b from y')
    sql2 = parse_sql('select a from x;\n\nselect b from y')
    assert sql1 != sql2
    _remove_stmt_len_and_location(sql1)
    _remove_stmt_len_and_location(sql2)
    assert sql1 == sql2
Beispiel #4
0
def validate_import_sql(sql: str) -> str:
    """
    Check an SQL query to see if it can be safely used in an IMPORT statement
    (e.g. `FROM noaa/climate:latest IMPORT {SELECT * FROM rainfall WHERE state = 'AZ'} AS rainfall`.
    In this case, only a single SELECT statement is supported.

    :param sql: SQL query
    :return: Canonical (formatted) form of the SQL statement
    :raises: UnsupportedSQLException if validation failed
    """
    if not _VALIDATION_SUPPORTED:
        logging.warning("SQL validation is unsupported on Windows. SQL will be run unvalidated.")
        return sql

    try:
        tree = Node(parse_sql(sql))
    except ParseError as e:
        raise UnsupportedSQLError("Could not parse %s: %s" % (sql, str(e)))
    if len(tree) != 1:
        raise UnsupportedSQLError("The query is supposed to consist of only one SELECT statement!")

    for node in tree.traverse():
        _validate_node(
            node,
            permitted_statements=IMPORT_SQL_PERMITTED_STATEMENTS,
            node_validators=_IMPORT_SQL_VALIDATORS,
        )
    return _emit_ast(tree)
Beispiel #5
0
    def clean_sql(self):
        query = self.cleaned_data["sql"].strip().rstrip(";")
        try:
            statements = pglast.parse_sql(query)
        except pglast.parser.ParseError as e:
            raise ValidationError(e) from e
        else:
            if len(statements) > 1:
                raise ValidationError("Enter a single statement")
            statement_dict = statements[0].stmt()
            if statement_dict["@"] != "SelectStmt":
                raise ValidationError("Only SELECT statements are supported")

        # Check that the query runs
        with connections[list(
                settings.DATABASES_DATA.items())[0][0]].cursor() as cursor:
            try:
                cursor.execute(f"SELECT * FROM ({query}) sq LIMIT 0")
            except Exception as e:  # pylint: disable=broad-except
                raise ValidationError(
                    "Error running query. Please check the query runs successfully before saving."
                ) from e

        columns = [x.name for x in cursor.description if x.name is not None]
        if len(set(columns)) != len(columns):
            raise ValidationError("Duplicate column names found")

        return self.cleaned_data["sql"]
Beispiel #6
0
 def __init__(self, sql: str, print=False):
     self._sql = sql
     statements = parse_sql(sql)
     if len(statements) != 1:
         raise SqlError(f'Multiple statements are not allowed: {sql}'
                        )  # pragma: no cover
     stmt = statements[0]['RawStmt']['stmt']
     if print:
         pprint(stmt)  # pragma: no cover
     meta = extract_meta(stmt)
     self._meta = meta
     if meta.typename != 'SelectStmt':
         raise SqlError(
             'Only SELECT statements are supported')  # pragma: no cover
     self._meta = meta
     # populate members
     self.target_list
     self.from_clause
     self.where_clause
     self.sort_clause
     self.limit_count
     self.limit_offset
     self.group_clause
     # check if all clauses have been walked
     assert meta.a.pop('op') == 0
     if len(meta.a) > 0:
         raise SqlError(f'Unknown syntax: {meta.a}')  # pragma: no cover
Beispiel #7
0
def parse_sql(sql: str):
    from pglast import parse_sql
    from pglast.parser import ParseError  # pylint: disable=no-name-in-module
    try:
        return parse_sql(sql)
    except ParseError as error:  # pragma: no cover
        raise SqlError(str(error))
Beispiel #8
0
def workhorse(args):
    input = args.infile or sys.stdin
    with input:
        statement = input.read()

    if args.parse_tree or args.plpgsql:
        tree = parse_plpgsql(statement) if args.plpgsql else parse_sql(statement)
        if args.no_location:
            _remove_stmt_len_and_location(tree)
        output = args.outfile or sys.stdout
        with output:
            json.dump(tree, output, sort_keys=True, indent=2)
            output.write('\n')
    else:
        try:
            prettified = prettify(
                statement,
                compact_lists_margin=args.compact_lists_margin,
                split_string_literals_threshold=args.split_string_literals,
                special_functions=args.special_functions,
                comma_at_eoln=args.comma_at_eoln,
                semicolon_after_last_statement=args.semicolon_after_last_statement)
        except Error as e:
            print()
            raise SystemExit(e)

        output = args.outfile or sys.stdout
        with output:
            output.write(prettified)
            output.write('\n')
Beispiel #9
0
    def statements(self):
        """
        Returns:
            generator: Return value yielding fully parsed statements from the
            feed.
        """
        logger = logging.getLogger()
        while self.finished_statements:
            stmt, comments = self.finished_statements.pop(0)
            if self.validate:
                with warnings.catch_warnings():
                    warnings.filterwarnings("error")
                    # The prettify stmt includes a sanity check raising
                    # warnings
                    try:
                        prettify(stmt, expression_level=1)
                    except Exception as e:
                        # We should raise Error, but pglast don't know about REPLICA
                        # And we are not able to fix and rebuild pglast package
                        logger.debug("Error while parsing statement")
                        logger.debug(stmt)
                        logger.debug(e)
                        continue

            parse_stmt = Node(parse_sql(stmt)[0]).stmt
            if comments:
                parse_stmt.parse_tree['comments'] = comments
            parse_stmt.parse_tree['original_string'] = stmt
            if parse_stmt.node_tag == 'CreateFunctionStmt':
                # Find the language used
                for option in parse_stmt.options:
                    if option.defname == 'language':
                        parse_stmt.parse_tree[
                            'language'] = option.arg.string_value
            yield parse_stmt
 def get_where_clause_list_for_query(query: str):
     """
         Retreives the complete WHERE CLAUSE
     """
     query_tree = Node(parse_sql(query))
     for tre in query_tree:
         for node in tre.stmt.whereClause:
             print(str(node))
 def __init__(self, query_string, tables_map: dict):
     self.query_string = sqlparse.format(query_string, strip_comments=True, reindent=True).strip()
     self.table_alias_dict = dict()
     # parse query to get where clause columns
     res = sqlparse.parse(self.query_string)
     #print(res)
     root = Node(parse_sql(self.query_string))
     if root[0].stmt.__class__.__name__ == 'Node':
         self._parse_node(root[0].stmt, tables_map)
Beispiel #12
0
    def _parse(self):
        root_node = Node(parse_sql(self._sql))

        assert len(root_node) == 1
        statement = root_node[0].parse_tree['stmt']
        if 'SelectStmt' not in statement:
            raise ParserError('Only SELECT statements are supported.')

        return statement['SelectStmt']
Beispiel #13
0
def get_data_hash(cursor, sql):
    statements = pglast.parse_sql(sql)
    if statements[0].stmt()["sortClause"]:
        hashed_data = hashlib.md5()
        cursor.execute(SQL(f"SELECT t.*::TEXT FROM ({sql}) as t"))
        for row in cursor:
            hashed_data.update(row[0].encode("utf-8"))
        return hashed_data.digest()
    return None
Beispiel #14
0
def extract_queried_tables_from_sql_query(query):
    """
    Returns a list of (schema, table) tuples extracted from the passed PostgreSQL query

    This does not communicate with a database, and instead uses pglast to parse the
    query. However, it does not use pglast's built-in functions to extract tables -
    they're buggy in the cases where CTEs have the same names as tables in the
    search path.

    This isn't perfect though - it assumes tables without a schema are in the "public"
    schema, but "public" might not be in the search path, or it might not be the only
    schema in the search path. However, it's probably fine for our usage where "public"
    _is_ in the search path, and the only tables without a schema that we care about in
    our queries are indeed in the public schema - typically only reference dataset tables.
    """

    try:
        statements = pglast.parse_sql(query)
    except pglast.parser.ParseError as e:
        logger.error(e)
        return []

    tables = set()

    node_ctenames = deque()
    node_ctenames.append((statements[0](), ()))

    while node_ctenames:
        node, ctenames = node_ctenames.popleft()

        if node.get("withClause", None) is not None:
            if node["withClause"]["recursive"]:
                ctenames += tuple(
                    (cte["ctename"] for cte in node["withClause"]["ctes"]))
                for cte in node["withClause"]["ctes"]:
                    node_ctenames.append((cte, ctenames))
            else:
                for cte in node["withClause"]["ctes"]:
                    node_ctenames.append((cte, ctenames))
                    ctenames += (cte["ctename"], )

        if node.get("@",
                    None) == "RangeVar" and (node["schemaname"] is not None or
                                             node["relname"] not in ctenames):
            tables.add((node["schemaname"] or "public", node["relname"]))

        for node_type, node_value in node.items():
            if node_type == "withClause":
                continue
            for nested_node in node_value if isinstance(
                    node_value, tuple) else (node_value, ):
                if isinstance(nested_node, dict):
                    node_ctenames.append((nested_node, ctenames))

    return sorted(list(tables))
Beispiel #15
0
def test_basic():
    ptree = parse_sql('SELECT 1')
    assert isinstance(ptree, list)
    assert len(ptree) == 1
    rawstmt = ptree[0]
    assert isinstance(rawstmt, dict)
    assert rawstmt.keys() == {'RawStmt'}

    ptree = parse_plpgsql('CREATE FUNCTION add (a integer, b integer)'
                          ' RETURNS integer AS $$ BEGIN RETURN a + b; END; $$'
                          ' LANGUAGE plpgsql')
    assert len(ptree) == 1
    function = ptree[0]
    assert isinstance(function, dict)
    assert function.keys() == {'PLpgSQL_function'}
def _validate_sql(sqlfile, uri):
    diagnostics = []
    try:
        statements = parse_sql(sqlfile)
    except ParseError as e:
        pos = char_pos_to_position(sqlfile, e.location)
        diagnostics.append(
            Diagnostic(Range(pos, pos),
                       message=e.args[0],
                       severity=DiagnosticSeverity.Error,
                       source=type(pg_language_server).__name__))
        return diagnostics
    for statement in statements:
        for diag in lint(statement, None, None):
            diagnostics.append(diag)
    return diagnostics
Beispiel #17
0
def _parse_string(text):
    """
    Use ``pglast`` to turn ``text`` into a SQL AST node.

    Returns ``pglast.node.Scalar(None)`` when no AST nodes could be
    parsed. This is a hack, but prevents convoluting the downstream
    logic too much, as ``Context.traverse`` will simply ignore scalar
    values.

    >>> _parse_string('SELECT 1')
    [1*{RawStmt}]
    >>> _parse_string('-- just a comment')
    <None>
    """
    ast = pglast.parse_sql(text)
    return pglast.Node(ast) if ast else pglast.node.Scalar(None)
Beispiel #18
0
def _parse_column_type(typ):
    """
    Feed column type name through pglast.

    >>> _parse_column_type('real')
    'pg_catalog.float4'

    >>> _parse_column_type('double precision')
    'pg_catalog.float8'
    """
    sql = 'CREATE TABLE _(_ {0});'.format(typ)

    create_table = pglast.Node(pglast.parse_sql(sql))[0].stmt
    type_name = create_table.tableElts[0].typeName

    return format_type_name(type_name)
Beispiel #19
0
async def _query_table(
    dataset: str,
    version: str,
    sql: str,
    geometry: Optional[Geometry],
) -> List[Dict[str, Any]]:
    # parse and validate SQL statement
    try:
        parsed = parse_sql(unquote(sql))
    except ParseError as e:
        raise HTTPException(status_code=400, detail=str(e))

    _has_only_one_statement(parsed)
    _is_select_statement(parsed)
    _has_no_with_clause(parsed)
    _only_one_from_table(parsed)
    _no_subqueries(parsed)
    _no_forbidden_functions(parsed)
    _no_forbidden_value_functions(parsed)

    # always overwrite the table name with the current dataset version name, to make sure no other table is queried
    parsed[0]["RawStmt"]["stmt"]["SelectStmt"]["fromClause"][0]["RangeVar"][
        "schemaname"] = dataset
    parsed[0]["RawStmt"]["stmt"]["SelectStmt"]["fromClause"][0]["RangeVar"][
        "relname"] = version

    if geometry:
        parsed = await _add_geometry_filter(parsed, geometry)

    # convert back to text
    sql = RawStream()(Node(parsed))

    try:
        rows = await db.all(sql)
        response: List[Dict[str, Any]] = [dict(row) for row in rows]
    except InsufficientPrivilegeError:
        raise HTTPException(status_code=403,
                            detail="Not authorized to execute this query.")
    except (SyntaxOrAccessError, DataError) as e:
        raise HTTPException(status_code=400, detail=f"Bad request. {str(e)}")

    return response
Beispiel #20
0
def parse_column_type(typ):
    """
    Feed the column type through pglast to normalize naming

    e.g. `timestamp with time zone => timestamptz`.

    >>> parse_column_type('integer')
    'integer'
    >>> parse_column_type('custom(type)')
    'custom(type)'
    """
    sql = 'CREATE TABLE _(_ {0});'.format(typ)

    try:
        create_table = pglast.Node(pglast.parse_sql(sql))[0].stmt
        _, typ = _normalize_columns(create_table.tableElts)[0]
        return typ
    except pglast.parser.ParseError:
        raise RuleConfigurationException(
            RequireColumns, 'unable to parse column type "%s' % typ)
Beispiel #21
0
async def _add_geometry_filter(parsed_sql, geometry: Geometry):
    # make empty select statement with where clause including filter
    # this way we can later parse it as AST
    intersect_filter = f"SELECT WHERE ST_Intersects(geom, ST_SetSRID(ST_GeomFromGeoJSON('{geometry.json()}'),4326))"

    # combine the two where clauses
    parsed_filter = parse_sql(intersect_filter)
    filter_where = parsed_filter[0]["RawStmt"]["stmt"]["SelectStmt"][
        "whereClause"]
    sql_where = parsed_sql[0]["RawStmt"]["stmt"]["SelectStmt"].get(
        "whereClause", None)

    if sql_where:
        parsed_sql[0]["RawStmt"]["stmt"]["SelectStmt"]["whereClause"] = {
            "BoolExpr": {
                "boolop": 0,
                "args": [sql_where, filter_where]
            }
        }
    else:
        parsed_sql[0]["RawStmt"]["stmt"]["SelectStmt"][
            "whereClause"] = filter_where

    return parsed_sql
Beispiel #22
0
def parse(sql: str, name: str = None) -> Parsed:
    if name is None:
        name = str(hash(sql))
    node = parse_sql(sql)

    return Parsed(name, sql, node)
Beispiel #23
0
def test_unicode():
    ptree = parse_sql('SELECT 1 AS "Naïve"')
    target = ptree[0]['RawStmt']['stmt']['SelectStmt']['targetList'][0][
        'ResTarget']
    assert target['name'] == "Naïve"
Beispiel #24
0
def parse(sql):
    return AcceptingNode(parse_sql(sql))
Beispiel #25
0
def prepare_splitfile_sql(sql: str, image_mapper: Callable) -> Tuple[str, str]:
    """
    Transform an SQL query to prepare for it to be used in a Splitfile SQL command and validate it.
    The rules are:

      * Only basic DDL (CREATE/ALTER/DROP table) and DML (SELECT/INSERT/UPDATE/DELETE) are permitted.
      * All tables must be either non-schema qualified (the statement is run with `search_path`
      set to the single schema that a Splitgraph image is checked out into) or have schemata of
      format namespace/repository:hash_or_tag. In the second case, the schema is rewritten to point
      at a temporary mount of the Splitgraph image.

    :param sql: SQL query
    :param image_mapper: Takes in an image and gives back the schema it should be rewritten to
        (for the purposes of execution) and the canonical form of the image.
    :return: Transformed form of the SQL with substituted schema shims for Splitfile execution
        and the canonical form (with e.g. tags resolved into at-the-time full image hashes)
    :raises: UnsupportedSQLException if validation failed
    """

    if not _VALIDATION_SUPPORTED:
        logging.warning("SQL validation is unsupported on Windows. SQL will be run unvalidated.")

        return _rewrite_sql_fallback(sql, image_mapper)

    # Avoid circular import
    from splitgraph.core.output import parse_repo_tag_or_hash

    try:
        tree = Node(parse_sql(sql))
    except ParseError as e:
        raise UnsupportedSQLError("Could not parse %s: %s" % (sql, str(e)))

    # List of dict pointers (into parts of the AST) and new schema names we have
    # to rewrite them to. We need to emit two kinds of rewritten SQL: one with schemata
    # replaced with LQ shims (that we send to PostgreSQL for execution) and one with
    # schemata replaced with full image names that we store in provenance_data for reruns etc.
    # On the first pass, we rewrite the parse tree to have the first kind of schemata, then
    # get pglast to serialize it, then use this dictionary to rewrite just the interesting
    # parts of the parse tree (instead of having to re-traverse/re-crawl the tree).
    future_rewrites = []

    for node in tree.traverse():
        _validate_node(
            node,
            permitted_statements=SPLITFILE_SQL_PERMITTED_STATEMENTS,
            node_validators=_SQL_VALIDATORS,
        )

        if not isinstance(node, Node) or node.node_tag != "RangeVar":
            continue

        if node["relname"].value in PG_CATALOG_TABLES:
            raise UnsupportedSQLError("Invalid table name %s!" % node["relname"].value)
        if "schemaname" not in node.attribute_names:
            continue

        schema_name = recover_original_schema_name(sql, node["schemaname"].value)

        # If the table name is schema-qualified, rewrite it to talk to a LQ shim
        repo, hash_or_tag = parse_repo_tag_or_hash(schema_name, default="latest")
        temporary_schema, canonical_name = image_mapper(repo, hash_or_tag)

        # We have to access the internal parse tree here to rewrite the schema.
        node._parse_tree["schemaname"] = temporary_schema
        future_rewrites.append((node._parse_tree, canonical_name))

    rewritten_sql = _emit_ast(tree)

    for tree_subset, canonical_name in future_rewrites:
        tree_subset["schemaname"] = canonical_name

    canonical_sql = _emit_ast(tree)
    return rewritten_sql, canonical_sql
 def get_where_clause_list_for_query(query: str):
     query_tree = Node(parse_sql(query))
     for tre in query_tree:
         for node in tre.stmt.whereClause:
             print(str(node))