def _driver_rw_test(event_ts, created_ts, write, expect_read): """ A helper function to write values and read them back """ write_lat, write_lon = write expect_lat, expect_lon = expect_read provider.online_write_batch( project=store.project, table=table, data=[ ( entity_key, { "lat": ValueProto(double_val=write_lat), "lon": ValueProto(string_val=write_lon), }, event_ts, created_ts, ) ], progress=None, ) read_rows = provider.online_read( project=store.project, table=table, entity_keys=[entity_key] ) assert len(read_rows) == 1 _, val = read_rows[0] assert val["lon"].string_val == expect_lon assert abs(val["lat"].double_val - expect_lat) < 1e-6
def _create_n_customer_test_samples(n=10): return [( EntityKeyProto(join_keys=["customer"], entity_values=[ValueProto(string_val=str(i))]), { "avg_orders_day": ValueProto(float_val=1.0), "name": ValueProto(string_val="John"), "age": ValueProto(int64_val=3), }, datetime.utcnow(), None, ) for i in range(n)]
def _get_features_for_entity( self, values: List[ByteString], feature_view: str, requested_features: List[str], ) -> Tuple[Optional[datetime], Optional[Dict[str, ValueProto]]]: res_val = dict(zip(requested_features, values)) res_ts = Timestamp() ts_val = res_val.pop(f"_ts:{feature_view}") if ts_val: res_ts.ParseFromString(ts_val) res = {} for feature_name, val_bin in res_val.items(): val = ValueProto() if val_bin: val.ParseFromString(val_bin) res[feature_name] = val if not res: return None, None else: timestamp = datetime.fromtimestamp(res_ts.seconds) return timestamp, res
def online_read( self, config: RepoConfig, table: Union[FeatureTable, FeatureView], entity_keys: List[EntityKeyProto], requested_features: Optional[List[str]] = None, ) -> List[Tuple[Optional[datetime], Optional[Dict[str, ValueProto]]]]: online_config = config.online_store assert isinstance(online_config, DatastoreOnlineStoreConfig) client = self._get_client(online_config) feast_project = config.project result: List[Tuple[Optional[datetime], Optional[Dict[str, ValueProto]]]] = [] for entity_key in entity_keys: document_id = compute_datastore_entity_id(entity_key) key = client.key("Project", feast_project, "Table", table.name, "Row", document_id) value = client.get(key) if value is not None: res = {} for feature_name, value_bin in value["values"].items(): val = ValueProto() val.ParseFromString(value_bin) res[feature_name] = val result.append((value["event_ts"], res)) else: result.append((None, None)) return result
def online_read( self, config: RepoConfig, table: FeatureView, entity_keys: List[EntityKeyProto], requested_features: Optional[List[str]] = None, ) -> List[Tuple[Optional[datetime], Optional[Dict[str, ValueProto]]]]: online_config = config.online_store assert isinstance(online_config, DynamoDBOnlineStoreConfig) dynamodb_resource = self._get_dynamodb_resource(online_config.region) result: List[Tuple[Optional[datetime], Optional[Dict[str, ValueProto]]]] = [] for entity_key in entity_keys: table_instance = dynamodb_resource.Table( _get_table_name(config, table)) entity_id = compute_entity_id(entity_key) with tracing_span(name="remote_call"): response = table_instance.get_item( Key={"entity_id": entity_id}) value = response.get("Item") if value is not None: res = {} for feature_name, value_bin in value["values"].items(): val = ValueProto() val.ParseFromString(value_bin.value) res[feature_name] = val result.append((datetime.fromisoformat(value["event_ts"]), res)) else: result.append((None, None)) return result
def online_read( self, config: RepoConfig, table: Union[FeatureTable, FeatureView], entity_keys: List[EntityKeyProto], requested_features: Optional[List[str]] = None, ) -> List[Tuple[Optional[datetime], Optional[Dict[str, ValueProto]]]]: online_config = config.online_store assert isinstance(online_config, DynamoDBOnlineStoreConfig) _, dynamodb_resource = self._initialize_dynamodb(online_config) result: List[Tuple[Optional[datetime], Optional[Dict[str, ValueProto]]]] = [] for entity_key in entity_keys: table_instance = dynamodb_resource.Table( f"{config.project}.{table.name}") entity_id = compute_entity_id(entity_key) response = table_instance.get_item(Key={"entity_id": entity_id}) value = response.get("Item") if value is not None: res = {} for feature_name, value_bin in value["values"].items(): val = ValueProto() val.ParseFromString(value_bin.value) res[feature_name] = val result.append((value["event_ts"], res)) else: result.append((None, None)) return result
def test_dynamodb_online_store_online_read_unknown_entity( repo_config, dynamodb_online_store): """Test DynamoDBOnlineStore online_read method.""" n_samples = 2 _create_test_table(PROJECT, f"{TABLE_NAME}_unknown_entity_{n_samples}", REGION) data = _create_n_customer_test_samples(n=n_samples) _insert_data_test_table(data, PROJECT, f"{TABLE_NAME}_unknown_entity_{n_samples}", REGION) entity_keys, features, *rest = zip(*data) # Append a nonsensical entity to search for entity_keys = list(entity_keys) features = list(features) # Have the unknown entity be in the beginning, middle, and end of the list of entities. for pos in range(len(entity_keys)): entity_keys_with_unknown = deepcopy(entity_keys) entity_keys_with_unknown.insert( pos, EntityKeyProto(join_keys=["customer"], entity_values=[ValueProto(string_val="12359")]), ) features_with_none = deepcopy(features) features_with_none.insert(pos, None) returned_items = dynamodb_online_store.online_read( config=repo_config, table=MockFeatureView( name=f"{TABLE_NAME}_unknown_entity_{n_samples}"), entity_keys=entity_keys_with_unknown, ) assert len(returned_items) == len(entity_keys_with_unknown) assert [item[1] for item in returned_items] == list(features_with_none) # The order should match the original entity key order assert returned_items[pos] == (None, None)
def online_read( self, project: str, table: Union[FeatureTable, FeatureView], entity_keys: List[EntityKeyProto], ) -> List[Tuple[Optional[datetime], Optional[Dict[str, ValueProto]]]]: client = self._initialize_client() result: List[Tuple[Optional[datetime], Optional[Dict[str, ValueProto]]]] = [] for entity_key in entity_keys: document_id = compute_datastore_entity_id(entity_key) key = client.key( "Project", project, "Table", table.name, "Row", document_id ) value = client.get(key) if value is not None: res = {} for feature_name, value_bin in value["values"].items(): val = ValueProto() val.ParseFromString(value_bin) res[feature_name] = val result.append((value["event_ts"], res)) else: result.append((None, None)) return result
def online_read( self, config: RepoConfig, table: Union[FeatureTable, FeatureView], entity_keys: List[EntityKeyProto], requested_features: Optional[List[str]] = None, ) -> List[Tuple[Optional[datetime], Optional[Dict[str, ValueProto]]]]: pass conn = self._get_conn(config) cur = conn.cursor() result: List[Tuple[Optional[datetime], Optional[Dict[str, ValueProto]]]] = [] project = config.project for entity_key in entity_keys: entity_key_bin = serialize_entity_key(entity_key) cur.execute( f"SELECT feature_name, value, event_ts FROM {_table_id(project, table)} WHERE entity_key = ?", (entity_key_bin,), ) res = {} res_ts = None for feature_name, val_bin, ts in cur.fetchall(): val = ValueProto() val.ParseFromString(val_bin) res[feature_name] = val res_ts = ts if not res: result.append((None, None)) else: result.append((res_ts, res)) return result
def online_read( self, project: str, table: Union[FeatureTable, FeatureView], entity_key: EntityKeyProto, ) -> Tuple[Optional[datetime], Optional[Dict[str, ValueProto]]]: entity_key_bin = serialize_entity_key(entity_key) conn = self._get_conn() cur = conn.cursor() cur.execute( f"SELECT feature_name, value, event_ts FROM {_table_id(project, table)} WHERE entity_key = ?", (entity_key_bin, ), ) res = {} res_ts = None for feature_name, val_bin, ts in cur.fetchall(): val = ValueProto() val.ParseFromString(val_bin) res[feature_name] = val res_ts = ts if not res: return None, None else: return res_ts, res
def _driver_rw_test(event_ts, created_ts, write, expect_read): """ A helper function to write values and read them back """ write_lat, write_lon = write expect_lat, expect_lon = expect_read provider.online_write_batch( config=store.config, table=table, data=[ ( entity_key, { "lat": ValueProto(double_val=write_lat), "lon": ValueProto(string_val=write_lon), }, event_ts, created_ts, ) ], progress=None, ) if feature_service_name: entity_dict = {"driver": 1} feature_service = store.get_feature_service(feature_service_name) features = store.get_online_features( features=feature_service, entity_rows=[entity_dict] ).to_dict() assert len(features["driver"]) == 1 assert features["lon"][0] == expect_lon assert abs(features["lat"][0] - expect_lat) < 1e-6 else: read_rows = provider.online_read( config=store.config, table=table, entity_keys=[entity_key] ) assert len(read_rows) == 1 _, val = read_rows[0] assert val["lon"].string_val == expect_lon assert abs(val["lat"].double_val - expect_lat) < 1e-6
def _driver_rw_test(event_ts, created_ts, write, expect_read): """ A helper function to write values and read them back """ write_lat, write_lon = write expect_lat, expect_lon = expect_read provider.online_write_batch( project=project_name, table=table, data=[( entity_key, { "lat": ValueProto(double_val=write_lat), "lon": ValueProto(string_val=write_lon), }, event_ts, )], created_ts=created_ts, ) _, val = provider.online_read(project=project_name, table=table, entity_key=entity_key) assert val["lon"].string_val == expect_lon assert abs(val["lat"].double_val - expect_lat) < 1e-6
def online_read( self, config: RepoConfig, table: Union[FeatureTable, FeatureView], entity_keys: List[EntityKeyProto], requested_features: Optional[List[str]] = None, ) -> List[Tuple[Optional[datetime], Optional[Dict[str, ValueProto]]]]: online_store_config = config.online_store assert isinstance(online_store_config, RedisOnlineStoreConfig) client = self._get_client(online_store_config) feature_view = table.name project = config.project result: List[Tuple[Optional[datetime], Optional[Dict[str, ValueProto]]]] = [] if not requested_features: requested_features = [f.name for f in table.features] for entity_key in entity_keys: redis_key_bin = _redis_key(project, entity_key) hset_keys = [ _mmh3(f"{feature_view}:{k}") for k in requested_features ] ts_key = f"_ts:{feature_view}" hset_keys.append(ts_key) values = client.hmget(redis_key_bin, hset_keys) requested_features.append(ts_key) res_val = dict(zip(requested_features, values)) res_ts = Timestamp() ts_val = res_val.pop(ts_key) if ts_val: res_ts.ParseFromString(ts_val) res = {} for feature_name, val_bin in res_val.items(): val = ValueProto() if val_bin: val.ParseFromString(val_bin) res[feature_name] = val if not res: result.append((None, None)) else: timestamp = datetime.fromtimestamp(res_ts.seconds) result.append((timestamp, res)) return result
def online_read( self, config: RepoConfig, table: FeatureView, entity_keys: List[EntityKeyProto], requested_features: Optional[List[str]] = None, ) -> List[Tuple[Optional[datetime], Optional[Dict[str, ValueProto]]]]: result: List[Tuple[Optional[datetime], Optional[Dict[str, ValueProto]]]] = [] project = config.project with self._get_conn(config) as conn, conn.cursor() as cur: # Collecting all the keys to a list allows us to make fewer round trips # to PostgreSQL keys = [] for entity_key in entity_keys: keys.append(serialize_entity_key(entity_key)) cur.execute( sql.SQL(""" SELECT entity_key, feature_name, value, event_ts FROM {} WHERE entity_key = ANY(%s); """).format(sql.Identifier(_table_id(project, table)), ), (keys, ), ) rows = cur.fetchall() # Since we don't know the order returned from PostgreSQL we'll need # to construct a dict to be able to quickly look up the correct row # when we iterate through the keys since they are in the correct order values_dict = defaultdict(list) for row in rows if rows is not None else []: values_dict[row[0].tobytes()].append(row[1:]) for key in keys: if key in values_dict: value = values_dict[key] res = {} for feature_name, value_bin, event_ts in value: val = ValueProto() val.ParseFromString(value_bin) res[feature_name] = val result.append((event_ts, res)) else: result.append((None, None)) return result
def online_read( self, config: RepoConfig, table: FeatureView, entity_keys: List[EntityKeyProto], requested_features: Optional[List[str]] = None, ) -> List[Tuple[Optional[datetime], Optional[Dict[str, ValueProto]]]]: conn = self._get_conn(config) cur = conn.cursor() result: List[Tuple[Optional[datetime], Optional[Dict[str, ValueProto]]]] = [] with tracing_span(name="remote_call"): # Fetch all entities in one go cur.execute( f"SELECT entity_key, feature_name, value, event_ts " f"FROM {_table_id(config.project, table)} " f"WHERE entity_key IN ({','.join('?' * len(entity_keys))}) " f"ORDER BY entity_key", [ serialize_entity_key(entity_key) for entity_key in entity_keys ], ) rows = cur.fetchall() rows = { k: list(group) for k, group in itertools.groupby(rows, key=lambda r: r[0]) } for entity_key in entity_keys: entity_key_bin = serialize_entity_key(entity_key) res = {} res_ts = None for _, feature_name, val_bin, ts in rows.get(entity_key_bin, []): val = ValueProto() val.ParseFromString(val_bin) res[feature_name] = val res_ts = ts if not res: result.append((None, None)) else: result.append((res_ts, res)) return result
def online_read( self, config: RepoConfig, table: Union[FeatureTable, FeatureView], entity_keys: List[EntityKeyProto], requested_features: Optional[List[str]] = None, ) -> List[Tuple[Optional[datetime], Optional[Dict[str, ValueProto]]]]: online_config = config.online_store assert isinstance(online_config, DatastoreOnlineStoreConfig) client = self._get_client(online_config) feast_project = config.project keys: List[Key] = [] result: List[Tuple[Optional[datetime], Optional[Dict[str, ValueProto]]]] = [] for entity_key in entity_keys: document_id = compute_entity_id(entity_key) key = client.key("Project", feast_project, "Table", table.name, "Row", document_id) keys.append(key) # NOTE: get_multi doesn't return values in the same order as the keys in the request. # Also, len(values) can be less than len(keys) in the case of missing values. with tracing_span(name="remote_call"): values = client.get_multi(keys) values_dict = {v.key: v for v in values} if values is not None else {} for key in keys: if key in values_dict: value = values_dict[key] res = {} for feature_name, value_bin in value["values"].items(): val = ValueProto() val.ParseFromString(value_bin) res[feature_name] = val result.append((value["event_ts"], res)) else: result.append((None, None)) return result
def online_read( self, config: RepoConfig, table: Union[FeatureTable, FeatureView], entity_keys: List[EntityKeyProto], requested_features: Optional[List[str]] = None, ) -> List[Tuple[Optional[datetime], Optional[Dict[str, ValueProto]]]]: online_config = config.online_store assert isinstance(online_config, DatastoreOnlineStoreConfig) client = self._get_client(online_config) feast_project = config.project keys: List[Key] = [] result: List[Tuple[Optional[datetime], Optional[Dict[str, ValueProto]]]] = [] for entity_key in entity_keys: document_id = compute_entity_id(entity_key) key = client.key("Project", feast_project, "Table", table.name, "Row", document_id) keys.append(key) values = client.get_multi(keys) if values is not None: keys_missing_from_response = set(keys) - set( [v.key for v in values]) values = sorted(values, key=lambda v: keys.index(v.key)) for value in values: res = {} for feature_name, value_bin in value["values"].items(): val = ValueProto() val.ParseFromString(value_bin) res[feature_name] = val result.append((value["event_ts"], res)) for missing_key_idx in sorted( [keys.index(k) for k in keys_missing_from_response]): result.insert(missing_key_idx, (None, None)) return result
def basic_rw_test(store: FeatureStore, view_name: str, feature_service_name: Optional[str] = None) -> None: """ This is a provider-independent test suite for reading and writing from the online store, to be used by provider-specific tests. """ table = store.get_feature_view(name=view_name) provider = store._get_provider() entity_key = EntityKeyProto(join_keys=["driver"], entity_values=[ValueProto(int64_val=1)]) def _driver_rw_test(event_ts, created_ts, write, expect_read): """ A helper function to write values and read them back """ write_lat, write_lon = write expect_lat, expect_lon = expect_read provider.online_write_batch( config=store.config, table=table, data=[( entity_key, { "lat": ValueProto(double_val=write_lat), "lon": ValueProto(string_val=write_lon), }, event_ts, created_ts, )], progress=None, ) if feature_service_name: entity_dict = {"driver": 1} feature_service = store.get_feature_service(feature_service_name) features = store.get_online_features(features=feature_service, entity_rows=[entity_dict ]).to_dict() assert len(features["driver"]) == 1 assert features["lon"][0] == expect_lon assert abs(features["lat"][0] - expect_lat) < 1e-6 else: read_rows = provider.online_read(config=store.config, table=table, entity_keys=[entity_key]) assert len(read_rows) == 1 _, val = read_rows[0] assert val["lon"].string_val == expect_lon assert abs(val["lat"].double_val - expect_lat) < 1e-6 """ 1. Basic test: write value, read it back """ time_1 = datetime.utcnow() _driver_rw_test(event_ts=time_1, created_ts=time_1, write=(1.1, "3.1"), expect_read=(1.1, "3.1")) # Note: This behavior has changed for performance. We should test that older # value can't overwrite over a newer value once we add the respective flag """ Values with an older event_ts should overwrite newer ones """ time_2 = datetime.utcnow() _driver_rw_test( event_ts=time_1 - timedelta(hours=1), created_ts=time_2, write=(-1000, "OLD"), expect_read=(-1000, "OLD"), ) """ Values with an new event_ts should overwrite older ones """ time_3 = datetime.utcnow() _driver_rw_test( event_ts=time_1 + timedelta(hours=1), created_ts=time_3, write=(1123, "NEWER"), expect_read=(1123, "NEWER"), ) # Note: This behavior has changed for performance. We should test that older # value can't overwrite over a newer value once we add the respective flag """ created_ts is used as a tie breaker, using older created_ts here, but we still overwrite """ _driver_rw_test( event_ts=time_1 + timedelta(hours=1), created_ts=time_3 - timedelta(hours=1), write=(54321, "I HAVE AN OLDER created_ts SO I LOSE"), expect_read=(54321, "I HAVE AN OLDER created_ts SO I LOSE"), ) """ created_ts is used as a tie breaker, using newer created_ts here so we should overwrite """ _driver_rw_test( event_ts=time_1 + timedelta(hours=1), created_ts=time_3 + timedelta(hours=1), write=(96864, "I HAVE A NEWER created_ts SO I WIN"), expect_read=(96864, "I HAVE A NEWER created_ts SO I WIN"), )
def basic_rw_test(store: FeatureStore, view_name: str) -> None: """ This is a provider-independent test suite for reading and writing from the online store, to be used by provider-specific tests. """ table = store.get_feature_view(name=view_name) provider = store._get_provider() entity_key = EntityKeyProto(entity_names=["driver"], entity_values=[ValueProto(int64_val=1)]) def _driver_rw_test(event_ts, created_ts, write, expect_read): """ A helper function to write values and read them back """ write_lat, write_lon = write expect_lat, expect_lon = expect_read provider.online_write_batch( project=store.project, table=table, data=[( entity_key, { "lat": ValueProto(double_val=write_lat), "lon": ValueProto(string_val=write_lon), }, event_ts, created_ts, )], progress=None, ) read_rows = provider.online_read(project=store.project, table=table, entity_keys=[entity_key]) assert len(read_rows) == 1 _, val = read_rows[0] assert val["lon"].string_val == expect_lon assert abs(val["lat"].double_val - expect_lat) < 1e-6 """ 1. Basic test: write value, read it back """ time_1 = datetime.utcnow() _driver_rw_test(event_ts=time_1, created_ts=time_1, write=(1.1, "3.1"), expect_read=(1.1, "3.1")) """ Values with an older event_ts should not overwrite newer ones """ time_2 = datetime.utcnow() _driver_rw_test( event_ts=time_1 - timedelta(hours=1), created_ts=time_2, write=(-1000, "OLD"), expect_read=(1.1, "3.1"), ) """ Values with an new event_ts should overwrite older ones """ time_3 = datetime.utcnow() _driver_rw_test( event_ts=time_1 + timedelta(hours=1), created_ts=time_3, write=(1123, "NEWER"), expect_read=(1123, "NEWER"), ) """ created_ts is used as a tie breaker, using older created_ts here so no overwrite """ _driver_rw_test( event_ts=time_1 + timedelta(hours=1), created_ts=time_3 - timedelta(hours=1), write=(54321, "I HAVE AN OLDER created_ts SO I LOSE"), expect_read=(1123, "NEWER"), ) """ created_ts is used as a tie breaker, using older created_ts here so no overwrite """ _driver_rw_test( event_ts=time_1 + timedelta(hours=1), created_ts=time_3 + timedelta(hours=1), write=(96864, "I HAVE A NEWER created_ts SO I WIN"), expect_read=(96864, "I HAVE A NEWER created_ts SO I WIN"), )
def test_online() -> None: """ Test reading from the online store in local mode. """ runner = CliRunner() with runner.local_repo( get_example_repo("example_feature_repo_1.py")) as store: # Write some data to two tables driver_locations_fv = store.get_feature_view(name="driver_locations") customer_profile_fv = store.get_feature_view(name="customer_profile") customer_driver_combined_fv = store.get_feature_view( name="customer_driver_combined") provider = store._get_provider() driver_key = EntityKeyProto(join_keys=["driver"], entity_values=[ValueProto(int64_val=1)]) provider.online_write_batch( project=store.config.project, table=driver_locations_fv, data=[( driver_key, { "lat": ValueProto(double_val=0.1), "lon": ValueProto(string_val="1.0"), }, datetime.utcnow(), datetime.utcnow(), )], progress=None, ) customer_key = EntityKeyProto(join_keys=["customer"], entity_values=[ValueProto(int64_val=5)]) provider.online_write_batch( project=store.config.project, table=customer_profile_fv, data=[( customer_key, { "avg_orders_day": ValueProto(float_val=1.0), "name": ValueProto(string_val="John"), "age": ValueProto(int64_val=3), }, datetime.utcnow(), datetime.utcnow(), )], progress=None, ) customer_key = EntityKeyProto( join_keys=["customer", "driver"], entity_values=[ValueProto(int64_val=5), ValueProto(int64_val=1)], ) provider.online_write_batch( project=store.config.project, table=customer_driver_combined_fv, data=[( customer_key, { "trips": ValueProto(int64_val=7) }, datetime.utcnow(), datetime.utcnow(), )], progress=None, ) # Retrieve two features using two keys, one valid one non-existing result = store.get_online_features( feature_refs=[ "driver_locations:lon", "customer_profile:avg_orders_day", "customer_profile:name", "customer_driver_combined:trips", ], entity_rows=[{ "driver": 1, "customer": 5 }, { "driver": 1, "customer": 5 }], ).to_dict() assert "driver_locations__lon" in result assert "customer_profile__avg_orders_day" in result assert "customer_profile__name" in result assert result["driver"] == [1, 1] assert result["customer"] == [5, 5] assert result["driver_locations__lon"] == ["1.0", "1.0"] assert result["customer_profile__avg_orders_day"] == [1.0, 1.0] assert result["customer_profile__name"] == ["John", "John"] assert result["customer_driver_combined__trips"] == [7, 7] # Ensure features are still in result when keys not found result = store.get_online_features( feature_refs=["customer_driver_combined:trips"], entity_rows=[{ "driver": 0, "customer": 0 }], ).to_dict() assert "customer_driver_combined__trips" in result # invalid table reference with pytest.raises(FeatureViewNotFoundException): store.get_online_features( feature_refs=["driver_locations_bad:lon"], entity_rows=[{ "driver": 1 }], ) # Create new FeatureStore object with fast cache invalidation cache_ttl = 1 fs_fast_ttl = FeatureStore(config=RepoConfig( registry=RegistryConfig(path=store.config.registry, cache_ttl_seconds=cache_ttl), online_store=store.config.online_store, project=store.config.project, provider=store.config.provider, )) # Should download the registry and cache it permanently (or until manually refreshed) result = fs_fast_ttl.get_online_features( feature_refs=[ "driver_locations:lon", "customer_profile:avg_orders_day", "customer_profile:name", "customer_driver_combined:trips", ], entity_rows=[{ "driver": 1, "customer": 5 }], ).to_dict() assert result["driver_locations__lon"] == ["1.0"] assert result["customer_driver_combined__trips"] == [7] # Rename the registry.db so that it cant be used for refreshes os.rename(store.config.registry, store.config.registry + "_fake") # Wait for registry to expire time.sleep(cache_ttl) # Will try to reload registry because it has expired (it will fail because we deleted the actual registry file) with pytest.raises(FileNotFoundError): fs_fast_ttl.get_online_features( feature_refs=[ "driver_locations:lon", "customer_profile:avg_orders_day", "customer_profile:name", "customer_driver_combined:trips", ], entity_rows=[{ "driver": 1, "customer": 5 }], ).to_dict() # Restore registry.db so that we can see if it actually reloads registry os.rename(store.config.registry + "_fake", store.config.registry) # Test if registry is actually reloaded and whether results return result = fs_fast_ttl.get_online_features( feature_refs=[ "driver_locations:lon", "customer_profile:avg_orders_day", "customer_profile:name", "customer_driver_combined:trips", ], entity_rows=[{ "driver": 1, "customer": 5 }], ).to_dict() assert result["driver_locations__lon"] == ["1.0"] assert result["customer_driver_combined__trips"] == [7] # Create a registry with infinite cache (for users that want to manually refresh the registry) fs_infinite_ttl = FeatureStore(config=RepoConfig( registry=RegistryConfig(path=store.config.registry, cache_ttl_seconds=0), online_store=store.config.online_store, project=store.config.project, provider=store.config.provider, )) # Should return results (and fill the registry cache) result = fs_infinite_ttl.get_online_features( feature_refs=[ "driver_locations:lon", "customer_profile:avg_orders_day", "customer_profile:name", "customer_driver_combined:trips", ], entity_rows=[{ "driver": 1, "customer": 5 }], ).to_dict() assert result["driver_locations__lon"] == ["1.0"] assert result["customer_driver_combined__trips"] == [7] # Wait a bit so that an arbitrary TTL would take effect time.sleep(2) # Rename the registry.db so that it cant be used for refreshes os.rename(store.config.registry, store.config.registry + "_fake") # TTL is infinite so this method should use registry cache result = fs_infinite_ttl.get_online_features( feature_refs=[ "driver_locations:lon", "customer_profile:avg_orders_day", "customer_profile:name", "customer_driver_combined:trips", ], entity_rows=[{ "driver": 1, "customer": 5 }], ).to_dict() assert result["driver_locations__lon"] == ["1.0"] assert result["customer_driver_combined__trips"] == [7] # Force registry reload (should fail because file is missing) with pytest.raises(FileNotFoundError): fs_infinite_ttl.refresh_registry() # Restore registry.db so that teardown works os.rename(store.config.registry + "_fake", store.config.registry)
def test_online() -> None: """ Test reading from the online store in local mode. """ runner = CliRunner() with runner.local_repo(get_example_repo("example_feature_repo_1.py")) as store: # Write some data to two tables registry = store._get_registry() table = registry.get_feature_view( project=store.config.project, name="driver_locations" ) table_2 = registry.get_feature_view( project=store.config.project, name="driver_locations_2" ) provider = store._get_provider() entity_key = EntityKeyProto( entity_names=["driver"], entity_values=[ValueProto(int64_val=1)] ) provider.online_write_batch( project=store.config.project, table=table, data=[ ( entity_key, { "lat": ValueProto(double_val=0.1), "lon": ValueProto(string_val="1.0"), }, datetime.utcnow(), datetime.utcnow(), ) ], progress=None, ) provider.online_write_batch( project=store.config.project, table=table_2, data=[ ( entity_key, { "lat": ValueProto(double_val=2.0), "lon": ValueProto(string_val="2.0"), }, datetime.utcnow(), datetime.utcnow(), ) ], progress=None, ) # Retrieve two features using two keys, one valid one non-existing result = store.get_online_features( feature_refs=["driver_locations:lon", "driver_locations_2:lon"], entity_rows=[{"driver": 1}, {"driver": 123}], ) assert "driver_locations:lon" in result.to_dict() assert result.to_dict()["driver_locations:lon"] == ["1.0", None] assert result.to_dict()["driver_locations_2:lon"] == ["2.0", None] # invalid table reference with pytest.raises(ValueError): store.get_online_features( feature_refs=["driver_locations_bad:lon"], entity_rows=[{"driver": 1}], )
def test_online_to_df(): """ Test dataframe conversion. Make sure the response columns and rows are the same order as the request. """ driver_ids = [1, 2, 3] customer_ids = [4, 5, 6] name = "foo" lon_multiply = 1.0 lat_multiply = 0.1 age_multiply = 10 avg_order_day_multiply = 1.0 runner = CliRunner() with runner.local_repo(get_example_repo("example_feature_repo_1.py"), "bigquery") as store: # Write three tables to online store driver_locations_fv = store.get_feature_view(name="driver_locations") customer_profile_fv = store.get_feature_view(name="customer_profile") customer_driver_combined_fv = store.get_feature_view( name="customer_driver_combined") provider = store._get_provider() for (d, c) in zip(driver_ids, customer_ids): """ driver table: driver driver_locations__lon driver_locations__lat 1 1.0 0.1 2 2.0 0.2 3 3.0 0.3 """ driver_key = EntityKeyProto( join_keys=["driver"], entity_values=[ValueProto(int64_val=d)]) provider.online_write_batch( config=store.config, table=driver_locations_fv, data=[( driver_key, { "lat": ValueProto(double_val=d * lat_multiply), "lon": ValueProto(string_val=str(d * lon_multiply)), }, datetime.utcnow(), datetime.utcnow(), )], progress=None, ) """ customer table customer customer_profile__avg_orders_day customer_profile__name customer_profile__age 4 4.0 foo4 40 5 5.0 foo5 50 6 6.0 foo6 60 """ customer_key = EntityKeyProto( join_keys=["customer"], entity_values=[ValueProto(int64_val=c)]) provider.online_write_batch( config=store.config, table=customer_profile_fv, data=[( customer_key, { "avg_orders_day": ValueProto(float_val=c * avg_order_day_multiply), "name": ValueProto(string_val=name + str(c)), "age": ValueProto(int64_val=c * age_multiply), }, datetime.utcnow(), datetime.utcnow(), )], progress=None, ) """ customer_driver_combined table customer driver customer_driver_combined__trips 4 1 4 5 2 10 6 3 18 """ combo_keys = EntityKeyProto( join_keys=["customer", "driver"], entity_values=[ ValueProto(int64_val=c), ValueProto(int64_val=d) ], ) provider.online_write_batch( config=store.config, table=customer_driver_combined_fv, data=[( combo_keys, { "trips": ValueProto(int64_val=c * d) }, datetime.utcnow(), datetime.utcnow(), )], progress=None, ) # Get online features in dataframe result_df = store.get_online_features( feature_refs=[ "driver_locations:lon", "driver_locations:lat", "customer_profile:avg_orders_day", "customer_profile:name", "customer_profile:age", "customer_driver_combined:trips", ], # Reverse the row order entity_rows=[{ "driver": d, "customer": c } for (d, c) in zip(reversed(driver_ids), reversed(customer_ids))], ).to_df() """ Construct the expected dataframe with reversed row order like so: driver customer driver_locations__lon driver_locations__lat customer_profile__avg_orders_day customer_profile__name customer_profile__age customer_driver_combined__trips 3 6 3.0 0.3 6.0 foo6 60 18 2 5 2.0 0.2 5.0 foo5 50 10 1 4 1.0 0.1 4.0 foo4 40 4 """ df_dict = { "driver": driver_ids, "customer": customer_ids, "driver_locations__lon": [str(d * lon_multiply) for d in driver_ids], "driver_locations__lat": [d * lat_multiply for d in driver_ids], "customer_profile__avg_orders_day": [c * avg_order_day_multiply for c in customer_ids], "customer_profile__name": [name + str(c) for c in customer_ids], "customer_profile__age": [c * age_multiply for c in customer_ids], "customer_driver_combined__trips": [d * c for (d, c) in zip(driver_ids, customer_ids)], } # Requested column order ordered_column = [ "driver", "customer", "driver_locations__lon", "driver_locations__lat", "customer_profile__avg_orders_day", "customer_profile__name", "customer_profile__age", "customer_driver_combined__trips", ] expected_df = pd.DataFrame( {k: reversed(v) for (k, v) in df_dict.items()}) assert_frame_equal(result_df[ordered_column], expected_df)
def online_read( self, config: RepoConfig, table: FeatureView, entity_keys: List[EntityKeyProto], requested_features: Optional[List[str]] = None, ) -> List[Tuple[Optional[datetime], Optional[Dict[str, ValueProto]]]]: """ Retrieve feature values from the online DynamoDB store. Args: config: The RepoConfig for the current FeatureStore. table: Feast FeatureView. entity_keys: a list of entity keys that should be read from the FeatureStore. """ online_config = config.online_store assert isinstance(online_config, DynamoDBOnlineStoreConfig) dynamodb_resource = self._get_dynamodb_resource( online_config.region, online_config.endpoint_url ) table_instance = dynamodb_resource.Table( _get_table_name(online_config, config, table) ) result: List[Tuple[Optional[datetime], Optional[Dict[str, ValueProto]]]] = [] entity_ids = [compute_entity_id(entity_key) for entity_key in entity_keys] batch_size = online_config.batch_size entity_ids_iter = iter(entity_ids) while True: batch = list(itertools.islice(entity_ids_iter, batch_size)) # No more items to insert if len(batch) == 0: break batch_entity_ids = { table_instance.name: { "Keys": [{"entity_id": entity_id} for entity_id in batch] } } with tracing_span(name="remote_call"): response = dynamodb_resource.batch_get_item( RequestItems=batch_entity_ids ) response = response.get("Responses") table_responses = response.get(table_instance.name) if table_responses: table_responses = self._sort_dynamodb_response( table_responses, entity_ids ) entity_idx = 0 for tbl_res in table_responses: entity_id = tbl_res["entity_id"] while entity_id != batch[entity_idx]: result.append((None, None)) entity_idx += 1 res = {} for feature_name, value_bin in tbl_res["values"].items(): val = ValueProto() val.ParseFromString(value_bin.value) res[feature_name] = val result.append((datetime.fromisoformat(tbl_res["event_ts"]), res)) entity_idx += 1 # Not all entities in a batch may have responses # Pad with remaining values in batch that were not found batch_size_nones = ((None, None),) * (len(batch) - len(result)) result.extend(batch_size_nones) return result
def basic_rw_test(store: FeatureStore, view_name: str, feature_service_name: Optional[str] = None) -> None: """ This is a provider-independent test suite for reading and writing from the online store, to be used by provider-specific tests. """ table = store.get_feature_view(name=view_name) provider = store._get_provider() entity_key = EntityKeyProto(join_keys=["driver_id"], entity_values=[ValueProto(int64_val=1)]) def _driver_rw_test(event_ts, created_ts, write, expect_read): """A helper function to write values and read them back""" write_lat, write_lon = write expect_lat, expect_lon = expect_read provider.online_write_batch( config=store.config, table=table, data=[( entity_key, { "lat": ValueProto(double_val=write_lat), "lon": ValueProto(string_val=write_lon), }, event_ts, created_ts, )], progress=None, ) if feature_service_name: entity_dict = {"driver_id": 1} feature_service = store.get_feature_service(feature_service_name) features = store.get_online_features(features=feature_service, entity_rows=[entity_dict ]).to_dict() assert len(features["driver_id"]) == 1 assert features["lon"][0] == expect_lon assert abs(features["lat"][0] - expect_lat) < 1e-6 else: read_rows = provider.online_read(config=store.config, table=table, entity_keys=[entity_key]) assert len(read_rows) == 1 _, val = read_rows[0] assert val["lon"].string_val == expect_lon assert abs(val["lat"].double_val - expect_lat) < 1e-6 """ 1. Basic test: write value, read it back """ time_1 = datetime.utcnow() _driver_rw_test(event_ts=time_1, created_ts=time_1, write=(1.1, "3.1"), expect_read=(1.1, "3.1")) """ Values with an new event_ts should overwrite older ones """ time_3 = datetime.utcnow() _driver_rw_test( event_ts=time_1 + timedelta(hours=1), created_ts=time_3, write=(1123, "NEWER"), expect_read=(1123, "NEWER"), )
def test_bigquery_ingestion_correctness(self): # create dataset ts = pd.Timestamp.now(tz="UTC").round("ms") checked_value = ( random.random() ) # random value so test doesn't still work if no values written to online store data = { "id": [1, 2, 1], "value": [0.1, 0.2, checked_value], "ts_1": [ts - timedelta(minutes=2), ts, ts], "created_ts": [ts, ts, ts], } df = pd.DataFrame.from_dict(data) # load dataset into BigQuery job_config = bigquery.LoadJobConfig() table_id = ( f"{self.gcp_project}.{self.bigquery_dataset}.correctness_{int(time.time())}" ) job = self.client.load_table_from_dataframe(df, table_id, job_config=job_config) job.result() # create FeatureView fv = FeatureView( name="test_bq_correctness", entities=["driver_id"], features=[Feature("value", ValueType.FLOAT)], ttl=timedelta(minutes=5), input=BigQuerySource( event_timestamp_column="ts", table_ref=table_id, created_timestamp_column="created_ts", field_mapping={ "ts_1": "ts", "id": "driver_id" }, date_partition_column="", ), ) config = RepoConfig( metadata_store="./metadata.db", project="default", provider="gcp", online_store=OnlineStoreConfig( local=LocalOnlineStoreConfig("online_store.db")), ) fs = FeatureStore(config=config) fs.apply([fv]) # run materialize() fs.materialize( ["test_bq_correctness"], datetime.utcnow() - timedelta(minutes=5), datetime.utcnow() - timedelta(minutes=0), ) # check result of materialize() entity_key = EntityKeyProto(entity_names=["driver_id"], entity_values=[ValueProto(int64_val=1)]) t, val = fs._get_provider().online_read("default", fv, entity_key) assert abs(val["value"].double_val - checked_value) < 1e-6