def test_write( self, feature_set_dataframe, historical_feature_set_dataframe, mocker, feature_set, ): # given spark_client = mocker.stub("spark_client") spark_client.write_table = mocker.stub("write_table") writer = HistoricalFeatureStoreWriter() # when writer.write( feature_set=feature_set, dataframe=feature_set_dataframe, spark_client=spark_client, ) result_df = spark_client.write_table.call_args[1]["dataframe"] # then assert_dataframe_equality(historical_feature_set_dataframe, result_df) assert (writer.db_config.format_ == spark_client.write_table.call_args[1]["format_"]) assert writer.db_config.mode == spark_client.write_table.call_args[1][ "mode"] assert (writer.PARTITION_BY == spark_client.write_table.call_args[1] ["partition_by"]) assert feature_set.name == spark_client.write_table.call_args[1][ "table_name"]
def test_write_interval_mode( self, feature_set_dataframe, historical_feature_set_dataframe, mocker, feature_set, ): # given spark_client = SparkClient() spark_client.write_table = mocker.stub("write_table") spark_client.conn.conf.set("spark.sql.sources.partitionOverwriteMode", "dynamic") writer = HistoricalFeatureStoreWriter(interval_mode=True) # when writer.write( feature_set=feature_set, dataframe=feature_set_dataframe, spark_client=spark_client, ) result_df = spark_client.write_table.call_args[1]["dataframe"] # then assert_dataframe_equality(historical_feature_set_dataframe, result_df) assert writer.database == spark_client.write_table.call_args[1][ "database"] assert feature_set.name == spark_client.write_table.call_args[1][ "table_name"] assert (writer.PARTITION_BY == spark_client.write_table.call_args[1] ["partition_by"])
def test_write_in_debug_mode_with_interval_mode( self, feature_set_dataframe, historical_feature_set_dataframe, feature_set, spark_session, mocker, ): # given spark_client = SparkClient() spark_client.write_dataframe = mocker.stub("write_dataframe") spark_client.conn.conf.set("spark.sql.sources.partitionOverwriteMode", "dynamic") writer = HistoricalFeatureStoreWriter(debug_mode=True, interval_mode=True) # when writer.write( feature_set=feature_set, dataframe=feature_set_dataframe, spark_client=spark_client, ) result_df = spark_session.table( f"historical_feature_store__{feature_set.name}") # then assert_dataframe_equality(historical_feature_set_dataframe, result_df)
def test_write_in_debug_mode( self, feature_set_dataframe, historical_feature_set_dataframe, feature_set, spark_session, ): # given spark_client = SparkClient() writer = HistoricalFeatureStoreWriter(debug_mode=True) # when writer.write( feature_set=feature_set, dataframe=feature_set_dataframe, spark_client=spark_client, ) result_df = spark_session.table( f"historical_feature_store__{feature_set.name}") # then assert_dataframe_equality(historical_feature_set_dataframe, result_df)
def test_write_interval_mode_invalid_partition_mode( self, feature_set_dataframe, historical_feature_set_dataframe, mocker, feature_set, ): # given spark_client = SparkClient() spark_client.write_dataframe = mocker.stub("write_dataframe") spark_client.conn.conf.set("spark.sql.sources.partitionOverwriteMode", "static") writer = HistoricalFeatureStoreWriter(interval_mode=True) # when with pytest.raises(RuntimeError): _ = writer.write( feature_set=feature_set, dataframe=feature_set_dataframe, spark_client=spark_client, )