Пример #1
0
def test_parse_simple_select_into():
    sql_meta = SqlParser.parse(
        '''
        SELECT *
          INTO table0
          FROM table1;
        '''
    )

    assert sql_meta.in_tables == [DbTableName('table1')]
    assert sql_meta.out_tables == [DbTableName('table0')]
Пример #2
0
def test_parse_simple_insert_into_select():
    sql_meta = SqlParser.parse(
        '''
        INSERT INTO table1 (col0, col1, col2)
        SELECT col0, col1, col2
          FROM table0;
        '''
    )

    assert sql_meta.in_tables == [DbTableName('table0')]
    assert sql_meta.out_tables == [DbTableName('table1')]
Пример #3
0
def test_parse_simple_right_outer_join():
    sql_meta = SqlParser.parse(
        '''
        SELECT col0, col1, col2
          FROM table0
          RIGHT OUTER JOIN table1
            ON t1.col0 = t2.col0;
        '''
    )

    assert sql_meta.in_tables == [DbTableName('table0'), DbTableName('table1')]
    assert sql_meta.out_tables == []
Пример #4
0
def test_parse_simple_left_join():
    sql_meta = SqlParser.parse(
        '''
        SELECT col0, col1, col2
          FROM table0
          LEFT JOIN table1
            ON t1.col0 = t2.col0
        '''
    )

    assert sql_meta.in_tables == [DbTableName('table0'), DbTableName('table1')]
    assert sql_meta.out_tables == []
Пример #5
0
def test_parse_simple_inner_join():
    sql_meta = SqlParser.parse(
        '''
        SELECT col0, col1, col2
          FROM table0
         INNER JOIN table1
            ON t1.col0 = t2.col0
        '''
    )

    assert set(sql_meta.in_tables) == {DbTableName('table0'), DbTableName('table1')}
    assert sql_meta.out_tables == []
Пример #6
0
 def extract_on_complete(self, task_instance) -> StepMetadata:
     inputs = [
         Dataset.from_table_schema(self.source, DbTableSchema(
             schema_name='schema',
             table_name=DbTableName('extract_on_complete_input1'),
             columns=[DbColumn(
                 name='field1',
                 type='text',
                 description='',
                 ordinal_position=1
             ),
                 DbColumn(
                 name='field2',
                 type='text',
                 description='',
                 ordinal_position=2
             )]
         ))
     ]
     outputs = [
         Dataset.from_table(self.source, "extract_on_complete_output1")
     ]
     return StepMetadata(
         name=get_job_name(task=self.operator),
         inputs=inputs,
         outputs=outputs,
         context={
             "extract_on_complete": "extract_on_complete"
         }
     )
Пример #7
0
def test_parse_simple_cte():
    sql_meta = SqlParser.parse(
        '''
        WITH sum_trans as (
            SELECT user_id, COUNT(*) as cnt, SUM(amount) as balance
            FROM transactions
            WHERE created_date > '2020-01-01'
            GROUP BY user_id
        )
        INSERT INTO potential_fraud (user_id, cnt, balance)
        SELECT user_id, cnt, balance
          FROM sum_trans
          WHERE count > 1000 OR balance > 100000;
        '''
    )
    assert sql_meta.in_tables == [DbTableName('transactions')]
    assert sql_meta.out_tables == [DbTableName('potential_fraud')]
Пример #8
0
    def _get_table_schemas(
            self, table_names: [DbTableName]
    ) -> [DbTableSchema]:
        # Avoid querying postgres by returning an empty array
        # if no table names have been provided.
        if not table_names:
            return []

        # Keeps tack of the schema by table.
        schemas_by_table = {}

        hook = PostgresHook(
            postgres_conn_id=self.operator.postgres_conn_id,
            schema=self.operator.database
        )
        with closing(hook.get_conn()) as conn:
            with closing(conn.cursor()) as cursor:
                table_names_as_list = ",".join(map(
                    lambda name: f"'{name}'", table_names
                ))
                cursor.execute(
                    f"""
                    SELECT table_schema,
                           table_name,
                           column_name,
                           ordinal_position,
                           udt_name
                      FROM information_schema.columns
                     WHERE table_name IN ({table_names_as_list});
                    """
                )
                for row in cursor.fetchall():
                    table_schema_name: str = row[_TABLE_SCHEMA]
                    table_name: DbTableName = DbTableName(row[_TABLE_NAME])
                    table_column: DbColumn = DbColumn(
                        name=row[_COLUMN_NAME],
                        type=row[_UDT_NAME],
                        ordinal_position=row[_ORDINAL_POSITION]
                    )

                    # Attempt to get table schema
                    table_key: str = f"{table_schema_name}.{table_name}"
                    table_schema: Optional[DbTableSchema] = \
                        schemas_by_table.get(table_key)

                    if table_schema:
                        # Add column to existing table schema.
                        schemas_by_table[table_key]. \
                            columns.append(table_column)
                    else:
                        # Create new table schema with column.
                        schemas_by_table[table_key] = DbTableSchema(
                            schema_name=table_schema_name,
                            table_name=table_name,
                            columns=[table_column]
                        )

        return list(schemas_by_table.values())
Пример #9
0
def test_ignores_default_schema_when_non_default_schema():
    sql_meta = SqlParser.parse(
        '''
        SELECT col0, col1, col2
          FROM transactions.table0
        ''',
        'public'
    )
    assert sql_meta.in_tables == [DbTableName('transactions.table0')]
Пример #10
0
def test_parse_default_schema():
    sql_meta = SqlParser.parse(
        '''
        SELECT col0, col1, col2
          FROM table0
        ''',
        'public'
    )
    assert sql_meta.in_tables == [DbTableName('public.table0')]
Пример #11
0
def test_parse_simple_select_with_table_schema_prefix_and_extra_whitespace():
    sql_meta = SqlParser.parse(
        '''
        SELECT *
          FROM    schema0.table0   ;
        '''
    )

    assert sql_meta.in_tables == [DbTableName('schema0.table0')]
    assert sql_meta.out_tables == []
Пример #12
0
def test_parse_simple_insert_into():
    sql_meta = SqlParser.parse(
        '''
        INSERT INTO table0 (col0, col1, col2)
        VALUES (val0, val1, val2);
        '''
    )

    assert sql_meta.in_tables == []
    assert sql_meta.out_tables == [DbTableName('table0')]
Пример #13
0
def test_parse_recursive_cte():
    sql_meta = SqlParser.parse(
        '''
        WITH RECURSIVE subordinates AS
            (SELECT employee_id,
                manager_id,
                full_name
            FROM employees
            WHERE employee_id = 2
            UNION SELECT e.employee_id,
                e.manager_id,
                e.full_name
            FROM employees e
            INNER JOIN subordinates s ON s.employee_id = e.manager_id)
        INSERT INTO sub_employees (employee_id, manager_id, full_name)
        SELECT employee_id, manager_id, full_name FROM subordinates;
        '''
    )
    assert sql_meta.in_tables == [DbTableName('employees')]
    assert sql_meta.out_tables == [DbTableName('sub_employees')]
Пример #14
0
def test_parse_simple_select():
    sql_meta = SqlParser.parse(
        '''
        SELECT *
          FROM table0;
        '''
    )

    log.debug("sqlparser.parse() successful.")
    assert sql_meta.in_tables == [DbTableName('table0')]
    assert sql_meta.out_tables == []
Пример #15
0
    def _get_table(self, table: str, client: bigquery.Client) -> DbTableSchema:
        bq_table = client.get_table(table)
        if not bq_table._properties:
            return
        table = bq_table._properties

        if not table.get('schema') or not table.get('schema').get('fields'):
            return

        fields = table.get('schema').get('fields')
        columns = [
            DbColumn(name=fields[i].get('name'),
                     type=fields[i].get('type'),
                     description=fields[i].get('description'),
                     ordinal_position=i) for i in range(len(fields))
        ]
        self.log.info(DbTableName(table.get('tableReference').get('tableId')))
        return DbTableSchema(
            schema_name=table.get('tableReference').get('projectId') + '.' +
            table.get('tableReference').get('datasetId'),
            table_name=DbTableName(table.get('tableReference').get('tableId')),
            columns=columns)
Пример #16
0
def test_parser_integration():
    sql_meta = SqlParser.parse(
        """
        INSERT INTO popular_orders_day_of_week (order_day_of_week, order_placed_on,orders_placed)
            SELECT EXTRACT(ISODOW FROM order_placed_on) AS order_day_of_week,
                   order_placed_on,
                   COUNT(*) AS orders_placed
              FROM top_delivery_times
             GROUP BY order_placed_on;
        """,
        "public"
    )
    assert sql_meta.in_tables == [DbTableName('public.top_delivery_times')]
Пример #17
0
def _get_table(tokens, idx):
    idx, token = tokens.token_next(idx=idx)
    token_list = token.flatten()
    table_name = next(token_list).value
    try:
        # Determine if the table contains the schema
        # separated by a dot (format: 'schema.table')
        dot = next(token_list)
        if dot.match(Punctuation, '.'):
            table_name += dot.value
            table_name += next(token_list).value
    except StopIteration:
        pass
    return idx, DbTableName(table_name)
Пример #18
0
def _get_tables(
        tokens,
        idx,
        default_schema: Optional[str] = None) -> Tuple[int, List[DbTableName]]:
    # Extract table identified by preceding SQL keyword at '_is_in_table'
    def parse_ident(ident: Identifier) -> str:
        # Extract table name from possible schema.table naming
        token_list = ident.flatten()
        table_name = next(token_list).value
        try:
            # Determine if the table contains the schema
            # separated by a dot (format: 'schema.table')
            dot = next(token_list)
            if dot.match(Punctuation, '.'):
                table_name += dot.value
                table_name += next(token_list).value
            elif default_schema:
                table_name = f'{default_schema}.{table_name}'
        except StopIteration:
            if default_schema:
                table_name = f'{default_schema}.{table_name}'
        return table_name

    idx, token = tokens.token_next(idx=idx)
    tables = []
    if isinstance(token, IdentifierList):
        # Handle "comma separated joins" as opposed to explicit JOIN keyword
        gidx = 0
        tables.append(
            parse_ident(token.token_first(skip_ws=True, skip_cm=True)))
        gidx, punc = token.token_next(gidx, skip_ws=True, skip_cm=True)
        while punc and punc.value == ',':
            gidx, name = token.token_next(gidx, skip_ws=True, skip_cm=True)
            tables.append(parse_ident(name))
            gidx, punc = token.token_next(gidx)
    else:
        tables.append(parse_ident(token))

    return idx, [DbTableName(table) for table in tables]
Пример #19
0
def test_tpcds_cte_query():
    sql_meta = SqlParser.parse("""
WITH year_total AS
    (SELECT c_customer_id customer_id,
            c_first_name customer_first_name,
            c_last_name customer_last_name,
            c_preferred_cust_flag customer_preferred_cust_flag,
            c_birth_country customer_birth_country,
            c_login customer_login,
            c_email_address customer_email_address,
            d_year dyear,
            Sum(((ss_ext_list_price - ss_ext_wholesale_cost - ss_ext_discount_amt)
                + ss_ext_sales_price) / 2) year_total,
            's' sale_type
     FROM src.customer,
          store_sales,
          date_dim
     WHERE c_customer_sk = ss_customer_sk
         AND ss_sold_date_sk = d_date_sk GROUP  BY c_customer_id,
                                                   c_first_name,
                                                   c_last_name,
                                                   c_preferred_cust_flag,
                                                   c_birth_country,
                                                   c_login,
                                                   c_email_address,
                                                   d_year)
SELECT t_s_secyear.customer_id,
       t_s_secyear.customer_first_name,
       t_s_secyear.customer_last_name,
       t_s_secyear.customer_preferred_cust_flag
FROM year_total t_s_firstyear,
     year_total t_s_secyear,
     year_total t_c_firstyear,
     year_total t_c_secyear,
     year_total t_w_firstyear,
     year_total t_w_secyear
WHERE t_s_secyear.customer_id = t_s_firstyear.customer_id
    AND t_s_firstyear.customer_id = t_c_secyear.customer_id
    AND t_s_firstyear.customer_id = t_c_firstyear.customer_id
    AND t_s_firstyear.customer_id = t_w_firstyear.customer_id
    AND t_s_firstyear.customer_id = t_w_secyear.customer_id
    AND t_s_firstyear.sale_type = 's'
    AND t_c_firstyear.sale_type = 'c'
    AND t_w_firstyear.sale_type = 'w'
    AND t_s_secyear.sale_type = 's'
    AND t_c_secyear.sale_type = 'c'
    AND t_w_secyear.sale_type = 'w'
    AND t_s_firstyear.dyear = 2001
    AND t_s_secyear.dyear = 2001 + 1
    AND t_c_firstyear.dyear = 2001
    AND t_c_secyear.dyear = 2001 + 1
    AND t_w_firstyear.dyear = 2001
    AND t_w_secyear.dyear = 2001 + 1
    AND t_s_firstyear.year_total > 0
    AND t_c_firstyear.year_total > 0
    AND t_w_firstyear.year_total > 0
    AND CASE WHEN
            t_c_firstyear.year_total > 0 THEN t_c_secyear.year_total / t_c_firstyear.year_total
            ELSE NULL
        END > CASE WHEN
            t_s_firstyear.year_total > 0 THEN t_s_secyear.year_total / t_s_firstyear.year_total
            ELSE NULL
        END
    AND CASE WHEN
            t_c_firstyear.year_total > 0 THEN t_c_secyear.year_total / t_c_firstyear.year_total
            ELSE NULL
        END > CASE WHEN
            t_w_firstyear.year_total > 0 THEN t_w_secyear.year_total / t_w_firstyear.year_total
            ELSE NULL
        END
    ORDER  BY t_s_secyear.customer_id,
              t_s_secyear.customer_first_name,
              t_s_secyear.customer_last_name,
              t_s_secyear.customer_preferred_cust_flag
LIMIT 100;
""")
    assert set(sql_meta.in_tables) == {
        DbTableName("src.customer"),
        DbTableName("store_sales"),
        DbTableName("date_dim")
    }
    assert len(sql_meta.out_tables) == 0
Пример #20
0
import mock

from airflow.hooks.postgres_hook import PostgresHook
from airflow.operators.postgres_operator import PostgresOperator
from airflow.utils.dates import days_ago

from marquez_airflow import DAG
from marquez_airflow.models import (DbTableName, DbTableSchema, DbColumn)
from marquez_airflow.extractors import Source, Dataset, DatasetType
from marquez_airflow.extractors.postgres_extractor import PostgresExtractor

CONN_ID = 'food_delivery_db'
CONN_URI = 'postgres://localhost:5432/food_delivery'

DB_SCHEMA_NAME = 'public'
DB_TABLE_NAME = DbTableName('discounts')
DB_TABLE_COLUMNS = [
    DbColumn(name='id', type='int4', ordinal_position=1),
    DbColumn(name='amount_off', type='int4', ordinal_position=2),
    DbColumn(name='customer_email', type='varchar', ordinal_position=3),
    DbColumn(name='starts_on', type='timestamp', ordinal_position=4),
    DbColumn(name='ends_on', type='timestamp', ordinal_position=5)
]
DB_TABLE_SCHEMA = DbTableSchema(schema_name=DB_SCHEMA_NAME,
                                table_name=DB_TABLE_NAME,
                                columns=DB_TABLE_COLUMNS)
NO_DB_TABLE_SCHEMA = []

SQL = f"SELECT * FROM {DB_TABLE_NAME.name};"

DAG_ID = 'email_discounts'