Example #1
0
    def pull_all_from_table_or_query(
        config: RepoConfig,
        data_source: DataSource,
        join_key_columns: List[str],
        feature_name_columns: List[str],
        event_timestamp_column: str,
        start_date: datetime,
        end_date: datetime,
    ) -> RetrievalJob:
        assert isinstance(data_source, RedshiftSource)
        from_expression = data_source.get_table_query_string()

        field_string = ", ".join(join_key_columns + feature_name_columns +
                                 [event_timestamp_column])

        redshift_client = aws_utils.get_redshift_data_client(
            config.offline_store.region)
        s3_resource = aws_utils.get_s3_resource(config.offline_store.region)

        start_date = start_date.astimezone(tz=utc)
        end_date = end_date.astimezone(tz=utc)

        query = f"""
            SELECT {field_string}
            FROM {from_expression}
            WHERE {event_timestamp_column} BETWEEN TIMESTAMP '{start_date}' AND TIMESTAMP '{end_date}'
        """

        return RedshiftRetrievalJob(
            query=query,
            redshift_client=redshift_client,
            s3_resource=s3_resource,
            config=config,
            full_feature_names=False,
        )
Example #2
0
def bootstrap():
    # Bootstrap() will automatically be called from the init_repo() during `feast init`

    import pathlib
    from datetime import datetime, timedelta

    from feast.driver_test_data import create_driver_hourly_stats_df

    end_date = datetime.now().replace(microsecond=0, second=0, minute=0)
    start_date = end_date - timedelta(days=15)

    driver_entities = [1001, 1002, 1003, 1004, 1005]
    driver_df = create_driver_hourly_stats_df(driver_entities, start_date,
                                              end_date)

    aws_region = click.prompt("AWS Region (e.g. us-west-2)")
    cluster_id = click.prompt("Redshift Cluster ID")
    database = click.prompt("Redshift Database Name")
    user = click.prompt("Redshift User Name")
    s3_staging_location = click.prompt("Redshift S3 Staging Location (s3://*)")
    iam_role = click.prompt("Redshift IAM Role for S3 (arn:aws:iam::*:role/*)")

    if click.confirm(
            "Should I upload example data to Redshift (overwriting 'feast_driver_hourly_stats' table)?",
            default=True,
    ):
        client = aws_utils.get_redshift_data_client(aws_region)
        s3 = aws_utils.get_s3_resource(aws_region)

        aws_utils.execute_redshift_statement(
            client,
            cluster_id,
            database,
            user,
            "DROP TABLE IF EXISTS feast_driver_hourly_stats",
        )

        aws_utils.upload_df_to_redshift(
            client,
            cluster_id,
            database,
            user,
            s3,
            f"{s3_staging_location}/data/feast_driver_hourly_stats.parquet",
            iam_role,
            "feast_driver_hourly_stats",
            driver_df,
        )

    repo_path = pathlib.Path(__file__).parent.absolute()
    config_file = repo_path / "feature_store.yaml"

    replace_str_in_file(config_file, "%AWS_REGION%", aws_region)
    replace_str_in_file(config_file, "%REDSHIFT_CLUSTER_ID%", cluster_id)
    replace_str_in_file(config_file, "%REDSHIFT_DATABASE%", database)
    replace_str_in_file(config_file, "%REDSHIFT_USER%", user)
    replace_str_in_file(config_file, "%REDSHIFT_S3_STAGING_LOCATION%",
                        s3_staging_location)
    replace_str_in_file(config_file, "%REDSHIFT_IAM_ROLE%", iam_role)
Example #3
0
    def pull_latest_from_table_or_query(
        config: RepoConfig,
        data_source: DataSource,
        join_key_columns: List[str],
        feature_name_columns: List[str],
        event_timestamp_column: str,
        created_timestamp_column: Optional[str],
        start_date: datetime,
        end_date: datetime,
    ) -> RetrievalJob:
        assert isinstance(data_source, RedshiftSource)
        assert isinstance(config.offline_store, RedshiftOfflineStoreConfig)

        from_expression = data_source.get_table_query_string()

        partition_by_join_key_string = ", ".join(join_key_columns)
        if partition_by_join_key_string != "":
            partition_by_join_key_string = (
                "PARTITION BY " + partition_by_join_key_string
            )
        timestamp_columns = [event_timestamp_column]
        if created_timestamp_column:
            timestamp_columns.append(created_timestamp_column)
        timestamp_desc_string = " DESC, ".join(timestamp_columns) + " DESC"
        field_string = ", ".join(
            join_key_columns + feature_name_columns + timestamp_columns
        )

        redshift_client = aws_utils.get_redshift_data_client(
            config.offline_store.region
        )
        s3_resource = aws_utils.get_s3_resource(config.offline_store.region)

        query = f"""
            SELECT
                {field_string}
                {f", {repr(DUMMY_ENTITY_VAL)} AS {DUMMY_ENTITY_ID}" if not join_key_columns else ""}
            FROM (
                SELECT {field_string},
                ROW_NUMBER() OVER({partition_by_join_key_string} ORDER BY {timestamp_desc_string}) AS _feast_row
                FROM {from_expression}
                WHERE {event_timestamp_column} BETWEEN TIMESTAMP '{start_date}' AND TIMESTAMP '{end_date}'
            )
            WHERE _feast_row = 1
            """
        # When materializing a single feature view, we don't need full feature names. On demand transforms aren't materialized
        return RedshiftRetrievalJob(
            query=query,
            redshift_client=redshift_client,
            s3_resource=s3_resource,
            config=config,
            full_feature_names=False,
            on_demand_feature_views=None,
        )
Example #4
0
    def get_table_column_names_and_types(
            self, config: RepoConfig) -> Iterable[Tuple[str, str]]:
        """
        Returns a mapping of column names to types for this Redshift source.

        Args:
            config: A RepoConfig describing the feature repo
        """
        from botocore.exceptions import ClientError

        from feast.infra.offline_stores.redshift import RedshiftOfflineStoreConfig
        from feast.infra.utils import aws_utils

        assert isinstance(config.offline_store, RedshiftOfflineStoreConfig)

        client = aws_utils.get_redshift_data_client(
            config.offline_store.region)
        if self.table is not None:
            try:
                table = client.describe_table(
                    ClusterIdentifier=config.offline_store.cluster_id,
                    Database=(self.database if self.database else
                              config.offline_store.database),
                    DbUser=config.offline_store.user,
                    Table=self.table,
                    Schema=self.schema,
                )
            except ClientError as e:
                if e.response["Error"]["Code"] == "ValidationException":
                    raise RedshiftCredentialsError() from e
                raise

            # The API returns valid JSON with empty column list when the table doesn't exist
            if len(table["ColumnList"]) == 0:
                raise DataSourceNotFoundException(self.table)

            columns = table["ColumnList"]
        else:
            statement_id = aws_utils.execute_redshift_statement(
                client,
                config.offline_store.cluster_id,
                self.database
                if self.database else config.offline_store.database,
                config.offline_store.user,
                f"SELECT * FROM ({self.query}) LIMIT 1",
            )
            columns = aws_utils.get_redshift_statement_result(
                client, statement_id)["ColumnMetadata"]

        return [(column["name"], column["typeName"].upper())
                for column in columns]
Example #5
0
    def __init__(self, project_name: str, *args, **kwargs):
        super().__init__(project_name)
        self.client = aws_utils.get_redshift_data_client("us-west-2")
        self.s3 = aws_utils.get_s3_resource("us-west-2")

        self.offline_store_config = RedshiftOfflineStoreConfig(
            cluster_id="feast-integration-tests",
            region="us-west-2",
            user="******",
            database="feast",
            s3_staging_location=
            "s3://feast-integration-tests/redshift/tests/ingestion",
            iam_role="arn:aws:iam::402087665549:role/redshift_s3_access_role",
        )
Example #6
0
    def pull_latest_from_table_or_query(
        config: RepoConfig,
        data_source: DataSource,
        join_key_columns: List[str],
        feature_name_columns: List[str],
        event_timestamp_column: str,
        created_timestamp_column: Optional[str],
        start_date: datetime,
        end_date: datetime,
    ) -> RetrievalJob:
        assert isinstance(data_source, RedshiftSource)
        assert isinstance(config.offline_store, RedshiftOfflineStoreConfig)

        from_expression = data_source.get_table_query_string()

        partition_by_join_key_string = ", ".join(join_key_columns)
        if partition_by_join_key_string != "":
            partition_by_join_key_string = ("PARTITION BY " +
                                            partition_by_join_key_string)
        timestamp_columns = [event_timestamp_column]
        if created_timestamp_column:
            timestamp_columns.append(created_timestamp_column)
        timestamp_desc_string = " DESC, ".join(timestamp_columns) + " DESC"
        field_string = ", ".join(join_key_columns + feature_name_columns +
                                 timestamp_columns)

        redshift_client = aws_utils.get_redshift_data_client(
            config.offline_store.region)
        s3_resource = aws_utils.get_s3_resource(config.offline_store.region)

        query = f"""
            SELECT {field_string}
            FROM (
                SELECT {field_string},
                ROW_NUMBER() OVER({partition_by_join_key_string} ORDER BY {timestamp_desc_string}) AS _feast_row
                FROM {from_expression}
                WHERE {event_timestamp_column} BETWEEN TIMESTAMP '{start_date}' AND TIMESTAMP '{end_date}'
            )
            WHERE _feast_row = 1
            """
        return RedshiftRetrievalJob(
            query=query,
            redshift_client=redshift_client,
            s3_resource=s3_resource,
            config=config,
        )
def prep_redshift_fs_and_fv(
    source_type: str, ) -> Iterator[Tuple[FeatureStore, FeatureView]]:
    client = aws_utils.get_redshift_data_client("us-west-2")
    s3 = aws_utils.get_s3_resource("us-west-2")

    df = create_dataset()

    table_name = f"test_ingestion_{source_type}_correctness_{int(time.time_ns())}_{random.randint(1000, 9999)}"

    offline_store = RedshiftOfflineStoreConfig(
        cluster_id="feast-integration-tests",
        region="us-west-2",
        user="******",
        database="feast",
        s3_staging_location=
        "s3://feast-integration-tests/redshift/tests/ingestion",
        iam_role="arn:aws:iam::402087665549:role/redshift_s3_access_role",
    )

    aws_utils.upload_df_to_redshift(
        client,
        offline_store.cluster_id,
        offline_store.database,
        offline_store.user,
        s3,
        f"{offline_store.s3_staging_location}/copy/{table_name}.parquet",
        offline_store.iam_role,
        table_name,
        df,
    )

    redshift_source = RedshiftSource(
        table=table_name if source_type == "table" else None,
        query=f"SELECT * FROM {table_name}"
        if source_type == "query" else None,
        event_timestamp_column="ts",
        created_timestamp_column="created_ts",
        date_partition_column="",
        field_mapping={
            "ts_1": "ts",
            "id": "driver_id"
        },
    )

    fv = driver_feature_view(redshift_source)
    e = Entity(
        name="driver",
        description="id for driver",
        join_key="driver_id",
        value_type=ValueType.INT32,
    )
    with tempfile.TemporaryDirectory(
    ) as repo_dir_name, tempfile.TemporaryDirectory() as data_dir_name:
        config = RepoConfig(
            registry=str(Path(repo_dir_name) / "registry.db"),
            project=f"test_bq_correctness_{str(uuid.uuid4()).replace('-', '')}",
            provider="local",
            online_store=SqliteOnlineStoreConfig(
                path=str(Path(data_dir_name) / "online_store.db")),
            offline_store=offline_store,
        )
        fs = FeatureStore(config=config)
        fs.apply([fv, e])

        yield fs, fv

        fs.teardown()

    # Clean up the uploaded Redshift table
    aws_utils.execute_redshift_statement(
        client,
        offline_store.cluster_id,
        offline_store.database,
        offline_store.user,
        f"DROP TABLE {table_name}",
    )
Example #8
0
    def get_historical_features(
        config: RepoConfig,
        feature_views: List[FeatureView],
        feature_refs: List[str],
        entity_df: Union[pd.DataFrame, str],
        registry: Registry,
        project: str,
        full_feature_names: bool = False,
    ) -> RetrievalJob:
        assert isinstance(config.offline_store, RedshiftOfflineStoreConfig)

        redshift_client = aws_utils.get_redshift_data_client(
            config.offline_store.region
        )
        s3_resource = aws_utils.get_s3_resource(config.offline_store.region)

        @contextlib.contextmanager
        def query_generator() -> Iterator[str]:
            table_name = offline_utils.get_temp_entity_table_name()

            entity_schema = _upload_entity_df_and_get_entity_schema(
                entity_df, redshift_client, config, s3_resource, table_name
            )

            entity_df_event_timestamp_col = offline_utils.infer_event_timestamp_from_entity_df(
                entity_schema
            )

            expected_join_keys = offline_utils.get_expected_join_keys(
                project, feature_views, registry
            )

            offline_utils.assert_expected_columns_in_entity_df(
                entity_schema, expected_join_keys, entity_df_event_timestamp_col
            )

            # Build a query context containing all information required to template the Redshift SQL query
            query_context = offline_utils.get_feature_view_query_context(
                feature_refs, feature_views, registry, project,
            )

            # Generate the Redshift SQL query from the query context
            query = offline_utils.build_point_in_time_query(
                query_context,
                left_table_query_string=table_name,
                entity_df_event_timestamp_col=entity_df_event_timestamp_col,
                query_template=MULTIPLE_FEATURE_VIEW_POINT_IN_TIME_JOIN,
                full_feature_names=full_feature_names,
            )

            yield query

            # Clean up the uploaded Redshift table
            aws_utils.execute_redshift_statement(
                redshift_client,
                config.offline_store.cluster_id,
                config.offline_store.database,
                config.offline_store.user,
                f"DROP TABLE {table_name}",
            )

        return RedshiftRetrievalJob(
            query=query_generator,
            redshift_client=redshift_client,
            s3_resource=s3_resource,
            config=config,
            full_feature_names=full_feature_names,
            on_demand_feature_views=OnDemandFeatureView.get_requested_odfvs(
                feature_refs, project, registry
            ),
            drop_columns=["entity_timestamp"]
            + [
                f"{feature_view.name}__entity_row_unique_id"
                for feature_view in feature_views
            ],
        )
def test_historical_features_from_redshift_sources(provider_type,
                                                   infer_event_timestamp_col,
                                                   capsys, full_feature_names):
    client = aws_utils.get_redshift_data_client("us-west-2")
    s3 = aws_utils.get_s3_resource("us-west-2")

    offline_store = RedshiftOfflineStoreConfig(
        cluster_id="feast-integration-tests",
        region="us-west-2",
        user="******",
        database="feast",
        s3_staging_location=
        "s3://feast-integration-tests/redshift/tests/ingestion",
        iam_role="arn:aws:iam::402087665549:role/redshift_s3_access_role",
    )

    start_date = datetime.now().replace(microsecond=0, second=0, minute=0)
    (
        customer_entities,
        driver_entities,
        end_date,
        orders_df,
        start_date,
    ) = generate_entities(start_date, infer_event_timestamp_col)

    redshift_table_prefix = (
        f"test_hist_retrieval_{int(time.time_ns())}_{random.randint(1000, 9999)}"
    )

    # Stage orders_df to Redshift
    table_name = f"{redshift_table_prefix}_orders"
    entity_df_query = f"SELECT * FROM {table_name}"
    orders_context = aws_utils.temporarily_upload_df_to_redshift(
        client,
        offline_store.cluster_id,
        offline_store.database,
        offline_store.user,
        s3,
        f"{offline_store.s3_staging_location}/copy/{table_name}.parquet",
        offline_store.iam_role,
        table_name,
        orders_df,
    )

    # Stage driver_df to Redshift
    driver_df = driver_data.create_driver_hourly_stats_df(
        driver_entities, start_date, end_date)
    driver_table_name = f"{redshift_table_prefix}_driver_hourly"
    driver_context = aws_utils.temporarily_upload_df_to_redshift(
        client,
        offline_store.cluster_id,
        offline_store.database,
        offline_store.user,
        s3,
        f"{offline_store.s3_staging_location}/copy/{driver_table_name}.parquet",
        offline_store.iam_role,
        driver_table_name,
        driver_df,
    )

    # Stage customer_df to Redshift
    customer_df = driver_data.create_customer_daily_profile_df(
        customer_entities, start_date, end_date)
    customer_table_name = f"{redshift_table_prefix}_customer_profile"
    customer_context = aws_utils.temporarily_upload_df_to_redshift(
        client,
        offline_store.cluster_id,
        offline_store.database,
        offline_store.user,
        s3,
        f"{offline_store.s3_staging_location}/copy/{customer_table_name}.parquet",
        offline_store.iam_role,
        customer_table_name,
        customer_df,
    )

    with orders_context, driver_context, customer_context, TemporaryDirectory(
    ) as temp_dir:
        driver_source = RedshiftSource(
            table=driver_table_name,
            event_timestamp_column="event_timestamp",
            created_timestamp_column="created",
        )
        driver_fv = create_driver_hourly_stats_feature_view(driver_source)

        customer_source = RedshiftSource(
            table=customer_table_name,
            event_timestamp_column="event_timestamp",
            created_timestamp_column="created",
        )
        customer_fv = create_customer_daily_profile_feature_view(
            customer_source)

        driver = Entity(name="driver",
                        join_key="driver_id",
                        value_type=ValueType.INT64)
        customer = Entity(name="customer_id", value_type=ValueType.INT64)

        if provider_type == "local":
            store = FeatureStore(config=RepoConfig(
                registry=os.path.join(temp_dir, "registry.db"),
                project="default",
                provider="local",
                online_store=SqliteOnlineStoreConfig(path=os.path.join(
                    temp_dir, "online_store.db"), ),
                offline_store=offline_store,
            ))
        elif provider_type == "aws":
            store = FeatureStore(config=RepoConfig(
                registry=os.path.join(temp_dir, "registry.db"),
                project="".join(
                    random.choices(string.ascii_uppercase + string.digits,
                                   k=10)),
                provider="aws",
                online_store=DynamoDBOnlineStoreConfig(region="us-west-2"),
                offline_store=offline_store,
            ))
        else:
            raise Exception(
                "Invalid provider used as part of test configuration")

        store.apply([driver, customer, driver_fv, customer_fv])

        try:
            event_timestamp = (DEFAULT_ENTITY_DF_EVENT_TIMESTAMP_COL
                               if DEFAULT_ENTITY_DF_EVENT_TIMESTAMP_COL
                               in orders_df.columns else "e_ts")
            expected_df = get_expected_training_df(
                customer_df,
                customer_fv,
                driver_df,
                driver_fv,
                orders_df,
                event_timestamp,
                full_feature_names,
            )

            job_from_sql = store.get_historical_features(
                entity_df=entity_df_query,
                features=[
                    "driver_stats:conv_rate",
                    "driver_stats:avg_daily_trips",
                    "customer_profile:current_balance",
                    "customer_profile:avg_passenger_count",
                    "customer_profile:lifetime_trip_count",
                ],
                full_feature_names=full_feature_names,
            )

            start_time = datetime.utcnow()
            actual_df_from_sql_entities = job_from_sql.to_df()
            end_time = datetime.utcnow()
            with capsys.disabled():
                print(
                    str(f"\nTime to execute job_from_sql.to_df() = '{(end_time - start_time)}'"
                        ))

            assert sorted(expected_df.columns) == sorted(
                actual_df_from_sql_entities.columns)
            assert_frame_equal(
                expected_df.sort_values(by=[
                    event_timestamp, "order_id", "driver_id", "customer_id"
                ]).reset_index(drop=True),
                actual_df_from_sql_entities[expected_df.columns].sort_values(
                    by=[
                        event_timestamp, "order_id", "driver_id", "customer_id"
                    ]).reset_index(drop=True),
                check_dtype=False,
            )

            table_from_sql_entities = job_from_sql.to_arrow()
            assert_frame_equal(
                actual_df_from_sql_entities.sort_values(by=[
                    event_timestamp, "order_id", "driver_id", "customer_id"
                ]).reset_index(drop=True),
                table_from_sql_entities.to_pandas().sort_values(by=[
                    event_timestamp, "order_id", "driver_id", "customer_id"
                ]).reset_index(drop=True),
            )

            timestamp_column = ("e_ts" if infer_event_timestamp_col else
                                DEFAULT_ENTITY_DF_EVENT_TIMESTAMP_COL)

            entity_df_query_with_invalid_join_key = (
                f"select order_id, driver_id, customer_id as customer, "
                f"order_is_success, {timestamp_column} FROM {table_name}")
            # Rename the join key; this should now raise an error.
            assertpy.assert_that(
                store.get_historical_features(
                    entity_df=entity_df_query_with_invalid_join_key,
                    features=[
                        "driver_stats:conv_rate",
                        "driver_stats:avg_daily_trips",
                        "customer_profile:current_balance",
                        "customer_profile:avg_passenger_count",
                        "customer_profile:lifetime_trip_count",
                    ],
                ).to_df).raises(errors.FeastEntityDFMissingColumnsError
                                ).when_called_with()

            job_from_df = store.get_historical_features(
                entity_df=orders_df,
                features=[
                    "driver_stats:conv_rate",
                    "driver_stats:avg_daily_trips",
                    "customer_profile:current_balance",
                    "customer_profile:avg_passenger_count",
                    "customer_profile:lifetime_trip_count",
                ],
                full_feature_names=full_feature_names,
            )

            # Rename the join key; this should now raise an error.
            orders_df_with_invalid_join_key = orders_df.rename(
                {"customer_id": "customer"}, axis="columns")
            assertpy.assert_that(
                store.get_historical_features(
                    entity_df=orders_df_with_invalid_join_key,
                    features=[
                        "driver_stats:conv_rate",
                        "driver_stats:avg_daily_trips",
                        "customer_profile:current_balance",
                        "customer_profile:avg_passenger_count",
                        "customer_profile:lifetime_trip_count",
                    ],
                ).to_df).raises(errors.FeastEntityDFMissingColumnsError
                                ).when_called_with()

            start_time = datetime.utcnow()
            actual_df_from_df_entities = job_from_df.to_df()
            end_time = datetime.utcnow()
            with capsys.disabled():
                print(
                    str(f"Time to execute job_from_df.to_df() = '{(end_time - start_time)}'\n"
                        ))

            assert sorted(expected_df.columns) == sorted(
                actual_df_from_df_entities.columns)
            assert_frame_equal(
                expected_df.sort_values(by=[
                    event_timestamp, "order_id", "driver_id", "customer_id"
                ]).reset_index(drop=True),
                actual_df_from_df_entities[expected_df.columns].sort_values(
                    by=[
                        event_timestamp, "order_id", "driver_id", "customer_id"
                    ]).reset_index(drop=True),
                check_dtype=False,
            )

            table_from_df_entities = job_from_df.to_arrow()
            assert_frame_equal(
                actual_df_from_df_entities.sort_values(by=[
                    event_timestamp, "order_id", "driver_id", "customer_id"
                ]).reset_index(drop=True),
                table_from_df_entities.to_pandas().sort_values(by=[
                    event_timestamp, "order_id", "driver_id", "customer_id"
                ]).reset_index(drop=True),
            )
        finally:
            store.teardown()