def query_generator() -> Iterator[str]: table_name = offline_utils.get_temp_entity_table_name() _upload_entity_df(entity_df, snowflake_conn, config, table_name) expected_join_keys = offline_utils.get_expected_join_keys( project, feature_views, registry) offline_utils.assert_expected_columns_in_entity_df( entity_schema, expected_join_keys, entity_df_event_timestamp_col) # Build a query context containing all information required to template the Snowflake SQL query query_context = offline_utils.get_feature_view_query_context( feature_refs, feature_views, registry, project, entity_df_event_timestamp_range, ) query_context = _fix_entity_selections_identifiers(query_context) # Generate the Snowflake SQL query from the query context query = offline_utils.build_point_in_time_query( query_context, left_table_query_string=table_name, entity_df_event_timestamp_col=entity_df_event_timestamp_col, entity_df_columns=entity_schema.keys(), query_template=MULTIPLE_FEATURE_VIEW_POINT_IN_TIME_JOIN, full_feature_names=full_feature_names, ) yield query
def _get_table_reference_for_new_entity( catalog: str, dataset_name: str, ) -> str: """Gets the table_id for the new entity to be uploaded.""" table_name = offline_utils.get_temp_entity_table_name() return f"{catalog}.{dataset_name}.{table_name}"
def query_generator() -> Iterator[str]: table_name = offline_utils.get_temp_entity_table_name() entity_schema = _upload_entity_df_and_get_entity_schema( entity_df, redshift_client, config, s3_resource, table_name) entity_df_event_timestamp_col = offline_utils.infer_event_timestamp_from_entity_df( entity_schema) expected_join_keys = offline_utils.get_expected_join_keys( project, feature_views, registry) offline_utils.assert_expected_columns_in_entity_df( entity_schema, expected_join_keys, entity_df_event_timestamp_col) entity_df_event_timestamp_range = _get_entity_df_event_timestamp_range( entity_df, entity_df_event_timestamp_col, redshift_client, config, table_name, ) # Build a query context containing all information required to template the Redshift SQL query query_context = offline_utils.get_feature_view_query_context( feature_refs, feature_views, registry, project, entity_df_event_timestamp_range, ) # Generate the Redshift SQL query from the query context query = offline_utils.build_point_in_time_query( query_context, left_table_query_string=table_name, entity_df_event_timestamp_col=entity_df_event_timestamp_col, entity_df_columns=entity_schema.keys(), query_template=MULTIPLE_FEATURE_VIEW_POINT_IN_TIME_JOIN, full_feature_names=full_feature_names, ) try: yield query finally: # Always clean up the uploaded Redshift table aws_utils.execute_redshift_statement( redshift_client, config.offline_store.cluster_id, config.offline_store.database, config.offline_store.user, f"DROP TABLE IF EXISTS {table_name}", )
def _get_table_reference_for_new_entity(client: Client, dataset_project: str, dataset_name: str) -> str: """Gets the table_id for the new entity to be uploaded.""" # First create the BigQuery dataset if it doesn't exist dataset = bigquery.Dataset(f"{dataset_project}.{dataset_name}") dataset.location = "US" try: client.get_dataset(dataset) except NotFound: # Only create the dataset if it does not exist client.create_dataset(dataset, exists_ok=True) table_name = offline_utils.get_temp_entity_table_name() return f"{dataset_project}.{dataset_name}.{table_name}"
def get_table_query_string(self) -> str: """Returns a string that can directly be used to reference this table in SQL""" if self.table: # Backticks make sure that spark sql knows this a table reference. return f"`{self.table}`" if self.query: return f"({self.query})" # If both the table query string and the actual query are null, we can load from file. spark_session = SparkSession.getActiveSession() if spark_session is None: raise AssertionError("Could not find an active spark session.") try: df = spark_session.read.format(self.file_format).load(self.path) except Exception: logger.exception("Spark read of file source failed.\n" + traceback.format_exc()) tmp_table_name = get_temp_entity_table_name() df.createOrReplaceTempView(tmp_table_name) return f"`{tmp_table_name}`"
def query_generator() -> Iterator[str]: table_name = None if isinstance(entity_df, pd.DataFrame): table_name = offline_utils.get_temp_entity_table_name() entity_schema = df_to_postgres_table(config.offline_store, entity_df, table_name) df_query = table_name elif isinstance(entity_df, str): df_query = f"({entity_df}) AS sub" entity_schema = get_query_schema(config.offline_store, df_query) else: raise TypeError(entity_df) entity_df_event_timestamp_col = offline_utils.infer_event_timestamp_from_entity_df( entity_schema) expected_join_keys = offline_utils.get_expected_join_keys( project, feature_views, registry) offline_utils.assert_expected_columns_in_entity_df( entity_schema, expected_join_keys, entity_df_event_timestamp_col) entity_df_event_timestamp_range = _get_entity_df_event_timestamp_range( entity_df, entity_df_event_timestamp_col, config, df_query, ) query_context = offline_utils.get_feature_view_query_context( feature_refs, feature_views, registry, project, entity_df_event_timestamp_range, ) query_context_dict = [asdict(context) for context in query_context] # Hack for query_context.entity_selections to support uppercase in columns for context in query_context_dict: context["entity_selections"] = [ f'''"{entity_selection.replace(' AS ', '" AS "')}\"''' for entity_selection in context["entity_selections"] ] try: yield build_point_in_time_query( query_context_dict, left_table_query_string=df_query, entity_df_event_timestamp_col=entity_df_event_timestamp_col, entity_df_columns=entity_schema.keys(), query_template=MULTIPLE_FEATURE_VIEW_POINT_IN_TIME_JOIN, full_feature_names=full_feature_names, ) finally: if table_name: with _get_conn(config.offline_store) as conn, conn.cursor( ) as cur: cur.execute( sql.SQL(""" DROP TABLE IF EXISTS {}; """).format(sql.Identifier(table_name)), )
def get_historical_features( config: RepoConfig, feature_views: List[FeatureView], feature_refs: List[str], entity_df: Union[pandas.DataFrame, str], registry: Registry, project: str, full_feature_names: bool = False, ) -> RetrievalJob: assert isinstance(config.offline_store, SparkOfflineStoreConfig) warnings.warn( "The spark offline store is an experimental feature in alpha development. " "Some functionality may still be unstable so functionality can change in the future.", RuntimeWarning, ) spark_session = get_spark_session_or_start_new_with_repoconfig( store_config=config.offline_store) tmp_entity_df_table_name = offline_utils.get_temp_entity_table_name() entity_schema = _get_entity_schema( spark_session=spark_session, entity_df=entity_df, ) event_timestamp_col = offline_utils.infer_event_timestamp_from_entity_df( entity_schema=entity_schema, ) entity_df_event_timestamp_range = _get_entity_df_event_timestamp_range( entity_df, event_timestamp_col, spark_session, ) _upload_entity_df( spark_session=spark_session, table_name=tmp_entity_df_table_name, entity_df=entity_df, event_timestamp_col=event_timestamp_col, ) expected_join_keys = offline_utils.get_expected_join_keys( project=project, feature_views=feature_views, registry=registry) offline_utils.assert_expected_columns_in_entity_df( entity_schema=entity_schema, join_keys=expected_join_keys, entity_df_event_timestamp_col=event_timestamp_col, ) query_context = offline_utils.get_feature_view_query_context( feature_refs, feature_views, registry, project, entity_df_event_timestamp_range, ) query = offline_utils.build_point_in_time_query( feature_view_query_contexts=query_context, left_table_query_string=tmp_entity_df_table_name, entity_df_event_timestamp_col=event_timestamp_col, entity_df_columns=entity_schema.keys(), query_template=MULTIPLE_FEATURE_VIEW_POINT_IN_TIME_JOIN, full_feature_names=full_feature_names, ) return SparkRetrievalJob( spark_session=spark_session, query=query, full_feature_names=full_feature_names, on_demand_feature_views=OnDemandFeatureView.get_requested_odfvs( feature_refs, project, registry), metadata=RetrievalMetadata( features=feature_refs, keys=list(set(entity_schema.keys()) - {event_timestamp_col}), min_event_timestamp=entity_df_event_timestamp_range[0], max_event_timestamp=entity_df_event_timestamp_range[1], ), )