예제 #1
0
def update_datahub_contact_consent(
    target_db: str,
    table: sqlalchemy.Table,
    **kwargs,
):
    """
    Updates Contacts temp table with email marketing consent data from Consent dataset.
    """
    table = get_temp_table(table, kwargs['ts_nodash'])
    update_consent_query = f"""
        UPDATE {table.schema}.{table.name} AS contacts_temp
        SET email_marketing_consent = consent.email_marketing_consent
        FROM {ConsentPipeline.fq_table_name()} AS consent
        WHERE lower(contacts_temp.email) = lower(consent.email)
    """
    engine = sqlalchemy.create_engine(
        'postgresql+psycopg2://',
        creator=PostgresHook(postgres_conn_id=target_db).get_conn,
        echo=config.DEBUG,
    )
    with engine.begin() as conn:
        conn.execute(sqlalchemy.text(update_consent_query))

    logger.info(
        'Updated Contacts temp table with email consent from Consent dataset')
예제 #2
0
def scrape_load_and_check_data(
    target_db: str,
    table_config: TableConfig,
    pipeline_instance: "_PandasPipelineWithPollingSupport",
    **kwargs,
):
    create_temp_tables(target_db, *table_config.tables, **kwargs)

    temp_table = get_temp_table(table_config.table, suffix=kwargs['ts_nodash'])

    data_frames = pipeline_instance.__class__.data_getter()

    parsed_uri = urlparse(os.environ['AIRFLOW_CONN_DATASETS_DB'])
    host, port, dbname, user, password = (
        parsed_uri.hostname,
        parsed_uri.port or 5432,
        parsed_uri.path.strip('/'),
        parsed_uri.username,
        parsed_uri.password,
    )
    # Psycopg3 is still under active development, but crucially has support for generating data and pushing it to
    # postgres efficiently via `cursor.copy` and the COPY protocol.
    with psycopg3.connect(
            f'host={host} port={port} dbname={dbname} user={user} password={password}'
    ) as connection:
        with connection.cursor() as cursor:
            logger.info("Starting streaming copy to DB")

            records_num = 0
            df_num = 0
            with cursor.copy(
                    f'COPY "{temp_table.schema}"."{temp_table.name}" FROM STDIN'
            ) as copy:
                for data_frame in data_frames:
                    df_num += 1
                    df_len = len(data_frame)
                    records_num += df_len

                    logger.info(
                        "Copying data frame #%s (records %s - %s)",
                        df_num,
                        records_num - df_len,
                        records_num,
                    )
                    copy.write(
                        data_frame.to_csv(
                            index=False,
                            header=False,
                            sep='\t',
                            na_rep=r'\N',
                            columns=[
                                data_column for data_column, sa_column in
                                table_config.columns
                            ],
                        ))
                    del data_frame

            logger.info("Copy complete.")
예제 #3
0
def check_table_data(target_db: str,
                     *tables: sa.Table,
                     allow_null_columns: bool = False,
                     **kwargs):
    """Verify basic constraints on temp table data."""

    engine = sa.create_engine(
        'postgresql+psycopg2://',
        creator=PostgresHook(postgres_conn_id=target_db).get_conn,
    )

    with engine.begin() as conn:
        for table in tables:
            temp_table = get_temp_table(table, kwargs["ts_nodash"])
            _check_table(engine, conn, temp_table, table, allow_null_columns)
예제 #4
0
def drop_swap_tables(target_db: str, *tables, **kwargs):
    """Delete temporary swap dataset DB tables.

    Given a dataset table `table`, deletes any related swap tables
    containing the previous version of the dataset.

    """
    engine = sa.create_engine(
        'postgresql+psycopg2://',
        creator=PostgresHook(postgres_conn_id=target_db).get_conn,
    )
    with engine.begin() as conn:
        conn.execute("SET statement_timeout = 600000")
        for table in tables:
            swap_table = get_temp_table(table, kwargs["ts_nodash"] + "_swap")
            logger.info("Removing %s", swap_table.name)
            swap_table.drop(conn, checkfirst=True)
예제 #5
0
def drop_temp_tables(target_db: str, *tables, **kwargs):
    """Delete temporary dataset DB tables.

    Given a dataset table `table`, deletes any related temporary
    tables created during the DAG run.

    """
    engine = sa.create_engine(
        'postgresql+psycopg2://',
        creator=PostgresHook(postgres_conn_id=target_db).get_conn,
    )
    with engine.begin() as conn:
        conn.execute("SET statement_timeout = 600000")
        for table in tables:
            temp_table = get_temp_table(table, kwargs["ts_nodash"])
            logger.info("Removing %s", temp_table.name)
            temp_table.drop(conn, checkfirst=True)
예제 #6
0
def update_table(target_db: str, target_table: sa.Table, update_query: str, **kwargs):
    """
    Run a query to update an existing table from a temporary table.
    """
    engine = sa.create_engine(
        'postgresql+psycopg2://',
        creator=PostgresHook(postgres_conn_id=target_db).get_conn,
    )
    with engine.begin() as conn:
        from_table = get_temp_table(target_table, kwargs["ts_nodash"])
        logger.info(f'Updating {target_table.name} from {from_table.name}')
        conn.execute(
            update_query.format(
                schema=engine.dialect.identifier_preparer.quote(target_table.schema),
                target_table=engine.dialect.identifier_preparer.quote(
                    target_table.name
                ),
                from_table=engine.dialect.identifier_preparer.quote(from_table.name),
            )
        )
예제 #7
0
def create_temp_tables(target_db: str, *tables: sa.Table, **kwargs):
    """
    Create a temporary table for the current DAG run for each of the given dataset
    tables.


    Table names are unique for each DAG run and use target table name as a prefix
    and current DAG execution timestamp as a suffix.
    """

    engine = sa.create_engine(
        'postgresql+psycopg2://',
        creator=PostgresHook(postgres_conn_id=target_db).get_conn,
    )

    with engine.begin() as conn:
        conn.execute("SET statement_timeout = 600000")
        for table in tables:
            table = get_temp_table(table, kwargs["ts_nodash"])
            logger.info("Creating schema %s if not exists", table.schema)
            conn.execute(f"CREATE SCHEMA IF NOT EXISTS {table.schema}")
            logger.info("Creating %s", table.name)
            table.create(conn, checkfirst=True)
예제 #8
0
def swap_dataset_tables(
    target_db: str,
    *tables: sa.Table,
    use_utc_now_as_source_modified: bool = False,
    **kwargs,
):
    """Rename temporary tables to replace current dataset one.

    Given a one or more dataset tables `tables` this finds the temporary table created
    for the current DAG run and replaces existing dataset one with it.

    If a dataset table didn't exist the new table gets renamed, otherwise
    the existing dataset table is renamed to a temporary "swap" name first.

    This requires an exclusive lock for the dataset table (similar to TRUNCATE)
    but doesn't need to copy any data around (reducing the amount of time dataset
    is unavailable) and will update the table schema at the same time (since it
    will apply the new schema temporary table was created with).

    """
    engine = sa.create_engine(
        'postgresql+psycopg2://',
        creator=PostgresHook(postgres_conn_id=target_db).get_conn,
    )
    for table in tables:
        temp_table = get_temp_table(table, kwargs["ts_nodash"])

        logger.info("Moving %s to %s", temp_table.name, table.name)
        with engine.begin() as conn:
            conn.execute("SET statement_timeout = 600000")
            grantees = [
                grantee[0] for grantee in conn.execute("""
                SELECT grantee
                FROM information_schema.role_table_grants
                WHERE table_name='{table_name}'
                AND privilege_type = 'SELECT'
                AND grantor != grantee
                """.format(table_name=engine.dialect.identifier_preparer.quote(
                    table.name))).fetchall()
            ]

            conn.execute("""
                SELECT dataflow.save_and_drop_dependencies('{schema}', '{target_temp_table}');
                ALTER TABLE IF EXISTS {schema}.{target_temp_table} RENAME TO {swap_table_name};
                ALTER TABLE {schema}.{temp_table} RENAME TO {target_temp_table};
                SELECT dataflow.restore_dependencies('{schema}', '{target_temp_table}');
                """.format(
                schema=engine.dialect.identifier_preparer.quote(table.schema),
                target_temp_table=engine.dialect.identifier_preparer.quote(
                    table.name),
                swap_table_name=engine.dialect.identifier_preparer.quote(
                    temp_table.name + "_swap"),
                temp_table=engine.dialect.identifier_preparer.quote(
                    temp_table.name),
            ))
            for grantee in grantees + config.DEFAULT_DATABASE_GRANTEES:
                conn.execute(
                    'GRANT SELECT ON {schema}.{table_name} TO {grantee}'.
                    format(
                        schema=engine.dialect.identifier_preparer.quote(
                            table.schema),
                        table_name=engine.dialect.identifier_preparer.quote(
                            table.name),
                        grantee=grantee,
                    ))

            new_modified_utc = kwargs['task_instance'].xcom_pull(
                key='source-modified-date-utc')
            if new_modified_utc is None and use_utc_now_as_source_modified:
                try:
                    new_modified_utc = get_task_instance(
                        kwargs['dag'].safe_dag_id,
                        'run-fetch',
                        kwargs['execution_date'],
                    ).end_date
                except TaskNotFound:
                    new_modified_utc = datetime.datetime.utcnow()

            conn.execute(
                """
                INSERT INTO dataflow.metadata
                (table_schema, table_name, source_data_modified_utc, dataflow_swapped_tables_utc)
                VALUES (%s, %s, %s, %s)
                """,
                (
                    table.schema,
                    table.name,
                    new_modified_utc,
                    datetime.datetime.utcnow(),
                ),
            )
예제 #9
0
def insert_data_into_db(
        target_db: str,
        table: Optional[sa.Table] = None,
        field_mapping: Optional[SingleTableFieldMapping] = None,
        table_config: Optional[TableConfig] = None,
        contexts: Tuple = tuple(),
        **kwargs,
):
    """Insert fetched response data into temporary DB tables.

    Goes through the stored response contents and loads individual
    records into the temporary DB table.

    DB columns are populated according to the field mapping, which
    if as list of `(response_field, column)` tuples, where field
    can either be a string or a tuple of keys/indexes forming a
    path for a nested value.

    """
    if table_config:
        if table or field_mapping:
            raise RuntimeError(
                "You must exclusively provide either (table_config) or (table && field_mapping), not bits of both."
            )

        table_config.configure(**kwargs)
        s3 = S3Data(table_config.table_name, kwargs["ts_nodash"])

    elif table is not None and field_mapping is not None:
        warnings.warn(
            ("`table` and `field_mapping` parameters are deprecated. "
             "This pipeline should be migrated to use `table_config`/`TableConfig`."
             ),
            DeprecationWarning,
        )

        s3 = S3Data(table.name, kwargs["ts_nodash"])
        temp_table = get_temp_table(table, kwargs["ts_nodash"])

    else:
        raise RuntimeError(
            f"No complete table/field mapping configuration provided: {table}, {field_mapping}"
        )

    engine = sa.create_engine(
        'postgresql+psycopg2://',
        creator=PostgresHook(postgres_conn_id=target_db).get_conn,
    )

    count = 0
    for page, records in s3.iter_keys():
        logger.info('Processing page %s', page)
        count += 1

        with engine.begin() as conn:
            if table_config:
                for record in records:
                    for transform in table_config.transforms:
                        record = transform(record, table_config, contexts)

                    conn.execute(
                        table_config.temp_table.insert(),
                        **_get_data_to_insert(table_config.columns, record),
                    )

                    if table_config.related_table_configs:
                        _insert_related_records(conn, table_config,
                                                contexts + (record, ))

            elif table is not None and field_mapping:
                for record in records:
                    conn.execute(
                        temp_table.insert(),  # pylint: disable=E1120
                        **_get_data_to_insert(field_mapping, record),
                    )

        logger.info('Page %s ingested successfully', page)

    if count == 0:
        raise MissingDataError(
            "There are no pages of records in S3 to insert.")
예제 #10
0
def create_temp_table_indexes(target_db: str, table_config: TableConfig,
                              **kwargs):
    """
    Applies TableConfig.indexes to a pre-existing and pre-populated table. This is far more efficient for large datasets
    than having the index in-place when data is being inserted.

    This task processes the contents of `TableConfig.indexes`.

    It only applies indexes from the root table in a TableConfig right now.
    """

    if table_config.indexes is None:
        return

    engine = sa.create_engine(
        'postgresql+psycopg2://',
        creator=PostgresHook(postgres_conn_id=target_db).get_conn,
    )

    def _get_sa_index_and_metadata(
            table_config: TableConfig,
            index: LateIndex) -> Tuple[sa.Index, str, List[str]]:
        """
        Returns a sqlalchemy.Index object for the given LateIndex object. The name is created in a pseudo-random
        way but is reproducible based on the inputs:

        1) Pipeline timestamp
        2) Table name
        3) All of the columns in the index

        The index name created is relatively user-unfriendly, but it avoids the case where the inputs exceed the max
        length of Postgres identifiers and get truncated, which has caused problems historically.
        """
        cols = [index.columns] if isinstance(index.columns,
                                             str) else index.columns
        index_name_parts = [table_config.schema, table_config.table_name
                            ] + cols
        index_hash = md5(
            '\n'.join(index_name_parts).encode('utf-8')).hexdigest()[:32]
        index_suffix = 'key' if index.unique else 'idx'
        index_type = 'unique index' if index.unique else 'index'
        index_name = f'{kwargs["ts_nodash"]}_{index_hash}_{index_suffix}'
        sa_index = Index(index_name, *cols, unique=index.unique)
        return sa_index, index_type, cols

    with engine.begin() as conn:
        conn.execute(
            "SET statement_timeout = 1800000"
        )  # 30-minute timeout on index creation - should be plenty generous

        table = get_temp_table(table_config.table, kwargs["ts_nodash"])
        indexes = []
        for index in table_config.indexes:
            sa_index, index_type, cols = _get_sa_index_and_metadata(
                table_config, index)
            table.append_constraint(sa_index)
            indexes.append((sa_index, index_type, cols))

        for sa_index, index_type, cols in indexes:
            logger.info(
                "Creating %s %s on %s (%s)",
                index_type,
                sa_index.name,
                table.fullname,
                ', '.join(cols),
            )
            sa_index.create(conn)
예제 #11
0
def test_get_temp_table(table):
    assert get_temp_table(table, "temp").name == "test_table_temp"
    assert table.name == "test_table"