def test_parse_simple_select_with_table_schema_prefix_and_extra_whitespace(): sql_meta = SqlParser.parse( ''' SELECT * FROM schema0.table0 ; ''' ) assert sql_meta.in_tables == [DbTableName('schema0.table0')] assert sql_meta.out_tables == []
def test_parse_simple_insert_into(): sql_meta = SqlParser.parse( ''' INSERT INTO table0 (col0, col1, col2) VALUES (val0, val1, val2); ''' ) assert sql_meta.in_tables == [] assert sql_meta.out_tables == [DbTableName('table0')]
def test_parse_simple_select(): sql_meta = SqlParser.parse( ''' SELECT * FROM table0; ''' ) log.debug("sqlparser.parse() successful.") assert sql_meta.in_tables == [DbTableName('table0')] assert sql_meta.out_tables == []
def test_parse_recursive_cte(): sql_meta = SqlParser.parse( ''' WITH RECURSIVE subordinates AS (SELECT employee_id, manager_id, full_name FROM employees WHERE employee_id = 2 UNION SELECT e.employee_id, e.manager_id, e.full_name FROM employees e INNER JOIN subordinates s ON s.employee_id = e.manager_id) INSERT INTO sub_employees (employee_id, manager_id, full_name) SELECT employee_id, manager_id, full_name FROM subordinates; ''' ) assert sql_meta.in_tables == [DbTableName('employees')] assert sql_meta.out_tables == [DbTableName('sub_employees')]
def _get_table(self, table: str, client: bigquery.Client) -> DbTableSchema: bq_table = client.get_table(table) if not bq_table._properties: return table = bq_table._properties if not table.get('schema') or not table.get('schema').get('fields'): return fields = table.get('schema').get('fields') columns = [ DbColumn(name=fields[i].get('name'), type=fields[i].get('type'), description=fields[i].get('description'), ordinal_position=i) for i in range(len(fields)) ] self.log.info(DbTableName(table.get('tableReference').get('tableId'))) return DbTableSchema( schema_name=table.get('tableReference').get('projectId') + '.' + table.get('tableReference').get('datasetId'), table_name=DbTableName(table.get('tableReference').get('tableId')), columns=columns)
def test_parser_integration(): sql_meta = SqlParser.parse( """ INSERT INTO popular_orders_day_of_week (order_day_of_week, order_placed_on,orders_placed) SELECT EXTRACT(ISODOW FROM order_placed_on) AS order_day_of_week, order_placed_on, COUNT(*) AS orders_placed FROM top_delivery_times GROUP BY order_placed_on; """, "public" ) assert sql_meta.in_tables == [DbTableName('public.top_delivery_times')]
def _get_table(tokens, idx): idx, token = tokens.token_next(idx=idx) token_list = token.flatten() table_name = next(token_list).value try: # Determine if the table contains the schema # separated by a dot (format: 'schema.table') dot = next(token_list) if dot.match(Punctuation, '.'): table_name += dot.value table_name += next(token_list).value except StopIteration: pass return idx, DbTableName(table_name)
def _get_table_schemas(self, table_names: [DbTableName]) -> [DbTableSchema]: # Avoid querying postgres by returning an empty array # if no table names have been provided. if not table_names: return [] # Keeps tack of the schema by table. schemas_by_table = {} hook = PostgresHook(postgres_conn_id=self.operator.postgres_conn_id, schema=self.operator.database) with closing(hook.get_conn()) as conn: with closing(conn.cursor()) as cursor: table_names_as_list = ",".join( map(lambda name: f"'{name}'", table_names)) cursor.execute(f""" SELECT table_schema, table_name, column_name, ordinal_position, udt_name FROM information_schema.columns WHERE table_name IN ({table_names_as_list}); """) for row in cursor.fetchall(): table_schema_name: str = row[_TABLE_SCHEMA] table_name: DbTableName = DbTableName(row[_TABLE_NAME]) table_column: DbColumn = DbColumn( name=row[_COLUMN_NAME], type=row[_UDT_NAME], ordinal_position=row[_ORDINAL_POSITION]) # Attempt to get table schema table_key: str = f"{table_schema_name}.{table_name}" table_schema: Optional[DbTableSchema] = \ schemas_by_table.get(table_key) if table_schema: # Add column to existing table schema. schemas_by_table[table_key]. \ columns.append(table_column) else: # Create new table schema with column. schemas_by_table[table_key] = DbTableSchema( schema_name=table_schema_name, table_name=table_name, columns=[table_column]) return list(schemas_by_table.values())
def _get_tables( tokens, idx, default_schema: Optional[str] = None) -> Tuple[int, List[DbTableName]]: # Extract table identified by preceding SQL keyword at '_is_in_table' def parse_ident(ident: Identifier) -> str: # Extract table name from possible schema.table naming token_list = ident.flatten() table_name = next(token_list).value try: # Determine if the table contains the schema # separated by a dot (format: 'schema.table') dot = next(token_list) if dot.match(Punctuation, '.'): table_name += dot.value table_name += next(token_list).value elif default_schema: table_name = f'{default_schema}.{table_name}' except StopIteration: if default_schema: table_name = f'{default_schema}.{table_name}' return table_name idx, token = tokens.token_next(idx=idx) tables = [] if isinstance(token, IdentifierList): # Handle "comma separated joins" as opposed to explicit JOIN keyword gidx = 0 tables.append( parse_ident(token.token_first(skip_ws=True, skip_cm=True))) gidx, punc = token.token_next(gidx, skip_ws=True, skip_cm=True) while punc and punc.value == ',': gidx, name = token.token_next(gidx, skip_ws=True, skip_cm=True) tables.append(parse_ident(name)) gidx, punc = token.token_next(gidx) else: tables.append(parse_ident(token)) return idx, [DbTableName(table) for table in tables]
def test_tpcds_cte_query(): sql_meta = SqlParser.parse(""" WITH year_total AS (SELECT c_customer_id customer_id, c_first_name customer_first_name, c_last_name customer_last_name, c_preferred_cust_flag customer_preferred_cust_flag, c_birth_country customer_birth_country, c_login customer_login, c_email_address customer_email_address, d_year dyear, Sum(((ss_ext_list_price - ss_ext_wholesale_cost - ss_ext_discount_amt) + ss_ext_sales_price) / 2) year_total, 's' sale_type FROM src.customer, store_sales, date_dim WHERE c_customer_sk = ss_customer_sk AND ss_sold_date_sk = d_date_sk GROUP BY c_customer_id, c_first_name, c_last_name, c_preferred_cust_flag, c_birth_country, c_login, c_email_address, d_year) SELECT t_s_secyear.customer_id, t_s_secyear.customer_first_name, t_s_secyear.customer_last_name, t_s_secyear.customer_preferred_cust_flag FROM year_total t_s_firstyear, year_total t_s_secyear, year_total t_c_firstyear, year_total t_c_secyear, year_total t_w_firstyear, year_total t_w_secyear WHERE t_s_secyear.customer_id = t_s_firstyear.customer_id AND t_s_firstyear.customer_id = t_c_secyear.customer_id AND t_s_firstyear.customer_id = t_c_firstyear.customer_id AND t_s_firstyear.customer_id = t_w_firstyear.customer_id AND t_s_firstyear.customer_id = t_w_secyear.customer_id AND t_s_firstyear.sale_type = 's' AND t_c_firstyear.sale_type = 'c' AND t_w_firstyear.sale_type = 'w' AND t_s_secyear.sale_type = 's' AND t_c_secyear.sale_type = 'c' AND t_w_secyear.sale_type = 'w' AND t_s_firstyear.dyear = 2001 AND t_s_secyear.dyear = 2001 + 1 AND t_c_firstyear.dyear = 2001 AND t_c_secyear.dyear = 2001 + 1 AND t_w_firstyear.dyear = 2001 AND t_w_secyear.dyear = 2001 + 1 AND t_s_firstyear.year_total > 0 AND t_c_firstyear.year_total > 0 AND t_w_firstyear.year_total > 0 AND CASE WHEN t_c_firstyear.year_total > 0 THEN t_c_secyear.year_total / t_c_firstyear.year_total ELSE NULL END > CASE WHEN t_s_firstyear.year_total > 0 THEN t_s_secyear.year_total / t_s_firstyear.year_total ELSE NULL END AND CASE WHEN t_c_firstyear.year_total > 0 THEN t_c_secyear.year_total / t_c_firstyear.year_total ELSE NULL END > CASE WHEN t_w_firstyear.year_total > 0 THEN t_w_secyear.year_total / t_w_firstyear.year_total ELSE NULL END ORDER BY t_s_secyear.customer_id, t_s_secyear.customer_first_name, t_s_secyear.customer_last_name, t_s_secyear.customer_preferred_cust_flag LIMIT 100; """) assert set(sql_meta.in_tables) == { DbTableName("src.customer"), DbTableName("store_sales"), DbTableName("date_dim") } assert len(sql_meta.out_tables) == 0
import mock from airflow.hooks.postgres_hook import PostgresHook from airflow.operators.postgres_operator import PostgresOperator from airflow.utils.dates import days_ago from marquez_airflow import DAG from marquez_airflow.models import (DbTableName, DbTableSchema, DbColumn) from marquez_airflow.extractors import Source, Dataset, DatasetType from marquez_airflow.extractors.postgres_extractor import PostgresExtractor CONN_ID = 'food_delivery_db' CONN_URI = 'postgres://localhost:5432/food_delivery' DB_SCHEMA_NAME = 'public' DB_TABLE_NAME = DbTableName('discounts') DB_TABLE_COLUMNS = [ DbColumn(name='id', type='int4', ordinal_position=1), DbColumn(name='amount_off', type='int4', ordinal_position=2), DbColumn(name='customer_email', type='varchar', ordinal_position=3), DbColumn(name='starts_on', type='timestamp', ordinal_position=4), DbColumn(name='ends_on', type='timestamp', ordinal_position=5) ] DB_TABLE_SCHEMA = DbTableSchema(schema_name=DB_SCHEMA_NAME, table_name=DB_TABLE_NAME, columns=DB_TABLE_COLUMNS) NO_DB_TABLE_SCHEMA = [] SQL = f"SELECT * FROM {DB_TABLE_NAME.name};" DAG_ID = 'email_discounts'