def test_get_online_features(self, mocked_client, auth_metadata, mocker, get_online_features_fields_statuses): ROW_COUNT = 100 mocked_client._serving_service_stub = Serving.ServingServiceStub( grpc.insecure_channel("")) request = GetOnlineFeaturesRequestV2(project="driver_project") request.features.extend([ FeatureRefProto(feature_table="driver", name="age"), FeatureRefProto(feature_table="driver", name="rating"), FeatureRefProto(feature_table="driver", name="null_value"), ]) receive_response = GetOnlineFeaturesResponse() entity_rows = [] for row_number in range(0, ROW_COUNT): fields = get_online_features_fields_statuses[row_number][0] statuses = get_online_features_fields_statuses[row_number][1] request.entity_rows.append( GetOnlineFeaturesRequestV2.EntityRow( fields={ "driver_id": ValueProto.Value(int64_val=row_number) })) entity_rows.append( {"driver_id": ValueProto.Value(int64_val=row_number)}) receive_response.field_values.append( GetOnlineFeaturesResponse.FieldValues(fields=fields, statuses=statuses)) mocker.patch.object( mocked_client._serving_service_stub, "GetOnlineFeaturesV2", return_value=receive_response, ) got_response = mocked_client.get_online_features( entity_rows=entity_rows, feature_refs=["driver:age", "driver:rating", "driver:null_value"], project="driver_project", ) # type: GetOnlineFeaturesResponse mocked_client._serving_service_stub.GetOnlineFeaturesV2.assert_called_with( request, metadata=auth_metadata, timeout=10) got_fields = got_response.field_values[1].fields got_statuses = got_response.field_values[1].statuses assert (got_fields["driver_id"] == ValueProto.Value(int64_val=1) and got_statuses["driver_id"] == GetOnlineFeaturesResponse.FieldStatus.PRESENT and got_fields["driver:age"] == ValueProto.Value(int64_val=1) and got_statuses["driver:age"] == GetOnlineFeaturesResponse.FieldStatus.PRESENT and got_fields["driver:rating"] == ValueProto.Value(string_val="9") and got_statuses["driver:rating"] == GetOnlineFeaturesResponse.FieldStatus.PRESENT and got_fields["driver:null_value"] == ValueProto.Value() and got_statuses["driver:null_value"] == GetOnlineFeaturesResponse.FieldStatus.NULL_VALUE)
def get_online_features( self, feature_refs: List[str], entity_rows: List[Dict[str, Any]], project: Optional[str] = None, ) -> OnlineResponse: """ Retrieves the latest online feature data from Feast Serving. Args: feature_refs: List of feature references that will be returned for each entity. Each feature reference should have the following format: "feature_table:feature" where "feature_table" & "feature" refer to the feature and feature table names respectively. Only the feature name is required. entity_rows: A list of dictionaries where each key-value is an entity-name, entity-value pair. project: Optionally specify the the project override. If specified, uses given project for retrieval. Overrides the projects specified in Feature References if also are specified. Returns: GetOnlineFeaturesResponse containing the feature data in records. Each EntityRow provided will yield one record, which contains data fields with data value and field status metadata (if included). Examples: >>> from feast import Client >>> >>> feast_client = Client(core_url="localhost:6565", serving_url="localhost:6566") >>> feature_refs = ["sales:daily_transactions"] >>> entity_rows = [{"customer_id": 0},{"customer_id": 1}] >>> >>> online_response = feast_client.get_online_features( >>> feature_refs, entity_rows, project="my_project") >>> online_response_dict = online_response.to_dict() >>> print(online_response_dict) {'sales:daily_transactions': [1.1,1.2], 'sales:customer_id': [0,1]} """ if self._telemetry_enabled: if self._telemetry_counter["get_online_features"] % 1000 == 0: log_usage( "get_online_features", self._telemetry_id, datetime.utcnow(), self.version(sdk_only=True), ) self._telemetry_counter["get_online_features"] += 1 try: response = self._serving_service.GetOnlineFeaturesV2( GetOnlineFeaturesRequestV2( features=_build_feature_references( feature_ref_strs=feature_refs), entity_rows=_infer_online_entity_rows(entity_rows), project=project if project is not None else self.project, ), timeout=self._config.getint(opt.GRPC_CONNECTION_TIMEOUT), metadata=self._get_grpc_metadata(), ) except grpc.RpcError as e: raise grpc.RpcError(e.details()) response = OnlineResponse(response) return response
def _infer_online_entity_rows( entity_rows: List[Dict[str, Any]] ) -> List[GetOnlineFeaturesRequestV2.EntityRow]: """ Builds a list of EntityRow protos from Python native type format passed by user. Args: entity_rows: A list of dictionaries where each key-value is an entity-name, entity-value pair. Returns: A list of EntityRow protos parsed from args. """ entity_rows_dicts = cast(List[Dict[str, Any]], entity_rows) entity_row_list = [] entity_type_map = dict() for entity in entity_rows_dicts: fields = {} for key, value in entity.items(): # Allow for feast.types.Value if isinstance(value, Value): proto_value = value else: # Infer the specific type for this row current_dtype = python_type_to_feast_value_type(name=key, value=value) if key not in entity_type_map: entity_type_map[key] = current_dtype else: if current_dtype != entity_type_map[key]: raise TypeError( f"Input entity {key} has mixed types, {current_dtype} and {entity_type_map[key]}. That is not allowed. " ) proto_value = _python_value_to_proto_value(current_dtype, value) fields[key] = proto_value entity_row_list.append(GetOnlineFeaturesRequestV2.EntityRow(fields=fields)) return entity_row_list