Beispiel #1
0
    def test_save_version_warning(self, mocker):
        exact_version = Version("2019-01-01T23.59.59.999Z",
                                "2019-01-02T00.00.00.000Z")
        ds_s3 = SparkDataSet(
            filepath=f"s3a://{BUCKET_NAME}/{FILENAME}",
            version=exact_version,
            credentials=AWS_CREDENTIALS,
        )
        mocked_spark_df = mocker.Mock()

        pattern = (r"Save version `{ev.save}` did not match load version "
                   r"`{ev.load}` for SparkDataSet\(.+\)".format(
                       ev=exact_version))
        with pytest.warns(UserWarning, match=pattern):
            ds_s3.save(mocked_spark_df)
        mocked_spark_df.write.save.assert_called_once_with(
            "s3a://{b}/{f}/{v}/{f}".format(b=BUCKET_NAME,
                                           f=FILENAME,
                                           v=exact_version.save),
            "parquet",
        )
Beispiel #2
0
    def test_save_parquet(self, tmp_path, sample_spark_df):
        # To cross check the correct Spark save operation we save to
        # a single spark partition and retrieve it with Kedro
        # ParquetDataSet
        temp_dir = Path(str(tmp_path / "test_data"))
        spark_data_set = SparkDataSet(filepath=temp_dir.as_posix(),
                                      save_args={"compression": "none"})
        spark_df = sample_spark_df.coalesce(1)
        spark_data_set.save(spark_df)

        single_parquet = [
            f for f in temp_dir.iterdir()
            if f.is_file() and f.name.startswith("part")
        ][0]

        local_parquet_data_set = ParquetDataSet(
            filepath=single_parquet.as_posix())

        pandas_df = local_parquet_data_set.load()

        assert pandas_df[pandas_df["name"] == "Bob"]["age"].iloc[0] == 12
    def test_save_version_warning(self, mocker):
        exact_version = Version("2019-01-01T23.59.59.999Z", "2019-01-02T00.00.00.000Z")
        versioned_hdfs = SparkDataSet(
            filepath=f"hdfs://{HDFS_PREFIX}", version=exact_version
        )
        mocker.patch.object(versioned_hdfs, "_exists_function", return_value=False)
        mocked_spark_df = mocker.Mock()

        pattern = (
            r"Save version `{ev.save}` did not match load version "
            r"`{ev.load}` for SparkDataSet\(.+\)".format(ev=exact_version)
        )

        with pytest.warns(UserWarning, match=pattern):
            versioned_hdfs.save(mocked_spark_df)
        mocked_spark_df.write.save.assert_called_once_with(
            "hdfs://{fn}/{f}/{sv}/{f}".format(
                fn=FOLDER_NAME, f=FILENAME, sv=exact_version.save
            ),
            "parquet",
        )
    def test_prevent_overwrite(self, mocker, version):
        hdfs_status = mocker.patch(
            "kedro.extras.datasets.spark.spark_dataset.InsecureClient.status")
        hdfs_status.return_value = True

        versioned_hdfs = SparkDataSet(filepath="hdfs://{}".format(HDFS_PREFIX),
                                      version=version)

        mocked_spark_df = mocker.Mock()

        pattern = (r"Save path `.+` for SparkDataSet\(.+\) must not exist "
                   r"if versioning is enabled")
        with pytest.raises(DataSetError, match=pattern):
            versioned_hdfs.save(mocked_spark_df)

        hdfs_status.assert_called_once_with(
            "{fn}/{f}/{v}/{f}".format(fn=FOLDER_NAME,
                                      v=version.save,
                                      f=FILENAME),
            strict=False,
        )
        mocked_spark_df.write.save.assert_not_called()
    def test_save_options_csv(self, tmp_path, sample_spark_df):
        # To cross check the correct Spark save operation we save to
        # a single spark partition with csv format and retrieve it with Kedro
        # CSVDataSet
        temp_dir = Path(str(tmp_path / "test_data"))
        spark_data_set = SparkDataSet(
            filepath=temp_dir.as_posix(),
            file_format="csv",
            save_args={"sep": "|", "header": True},
        )
        spark_df = sample_spark_df.coalesce(1)
        spark_data_set.save(spark_df)

        single_csv_file = [
            f for f in temp_dir.iterdir() if f.is_file() and f.suffix == ".csv"
        ][0]

        csv_local_data_set = CSVDataSet(
            filepath=single_csv_file.as_posix(), load_args={"sep": "|"}
        )
        pandas_df = csv_local_data_set.load()

        assert pandas_df[pandas_df["name"] == "Alex"]["age"][0] == 31
def spark_in(tmp_path, sample_spark_df):
    spark_in = SparkDataSet(filepath=str(tmp_path / "input"))
    spark_in.save(sample_spark_df)
    return spark_in