def test_sink(input_dataframe, feature_set): # arrange client = SparkClient() client.conn.conf.set("spark.sql.sources.partitionOverwriteMode", "dynamic") feature_set_df = feature_set.construct(input_dataframe, client) target_latest_df = OnlineFeatureStoreWriter.filter_latest( feature_set_df, id_columns=[key.name for key in feature_set.keys]) columns_sort = feature_set_df.schema.fieldNames() # setup historical writer s3config = Mock() s3config.mode = "overwrite" s3config.format_ = "parquet" s3config.get_options = Mock( return_value={"path": "test_folder/historical/entity/feature_set"}) s3config.get_path_with_partitions = Mock( return_value="test_folder/historical/entity/feature_set") historical_writer = HistoricalFeatureStoreWriter(db_config=s3config, interval_mode=True) # setup online writer # TODO: Change for CassandraConfig when Cassandra for test is ready online_config = Mock() online_config.mode = "overwrite" online_config.format_ = "parquet" online_config.get_options = Mock( return_value={"path": "test_folder/online/entity/feature_set"}) online_writer = OnlineFeatureStoreWriter(db_config=online_config) writers = [historical_writer, online_writer] sink = Sink(writers) # act client.sql("CREATE DATABASE IF NOT EXISTS {}".format( historical_writer.database)) sink.flush(feature_set, feature_set_df, client) # get historical results historical_result_df = client.read( s3config.format_, path=s3config.get_path_with_partitions(feature_set.name, feature_set_df), ) # get online results online_result_df = client.read( online_config.format_, **online_config.get_options(feature_set.name)) # assert # assert historical results assert sorted(feature_set_df.select(*columns_sort).collect()) == sorted( historical_result_df.select(*columns_sort).collect()) # assert online results assert sorted(target_latest_df.select(*columns_sort).collect()) == sorted( online_result_df.select(*columns_sort).collect()) # tear down shutil.rmtree("test_folder")
def test_read_invalid_params(self, format: Optional[str], path: Any) -> None: # arrange spark_client = SparkClient() # act and assert with pytest.raises(ValueError): spark_client.read(format=format, path=path) # type: ignore
def test_read_invalid_params( self, format: Optional[str], options: Union[Dict[str, Any], str] ) -> None: # arrange spark_client = SparkClient() # act and assert with pytest.raises(ValueError): spark_client.read(format, options) # type: ignore
def validate(self, feature_set: FeatureSet, dataframe: DataFrame, spark_client: SparkClient) -> None: """Calculate dataframe rows to validate data into Feature Store. Args: feature_set: object processed with feature_set informations. dataframe: spark dataframe containing data from a feature set. spark_client: client for spark connections with external services. Raises: AssertionError: if count of written data doesn't match count in current feature set dataframe. """ table_name = ( os.path.join("historical", feature_set.entity, feature_set.name) if self.interval_mode and not self.debug_mode else (f"{self.database}.{feature_set.name}" if not self.debug_mode else f"historical_feature_store__{feature_set.name}")) written_count = (spark_client.read( self.db_config.format_, path=self.db_config.get_path_with_partitions( table_name, self._create_partitions(dataframe)), ).count() if self.interval_mode and not self.debug_mode else spark_client.read_table(table_name).count()) dataframe_count = dataframe.count() self._assert_validation_count(table_name, written_count, dataframe_count)
def consume(self, client: SparkClient) -> DataFrame: """Extract data from a kafka topic. When stream mode it will get all the new data arriving at the topic in a streaming dataframe. When not in stream mode it will get all data available in the kafka topic. Args: client: client responsible for connecting to Spark session. Returns: Dataframe with data from topic. """ # read using client and cast key and value columns from binary to string raw_df = (client.read(format="kafka", options=self.options, stream=self.stream).withColumn( "key", col("key").cast("string")).withColumn( "value", col("value").cast("string"))) # apply schema defined in self.value_schema return self._struct_df(raw_df)
def test_read( self, format: str, stream: bool, schema: Optional[StructType], path: Any, options: Any, target_df: DataFrame, mocked_spark_read: Mock, ) -> None: # arrange spark_client = SparkClient() mocked_spark_read.load.return_value = target_df spark_client._session = mocked_spark_read # act result_df = spark_client.read(format=format, schema=schema, stream=stream, path=path, **options) # assert mocked_spark_read.format.assert_called_once_with(format) mocked_spark_read.load.assert_called_once_with(path=path, **options) assert target_df.collect() == result_df.collect()
def consume(self, client: SparkClient) -> DataFrame: """Extract data from files stored in defined path. Try to auto-infer schema if in stream mode and not manually defining a schema. Args: client: client responsible for connecting to Spark session. Returns: Dataframe with all the files data. """ schema = (client.read( format=self.format, options=self.options, ).schema if (self.stream and not self.schema) else self.schema) return client.read( format=self.format, options=self.options, schema=schema, stream=self.stream, )