예제 #1
0
    def test_write(
        self,
        feature_set_dataframe,
        latest_feature_set_dataframe,
        cassandra_config,
        mocker,
        feature_set,
    ):
        # with
        spark_client = mocker.stub("spark_client")
        spark_client.write_dataframe = mocker.stub("write_dataframe")
        writer = OnlineFeatureStoreWriter(cassandra_config)

        # when
        writer.write(feature_set, feature_set_dataframe, spark_client)

        assert sorted(latest_feature_set_dataframe.collect()) == sorted(
            spark_client.write_dataframe.call_args[1]["dataframe"].collect()
        )
        assert (
            writer.db_config.mode == spark_client.write_dataframe.call_args[1]["mode"]
        )
        assert (
            writer.db_config.format_
            == spark_client.write_dataframe.call_args[1]["format_"]
        )
        # assert if all additional options got from db_config
        # are in the called args in write_dataframe
        assert all(
            item in spark_client.write_dataframe.call_args[1].items()
            for item in writer.db_config.get_options(table=feature_set.name).items()
        )
예제 #2
0
    def test_write_in_debug_and_stream_mode(
        self, feature_set, spark_session,
    ):
        # arrange
        spark_client = SparkClient()

        mocked_stream_df = Mock()
        mocked_stream_df.isStreaming = True
        mocked_stream_df.writeStream = mocked_stream_df
        mocked_stream_df.format.return_value = mocked_stream_df
        mocked_stream_df.queryName.return_value = mocked_stream_df
        mocked_stream_df.start.return_value = Mock(spec=StreamingQuery)

        writer = OnlineFeatureStoreWriter(debug_mode=True)

        # act
        handler = writer.write(
            feature_set=feature_set,
            dataframe=mocked_stream_df,
            spark_client=spark_client,
        )

        # assert
        mocked_stream_df.format.assert_called_with("memory")
        mocked_stream_df.queryName.assert_called_with(
            f"online_feature_store__{feature_set.name}"
        )
        assert isinstance(handler, StreamingQuery)
예제 #3
0
    def test_write_with_kafka_config(
        self,
        feature_set_dataframe,
        online_feature_set_dataframe_json,
        mocker,
        feature_set,
    ):
        # with
        spark_client = mocker.stub("spark_client")
        spark_client.write_dataframe = mocker.stub("write_dataframe")
        kafka_config = KafkaConfig()
        writer = OnlineFeatureStoreWriter(kafka_config).with_(json_transform)

        # when
        writer.write(feature_set, feature_set_dataframe, spark_client)

        assert all(
            item in spark_client.write_dataframe.call_args[1].items()
            for item in writer.db_config.get_options(topic=feature_set.name).items()
        )
예제 #4
0
    def test_write_in_debug_mode(
        self,
        feature_set_dataframe,
        latest_feature_set_dataframe,
        feature_set,
        spark_session,
    ):
        # given
        spark_client = SparkClient()
        writer = OnlineFeatureStoreWriter(debug_mode=True)

        # when
        writer.write(
            feature_set=feature_set,
            dataframe=feature_set_dataframe,
            spark_client=spark_client,
        )
        result_df = spark_session.table(f"online_feature_store__{feature_set.name}")

        # then
        assert_dataframe_equality(latest_feature_set_dataframe, result_df)
예제 #5
0
    def test_write_stream(self, feature_set, has_checkpoint, monkeypatch):
        # arrange
        spark_client = SparkClient()
        spark_client.write_stream = Mock()
        spark_client.write_dataframe = Mock()
        spark_client.write_stream.return_value = Mock(spec=StreamingQuery)

        dataframe = Mock(spec=DataFrame)
        dataframe.isStreaming = True

        if has_checkpoint:
            monkeypatch.setenv("STREAM_CHECKPOINT_PATH", "test")

        cassandra_config = CassandraConfig(keyspace="feature_set")
        target_checkpoint_path = (
            "test/entity/feature_set"
            if cassandra_config.stream_checkpoint_path
            else None
        )

        writer = OnlineFeatureStoreWriter(cassandra_config)
        writer.filter_latest = Mock()

        # act
        stream_handler = writer.write(feature_set, dataframe, spark_client)

        # assert
        assert isinstance(stream_handler, StreamingQuery)
        spark_client.write_stream.assert_any_call(
            dataframe,
            processing_time=cassandra_config.stream_processing_time,
            output_mode=cassandra_config.stream_output_mode,
            checkpoint_path=target_checkpoint_path,
            format_=cassandra_config.format_,
            mode=cassandra_config.mode,
            **cassandra_config.get_options(table=feature_set.name),
        )
        writer.filter_latest.assert_not_called()
        spark_client.write_dataframe.assert_not_called()
예제 #6
0
    def test_write_stream_on_entity(self, feature_set, monkeypatch):
        """Test write method with stream dataframe and write_to_entity enabled.

        The main purpose of this test is assert the correct setup of stream checkpoint
        path and if the target table name is the entity.

        """

        # arrange
        spark_client = SparkClient()
        spark_client.write_stream = Mock()
        spark_client.write_stream.return_value = Mock(spec=StreamingQuery)

        dataframe = Mock(spec=DataFrame)
        dataframe.isStreaming = True

        feature_set.entity = "my_entity"
        feature_set.name = "my_feature_set"
        monkeypatch.setenv("STREAM_CHECKPOINT_PATH", "test")
        target_checkpoint_path = "test/my_entity/my_feature_set__on_entity"

        writer = OnlineFeatureStoreWriter(write_to_entity=True)

        # act
        stream_handler = writer.write(feature_set, dataframe, spark_client)

        # assert
        assert isinstance(stream_handler, StreamingQuery)
        spark_client.write_stream.assert_called_with(
            dataframe,
            processing_time=writer.db_config.stream_processing_time,
            output_mode=writer.db_config.stream_output_mode,
            checkpoint_path=target_checkpoint_path,
            format_=writer.db_config.format_,
            mode=writer.db_config.mode,
            **writer.db_config.get_options(table="my_entity"),
        )