def extract_on_complete(self, task_instance) -> TaskMetadata: inputs = [ Dataset.from_table_schema(self.source, DbTableSchema( schema_name='schema', table_name=DbTableName('extract_on_complete_input1'), columns=[DbColumn( name='field1', type='text', description='', ordinal_position=1 ), DbColumn( name='field2', type='text', description='', ordinal_position=2 )] )).to_openlineage_dataset() ] outputs = [ Dataset.from_table(self.source, "extract_on_complete_output1").to_openlineage_dataset() ] return TaskMetadata( name=get_job_name(task=self.operator), inputs=inputs, outputs=outputs, )
def _get_tables( tokens, idx, default_schema: Optional[str] = None) -> Tuple[int, List[DbTableName]]: # Extract table identified by preceding SQL keyword at '_is_in_table' def parse_ident(ident: Identifier) -> str: # Extract table name from possible schema.table naming token_list = ident.flatten() table_name = next(token_list).value try: # Determine if the table contains the schema # separated by a dot (format: 'schema.table') dot = next(token_list) if dot.match(Punctuation, '.'): table_name += dot.value table_name += next(token_list).value # And again, to match bigquery's 'database.schema.table' try: dot = next(token_list) if dot.match(Punctuation, '.'): table_name += dot.value table_name += next(token_list).value except StopIteration: # Do not insert database name if it's not specified pass elif default_schema: table_name = f'{default_schema}.{table_name}' except StopIteration: if default_schema: table_name = f'{default_schema}.{table_name}' table_name = table_name.replace('`', '') return table_name idx, token = tokens.token_next(idx=idx) tables = [] if isinstance(token, IdentifierList): # Handle "comma separated joins" as opposed to explicit JOIN keyword gidx = 0 tables.append( parse_ident(token.token_first(skip_ws=True, skip_cm=True))) gidx, punc = token.token_next(gidx, skip_ws=True, skip_cm=True) while punc and punc.value == ',': gidx, name = token.token_next(gidx, skip_ws=True, skip_cm=True) tables.append(parse_ident(name)) gidx, punc = token.token_next(gidx) else: tables.append(parse_ident(token)) return idx, [DbTableName(table) for table in tables]
def table_schema(source): schema_name = 'public' table_name = DbTableName('discounts') columns = [ DbColumn(name='id', type='int4', ordinal_position=1), DbColumn(name='amount_off', type='int4', ordinal_position=2), DbColumn(name='customer_email', type='varchar', ordinal_position=3), DbColumn(name='starts_on', type='timestamp', ordinal_position=4), DbColumn(name='ends_on', type='timestamp', ordinal_position=5) ] schema = DbTableSchema(schema_name=schema_name, table_name=table_name, columns=columns) return source, columns, schema
def _get_table_schemas(self, table_names: [DbTableName]) -> [DbTableSchema]: # Avoid querying postgres by returning an empty array # if no table names have been provided. if not table_names: return [] # Keeps tack of the schema by table. schemas_by_table = {} hook = self._get_hook() with closing(hook.get_conn()) as conn: with closing(conn.cursor()) as cursor: table_names_as_str = ",".join( map(lambda name: f"'{name.name}'", table_names)) cursor.execute( self._information_schema_query(table_names_as_str)) for row in cursor.fetchall(): table_schema_name: str = row[_TABLE_SCHEMA] table_name: DbTableName = DbTableName(row[_TABLE_NAME]) table_column: DbColumn = DbColumn( name=row[_COLUMN_NAME], type=row[_UDT_NAME], ordinal_position=row[_ORDINAL_POSITION]) # Attempt to get table schema table_key: str = f"{table_schema_name}.{table_name}" table_schema: Optional[ DbTableSchema] = schemas_by_table.get(table_key) if table_schema: # Add column to existing table schema. schemas_by_table[table_key].columns.append( table_column) else: # Create new table schema with column. schemas_by_table[table_key] = DbTableSchema( schema_name=table_schema_name, table_name=table_name, columns=[table_column]) return list(schemas_by_table.values())
def _get_table(self, table: str) -> Optional[DbTableSchema]: bq_table = self.client.get_table(table) if not bq_table._properties: return table = bq_table._properties fields = get_from_nullable_chain(table, ['schema', 'fields']) if not fields: return columns = [ DbColumn(name=fields[i].get('name'), type=fields[i].get('type'), description=fields[i].get('description'), ordinal_position=i) for i in range(len(fields)) ] return DbTableSchema( schema_name=table.get('tableReference').get('projectId') + '.' + table.get('tableReference').get('datasetId'), table_name=DbTableName(table.get('tableReference').get('tableId')), columns=columns)
airflow_1_path="airflow.operators.postgres_operator.PostgresOperator", airflow_2_path="airflow.providers.postgres.operators.postgres.PostgresOperator" ) PostgresHook = safe_import_airflow( airflow_1_path="airflow.hooks.postgres_hook.PostgresHook", airflow_2_path="airflow.providers.postgres.hooks.postgres.PostgresHook" ) CONN_ID = 'food_delivery_db' CONN_URI = 'postgres://*****:*****@localhost:5432/food_delivery' CONN_URI_WITHOUT_USERPASS = '******' DB_NAME = 'food_delivery' DB_SCHEMA_NAME = 'public' DB_TABLE_NAME = DbTableName('discounts') DB_TABLE_COLUMNS = [ DbColumn( name='id', type='int4', ordinal_position=1 ), DbColumn( name='amount_off', type='int4', ordinal_position=2 ), DbColumn( name='customer_email', type='varchar', ordinal_position=3
def test_tpcds_cte_query(): sql_meta = SqlParser.parse(""" WITH year_total AS (SELECT c_customer_id customer_id, c_first_name customer_first_name, c_last_name customer_last_name, c_preferred_cust_flag customer_preferred_cust_flag, c_birth_country customer_birth_country, c_login customer_login, c_email_address customer_email_address, d_year dyear, Sum(((ss_ext_list_price - ss_ext_wholesale_cost - ss_ext_discount_amt) + ss_ext_sales_price) / 2) year_total, 's' sale_type FROM src.customer, store_sales, date_dim WHERE c_customer_sk = ss_customer_sk AND ss_sold_date_sk = d_date_sk GROUP BY c_customer_id, c_first_name, c_last_name, c_preferred_cust_flag, c_birth_country, c_login, c_email_address, d_year) SELECT t_s_secyear.customer_id, t_s_secyear.customer_first_name, t_s_secyear.customer_last_name, t_s_secyear.customer_preferred_cust_flag FROM year_total t_s_firstyear, year_total t_s_secyear, year_total t_c_firstyear, year_total t_c_secyear, year_total t_w_firstyear, year_total t_w_secyear WHERE t_s_secyear.customer_id = t_s_firstyear.customer_id AND t_s_firstyear.customer_id = t_c_secyear.customer_id AND t_s_firstyear.customer_id = t_c_firstyear.customer_id AND t_s_firstyear.customer_id = t_w_firstyear.customer_id AND t_s_firstyear.customer_id = t_w_secyear.customer_id AND t_s_firstyear.sale_type = 's' AND t_c_firstyear.sale_type = 'c' AND t_w_firstyear.sale_type = 'w' AND t_s_secyear.sale_type = 's' AND t_c_secyear.sale_type = 'c' AND t_w_secyear.sale_type = 'w' AND t_s_firstyear.dyear = 2001 AND t_s_secyear.dyear = 2001 + 1 AND t_c_firstyear.dyear = 2001 AND t_c_secyear.dyear = 2001 + 1 AND t_w_firstyear.dyear = 2001 AND t_w_secyear.dyear = 2001 + 1 AND t_s_firstyear.year_total > 0 AND t_c_firstyear.year_total > 0 AND t_w_firstyear.year_total > 0 AND CASE WHEN t_c_firstyear.year_total > 0 THEN t_c_secyear.year_total / t_c_firstyear.year_total ELSE NULL END > CASE WHEN t_s_firstyear.year_total > 0 THEN t_s_secyear.year_total / t_s_firstyear.year_total ELSE NULL END AND CASE WHEN t_c_firstyear.year_total > 0 THEN t_c_secyear.year_total / t_c_firstyear.year_total ELSE NULL END > CASE WHEN t_w_firstyear.year_total > 0 THEN t_w_secyear.year_total / t_w_firstyear.year_total ELSE NULL END ORDER BY t_s_secyear.customer_id, t_s_secyear.customer_first_name, t_s_secyear.customer_last_name, t_s_secyear.customer_preferred_cust_flag LIMIT 100; """) assert set(sql_meta.in_tables) == { DbTableName("src.customer"), DbTableName("store_sales"), DbTableName("date_dim") } assert len(sql_meta.out_tables) == 0
def test_eq_table_name(): assert DbTableName('discounts') != DbTableName('public.discounts') assert DbTableName('discounts').qualified_name is None assert DbTableName('public.discounts').qualified_name == 'public.discounts'