def test_historical_feature_retrieval_with_mapping(spark: SparkSession): test_data_dir = path.join(pathlib.Path(__file__).parent.absolute(), "data") entity_source = FileSource( format="csv", path= f"file://{path.join(test_data_dir, 'column_mapping_test_entity.csv')}", event_timestamp_column="event_timestamp", field_mapping={"id": "customer_id"}, options={ "inferSchema": "true", "header": "true" }, ) booking_source = FileSource( format="csv", path= f"file://{path.join(test_data_dir, 'column_mapping_test_feature.csv')}", event_timestamp_column="datetime", created_timestamp_column="created_datetime", options={ "inferSchema": "true", "header": "true" }, ) booking_table = FeatureTable( name="bookings", entities=[Field("customer_id", "int32")], features=[Field("total_bookings", "int32")], ) joined_df = retrieve_historical_features( spark, entity_source, [booking_source], [booking_table], ) expected_joined_schema = StructType([ StructField("customer_id", IntegerType()), StructField("event_timestamp", TimestampType()), StructField("bookings__total_bookings", IntegerType()), ]) expected_joined_data = [ (1001, datetime(year=2020, month=9, day=2), 200), (1001, datetime(year=2020, month=9, day=3), 200), (2001, datetime(year=2020, month=9, day=4), 600), (2001, datetime(year=2020, month=9, day=4), 600), (3001, datetime(year=2020, month=9, day=4), 700), ] expected_joined_df = spark.createDataFrame( spark.sparkContext.parallelize(expected_joined_data), expected_joined_schema) assert_dataframe_equal(joined_df, expected_joined_df)
def test_large_historical_feature_retrieval(spark: SparkSession, large_entity_csv_file: str, large_feature_csv_file: str): nr_rows = 1000 start_datetime = datetime(year=2020, month=8, day=31) expected_join_data = [(1000 + i, start_datetime + timedelta(days=i), i * 10) for i in range(nr_rows)] expected_join_data_schema = StructType([ StructField("customer_id", IntegerType()), StructField("event_timestamp", TimestampType()), StructField("feature__total_bookings", IntegerType()), ]) expected_join_data_df = spark.createDataFrame( spark.sparkContext.parallelize(expected_join_data), expected_join_data_schema) entity_source = FileSource( format="csv", path=f"file://{large_entity_csv_file}", event_timestamp_column="event_timestamp", field_mapping={"id": "customer_id"}, options={ "inferSchema": "true", "header": "true" }, ) feature_source = FileSource( format="csv", path=f"file://{large_feature_csv_file}", event_timestamp_column="event_timestamp", created_timestamp_column="created_timestamp", options={ "inferSchema": "true", "header": "true" }, ) feature_table = FeatureTable( name="feature", entities=[Field("customer_id", "int32")], features=[Field("total_bookings", "int32")], ) joined_df = retrieve_historical_features(spark, entity_source, [feature_source], [feature_table]) assert_dataframe_equal(joined_df, expected_join_data_df)
def test_multiple_join( spark: SparkSession, composite_entity_schema: StructType, customer_feature_schema: StructType, driver_feature_schema: StructType, ): entity_data = [ (1001, 8001, datetime(year=2020, month=9, day=2)), (1001, 8002, datetime(year=2020, month=9, day=2)), (2001, 8002, datetime(year=2020, month=9, day=3)), ] entity_df = spark.createDataFrame( spark.sparkContext.parallelize(entity_data), composite_entity_schema) customer_table_data = [ ( 1001, datetime(year=2020, month=9, day=1), datetime(year=2020, month=9, day=1), 100.0, ), ( 2001, datetime(year=2020, month=9, day=1), datetime(year=2020, month=9, day=1), 200.0, ), ] customer_table_df = spark.createDataFrame( spark.sparkContext.parallelize(customer_table_data), customer_feature_schema) customer_table = FeatureTable( name="transactions", features=[Field("daily_transactions", "double")], entities=[Field("customer_id", "int32")], max_age=86400, ) driver_table_data = [ ( 8001, datetime(year=2020, month=8, day=31), datetime(year=2020, month=8, day=31), 200, ), ( 8001, datetime(year=2020, month=9, day=1), datetime(year=2020, month=9, day=1), 300, ), ( 8002, datetime(year=2020, month=9, day=1), datetime(year=2020, month=9, day=1), 600, ), ( 8002, datetime(year=2020, month=9, day=1), datetime(year=2020, month=9, day=2), 500, ), ] driver_table_df = spark.createDataFrame( spark.sparkContext.parallelize(driver_table_data), driver_feature_schema) driver_table = FeatureTable( name="bookings", features=[Field("completed_bookings", "int32")], entities=[Field("driver_id", "int32")], ) joined_df = join_entity_to_feature_tables( entity_df, "event_timestamp", [customer_table_df, driver_table_df], [customer_table, driver_table], ["event_timestamp"] * 2, ["created_timestamp"] * 2, ) expected_joined_schema = StructType([ StructField("customer_id", IntegerType()), StructField("driver_id", IntegerType()), StructField("event_timestamp", TimestampType()), StructField("transactions__daily_transactions", FloatType()), StructField("bookings__completed_bookings", IntegerType()), ]) expected_joined_data = [ ( 1001, 8001, datetime(year=2020, month=9, day=2), 100.0, 300, ), ( 1001, 8002, datetime(year=2020, month=9, day=2), 100.0, 500, ), ( 2001, 8002, datetime(year=2020, month=9, day=3), None, 500, ), ] expected_joined_df = spark.createDataFrame( spark.sparkContext.parallelize(expected_joined_data), expected_joined_schema) assert_dataframe_equal(joined_df, expected_joined_df)
def test_select_subset_of_columns_as_entity_primary_keys( spark: SparkSession, composite_entity_schema: StructType, customer_feature_schema: StructType, ): entity_data = [ (1001, 8001, datetime(year=2020, month=9, day=2)), (2001, 8002, datetime(year=2020, month=9, day=2)), ] entity_df = spark.createDataFrame( spark.sparkContext.parallelize(entity_data), composite_entity_schema) feature_table_data = [ ( 1001, datetime(year=2020, month=9, day=1), datetime(year=2020, month=9, day=2), 100.0, ), ( 2001, datetime(year=2020, month=9, day=1), datetime(year=2020, month=9, day=1), 400.0, ), ] feature_table_df = spark.createDataFrame( spark.sparkContext.parallelize(feature_table_data), customer_feature_schema) feature_table = FeatureTable( name="transactions", features=[Field("daily_transactions", "double")], entities=[Field("customer_id", "int32")], ) joined_df = as_of_join( entity_df, "event_timestamp", feature_table_df, feature_table, "event_timestamp", "created_timestamp", ) expected_joined_schema = StructType([ StructField("customer_id", IntegerType()), StructField("driver_id", IntegerType()), StructField("event_timestamp", TimestampType()), StructField("transactions__daily_transactions", FloatType()), ]) expected_joined_data = [ ( 1001, 8001, datetime(year=2020, month=9, day=2), 100.0, ), ( 2001, 8002, datetime(year=2020, month=9, day=2), 400.0, ), ] expected_joined_df = spark.createDataFrame( spark.sparkContext.parallelize(expected_joined_data), expected_joined_schema) assert_dataframe_equal(joined_df, expected_joined_df)
def test_join_with_composite_entity( spark: SparkSession, composite_entity_schema: StructType, rating_feature_schema: StructType, ): entity_data = [ (1001, 8001, datetime(year=2020, month=9, day=1)), (1001, 8002, datetime(year=2020, month=9, day=3)), (1001, 8003, datetime(year=2020, month=9, day=1)), (2001, 8001, datetime(year=2020, month=9, day=2)), ] entity_df = spark.createDataFrame( spark.sparkContext.parallelize(entity_data), composite_entity_schema) feature_table_data = [ ( 1001, 8001, datetime(year=2020, month=9, day=1), datetime(year=2020, month=9, day=1), 3.0, 5.0, ), ( 1001, 8002, datetime(year=2020, month=9, day=1), datetime(year=2020, month=9, day=1), 4.0, 3.0, ), ( 2001, 8001, datetime(year=2020, month=9, day=1), datetime(year=2020, month=9, day=1), 4.0, 4.5, ), ] feature_table_df = spark.createDataFrame( spark.sparkContext.parallelize(feature_table_data), rating_feature_schema, ) feature_table = FeatureTable( name="ratings", features=[ Field("customer_rating", "double"), Field("driver_rating", "double") ], entities=[Field("customer_id", "int32"), Field("driver_id", "int32")], max_age=86400, ) joined_df = as_of_join( entity_df, "event_timestamp", feature_table_df, feature_table, "event_timestamp", "created_timestamp", ) expected_joined_schema = StructType([ StructField("customer_id", IntegerType()), StructField("driver_id", IntegerType()), StructField("event_timestamp", TimestampType()), StructField("ratings__customer_rating", FloatType()), StructField("ratings__driver_rating", FloatType()), ]) expected_joined_data = [ ( 1001, 8001, datetime(year=2020, month=9, day=1), 3.0, 5.0, ), (1001, 8002, datetime(year=2020, month=9, day=3), None, None), (1001, 8003, datetime(year=2020, month=9, day=1), None, None), ( 2001, 8001, datetime(year=2020, month=9, day=2), 4.0, 4.5, ), ] expected_joined_df = spark.createDataFrame( spark.sparkContext.parallelize(expected_joined_data), expected_joined_schema) assert_dataframe_equal(joined_df, expected_joined_df)
def test_join_with_max_age( spark: SparkSession, single_entity_schema: StructType, customer_feature_schema: StructType, ): entity_data = [ (1001, datetime(year=2020, month=9, day=1)), (1001, datetime(year=2020, month=9, day=3)), (2001, datetime(year=2020, month=9, day=2)), ] entity_df = spark.createDataFrame( spark.sparkContext.parallelize(entity_data), single_entity_schema) feature_table_data = [ ( 1001, datetime(year=2020, month=9, day=1), datetime(year=2020, month=9, day=1), 100.0, ), ( 2001, datetime(year=2020, month=9, day=1), datetime(year=2020, month=9, day=1), 200.0, ), ] feature_table_df = spark.createDataFrame( spark.sparkContext.parallelize(feature_table_data), customer_feature_schema) feature_table = FeatureTable( name="transactions", features=[Field("daily_transactions", "double")], entities=[Field("customer_id", "int32")], max_age=86400, ) joined_df = as_of_join( entity_df, "event_timestamp", feature_table_df, feature_table, "event_timestamp", "created_timestamp", ) expected_joined_schema = StructType([ StructField("customer_id", IntegerType()), StructField("event_timestamp", TimestampType()), StructField("transactions__daily_transactions", FloatType()), ]) expected_joined_data = [ ( 1001, datetime(year=2020, month=9, day=1), 100.0, ), (1001, datetime(year=2020, month=9, day=3), None), ( 2001, datetime(year=2020, month=9, day=2), 200.0, ), ] expected_joined_df = spark.createDataFrame( spark.sparkContext.parallelize(expected_joined_data), expected_joined_schema) assert_dataframe_equal(joined_df, expected_joined_df)
def test_historical_feature_retrieval_with_schema_errors(spark: SparkSession): test_data_dir = path.join(pathlib.Path(__file__).parent.absolute(), "data") entity_source = FileSource( format="csv", path=f"file://{path.join(test_data_dir, 'customer_driver_pairs.csv')}", event_timestamp_column="event_timestamp", options={ "inferSchema": "true", "header": "true" }, ) entity_source_missing_timestamp = FileSource( format="csv", path=f"file://{path.join(test_data_dir, 'customer_driver_pairs.csv')}", event_timestamp_column="datetime", options={ "inferSchema": "true", "header": "true" }, ) entity_source_missing_entity = FileSource( format="csv", path=f"file://{path.join(test_data_dir, 'customers.csv')}", event_timestamp_column="event_timestamp", options={ "inferSchema": "true", "header": "true" }, ) booking_source = FileSource( format="csv", path=f"file://{path.join(test_data_dir, 'bookings.csv')}", event_timestamp_column="event_timestamp", created_timestamp_column="created_timestamp", options={ "inferSchema": "true", "header": "true" }, ) booking_source_missing_timestamp = FileSource( format="csv", path=f"file://{path.join(test_data_dir, 'bookings.csv')}", event_timestamp_column="datetime", created_timestamp_column="created_datetime", options={ "inferSchema": "true", "header": "true" }, ) booking_table = FeatureTable( name="bookings", entities=[Field("driver_id", "int32")], features=[Field("completed_bookings", "int32")], ) booking_table_missing_features = FeatureTable( name="bookings", entities=[Field("driver_id", "int32")], features=[Field("nonexist_feature", "int32")], ) booking_table_wrong_column_type = FeatureTable( name="bookings", entities=[Field("driver_id", "string")], features=[Field("completed_bookings", "int32")], ) with pytest.raises(SchemaError): retrieve_historical_features( spark, entity_source_missing_timestamp, [booking_source], [booking_table], ) with pytest.raises(SchemaError): retrieve_historical_features( spark, entity_source, [booking_source_missing_timestamp], [booking_table], ) with pytest.raises(SchemaError): retrieve_historical_features( spark, entity_source, [booking_source], [booking_table_missing_features], ) with pytest.raises(SchemaError): retrieve_historical_features( spark, entity_source, [booking_source], [booking_table_wrong_column_type], ) with pytest.raises(SchemaError): retrieve_historical_features( spark, entity_source_missing_entity, [booking_source], [booking_table], )
def test_historical_feature_retrieval(spark: SparkSession): test_data_dir = path.join(pathlib.Path(__file__).parent.absolute(), "data") entity_source = FileSource( format="csv", path=f"file://{path.join(test_data_dir, 'customer_driver_pairs.csv')}", event_timestamp_column="event_timestamp", options={ "inferSchema": "true", "header": "true" }, ) booking_source = FileSource( format="csv", path=f"file://{path.join(test_data_dir, 'bookings.csv')}", event_timestamp_column="event_timestamp", created_timestamp_column="created_timestamp", options={ "inferSchema": "true", "header": "true" }, ) transaction_source = FileSource( format="csv", path=f"file://{path.join(test_data_dir, 'transactions.csv')}", event_timestamp_column="event_timestamp", created_timestamp_column="created_timestamp", options={ "inferSchema": "true", "header": "true" }, ) booking_table = FeatureTable( name="bookings", entities=[Field("driver_id", "int32")], features=[Field("completed_bookings", "int32")], ) transaction_table = FeatureTable( name="transactions", entities=[Field("customer_id", "int32")], features=[Field("daily_transactions", "double")], max_age=86400, ) joined_df = retrieve_historical_features( spark, entity_source, [transaction_source, booking_source], [transaction_table, booking_table], ) expected_joined_schema = StructType([ StructField("customer_id", IntegerType()), StructField("driver_id", IntegerType()), StructField("event_timestamp", TimestampType()), StructField("transactions__daily_transactions", FloatType()), StructField("bookings__completed_bookings", IntegerType()), ]) expected_joined_data = [ ( 1001, 8001, datetime(year=2020, month=9, day=2), 100.0, 300, ), ( 1001, 8002, datetime(year=2020, month=9, day=2), 100.0, 500, ), ( 1001, 8002, datetime(year=2020, month=9, day=3), None, 500, ), ( 2001, 8002, datetime(year=2020, month=9, day=3), None, 500, ), ( 2001, 8002, datetime(year=2020, month=9, day=4), None, 500, ), ] expected_joined_df = spark.createDataFrame( spark.sparkContext.parallelize(expected_joined_data), expected_joined_schema) assert_dataframe_equal(joined_df, expected_joined_df)