def test_sqllineage_sql_parser_with_weird_lookml_query(): sql_query = """ SELECT date DATE, platform VARCHAR(20) AS aliased_platform, country VARCHAR(20) FROM fragment_derived_view' """ columns_list = SqlLineageSQLParser(sql_query).get_columns() columns_list.sort() assert columns_list == ["aliased_platform", "country", "date"]
def test_sqllineage_sql_parser_get_tables_from_templated_query(): sql_query = """ SELECT country, city, timestamp, measurement FROM ${my_view.SQL_TABLE_NAME} AS my_view """ tables_list = SqlLineageSQLParser(sql_query).get_tables() tables_list.sort() assert tables_list == ["my_view.SQL_TABLE_NAME"]
def test_sqllineage_sql_parser_get_columns_from_templated_query(): sql_query = """ SELECT country, city, timestamp, measurement FROM ${my_view.SQL_TABLE_NAME} AS my_view """ columns_list = SqlLineageSQLParser(sql_query).get_columns() columns_list.sort() assert columns_list == ["city", "country", "measurement", "timestamp"]
def test_sqllineage_sql_parser_get_tables_from_complex_query(): sql_query = """ ( SELECT CAST(substring(e, 1, 10) AS date) AS __d_a_t_e, e AS e, u AS u, x, c, count(*) FROM schema1.foo WHERE datediff('day', substring(e, 1, 10) :: date, date :: date) <= 7 AND CAST(substring(e, 1, 10) AS date) >= date('2010-01-01') AND CAST(substring(e, 1, 10) AS date) < getdate() GROUP BY 1, 2, 3, 4, 5) UNION ALL( SELECT CAST(substring(e, 1, 10) AS date) AS date, e AS e, u AS u, x, c, count(*) FROM schema2.bar WHERE datediff('day', substring(e, 1, 10) :: date, date :: date) <= 7 AND CAST(substring(e, 1, 10) AS date) >= date('2020-08-03') AND CAST(substring(e, 1, 10) AS date) < getdate() GROUP BY 1, 2, 3, 4, 5) """ tables_list = SqlLineageSQLParser(sql_query).get_tables() tables_list.sort() assert tables_list == ["schema1.foo", "schema2.bar"]
def test_sqllineage_sql_parser_get_columns_complex_query_with_union(): sql_query = """ ( SELECT CAST(substring(e, 1, 10) AS date) AS date , e AS e, u AS u, x, c, count(*) FROM foo WHERE datediff('day', substring(e, 1, 10) :: date, date :: date) <= 7 AND CAST(substring(e, 1, 10) AS date) >= date('2010-01-01') AND CAST(substring(e, 1, 10) AS date) < getdate() GROUP BY 1, 2, 3, 4, 5) UNION ALL( SELECT CAST(substring(e, 1, 10) AS date) AS date, e AS e, u AS u, x, c, count(*) FROM bar WHERE datediff('day', substring(e, 1, 10) :: date, date :: date) <= 7 AND CAST(substring(e, 1, 10) AS date) >= date('2020-08-03') AND CAST(substring(e, 1, 10) AS date) < getdate() GROUP BY 1, 2, 3, 4, 5) """ columns_list = SqlLineageSQLParser(sql_query).get_columns() columns_list.sort() assert columns_list == ["c", "date", "e", "u", "x"]
def test_sqllineage_sql_parser_tables_from_redash_query(): sql_query = """SELECT name, SUM(quantity * list_price * (1 - discount)) AS total, YEAR(order_date) as order_year FROM `orders` o INNER JOIN `order_items` i ON i.order_id = o.order_id INNER JOIN `staffs` s ON s.staff_id = o.staff_id GROUP BY name, year(order_date)""" table_list = SqlLineageSQLParser(sql_query).get_tables() table_list.sort() assert table_list == ["order_items", "orders", "staffs"]
def test_metadatasql_sql_parser_get_columns_with_more_complex_join(): sql_query = """ INSERT INTO foo SELECT pl.pi pi, REGEXP_REPLACE(pl.tt, '_', ' ') pt, pl.tt pu, fp.v, fp.bs FROM bar pl JOIN baz fp ON fp.rt = pl.rt WHERE fp.dt = '2018-01-01' """ columns_list = SqlLineageSQLParser(sql_query).get_columns() columns_list.sort() assert columns_list == ["bs", "pi", "pt", "pu", "v"]
def test_metadatasql_sql_parser_get_columns_with_alias_and_count_star(): sql_query = "SELECT foo.a, foo.b, bar.c as test, count(*) as count FROM foo JOIN bar ON (foo.a == bar.b);" columns_list = SqlLineageSQLParser(sql_query).get_columns() columns_list.sort() assert columns_list == ["a", "b", "count", "test"]
def test_sqllineage_sql_parser_get_columns_from_simple_query(): sql_query = "SELECT foo.a, foo.b FROM foo;" columns_list = SqlLineageSQLParser(sql_query).get_columns() columns_list.sort() assert columns_list == ["a", "b"]
def test_sqllineage_sql_parser_get_columns_with_join(): sql_query = "SELECT foo.a, foo.b, bar.c FROM foo JOIN bar ON (foo.a == bar.b);" columns_list = SqlLineageSQLParser(sql_query).get_columns() columns_list.sort() assert columns_list == ["a", "b", "c"]
class BigQuerySQLParser(SQLParser): parser: SQLParser def __init__(self, sql_query: str) -> None: super().__init__(sql_query) self._parsed_sql_query = self.parse_sql_query(sql_query) self.parser = SqlLineageSQLParser(self._parsed_sql_query) def parse_sql_query(self, sql_query: str) -> str: sql_query = BigQuerySQLParser._parse_bigquery_comment_sign(sql_query) sql_query = BigQuerySQLParser._escape_keyword_from_as_field_name( sql_query) sql_query = BigQuerySQLParser._escape_cte_name_after_keyword_with( sql_query) sql_query = sqlparse.format( sql_query.strip(), reindent_aligned=True, strip_comments=True, ) sql_query = BigQuerySQLParser._escape_table_or_view_name_at_create_statement( sql_query) sql_query = BigQuerySQLParser._escape_object_name_after_keyword_from( sql_query) return sql_query @staticmethod def _parse_bigquery_comment_sign(sql_query: str) -> str: return re.sub(r"#(.*)", r"-- \1", sql_query, flags=re.IGNORECASE) @staticmethod def _escape_keyword_from_as_field_name(sql_query: str) -> str: return re.sub(r"(\w*\.from)", r"`\1`", sql_query, flags=re.IGNORECASE) @staticmethod def _escape_cte_name_after_keyword_with(sql_query: str) -> str: """ Escape the first cte name in case it is one of reserved words """ return re.sub(r"(with\s)([^`\s()]+)", r"\1`\2`", sql_query, flags=re.IGNORECASE) @staticmethod def _escape_table_or_view_name_at_create_statement(sql_query: str) -> str: """ Reason: in case table name contains hyphens which breaks sqllineage later on """ return re.sub( r"(create.*\s)(table\s|view\s)([^`\s()]+)(?=\sas)", r"\1\2`\3`", sql_query, flags=re.IGNORECASE, ) @staticmethod def _escape_object_name_after_keyword_from(sql_query: str) -> str: """ Reason: in case table name contains hyphens which breaks sqllineage later on Note: ignore cases of having keyword FROM as part of datetime function EXTRACT """ return re.sub( r"(?<!day\s)(?<!(date|time|hour|week|year)\s)(?<!month\s)(?<!(second|minute)\s)(?<!quarter\s)(?<!\.)(from\s)([^`\s()]+)", r"\3`\4`", sql_query, flags=re.IGNORECASE, ) def get_tables(self) -> List[str]: return self.parser.get_tables() def get_columns(self) -> List[str]: return self.parser.get_columns()
def __init__(self, sql_query: str) -> None: super().__init__(sql_query) self._parsed_sql_query = self.parse_sql_query(sql_query) self.parser = SqlLineageSQLParser(self._parsed_sql_query)