Exemple #1
0
        def query_generator() -> Iterator[str]:

            table_name = offline_utils.get_temp_entity_table_name()

            _upload_entity_df(entity_df, snowflake_conn, config, table_name)

            expected_join_keys = offline_utils.get_expected_join_keys(
                project, feature_views, registry)

            offline_utils.assert_expected_columns_in_entity_df(
                entity_schema, expected_join_keys,
                entity_df_event_timestamp_col)

            # Build a query context containing all information required to template the Snowflake SQL query
            query_context = offline_utils.get_feature_view_query_context(
                feature_refs,
                feature_views,
                registry,
                project,
                entity_df_event_timestamp_range,
            )

            query_context = _fix_entity_selections_identifiers(query_context)

            # Generate the Snowflake SQL query from the query context
            query = offline_utils.build_point_in_time_query(
                query_context,
                left_table_query_string=table_name,
                entity_df_event_timestamp_col=entity_df_event_timestamp_col,
                entity_df_columns=entity_schema.keys(),
                query_template=MULTIPLE_FEATURE_VIEW_POINT_IN_TIME_JOIN,
                full_feature_names=full_feature_names,
            )

            yield query
Exemple #2
0
def _get_table_reference_for_new_entity(
    catalog: str,
    dataset_name: str,
) -> str:
    """Gets the table_id for the new entity to be uploaded."""
    table_name = offline_utils.get_temp_entity_table_name()
    return f"{catalog}.{dataset_name}.{table_name}"
Exemple #3
0
        def query_generator() -> Iterator[str]:
            table_name = offline_utils.get_temp_entity_table_name()

            entity_schema = _upload_entity_df_and_get_entity_schema(
                entity_df, redshift_client, config, s3_resource, table_name)

            entity_df_event_timestamp_col = offline_utils.infer_event_timestamp_from_entity_df(
                entity_schema)

            expected_join_keys = offline_utils.get_expected_join_keys(
                project, feature_views, registry)

            offline_utils.assert_expected_columns_in_entity_df(
                entity_schema, expected_join_keys,
                entity_df_event_timestamp_col)

            entity_df_event_timestamp_range = _get_entity_df_event_timestamp_range(
                entity_df,
                entity_df_event_timestamp_col,
                redshift_client,
                config,
                table_name,
            )

            # Build a query context containing all information required to template the Redshift SQL query
            query_context = offline_utils.get_feature_view_query_context(
                feature_refs,
                feature_views,
                registry,
                project,
                entity_df_event_timestamp_range,
            )

            # Generate the Redshift SQL query from the query context
            query = offline_utils.build_point_in_time_query(
                query_context,
                left_table_query_string=table_name,
                entity_df_event_timestamp_col=entity_df_event_timestamp_col,
                entity_df_columns=entity_schema.keys(),
                query_template=MULTIPLE_FEATURE_VIEW_POINT_IN_TIME_JOIN,
                full_feature_names=full_feature_names,
            )

            try:
                yield query
            finally:
                # Always clean up the uploaded Redshift table
                aws_utils.execute_redshift_statement(
                    redshift_client,
                    config.offline_store.cluster_id,
                    config.offline_store.database,
                    config.offline_store.user,
                    f"DROP TABLE IF EXISTS {table_name}",
                )
Exemple #4
0
def _get_table_reference_for_new_entity(client: Client, dataset_project: str,
                                        dataset_name: str) -> str:
    """Gets the table_id for the new entity to be uploaded."""

    # First create the BigQuery dataset if it doesn't exist
    dataset = bigquery.Dataset(f"{dataset_project}.{dataset_name}")
    dataset.location = "US"

    try:
        client.get_dataset(dataset)
    except NotFound:
        # Only create the dataset if it does not exist
        client.create_dataset(dataset, exists_ok=True)

    table_name = offline_utils.get_temp_entity_table_name()

    return f"{dataset_project}.{dataset_name}.{table_name}"
Exemple #5
0
    def get_table_query_string(self) -> str:
        """Returns a string that can directly be used to reference this table in SQL"""
        if self.table:
            # Backticks make sure that spark sql knows this a table reference.
            return f"`{self.table}`"
        if self.query:
            return f"({self.query})"

        # If both the table query string and the actual query are null, we can load from file.
        spark_session = SparkSession.getActiveSession()
        if spark_session is None:
            raise AssertionError("Could not find an active spark session.")
        try:
            df = spark_session.read.format(self.file_format).load(self.path)
        except Exception:
            logger.exception("Spark read of file source failed.\n" +
                             traceback.format_exc())
        tmp_table_name = get_temp_entity_table_name()
        df.createOrReplaceTempView(tmp_table_name)

        return f"`{tmp_table_name}`"
Exemple #6
0
        def query_generator() -> Iterator[str]:
            table_name = None
            if isinstance(entity_df, pd.DataFrame):
                table_name = offline_utils.get_temp_entity_table_name()
                entity_schema = df_to_postgres_table(config.offline_store,
                                                     entity_df, table_name)
                df_query = table_name
            elif isinstance(entity_df, str):
                df_query = f"({entity_df}) AS sub"
                entity_schema = get_query_schema(config.offline_store,
                                                 df_query)
            else:
                raise TypeError(entity_df)

            entity_df_event_timestamp_col = offline_utils.infer_event_timestamp_from_entity_df(
                entity_schema)

            expected_join_keys = offline_utils.get_expected_join_keys(
                project, feature_views, registry)

            offline_utils.assert_expected_columns_in_entity_df(
                entity_schema, expected_join_keys,
                entity_df_event_timestamp_col)

            entity_df_event_timestamp_range = _get_entity_df_event_timestamp_range(
                entity_df,
                entity_df_event_timestamp_col,
                config,
                df_query,
            )

            query_context = offline_utils.get_feature_view_query_context(
                feature_refs,
                feature_views,
                registry,
                project,
                entity_df_event_timestamp_range,
            )

            query_context_dict = [asdict(context) for context in query_context]
            # Hack for query_context.entity_selections to support uppercase in columns
            for context in query_context_dict:
                context["entity_selections"] = [
                    f'''"{entity_selection.replace(' AS ', '" AS "')}\"'''
                    for entity_selection in context["entity_selections"]
                ]

            try:
                yield build_point_in_time_query(
                    query_context_dict,
                    left_table_query_string=df_query,
                    entity_df_event_timestamp_col=entity_df_event_timestamp_col,
                    entity_df_columns=entity_schema.keys(),
                    query_template=MULTIPLE_FEATURE_VIEW_POINT_IN_TIME_JOIN,
                    full_feature_names=full_feature_names,
                )
            finally:
                if table_name:
                    with _get_conn(config.offline_store) as conn, conn.cursor(
                    ) as cur:
                        cur.execute(
                            sql.SQL("""
                                DROP TABLE IF EXISTS {};
                                """).format(sql.Identifier(table_name)), )
Exemple #7
0
    def get_historical_features(
        config: RepoConfig,
        feature_views: List[FeatureView],
        feature_refs: List[str],
        entity_df: Union[pandas.DataFrame, str],
        registry: Registry,
        project: str,
        full_feature_names: bool = False,
    ) -> RetrievalJob:
        assert isinstance(config.offline_store, SparkOfflineStoreConfig)
        warnings.warn(
            "The spark offline store is an experimental feature in alpha development. "
            "Some functionality may still be unstable so functionality can change in the future.",
            RuntimeWarning,
        )
        spark_session = get_spark_session_or_start_new_with_repoconfig(
            store_config=config.offline_store)
        tmp_entity_df_table_name = offline_utils.get_temp_entity_table_name()

        entity_schema = _get_entity_schema(
            spark_session=spark_session,
            entity_df=entity_df,
        )
        event_timestamp_col = offline_utils.infer_event_timestamp_from_entity_df(
            entity_schema=entity_schema, )
        entity_df_event_timestamp_range = _get_entity_df_event_timestamp_range(
            entity_df,
            event_timestamp_col,
            spark_session,
        )
        _upload_entity_df(
            spark_session=spark_session,
            table_name=tmp_entity_df_table_name,
            entity_df=entity_df,
            event_timestamp_col=event_timestamp_col,
        )

        expected_join_keys = offline_utils.get_expected_join_keys(
            project=project, feature_views=feature_views, registry=registry)
        offline_utils.assert_expected_columns_in_entity_df(
            entity_schema=entity_schema,
            join_keys=expected_join_keys,
            entity_df_event_timestamp_col=event_timestamp_col,
        )

        query_context = offline_utils.get_feature_view_query_context(
            feature_refs,
            feature_views,
            registry,
            project,
            entity_df_event_timestamp_range,
        )

        query = offline_utils.build_point_in_time_query(
            feature_view_query_contexts=query_context,
            left_table_query_string=tmp_entity_df_table_name,
            entity_df_event_timestamp_col=event_timestamp_col,
            entity_df_columns=entity_schema.keys(),
            query_template=MULTIPLE_FEATURE_VIEW_POINT_IN_TIME_JOIN,
            full_feature_names=full_feature_names,
        )

        return SparkRetrievalJob(
            spark_session=spark_session,
            query=query,
            full_feature_names=full_feature_names,
            on_demand_feature_views=OnDemandFeatureView.get_requested_odfvs(
                feature_refs, project, registry),
            metadata=RetrievalMetadata(
                features=feature_refs,
                keys=list(set(entity_schema.keys()) - {event_timestamp_col}),
                min_event_timestamp=entity_df_event_timestamp_range[0],
                max_event_timestamp=entity_df_event_timestamp_range[1],
            ),
        )