def test_get_online_features(self, mocked_client, auth_metadata, mocker, get_online_features_fields_statuses): ROW_COUNT = 100 mocked_client._serving_service_stub = Serving.ServingServiceStub( grpc.insecure_channel("")) request = GetOnlineFeaturesRequestV2(project="driver_project") request.features.extend([ FeatureRefProto(feature_table="driver", name="age"), FeatureRefProto(feature_table="driver", name="rating"), FeatureRefProto(feature_table="driver", name="null_value"), ]) receive_response = GetOnlineFeaturesResponse() entity_rows = [] for row_number in range(0, ROW_COUNT): fields = get_online_features_fields_statuses[row_number][0] statuses = get_online_features_fields_statuses[row_number][1] request.entity_rows.append( GetOnlineFeaturesRequestV2.EntityRow( fields={ "driver_id": ValueProto.Value(int64_val=row_number) })) entity_rows.append( {"driver_id": ValueProto.Value(int64_val=row_number)}) receive_response.field_values.append( GetOnlineFeaturesResponse.FieldValues(fields=fields, statuses=statuses)) mocker.patch.object( mocked_client._serving_service_stub, "GetOnlineFeaturesV2", return_value=receive_response, ) got_response = mocked_client.get_online_features( entity_rows=entity_rows, feature_refs=["driver:age", "driver:rating", "driver:null_value"], project="driver_project", ) # type: GetOnlineFeaturesResponse mocked_client._serving_service_stub.GetOnlineFeaturesV2.assert_called_with( request, metadata=auth_metadata, timeout=10) got_fields = got_response.field_values[1].fields got_statuses = got_response.field_values[1].statuses assert (got_fields["driver_id"] == ValueProto.Value(int64_val=1) and got_statuses["driver_id"] == GetOnlineFeaturesResponse.FieldStatus.PRESENT and got_fields["driver:age"] == ValueProto.Value(int64_val=1) and got_statuses["driver:age"] == GetOnlineFeaturesResponse.FieldStatus.PRESENT and got_fields["driver:rating"] == ValueProto.Value(string_val="9") and got_statuses["driver:rating"] == GetOnlineFeaturesResponse.FieldStatus.PRESENT and got_fields["driver:null_value"] == ValueProto.Value() and got_statuses["driver:null_value"] == GetOnlineFeaturesResponse.FieldStatus.NULL_VALUE)
def _entity_row_to_field_values( row: GetOnlineFeaturesRequestV2.EntityRow, ) -> GetOnlineFeaturesResponse.FieldValues: result = GetOnlineFeaturesResponse.FieldValues() for k in row.fields: result.fields[k].CopyFrom(row.fields[k]) result.statuses[k] = GetOnlineFeaturesResponse.FieldStatus.PRESENT return result
def _augment_response_with_on_demand_transforms( self, feature_refs: List[str], full_feature_names: bool, initial_response: OnlineResponse, result_rows: List[GetOnlineFeaturesResponse.FieldValues], ) -> OnlineResponse: all_on_demand_feature_views = { view.name: view for view in self._registry.list_on_demand_feature_views( project=self.project, allow_cache=True) } all_odfv_feature_names = all_on_demand_feature_views.keys() if len(all_on_demand_feature_views) == 0: return initial_response initial_response_df = initial_response.to_df() odfv_feature_refs = defaultdict(list) for feature_ref in feature_refs: view_name, feature_name = feature_ref.split(":") if view_name in all_odfv_feature_names: odfv_feature_refs[view_name].append(feature_name) # Apply on demand transformations for odfv_name, _feature_refs in odfv_feature_refs.items(): odfv = all_on_demand_feature_views[odfv_name] transformed_features_df = odfv.get_transformed_features_df( full_feature_names, initial_response_df) for row_idx in range(len(result_rows)): result_row = result_rows[row_idx] selected_subset = [ f for f in transformed_features_df.columns if f in _feature_refs ] for transformed_feature in selected_subset: transformed_feature_name = ( f"{odfv.name}__{transformed_feature}" if full_feature_names else transformed_feature) proto_value = python_value_to_proto_value( transformed_features_df[transformed_feature]. values[row_idx]) result_row.fields[transformed_feature_name].CopyFrom( proto_value) result_row.statuses[ transformed_feature_name] = GetOnlineFeaturesResponse.FieldStatus.PRESENT return OnlineResponse( GetOnlineFeaturesResponse(field_values=result_rows))
def _get_online_features( self, entity_rows: List[GetOnlineFeaturesRequestV2.EntityRow], feature_refs: List[str], project: str, ) -> GetOnlineFeaturesResponse: provider = self._get_provider() entity_keys = [] result_rows: List[GetOnlineFeaturesResponse.FieldValues] = [] for row in entity_rows: entity_keys.append(_entity_row_to_key(row)) result_rows.append(_entity_row_to_field_values(row)) registry = self._get_registry() all_feature_views = registry.list_feature_views( project=self.config.project) grouped_refs = _group_refs(feature_refs, all_feature_views) for table, requested_features in grouped_refs: read_rows = provider.online_read( project=project, table=table, entity_keys=entity_keys, ) for row_idx, read_row in enumerate(read_rows): row_ts, feature_data = read_row result_row = result_rows[row_idx] if feature_data is None: for feature_name in requested_features: feature_ref = f"{table.name}:{feature_name}" result_row.statuses[ feature_ref] = GetOnlineFeaturesResponse.FieldStatus.NOT_FOUND else: for feature_name in feature_data: feature_ref = f"{table.name}:{feature_name}" if feature_name in requested_features: result_row.fields[feature_ref].CopyFrom( feature_data[feature_name]) result_row.statuses[ feature_ref] = GetOnlineFeaturesResponse.FieldStatus.PRESENT return GetOnlineFeaturesResponse(field_values=result_rows)
def get_online_features( self, feature_refs: List[str], entity_rows: List[Dict[str, Any]], ) -> OnlineResponse: """ Retrieves the latest online feature data. Note: This method will download the full feature registry the first time it is run. If you are using a remote registry like GCS or S3 then that may take a few seconds. The registry remains cached up to a TTL duration (which can be set to infinitey). If the cached registry is stale (more time than the TTL has passed), then a new registry will be downloaded synchronously by this method. This download may introduce latency to online feature retrieval. In order to avoid synchronous downloads, please call refresh_registry() prior to the TTL being reached. Remember it is possible to set the cache TTL to infinity (cache forever). Args: feature_refs: List of feature references that will be returned for each entity. Each feature reference should have the following format: "feature_table:feature" where "feature_table" & "feature" refer to the feature and feature table names respectively. Only the feature name is required. entity_rows: A list of dictionaries where each key-value is an entity-name, entity-value pair. Returns: OnlineResponse containing the feature data in records. Examples: >>> from feast import FeatureStore >>> >>> store = FeatureStore(repo_path="...") >>> feature_refs = ["sales:daily_transactions"] >>> entity_rows = [{"customer_id": 0},{"customer_id": 1}] >>> >>> online_response = store.get_online_features( >>> feature_refs, entity_rows) >>> online_response_dict = online_response.to_dict() >>> print(online_response_dict) {'sales:daily_transactions': [1.1,1.2], 'sales:customer_id': [0,1]} """ provider = self._get_provider() entities = self.list_entities(allow_cache=True) entity_name_to_join_key_map = {} for entity in entities: entity_name_to_join_key_map[entity.name] = entity.join_key join_key_rows = [] for row in entity_rows: join_key_row = {} for entity_name, entity_value in row.items(): try: join_key = entity_name_to_join_key_map[entity_name] except KeyError: raise Exception( f"Entity {entity_name} does not exist in project {self.project}" ) join_key_row[join_key] = entity_value join_key_rows.append(join_key_row) entity_row_proto_list = _infer_online_entity_rows(join_key_rows) union_of_entity_keys = [] result_rows: List[GetOnlineFeaturesResponse.FieldValues] = [] for entity_row_proto in entity_row_proto_list: union_of_entity_keys.append(_entity_row_to_key(entity_row_proto)) result_rows.append(_entity_row_to_field_values(entity_row_proto)) all_feature_views = self._registry.list_feature_views( project=self.project, allow_cache=True) grouped_refs = _group_refs(feature_refs, all_feature_views) for table, requested_features in grouped_refs: entity_keys = _get_table_entity_keys(table, union_of_entity_keys, entity_name_to_join_key_map) read_rows = provider.online_read( project=self.project, table=table, entity_keys=entity_keys, ) for row_idx, read_row in enumerate(read_rows): row_ts, feature_data = read_row result_row = result_rows[row_idx] if feature_data is None: for feature_name in requested_features: feature_ref = f"{table.name}__{feature_name}" result_row.statuses[ feature_ref] = GetOnlineFeaturesResponse.FieldStatus.NOT_FOUND else: for feature_name in feature_data: feature_ref = f"{table.name}__{feature_name}" if feature_name in requested_features: result_row.fields[feature_ref].CopyFrom( feature_data[feature_name]) result_row.statuses[ feature_ref] = GetOnlineFeaturesResponse.FieldStatus.PRESENT return OnlineResponse( GetOnlineFeaturesResponse(field_values=result_rows))
def get_online_features( self, features: Union[List[str], FeatureService], entity_rows: List[Dict[str, Any]], feature_refs: Optional[List[str]] = None, full_feature_names: bool = False, ) -> OnlineResponse: """ Retrieves the latest online feature data. Note: This method will download the full feature registry the first time it is run. If you are using a remote registry like GCS or S3 then that may take a few seconds. The registry remains cached up to a TTL duration (which can be set to infinity). If the cached registry is stale (more time than the TTL has passed), then a new registry will be downloaded synchronously by this method. This download may introduce latency to online feature retrieval. In order to avoid synchronous downloads, please call refresh_registry() prior to the TTL being reached. Remember it is possible to set the cache TTL to infinity (cache forever). Args: features: List of feature references that will be returned for each entity. Each feature reference should have the following format: "feature_table:feature" where "feature_table" & "feature" refer to the feature and feature table names respectively. Only the feature name is required. entity_rows: A list of dictionaries where each key-value is an entity-name, entity-value pair. Returns: OnlineResponse containing the feature data in records. Raises: Exception: No entity with the specified name exists. Examples: Materialize all features into the online store over the interval from 3 hours ago to 10 minutes ago, and then retrieve these online features. >>> from feast import FeatureStore, RepoConfig >>> fs = FeatureStore(repo_path="feature_repo") >>> online_response = fs.get_online_features( ... features=[ ... "driver_hourly_stats:conv_rate", ... "driver_hourly_stats:acc_rate", ... "driver_hourly_stats:avg_daily_trips", ... ], ... entity_rows=[{"driver_id": 1001}, {"driver_id": 1002}, {"driver_id": 1003}, {"driver_id": 1004}], ... ) >>> online_response_dict = online_response.to_dict() """ _feature_refs = self._get_features(features, feature_refs) provider = self._get_provider() entities = self.list_entities(allow_cache=True) entity_name_to_join_key_map = {} for entity in entities: entity_name_to_join_key_map[entity.name] = entity.join_key join_key_rows = [] for row in entity_rows: join_key_row = {} for entity_name, entity_value in row.items(): try: join_key = entity_name_to_join_key_map[entity_name] except KeyError: raise EntityNotFoundException(entity_name, self.project) join_key_row[join_key] = entity_value join_key_rows.append(join_key_row) entity_row_proto_list = _infer_online_entity_rows(join_key_rows) union_of_entity_keys = [] result_rows: List[GetOnlineFeaturesResponse.FieldValues] = [] for entity_row_proto in entity_row_proto_list: union_of_entity_keys.append(_entity_row_to_key(entity_row_proto)) result_rows.append(_entity_row_to_field_values(entity_row_proto)) all_feature_views = self._registry.list_feature_views( project=self.project, allow_cache=True) _validate_feature_refs(_feature_refs, full_feature_names) grouped_refs = _group_feature_refs(_feature_refs, all_feature_views) for table, requested_features in grouped_refs: entity_keys = _get_table_entity_keys(table, union_of_entity_keys, entity_name_to_join_key_map) read_rows = provider.online_read( config=self.config, table=table, entity_keys=entity_keys, requested_features=requested_features, ) for row_idx, read_row in enumerate(read_rows): row_ts, feature_data = read_row result_row = result_rows[row_idx] if feature_data is None: for feature_name in requested_features: feature_ref = (f"{table.name}__{feature_name}" if full_feature_names else feature_name) result_row.statuses[ feature_ref] = GetOnlineFeaturesResponse.FieldStatus.NOT_FOUND else: for feature_name in feature_data: feature_ref = (f"{table.name}__{feature_name}" if full_feature_names else feature_name) if feature_name in requested_features: result_row.fields[feature_ref].CopyFrom( feature_data[feature_name]) result_row.statuses[ feature_ref] = GetOnlineFeaturesResponse.FieldStatus.PRESENT return OnlineResponse( GetOnlineFeaturesResponse(field_values=result_rows))
def get_online_features( self, features: Union[List[str], FeatureService], entity_rows: List[Dict[str, Any]], feature_refs: Optional[List[str]] = None, full_feature_names: bool = False, ) -> OnlineResponse: """ Retrieves the latest online feature data. Note: This method will download the full feature registry the first time it is run. If you are using a remote registry like GCS or S3 then that may take a few seconds. The registry remains cached up to a TTL duration (which can be set to infinity). If the cached registry is stale (more time than the TTL has passed), then a new registry will be downloaded synchronously by this method. This download may introduce latency to online feature retrieval. In order to avoid synchronous downloads, please call refresh_registry() prior to the TTL being reached. Remember it is possible to set the cache TTL to infinity (cache forever). Args: features: List of feature references that will be returned for each entity. Each feature reference should have the following format: "feature_table:feature" where "feature_table" & "feature" refer to the feature and feature table names respectively. Only the feature name is required. entity_rows: A list of dictionaries where each key-value is an entity-name, entity-value pair. Returns: OnlineResponse containing the feature data in records. Raises: Exception: No entity with the specified name exists. Examples: Materialize all features into the online store over the interval from 3 hours ago to 10 minutes ago, and then retrieve these online features. >>> from feast import FeatureStore, RepoConfig >>> fs = FeatureStore(repo_path="feature_repo") >>> online_response = fs.get_online_features( ... features=[ ... "driver_hourly_stats:conv_rate", ... "driver_hourly_stats:acc_rate", ... "driver_hourly_stats:avg_daily_trips", ... ], ... entity_rows=[{"driver_id": 1001}, {"driver_id": 1002}, {"driver_id": 1003}, {"driver_id": 1004}], ... ) >>> online_response_dict = online_response.to_dict() """ _feature_refs = self._get_features(features, feature_refs) all_feature_views = self._list_feature_views(allow_cache=True, hide_dummy_entity=False) all_on_demand_feature_views = self._registry.list_on_demand_feature_views( project=self.project, allow_cache=True) _validate_feature_refs(_feature_refs, full_feature_names) grouped_refs, grouped_odfv_refs = _group_feature_refs( _feature_refs, all_feature_views, all_on_demand_feature_views) if len(grouped_odfv_refs) > 0: log_event(UsageEvent.GET_ONLINE_FEATURES_WITH_ODFV) feature_views = list(view for view, _ in grouped_refs) entityless_case = DUMMY_ENTITY_NAME in [ entity_name for feature_view in feature_views for entity_name in feature_view.entities ] provider = self._get_provider() entities = self._list_entities(allow_cache=True, hide_dummy_entity=False) entity_name_to_join_key_map = {} for entity in entities: entity_name_to_join_key_map[entity.name] = entity.join_key needed_request_data_features = self._get_needed_request_data_features( grouped_odfv_refs) join_key_rows = [] request_data_features: Dict[str, List[Any]] = {} # Entity rows may be either entities or request data. for row in entity_rows: join_key_row = {} for entity_name, entity_value in row.items(): # Found request data if entity_name in needed_request_data_features: if entity_name not in request_data_features: request_data_features[entity_name] = [] request_data_features[entity_name].append(entity_value) continue try: join_key = entity_name_to_join_key_map[entity_name] except KeyError: raise EntityNotFoundException(entity_name, self.project) join_key_row[join_key] = entity_value if entityless_case: join_key_row[DUMMY_ENTITY_ID] = DUMMY_ENTITY_VAL if len(join_key_row) > 0: # May be empty if this entity row was request data join_key_rows.append(join_key_row) if len(needed_request_data_features) != len( request_data_features.keys()): raise RequestDataNotFoundInEntityRowsException( feature_names=needed_request_data_features) entity_row_proto_list = _infer_online_entity_rows(join_key_rows) union_of_entity_keys: List[EntityKeyProto] = [] result_rows: List[GetOnlineFeaturesResponse.FieldValues] = [] for entity_row_proto in entity_row_proto_list: # Create a list of entity keys to filter down for each feature view at lookup time. union_of_entity_keys.append(_entity_row_to_key(entity_row_proto)) # Also create entity values to append to the result result_rows.append(_entity_row_to_field_values(entity_row_proto)) # Add more feature values to the existing result rows for the request data features for feature_name, feature_values in request_data_features.items(): for row_idx, feature_value in enumerate(feature_values): result_row = result_rows[row_idx] result_row.fields[feature_name].CopyFrom( python_value_to_proto_value(feature_value)) result_row.statuses[ feature_name] = GetOnlineFeaturesResponse.FieldStatus.PRESENT for table, requested_features in grouped_refs: self._populate_result_rows_from_feature_view( entity_name_to_join_key_map, full_feature_names, provider, requested_features, result_rows, table, union_of_entity_keys, ) initial_response = OnlineResponse( GetOnlineFeaturesResponse(field_values=result_rows)) return self._augment_response_with_on_demand_transforms( _feature_refs, full_feature_names, initial_response, result_rows)