def test_read_table_invalid_params(self, database, table): # arrange spark_client = SparkClient() # act and assert with pytest.raises(ValueError): spark_client.read_table(table, database)
def test_sink(input_dataframe, feature_set): # arrange client = SparkClient() feature_set_df = feature_set.construct(input_dataframe, client) target_latest_df = OnlineFeatureStoreWriter.filter_latest( feature_set_df, id_columns=[key.name for key in feature_set.keys]) columns_sort = feature_set_df.schema.fieldNames() # setup historical writer s3config = Mock() s3config.get_options = Mock( return_value={ "mode": "overwrite", "format_": "parquet", "path": "test_folder/historical/entity/feature_set", }) historical_writer = HistoricalFeatureStoreWriter(db_config=s3config) # setup online writer # TODO: Change for CassandraConfig when Cassandra for test is ready online_config = Mock() online_config.mode = "overwrite" online_config.format_ = "parquet" online_config.get_options = Mock( return_value={"path": "test_folder/online/entity/feature_set"}) online_writer = OnlineFeatureStoreWriter(db_config=online_config) writers = [historical_writer, online_writer] sink = Sink(writers) # act client.sql("CREATE DATABASE IF NOT EXISTS {}".format( historical_writer.database)) sink.flush(feature_set, feature_set_df, client) # get historical results historical_result_df = client.read_table(feature_set.name, historical_writer.database) # get online results online_result_df = client.read(online_config.format_, options=online_config.get_options( feature_set.name)) # assert # assert historical results assert sorted(feature_set_df.select(*columns_sort).collect()) == sorted( historical_result_df.select(*columns_sort).collect()) # assert online results assert sorted(target_latest_df.select(*columns_sort).collect()) == sorted( online_result_df.select(*columns_sort).collect()) # tear down shutil.rmtree("test_folder")
def consume(self, client: SparkClient) -> DataFrame: """Extract data from a table in Spark metastore. Args: client: client responsible for connecting to Spark session. Returns: Dataframe with all the data from the table. """ return client.read_table(self.table, self.database)
def test_read_table(self, target_df, mocked_spark_read, database, table, target_table_name): # arrange spark_client = SparkClient() mocked_spark_read.table.return_value = target_df spark_client._session = mocked_spark_read # act result_df = spark_client.read_table(table, database) # assert mocked_spark_read.table.assert_called_once_with(target_table_name) assert target_df == result_df
def validate( self, feature_set: FeatureSet, dataframe: DataFrame, spark_client: SparkClient ): """Calculate dataframe rows to validate data into Feature Store. Args: feature_set: object processed with feature_set informations. dataframe: spark dataframe containing data from a feature set. spark_client: client for spark connections with external services. Raises: AssertionError: if count of written data doesn't match count in current feature set dataframe. """ table_name = ( f"{self.database}.{feature_set.name}" if not self.debug_mode else f"historical_feature_store__{feature_set.name}" ) written_count = spark_client.read_table(table_name).count() dataframe_count = dataframe.count() self._assert_validation_count(table_name, written_count, dataframe_count)