예제 #1
0
    def test_register_feature_set(self, sqlite_store):
        fs = FeatureSet("my-feature-set")
        fs.add(Feature(name="my-feature-1", dtype=ValueType.INT64))
        fs.add(Feature(name="my-feature-2", dtype=ValueType.INT64))
        fs.add(Entity(name="my-entity-1", dtype=ValueType.INT64))
        fs._version = 1
        feature_set_spec_proto = fs.to_proto().spec

        sqlite_store.register_feature_set(feature_set_spec_proto)
        feature_row = FeatureRowProto.FeatureRow(
            feature_set="feature_set_1",
            event_timestamp=Timestamp(),
            fields=[
                FieldProto.Field(
                    name="feature_1", value=ValueProto.Value(float_val=1.2)
                ),
                FieldProto.Field(
                    name="feature_2", value=ValueProto.Value(float_val=1.2)
                ),
                FieldProto.Field(
                    name="feature_3", value=ValueProto.Value(float_val=1.2)
                ),
            ],
        )
        # sqlite_store.upsert_feature_row(feature_set_proto, feature_row)
        assert True
예제 #2
0
    def test_get_online_features(self, mock_client, mocker):
        ROW_COUNT = 300

        mock_client._serving_service_stub = Serving.ServingServiceStub(
            grpc.insecure_channel("")
        )

        fields = dict()
        for feature_num in range(1, 10):
            fields["feature_set_1:1:feature_" + str(feature_num)] = ValueProto.Value(
                int64_val=feature_num
            )
        field_values = GetOnlineFeaturesResponse.FieldValues(fields=fields)

        response = GetOnlineFeaturesResponse()
        entity_rows = []
        for row_number in range(1, ROW_COUNT + 1):
            response.field_values.append(field_values)
            entity_rows.append(
                GetOnlineFeaturesRequest.EntityRow(
                    fields={"customer_id": ValueProto.Value(int64_val=row_number)}
                )
            )

        mocker.patch.object(
            mock_client._serving_service_stub,
            "GetOnlineFeatures",
            return_value=response,
        )

        response = mock_client.get_online_features(
            entity_rows=entity_rows,
            feature_ids=[
                "feature_set_1:1:feature_1",
                "feature_set_1:1:feature_2",
                "feature_set_1:1:feature_3",
                "feature_set_1:1:feature_4",
                "feature_set_1:1:feature_5",
                "feature_set_1:1:feature_6",
                "feature_set_1:1:feature_7",
                "feature_set_1:1:feature_8",
                "feature_set_1:1:feature_9",
            ],
        )  # type: GetOnlineFeaturesResponse

        assert (
            response.field_values[0].fields["feature_set_1:1:feature_1"].int64_val == 1
            and response.field_values[0].fields["feature_set_1:1:feature_9"].int64_val
            == 9
        )
예제 #3
0
 def get_online_features_fields_statuses(self):
     ROW_COUNT = 100
     fields_statuses_tuple_list = []
     for row_number in range(0, ROW_COUNT):
         fields_statuses_tuple_list.append(
             (
                 {
                     "driver_id": ValueProto.Value(int64_val=row_number),
                     "driver:age": ValueProto.Value(int64_val=1),
                     "driver:rating": ValueProto.Value(string_val="9"),
                     "driver:null_value": ValueProto.Value(),
                 },
                 {
                     "driver_id": GetOnlineFeaturesResponse.FieldStatus.PRESENT,
                     "driver:age": GetOnlineFeaturesResponse.FieldStatus.PRESENT,
                     "driver:rating": GetOnlineFeaturesResponse.FieldStatus.PRESENT,
                     "driver:null_value": GetOnlineFeaturesResponse.FieldStatus.NULL_VALUE,
                 },
             )
         )
     return fields_statuses_tuple_list
예제 #4
0
    def test_get_online_features(self, mocked_client, auth_metadata, mocker,
                                 get_online_features_fields_statuses):
        ROW_COUNT = 100

        mocked_client._serving_service_stub = Serving.ServingServiceStub(
            grpc.insecure_channel(""))

        request = GetOnlineFeaturesRequestV2(project="driver_project")
        request.features.extend([
            FeatureRefProto(feature_table="driver", name="age"),
            FeatureRefProto(feature_table="driver", name="rating"),
            FeatureRefProto(feature_table="driver", name="null_value"),
        ])

        receive_response = GetOnlineFeaturesResponse()
        entity_rows = []
        for row_number in range(0, ROW_COUNT):
            fields = get_online_features_fields_statuses[row_number][0]
            statuses = get_online_features_fields_statuses[row_number][1]
            request.entity_rows.append(
                GetOnlineFeaturesRequestV2.EntityRow(
                    fields={
                        "driver_id": ValueProto.Value(int64_val=row_number)
                    }))
            entity_rows.append(
                {"driver_id": ValueProto.Value(int64_val=row_number)})
            receive_response.field_values.append(
                GetOnlineFeaturesResponse.FieldValues(fields=fields,
                                                      statuses=statuses))

        mocker.patch.object(
            mocked_client._serving_service_stub,
            "GetOnlineFeaturesV2",
            return_value=receive_response,
        )
        got_response = mocked_client.get_online_features(
            entity_rows=entity_rows,
            feature_refs=["driver:age", "driver:rating", "driver:null_value"],
            project="driver_project",
        )  # type: GetOnlineFeaturesResponse
        mocked_client._serving_service_stub.GetOnlineFeaturesV2.assert_called_with(
            request, metadata=auth_metadata, timeout=10)

        got_fields = got_response.field_values[1].fields
        got_statuses = got_response.field_values[1].statuses
        assert (got_fields["driver_id"] == ValueProto.Value(int64_val=1)
                and got_statuses["driver_id"]
                == GetOnlineFeaturesResponse.FieldStatus.PRESENT
                and got_fields["driver:age"] == ValueProto.Value(int64_val=1)
                and got_statuses["driver:age"]
                == GetOnlineFeaturesResponse.FieldStatus.PRESENT and
                got_fields["driver:rating"] == ValueProto.Value(string_val="9")
                and got_statuses["driver:rating"]
                == GetOnlineFeaturesResponse.FieldStatus.PRESENT
                and got_fields["driver:null_value"] == ValueProto.Value()
                and got_statuses["driver:null_value"]
                == GetOnlineFeaturesResponse.FieldStatus.NULL_VALUE)
예제 #5
0
 def int_val(x):
     return ValueProto.Value(int64_val=x)
예제 #6
0
    def GetOnlineFeatures(self, request: GetOnlineFeaturesRequest, context):

        response = GetOnlineFeaturesResponse(feature_data_sets=[
            GetOnlineFeaturesResponse.FeatureDataSet(
                name="feature_set_1",
                feature_rows=[
                    FeatureRowProto.FeatureRow(
                        feature_set="feature_set_1",
                        event_timestamp=Timestamp(),
                        fields=[
                            FieldProto.Field(
                                name="feature_1",
                                value=ValueProto.Value(float_val=1.2),
                            ),
                            FieldProto.Field(
                                name="feature_2",
                                value=ValueProto.Value(float_val=1.2),
                            ),
                            FieldProto.Field(
                                name="feature_3",
                                value=ValueProto.Value(float_val=1.2),
                            ),
                        ],
                    ),
                    FeatureRowProto.FeatureRow(
                        feature_set="feature_set_1",
                        event_timestamp=Timestamp(),
                        fields=[
                            FieldProto.Field(
                                name="feature_1",
                                value=ValueProto.Value(float_val=1.2),
                            ),
                            FieldProto.Field(
                                name="feature_2",
                                value=ValueProto.Value(float_val=1.2),
                            ),
                            FieldProto.Field(
                                name="feature_3",
                                value=ValueProto.Value(float_val=1.2),
                            ),
                        ],
                    ),
                    FeatureRowProto.FeatureRow(
                        feature_set="feature_set_1",
                        event_timestamp=Timestamp(),
                        fields=[
                            FieldProto.Field(
                                name="feature_1",
                                value=ValueProto.Value(float_val=1.2),
                            ),
                            FieldProto.Field(
                                name="feature_2",
                                value=ValueProto.Value(float_val=1.2),
                            ),
                            FieldProto.Field(
                                name="feature_3",
                                value=ValueProto.Value(float_val=1.2),
                            ),
                        ],
                    ),
                ],
            )
        ])

        return response
예제 #7
0
    def test_get_online_features(self, mocked_client, mocker):
        ROW_COUNT = 300

        mocked_client._serving_service_stub = Serving.ServingServiceStub(
            grpc.insecure_channel(""))

        def int_val(x):
            return ValueProto.Value(int64_val=x)

        request = GetOnlineFeaturesRequest()
        request.features.extend([
            FeatureRefProto(project="driver_project",
                            feature_set="driver",
                            name="age"),
            FeatureRefProto(project="driver_project", name="rating"),
            FeatureRefProto(project="driver_project", name="null_value"),
        ])
        recieve_response = GetOnlineFeaturesResponse()
        for row_number in range(1, ROW_COUNT + 1):
            request.entity_rows.append(
                GetOnlineFeaturesRequest.EntityRow(
                    fields={"driver_id": int_val(row_number)})),
            field_values = GetOnlineFeaturesResponse.FieldValues(
                fields={
                    "driver_id": int_val(row_number),
                    "driver_project/driver:age": int_val(1),
                    "driver_project/rating": int_val(9),
                    "driver_project/null_value": ValueProto.Value(),
                },
                statuses={
                    "driver_id":
                    GetOnlineFeaturesResponse.FieldStatus.PRESENT,
                    "driver_project/driver:age":
                    GetOnlineFeaturesResponse.FieldStatus.PRESENT,
                    "driver_project/rating":
                    GetOnlineFeaturesResponse.FieldStatus.PRESENT,
                    "driver_project/null_value":
                    GetOnlineFeaturesResponse.FieldStatus.NULL_VALUE,
                },
            )
            recieve_response.field_values.append(field_values)

        mocker.patch.object(
            mocked_client._serving_service_stub,
            "GetOnlineFeatures",
            return_value=recieve_response,
        )
        got_response = mocked_client.get_online_features(
            entity_rows=request.entity_rows,
            feature_refs=["driver:age", "rating", "null_value"],
            project="driver_project",
        )  # type: GetOnlineFeaturesResponse
        mocked_client._serving_service_stub.GetOnlineFeatures.assert_called_with(
            request)

        got_fields = got_response.field_values[0].fields
        got_statuses = got_response.field_values[0].statuses
        assert (got_fields["driver_id"] == int_val(1)
                and got_statuses["driver_id"]
                == GetOnlineFeaturesResponse.FieldStatus.PRESENT
                and got_fields["driver:age"] == int_val(1)
                and got_statuses["driver:age"]
                == GetOnlineFeaturesResponse.FieldStatus.PRESENT
                and got_fields["rating"] == int_val(9)
                and got_statuses["rating"]
                == GetOnlineFeaturesResponse.FieldStatus.PRESENT
                and got_fields["null_value"] == ValueProto.Value()
                and got_statuses["null_value"]
                == GetOnlineFeaturesResponse.FieldStatus.NULL_VALUE)