def test_save_options_csv(): # To cross check the correct Spark save operation we save to # a single spark partition with csv format and retrieve it with Kedro # CSVLocalDataSet with tempfile.TemporaryDirectory() as temp_dir: temp_path = join(temp_dir, "test_data") spark_data_set = SparkDataSet( filepath=temp_path, file_format="csv", save_args={ "sep": "|", "header": True }, ) spark_df = _get_sample_spark_data_frame().coalesce(1) spark_data_set.save(spark_df) single_csv_file = [ join(temp_path, f) for f in listdir(temp_path) if f.endswith("csv") ][0] csv_local_data_set = CSVLocalDataSet(filepath=single_csv_file, load_args={"sep": "|"}) pandas_df = csv_local_data_set.load() assert pandas_df[pandas_df["name"] == "Alex"]["age"][0] == 31
def test_load_parquet(self, tmp_path, sample_pandas_df): temp_path = str(tmp_path / "data") local_parquet_set = ParquetLocalDataSet(filepath=temp_path) local_parquet_set.save(sample_pandas_df) spark_data_set = SparkDataSet(filepath=temp_path) spark_df = spark_data_set.load() assert spark_df.count() == 4
def test_save_options_csv(self, tmp_path, sample_spark_df): # To cross check the correct Spark save operation we save to # a single spark partition with csv format and retrieve it with Kedro # CSVLocalDataSet temp_dir = Path(str(tmp_path / "test_data")) spark_data_set = SparkDataSet( filepath=str(temp_dir), file_format="csv", save_args={ "sep": "|", "header": True }, ) spark_df = sample_spark_df.coalesce(1) spark_data_set.save(spark_df) single_csv_file = [ f for f in temp_dir.iterdir() if f.is_file() and f.suffix == ".csv" ][0] csv_local_data_set = CSVLocalDataSet(filepath=str(single_csv_file), load_args={"sep": "|"}) pandas_df = csv_local_data_set.load() assert pandas_df[pandas_df["name"] == "Alex"]["age"][0] == 31
def test_save(self, mocker, version): hdfs_status = mocker.patch( "kedro.contrib.io.pyspark.spark_data_set.InsecureClient.status") hdfs_status.return_value = None versioned_hdfs = SparkDataSet(filepath="hdfs://{}".format(HDFS_PREFIX), version=version) mocked_load_version = mocker.MagicMock() mocked_load_version.__eq__.return_value = True # need to mock _get_load_path() call inside _save() # also need _get_load_path() to return a load version that # _check_paths_consistency() will be happy with (hence mocking __eq__) mocker.patch.object(versioned_hdfs, "_get_load_path", return_value=mocked_load_version) mocked_spark_df = mocker.Mock() versioned_hdfs.save(mocked_spark_df) hdfs_status.assert_called_once_with( "{fn}/{f}/{v}/{f}".format(fn=FOLDER_NAME, v=version.save, f=FILENAME), strict=False, ) mocked_spark_df.write.save.assert_called_once_with( "hdfs://{fn}/{f}/{v}/{f}".format(fn=FOLDER_NAME, v=version.save, f=FILENAME), "parquet", )
def test_save(self, mocker, version): hdfs_status = mocker.patch( "kedro.contrib.io.pyspark.spark_data_set.InsecureClient.status" ) hdfs_status.return_value = None versioned_hdfs = SparkDataSet( filepath="hdfs://{}".format(HDFS_PREFIX), version=version ) # need _lookup_load_version() call to return a load version that # matches save version due to consistency check in versioned_hdfs.save() mocker.patch.object( versioned_hdfs, "_lookup_load_version", return_value=version.save ) mocked_spark_df = mocker.Mock() versioned_hdfs.save(mocked_spark_df) hdfs_status.assert_called_once_with( "{fn}/{f}/{v}/{f}".format(fn=FOLDER_NAME, v=version.save, f=FILENAME), strict=False, ) mocked_spark_df.write.save.assert_called_once_with( "hdfs://{fn}/{f}/{v}/{f}".format( fn=FOLDER_NAME, v=version.save, f=FILENAME ), "parquet", )
def test_save_version_warning(self, mocker): exact_version = Version("2019-01-01T23.59.59.999Z", "2019-01-02T00.00.00.000Z") versioned_hdfs = SparkDataSet(filepath="hdfs://{}".format(HDFS_PREFIX), version=exact_version) mocker.patch.object(versioned_hdfs, "_exists_function", return_value=False) mocked_spark_df = mocker.Mock() pattern = (r"Save path `{fn}/{f}/{sv}/{f}` did not match load path " r"`{fn}/{f}/{lv}/{f}` for SparkDataSet\(.+\)".format( fn=FOLDER_NAME, f=FILENAME, sv=exact_version.save, lv=exact_version.load)) with pytest.warns(UserWarning, match=pattern): versioned_hdfs.save(mocked_spark_df) mocked_spark_df.write.save.assert_called_once_with( "hdfs://{fn}/{f}/{sv}/{f}".format(fn=FOLDER_NAME, f=FILENAME, sv=exact_version.save), "parquet", )
def test_save_overwrite_mode(self, tmp_path, sample_spark_df): # Writes a data frame in overwrite mode. filepath = str(tmp_path / "test_data") spark_data_set = SparkDataSet(filepath=filepath, save_args={"mode": "overwrite"}) spark_data_set.save(sample_spark_df) spark_data_set.save(sample_spark_df)
def test_load_parquet(tmpdir): temp_path = str(tmpdir.join("data")) pandas_df = _get_sample_pandas_data_frame() local_parquet_set = ParquetLocalDataSet(filepath=temp_path) local_parquet_set.save(pandas_df) spark_data_set = SparkDataSet(filepath=temp_path) spark_df = spark_data_set.load() assert spark_df.count() == 4
def test_load_options_csv(self, tmp_path, sample_pandas_df): filepath = str(tmp_path / "data") local_csv_data_set = CSVLocalDataSet(filepath=filepath) local_csv_data_set.save(sample_pandas_df) spark_data_set = SparkDataSet(filepath=filepath, file_format="csv", load_args={"header": True}) spark_df = spark_data_set.load() assert spark_df.filter(col("Name") == "Alex").count() == 1
def test_load_exact(self, tmp_path, sample_spark_df): ts = generate_timestamp() ds_local = SparkDataSet(filepath=str(tmp_path / FILENAME), version=Version(ts, ts)) ds_local.save(sample_spark_df) reloaded = ds_local.load() assert reloaded.exceptAll(sample_spark_df).count() == 0
def test_load_options_csv(tmpdir): temp_path = str(tmpdir.join("data")) pandas_df = _get_sample_pandas_data_frame() local_csv_data_set = CSVLocalDataSet(filepath=temp_path) local_csv_data_set.save(pandas_df) spark_data_set = SparkDataSet(filepath=temp_path, file_format="csv", load_args={"header": True}) spark_df = spark_data_set.load() assert spark_df.filter(col("Name") == "Alex").count() == 1
def test_repr(self, version): versioned_hdfs = SparkDataSet(filepath="hdfs://{}".format(HDFS_PREFIX), version=version) assert "filepath=hdfs://" in str(versioned_hdfs) assert "version=Version(load=None, save='{}')".format( version.save) in str(versioned_hdfs) dataset_hdfs = SparkDataSet(filepath="hdfs://{}".format(HDFS_PREFIX)) assert "filepath=hdfs://" in str(dataset_hdfs) assert "version=" not in str(dataset_hdfs)
def test_save_version_warning(self, tmp_path, sample_spark_df): exact_version = Version("2019-01-01T23.59.59.999Z", "2019-01-02T00.00.00.000Z") ds_local = SparkDataSet(filepath=str(tmp_path / FILENAME), version=exact_version) pattern = (r"Save version `{ev.save}` did not match load version " r"`{ev.load}` for SparkDataSet\(.+\)".format( ev=exact_version)) with pytest.warns(UserWarning, match=pattern): ds_local.save(sample_spark_df)
def test_exists_raises_error(monkeypatch): # exists should raise all errors except for # AnalysisExceptions clearly indicating a missing file def faulty_get_spark(): raise AnalysisException("Other Exception", []) spark_data_set = SparkDataSet(filepath="") monkeypatch.setattr(spark_data_set, "_get_spark", faulty_get_spark) with pytest.raises(DataSetError) as error: spark_data_set.exists() assert "Other Exception" in str(error.value)
def test_load_exact(self, mocker): ts = generate_timestamp() versioned_hdfs = SparkDataSet(filepath="hdfs://{}".format(HDFS_PREFIX), version=Version(ts, None)) get_spark = mocker.patch.object(versioned_hdfs, "_get_spark") versioned_hdfs.load() get_spark.return_value.read.load.assert_called_once_with( "hdfs://{fn}/{f}/{v}/{f}".format(fn=FOLDER_NAME, f=FILENAME, v=ts), "parquet", )
def test_exists_raises_error(self, mocker): # exists should raise all errors except for # AnalysisExceptions clearly indicating a missing file spark_data_set = SparkDataSet(filepath="") mocker.patch.object( spark_data_set, "_get_spark", side_effect=AnalysisException("Other Exception", []), ) with pytest.raises(DataSetError, match="Other Exception"): spark_data_set.exists()
def test_no_version(self, mocker, version): hdfs_walk = mocker.patch( "kedro.contrib.io.pyspark.spark_data_set.InsecureClient.walk") hdfs_walk.return_value = [] versioned_hdfs = SparkDataSet(filepath="hdfs://{}".format(HDFS_PREFIX), version=version) pattern = r"Did not find any versions for SparkDataSet\(.+\)" with pytest.raises(DataSetError, match=pattern): versioned_hdfs.load() hdfs_walk.assert_called_once_with(HDFS_PREFIX)
def test_load_exact(self, mocker): ts = generate_timestamp() ds_s3 = SparkDataSet( filepath="s3a://{}/{}".format(BUCKET_NAME, FILENAME), version=Version(ts, None), credentials=AWS_CREDENTIALS, ) get_spark = mocker.patch.object(ds_s3, "_get_spark") ds_s3.load() get_spark.return_value.read.load.assert_called_once_with( "s3a://{b}/{f}/{v}/{f}".format(b=BUCKET_NAME, f=FILENAME, v=ts), "parquet")
def test_hdfs_warning(self, version): pattern = ( "HDFS filesystem support for versioned SparkDataSet is in beta " "and uses `hdfs.client.InsecureClient`, please use with caution") with pytest.warns(UserWarning, match=pattern): SparkDataSet(filepath="hdfs://{}".format(HDFS_PREFIX), version=version)
def test_save_partition(self, tmp_path, sample_spark_df): # To verify partitioning this test will partition the data by one # of the columns and then check whether partitioned column is added # to the save path filepath = Path(str(tmp_path / "test_data")) spark_data_set = SparkDataSet( filepath=str(filepath), save_args={"mode": "overwrite", "partitionBy": ["name"]}, ) spark_data_set.save(sample_spark_df) expected_path = filepath / "name=Alex" assert expected_path.exists()
def test_repr(self, versioned_dataset_local, tmp_path, version): assert "version=Version(load=None, save='{}')".format(version.save) in str( versioned_dataset_local ) dataset_local = SparkDataSet(filepath=str(tmp_path / FILENAME)) assert "version=" not in str(dataset_local)
def customer_dimension(df_customer: SparkDataSet, df_region: SparkDataSet, df_nation: SparkDataSet) -> SparkDataSet: """ Args: **df: Source data frames Returns: Spark data frame """ column_name = [ "Customer_Key", "Customer_Name", "Customer_Address", "Customer_Phone", "Customer_Account_Balance", "Customer_Marketing_Segment", "Customer_Country_Name", "Customer_Region_Name" ] df_customer_dimension = df_customer.join(df_nation, df_customer.C_NATIONKEY == df_nation.N_NATIONKEY) \ .join(df_region, df_region.R_REGIONKEY == df_nation.N_REGIONKEY) \ .drop("C_NATIONKEY", "N_NATIONKEY", "N_REGIONKEY", "R_REGIONKEY", "C_COMMENT", "N_COMMENT", "R_COMMENT") \ .toDF(*column_name) \ .withColumn("Customer_Account_Balance_Group", f.when(f.col("Customer_Account_Balance") < 4000, "Less than 3000") \ .when(f.col("Customer_Account_Balance") < 8000, "Between 3000 and 8000") \ .otherwise("More than 8000")) return df_customer_dimension
def supplier_dimension(df_supplier: SparkDataSet, df_region: SparkDataSet, df_nation: SparkDataSet) -> SparkDataSet: """ Args: **df: Source data frames Returns: Spark data frame """ column_name = [ "Supplier_Key", "Supplier_Name", "Supplier_Address", "Supplier_Phone", "Supplier_Account_Balance", "Supplier_Country_Name", "Supplier_Region_Name" ] df_supplier_dimension = df_supplier.join(df_nation, df_supplier.S_NATIONKEY == df_nation.N_NATIONKEY) \ .join(df_region, df_region.R_REGIONKEY == df_nation.N_REGIONKEY) \ .drop("S_COMMENT", "N_COMMENT", "R_COMMENT", "N_NATIONKEY", "N_REGIONKEY", "S_NATIONKEY", "R_REGIONKEY" ) \ .toDF(*column_name) return df_supplier_dimension
def test_str_representation(): with tempfile.NamedTemporaryFile() as temp_data_file: spark_data_set = SparkDataSet(filepath=temp_data_file.name, file_format="csv", load_args={"header": True}) assert "SparkDataSet" in str(spark_data_set) assert "filepath={}".format(temp_data_file.name) in str(spark_data_set)
def test_repr(self, versioned_dataset_s3, version): assert "filepath=s3a://" in str(versioned_dataset_s3) assert "version=Version(load=None, save='{}')".format( version.save) in str(versioned_dataset_s3) dataset_s3 = SparkDataSet( filepath="s3a://{}/{}".format(BUCKET_NAME, FILENAME)) assert "filepath=s3a://" in str(dataset_s3) assert "version=" not in str(dataset_s3)
def test_s3n_warning(self, version): pattern = ( "`s3n` filesystem has now been deprecated by Spark, " "please consider switching to `s3a`" ) with pytest.warns(DeprecationWarning, match=pattern): SparkDataSet( filepath="s3n://{}/{}".format(BUCKET_NAME, FILENAME), version=version )
def test_save_partition(): # To verify partitioning this test with partition data # and checked whether paritioned column is added to the # path with tempfile.TemporaryDirectory() as temp_dir: temp_path = join(temp_dir, "test_data") spark_data_set = SparkDataSet(filepath=temp_path, save_args={ "mode": "overwrite", "partitionBy": ["name"] }) spark_df = _get_sample_spark_data_frame() spark_data_set.save(spark_df) expected_path = join(temp_path, "name=Alex") assert exists(expected_path)
def test_save_parquet(self, tmp_path, sample_spark_df): # To cross check the correct Spark save operation we save to # a single spark partition and retrieve it with Kedro # ParquetLocalDataSet temp_dir = Path(str(tmp_path / "test_data")) spark_data_set = SparkDataSet( filepath=str(temp_dir), save_args={"compression": "none"} ) spark_df = sample_spark_df.coalesce(1) spark_data_set.save(spark_df) single_parquet = [ f for f in temp_dir.iterdir() if f.is_file() and f.name.startswith("part") ][0] local_parquet_data_set = ParquetLocalDataSet(filepath=str(single_parquet)) pandas_df = local_parquet_data_set.load() assert pandas_df[pandas_df["name"] == "Bob"]["age"].iloc[0] == 12
def test_save_version_warning(self, mocker): exact_version = Version("2019-01-01T23.59.59.999Z", "2019-01-02T00.00.00.000Z") ds_s3 = SparkDataSet( filepath="s3a://{}/{}".format(BUCKET_NAME, FILENAME), version=exact_version, credentials=AWS_CREDENTIALS, ) mocked_spark_df = mocker.Mock() pattern = ( r"Save version `{ev.save}` did not match load version " r"`{ev.load}` for SparkDataSet\(.+\)".format(ev=exact_version) ) with pytest.warns(UserWarning, match=pattern): ds_s3.save(mocked_spark_df) mocked_spark_df.write.save.assert_called_once_with( "s3a://{b}/{f}/{v}/{f}".format( b=BUCKET_NAME, f=FILENAME, v=exact_version.save ), "parquet", )
def test_save_parquet(): # To cross check the correct Spark save operation we save to # a single spark partition and retrieve it with Kedro # ParquetLocalDataSet with tempfile.TemporaryDirectory() as temp_dir: temp_path = join(temp_dir, "test_data") spark_data_set = SparkDataSet(filepath=temp_path, save_args={"compression": "none"}) spark_df = _get_sample_spark_data_frame().coalesce(1) spark_data_set.save(spark_df) single_parquet = [ join(temp_path, f) for f in listdir(temp_path) if f.startswith("part") ][0] local_parquet_data_set = ParquetLocalDataSet(filepath=single_parquet) pandas_df = local_parquet_data_set.load() assert pandas_df[pandas_df["name"] == "Bob"]["age"].iloc[0] == 12