コード例 #1
0
ファイル: test_feature_set.py プロジェクト: zuston/butterfree
    def test_construct(
        self, feature_set_dataframe, fixed_windows_output_feature_set_dataframe
    ):
        # given

        spark_client = SparkClient()

        # arrange

        feature_set = FeatureSet(
            name="feature_set",
            entity="entity",
            description="description",
            features=[
                Feature(
                    name="feature1",
                    description="test",
                    transformation=SparkFunctionTransform(
                        functions=[
                            Function(F.avg, DataType.FLOAT),
                            Function(F.stddev_pop, DataType.FLOAT),
                        ]
                    ).with_window(
                        partition_by="id",
                        order_by=TIMESTAMP_COLUMN,
                        mode="fixed_windows",
                        window_definition=["2 minutes", "15 minutes"],
                    ),
                ),
                Feature(
                    name="divided_feature",
                    description="unit test",
                    dtype=DataType.FLOAT,
                    transformation=CustomTransform(
                        transformer=divide, column1="feature1", column2="feature2",
                    ),
                ),
            ],
            keys=[
                KeyFeature(
                    name="id",
                    description="The user's Main ID or device ID",
                    dtype=DataType.INTEGER,
                )
            ],
            timestamp=TimestampFeature(),
        )

        output_df = (
            feature_set.construct(feature_set_dataframe, client=spark_client)
            .orderBy(feature_set.timestamp_column)
            .select(feature_set.columns)
        )

        target_df = fixed_windows_output_feature_set_dataframe.orderBy(
            feature_set.timestamp_column
        ).select(feature_set.columns)

        # assert
        assert_dataframe_equality(output_df, target_df)
コード例 #2
0
    def test_filtering(
        self,
        filtering_dataframe,
        key_id,
        timestamp_c,
        feature1,
        feature2,
        feature3,
        output_filtering_dataframe,
    ):
        spark_client = Mock()

        # arrange
        feature_set = FeatureSet(
            "name",
            "entity",
            "description",
            [key_id],
            timestamp_c,
            [feature1, feature2, feature3],
        )

        # act
        result_df = (feature_set.construct(
            filtering_dataframe, spark_client).orderBy("timestamp").collect())

        # assert
        assert (result_df == output_filtering_dataframe.orderBy(
            "timestamp").select(feature_set.columns).collect())
コード例 #3
0
    def test_construct_transformations(
        self,
        dataframe,
        feature_set_dataframe,
        key_id,
        timestamp_c,
        feature_add,
        feature_divide,
    ):
        spark_client = Mock()

        # arrange
        feature_set = FeatureSet(
            "name",
            "entity",
            "description",
            [key_id],
            timestamp_c,
            [feature_add, feature_divide],
        )

        # act
        result_df = feature_set.construct(dataframe, spark_client)

        # assert
        assert_dataframe_equality(result_df, feature_set_dataframe)
コード例 #4
0
    def test_construct(
        self,
        dataframe,
        feature_set_dataframe,
        key_id,
        timestamp_c,
        feature_add,
        feature_divide,
    ):
        spark_client = Mock()

        # arrange
        feature_set = FeatureSet(
            "name",
            "entity",
            "description",
            [key_id],
            timestamp_c,
            [feature_add, feature_divide],
        )

        # act
        result_df = feature_set.construct(dataframe, spark_client)
        result_columns = result_df.columns

        # assert
        assert (result_columns == key_id.get_output_columns() +
                timestamp_c.get_output_columns() +
                feature_add.get_output_columns() +
                feature_divide.get_output_columns())
        assert_dataframe_equality(result_df, feature_set_dataframe)
        assert result_df.is_cached
コード例 #5
0
    def test_construct_invalid_df(self, key_id, timestamp_c, feature_add,
                                  feature_divide):
        spark_client = Mock()

        # arrange
        feature_set = FeatureSet(
            "name",
            "entity",
            "description",
            [key_id],
            timestamp_c,
            [feature_add, feature_divide],
        )

        # act and assert
        with pytest.raises(ValueError):
            _ = feature_set.construct("not a dataframe", spark_client)
コード例 #6
0
    def test_construct_with_date_boundaries(
            self, feature_set_dates_dataframe,
            feature_set_dates_output_dataframe):
        # given

        spark_client = SparkClient()

        # arrange

        feature_set = FeatureSet(
            name="feature_set",
            entity="entity",
            description="description",
            features=[
                Feature(
                    name="feature",
                    description="test",
                    dtype=DataType.FLOAT,
                ),
            ],
            keys=[
                KeyFeature(
                    name="id",
                    description="The user's Main ID or device ID",
                    dtype=DataType.INTEGER,
                )
            ],
            timestamp=TimestampFeature(),
        )

        output_df = (feature_set.construct(
            feature_set_dates_dataframe,
            client=spark_client,
            start_date="2016-04-11",
            end_date="2016-04-12",
        ).orderBy(feature_set.timestamp_column).select(feature_set.columns))

        target_df = feature_set_dates_output_dataframe.orderBy(
            feature_set.timestamp_column).select(feature_set.columns)

        # assert
        assert_dataframe_equality(output_df, target_df)