Beispiel #1
0
    def test_get_online_features(self, mocked_client, auth_metadata, mocker,
                                 get_online_features_fields_statuses):
        ROW_COUNT = 100

        mocked_client._serving_service_stub = Serving.ServingServiceStub(
            grpc.insecure_channel(""))

        request = GetOnlineFeaturesRequestV2(project="driver_project")
        request.features.extend([
            FeatureRefProto(feature_table="driver", name="age"),
            FeatureRefProto(feature_table="driver", name="rating"),
            FeatureRefProto(feature_table="driver", name="null_value"),
        ])

        receive_response = GetOnlineFeaturesResponse()
        entity_rows = []
        for row_number in range(0, ROW_COUNT):
            fields = get_online_features_fields_statuses[row_number][0]
            statuses = get_online_features_fields_statuses[row_number][1]
            request.entity_rows.append(
                GetOnlineFeaturesRequestV2.EntityRow(
                    fields={
                        "driver_id": ValueProto.Value(int64_val=row_number)
                    }))
            entity_rows.append(
                {"driver_id": ValueProto.Value(int64_val=row_number)})
            receive_response.field_values.append(
                GetOnlineFeaturesResponse.FieldValues(fields=fields,
                                                      statuses=statuses))

        mocker.patch.object(
            mocked_client._serving_service_stub,
            "GetOnlineFeaturesV2",
            return_value=receive_response,
        )
        got_response = mocked_client.get_online_features(
            entity_rows=entity_rows,
            feature_refs=["driver:age", "driver:rating", "driver:null_value"],
            project="driver_project",
        )  # type: GetOnlineFeaturesResponse
        mocked_client._serving_service_stub.GetOnlineFeaturesV2.assert_called_with(
            request, metadata=auth_metadata, timeout=10)

        got_fields = got_response.field_values[1].fields
        got_statuses = got_response.field_values[1].statuses
        assert (got_fields["driver_id"] == ValueProto.Value(int64_val=1)
                and got_statuses["driver_id"]
                == GetOnlineFeaturesResponse.FieldStatus.PRESENT
                and got_fields["driver:age"] == ValueProto.Value(int64_val=1)
                and got_statuses["driver:age"]
                == GetOnlineFeaturesResponse.FieldStatus.PRESENT and
                got_fields["driver:rating"] == ValueProto.Value(string_val="9")
                and got_statuses["driver:rating"]
                == GetOnlineFeaturesResponse.FieldStatus.PRESENT
                and got_fields["driver:null_value"] == ValueProto.Value()
                and got_statuses["driver:null_value"]
                == GetOnlineFeaturesResponse.FieldStatus.NULL_VALUE)
Beispiel #2
0
def _entity_row_to_field_values(
    row: GetOnlineFeaturesRequestV2.EntityRow,
) -> GetOnlineFeaturesResponse.FieldValues:
    result = GetOnlineFeaturesResponse.FieldValues()
    for k in row.fields:
        result.fields[k].CopyFrom(row.fields[k])
        result.statuses[k] = GetOnlineFeaturesResponse.FieldStatus.PRESENT

    return result
Beispiel #3
0
    def _augment_response_with_on_demand_transforms(
        self,
        feature_refs: List[str],
        full_feature_names: bool,
        initial_response: OnlineResponse,
        result_rows: List[GetOnlineFeaturesResponse.FieldValues],
    ) -> OnlineResponse:
        all_on_demand_feature_views = {
            view.name: view
            for view in self._registry.list_on_demand_feature_views(
                project=self.project, allow_cache=True)
        }
        all_odfv_feature_names = all_on_demand_feature_views.keys()

        if len(all_on_demand_feature_views) == 0:
            return initial_response
        initial_response_df = initial_response.to_df()

        odfv_feature_refs = defaultdict(list)
        for feature_ref in feature_refs:
            view_name, feature_name = feature_ref.split(":")
            if view_name in all_odfv_feature_names:
                odfv_feature_refs[view_name].append(feature_name)

        # Apply on demand transformations
        for odfv_name, _feature_refs in odfv_feature_refs.items():
            odfv = all_on_demand_feature_views[odfv_name]
            transformed_features_df = odfv.get_transformed_features_df(
                full_feature_names, initial_response_df)
            for row_idx in range(len(result_rows)):
                result_row = result_rows[row_idx]

                selected_subset = [
                    f for f in transformed_features_df.columns
                    if f in _feature_refs
                ]

                for transformed_feature in selected_subset:
                    transformed_feature_name = (
                        f"{odfv.name}__{transformed_feature}"
                        if full_feature_names else transformed_feature)
                    proto_value = python_value_to_proto_value(
                        transformed_features_df[transformed_feature].
                        values[row_idx])
                    result_row.fields[transformed_feature_name].CopyFrom(
                        proto_value)
                    result_row.statuses[
                        transformed_feature_name] = GetOnlineFeaturesResponse.FieldStatus.PRESENT
        return OnlineResponse(
            GetOnlineFeaturesResponse(field_values=result_rows))
Beispiel #4
0
    def _get_online_features(
        self,
        entity_rows: List[GetOnlineFeaturesRequestV2.EntityRow],
        feature_refs: List[str],
        project: str,
    ) -> GetOnlineFeaturesResponse:

        provider = self._get_provider()

        entity_keys = []
        result_rows: List[GetOnlineFeaturesResponse.FieldValues] = []

        for row in entity_rows:
            entity_keys.append(_entity_row_to_key(row))
            result_rows.append(_entity_row_to_field_values(row))

        registry = self._get_registry()
        all_feature_views = registry.list_feature_views(
            project=self.config.project)

        grouped_refs = _group_refs(feature_refs, all_feature_views)
        for table, requested_features in grouped_refs:
            read_rows = provider.online_read(
                project=project,
                table=table,
                entity_keys=entity_keys,
            )
            for row_idx, read_row in enumerate(read_rows):
                row_ts, feature_data = read_row
                result_row = result_rows[row_idx]

                if feature_data is None:
                    for feature_name in requested_features:
                        feature_ref = f"{table.name}:{feature_name}"
                        result_row.statuses[
                            feature_ref] = GetOnlineFeaturesResponse.FieldStatus.NOT_FOUND
                else:
                    for feature_name in feature_data:
                        feature_ref = f"{table.name}:{feature_name}"
                        if feature_name in requested_features:
                            result_row.fields[feature_ref].CopyFrom(
                                feature_data[feature_name])
                            result_row.statuses[
                                feature_ref] = GetOnlineFeaturesResponse.FieldStatus.PRESENT

        return GetOnlineFeaturesResponse(field_values=result_rows)
Beispiel #5
0
    def get_online_features(
        self,
        feature_refs: List[str],
        entity_rows: List[Dict[str, Any]],
    ) -> OnlineResponse:
        """
        Retrieves the latest online feature data.

        Note: This method will download the full feature registry the first time it is run. If you are using a
        remote registry like GCS or S3 then that may take a few seconds. The registry remains cached up to a TTL
        duration (which can be set to infinitey). If the cached registry is stale (more time than the TTL has
        passed), then a new registry will be downloaded synchronously by this method. This download may
        introduce latency to online feature retrieval. In order to avoid synchronous downloads, please call
        refresh_registry() prior to the TTL being reached. Remember it is possible to set the cache TTL to
        infinity (cache forever).

        Args:
            feature_refs: List of feature references that will be returned for each entity.
                Each feature reference should have the following format:
                "feature_table:feature" where "feature_table" & "feature" refer to
                the feature and feature table names respectively.
                Only the feature name is required.
            entity_rows: A list of dictionaries where each key-value is an entity-name, entity-value pair.
        Returns:
            OnlineResponse containing the feature data in records.
        Examples:
            >>> from feast import FeatureStore
            >>>
            >>> store = FeatureStore(repo_path="...")
            >>> feature_refs = ["sales:daily_transactions"]
            >>> entity_rows = [{"customer_id": 0},{"customer_id": 1}]
            >>>
            >>> online_response = store.get_online_features(
            >>>     feature_refs, entity_rows)
            >>> online_response_dict = online_response.to_dict()
            >>> print(online_response_dict)
            {'sales:daily_transactions': [1.1,1.2], 'sales:customer_id': [0,1]}
        """

        provider = self._get_provider()
        entities = self.list_entities(allow_cache=True)
        entity_name_to_join_key_map = {}
        for entity in entities:
            entity_name_to_join_key_map[entity.name] = entity.join_key

        join_key_rows = []
        for row in entity_rows:
            join_key_row = {}
            for entity_name, entity_value in row.items():
                try:
                    join_key = entity_name_to_join_key_map[entity_name]
                except KeyError:
                    raise Exception(
                        f"Entity {entity_name} does not exist in project {self.project}"
                    )
                join_key_row[join_key] = entity_value
            join_key_rows.append(join_key_row)

        entity_row_proto_list = _infer_online_entity_rows(join_key_rows)

        union_of_entity_keys = []
        result_rows: List[GetOnlineFeaturesResponse.FieldValues] = []

        for entity_row_proto in entity_row_proto_list:
            union_of_entity_keys.append(_entity_row_to_key(entity_row_proto))
            result_rows.append(_entity_row_to_field_values(entity_row_proto))

        all_feature_views = self._registry.list_feature_views(
            project=self.project, allow_cache=True)

        grouped_refs = _group_refs(feature_refs, all_feature_views)
        for table, requested_features in grouped_refs:
            entity_keys = _get_table_entity_keys(table, union_of_entity_keys,
                                                 entity_name_to_join_key_map)
            read_rows = provider.online_read(
                project=self.project,
                table=table,
                entity_keys=entity_keys,
            )
            for row_idx, read_row in enumerate(read_rows):
                row_ts, feature_data = read_row
                result_row = result_rows[row_idx]

                if feature_data is None:
                    for feature_name in requested_features:
                        feature_ref = f"{table.name}__{feature_name}"
                        result_row.statuses[
                            feature_ref] = GetOnlineFeaturesResponse.FieldStatus.NOT_FOUND
                else:
                    for feature_name in feature_data:
                        feature_ref = f"{table.name}__{feature_name}"
                        if feature_name in requested_features:
                            result_row.fields[feature_ref].CopyFrom(
                                feature_data[feature_name])
                            result_row.statuses[
                                feature_ref] = GetOnlineFeaturesResponse.FieldStatus.PRESENT

        return OnlineResponse(
            GetOnlineFeaturesResponse(field_values=result_rows))
Beispiel #6
0
    def get_online_features(
        self,
        features: Union[List[str], FeatureService],
        entity_rows: List[Dict[str, Any]],
        feature_refs: Optional[List[str]] = None,
        full_feature_names: bool = False,
    ) -> OnlineResponse:
        """
        Retrieves the latest online feature data.

        Note: This method will download the full feature registry the first time it is run. If you are using a
        remote registry like GCS or S3 then that may take a few seconds. The registry remains cached up to a TTL
        duration (which can be set to infinity). If the cached registry is stale (more time than the TTL has
        passed), then a new registry will be downloaded synchronously by this method. This download may
        introduce latency to online feature retrieval. In order to avoid synchronous downloads, please call
        refresh_registry() prior to the TTL being reached. Remember it is possible to set the cache TTL to
        infinity (cache forever).

        Args:
            features: List of feature references that will be returned for each entity.
                Each feature reference should have the following format:
                "feature_table:feature" where "feature_table" & "feature" refer to
                the feature and feature table names respectively.
                Only the feature name is required.
            entity_rows: A list of dictionaries where each key-value is an entity-name, entity-value pair.

        Returns:
            OnlineResponse containing the feature data in records.

        Raises:
            Exception: No entity with the specified name exists.

        Examples:
            Materialize all features into the online store over the interval
            from 3 hours ago to 10 minutes ago, and then retrieve these online features.

            >>> from feast import FeatureStore, RepoConfig
            >>> fs = FeatureStore(repo_path="feature_repo")
            >>> online_response = fs.get_online_features(
            ...     features=[
            ...         "driver_hourly_stats:conv_rate",
            ...         "driver_hourly_stats:acc_rate",
            ...         "driver_hourly_stats:avg_daily_trips",
            ...     ],
            ...     entity_rows=[{"driver_id": 1001}, {"driver_id": 1002}, {"driver_id": 1003}, {"driver_id": 1004}],
            ... )
            >>> online_response_dict = online_response.to_dict()
        """
        _feature_refs = self._get_features(features, feature_refs)

        provider = self._get_provider()
        entities = self.list_entities(allow_cache=True)
        entity_name_to_join_key_map = {}
        for entity in entities:
            entity_name_to_join_key_map[entity.name] = entity.join_key

        join_key_rows = []
        for row in entity_rows:
            join_key_row = {}
            for entity_name, entity_value in row.items():
                try:
                    join_key = entity_name_to_join_key_map[entity_name]
                except KeyError:
                    raise EntityNotFoundException(entity_name, self.project)
                join_key_row[join_key] = entity_value
            join_key_rows.append(join_key_row)

        entity_row_proto_list = _infer_online_entity_rows(join_key_rows)

        union_of_entity_keys = []
        result_rows: List[GetOnlineFeaturesResponse.FieldValues] = []

        for entity_row_proto in entity_row_proto_list:
            union_of_entity_keys.append(_entity_row_to_key(entity_row_proto))
            result_rows.append(_entity_row_to_field_values(entity_row_proto))

        all_feature_views = self._registry.list_feature_views(
            project=self.project, allow_cache=True)

        _validate_feature_refs(_feature_refs, full_feature_names)
        grouped_refs = _group_feature_refs(_feature_refs, all_feature_views)
        for table, requested_features in grouped_refs:
            entity_keys = _get_table_entity_keys(table, union_of_entity_keys,
                                                 entity_name_to_join_key_map)
            read_rows = provider.online_read(
                config=self.config,
                table=table,
                entity_keys=entity_keys,
                requested_features=requested_features,
            )
            for row_idx, read_row in enumerate(read_rows):
                row_ts, feature_data = read_row
                result_row = result_rows[row_idx]

                if feature_data is None:
                    for feature_name in requested_features:
                        feature_ref = (f"{table.name}__{feature_name}"
                                       if full_feature_names else feature_name)
                        result_row.statuses[
                            feature_ref] = GetOnlineFeaturesResponse.FieldStatus.NOT_FOUND
                else:
                    for feature_name in feature_data:
                        feature_ref = (f"{table.name}__{feature_name}"
                                       if full_feature_names else feature_name)
                        if feature_name in requested_features:
                            result_row.fields[feature_ref].CopyFrom(
                                feature_data[feature_name])
                            result_row.statuses[
                                feature_ref] = GetOnlineFeaturesResponse.FieldStatus.PRESENT

        return OnlineResponse(
            GetOnlineFeaturesResponse(field_values=result_rows))
Beispiel #7
0
    def get_online_features(
        self,
        features: Union[List[str], FeatureService],
        entity_rows: List[Dict[str, Any]],
        feature_refs: Optional[List[str]] = None,
        full_feature_names: bool = False,
    ) -> OnlineResponse:
        """
        Retrieves the latest online feature data.

        Note: This method will download the full feature registry the first time it is run. If you are using a
        remote registry like GCS or S3 then that may take a few seconds. The registry remains cached up to a TTL
        duration (which can be set to infinity). If the cached registry is stale (more time than the TTL has
        passed), then a new registry will be downloaded synchronously by this method. This download may
        introduce latency to online feature retrieval. In order to avoid synchronous downloads, please call
        refresh_registry() prior to the TTL being reached. Remember it is possible to set the cache TTL to
        infinity (cache forever).

        Args:
            features: List of feature references that will be returned for each entity.
                Each feature reference should have the following format:
                "feature_table:feature" where "feature_table" & "feature" refer to
                the feature and feature table names respectively.
                Only the feature name is required.
            entity_rows: A list of dictionaries where each key-value is an entity-name, entity-value pair.

        Returns:
            OnlineResponse containing the feature data in records.

        Raises:
            Exception: No entity with the specified name exists.

        Examples:
            Materialize all features into the online store over the interval
            from 3 hours ago to 10 minutes ago, and then retrieve these online features.

            >>> from feast import FeatureStore, RepoConfig
            >>> fs = FeatureStore(repo_path="feature_repo")
            >>> online_response = fs.get_online_features(
            ...     features=[
            ...         "driver_hourly_stats:conv_rate",
            ...         "driver_hourly_stats:acc_rate",
            ...         "driver_hourly_stats:avg_daily_trips",
            ...     ],
            ...     entity_rows=[{"driver_id": 1001}, {"driver_id": 1002}, {"driver_id": 1003}, {"driver_id": 1004}],
            ... )
            >>> online_response_dict = online_response.to_dict()
        """
        _feature_refs = self._get_features(features, feature_refs)
        all_feature_views = self._list_feature_views(allow_cache=True,
                                                     hide_dummy_entity=False)
        all_on_demand_feature_views = self._registry.list_on_demand_feature_views(
            project=self.project, allow_cache=True)

        _validate_feature_refs(_feature_refs, full_feature_names)
        grouped_refs, grouped_odfv_refs = _group_feature_refs(
            _feature_refs, all_feature_views, all_on_demand_feature_views)
        if len(grouped_odfv_refs) > 0:
            log_event(UsageEvent.GET_ONLINE_FEATURES_WITH_ODFV)

        feature_views = list(view for view, _ in grouped_refs)
        entityless_case = DUMMY_ENTITY_NAME in [
            entity_name for feature_view in feature_views
            for entity_name in feature_view.entities
        ]

        provider = self._get_provider()
        entities = self._list_entities(allow_cache=True,
                                       hide_dummy_entity=False)
        entity_name_to_join_key_map = {}
        for entity in entities:
            entity_name_to_join_key_map[entity.name] = entity.join_key

        needed_request_data_features = self._get_needed_request_data_features(
            grouped_odfv_refs)

        join_key_rows = []
        request_data_features: Dict[str, List[Any]] = {}
        # Entity rows may be either entities or request data.
        for row in entity_rows:
            join_key_row = {}
            for entity_name, entity_value in row.items():
                # Found request data
                if entity_name in needed_request_data_features:
                    if entity_name not in request_data_features:
                        request_data_features[entity_name] = []
                    request_data_features[entity_name].append(entity_value)
                    continue
                try:
                    join_key = entity_name_to_join_key_map[entity_name]
                except KeyError:
                    raise EntityNotFoundException(entity_name, self.project)
                join_key_row[join_key] = entity_value
                if entityless_case:
                    join_key_row[DUMMY_ENTITY_ID] = DUMMY_ENTITY_VAL
            if len(join_key_row) > 0:
                # May be empty if this entity row was request data
                join_key_rows.append(join_key_row)

        if len(needed_request_data_features) != len(
                request_data_features.keys()):
            raise RequestDataNotFoundInEntityRowsException(
                feature_names=needed_request_data_features)

        entity_row_proto_list = _infer_online_entity_rows(join_key_rows)

        union_of_entity_keys: List[EntityKeyProto] = []
        result_rows: List[GetOnlineFeaturesResponse.FieldValues] = []

        for entity_row_proto in entity_row_proto_list:
            # Create a list of entity keys to filter down for each feature view at lookup time.
            union_of_entity_keys.append(_entity_row_to_key(entity_row_proto))
            # Also create entity values to append to the result
            result_rows.append(_entity_row_to_field_values(entity_row_proto))

        # Add more feature values to the existing result rows for the request data features
        for feature_name, feature_values in request_data_features.items():
            for row_idx, feature_value in enumerate(feature_values):
                result_row = result_rows[row_idx]
                result_row.fields[feature_name].CopyFrom(
                    python_value_to_proto_value(feature_value))
                result_row.statuses[
                    feature_name] = GetOnlineFeaturesResponse.FieldStatus.PRESENT

        for table, requested_features in grouped_refs:
            self._populate_result_rows_from_feature_view(
                entity_name_to_join_key_map,
                full_feature_names,
                provider,
                requested_features,
                result_rows,
                table,
                union_of_entity_keys,
            )

        initial_response = OnlineResponse(
            GetOnlineFeaturesResponse(field_values=result_rows))
        return self._augment_response_with_on_demand_transforms(
            _feature_refs, full_feature_names, initial_response, result_rows)