Esempio n. 1
0
    def test_write(
        self,
        feature_set_dataframe,
        latest_feature_set_dataframe,
        cassandra_config,
        mocker,
        feature_set,
    ):
        # with
        spark_client = mocker.stub("spark_client")
        spark_client.write_dataframe = mocker.stub("write_dataframe")
        writer = OnlineFeatureStoreWriter(cassandra_config)

        # when
        writer.write(feature_set, feature_set_dataframe, spark_client)

        assert sorted(latest_feature_set_dataframe.collect()) == sorted(
            spark_client.write_dataframe.call_args[1]["dataframe"].collect()
        )
        assert (
            writer.db_config.mode == spark_client.write_dataframe.call_args[1]["mode"]
        )
        assert (
            writer.db_config.format_
            == spark_client.write_dataframe.call_args[1]["format_"]
        )
        # assert if all additional options got from db_config
        # are in the called args in write_dataframe
        assert all(
            item in spark_client.write_dataframe.call_args[1].items()
            for item in writer.db_config.get_options(table=feature_set.name).items()
        )
Esempio n. 2
0
def test_sink(input_dataframe, feature_set):
    # arrange
    client = SparkClient()
    client.conn.conf.set("spark.sql.sources.partitionOverwriteMode", "dynamic")
    feature_set_df = feature_set.construct(input_dataframe, client)
    target_latest_df = OnlineFeatureStoreWriter.filter_latest(
        feature_set_df, id_columns=[key.name for key in feature_set.keys])
    columns_sort = feature_set_df.schema.fieldNames()

    # setup historical writer
    s3config = Mock()
    s3config.mode = "overwrite"
    s3config.format_ = "parquet"
    s3config.get_options = Mock(
        return_value={"path": "test_folder/historical/entity/feature_set"})
    s3config.get_path_with_partitions = Mock(
        return_value="test_folder/historical/entity/feature_set")

    historical_writer = HistoricalFeatureStoreWriter(db_config=s3config,
                                                     interval_mode=True)

    # setup online writer
    # TODO: Change for CassandraConfig when Cassandra for test is ready
    online_config = Mock()
    online_config.mode = "overwrite"
    online_config.format_ = "parquet"
    online_config.get_options = Mock(
        return_value={"path": "test_folder/online/entity/feature_set"})
    online_writer = OnlineFeatureStoreWriter(db_config=online_config)

    writers = [historical_writer, online_writer]
    sink = Sink(writers)

    # act
    client.sql("CREATE DATABASE IF NOT EXISTS {}".format(
        historical_writer.database))
    sink.flush(feature_set, feature_set_df, client)

    # get historical results
    historical_result_df = client.read(
        s3config.format_,
        path=s3config.get_path_with_partitions(feature_set.name,
                                               feature_set_df),
    )

    # get online results
    online_result_df = client.read(
        online_config.format_, **online_config.get_options(feature_set.name))

    # assert
    # assert historical results
    assert sorted(feature_set_df.select(*columns_sort).collect()) == sorted(
        historical_result_df.select(*columns_sort).collect())

    # assert online results
    assert sorted(target_latest_df.select(*columns_sort).collect()) == sorted(
        online_result_df.select(*columns_sort).collect())

    # tear down
    shutil.rmtree("test_folder")
Esempio n. 3
0
    def test_write_in_debug_and_stream_mode(
        self, feature_set, spark_session,
    ):
        # arrange
        spark_client = SparkClient()

        mocked_stream_df = Mock()
        mocked_stream_df.isStreaming = True
        mocked_stream_df.writeStream = mocked_stream_df
        mocked_stream_df.format.return_value = mocked_stream_df
        mocked_stream_df.queryName.return_value = mocked_stream_df
        mocked_stream_df.start.return_value = Mock(spec=StreamingQuery)

        writer = OnlineFeatureStoreWriter(debug_mode=True)

        # act
        handler = writer.write(
            feature_set=feature_set,
            dataframe=mocked_stream_df,
            spark_client=spark_client,
        )

        # assert
        mocked_stream_df.format.assert_called_with("memory")
        mocked_stream_df.queryName.assert_called_with(
            f"online_feature_store__{feature_set.name}"
        )
        assert isinstance(handler, StreamingQuery)
Esempio n. 4
0
    def test_filter_latest_without_ts(
        self, feature_set_dataframe_without_ts, cassandra_config
    ):
        # with
        writer = OnlineFeatureStoreWriter(cassandra_config)

        # then
        with pytest.raises(KeyError):
            _ = writer.filter_latest(
                feature_set_dataframe_without_ts, id_columns=["id"]
            )
Esempio n. 5
0
    def test_filter_latest(
        self, feature_set_dataframe, latest_feature_set_dataframe, cassandra_config
    ):
        # with
        writer = OnlineFeatureStoreWriter(cassandra_config)

        # when
        result_df = writer.filter_latest(feature_set_dataframe, id_columns=["id"])
        sort_columns = result_df.schema.fieldNames()

        # then
        assert sorted(
            latest_feature_set_dataframe.select(*sort_columns).collect()
        ) == sorted(result_df.select(*sort_columns).collect())
Esempio n. 6
0
    def test_filter_latest_without_id_columns(
        self, feature_set_dataframe, cassandra_config
    ):
        # with
        writer = OnlineFeatureStoreWriter(cassandra_config)

        # then
        with pytest.raises(ValueError, match="must provide the unique identifiers"):
            _ = writer.filter_latest(feature_set_dataframe, id_columns=[])

        # then
        with pytest.raises(KeyError, match="not found"):
            _ = writer.filter_latest(
                feature_set_dataframe.drop("id"), id_columns=["id"]
            )
Esempio n. 7
0
    def test_flush(self, feature_set_dataframe, mocker):
        # given
        spark_client = SparkClient()
        writer = [
            HistoricalFeatureStoreWriter(),
            OnlineFeatureStoreWriter(),
        ]

        for w in writer:
            w.write = mocker.stub("write")

        feature_set = mocker.stub("feature_set")
        feature_set.entity = "house"
        feature_set.name = "test"

        # when
        sink = Sink(writers=writer)
        sink.flush(
            dataframe=feature_set_dataframe,
            feature_set=feature_set,
            spark_client=spark_client,
        )

        # then
        for w in writer:
            w.write.assert_called_once()
Esempio n. 8
0
 def __init__(self):
     super(FirstPipeline, self).__init__(
         source=Source(
             readers=[TableReader(id="t", database="db", table="table",)],
             query=f"select * from t",  # noqa
         ),
         feature_set=FeatureSet(
             name="first",
             entity="entity",
             description="description",
             features=[
                 Feature(name="feature1", description="test", dtype=DataType.FLOAT,),
                 Feature(
                     name="feature2",
                     description="another test",
                     dtype=DataType.STRING,
                 ),
             ],
             keys=[
                 KeyFeature(
                     name="id", description="identifier", dtype=DataType.BIGINT,
                 )
             ],
             timestamp=TimestampFeature(),
         ),
         sink=Sink(
             writers=[HistoricalFeatureStoreWriter(), OnlineFeatureStoreWriter()]
         ),
     )
Esempio n. 9
0
 def __init__(self):
     super(UserChargebacksPipeline, self).__init__(
         source=Source(
             readers=[
                 FileReader(
                     id="chargeback_events",
                     path="data/order_events/input.csv",
                     format="csv",
                     format_options={"header": True},
                 )
             ],
             query=("""
                 select
                     cpf,
                     timestamp(chargeback_timestamp) as timestamp,
                     order_id
                 from
                     chargeback_events
                 where
                     chargeback_timestamp is not null
                 """),
         ),
         feature_set=AggregatedFeatureSet(
             name="user_chargebacks",
             entity="user",
             description="Aggregates the total of chargebacks from users in "
             "different time windows.",
             keys=[
                 KeyFeature(
                     name="cpf",
                     description="User unique identifier, entity key.",
                     dtype=DataType.STRING,
                 )
             ],
             timestamp=TimestampFeature(),
             features=[
                 Feature(
                     name="cpf_chargebacks",
                     description=
                     "Total of chargebacks registered on user's CPF",
                     transformation=AggregatedTransform(functions=[
                         Function(functions.count, DataType.INTEGER)
                     ]),
                     from_column="order_id",
                 ),
             ],
         ).with_windows(
             definitions=["3 days", "7 days", "30 days"]).add_post_hook(
                 ZeroFillHook()),
         sink=Sink(writers=[
             LocalHistoricalFSWriter(),
             OnlineFeatureStoreWriter(
                 interval_mode=True,
                 check_schema_hook=NotCheckSchemaHook(),
                 debug_mode=True,
             ),
         ]),
     )
Esempio n. 10
0
    def test_flush_with_multiple_online_writers(self, feature_set,
                                                feature_set_dataframe):
        """Testing the flow of writing to a feature-set table and to an entity table."""
        # arrange
        spark_client = SparkClient()
        spark_client.write_dataframe = Mock()

        feature_set.entity = "my_entity"
        feature_set.name = "my_feature_set"

        online_feature_store_writer = OnlineFeatureStoreWriter()

        online_feature_store_writer_on_entity = OnlineFeatureStoreWriter(
            write_to_entity=True)

        sink = Sink(writers=[
            online_feature_store_writer, online_feature_store_writer_on_entity
        ])

        # act
        sink.flush(
            dataframe=feature_set_dataframe,
            feature_set=feature_set,
            spark_client=spark_client,
        )

        # assert
        spark_client.write_dataframe.assert_any_call(
            dataframe=ANY,
            format_=ANY,
            mode=ANY,
            **online_feature_store_writer.db_config.get_options(
                table="my_entity"),
        )

        spark_client.write_dataframe.assert_any_call(
            dataframe=ANY,
            format_=ANY,
            mode=ANY,
            **online_feature_store_writer.db_config.get_options(
                table="my_feature_set"),
        )
Esempio n. 11
0
    def test_write_with_kafka_config(
        self,
        feature_set_dataframe,
        online_feature_set_dataframe_json,
        mocker,
        feature_set,
    ):
        # with
        spark_client = mocker.stub("spark_client")
        spark_client.write_dataframe = mocker.stub("write_dataframe")
        kafka_config = KafkaConfig()
        writer = OnlineFeatureStoreWriter(kafka_config).with_(json_transform)

        # when
        writer.write(feature_set, feature_set_dataframe, spark_client)

        assert all(
            item in spark_client.write_dataframe.call_args[1].items()
            for item in writer.db_config.get_options(topic=feature_set.name).items()
        )
Esempio n. 12
0
    def test_write_in_debug_mode(
        self,
        feature_set_dataframe,
        latest_feature_set_dataframe,
        feature_set,
        spark_session,
    ):
        # given
        spark_client = SparkClient()
        writer = OnlineFeatureStoreWriter(debug_mode=True)

        # when
        writer.write(
            feature_set=feature_set,
            dataframe=feature_set_dataframe,
            spark_client=spark_client,
        )
        result_df = spark_session.table(f"online_feature_store__{feature_set.name}")

        # then
        assert_dataframe_equality(latest_feature_set_dataframe, result_df)
Esempio n. 13
0
    def test_flush_streaming_df(self, feature_set):
        """Testing the return of the streaming handlers by the sink."""
        # arrange
        spark_client = SparkClient()

        mocked_stream_df = Mock()
        mocked_stream_df.isStreaming = True
        mocked_stream_df.writeStream = mocked_stream_df
        mocked_stream_df.trigger.return_value = mocked_stream_df
        mocked_stream_df.outputMode.return_value = mocked_stream_df
        mocked_stream_df.outputMode.return_value = mocked_stream_df
        mocked_stream_df.option.return_value = mocked_stream_df
        mocked_stream_df.foreachBatch.return_value = mocked_stream_df
        mocked_stream_df.start.return_value = Mock(spec=StreamingQuery)

        online_feature_store_writer = OnlineFeatureStoreWriter()

        online_feature_store_writer_on_entity = OnlineFeatureStoreWriter(
            write_to_entity=True)

        sink = Sink(
            writers=[
                online_feature_store_writer,
                online_feature_store_writer_on_entity,
            ],
            validation=Mock(spec=BasicValidation),
        )

        # act
        handlers = sink.flush(
            dataframe=mocked_stream_df,
            feature_set=feature_set,
            spark_client=spark_client,
        )

        # assert
        print(handlers[0])
        print(isinstance(handlers[0], StreamingQuery))
        for handler in handlers:
            assert isinstance(handler, StreamingQuery)
Esempio n. 14
0
    def test_write_stream(self, feature_set, has_checkpoint, monkeypatch):
        # arrange
        spark_client = SparkClient()
        spark_client.write_stream = Mock()
        spark_client.write_dataframe = Mock()
        spark_client.write_stream.return_value = Mock(spec=StreamingQuery)

        dataframe = Mock(spec=DataFrame)
        dataframe.isStreaming = True

        if has_checkpoint:
            monkeypatch.setenv("STREAM_CHECKPOINT_PATH", "test")

        cassandra_config = CassandraConfig(keyspace="feature_set")
        target_checkpoint_path = (
            "test/entity/feature_set"
            if cassandra_config.stream_checkpoint_path
            else None
        )

        writer = OnlineFeatureStoreWriter(cassandra_config)
        writer.filter_latest = Mock()

        # act
        stream_handler = writer.write(feature_set, dataframe, spark_client)

        # assert
        assert isinstance(stream_handler, StreamingQuery)
        spark_client.write_stream.assert_any_call(
            dataframe,
            processing_time=cassandra_config.stream_processing_time,
            output_mode=cassandra_config.stream_output_mode,
            checkpoint_path=target_checkpoint_path,
            format_=cassandra_config.format_,
            mode=cassandra_config.mode,
            **cassandra_config.get_options(table=feature_set.name),
        )
        writer.filter_latest.assert_not_called()
        spark_client.write_dataframe.assert_not_called()
Esempio n. 15
0
    def test_write_stream_on_entity(self, feature_set, monkeypatch):
        """Test write method with stream dataframe and write_to_entity enabled.

        The main purpose of this test is assert the correct setup of stream checkpoint
        path and if the target table name is the entity.

        """

        # arrange
        spark_client = SparkClient()
        spark_client.write_stream = Mock()
        spark_client.write_stream.return_value = Mock(spec=StreamingQuery)

        dataframe = Mock(spec=DataFrame)
        dataframe.isStreaming = True

        feature_set.entity = "my_entity"
        feature_set.name = "my_feature_set"
        monkeypatch.setenv("STREAM_CHECKPOINT_PATH", "test")
        target_checkpoint_path = "test/my_entity/my_feature_set__on_entity"

        writer = OnlineFeatureStoreWriter(write_to_entity=True)

        # act
        stream_handler = writer.write(feature_set, dataframe, spark_client)

        # assert
        assert isinstance(stream_handler, StreamingQuery)
        spark_client.write_stream.assert_called_with(
            dataframe,
            processing_time=writer.db_config.stream_processing_time,
            output_mode=writer.db_config.stream_output_mode,
            checkpoint_path=target_checkpoint_path,
            format_=writer.db_config.format_,
            mode=writer.db_config.mode,
            **writer.db_config.get_options(table="my_entity"),
        )
Esempio n. 16
0
def loader(features_set_df: pyspark.sql.DataFrame) -> Sink:

    db_config = get_config()

    keyspace = "feature_store"
    table_name = "orders_feature_master_table_"
    primary_key = "customer_id"

    create_table(features_set_df, keyspace, table_name, primary_key)

    writers = [
        HistoricalFeatureStoreWriter(debug_mode=True),
        OnlineFeatureStoreWriter(db_config=db_config)
    ]

    #writers = [HistoricalFeatureStoreWriter(debug_mode=True)]

    sink = Sink(writers=writers)
    return sink
Esempio n. 17
0
    def test_flush_with_invalid_df(self, not_feature_set_dataframe, mocker):
        # given
        spark_client = SparkClient()
        writer = [
            HistoricalFeatureStoreWriter(),
            OnlineFeatureStoreWriter(),
        ]
        feature_set = mocker.stub("feature_set")
        feature_set.entity = "house"
        feature_set.name = "test"

        # when
        sink = Sink(writers=writer)

        # then
        with pytest.raises(ValueError):
            sink.flush(
                dataframe=not_feature_set_dataframe,
                feature_set=feature_set,
                spark_client=spark_client,
            )
Esempio n. 18
0
    def test_validate(self, feature_set_dataframe, mocker):
        # given
        spark_client = SparkClient()
        writer = [
            HistoricalFeatureStoreWriter(),
            OnlineFeatureStoreWriter(),
        ]

        for w in writer:
            w.validate = mocker.stub("validate")

        feature_set = mocker.stub("feature_set")

        # when
        sink = Sink(writers=writer)
        sink.validate(
            dataframe=feature_set_dataframe,
            feature_set=feature_set,
            spark_client=spark_client,
        )

        # then
        for w in writer:
            w.validate.assert_called_once()
Esempio n. 19
0
    def test_validate_false(self, feature_set_dataframe, mocker):
        # given
        spark_client = SparkClient()
        writer = [
            HistoricalFeatureStoreWriter(),
            OnlineFeatureStoreWriter(),
        ]

        for w in writer:
            w.validate = mocker.stub("validate")
            w.validate.side_effect = AssertionError("test")

        feature_set = mocker.stub("feature_set")

        # when
        sink = Sink(writers=writer)

        # then
        with pytest.raises(RuntimeError):
            sink.validate(
                dataframe=feature_set_dataframe,
                feature_set=feature_set,
                spark_client=spark_client,
            )
    def test_feature_set_args(self):
        # arrange and act
        out_columns = [
            "user_id",
            "timestamp",
            "listing_page_viewed__rent_per_month__avg_over_7_days_fixed_windows",
            "listing_page_viewed__rent_per_month__avg_over_2_weeks_fixed_windows",
            "listing_page_viewed__rent_per_month__stddev_pop_over_7_days_fixed_windows",
            "listing_page_viewed__rent_per_month__"
            "stddev_pop_over_2_weeks_fixed_windows",
            # noqa
        ]
        pipeline = FeatureSetPipeline(
            source=Source(
                readers=[
                    TableReader(
                        id="source_a",
                        database="db",
                        table="table",
                    ),
                    FileReader(
                        id="source_b",
                        path="path",
                        format="parquet",
                    ),
                ],
                query="select a.*, b.specific_feature "
                "from source_a left join source_b on a.id=b.id",
            ),
            feature_set=FeatureSet(
                name="feature_set",
                entity="entity",
                description="description",
                keys=[
                    KeyFeature(
                        name="user_id",
                        description="The user's Main ID or device ID",
                        dtype=DataType.INTEGER,
                    )
                ],
                timestamp=TimestampFeature(from_column="ts"),
                features=[
                    Feature(
                        name="listing_page_viewed__rent_per_month",
                        description="Average of something.",
                        transformation=SparkFunctionTransform(functions=[
                            Function(functions.avg, DataType.FLOAT),
                            Function(functions.stddev_pop, DataType.FLOAT),
                        ], ).with_window(
                            partition_by="user_id",
                            order_by=TIMESTAMP_COLUMN,
                            window_definition=["7 days", "2 weeks"],
                            mode="fixed_windows",
                        ),
                    ),
                ],
            ),
            sink=Sink(writers=[
                HistoricalFeatureStoreWriter(db_config=None),
                OnlineFeatureStoreWriter(db_config=None),
            ], ),
        )

        assert isinstance(pipeline.spark_client, SparkClient)
        assert len(pipeline.source.readers) == 2
        assert all(
            isinstance(reader, Reader) for reader in pipeline.source.readers)
        assert isinstance(pipeline.source.query, str)
        assert pipeline.feature_set.name == "feature_set"
        assert pipeline.feature_set.entity == "entity"
        assert pipeline.feature_set.description == "description"
        assert isinstance(pipeline.feature_set.timestamp, TimestampFeature)
        assert len(pipeline.feature_set.keys) == 1
        assert all(
            isinstance(k, KeyFeature) for k in pipeline.feature_set.keys)
        assert len(pipeline.feature_set.features) == 1
        assert all(
            isinstance(feature, Feature)
            for feature in pipeline.feature_set.features)
        assert pipeline.feature_set.columns == out_columns
        assert len(pipeline.sink.writers) == 2
        assert all(
            isinstance(writer, Writer) for writer in pipeline.sink.writers)
Esempio n. 21
0
    def test_get_db_schema(self, cassandra_config, test_feature_set, expected_schema):
        writer = OnlineFeatureStoreWriter(cassandra_config)
        schema = writer.get_db_schema(test_feature_set)

        assert schema == expected_schema