Exemplo n.º 1
0
 def test_source_raise(self):
     with pytest.raises(ValueError, match="source must be a Source instance"):
         FeatureSetPipeline(
             spark_client=SparkClient(),
             source=Mock(
                 spark_client=SparkClient(),
                 readers=[
                     TableReader(id="source_a", database="db", table="table",),
                 ],
                 query="select * from source_a",
             ),
             feature_set=Mock(
                 spec=FeatureSet,
                 name="feature_set",
                 entity="entity",
                 description="description",
                 keys=[
                     KeyFeature(
                         name="user_id",
                         description="The user's Main ID or device ID",
                         dtype=DataType.INTEGER,
                     )
                 ],
                 timestamp=TimestampFeature(from_column="ts"),
                 features=[
                     Feature(
                         name="listing_page_viewed__rent_per_month",
                         description="Average of something.",
                         transformation=SparkFunctionTransform(
                             functions=[
                                 Function(functions.avg, DataType.FLOAT),
                                 Function(functions.stddev_pop, DataType.FLOAT),
                             ],
                         ).with_window(
                             partition_by="user_id",
                             order_by=TIMESTAMP_COLUMN,
                             window_definition=["7 days", "2 weeks"],
                             mode="fixed_windows",
                         ),
                     ),
                 ],
             ),
             sink=Mock(
                 spec=Sink, writers=[HistoricalFeatureStoreWriter(db_config=None)],
             ),
         )
Exemplo n.º 2
0
    def test_conn(self):
        # arrange
        spark_client = SparkClient()

        # act
        start_conn = spark_client._session

        # assert
        assert start_conn is None
Exemplo n.º 3
0
    def test_run_with_repartition(self, spark_session):
        test_pipeline = FeatureSetPipeline(
            spark_client=SparkClient(),
            source=Mock(
                spec=Source,
                readers=[TableReader(id="source_a", database="db", table="table",)],
                query="select * from source_a",
            ),
            feature_set=Mock(
                spec=FeatureSet,
                name="feature_set",
                entity="entity",
                description="description",
                keys=[
                    KeyFeature(
                        name="user_id",
                        description="The user's Main ID or device ID",
                        dtype=DataType.INTEGER,
                    )
                ],
                timestamp=TimestampFeature(from_column="ts"),
                features=[
                    Feature(
                        name="listing_page_viewed__rent_per_month",
                        description="Average of something.",
                        transformation=SparkFunctionTransform(
                            functions=[
                                Function(functions.avg, DataType.FLOAT),
                                Function(functions.stddev_pop, DataType.FLOAT),
                            ],
                        ).with_window(
                            partition_by="user_id",
                            order_by=TIMESTAMP_COLUMN,
                            window_definition=["7 days", "2 weeks"],
                            mode="fixed_windows",
                        ),
                    ),
                ],
            ),
            sink=Mock(
                spec=Sink, writers=[HistoricalFeatureStoreWriter(db_config=None)],
            ),
        )

        # feature_set need to return a real df for streaming validation
        sample_df = spark_session.createDataFrame([{"a": "x", "b": "y", "c": "3"}])
        test_pipeline.feature_set.construct.return_value = sample_df

        test_pipeline.run(partition_by=["id"])

        test_pipeline.source.construct.assert_called_once()
        test_pipeline.feature_set.construct.assert_called_once()
        test_pipeline.sink.flush.assert_called_once()
        test_pipeline.sink.validate.assert_called_once()
Exemplo n.º 4
0
    def consume(self, client: SparkClient) -> DataFrame:
        """Extract data from a table in Spark metastore.

        Args:
            client: client responsible for connecting to Spark session.

        Returns:
            Dataframe with all the data from the table.

        """
        return client.read_table(self.table, self.database)
Exemplo n.º 5
0
    def test_construct_rolling_windows_with_end_date(
        self,
        feature_set_dataframe,
        rolling_windows_output_feature_set_dataframe_base_date,
    ):
        # given

        spark_client = SparkClient()

        # arrange

        feature_set = AggregatedFeatureSet(
            name="feature_set",
            entity="entity",
            description="description",
            features=[
                Feature(
                    name="feature1",
                    description="test",
                    transformation=AggregatedTransform(functions=[
                        Function(F.avg, DataType.DOUBLE),
                        Function(F.stddev_pop, DataType.DOUBLE),
                    ], ),
                ),
                Feature(
                    name="feature2",
                    description="test",
                    transformation=AggregatedTransform(functions=[
                        Function(F.avg, DataType.DOUBLE),
                        Function(F.stddev_pop, DataType.DOUBLE),
                    ], ),
                ),
            ],
            keys=[
                KeyFeature(
                    name="id",
                    description="The user's Main ID or device ID",
                    dtype=DataType.INTEGER,
                )
            ],
            timestamp=TimestampFeature(),
        ).with_windows(definitions=["1 day", "1 week"])

        # act
        output_df = feature_set.construct(
            feature_set_dataframe, client=spark_client,
            end_date="2016-04-18").orderBy("timestamp")

        target_df = rolling_windows_output_feature_set_dataframe_base_date.orderBy(
            feature_set.timestamp_column).select(feature_set.columns)

        # assert
        assert_dataframe_equality(output_df, target_df)
Exemplo n.º 6
0
    def test_write_table(self, format, mode, database, table_name, path,
                         mocked_spark_write):
        # given
        name = "{}.{}".format(database, table_name)

        # when
        SparkClient.write_table(
            dataframe=mocked_spark_write,
            database=database,
            table_name=table_name,
            format_=format,
            mode=mode,
            path=path,
        )

        # then
        mocked_spark_write.saveAsTable.assert_called_with(mode=mode,
                                                          format=format,
                                                          partitionBy=None,
                                                          name=name,
                                                          path=path)
Exemplo n.º 7
0
    def test_source(
        self,
        target_df_source,
        target_df_table_reader,
        spark_session,
    ):
        # given
        spark_client = SparkClient()

        table_reader_id = "a_test_source"
        table_reader_db = "db"
        table_reader_table = "table_test_source"

        create_temp_view(dataframe=target_df_table_reader,
                         name=table_reader_id)
        create_db_and_table(
            spark=spark_session,
            table_reader_id=table_reader_id,
            table_reader_db=table_reader_db,
            table_reader_table=table_reader_table,
        )

        file_reader_id = "b_test_source"
        data_sample_path = INPUT_PATH + "/data.json"

        # when
        source = Source(
            readers=[
                TableReader(
                    id=table_reader_id,
                    database=table_reader_db,
                    table=table_reader_table,
                ),
                FileReader(id=file_reader_id,
                           path=data_sample_path,
                           format="json"),
            ],
            query=f"select a.*, b.feature2 "  # noqa
            f"from {table_reader_id} a "  # noqa
            f"inner join {file_reader_id} b on a.id = b.id ",  # noqa
        )

        result_df = source.construct(client=spark_client)
        target_df = target_df_source

        # then
        assert (compare_dataframes(
            actual_df=result_df,
            expected_df=target_df,
            columns_sort=result_df.columns,
        ) is True)
Exemplo n.º 8
0
    def test_write_stream(self, feature_set, has_checkpoint, monkeypatch):
        # arrange
        spark_client = SparkClient()
        spark_client.write_stream = Mock()
        spark_client.write_dataframe = Mock()
        spark_client.write_stream.return_value = Mock(spec=StreamingQuery)

        dataframe = Mock(spec=DataFrame)
        dataframe.isStreaming = True

        if has_checkpoint:
            monkeypatch.setenv("STREAM_CHECKPOINT_PATH", "test")

        cassandra_config = CassandraConfig(keyspace="feature_set")
        target_checkpoint_path = ("test/entity/feature_set"
                                  if cassandra_config.stream_checkpoint_path
                                  else None)

        writer = OnlineFeatureStoreWriter(cassandra_config)
        writer.filter_latest = Mock()

        # act
        stream_handler = writer.write(feature_set, dataframe, spark_client)

        # assert
        assert isinstance(stream_handler, StreamingQuery)
        spark_client.write_stream.assert_any_call(
            dataframe,
            processing_time=cassandra_config.stream_processing_time,
            output_mode=cassandra_config.stream_output_mode,
            checkpoint_path=target_checkpoint_path,
            format_=cassandra_config.format_,
            mode=cassandra_config.mode,
            **cassandra_config.get_options(table=feature_set.name),
        )
        writer.filter_latest.assert_not_called()
        spark_client.write_dataframe.assert_not_called()
Exemplo n.º 9
0
    def consume(self, client: SparkClient) -> DataFrame:
        """Extract data from files stored in defined path.

        Try to auto-infer schema if in stream mode and not manually defining a
        schema.

        Args:
            client: client responsible for connecting to Spark session.

        Returns:
            Dataframe with all the files data.

        """
        schema = (client.read(
            format=self.format,
            options=self.options,
        ).schema if (self.stream and not self.schema) else self.schema)

        return client.read(
            format=self.format,
            options=self.options,
            schema=schema,
            stream=self.stream,
        )
Exemplo n.º 10
0
def test_sink(input_dataframe, feature_set):
    # arrange
    client = SparkClient()
    feature_set_df = feature_set.construct(input_dataframe, client)
    target_latest_df = OnlineFeatureStoreWriter.filter_latest(
        feature_set_df, id_columns=[key.name for key in feature_set.keys])
    columns_sort = feature_set_df.schema.fieldNames()

    # setup historical writer
    s3config = Mock()
    s3config.get_options = Mock(
        return_value={
            "mode": "overwrite",
            "format_": "parquet",
            "path": "test_folder/historical/entity/feature_set",
        })
    historical_writer = HistoricalFeatureStoreWriter(db_config=s3config)

    # setup online writer
    # TODO: Change for CassandraConfig when Cassandra for test is ready
    online_config = Mock()
    online_config.mode = "overwrite"
    online_config.format_ = "parquet"
    online_config.get_options = Mock(
        return_value={"path": "test_folder/online/entity/feature_set"})
    online_writer = OnlineFeatureStoreWriter(db_config=online_config)

    writers = [historical_writer, online_writer]
    sink = Sink(writers)

    # act
    client.sql("CREATE DATABASE IF NOT EXISTS {}".format(
        historical_writer.database))
    sink.flush(feature_set, feature_set_df, client)

    # get historical results
    historical_result_df = client.read_table(feature_set.name,
                                             historical_writer.database)

    # get online results
    online_result_df = client.read(online_config.format_,
                                   options=online_config.get_options(
                                       feature_set.name))

    # assert
    # assert historical results
    assert sorted(feature_set_df.select(*columns_sort).collect()) == sorted(
        historical_result_df.select(*columns_sort).collect())

    # assert online results
    assert sorted(target_latest_df.select(*columns_sort).collect()) == sorted(
        online_result_df.select(*columns_sort).collect())

    # tear down
    shutil.rmtree("test_folder")
Exemplo n.º 11
0
    def test_construct_without_window(
        self,
        feature_set_dataframe,
        target_df_without_window,
    ):
        # given

        spark_client = SparkClient()

        # arrange

        feature_set = AggregatedFeatureSet(
            name="feature_set",
            entity="entity",
            description="description",
            features=[
                Feature(
                    name="feature1",
                    description="test",
                    dtype=DataType.DOUBLE,
                    transformation=AggregatedTransform(
                        functions=[Function(F.avg, DataType.DOUBLE)]),
                ),
                Feature(
                    name="feature2",
                    description="test",
                    dtype=DataType.FLOAT,
                    transformation=AggregatedTransform(
                        functions=[Function(F.count, DataType.BIGINT)]),
                ),
            ],
            keys=[
                KeyFeature(
                    name="id",
                    description="The user's Main ID or device ID",
                    dtype=DataType.INTEGER,
                )
            ],
            timestamp=TimestampFeature(from_column="fixed_ts"),
        )

        # act
        output_df = feature_set.construct(feature_set_dataframe,
                                          client=spark_client)

        # assert
        assert_dataframe_equality(output_df, target_df_without_window)
Exemplo n.º 12
0
    def test_construct(self, mocker, target_df):
        # given
        spark_client = SparkClient()

        reader_id = "a_source"
        reader = mocker.stub(reader_id)
        reader.build = mocker.stub("build")
        reader.build.side_effect = target_df.createOrReplaceTempView(reader_id)

        # when
        source_selector = Source(
            readers=[reader], query=f"select * from {reader_id}",  # noqa
        )

        result_df = source_selector.construct(spark_client)

        assert result_df.collect() == target_df.collect()
 def test_feature_without_datatype(self, key_id, timestamp_c, dataframe):
     spark_client = SparkClient()
     with pytest.raises(ValueError):
         FeatureSet(
             name="name",
             entity="entity",
             description="description",
             features=[
                 Feature(
                     name="feature1",
                     description="test",
                     transformation=SQLExpressionTransform(
                         expression="feature1 + a"),
                 ),
             ],
             keys=[key_id],
             timestamp=timestamp_c,
         ).construct(dataframe, spark_client)
 def test_feature_set_with_invalid_feature(self, key_id, timestamp_c,
                                           dataframe):
     spark_client = SparkClient()
     with pytest.raises(ValueError):
         FeatureSet(
             name="name",
             entity="entity",
             description="description",
             features=[
                 Feature(
                     name="feature1",
                     description="test",
                     transformation=AggregatedTransform(
                         functions=[Function(F.avg, DataType.FLOAT)]),
                 ),
             ],
             keys=[key_id],
             timestamp=timestamp_c,
         ).construct(dataframe, spark_client)
Exemplo n.º 15
0
    def test_construct_with_pivot(
        self,
        feature_set_df_pivot,
        target_df_pivot_agg,
    ):
        # given

        spark_client = SparkClient()

        # arrange

        feature_set = AggregatedFeatureSet(
            name="feature_set",
            entity="entity",
            description="description",
            features=[
                Feature(
                    name="feature",
                    description="unit test",
                    transformation=AggregatedTransform(functions=[
                        Function(F.avg, DataType.FLOAT),
                        Function(F.stddev_pop, DataType.DOUBLE),
                    ], ),
                    from_column="feature1",
                )
            ],
            keys=[
                KeyFeature(
                    name="id",
                    description="The user's Main ID or device ID",
                    dtype=DataType.INTEGER,
                )
            ],
            timestamp=TimestampFeature(from_column="fixed_ts"),
        ).with_pivot("pivot_col", ["S", "N"])

        # act
        output_df = feature_set.construct(feature_set_df_pivot,
                                          client=spark_client)

        # assert
        assert_dataframe_equality(output_df, target_df_pivot_agg)
    def test_json_file_with_schema(self):
        # given
        spark_client = SparkClient()
        schema_json = StructType([
            StructField("A", StringType()),
            StructField("B", DoubleType()),
            StructField("C", StringType()),
        ])

        file = "tests/unit/butterfree/core/extract/readers/file-reader-test.json"

        # when
        file_reader = FileReader(id="id",
                                 path=file,
                                 format="json",
                                 schema=schema_json)
        df = file_reader.consume(spark_client)

        # assert
        assert schema_json == df.schema
 def _write_stream(self, feature_set: FeatureSet, dataframe: DataFrame,
                   spark_client: SparkClient):
     """Writes the dataframe in streaming mode."""
     # TODO: Refactor this logic using the Sink returning the Query Handler
     for table in [feature_set.name, feature_set.entity]:
         checkpoint_path = (os.path.join(
             self.db_config.stream_checkpoint_path,
             feature_set.entity,
             f"{feature_set.name}__on_entity"
             if table == feature_set.entity else table,
         ) if self.db_config.stream_checkpoint_path else None)
         streaming_handler = spark_client.write_stream(
             dataframe,
             processing_time=self.db_config.stream_processing_time,
             output_mode=self.db_config.stream_output_mode,
             checkpoint_path=checkpoint_path,
             format_=self.db_config.format_,
             mode=self.db_config.mode,
             **self.db_config.get_options(table=table),
         )
     return streaming_handler
Exemplo n.º 18
0
    def test_flush_with_invalid_df(self, not_feature_set_dataframe, mocker):
        # given
        spark_client = SparkClient()
        writer = [
            HistoricalFeatureStoreWriter(),
            OnlineFeatureStoreWriter(),
        ]
        feature_set = mocker.stub("feature_set")
        feature_set.entity = "house"
        feature_set.name = "test"

        # when
        sink = Sink(writers=writer)

        # then
        with pytest.raises(ValueError):
            sink.flush(
                dataframe=not_feature_set_dataframe,
                feature_set=feature_set,
                spark_client=spark_client,
            )
Exemplo n.º 19
0
    def test_agg_feature_set_with_window(self, key_id, timestamp_c, dataframe,
                                         rolling_windows_agg_dataframe):
        spark_client = SparkClient()

        fs = AggregatedFeatureSet(
            name="name",
            entity="entity",
            description="description",
            features=[
                Feature(
                    name="feature1",
                    description="unit test",
                    transformation=AggregatedTransform(
                        functions=[Function(functions.avg, DataType.FLOAT)]),
                ),
                Feature(
                    name="feature2",
                    description="unit test",
                    transformation=AggregatedTransform(
                        functions=[Function(functions.avg, DataType.FLOAT)]),
                ),
            ],
            keys=[key_id],
            timestamp=timestamp_c,
        ).with_windows(definitions=["1 week"])

        # raises without end date
        with pytest.raises(ValueError):
            _ = fs.construct(dataframe, spark_client)

        # filters with date smaller then mocked max
        output_df = fs.construct(dataframe,
                                 spark_client,
                                 end_date="2016-04-17")
        assert output_df.count() < rolling_windows_agg_dataframe.count()
        output_df = fs.construct(dataframe,
                                 spark_client,
                                 end_date="2016-05-01")
        assert_dataframe_equality(output_df, rolling_windows_agg_dataframe)
    def test_write_in_debug_mode(
        self,
        feature_set_dataframe,
        historical_feature_set_dataframe,
        feature_set,
        spark_session,
    ):
        # given
        spark_client = SparkClient()
        writer = HistoricalFeatureStoreWriter(debug_mode=True)

        # when
        writer.write(
            feature_set=feature_set,
            dataframe=feature_set_dataframe,
            spark_client=spark_client,
        )
        result_df = spark_session.table(
            f"historical_feature_store__{feature_set.name}")

        # then
        assert_dataframe_equality(historical_feature_set_dataframe, result_df)
Exemplo n.º 21
0
    def consume(self, client: SparkClient) -> DataFrame:
        """Extract data from a kafka topic.

        When stream mode it will get all the new data arriving at the topic in a
        streaming dataframe. When not in stream mode it will get all data
        available in the kafka topic.

        Args:
            client: client responsible for connecting to Spark session.

        Returns:
            Dataframe with data from topic.

        """
        # read using client and cast key and value columns from binary to string
        raw_df = (
            client.read(format="kafka", options=self.options, stream=self.stream)
            .withColumn("key", col("key").cast("string"))
            .withColumn("value", col("value").cast("string"))
        )

        # apply schema defined in self.value_schema
        return self._struct_df(raw_df)
    def validate(
        self, feature_set: FeatureSet, dataframe: DataFrame, spark_client: SparkClient
    ):
        """Calculate dataframe rows to validate data into Feature Store.

        Args:
            feature_set: object processed with feature_set informations.
            dataframe: spark dataframe containing data from a feature set.
            spark_client: client for spark connections with external services.

        Raises:
            AssertionError: if count of written data doesn't match count in current
                feature set dataframe.

        """
        table_name = (
            f"{self.database}.{feature_set.name}"
            if not self.debug_mode
            else f"historical_feature_store__{feature_set.name}"
        )
        written_count = spark_client.read_table(table_name).count()
        dataframe_count = dataframe.count()
        self._assert_validation_count(table_name, written_count, dataframe_count)
Exemplo n.º 23
0
    def test_construct_rolling_windows_without_end_date(
            self, feature_set_dataframe,
            rolling_windows_output_feature_set_dataframe):
        # given

        spark_client = SparkClient()

        # arrange

        feature_set = AggregatedFeatureSet(
            name="feature_set",
            entity="entity",
            description="description",
            features=[
                Feature(
                    name="feature1",
                    description="test",
                    transformation=AggregatedTransform(functions=[
                        Function(F.avg, DataType.DOUBLE),
                        Function(F.stddev_pop, DataType.DOUBLE),
                    ], ),
                ),
            ],
            keys=[
                KeyFeature(
                    name="id",
                    description="The user's Main ID or device ID",
                    dtype=DataType.INTEGER,
                )
            ],
            timestamp=TimestampFeature(),
        ).with_windows(definitions=["1 day", "1 week"], )

        # act & assert
        with pytest.raises(ValueError):
            _ = feature_set.construct(feature_set_dataframe,
                                      client=spark_client)
Exemplo n.º 24
0
    def test_validate_false(self, feature_set_dataframe, mocker):
        # given
        spark_client = SparkClient()
        writer = [
            HistoricalFeatureStoreWriter(),
            OnlineFeatureStoreWriter(),
        ]

        for w in writer:
            w.validate = mocker.stub("validate")
            w.validate.side_effect = AssertionError("test")

        feature_set = mocker.stub("feature_set")

        # when
        sink = Sink(writers=writer)

        # then
        with pytest.raises(RuntimeError):
            sink.validate(
                dataframe=feature_set_dataframe,
                feature_set=feature_set,
                spark_client=spark_client,
            )
Exemplo n.º 25
0
    def test_validate(self, feature_set_dataframe, mocker):
        # given
        spark_client = SparkClient()
        writer = [
            HistoricalFeatureStoreWriter(),
            OnlineFeatureStoreWriter(),
        ]

        for w in writer:
            w.validate = mocker.stub("validate")

        feature_set = mocker.stub("feature_set")

        # when
        sink = Sink(writers=writer)
        sink.validate(
            dataframe=feature_set_dataframe,
            feature_set=feature_set,
            spark_client=spark_client,
        )

        # then
        for w in writer:
            w.validate.assert_called_once()
Exemplo n.º 26
0
 def test_write_dataframe(self, format, mode, mocked_spark_write):
     SparkClient.write_dataframe(mocked_spark_write, format, mode)
     mocked_spark_write.save.assert_called_with(format=format, mode=mode)
Exemplo n.º 27
0
    def test_feature_transform_with_data_type_array(self, spark_context,
                                                    spark_session):
        # arrange
        input_data = [
            {
                "id": 1,
                "timestamp": "2020-04-22T00:00:00+00:00",
                "feature": 10
            },
            {
                "id": 1,
                "timestamp": "2020-04-22T00:00:00+00:00",
                "feature": 20
            },
            {
                "id": 1,
                "timestamp": "2020-04-22T00:00:00+00:00",
                "feature": 30
            },
            {
                "id": 2,
                "timestamp": "2020-04-22T00:00:00+00:00",
                "feature": 10
            },
        ]
        target_data = [
            {
                "id": 1,
                "timestamp": "2020-04-22T00:00:00+00:00",
                "feature__collect_set": [30.0, 20.0, 10.0],
            },
            {
                "id": 2,
                "timestamp": "2020-04-22T00:00:00+00:00",
                "feature__collect_set": [10.0],
            },
        ]
        input_df = create_df_from_collection(
            input_data, spark_context, spark_session).withColumn(
                "timestamp",
                functions.to_timestamp(functions.col("timestamp")))
        target_df = create_df_from_collection(
            target_data, spark_context, spark_session).withColumn(
                "timestamp",
                functions.to_timestamp(functions.col("timestamp")))

        fs = AggregatedFeatureSet(
            name="name",
            entity="entity",
            description="description",
            keys=[
                KeyFeature(name="id",
                           description="test",
                           dtype=DataType.INTEGER)
            ],
            timestamp=TimestampFeature(),
            features=[
                Feature(
                    name="feature",
                    description="aggregations with ",
                    dtype=DataType.BIGINT,
                    transformation=AggregatedTransform(functions=[
                        Function(functions.collect_set, DataType.ARRAY_FLOAT),
                    ], ),
                    from_column="feature",
                ),
            ],
        )

        # act
        output_df = fs.construct(input_df, SparkClient())

        # assert
        assert_dataframe_equality(target_df, output_df)
Exemplo n.º 28
0
    def test_feature_transform_with_filter_expression(self, spark_context,
                                                      spark_session):
        # arrange
        input_data = [
            {
                "id": 1,
                "timestamp": "2020-04-22T00:00:00+00:00",
                "feature": 10,
                "type": "a",
            },
            {
                "id": 1,
                "timestamp": "2020-04-22T00:00:00+00:00",
                "feature": 20,
                "type": "a",
            },
            {
                "id": 1,
                "timestamp": "2020-04-22T00:00:00+00:00",
                "feature": 30,
                "type": "b",
            },
            {
                "id": 2,
                "timestamp": "2020-04-22T00:00:00+00:00",
                "feature": 10,
                "type": "a",
            },
        ]
        target_data = [
            {
                "id": 1,
                "timestamp": "2020-04-22T00:00:00+00:00",
                "feature_only_type_a__avg": 15.0,
                "feature_only_type_a__min": 10,
                "feature_only_type_a__max": 20,
            },
            {
                "id": 2,
                "timestamp": "2020-04-22T00:00:00+00:00",
                "feature_only_type_a__avg": 10.0,
                "feature_only_type_a__min": 10,
                "feature_only_type_a__max": 10,
            },
        ]
        input_df = create_df_from_collection(
            input_data, spark_context, spark_session).withColumn(
                "timestamp",
                functions.to_timestamp(functions.col("timestamp")))
        target_df = create_df_from_collection(
            target_data, spark_context, spark_session).withColumn(
                "timestamp",
                functions.to_timestamp(functions.col("timestamp")))

        fs = AggregatedFeatureSet(
            name="name",
            entity="entity",
            description="description",
            keys=[
                KeyFeature(name="id",
                           description="test",
                           dtype=DataType.INTEGER)
            ],
            timestamp=TimestampFeature(),
            features=[
                Feature(
                    name="feature_only_type_a",
                    description="aggregations only when type = a",
                    dtype=DataType.BIGINT,
                    transformation=AggregatedTransform(
                        functions=[
                            Function(functions.avg, DataType.FLOAT),
                            Function(functions.min, DataType.FLOAT),
                            Function(functions.max, DataType.FLOAT),
                        ],
                        filter_expression="type = 'a'",
                    ),
                    from_column="feature",
                ),
            ],
        )

        # act
        output_df = fs.construct(input_df, SparkClient())

        # assert
        assert_dataframe_equality(target_df, output_df)
 def _write_in_debug_mode(feature_set: FeatureSet, dataframe: DataFrame,
                          spark_client: SparkClient):
     """Creates a temporary table instead of writing to the real data source."""
     return spark_client.create_temporary_view(
         dataframe=dataframe,
         name=f"online_feature_store__{feature_set.name}")