Пример #1
0
    def test_h3_feature_set(self, h3_input_df, h3_target_df):
        spark_client = SparkClient()

        feature_set = AggregatedFeatureSet(
            name="h3_test",
            entity="h3geolocation",
            description="Test",
            keys=[
                KeyFeature(
                    name="h3_id",
                    description="The h3 hash ID",
                    dtype=DataType.DOUBLE,
                    transformation=H3HashTransform(
                        h3_resolutions=[6, 7, 8, 9, 10, 11, 12],
                        lat_column="lat",
                        lng_column="lng",
                    ).with_stack(),
                )
            ],
            timestamp=TimestampFeature(),
            features=[
                Feature(
                    name="house_id",
                    description="Count of house ids over a day.",
                    transformation=AggregatedTransform(
                        functions=[Function(F.count, DataType.BIGINT)]),
                ),
            ],
        ).with_windows(definitions=["1 day"])

        output_df = feature_set.construct(h3_input_df,
                                          client=spark_client,
                                          end_date="2016-04-14")

        assert_dataframe_equality(output_df, h3_target_df)
Пример #2
0
    def test_construct_transformations(
        self,
        dataframe,
        feature_set_dataframe,
        key_id,
        timestamp_c,
        feature_add,
        feature_divide,
    ):
        spark_client = Mock()

        # arrange
        feature_set = FeatureSet(
            "name",
            "entity",
            "description",
            [key_id],
            timestamp_c,
            [feature_add, feature_divide],
        )

        # act
        result_df = feature_set.construct(dataframe, spark_client)

        # assert
        assert_dataframe_equality(result_df, feature_set_dataframe)
def test_explode_json_column(spark_context, spark_session):
    # arrange
    input_data = [{
        "json_column":
        '{"a": 123, "b": "abc", "c": "123", "d": [1, 2, 3]}'
    }]
    target_data = [{
        "json_column": '{"a": 123, "b": "abc", "c": "123", "d": [1, 2, 3]}',
        "a": 123,
        "b": "abc",
        "c": 123,
        "d": [1, 2, 3],
    }]

    input_df = create_df_from_collection(input_data, spark_context,
                                         spark_session)
    target_df = create_df_from_collection(target_data, spark_context,
                                          spark_session)

    json_column_schema = StructType([
        StructField("a", IntegerType()),
        StructField("b", StringType()),
        StructField("c", IntegerType()),
        StructField("d", ArrayType(IntegerType())),
    ])

    # act
    output_df = explode_json_column(input_df, "json_column",
                                    json_column_schema)

    # arrange
    assert_dataframe_equality(target_df, output_df)
    def test_write_interval_mode(
        self,
        feature_set_dataframe,
        historical_feature_set_dataframe,
        mocker,
        feature_set,
    ):
        # given
        spark_client = SparkClient()
        spark_client.write_table = mocker.stub("write_table")
        spark_client.conn.conf.set("spark.sql.sources.partitionOverwriteMode",
                                   "dynamic")
        writer = HistoricalFeatureStoreWriter(interval_mode=True)

        # when
        writer.write(
            feature_set=feature_set,
            dataframe=feature_set_dataframe,
            spark_client=spark_client,
        )
        result_df = spark_client.write_table.call_args[1]["dataframe"]

        # then
        assert_dataframe_equality(historical_feature_set_dataframe, result_df)

        assert writer.database == spark_client.write_table.call_args[1][
            "database"]
        assert feature_set.name == spark_client.write_table.call_args[1][
            "table_name"]
        assert (writer.PARTITION_BY == spark_client.write_table.call_args[1]
                ["partition_by"])
Пример #5
0
    def test_construct(
        self,
        dataframe,
        feature_set_dataframe,
        key_id,
        timestamp_c,
        feature_add,
        feature_divide,
    ):
        spark_client = Mock()

        # arrange
        feature_set = FeatureSet(
            "name",
            "entity",
            "description",
            [key_id],
            timestamp_c,
            [feature_add, feature_divide],
        )

        # act
        result_df = feature_set.construct(dataframe, spark_client)
        result_columns = result_df.columns

        # assert
        assert (result_columns == key_id.get_output_columns() +
                timestamp_c.get_output_columns() +
                feature_add.get_output_columns() +
                feature_divide.get_output_columns())
        assert_dataframe_equality(result_df, feature_set_dataframe)
        assert result_df.is_cached
Пример #6
0
    def test_construct(
        self, feature_set_dataframe, fixed_windows_output_feature_set_dataframe
    ):
        # given

        spark_client = SparkClient()

        # arrange

        feature_set = FeatureSet(
            name="feature_set",
            entity="entity",
            description="description",
            features=[
                Feature(
                    name="feature1",
                    description="test",
                    transformation=SparkFunctionTransform(
                        functions=[
                            Function(F.avg, DataType.FLOAT),
                            Function(F.stddev_pop, DataType.FLOAT),
                        ]
                    ).with_window(
                        partition_by="id",
                        order_by=TIMESTAMP_COLUMN,
                        mode="fixed_windows",
                        window_definition=["2 minutes", "15 minutes"],
                    ),
                ),
                Feature(
                    name="divided_feature",
                    description="unit test",
                    dtype=DataType.FLOAT,
                    transformation=CustomTransform(
                        transformer=divide, column1="feature1", column2="feature2",
                    ),
                ),
            ],
            keys=[
                KeyFeature(
                    name="id",
                    description="The user's Main ID or device ID",
                    dtype=DataType.INTEGER,
                )
            ],
            timestamp=TimestampFeature(),
        )

        output_df = (
            feature_set.construct(feature_set_dataframe, client=spark_client)
            .orderBy(feature_set.timestamp_column)
            .select(feature_set.columns)
        )

        target_df = fixed_windows_output_feature_set_dataframe.orderBy(
            feature_set.timestamp_column
        ).select(feature_set.columns)

        # assert
        assert_dataframe_equality(output_df, target_df)
    def test_write_in_debug_mode_with_interval_mode(
        self,
        feature_set_dataframe,
        historical_feature_set_dataframe,
        feature_set,
        spark_session,
        mocker,
    ):
        # given
        spark_client = SparkClient()
        spark_client.write_dataframe = mocker.stub("write_dataframe")
        spark_client.conn.conf.set("spark.sql.sources.partitionOverwriteMode",
                                   "dynamic")
        writer = HistoricalFeatureStoreWriter(debug_mode=True,
                                              interval_mode=True)

        # when
        writer.write(
            feature_set=feature_set,
            dataframe=feature_set_dataframe,
            spark_client=spark_client,
        )
        result_df = spark_session.table(
            f"historical_feature_store__{feature_set.name}")

        # then
        assert_dataframe_equality(historical_feature_set_dataframe, result_df)
    def test_feature_set_pipeline_with_execution_date(
        self,
        mocked_date_df,
        spark_session,
        fixed_windows_output_feature_set_date_dataframe,
        feature_set_pipeline,
    ):
        # arrange
        table_reader_table = "b_table"
        create_temp_view(dataframe=mocked_date_df, name=table_reader_table)

        target_df = fixed_windows_output_feature_set_date_dataframe.filter(
            "timestamp < '2016-04-13'")

        historical_writer = HistoricalFeatureStoreWriter(debug_mode=True)

        feature_set_pipeline.sink.writers = [historical_writer]

        # act
        feature_set_pipeline.run_for_date(execution_date="2016-04-12")

        df = spark_session.sql(
            "select * from historical_feature_store__feature_set")

        # assert
        assert_dataframe_equality(df, target_df)
Пример #9
0
def test_assert_dataframe_equality_different_shapes(spark_context,
                                                    spark_session):
    # arrange
    data1 = [
        {
            "value": "abc"
        },
        {
            "value": "cba"
        },
        {
            "value": "cba"
        },
    ]
    data2 = [
        {
            "value": "abc"
        },
        {
            "value": "cba"
        },
    ]

    df1 = spark_session.read.json(spark_context.parallelize(data1, 1))
    df2 = spark_session.read.json(spark_context.parallelize(data2, 1))

    # act and assert
    with pytest.raises(AssertionError, match="DataFrame shape mismatch:"):
        assert_dataframe_equality(df1, df2)
Пример #10
0
def test_dataset_writer(spark_session):
    # arrange
    dataset_writer = DatasetWriter()

    mock_feature_set = Mock(spec=FeatureSet)
    mock_feature_set.name = "feature_set"

    mock_spark_client = Mock(spec=SparkClient)
    mock_spark_client.write_dataframe = Mock()
    mock_spark_client.write_dataframe.return_value = True

    input_df = spark_session.sql("select 1")

    # act
    dataset_writer.check_schema(0, 0, 0, 0)  # nothing should happen
    dataset_writer.write(feature_set=mock_feature_set,
                         dataframe=input_df,
                         spark_client=mock_spark_client)
    args = mock_spark_client.write_dataframe.call_args[1]

    # assert
    assert_dataframe_equality(args["dataframe"], input_df)
    assert args["format_"] == "csv"
    assert args["mode"] == "overwrite"
    assert args["path"] == "data/datasets/feature_set"
    assert args["header"] is True
    assert dataset_writer.__name__ == "Dataset Writer"
    def test_write(
        self,
        feature_set_dataframe,
        historical_feature_set_dataframe,
        mocker,
        feature_set,
    ):
        # given
        spark_client = mocker.stub("spark_client")
        spark_client.write_table = mocker.stub("write_table")
        writer = HistoricalFeatureStoreWriter()

        # when
        writer.write(
            feature_set=feature_set,
            dataframe=feature_set_dataframe,
            spark_client=spark_client,
        )
        result_df = spark_client.write_table.call_args[1]["dataframe"]

        # then
        assert_dataframe_equality(historical_feature_set_dataframe, result_df)

        assert (writer.db_config.format_ ==
                spark_client.write_table.call_args[1]["format_"])
        assert writer.db_config.mode == spark_client.write_table.call_args[1][
            "mode"]
        assert (writer.PARTITION_BY == spark_client.write_table.call_args[1]
                ["partition_by"])
        assert feature_set.name == spark_client.write_table.call_args[1][
            "table_name"]
    def test_feature_transform_with_distinct(
        self,
        timestamp_c,
        feature_set_with_distinct_dataframe,
        target_with_distinct_dataframe,
    ):
        spark_client = SparkClient()

        fs = (AggregatedFeatureSet(
            name="name",
            entity="entity",
            description="description",
            features=[
                Feature(
                    name="feature",
                    description="test",
                    transformation=AggregatedTransform(
                        functions=[Function(functions.sum, DataType.INTEGER)]),
                ),
            ],
            keys=[
                KeyFeature(name="h3",
                           description="test",
                           dtype=DataType.STRING)
            ],
            timestamp=timestamp_c,
        ).with_windows(["3 days"]).with_distinct(subset=["id"], keep="last"))

        # assert
        output_df = fs.construct(feature_set_with_distinct_dataframe,
                                 spark_client,
                                 end_date="2020-01-10")
        assert_dataframe_equality(output_df, target_with_distinct_dataframe)
    def test_construct_rolling_windows_with_end_date(
        self,
        feature_set_dataframe,
        rolling_windows_output_feature_set_dataframe_base_date,
    ):
        # given

        spark_client = SparkClient()

        # arrange

        feature_set = AggregatedFeatureSet(
            name="feature_set",
            entity="entity",
            description="description",
            features=[
                Feature(
                    name="feature1",
                    description="test",
                    transformation=AggregatedTransform(
                        functions=[
                            Function(F.avg, DataType.DOUBLE),
                            Function(F.stddev_pop, DataType.DOUBLE),
                        ],
                    ),
                ),
                Feature(
                    name="feature2",
                    description="test",
                    transformation=AggregatedTransform(
                        functions=[
                            Function(F.avg, DataType.DOUBLE),
                            Function(F.stddev_pop, DataType.DOUBLE),
                        ],
                    ),
                ),
            ],
            keys=[
                KeyFeature(
                    name="id",
                    description="The user's Main ID or device ID",
                    dtype=DataType.INTEGER,
                )
            ],
            timestamp=TimestampFeature(),
        ).with_windows(definitions=["1 day", "1 week"])

        # act
        output_df = feature_set.construct(
            feature_set_dataframe, client=spark_client, end_date="2016-04-18"
        ).orderBy("timestamp")

        target_df = rolling_windows_output_feature_set_dataframe_base_date.orderBy(
            feature_set.timestamp_column
        ).select(feature_set.columns)

        # assert
        assert_dataframe_equality(output_df, target_df)
    def test_pipeline_with_hooks(self, spark_session):
        # arrange
        hook1 = AddHook(value=1)

        spark_session.sql(
            "select 1 as id, timestamp('2020-01-01') as timestamp, 0 as feature"
        ).createOrReplaceTempView("test")

        target_df = spark_session.sql(
            "select 1 as id, timestamp('2020-01-01') as timestamp, 6 as feature, 2020 "
            "as year, 1 as month, 1 as day")

        historical_writer = HistoricalFeatureStoreWriter(debug_mode=True)

        test_pipeline = FeatureSetPipeline(
            source=Source(
                readers=[
                    TableReader(
                        id="reader",
                        table="test",
                    ).add_post_hook(hook1)
                ],
                query="select * from reader",
            ).add_post_hook(hook1),
            feature_set=FeatureSet(
                name="feature_set",
                entity="entity",
                description="description",
                features=[
                    Feature(
                        name="feature",
                        description="test",
                        transformation=SQLExpressionTransform(
                            expression="feature + 1"),
                        dtype=DataType.INTEGER,
                    ),
                ],
                keys=[
                    KeyFeature(
                        name="id",
                        description="The user's Main ID or device ID",
                        dtype=DataType.INTEGER,
                    )
                ],
                timestamp=TimestampFeature(),
            ).add_pre_hook(hook1).add_post_hook(hook1),
            sink=Sink(writers=[historical_writer], ).add_pre_hook(hook1),
        )

        # act
        test_pipeline.run()
        output_df = spark_session.table(
            "historical_feature_store__feature_set")

        # assert
        output_df.show()
        assert_dataframe_equality(output_df, target_df)
Пример #15
0
    def test_create_temporary_view(self, target_df, spark_session):
        # arrange
        spark_client = SparkClient()

        # act
        spark_client.create_temporary_view(target_df, "temp_view")
        result_df = spark_session.table("temp_view")

        # assert
        assert_dataframe_equality(target_df, result_df)
Пример #16
0
    def test_build_with_incremental_strategy(
        self, incremental_source_df, spark_client, spark_session
    ):
        # arrange
        readers = [
            # directly from column
            FileReader(
                id="test_1", path="path/to/file", format="format"
            ).with_incremental_strategy(
                incremental_strategy=IncrementalStrategy(column="date")
            ),
            # from milliseconds
            FileReader(
                id="test_2", path="path/to/file", format="format"
            ).with_incremental_strategy(
                incremental_strategy=IncrementalStrategy().from_milliseconds(
                    column_name="milliseconds"
                )
            ),
            # from str
            FileReader(
                id="test_3", path="path/to/file", format="format"
            ).with_incremental_strategy(
                incremental_strategy=IncrementalStrategy().from_string(
                    column_name="date_str", mask="dd/MM/yyyy"
                )
            ),
            # from year, month, day partitions
            FileReader(
                id="test_4", path="path/to/file", format="format"
            ).with_incremental_strategy(
                incremental_strategy=(
                    IncrementalStrategy().from_year_month_day_partitions()
                )
            ),
        ]

        spark_client.read.return_value = incremental_source_df
        target_df = incremental_source_df.where(
            "date >= date('2020-07-29') and date <= date('2020-07-31')"
        )

        # act
        for reader in readers:
            reader.build(
                client=spark_client, start_date="2020-07-29", end_date="2020-07-31"
            )

        output_dfs = [
            spark_session.table(f"test_{i + 1}") for i, _ in enumerate(readers)
        ]

        # assert
        for output_df in output_dfs:
            assert_dataframe_equality(output_df=output_df, target_df=target_df)
Пример #17
0
def test_not_check_schema_hook(spark_session):
    # arrange
    hook = NotCheckSchemaHook()
    input_df = spark_session.sql(
        "select 1 as id, int(null) as orders, int(null) as chargebacks")

    # act
    output_df = hook.run(input_df)

    # assert
    assert_dataframe_equality(output_df, input_df)
Пример #18
0
    def test_feature_transform(self, feature_set_dataframe, target_df_spark):
        test_feature = Feature(
            name="feature",
            description="unit test",
            transformation=SparkFunctionTransform(
                functions=[Function(functions.cos, DataType.DOUBLE)], ),
            from_column="feature1",
        )

        output_df = test_feature.transform(feature_set_dataframe)

        assert_dataframe_equality(output_df, target_df_spark)
Пример #19
0
def test_zero_fill_hook(spark_session):
    # arrange
    hook = ZeroFillHook()
    input_df = spark_session.sql(
        "select 1 as id, int(null) as orders, int(null) as chargebacks")
    expected_df = spark_session.sql(
        "select 1 as id, 0 as orders, 0 as chargebacks")

    # act
    output_df = hook.run(input_df)

    # assert
    assert_dataframe_equality(output_df, expected_df)
Пример #20
0
def test_create_df_from_collection(spark_context, spark_session):
    # arrange
    input_data = [{"json_column": '{"abc": 123}', "a": 123, "b": "abc"}]

    # act
    output_df = create_df_from_collection(input_data, spark_context,
                                          spark_session)
    target_df = spark_session.sql(
        "select 123 as a, 'abc' as b, replace("
        "to_json(named_struct('abc', 123)), ':', ': ') as json_column"
    )  # generate the same data but with SparkSQL directly to df

    # arrange
    assert_dataframe_equality(target_df, output_df)
    def test_run_hooks(self, spark_session):
        # arrange
        input_dataframe = spark_session.sql("select 2 as feature")
        test_component = (
            TestComponent()
            .add_pre_hook(AddHook(value=1))
            .add_post_hook(AddHook(value=1))
        )
        target_table = spark_session.sql("select 10 as feature")

        # act
        output_df = test_component.construct(input_dataframe)

        # assert
        assert_dataframe_equality(output_df, target_table)
Пример #22
0
    def test_overwriting_column(self, spark_session):
        # arrange
        input_df = spark_session.sql("select 0 as feature")
        feature_with_same_name = Feature(
            name="feature",
            description="description",
            dtype=DataType.INTEGER,
            transformation=SQLExpressionTransform(expression="feature + 1"),
        )
        target_df = spark_session.sql("select 1 as feature")

        # act
        output_df = feature_with_same_name.transform(input_df)

        # assert
        assert_dataframe_equality(output_df, target_df)
Пример #23
0
    def test_construct_without_window(
        self,
        feature_set_dataframe,
        target_df_without_window,
    ):
        # given

        spark_client = SparkClient()

        # arrange

        feature_set = AggregatedFeatureSet(
            name="feature_set",
            entity="entity",
            description="description",
            features=[
                Feature(
                    name="feature1",
                    description="test",
                    dtype=DataType.DOUBLE,
                    transformation=AggregatedTransform(
                        functions=[Function(F.avg, DataType.DOUBLE)]),
                ),
                Feature(
                    name="feature2",
                    description="test",
                    dtype=DataType.FLOAT,
                    transformation=AggregatedTransform(
                        functions=[Function(F.count, DataType.BIGINT)]),
                ),
            ],
            keys=[
                KeyFeature(
                    name="id",
                    description="The user's Main ID or device ID",
                    dtype=DataType.INTEGER,
                )
            ],
            timestamp=TimestampFeature(from_column="fixed_ts"),
        )

        # act
        output_df = feature_set.construct(feature_set_dataframe,
                                          client=spark_client)

        # assert
        assert_dataframe_equality(output_df, target_df_without_window)
Пример #24
0
def test_assert_dataframe_equality(spark_context, spark_session):
    # arrange
    data1 = [
        {
            "ts": 1582911000000,
            "flag": 1,
            "value": 1234.0
        },
        {
            "ts": 1577923200000,
            "flag": 0,
            "value": 123.0
        },
    ]
    data2 = [
        {
            "ts": "2020-01-02T00:00:00+00:00",
            "flag": "false",
            "value": 123
        },
        {
            "ts": "2020-02-28T17:30:00+00:00",
            "flag": "true",
            "value": 1234
        },
    ]  # same data declared in different formats and in different order

    df1 = spark_session.read.json(spark_context.parallelize(data1, 1))
    df1 = (df1.withColumn(
        "ts",
        from_unixtime(col("ts") / 1000.0).cast("timestamp")).withColumn(
            "flag",
            col("flag").cast("boolean")).withColumn(
                "value",
                col("flag").cast("integer")))

    df2 = spark_session.read.json(spark_context.parallelize(data2, 1))
    df2 = (df2.withColumn("ts",
                          col("ts").cast("timestamp")).withColumn(
                              "flag",
                              col("flag").cast("boolean")).withColumn(
                                  "value",
                                  col("flag").cast("integer")))

    # act and assert
    assert_dataframe_equality(df1, df2)
Пример #25
0
    def test_feature_transform_with_window(self, feature_set_dataframe,
                                           target_df_rows_agg):
        test_feature = Feature(
            name="feature1",
            description="unit test",
            transformation=SparkFunctionTransform(functions=[
                Function(functions.avg, DataType.DOUBLE)
            ], ).with_window(
                partition_by="id",
                mode="row_windows",
                window_definition=["2 events", "3 events"],
            ),
        )

        output_df = test_feature.transform(feature_set_dataframe)

        assert_dataframe_equality(output_df, target_df_rows_agg)
Пример #26
0
    def test_consume(
        self, topic, topic_options, stream, spark_client, spark_context, spark_session
    ):
        """Test for consume method in KafkaReader class.

        The test consists in check the correct use of the read method used inside
        consume. From a kafka format, there are some columns received from the client
        that are in binary. The raw_data and target_data defined in the method are
        used to assert if the consume method is casting the data types correctly,
        besides check if method is been called with the correct args.

        """
        # arrange
        raw_stream_df = create_df_from_collection(
            self.RAW_DATA, spark_context, spark_session
        )
        target_df = create_df_from_collection(
            self.TARGET_DATA, spark_context, spark_session
        )

        spark_client.read.return_value = raw_stream_df
        value_json_schema = StructType(
            [
                StructField("a", LongType()),
                StructField("b", StringType()),
                StructField("c", LongType()),
            ]
        )
        kafka_reader = KafkaReader(
            "test", topic, value_json_schema, topic_options=topic_options, stream=stream
        )

        # act
        output_df = kafka_reader.consume(spark_client)
        connection_string = specification["KAFKA_CONSUMER_CONNECTION_STRING"]
        options = dict(
            {"kafka.bootstrap.servers": connection_string, "subscribe": topic},
            **topic_options if topic_options else {},
        )

        # assert
        spark_client.read.assert_called_once_with(
            format="kafka", stream=kafka_reader.stream, **options
        )
        assert_dataframe_equality(target_df, output_df)
    def test_feature_transform(self, spark_context, spark_session):
        # arrange
        target_data = [
            {
                "id": 1,
                "feature": 100,
                "id_a": 1,
                "id_b": 2
            },
            {
                "id": 2,
                "feature": 100,
                "id_a": 1,
                "id_b": 2
            },
            {
                "id": 3,
                "feature": 120,
                "id_a": 3,
                "id_b": 4
            },
            {
                "id": 4,
                "feature": 120,
                "id_a": 3,
                "id_b": 4
            },
        ]
        input_df = create_df_from_collection(self.input_data, spark_context,
                                             spark_session)
        target_df = create_df_from_collection(target_data, spark_context,
                                              spark_session)

        feature_using_names = KeyFeature(
            name="id",
            description="id_a and id_b stacked in a single column.",
            dtype=DataType.INTEGER,
            transformation=StackTransform("id_*"),
        )

        # act
        result_df_1 = feature_using_names.transform(input_df)

        # assert
        assert_dataframe_equality(target_df, result_df_1)
    def test__struct_df(self, spark_context, spark_session):
        # arrange
        input_df = create_df_from_collection(self.RAW_DATA, spark_context,
                                             spark_session)
        target_df = create_df_from_collection(self.TARGET_DATA, spark_context,
                                              spark_session)
        value_schema = StructType([
            StructField("a", LongType()),
            StructField("b", StringType()),
            StructField("c", LongType()),
        ])
        kafka_reader = KafkaReader("id", "topic", value_schema)

        # act
        output_df = kafka_reader._struct_df(input_df)

        # arrange
        assert_dataframe_equality(target_df, output_df)
Пример #29
0
    def test_with_stack(self, h3_input_df, h3_with_stack_target_df):
        # arrange
        test_feature = KeyFeature(
            name="id",
            description="unit test",
            dtype=DataType.STRING,
            transformation=H3HashTransform(
                h3_resolutions=[6, 7, 8, 9, 10, 11, 12],
                lat_column="lat",
                lng_column="lng",
            ).with_stack(),
        )

        # act
        output_df = test_feature.transform(h3_input_df)

        # assert
        assert_dataframe_equality(h3_with_stack_target_df, output_df)
Пример #30
0
    def test_construct_with_pivot(
        self,
        feature_set_df_pivot,
        target_df_pivot_agg,
    ):
        # given

        spark_client = SparkClient()

        # arrange

        feature_set = AggregatedFeatureSet(
            name="feature_set",
            entity="entity",
            description="description",
            features=[
                Feature(
                    name="feature",
                    description="unit test",
                    transformation=AggregatedTransform(functions=[
                        Function(F.avg, DataType.FLOAT),
                        Function(F.stddev_pop, DataType.DOUBLE),
                    ], ),
                    from_column="feature1",
                )
            ],
            keys=[
                KeyFeature(
                    name="id",
                    description="The user's Main ID or device ID",
                    dtype=DataType.INTEGER,
                )
            ],
            timestamp=TimestampFeature(from_column="fixed_ts"),
        ).with_pivot("pivot_col", ["S", "N"])

        # act
        output_df = feature_set.construct(feature_set_df_pivot,
                                          client=spark_client)

        # assert
        assert_dataframe_equality(output_df, target_df_pivot_agg)