def test_h3_feature_set(self, h3_input_df, h3_target_df):
        spark_client = SparkClient()

        feature_set = AggregatedFeatureSet(
            name="h3_test",
            entity="h3geolocation",
            description="Test",
            keys=[
                KeyFeature(
                    name="h3_id",
                    description="The h3 hash ID",
                    dtype=DataType.DOUBLE,
                    transformation=H3HashTransform(
                        h3_resolutions=[6, 7, 8, 9, 10, 11, 12],
                        lat_column="lat",
                        lng_column="lng",
                    ).with_stack(),
                )
            ],
            timestamp=TimestampFeature(),
            features=[
                Feature(
                    name="house_id",
                    description="Count of house ids over a day.",
                    transformation=AggregatedTransform(
                        functions=[Function(F.count, DataType.BIGINT)]),
                ),
            ],
        ).with_windows(definitions=["1 day"])

        output_df = feature_set.construct(h3_input_df,
                                          client=spark_client,
                                          end_date="2016-04-14")

        assert_dataframe_equality(output_df, h3_target_df)
    def test_construct_rolling_windows_with_end_date(
        self,
        feature_set_dataframe,
        rolling_windows_output_feature_set_dataframe_base_date,
    ):
        # given

        spark_client = SparkClient()

        # arrange

        feature_set = AggregatedFeatureSet(
            name="feature_set",
            entity="entity",
            description="description",
            features=[
                Feature(
                    name="feature1",
                    description="test",
                    transformation=AggregatedTransform(
                        functions=[
                            Function(F.avg, DataType.DOUBLE),
                            Function(F.stddev_pop, DataType.DOUBLE),
                        ],
                    ),
                ),
                Feature(
                    name="feature2",
                    description="test",
                    transformation=AggregatedTransform(
                        functions=[
                            Function(F.avg, DataType.DOUBLE),
                            Function(F.stddev_pop, DataType.DOUBLE),
                        ],
                    ),
                ),
            ],
            keys=[
                KeyFeature(
                    name="id",
                    description="The user's Main ID or device ID",
                    dtype=DataType.INTEGER,
                )
            ],
            timestamp=TimestampFeature(),
        ).with_windows(definitions=["1 day", "1 week"])

        # act
        output_df = feature_set.construct(
            feature_set_dataframe, client=spark_client, end_date="2016-04-18"
        ).orderBy("timestamp")

        target_df = rolling_windows_output_feature_set_dataframe_base_date.orderBy(
            feature_set.timestamp_column
        ).select(feature_set.columns)

        # assert
        assert_dataframe_equality(output_df, target_df)
Esempio n. 3
0
    def test_agg_feature_set_with_window(self, key_id, timestamp_c, dataframe,
                                         rolling_windows_agg_dataframe):
        spark_client = SparkClient()

        fs = AggregatedFeatureSet(
            name="name",
            entity="entity",
            description="description",
            features=[
                Feature(
                    name="feature1",
                    description="unit test",
                    transformation=AggregatedTransform(
                        functions=[Function(functions.avg, DataType.FLOAT)]),
                ),
                Feature(
                    name="feature2",
                    description="unit test",
                    transformation=AggregatedTransform(
                        functions=[Function(functions.avg, DataType.FLOAT)]),
                ),
            ],
            keys=[key_id],
            timestamp=timestamp_c,
        ).with_windows(definitions=["1 week"])

        # raises without end date
        with pytest.raises(ValueError):
            _ = fs.construct(dataframe, spark_client)

        # filters with date smaller then mocked max
        output_df = fs.construct(dataframe,
                                 spark_client,
                                 end_date="2016-04-17")
        assert output_df.count() < rolling_windows_agg_dataframe.count()
        output_df = fs.construct(dataframe,
                                 spark_client,
                                 end_date="2016-05-01")
        assert_dataframe_equality(output_df, rolling_windows_agg_dataframe)
    def test_construct_without_window(
        self,
        feature_set_dataframe,
        target_df_without_window,
    ):
        # given

        spark_client = SparkClient()

        # arrange

        feature_set = AggregatedFeatureSet(
            name="feature_set",
            entity="entity",
            description="description",
            features=[
                Feature(
                    name="feature1",
                    description="test",
                    dtype=DataType.DOUBLE,
                    transformation=AggregatedTransform(
                        functions=[Function(F.avg, DataType.DOUBLE)]),
                ),
                Feature(
                    name="feature2",
                    description="test",
                    dtype=DataType.FLOAT,
                    transformation=AggregatedTransform(
                        functions=[Function(F.count, DataType.BIGINT)]),
                ),
            ],
            keys=[
                KeyFeature(
                    name="id",
                    description="The user's Main ID or device ID",
                    dtype=DataType.INTEGER,
                )
            ],
            timestamp=TimestampFeature(from_column="fixed_ts"),
        )

        # act
        output_df = feature_set.construct(feature_set_dataframe,
                                          client=spark_client)

        # assert
        assert_dataframe_equality(output_df, target_df_without_window)
    def test_construct_with_pivot(
        self,
        feature_set_df_pivot,
        target_df_pivot_agg,
    ):
        # given

        spark_client = SparkClient()

        # arrange

        feature_set = AggregatedFeatureSet(
            name="feature_set",
            entity="entity",
            description="description",
            features=[
                Feature(
                    name="feature",
                    description="unit test",
                    transformation=AggregatedTransform(functions=[
                        Function(F.avg, DataType.FLOAT),
                        Function(F.stddev_pop, DataType.DOUBLE),
                    ], ),
                    from_column="feature1",
                )
            ],
            keys=[
                KeyFeature(
                    name="id",
                    description="The user's Main ID or device ID",
                    dtype=DataType.INTEGER,
                )
            ],
            timestamp=TimestampFeature(from_column="fixed_ts"),
        ).with_pivot("pivot_col", ["S", "N"])

        # act
        output_df = feature_set.construct(feature_set_df_pivot,
                                          client=spark_client)

        # assert
        assert_dataframe_equality(output_df, target_df_pivot_agg)
    def test_construct_rolling_windows_without_end_date(
        self, feature_set_dataframe, rolling_windows_output_feature_set_dataframe
    ):
        # given

        spark_client = SparkClient()

        # arrange

        feature_set = AggregatedFeatureSet(
            name="feature_set",
            entity="entity",
            description="description",
            features=[
                Feature(
                    name="feature1",
                    description="test",
                    transformation=AggregatedTransform(
                        functions=[
                            Function(F.avg, DataType.DOUBLE),
                            Function(F.stddev_pop, DataType.DOUBLE),
                        ],
                    ),
                ),
            ],
            keys=[
                KeyFeature(
                    name="id",
                    description="The user's Main ID or device ID",
                    dtype=DataType.INTEGER,
                )
            ],
            timestamp=TimestampFeature(),
        ).with_windows(definitions=["1 day", "1 week"],)

        # act & assert
        with pytest.raises(ValueError):
            _ = feature_set.construct(feature_set_dataframe, client=spark_client)
    def test_feature_transform_with_data_type_array(self, spark_context,
                                                    spark_session):
        # arrange
        input_data = [
            {
                "id": 1,
                "timestamp": "2020-04-22T00:00:00+00:00",
                "feature": 10
            },
            {
                "id": 1,
                "timestamp": "2020-04-22T00:00:00+00:00",
                "feature": 20
            },
            {
                "id": 1,
                "timestamp": "2020-04-22T00:00:00+00:00",
                "feature": 30
            },
            {
                "id": 2,
                "timestamp": "2020-04-22T00:00:00+00:00",
                "feature": 10
            },
        ]
        target_data = [
            {
                "id": 1,
                "timestamp": "2020-04-22T00:00:00+00:00",
                "feature__collect_set": [30.0, 20.0, 10.0],
            },
            {
                "id": 2,
                "timestamp": "2020-04-22T00:00:00+00:00",
                "feature__collect_set": [10.0],
            },
        ]
        input_df = create_df_from_collection(
            input_data, spark_context, spark_session).withColumn(
                "timestamp",
                functions.to_timestamp(functions.col("timestamp")))
        target_df = create_df_from_collection(
            target_data, spark_context, spark_session).withColumn(
                "timestamp",
                functions.to_timestamp(functions.col("timestamp")))

        fs = AggregatedFeatureSet(
            name="name",
            entity="entity",
            description="description",
            keys=[
                KeyFeature(name="id",
                           description="test",
                           dtype=DataType.INTEGER)
            ],
            timestamp=TimestampFeature(),
            features=[
                Feature(
                    name="feature",
                    description="aggregations with ",
                    dtype=DataType.BIGINT,
                    transformation=AggregatedTransform(functions=[
                        Function(functions.collect_set, DataType.ARRAY_FLOAT),
                    ], ),
                    from_column="feature",
                ),
            ],
        )

        # act
        output_df = fs.construct(input_df, SparkClient())

        # assert
        assert_dataframe_equality(target_df, output_df)
    def test_feature_transform_with_filter_expression(self, spark_context,
                                                      spark_session):
        # arrange
        input_data = [
            {
                "id": 1,
                "timestamp": "2020-04-22T00:00:00+00:00",
                "feature": 10,
                "type": "a",
            },
            {
                "id": 1,
                "timestamp": "2020-04-22T00:00:00+00:00",
                "feature": 20,
                "type": "a",
            },
            {
                "id": 1,
                "timestamp": "2020-04-22T00:00:00+00:00",
                "feature": 30,
                "type": "b",
            },
            {
                "id": 2,
                "timestamp": "2020-04-22T00:00:00+00:00",
                "feature": 10,
                "type": "a",
            },
        ]
        target_data = [
            {
                "id": 1,
                "timestamp": "2020-04-22T00:00:00+00:00",
                "feature_only_type_a__avg": 15.0,
                "feature_only_type_a__min": 10,
                "feature_only_type_a__max": 20,
            },
            {
                "id": 2,
                "timestamp": "2020-04-22T00:00:00+00:00",
                "feature_only_type_a__avg": 10.0,
                "feature_only_type_a__min": 10,
                "feature_only_type_a__max": 10,
            },
        ]
        input_df = create_df_from_collection(
            input_data, spark_context, spark_session).withColumn(
                "timestamp",
                functions.to_timestamp(functions.col("timestamp")))
        target_df = create_df_from_collection(
            target_data, spark_context, spark_session).withColumn(
                "timestamp",
                functions.to_timestamp(functions.col("timestamp")))

        fs = AggregatedFeatureSet(
            name="name",
            entity="entity",
            description="description",
            keys=[
                KeyFeature(name="id",
                           description="test",
                           dtype=DataType.INTEGER)
            ],
            timestamp=TimestampFeature(),
            features=[
                Feature(
                    name="feature_only_type_a",
                    description="aggregations only when type = a",
                    dtype=DataType.BIGINT,
                    transformation=AggregatedTransform(
                        functions=[
                            Function(functions.avg, DataType.FLOAT),
                            Function(functions.min, DataType.FLOAT),
                            Function(functions.max, DataType.FLOAT),
                        ],
                        filter_expression="type = 'a'",
                    ),
                    from_column="feature",
                ),
            ],
        )

        # act
        output_df = fs.construct(input_df, SparkClient())

        # assert
        assert_dataframe_equality(target_df, output_df)