Esempio n. 1
0
 def write(self, feature_set: FeatureSet, dataframe: DataFrame,
           spark_client: SparkClient) -> Any:
     """Write output to single file CSV dataset."""
     path = f"data/datasets/{feature_set.name}"
     spark_client.write_dataframe(
         dataframe=dataframe.coalesce(1),
         format_="csv",
         mode="overwrite",
         path=path,
         header=True,
     )
Esempio n. 2
0
    def test_write_dataframe_invalid_params(
        self, target_df: DataFrame, format: Optional[str], mode: Union[str, int]
    ) -> None:
        # arrange
        spark_client = SparkClient()

        # act and assert
        with pytest.raises(ValueError):
            spark_client.write_dataframe(
                dataframe=target_df, format_=format, mode=mode  # type: ignore
            )
    def test_write_in_debug_mode_with_interval_mode(
        self,
        feature_set_dataframe,
        historical_feature_set_dataframe,
        feature_set,
        spark_session,
        mocker,
    ):
        # given
        spark_client = SparkClient()
        spark_client.write_dataframe = mocker.stub("write_dataframe")
        spark_client.conn.conf.set("spark.sql.sources.partitionOverwriteMode",
                                   "dynamic")
        writer = HistoricalFeatureStoreWriter(debug_mode=True,
                                              interval_mode=True)

        # when
        writer.write(
            feature_set=feature_set,
            dataframe=feature_set_dataframe,
            spark_client=spark_client,
        )
        result_df = spark_session.table(
            f"historical_feature_store__{feature_set.name}")

        # then
        assert_dataframe_equality(historical_feature_set_dataframe, result_df)
Esempio n. 4
0
    def write(
        self,
        feature_set: FeatureSet,
        dataframe: DataFrame,
        spark_client: SparkClient,
    ) -> Union[StreamingQuery, None]:
        """Loads the latest data from a feature set into the Feature Store.

        Args:
            feature_set: object processed with feature set metadata.
            dataframe: Spark dataframe containing data from a feature set.
            spark_client: client for Spark connections with external services.

        Returns:
            Streaming handler if writing streaming df, None otherwise.

        If the debug_mode is set to True, a temporary table with a name in the format:
        `online_feature_store__my_feature_set` will be created instead of writing to
        the real online feature store. If dataframe is streaming this temporary table
        will be updated in real time.

        """
        table_name = feature_set.entity if self.write_to_entity else feature_set.name

        if dataframe.isStreaming:
            dataframe = self._apply_transformations(dataframe)
            if self.debug_mode:
                return self._write_in_debug_mode(
                    table_name=table_name,
                    dataframe=dataframe,
                    spark_client=spark_client,
                )
            return self._write_stream(
                feature_set=feature_set,
                dataframe=dataframe,
                spark_client=spark_client,
                table_name=table_name,
            )

        latest_df = self.filter_latest(dataframe=dataframe,
                                       id_columns=feature_set.keys_columns)

        latest_df = self._apply_transformations(latest_df)

        if self.debug_mode:
            return self._write_in_debug_mode(table_name=table_name,
                                             dataframe=latest_df,
                                             spark_client=spark_client)

        return spark_client.write_dataframe(
            dataframe=latest_df,
            format_=self.db_config.format_,
            mode=self.db_config.mode,
            **self.db_config.get_options(table_name),
        )
Esempio n. 5
0
    def test_flush_with_multiple_online_writers(self, feature_set,
                                                feature_set_dataframe):
        """Testing the flow of writing to a feature-set table and to an entity table."""
        # arrange
        spark_client = SparkClient()
        spark_client.write_dataframe = Mock()

        feature_set.entity = "my_entity"
        feature_set.name = "my_feature_set"

        online_feature_store_writer = OnlineFeatureStoreWriter()

        online_feature_store_writer_on_entity = OnlineFeatureStoreWriter(
            write_to_entity=True)

        sink = Sink(writers=[
            online_feature_store_writer, online_feature_store_writer_on_entity
        ])

        # act
        sink.flush(
            dataframe=feature_set_dataframe,
            feature_set=feature_set,
            spark_client=spark_client,
        )

        # assert
        spark_client.write_dataframe.assert_any_call(
            dataframe=ANY,
            format_=ANY,
            mode=ANY,
            **online_feature_store_writer.db_config.get_options(
                table="my_entity"),
        )

        spark_client.write_dataframe.assert_any_call(
            dataframe=ANY,
            format_=ANY,
            mode=ANY,
            **online_feature_store_writer.db_config.get_options(
                table="my_feature_set"),
        )
Esempio n. 6
0
    def test_write_stream(self, feature_set, has_checkpoint, monkeypatch):
        # arrange
        spark_client = SparkClient()
        spark_client.write_stream = Mock()
        spark_client.write_dataframe = Mock()
        spark_client.write_stream.return_value = Mock(spec=StreamingQuery)

        dataframe = Mock(spec=DataFrame)
        dataframe.isStreaming = True

        if has_checkpoint:
            monkeypatch.setenv("STREAM_CHECKPOINT_PATH", "test")

        cassandra_config = CassandraConfig(keyspace="feature_set")
        target_checkpoint_path = (
            "test/entity/feature_set"
            if cassandra_config.stream_checkpoint_path
            else None
        )

        writer = OnlineFeatureStoreWriter(cassandra_config)
        writer.filter_latest = Mock()

        # act
        stream_handler = writer.write(feature_set, dataframe, spark_client)

        # assert
        assert isinstance(stream_handler, StreamingQuery)
        spark_client.write_stream.assert_any_call(
            dataframe,
            processing_time=cassandra_config.stream_processing_time,
            output_mode=cassandra_config.stream_output_mode,
            checkpoint_path=target_checkpoint_path,
            format_=cassandra_config.format_,
            mode=cassandra_config.mode,
            **cassandra_config.get_options(table=feature_set.name),
        )
        writer.filter_latest.assert_not_called()
        spark_client.write_dataframe.assert_not_called()
    def test_write_interval_mode_invalid_partition_mode(
        self,
        feature_set_dataframe,
        historical_feature_set_dataframe,
        mocker,
        feature_set,
    ):
        # given
        spark_client = SparkClient()
        spark_client.write_dataframe = mocker.stub("write_dataframe")
        spark_client.conn.conf.set("spark.sql.sources.partitionOverwriteMode",
                                   "static")

        writer = HistoricalFeatureStoreWriter(interval_mode=True)

        # when
        with pytest.raises(RuntimeError):
            _ = writer.write(
                feature_set=feature_set,
                dataframe=feature_set_dataframe,
                spark_client=spark_client,
            )
Esempio n. 8
0
 def test_write_dataframe(
     self, format: str, mode: str, mocked_spark_write: Mock
 ) -> None:
     SparkClient.write_dataframe(mocked_spark_write, format, mode)
     mocked_spark_write.save.assert_called_with(format=format, mode=mode)