def test_h3_feature_set(self, h3_input_df, h3_target_df): spark_client = SparkClient() feature_set = AggregatedFeatureSet( name="h3_test", entity="h3geolocation", description="Test", keys=[ KeyFeature( name="h3_id", description="The h3 hash ID", dtype=DataType.DOUBLE, transformation=H3HashTransform( h3_resolutions=[6, 7, 8, 9, 10, 11, 12], lat_column="lat", lng_column="lng", ).with_stack(), ) ], timestamp=TimestampFeature(), features=[ Feature( name="house_id", description="Count of house ids over a day.", transformation=AggregatedTransform( functions=[Function(F.count, DataType.BIGINT)]), ), ], ).with_windows(definitions=["1 day"]) output_df = feature_set.construct(h3_input_df, client=spark_client, end_date="2016-04-14") assert_dataframe_equality(output_df, h3_target_df)
def test_construct_rolling_windows_with_end_date( self, feature_set_dataframe, rolling_windows_output_feature_set_dataframe_base_date, ): # given spark_client = SparkClient() # arrange feature_set = AggregatedFeatureSet( name="feature_set", entity="entity", description="description", features=[ Feature( name="feature1", description="test", transformation=AggregatedTransform( functions=[ Function(F.avg, DataType.DOUBLE), Function(F.stddev_pop, DataType.DOUBLE), ], ), ), Feature( name="feature2", description="test", transformation=AggregatedTransform( functions=[ Function(F.avg, DataType.DOUBLE), Function(F.stddev_pop, DataType.DOUBLE), ], ), ), ], keys=[ KeyFeature( name="id", description="The user's Main ID or device ID", dtype=DataType.INTEGER, ) ], timestamp=TimestampFeature(), ).with_windows(definitions=["1 day", "1 week"]) # act output_df = feature_set.construct( feature_set_dataframe, client=spark_client, end_date="2016-04-18" ).orderBy("timestamp") target_df = rolling_windows_output_feature_set_dataframe_base_date.orderBy( feature_set.timestamp_column ).select(feature_set.columns) # assert assert_dataframe_equality(output_df, target_df)
def test_agg_feature_set_with_window(self, key_id, timestamp_c, dataframe, rolling_windows_agg_dataframe): spark_client = SparkClient() fs = AggregatedFeatureSet( name="name", entity="entity", description="description", features=[ Feature( name="feature1", description="unit test", transformation=AggregatedTransform( functions=[Function(functions.avg, DataType.FLOAT)]), ), Feature( name="feature2", description="unit test", transformation=AggregatedTransform( functions=[Function(functions.avg, DataType.FLOAT)]), ), ], keys=[key_id], timestamp=timestamp_c, ).with_windows(definitions=["1 week"]) # raises without end date with pytest.raises(ValueError): _ = fs.construct(dataframe, spark_client) # filters with date smaller then mocked max output_df = fs.construct(dataframe, spark_client, end_date="2016-04-17") assert output_df.count() < rolling_windows_agg_dataframe.count() output_df = fs.construct(dataframe, spark_client, end_date="2016-05-01") assert_dataframe_equality(output_df, rolling_windows_agg_dataframe)
def test_construct_without_window( self, feature_set_dataframe, target_df_without_window, ): # given spark_client = SparkClient() # arrange feature_set = AggregatedFeatureSet( name="feature_set", entity="entity", description="description", features=[ Feature( name="feature1", description="test", dtype=DataType.DOUBLE, transformation=AggregatedTransform( functions=[Function(F.avg, DataType.DOUBLE)]), ), Feature( name="feature2", description="test", dtype=DataType.FLOAT, transformation=AggregatedTransform( functions=[Function(F.count, DataType.BIGINT)]), ), ], keys=[ KeyFeature( name="id", description="The user's Main ID or device ID", dtype=DataType.INTEGER, ) ], timestamp=TimestampFeature(from_column="fixed_ts"), ) # act output_df = feature_set.construct(feature_set_dataframe, client=spark_client) # assert assert_dataframe_equality(output_df, target_df_without_window)
def test_construct_with_pivot( self, feature_set_df_pivot, target_df_pivot_agg, ): # given spark_client = SparkClient() # arrange feature_set = AggregatedFeatureSet( name="feature_set", entity="entity", description="description", features=[ Feature( name="feature", description="unit test", transformation=AggregatedTransform(functions=[ Function(F.avg, DataType.FLOAT), Function(F.stddev_pop, DataType.DOUBLE), ], ), from_column="feature1", ) ], keys=[ KeyFeature( name="id", description="The user's Main ID or device ID", dtype=DataType.INTEGER, ) ], timestamp=TimestampFeature(from_column="fixed_ts"), ).with_pivot("pivot_col", ["S", "N"]) # act output_df = feature_set.construct(feature_set_df_pivot, client=spark_client) # assert assert_dataframe_equality(output_df, target_df_pivot_agg)
def test_construct_rolling_windows_without_end_date( self, feature_set_dataframe, rolling_windows_output_feature_set_dataframe ): # given spark_client = SparkClient() # arrange feature_set = AggregatedFeatureSet( name="feature_set", entity="entity", description="description", features=[ Feature( name="feature1", description="test", transformation=AggregatedTransform( functions=[ Function(F.avg, DataType.DOUBLE), Function(F.stddev_pop, DataType.DOUBLE), ], ), ), ], keys=[ KeyFeature( name="id", description="The user's Main ID or device ID", dtype=DataType.INTEGER, ) ], timestamp=TimestampFeature(), ).with_windows(definitions=["1 day", "1 week"],) # act & assert with pytest.raises(ValueError): _ = feature_set.construct(feature_set_dataframe, client=spark_client)
def test_feature_transform_with_data_type_array(self, spark_context, spark_session): # arrange input_data = [ { "id": 1, "timestamp": "2020-04-22T00:00:00+00:00", "feature": 10 }, { "id": 1, "timestamp": "2020-04-22T00:00:00+00:00", "feature": 20 }, { "id": 1, "timestamp": "2020-04-22T00:00:00+00:00", "feature": 30 }, { "id": 2, "timestamp": "2020-04-22T00:00:00+00:00", "feature": 10 }, ] target_data = [ { "id": 1, "timestamp": "2020-04-22T00:00:00+00:00", "feature__collect_set": [30.0, 20.0, 10.0], }, { "id": 2, "timestamp": "2020-04-22T00:00:00+00:00", "feature__collect_set": [10.0], }, ] input_df = create_df_from_collection( input_data, spark_context, spark_session).withColumn( "timestamp", functions.to_timestamp(functions.col("timestamp"))) target_df = create_df_from_collection( target_data, spark_context, spark_session).withColumn( "timestamp", functions.to_timestamp(functions.col("timestamp"))) fs = AggregatedFeatureSet( name="name", entity="entity", description="description", keys=[ KeyFeature(name="id", description="test", dtype=DataType.INTEGER) ], timestamp=TimestampFeature(), features=[ Feature( name="feature", description="aggregations with ", dtype=DataType.BIGINT, transformation=AggregatedTransform(functions=[ Function(functions.collect_set, DataType.ARRAY_FLOAT), ], ), from_column="feature", ), ], ) # act output_df = fs.construct(input_df, SparkClient()) # assert assert_dataframe_equality(target_df, output_df)
def test_feature_transform_with_filter_expression(self, spark_context, spark_session): # arrange input_data = [ { "id": 1, "timestamp": "2020-04-22T00:00:00+00:00", "feature": 10, "type": "a", }, { "id": 1, "timestamp": "2020-04-22T00:00:00+00:00", "feature": 20, "type": "a", }, { "id": 1, "timestamp": "2020-04-22T00:00:00+00:00", "feature": 30, "type": "b", }, { "id": 2, "timestamp": "2020-04-22T00:00:00+00:00", "feature": 10, "type": "a", }, ] target_data = [ { "id": 1, "timestamp": "2020-04-22T00:00:00+00:00", "feature_only_type_a__avg": 15.0, "feature_only_type_a__min": 10, "feature_only_type_a__max": 20, }, { "id": 2, "timestamp": "2020-04-22T00:00:00+00:00", "feature_only_type_a__avg": 10.0, "feature_only_type_a__min": 10, "feature_only_type_a__max": 10, }, ] input_df = create_df_from_collection( input_data, spark_context, spark_session).withColumn( "timestamp", functions.to_timestamp(functions.col("timestamp"))) target_df = create_df_from_collection( target_data, spark_context, spark_session).withColumn( "timestamp", functions.to_timestamp(functions.col("timestamp"))) fs = AggregatedFeatureSet( name="name", entity="entity", description="description", keys=[ KeyFeature(name="id", description="test", dtype=DataType.INTEGER) ], timestamp=TimestampFeature(), features=[ Feature( name="feature_only_type_a", description="aggregations only when type = a", dtype=DataType.BIGINT, transformation=AggregatedTransform( functions=[ Function(functions.avg, DataType.FLOAT), Function(functions.min, DataType.FLOAT), Function(functions.max, DataType.FLOAT), ], filter_expression="type = 'a'", ), from_column="feature", ), ], ) # act output_df = fs.construct(input_df, SparkClient()) # assert assert_dataframe_equality(target_df, output_df)