예제 #1
0
    def test_get_online_features(self, mocked_client, auth_metadata, mocker,
                                 get_online_features_fields_statuses):
        ROW_COUNT = 100

        mocked_client._serving_service_stub = Serving.ServingServiceStub(
            grpc.insecure_channel(""))

        request = GetOnlineFeaturesRequestV2(project="driver_project")
        request.features.extend([
            FeatureRefProto(feature_table="driver", name="age"),
            FeatureRefProto(feature_table="driver", name="rating"),
            FeatureRefProto(feature_table="driver", name="null_value"),
        ])

        receive_response = GetOnlineFeaturesResponse()
        entity_rows = []
        for row_number in range(0, ROW_COUNT):
            fields = get_online_features_fields_statuses[row_number][0]
            statuses = get_online_features_fields_statuses[row_number][1]
            request.entity_rows.append(
                GetOnlineFeaturesRequestV2.EntityRow(
                    fields={
                        "driver_id": ValueProto.Value(int64_val=row_number)
                    }))
            entity_rows.append(
                {"driver_id": ValueProto.Value(int64_val=row_number)})
            receive_response.field_values.append(
                GetOnlineFeaturesResponse.FieldValues(fields=fields,
                                                      statuses=statuses))

        mocker.patch.object(
            mocked_client._serving_service_stub,
            "GetOnlineFeaturesV2",
            return_value=receive_response,
        )
        got_response = mocked_client.get_online_features(
            entity_rows=entity_rows,
            feature_refs=["driver:age", "driver:rating", "driver:null_value"],
            project="driver_project",
        )  # type: GetOnlineFeaturesResponse
        mocked_client._serving_service_stub.GetOnlineFeaturesV2.assert_called_with(
            request, metadata=auth_metadata, timeout=10)

        got_fields = got_response.field_values[1].fields
        got_statuses = got_response.field_values[1].statuses
        assert (got_fields["driver_id"] == ValueProto.Value(int64_val=1)
                and got_statuses["driver_id"]
                == GetOnlineFeaturesResponse.FieldStatus.PRESENT
                and got_fields["driver:age"] == ValueProto.Value(int64_val=1)
                and got_statuses["driver:age"]
                == GetOnlineFeaturesResponse.FieldStatus.PRESENT and
                got_fields["driver:rating"] == ValueProto.Value(string_val="9")
                and got_statuses["driver:rating"]
                == GetOnlineFeaturesResponse.FieldStatus.PRESENT
                and got_fields["driver:null_value"] == ValueProto.Value()
                and got_statuses["driver:null_value"]
                == GetOnlineFeaturesResponse.FieldStatus.NULL_VALUE)
예제 #2
0
    def from_str(cls, feature_ref_str: str, ignore_project: bool = False):
        """
        Parse the given string feature reference into FeatureRef model
        String feature reference should be in the format feature_set:feature.
        Where "feature_set" and "name" are the feature_set name and feature name
        respectively.

        Args:
            feature_ref_str: String representation of the feature reference
            ignore_project: Ignore projects in given string feature reference
                            instead throwing an error

        Returns:
            FeatureRef that refers to the given feature
        """
        proto = FeatureRefProto()
        if "/" in feature_ref_str:
            if ignore_project:
                _, feature_ref_str = feature_ref_str.split("/")
            else:
                raise ValueError(
                    f"Unsupported feature reference: {feature_ref_str}")

        # parse feature set name if specified
        if ":" in feature_ref_str:
            proto.feature_set, feature_ref_str = feature_ref_str.split(":")

        proto.name = feature_ref_str
        return cls.from_proto(proto)
예제 #3
0
    def test_get_online_features(self, mocked_client, mocker):
        ROW_COUNT = 300

        mocked_client._serving_service_stub = Serving.ServingServiceStub(
            grpc.insecure_channel(""))

        def int_val(x):
            return ValueProto.Value(int64_val=x)

        request = GetOnlineFeaturesRequest()
        request.features.extend([
            FeatureRefProto(project="driver_project",
                            feature_set="driver",
                            name="age"),
            FeatureRefProto(project="driver_project", name="rating"),
        ])
        recieve_response = GetOnlineFeaturesResponse()
        for row_number in range(1, ROW_COUNT + 1):
            request.entity_rows.append(
                GetOnlineFeaturesRequest.EntityRow(
                    fields={"driver_id": int_val(row_number)})),
            field_values = GetOnlineFeaturesResponse.FieldValues(
                fields={
                    "driver_id": int_val(row_number),
                    "driver_project/driver:age": int_val(1),
                    "driver_project/rating": int_val(9),
                })
            recieve_response.field_values.append(field_values)

        mocker.patch.object(
            mocked_client._serving_service_stub,
            "GetOnlineFeatures",
            return_value=recieve_response,
        )
        got_response = mocked_client.get_online_features(
            entity_rows=request.entity_rows,
            feature_refs=["driver:age", "rating"],
            project="driver_project",
        )  # type: GetOnlineFeaturesResponse
        mocked_client._serving_service_stub.GetOnlineFeatures.assert_called_with(
            request)

        got_fields = got_response.field_values[0].fields
        assert (got_fields["driver_id"] == int_val(1)
                and got_fields["driver:age"] == int_val(1)
                and got_fields["rating"] == int_val(9))
예제 #4
0
    def from_str(cls, feature_ref_str: str):
        """
        Parse the given string feature reference into FeatureRef model
        String feature reference should be in the format feature_table:feature.
        Where "feature_table" and "name" are the feature_table name and feature name
        respectively.
        Args:
            feature_ref_str: String representation of the feature reference
        Returns:
            FeatureRef that refers to the given feature
        """
        proto = FeatureRefProto()

        # parse feature table name if specified
        if ":" in feature_ref_str:
            proto.feature_table, proto.name = feature_ref_str.split(":")
        else:
            raise ValueError(
                f"Unsupported feature reference: {feature_ref_str} - Feature reference string should be in the form [featuretable_name:featurename]"
            )

        return cls.from_proto(proto)
예제 #5
0
 def __init__(self, name: str, feature_table: str = None):
     self.proto = FeatureRefProto(name=name, feature_table=feature_table)
예제 #6
0
 def __init__(self, name: str, feature_set: str = None):
     self.proto = FeatureRefProto(name=name, feature_set=feature_set)
예제 #7
0
파일: test_client.py 프로젝트: vjrkr/feast
    def test_get_online_features(self, mocked_client, auth_metadata, mocker):
        ROW_COUNT = 300

        mocked_client._serving_service_stub = Serving.ServingServiceStub(
            grpc.insecure_channel(""))

        def int_val(x):
            return ValueProto.Value(int64_val=x)

        request = GetOnlineFeaturesRequest(project="driver_project")
        request.features.extend([
            FeatureRefProto(feature_set="driver", name="age"),
            FeatureRefProto(name="rating"),
            FeatureRefProto(name="null_value"),
        ])
        recieve_response = GetOnlineFeaturesResponse()
        entity_rows = []
        for row_number in range(1, ROW_COUNT + 1):
            request.entity_rows.append(
                GetOnlineFeaturesRequest.EntityRow(
                    fields={"driver_id": int_val(row_number)}))
            entity_rows.append({"driver_id": int_val(row_number)})
            field_values = GetOnlineFeaturesResponse.FieldValues(
                fields={
                    "driver_id": int_val(row_number),
                    "driver:age": int_val(1),
                    "rating": int_val(9),
                    "null_value": ValueProto.Value(),
                },
                statuses={
                    "driver_id": GetOnlineFeaturesResponse.FieldStatus.PRESENT,
                    "driver:age":
                    GetOnlineFeaturesResponse.FieldStatus.PRESENT,
                    "rating": GetOnlineFeaturesResponse.FieldStatus.PRESENT,
                    "null_value":
                    GetOnlineFeaturesResponse.FieldStatus.NULL_VALUE,
                },
            )
            recieve_response.field_values.append(field_values)

        mocker.patch.object(
            mocked_client._serving_service_stub,
            "GetOnlineFeatures",
            return_value=recieve_response,
        )
        got_response = mocked_client.get_online_features(
            entity_rows=entity_rows,
            feature_refs=["driver:age", "rating", "null_value"],
            project="driver_project",
        )  # type: GetOnlineFeaturesResponse
        mocked_client._serving_service_stub.GetOnlineFeatures.assert_called_with(
            request, metadata=auth_metadata)

        got_fields = got_response.field_values[0].fields
        got_statuses = got_response.field_values[0].statuses
        assert (got_fields["driver_id"] == int_val(1)
                and got_statuses["driver_id"]
                == GetOnlineFeaturesResponse.FieldStatus.PRESENT
                and got_fields["driver:age"] == int_val(1)
                and got_statuses["driver:age"]
                == GetOnlineFeaturesResponse.FieldStatus.PRESENT
                and got_fields["rating"] == int_val(9)
                and got_statuses["rating"]
                == GetOnlineFeaturesResponse.FieldStatus.PRESENT
                and got_fields["null_value"] == ValueProto.Value()
                and got_statuses["null_value"]
                == GetOnlineFeaturesResponse.FieldStatus.NULL_VALUE)