def test_load_exact(self, mocker): ts = generate_timestamp() versioned_hdfs = SparkDataSet(filepath="hdfs://{}".format(HDFS_PREFIX), version=Version(ts, None)) get_spark = mocker.patch.object(versioned_hdfs, "_get_spark") versioned_hdfs.load() get_spark.return_value.read.load.assert_called_once_with( "hdfs://{fn}/{f}/{v}/{f}".format(fn=FOLDER_NAME, f=FILENAME, v=ts), "parquet", )
def test_no_version(self, mocker, version): hdfs_walk = mocker.patch( "kedro.contrib.io.pyspark.spark_data_set.InsecureClient.walk") hdfs_walk.return_value = [] versioned_hdfs = SparkDataSet(filepath="hdfs://{}".format(HDFS_PREFIX), version=version) pattern = r"Did not find any versions for SparkDataSet\(.+\)" with pytest.raises(DataSetError, match=pattern): versioned_hdfs.load() hdfs_walk.assert_called_once_with(HDFS_PREFIX)
def test_load_exact(self, mocker): ts = generate_timestamp() ds_s3 = SparkDataSet( filepath="s3a://{}/{}".format(BUCKET_NAME, FILENAME), version=Version(ts, None), credentials=AWS_CREDENTIALS, ) get_spark = mocker.patch.object(ds_s3, "_get_spark") ds_s3.load() get_spark.return_value.read.load.assert_called_once_with( "s3a://{b}/{f}/{v}/{f}".format(b=BUCKET_NAME, f=FILENAME, v=ts), "parquet")
def test_load_parquet(self, tmp_path, sample_pandas_df): temp_path = str(tmp_path / "data") local_parquet_set = ParquetLocalDataSet(filepath=temp_path) local_parquet_set.save(sample_pandas_df) spark_data_set = SparkDataSet(filepath=temp_path) spark_df = spark_data_set.load() assert spark_df.count() == 4
def test_load_parquet(tmpdir): temp_path = str(tmpdir.join("data")) pandas_df = _get_sample_pandas_data_frame() local_parquet_set = ParquetLocalDataSet(filepath=temp_path) local_parquet_set.save(pandas_df) spark_data_set = SparkDataSet(filepath=temp_path) spark_df = spark_data_set.load() assert spark_df.count() == 4
def test_load_exact(self, tmp_path, sample_spark_df): ts = generate_timestamp() ds_local = SparkDataSet(filepath=str(tmp_path / FILENAME), version=Version(ts, ts)) ds_local.save(sample_spark_df) reloaded = ds_local.load() assert reloaded.exceptAll(sample_spark_df).count() == 0
def test_load_options_csv(self, tmp_path, sample_pandas_df): filepath = str(tmp_path / "data") local_csv_data_set = CSVLocalDataSet(filepath=filepath) local_csv_data_set.save(sample_pandas_df) spark_data_set = SparkDataSet(filepath=filepath, file_format="csv", load_args={"header": True}) spark_df = spark_data_set.load() assert spark_df.filter(col("Name") == "Alex").count() == 1
def test_load_options_csv(tmpdir): temp_path = str(tmpdir.join("data")) pandas_df = _get_sample_pandas_data_frame() local_csv_data_set = CSVLocalDataSet(filepath=temp_path) local_csv_data_set.save(pandas_df) spark_data_set = SparkDataSet(filepath=temp_path, file_format="csv", load_args={"header": True}) spark_df = spark_data_set.load() assert spark_df.filter(col("Name") == "Alex").count() == 1
def test_load_latest(self, mocker, version): mocker.patch( "kedro.contrib.io.pyspark.spark_data_set.InsecureClient.status", return_value=True, ) hdfs_walk = mocker.patch( "kedro.contrib.io.pyspark.spark_data_set.InsecureClient.walk") hdfs_walk.return_value = HDFS_FOLDER_STRUCTURE versioned_hdfs = SparkDataSet(filepath="hdfs://{}".format(HDFS_PREFIX), version=version) get_spark = mocker.patch.object(versioned_hdfs, "_get_spark") versioned_hdfs.load() hdfs_walk.assert_called_once_with(HDFS_PREFIX) get_spark.return_value.read.load.assert_called_once_with( "hdfs://{fn}/{f}/{v}/{f}".format(fn=FOLDER_NAME, v="2019-01-02T01.00.00.000Z", f=FILENAME), "parquet", )