def _convert_arrow_to_proto( table: pyarrow.Table, feature_view: FeatureView ) -> List[Tuple[EntityKeyProto, Dict[str, ValueProto], datetime, Optional[datetime]]]: rows_to_write = [] for row in zip(*table.to_pydict().values()): entity_key = EntityKeyProto() for entity_name in feature_view.entities: entity_key.entity_names.append(entity_name) idx = table.column_names.index(entity_name) value = python_value_to_proto_value(row[idx]) entity_key.entity_values.append(value) feature_dict = {} for feature in feature_view.features: idx = table.column_names.index(feature.name) value = python_value_to_proto_value(row[idx]) feature_dict[feature.name] = value event_timestamp_idx = table.column_names.index( feature_view.input.event_timestamp_column) event_timestamp = row[event_timestamp_idx] if feature_view.input.created_timestamp_column is not None: created_timestamp_idx = table.column_names.index( feature_view.input.created_timestamp_column) created_timestamp = row[created_timestamp_idx] else: created_timestamp = None rows_to_write.append( (entity_key, feature_dict, event_timestamp, created_timestamp)) return rows_to_write
def _get_table_entity_keys( table: FeatureView, entity_keys: List[EntityKeyProto], join_key_map: Dict[str, str], ) -> List[EntityKeyProto]: table_join_keys = [join_key_map[entity_name] for entity_name in table.entities] required_entities = OrderedDict.fromkeys(sorted(table_join_keys)) entity_key_protos = [] for entity_key in entity_keys: required_entities_to_values = required_entities.copy() for i in range(len(entity_key.join_keys)): entity_name = entity_key.join_keys[i] entity_value = entity_key.entity_values[i] if entity_name in required_entities_to_values: if required_entities_to_values[entity_name] is not None: raise ValueError( f"Duplicate entity keys detected. Table {table.name} expects {table_join_keys}. The entity " f"{entity_name} was provided at least twice" ) required_entities_to_values[entity_name] = entity_value entity_names = [] entity_values = [] for entity_name, entity_value in required_entities_to_values.items(): if entity_value is None: raise ValueError( f"Table {table.name} expects entity field {table_join_keys}. No entity value was found for " f"{entity_name}" ) entity_names.append(entity_name) entity_values.append(entity_value) entity_key_protos.append( EntityKeyProto(join_keys=entity_names, entity_values=entity_values) ) return entity_key_protos
def test_dynamodb_online_store_online_read_unknown_entity( repo_config, dynamodb_online_store): """Test DynamoDBOnlineStore online_read method.""" n_samples = 2 _create_test_table(PROJECT, f"{TABLE_NAME}_unknown_entity_{n_samples}", REGION) data = _create_n_customer_test_samples(n=n_samples) _insert_data_test_table(data, PROJECT, f"{TABLE_NAME}_unknown_entity_{n_samples}", REGION) entity_keys, features, *rest = zip(*data) # Append a nonsensical entity to search for entity_keys = list(entity_keys) features = list(features) # Have the unknown entity be in the beginning, middle, and end of the list of entities. for pos in range(len(entity_keys)): entity_keys_with_unknown = deepcopy(entity_keys) entity_keys_with_unknown.insert( pos, EntityKeyProto(join_keys=["customer"], entity_values=[ValueProto(string_val="12359")]), ) features_with_none = deepcopy(features) features_with_none.insert(pos, None) returned_items = dynamodb_online_store.online_read( config=repo_config, table=MockFeatureView( name=f"{TABLE_NAME}_unknown_entity_{n_samples}"), entity_keys=entity_keys_with_unknown, ) assert len(returned_items) == len(entity_keys_with_unknown) assert [item[1] for item in returned_items] == list(features_with_none) # The order should match the original entity key order assert returned_items[pos] == (None, None)
def _convert_arrow_to_proto( table: Union[pyarrow.Table, pyarrow.RecordBatch], feature_view: FeatureView, join_keys: Dict[str, ValueType], ) -> List[Tuple[EntityKeyProto, Dict[str, ValueProto], datetime, Optional[datetime]]]: # Avoid ChunkedArrays which guarentees `zero_copy_only` availiable. if isinstance(table, pyarrow.Table): table = table.to_batches()[0] columns = [(field.name, field.dtype.to_value_type()) for field in feature_view.schema] + list(join_keys.items()) proto_values_by_column = { column: python_values_to_proto_values( table.column(column).to_numpy(zero_copy_only=False), value_type) for column, value_type in columns } entity_keys = [ EntityKeyProto( join_keys=join_keys, entity_values=[proto_values_by_column[k][idx] for k in join_keys], ) for idx in range(table.num_rows) ] # Serialize the features per row feature_dict = { feature.name: proto_values_by_column[feature.name] for feature in feature_view.features } features = [ dict(zip(feature_dict, vars)) for vars in zip(*feature_dict.values()) ] # Convert event_timestamps event_timestamps = [ _coerce_datetime(val) for val in pandas.to_datetime( table.column(feature_view.batch_source.timestamp_field).to_numpy( zero_copy_only=False)) ] # Convert created_timestamps if they exist if feature_view.batch_source.created_timestamp_column: created_timestamps = [ _coerce_datetime(val) for val in pandas.to_datetime( table.column(feature_view.batch_source.created_timestamp_column ).to_numpy(zero_copy_only=False)) ] else: created_timestamps = [None] * table.num_rows return list( zip(entity_keys, features, event_timestamps, created_timestamps))
def _create_n_customer_test_samples(n=10): return [( EntityKeyProto(join_keys=["customer"], entity_values=[ValueProto(string_val=str(i))]), { "avg_orders_day": ValueProto(float_val=1.0), "name": ValueProto(string_val="John"), "age": ValueProto(int64_val=3), }, datetime.utcnow(), None, ) for i in range(n)]
def _convert_arrow_to_proto( table: pyarrow.Table, feature_view: FeatureView, join_keys: List[str], ) -> List[Tuple[EntityKeyProto, Dict[str, ValueProto], datetime, Optional[datetime]]]: rows_to_write = [] def _coerce_datetime(ts): """ Depending on underlying time resolution, arrow to_pydict() sometimes returns pandas timestamp type (for nanosecond resolution), and sometimes you get standard python datetime (for microsecond resolution). While pandas timestamp class is a subclass of python datetime, it doesn't always behave the same way. We convert it to normal datetime so that consumers downstream don't have to deal with these quirks. """ if isinstance(ts, pandas.Timestamp): return ts.to_pydatetime() else: return ts column_names_idx = {k: i for i, k in enumerate(table.column_names)} for row in zip(*table.to_pydict().values()): entity_key = EntityKeyProto() for join_key in join_keys: entity_key.join_keys.append(join_key) idx = column_names_idx[join_key] value = python_value_to_proto_value(row[idx]) entity_key.entity_values.append(value) feature_dict = {} for feature in feature_view.features: idx = column_names_idx[feature.name] value = python_value_to_proto_value(row[idx], feature.dtype) feature_dict[feature.name] = value event_timestamp_idx = column_names_idx[ feature_view.batch_source.event_timestamp_column] event_timestamp = _coerce_datetime(row[event_timestamp_idx]) if feature_view.batch_source.created_timestamp_column: created_timestamp_idx = column_names_idx[ feature_view.batch_source.created_timestamp_column] created_timestamp = _coerce_datetime(row[created_timestamp_idx]) else: created_timestamp = None rows_to_write.append( (entity_key, feature_dict, event_timestamp, created_timestamp)) return rows_to_write
def _convert_arrow_to_proto( table: Union[pyarrow.Table, pyarrow.RecordBatch], feature_view: FeatureView, join_keys: List[str], ) -> List[Tuple[EntityKeyProto, Dict[str, ValueProto], datetime, Optional[datetime]]]: # Handle join keys join_key_values = {k: table.column(k).to_pylist() for k in join_keys} entity_keys = [ EntityKeyProto( join_keys=join_keys, entity_values=[ python_value_to_proto_value(join_key_values[k][idx]) for k in join_keys ], ) for idx in range(table.num_rows) ] # Serialize the features per row feature_dict = { feature.name: [ python_value_to_proto_value(val, feature.dtype) for val in table.column(feature.name).to_pylist() ] for feature in feature_view.features } features = [ dict(zip(feature_dict, vars)) for vars in zip(*feature_dict.values()) ] # Convert event_timestamps event_timestamps = [ _coerce_datetime(val) for val in table.column( feature_view.batch_source.event_timestamp_column).to_pylist() ] # Convert created_timestamps if they exist if feature_view.batch_source.created_timestamp_column: created_timestamps = [ _coerce_datetime(val) for val in table.column(feature_view.batch_source. created_timestamp_column).to_pylist() ] else: created_timestamps = [None] * table.num_rows return list( zip(entity_keys, features, event_timestamps, created_timestamps))
def teardown_infra( self, project: str, tables: Sequence[Union[FeatureTable, FeatureView]], entities: Sequence[Entity], ) -> None: # according to the repos_operations.py we can delete the whole project client = self._get_client() tables_join_keys = [[e for e in t.entities] for t in tables] for table_join_keys in tables_join_keys: redis_key_bin = _redis_key( project, EntityKeyProto(join_keys=table_join_keys) ) keys = [k for k in client.scan_iter(match=f"{redis_key_bin}*", count=100)] if keys: client.unlink(*keys)
def basic_rw_test(store: FeatureStore, view_name: str, feature_service_name: Optional[str] = None) -> None: """ This is a provider-independent test suite for reading and writing from the online store, to be used by provider-specific tests. """ table = store.get_feature_view(name=view_name) provider = store._get_provider() entity_key = EntityKeyProto(join_keys=["driver"], entity_values=[ValueProto(int64_val=1)]) def _driver_rw_test(event_ts, created_ts, write, expect_read): """ A helper function to write values and read them back """ write_lat, write_lon = write expect_lat, expect_lon = expect_read provider.online_write_batch( config=store.config, table=table, data=[( entity_key, { "lat": ValueProto(double_val=write_lat), "lon": ValueProto(string_val=write_lon), }, event_ts, created_ts, )], progress=None, ) if feature_service_name: entity_dict = {"driver": 1} feature_service = store.get_feature_service(feature_service_name) features = store.get_online_features(features=feature_service, entity_rows=[entity_dict ]).to_dict() assert len(features["driver"]) == 1 assert features["lon"][0] == expect_lon assert abs(features["lat"][0] - expect_lat) < 1e-6 else: read_rows = provider.online_read(config=store.config, table=table, entity_keys=[entity_key]) assert len(read_rows) == 1 _, val = read_rows[0] assert val["lon"].string_val == expect_lon assert abs(val["lat"].double_val - expect_lat) < 1e-6 """ 1. Basic test: write value, read it back """ time_1 = datetime.utcnow() _driver_rw_test(event_ts=time_1, created_ts=time_1, write=(1.1, "3.1"), expect_read=(1.1, "3.1")) # Note: This behavior has changed for performance. We should test that older # value can't overwrite over a newer value once we add the respective flag """ Values with an older event_ts should overwrite newer ones """ time_2 = datetime.utcnow() _driver_rw_test( event_ts=time_1 - timedelta(hours=1), created_ts=time_2, write=(-1000, "OLD"), expect_read=(-1000, "OLD"), ) """ Values with an new event_ts should overwrite older ones """ time_3 = datetime.utcnow() _driver_rw_test( event_ts=time_1 + timedelta(hours=1), created_ts=time_3, write=(1123, "NEWER"), expect_read=(1123, "NEWER"), ) # Note: This behavior has changed for performance. We should test that older # value can't overwrite over a newer value once we add the respective flag """ created_ts is used as a tie breaker, using older created_ts here, but we still overwrite """ _driver_rw_test( event_ts=time_1 + timedelta(hours=1), created_ts=time_3 - timedelta(hours=1), write=(54321, "I HAVE AN OLDER created_ts SO I LOSE"), expect_read=(54321, "I HAVE AN OLDER created_ts SO I LOSE"), ) """ created_ts is used as a tie breaker, using newer created_ts here so we should overwrite """ _driver_rw_test( event_ts=time_1 + timedelta(hours=1), created_ts=time_3 + timedelta(hours=1), write=(96864, "I HAVE A NEWER created_ts SO I WIN"), expect_read=(96864, "I HAVE A NEWER created_ts SO I WIN"), )
def _entity_row_to_key( row: GetOnlineFeaturesRequestV2.EntityRow) -> EntityKeyProto: names, values = zip(*row.fields.items()) return EntityKeyProto(join_keys=names, entity_values=values) # type: ignore
def basic_rw_test(store: FeatureStore, view_name: str, feature_service_name: Optional[str] = None) -> None: """ This is a provider-independent test suite for reading and writing from the online store, to be used by provider-specific tests. """ table = store.get_feature_view(name=view_name) provider = store._get_provider() entity_key = EntityKeyProto(join_keys=["driver_id"], entity_values=[ValueProto(int64_val=1)]) def _driver_rw_test(event_ts, created_ts, write, expect_read): """A helper function to write values and read them back""" write_lat, write_lon = write expect_lat, expect_lon = expect_read provider.online_write_batch( config=store.config, table=table, data=[( entity_key, { "lat": ValueProto(double_val=write_lat), "lon": ValueProto(string_val=write_lon), }, event_ts, created_ts, )], progress=None, ) if feature_service_name: entity_dict = {"driver_id": 1} feature_service = store.get_feature_service(feature_service_name) features = store.get_online_features(features=feature_service, entity_rows=[entity_dict ]).to_dict() assert len(features["driver_id"]) == 1 assert features["lon"][0] == expect_lon assert abs(features["lat"][0] - expect_lat) < 1e-6 else: read_rows = provider.online_read(config=store.config, table=table, entity_keys=[entity_key]) assert len(read_rows) == 1 _, val = read_rows[0] assert val["lon"].string_val == expect_lon assert abs(val["lat"].double_val - expect_lat) < 1e-6 """ 1. Basic test: write value, read it back """ time_1 = datetime.utcnow() _driver_rw_test(event_ts=time_1, created_ts=time_1, write=(1.1, "3.1"), expect_read=(1.1, "3.1")) """ Values with an new event_ts should overwrite older ones """ time_3 = datetime.utcnow() _driver_rw_test( event_ts=time_1 + timedelta(hours=1), created_ts=time_3, write=(1123, "NEWER"), expect_read=(1123, "NEWER"), )
def test_online() -> None: """ Test reading from the online store in local mode. """ runner = CliRunner() with runner.local_repo( get_example_repo("example_feature_repo_1.py")) as store: # Write some data to two tables driver_locations_fv = store.get_feature_view(name="driver_locations") customer_profile_fv = store.get_feature_view(name="customer_profile") customer_driver_combined_fv = store.get_feature_view( name="customer_driver_combined") provider = store._get_provider() driver_key = EntityKeyProto(join_keys=["driver"], entity_values=[ValueProto(int64_val=1)]) provider.online_write_batch( project=store.config.project, table=driver_locations_fv, data=[( driver_key, { "lat": ValueProto(double_val=0.1), "lon": ValueProto(string_val="1.0"), }, datetime.utcnow(), datetime.utcnow(), )], progress=None, ) customer_key = EntityKeyProto(join_keys=["customer"], entity_values=[ValueProto(int64_val=5)]) provider.online_write_batch( project=store.config.project, table=customer_profile_fv, data=[( customer_key, { "avg_orders_day": ValueProto(float_val=1.0), "name": ValueProto(string_val="John"), "age": ValueProto(int64_val=3), }, datetime.utcnow(), datetime.utcnow(), )], progress=None, ) customer_key = EntityKeyProto( join_keys=["customer", "driver"], entity_values=[ValueProto(int64_val=5), ValueProto(int64_val=1)], ) provider.online_write_batch( project=store.config.project, table=customer_driver_combined_fv, data=[( customer_key, { "trips": ValueProto(int64_val=7) }, datetime.utcnow(), datetime.utcnow(), )], progress=None, ) # Retrieve two features using two keys, one valid one non-existing result = store.get_online_features( feature_refs=[ "driver_locations:lon", "customer_profile:avg_orders_day", "customer_profile:name", "customer_driver_combined:trips", ], entity_rows=[{ "driver": 1, "customer": 5 }, { "driver": 1, "customer": 5 }], ).to_dict() assert "driver_locations__lon" in result assert "customer_profile__avg_orders_day" in result assert "customer_profile__name" in result assert result["driver"] == [1, 1] assert result["customer"] == [5, 5] assert result["driver_locations__lon"] == ["1.0", "1.0"] assert result["customer_profile__avg_orders_day"] == [1.0, 1.0] assert result["customer_profile__name"] == ["John", "John"] assert result["customer_driver_combined__trips"] == [7, 7] # Ensure features are still in result when keys not found result = store.get_online_features( feature_refs=["customer_driver_combined:trips"], entity_rows=[{ "driver": 0, "customer": 0 }], ).to_dict() assert "customer_driver_combined__trips" in result # invalid table reference with pytest.raises(FeatureViewNotFoundException): store.get_online_features( feature_refs=["driver_locations_bad:lon"], entity_rows=[{ "driver": 1 }], ) # Create new FeatureStore object with fast cache invalidation cache_ttl = 1 fs_fast_ttl = FeatureStore(config=RepoConfig( registry=RegistryConfig(path=store.config.registry, cache_ttl_seconds=cache_ttl), online_store=store.config.online_store, project=store.config.project, provider=store.config.provider, )) # Should download the registry and cache it permanently (or until manually refreshed) result = fs_fast_ttl.get_online_features( feature_refs=[ "driver_locations:lon", "customer_profile:avg_orders_day", "customer_profile:name", "customer_driver_combined:trips", ], entity_rows=[{ "driver": 1, "customer": 5 }], ).to_dict() assert result["driver_locations__lon"] == ["1.0"] assert result["customer_driver_combined__trips"] == [7] # Rename the registry.db so that it cant be used for refreshes os.rename(store.config.registry, store.config.registry + "_fake") # Wait for registry to expire time.sleep(cache_ttl) # Will try to reload registry because it has expired (it will fail because we deleted the actual registry file) with pytest.raises(FileNotFoundError): fs_fast_ttl.get_online_features( feature_refs=[ "driver_locations:lon", "customer_profile:avg_orders_day", "customer_profile:name", "customer_driver_combined:trips", ], entity_rows=[{ "driver": 1, "customer": 5 }], ).to_dict() # Restore registry.db so that we can see if it actually reloads registry os.rename(store.config.registry + "_fake", store.config.registry) # Test if registry is actually reloaded and whether results return result = fs_fast_ttl.get_online_features( feature_refs=[ "driver_locations:lon", "customer_profile:avg_orders_day", "customer_profile:name", "customer_driver_combined:trips", ], entity_rows=[{ "driver": 1, "customer": 5 }], ).to_dict() assert result["driver_locations__lon"] == ["1.0"] assert result["customer_driver_combined__trips"] == [7] # Create a registry with infinite cache (for users that want to manually refresh the registry) fs_infinite_ttl = FeatureStore(config=RepoConfig( registry=RegistryConfig(path=store.config.registry, cache_ttl_seconds=0), online_store=store.config.online_store, project=store.config.project, provider=store.config.provider, )) # Should return results (and fill the registry cache) result = fs_infinite_ttl.get_online_features( feature_refs=[ "driver_locations:lon", "customer_profile:avg_orders_day", "customer_profile:name", "customer_driver_combined:trips", ], entity_rows=[{ "driver": 1, "customer": 5 }], ).to_dict() assert result["driver_locations__lon"] == ["1.0"] assert result["customer_driver_combined__trips"] == [7] # Wait a bit so that an arbitrary TTL would take effect time.sleep(2) # Rename the registry.db so that it cant be used for refreshes os.rename(store.config.registry, store.config.registry + "_fake") # TTL is infinite so this method should use registry cache result = fs_infinite_ttl.get_online_features( feature_refs=[ "driver_locations:lon", "customer_profile:avg_orders_day", "customer_profile:name", "customer_driver_combined:trips", ], entity_rows=[{ "driver": 1, "customer": 5 }], ).to_dict() assert result["driver_locations__lon"] == ["1.0"] assert result["customer_driver_combined__trips"] == [7] # Force registry reload (should fail because file is missing) with pytest.raises(FileNotFoundError): fs_infinite_ttl.refresh_registry() # Restore registry.db so that teardown works os.rename(store.config.registry + "_fake", store.config.registry)
def test_online() -> None: """ Test reading from the online store in local mode. """ runner = CliRunner() with runner.local_repo(get_example_repo("example_feature_repo_1.py")) as store: # Write some data to two tables registry = store._get_registry() table = registry.get_feature_view( project=store.config.project, name="driver_locations" ) table_2 = registry.get_feature_view( project=store.config.project, name="driver_locations_2" ) provider = store._get_provider() entity_key = EntityKeyProto( entity_names=["driver"], entity_values=[ValueProto(int64_val=1)] ) provider.online_write_batch( project=store.config.project, table=table, data=[ ( entity_key, { "lat": ValueProto(double_val=0.1), "lon": ValueProto(string_val="1.0"), }, datetime.utcnow(), datetime.utcnow(), ) ], progress=None, ) provider.online_write_batch( project=store.config.project, table=table_2, data=[ ( entity_key, { "lat": ValueProto(double_val=2.0), "lon": ValueProto(string_val="2.0"), }, datetime.utcnow(), datetime.utcnow(), ) ], progress=None, ) # Retrieve two features using two keys, one valid one non-existing result = store.get_online_features( feature_refs=["driver_locations:lon", "driver_locations_2:lon"], entity_rows=[{"driver": 1}, {"driver": 123}], ) assert "driver_locations:lon" in result.to_dict() assert result.to_dict()["driver_locations:lon"] == ["1.0", None] assert result.to_dict()["driver_locations_2:lon"] == ["2.0", None] # invalid table reference with pytest.raises(ValueError): store.get_online_features( feature_refs=["driver_locations_bad:lon"], entity_rows=[{"driver": 1}], )
def test_online_to_df(): """ Test dataframe conversion. Make sure the response columns and rows are the same order as the request. """ driver_ids = [1, 2, 3] customer_ids = [4, 5, 6] name = "foo" lon_multiply = 1.0 lat_multiply = 0.1 age_multiply = 10 avg_order_day_multiply = 1.0 runner = CliRunner() with runner.local_repo(get_example_repo("example_feature_repo_1.py"), "bigquery") as store: # Write three tables to online store driver_locations_fv = store.get_feature_view(name="driver_locations") customer_profile_fv = store.get_feature_view(name="customer_profile") customer_driver_combined_fv = store.get_feature_view( name="customer_driver_combined") provider = store._get_provider() for (d, c) in zip(driver_ids, customer_ids): """ driver table: driver driver_locations__lon driver_locations__lat 1 1.0 0.1 2 2.0 0.2 3 3.0 0.3 """ driver_key = EntityKeyProto( join_keys=["driver"], entity_values=[ValueProto(int64_val=d)]) provider.online_write_batch( config=store.config, table=driver_locations_fv, data=[( driver_key, { "lat": ValueProto(double_val=d * lat_multiply), "lon": ValueProto(string_val=str(d * lon_multiply)), }, datetime.utcnow(), datetime.utcnow(), )], progress=None, ) """ customer table customer customer_profile__avg_orders_day customer_profile__name customer_profile__age 4 4.0 foo4 40 5 5.0 foo5 50 6 6.0 foo6 60 """ customer_key = EntityKeyProto( join_keys=["customer"], entity_values=[ValueProto(int64_val=c)]) provider.online_write_batch( config=store.config, table=customer_profile_fv, data=[( customer_key, { "avg_orders_day": ValueProto(float_val=c * avg_order_day_multiply), "name": ValueProto(string_val=name + str(c)), "age": ValueProto(int64_val=c * age_multiply), }, datetime.utcnow(), datetime.utcnow(), )], progress=None, ) """ customer_driver_combined table customer driver customer_driver_combined__trips 4 1 4 5 2 10 6 3 18 """ combo_keys = EntityKeyProto( join_keys=["customer", "driver"], entity_values=[ ValueProto(int64_val=c), ValueProto(int64_val=d) ], ) provider.online_write_batch( config=store.config, table=customer_driver_combined_fv, data=[( combo_keys, { "trips": ValueProto(int64_val=c * d) }, datetime.utcnow(), datetime.utcnow(), )], progress=None, ) # Get online features in dataframe result_df = store.get_online_features( feature_refs=[ "driver_locations:lon", "driver_locations:lat", "customer_profile:avg_orders_day", "customer_profile:name", "customer_profile:age", "customer_driver_combined:trips", ], # Reverse the row order entity_rows=[{ "driver": d, "customer": c } for (d, c) in zip(reversed(driver_ids), reversed(customer_ids))], ).to_df() """ Construct the expected dataframe with reversed row order like so: driver customer driver_locations__lon driver_locations__lat customer_profile__avg_orders_day customer_profile__name customer_profile__age customer_driver_combined__trips 3 6 3.0 0.3 6.0 foo6 60 18 2 5 2.0 0.2 5.0 foo5 50 10 1 4 1.0 0.1 4.0 foo4 40 4 """ df_dict = { "driver": driver_ids, "customer": customer_ids, "driver_locations__lon": [str(d * lon_multiply) for d in driver_ids], "driver_locations__lat": [d * lat_multiply for d in driver_ids], "customer_profile__avg_orders_day": [c * avg_order_day_multiply for c in customer_ids], "customer_profile__name": [name + str(c) for c in customer_ids], "customer_profile__age": [c * age_multiply for c in customer_ids], "customer_driver_combined__trips": [d * c for (d, c) in zip(driver_ids, customer_ids)], } # Requested column order ordered_column = [ "driver", "customer", "driver_locations__lon", "driver_locations__lat", "customer_profile__avg_orders_day", "customer_profile__name", "customer_profile__age", "customer_driver_combined__trips", ] expected_df = pd.DataFrame( {k: reversed(v) for (k, v) in df_dict.items()}) assert_frame_equal(result_df[ordered_column], expected_df)
def basic_rw_test(store: FeatureStore, view_name: str) -> None: """ This is a provider-independent test suite for reading and writing from the online store, to be used by provider-specific tests. """ table = store.get_feature_view(name=view_name) provider = store._get_provider() entity_key = EntityKeyProto(entity_names=["driver"], entity_values=[ValueProto(int64_val=1)]) def _driver_rw_test(event_ts, created_ts, write, expect_read): """ A helper function to write values and read them back """ write_lat, write_lon = write expect_lat, expect_lon = expect_read provider.online_write_batch( project=store.project, table=table, data=[( entity_key, { "lat": ValueProto(double_val=write_lat), "lon": ValueProto(string_val=write_lon), }, event_ts, created_ts, )], progress=None, ) read_rows = provider.online_read(project=store.project, table=table, entity_keys=[entity_key]) assert len(read_rows) == 1 _, val = read_rows[0] assert val["lon"].string_val == expect_lon assert abs(val["lat"].double_val - expect_lat) < 1e-6 """ 1. Basic test: write value, read it back """ time_1 = datetime.utcnow() _driver_rw_test(event_ts=time_1, created_ts=time_1, write=(1.1, "3.1"), expect_read=(1.1, "3.1")) """ Values with an older event_ts should not overwrite newer ones """ time_2 = datetime.utcnow() _driver_rw_test( event_ts=time_1 - timedelta(hours=1), created_ts=time_2, write=(-1000, "OLD"), expect_read=(1.1, "3.1"), ) """ Values with an new event_ts should overwrite older ones """ time_3 = datetime.utcnow() _driver_rw_test( event_ts=time_1 + timedelta(hours=1), created_ts=time_3, write=(1123, "NEWER"), expect_read=(1123, "NEWER"), ) """ created_ts is used as a tie breaker, using older created_ts here so no overwrite """ _driver_rw_test( event_ts=time_1 + timedelta(hours=1), created_ts=time_3 - timedelta(hours=1), write=(54321, "I HAVE AN OLDER created_ts SO I LOSE"), expect_read=(1123, "NEWER"), ) """ created_ts is used as a tie breaker, using older created_ts here so no overwrite """ _driver_rw_test( event_ts=time_1 + timedelta(hours=1), created_ts=time_3 + timedelta(hours=1), write=(96864, "I HAVE A NEWER created_ts SO I WIN"), expect_read=(96864, "I HAVE A NEWER created_ts SO I WIN"), )
def test_bigquery_ingestion_correctness(self): # create dataset ts = pd.Timestamp.now(tz="UTC").round("ms") checked_value = ( random.random() ) # random value so test doesn't still work if no values written to online store data = { "id": [1, 2, 1], "value": [0.1, 0.2, checked_value], "ts_1": [ts - timedelta(minutes=2), ts, ts], "created_ts": [ts, ts, ts], } df = pd.DataFrame.from_dict(data) # load dataset into BigQuery job_config = bigquery.LoadJobConfig() table_id = ( f"{self.gcp_project}.{self.bigquery_dataset}.correctness_{int(time.time())}" ) job = self.client.load_table_from_dataframe(df, table_id, job_config=job_config) job.result() # create FeatureView fv = FeatureView( name="test_bq_correctness", entities=["driver_id"], features=[Feature("value", ValueType.FLOAT)], ttl=timedelta(minutes=5), input=BigQuerySource( event_timestamp_column="ts", table_ref=table_id, created_timestamp_column="created_ts", field_mapping={ "ts_1": "ts", "id": "driver_id" }, date_partition_column="", ), ) config = RepoConfig( metadata_store="./metadata.db", project="default", provider="gcp", online_store=OnlineStoreConfig( local=LocalOnlineStoreConfig("online_store.db")), ) fs = FeatureStore(config=config) fs.apply([fv]) # run materialize() fs.materialize( ["test_bq_correctness"], datetime.utcnow() - timedelta(minutes=5), datetime.utcnow() - timedelta(minutes=0), ) # check result of materialize() entity_key = EntityKeyProto(entity_names=["driver_id"], entity_values=[ValueProto(int64_val=1)]) t, val = fs._get_provider().online_read("default", fv, entity_key) assert abs(val["value"].double_val - checked_value) < 1e-6