示例#1
0
def test_sink(input_dataframe, feature_set):
    # arrange
    client = SparkClient()
    client.conn.conf.set("spark.sql.sources.partitionOverwriteMode", "dynamic")
    feature_set_df = feature_set.construct(input_dataframe, client)
    target_latest_df = OnlineFeatureStoreWriter.filter_latest(
        feature_set_df, id_columns=[key.name for key in feature_set.keys])
    columns_sort = feature_set_df.schema.fieldNames()

    # setup historical writer
    s3config = Mock()
    s3config.mode = "overwrite"
    s3config.format_ = "parquet"
    s3config.get_options = Mock(
        return_value={"path": "test_folder/historical/entity/feature_set"})
    s3config.get_path_with_partitions = Mock(
        return_value="test_folder/historical/entity/feature_set")

    historical_writer = HistoricalFeatureStoreWriter(db_config=s3config,
                                                     interval_mode=True)

    # setup online writer
    # TODO: Change for CassandraConfig when Cassandra for test is ready
    online_config = Mock()
    online_config.mode = "overwrite"
    online_config.format_ = "parquet"
    online_config.get_options = Mock(
        return_value={"path": "test_folder/online/entity/feature_set"})
    online_writer = OnlineFeatureStoreWriter(db_config=online_config)

    writers = [historical_writer, online_writer]
    sink = Sink(writers)

    # act
    client.sql("CREATE DATABASE IF NOT EXISTS {}".format(
        historical_writer.database))
    sink.flush(feature_set, feature_set_df, client)

    # get historical results
    historical_result_df = client.read(
        s3config.format_,
        path=s3config.get_path_with_partitions(feature_set.name,
                                               feature_set_df),
    )

    # get online results
    online_result_df = client.read(
        online_config.format_, **online_config.get_options(feature_set.name))

    # assert
    # assert historical results
    assert sorted(feature_set_df.select(*columns_sort).collect()) == sorted(
        historical_result_df.select(*columns_sort).collect())

    # assert online results
    assert sorted(target_latest_df.select(*columns_sort).collect()) == sorted(
        online_result_df.select(*columns_sort).collect())

    # tear down
    shutil.rmtree("test_folder")
示例#2
0
    def construct(self, client: SparkClient) -> DataFrame:
        """Construct an entry point dataframe for a feature set.

        This method will assemble multiple readers, by building each one and
        querying them using a Spark SQL.

        After that, there's the caching of the dataframe, however since cache()
        in Spark is lazy, an action is triggered in order to force persistence.

        Args:
            client: client responsible for connecting to Spark session.

        Returns:
            DataFrame with the query result against all readers.

        """
        for reader in self.readers:
            reader.build(client)  # create temporary views for each reader

        dataframe = client.sql(self.query)

        if not dataframe.isStreaming:
            dataframe.cache().count()

        return dataframe
示例#3
0
    def test_sql(self, target_df: DataFrame) -> None:
        # arrange
        spark_client = SparkClient()
        create_temp_view(target_df, "test")

        # act
        result_df = spark_client.sql("select * from test")

        # assert
        assert result_df.collect() == target_df.collect()
示例#4
0
    def construct(
        self, client: SparkClient, start_date: str = None, end_date: str = None
    ) -> DataFrame:
        """Construct an entry point dataframe for a feature set.

        This method will assemble multiple readers, by building each one and
        querying them using a Spark SQL. It's important to highlight that in
        order to filter a dataframe regarding date boundaries, it's important
        to define a IncrementalStrategy, otherwise your data will not be filtered.
        Besides, both start and end dates parameters are optional.

        After that, there's the caching of the dataframe, however since cache()
        in Spark is lazy, an action is triggered in order to force persistence.

        Args:
            client: client responsible for connecting to Spark session.
            start_date: user defined start date for filtering.
            end_date: user defined end date for filtering.

        Returns:
            DataFrame with the query result against all readers.

        """
        for reader in self.readers:
            reader.build(
                client=client, start_date=start_date, end_date=end_date
            )  # create temporary views for each reader

        dataframe = client.sql(self.query)

        if not dataframe.isStreaming:
            dataframe.cache().count()

        post_hook_df = self.run_post_hooks(dataframe)

        return post_hook_df