Beispiel #1
0
    def test_write_stream_invalid_params(self, mocked_stream_df: Mock) -> None:
        # arrange
        spark_client = SparkClient()
        mocked_stream_df.isStreaming = False

        # act and assert
        with pytest.raises(ValueError):
            spark_client.write_stream(
                mocked_stream_df,
                processing_time="0 seconds",
                output_mode="update",
                checkpoint_path="s3://path/to/checkpoint",
                format_="parquet",
                mode="append",
            )
Beispiel #2
0
 def _write_stream(
     self,
     feature_set: FeatureSet,
     dataframe: DataFrame,
     spark_client: SparkClient,
     table_name: str,
 ) -> StreamingQuery:
     """Writes the dataframe in streaming mode."""
     checkpoint_folder = (f"{feature_set.name}__on_entity"
                          if self.write_to_entity else table_name)
     checkpoint_path = (os.path.join(
         self.db_config.stream_checkpoint_path,
         feature_set.entity,
         checkpoint_folder,
     ) if self.db_config.stream_checkpoint_path else None)
     streaming_handler = spark_client.write_stream(
         dataframe,
         processing_time=self.db_config.stream_processing_time,
         output_mode=self.db_config.stream_output_mode,
         checkpoint_path=checkpoint_path,
         format_=self.db_config.format_,
         mode=self.db_config.mode,
         **self.db_config.get_options(table_name),
     )
     return streaming_handler
Beispiel #3
0
    def test_write_stream(self, mocked_stream_df: Mock) -> None:
        # arrange
        spark_client = SparkClient()

        processing_time = "0 seconds"
        output_mode = "update"
        checkpoint_path = "s3://path/to/checkpoint"

        # act
        stream_handler = spark_client.write_stream(
            mocked_stream_df,
            processing_time,
            output_mode,
            checkpoint_path,
            format_="parquet",
            mode="append",
        )

        # assert
        assert isinstance(stream_handler, StreamingQuery)
        mocked_stream_df.trigger.assert_called_with(processingTime=processing_time)
        mocked_stream_df.outputMode.assert_called_with(output_mode)
        mocked_stream_df.option.assert_called_with(
            "checkpointLocation", checkpoint_path
        )
        mocked_stream_df.foreachBatch.assert_called_once()
        mocked_stream_df.start.assert_called_once()
Beispiel #4
0
    def test_write_stream(self, feature_set, has_checkpoint, monkeypatch):
        # arrange
        spark_client = SparkClient()
        spark_client.write_stream = Mock()
        spark_client.write_dataframe = Mock()
        spark_client.write_stream.return_value = Mock(spec=StreamingQuery)

        dataframe = Mock(spec=DataFrame)
        dataframe.isStreaming = True

        if has_checkpoint:
            monkeypatch.setenv("STREAM_CHECKPOINT_PATH", "test")

        cassandra_config = CassandraConfig(keyspace="feature_set")
        target_checkpoint_path = (
            "test/entity/feature_set"
            if cassandra_config.stream_checkpoint_path
            else None
        )

        writer = OnlineFeatureStoreWriter(cassandra_config)
        writer.filter_latest = Mock()

        # act
        stream_handler = writer.write(feature_set, dataframe, spark_client)

        # assert
        assert isinstance(stream_handler, StreamingQuery)
        spark_client.write_stream.assert_any_call(
            dataframe,
            processing_time=cassandra_config.stream_processing_time,
            output_mode=cassandra_config.stream_output_mode,
            checkpoint_path=target_checkpoint_path,
            format_=cassandra_config.format_,
            mode=cassandra_config.mode,
            **cassandra_config.get_options(table=feature_set.name),
        )
        writer.filter_latest.assert_not_called()
        spark_client.write_dataframe.assert_not_called()
Beispiel #5
0
    def test_write_stream_on_entity(self, feature_set, monkeypatch):
        """Test write method with stream dataframe and write_to_entity enabled.

        The main purpose of this test is assert the correct setup of stream checkpoint
        path and if the target table name is the entity.

        """

        # arrange
        spark_client = SparkClient()
        spark_client.write_stream = Mock()
        spark_client.write_stream.return_value = Mock(spec=StreamingQuery)

        dataframe = Mock(spec=DataFrame)
        dataframe.isStreaming = True

        feature_set.entity = "my_entity"
        feature_set.name = "my_feature_set"
        monkeypatch.setenv("STREAM_CHECKPOINT_PATH", "test")
        target_checkpoint_path = "test/my_entity/my_feature_set__on_entity"

        writer = OnlineFeatureStoreWriter(write_to_entity=True)

        # act
        stream_handler = writer.write(feature_set, dataframe, spark_client)

        # assert
        assert isinstance(stream_handler, StreamingQuery)
        spark_client.write_stream.assert_called_with(
            dataframe,
            processing_time=writer.db_config.stream_processing_time,
            output_mode=writer.db_config.stream_output_mode,
            checkpoint_path=target_checkpoint_path,
            format_=writer.db_config.format_,
            mode=writer.db_config.mode,
            **writer.db_config.get_options(table="my_entity"),
        )