def test_save_options_csv(): # To cross check the correct Spark save operation we save to # a single spark partition with csv format and retrieve it with Kedro # CSVLocalDataSet with tempfile.TemporaryDirectory() as temp_dir: temp_path = join(temp_dir, "test_data") spark_data_set = SparkDataSet( filepath=temp_path, file_format="csv", save_args={ "sep": "|", "header": True }, ) spark_df = _get_sample_spark_data_frame().coalesce(1) spark_data_set.save(spark_df) single_csv_file = [ join(temp_path, f) for f in listdir(temp_path) if f.endswith("csv") ][0] csv_local_data_set = CSVLocalDataSet(filepath=single_csv_file, load_args={"sep": "|"}) pandas_df = csv_local_data_set.load() assert pandas_df[pandas_df["name"] == "Alex"]["age"][0] == 31
def test_save(self, mocker, version): hdfs_status = mocker.patch( "kedro.contrib.io.pyspark.spark_data_set.InsecureClient.status") hdfs_status.return_value = None versioned_hdfs = SparkDataSet(filepath="hdfs://{}".format(HDFS_PREFIX), version=version) mocked_load_version = mocker.MagicMock() mocked_load_version.__eq__.return_value = True # need to mock _get_load_path() call inside _save() # also need _get_load_path() to return a load version that # _check_paths_consistency() will be happy with (hence mocking __eq__) mocker.patch.object(versioned_hdfs, "_get_load_path", return_value=mocked_load_version) mocked_spark_df = mocker.Mock() versioned_hdfs.save(mocked_spark_df) hdfs_status.assert_called_once_with( "{fn}/{f}/{v}/{f}".format(fn=FOLDER_NAME, v=version.save, f=FILENAME), strict=False, ) mocked_spark_df.write.save.assert_called_once_with( "hdfs://{fn}/{f}/{v}/{f}".format(fn=FOLDER_NAME, v=version.save, f=FILENAME), "parquet", )
def test_save(self, mocker, version): hdfs_status = mocker.patch( "kedro.contrib.io.pyspark.spark_data_set.InsecureClient.status" ) hdfs_status.return_value = None versioned_hdfs = SparkDataSet( filepath="hdfs://{}".format(HDFS_PREFIX), version=version ) # need _lookup_load_version() call to return a load version that # matches save version due to consistency check in versioned_hdfs.save() mocker.patch.object( versioned_hdfs, "_lookup_load_version", return_value=version.save ) mocked_spark_df = mocker.Mock() versioned_hdfs.save(mocked_spark_df) hdfs_status.assert_called_once_with( "{fn}/{f}/{v}/{f}".format(fn=FOLDER_NAME, v=version.save, f=FILENAME), strict=False, ) mocked_spark_df.write.save.assert_called_once_with( "hdfs://{fn}/{f}/{v}/{f}".format( fn=FOLDER_NAME, v=version.save, f=FILENAME ), "parquet", )
def test_save_version_warning(self, mocker): exact_version = Version("2019-01-01T23.59.59.999Z", "2019-01-02T00.00.00.000Z") versioned_hdfs = SparkDataSet(filepath="hdfs://{}".format(HDFS_PREFIX), version=exact_version) mocker.patch.object(versioned_hdfs, "_exists_function", return_value=False) mocked_spark_df = mocker.Mock() pattern = (r"Save path `{fn}/{f}/{sv}/{f}` did not match load path " r"`{fn}/{f}/{lv}/{f}` for SparkDataSet\(.+\)".format( fn=FOLDER_NAME, f=FILENAME, sv=exact_version.save, lv=exact_version.load)) with pytest.warns(UserWarning, match=pattern): versioned_hdfs.save(mocked_spark_df) mocked_spark_df.write.save.assert_called_once_with( "hdfs://{fn}/{f}/{sv}/{f}".format(fn=FOLDER_NAME, f=FILENAME, sv=exact_version.save), "parquet", )
def test_save_options_csv(self, tmp_path, sample_spark_df): # To cross check the correct Spark save operation we save to # a single spark partition with csv format and retrieve it with Kedro # CSVLocalDataSet temp_dir = Path(str(tmp_path / "test_data")) spark_data_set = SparkDataSet( filepath=str(temp_dir), file_format="csv", save_args={ "sep": "|", "header": True }, ) spark_df = sample_spark_df.coalesce(1) spark_data_set.save(spark_df) single_csv_file = [ f for f in temp_dir.iterdir() if f.is_file() and f.suffix == ".csv" ][0] csv_local_data_set = CSVLocalDataSet(filepath=str(single_csv_file), load_args={"sep": "|"}) pandas_df = csv_local_data_set.load() assert pandas_df[pandas_df["name"] == "Alex"]["age"][0] == 31
def test_save_overwrite_fail(self, tmp_path, sample_spark_df): # Writes a data frame twice and expects it to fail. filepath = str(tmp_path / "test_data") spark_data_set = SparkDataSet(filepath=filepath) spark_data_set.save(sample_spark_df) with pytest.raises(DataSetError): spark_data_set.save(sample_spark_df)
def test_save_overwrite_mode(self, tmp_path, sample_spark_df): # Writes a data frame in overwrite mode. filepath = str(tmp_path / "test_data") spark_data_set = SparkDataSet(filepath=filepath, save_args={"mode": "overwrite"}) spark_data_set.save(sample_spark_df) spark_data_set.save(sample_spark_df)
def test_exists(self, file_format, tmp_path, sample_spark_df): filepath = str(tmp_path / "test_data") spark_data_set = SparkDataSet(filepath=filepath, file_format=file_format) assert not spark_data_set.exists() spark_data_set.save(sample_spark_df) assert spark_data_set.exists()
def test_load_exact(self, tmp_path, sample_spark_df): ts = generate_timestamp() ds_local = SparkDataSet(filepath=str(tmp_path / FILENAME), version=Version(ts, ts)) ds_local.save(sample_spark_df) reloaded = ds_local.load() assert reloaded.exceptAll(sample_spark_df).count() == 0
def test_exists(file_format): with tempfile.TemporaryDirectory() as temp_dir: temp_path = join(temp_dir, "test_data") spark_data_set = SparkDataSet(filepath=temp_path, file_format=file_format) spark_df = _get_sample_spark_data_frame().coalesce(1) assert not spark_data_set.exists() spark_data_set.save(spark_df) assert spark_data_set.exists()
def test_save_version_warning(self, tmp_path, sample_spark_df): exact_version = Version("2019-01-01T23.59.59.999Z", "2019-01-02T00.00.00.000Z") ds_local = SparkDataSet(filepath=str(tmp_path / FILENAME), version=exact_version) pattern = (r"Save version `{ev.save}` did not match load version " r"`{ev.load}` for SparkDataSet\(.+\)".format( ev=exact_version)) with pytest.warns(UserWarning, match=pattern): ds_local.save(sample_spark_df)
def test_prevent_overwrite(self, tmp_path, version, sample_spark_df): versioned_local = SparkDataSet( filepath=str(tmp_path / FILENAME), version=version, # second save should fail even in overwrite mode save_args={"mode": "overwrite"}, ) versioned_local.save(sample_spark_df) pattern = (r"Save path `.+` for SparkDataSet\(.+\) must not exist " r"if versioning is enabled") with pytest.raises(DataSetError, match=pattern): versioned_local.save(sample_spark_df)
def test_save_partition(self, tmp_path, sample_spark_df): # To verify partitioning this test will partition the data by one # of the columns and then check whether partitioned column is added # to the save path filepath = Path(str(tmp_path / "test_data")) spark_data_set = SparkDataSet( filepath=str(filepath), save_args={"mode": "overwrite", "partitionBy": ["name"]}, ) spark_data_set.save(sample_spark_df) expected_path = filepath / "name=Alex" assert expected_path.exists()
def test_save_partition(): # To verify partitioning this test with partition data # and checked whether paritioned column is added to the # path with tempfile.TemporaryDirectory() as temp_dir: temp_path = join(temp_dir, "test_data") spark_data_set = SparkDataSet(filepath=temp_path, save_args={ "mode": "overwrite", "partitionBy": ["name"] }) spark_df = _get_sample_spark_data_frame() spark_data_set.save(spark_df) expected_path = join(temp_path, "name=Alex") assert exists(expected_path)
def test_save_parquet(self, tmp_path, sample_spark_df): # To cross check the correct Spark save operation we save to # a single spark partition and retrieve it with Kedro # ParquetLocalDataSet temp_dir = Path(str(tmp_path / "test_data")) spark_data_set = SparkDataSet( filepath=str(temp_dir), save_args={"compression": "none"} ) spark_df = sample_spark_df.coalesce(1) spark_data_set.save(spark_df) single_parquet = [ f for f in temp_dir.iterdir() if f.is_file() and f.name.startswith("part") ][0] local_parquet_data_set = ParquetLocalDataSet(filepath=str(single_parquet)) pandas_df = local_parquet_data_set.load() assert pandas_df[pandas_df["name"] == "Bob"]["age"].iloc[0] == 12
def test_save_version_warning(self, mocker): exact_version = Version("2019-01-01T23.59.59.999Z", "2019-01-02T00.00.00.000Z") ds_s3 = SparkDataSet( filepath="s3a://{}/{}".format(BUCKET_NAME, FILENAME), version=exact_version, credentials=AWS_CREDENTIALS, ) mocked_spark_df = mocker.Mock() pattern = ( r"Save version `{ev.save}` did not match load version " r"`{ev.load}` for SparkDataSet\(.+\)".format(ev=exact_version) ) with pytest.warns(UserWarning, match=pattern): ds_s3.save(mocked_spark_df) mocked_spark_df.write.save.assert_called_once_with( "s3a://{b}/{f}/{v}/{f}".format( b=BUCKET_NAME, f=FILENAME, v=exact_version.save ), "parquet", )
def test_save_parquet(): # To cross check the correct Spark save operation we save to # a single spark partition and retrieve it with Kedro # ParquetLocalDataSet with tempfile.TemporaryDirectory() as temp_dir: temp_path = join(temp_dir, "test_data") spark_data_set = SparkDataSet(filepath=temp_path, save_args={"compression": "none"}) spark_df = _get_sample_spark_data_frame().coalesce(1) spark_data_set.save(spark_df) single_parquet = [ join(temp_path, f) for f in listdir(temp_path) if f.startswith("part") ][0] local_parquet_data_set = ParquetLocalDataSet(filepath=single_parquet) pandas_df = local_parquet_data_set.load() assert pandas_df[pandas_df["name"] == "Bob"]["age"].iloc[0] == 12
def test_prevent_overwrite(self, mocker, version): hdfs_status = mocker.patch( "kedro.contrib.io.pyspark.spark_data_set.InsecureClient.status") hdfs_status.return_value = True versioned_hdfs = SparkDataSet(filepath="hdfs://{}".format(HDFS_PREFIX), version=version) mocked_spark_df = mocker.Mock() pattern = (r"Save path `.+` for SparkDataSet\(.+\) must not exist " r"if versioning is enabled") with pytest.raises(DataSetError, match=pattern): versioned_hdfs.save(mocked_spark_df) hdfs_status.assert_called_once_with( "{fn}/{f}/{v}/{f}".format(fn=FOLDER_NAME, v=version.save, f=FILENAME), strict=False, ) mocked_spark_df.write.save.assert_not_called()
def test_save_overwrite(): # This test is split into two sections. # Firstly, it writes a data frame twice and expects it to fail. # Secondly, it writes a data frame with overwrite mode. with pytest.raises(DataSetError): with tempfile.TemporaryDirectory() as temp_dir: temp_path = join(temp_dir, "test_data") spark_data_set = SparkDataSet(filepath=temp_path) spark_df = _get_sample_spark_data_frame() spark_data_set.save(spark_df) spark_data_set.save(spark_df) with tempfile.TemporaryDirectory() as temp_dir: temp_path = join(temp_dir, "test_data") spark_data_set = SparkDataSet(filepath=temp_path, save_args={"mode": "overwrite"}) spark_df = _get_sample_spark_data_frame() spark_data_set.save(spark_df) spark_data_set.save(spark_df)
def spark_in(tmp_path, sample_spark_df): spark_in = SparkDataSet(filepath=str(tmp_path / "input")) spark_in.save(sample_spark_df) return spark_in