Exemplo n.º 1
0
def bootstrap():
    # Bootstrap() will automatically be called from the init_repo() during `feast init`
    import pathlib
    from datetime import datetime, timedelta

    from feast.driver_test_data import (
        create_customer_daily_profile_df,
        create_driver_hourly_stats_df,
    )

    repo_path = pathlib.Path(__file__).parent.absolute()
    data_path = repo_path / "data"
    data_path.mkdir(exist_ok=True)

    driver_entities = [1001, 1002, 1003]
    end_date = datetime.now().replace(microsecond=0, second=0, minute=0)
    start_date = end_date - timedelta(days=15)
    driver_stats_df = create_driver_hourly_stats_df(driver_entities,
                                                    start_date, end_date)
    driver_stats_df.to_parquet(
        path=str(data_path / "driver_hourly_stats.parquet"),
        allow_truncated_timestamps=True,
    )

    customer_entities = [201, 202, 203]
    customer_profile_df = create_customer_daily_profile_df(
        customer_entities, start_date, end_date)
    customer_profile_df.to_parquet(
        path=str(data_path / "customer_daily_profile.parquet"),
        allow_truncated_timestamps=True,
    )
Exemplo n.º 2
0
def construct_universal_datasets(
        entities: Dict[str, List[Any]], start_time: datetime,
        end_time: datetime) -> Dict[str, pd.DataFrame]:
    customer_df = driver_test_data.create_customer_daily_profile_df(
        entities["customer"], start_time, end_time)
    driver_df = driver_test_data.create_driver_hourly_stats_df(
        entities["driver"], start_time, end_time)
    orders_df = driver_test_data.create_orders_df(
        customers=entities["customer"],
        drivers=entities["driver"],
        start_date=start_time,
        end_date=end_time,
        order_count=20,
    )
    global_df = driver_test_data.create_global_daily_stats_df(
        start_time, end_time)
    entity_df = orders_df[[
        "customer_id", "driver_id", "order_id", "event_timestamp"
    ]]

    return {
        "customer": customer_df,
        "driver": driver_df,
        "orders": orders_df,
        "global": global_df,
        "entity": entity_df,
    }
Exemplo n.º 3
0
def construct_universal_datasets(entities: UniversalEntities,
                                 start_time: datetime,
                                 end_time: datetime) -> UniversalDatasets:
    customer_df = driver_test_data.create_customer_daily_profile_df(
        entities.customer_vals, start_time, end_time)
    driver_df = driver_test_data.create_driver_hourly_stats_df(
        entities.driver_vals, start_time, end_time)
    location_df = driver_test_data.create_location_stats_df(
        entities.location_vals, start_time, end_time)
    orders_df = driver_test_data.create_orders_df(
        customers=entities.customer_vals,
        drivers=entities.driver_vals,
        locations=entities.location_vals,
        start_date=start_time,
        end_date=end_time,
        order_count=20,
    )
    global_df = driver_test_data.create_global_daily_stats_df(
        start_time, end_time)
    field_mapping_df = driver_test_data.create_field_mapping_df(
        start_time, end_time)
    entity_df = orders_df[[
        "customer_id",
        "driver_id",
        "order_id",
        "origin_id",
        "destination_id",
        "event_timestamp",
    ]]

    return UniversalDatasets(
        customer_df=customer_df,
        driver_df=driver_df,
        location_df=location_df,
        orders_df=orders_df,
        global_df=global_df,
        field_mapping_df=field_mapping_df,
        entity_df=entity_df,
    )
Exemplo n.º 4
0
def test_historical_features_from_bigquery_sources(
    provider_type, infer_event_timestamp_col
):
    start_date = datetime.now().replace(microsecond=0, second=0, minute=0)
    (
        customer_entities,
        driver_entities,
        end_date,
        orders_df,
        start_date,
    ) = generate_entities(start_date, infer_event_timestamp_col)

    # bigquery_dataset = "test_hist_retrieval_static"
    bigquery_dataset = (
        f"test_hist_retrieval_{int(time.time_ns())}_{random.randint(1000, 9999)}"
    )

    with BigQueryDataSet(bigquery_dataset), TemporaryDirectory() as temp_dir:
        gcp_project = bigquery.Client().project

        # Orders Query
        table_id = f"{bigquery_dataset}.orders"
        stage_orders_bigquery(orders_df, table_id)
        entity_df_query = f"SELECT * FROM {gcp_project}.{table_id}"

        # Driver Feature View
        driver_df = driver_data.create_driver_hourly_stats_df(
            driver_entities, start_date, end_date
        )
        driver_table_id = f"{gcp_project}.{bigquery_dataset}.driver_hourly"
        stage_driver_hourly_stats_bigquery_source(driver_df, driver_table_id)
        driver_source = BigQuerySource(
            table_ref=driver_table_id,
            event_timestamp_column="datetime",
            created_timestamp_column="created",
        )
        driver_fv = create_driver_hourly_stats_feature_view(driver_source)

        # Customer Feature View
        customer_df = driver_data.create_customer_daily_profile_df(
            customer_entities, start_date, end_date
        )
        customer_table_id = f"{gcp_project}.{bigquery_dataset}.customer_profile"

        stage_customer_daily_profile_bigquery_source(customer_df, customer_table_id)
        customer_source = BigQuerySource(
            table_ref=customer_table_id,
            event_timestamp_column="datetime",
            created_timestamp_column="",
        )
        customer_fv = create_customer_daily_profile_feature_view(customer_source)

        driver = Entity(name="driver", join_key="driver_id", value_type=ValueType.INT64)
        customer = Entity(name="customer_id", value_type=ValueType.INT64)

        if provider_type == "local":
            store = FeatureStore(
                config=RepoConfig(
                    registry=os.path.join(temp_dir, "registry.db"),
                    project="default",
                    provider="local",
                    online_store=SqliteOnlineStoreConfig(
                        path=os.path.join(temp_dir, "online_store.db"),
                    ),
                    offline_store=BigQueryOfflineStoreConfig(type="bigquery",),
                )
            )
        elif provider_type == "gcp":
            store = FeatureStore(
                config=RepoConfig(
                    registry=os.path.join(temp_dir, "registry.db"),
                    project="".join(
                        random.choices(string.ascii_uppercase + string.digits, k=10)
                    ),
                    provider="gcp",
                    offline_store=BigQueryOfflineStoreConfig(type="bigquery",),
                )
            )
        elif provider_type == "gcp_custom_offline_config":
            store = FeatureStore(
                config=RepoConfig(
                    registry=os.path.join(temp_dir, "registry.db"),
                    project="".join(
                        random.choices(string.ascii_uppercase + string.digits, k=10)
                    ),
                    provider="gcp",
                    offline_store=BigQueryOfflineStoreConfig(
                        type="bigquery", dataset="foo"
                    ),
                )
            )
        else:
            raise Exception("Invalid provider used as part of test configuration")

        store.apply([driver, customer, driver_fv, customer_fv])

        event_timestamp = (
            DEFAULT_ENTITY_DF_EVENT_TIMESTAMP_COL
            if DEFAULT_ENTITY_DF_EVENT_TIMESTAMP_COL in orders_df.columns
            else "e_ts"
        )
        expected_df = get_expected_training_df(
            customer_df, customer_fv, driver_df, driver_fv, orders_df, event_timestamp,
        )

        job_from_sql = store.get_historical_features(
            entity_df=entity_df_query,
            feature_refs=[
                "driver_stats:conv_rate",
                "driver_stats:avg_daily_trips",
                "customer_profile:current_balance",
                "customer_profile:avg_passenger_count",
                "customer_profile:lifetime_trip_count",
            ],
        )

        actual_df_from_sql_entities = job_from_sql.to_df()

        assert_frame_equal(
            expected_df.sort_values(
                by=[event_timestamp, "order_id", "driver_id", "customer_id"]
            ).reset_index(drop=True),
            actual_df_from_sql_entities.sort_values(
                by=[event_timestamp, "order_id", "driver_id", "customer_id"]
            ).reset_index(drop=True),
            check_dtype=False,
        )

        job_from_df = store.get_historical_features(
            entity_df=orders_df,
            feature_refs=[
                "driver_stats:conv_rate",
                "driver_stats:avg_daily_trips",
                "customer_profile:current_balance",
                "customer_profile:avg_passenger_count",
                "customer_profile:lifetime_trip_count",
            ],
        )

        if provider_type == "gcp_custom_offline_config":
            # Make sure that custom dataset name is being used from the offline_store config
            assertpy.assert_that(job_from_df.query).contains("foo.entity_df")
        else:
            # If the custom dataset name isn't provided in the config, use default `feast` name
            assertpy.assert_that(job_from_df.query).contains("feast.entity_df")

        actual_df_from_df_entities = job_from_df.to_df()

        assert_frame_equal(
            expected_df.sort_values(
                by=[event_timestamp, "order_id", "driver_id", "customer_id"]
            ).reset_index(drop=True),
            actual_df_from_df_entities.sort_values(
                by=[event_timestamp, "order_id", "driver_id", "customer_id"]
            ).reset_index(drop=True),
            check_dtype=False,
        )
Exemplo n.º 5
0
def test_historical_features_from_parquet_sources(infer_event_timestamp_col):
    start_date = datetime.now().replace(microsecond=0, second=0, minute=0)
    (
        customer_entities,
        driver_entities,
        end_date,
        orders_df,
        start_date,
    ) = generate_entities(start_date, infer_event_timestamp_col)

    with TemporaryDirectory() as temp_dir:
        driver_df = driver_data.create_driver_hourly_stats_df(
            driver_entities, start_date, end_date
        )
        driver_source = stage_driver_hourly_stats_parquet_source(temp_dir, driver_df)
        driver_fv = create_driver_hourly_stats_feature_view(driver_source)
        customer_df = driver_data.create_customer_daily_profile_df(
            customer_entities, start_date, end_date
        )
        customer_source = stage_customer_daily_profile_parquet_source(
            temp_dir, customer_df
        )
        customer_fv = create_customer_daily_profile_feature_view(customer_source)
        driver = Entity(name="driver", join_key="driver_id", value_type=ValueType.INT64)
        customer = Entity(name="customer_id", value_type=ValueType.INT64)

        store = FeatureStore(
            config=RepoConfig(
                registry=os.path.join(temp_dir, "registry.db"),
                project="default",
                provider="local",
                online_store=SqliteOnlineStoreConfig(
                    path=os.path.join(temp_dir, "online_store.db")
                ),
            )
        )

        store.apply([driver, customer, driver_fv, customer_fv])

        job = store.get_historical_features(
            entity_df=orders_df,
            feature_refs=[
                "driver_stats:conv_rate",
                "driver_stats:avg_daily_trips",
                "customer_profile:current_balance",
                "customer_profile:avg_passenger_count",
                "customer_profile:lifetime_trip_count",
            ],
        )

        actual_df = job.to_df()
        event_timestamp = (
            DEFAULT_ENTITY_DF_EVENT_TIMESTAMP_COL
            if DEFAULT_ENTITY_DF_EVENT_TIMESTAMP_COL in orders_df.columns
            else "e_ts"
        )
        expected_df = get_expected_training_df(
            customer_df, customer_fv, driver_df, driver_fv, orders_df, event_timestamp,
        )
        assert_frame_equal(
            expected_df.sort_values(
                by=[event_timestamp, "order_id", "driver_id", "customer_id"]
            ).reset_index(drop=True),
            actual_df.sort_values(
                by=[event_timestamp, "order_id", "driver_id", "customer_id"]
            ).reset_index(drop=True),
        )
Exemplo n.º 6
0
def test_historical_features_from_bigquery_sources(provider_type,
                                                   infer_event_timestamp_col,
                                                   capsys):
    start_date = datetime.now().replace(microsecond=0, second=0, minute=0)
    (
        customer_entities,
        driver_entities,
        end_date,
        orders_df,
        start_date,
    ) = generate_entities(start_date, infer_event_timestamp_col)

    bigquery_dataset = (
        f"test_hist_retrieval_{int(time.time_ns())}_{random.randint(1000, 9999)}"
    )

    with BigQueryDataSet(bigquery_dataset), TemporaryDirectory() as temp_dir:
        gcp_project = bigquery.Client().project

        # Orders Query
        table_id = f"{bigquery_dataset}.orders"
        stage_orders_bigquery(orders_df, table_id)
        entity_df_query = f"SELECT * FROM {gcp_project}.{table_id}"

        # Driver Feature View
        driver_df = driver_data.create_driver_hourly_stats_df(
            driver_entities, start_date, end_date)
        driver_table_id = f"{gcp_project}.{bigquery_dataset}.driver_hourly"
        stage_driver_hourly_stats_bigquery_source(driver_df, driver_table_id)
        driver_source = BigQuerySource(
            table_ref=driver_table_id,
            event_timestamp_column="datetime",
            created_timestamp_column="created",
        )
        driver_fv = create_driver_hourly_stats_feature_view(driver_source)

        # Customer Feature View
        customer_df = driver_data.create_customer_daily_profile_df(
            customer_entities, start_date, end_date)
        customer_table_id = f"{gcp_project}.{bigquery_dataset}.customer_profile"

        stage_customer_daily_profile_bigquery_source(customer_df,
                                                     customer_table_id)
        customer_source = BigQuerySource(
            table_ref=customer_table_id,
            event_timestamp_column="datetime",
            created_timestamp_column="",
        )
        customer_fv = create_customer_daily_profile_feature_view(
            customer_source)

        driver = Entity(name="driver",
                        join_key="driver_id",
                        value_type=ValueType.INT64)
        customer = Entity(name="customer_id", value_type=ValueType.INT64)

        if provider_type == "local":
            store = FeatureStore(config=RepoConfig(
                registry=os.path.join(temp_dir, "registry.db"),
                project="default",
                provider="local",
                online_store=SqliteOnlineStoreConfig(path=os.path.join(
                    temp_dir, "online_store.db"), ),
                offline_store=BigQueryOfflineStoreConfig(
                    type="bigquery", dataset=bigquery_dataset),
            ))
        elif provider_type == "gcp":
            store = FeatureStore(config=RepoConfig(
                registry=os.path.join(temp_dir, "registry.db"),
                project="".join(
                    random.choices(string.ascii_uppercase + string.digits,
                                   k=10)),
                provider="gcp",
                offline_store=BigQueryOfflineStoreConfig(
                    type="bigquery", dataset=bigquery_dataset),
            ))
        elif provider_type == "gcp_custom_offline_config":
            store = FeatureStore(config=RepoConfig(
                registry=os.path.join(temp_dir, "registry.db"),
                project="".join(
                    random.choices(string.ascii_uppercase + string.digits,
                                   k=10)),
                provider="gcp",
                offline_store=BigQueryOfflineStoreConfig(type="bigquery",
                                                         dataset="foo"),
            ))
        else:
            raise Exception(
                "Invalid provider used as part of test configuration")

        store.apply([driver, customer, driver_fv, customer_fv])

        event_timestamp = (DEFAULT_ENTITY_DF_EVENT_TIMESTAMP_COL
                           if DEFAULT_ENTITY_DF_EVENT_TIMESTAMP_COL
                           in orders_df.columns else "e_ts")
        expected_df = get_expected_training_df(
            customer_df,
            customer_fv,
            driver_df,
            driver_fv,
            orders_df,
            event_timestamp,
        )

        job_from_sql = store.get_historical_features(
            entity_df=entity_df_query,
            feature_refs=[
                "driver_stats:conv_rate",
                "driver_stats:avg_daily_trips",
                "customer_profile:current_balance",
                "customer_profile:avg_passenger_count",
                "customer_profile:lifetime_trip_count",
            ],
        )

        start_time = datetime.utcnow()
        actual_df_from_sql_entities = job_from_sql.to_df()
        end_time = datetime.utcnow()
        with capsys.disabled():
            print(
                str(f"\nTime to execute job_from_sql.to_df() = '{(end_time - start_time)}'"
                    ))

        assert sorted(expected_df.columns) == sorted(
            actual_df_from_sql_entities.columns)
        assert_frame_equal(
            expected_df.sort_values(
                by=[event_timestamp, "order_id", "driver_id", "customer_id"
                    ]).reset_index(drop=True),
            actual_df_from_sql_entities[expected_df.columns].sort_values(
                by=[event_timestamp, "order_id", "driver_id", "customer_id"
                    ]).reset_index(drop=True),
            check_dtype=False,
        )

        table_from_sql_entities = job_from_sql.to_arrow()
        assert_frame_equal(actual_df_from_sql_entities,
                           table_from_sql_entities.to_pandas())

        timestamp_column = ("e_ts" if infer_event_timestamp_col else
                            DEFAULT_ENTITY_DF_EVENT_TIMESTAMP_COL)

        entity_df_query_with_invalid_join_key = (
            f"select order_id, driver_id, customer_id as customer, "
            f"order_is_success, {timestamp_column}, FROM {gcp_project}.{table_id}"
        )
        # Rename the join key; this should now raise an error.
        assertpy.assert_that(store.get_historical_features).raises(
            errors.FeastEntityDFMissingColumnsError).when_called_with(
                entity_df=entity_df_query_with_invalid_join_key,
                feature_refs=[
                    "driver_stats:conv_rate",
                    "driver_stats:avg_daily_trips",
                    "customer_profile:current_balance",
                    "customer_profile:avg_passenger_count",
                    "customer_profile:lifetime_trip_count",
                ],
            )

        job_from_df = store.get_historical_features(
            entity_df=orders_df,
            feature_refs=[
                "driver_stats:conv_rate",
                "driver_stats:avg_daily_trips",
                "customer_profile:current_balance",
                "customer_profile:avg_passenger_count",
                "customer_profile:lifetime_trip_count",
            ],
        )

        # Rename the join key; this should now raise an error.
        orders_df_with_invalid_join_key = orders_df.rename(
            {"customer_id": "customer"}, axis="columns")
        assertpy.assert_that(store.get_historical_features).raises(
            errors.FeastEntityDFMissingColumnsError).when_called_with(
                entity_df=orders_df_with_invalid_join_key,
                feature_refs=[
                    "driver_stats:conv_rate",
                    "driver_stats:avg_daily_trips",
                    "customer_profile:current_balance",
                    "customer_profile:avg_passenger_count",
                    "customer_profile:lifetime_trip_count",
                ],
            )

        # Make sure that custom dataset name is being used from the offline_store config
        if provider_type == "gcp_custom_offline_config":
            assertpy.assert_that(job_from_df.query).contains("foo.entity_df")
        else:
            assertpy.assert_that(
                job_from_df.query).contains(f"{bigquery_dataset}.entity_df")

        start_time = datetime.utcnow()
        actual_df_from_df_entities = job_from_df.to_df()
        end_time = datetime.utcnow()
        with capsys.disabled():
            print(
                str(f"Time to execute job_from_df.to_df() = '{(end_time - start_time)}'\n"
                    ))

        assert sorted(expected_df.columns) == sorted(
            actual_df_from_df_entities.columns)
        assert_frame_equal(
            expected_df.sort_values(
                by=[event_timestamp, "order_id", "driver_id", "customer_id"
                    ]).reset_index(drop=True),
            actual_df_from_df_entities[expected_df.columns].sort_values(
                by=[event_timestamp, "order_id", "driver_id", "customer_id"
                    ]).reset_index(drop=True),
            check_dtype=False,
        )
        table_from_df_entities = job_from_df.to_arrow()
        assert_frame_equal(actual_df_from_df_entities,
                           table_from_df_entities.to_pandas())
Exemplo n.º 7
0
def test_historical_features_from_bigquery_sources():
    start_date = datetime.now().replace(microsecond=0, second=0, minute=0)
    (
        customer_entities,
        driver_entities,
        end_date,
        orders_df,
        start_date,
    ) = generate_entities(start_date)

    # bigquery_dataset = "test_hist_retrieval_static"
    bigquery_dataset = f"test_hist_retrieval_{int(time.time())}"

    with BigQueryDataSet(bigquery_dataset), TemporaryDirectory() as temp_dir:
        gcp_project = bigquery.Client().project

        # Orders Query
        table_id = f"{bigquery_dataset}.orders"
        stage_orders_bigquery(orders_df, table_id)
        entity_df_query = f"SELECT * FROM {gcp_project}.{table_id}"

        # Driver Feature View
        driver_df = driver_data.create_driver_hourly_stats_df(
            driver_entities, start_date, end_date)
        driver_table_id = f"{gcp_project}.{bigquery_dataset}.driver_hourly"
        stage_driver_hourly_stats_bigquery_source(driver_df, driver_table_id)
        driver_source = BigQuerySource(
            table_ref=driver_table_id,
            event_timestamp_column="datetime",
            created_timestamp_column="created",
        )
        driver_fv = create_driver_hourly_stats_feature_view(driver_source)

        # Customer Feature View
        customer_df = driver_data.create_customer_daily_profile_df(
            customer_entities, start_date, end_date)
        customer_table_id = f"{gcp_project}.{bigquery_dataset}.customer_profile"

        stage_customer_daily_profile_bigquery_source(customer_df,
                                                     customer_table_id)
        customer_source = BigQuerySource(
            table_ref=customer_table_id,
            event_timestamp_column="datetime",
            created_timestamp_column="created",
        )
        customer_fv = create_customer_daily_profile_feature_view(
            customer_source)

        driver = Entity(name="driver", value_type=ValueType.INT64)
        customer = Entity(name="customer", value_type=ValueType.INT64)

        store = FeatureStore(config=RepoConfig(
            registry=os.path.join(temp_dir, "registry.db"),
            project="default",
            provider="gcp",
            online_store=OnlineStoreConfig(local=LocalOnlineStoreConfig(
                path=os.path.join(temp_dir, "online_store.db"), )),
        ))
        store.apply([driver, customer, driver_fv, customer_fv])

        expected_df = get_expected_training_df(
            customer_df,
            customer_fv,
            driver_df,
            driver_fv,
            orders_df,
        )

        job_from_sql = store.get_historical_features(
            entity_df=entity_df_query,
            feature_refs=[
                "driver_stats:conv_rate",
                "driver_stats:avg_daily_trips",
                "customer_profile:current_balance",
                "customer_profile:avg_passenger_count",
                "customer_profile:lifetime_trip_count",
            ],
        )
        actual_df_from_sql_entities = job_from_sql.to_df()

        assert_frame_equal(
            expected_df.sort_values(by=[
                ENTITY_DF_EVENT_TIMESTAMP_COL,
                "order_id",
                "driver_id",
                "customer_id",
            ]).reset_index(drop=True),
            actual_df_from_sql_entities.sort_values(by=[
                ENTITY_DF_EVENT_TIMESTAMP_COL,
                "order_id",
                "driver_id",
                "customer_id",
            ]).reset_index(drop=True),
            check_dtype=False,
        )

        job_from_df = store.get_historical_features(
            entity_df=orders_df,
            feature_refs=[
                "driver_stats:conv_rate",
                "driver_stats:avg_daily_trips",
                "customer_profile:current_balance",
                "customer_profile:avg_passenger_count",
                "customer_profile:lifetime_trip_count",
            ],
        )
        actual_df_from_df_entities = job_from_df.to_df()

        assert_frame_equal(
            expected_df.sort_values(by=[
                ENTITY_DF_EVENT_TIMESTAMP_COL,
                "order_id",
                "driver_id",
                "customer_id",
            ]).reset_index(drop=True),
            actual_df_from_df_entities.sort_values(by=[
                ENTITY_DF_EVENT_TIMESTAMP_COL,
                "order_id",
                "driver_id",
                "customer_id",
            ]).reset_index(drop=True),
            check_dtype=False,
        )
Exemplo n.º 8
0
class Environment:
    name: str
    test_repo_config: TestRepoConfig
    feature_store: FeatureStore
    data_source: DataSource
    data_source_creator: DataSourceCreator

    end_date = datetime.now().replace(microsecond=0, second=0, minute=0)
    start_date = end_date - timedelta(days=7)
    before_start_date = end_date - timedelta(days=365)
    after_end_date = end_date + timedelta(days=365)

    customer_entities = list(range(1001, 1110))
    customer_df = driver_test_data.create_customer_daily_profile_df(
        customer_entities, start_date, end_date)
    _customer_feature_view: Optional[FeatureView] = None

    driver_entities = list(range(5001, 5110))
    driver_df = driver_test_data.create_driver_hourly_stats_df(
        driver_entities, start_date, end_date)
    _driver_stats_feature_view: Optional[FeatureView] = None

    orders_df = driver_test_data.create_orders_df(
        customers=customer_entities,
        drivers=driver_entities,
        start_date=before_start_date,
        end_date=after_end_date,
        order_count=1000,
    )
    _orders_table: Optional[str] = None

    def customer_feature_view(self) -> FeatureView:
        if self._customer_feature_view is None:
            customer_table_id = self.data_source_creator.get_prefixed_table_name(
                self.name, "customer_profile")
            ds = self.data_source_creator.create_data_source(
                customer_table_id,
                self.customer_df,
                event_timestamp_column="event_timestamp",
                created_timestamp_column="created",
            )
            self._customer_feature_view = create_customer_daily_profile_feature_view(
                ds)
        return self._customer_feature_view

    def driver_stats_feature_view(self) -> FeatureView:
        if self._driver_stats_feature_view is None:
            driver_table_id = self.data_source_creator.get_prefixed_table_name(
                self.name, "driver_hourly")
            ds = self.data_source_creator.create_data_source(
                driver_table_id,
                self.driver_df,
                event_timestamp_column="event_timestamp",
                created_timestamp_column="created",
            )
            self._driver_stats_feature_view = create_driver_hourly_stats_feature_view(
                ds)
        return self._driver_stats_feature_view

    def orders_table(self) -> Optional[str]:
        if self._orders_table is None:
            orders_table_id = self.data_source_creator.get_prefixed_table_name(
                self.name, "orders")
            ds = self.data_source_creator.create_data_source(
                orders_table_id,
                self.orders_df,
                event_timestamp_column="event_timestamp",
                created_timestamp_column="created",
            )
            if hasattr(ds, "table_ref"):
                self._orders_table = ds.table_ref
            elif hasattr(ds, "table"):
                self._orders_table = ds.table
        return self._orders_table
Exemplo n.º 9
0
def test_historical_features_from_redshift_sources(provider_type,
                                                   infer_event_timestamp_col,
                                                   capsys, full_feature_names):
    client = aws_utils.get_redshift_data_client("us-west-2")
    s3 = aws_utils.get_s3_resource("us-west-2")

    offline_store = RedshiftOfflineStoreConfig(
        cluster_id="feast-integration-tests",
        region="us-west-2",
        user="******",
        database="feast",
        s3_staging_location=
        "s3://feast-integration-tests/redshift/tests/ingestion",
        iam_role="arn:aws:iam::402087665549:role/redshift_s3_access_role",
    )

    start_date = datetime.now().replace(microsecond=0, second=0, minute=0)
    (
        customer_entities,
        driver_entities,
        end_date,
        orders_df,
        start_date,
    ) = generate_entities(start_date, infer_event_timestamp_col)

    redshift_table_prefix = (
        f"test_hist_retrieval_{int(time.time_ns())}_{random.randint(1000, 9999)}"
    )

    # Stage orders_df to Redshift
    table_name = f"{redshift_table_prefix}_orders"
    entity_df_query = f"SELECT * FROM {table_name}"
    orders_context = aws_utils.temporarily_upload_df_to_redshift(
        client,
        offline_store.cluster_id,
        offline_store.database,
        offline_store.user,
        s3,
        f"{offline_store.s3_staging_location}/copy/{table_name}.parquet",
        offline_store.iam_role,
        table_name,
        orders_df,
    )

    # Stage driver_df to Redshift
    driver_df = driver_data.create_driver_hourly_stats_df(
        driver_entities, start_date, end_date)
    driver_table_name = f"{redshift_table_prefix}_driver_hourly"
    driver_context = aws_utils.temporarily_upload_df_to_redshift(
        client,
        offline_store.cluster_id,
        offline_store.database,
        offline_store.user,
        s3,
        f"{offline_store.s3_staging_location}/copy/{driver_table_name}.parquet",
        offline_store.iam_role,
        driver_table_name,
        driver_df,
    )

    # Stage customer_df to Redshift
    customer_df = driver_data.create_customer_daily_profile_df(
        customer_entities, start_date, end_date)
    customer_table_name = f"{redshift_table_prefix}_customer_profile"
    customer_context = aws_utils.temporarily_upload_df_to_redshift(
        client,
        offline_store.cluster_id,
        offline_store.database,
        offline_store.user,
        s3,
        f"{offline_store.s3_staging_location}/copy/{customer_table_name}.parquet",
        offline_store.iam_role,
        customer_table_name,
        customer_df,
    )

    with orders_context, driver_context, customer_context, TemporaryDirectory(
    ) as temp_dir:
        driver_source = RedshiftSource(
            table=driver_table_name,
            event_timestamp_column="event_timestamp",
            created_timestamp_column="created",
        )
        driver_fv = create_driver_hourly_stats_feature_view(driver_source)

        customer_source = RedshiftSource(
            table=customer_table_name,
            event_timestamp_column="event_timestamp",
            created_timestamp_column="created",
        )
        customer_fv = create_customer_daily_profile_feature_view(
            customer_source)

        driver = Entity(name="driver",
                        join_key="driver_id",
                        value_type=ValueType.INT64)
        customer = Entity(name="customer_id", value_type=ValueType.INT64)

        if provider_type == "local":
            store = FeatureStore(config=RepoConfig(
                registry=os.path.join(temp_dir, "registry.db"),
                project="default",
                provider="local",
                online_store=SqliteOnlineStoreConfig(path=os.path.join(
                    temp_dir, "online_store.db"), ),
                offline_store=offline_store,
            ))
        elif provider_type == "aws":
            store = FeatureStore(config=RepoConfig(
                registry=os.path.join(temp_dir, "registry.db"),
                project="".join(
                    random.choices(string.ascii_uppercase + string.digits,
                                   k=10)),
                provider="aws",
                online_store=DynamoDBOnlineStoreConfig(region="us-west-2"),
                offline_store=offline_store,
            ))
        else:
            raise Exception(
                "Invalid provider used as part of test configuration")

        store.apply([driver, customer, driver_fv, customer_fv])

        try:
            event_timestamp = (DEFAULT_ENTITY_DF_EVENT_TIMESTAMP_COL
                               if DEFAULT_ENTITY_DF_EVENT_TIMESTAMP_COL
                               in orders_df.columns else "e_ts")
            expected_df = get_expected_training_df(
                customer_df,
                customer_fv,
                driver_df,
                driver_fv,
                orders_df,
                event_timestamp,
                full_feature_names,
            )

            job_from_sql = store.get_historical_features(
                entity_df=entity_df_query,
                features=[
                    "driver_stats:conv_rate",
                    "driver_stats:avg_daily_trips",
                    "customer_profile:current_balance",
                    "customer_profile:avg_passenger_count",
                    "customer_profile:lifetime_trip_count",
                ],
                full_feature_names=full_feature_names,
            )

            start_time = datetime.utcnow()
            actual_df_from_sql_entities = job_from_sql.to_df()
            end_time = datetime.utcnow()
            with capsys.disabled():
                print(
                    str(f"\nTime to execute job_from_sql.to_df() = '{(end_time - start_time)}'"
                        ))

            assert sorted(expected_df.columns) == sorted(
                actual_df_from_sql_entities.columns)
            assert_frame_equal(
                expected_df.sort_values(by=[
                    event_timestamp, "order_id", "driver_id", "customer_id"
                ]).reset_index(drop=True),
                actual_df_from_sql_entities[expected_df.columns].sort_values(
                    by=[
                        event_timestamp, "order_id", "driver_id", "customer_id"
                    ]).reset_index(drop=True),
                check_dtype=False,
            )

            table_from_sql_entities = job_from_sql.to_arrow()
            assert_frame_equal(
                actual_df_from_sql_entities.sort_values(by=[
                    event_timestamp, "order_id", "driver_id", "customer_id"
                ]).reset_index(drop=True),
                table_from_sql_entities.to_pandas().sort_values(by=[
                    event_timestamp, "order_id", "driver_id", "customer_id"
                ]).reset_index(drop=True),
            )

            timestamp_column = ("e_ts" if infer_event_timestamp_col else
                                DEFAULT_ENTITY_DF_EVENT_TIMESTAMP_COL)

            entity_df_query_with_invalid_join_key = (
                f"select order_id, driver_id, customer_id as customer, "
                f"order_is_success, {timestamp_column} FROM {table_name}")
            # Rename the join key; this should now raise an error.
            assertpy.assert_that(
                store.get_historical_features(
                    entity_df=entity_df_query_with_invalid_join_key,
                    features=[
                        "driver_stats:conv_rate",
                        "driver_stats:avg_daily_trips",
                        "customer_profile:current_balance",
                        "customer_profile:avg_passenger_count",
                        "customer_profile:lifetime_trip_count",
                    ],
                ).to_df).raises(errors.FeastEntityDFMissingColumnsError
                                ).when_called_with()

            job_from_df = store.get_historical_features(
                entity_df=orders_df,
                features=[
                    "driver_stats:conv_rate",
                    "driver_stats:avg_daily_trips",
                    "customer_profile:current_balance",
                    "customer_profile:avg_passenger_count",
                    "customer_profile:lifetime_trip_count",
                ],
                full_feature_names=full_feature_names,
            )

            # Rename the join key; this should now raise an error.
            orders_df_with_invalid_join_key = orders_df.rename(
                {"customer_id": "customer"}, axis="columns")
            assertpy.assert_that(
                store.get_historical_features(
                    entity_df=orders_df_with_invalid_join_key,
                    features=[
                        "driver_stats:conv_rate",
                        "driver_stats:avg_daily_trips",
                        "customer_profile:current_balance",
                        "customer_profile:avg_passenger_count",
                        "customer_profile:lifetime_trip_count",
                    ],
                ).to_df).raises(errors.FeastEntityDFMissingColumnsError
                                ).when_called_with()

            start_time = datetime.utcnow()
            actual_df_from_df_entities = job_from_df.to_df()
            end_time = datetime.utcnow()
            with capsys.disabled():
                print(
                    str(f"Time to execute job_from_df.to_df() = '{(end_time - start_time)}'\n"
                        ))

            assert sorted(expected_df.columns) == sorted(
                actual_df_from_df_entities.columns)
            assert_frame_equal(
                expected_df.sort_values(by=[
                    event_timestamp, "order_id", "driver_id", "customer_id"
                ]).reset_index(drop=True),
                actual_df_from_df_entities[expected_df.columns].sort_values(
                    by=[
                        event_timestamp, "order_id", "driver_id", "customer_id"
                    ]).reset_index(drop=True),
                check_dtype=False,
            )

            table_from_df_entities = job_from_df.to_arrow()
            assert_frame_equal(
                actual_df_from_df_entities.sort_values(by=[
                    event_timestamp, "order_id", "driver_id", "customer_id"
                ]).reset_index(drop=True),
                table_from_df_entities.to_pandas().sort_values(by=[
                    event_timestamp, "order_id", "driver_id", "customer_id"
                ]).reset_index(drop=True),
            )
        finally:
            store.teardown()