def data_catalog(tmp_path): source_path = Path(__file__).parent / "data/test.parquet" spark_in = SparkDataSet(source_path.as_posix()) spark_out = SparkDataSet((tmp_path / "spark_data").as_posix()) pickle_ds = PickleDataSet((tmp_path / "pickle/test.pkl").as_posix()) return DataCatalog( {"spark_in": spark_in, "spark_out": spark_out, "pickle_ds": pickle_ds} )
def test_load_exact(self, tmp_path, sample_spark_df): ts = generate_timestamp() ds_dbfs = SparkDataSet(filepath="/dbfs" + str(tmp_path / FILENAME), version=Version(ts, ts)) ds_dbfs.save(sample_spark_df) reloaded = ds_dbfs.load() assert reloaded.exceptAll(sample_spark_df).count() == 0
def test_load_exact(self, tmp_path, sample_spark_df): ts = generate_timestamp() ds_local = SparkDataSet(filepath=(tmp_path / FILENAME).as_posix(), version=Version(ts, ts)) ds_local.save(sample_spark_df) reloaded = ds_local.load() assert reloaded.exceptAll(sample_spark_df).count() == 0
def test_save_overwrite_mode(self, tmp_path, sample_spark_df): # Writes a data frame in overwrite mode. filepath = (tmp_path / "test_data").as_posix() spark_data_set = SparkDataSet( filepath=filepath, save_args={"mode": "overwrite"} ) spark_data_set.save(sample_spark_df) spark_data_set.save(sample_spark_df)
def test_load_options_csv(self, tmp_path, sample_pandas_df): filepath = (tmp_path / "data").as_posix() local_csv_data_set = CSVDataSet(filepath=filepath) local_csv_data_set.save(sample_pandas_df) spark_data_set = SparkDataSet(filepath=filepath, file_format="csv", load_args={"header": True}) spark_df = spark_data_set.load() assert spark_df.filter(col("Name") == "Alex").count() == 1
def test_repr(self, version): versioned_hdfs = SparkDataSet(filepath=f"hdfs://{HDFS_PREFIX}", version=version) assert "filepath=hdfs://" in str(versioned_hdfs) assert f"version=Version(load=None, save='{version.save}')" in str( versioned_hdfs) dataset_hdfs = SparkDataSet(filepath=f"hdfs://{HDFS_PREFIX}") assert "filepath=hdfs://" in str(dataset_hdfs) assert "version=" not in str(dataset_hdfs)
def test_exists(self, tmp_path, sample_spark_df): filepath = (tmp_path / "test_data").as_posix() delta_ds = DeltaTableDataSet(filepath=filepath) assert not delta_ds.exists() spark_delta_ds = SparkDataSet(filepath=filepath, file_format="delta") spark_delta_ds.save(sample_spark_df) assert delta_ds.exists()
def test_repr(self, version): versioned_hdfs = SparkDataSet(filepath="hdfs://{}".format(HDFS_PREFIX), version=version) assert "filepath=hdfs://" in str(versioned_hdfs) assert "version=Version(load=None, save='{}')".format( version.save) in str(versioned_hdfs) dataset_hdfs = SparkDataSet(filepath="hdfs://{}".format(HDFS_PREFIX)) assert "filepath=hdfs://" in str(dataset_hdfs) assert "version=" not in str(dataset_hdfs)
def test_versioning_existing_dataset(self, versioned_dataset_local, sample_spark_df): """Check behavior when attempting to save a versioned dataset on top of an already existing (non-versioned) dataset. Note: because SparkDataSet saves to a directory even if non-versioned, an error is not expected.""" spark_data_set = SparkDataSet( filepath=versioned_dataset_local._filepath.as_posix()) spark_data_set.save(sample_spark_df) assert spark_data_set.exists() versioned_dataset_local.save(sample_spark_df) assert versioned_dataset_local.exists()
def test_save_version_warning(self, tmp_path, sample_spark_df): exact_version = Version("2019-01-01T23.59.59.999Z", "2019-01-02T00.00.00.000Z") ds_local = SparkDataSet(filepath=str(tmp_path / FILENAME), version=exact_version) pattern = (r"Save version `{ev.save}` did not match load version " r"`{ev.load}` for SparkDataSet\(.+\)".format( ev=exact_version)) with pytest.warns(UserWarning, match=pattern): ds_local.save(sample_spark_df)
def test_copy(self): spark_dataset = SparkDataSet(filepath="/tmp/data", save_args={"mode": "overwrite"}) assert spark_dataset._file_format == "parquet" spark_dataset_copy = spark_dataset._copy(_file_format="csv") assert spark_dataset is not spark_dataset_copy assert spark_dataset._file_format == "parquet" assert spark_dataset._save_args == {"mode": "overwrite"} assert spark_dataset_copy._file_format == "csv" assert spark_dataset_copy._save_args == {"mode": "overwrite"}
def test_load_exact(self, mocker): ts = generate_timestamp() versioned_hdfs = SparkDataSet(filepath="hdfs://{}".format(HDFS_PREFIX), version=Version(ts, None)) get_spark = mocker.patch.object(versioned_hdfs, "_get_spark") versioned_hdfs.load() get_spark.return_value.read.load.assert_called_once_with( "hdfs://{fn}/{f}/{v}/{f}".format(fn=FOLDER_NAME, f=FILENAME, v=ts), "parquet", )
def test_exists_raises_error(self, mocker): # exists should raise all errors except for # AnalysisExceptions clearly indicating a missing file spark_data_set = SparkDataSet(filepath="") mocker.patch.object( spark_data_set, "_get_spark", side_effect=AnalysisException("Other Exception", []), ) with pytest.raises(DataSetError, match="Other Exception"): spark_data_set.exists()
def test_load(self, tmp_path, sample_spark_df): filepath = (tmp_path / "test_data").as_posix() spark_delta_ds = SparkDataSet(filepath=filepath, file_format="delta") spark_delta_ds.save(sample_spark_df) loaded_with_spark = spark_delta_ds.load() assert loaded_with_spark.exceptAll(sample_spark_df).count() == 0 delta_ds = DeltaTableDataSet(filepath=filepath) delta_table = delta_ds.load() assert isinstance(delta_table, DeltaTable) loaded_with_deltalake = delta_table.toDF() assert loaded_with_deltalake.exceptAll(loaded_with_spark).count() == 0
def test_no_version(self, mocker, version): hdfs_walk = mocker.patch( "kedro.extras.datasets.spark.spark_dataset.InsecureClient.walk") hdfs_walk.return_value = [] versioned_hdfs = SparkDataSet(filepath="hdfs://{}".format(HDFS_PREFIX), version=version) pattern = r"Did not find any versions for SparkDataSet\(.+\)" with pytest.raises(DataSetError, match=pattern): versioned_hdfs.load() hdfs_walk.assert_called_once_with(HDFS_PREFIX)
def test_load_exact(self, mocker): ts = generate_timestamp() ds_s3 = SparkDataSet( filepath="s3a://{}/{}".format(BUCKET_NAME, FILENAME), version=Version(ts, None), credentials=AWS_CREDENTIALS, ) get_spark = mocker.patch.object(ds_s3, "_get_spark") ds_s3.load() get_spark.return_value.read.load.assert_called_once_with( "s3a://{b}/{f}/{v}/{f}".format(b=BUCKET_NAME, f=FILENAME, v=ts), "parquet")
def test_hdfs_warning(self, version): pattern = ( "HDFS filesystem support for versioned SparkDataSet is in beta " "and uses `hdfs.client.InsecureClient`, please use with caution") with pytest.warns(UserWarning, match=pattern): SparkDataSet(filepath="hdfs://{}".format(HDFS_PREFIX), version=version)
def test_s3n_warning(self, version): pattern = ( "`s3n` filesystem has now been deprecated by Spark, " "please consider switching to `s3a`" ) with pytest.warns(DeprecationWarning, match=pattern): SparkDataSet(filepath=f"s3n://{BUCKET_NAME}/{FILENAME}", version=version)
def test_save_partition(self, tmp_path, sample_spark_df): # To verify partitioning this test will partition the data by one # of the columns and then check whether partitioned column is added # to the save path filepath = Path(str(tmp_path / "test_data")) spark_data_set = SparkDataSet( filepath=filepath.as_posix(), save_args={"mode": "overwrite", "partitionBy": ["name"]}, ) spark_data_set.save(sample_spark_df) expected_path = filepath / "name=Alex" assert expected_path.exists()
def test_repr(self, versioned_dataset_local, tmp_path, version): assert "version=Version(load=None, save='{}')".format(version.save) in str( versioned_dataset_local ) dataset_local = SparkDataSet(filepath=str(tmp_path / FILENAME)) assert "version=" not in str(dataset_local)
def test_repr(self, versioned_dataset_local, tmp_path, version): assert f"version=Version(load=None, save='{version.save}')" in str( versioned_dataset_local ) dataset_local = SparkDataSet(filepath=(tmp_path / FILENAME).as_posix()) assert "version=" not in str(dataset_local)
def test_str_representation(self): with tempfile.NamedTemporaryFile() as temp_data_file: filepath = Path(temp_data_file.name).as_posix() spark_data_set = SparkDataSet( filepath=filepath, file_format="csv", load_args={"header": True} ) assert "SparkDataSet" in str(spark_data_set) assert f"filepath={filepath}" in str(spark_data_set)
def test_repr(self, versioned_dataset_s3, version): assert "filepath=s3a://" in str(versioned_dataset_s3) assert f"version=Version(load=None, save='{version.save}')" in str( versioned_dataset_s3) dataset_s3 = SparkDataSet(filepath=f"s3a://{BUCKET_NAME}/{FILENAME}") assert "filepath=s3a://" in str(dataset_s3) assert "version=" not in str(dataset_s3)
def test_str_representation(self): with tempfile.NamedTemporaryFile() as temp_data_file: spark_data_set = SparkDataSet( filepath=temp_data_file.name, file_format="csv", load_args={"header": True}, ) assert "SparkDataSet" in str(spark_data_set) assert "filepath={}".format(temp_data_file.name) in str(spark_data_set)
def test_repr(self, versioned_dataset_s3, version): assert "filepath=s3a://" in str(versioned_dataset_s3) assert "version=Version(load=None, save='{}')".format( version.save) in str(versioned_dataset_s3) dataset_s3 = SparkDataSet( filepath="s3a://{}/{}".format(BUCKET_NAME, FILENAME)) assert "filepath=s3a://" in str(dataset_s3) assert "version=" not in str(dataset_s3)
def test_ds_init_no_dbutils(self, mocker): get_dbutils_mock = mocker.patch( "kedro.extras.datasets.spark.spark_dataset._get_dbutils", return_value=None) data_set = SparkDataSet(filepath="/dbfs/tmp/data") get_dbutils_mock.assert_called_once() assert data_set._glob_function.__name__ == "iglob"
def test_save_parquet(self, tmp_path, sample_spark_df): # To cross check the correct Spark save operation we save to # a single spark partition and retrieve it with Kedro # ParquetDataSet temp_dir = Path(str(tmp_path / "test_data")) spark_data_set = SparkDataSet(filepath=str(temp_dir), save_args={"compression": "none"}) spark_df = sample_spark_df.coalesce(1) spark_data_set.save(spark_df) single_parquet = [ f for f in temp_dir.iterdir() if f.is_file() and f.name.startswith("part") ][0] local_parquet_data_set = ParquetDataSet(filepath=str(single_parquet)) pandas_df = local_parquet_data_set.load() assert pandas_df[pandas_df["name"] == "Bob"]["age"].iloc[0] == 12
def test_save_version_warning(self, mocker): exact_version = Version("2019-01-01T23.59.59.999Z", "2019-01-02T00.00.00.000Z") ds_s3 = SparkDataSet( filepath="s3a://{}/{}".format(BUCKET_NAME, FILENAME), version=exact_version, credentials=AWS_CREDENTIALS, ) mocked_spark_df = mocker.Mock() pattern = (r"Save version `{ev.save}` did not match load version " r"`{ev.load}` for SparkDataSet\(.+\)".format( ev=exact_version)) with pytest.warns(UserWarning, match=pattern): ds_s3.save(mocked_spark_df) mocked_spark_df.write.save.assert_called_once_with( "s3a://{b}/{f}/{v}/{f}".format(b=BUCKET_NAME, f=FILENAME, v=exact_version.save), "parquet", )
def test_save_version_warning(self, mocker): exact_version = Version("2019-01-01T23.59.59.999Z", "2019-01-02T00.00.00.000Z") versioned_hdfs = SparkDataSet( filepath=f"hdfs://{HDFS_PREFIX}", version=exact_version ) mocker.patch.object(versioned_hdfs, "_exists_function", return_value=False) mocked_spark_df = mocker.Mock() pattern = ( r"Save version `{ev.save}` did not match load version " r"`{ev.load}` for SparkDataSet\(.+\)".format(ev=exact_version) ) with pytest.warns(UserWarning, match=pattern): versioned_hdfs.save(mocked_spark_df) mocked_spark_df.write.save.assert_called_once_with( "hdfs://{fn}/{f}/{sv}/{f}".format( fn=FOLDER_NAME, f=FILENAME, sv=exact_version.save ), "parquet", )
def test_file_format_delta_and_unsupported_mode(self, tmp_path, mode): filepath = (tmp_path / "test_data").as_posix() pattern = ( f"It is not possible to perform `save()` for file format `delta` " f"with mode `{mode}` on `SparkDataSet`. " f"Please use `spark.DeltaTableDataSet` instead." ) with pytest.raises(DataSetError, match=re.escape(pattern)): _ = SparkDataSet( filepath=filepath, file_format="delta", save_args={"mode": mode} )