Beispiel #1
0
    def test_load_exact(self, mocker):
        ts = generate_timestamp()
        versioned_hdfs = SparkDataSet(filepath="hdfs://{}".format(HDFS_PREFIX),
                                      version=Version(ts, None))
        get_spark = mocker.patch.object(versioned_hdfs, "_get_spark")

        versioned_hdfs.load()

        get_spark.return_value.read.load.assert_called_once_with(
            "hdfs://{fn}/{f}/{v}/{f}".format(fn=FOLDER_NAME, f=FILENAME, v=ts),
            "parquet",
        )
Beispiel #2
0
    def test_no_version(self, mocker, version):
        hdfs_walk = mocker.patch(
            "kedro.contrib.io.pyspark.spark_data_set.InsecureClient.walk")
        hdfs_walk.return_value = []

        versioned_hdfs = SparkDataSet(filepath="hdfs://{}".format(HDFS_PREFIX),
                                      version=version)

        pattern = r"Did not find any versions for SparkDataSet\(.+\)"
        with pytest.raises(DataSetError, match=pattern):
            versioned_hdfs.load()

        hdfs_walk.assert_called_once_with(HDFS_PREFIX)
Beispiel #3
0
    def test_load_exact(self, mocker):
        ts = generate_timestamp()
        ds_s3 = SparkDataSet(
            filepath="s3a://{}/{}".format(BUCKET_NAME, FILENAME),
            version=Version(ts, None),
            credentials=AWS_CREDENTIALS,
        )
        get_spark = mocker.patch.object(ds_s3, "_get_spark")

        ds_s3.load()

        get_spark.return_value.read.load.assert_called_once_with(
            "s3a://{b}/{f}/{v}/{f}".format(b=BUCKET_NAME, f=FILENAME, v=ts),
            "parquet")
Beispiel #4
0
 def test_load_parquet(self, tmp_path, sample_pandas_df):
     temp_path = str(tmp_path / "data")
     local_parquet_set = ParquetLocalDataSet(filepath=temp_path)
     local_parquet_set.save(sample_pandas_df)
     spark_data_set = SparkDataSet(filepath=temp_path)
     spark_df = spark_data_set.load()
     assert spark_df.count() == 4
Beispiel #5
0
def test_load_parquet(tmpdir):
    temp_path = str(tmpdir.join("data"))
    pandas_df = _get_sample_pandas_data_frame()
    local_parquet_set = ParquetLocalDataSet(filepath=temp_path)
    local_parquet_set.save(pandas_df)
    spark_data_set = SparkDataSet(filepath=temp_path)
    spark_df = spark_data_set.load()
    assert spark_df.count() == 4
Beispiel #6
0
    def test_load_exact(self, tmp_path, sample_spark_df):
        ts = generate_timestamp()
        ds_local = SparkDataSet(filepath=str(tmp_path / FILENAME),
                                version=Version(ts, ts))

        ds_local.save(sample_spark_df)
        reloaded = ds_local.load()

        assert reloaded.exceptAll(sample_spark_df).count() == 0
Beispiel #7
0
 def test_load_options_csv(self, tmp_path, sample_pandas_df):
     filepath = str(tmp_path / "data")
     local_csv_data_set = CSVLocalDataSet(filepath=filepath)
     local_csv_data_set.save(sample_pandas_df)
     spark_data_set = SparkDataSet(filepath=filepath,
                                   file_format="csv",
                                   load_args={"header": True})
     spark_df = spark_data_set.load()
     assert spark_df.filter(col("Name") == "Alex").count() == 1
Beispiel #8
0
def test_load_options_csv(tmpdir):
    temp_path = str(tmpdir.join("data"))
    pandas_df = _get_sample_pandas_data_frame()
    local_csv_data_set = CSVLocalDataSet(filepath=temp_path)
    local_csv_data_set.save(pandas_df)
    spark_data_set = SparkDataSet(filepath=temp_path,
                                  file_format="csv",
                                  load_args={"header": True})
    spark_df = spark_data_set.load()
    assert spark_df.filter(col("Name") == "Alex").count() == 1
Beispiel #9
0
    def test_load_latest(self, mocker, version):
        mocker.patch(
            "kedro.contrib.io.pyspark.spark_data_set.InsecureClient.status",
            return_value=True,
        )
        hdfs_walk = mocker.patch(
            "kedro.contrib.io.pyspark.spark_data_set.InsecureClient.walk")
        hdfs_walk.return_value = HDFS_FOLDER_STRUCTURE

        versioned_hdfs = SparkDataSet(filepath="hdfs://{}".format(HDFS_PREFIX),
                                      version=version)
        get_spark = mocker.patch.object(versioned_hdfs, "_get_spark")

        versioned_hdfs.load()

        hdfs_walk.assert_called_once_with(HDFS_PREFIX)
        get_spark.return_value.read.load.assert_called_once_with(
            "hdfs://{fn}/{f}/{v}/{f}".format(fn=FOLDER_NAME,
                                             v="2019-01-02T01.00.00.000Z",
                                             f=FILENAME),
            "parquet",
        )