def benchmark_writes(): project_id = "test" + "".join( random.choice(string.ascii_lowercase + string.digits) for _ in range(10) ) with tempfile.TemporaryDirectory() as temp_dir: store = FeatureStore( config=RepoConfig( registry=os.path.join(temp_dir, "registry.db"), project=project_id, provider="gcp", ) ) # This is just to set data source to something, we're not reading from parquet source here. parquet_path = os.path.join(temp_dir, "data.parquet") driver = Entity(name="driver_id", value_type=ValueType.INT64) table = create_driver_hourly_stats_feature_view( create_driver_hourly_stats_source(parquet_path=parquet_path) ) store.apply([table, driver]) provider = store._get_provider() end_date = datetime.utcnow() start_date = end_date - timedelta(days=14) customers = list(range(100)) data = create_driver_hourly_stats_df(customers, start_date, end_date) # Show the data for reference print(data) proto_data = _convert_arrow_to_proto( pa.Table.from_pandas(data), table, ["driver_id"] ) # Write it with tqdm(total=len(proto_data)) as progress: provider.online_write_batch( project=store.project, table=table, data=proto_data, progress=progress.update, ) registry_tables = store.list_feature_views() registry_entities = store.list_entities() provider.teardown_infra( store.project, tables=registry_tables, entities=registry_entities )
def basic_rw_test(store: FeatureStore, view_name: str, feature_service_name: Optional[str] = None) -> None: """ This is a provider-independent test suite for reading and writing from the online store, to be used by provider-specific tests. """ table = store.get_feature_view(name=view_name) provider = store._get_provider() entity_key = EntityKeyProto(join_keys=["driver"], entity_values=[ValueProto(int64_val=1)]) def _driver_rw_test(event_ts, created_ts, write, expect_read): """ A helper function to write values and read them back """ write_lat, write_lon = write expect_lat, expect_lon = expect_read provider.online_write_batch( config=store.config, table=table, data=[( entity_key, { "lat": ValueProto(double_val=write_lat), "lon": ValueProto(string_val=write_lon), }, event_ts, created_ts, )], progress=None, ) if feature_service_name: entity_dict = {"driver": 1} feature_service = store.get_feature_service(feature_service_name) features = store.get_online_features(features=feature_service, entity_rows=[entity_dict ]).to_dict() assert len(features["driver"]) == 1 assert features["lon"][0] == expect_lon assert abs(features["lat"][0] - expect_lat) < 1e-6 else: read_rows = provider.online_read(config=store.config, table=table, entity_keys=[entity_key]) assert len(read_rows) == 1 _, val = read_rows[0] assert val["lon"].string_val == expect_lon assert abs(val["lat"].double_val - expect_lat) < 1e-6 """ 1. Basic test: write value, read it back """ time_1 = datetime.utcnow() _driver_rw_test(event_ts=time_1, created_ts=time_1, write=(1.1, "3.1"), expect_read=(1.1, "3.1")) # Note: This behavior has changed for performance. We should test that older # value can't overwrite over a newer value once we add the respective flag """ Values with an older event_ts should overwrite newer ones """ time_2 = datetime.utcnow() _driver_rw_test( event_ts=time_1 - timedelta(hours=1), created_ts=time_2, write=(-1000, "OLD"), expect_read=(-1000, "OLD"), ) """ Values with an new event_ts should overwrite older ones """ time_3 = datetime.utcnow() _driver_rw_test( event_ts=time_1 + timedelta(hours=1), created_ts=time_3, write=(1123, "NEWER"), expect_read=(1123, "NEWER"), ) # Note: This behavior has changed for performance. We should test that older # value can't overwrite over a newer value once we add the respective flag """ created_ts is used as a tie breaker, using older created_ts here, but we still overwrite """ _driver_rw_test( event_ts=time_1 + timedelta(hours=1), created_ts=time_3 - timedelta(hours=1), write=(54321, "I HAVE AN OLDER created_ts SO I LOSE"), expect_read=(54321, "I HAVE AN OLDER created_ts SO I LOSE"), ) """ created_ts is used as a tie breaker, using newer created_ts here so we should overwrite """ _driver_rw_test( event_ts=time_1 + timedelta(hours=1), created_ts=time_3 + timedelta(hours=1), write=(96864, "I HAVE A NEWER created_ts SO I WIN"), expect_read=(96864, "I HAVE A NEWER created_ts SO I WIN"), )
def basic_rw_test(store: FeatureStore, view_name: str, feature_service_name: Optional[str] = None) -> None: """ This is a provider-independent test suite for reading and writing from the online store, to be used by provider-specific tests. """ table = store.get_feature_view(name=view_name) provider = store._get_provider() entity_key = EntityKeyProto(join_keys=["driver_id"], entity_values=[ValueProto(int64_val=1)]) def _driver_rw_test(event_ts, created_ts, write, expect_read): """A helper function to write values and read them back""" write_lat, write_lon = write expect_lat, expect_lon = expect_read provider.online_write_batch( config=store.config, table=table, data=[( entity_key, { "lat": ValueProto(double_val=write_lat), "lon": ValueProto(string_val=write_lon), }, event_ts, created_ts, )], progress=None, ) if feature_service_name: entity_dict = {"driver_id": 1} feature_service = store.get_feature_service(feature_service_name) features = store.get_online_features(features=feature_service, entity_rows=[entity_dict ]).to_dict() assert len(features["driver_id"]) == 1 assert features["lon"][0] == expect_lon assert abs(features["lat"][0] - expect_lat) < 1e-6 else: read_rows = provider.online_read(config=store.config, table=table, entity_keys=[entity_key]) assert len(read_rows) == 1 _, val = read_rows[0] assert val["lon"].string_val == expect_lon assert abs(val["lat"].double_val - expect_lat) < 1e-6 """ 1. Basic test: write value, read it back """ time_1 = datetime.utcnow() _driver_rw_test(event_ts=time_1, created_ts=time_1, write=(1.1, "3.1"), expect_read=(1.1, "3.1")) """ Values with an new event_ts should overwrite older ones """ time_3 = datetime.utcnow() _driver_rw_test( event_ts=time_1 + timedelta(hours=1), created_ts=time_3, write=(1123, "NEWER"), expect_read=(1123, "NEWER"), )
def basic_rw_test(store: FeatureStore, view_name: str) -> None: """ This is a provider-independent test suite for reading and writing from the online store, to be used by provider-specific tests. """ table = store.get_feature_view(name=view_name) provider = store._get_provider() entity_key = EntityKeyProto(entity_names=["driver"], entity_values=[ValueProto(int64_val=1)]) def _driver_rw_test(event_ts, created_ts, write, expect_read): """ A helper function to write values and read them back """ write_lat, write_lon = write expect_lat, expect_lon = expect_read provider.online_write_batch( project=store.project, table=table, data=[( entity_key, { "lat": ValueProto(double_val=write_lat), "lon": ValueProto(string_val=write_lon), }, event_ts, created_ts, )], progress=None, ) read_rows = provider.online_read(project=store.project, table=table, entity_keys=[entity_key]) assert len(read_rows) == 1 _, val = read_rows[0] assert val["lon"].string_val == expect_lon assert abs(val["lat"].double_val - expect_lat) < 1e-6 """ 1. Basic test: write value, read it back """ time_1 = datetime.utcnow() _driver_rw_test(event_ts=time_1, created_ts=time_1, write=(1.1, "3.1"), expect_read=(1.1, "3.1")) """ Values with an older event_ts should not overwrite newer ones """ time_2 = datetime.utcnow() _driver_rw_test( event_ts=time_1 - timedelta(hours=1), created_ts=time_2, write=(-1000, "OLD"), expect_read=(1.1, "3.1"), ) """ Values with an new event_ts should overwrite older ones """ time_3 = datetime.utcnow() _driver_rw_test( event_ts=time_1 + timedelta(hours=1), created_ts=time_3, write=(1123, "NEWER"), expect_read=(1123, "NEWER"), ) """ created_ts is used as a tie breaker, using older created_ts here so no overwrite """ _driver_rw_test( event_ts=time_1 + timedelta(hours=1), created_ts=time_3 - timedelta(hours=1), write=(54321, "I HAVE AN OLDER created_ts SO I LOSE"), expect_read=(1123, "NEWER"), ) """ created_ts is used as a tie breaker, using older created_ts here so no overwrite """ _driver_rw_test( event_ts=time_1 + timedelta(hours=1), created_ts=time_3 + timedelta(hours=1), write=(96864, "I HAVE A NEWER created_ts SO I WIN"), expect_read=(96864, "I HAVE A NEWER created_ts SO I WIN"), )
def test_bigquery_ingestion_correctness(self): # create dataset ts = pd.Timestamp.now(tz="UTC").round("ms") checked_value = ( random.random() ) # random value so test doesn't still work if no values written to online store data = { "id": [1, 2, 1], "value": [0.1, 0.2, checked_value], "ts_1": [ts - timedelta(minutes=2), ts, ts], "created_ts": [ts, ts, ts], } df = pd.DataFrame.from_dict(data) # load dataset into BigQuery job_config = bigquery.LoadJobConfig() table_id = ( f"{self.gcp_project}.{self.bigquery_dataset}.correctness_{int(time.time())}" ) job = self.client.load_table_from_dataframe(df, table_id, job_config=job_config) job.result() # create FeatureView fv = FeatureView( name="test_bq_correctness", entities=["driver_id"], features=[Feature("value", ValueType.FLOAT)], ttl=timedelta(minutes=5), input=BigQuerySource( event_timestamp_column="ts", table_ref=table_id, created_timestamp_column="created_ts", field_mapping={ "ts_1": "ts", "id": "driver_id" }, date_partition_column="", ), ) config = RepoConfig( metadata_store="./metadata.db", project="default", provider="gcp", online_store=OnlineStoreConfig( local=LocalOnlineStoreConfig("online_store.db")), ) fs = FeatureStore(config=config) fs.apply([fv]) # run materialize() fs.materialize( ["test_bq_correctness"], datetime.utcnow() - timedelta(minutes=5), datetime.utcnow() - timedelta(minutes=0), ) # check result of materialize() entity_key = EntityKeyProto(entity_names=["driver_id"], entity_values=[ValueProto(int64_val=1)]) t, val = fs._get_provider().online_read("default", fv, entity_key) assert abs(val["value"].double_val - checked_value) < 1e-6