def test_multiple_statement_safety_belt(): sql1 = parse_sql('select a from x; select b from y') sql2 = parse_sql('select a from x;\n\nselect b from y') assert sql1 != sql2 _remove_stmt_len_and_location(sql1) _remove_stmt_len_and_location(sql2) assert sql1 == sql2
def test_errors(): with pytest.raises(Error) as exc: parse_sql('FooBar') assert exc.typename == 'ParseError' assert exc.value.location == 1 assert 'syntax error ' in str(exc.value) with pytest.raises(Error) as exc: parse_sql('SELECT foo FRON bar') assert exc.typename == 'ParseError' assert exc.value.location == 17 errmsg = str(exc.value) assert 'syntax error at or near "bar"' in errmsg assert 'location 17' in errmsg with pytest.raises(Error) as exc: parse_plpgsql('CREATE FUMCTION add (a integer, b integer)' ' RETURNS integer AS $$ BEGIN RETURN a + b; END; $$' ' LANGUAGE plpgsql') assert exc.typename == 'ParseError' assert exc.value.location == 8 errmsg = str(exc.value) assert 'syntax error at or near "FUMCTION"' in errmsg assert 'location 8' in errmsg with pytest.raises(Error) as exc: fingerprint('SELECT foo FRON bar') assert exc.typename == 'ParseError' assert exc.value.location == 17 errmsg = str(exc.value) assert 'syntax error at or near "bar"' in errmsg assert 'location 17' in errmsg
def test_pointless_attributes_remotion(): sql1 = parse_sql('select a from x; select b from y') sql2 = parse_sql('select a from x;\n\nselect b from y') assert sql1 != sql2 _remove_stmt_len_and_location(sql1) _remove_stmt_len_and_location(sql2) assert sql1 == sql2
def validate_import_sql(sql: str) -> str: """ Check an SQL query to see if it can be safely used in an IMPORT statement (e.g. `FROM noaa/climate:latest IMPORT {SELECT * FROM rainfall WHERE state = 'AZ'} AS rainfall`. In this case, only a single SELECT statement is supported. :param sql: SQL query :return: Canonical (formatted) form of the SQL statement :raises: UnsupportedSQLException if validation failed """ if not _VALIDATION_SUPPORTED: logging.warning("SQL validation is unsupported on Windows. SQL will be run unvalidated.") return sql try: tree = Node(parse_sql(sql)) except ParseError as e: raise UnsupportedSQLError("Could not parse %s: %s" % (sql, str(e))) if len(tree) != 1: raise UnsupportedSQLError("The query is supposed to consist of only one SELECT statement!") for node in tree.traverse(): _validate_node( node, permitted_statements=IMPORT_SQL_PERMITTED_STATEMENTS, node_validators=_IMPORT_SQL_VALIDATORS, ) return _emit_ast(tree)
def clean_sql(self): query = self.cleaned_data["sql"].strip().rstrip(";") try: statements = pglast.parse_sql(query) except pglast.parser.ParseError as e: raise ValidationError(e) from e else: if len(statements) > 1: raise ValidationError("Enter a single statement") statement_dict = statements[0].stmt() if statement_dict["@"] != "SelectStmt": raise ValidationError("Only SELECT statements are supported") # Check that the query runs with connections[list( settings.DATABASES_DATA.items())[0][0]].cursor() as cursor: try: cursor.execute(f"SELECT * FROM ({query}) sq LIMIT 0") except Exception as e: # pylint: disable=broad-except raise ValidationError( "Error running query. Please check the query runs successfully before saving." ) from e columns = [x.name for x in cursor.description if x.name is not None] if len(set(columns)) != len(columns): raise ValidationError("Duplicate column names found") return self.cleaned_data["sql"]
def __init__(self, sql: str, print=False): self._sql = sql statements = parse_sql(sql) if len(statements) != 1: raise SqlError(f'Multiple statements are not allowed: {sql}' ) # pragma: no cover stmt = statements[0]['RawStmt']['stmt'] if print: pprint(stmt) # pragma: no cover meta = extract_meta(stmt) self._meta = meta if meta.typename != 'SelectStmt': raise SqlError( 'Only SELECT statements are supported') # pragma: no cover self._meta = meta # populate members self.target_list self.from_clause self.where_clause self.sort_clause self.limit_count self.limit_offset self.group_clause # check if all clauses have been walked assert meta.a.pop('op') == 0 if len(meta.a) > 0: raise SqlError(f'Unknown syntax: {meta.a}') # pragma: no cover
def parse_sql(sql: str): from pglast import parse_sql from pglast.parser import ParseError # pylint: disable=no-name-in-module try: return parse_sql(sql) except ParseError as error: # pragma: no cover raise SqlError(str(error))
def workhorse(args): input = args.infile or sys.stdin with input: statement = input.read() if args.parse_tree or args.plpgsql: tree = parse_plpgsql(statement) if args.plpgsql else parse_sql(statement) if args.no_location: _remove_stmt_len_and_location(tree) output = args.outfile or sys.stdout with output: json.dump(tree, output, sort_keys=True, indent=2) output.write('\n') else: try: prettified = prettify( statement, compact_lists_margin=args.compact_lists_margin, split_string_literals_threshold=args.split_string_literals, special_functions=args.special_functions, comma_at_eoln=args.comma_at_eoln, semicolon_after_last_statement=args.semicolon_after_last_statement) except Error as e: print() raise SystemExit(e) output = args.outfile or sys.stdout with output: output.write(prettified) output.write('\n')
def statements(self): """ Returns: generator: Return value yielding fully parsed statements from the feed. """ logger = logging.getLogger() while self.finished_statements: stmt, comments = self.finished_statements.pop(0) if self.validate: with warnings.catch_warnings(): warnings.filterwarnings("error") # The prettify stmt includes a sanity check raising # warnings try: prettify(stmt, expression_level=1) except Exception as e: # We should raise Error, but pglast don't know about REPLICA # And we are not able to fix and rebuild pglast package logger.debug("Error while parsing statement") logger.debug(stmt) logger.debug(e) continue parse_stmt = Node(parse_sql(stmt)[0]).stmt if comments: parse_stmt.parse_tree['comments'] = comments parse_stmt.parse_tree['original_string'] = stmt if parse_stmt.node_tag == 'CreateFunctionStmt': # Find the language used for option in parse_stmt.options: if option.defname == 'language': parse_stmt.parse_tree[ 'language'] = option.arg.string_value yield parse_stmt
def get_where_clause_list_for_query(query: str): """ Retreives the complete WHERE CLAUSE """ query_tree = Node(parse_sql(query)) for tre in query_tree: for node in tre.stmt.whereClause: print(str(node))
def __init__(self, query_string, tables_map: dict): self.query_string = sqlparse.format(query_string, strip_comments=True, reindent=True).strip() self.table_alias_dict = dict() # parse query to get where clause columns res = sqlparse.parse(self.query_string) #print(res) root = Node(parse_sql(self.query_string)) if root[0].stmt.__class__.__name__ == 'Node': self._parse_node(root[0].stmt, tables_map)
def _parse(self): root_node = Node(parse_sql(self._sql)) assert len(root_node) == 1 statement = root_node[0].parse_tree['stmt'] if 'SelectStmt' not in statement: raise ParserError('Only SELECT statements are supported.') return statement['SelectStmt']
def get_data_hash(cursor, sql): statements = pglast.parse_sql(sql) if statements[0].stmt()["sortClause"]: hashed_data = hashlib.md5() cursor.execute(SQL(f"SELECT t.*::TEXT FROM ({sql}) as t")) for row in cursor: hashed_data.update(row[0].encode("utf-8")) return hashed_data.digest() return None
def extract_queried_tables_from_sql_query(query): """ Returns a list of (schema, table) tuples extracted from the passed PostgreSQL query This does not communicate with a database, and instead uses pglast to parse the query. However, it does not use pglast's built-in functions to extract tables - they're buggy in the cases where CTEs have the same names as tables in the search path. This isn't perfect though - it assumes tables without a schema are in the "public" schema, but "public" might not be in the search path, or it might not be the only schema in the search path. However, it's probably fine for our usage where "public" _is_ in the search path, and the only tables without a schema that we care about in our queries are indeed in the public schema - typically only reference dataset tables. """ try: statements = pglast.parse_sql(query) except pglast.parser.ParseError as e: logger.error(e) return [] tables = set() node_ctenames = deque() node_ctenames.append((statements[0](), ())) while node_ctenames: node, ctenames = node_ctenames.popleft() if node.get("withClause", None) is not None: if node["withClause"]["recursive"]: ctenames += tuple( (cte["ctename"] for cte in node["withClause"]["ctes"])) for cte in node["withClause"]["ctes"]: node_ctenames.append((cte, ctenames)) else: for cte in node["withClause"]["ctes"]: node_ctenames.append((cte, ctenames)) ctenames += (cte["ctename"], ) if node.get("@", None) == "RangeVar" and (node["schemaname"] is not None or node["relname"] not in ctenames): tables.add((node["schemaname"] or "public", node["relname"])) for node_type, node_value in node.items(): if node_type == "withClause": continue for nested_node in node_value if isinstance( node_value, tuple) else (node_value, ): if isinstance(nested_node, dict): node_ctenames.append((nested_node, ctenames)) return sorted(list(tables))
def test_basic(): ptree = parse_sql('SELECT 1') assert isinstance(ptree, list) assert len(ptree) == 1 rawstmt = ptree[0] assert isinstance(rawstmt, dict) assert rawstmt.keys() == {'RawStmt'} ptree = parse_plpgsql('CREATE FUNCTION add (a integer, b integer)' ' RETURNS integer AS $$ BEGIN RETURN a + b; END; $$' ' LANGUAGE plpgsql') assert len(ptree) == 1 function = ptree[0] assert isinstance(function, dict) assert function.keys() == {'PLpgSQL_function'}
def _validate_sql(sqlfile, uri): diagnostics = [] try: statements = parse_sql(sqlfile) except ParseError as e: pos = char_pos_to_position(sqlfile, e.location) diagnostics.append( Diagnostic(Range(pos, pos), message=e.args[0], severity=DiagnosticSeverity.Error, source=type(pg_language_server).__name__)) return diagnostics for statement in statements: for diag in lint(statement, None, None): diagnostics.append(diag) return diagnostics
def _parse_string(text): """ Use ``pglast`` to turn ``text`` into a SQL AST node. Returns ``pglast.node.Scalar(None)`` when no AST nodes could be parsed. This is a hack, but prevents convoluting the downstream logic too much, as ``Context.traverse`` will simply ignore scalar values. >>> _parse_string('SELECT 1') [1*{RawStmt}] >>> _parse_string('-- just a comment') <None> """ ast = pglast.parse_sql(text) return pglast.Node(ast) if ast else pglast.node.Scalar(None)
def _parse_column_type(typ): """ Feed column type name through pglast. >>> _parse_column_type('real') 'pg_catalog.float4' >>> _parse_column_type('double precision') 'pg_catalog.float8' """ sql = 'CREATE TABLE _(_ {0});'.format(typ) create_table = pglast.Node(pglast.parse_sql(sql))[0].stmt type_name = create_table.tableElts[0].typeName return format_type_name(type_name)
async def _query_table( dataset: str, version: str, sql: str, geometry: Optional[Geometry], ) -> List[Dict[str, Any]]: # parse and validate SQL statement try: parsed = parse_sql(unquote(sql)) except ParseError as e: raise HTTPException(status_code=400, detail=str(e)) _has_only_one_statement(parsed) _is_select_statement(parsed) _has_no_with_clause(parsed) _only_one_from_table(parsed) _no_subqueries(parsed) _no_forbidden_functions(parsed) _no_forbidden_value_functions(parsed) # always overwrite the table name with the current dataset version name, to make sure no other table is queried parsed[0]["RawStmt"]["stmt"]["SelectStmt"]["fromClause"][0]["RangeVar"][ "schemaname"] = dataset parsed[0]["RawStmt"]["stmt"]["SelectStmt"]["fromClause"][0]["RangeVar"][ "relname"] = version if geometry: parsed = await _add_geometry_filter(parsed, geometry) # convert back to text sql = RawStream()(Node(parsed)) try: rows = await db.all(sql) response: List[Dict[str, Any]] = [dict(row) for row in rows] except InsufficientPrivilegeError: raise HTTPException(status_code=403, detail="Not authorized to execute this query.") except (SyntaxOrAccessError, DataError) as e: raise HTTPException(status_code=400, detail=f"Bad request. {str(e)}") return response
def parse_column_type(typ): """ Feed the column type through pglast to normalize naming e.g. `timestamp with time zone => timestamptz`. >>> parse_column_type('integer') 'integer' >>> parse_column_type('custom(type)') 'custom(type)' """ sql = 'CREATE TABLE _(_ {0});'.format(typ) try: create_table = pglast.Node(pglast.parse_sql(sql))[0].stmt _, typ = _normalize_columns(create_table.tableElts)[0] return typ except pglast.parser.ParseError: raise RuleConfigurationException( RequireColumns, 'unable to parse column type "%s' % typ)
async def _add_geometry_filter(parsed_sql, geometry: Geometry): # make empty select statement with where clause including filter # this way we can later parse it as AST intersect_filter = f"SELECT WHERE ST_Intersects(geom, ST_SetSRID(ST_GeomFromGeoJSON('{geometry.json()}'),4326))" # combine the two where clauses parsed_filter = parse_sql(intersect_filter) filter_where = parsed_filter[0]["RawStmt"]["stmt"]["SelectStmt"][ "whereClause"] sql_where = parsed_sql[0]["RawStmt"]["stmt"]["SelectStmt"].get( "whereClause", None) if sql_where: parsed_sql[0]["RawStmt"]["stmt"]["SelectStmt"]["whereClause"] = { "BoolExpr": { "boolop": 0, "args": [sql_where, filter_where] } } else: parsed_sql[0]["RawStmt"]["stmt"]["SelectStmt"][ "whereClause"] = filter_where return parsed_sql
def parse(sql: str, name: str = None) -> Parsed: if name is None: name = str(hash(sql)) node = parse_sql(sql) return Parsed(name, sql, node)
def test_unicode(): ptree = parse_sql('SELECT 1 AS "Naïve"') target = ptree[0]['RawStmt']['stmt']['SelectStmt']['targetList'][0][ 'ResTarget'] assert target['name'] == "Naïve"
def parse(sql): return AcceptingNode(parse_sql(sql))
def prepare_splitfile_sql(sql: str, image_mapper: Callable) -> Tuple[str, str]: """ Transform an SQL query to prepare for it to be used in a Splitfile SQL command and validate it. The rules are: * Only basic DDL (CREATE/ALTER/DROP table) and DML (SELECT/INSERT/UPDATE/DELETE) are permitted. * All tables must be either non-schema qualified (the statement is run with `search_path` set to the single schema that a Splitgraph image is checked out into) or have schemata of format namespace/repository:hash_or_tag. In the second case, the schema is rewritten to point at a temporary mount of the Splitgraph image. :param sql: SQL query :param image_mapper: Takes in an image and gives back the schema it should be rewritten to (for the purposes of execution) and the canonical form of the image. :return: Transformed form of the SQL with substituted schema shims for Splitfile execution and the canonical form (with e.g. tags resolved into at-the-time full image hashes) :raises: UnsupportedSQLException if validation failed """ if not _VALIDATION_SUPPORTED: logging.warning("SQL validation is unsupported on Windows. SQL will be run unvalidated.") return _rewrite_sql_fallback(sql, image_mapper) # Avoid circular import from splitgraph.core.output import parse_repo_tag_or_hash try: tree = Node(parse_sql(sql)) except ParseError as e: raise UnsupportedSQLError("Could not parse %s: %s" % (sql, str(e))) # List of dict pointers (into parts of the AST) and new schema names we have # to rewrite them to. We need to emit two kinds of rewritten SQL: one with schemata # replaced with LQ shims (that we send to PostgreSQL for execution) and one with # schemata replaced with full image names that we store in provenance_data for reruns etc. # On the first pass, we rewrite the parse tree to have the first kind of schemata, then # get pglast to serialize it, then use this dictionary to rewrite just the interesting # parts of the parse tree (instead of having to re-traverse/re-crawl the tree). future_rewrites = [] for node in tree.traverse(): _validate_node( node, permitted_statements=SPLITFILE_SQL_PERMITTED_STATEMENTS, node_validators=_SQL_VALIDATORS, ) if not isinstance(node, Node) or node.node_tag != "RangeVar": continue if node["relname"].value in PG_CATALOG_TABLES: raise UnsupportedSQLError("Invalid table name %s!" % node["relname"].value) if "schemaname" not in node.attribute_names: continue schema_name = recover_original_schema_name(sql, node["schemaname"].value) # If the table name is schema-qualified, rewrite it to talk to a LQ shim repo, hash_or_tag = parse_repo_tag_or_hash(schema_name, default="latest") temporary_schema, canonical_name = image_mapper(repo, hash_or_tag) # We have to access the internal parse tree here to rewrite the schema. node._parse_tree["schemaname"] = temporary_schema future_rewrites.append((node._parse_tree, canonical_name)) rewritten_sql = _emit_ast(tree) for tree_subset, canonical_name in future_rewrites: tree_subset["schemaname"] = canonical_name canonical_sql = _emit_ast(tree) return rewritten_sql, canonical_sql
def get_where_clause_list_for_query(query: str): query_tree = Node(parse_sql(query)) for tre in query_tree: for node in tre.stmt.whereClause: print(str(node))