예제 #1
0
def prep_local_fs_and_fv() -> Iterator[Tuple[FeatureStore, FeatureView]]:
    with tempfile.NamedTemporaryFile(suffix=".parquet") as f:
        df = create_dataset()
        f.close()
        df.to_parquet(f.name)
        file_source = FileSource(
            file_format=ParquetFormat(),
            file_url=f"file://{f.name}",
            event_timestamp_column="ts",
            created_timestamp_column="created_ts",
            date_partition_column="",
            field_mapping={
                "ts_1": "ts",
                "id": "driver_id"
            },
        )
        fv = get_feature_view(file_source)
        with tempfile.TemporaryDirectory(
        ) as repo_dir_name, tempfile.TemporaryDirectory() as data_dir_name:
            config = RepoConfig(
                registry=str(Path(repo_dir_name) / "registry.db"),
                project=
                f"test_bq_correctness_{str(uuid.uuid4()).replace('-', '')}",
                provider="local",
                online_store=OnlineStoreConfig(local=LocalOnlineStoreConfig(
                    path=str(Path(data_dir_name) / "online_store.db"))),
            )
            fs = FeatureStore(config=config)
            fs.apply([fv])

            yield fs, fv
예제 #2
0
    def __init__(
        self, repo_path: Optional[str] = None, config: Optional[RepoConfig] = None,
    ):
        self.repo_path = repo_path
        if repo_path is not None and config is not None:
            raise ValueError("You cannot specify both repo_path and config")
        if config is not None:
            self.config = config
        elif repo_path is not None:
            self.config = load_repo_config(Path(repo_path))
        else:
            self.config = RepoConfig(
                registry="./registry.db",
                project="default",
                provider="local",
                online_store=OnlineStoreConfig(
                    local=LocalOnlineStoreConfig(path="online_store.db")
                ),
            )

        registry_config = self.config.get_registry_config()
        self._registry = Registry(
            registry_path=registry_config.path,
            cache_ttl=timedelta(seconds=registry_config.cache_ttl_seconds),
        )
        self._tele = Telemetry()
예제 #3
0
 def feature_store_with_local_registry(self):
     fd, registry_path = mkstemp()
     fd, online_store_path = mkstemp()
     return FeatureStore(config=RepoConfig(
         metadata_store=registry_path,
         project="default",
         provider="local",
         online_store=OnlineStoreConfig(local=LocalOnlineStoreConfig(
             path=online_store_path)),
     ))
예제 #4
0
 def __init__(
     self,
     repo_path: Optional[str] = None,
     config: Optional[RepoConfig] = None,
 ):
     if repo_path is not None and config is not None:
         raise ValueError("You cannot specify both repo_path and config")
     if config is not None:
         self.config = config
     elif repo_path is not None:
         self.config = load_repo_config(Path(repo_path))
     else:
         self.config = RepoConfig(
             metadata_store="./metadata.db",
             project="default",
             provider="local",
             online_store=OnlineStoreConfig(
                 local=LocalOnlineStoreConfig("online_store.db")),
         )
예제 #5
0
def test_historical_features_from_bigquery_sources():
    start_date = datetime.now().replace(microsecond=0, second=0, minute=0)
    (
        customer_entities,
        driver_entities,
        end_date,
        orders_df,
        start_date,
    ) = generate_entities(start_date)

    # bigquery_dataset = "test_hist_retrieval_static"
    bigquery_dataset = f"test_hist_retrieval_{int(time.time())}"

    with BigQueryDataSet(bigquery_dataset), TemporaryDirectory() as temp_dir:
        gcp_project = bigquery.Client().project

        # Orders Query
        table_id = f"{bigquery_dataset}.orders"
        stage_orders_bigquery(orders_df, table_id)
        entity_df_query = f"SELECT * FROM {gcp_project}.{table_id}"

        # Driver Feature View
        driver_df = driver_data.create_driver_hourly_stats_df(
            driver_entities, start_date, end_date)
        driver_table_id = f"{gcp_project}.{bigquery_dataset}.driver_hourly"
        stage_driver_hourly_stats_bigquery_source(driver_df, driver_table_id)
        driver_source = BigQuerySource(
            table_ref=driver_table_id,
            event_timestamp_column="datetime",
            created_timestamp_column="created",
        )
        driver_fv = create_driver_hourly_stats_feature_view(driver_source)

        # Customer Feature View
        customer_df = driver_data.create_customer_daily_profile_df(
            customer_entities, start_date, end_date)
        customer_table_id = f"{gcp_project}.{bigquery_dataset}.customer_profile"

        stage_customer_daily_profile_bigquery_source(customer_df,
                                                     customer_table_id)
        customer_source = BigQuerySource(
            table_ref=customer_table_id,
            event_timestamp_column="datetime",
            created_timestamp_column="created",
        )
        customer_fv = create_customer_daily_profile_feature_view(
            customer_source)

        driver = Entity(name="driver", value_type=ValueType.INT64)
        customer = Entity(name="customer", value_type=ValueType.INT64)

        store = FeatureStore(config=RepoConfig(
            registry=os.path.join(temp_dir, "registry.db"),
            project="default",
            provider="gcp",
            online_store=OnlineStoreConfig(local=LocalOnlineStoreConfig(
                path=os.path.join(temp_dir, "online_store.db"), )),
        ))
        store.apply([driver, customer, driver_fv, customer_fv])

        expected_df = get_expected_training_df(
            customer_df,
            customer_fv,
            driver_df,
            driver_fv,
            orders_df,
        )

        job_from_sql = store.get_historical_features(
            entity_df=entity_df_query,
            feature_refs=[
                "driver_stats:conv_rate",
                "driver_stats:avg_daily_trips",
                "customer_profile:current_balance",
                "customer_profile:avg_passenger_count",
                "customer_profile:lifetime_trip_count",
            ],
        )
        actual_df_from_sql_entities = job_from_sql.to_df()

        assert_frame_equal(
            expected_df.sort_values(by=[
                ENTITY_DF_EVENT_TIMESTAMP_COL,
                "order_id",
                "driver_id",
                "customer_id",
            ]).reset_index(drop=True),
            actual_df_from_sql_entities.sort_values(by=[
                ENTITY_DF_EVENT_TIMESTAMP_COL,
                "order_id",
                "driver_id",
                "customer_id",
            ]).reset_index(drop=True),
            check_dtype=False,
        )

        job_from_df = store.get_historical_features(
            entity_df=orders_df,
            feature_refs=[
                "driver_stats:conv_rate",
                "driver_stats:avg_daily_trips",
                "customer_profile:current_balance",
                "customer_profile:avg_passenger_count",
                "customer_profile:lifetime_trip_count",
            ],
        )
        actual_df_from_df_entities = job_from_df.to_df()

        assert_frame_equal(
            expected_df.sort_values(by=[
                ENTITY_DF_EVENT_TIMESTAMP_COL,
                "order_id",
                "driver_id",
                "customer_id",
            ]).reset_index(drop=True),
            actual_df_from_df_entities.sort_values(by=[
                ENTITY_DF_EVENT_TIMESTAMP_COL,
                "order_id",
                "driver_id",
                "customer_id",
            ]).reset_index(drop=True),
            check_dtype=False,
        )
예제 #6
0
def test_historical_features_from_parquet_sources():
    start_date = datetime.now().replace(microsecond=0, second=0, minute=0)
    (
        customer_entities,
        driver_entities,
        end_date,
        orders_df,
        start_date,
    ) = generate_entities(start_date)

    with TemporaryDirectory() as temp_dir:
        driver_df = driver_data.create_driver_hourly_stats_df(
            driver_entities, start_date, end_date)
        driver_source = stage_driver_hourly_stats_parquet_source(
            temp_dir, driver_df)
        driver_fv = create_driver_hourly_stats_feature_view(driver_source)
        customer_df = driver_data.create_customer_daily_profile_df(
            customer_entities, start_date, end_date)
        customer_source = stage_customer_daily_profile_parquet_source(
            temp_dir, customer_df)
        customer_fv = create_customer_daily_profile_feature_view(
            customer_source)
        driver = Entity(name="driver", value_type=ValueType.INT64)
        customer = Entity(name="customer", value_type=ValueType.INT64)

        store = FeatureStore(config=RepoConfig(
            registry=os.path.join(temp_dir, "registry.db"),
            project="default",
            provider="local",
            online_store=OnlineStoreConfig(local=LocalOnlineStoreConfig(
                path=os.path.join(temp_dir, "online_store.db"))),
        ))

        store.apply([driver, customer, driver_fv, customer_fv])

        job = store.get_historical_features(
            entity_df=orders_df,
            feature_refs=[
                "driver_stats:conv_rate",
                "driver_stats:avg_daily_trips",
                "customer_profile:current_balance",
                "customer_profile:avg_passenger_count",
                "customer_profile:lifetime_trip_count",
            ],
        )

        actual_df = job.to_df()
        expected_df = get_expected_training_df(
            customer_df,
            customer_fv,
            driver_df,
            driver_fv,
            orders_df,
        )
        assert_frame_equal(
            expected_df.sort_values(by=[
                ENTITY_DF_EVENT_TIMESTAMP_COL,
                "order_id",
                "driver_id",
                "customer_id",
            ]).reset_index(drop=True),
            actual_df.sort_values(by=[
                ENTITY_DF_EVENT_TIMESTAMP_COL,
                "order_id",
                "driver_id",
                "customer_id",
            ]).reset_index(drop=True),
        )
예제 #7
0
    def test_bigquery_ingestion_correctness(self):
        # create dataset
        ts = pd.Timestamp.now(tz="UTC").round("ms")
        checked_value = (
            random.random()
        )  # random value so test doesn't still work if no values written to online store
        data = {
            "id": [1, 2, 1],
            "value": [0.1, 0.2, checked_value],
            "ts_1": [ts - timedelta(minutes=2), ts, ts],
            "created_ts": [ts, ts, ts],
        }
        df = pd.DataFrame.from_dict(data)

        # load dataset into BigQuery
        job_config = bigquery.LoadJobConfig()
        table_id = (
            f"{self.gcp_project}.{self.bigquery_dataset}.correctness_{int(time.time())}"
        )
        job = self.client.load_table_from_dataframe(df,
                                                    table_id,
                                                    job_config=job_config)
        job.result()

        # create FeatureView
        fv = FeatureView(
            name="test_bq_correctness",
            entities=["driver_id"],
            features=[Feature("value", ValueType.FLOAT)],
            ttl=timedelta(minutes=5),
            input=BigQuerySource(
                event_timestamp_column="ts",
                table_ref=table_id,
                created_timestamp_column="created_ts",
                field_mapping={
                    "ts_1": "ts",
                    "id": "driver_id"
                },
                date_partition_column="",
            ),
        )
        config = RepoConfig(
            metadata_store="./metadata.db",
            project="default",
            provider="gcp",
            online_store=OnlineStoreConfig(
                local=LocalOnlineStoreConfig("online_store.db")),
        )
        fs = FeatureStore(config=config)
        fs.apply([fv])

        # run materialize()
        fs.materialize(
            ["test_bq_correctness"],
            datetime.utcnow() - timedelta(minutes=5),
            datetime.utcnow() - timedelta(minutes=0),
        )

        # check result of materialize()
        entity_key = EntityKeyProto(entity_names=["driver_id"],
                                    entity_values=[ValueProto(int64_val=1)])
        t, val = fs._get_provider().online_read("default", fv, entity_key)
        assert abs(val["value"].double_val - checked_value) < 1e-6