def test_read_invalid_params(self, format, options): # arrange spark_client = SparkClient() # act and assert with pytest.raises(ValueError): spark_client.read(format, options)
def test_sink(input_dataframe, feature_set): # arrange client = SparkClient() feature_set_df = feature_set.construct(input_dataframe, client) target_latest_df = OnlineFeatureStoreWriter.filter_latest( feature_set_df, id_columns=[key.name for key in feature_set.keys]) columns_sort = feature_set_df.schema.fieldNames() # setup historical writer s3config = Mock() s3config.get_options = Mock( return_value={ "mode": "overwrite", "format_": "parquet", "path": "test_folder/historical/entity/feature_set", }) historical_writer = HistoricalFeatureStoreWriter(db_config=s3config) # setup online writer # TODO: Change for CassandraConfig when Cassandra for test is ready online_config = Mock() online_config.mode = "overwrite" online_config.format_ = "parquet" online_config.get_options = Mock( return_value={"path": "test_folder/online/entity/feature_set"}) online_writer = OnlineFeatureStoreWriter(db_config=online_config) writers = [historical_writer, online_writer] sink = Sink(writers) # act client.sql("CREATE DATABASE IF NOT EXISTS {}".format( historical_writer.database)) sink.flush(feature_set, feature_set_df, client) # get historical results historical_result_df = client.read_table(feature_set.name, historical_writer.database) # get online results online_result_df = client.read(online_config.format_, options=online_config.get_options( feature_set.name)) # assert # assert historical results assert sorted(feature_set_df.select(*columns_sort).collect()) == sorted( historical_result_df.select(*columns_sort).collect()) # assert online results assert sorted(target_latest_df.select(*columns_sort).collect()) == sorted( online_result_df.select(*columns_sort).collect()) # tear down shutil.rmtree("test_folder")
def test_read(self, format, options, stream, schema, target_df, mocked_spark_read): # arrange spark_client = SparkClient() mocked_spark_read.load.return_value = target_df spark_client._session = mocked_spark_read # act result_df = spark_client.read(format, options, schema, stream) # assert mocked_spark_read.format.assert_called_once_with(format) mocked_spark_read.options.assert_called_once_with(**options) assert target_df.collect() == result_df.collect()
def consume(self, client: SparkClient) -> DataFrame: """Extract data from files stored in defined path. Try to auto-infer schema if in stream mode and not manually defining a schema. Args: client: client responsible for connecting to Spark session. Returns: Dataframe with all the files data. """ schema = (client.read( format=self.format, options=self.options, ).schema if (self.stream and not self.schema) else self.schema) return client.read( format=self.format, options=self.options, schema=schema, stream=self.stream, )
def consume(self, client: SparkClient) -> DataFrame: """Extract data from a kafka topic. When stream mode it will get all the new data arriving at the topic in a streaming dataframe. When not in stream mode it will get all data available in the kafka topic. Args: client: client responsible for connecting to Spark session. Returns: Dataframe with data from topic. """ # read using client and cast key and value columns from binary to string raw_df = ( client.read(format="kafka", options=self.options, stream=self.stream) .withColumn("key", col("key").cast("string")) .withColumn("value", col("value").cast("string")) ) # apply schema defined in self.value_schema return self._struct_df(raw_df)