def test_write_stream_invalid_params(self, mocked_stream_df): # arrange spark_client = SparkClient() mocked_stream_df.isStreaming = False # act and assert with pytest.raises(ValueError): spark_client.write_stream( mocked_stream_df, processing_time="0 seconds", output_mode="update", checkpoint_path="s3://path/to/checkpoint", format_="parquet", mode="append", )
def test_write_stream(self, mocked_stream_df): # arrange spark_client = SparkClient() processing_time = "0 seconds" output_mode = "update" checkpoint_path = "s3://path/to/checkpoint" # act stream_handler = spark_client.write_stream( mocked_stream_df, processing_time, output_mode, checkpoint_path, format_="parquet", mode="append", ) # assert assert isinstance(stream_handler, StreamingQuery) mocked_stream_df.trigger.assert_called_with( processingTime=processing_time) mocked_stream_df.outputMode.assert_called_with(output_mode) mocked_stream_df.option.assert_called_with("checkpointLocation", checkpoint_path) mocked_stream_df.foreachBatch.assert_called_once() mocked_stream_df.start.assert_called_once()
def _write_stream(self, feature_set: FeatureSet, dataframe: DataFrame, spark_client: SparkClient): """Writes the dataframe in streaming mode.""" # TODO: Refactor this logic using the Sink returning the Query Handler for table in [feature_set.name, feature_set.entity]: checkpoint_path = (os.path.join( self.db_config.stream_checkpoint_path, feature_set.entity, f"{feature_set.name}__on_entity" if table == feature_set.entity else table, ) if self.db_config.stream_checkpoint_path else None) streaming_handler = spark_client.write_stream( dataframe, processing_time=self.db_config.stream_processing_time, output_mode=self.db_config.stream_output_mode, checkpoint_path=checkpoint_path, format_=self.db_config.format_, mode=self.db_config.mode, **self.db_config.get_options(table=table), ) return streaming_handler
def test_write_stream(self, feature_set, has_checkpoint, monkeypatch): # arrange spark_client = SparkClient() spark_client.write_stream = Mock() spark_client.write_dataframe = Mock() spark_client.write_stream.return_value = Mock(spec=StreamingQuery) dataframe = Mock(spec=DataFrame) dataframe.isStreaming = True if has_checkpoint: monkeypatch.setenv("STREAM_CHECKPOINT_PATH", "test") cassandra_config = CassandraConfig(keyspace="feature_set") target_checkpoint_path = ("test/entity/feature_set" if cassandra_config.stream_checkpoint_path else None) writer = OnlineFeatureStoreWriter(cassandra_config) writer.filter_latest = Mock() # act stream_handler = writer.write(feature_set, dataframe, spark_client) # assert assert isinstance(stream_handler, StreamingQuery) spark_client.write_stream.assert_any_call( dataframe, processing_time=cassandra_config.stream_processing_time, output_mode=cassandra_config.stream_output_mode, checkpoint_path=target_checkpoint_path, format_=cassandra_config.format_, mode=cassandra_config.mode, **cassandra_config.get_options(table=feature_set.name), ) writer.filter_latest.assert_not_called() spark_client.write_dataframe.assert_not_called()