コード例 #1
0
    def write(
        self, feature_set: FeatureSet, dataframe: DataFrame, spark_client: SparkClient,
    ):
        """Loads the data from a feature set into the Historical Feature Store.

        Args:
            feature_set: object processed with feature_set informations.
            dataframe: spark dataframe containing data from a feature set.
            spark_client: client for spark connections with external services.

        If the debug_mode is set to True, a temporary table with a name in the format:
        historical_feature_store__{feature_set.name} will be created instead of writing
        to the real historical feature store.

        """
        dataframe = self._create_partitions(dataframe)

        if self.debug_mode:
            spark_client.create_temporary_view(
                dataframe=dataframe,
                name=f"historical_feature_store__{feature_set.name}",
            )
            return

        s3_key = os.path.join("historical", feature_set.entity, feature_set.name)
        spark_client.write_table(
            dataframe=dataframe,
            database=self.database,
            table_name=feature_set.name,
            partition_by=self.PARTITION_BY,
            **self.db_config.get_options(s3_key),
        )
コード例 #2
0
    def test_write_table(
        self,
        format: str,
        mode: str,
        database: str,
        table_name: str,
        path: str,
        mocked_spark_write: Mock,
    ) -> None:
        # given
        name = "{}.{}".format(database, table_name)

        # when
        SparkClient.write_table(
            dataframe=mocked_spark_write,
            database=database,
            table_name=table_name,
            format_=format,
            mode=mode,
            path=path,
        )

        # then
        mocked_spark_write.saveAsTable.assert_called_with(
            mode=mode, format=format, partitionBy=None, name=name, path=path
        )
コード例 #3
0
    def test_write_table_with_invalid_params(
        self, database: Optional[str], table_name: Optional[str], path: Optional[str]
    ) -> None:
        df_writer = "not a spark df writer"

        with pytest.raises(ValueError):
            SparkClient.write_table(
                dataframe=df_writer,  # type: ignore
                database=database,  # type: ignore
                table_name=table_name,  # type: ignore
                path=path,  # type: ignore
            )
コード例 #4
0
    def write(
        self,
        feature_set: FeatureSet,
        dataframe: DataFrame,
        spark_client: SparkClient,
    ) -> None:
        """Loads the data from a feature set into the Historical Feature Store.

        Args:
            feature_set: object processed with feature_set informations.
            dataframe: spark dataframe containing data from a feature set.
            spark_client: client for spark connections with external services.

        If the debug_mode is set to True, a temporary table with a name in the format:
        historical_feature_store__{feature_set.name} will be created instead of writing
        to the real historical feature store.

        """
        dataframe = self._create_partitions(dataframe)

        dataframe = self._apply_transformations(dataframe)

        if self.interval_mode:
            partition_overwrite_mode = spark_client.conn.conf.get(
                "spark.sql.sources.partitionOverwriteMode").lower()

            if partition_overwrite_mode != "dynamic":
                raise RuntimeError(
                    "m=load_incremental_table, "
                    "spark.sql.sources.partitionOverwriteMode={}, "
                    "msg=partitionOverwriteMode have to "
                    "be configured to 'dynamic'".format(
                        partition_overwrite_mode))

        if self.debug_mode:
            spark_client.create_temporary_view(
                dataframe=dataframe,
                name=f"historical_feature_store__{feature_set.name}",
            )
            return

        s3_key = os.path.join("historical", feature_set.entity,
                              feature_set.name)

        spark_client.write_table(
            dataframe=dataframe,
            database=self.database,
            table_name=feature_set.name,
            partition_by=self.PARTITION_BY,
            **self.db_config.get_options(s3_key),
        )
コード例 #5
0
    def test_write_interval_mode(
        self,
        feature_set_dataframe,
        historical_feature_set_dataframe,
        mocker,
        feature_set,
    ):
        # given
        spark_client = SparkClient()
        spark_client.write_table = mocker.stub("write_table")
        spark_client.conn.conf.set("spark.sql.sources.partitionOverwriteMode",
                                   "dynamic")
        writer = HistoricalFeatureStoreWriter(interval_mode=True)

        # when
        writer.write(
            feature_set=feature_set,
            dataframe=feature_set_dataframe,
            spark_client=spark_client,
        )
        result_df = spark_client.write_table.call_args[1]["dataframe"]

        # then
        assert_dataframe_equality(historical_feature_set_dataframe, result_df)

        assert writer.database == spark_client.write_table.call_args[1][
            "database"]
        assert feature_set.name == spark_client.write_table.call_args[1][
            "table_name"]
        assert (writer.PARTITION_BY == spark_client.write_table.call_args[1]
                ["partition_by"])
コード例 #6
0
    def test_write(
        self,
        feature_set_dataframe,
        historical_feature_set_dataframe,
        mocker,
        feature_set,
    ):
        # given
        spark_client = SparkClient()
        spark_client.write_table = mocker.stub("write_table")
        writer = HistoricalFeatureStoreWriter()

        # when
        writer.write(
            feature_set=feature_set,
            dataframe=feature_set_dataframe,
            spark_client=spark_client,
        )
        result_df = spark_client.write_table.call_args[1]["dataframe"]

        # then
        assert_dataframe_equality(historical_feature_set_dataframe, result_df)

        assert (writer.db_config.format_ ==
                spark_client.write_table.call_args[1]["format_"])
        assert writer.db_config.mode == spark_client.write_table.call_args[1][
            "mode"]
        assert (writer.PARTITION_BY == spark_client.write_table.call_args[1]
                ["partition_by"])