def test_feature_set_start_date( self, timestamp_c, feature_set_with_distinct_dataframe, ): fs = AggregatedFeatureSet( name="name", entity="entity", description="description", features=[ Feature( name="feature", description="test", transformation=AggregatedTransform( functions=[Function(functions.sum, DataType.INTEGER)]), ), ], keys=[ KeyFeature(name="h3", description="test", dtype=DataType.STRING) ], timestamp=timestamp_c, ).with_windows(["10 days", "3 weeks", "90 days"]) # assert start_date = fs.define_start_date("2016-04-14") assert start_date == "2016-01-14"
def feature_set(): feature_set = FeatureSet( name="feature_set", entity="entity", description="description", features=[ Feature( name="feature1", description="test", transformation=SparkFunctionTransform(functions=[ Function(functions.avg, DataType.FLOAT), Function(functions.stddev_pop, DataType.DOUBLE), ]).with_window( partition_by="id", order_by=TIMESTAMP_COLUMN, mode="fixed_windows", window_definition=["2 minutes", "15 minutes"], ), ), ], keys=[ KeyFeature( name="id", description="The user's Main ID or device ID", dtype=DataType.BIGINT, ) ], timestamp=TimestampFeature(), ) return feature_set
def transformer(): # primary key keys = [ KeyFeature( name="customer_id", description="Unique identificator code for customer.", from_column="customer_id", dtype=DataType.STRING, ) ] ts_feature = TimestampFeature(from_column="order_created_at") # features transformations features = [ #order_total_amount(), count_items_in_order(), avg_order_total_amount_from_last_1_month(), ratio_order_amount_and_items(), ratio_order_amount_and_average_ticket() ] # joining all together feature_set = FeatureSet( name="orders_feature_master_table", entity= "orders_feature_master_table", # entity: to which "business context" this feature set belongs description="Features describring events about ifood store.", keys=keys, timestamp=ts_feature, features=features, ) return feature_set
def test_feature_set(): return AggregatedFeatureSet( name="feature_set", entity="entity", description="description", features=[ Feature( name="feature1", description="test", transformation=AggregatedTransform(functions=[ Function(functions.avg, DataType.DOUBLE), Function(functions.stddev_pop, DataType.DOUBLE), ]), ), Feature( name="feature2", description="test", transformation=AggregatedTransform( functions=[Function(functions.count, DataType.INTEGER)]), ), ], keys=[ KeyFeature( name="id", description="The user's Main ID or device ID", dtype=DataType.BIGINT, ) ], timestamp=TimestampFeature(), ).with_windows(definitions=["1 week", "2 days"])
def test_construct( self, feature_set_dataframe, fixed_windows_output_feature_set_dataframe ): # given spark_client = SparkClient() # arrange feature_set = FeatureSet( name="feature_set", entity="entity", description="description", features=[ Feature( name="feature1", description="test", transformation=SparkFunctionTransform( functions=[ Function(F.avg, DataType.FLOAT), Function(F.stddev_pop, DataType.FLOAT), ] ).with_window( partition_by="id", order_by=TIMESTAMP_COLUMN, mode="fixed_windows", window_definition=["2 minutes", "15 minutes"], ), ), Feature( name="divided_feature", description="unit test", dtype=DataType.FLOAT, transformation=CustomTransform( transformer=divide, column1="feature1", column2="feature2", ), ), ], keys=[ KeyFeature( name="id", description="The user's Main ID or device ID", dtype=DataType.INTEGER, ) ], timestamp=TimestampFeature(), ) output_df = ( feature_set.construct(feature_set_dataframe, client=spark_client) .orderBy(feature_set.timestamp_column) .select(feature_set.columns) ) target_df = fixed_windows_output_feature_set_dataframe.orderBy( feature_set.timestamp_column ).select(feature_set.columns) # assert assert_dataframe_equality(output_df, target_df)
def test_h3_feature_set(self, h3_input_df, h3_target_df): spark_client = SparkClient() feature_set = AggregatedFeatureSet( name="h3_test", entity="h3geolocation", description="Test", keys=[ KeyFeature( name="h3_id", description="The h3 hash ID", dtype=DataType.DOUBLE, transformation=H3HashTransform( h3_resolutions=[6, 7, 8, 9, 10, 11, 12], lat_column="lat", lng_column="lng", ).with_stack(), ) ], timestamp=TimestampFeature(), features=[ Feature( name="house_id", description="Count of house ids over a day.", transformation=AggregatedTransform( functions=[Function(F.count, DataType.BIGINT)]), ), ], ).with_windows(definitions=["1 day"]) output_df = feature_set.construct(h3_input_df, client=spark_client, end_date="2016-04-14") assert_dataframe_equality(output_df, h3_target_df)
def test_feature_transform_with_distinct( self, timestamp_c, feature_set_with_distinct_dataframe, target_with_distinct_dataframe, ): spark_client = SparkClient() fs = (AggregatedFeatureSet( name="name", entity="entity", description="description", features=[ Feature( name="feature", description="test", transformation=AggregatedTransform( functions=[Function(functions.sum, DataType.INTEGER)]), ), ], keys=[ KeyFeature(name="h3", description="test", dtype=DataType.STRING) ], timestamp=timestamp_c, ).with_windows(["3 days"]).with_distinct(subset=["id"], keep="last")) # assert output_df = fs.construct(feature_set_with_distinct_dataframe, spark_client, end_date="2020-01-10") assert_dataframe_equality(output_df, target_with_distinct_dataframe)
def __init__(self): super(FirstPipeline, self).__init__( source=Source( readers=[TableReader(id="t", database="db", table="table",)], query=f"select * from t", # noqa ), feature_set=FeatureSet( name="first", entity="entity", description="description", features=[ Feature(name="feature1", description="test", dtype=DataType.FLOAT,), Feature( name="feature2", description="another test", dtype=DataType.STRING, ), ], keys=[ KeyFeature( name="id", description="identifier", dtype=DataType.BIGINT, ) ], timestamp=TimestampFeature(), ), sink=Sink( writers=[HistoricalFeatureStoreWriter(), OnlineFeatureStoreWriter()] ), )
def test_get_schema(self): expected_schema = [ {"column_name": "id", "type": LongType(), "primary_key": True}, {"column_name": "timestamp", "type": TimestampType(), "primary_key": False}, { "column_name": "feature1__avg_over_2_minutes_fixed_windows", "type": FloatType(), "primary_key": False, }, { "column_name": "feature1__avg_over_15_minutes_fixed_windows", "type": FloatType(), "primary_key": False, }, { "column_name": "feature1__stddev_pop_over_2_minutes_fixed_windows", "type": DoubleType(), "primary_key": False, }, { "column_name": "feature1__stddev_pop_over_15_minutes_fixed_windows", "type": DoubleType(), "primary_key": False, }, ] feature_set = FeatureSet( name="feature_set", entity="entity", description="description", features=[ Feature( name="feature1", description="test", transformation=SparkFunctionTransform( functions=[ Function(F.avg, DataType.FLOAT), Function(F.stddev_pop, DataType.DOUBLE), ] ).with_window( partition_by="id", order_by=TIMESTAMP_COLUMN, mode="fixed_windows", window_definition=["2 minutes", "15 minutes"], ), ), ], keys=[ KeyFeature( name="id", description="The user's Main ID or device ID", dtype=DataType.BIGINT, ) ], timestamp=TimestampFeature(), ) schema = feature_set.get_schema() assert schema == expected_schema
def test_feature_transform_with_distinct_empty_subset( self, timestamp_c, feature_set_with_distinct_dataframe): spark_client = SparkClient() with pytest.raises(ValueError, match="The distinct subset param can't be empty."): AggregatedFeatureSet( name="name", entity="entity", description="description", features=[ Feature( name="feature", description="test", transformation=AggregatedTransform(functions=[ Function(functions.sum, DataType.INTEGER) ]), ), ], keys=[ KeyFeature(name="h3", description="test", dtype=DataType.STRING) ], timestamp=timestamp_c, ).with_windows(["3 days"]).with_distinct( subset=[], keep="first").construct(feature_set_with_distinct_dataframe, spark_client, end_date="2020-01-10")
def agg_feature_set(): return AggregatedFeatureSet( name="name", entity="entity", description="description", features=[ Feature( name="feature1", description="test", transformation=AggregatedTransform( functions=[Function(functions.avg, DataType.DOUBLE)], ), ), Feature( name="feature2", description="test", transformation=AggregatedTransform( functions=[Function(functions.avg, DataType.DOUBLE)]), ), ], keys=[ KeyFeature( name="id", description="description", dtype=DataType.BIGINT, ) ], timestamp=TimestampFeature(), )
def __init__(self): super(UserChargebacksPipeline, self).__init__( source=Source( readers=[ FileReader( id="chargeback_events", path="data/order_events/input.csv", format="csv", format_options={"header": True}, ) ], query=(""" select cpf, timestamp(chargeback_timestamp) as timestamp, order_id from chargeback_events where chargeback_timestamp is not null """), ), feature_set=AggregatedFeatureSet( name="user_chargebacks", entity="user", description="Aggregates the total of chargebacks from users in " "different time windows.", keys=[ KeyFeature( name="cpf", description="User unique identifier, entity key.", dtype=DataType.STRING, ) ], timestamp=TimestampFeature(), features=[ Feature( name="cpf_chargebacks", description= "Total of chargebacks registered on user's CPF", transformation=AggregatedTransform(functions=[ Function(functions.count, DataType.INTEGER) ]), from_column="order_id", ), ], ).with_windows( definitions=["3 days", "7 days", "30 days"]).add_post_hook( ZeroFillHook()), sink=Sink(writers=[ LocalHistoricalFSWriter(), OnlineFeatureStoreWriter( interval_mode=True, check_schema_hook=NotCheckSchemaHook(), debug_mode=True, ), ]), )
def test_run_agg_with_end_date(self, spark_session): test_pipeline = FeatureSetPipeline( spark_client=SparkClient(), source=Mock( spec=Source, readers=[ TableReader( id="source_a", database="db", table="table", ) ], query="select * from source_a", ), feature_set=Mock( spec=AggregatedFeatureSet, name="feature_set", entity="entity", description="description", keys=[ KeyFeature( name="user_id", description="The user's Main ID or device ID", dtype=DataType.INTEGER, ) ], timestamp=TimestampFeature(from_column="ts"), features=[ Feature( name="listing_page_viewed__rent_per_month", description="Average of something.", transformation=AggregatedTransform(functions=[ Function(functions.avg, DataType.FLOAT), Function(functions.stddev_pop, DataType.FLOAT), ], ), ), ], ), sink=Mock( spec=Sink, writers=[HistoricalFeatureStoreWriter(db_config=None)], ), ) # feature_set need to return a real df for streaming validation sample_df = spark_session.createDataFrame([{ "a": "x", "b": "y", "c": "3" }]) test_pipeline.feature_set.construct.return_value = sample_df test_pipeline.run(end_date="2016-04-18") test_pipeline.source.construct.assert_called_once() test_pipeline.feature_set.construct.assert_called_once() test_pipeline.sink.flush.assert_called_once() test_pipeline.sink.validate.assert_called_once()
def test_construct_rolling_windows_with_end_date( self, feature_set_dataframe, rolling_windows_output_feature_set_dataframe_base_date, ): # given spark_client = SparkClient() # arrange feature_set = AggregatedFeatureSet( name="feature_set", entity="entity", description="description", features=[ Feature( name="feature1", description="test", transformation=AggregatedTransform( functions=[ Function(F.avg, DataType.DOUBLE), Function(F.stddev_pop, DataType.DOUBLE), ], ), ), Feature( name="feature2", description="test", transformation=AggregatedTransform( functions=[ Function(F.avg, DataType.DOUBLE), Function(F.stddev_pop, DataType.DOUBLE), ], ), ), ], keys=[ KeyFeature( name="id", description="The user's Main ID or device ID", dtype=DataType.INTEGER, ) ], timestamp=TimestampFeature(), ).with_windows(definitions=["1 day", "1 week"]) # act output_df = feature_set.construct( feature_set_dataframe, client=spark_client, end_date="2016-04-18" ).orderBy("timestamp") target_df = rolling_windows_output_feature_set_dataframe_base_date.orderBy( feature_set.timestamp_column ).select(feature_set.columns) # assert assert_dataframe_equality(output_df, target_df)
def test_feature_transform(self, spark_context, spark_session): # arrange target_data = [ { "id": 1, "feature": 100, "id_a": 1, "id_b": 2 }, { "id": 2, "feature": 100, "id_a": 1, "id_b": 2 }, { "id": 3, "feature": 120, "id_a": 3, "id_b": 4 }, { "id": 4, "feature": 120, "id_a": 3, "id_b": 4 }, ] input_df = create_df_from_collection(self.input_data, spark_context, spark_session) target_df = create_df_from_collection(target_data, spark_context, spark_session) feature_using_names = KeyFeature( name="id", description="id_a and id_b stacked in a single column.", dtype=DataType.INTEGER, transformation=StackTransform("id_*"), ) # act result_df_1 = feature_using_names.transform(input_df) # assert assert_dataframe_equality(target_df, result_df_1)
def test_with_stack(self, h3_input_df, h3_with_stack_target_df): # arrange test_feature = KeyFeature( name="id", description="unit test", dtype=DataType.STRING, transformation=H3HashTransform( h3_resolutions=[6, 7, 8, 9, 10, 11, 12], lat_column="lat", lng_column="lng", ).with_stack(), ) # act output_df = test_feature.transform(h3_input_df) # assert assert_dataframe_equality(h3_with_stack_target_df, output_df)
def test_pipeline_with_hooks(self, spark_session): # arrange hook1 = AddHook(value=1) spark_session.sql( "select 1 as id, timestamp('2020-01-01') as timestamp, 0 as feature" ).createOrReplaceTempView("test") target_df = spark_session.sql( "select 1 as id, timestamp('2020-01-01') as timestamp, 6 as feature, 2020 " "as year, 1 as month, 1 as day") historical_writer = HistoricalFeatureStoreWriter(debug_mode=True) test_pipeline = FeatureSetPipeline( source=Source( readers=[ TableReader( id="reader", table="test", ).add_post_hook(hook1) ], query="select * from reader", ).add_post_hook(hook1), feature_set=FeatureSet( name="feature_set", entity="entity", description="description", features=[ Feature( name="feature", description="test", transformation=SQLExpressionTransform( expression="feature + 1"), dtype=DataType.INTEGER, ), ], keys=[ KeyFeature( name="id", description="The user's Main ID or device ID", dtype=DataType.INTEGER, ) ], timestamp=TimestampFeature(), ).add_pre_hook(hook1).add_post_hook(hook1), sink=Sink(writers=[historical_writer], ).add_pre_hook(hook1), ) # act test_pipeline.run() output_df = spark_session.table( "historical_feature_store__feature_set") # assert output_df.show() assert_dataframe_equality(output_df, target_df)
def test_args_without_transformation(self): test_key = KeyFeature( name="id", from_column="origin", description="unit test", dtype=DataType.INTEGER, ) assert test_key.name == "id" assert test_key.from_column == "origin" assert test_key.description == "unit test"
def test_source_raise(self): with pytest.raises(ValueError, match="source must be a Source instance"): FeatureSetPipeline( spark_client=SparkClient(), source=Mock( spark_client=SparkClient(), readers=[ TableReader( id="source_a", database="db", table="table", ), ], query="select * from source_a", ), feature_set=Mock( spec=FeatureSet, name="feature_set", entity="entity", description="description", keys=[ KeyFeature( name="user_id", description="The user's Main ID or device ID", dtype=DataType.INTEGER, ) ], timestamp=TimestampFeature(from_column="ts"), features=[ Feature( name="listing_page_viewed__rent_per_month", description="Average of something.", transformation=SparkFunctionTransform(functions=[ Function(functions.avg, DataType.FLOAT), Function(functions.stddev_pop, DataType.FLOAT), ], ).with_window( partition_by="user_id", order_by=TIMESTAMP_COLUMN, window_definition=["7 days", "2 weeks"], mode="fixed_windows", ), ), ], ), sink=Mock( spec=Sink, writers=[HistoricalFeatureStoreWriter(db_config=None)], ), )
def feature_set(): key_features = [ KeyFeature(name="id", description="Description", dtype=DataType.INTEGER) ] ts_feature = TimestampFeature(from_column=TIMESTAMP_COLUMN) features = [ Feature(name="feature", description="Description", dtype=DataType.BIGINT,) ] return FeatureSet( "feature_set", "entity", "description", keys=key_features, timestamp=ts_feature, features=features, )
def feature_set_pipeline( spark_context, spark_session, ): feature_set_pipeline = FeatureSetPipeline( source=Source( readers=[ TableReader(id="b_source", table="b_table",).with_incremental_strategy( incremental_strategy=IncrementalStrategy(column="timestamp") ), ], query=f"select * from b_source ", # noqa ), feature_set=FeatureSet( name="feature_set", entity="entity", description="description", features=[ Feature( name="feature", description="test", transformation=SparkFunctionTransform( functions=[ Function(F.avg, DataType.FLOAT), Function(F.stddev_pop, DataType.FLOAT), ], ).with_window( partition_by="id", order_by=TIMESTAMP_COLUMN, mode="fixed_windows", window_definition=["1 day"], ), ), ], keys=[ KeyFeature( name="id", description="The user's Main ID or device ID", dtype=DataType.INTEGER, ) ], timestamp=TimestampFeature(), ), sink=Sink(writers=[HistoricalFeatureStoreWriter(debug_mode=True)]), ) return feature_set_pipeline
def test_construct_without_window( self, feature_set_dataframe, target_df_without_window, ): # given spark_client = SparkClient() # arrange feature_set = AggregatedFeatureSet( name="feature_set", entity="entity", description="description", features=[ Feature( name="feature1", description="test", dtype=DataType.DOUBLE, transformation=AggregatedTransform( functions=[Function(F.avg, DataType.DOUBLE)]), ), Feature( name="feature2", description="test", dtype=DataType.FLOAT, transformation=AggregatedTransform( functions=[Function(F.count, DataType.BIGINT)]), ), ], keys=[ KeyFeature( name="id", description="The user's Main ID or device ID", dtype=DataType.INTEGER, ) ], timestamp=TimestampFeature(from_column="fixed_ts"), ) # act output_df = feature_set.construct(feature_set_dataframe, client=spark_client) # assert assert_dataframe_equality(output_df, target_df_without_window)
def test_construct_with_date_boundaries( self, feature_set_dates_dataframe, feature_set_dates_output_dataframe): # given spark_client = SparkClient() # arrange feature_set = FeatureSet( name="feature_set", entity="entity", description="description", features=[ Feature( name="feature", description="test", dtype=DataType.FLOAT, ), ], keys=[ KeyFeature( name="id", description="The user's Main ID or device ID", dtype=DataType.INTEGER, ) ], timestamp=TimestampFeature(), ) output_df = (feature_set.construct( feature_set_dates_dataframe, client=spark_client, start_date="2016-04-11", end_date="2016-04-12", ).orderBy(feature_set.timestamp_column).select(feature_set.columns)) target_df = feature_set_dates_output_dataframe.orderBy( feature_set.timestamp_column).select(feature_set.columns) # assert assert_dataframe_equality(output_df, target_df)
def test_construct_with_pivot( self, feature_set_df_pivot, target_df_pivot_agg, ): # given spark_client = SparkClient() # arrange feature_set = AggregatedFeatureSet( name="feature_set", entity="entity", description="description", features=[ Feature( name="feature", description="unit test", transformation=AggregatedTransform(functions=[ Function(F.avg, DataType.FLOAT), Function(F.stddev_pop, DataType.DOUBLE), ], ), from_column="feature1", ) ], keys=[ KeyFeature( name="id", description="The user's Main ID or device ID", dtype=DataType.INTEGER, ) ], timestamp=TimestampFeature(from_column="fixed_ts"), ).with_pivot("pivot_col", ["S", "N"]) # act output_df = feature_set.construct(feature_set_df_pivot, client=spark_client) # assert assert_dataframe_equality(output_df, target_df_pivot_agg)
def feature_set(): key_features = [ KeyFeature(name="id", description="Description", dtype=DataType.INTEGER) ] ts_feature = TimestampFeature(from_column="timestamp") features = [ Feature(name="feature", description="Description", dtype=DataType.FLOAT), ] return FeatureSet( "test_sink_feature_set", "test_sink_entity", "description", keys=key_features, timestamp=ts_feature, features=features, )
def feature_set_incremental(): key_features = [ KeyFeature(name="id", description="Description", dtype=DataType.INTEGER) ] ts_feature = TimestampFeature(from_column=TIMESTAMP_COLUMN) features = [ Feature( name="feature", description="test", transformation=AggregatedTransform( functions=[Function(functions.sum, DataType.INTEGER)] ), ), ] return AggregatedFeatureSet( "feature_set", "entity", "description", keys=key_features, timestamp=ts_feature, features=features, )
def test_construct_rolling_windows_without_end_date( self, feature_set_dataframe, rolling_windows_output_feature_set_dataframe ): # given spark_client = SparkClient() # arrange feature_set = AggregatedFeatureSet( name="feature_set", entity="entity", description="description", features=[ Feature( name="feature1", description="test", transformation=AggregatedTransform( functions=[ Function(F.avg, DataType.DOUBLE), Function(F.stddev_pop, DataType.DOUBLE), ], ), ), ], keys=[ KeyFeature( name="id", description="The user's Main ID or device ID", dtype=DataType.INTEGER, ) ], timestamp=TimestampFeature(), ).with_windows(definitions=["1 day", "1 week"],) # act & assert with pytest.raises(ValueError): _ = feature_set.construct(feature_set_dataframe, client=spark_client)
def test_feature_transform_with_data_type_array(self, spark_context, spark_session): # arrange input_data = [ { "id": 1, "timestamp": "2020-04-22T00:00:00+00:00", "feature": 10 }, { "id": 1, "timestamp": "2020-04-22T00:00:00+00:00", "feature": 20 }, { "id": 1, "timestamp": "2020-04-22T00:00:00+00:00", "feature": 30 }, { "id": 2, "timestamp": "2020-04-22T00:00:00+00:00", "feature": 10 }, ] target_data = [ { "id": 1, "timestamp": "2020-04-22T00:00:00+00:00", "feature__collect_set": [30.0, 20.0, 10.0], }, { "id": 2, "timestamp": "2020-04-22T00:00:00+00:00", "feature__collect_set": [10.0], }, ] input_df = create_df_from_collection( input_data, spark_context, spark_session).withColumn( "timestamp", functions.to_timestamp(functions.col("timestamp"))) target_df = create_df_from_collection( target_data, spark_context, spark_session).withColumn( "timestamp", functions.to_timestamp(functions.col("timestamp"))) fs = AggregatedFeatureSet( name="name", entity="entity", description="description", keys=[ KeyFeature(name="id", description="test", dtype=DataType.INTEGER) ], timestamp=TimestampFeature(), features=[ Feature( name="feature", description="aggregations with ", dtype=DataType.BIGINT, transformation=AggregatedTransform(functions=[ Function(functions.collect_set, DataType.ARRAY_FLOAT), ], ), from_column="feature", ), ], ) # act output_df = fs.construct(input_df, SparkClient()) # assert assert_dataframe_equality(target_df, output_df)
def test_feature_transform_with_filter_expression(self, spark_context, spark_session): # arrange input_data = [ { "id": 1, "timestamp": "2020-04-22T00:00:00+00:00", "feature": 10, "type": "a", }, { "id": 1, "timestamp": "2020-04-22T00:00:00+00:00", "feature": 20, "type": "a", }, { "id": 1, "timestamp": "2020-04-22T00:00:00+00:00", "feature": 30, "type": "b", }, { "id": 2, "timestamp": "2020-04-22T00:00:00+00:00", "feature": 10, "type": "a", }, ] target_data = [ { "id": 1, "timestamp": "2020-04-22T00:00:00+00:00", "feature_only_type_a__avg": 15.0, "feature_only_type_a__min": 10, "feature_only_type_a__max": 20, }, { "id": 2, "timestamp": "2020-04-22T00:00:00+00:00", "feature_only_type_a__avg": 10.0, "feature_only_type_a__min": 10, "feature_only_type_a__max": 10, }, ] input_df = create_df_from_collection( input_data, spark_context, spark_session).withColumn( "timestamp", functions.to_timestamp(functions.col("timestamp"))) target_df = create_df_from_collection( target_data, spark_context, spark_session).withColumn( "timestamp", functions.to_timestamp(functions.col("timestamp"))) fs = AggregatedFeatureSet( name="name", entity="entity", description="description", keys=[ KeyFeature(name="id", description="test", dtype=DataType.INTEGER) ], timestamp=TimestampFeature(), features=[ Feature( name="feature_only_type_a", description="aggregations only when type = a", dtype=DataType.BIGINT, transformation=AggregatedTransform( functions=[ Function(functions.avg, DataType.FLOAT), Function(functions.min, DataType.FLOAT), Function(functions.max, DataType.FLOAT), ], filter_expression="type = 'a'", ), from_column="feature", ), ], ) # act output_df = fs.construct(input_df, SparkClient()) # assert assert_dataframe_equality(target_df, output_df)
def test_feature_set_args(self): # arrange and act out_columns = [ "user_id", "timestamp", "listing_page_viewed__rent_per_month__avg_over_7_days_fixed_windows", "listing_page_viewed__rent_per_month__avg_over_2_weeks_fixed_windows", "listing_page_viewed__rent_per_month__stddev_pop_over_7_days_fixed_windows", "listing_page_viewed__rent_per_month__" "stddev_pop_over_2_weeks_fixed_windows", # noqa ] pipeline = FeatureSetPipeline( source=Source( readers=[ TableReader( id="source_a", database="db", table="table", ), FileReader( id="source_b", path="path", format="parquet", ), ], query="select a.*, b.specific_feature " "from source_a left join source_b on a.id=b.id", ), feature_set=FeatureSet( name="feature_set", entity="entity", description="description", keys=[ KeyFeature( name="user_id", description="The user's Main ID or device ID", dtype=DataType.INTEGER, ) ], timestamp=TimestampFeature(from_column="ts"), features=[ Feature( name="listing_page_viewed__rent_per_month", description="Average of something.", transformation=SparkFunctionTransform(functions=[ Function(functions.avg, DataType.FLOAT), Function(functions.stddev_pop, DataType.FLOAT), ], ).with_window( partition_by="user_id", order_by=TIMESTAMP_COLUMN, window_definition=["7 days", "2 weeks"], mode="fixed_windows", ), ), ], ), sink=Sink(writers=[ HistoricalFeatureStoreWriter(db_config=None), OnlineFeatureStoreWriter(db_config=None), ], ), ) assert isinstance(pipeline.spark_client, SparkClient) assert len(pipeline.source.readers) == 2 assert all( isinstance(reader, Reader) for reader in pipeline.source.readers) assert isinstance(pipeline.source.query, str) assert pipeline.feature_set.name == "feature_set" assert pipeline.feature_set.entity == "entity" assert pipeline.feature_set.description == "description" assert isinstance(pipeline.feature_set.timestamp, TimestampFeature) assert len(pipeline.feature_set.keys) == 1 assert all( isinstance(k, KeyFeature) for k in pipeline.feature_set.keys) assert len(pipeline.feature_set.features) == 1 assert all( isinstance(feature, Feature) for feature in pipeline.feature_set.features) assert pipeline.feature_set.columns == out_columns assert len(pipeline.sink.writers) == 2 assert all( isinstance(writer, Writer) for writer in pipeline.sink.writers)