コード例 #1
0
 def extract_on_complete(self, task_instance) -> TaskMetadata:
     inputs = [
         Dataset.from_table_schema(self.source, DbTableSchema(
             schema_name='schema',
             table_name=DbTableName('extract_on_complete_input1'),
             columns=[DbColumn(
                 name='field1',
                 type='text',
                 description='',
                 ordinal_position=1
             ),
                 DbColumn(
                 name='field2',
                 type='text',
                 description='',
                 ordinal_position=2
             )]
         )).to_openlineage_dataset()
     ]
     outputs = [
         Dataset.from_table(self.source, "extract_on_complete_output1").to_openlineage_dataset()
     ]
     return TaskMetadata(
         name=get_job_name(task=self.operator),
         inputs=inputs,
         outputs=outputs,
     )
コード例 #2
0
def _get_tables(
        tokens,
        idx,
        default_schema: Optional[str] = None) -> Tuple[int, List[DbTableName]]:
    # Extract table identified by preceding SQL keyword at '_is_in_table'
    def parse_ident(ident: Identifier) -> str:
        # Extract table name from possible schema.table naming
        token_list = ident.flatten()
        table_name = next(token_list).value
        try:
            # Determine if the table contains the schema
            # separated by a dot (format: 'schema.table')
            dot = next(token_list)
            if dot.match(Punctuation, '.'):
                table_name += dot.value
                table_name += next(token_list).value

                # And again, to match bigquery's 'database.schema.table'
                try:
                    dot = next(token_list)
                    if dot.match(Punctuation, '.'):
                        table_name += dot.value
                        table_name += next(token_list).value
                except StopIteration:
                    # Do not insert database name if it's not specified
                    pass
            elif default_schema:
                table_name = f'{default_schema}.{table_name}'
        except StopIteration:
            if default_schema:
                table_name = f'{default_schema}.{table_name}'

        table_name = table_name.replace('`', '')
        return table_name

    idx, token = tokens.token_next(idx=idx)
    tables = []
    if isinstance(token, IdentifierList):
        # Handle "comma separated joins" as opposed to explicit JOIN keyword
        gidx = 0
        tables.append(
            parse_ident(token.token_first(skip_ws=True, skip_cm=True)))
        gidx, punc = token.token_next(gidx, skip_ws=True, skip_cm=True)
        while punc and punc.value == ',':
            gidx, name = token.token_next(gidx, skip_ws=True, skip_cm=True)
            tables.append(parse_ident(name))
            gidx, punc = token.token_next(gidx)
    else:
        tables.append(parse_ident(token))

    return idx, [DbTableName(table) for table in tables]
コード例 #3
0
def table_schema(source):
    schema_name = 'public'
    table_name = DbTableName('discounts')
    columns = [
        DbColumn(name='id', type='int4', ordinal_position=1),
        DbColumn(name='amount_off', type='int4', ordinal_position=2),
        DbColumn(name='customer_email', type='varchar', ordinal_position=3),
        DbColumn(name='starts_on', type='timestamp', ordinal_position=4),
        DbColumn(name='ends_on', type='timestamp', ordinal_position=5)
    ]
    schema = DbTableSchema(schema_name=schema_name,
                           table_name=table_name,
                           columns=columns)
    return source, columns, schema
コード例 #4
0
    def _get_table_schemas(self,
                           table_names: [DbTableName]) -> [DbTableSchema]:
        # Avoid querying postgres by returning an empty array
        # if no table names have been provided.
        if not table_names:
            return []

        # Keeps tack of the schema by table.
        schemas_by_table = {}

        hook = self._get_hook()
        with closing(hook.get_conn()) as conn:
            with closing(conn.cursor()) as cursor:
                table_names_as_str = ",".join(
                    map(lambda name: f"'{name.name}'", table_names))
                cursor.execute(
                    self._information_schema_query(table_names_as_str))
                for row in cursor.fetchall():
                    table_schema_name: str = row[_TABLE_SCHEMA]
                    table_name: DbTableName = DbTableName(row[_TABLE_NAME])
                    table_column: DbColumn = DbColumn(
                        name=row[_COLUMN_NAME],
                        type=row[_UDT_NAME],
                        ordinal_position=row[_ORDINAL_POSITION])

                    # Attempt to get table schema
                    table_key: str = f"{table_schema_name}.{table_name}"
                    table_schema: Optional[
                        DbTableSchema] = schemas_by_table.get(table_key)

                    if table_schema:
                        # Add column to existing table schema.
                        schemas_by_table[table_key].columns.append(
                            table_column)
                    else:
                        # Create new table schema with column.
                        schemas_by_table[table_key] = DbTableSchema(
                            schema_name=table_schema_name,
                            table_name=table_name,
                            columns=[table_column])

        return list(schemas_by_table.values())
コード例 #5
0
    def _get_table(self, table: str) -> Optional[DbTableSchema]:
        bq_table = self.client.get_table(table)
        if not bq_table._properties:
            return
        table = bq_table._properties

        fields = get_from_nullable_chain(table, ['schema', 'fields'])
        if not fields:
            return

        columns = [
            DbColumn(name=fields[i].get('name'),
                     type=fields[i].get('type'),
                     description=fields[i].get('description'),
                     ordinal_position=i) for i in range(len(fields))
        ]

        return DbTableSchema(
            schema_name=table.get('tableReference').get('projectId') + '.' +
            table.get('tableReference').get('datasetId'),
            table_name=DbTableName(table.get('tableReference').get('tableId')),
            columns=columns)
コード例 #6
0
    airflow_1_path="airflow.operators.postgres_operator.PostgresOperator",
    airflow_2_path="airflow.providers.postgres.operators.postgres.PostgresOperator"
)

PostgresHook = safe_import_airflow(
    airflow_1_path="airflow.hooks.postgres_hook.PostgresHook",
    airflow_2_path="airflow.providers.postgres.hooks.postgres.PostgresHook"
)

CONN_ID = 'food_delivery_db'
CONN_URI = 'postgres://*****:*****@localhost:5432/food_delivery'
CONN_URI_WITHOUT_USERPASS = '******'

DB_NAME = 'food_delivery'
DB_SCHEMA_NAME = 'public'
DB_TABLE_NAME = DbTableName('discounts')
DB_TABLE_COLUMNS = [
    DbColumn(
        name='id',
        type='int4',
        ordinal_position=1
    ),
    DbColumn(
        name='amount_off',
        type='int4',
        ordinal_position=2
    ),
    DbColumn(
        name='customer_email',
        type='varchar',
        ordinal_position=3
コード例 #7
0
def test_tpcds_cte_query():
    sql_meta = SqlParser.parse("""
WITH year_total AS
    (SELECT c_customer_id customer_id,
            c_first_name customer_first_name,
            c_last_name customer_last_name,
            c_preferred_cust_flag customer_preferred_cust_flag,
            c_birth_country customer_birth_country,
            c_login customer_login,
            c_email_address customer_email_address,
            d_year dyear,
            Sum(((ss_ext_list_price - ss_ext_wholesale_cost - ss_ext_discount_amt)
                + ss_ext_sales_price) / 2) year_total,
            's' sale_type
     FROM src.customer,
          store_sales,
          date_dim
     WHERE c_customer_sk = ss_customer_sk
         AND ss_sold_date_sk = d_date_sk GROUP  BY c_customer_id,
                                                   c_first_name,
                                                   c_last_name,
                                                   c_preferred_cust_flag,
                                                   c_birth_country,
                                                   c_login,
                                                   c_email_address,
                                                   d_year)
SELECT t_s_secyear.customer_id,
       t_s_secyear.customer_first_name,
       t_s_secyear.customer_last_name,
       t_s_secyear.customer_preferred_cust_flag
FROM year_total t_s_firstyear,
     year_total t_s_secyear,
     year_total t_c_firstyear,
     year_total t_c_secyear,
     year_total t_w_firstyear,
     year_total t_w_secyear
WHERE t_s_secyear.customer_id = t_s_firstyear.customer_id
    AND t_s_firstyear.customer_id = t_c_secyear.customer_id
    AND t_s_firstyear.customer_id = t_c_firstyear.customer_id
    AND t_s_firstyear.customer_id = t_w_firstyear.customer_id
    AND t_s_firstyear.customer_id = t_w_secyear.customer_id
    AND t_s_firstyear.sale_type = 's'
    AND t_c_firstyear.sale_type = 'c'
    AND t_w_firstyear.sale_type = 'w'
    AND t_s_secyear.sale_type = 's'
    AND t_c_secyear.sale_type = 'c'
    AND t_w_secyear.sale_type = 'w'
    AND t_s_firstyear.dyear = 2001
    AND t_s_secyear.dyear = 2001 + 1
    AND t_c_firstyear.dyear = 2001
    AND t_c_secyear.dyear = 2001 + 1
    AND t_w_firstyear.dyear = 2001
    AND t_w_secyear.dyear = 2001 + 1
    AND t_s_firstyear.year_total > 0
    AND t_c_firstyear.year_total > 0
    AND t_w_firstyear.year_total > 0
    AND CASE WHEN
            t_c_firstyear.year_total > 0 THEN t_c_secyear.year_total / t_c_firstyear.year_total
            ELSE NULL
        END > CASE WHEN
            t_s_firstyear.year_total > 0 THEN t_s_secyear.year_total / t_s_firstyear.year_total
            ELSE NULL
        END
    AND CASE WHEN
            t_c_firstyear.year_total > 0 THEN t_c_secyear.year_total / t_c_firstyear.year_total
            ELSE NULL
        END > CASE WHEN
            t_w_firstyear.year_total > 0 THEN t_w_secyear.year_total / t_w_firstyear.year_total
            ELSE NULL
        END
    ORDER  BY t_s_secyear.customer_id,
              t_s_secyear.customer_first_name,
              t_s_secyear.customer_last_name,
              t_s_secyear.customer_preferred_cust_flag
LIMIT 100;
""")
    assert set(sql_meta.in_tables) == {
        DbTableName("src.customer"),
        DbTableName("store_sales"),
        DbTableName("date_dim")
    }
    assert len(sql_meta.out_tables) == 0
コード例 #8
0
def test_eq_table_name():
    assert DbTableName('discounts') != DbTableName('public.discounts')
    assert DbTableName('discounts').qualified_name is None
    assert DbTableName('public.discounts').qualified_name == 'public.discounts'