def test_load_exact(self, mocker): ts = generate_timestamp() versioned_hdfs = SparkDataSet(filepath="hdfs://{}".format(HDFS_PREFIX), version=Version(ts, None)) get_spark = mocker.patch.object(versioned_hdfs, "_get_spark") versioned_hdfs.load() get_spark.return_value.read.load.assert_called_once_with( "hdfs://{fn}/{f}/{v}/{f}".format(fn=FOLDER_NAME, f=FILENAME, v=ts), "parquet", )
def test_no_version(self, mocker, version): hdfs_walk = mocker.patch( "kedro.extras.datasets.spark.spark_dataset.InsecureClient.walk") hdfs_walk.return_value = [] versioned_hdfs = SparkDataSet(filepath="hdfs://{}".format(HDFS_PREFIX), version=version) pattern = r"Did not find any versions for SparkDataSet\(.+\)" with pytest.raises(DataSetError, match=pattern): versioned_hdfs.load() hdfs_walk.assert_called_once_with(HDFS_PREFIX)
def test_load_exact(self, mocker): ts = generate_timestamp() ds_s3 = SparkDataSet( filepath="s3a://{}/{}".format(BUCKET_NAME, FILENAME), version=Version(ts, None), credentials=AWS_CREDENTIALS, ) get_spark = mocker.patch.object(ds_s3, "_get_spark") ds_s3.load() get_spark.return_value.read.load.assert_called_once_with( "s3a://{b}/{f}/{v}/{f}".format(b=BUCKET_NAME, f=FILENAME, v=ts), "parquet")
def test_load_parquet(self, tmp_path, sample_pandas_df): temp_path = str(tmp_path / "data") local_parquet_set = ParquetDataSet(filepath=temp_path) local_parquet_set.save(sample_pandas_df) spark_data_set = SparkDataSet(filepath=temp_path) spark_df = spark_data_set.load() assert spark_df.count() == 4
def test_load_exact(self, tmp_path, sample_spark_df): ts = generate_timestamp() ds_dbfs = SparkDataSet(filepath="/dbfs" + str(tmp_path / FILENAME), version=Version(ts, ts)) ds_dbfs.save(sample_spark_df) reloaded = ds_dbfs.load() assert reloaded.exceptAll(sample_spark_df).count() == 0
def test_load_options_csv(self, tmp_path, sample_pandas_df): filepath = str(tmp_path / "data") local_csv_data_set = CSVDataSet(filepath=filepath) local_csv_data_set.save(sample_pandas_df) spark_data_set = SparkDataSet(filepath=filepath, file_format="csv", load_args={"header": True}) spark_df = spark_data_set.load() assert spark_df.filter(col("Name") == "Alex").count() == 1
def test_load_exact(self, tmp_path, sample_spark_df): ts = generate_timestamp() ds_local = SparkDataSet(filepath=(tmp_path / FILENAME).as_posix(), version=Version(ts, ts)) ds_local.save(sample_spark_df) reloaded = ds_local.load() assert reloaded.exceptAll(sample_spark_df).count() == 0
def test_load_latest(self, mocker, version): mocker.patch( "kedro.extras.datasets.spark.spark_dataset.InsecureClient.status", return_value=True, ) hdfs_walk = mocker.patch( "kedro.extras.datasets.spark.spark_dataset.InsecureClient.walk") hdfs_walk.return_value = HDFS_FOLDER_STRUCTURE versioned_hdfs = SparkDataSet(filepath="hdfs://{}".format(HDFS_PREFIX), version=version) get_spark = mocker.patch.object(versioned_hdfs, "_get_spark") versioned_hdfs.load() hdfs_walk.assert_called_once_with(HDFS_PREFIX) get_spark.return_value.read.load.assert_called_once_with( "hdfs://{fn}/{f}/{v}/{f}".format(fn=FOLDER_NAME, v="2019-01-02T01.00.00.000Z", f=FILENAME), "parquet", )
def test_load(self, tmp_path, sample_spark_df): filepath = (tmp_path / "test_data").as_posix() spark_delta_ds = SparkDataSet(filepath=filepath, file_format="delta") spark_delta_ds.save(sample_spark_df) loaded_with_spark = spark_delta_ds.load() assert loaded_with_spark.exceptAll(sample_spark_df).count() == 0 delta_ds = DeltaTableDataSet(filepath=filepath) delta_table = delta_ds.load() assert isinstance(delta_table, DeltaTable) loaded_with_deltalake = delta_table.toDF() assert loaded_with_deltalake.exceptAll(loaded_with_spark).count() == 0