Example #1
0
def test_save_options_csv():
    # To cross check the correct Spark save operation we save to
    # a single spark partition with csv format and retrieve it with Kedro
    # CSVLocalDataSet
    with tempfile.TemporaryDirectory() as temp_dir:
        temp_path = join(temp_dir, "test_data")
        spark_data_set = SparkDataSet(
            filepath=temp_path,
            file_format="csv",
            save_args={
                "sep": "|",
                "header": True
            },
        )
        spark_df = _get_sample_spark_data_frame().coalesce(1)
        spark_data_set.save(spark_df)

        single_csv_file = [
            join(temp_path, f) for f in listdir(temp_path) if f.endswith("csv")
        ][0]

        csv_local_data_set = CSVLocalDataSet(filepath=single_csv_file,
                                             load_args={"sep": "|"})
        pandas_df = csv_local_data_set.load()

        assert pandas_df[pandas_df["name"] == "Alex"]["age"][0] == 31
Example #2
0
    def test_save(self, mocker, version):
        hdfs_status = mocker.patch(
            "kedro.contrib.io.pyspark.spark_data_set.InsecureClient.status")
        hdfs_status.return_value = None

        versioned_hdfs = SparkDataSet(filepath="hdfs://{}".format(HDFS_PREFIX),
                                      version=version)

        mocked_load_version = mocker.MagicMock()
        mocked_load_version.__eq__.return_value = True
        # need to mock _get_load_path() call inside _save()
        # also need _get_load_path() to return a load version that
        # _check_paths_consistency() will be happy with (hence mocking __eq__)
        mocker.patch.object(versioned_hdfs,
                            "_get_load_path",
                            return_value=mocked_load_version)

        mocked_spark_df = mocker.Mock()
        versioned_hdfs.save(mocked_spark_df)

        hdfs_status.assert_called_once_with(
            "{fn}/{f}/{v}/{f}".format(fn=FOLDER_NAME,
                                      v=version.save,
                                      f=FILENAME),
            strict=False,
        )
        mocked_spark_df.write.save.assert_called_once_with(
            "hdfs://{fn}/{f}/{v}/{f}".format(fn=FOLDER_NAME,
                                             v=version.save,
                                             f=FILENAME),
            "parquet",
        )
Example #3
0
    def test_save(self, mocker, version):
        hdfs_status = mocker.patch(
            "kedro.contrib.io.pyspark.spark_data_set.InsecureClient.status"
        )
        hdfs_status.return_value = None

        versioned_hdfs = SparkDataSet(
            filepath="hdfs://{}".format(HDFS_PREFIX), version=version
        )

        # need _lookup_load_version() call to return a load version that
        # matches save version due to consistency check in versioned_hdfs.save()
        mocker.patch.object(
            versioned_hdfs, "_lookup_load_version", return_value=version.save
        )

        mocked_spark_df = mocker.Mock()
        versioned_hdfs.save(mocked_spark_df)

        hdfs_status.assert_called_once_with(
            "{fn}/{f}/{v}/{f}".format(fn=FOLDER_NAME, v=version.save, f=FILENAME),
            strict=False,
        )
        mocked_spark_df.write.save.assert_called_once_with(
            "hdfs://{fn}/{f}/{v}/{f}".format(
                fn=FOLDER_NAME, v=version.save, f=FILENAME
            ),
            "parquet",
        )
Example #4
0
    def test_save_version_warning(self, mocker):
        exact_version = Version("2019-01-01T23.59.59.999Z",
                                "2019-01-02T00.00.00.000Z")
        versioned_hdfs = SparkDataSet(filepath="hdfs://{}".format(HDFS_PREFIX),
                                      version=exact_version)
        mocker.patch.object(versioned_hdfs,
                            "_exists_function",
                            return_value=False)
        mocked_spark_df = mocker.Mock()

        pattern = (r"Save path `{fn}/{f}/{sv}/{f}` did not match load path "
                   r"`{fn}/{f}/{lv}/{f}` for SparkDataSet\(.+\)".format(
                       fn=FOLDER_NAME,
                       f=FILENAME,
                       sv=exact_version.save,
                       lv=exact_version.load))

        with pytest.warns(UserWarning, match=pattern):
            versioned_hdfs.save(mocked_spark_df)
        mocked_spark_df.write.save.assert_called_once_with(
            "hdfs://{fn}/{f}/{sv}/{f}".format(fn=FOLDER_NAME,
                                              f=FILENAME,
                                              sv=exact_version.save),
            "parquet",
        )
Example #5
0
    def test_save_options_csv(self, tmp_path, sample_spark_df):
        # To cross check the correct Spark save operation we save to
        # a single spark partition with csv format and retrieve it with Kedro
        # CSVLocalDataSet
        temp_dir = Path(str(tmp_path / "test_data"))
        spark_data_set = SparkDataSet(
            filepath=str(temp_dir),
            file_format="csv",
            save_args={
                "sep": "|",
                "header": True
            },
        )
        spark_df = sample_spark_df.coalesce(1)
        spark_data_set.save(spark_df)

        single_csv_file = [
            f for f in temp_dir.iterdir() if f.is_file() and f.suffix == ".csv"
        ][0]

        csv_local_data_set = CSVLocalDataSet(filepath=str(single_csv_file),
                                             load_args={"sep": "|"})
        pandas_df = csv_local_data_set.load()

        assert pandas_df[pandas_df["name"] == "Alex"]["age"][0] == 31
Example #6
0
    def test_save_overwrite_fail(self, tmp_path, sample_spark_df):
        # Writes a data frame twice and expects it to fail.
        filepath = str(tmp_path / "test_data")
        spark_data_set = SparkDataSet(filepath=filepath)
        spark_data_set.save(sample_spark_df)

        with pytest.raises(DataSetError):
            spark_data_set.save(sample_spark_df)
Example #7
0
    def test_save_overwrite_mode(self, tmp_path, sample_spark_df):
        # Writes a data frame in overwrite mode.
        filepath = str(tmp_path / "test_data")
        spark_data_set = SparkDataSet(filepath=filepath,
                                      save_args={"mode": "overwrite"})

        spark_data_set.save(sample_spark_df)
        spark_data_set.save(sample_spark_df)
Example #8
0
    def test_exists(self, file_format, tmp_path, sample_spark_df):
        filepath = str(tmp_path / "test_data")
        spark_data_set = SparkDataSet(filepath=filepath, file_format=file_format)

        assert not spark_data_set.exists()

        spark_data_set.save(sample_spark_df)
        assert spark_data_set.exists()
Example #9
0
    def test_load_exact(self, tmp_path, sample_spark_df):
        ts = generate_timestamp()
        ds_local = SparkDataSet(filepath=str(tmp_path / FILENAME),
                                version=Version(ts, ts))

        ds_local.save(sample_spark_df)
        reloaded = ds_local.load()

        assert reloaded.exceptAll(sample_spark_df).count() == 0
Example #10
0
def test_exists(file_format):
    with tempfile.TemporaryDirectory() as temp_dir:
        temp_path = join(temp_dir, "test_data")
        spark_data_set = SparkDataSet(filepath=temp_path,
                                      file_format=file_format)
        spark_df = _get_sample_spark_data_frame().coalesce(1)

        assert not spark_data_set.exists()

        spark_data_set.save(spark_df)
        assert spark_data_set.exists()
Example #11
0
    def test_save_version_warning(self, tmp_path, sample_spark_df):
        exact_version = Version("2019-01-01T23.59.59.999Z",
                                "2019-01-02T00.00.00.000Z")
        ds_local = SparkDataSet(filepath=str(tmp_path / FILENAME),
                                version=exact_version)

        pattern = (r"Save version `{ev.save}` did not match load version "
                   r"`{ev.load}` for SparkDataSet\(.+\)".format(
                       ev=exact_version))
        with pytest.warns(UserWarning, match=pattern):
            ds_local.save(sample_spark_df)
Example #12
0
    def test_prevent_overwrite(self, tmp_path, version, sample_spark_df):
        versioned_local = SparkDataSet(
            filepath=str(tmp_path / FILENAME),
            version=version,
            # second save should fail even in overwrite mode
            save_args={"mode": "overwrite"},
        )
        versioned_local.save(sample_spark_df)

        pattern = (r"Save path `.+` for SparkDataSet\(.+\) must not exist "
                   r"if versioning is enabled")
        with pytest.raises(DataSetError, match=pattern):
            versioned_local.save(sample_spark_df)
Example #13
0
    def test_save_partition(self, tmp_path, sample_spark_df):
        # To verify partitioning this test will partition the data by one
        # of the columns and then check whether partitioned column is added
        # to the save path

        filepath = Path(str(tmp_path / "test_data"))
        spark_data_set = SparkDataSet(
            filepath=str(filepath),
            save_args={"mode": "overwrite", "partitionBy": ["name"]},
        )

        spark_data_set.save(sample_spark_df)

        expected_path = filepath / "name=Alex"

        assert expected_path.exists()
Example #14
0
def test_save_partition():
    # To verify partitioning this test with partition data
    # and checked whether paritioned column is added to the
    # path

    with tempfile.TemporaryDirectory() as temp_dir:
        temp_path = join(temp_dir, "test_data")
        spark_data_set = SparkDataSet(filepath=temp_path,
                                      save_args={
                                          "mode": "overwrite",
                                          "partitionBy": ["name"]
                                      })

        spark_df = _get_sample_spark_data_frame()
        spark_data_set.save(spark_df)

        expected_path = join(temp_path, "name=Alex")

        assert exists(expected_path)
Example #15
0
    def test_save_parquet(self, tmp_path, sample_spark_df):
        # To cross check the correct Spark save operation we save to
        # a single spark partition and retrieve it with Kedro
        # ParquetLocalDataSet
        temp_dir = Path(str(tmp_path / "test_data"))
        spark_data_set = SparkDataSet(
            filepath=str(temp_dir), save_args={"compression": "none"}
        )
        spark_df = sample_spark_df.coalesce(1)
        spark_data_set.save(spark_df)

        single_parquet = [
            f for f in temp_dir.iterdir() if f.is_file() and f.name.startswith("part")
        ][0]

        local_parquet_data_set = ParquetLocalDataSet(filepath=str(single_parquet))

        pandas_df = local_parquet_data_set.load()

        assert pandas_df[pandas_df["name"] == "Bob"]["age"].iloc[0] == 12
Example #16
0
    def test_save_version_warning(self, mocker):
        exact_version = Version("2019-01-01T23.59.59.999Z", "2019-01-02T00.00.00.000Z")
        ds_s3 = SparkDataSet(
            filepath="s3a://{}/{}".format(BUCKET_NAME, FILENAME),
            version=exact_version,
            credentials=AWS_CREDENTIALS,
        )
        mocked_spark_df = mocker.Mock()

        pattern = (
            r"Save version `{ev.save}` did not match load version "
            r"`{ev.load}` for SparkDataSet\(.+\)".format(ev=exact_version)
        )
        with pytest.warns(UserWarning, match=pattern):
            ds_s3.save(mocked_spark_df)
        mocked_spark_df.write.save.assert_called_once_with(
            "s3a://{b}/{f}/{v}/{f}".format(
                b=BUCKET_NAME, f=FILENAME, v=exact_version.save
            ),
            "parquet",
        )
Example #17
0
def test_save_parquet():
    # To cross check the correct Spark save operation we save to
    # a single spark partition and retrieve it with Kedro
    # ParquetLocalDataSet
    with tempfile.TemporaryDirectory() as temp_dir:
        temp_path = join(temp_dir, "test_data")
        spark_data_set = SparkDataSet(filepath=temp_path,
                                      save_args={"compression": "none"})
        spark_df = _get_sample_spark_data_frame().coalesce(1)
        spark_data_set.save(spark_df)

        single_parquet = [
            join(temp_path, f) for f in listdir(temp_path)
            if f.startswith("part")
        ][0]

        local_parquet_data_set = ParquetLocalDataSet(filepath=single_parquet)

        pandas_df = local_parquet_data_set.load()

        assert pandas_df[pandas_df["name"] == "Bob"]["age"].iloc[0] == 12
Example #18
0
    def test_prevent_overwrite(self, mocker, version):
        hdfs_status = mocker.patch(
            "kedro.contrib.io.pyspark.spark_data_set.InsecureClient.status")
        hdfs_status.return_value = True

        versioned_hdfs = SparkDataSet(filepath="hdfs://{}".format(HDFS_PREFIX),
                                      version=version)

        mocked_spark_df = mocker.Mock()

        pattern = (r"Save path `.+` for SparkDataSet\(.+\) must not exist "
                   r"if versioning is enabled")
        with pytest.raises(DataSetError, match=pattern):
            versioned_hdfs.save(mocked_spark_df)

        hdfs_status.assert_called_once_with(
            "{fn}/{f}/{v}/{f}".format(fn=FOLDER_NAME,
                                      v=version.save,
                                      f=FILENAME),
            strict=False,
        )
        mocked_spark_df.write.save.assert_not_called()
Example #19
0
def test_save_overwrite():
    # This test is split into two sections.
    # Firstly, it writes a data frame twice and expects it to fail.
    # Secondly, it writes a data frame with overwrite mode.

    with pytest.raises(DataSetError):
        with tempfile.TemporaryDirectory() as temp_dir:
            temp_path = join(temp_dir, "test_data")
            spark_data_set = SparkDataSet(filepath=temp_path)

            spark_df = _get_sample_spark_data_frame()
            spark_data_set.save(spark_df)
            spark_data_set.save(spark_df)

    with tempfile.TemporaryDirectory() as temp_dir:
        temp_path = join(temp_dir, "test_data")
        spark_data_set = SparkDataSet(filepath=temp_path,
                                      save_args={"mode": "overwrite"})

        spark_df = _get_sample_spark_data_frame()
        spark_data_set.save(spark_df)
        spark_data_set.save(spark_df)
Example #20
0
def spark_in(tmp_path, sample_spark_df):
    spark_in = SparkDataSet(filepath=str(tmp_path / "input"))
    spark_in.save(sample_spark_df)
    return spark_in