def test_source_raise(self): with pytest.raises(ValueError, match="source must be a Source instance"): FeatureSetPipeline( spark_client=SparkClient(), source=Mock( spark_client=SparkClient(), readers=[ TableReader(id="source_a", database="db", table="table",), ], query="select * from source_a", ), feature_set=Mock( spec=FeatureSet, name="feature_set", entity="entity", description="description", keys=[ KeyFeature( name="user_id", description="The user's Main ID or device ID", dtype=DataType.INTEGER, ) ], timestamp=TimestampFeature(from_column="ts"), features=[ Feature( name="listing_page_viewed__rent_per_month", description="Average of something.", transformation=SparkFunctionTransform( functions=[ Function(functions.avg, DataType.FLOAT), Function(functions.stddev_pop, DataType.FLOAT), ], ).with_window( partition_by="user_id", order_by=TIMESTAMP_COLUMN, window_definition=["7 days", "2 weeks"], mode="fixed_windows", ), ), ], ), sink=Mock( spec=Sink, writers=[HistoricalFeatureStoreWriter(db_config=None)], ), )
def test_conn(self): # arrange spark_client = SparkClient() # act start_conn = spark_client._session # assert assert start_conn is None
def test_run_with_repartition(self, spark_session): test_pipeline = FeatureSetPipeline( spark_client=SparkClient(), source=Mock( spec=Source, readers=[TableReader(id="source_a", database="db", table="table",)], query="select * from source_a", ), feature_set=Mock( spec=FeatureSet, name="feature_set", entity="entity", description="description", keys=[ KeyFeature( name="user_id", description="The user's Main ID or device ID", dtype=DataType.INTEGER, ) ], timestamp=TimestampFeature(from_column="ts"), features=[ Feature( name="listing_page_viewed__rent_per_month", description="Average of something.", transformation=SparkFunctionTransform( functions=[ Function(functions.avg, DataType.FLOAT), Function(functions.stddev_pop, DataType.FLOAT), ], ).with_window( partition_by="user_id", order_by=TIMESTAMP_COLUMN, window_definition=["7 days", "2 weeks"], mode="fixed_windows", ), ), ], ), sink=Mock( spec=Sink, writers=[HistoricalFeatureStoreWriter(db_config=None)], ), ) # feature_set need to return a real df for streaming validation sample_df = spark_session.createDataFrame([{"a": "x", "b": "y", "c": "3"}]) test_pipeline.feature_set.construct.return_value = sample_df test_pipeline.run(partition_by=["id"]) test_pipeline.source.construct.assert_called_once() test_pipeline.feature_set.construct.assert_called_once() test_pipeline.sink.flush.assert_called_once() test_pipeline.sink.validate.assert_called_once()
def consume(self, client: SparkClient) -> DataFrame: """Extract data from a table in Spark metastore. Args: client: client responsible for connecting to Spark session. Returns: Dataframe with all the data from the table. """ return client.read_table(self.table, self.database)
def test_construct_rolling_windows_with_end_date( self, feature_set_dataframe, rolling_windows_output_feature_set_dataframe_base_date, ): # given spark_client = SparkClient() # arrange feature_set = AggregatedFeatureSet( name="feature_set", entity="entity", description="description", features=[ Feature( name="feature1", description="test", transformation=AggregatedTransform(functions=[ Function(F.avg, DataType.DOUBLE), Function(F.stddev_pop, DataType.DOUBLE), ], ), ), Feature( name="feature2", description="test", transformation=AggregatedTransform(functions=[ Function(F.avg, DataType.DOUBLE), Function(F.stddev_pop, DataType.DOUBLE), ], ), ), ], keys=[ KeyFeature( name="id", description="The user's Main ID or device ID", dtype=DataType.INTEGER, ) ], timestamp=TimestampFeature(), ).with_windows(definitions=["1 day", "1 week"]) # act output_df = feature_set.construct( feature_set_dataframe, client=spark_client, end_date="2016-04-18").orderBy("timestamp") target_df = rolling_windows_output_feature_set_dataframe_base_date.orderBy( feature_set.timestamp_column).select(feature_set.columns) # assert assert_dataframe_equality(output_df, target_df)
def test_write_table(self, format, mode, database, table_name, path, mocked_spark_write): # given name = "{}.{}".format(database, table_name) # when SparkClient.write_table( dataframe=mocked_spark_write, database=database, table_name=table_name, format_=format, mode=mode, path=path, ) # then mocked_spark_write.saveAsTable.assert_called_with(mode=mode, format=format, partitionBy=None, name=name, path=path)
def test_source( self, target_df_source, target_df_table_reader, spark_session, ): # given spark_client = SparkClient() table_reader_id = "a_test_source" table_reader_db = "db" table_reader_table = "table_test_source" create_temp_view(dataframe=target_df_table_reader, name=table_reader_id) create_db_and_table( spark=spark_session, table_reader_id=table_reader_id, table_reader_db=table_reader_db, table_reader_table=table_reader_table, ) file_reader_id = "b_test_source" data_sample_path = INPUT_PATH + "/data.json" # when source = Source( readers=[ TableReader( id=table_reader_id, database=table_reader_db, table=table_reader_table, ), FileReader(id=file_reader_id, path=data_sample_path, format="json"), ], query=f"select a.*, b.feature2 " # noqa f"from {table_reader_id} a " # noqa f"inner join {file_reader_id} b on a.id = b.id ", # noqa ) result_df = source.construct(client=spark_client) target_df = target_df_source # then assert (compare_dataframes( actual_df=result_df, expected_df=target_df, columns_sort=result_df.columns, ) is True)
def test_write_stream(self, feature_set, has_checkpoint, monkeypatch): # arrange spark_client = SparkClient() spark_client.write_stream = Mock() spark_client.write_dataframe = Mock() spark_client.write_stream.return_value = Mock(spec=StreamingQuery) dataframe = Mock(spec=DataFrame) dataframe.isStreaming = True if has_checkpoint: monkeypatch.setenv("STREAM_CHECKPOINT_PATH", "test") cassandra_config = CassandraConfig(keyspace="feature_set") target_checkpoint_path = ("test/entity/feature_set" if cassandra_config.stream_checkpoint_path else None) writer = OnlineFeatureStoreWriter(cassandra_config) writer.filter_latest = Mock() # act stream_handler = writer.write(feature_set, dataframe, spark_client) # assert assert isinstance(stream_handler, StreamingQuery) spark_client.write_stream.assert_any_call( dataframe, processing_time=cassandra_config.stream_processing_time, output_mode=cassandra_config.stream_output_mode, checkpoint_path=target_checkpoint_path, format_=cassandra_config.format_, mode=cassandra_config.mode, **cassandra_config.get_options(table=feature_set.name), ) writer.filter_latest.assert_not_called() spark_client.write_dataframe.assert_not_called()
def consume(self, client: SparkClient) -> DataFrame: """Extract data from files stored in defined path. Try to auto-infer schema if in stream mode and not manually defining a schema. Args: client: client responsible for connecting to Spark session. Returns: Dataframe with all the files data. """ schema = (client.read( format=self.format, options=self.options, ).schema if (self.stream and not self.schema) else self.schema) return client.read( format=self.format, options=self.options, schema=schema, stream=self.stream, )
def test_sink(input_dataframe, feature_set): # arrange client = SparkClient() feature_set_df = feature_set.construct(input_dataframe, client) target_latest_df = OnlineFeatureStoreWriter.filter_latest( feature_set_df, id_columns=[key.name for key in feature_set.keys]) columns_sort = feature_set_df.schema.fieldNames() # setup historical writer s3config = Mock() s3config.get_options = Mock( return_value={ "mode": "overwrite", "format_": "parquet", "path": "test_folder/historical/entity/feature_set", }) historical_writer = HistoricalFeatureStoreWriter(db_config=s3config) # setup online writer # TODO: Change for CassandraConfig when Cassandra for test is ready online_config = Mock() online_config.mode = "overwrite" online_config.format_ = "parquet" online_config.get_options = Mock( return_value={"path": "test_folder/online/entity/feature_set"}) online_writer = OnlineFeatureStoreWriter(db_config=online_config) writers = [historical_writer, online_writer] sink = Sink(writers) # act client.sql("CREATE DATABASE IF NOT EXISTS {}".format( historical_writer.database)) sink.flush(feature_set, feature_set_df, client) # get historical results historical_result_df = client.read_table(feature_set.name, historical_writer.database) # get online results online_result_df = client.read(online_config.format_, options=online_config.get_options( feature_set.name)) # assert # assert historical results assert sorted(feature_set_df.select(*columns_sort).collect()) == sorted( historical_result_df.select(*columns_sort).collect()) # assert online results assert sorted(target_latest_df.select(*columns_sort).collect()) == sorted( online_result_df.select(*columns_sort).collect()) # tear down shutil.rmtree("test_folder")
def test_construct_without_window( self, feature_set_dataframe, target_df_without_window, ): # given spark_client = SparkClient() # arrange feature_set = AggregatedFeatureSet( name="feature_set", entity="entity", description="description", features=[ Feature( name="feature1", description="test", dtype=DataType.DOUBLE, transformation=AggregatedTransform( functions=[Function(F.avg, DataType.DOUBLE)]), ), Feature( name="feature2", description="test", dtype=DataType.FLOAT, transformation=AggregatedTransform( functions=[Function(F.count, DataType.BIGINT)]), ), ], keys=[ KeyFeature( name="id", description="The user's Main ID or device ID", dtype=DataType.INTEGER, ) ], timestamp=TimestampFeature(from_column="fixed_ts"), ) # act output_df = feature_set.construct(feature_set_dataframe, client=spark_client) # assert assert_dataframe_equality(output_df, target_df_without_window)
def test_construct(self, mocker, target_df): # given spark_client = SparkClient() reader_id = "a_source" reader = mocker.stub(reader_id) reader.build = mocker.stub("build") reader.build.side_effect = target_df.createOrReplaceTempView(reader_id) # when source_selector = Source( readers=[reader], query=f"select * from {reader_id}", # noqa ) result_df = source_selector.construct(spark_client) assert result_df.collect() == target_df.collect()
def test_feature_without_datatype(self, key_id, timestamp_c, dataframe): spark_client = SparkClient() with pytest.raises(ValueError): FeatureSet( name="name", entity="entity", description="description", features=[ Feature( name="feature1", description="test", transformation=SQLExpressionTransform( expression="feature1 + a"), ), ], keys=[key_id], timestamp=timestamp_c, ).construct(dataframe, spark_client)
def test_feature_set_with_invalid_feature(self, key_id, timestamp_c, dataframe): spark_client = SparkClient() with pytest.raises(ValueError): FeatureSet( name="name", entity="entity", description="description", features=[ Feature( name="feature1", description="test", transformation=AggregatedTransform( functions=[Function(F.avg, DataType.FLOAT)]), ), ], keys=[key_id], timestamp=timestamp_c, ).construct(dataframe, spark_client)
def test_construct_with_pivot( self, feature_set_df_pivot, target_df_pivot_agg, ): # given spark_client = SparkClient() # arrange feature_set = AggregatedFeatureSet( name="feature_set", entity="entity", description="description", features=[ Feature( name="feature", description="unit test", transformation=AggregatedTransform(functions=[ Function(F.avg, DataType.FLOAT), Function(F.stddev_pop, DataType.DOUBLE), ], ), from_column="feature1", ) ], keys=[ KeyFeature( name="id", description="The user's Main ID or device ID", dtype=DataType.INTEGER, ) ], timestamp=TimestampFeature(from_column="fixed_ts"), ).with_pivot("pivot_col", ["S", "N"]) # act output_df = feature_set.construct(feature_set_df_pivot, client=spark_client) # assert assert_dataframe_equality(output_df, target_df_pivot_agg)
def test_json_file_with_schema(self): # given spark_client = SparkClient() schema_json = StructType([ StructField("A", StringType()), StructField("B", DoubleType()), StructField("C", StringType()), ]) file = "tests/unit/butterfree/core/extract/readers/file-reader-test.json" # when file_reader = FileReader(id="id", path=file, format="json", schema=schema_json) df = file_reader.consume(spark_client) # assert assert schema_json == df.schema
def _write_stream(self, feature_set: FeatureSet, dataframe: DataFrame, spark_client: SparkClient): """Writes the dataframe in streaming mode.""" # TODO: Refactor this logic using the Sink returning the Query Handler for table in [feature_set.name, feature_set.entity]: checkpoint_path = (os.path.join( self.db_config.stream_checkpoint_path, feature_set.entity, f"{feature_set.name}__on_entity" if table == feature_set.entity else table, ) if self.db_config.stream_checkpoint_path else None) streaming_handler = spark_client.write_stream( dataframe, processing_time=self.db_config.stream_processing_time, output_mode=self.db_config.stream_output_mode, checkpoint_path=checkpoint_path, format_=self.db_config.format_, mode=self.db_config.mode, **self.db_config.get_options(table=table), ) return streaming_handler
def test_flush_with_invalid_df(self, not_feature_set_dataframe, mocker): # given spark_client = SparkClient() writer = [ HistoricalFeatureStoreWriter(), OnlineFeatureStoreWriter(), ] feature_set = mocker.stub("feature_set") feature_set.entity = "house" feature_set.name = "test" # when sink = Sink(writers=writer) # then with pytest.raises(ValueError): sink.flush( dataframe=not_feature_set_dataframe, feature_set=feature_set, spark_client=spark_client, )
def test_agg_feature_set_with_window(self, key_id, timestamp_c, dataframe, rolling_windows_agg_dataframe): spark_client = SparkClient() fs = AggregatedFeatureSet( name="name", entity="entity", description="description", features=[ Feature( name="feature1", description="unit test", transformation=AggregatedTransform( functions=[Function(functions.avg, DataType.FLOAT)]), ), Feature( name="feature2", description="unit test", transformation=AggregatedTransform( functions=[Function(functions.avg, DataType.FLOAT)]), ), ], keys=[key_id], timestamp=timestamp_c, ).with_windows(definitions=["1 week"]) # raises without end date with pytest.raises(ValueError): _ = fs.construct(dataframe, spark_client) # filters with date smaller then mocked max output_df = fs.construct(dataframe, spark_client, end_date="2016-04-17") assert output_df.count() < rolling_windows_agg_dataframe.count() output_df = fs.construct(dataframe, spark_client, end_date="2016-05-01") assert_dataframe_equality(output_df, rolling_windows_agg_dataframe)
def test_write_in_debug_mode( self, feature_set_dataframe, historical_feature_set_dataframe, feature_set, spark_session, ): # given spark_client = SparkClient() writer = HistoricalFeatureStoreWriter(debug_mode=True) # when writer.write( feature_set=feature_set, dataframe=feature_set_dataframe, spark_client=spark_client, ) result_df = spark_session.table( f"historical_feature_store__{feature_set.name}") # then assert_dataframe_equality(historical_feature_set_dataframe, result_df)
def consume(self, client: SparkClient) -> DataFrame: """Extract data from a kafka topic. When stream mode it will get all the new data arriving at the topic in a streaming dataframe. When not in stream mode it will get all data available in the kafka topic. Args: client: client responsible for connecting to Spark session. Returns: Dataframe with data from topic. """ # read using client and cast key and value columns from binary to string raw_df = ( client.read(format="kafka", options=self.options, stream=self.stream) .withColumn("key", col("key").cast("string")) .withColumn("value", col("value").cast("string")) ) # apply schema defined in self.value_schema return self._struct_df(raw_df)
def validate( self, feature_set: FeatureSet, dataframe: DataFrame, spark_client: SparkClient ): """Calculate dataframe rows to validate data into Feature Store. Args: feature_set: object processed with feature_set informations. dataframe: spark dataframe containing data from a feature set. spark_client: client for spark connections with external services. Raises: AssertionError: if count of written data doesn't match count in current feature set dataframe. """ table_name = ( f"{self.database}.{feature_set.name}" if not self.debug_mode else f"historical_feature_store__{feature_set.name}" ) written_count = spark_client.read_table(table_name).count() dataframe_count = dataframe.count() self._assert_validation_count(table_name, written_count, dataframe_count)
def test_construct_rolling_windows_without_end_date( self, feature_set_dataframe, rolling_windows_output_feature_set_dataframe): # given spark_client = SparkClient() # arrange feature_set = AggregatedFeatureSet( name="feature_set", entity="entity", description="description", features=[ Feature( name="feature1", description="test", transformation=AggregatedTransform(functions=[ Function(F.avg, DataType.DOUBLE), Function(F.stddev_pop, DataType.DOUBLE), ], ), ), ], keys=[ KeyFeature( name="id", description="The user's Main ID or device ID", dtype=DataType.INTEGER, ) ], timestamp=TimestampFeature(), ).with_windows(definitions=["1 day", "1 week"], ) # act & assert with pytest.raises(ValueError): _ = feature_set.construct(feature_set_dataframe, client=spark_client)
def test_validate_false(self, feature_set_dataframe, mocker): # given spark_client = SparkClient() writer = [ HistoricalFeatureStoreWriter(), OnlineFeatureStoreWriter(), ] for w in writer: w.validate = mocker.stub("validate") w.validate.side_effect = AssertionError("test") feature_set = mocker.stub("feature_set") # when sink = Sink(writers=writer) # then with pytest.raises(RuntimeError): sink.validate( dataframe=feature_set_dataframe, feature_set=feature_set, spark_client=spark_client, )
def test_validate(self, feature_set_dataframe, mocker): # given spark_client = SparkClient() writer = [ HistoricalFeatureStoreWriter(), OnlineFeatureStoreWriter(), ] for w in writer: w.validate = mocker.stub("validate") feature_set = mocker.stub("feature_set") # when sink = Sink(writers=writer) sink.validate( dataframe=feature_set_dataframe, feature_set=feature_set, spark_client=spark_client, ) # then for w in writer: w.validate.assert_called_once()
def test_write_dataframe(self, format, mode, mocked_spark_write): SparkClient.write_dataframe(mocked_spark_write, format, mode) mocked_spark_write.save.assert_called_with(format=format, mode=mode)
def test_feature_transform_with_data_type_array(self, spark_context, spark_session): # arrange input_data = [ { "id": 1, "timestamp": "2020-04-22T00:00:00+00:00", "feature": 10 }, { "id": 1, "timestamp": "2020-04-22T00:00:00+00:00", "feature": 20 }, { "id": 1, "timestamp": "2020-04-22T00:00:00+00:00", "feature": 30 }, { "id": 2, "timestamp": "2020-04-22T00:00:00+00:00", "feature": 10 }, ] target_data = [ { "id": 1, "timestamp": "2020-04-22T00:00:00+00:00", "feature__collect_set": [30.0, 20.0, 10.0], }, { "id": 2, "timestamp": "2020-04-22T00:00:00+00:00", "feature__collect_set": [10.0], }, ] input_df = create_df_from_collection( input_data, spark_context, spark_session).withColumn( "timestamp", functions.to_timestamp(functions.col("timestamp"))) target_df = create_df_from_collection( target_data, spark_context, spark_session).withColumn( "timestamp", functions.to_timestamp(functions.col("timestamp"))) fs = AggregatedFeatureSet( name="name", entity="entity", description="description", keys=[ KeyFeature(name="id", description="test", dtype=DataType.INTEGER) ], timestamp=TimestampFeature(), features=[ Feature( name="feature", description="aggregations with ", dtype=DataType.BIGINT, transformation=AggregatedTransform(functions=[ Function(functions.collect_set, DataType.ARRAY_FLOAT), ], ), from_column="feature", ), ], ) # act output_df = fs.construct(input_df, SparkClient()) # assert assert_dataframe_equality(target_df, output_df)
def test_feature_transform_with_filter_expression(self, spark_context, spark_session): # arrange input_data = [ { "id": 1, "timestamp": "2020-04-22T00:00:00+00:00", "feature": 10, "type": "a", }, { "id": 1, "timestamp": "2020-04-22T00:00:00+00:00", "feature": 20, "type": "a", }, { "id": 1, "timestamp": "2020-04-22T00:00:00+00:00", "feature": 30, "type": "b", }, { "id": 2, "timestamp": "2020-04-22T00:00:00+00:00", "feature": 10, "type": "a", }, ] target_data = [ { "id": 1, "timestamp": "2020-04-22T00:00:00+00:00", "feature_only_type_a__avg": 15.0, "feature_only_type_a__min": 10, "feature_only_type_a__max": 20, }, { "id": 2, "timestamp": "2020-04-22T00:00:00+00:00", "feature_only_type_a__avg": 10.0, "feature_only_type_a__min": 10, "feature_only_type_a__max": 10, }, ] input_df = create_df_from_collection( input_data, spark_context, spark_session).withColumn( "timestamp", functions.to_timestamp(functions.col("timestamp"))) target_df = create_df_from_collection( target_data, spark_context, spark_session).withColumn( "timestamp", functions.to_timestamp(functions.col("timestamp"))) fs = AggregatedFeatureSet( name="name", entity="entity", description="description", keys=[ KeyFeature(name="id", description="test", dtype=DataType.INTEGER) ], timestamp=TimestampFeature(), features=[ Feature( name="feature_only_type_a", description="aggregations only when type = a", dtype=DataType.BIGINT, transformation=AggregatedTransform( functions=[ Function(functions.avg, DataType.FLOAT), Function(functions.min, DataType.FLOAT), Function(functions.max, DataType.FLOAT), ], filter_expression="type = 'a'", ), from_column="feature", ), ], ) # act output_df = fs.construct(input_df, SparkClient()) # assert assert_dataframe_equality(target_df, output_df)
def _write_in_debug_mode(feature_set: FeatureSet, dataframe: DataFrame, spark_client: SparkClient): """Creates a temporary table instead of writing to the real data source.""" return spark_client.create_temporary_view( dataframe=dataframe, name=f"online_feature_store__{feature_set.name}")