예제 #1
0
    def test_get_online_features(self, mocked_client, auth_metadata, mocker,
                                 get_online_features_fields_statuses):
        ROW_COUNT = 100

        mocked_client._serving_service_stub = Serving.ServingServiceStub(
            grpc.insecure_channel(""))

        request = GetOnlineFeaturesRequestV2(project="driver_project")
        request.features.extend([
            FeatureRefProto(feature_table="driver", name="age"),
            FeatureRefProto(feature_table="driver", name="rating"),
            FeatureRefProto(feature_table="driver", name="null_value"),
        ])

        receive_response = GetOnlineFeaturesResponse()
        entity_rows = []
        for row_number in range(0, ROW_COUNT):
            fields = get_online_features_fields_statuses[row_number][0]
            statuses = get_online_features_fields_statuses[row_number][1]
            request.entity_rows.append(
                GetOnlineFeaturesRequestV2.EntityRow(
                    fields={
                        "driver_id": ValueProto.Value(int64_val=row_number)
                    }))
            entity_rows.append(
                {"driver_id": ValueProto.Value(int64_val=row_number)})
            receive_response.field_values.append(
                GetOnlineFeaturesResponse.FieldValues(fields=fields,
                                                      statuses=statuses))

        mocker.patch.object(
            mocked_client._serving_service_stub,
            "GetOnlineFeaturesV2",
            return_value=receive_response,
        )
        got_response = mocked_client.get_online_features(
            entity_rows=entity_rows,
            feature_refs=["driver:age", "driver:rating", "driver:null_value"],
            project="driver_project",
        )  # type: GetOnlineFeaturesResponse
        mocked_client._serving_service_stub.GetOnlineFeaturesV2.assert_called_with(
            request, metadata=auth_metadata, timeout=10)

        got_fields = got_response.field_values[1].fields
        got_statuses = got_response.field_values[1].statuses
        assert (got_fields["driver_id"] == ValueProto.Value(int64_val=1)
                and got_statuses["driver_id"]
                == GetOnlineFeaturesResponse.FieldStatus.PRESENT
                and got_fields["driver:age"] == ValueProto.Value(int64_val=1)
                and got_statuses["driver:age"]
                == GetOnlineFeaturesResponse.FieldStatus.PRESENT and
                got_fields["driver:rating"] == ValueProto.Value(string_val="9")
                and got_statuses["driver:rating"]
                == GetOnlineFeaturesResponse.FieldStatus.PRESENT
                and got_fields["driver:null_value"] == ValueProto.Value()
                and got_statuses["driver:null_value"]
                == GetOnlineFeaturesResponse.FieldStatus.NULL_VALUE)
예제 #2
0
    def get_online_features(
        self,
        feature_refs: List[str],
        entity_rows: List[Dict[str, Any]],
        project: Optional[str] = None,
    ) -> OnlineResponse:
        """
        Retrieves the latest online feature data from Feast Serving.
        Args:
            feature_refs: List of feature references that will be returned for each entity.
                Each feature reference should have the following format:
                "feature_table:feature" where "feature_table" & "feature" refer to
                the feature and feature table names respectively.
                Only the feature name is required.
            entity_rows: A list of dictionaries where each key-value is an entity-name, entity-value pair.
            project: Optionally specify the the project override. If specified, uses given project for retrieval.
                Overrides the projects specified in Feature References if also are specified.
        Returns:
            GetOnlineFeaturesResponse containing the feature data in records.
            Each EntityRow provided will yield one record, which contains
            data fields with data value and field status metadata (if included).
        Examples:
            >>> from feast import Client
            >>>
            >>> feast_client = Client(core_url="localhost:6565", serving_url="localhost:6566")
            >>> feature_refs = ["sales:daily_transactions"]
            >>> entity_rows = [{"customer_id": 0},{"customer_id": 1}]
            >>>
            >>> online_response = feast_client.get_online_features(
            >>>     feature_refs, entity_rows, project="my_project")
            >>> online_response_dict = online_response.to_dict()
            >>> print(online_response_dict)
            {'sales:daily_transactions': [1.1,1.2], 'sales:customer_id': [0,1]}
        """

        if self._telemetry_enabled:
            if self._telemetry_counter["get_online_features"] % 1000 == 0:
                log_usage(
                    "get_online_features",
                    self._telemetry_id,
                    datetime.utcnow(),
                    self.version(sdk_only=True),
                )
            self._telemetry_counter["get_online_features"] += 1
        try:
            response = self._serving_service.GetOnlineFeaturesV2(
                GetOnlineFeaturesRequestV2(
                    features=_build_feature_references(
                        feature_ref_strs=feature_refs),
                    entity_rows=_infer_online_entity_rows(entity_rows),
                    project=project if project is not None else self.project,
                ),
                timeout=self._config.getint(opt.GRPC_CONNECTION_TIMEOUT),
                metadata=self._get_grpc_metadata(),
            )
        except grpc.RpcError as e:
            raise grpc.RpcError(e.details())

        response = OnlineResponse(response)
        return response
예제 #3
0
def _infer_online_entity_rows(
    entity_rows: List[Dict[str, Any]]
) -> List[GetOnlineFeaturesRequestV2.EntityRow]:
    """
    Builds a list of EntityRow protos from Python native type format passed by user.

    Args:
        entity_rows: A list of dictionaries where each key-value is an entity-name, entity-value pair.
    Returns:
        A list of EntityRow protos parsed from args.
    """

    entity_rows_dicts = cast(List[Dict[str, Any]], entity_rows)
    entity_row_list = []
    entity_type_map = dict()

    for entity in entity_rows_dicts:
        fields = {}
        for key, value in entity.items():
            # Allow for feast.types.Value
            if isinstance(value, Value):
                proto_value = value
            else:
                # Infer the specific type for this row
                current_dtype = python_type_to_feast_value_type(name=key, value=value)

                if key not in entity_type_map:
                    entity_type_map[key] = current_dtype
                else:
                    if current_dtype != entity_type_map[key]:
                        raise TypeError(
                            f"Input entity {key} has mixed types, {current_dtype} and {entity_type_map[key]}. That is not allowed. "
                        )
                proto_value = _python_value_to_proto_value(current_dtype, value)
            fields[key] = proto_value
        entity_row_list.append(GetOnlineFeaturesRequestV2.EntityRow(fields=fields))
    return entity_row_list