Ejemplo n.º 1
0
    def test_get_online_features(self, mocked_client, auth_metadata, mocker,
                                 get_online_features_fields_statuses):
        ROW_COUNT = 100

        mocked_client._serving_service_stub = Serving.ServingServiceStub(
            grpc.insecure_channel(""))

        request = GetOnlineFeaturesRequestV2(project="driver_project")
        request.features.extend([
            FeatureRefProto(feature_table="driver", name="age"),
            FeatureRefProto(feature_table="driver", name="rating"),
            FeatureRefProto(feature_table="driver", name="null_value"),
        ])

        receive_response = GetOnlineFeaturesResponse()
        entity_rows = []
        for row_number in range(0, ROW_COUNT):
            fields = get_online_features_fields_statuses[row_number][0]
            statuses = get_online_features_fields_statuses[row_number][1]
            request.entity_rows.append(
                GetOnlineFeaturesRequestV2.EntityRow(
                    fields={
                        "driver_id": ValueProto.Value(int64_val=row_number)
                    }))
            entity_rows.append(
                {"driver_id": ValueProto.Value(int64_val=row_number)})
            receive_response.field_values.append(
                GetOnlineFeaturesResponse.FieldValues(fields=fields,
                                                      statuses=statuses))

        mocker.patch.object(
            mocked_client._serving_service_stub,
            "GetOnlineFeaturesV2",
            return_value=receive_response,
        )
        got_response = mocked_client.get_online_features(
            entity_rows=entity_rows,
            feature_refs=["driver:age", "driver:rating", "driver:null_value"],
            project="driver_project",
        )  # type: GetOnlineFeaturesResponse
        mocked_client._serving_service_stub.GetOnlineFeaturesV2.assert_called_with(
            request, metadata=auth_metadata, timeout=10)

        got_fields = got_response.field_values[1].fields
        got_statuses = got_response.field_values[1].statuses
        assert (got_fields["driver_id"] == ValueProto.Value(int64_val=1)
                and got_statuses["driver_id"]
                == GetOnlineFeaturesResponse.FieldStatus.PRESENT
                and got_fields["driver:age"] == ValueProto.Value(int64_val=1)
                and got_statuses["driver:age"]
                == GetOnlineFeaturesResponse.FieldStatus.PRESENT and
                got_fields["driver:rating"] == ValueProto.Value(string_val="9")
                and got_statuses["driver:rating"]
                == GetOnlineFeaturesResponse.FieldStatus.PRESENT
                and got_fields["driver:null_value"] == ValueProto.Value()
                and got_statuses["driver:null_value"]
                == GetOnlineFeaturesResponse.FieldStatus.NULL_VALUE)
Ejemplo n.º 2
0
def _infer_online_entity_rows(
    entity_rows: List[Dict[str, Any]]
) -> List[GetOnlineFeaturesRequestV2.EntityRow]:
    """
    Builds a list of EntityRow protos from Python native type format passed by user.

    Args:
        entity_rows: A list of dictionaries where each key-value is an entity-name, entity-value pair.
    Returns:
        A list of EntityRow protos parsed from args.
    """

    entity_rows_dicts = cast(List[Dict[str, Any]], entity_rows)
    entity_row_list = []
    entity_type_map = dict()

    for entity in entity_rows_dicts:
        fields = {}
        for key, value in entity.items():
            # Allow for feast.types.Value
            if isinstance(value, Value):
                proto_value = value
            else:
                # Infer the specific type for this row
                current_dtype = python_type_to_feast_value_type(name=key, value=value)

                if key not in entity_type_map:
                    entity_type_map[key] = current_dtype
                else:
                    if current_dtype != entity_type_map[key]:
                        raise TypeError(
                            f"Input entity {key} has mixed types, {current_dtype} and {entity_type_map[key]}. That is not allowed. "
                        )
                proto_value = _python_value_to_proto_value(current_dtype, value)
            fields[key] = proto_value
        entity_row_list.append(GetOnlineFeaturesRequestV2.EntityRow(fields=fields))
    return entity_row_list