def data_catalog(tmp_path):
    source_path = Path(__file__).parent / "data/test.parquet"
    spark_in = SparkDataSet(source_path.as_posix())
    spark_out = SparkDataSet((tmp_path / "spark_data").as_posix())
    pickle_ds = PickleDataSet((tmp_path / "pickle/test.pkl").as_posix())

    return DataCatalog(
        {"spark_in": spark_in, "spark_out": spark_out, "pickle_ds": pickle_ds}
    )
示例#2
0
    def test_load_exact(self, tmp_path, sample_spark_df):
        ts = generate_timestamp()
        ds_dbfs = SparkDataSet(filepath="/dbfs" + str(tmp_path / FILENAME),
                               version=Version(ts, ts))

        ds_dbfs.save(sample_spark_df)
        reloaded = ds_dbfs.load()

        assert reloaded.exceptAll(sample_spark_df).count() == 0
示例#3
0
    def test_load_exact(self, tmp_path, sample_spark_df):
        ts = generate_timestamp()
        ds_local = SparkDataSet(filepath=(tmp_path / FILENAME).as_posix(),
                                version=Version(ts, ts))

        ds_local.save(sample_spark_df)
        reloaded = ds_local.load()

        assert reloaded.exceptAll(sample_spark_df).count() == 0
    def test_save_overwrite_mode(self, tmp_path, sample_spark_df):
        # Writes a data frame in overwrite mode.
        filepath = (tmp_path / "test_data").as_posix()
        spark_data_set = SparkDataSet(
            filepath=filepath, save_args={"mode": "overwrite"}
        )

        spark_data_set.save(sample_spark_df)
        spark_data_set.save(sample_spark_df)
示例#5
0
 def test_load_options_csv(self, tmp_path, sample_pandas_df):
     filepath = (tmp_path / "data").as_posix()
     local_csv_data_set = CSVDataSet(filepath=filepath)
     local_csv_data_set.save(sample_pandas_df)
     spark_data_set = SparkDataSet(filepath=filepath,
                                   file_format="csv",
                                   load_args={"header": True})
     spark_df = spark_data_set.load()
     assert spark_df.filter(col("Name") == "Alex").count() == 1
示例#6
0
    def test_repr(self, version):
        versioned_hdfs = SparkDataSet(filepath=f"hdfs://{HDFS_PREFIX}",
                                      version=version)
        assert "filepath=hdfs://" in str(versioned_hdfs)
        assert f"version=Version(load=None, save='{version.save}')" in str(
            versioned_hdfs)

        dataset_hdfs = SparkDataSet(filepath=f"hdfs://{HDFS_PREFIX}")
        assert "filepath=hdfs://" in str(dataset_hdfs)
        assert "version=" not in str(dataset_hdfs)
    def test_exists(self, tmp_path, sample_spark_df):
        filepath = (tmp_path / "test_data").as_posix()
        delta_ds = DeltaTableDataSet(filepath=filepath)

        assert not delta_ds.exists()

        spark_delta_ds = SparkDataSet(filepath=filepath, file_format="delta")
        spark_delta_ds.save(sample_spark_df)

        assert delta_ds.exists()
    def test_repr(self, version):
        versioned_hdfs = SparkDataSet(filepath="hdfs://{}".format(HDFS_PREFIX),
                                      version=version)
        assert "filepath=hdfs://" in str(versioned_hdfs)
        assert "version=Version(load=None, save='{}')".format(
            version.save) in str(versioned_hdfs)

        dataset_hdfs = SparkDataSet(filepath="hdfs://{}".format(HDFS_PREFIX))
        assert "filepath=hdfs://" in str(dataset_hdfs)
        assert "version=" not in str(dataset_hdfs)
示例#9
0
 def test_versioning_existing_dataset(self, versioned_dataset_local,
                                      sample_spark_df):
     """Check behavior when attempting to save a versioned dataset on top of an
     already existing (non-versioned) dataset. Note: because SparkDataSet saves to a
     directory even if non-versioned, an error is not expected."""
     spark_data_set = SparkDataSet(
         filepath=versioned_dataset_local._filepath.as_posix())
     spark_data_set.save(sample_spark_df)
     assert spark_data_set.exists()
     versioned_dataset_local.save(sample_spark_df)
     assert versioned_dataset_local.exists()
示例#10
0
    def test_save_version_warning(self, tmp_path, sample_spark_df):
        exact_version = Version("2019-01-01T23.59.59.999Z",
                                "2019-01-02T00.00.00.000Z")
        ds_local = SparkDataSet(filepath=str(tmp_path / FILENAME),
                                version=exact_version)

        pattern = (r"Save version `{ev.save}` did not match load version "
                   r"`{ev.load}` for SparkDataSet\(.+\)".format(
                       ev=exact_version))
        with pytest.warns(UserWarning, match=pattern):
            ds_local.save(sample_spark_df)
示例#11
0
    def test_copy(self):
        spark_dataset = SparkDataSet(filepath="/tmp/data",
                                     save_args={"mode": "overwrite"})
        assert spark_dataset._file_format == "parquet"

        spark_dataset_copy = spark_dataset._copy(_file_format="csv")

        assert spark_dataset is not spark_dataset_copy
        assert spark_dataset._file_format == "parquet"
        assert spark_dataset._save_args == {"mode": "overwrite"}
        assert spark_dataset_copy._file_format == "csv"
        assert spark_dataset_copy._save_args == {"mode": "overwrite"}
示例#12
0
    def test_load_exact(self, mocker):
        ts = generate_timestamp()
        versioned_hdfs = SparkDataSet(filepath="hdfs://{}".format(HDFS_PREFIX),
                                      version=Version(ts, None))
        get_spark = mocker.patch.object(versioned_hdfs, "_get_spark")

        versioned_hdfs.load()

        get_spark.return_value.read.load.assert_called_once_with(
            "hdfs://{fn}/{f}/{v}/{f}".format(fn=FOLDER_NAME, f=FILENAME, v=ts),
            "parquet",
        )
示例#13
0
    def test_exists_raises_error(self, mocker):
        # exists should raise all errors except for
        # AnalysisExceptions clearly indicating a missing file
        spark_data_set = SparkDataSet(filepath="")
        mocker.patch.object(
            spark_data_set,
            "_get_spark",
            side_effect=AnalysisException("Other Exception", []),
        )

        with pytest.raises(DataSetError, match="Other Exception"):
            spark_data_set.exists()
    def test_load(self, tmp_path, sample_spark_df):
        filepath = (tmp_path / "test_data").as_posix()
        spark_delta_ds = SparkDataSet(filepath=filepath, file_format="delta")
        spark_delta_ds.save(sample_spark_df)
        loaded_with_spark = spark_delta_ds.load()
        assert loaded_with_spark.exceptAll(sample_spark_df).count() == 0

        delta_ds = DeltaTableDataSet(filepath=filepath)
        delta_table = delta_ds.load()

        assert isinstance(delta_table, DeltaTable)
        loaded_with_deltalake = delta_table.toDF()
        assert loaded_with_deltalake.exceptAll(loaded_with_spark).count() == 0
示例#15
0
    def test_no_version(self, mocker, version):
        hdfs_walk = mocker.patch(
            "kedro.extras.datasets.spark.spark_dataset.InsecureClient.walk")
        hdfs_walk.return_value = []

        versioned_hdfs = SparkDataSet(filepath="hdfs://{}".format(HDFS_PREFIX),
                                      version=version)

        pattern = r"Did not find any versions for SparkDataSet\(.+\)"
        with pytest.raises(DataSetError, match=pattern):
            versioned_hdfs.load()

        hdfs_walk.assert_called_once_with(HDFS_PREFIX)
示例#16
0
    def test_load_exact(self, mocker):
        ts = generate_timestamp()
        ds_s3 = SparkDataSet(
            filepath="s3a://{}/{}".format(BUCKET_NAME, FILENAME),
            version=Version(ts, None),
            credentials=AWS_CREDENTIALS,
        )
        get_spark = mocker.patch.object(ds_s3, "_get_spark")

        ds_s3.load()

        get_spark.return_value.read.load.assert_called_once_with(
            "s3a://{b}/{f}/{v}/{f}".format(b=BUCKET_NAME, f=FILENAME, v=ts),
            "parquet")
示例#17
0
 def test_hdfs_warning(self, version):
     pattern = (
         "HDFS filesystem support for versioned SparkDataSet is in beta "
         "and uses `hdfs.client.InsecureClient`, please use with caution")
     with pytest.warns(UserWarning, match=pattern):
         SparkDataSet(filepath="hdfs://{}".format(HDFS_PREFIX),
                      version=version)
示例#18
0
 def test_s3n_warning(self, version):
     pattern = (
         "`s3n` filesystem has now been deprecated by Spark, "
         "please consider switching to `s3a`"
     )
     with pytest.warns(DeprecationWarning, match=pattern):
         SparkDataSet(filepath=f"s3n://{BUCKET_NAME}/{FILENAME}", version=version)
示例#19
0
    def test_save_partition(self, tmp_path, sample_spark_df):
        # To verify partitioning this test will partition the data by one
        # of the columns and then check whether partitioned column is added
        # to the save path

        filepath = Path(str(tmp_path / "test_data"))
        spark_data_set = SparkDataSet(
            filepath=filepath.as_posix(),
            save_args={"mode": "overwrite", "partitionBy": ["name"]},
        )

        spark_data_set.save(sample_spark_df)

        expected_path = filepath / "name=Alex"

        assert expected_path.exists()
示例#20
0
    def test_repr(self, versioned_dataset_local, tmp_path, version):
        assert "version=Version(load=None, save='{}')".format(version.save) in str(
            versioned_dataset_local
        )

        dataset_local = SparkDataSet(filepath=str(tmp_path / FILENAME))
        assert "version=" not in str(dataset_local)
示例#21
0
    def test_repr(self, versioned_dataset_local, tmp_path, version):
        assert f"version=Version(load=None, save='{version.save}')" in str(
            versioned_dataset_local
        )

        dataset_local = SparkDataSet(filepath=(tmp_path / FILENAME).as_posix())
        assert "version=" not in str(dataset_local)
示例#22
0
 def test_str_representation(self):
     with tempfile.NamedTemporaryFile() as temp_data_file:
         filepath = Path(temp_data_file.name).as_posix()
         spark_data_set = SparkDataSet(
             filepath=filepath, file_format="csv", load_args={"header": True}
         )
         assert "SparkDataSet" in str(spark_data_set)
         assert f"filepath={filepath}" in str(spark_data_set)
示例#23
0
    def test_repr(self, versioned_dataset_s3, version):
        assert "filepath=s3a://" in str(versioned_dataset_s3)
        assert f"version=Version(load=None, save='{version.save}')" in str(
            versioned_dataset_s3)

        dataset_s3 = SparkDataSet(filepath=f"s3a://{BUCKET_NAME}/{FILENAME}")
        assert "filepath=s3a://" in str(dataset_s3)
        assert "version=" not in str(dataset_s3)
示例#24
0
 def test_str_representation(self):
     with tempfile.NamedTemporaryFile() as temp_data_file:
         spark_data_set = SparkDataSet(
             filepath=temp_data_file.name,
             file_format="csv",
             load_args={"header": True},
         )
         assert "SparkDataSet" in str(spark_data_set)
         assert "filepath={}".format(temp_data_file.name) in str(spark_data_set)
示例#25
0
    def test_repr(self, versioned_dataset_s3, version):
        assert "filepath=s3a://" in str(versioned_dataset_s3)
        assert "version=Version(load=None, save='{}')".format(
            version.save) in str(versioned_dataset_s3)

        dataset_s3 = SparkDataSet(
            filepath="s3a://{}/{}".format(BUCKET_NAME, FILENAME))
        assert "filepath=s3a://" in str(dataset_s3)
        assert "version=" not in str(dataset_s3)
示例#26
0
    def test_ds_init_no_dbutils(self, mocker):
        get_dbutils_mock = mocker.patch(
            "kedro.extras.datasets.spark.spark_dataset._get_dbutils",
            return_value=None)

        data_set = SparkDataSet(filepath="/dbfs/tmp/data")

        get_dbutils_mock.assert_called_once()
        assert data_set._glob_function.__name__ == "iglob"
示例#27
0
    def test_save_parquet(self, tmp_path, sample_spark_df):
        # To cross check the correct Spark save operation we save to
        # a single spark partition and retrieve it with Kedro
        # ParquetDataSet
        temp_dir = Path(str(tmp_path / "test_data"))
        spark_data_set = SparkDataSet(filepath=str(temp_dir),
                                      save_args={"compression": "none"})
        spark_df = sample_spark_df.coalesce(1)
        spark_data_set.save(spark_df)

        single_parquet = [
            f for f in temp_dir.iterdir()
            if f.is_file() and f.name.startswith("part")
        ][0]

        local_parquet_data_set = ParquetDataSet(filepath=str(single_parquet))

        pandas_df = local_parquet_data_set.load()

        assert pandas_df[pandas_df["name"] == "Bob"]["age"].iloc[0] == 12
示例#28
0
    def test_save_version_warning(self, mocker):
        exact_version = Version("2019-01-01T23.59.59.999Z",
                                "2019-01-02T00.00.00.000Z")
        ds_s3 = SparkDataSet(
            filepath="s3a://{}/{}".format(BUCKET_NAME, FILENAME),
            version=exact_version,
            credentials=AWS_CREDENTIALS,
        )
        mocked_spark_df = mocker.Mock()

        pattern = (r"Save version `{ev.save}` did not match load version "
                   r"`{ev.load}` for SparkDataSet\(.+\)".format(
                       ev=exact_version))
        with pytest.warns(UserWarning, match=pattern):
            ds_s3.save(mocked_spark_df)
        mocked_spark_df.write.save.assert_called_once_with(
            "s3a://{b}/{f}/{v}/{f}".format(b=BUCKET_NAME,
                                           f=FILENAME,
                                           v=exact_version.save),
            "parquet",
        )
示例#29
0
    def test_save_version_warning(self, mocker):
        exact_version = Version("2019-01-01T23.59.59.999Z", "2019-01-02T00.00.00.000Z")
        versioned_hdfs = SparkDataSet(
            filepath=f"hdfs://{HDFS_PREFIX}", version=exact_version
        )
        mocker.patch.object(versioned_hdfs, "_exists_function", return_value=False)
        mocked_spark_df = mocker.Mock()

        pattern = (
            r"Save version `{ev.save}` did not match load version "
            r"`{ev.load}` for SparkDataSet\(.+\)".format(ev=exact_version)
        )

        with pytest.warns(UserWarning, match=pattern):
            versioned_hdfs.save(mocked_spark_df)
        mocked_spark_df.write.save.assert_called_once_with(
            "hdfs://{fn}/{f}/{sv}/{f}".format(
                fn=FOLDER_NAME, f=FILENAME, sv=exact_version.save
            ),
            "parquet",
        )
示例#30
0
    def test_file_format_delta_and_unsupported_mode(self, tmp_path, mode):
        filepath = (tmp_path / "test_data").as_posix()
        pattern = (
            f"It is not possible to perform `save()` for file format `delta` "
            f"with mode `{mode}` on `SparkDataSet`. "
            f"Please use `spark.DeltaTableDataSet` instead."
        )

        with pytest.raises(DataSetError, match=re.escape(pattern)):
            _ = SparkDataSet(
                filepath=filepath, file_format="delta", save_args={"mode": mode}
            )