Exemple #1
0
    def test_write_stream_invalid_params(self, mocked_stream_df):
        # arrange
        spark_client = SparkClient()
        mocked_stream_df.isStreaming = False

        # act and assert
        with pytest.raises(ValueError):
            spark_client.write_stream(
                mocked_stream_df,
                processing_time="0 seconds",
                output_mode="update",
                checkpoint_path="s3://path/to/checkpoint",
                format_="parquet",
                mode="append",
            )
Exemple #2
0
    def test_write_stream(self, mocked_stream_df):
        # arrange
        spark_client = SparkClient()

        processing_time = "0 seconds"
        output_mode = "update"
        checkpoint_path = "s3://path/to/checkpoint"

        # act
        stream_handler = spark_client.write_stream(
            mocked_stream_df,
            processing_time,
            output_mode,
            checkpoint_path,
            format_="parquet",
            mode="append",
        )

        # assert
        assert isinstance(stream_handler, StreamingQuery)
        mocked_stream_df.trigger.assert_called_with(
            processingTime=processing_time)
        mocked_stream_df.outputMode.assert_called_with(output_mode)
        mocked_stream_df.option.assert_called_with("checkpointLocation",
                                                   checkpoint_path)
        mocked_stream_df.foreachBatch.assert_called_once()
        mocked_stream_df.start.assert_called_once()
 def _write_stream(self, feature_set: FeatureSet, dataframe: DataFrame,
                   spark_client: SparkClient):
     """Writes the dataframe in streaming mode."""
     # TODO: Refactor this logic using the Sink returning the Query Handler
     for table in [feature_set.name, feature_set.entity]:
         checkpoint_path = (os.path.join(
             self.db_config.stream_checkpoint_path,
             feature_set.entity,
             f"{feature_set.name}__on_entity"
             if table == feature_set.entity else table,
         ) if self.db_config.stream_checkpoint_path else None)
         streaming_handler = spark_client.write_stream(
             dataframe,
             processing_time=self.db_config.stream_processing_time,
             output_mode=self.db_config.stream_output_mode,
             checkpoint_path=checkpoint_path,
             format_=self.db_config.format_,
             mode=self.db_config.mode,
             **self.db_config.get_options(table=table),
         )
     return streaming_handler
Exemple #4
0
    def test_write_stream(self, feature_set, has_checkpoint, monkeypatch):
        # arrange
        spark_client = SparkClient()
        spark_client.write_stream = Mock()
        spark_client.write_dataframe = Mock()
        spark_client.write_stream.return_value = Mock(spec=StreamingQuery)

        dataframe = Mock(spec=DataFrame)
        dataframe.isStreaming = True

        if has_checkpoint:
            monkeypatch.setenv("STREAM_CHECKPOINT_PATH", "test")

        cassandra_config = CassandraConfig(keyspace="feature_set")
        target_checkpoint_path = ("test/entity/feature_set"
                                  if cassandra_config.stream_checkpoint_path
                                  else None)

        writer = OnlineFeatureStoreWriter(cassandra_config)
        writer.filter_latest = Mock()

        # act
        stream_handler = writer.write(feature_set, dataframe, spark_client)

        # assert
        assert isinstance(stream_handler, StreamingQuery)
        spark_client.write_stream.assert_any_call(
            dataframe,
            processing_time=cassandra_config.stream_processing_time,
            output_mode=cassandra_config.stream_output_mode,
            checkpoint_path=target_checkpoint_path,
            format_=cassandra_config.format_,
            mode=cassandra_config.mode,
            **cassandra_config.get_options(table=feature_set.name),
        )
        writer.filter_latest.assert_not_called()
        spark_client.write_dataframe.assert_not_called()