Ejemplo n.º 1
0
    def get_historical_features(
        config: RepoConfig,
        feature_views: List[FeatureView],
        feature_refs: List[str],
        entity_df: Union[pd.DataFrame, str],
        registry: Registry,
        project: str,
        full_feature_names: bool = False,
    ) -> RetrievalJob:
        # TODO: Add entity_df validation in order to fail before interacting with BigQuery
        assert isinstance(config.offline_store, BigQueryOfflineStoreConfig)

        client = _get_bigquery_client(project=config.offline_store.project_id)

        assert isinstance(config.offline_store, BigQueryOfflineStoreConfig)

        table_reference = _get_table_reference_for_new_entity(
            client, client.project, config.offline_store.dataset
        )

        entity_schema = _upload_entity_df_and_get_entity_schema(
            client=client, table_name=table_reference, entity_df=entity_df,
        )

        entity_df_event_timestamp_col = offline_utils.infer_event_timestamp_from_entity_df(
            entity_schema
        )

        expected_join_keys = offline_utils.get_expected_join_keys(
            project, feature_views, registry
        )

        offline_utils.assert_expected_columns_in_entity_df(
            entity_schema, expected_join_keys, entity_df_event_timestamp_col
        )

        # Build a query context containing all information required to template the BigQuery SQL query
        query_context = offline_utils.get_feature_view_query_context(
            feature_refs, feature_views, registry, project,
        )

        # Generate the BigQuery SQL query from the query context
        query = offline_utils.build_point_in_time_query(
            query_context,
            left_table_query_string=table_reference,
            entity_df_event_timestamp_col=entity_df_event_timestamp_col,
            query_template=MULTIPLE_FEATURE_VIEW_POINT_IN_TIME_JOIN,
            full_feature_names=full_feature_names,
        )

        return BigQueryRetrievalJob(
            query=query,
            client=client,
            config=config,
            full_feature_names=full_feature_names,
            on_demand_feature_views=OnDemandFeatureView.get_requested_odfvs(
                feature_refs, project, registry
            ),
        )
Ejemplo n.º 2
0
    def apply_on_demand_feature_view(
        self,
        on_demand_feature_view: OnDemandFeatureView,
        project: str,
        commit: bool = True,
    ):
        """
        Registers a single on demand feature view with Feast

        Args:
            on_demand_feature_view: Feature view that will be registered
            project: Feast project that this feature view belongs to
            commit: Whether the change should be persisted immediately
        """
        on_demand_feature_view_proto = on_demand_feature_view.to_proto()
        on_demand_feature_view_proto.spec.project = project
        self._prepare_registry_for_changes()
        assert self.cached_registry_proto

        if on_demand_feature_view.name in self._get_existing_feature_view_names(
        ):
            raise ConflictingFeatureViewNames(on_demand_feature_view.name)

        for idx, existing_feature_view_proto in enumerate(
                self.cached_registry_proto.on_demand_feature_views):
            if (existing_feature_view_proto.spec.name
                    == on_demand_feature_view_proto.spec.name
                    and existing_feature_view_proto.spec.project == project):
                if (OnDemandFeatureView.from_proto(existing_feature_view_proto)
                        == on_demand_feature_view):
                    return
                else:
                    del self.cached_registry_proto.on_demand_feature_views[idx]
                    break

        self.cached_registry_proto.on_demand_feature_views.append(
            on_demand_feature_view_proto)
        if commit:
            self.commit()
Ejemplo n.º 3
0
    def list_on_demand_feature_views(
            self,
            project: str,
            allow_cache: bool = False) -> List[OnDemandFeatureView]:
        """
        Retrieve a list of on demand feature views from the registry

        Args:
            project: Filter on demand feature views based on project name
            allow_cache: Whether to allow returning on demand feature views from a cached registry

        Returns:
            List of on demand feature views
        """

        registry = self._get_registry_proto(allow_cache=allow_cache)
        on_demand_feature_views = []
        for on_demand_feature_view in registry.on_demand_feature_views:
            if on_demand_feature_view.spec.project == project:
                on_demand_feature_views.append(
                    OnDemandFeatureView.from_proto(on_demand_feature_view))
        return on_demand_feature_views
Ejemplo n.º 4
0
    def get_on_demand_feature_view(
            self,
            name: str,
            project: str,
            allow_cache: bool = False) -> OnDemandFeatureView:
        """
        Retrieves an on demand feature view.

        Args:
            name: Name of on demand feature view
            project: Feast project that this on demand feature  belongs to

        Returns:
            Returns either the specified on demand feature view, or raises an exception if
            none is found
        """
        registry = self._get_registry_proto(allow_cache=allow_cache)

        for on_demand_feature_view in registry.on_demand_feature_views:
            if (on_demand_feature_view.spec.project == project
                    and on_demand_feature_view.spec.name == name):
                return OnDemandFeatureView.from_proto(on_demand_feature_view)
        raise OnDemandFeatureViewNotFoundException(name, project=project)
Ejemplo n.º 5
0
    def get_historical_features(
        config: RepoConfig,
        feature_views: List[FeatureView],
        feature_refs: List[str],
        entity_df: Union[pd.DataFrame, str],
        registry: Registry,
        project: str,
        full_feature_names: bool = False,
        user: str = "user",
        auth: Optional[Authentication] = None,
        http_scheme: Optional[str] = None,
    ) -> TrinoRetrievalJob:
        if not isinstance(config.offline_store, TrinoOfflineStoreConfig):
            raise ValueError(
                f"This function should be used with a TrinoOfflineStoreConfig object. Instead we have config.offline_store being '{type(config.offline_store)}'"
            )

        client = _get_trino_client(config=config,
                                   user=user,
                                   auth=auth,
                                   http_scheme=http_scheme)

        table_reference = _get_table_reference_for_new_entity(
            catalog=config.offline_store.catalog,
            dataset_name=config.offline_store.dataset,
        )

        entity_schema = _upload_entity_df_and_get_entity_schema(
            client=client,
            table_name=table_reference,
            entity_df=entity_df,
            connector=config.offline_store.connector,
        )

        entity_df_event_timestamp_col = offline_utils.infer_event_timestamp_from_entity_df(
            entity_schema=entity_schema)

        entity_df_event_timestamp_range = _get_entity_df_event_timestamp_range(
            entity_df=entity_df,
            entity_df_event_timestamp_col=entity_df_event_timestamp_col,
            client=client,
        )

        expected_join_keys = offline_utils.get_expected_join_keys(
            project=project, feature_views=feature_views, registry=registry)

        offline_utils.assert_expected_columns_in_entity_df(
            entity_schema=entity_schema,
            join_keys=expected_join_keys,
            entity_df_event_timestamp_col=entity_df_event_timestamp_col,
        )

        # Build a query context containing all information required to template the Trino SQL query
        query_context = offline_utils.get_feature_view_query_context(
            feature_refs=feature_refs,
            feature_views=feature_views,
            registry=registry,
            project=project,
            entity_df_timestamp_range=entity_df_event_timestamp_range,
        )

        # Generate the Trino SQL query from the query context
        query = offline_utils.build_point_in_time_query(
            query_context,
            left_table_query_string=table_reference,
            entity_df_event_timestamp_col=entity_df_event_timestamp_col,
            entity_df_columns=entity_schema.keys(),
            query_template=MULTIPLE_FEATURE_VIEW_POINT_IN_TIME_JOIN,
            full_feature_names=full_feature_names,
        )

        return TrinoRetrievalJob(
            query=query,
            client=client,
            config=config,
            full_feature_names=full_feature_names,
            on_demand_feature_views=OnDemandFeatureView.get_requested_odfvs(
                feature_refs, project, registry),
            metadata=RetrievalMetadata(
                features=feature_refs,
                keys=list(
                    set(entity_schema.keys()) -
                    {entity_df_event_timestamp_col}),
                min_event_timestamp=entity_df_event_timestamp_range[0],
                max_event_timestamp=entity_df_event_timestamp_range[1],
            ),
        )
Ejemplo n.º 6
0
def test_hash():
    file_source = FileSource(name="my-file-source", path="test.parquet")
    feature_view = FeatureView(
        name="my-feature-view",
        entities=[],
        schema=[
            Field(name="feature1", dtype=Float32),
            Field(name="feature2", dtype=Float32),
        ],
        source=file_source,
    )
    sources = [feature_view]
    on_demand_feature_view_1 = OnDemandFeatureView(
        name="my-on-demand-feature-view",
        sources=sources,
        schema=[
            Field(name="output1", dtype=Float32),
            Field(name="output2", dtype=Float32),
        ],
        udf=udf1,
    )
    on_demand_feature_view_2 = OnDemandFeatureView(
        name="my-on-demand-feature-view",
        sources=sources,
        schema=[
            Field(name="output1", dtype=Float32),
            Field(name="output2", dtype=Float32),
        ],
        udf=udf1,
    )
    on_demand_feature_view_3 = OnDemandFeatureView(
        name="my-on-demand-feature-view",
        sources=sources,
        schema=[
            Field(name="output1", dtype=Float32),
            Field(name="output2", dtype=Float32),
        ],
        udf=udf2,
    )
    on_demand_feature_view_4 = OnDemandFeatureView(
        name="my-on-demand-feature-view",
        sources=sources,
        schema=[
            Field(name="output1", dtype=Float32),
            Field(name="output2", dtype=Float32),
        ],
        udf=udf2,
        description="test",
    )

    s1 = {on_demand_feature_view_1, on_demand_feature_view_2}
    assert len(s1) == 1

    s2 = {on_demand_feature_view_1, on_demand_feature_view_3}
    assert len(s2) == 2

    s3 = {on_demand_feature_view_3, on_demand_feature_view_4}
    assert len(s3) == 2

    s4 = {
        on_demand_feature_view_1,
        on_demand_feature_view_2,
        on_demand_feature_view_3,
        on_demand_feature_view_4,
    }
    assert len(s4) == 3
Ejemplo n.º 7
0
    def get_historical_features(
        config: RepoConfig,
        feature_views: List[FeatureView],
        feature_refs: List[str],
        entity_df: Union[pd.DataFrame, str],
        registry: Registry,
        project: str,
        full_feature_names: bool = False,
    ) -> RetrievalJob:
        # TODO: Add entity_df validation in order to fail before interacting with BigQuery
        assert isinstance(config.offline_store, BigQueryOfflineStoreConfig)

        client = _get_bigquery_client(
            project=config.offline_store.project_id,
            location=config.offline_store.location,
        )

        assert isinstance(config.offline_store, BigQueryOfflineStoreConfig)

        table_reference = _get_table_reference_for_new_entity(
            client,
            client.project,
            config.offline_store.dataset,
            config.offline_store.location,
        )

        entity_schema = _get_entity_schema(
            client=client,
            entity_df=entity_df,
        )

        entity_df_event_timestamp_col = offline_utils.infer_event_timestamp_from_entity_df(
            entity_schema)

        entity_df_event_timestamp_range = _get_entity_df_event_timestamp_range(
            entity_df,
            entity_df_event_timestamp_col,
            client,
        )

        @contextlib.contextmanager
        def query_generator() -> Iterator[str]:
            _upload_entity_df(
                client=client,
                table_name=table_reference,
                entity_df=entity_df,
            )

            expected_join_keys = offline_utils.get_expected_join_keys(
                project, feature_views, registry)

            offline_utils.assert_expected_columns_in_entity_df(
                entity_schema, expected_join_keys,
                entity_df_event_timestamp_col)

            # Build a query context containing all information required to template the BigQuery SQL query
            query_context = offline_utils.get_feature_view_query_context(
                feature_refs,
                feature_views,
                registry,
                project,
                entity_df_event_timestamp_range,
            )

            # Generate the BigQuery SQL query from the query context
            query = offline_utils.build_point_in_time_query(
                query_context,
                left_table_query_string=table_reference,
                entity_df_event_timestamp_col=entity_df_event_timestamp_col,
                entity_df_columns=entity_schema.keys(),
                query_template=MULTIPLE_FEATURE_VIEW_POINT_IN_TIME_JOIN,
                full_feature_names=full_feature_names,
            )

            try:
                yield query
            finally:
                # Asynchronously clean up the uploaded Bigquery table, which will expire
                # if cleanup fails
                client.delete_table(table=table_reference, not_found_ok=True)

        return BigQueryRetrievalJob(
            query=query_generator,
            client=client,
            config=config,
            full_feature_names=full_feature_names,
            on_demand_feature_views=OnDemandFeatureView.get_requested_odfvs(
                feature_refs, project, registry),
            metadata=RetrievalMetadata(
                features=feature_refs,
                keys=list(entity_schema.keys() -
                          {entity_df_event_timestamp_col}),
                min_event_timestamp=entity_df_event_timestamp_range[0],
                max_event_timestamp=entity_df_event_timestamp_range[1],
            ),
        )
Ejemplo n.º 8
0
    def get_historical_features(
        config: RepoConfig,
        feature_views: List[FeatureView],
        feature_refs: List[str],
        entity_df: Union[pd.DataFrame, str],
        registry: Registry,
        project: str,
        full_feature_names: bool = False,
    ) -> RetrievalJob:
        @contextlib.contextmanager
        def query_generator() -> Iterator[str]:
            table_name = None
            if isinstance(entity_df, pd.DataFrame):
                table_name = offline_utils.get_temp_entity_table_name()
                entity_schema = df_to_postgres_table(config.offline_store,
                                                     entity_df, table_name)
                df_query = table_name
            elif isinstance(entity_df, str):
                df_query = f"({entity_df}) AS sub"
                entity_schema = get_query_schema(config.offline_store,
                                                 df_query)
            else:
                raise TypeError(entity_df)

            entity_df_event_timestamp_col = offline_utils.infer_event_timestamp_from_entity_df(
                entity_schema)

            expected_join_keys = offline_utils.get_expected_join_keys(
                project, feature_views, registry)

            offline_utils.assert_expected_columns_in_entity_df(
                entity_schema, expected_join_keys,
                entity_df_event_timestamp_col)

            entity_df_event_timestamp_range = _get_entity_df_event_timestamp_range(
                entity_df,
                entity_df_event_timestamp_col,
                config,
                df_query,
            )

            query_context = offline_utils.get_feature_view_query_context(
                feature_refs,
                feature_views,
                registry,
                project,
                entity_df_event_timestamp_range,
            )

            query_context_dict = [asdict(context) for context in query_context]
            # Hack for query_context.entity_selections to support uppercase in columns
            for context in query_context_dict:
                context["entity_selections"] = [
                    f'''"{entity_selection.replace(' AS ', '" AS "')}\"'''
                    for entity_selection in context["entity_selections"]
                ]

            try:
                yield build_point_in_time_query(
                    query_context_dict,
                    left_table_query_string=df_query,
                    entity_df_event_timestamp_col=entity_df_event_timestamp_col,
                    entity_df_columns=entity_schema.keys(),
                    query_template=MULTIPLE_FEATURE_VIEW_POINT_IN_TIME_JOIN,
                    full_feature_names=full_feature_names,
                )
            finally:
                if table_name:
                    with _get_conn(config.offline_store) as conn, conn.cursor(
                    ) as cur:
                        cur.execute(
                            sql.SQL("""
                                DROP TABLE IF EXISTS {};
                                """).format(sql.Identifier(table_name)), )

        return PostgreSQLRetrievalJob(
            query=query_generator,
            config=config,
            full_feature_names=full_feature_names,
            on_demand_feature_views=OnDemandFeatureView.get_requested_odfvs(
                feature_refs, project, registry),
        )