Ejemplo n.º 1
0
def test_join_with_composite_entity(
    spark: SparkSession,
    composite_entity_schema: StructType,
    rating_feature_schema: StructType,
):
    entity_data = [
        (1001, 8001, datetime(year=2020, month=9, day=1)),
        (1001, 8002, datetime(year=2020, month=9, day=3)),
        (1001, 8003, datetime(year=2020, month=9, day=1)),
        (2001, 8001, datetime(year=2020, month=9, day=2)),
    ]
    entity_df = spark.createDataFrame(
        spark.sparkContext.parallelize(entity_data), composite_entity_schema)

    feature_table_data = [
        (
            1001,
            8001,
            datetime(year=2020, month=9, day=1),
            datetime(year=2020, month=9, day=1),
            3.0,
            5.0,
        ),
        (
            1001,
            8002,
            datetime(year=2020, month=9, day=1),
            datetime(year=2020, month=9, day=1),
            4.0,
            3.0,
        ),
        (
            2001,
            8001,
            datetime(year=2020, month=9, day=1),
            datetime(year=2020, month=9, day=1),
            4.0,
            4.5,
        ),
    ]
    feature_table_df = spark.createDataFrame(
        spark.sparkContext.parallelize(feature_table_data),
        rating_feature_schema,
    )
    feature_table = FeatureTable(
        name="ratings",
        features=[
            Field("customer_rating", "double"),
            Field("driver_rating", "double")
        ],
        entities=[Field("customer_id", "int32"),
                  Field("driver_id", "int32")],
        max_age=86400,
    )

    joined_df = as_of_join(
        entity_df,
        "event_timestamp",
        feature_table_df,
        feature_table,
    )

    expected_joined_schema = StructType([
        StructField("customer_id", IntegerType()),
        StructField("driver_id", IntegerType()),
        StructField("event_timestamp", TimestampType()),
        StructField("ratings__customer_rating", FloatType()),
        StructField("ratings__driver_rating", FloatType()),
    ])
    expected_joined_data = [
        (
            1001,
            8001,
            datetime(year=2020, month=9, day=1),
            3.0,
            5.0,
        ),
        (1001, 8002, datetime(year=2020, month=9, day=3), None, None),
        (1001, 8003, datetime(year=2020, month=9, day=1), None, None),
        (
            2001,
            8001,
            datetime(year=2020, month=9, day=2),
            4.0,
            4.5,
        ),
    ]
    expected_joined_df = spark.createDataFrame(
        spark.sparkContext.parallelize(expected_joined_data),
        expected_joined_schema)

    assert_dataframe_equal(joined_df, expected_joined_df)
Ejemplo n.º 2
0
def test_select_subset_of_columns_as_entity_primary_keys(
    spark: SparkSession,
    composite_entity_schema: StructType,
    customer_feature_schema: StructType,
):
    entity_data = [
        (1001, 8001, datetime(year=2020, month=9, day=2)),
        (2001, 8002, datetime(year=2020, month=9, day=2)),
    ]
    entity_df = spark.createDataFrame(
        spark.sparkContext.parallelize(entity_data), composite_entity_schema)

    feature_table_data = [
        (
            1001,
            datetime(year=2020, month=9, day=1),
            datetime(year=2020, month=9, day=2),
            100.0,
        ),
        (
            2001,
            datetime(year=2020, month=9, day=1),
            datetime(year=2020, month=9, day=1),
            400.0,
        ),
    ]
    feature_table_df = spark.createDataFrame(
        spark.sparkContext.parallelize(feature_table_data),
        customer_feature_schema)
    feature_table = FeatureTable(
        name="transactions",
        features=[Field("daily_transactions", "double")],
        entities=[Field("customer_id", "int32")],
    )

    joined_df = as_of_join(
        entity_df,
        "event_timestamp",
        feature_table_df,
        feature_table,
    )

    expected_joined_schema = StructType([
        StructField("customer_id", IntegerType()),
        StructField("driver_id", IntegerType()),
        StructField("event_timestamp", TimestampType()),
        StructField("transactions__daily_transactions", FloatType()),
    ])
    expected_joined_data = [
        (
            1001,
            8001,
            datetime(year=2020, month=9, day=2),
            100.0,
        ),
        (
            2001,
            8002,
            datetime(year=2020, month=9, day=2),
            400.0,
        ),
    ]
    expected_joined_df = spark.createDataFrame(
        spark.sparkContext.parallelize(expected_joined_data),
        expected_joined_schema)

    assert_dataframe_equal(joined_df, expected_joined_df)
Ejemplo n.º 3
0
def test_join_with_max_age(
    spark: SparkSession,
    single_entity_schema: StructType,
    customer_feature_schema: StructType,
):
    entity_data = [
        (1001, datetime(year=2020, month=9, day=1)),
        (1001, datetime(year=2020, month=9, day=3)),
        (2001, datetime(year=2020, month=9, day=2)),
    ]
    entity_df = spark.createDataFrame(
        spark.sparkContext.parallelize(entity_data), single_entity_schema)

    feature_table_data = [
        (
            1001,
            datetime(year=2020, month=9, day=1),
            datetime(year=2020, month=9, day=1),
            100.0,
        ),
        (
            2001,
            datetime(year=2020, month=9, day=1),
            datetime(year=2020, month=9, day=1),
            200.0,
        ),
    ]
    feature_table_df = spark.createDataFrame(
        spark.sparkContext.parallelize(feature_table_data),
        customer_feature_schema)
    feature_table = FeatureTable(
        name="transactions",
        features=[Field("daily_transactions", "double")],
        entities=[Field("customer_id", "int32")],
        max_age=86400,
    )

    joined_df = as_of_join(
        entity_df,
        "event_timestamp",
        feature_table_df,
        feature_table,
    )

    expected_joined_schema = StructType([
        StructField("customer_id", IntegerType()),
        StructField("event_timestamp", TimestampType()),
        StructField("transactions__daily_transactions", FloatType()),
    ])
    expected_joined_data = [
        (
            1001,
            datetime(year=2020, month=9, day=1),
            100.0,
        ),
        (1001, datetime(year=2020, month=9, day=3), None),
        (
            2001,
            datetime(year=2020, month=9, day=2),
            200.0,
        ),
    ]
    expected_joined_df = spark.createDataFrame(
        spark.sparkContext.parallelize(expected_joined_data),
        expected_joined_schema)

    assert_dataframe_equal(joined_df, expected_joined_df)