def test_register_feature_set(self, sqlite_store): fs = FeatureSet("my-feature-set") fs.add(Feature(name="my-feature-1", dtype=ValueType.INT64)) fs.add(Feature(name="my-feature-2", dtype=ValueType.INT64)) fs.add(Entity(name="my-entity-1", dtype=ValueType.INT64)) fs._version = 1 feature_set_spec_proto = fs.to_proto().spec sqlite_store.register_feature_set(feature_set_spec_proto) feature_row = FeatureRowProto.FeatureRow( feature_set="feature_set_1", event_timestamp=Timestamp(), fields=[ FieldProto.Field( name="feature_1", value=ValueProto.Value(float_val=1.2) ), FieldProto.Field( name="feature_2", value=ValueProto.Value(float_val=1.2) ), FieldProto.Field( name="feature_3", value=ValueProto.Value(float_val=1.2) ), ], ) # sqlite_store.upsert_feature_row(feature_set_proto, feature_row) assert True
def test_get_online_features(self, mock_client, mocker): ROW_COUNT = 300 mock_client._serving_service_stub = Serving.ServingServiceStub( grpc.insecure_channel("") ) fields = dict() for feature_num in range(1, 10): fields["feature_set_1:1:feature_" + str(feature_num)] = ValueProto.Value( int64_val=feature_num ) field_values = GetOnlineFeaturesResponse.FieldValues(fields=fields) response = GetOnlineFeaturesResponse() entity_rows = [] for row_number in range(1, ROW_COUNT + 1): response.field_values.append(field_values) entity_rows.append( GetOnlineFeaturesRequest.EntityRow( fields={"customer_id": ValueProto.Value(int64_val=row_number)} ) ) mocker.patch.object( mock_client._serving_service_stub, "GetOnlineFeatures", return_value=response, ) response = mock_client.get_online_features( entity_rows=entity_rows, feature_ids=[ "feature_set_1:1:feature_1", "feature_set_1:1:feature_2", "feature_set_1:1:feature_3", "feature_set_1:1:feature_4", "feature_set_1:1:feature_5", "feature_set_1:1:feature_6", "feature_set_1:1:feature_7", "feature_set_1:1:feature_8", "feature_set_1:1:feature_9", ], ) # type: GetOnlineFeaturesResponse assert ( response.field_values[0].fields["feature_set_1:1:feature_1"].int64_val == 1 and response.field_values[0].fields["feature_set_1:1:feature_9"].int64_val == 9 )
def get_online_features_fields_statuses(self): ROW_COUNT = 100 fields_statuses_tuple_list = [] for row_number in range(0, ROW_COUNT): fields_statuses_tuple_list.append( ( { "driver_id": ValueProto.Value(int64_val=row_number), "driver:age": ValueProto.Value(int64_val=1), "driver:rating": ValueProto.Value(string_val="9"), "driver:null_value": ValueProto.Value(), }, { "driver_id": GetOnlineFeaturesResponse.FieldStatus.PRESENT, "driver:age": GetOnlineFeaturesResponse.FieldStatus.PRESENT, "driver:rating": GetOnlineFeaturesResponse.FieldStatus.PRESENT, "driver:null_value": GetOnlineFeaturesResponse.FieldStatus.NULL_VALUE, }, ) ) return fields_statuses_tuple_list
def test_get_online_features(self, mocked_client, auth_metadata, mocker, get_online_features_fields_statuses): ROW_COUNT = 100 mocked_client._serving_service_stub = Serving.ServingServiceStub( grpc.insecure_channel("")) request = GetOnlineFeaturesRequestV2(project="driver_project") request.features.extend([ FeatureRefProto(feature_table="driver", name="age"), FeatureRefProto(feature_table="driver", name="rating"), FeatureRefProto(feature_table="driver", name="null_value"), ]) receive_response = GetOnlineFeaturesResponse() entity_rows = [] for row_number in range(0, ROW_COUNT): fields = get_online_features_fields_statuses[row_number][0] statuses = get_online_features_fields_statuses[row_number][1] request.entity_rows.append( GetOnlineFeaturesRequestV2.EntityRow( fields={ "driver_id": ValueProto.Value(int64_val=row_number) })) entity_rows.append( {"driver_id": ValueProto.Value(int64_val=row_number)}) receive_response.field_values.append( GetOnlineFeaturesResponse.FieldValues(fields=fields, statuses=statuses)) mocker.patch.object( mocked_client._serving_service_stub, "GetOnlineFeaturesV2", return_value=receive_response, ) got_response = mocked_client.get_online_features( entity_rows=entity_rows, feature_refs=["driver:age", "driver:rating", "driver:null_value"], project="driver_project", ) # type: GetOnlineFeaturesResponse mocked_client._serving_service_stub.GetOnlineFeaturesV2.assert_called_with( request, metadata=auth_metadata, timeout=10) got_fields = got_response.field_values[1].fields got_statuses = got_response.field_values[1].statuses assert (got_fields["driver_id"] == ValueProto.Value(int64_val=1) and got_statuses["driver_id"] == GetOnlineFeaturesResponse.FieldStatus.PRESENT and got_fields["driver:age"] == ValueProto.Value(int64_val=1) and got_statuses["driver:age"] == GetOnlineFeaturesResponse.FieldStatus.PRESENT and got_fields["driver:rating"] == ValueProto.Value(string_val="9") and got_statuses["driver:rating"] == GetOnlineFeaturesResponse.FieldStatus.PRESENT and got_fields["driver:null_value"] == ValueProto.Value() and got_statuses["driver:null_value"] == GetOnlineFeaturesResponse.FieldStatus.NULL_VALUE)
def int_val(x): return ValueProto.Value(int64_val=x)
def GetOnlineFeatures(self, request: GetOnlineFeaturesRequest, context): response = GetOnlineFeaturesResponse(feature_data_sets=[ GetOnlineFeaturesResponse.FeatureDataSet( name="feature_set_1", feature_rows=[ FeatureRowProto.FeatureRow( feature_set="feature_set_1", event_timestamp=Timestamp(), fields=[ FieldProto.Field( name="feature_1", value=ValueProto.Value(float_val=1.2), ), FieldProto.Field( name="feature_2", value=ValueProto.Value(float_val=1.2), ), FieldProto.Field( name="feature_3", value=ValueProto.Value(float_val=1.2), ), ], ), FeatureRowProto.FeatureRow( feature_set="feature_set_1", event_timestamp=Timestamp(), fields=[ FieldProto.Field( name="feature_1", value=ValueProto.Value(float_val=1.2), ), FieldProto.Field( name="feature_2", value=ValueProto.Value(float_val=1.2), ), FieldProto.Field( name="feature_3", value=ValueProto.Value(float_val=1.2), ), ], ), FeatureRowProto.FeatureRow( feature_set="feature_set_1", event_timestamp=Timestamp(), fields=[ FieldProto.Field( name="feature_1", value=ValueProto.Value(float_val=1.2), ), FieldProto.Field( name="feature_2", value=ValueProto.Value(float_val=1.2), ), FieldProto.Field( name="feature_3", value=ValueProto.Value(float_val=1.2), ), ], ), ], ) ]) return response
def test_get_online_features(self, mocked_client, mocker): ROW_COUNT = 300 mocked_client._serving_service_stub = Serving.ServingServiceStub( grpc.insecure_channel("")) def int_val(x): return ValueProto.Value(int64_val=x) request = GetOnlineFeaturesRequest() request.features.extend([ FeatureRefProto(project="driver_project", feature_set="driver", name="age"), FeatureRefProto(project="driver_project", name="rating"), FeatureRefProto(project="driver_project", name="null_value"), ]) recieve_response = GetOnlineFeaturesResponse() for row_number in range(1, ROW_COUNT + 1): request.entity_rows.append( GetOnlineFeaturesRequest.EntityRow( fields={"driver_id": int_val(row_number)})), field_values = GetOnlineFeaturesResponse.FieldValues( fields={ "driver_id": int_val(row_number), "driver_project/driver:age": int_val(1), "driver_project/rating": int_val(9), "driver_project/null_value": ValueProto.Value(), }, statuses={ "driver_id": GetOnlineFeaturesResponse.FieldStatus.PRESENT, "driver_project/driver:age": GetOnlineFeaturesResponse.FieldStatus.PRESENT, "driver_project/rating": GetOnlineFeaturesResponse.FieldStatus.PRESENT, "driver_project/null_value": GetOnlineFeaturesResponse.FieldStatus.NULL_VALUE, }, ) recieve_response.field_values.append(field_values) mocker.patch.object( mocked_client._serving_service_stub, "GetOnlineFeatures", return_value=recieve_response, ) got_response = mocked_client.get_online_features( entity_rows=request.entity_rows, feature_refs=["driver:age", "rating", "null_value"], project="driver_project", ) # type: GetOnlineFeaturesResponse mocked_client._serving_service_stub.GetOnlineFeatures.assert_called_with( request) got_fields = got_response.field_values[0].fields got_statuses = got_response.field_values[0].statuses assert (got_fields["driver_id"] == int_val(1) and got_statuses["driver_id"] == GetOnlineFeaturesResponse.FieldStatus.PRESENT and got_fields["driver:age"] == int_val(1) and got_statuses["driver:age"] == GetOnlineFeaturesResponse.FieldStatus.PRESENT and got_fields["rating"] == int_val(9) and got_statuses["rating"] == GetOnlineFeaturesResponse.FieldStatus.PRESENT and got_fields["null_value"] == ValueProto.Value() and got_statuses["null_value"] == GetOnlineFeaturesResponse.FieldStatus.NULL_VALUE)