def test_save_version_warning(self, mocker): exact_version = Version("2019-01-01T23.59.59.999Z", "2019-01-02T00.00.00.000Z") ds_s3 = SparkDataSet( filepath=f"s3a://{BUCKET_NAME}/{FILENAME}", version=exact_version, credentials=AWS_CREDENTIALS, ) mocked_spark_df = mocker.Mock() pattern = (r"Save version `{ev.save}` did not match load version " r"`{ev.load}` for SparkDataSet\(.+\)".format( ev=exact_version)) with pytest.warns(UserWarning, match=pattern): ds_s3.save(mocked_spark_df) mocked_spark_df.write.save.assert_called_once_with( "s3a://{b}/{f}/{v}/{f}".format(b=BUCKET_NAME, f=FILENAME, v=exact_version.save), "parquet", )
def test_save_parquet(self, tmp_path, sample_spark_df): # To cross check the correct Spark save operation we save to # a single spark partition and retrieve it with Kedro # ParquetDataSet temp_dir = Path(str(tmp_path / "test_data")) spark_data_set = SparkDataSet(filepath=temp_dir.as_posix(), save_args={"compression": "none"}) spark_df = sample_spark_df.coalesce(1) spark_data_set.save(spark_df) single_parquet = [ f for f in temp_dir.iterdir() if f.is_file() and f.name.startswith("part") ][0] local_parquet_data_set = ParquetDataSet( filepath=single_parquet.as_posix()) pandas_df = local_parquet_data_set.load() assert pandas_df[pandas_df["name"] == "Bob"]["age"].iloc[0] == 12
def test_save_version_warning(self, mocker): exact_version = Version("2019-01-01T23.59.59.999Z", "2019-01-02T00.00.00.000Z") versioned_hdfs = SparkDataSet( filepath=f"hdfs://{HDFS_PREFIX}", version=exact_version ) mocker.patch.object(versioned_hdfs, "_exists_function", return_value=False) mocked_spark_df = mocker.Mock() pattern = ( r"Save version `{ev.save}` did not match load version " r"`{ev.load}` for SparkDataSet\(.+\)".format(ev=exact_version) ) with pytest.warns(UserWarning, match=pattern): versioned_hdfs.save(mocked_spark_df) mocked_spark_df.write.save.assert_called_once_with( "hdfs://{fn}/{f}/{sv}/{f}".format( fn=FOLDER_NAME, f=FILENAME, sv=exact_version.save ), "parquet", )
def test_prevent_overwrite(self, mocker, version): hdfs_status = mocker.patch( "kedro.extras.datasets.spark.spark_dataset.InsecureClient.status") hdfs_status.return_value = True versioned_hdfs = SparkDataSet(filepath="hdfs://{}".format(HDFS_PREFIX), version=version) mocked_spark_df = mocker.Mock() pattern = (r"Save path `.+` for SparkDataSet\(.+\) must not exist " r"if versioning is enabled") with pytest.raises(DataSetError, match=pattern): versioned_hdfs.save(mocked_spark_df) hdfs_status.assert_called_once_with( "{fn}/{f}/{v}/{f}".format(fn=FOLDER_NAME, v=version.save, f=FILENAME), strict=False, ) mocked_spark_df.write.save.assert_not_called()
def test_save_options_csv(self, tmp_path, sample_spark_df): # To cross check the correct Spark save operation we save to # a single spark partition with csv format and retrieve it with Kedro # CSVDataSet temp_dir = Path(str(tmp_path / "test_data")) spark_data_set = SparkDataSet( filepath=temp_dir.as_posix(), file_format="csv", save_args={"sep": "|", "header": True}, ) spark_df = sample_spark_df.coalesce(1) spark_data_set.save(spark_df) single_csv_file = [ f for f in temp_dir.iterdir() if f.is_file() and f.suffix == ".csv" ][0] csv_local_data_set = CSVDataSet( filepath=single_csv_file.as_posix(), load_args={"sep": "|"} ) pandas_df = csv_local_data_set.load() assert pandas_df[pandas_df["name"] == "Alex"]["age"][0] == 31
def spark_in(tmp_path, sample_spark_df): spark_in = SparkDataSet(filepath=str(tmp_path / "input")) spark_in.save(sample_spark_df) return spark_in