예제 #1
0
def test_sink(input_dataframe, feature_set):
    # arrange
    client = SparkClient()
    client.conn.conf.set("spark.sql.sources.partitionOverwriteMode", "dynamic")
    feature_set_df = feature_set.construct(input_dataframe, client)
    target_latest_df = OnlineFeatureStoreWriter.filter_latest(
        feature_set_df, id_columns=[key.name for key in feature_set.keys])
    columns_sort = feature_set_df.schema.fieldNames()

    # setup historical writer
    s3config = Mock()
    s3config.mode = "overwrite"
    s3config.format_ = "parquet"
    s3config.get_options = Mock(
        return_value={"path": "test_folder/historical/entity/feature_set"})
    s3config.get_path_with_partitions = Mock(
        return_value="test_folder/historical/entity/feature_set")

    historical_writer = HistoricalFeatureStoreWriter(db_config=s3config,
                                                     interval_mode=True)

    # setup online writer
    # TODO: Change for CassandraConfig when Cassandra for test is ready
    online_config = Mock()
    online_config.mode = "overwrite"
    online_config.format_ = "parquet"
    online_config.get_options = Mock(
        return_value={"path": "test_folder/online/entity/feature_set"})
    online_writer = OnlineFeatureStoreWriter(db_config=online_config)

    writers = [historical_writer, online_writer]
    sink = Sink(writers)

    # act
    client.sql("CREATE DATABASE IF NOT EXISTS {}".format(
        historical_writer.database))
    sink.flush(feature_set, feature_set_df, client)

    # get historical results
    historical_result_df = client.read(
        s3config.format_,
        path=s3config.get_path_with_partitions(feature_set.name,
                                               feature_set_df),
    )

    # get online results
    online_result_df = client.read(
        online_config.format_, **online_config.get_options(feature_set.name))

    # assert
    # assert historical results
    assert sorted(feature_set_df.select(*columns_sort).collect()) == sorted(
        historical_result_df.select(*columns_sort).collect())

    # assert online results
    assert sorted(target_latest_df.select(*columns_sort).collect()) == sorted(
        online_result_df.select(*columns_sort).collect())

    # tear down
    shutil.rmtree("test_folder")
예제 #2
0
    def test_read_invalid_params(self, format: Optional[str],
                                 path: Any) -> None:
        # arrange
        spark_client = SparkClient()

        # act and assert
        with pytest.raises(ValueError):
            spark_client.read(format=format, path=path)  # type: ignore
예제 #3
0
    def test_read_invalid_params(
        self, format: Optional[str], options: Union[Dict[str, Any], str]
    ) -> None:
        # arrange
        spark_client = SparkClient()

        # act and assert
        with pytest.raises(ValueError):
            spark_client.read(format, options)  # type: ignore
    def validate(self, feature_set: FeatureSet, dataframe: DataFrame,
                 spark_client: SparkClient) -> None:
        """Calculate dataframe rows to validate data into Feature Store.

        Args:
            feature_set: object processed with feature_set informations.
            dataframe: spark dataframe containing data from a feature set.
            spark_client: client for spark connections with external services.

        Raises:
            AssertionError: if count of written data doesn't match count in current
                feature set dataframe.
        """
        table_name = (
            os.path.join("historical", feature_set.entity, feature_set.name)
            if self.interval_mode and not self.debug_mode else
            (f"{self.database}.{feature_set.name}" if not self.debug_mode else
             f"historical_feature_store__{feature_set.name}"))

        written_count = (spark_client.read(
            self.db_config.format_,
            path=self.db_config.get_path_with_partitions(
                table_name, self._create_partitions(dataframe)),
        ).count() if self.interval_mode and not self.debug_mode else
                         spark_client.read_table(table_name).count())

        dataframe_count = dataframe.count()

        self._assert_validation_count(table_name, written_count,
                                      dataframe_count)
예제 #5
0
    def consume(self, client: SparkClient) -> DataFrame:
        """Extract data from a kafka topic.

        When stream mode it will get all the new data arriving at the topic in a
        streaming dataframe. When not in stream mode it will get all data
        available in the kafka topic.

        Args:
            client: client responsible for connecting to Spark session.

        Returns:
            Dataframe with data from topic.

        """
        # read using client and cast key and value columns from binary to string
        raw_df = (client.read(format="kafka",
                              options=self.options,
                              stream=self.stream).withColumn(
                                  "key",
                                  col("key").cast("string")).withColumn(
                                      "value",
                                      col("value").cast("string")))

        # apply schema defined in self.value_schema
        return self._struct_df(raw_df)
예제 #6
0
    def test_read(
        self,
        format: str,
        stream: bool,
        schema: Optional[StructType],
        path: Any,
        options: Any,
        target_df: DataFrame,
        mocked_spark_read: Mock,
    ) -> None:
        # arrange
        spark_client = SparkClient()
        mocked_spark_read.load.return_value = target_df
        spark_client._session = mocked_spark_read

        # act
        result_df = spark_client.read(format=format,
                                      schema=schema,
                                      stream=stream,
                                      path=path,
                                      **options)

        # assert
        mocked_spark_read.format.assert_called_once_with(format)
        mocked_spark_read.load.assert_called_once_with(path=path, **options)
        assert target_df.collect() == result_df.collect()
예제 #7
0
    def consume(self, client: SparkClient) -> DataFrame:
        """Extract data from files stored in defined path.

        Try to auto-infer schema if in stream mode and not manually defining a
        schema.

        Args:
            client: client responsible for connecting to Spark session.

        Returns:
            Dataframe with all the files data.

        """
        schema = (client.read(
            format=self.format,
            options=self.options,
        ).schema if (self.stream and not self.schema) else self.schema)

        return client.read(
            format=self.format,
            options=self.options,
            schema=schema,
            stream=self.stream,
        )