def test_load_exact(self, mocker):
        ts = generate_timestamp()
        versioned_hdfs = SparkDataSet(filepath="hdfs://{}".format(HDFS_PREFIX),
                                      version=Version(ts, None))
        get_spark = mocker.patch.object(versioned_hdfs, "_get_spark")

        versioned_hdfs.load()

        get_spark.return_value.read.load.assert_called_once_with(
            "hdfs://{fn}/{f}/{v}/{f}".format(fn=FOLDER_NAME, f=FILENAME, v=ts),
            "parquet",
        )
    def test_no_version(self, mocker, version):
        hdfs_walk = mocker.patch(
            "kedro.extras.datasets.spark.spark_dataset.InsecureClient.walk")
        hdfs_walk.return_value = []

        versioned_hdfs = SparkDataSet(filepath="hdfs://{}".format(HDFS_PREFIX),
                                      version=version)

        pattern = r"Did not find any versions for SparkDataSet\(.+\)"
        with pytest.raises(DataSetError, match=pattern):
            versioned_hdfs.load()

        hdfs_walk.assert_called_once_with(HDFS_PREFIX)
    def test_load_exact(self, mocker):
        ts = generate_timestamp()
        ds_s3 = SparkDataSet(
            filepath="s3a://{}/{}".format(BUCKET_NAME, FILENAME),
            version=Version(ts, None),
            credentials=AWS_CREDENTIALS,
        )
        get_spark = mocker.patch.object(ds_s3, "_get_spark")

        ds_s3.load()

        get_spark.return_value.read.load.assert_called_once_with(
            "s3a://{b}/{f}/{v}/{f}".format(b=BUCKET_NAME, f=FILENAME, v=ts),
            "parquet")
 def test_load_parquet(self, tmp_path, sample_pandas_df):
     temp_path = str(tmp_path / "data")
     local_parquet_set = ParquetDataSet(filepath=temp_path)
     local_parquet_set.save(sample_pandas_df)
     spark_data_set = SparkDataSet(filepath=temp_path)
     spark_df = spark_data_set.load()
     assert spark_df.count() == 4
    def test_load_exact(self, tmp_path, sample_spark_df):
        ts = generate_timestamp()
        ds_dbfs = SparkDataSet(filepath="/dbfs" + str(tmp_path / FILENAME),
                               version=Version(ts, ts))

        ds_dbfs.save(sample_spark_df)
        reloaded = ds_dbfs.load()

        assert reloaded.exceptAll(sample_spark_df).count() == 0
 def test_load_options_csv(self, tmp_path, sample_pandas_df):
     filepath = str(tmp_path / "data")
     local_csv_data_set = CSVDataSet(filepath=filepath)
     local_csv_data_set.save(sample_pandas_df)
     spark_data_set = SparkDataSet(filepath=filepath,
                                   file_format="csv",
                                   load_args={"header": True})
     spark_df = spark_data_set.load()
     assert spark_df.filter(col("Name") == "Alex").count() == 1
Exemple #7
0
    def test_load_exact(self, tmp_path, sample_spark_df):
        ts = generate_timestamp()
        ds_local = SparkDataSet(filepath=(tmp_path / FILENAME).as_posix(),
                                version=Version(ts, ts))

        ds_local.save(sample_spark_df)
        reloaded = ds_local.load()

        assert reloaded.exceptAll(sample_spark_df).count() == 0
    def test_load_latest(self, mocker, version):
        mocker.patch(
            "kedro.extras.datasets.spark.spark_dataset.InsecureClient.status",
            return_value=True,
        )
        hdfs_walk = mocker.patch(
            "kedro.extras.datasets.spark.spark_dataset.InsecureClient.walk")
        hdfs_walk.return_value = HDFS_FOLDER_STRUCTURE

        versioned_hdfs = SparkDataSet(filepath="hdfs://{}".format(HDFS_PREFIX),
                                      version=version)
        get_spark = mocker.patch.object(versioned_hdfs, "_get_spark")

        versioned_hdfs.load()

        hdfs_walk.assert_called_once_with(HDFS_PREFIX)
        get_spark.return_value.read.load.assert_called_once_with(
            "hdfs://{fn}/{f}/{v}/{f}".format(fn=FOLDER_NAME,
                                             v="2019-01-02T01.00.00.000Z",
                                             f=FILENAME),
            "parquet",
        )
    def test_load(self, tmp_path, sample_spark_df):
        filepath = (tmp_path / "test_data").as_posix()
        spark_delta_ds = SparkDataSet(filepath=filepath, file_format="delta")
        spark_delta_ds.save(sample_spark_df)
        loaded_with_spark = spark_delta_ds.load()
        assert loaded_with_spark.exceptAll(sample_spark_df).count() == 0

        delta_ds = DeltaTableDataSet(filepath=filepath)
        delta_table = delta_ds.load()

        assert isinstance(delta_table, DeltaTable)
        loaded_with_deltalake = delta_table.toDF()
        assert loaded_with_deltalake.exceptAll(loaded_with_spark).count() == 0