Esempio n. 1
0
    def test_add_table_partitions(self, mock_spark_sql: Mock):
        # arrange
        target_command = (f"ALTER TABLE `db`.`table` ADD IF NOT EXISTS "
                          f"PARTITION ( year = 2020, month = 8, day = 14 ) "
                          f"PARTITION ( year = 2020, month = 8, day = 15 ) "
                          f"PARTITION ( year = 2020, month = 8, day = 16 )")

        spark_client = SparkClient()
        spark_client._session = mock_spark_sql
        partitions = [
            {
                "year": 2020,
                "month": 8,
                "day": 14
            },
            {
                "year": 2020,
                "month": 8,
                "day": 15
            },
            {
                "year": 2020,
                "month": 8,
                "day": 16
            },
        ]

        # act
        spark_client.add_table_partitions(partitions, "table", "db")

        # assert
        mock_spark_sql.assert_called_once_with(target_command)
    def test_write_in_debug_mode_with_interval_mode(
        self,
        feature_set_dataframe,
        historical_feature_set_dataframe,
        feature_set,
        spark_session,
        mocker,
    ):
        # given
        spark_client = SparkClient()
        spark_client.write_dataframe = mocker.stub("write_dataframe")
        spark_client.conn.conf.set("spark.sql.sources.partitionOverwriteMode",
                                   "dynamic")
        writer = HistoricalFeatureStoreWriter(debug_mode=True,
                                              interval_mode=True)

        # when
        writer.write(
            feature_set=feature_set,
            dataframe=feature_set_dataframe,
            spark_client=spark_client,
        )
        result_df = spark_session.table(
            f"historical_feature_store__{feature_set.name}")

        # then
        assert_dataframe_equality(historical_feature_set_dataframe, result_df)
Esempio n. 3
0
    def test_write_table(
        self,
        format: str,
        mode: str,
        database: str,
        table_name: str,
        path: str,
        mocked_spark_write: Mock,
    ) -> None:
        # given
        name = "{}.{}".format(database, table_name)

        # when
        SparkClient.write_table(
            dataframe=mocked_spark_write,
            database=database,
            table_name=table_name,
            format_=format,
            mode=mode,
            path=path,
        )

        # then
        mocked_spark_write.saveAsTable.assert_called_with(
            mode=mode, format=format, partitionBy=None, name=name, path=path
        )
    def validate(self, feature_set: FeatureSet, dataframe: DataFrame,
                 spark_client: SparkClient) -> None:
        """Calculate dataframe rows to validate data into Feature Store.

        Args:
            feature_set: object processed with feature_set informations.
            dataframe: spark dataframe containing data from a feature set.
            spark_client: client for spark connections with external services.

        Raises:
            AssertionError: if count of written data doesn't match count in current
                feature set dataframe.
        """
        table_name = (
            os.path.join("historical", feature_set.entity, feature_set.name)
            if self.interval_mode and not self.debug_mode else
            (f"{self.database}.{feature_set.name}" if not self.debug_mode else
             f"historical_feature_store__{feature_set.name}"))

        written_count = (spark_client.read(
            self.db_config.format_,
            path=self.db_config.get_path_with_partitions(
                table_name, self._create_partitions(dataframe)),
        ).count() if self.interval_mode and not self.debug_mode else
                         spark_client.read_table(table_name).count())

        dataframe_count = dataframe.count()

        self._assert_validation_count(table_name, written_count,
                                      dataframe_count)
    def test_write_interval_mode(
        self,
        feature_set_dataframe,
        historical_feature_set_dataframe,
        mocker,
        feature_set,
    ):
        # given
        spark_client = SparkClient()
        spark_client.write_table = mocker.stub("write_table")
        spark_client.conn.conf.set("spark.sql.sources.partitionOverwriteMode",
                                   "dynamic")
        writer = HistoricalFeatureStoreWriter(interval_mode=True)

        # when
        writer.write(
            feature_set=feature_set,
            dataframe=feature_set_dataframe,
            spark_client=spark_client,
        )
        result_df = spark_client.write_table.call_args[1]["dataframe"]

        # then
        assert_dataframe_equality(historical_feature_set_dataframe, result_df)

        assert writer.database == spark_client.write_table.call_args[1][
            "database"]
        assert feature_set.name == spark_client.write_table.call_args[1][
            "table_name"]
        assert (writer.PARTITION_BY == spark_client.write_table.call_args[1]
                ["partition_by"])
    def write(
        self, feature_set: FeatureSet, dataframe: DataFrame, spark_client: SparkClient,
    ):
        """Loads the data from a feature set into the Historical Feature Store.

        Args:
            feature_set: object processed with feature_set informations.
            dataframe: spark dataframe containing data from a feature set.
            spark_client: client for spark connections with external services.

        If the debug_mode is set to True, a temporary table with a name in the format:
        historical_feature_store__{feature_set.name} will be created instead of writing
        to the real historical feature store.

        """
        dataframe = self._create_partitions(dataframe)

        if self.debug_mode:
            spark_client.create_temporary_view(
                dataframe=dataframe,
                name=f"historical_feature_store__{feature_set.name}",
            )
            return

        s3_key = os.path.join("historical", feature_set.entity, feature_set.name)
        spark_client.write_table(
            dataframe=dataframe,
            database=self.database,
            table_name=feature_set.name,
            partition_by=self.PARTITION_BY,
            **self.db_config.get_options(s3_key),
        )
Esempio n. 7
0
    def test_write_stream(self, mocked_stream_df: Mock) -> None:
        # arrange
        spark_client = SparkClient()

        processing_time = "0 seconds"
        output_mode = "update"
        checkpoint_path = "s3://path/to/checkpoint"

        # act
        stream_handler = spark_client.write_stream(
            mocked_stream_df,
            processing_time,
            output_mode,
            checkpoint_path,
            format_="parquet",
            mode="append",
        )

        # assert
        assert isinstance(stream_handler, StreamingQuery)
        mocked_stream_df.trigger.assert_called_with(processingTime=processing_time)
        mocked_stream_df.outputMode.assert_called_with(output_mode)
        mocked_stream_df.option.assert_called_with(
            "checkpointLocation", checkpoint_path
        )
        mocked_stream_df.foreachBatch.assert_called_once()
        mocked_stream_df.start.assert_called_once()
    def test_write(
        self,
        feature_set_dataframe,
        historical_feature_set_dataframe,
        mocker,
        feature_set,
    ):
        # given
        spark_client = SparkClient()
        spark_client.write_table = mocker.stub("write_table")
        writer = HistoricalFeatureStoreWriter()

        # when
        writer.write(
            feature_set=feature_set,
            dataframe=feature_set_dataframe,
            spark_client=spark_client,
        )
        result_df = spark_client.write_table.call_args[1]["dataframe"]

        # then
        assert_dataframe_equality(historical_feature_set_dataframe, result_df)

        assert (writer.db_config.format_ ==
                spark_client.write_table.call_args[1]["format_"])
        assert writer.db_config.mode == spark_client.write_table.call_args[1][
            "mode"]
        assert (writer.PARTITION_BY == spark_client.write_table.call_args[1]
                ["partition_by"])
Esempio n. 9
0
    def test_read(
        self,
        format: str,
        stream: bool,
        schema: Optional[StructType],
        path: Any,
        options: Any,
        target_df: DataFrame,
        mocked_spark_read: Mock,
    ) -> None:
        # arrange
        spark_client = SparkClient()
        mocked_spark_read.load.return_value = target_df
        spark_client._session = mocked_spark_read

        # act
        result_df = spark_client.read(format=format,
                                      schema=schema,
                                      stream=stream,
                                      path=path,
                                      **options)

        # assert
        mocked_spark_read.format.assert_called_once_with(format)
        mocked_spark_read.load.assert_called_once_with(path=path, **options)
        assert target_df.collect() == result_df.collect()
Esempio n. 10
0
    def test_get_schema(self, target_df: DataFrame,
                        spark_session: SparkSession) -> None:
        # arrange
        spark_client = SparkClient()
        create_temp_view(dataframe=target_df, name="temp_view")
        create_db_and_table(
            spark=spark_session,
            database="test_db",
            table="test_table",
            view="temp_view",
        )

        expected_schema = [
            {
                "col_name": "col1",
                "data_type": "string"
            },
            {
                "col_name": "col2",
                "data_type": "bigint"
            },
        ]

        # act
        schema = spark_client.get_schema(table="test_table",
                                         database="test_db")

        # assert
        assert schema, expected_schema
Esempio n. 11
0
    def test_read_table_invalid_params(self, database: str,
                                       table: Optional[int]) -> None:
        # arrange
        spark_client = SparkClient()

        # act and assert
        with pytest.raises(ValueError):
            spark_client.read_table(table, database)  # type: ignore
Esempio n. 12
0
    def test_read_invalid_params(self, format: Optional[str],
                                 path: Any) -> None:
        # arrange
        spark_client = SparkClient()

        # act and assert
        with pytest.raises(ValueError):
            spark_client.read(format=format, path=path)  # type: ignore
Esempio n. 13
0
    def test_add_invalid_partitions(self, mock_spark_sql: Mock, partition):
        # arrange
        spark_client = SparkClient()
        spark_client._session = mock_spark_sql

        # act and assert
        with pytest.raises(ValueError):
            spark_client.add_table_partitions(partition, "table", "db")
Esempio n. 14
0
    def test_read_invalid_params(
        self, format: Optional[str], options: Union[Dict[str, Any], str]
    ) -> None:
        # arrange
        spark_client = SparkClient()

        # act and assert
        with pytest.raises(ValueError):
            spark_client.read(format, options)  # type: ignore
Esempio n. 15
0
    def test_sql(self, target_df: DataFrame) -> None:
        # arrange
        spark_client = SparkClient()
        create_temp_view(target_df, "test")

        # act
        result_df = spark_client.sql("select * from test")

        # assert
        assert result_df.collect() == target_df.collect()
Esempio n. 16
0
    def test_create_temporary_view(self, target_df: DataFrame,
                                   spark_session: SparkSession) -> None:
        # arrange
        spark_client = SparkClient()

        # act
        spark_client.create_temporary_view(target_df, "temp_view")
        result_df = spark_session.table("temp_view")

        # assert
        assert_dataframe_equality(target_df, result_df)
Esempio n. 17
0
 def write(self, feature_set: FeatureSet, dataframe: DataFrame,
           spark_client: SparkClient) -> Any:
     """Write output to single file CSV dataset."""
     path = f"data/datasets/{feature_set.name}"
     spark_client.write_dataframe(
         dataframe=dataframe.coalesce(1),
         format_="csv",
         mode="overwrite",
         path=path,
         header=True,
     )
Esempio n. 18
0
    def test_write_dataframe_invalid_params(
        self, target_df: DataFrame, format: Optional[str], mode: Union[str, int]
    ) -> None:
        # arrange
        spark_client = SparkClient()

        # act and assert
        with pytest.raises(ValueError):
            spark_client.write_dataframe(
                dataframe=target_df, format_=format, mode=mode  # type: ignore
            )
Esempio n. 19
0
    def test_write_table_with_invalid_params(
        self, database: Optional[str], table_name: Optional[str], path: Optional[str]
    ) -> None:
        df_writer = "not a spark df writer"

        with pytest.raises(ValueError):
            SparkClient.write_table(
                dataframe=df_writer,  # type: ignore
                database=database,  # type: ignore
                table_name=table_name,  # type: ignore
                path=path,  # type: ignore
            )
    def write(
        self,
        feature_set: FeatureSet,
        dataframe: DataFrame,
        spark_client: SparkClient,
    ) -> None:
        """Loads the data from a feature set into the Historical Feature Store.

        Args:
            feature_set: object processed with feature_set informations.
            dataframe: spark dataframe containing data from a feature set.
            spark_client: client for spark connections with external services.

        If the debug_mode is set to True, a temporary table with a name in the format:
        historical_feature_store__{feature_set.name} will be created instead of writing
        to the real historical feature store.

        """
        dataframe = self._create_partitions(dataframe)

        dataframe = self._apply_transformations(dataframe)

        if self.interval_mode:
            partition_overwrite_mode = spark_client.conn.conf.get(
                "spark.sql.sources.partitionOverwriteMode").lower()

            if partition_overwrite_mode != "dynamic":
                raise RuntimeError(
                    "m=load_incremental_table, "
                    "spark.sql.sources.partitionOverwriteMode={}, "
                    "msg=partitionOverwriteMode have to "
                    "be configured to 'dynamic'".format(
                        partition_overwrite_mode))

        if self.debug_mode:
            spark_client.create_temporary_view(
                dataframe=dataframe,
                name=f"historical_feature_store__{feature_set.name}",
            )
            return

        s3_key = os.path.join("historical", feature_set.entity,
                              feature_set.name)

        spark_client.write_table(
            dataframe=dataframe,
            database=self.database,
            table_name=feature_set.name,
            partition_by=self.PARTITION_BY,
            **self.db_config.get_options(s3_key),
        )
 def test_source_raise(self):
     with pytest.raises(ValueError,
                        match="source must be a Source instance"):
         FeatureSetPipeline(
             spark_client=SparkClient(),
             source=Mock(
                 spark_client=SparkClient(),
                 readers=[
                     TableReader(
                         id="source_a",
                         database="db",
                         table="table",
                     ),
                 ],
                 query="select * from source_a",
             ),
             feature_set=Mock(
                 spec=FeatureSet,
                 name="feature_set",
                 entity="entity",
                 description="description",
                 keys=[
                     KeyFeature(
                         name="user_id",
                         description="The user's Main ID or device ID",
                         dtype=DataType.INTEGER,
                     )
                 ],
                 timestamp=TimestampFeature(from_column="ts"),
                 features=[
                     Feature(
                         name="listing_page_viewed__rent_per_month",
                         description="Average of something.",
                         transformation=SparkFunctionTransform(functions=[
                             Function(functions.avg, DataType.FLOAT),
                             Function(functions.stddev_pop, DataType.FLOAT),
                         ], ).with_window(
                             partition_by="user_id",
                             order_by=TIMESTAMP_COLUMN,
                             window_definition=["7 days", "2 weeks"],
                             mode="fixed_windows",
                         ),
                     ),
                 ],
             ),
             sink=Mock(
                 spec=Sink,
                 writers=[HistoricalFeatureStoreWriter(db_config=None)],
             ),
         )
Esempio n. 22
0
    def test_write_stream_invalid_params(self, mocked_stream_df: Mock) -> None:
        # arrange
        spark_client = SparkClient()
        mocked_stream_df.isStreaming = False

        # act and assert
        with pytest.raises(ValueError):
            spark_client.write_stream(
                mocked_stream_df,
                processing_time="0 seconds",
                output_mode="update",
                checkpoint_path="s3://path/to/checkpoint",
                format_="parquet",
                mode="append",
            )
    def test_feature_transform_with_distinct_empty_subset(
            self, timestamp_c, feature_set_with_distinct_dataframe):
        spark_client = SparkClient()

        with pytest.raises(ValueError,
                           match="The distinct subset param can't be empty."):
            AggregatedFeatureSet(
                name="name",
                entity="entity",
                description="description",
                features=[
                    Feature(
                        name="feature",
                        description="test",
                        transformation=AggregatedTransform(functions=[
                            Function(functions.sum, DataType.INTEGER)
                        ]),
                    ),
                ],
                keys=[
                    KeyFeature(name="h3",
                               description="test",
                               dtype=DataType.STRING)
                ],
                timestamp=timestamp_c,
            ).with_windows(["3 days"]).with_distinct(
                subset=[],
                keep="first").construct(feature_set_with_distinct_dataframe,
                                        spark_client,
                                        end_date="2020-01-10")
    def test_feature_set_with_invalid_feature(self, key_id, timestamp_c,
                                              dataframe):
        spark_client = SparkClient()

        with pytest.raises(ValueError):
            AggregatedFeatureSet(
                name="name",
                entity="entity",
                description="description",
                features=[
                    Feature(
                        name="feature1",
                        description="test",
                        transformation=SparkFunctionTransform(functions=[
                            Function(functions.avg, DataType.FLOAT)
                        ], ).with_window(
                            partition_by="id",
                            mode="row_windows",
                            window_definition=["2 events"],
                        ),
                    ),
                ],
                keys=[key_id],
                timestamp=timestamp_c,
            ).construct(dataframe, spark_client)
    def test_feature_transform_with_distinct(
        self,
        timestamp_c,
        feature_set_with_distinct_dataframe,
        target_with_distinct_dataframe,
    ):
        spark_client = SparkClient()

        fs = (AggregatedFeatureSet(
            name="name",
            entity="entity",
            description="description",
            features=[
                Feature(
                    name="feature",
                    description="test",
                    transformation=AggregatedTransform(
                        functions=[Function(functions.sum, DataType.INTEGER)]),
                ),
            ],
            keys=[
                KeyFeature(name="h3",
                           description="test",
                           dtype=DataType.STRING)
            ],
            timestamp=timestamp_c,
        ).with_windows(["3 days"]).with_distinct(subset=["id"], keep="last"))

        # assert
        output_df = fs.construct(feature_set_with_distinct_dataframe,
                                 spark_client,
                                 end_date="2020-01-10")
        assert_dataframe_equality(output_df, target_with_distinct_dataframe)
Esempio n. 26
0
    def test_csv_file_with_schema_and_header(self):
        # given
        spark_client = SparkClient()
        schema_csv = StructType(
            [
                StructField("A", LongType()),
                StructField("B", DoubleType()),
                StructField("C", StringType()),
            ]
        )

        file = "tests/unit/butterfree/extract/readers/file-reader-test.csv"

        # when
        file_reader = FileReader(
            id="id",
            path=file,
            format="csv",
            schema=schema_csv,
            format_options={"header": True},
        )
        df = file_reader.consume(spark_client)

        # assert
        assert schema_csv == df.schema
        assert df.columns == ["A", "B", "C"]
        for value in range(3):
            assert df.first()[value] != ["A", "B", "C"][value]
Esempio n. 27
0
    def test_construct(
        self, feature_set_dataframe, fixed_windows_output_feature_set_dataframe
    ):
        # given

        spark_client = SparkClient()

        # arrange

        feature_set = FeatureSet(
            name="feature_set",
            entity="entity",
            description="description",
            features=[
                Feature(
                    name="feature1",
                    description="test",
                    transformation=SparkFunctionTransform(
                        functions=[
                            Function(F.avg, DataType.FLOAT),
                            Function(F.stddev_pop, DataType.FLOAT),
                        ]
                    ).with_window(
                        partition_by="id",
                        order_by=TIMESTAMP_COLUMN,
                        mode="fixed_windows",
                        window_definition=["2 minutes", "15 minutes"],
                    ),
                ),
                Feature(
                    name="divided_feature",
                    description="unit test",
                    dtype=DataType.FLOAT,
                    transformation=CustomTransform(
                        transformer=divide, column1="feature1", column2="feature2",
                    ),
                ),
            ],
            keys=[
                KeyFeature(
                    name="id",
                    description="The user's Main ID or device ID",
                    dtype=DataType.INTEGER,
                )
            ],
            timestamp=TimestampFeature(),
        )

        output_df = (
            feature_set.construct(feature_set_dataframe, client=spark_client)
            .orderBy(feature_set.timestamp_column)
            .select(feature_set.columns)
        )

        target_df = fixed_windows_output_feature_set_dataframe.orderBy(
            feature_set.timestamp_column
        ).select(feature_set.columns)

        # assert
        assert_dataframe_equality(output_df, target_df)
Esempio n. 28
0
 def _write_stream(
     self,
     feature_set: FeatureSet,
     dataframe: DataFrame,
     spark_client: SparkClient,
     table_name: str,
 ) -> StreamingQuery:
     """Writes the dataframe in streaming mode."""
     checkpoint_folder = (f"{feature_set.name}__on_entity"
                          if self.write_to_entity else table_name)
     checkpoint_path = (os.path.join(
         self.db_config.stream_checkpoint_path,
         feature_set.entity,
         checkpoint_folder,
     ) if self.db_config.stream_checkpoint_path else None)
     streaming_handler = spark_client.write_stream(
         dataframe,
         processing_time=self.db_config.stream_processing_time,
         output_mode=self.db_config.stream_output_mode,
         checkpoint_path=checkpoint_path,
         format_=self.db_config.format_,
         mode=self.db_config.mode,
         **self.db_config.get_options(table_name),
     )
     return streaming_handler
    def test_h3_feature_set(self, h3_input_df, h3_target_df):
        spark_client = SparkClient()

        feature_set = AggregatedFeatureSet(
            name="h3_test",
            entity="h3geolocation",
            description="Test",
            keys=[
                KeyFeature(
                    name="h3_id",
                    description="The h3 hash ID",
                    dtype=DataType.DOUBLE,
                    transformation=H3HashTransform(
                        h3_resolutions=[6, 7, 8, 9, 10, 11, 12],
                        lat_column="lat",
                        lng_column="lng",
                    ).with_stack(),
                )
            ],
            timestamp=TimestampFeature(),
            features=[
                Feature(
                    name="house_id",
                    description="Count of house ids over a day.",
                    transformation=AggregatedTransform(
                        functions=[Function(F.count, DataType.BIGINT)]),
                ),
            ],
        ).with_windows(definitions=["1 day"])

        output_df = feature_set.construct(h3_input_df,
                                          client=spark_client,
                                          end_date="2016-04-14")

        assert_dataframe_equality(output_df, h3_target_df)
Esempio n. 30
0
    def test_write_in_debug_and_stream_mode(
        self, feature_set, spark_session,
    ):
        # arrange
        spark_client = SparkClient()

        mocked_stream_df = Mock()
        mocked_stream_df.isStreaming = True
        mocked_stream_df.writeStream = mocked_stream_df
        mocked_stream_df.format.return_value = mocked_stream_df
        mocked_stream_df.queryName.return_value = mocked_stream_df
        mocked_stream_df.start.return_value = Mock(spec=StreamingQuery)

        writer = OnlineFeatureStoreWriter(debug_mode=True)

        # act
        handler = writer.write(
            feature_set=feature_set,
            dataframe=mocked_stream_df,
            spark_client=spark_client,
        )

        # assert
        mocked_stream_df.format.assert_called_with("memory")
        mocked_stream_df.queryName.assert_called_with(
            f"online_feature_store__{feature_set.name}"
        )
        assert isinstance(handler, StreamingQuery)