def test_sink(input_dataframe, feature_set): # arrange client = SparkClient() client.conn.conf.set("spark.sql.sources.partitionOverwriteMode", "dynamic") feature_set_df = feature_set.construct(input_dataframe, client) target_latest_df = OnlineFeatureStoreWriter.filter_latest( feature_set_df, id_columns=[key.name for key in feature_set.keys]) columns_sort = feature_set_df.schema.fieldNames() # setup historical writer s3config = Mock() s3config.mode = "overwrite" s3config.format_ = "parquet" s3config.get_options = Mock( return_value={"path": "test_folder/historical/entity/feature_set"}) s3config.get_path_with_partitions = Mock( return_value="test_folder/historical/entity/feature_set") historical_writer = HistoricalFeatureStoreWriter(db_config=s3config, interval_mode=True) # setup online writer # TODO: Change for CassandraConfig when Cassandra for test is ready online_config = Mock() online_config.mode = "overwrite" online_config.format_ = "parquet" online_config.get_options = Mock( return_value={"path": "test_folder/online/entity/feature_set"}) online_writer = OnlineFeatureStoreWriter(db_config=online_config) writers = [historical_writer, online_writer] sink = Sink(writers) # act client.sql("CREATE DATABASE IF NOT EXISTS {}".format( historical_writer.database)) sink.flush(feature_set, feature_set_df, client) # get historical results historical_result_df = client.read( s3config.format_, path=s3config.get_path_with_partitions(feature_set.name, feature_set_df), ) # get online results online_result_df = client.read( online_config.format_, **online_config.get_options(feature_set.name)) # assert # assert historical results assert sorted(feature_set_df.select(*columns_sort).collect()) == sorted( historical_result_df.select(*columns_sort).collect()) # assert online results assert sorted(target_latest_df.select(*columns_sort).collect()) == sorted( online_result_df.select(*columns_sort).collect()) # tear down shutil.rmtree("test_folder")
def construct(self, client: SparkClient) -> DataFrame: """Construct an entry point dataframe for a feature set. This method will assemble multiple readers, by building each one and querying them using a Spark SQL. After that, there's the caching of the dataframe, however since cache() in Spark is lazy, an action is triggered in order to force persistence. Args: client: client responsible for connecting to Spark session. Returns: DataFrame with the query result against all readers. """ for reader in self.readers: reader.build(client) # create temporary views for each reader dataframe = client.sql(self.query) if not dataframe.isStreaming: dataframe.cache().count() return dataframe
def test_sql(self, target_df: DataFrame) -> None: # arrange spark_client = SparkClient() create_temp_view(target_df, "test") # act result_df = spark_client.sql("select * from test") # assert assert result_df.collect() == target_df.collect()
def construct( self, client: SparkClient, start_date: str = None, end_date: str = None ) -> DataFrame: """Construct an entry point dataframe for a feature set. This method will assemble multiple readers, by building each one and querying them using a Spark SQL. It's important to highlight that in order to filter a dataframe regarding date boundaries, it's important to define a IncrementalStrategy, otherwise your data will not be filtered. Besides, both start and end dates parameters are optional. After that, there's the caching of the dataframe, however since cache() in Spark is lazy, an action is triggered in order to force persistence. Args: client: client responsible for connecting to Spark session. start_date: user defined start date for filtering. end_date: user defined end date for filtering. Returns: DataFrame with the query result against all readers. """ for reader in self.readers: reader.build( client=client, start_date=start_date, end_date=end_date ) # create temporary views for each reader dataframe = client.sql(self.query) if not dataframe.isStreaming: dataframe.cache().count() post_hook_df = self.run_post_hooks(dataframe) return post_hook_df