Ejemplo n.º 1
0
 def test_source_raise(self):
     with pytest.raises(ValueError,
                        match="source must be a Source instance"):
         FeatureSetPipeline(
             spark_client=SparkClient(),
             source=Mock(
                 spark_client=SparkClient(),
                 readers=[
                     TableReader(
                         id="source_a",
                         database="db",
                         table="table",
                     ),
                 ],
                 query="select * from source_a",
             ),
             feature_set=Mock(
                 spec=FeatureSet,
                 name="feature_set",
                 entity="entity",
                 description="description",
                 keys=[
                     KeyFeature(
                         name="user_id",
                         description="The user's Main ID or device ID",
                         dtype=DataType.INTEGER,
                     )
                 ],
                 timestamp=TimestampFeature(from_column="ts"),
                 features=[
                     Feature(
                         name="listing_page_viewed__rent_per_month",
                         description="Average of something.",
                         transformation=SparkFunctionTransform(functions=[
                             Function(functions.avg, DataType.FLOAT),
                             Function(functions.stddev_pop, DataType.FLOAT),
                         ], ).with_window(
                             partition_by="user_id",
                             order_by=TIMESTAMP_COLUMN,
                             window_definition=["7 days", "2 weeks"],
                             mode="fixed_windows",
                         ),
                     ),
                 ],
             ),
             sink=Mock(
                 spec=Sink,
                 writers=[HistoricalFeatureStoreWriter(db_config=None)],
             ),
         )
Ejemplo n.º 2
0
    def test_get_schema(self, target_df: DataFrame,
                        spark_session: SparkSession) -> None:
        # arrange
        spark_client = SparkClient()
        create_temp_view(dataframe=target_df, name="temp_view")
        create_db_and_table(
            spark=spark_session,
            database="test_db",
            table="test_table",
            view="temp_view",
        )

        expected_schema = [
            {
                "col_name": "col1",
                "data_type": "string"
            },
            {
                "col_name": "col2",
                "data_type": "bigint"
            },
        ]

        # act
        schema = spark_client.get_schema(table="test_table",
                                         database="test_db")

        # assert
        assert schema, expected_schema
Ejemplo n.º 3
0
    def test_read(
        self,
        format: str,
        stream: bool,
        schema: Optional[StructType],
        path: Any,
        options: Any,
        target_df: DataFrame,
        mocked_spark_read: Mock,
    ) -> None:
        # arrange
        spark_client = SparkClient()
        mocked_spark_read.load.return_value = target_df
        spark_client._session = mocked_spark_read

        # act
        result_df = spark_client.read(format=format,
                                      schema=schema,
                                      stream=stream,
                                      path=path,
                                      **options)

        # assert
        mocked_spark_read.format.assert_called_once_with(format)
        mocked_spark_read.load.assert_called_once_with(path=path, **options)
        assert target_df.collect() == result_df.collect()
Ejemplo n.º 4
0
    def test_add_table_partitions(self, mock_spark_sql: Mock):
        # arrange
        target_command = (f"ALTER TABLE `db`.`table` ADD IF NOT EXISTS "
                          f"PARTITION ( year = 2020, month = 8, day = 14 ) "
                          f"PARTITION ( year = 2020, month = 8, day = 15 ) "
                          f"PARTITION ( year = 2020, month = 8, day = 16 )")

        spark_client = SparkClient()
        spark_client._session = mock_spark_sql
        partitions = [
            {
                "year": 2020,
                "month": 8,
                "day": 14
            },
            {
                "year": 2020,
                "month": 8,
                "day": 15
            },
            {
                "year": 2020,
                "month": 8,
                "day": 16
            },
        ]

        # act
        spark_client.add_table_partitions(partitions, "table", "db")

        # assert
        mock_spark_sql.assert_called_once_with(target_command)
    def test_write_interval_mode(
        self,
        feature_set_dataframe,
        historical_feature_set_dataframe,
        mocker,
        feature_set,
    ):
        # given
        spark_client = SparkClient()
        spark_client.write_table = mocker.stub("write_table")
        spark_client.conn.conf.set("spark.sql.sources.partitionOverwriteMode",
                                   "dynamic")
        writer = HistoricalFeatureStoreWriter(interval_mode=True)

        # when
        writer.write(
            feature_set=feature_set,
            dataframe=feature_set_dataframe,
            spark_client=spark_client,
        )
        result_df = spark_client.write_table.call_args[1]["dataframe"]

        # then
        assert_dataframe_equality(historical_feature_set_dataframe, result_df)

        assert writer.database == spark_client.write_table.call_args[1][
            "database"]
        assert feature_set.name == spark_client.write_table.call_args[1][
            "table_name"]
        assert (writer.PARTITION_BY == spark_client.write_table.call_args[1]
                ["partition_by"])
Ejemplo n.º 6
0
    def test_csv_file_with_schema_and_header(self):
        # given
        spark_client = SparkClient()
        schema_csv = StructType(
            [
                StructField("A", LongType()),
                StructField("B", DoubleType()),
                StructField("C", StringType()),
            ]
        )

        file = "tests/unit/butterfree/extract/readers/file-reader-test.csv"

        # when
        file_reader = FileReader(
            id="id",
            path=file,
            format="csv",
            schema=schema_csv,
            format_options={"header": True},
        )
        df = file_reader.consume(spark_client)

        # assert
        assert schema_csv == df.schema
        assert df.columns == ["A", "B", "C"]
        for value in range(3):
            assert df.first()[value] != ["A", "B", "C"][value]
    def test_feature_transform_with_distinct_empty_subset(
            self, timestamp_c, feature_set_with_distinct_dataframe):
        spark_client = SparkClient()

        with pytest.raises(ValueError,
                           match="The distinct subset param can't be empty."):
            AggregatedFeatureSet(
                name="name",
                entity="entity",
                description="description",
                features=[
                    Feature(
                        name="feature",
                        description="test",
                        transformation=AggregatedTransform(functions=[
                            Function(functions.sum, DataType.INTEGER)
                        ]),
                    ),
                ],
                keys=[
                    KeyFeature(name="h3",
                               description="test",
                               dtype=DataType.STRING)
                ],
                timestamp=timestamp_c,
            ).with_windows(["3 days"]).with_distinct(
                subset=[],
                keep="first").construct(feature_set_with_distinct_dataframe,
                                        spark_client,
                                        end_date="2020-01-10")
Ejemplo n.º 8
0
    def test_write_stream(self, mocked_stream_df: Mock) -> None:
        # arrange
        spark_client = SparkClient()

        processing_time = "0 seconds"
        output_mode = "update"
        checkpoint_path = "s3://path/to/checkpoint"

        # act
        stream_handler = spark_client.write_stream(
            mocked_stream_df,
            processing_time,
            output_mode,
            checkpoint_path,
            format_="parquet",
            mode="append",
        )

        # assert
        assert isinstance(stream_handler, StreamingQuery)
        mocked_stream_df.trigger.assert_called_with(processingTime=processing_time)
        mocked_stream_df.outputMode.assert_called_with(output_mode)
        mocked_stream_df.option.assert_called_with(
            "checkpointLocation", checkpoint_path
        )
        mocked_stream_df.foreachBatch.assert_called_once()
        mocked_stream_df.start.assert_called_once()
    def test_feature_set_with_invalid_feature(self, key_id, timestamp_c,
                                              dataframe):
        spark_client = SparkClient()

        with pytest.raises(ValueError):
            AggregatedFeatureSet(
                name="name",
                entity="entity",
                description="description",
                features=[
                    Feature(
                        name="feature1",
                        description="test",
                        transformation=SparkFunctionTransform(functions=[
                            Function(functions.avg, DataType.FLOAT)
                        ], ).with_window(
                            partition_by="id",
                            mode="row_windows",
                            window_definition=["2 events"],
                        ),
                    ),
                ],
                keys=[key_id],
                timestamp=timestamp_c,
            ).construct(dataframe, spark_client)
    def test_feature_transform_with_distinct(
        self,
        timestamp_c,
        feature_set_with_distinct_dataframe,
        target_with_distinct_dataframe,
    ):
        spark_client = SparkClient()

        fs = (AggregatedFeatureSet(
            name="name",
            entity="entity",
            description="description",
            features=[
                Feature(
                    name="feature",
                    description="test",
                    transformation=AggregatedTransform(
                        functions=[Function(functions.sum, DataType.INTEGER)]),
                ),
            ],
            keys=[
                KeyFeature(name="h3",
                           description="test",
                           dtype=DataType.STRING)
            ],
            timestamp=timestamp_c,
        ).with_windows(["3 days"]).with_distinct(subset=["id"], keep="last"))

        # assert
        output_df = fs.construct(feature_set_with_distinct_dataframe,
                                 spark_client,
                                 end_date="2020-01-10")
        assert_dataframe_equality(output_df, target_with_distinct_dataframe)
    def test_h3_feature_set(self, h3_input_df, h3_target_df):
        spark_client = SparkClient()

        feature_set = AggregatedFeatureSet(
            name="h3_test",
            entity="h3geolocation",
            description="Test",
            keys=[
                KeyFeature(
                    name="h3_id",
                    description="The h3 hash ID",
                    dtype=DataType.DOUBLE,
                    transformation=H3HashTransform(
                        h3_resolutions=[6, 7, 8, 9, 10, 11, 12],
                        lat_column="lat",
                        lng_column="lng",
                    ).with_stack(),
                )
            ],
            timestamp=TimestampFeature(),
            features=[
                Feature(
                    name="house_id",
                    description="Count of house ids over a day.",
                    transformation=AggregatedTransform(
                        functions=[Function(F.count, DataType.BIGINT)]),
                ),
            ],
        ).with_windows(definitions=["1 day"])

        output_df = feature_set.construct(h3_input_df,
                                          client=spark_client,
                                          end_date="2016-04-14")

        assert_dataframe_equality(output_df, h3_target_df)
Ejemplo n.º 12
0
    def test_write_in_debug_and_stream_mode(
        self, feature_set, spark_session,
    ):
        # arrange
        spark_client = SparkClient()

        mocked_stream_df = Mock()
        mocked_stream_df.isStreaming = True
        mocked_stream_df.writeStream = mocked_stream_df
        mocked_stream_df.format.return_value = mocked_stream_df
        mocked_stream_df.queryName.return_value = mocked_stream_df
        mocked_stream_df.start.return_value = Mock(spec=StreamingQuery)

        writer = OnlineFeatureStoreWriter(debug_mode=True)

        # act
        handler = writer.write(
            feature_set=feature_set,
            dataframe=mocked_stream_df,
            spark_client=spark_client,
        )

        # assert
        mocked_stream_df.format.assert_called_with("memory")
        mocked_stream_df.queryName.assert_called_with(
            f"online_feature_store__{feature_set.name}"
        )
        assert isinstance(handler, StreamingQuery)
Ejemplo n.º 13
0
def test_sink(input_dataframe, feature_set):
    # arrange
    client = SparkClient()
    client.conn.conf.set("spark.sql.sources.partitionOverwriteMode", "dynamic")
    feature_set_df = feature_set.construct(input_dataframe, client)
    target_latest_df = OnlineFeatureStoreWriter.filter_latest(
        feature_set_df, id_columns=[key.name for key in feature_set.keys])
    columns_sort = feature_set_df.schema.fieldNames()

    # setup historical writer
    s3config = Mock()
    s3config.mode = "overwrite"
    s3config.format_ = "parquet"
    s3config.get_options = Mock(
        return_value={"path": "test_folder/historical/entity/feature_set"})
    s3config.get_path_with_partitions = Mock(
        return_value="test_folder/historical/entity/feature_set")

    historical_writer = HistoricalFeatureStoreWriter(db_config=s3config,
                                                     interval_mode=True)

    # setup online writer
    # TODO: Change for CassandraConfig when Cassandra for test is ready
    online_config = Mock()
    online_config.mode = "overwrite"
    online_config.format_ = "parquet"
    online_config.get_options = Mock(
        return_value={"path": "test_folder/online/entity/feature_set"})
    online_writer = OnlineFeatureStoreWriter(db_config=online_config)

    writers = [historical_writer, online_writer]
    sink = Sink(writers)

    # act
    client.sql("CREATE DATABASE IF NOT EXISTS {}".format(
        historical_writer.database))
    sink.flush(feature_set, feature_set_df, client)

    # get historical results
    historical_result_df = client.read(
        s3config.format_,
        path=s3config.get_path_with_partitions(feature_set.name,
                                               feature_set_df),
    )

    # get online results
    online_result_df = client.read(
        online_config.format_, **online_config.get_options(feature_set.name))

    # assert
    # assert historical results
    assert sorted(feature_set_df.select(*columns_sort).collect()) == sorted(
        historical_result_df.select(*columns_sort).collect())

    # assert online results
    assert sorted(target_latest_df.select(*columns_sort).collect()) == sorted(
        online_result_df.select(*columns_sort).collect())

    # tear down
    shutil.rmtree("test_folder")
Ejemplo n.º 14
0
    def test_flush(self, feature_set_dataframe, mocker):
        # given
        spark_client = SparkClient()
        writer = [
            HistoricalFeatureStoreWriter(),
            OnlineFeatureStoreWriter(),
        ]

        for w in writer:
            w.write = mocker.stub("write")

        feature_set = mocker.stub("feature_set")
        feature_set.entity = "house"
        feature_set.name = "test"

        # when
        sink = Sink(writers=writer)
        sink.flush(
            dataframe=feature_set_dataframe,
            feature_set=feature_set,
            spark_client=spark_client,
        )

        # then
        for w in writer:
            w.write.assert_called_once()
Ejemplo n.º 15
0
    def test_construct(
        self, feature_set_dataframe, fixed_windows_output_feature_set_dataframe
    ):
        # given

        spark_client = SparkClient()

        # arrange

        feature_set = FeatureSet(
            name="feature_set",
            entity="entity",
            description="description",
            features=[
                Feature(
                    name="feature1",
                    description="test",
                    transformation=SparkFunctionTransform(
                        functions=[
                            Function(F.avg, DataType.FLOAT),
                            Function(F.stddev_pop, DataType.FLOAT),
                        ]
                    ).with_window(
                        partition_by="id",
                        order_by=TIMESTAMP_COLUMN,
                        mode="fixed_windows",
                        window_definition=["2 minutes", "15 minutes"],
                    ),
                ),
                Feature(
                    name="divided_feature",
                    description="unit test",
                    dtype=DataType.FLOAT,
                    transformation=CustomTransform(
                        transformer=divide, column1="feature1", column2="feature2",
                    ),
                ),
            ],
            keys=[
                KeyFeature(
                    name="id",
                    description="The user's Main ID or device ID",
                    dtype=DataType.INTEGER,
                )
            ],
            timestamp=TimestampFeature(),
        )

        output_df = (
            feature_set.construct(feature_set_dataframe, client=spark_client)
            .orderBy(feature_set.timestamp_column)
            .select(feature_set.columns)
        )

        target_df = fixed_windows_output_feature_set_dataframe.orderBy(
            feature_set.timestamp_column
        ).select(feature_set.columns)

        # assert
        assert_dataframe_equality(output_df, target_df)
    def test_write(
        self,
        feature_set_dataframe,
        historical_feature_set_dataframe,
        mocker,
        feature_set,
    ):
        # given
        spark_client = SparkClient()
        spark_client.write_table = mocker.stub("write_table")
        writer = HistoricalFeatureStoreWriter()

        # when
        writer.write(
            feature_set=feature_set,
            dataframe=feature_set_dataframe,
            spark_client=spark_client,
        )
        result_df = spark_client.write_table.call_args[1]["dataframe"]

        # then
        assert_dataframe_equality(historical_feature_set_dataframe, result_df)

        assert (writer.db_config.format_ ==
                spark_client.write_table.call_args[1]["format_"])
        assert writer.db_config.mode == spark_client.write_table.call_args[1][
            "mode"]
        assert (writer.PARTITION_BY == spark_client.write_table.call_args[1]
                ["partition_by"])
    def test_write_in_debug_mode_with_interval_mode(
        self,
        feature_set_dataframe,
        historical_feature_set_dataframe,
        feature_set,
        spark_session,
        mocker,
    ):
        # given
        spark_client = SparkClient()
        spark_client.write_dataframe = mocker.stub("write_dataframe")
        spark_client.conn.conf.set("spark.sql.sources.partitionOverwriteMode",
                                   "dynamic")
        writer = HistoricalFeatureStoreWriter(debug_mode=True,
                                              interval_mode=True)

        # when
        writer.write(
            feature_set=feature_set,
            dataframe=feature_set_dataframe,
            spark_client=spark_client,
        )
        result_df = spark_session.table(
            f"historical_feature_store__{feature_set.name}")

        # then
        assert_dataframe_equality(historical_feature_set_dataframe, result_df)
Ejemplo n.º 18
0
    def test_add_invalid_partitions(self, mock_spark_sql: Mock, partition):
        # arrange
        spark_client = SparkClient()
        spark_client._session = mock_spark_sql

        # act and assert
        with pytest.raises(ValueError):
            spark_client.add_table_partitions(partition, "table", "db")
Ejemplo n.º 19
0
 def __init__(
     self,
     database: str = None,
 ) -> None:
     self._db_config = MetastoreConfig()
     self.database = database or environment.get_variable(
         "FEATURE_STORE_HISTORICAL_DATABASE")
     super(MetastoreMigration, self).__init__(SparkClient())
Ejemplo n.º 20
0
    def test_read_invalid_params(self, format: Optional[str],
                                 path: Any) -> None:
        # arrange
        spark_client = SparkClient()

        # act and assert
        with pytest.raises(ValueError):
            spark_client.read(format=format, path=path)  # type: ignore
Ejemplo n.º 21
0
    def test_read_table_invalid_params(self, database: str,
                                       table: Optional[int]) -> None:
        # arrange
        spark_client = SparkClient()

        # act and assert
        with pytest.raises(ValueError):
            spark_client.read_table(table, database)  # type: ignore
Ejemplo n.º 22
0
    def test_run_agg_with_end_date(self, spark_session):
        test_pipeline = FeatureSetPipeline(
            spark_client=SparkClient(),
            source=Mock(
                spec=Source,
                readers=[
                    TableReader(
                        id="source_a",
                        database="db",
                        table="table",
                    )
                ],
                query="select * from source_a",
            ),
            feature_set=Mock(
                spec=AggregatedFeatureSet,
                name="feature_set",
                entity="entity",
                description="description",
                keys=[
                    KeyFeature(
                        name="user_id",
                        description="The user's Main ID or device ID",
                        dtype=DataType.INTEGER,
                    )
                ],
                timestamp=TimestampFeature(from_column="ts"),
                features=[
                    Feature(
                        name="listing_page_viewed__rent_per_month",
                        description="Average of something.",
                        transformation=AggregatedTransform(functions=[
                            Function(functions.avg, DataType.FLOAT),
                            Function(functions.stddev_pop, DataType.FLOAT),
                        ], ),
                    ),
                ],
            ),
            sink=Mock(
                spec=Sink,
                writers=[HistoricalFeatureStoreWriter(db_config=None)],
            ),
        )

        # feature_set need to return a real df for streaming validation
        sample_df = spark_session.createDataFrame([{
            "a": "x",
            "b": "y",
            "c": "3"
        }])
        test_pipeline.feature_set.construct.return_value = sample_df

        test_pipeline.run(end_date="2016-04-18")

        test_pipeline.source.construct.assert_called_once()
        test_pipeline.feature_set.construct.assert_called_once()
        test_pipeline.sink.flush.assert_called_once()
        test_pipeline.sink.validate.assert_called_once()
    def test_construct_rolling_windows_with_end_date(
        self,
        feature_set_dataframe,
        rolling_windows_output_feature_set_dataframe_base_date,
    ):
        # given

        spark_client = SparkClient()

        # arrange

        feature_set = AggregatedFeatureSet(
            name="feature_set",
            entity="entity",
            description="description",
            features=[
                Feature(
                    name="feature1",
                    description="test",
                    transformation=AggregatedTransform(
                        functions=[
                            Function(F.avg, DataType.DOUBLE),
                            Function(F.stddev_pop, DataType.DOUBLE),
                        ],
                    ),
                ),
                Feature(
                    name="feature2",
                    description="test",
                    transformation=AggregatedTransform(
                        functions=[
                            Function(F.avg, DataType.DOUBLE),
                            Function(F.stddev_pop, DataType.DOUBLE),
                        ],
                    ),
                ),
            ],
            keys=[
                KeyFeature(
                    name="id",
                    description="The user's Main ID or device ID",
                    dtype=DataType.INTEGER,
                )
            ],
            timestamp=TimestampFeature(),
        ).with_windows(definitions=["1 day", "1 week"])

        # act
        output_df = feature_set.construct(
            feature_set_dataframe, client=spark_client, end_date="2016-04-18"
        ).orderBy("timestamp")

        target_df = rolling_windows_output_feature_set_dataframe_base_date.orderBy(
            feature_set.timestamp_column
        ).select(feature_set.columns)

        # assert
        assert_dataframe_equality(output_df, target_df)
Ejemplo n.º 24
0
    def test_conn(self) -> None:
        # arrange
        spark_client = SparkClient()

        # act
        start_conn = spark_client._session

        # assert
        assert start_conn is None
Ejemplo n.º 25
0
    def test_read_invalid_params(
        self, format: Optional[str], options: Union[Dict[str, Any], str]
    ) -> None:
        # arrange
        spark_client = SparkClient()

        # act and assert
        with pytest.raises(ValueError):
            spark_client.read(format, options)  # type: ignore
Ejemplo n.º 26
0
    def test_sql(self, target_df: DataFrame) -> None:
        # arrange
        spark_client = SparkClient()
        create_temp_view(target_df, "test")

        # act
        result_df = spark_client.sql("select * from test")

        # assert
        assert result_df.collect() == target_df.collect()
Ejemplo n.º 27
0
    def test_write_dataframe_invalid_params(
        self, target_df: DataFrame, format: Optional[str], mode: Union[str, int]
    ) -> None:
        # arrange
        spark_client = SparkClient()

        # act and assert
        with pytest.raises(ValueError):
            spark_client.write_dataframe(
                dataframe=target_df, format_=format, mode=mode  # type: ignore
            )
Ejemplo n.º 28
0
    def test_create_temporary_view(self, target_df: DataFrame,
                                   spark_session: SparkSession) -> None:
        # arrange
        spark_client = SparkClient()

        # act
        spark_client.create_temporary_view(target_df, "temp_view")
        result_df = spark_session.table("temp_view")

        # assert
        assert_dataframe_equality(target_df, result_df)
Ejemplo n.º 29
0
 def __init__(
     self,
     source: Source,
     feature_set: FeatureSet,
     sink: Sink,
     spark_client: SparkClient = None,
 ):
     self.source = source
     self.feature_set = feature_set
     self.sink = sink
     self.spark_client = spark_client or SparkClient()
Ejemplo n.º 30
0
    def test_build_table_expression(self, table, database,
                                    target_table_expression):
        # arrange
        spark_client = SparkClient()

        # act
        result_table_expression = SparkTableSchemaCompatibilityHook(
            spark_client, table, database).table_expression

        # assert
        assert target_table_expression == result_table_expression