Пример #1
0
def test_save_options_csv():
    # To cross check the correct Spark save operation we save to
    # a single spark partition with csv format and retrieve it with Kedro
    # CSVLocalDataSet
    with tempfile.TemporaryDirectory() as temp_dir:
        temp_path = join(temp_dir, "test_data")
        spark_data_set = SparkDataSet(
            filepath=temp_path,
            file_format="csv",
            save_args={
                "sep": "|",
                "header": True
            },
        )
        spark_df = _get_sample_spark_data_frame().coalesce(1)
        spark_data_set.save(spark_df)

        single_csv_file = [
            join(temp_path, f) for f in listdir(temp_path) if f.endswith("csv")
        ][0]

        csv_local_data_set = CSVLocalDataSet(filepath=single_csv_file,
                                             load_args={"sep": "|"})
        pandas_df = csv_local_data_set.load()

        assert pandas_df[pandas_df["name"] == "Alex"]["age"][0] == 31
Пример #2
0
 def test_load_parquet(self, tmp_path, sample_pandas_df):
     temp_path = str(tmp_path / "data")
     local_parquet_set = ParquetLocalDataSet(filepath=temp_path)
     local_parquet_set.save(sample_pandas_df)
     spark_data_set = SparkDataSet(filepath=temp_path)
     spark_df = spark_data_set.load()
     assert spark_df.count() == 4
Пример #3
0
    def test_save_options_csv(self, tmp_path, sample_spark_df):
        # To cross check the correct Spark save operation we save to
        # a single spark partition with csv format and retrieve it with Kedro
        # CSVLocalDataSet
        temp_dir = Path(str(tmp_path / "test_data"))
        spark_data_set = SparkDataSet(
            filepath=str(temp_dir),
            file_format="csv",
            save_args={
                "sep": "|",
                "header": True
            },
        )
        spark_df = sample_spark_df.coalesce(1)
        spark_data_set.save(spark_df)

        single_csv_file = [
            f for f in temp_dir.iterdir() if f.is_file() and f.suffix == ".csv"
        ][0]

        csv_local_data_set = CSVLocalDataSet(filepath=str(single_csv_file),
                                             load_args={"sep": "|"})
        pandas_df = csv_local_data_set.load()

        assert pandas_df[pandas_df["name"] == "Alex"]["age"][0] == 31
Пример #4
0
    def test_save(self, mocker, version):
        hdfs_status = mocker.patch(
            "kedro.contrib.io.pyspark.spark_data_set.InsecureClient.status")
        hdfs_status.return_value = None

        versioned_hdfs = SparkDataSet(filepath="hdfs://{}".format(HDFS_PREFIX),
                                      version=version)

        mocked_load_version = mocker.MagicMock()
        mocked_load_version.__eq__.return_value = True
        # need to mock _get_load_path() call inside _save()
        # also need _get_load_path() to return a load version that
        # _check_paths_consistency() will be happy with (hence mocking __eq__)
        mocker.patch.object(versioned_hdfs,
                            "_get_load_path",
                            return_value=mocked_load_version)

        mocked_spark_df = mocker.Mock()
        versioned_hdfs.save(mocked_spark_df)

        hdfs_status.assert_called_once_with(
            "{fn}/{f}/{v}/{f}".format(fn=FOLDER_NAME,
                                      v=version.save,
                                      f=FILENAME),
            strict=False,
        )
        mocked_spark_df.write.save.assert_called_once_with(
            "hdfs://{fn}/{f}/{v}/{f}".format(fn=FOLDER_NAME,
                                             v=version.save,
                                             f=FILENAME),
            "parquet",
        )
Пример #5
0
    def test_save(self, mocker, version):
        hdfs_status = mocker.patch(
            "kedro.contrib.io.pyspark.spark_data_set.InsecureClient.status"
        )
        hdfs_status.return_value = None

        versioned_hdfs = SparkDataSet(
            filepath="hdfs://{}".format(HDFS_PREFIX), version=version
        )

        # need _lookup_load_version() call to return a load version that
        # matches save version due to consistency check in versioned_hdfs.save()
        mocker.patch.object(
            versioned_hdfs, "_lookup_load_version", return_value=version.save
        )

        mocked_spark_df = mocker.Mock()
        versioned_hdfs.save(mocked_spark_df)

        hdfs_status.assert_called_once_with(
            "{fn}/{f}/{v}/{f}".format(fn=FOLDER_NAME, v=version.save, f=FILENAME),
            strict=False,
        )
        mocked_spark_df.write.save.assert_called_once_with(
            "hdfs://{fn}/{f}/{v}/{f}".format(
                fn=FOLDER_NAME, v=version.save, f=FILENAME
            ),
            "parquet",
        )
Пример #6
0
    def test_save_version_warning(self, mocker):
        exact_version = Version("2019-01-01T23.59.59.999Z",
                                "2019-01-02T00.00.00.000Z")
        versioned_hdfs = SparkDataSet(filepath="hdfs://{}".format(HDFS_PREFIX),
                                      version=exact_version)
        mocker.patch.object(versioned_hdfs,
                            "_exists_function",
                            return_value=False)
        mocked_spark_df = mocker.Mock()

        pattern = (r"Save path `{fn}/{f}/{sv}/{f}` did not match load path "
                   r"`{fn}/{f}/{lv}/{f}` for SparkDataSet\(.+\)".format(
                       fn=FOLDER_NAME,
                       f=FILENAME,
                       sv=exact_version.save,
                       lv=exact_version.load))

        with pytest.warns(UserWarning, match=pattern):
            versioned_hdfs.save(mocked_spark_df)
        mocked_spark_df.write.save.assert_called_once_with(
            "hdfs://{fn}/{f}/{sv}/{f}".format(fn=FOLDER_NAME,
                                              f=FILENAME,
                                              sv=exact_version.save),
            "parquet",
        )
Пример #7
0
    def test_save_overwrite_mode(self, tmp_path, sample_spark_df):
        # Writes a data frame in overwrite mode.
        filepath = str(tmp_path / "test_data")
        spark_data_set = SparkDataSet(filepath=filepath,
                                      save_args={"mode": "overwrite"})

        spark_data_set.save(sample_spark_df)
        spark_data_set.save(sample_spark_df)
Пример #8
0
def test_load_parquet(tmpdir):
    temp_path = str(tmpdir.join("data"))
    pandas_df = _get_sample_pandas_data_frame()
    local_parquet_set = ParquetLocalDataSet(filepath=temp_path)
    local_parquet_set.save(pandas_df)
    spark_data_set = SparkDataSet(filepath=temp_path)
    spark_df = spark_data_set.load()
    assert spark_df.count() == 4
Пример #9
0
 def test_load_options_csv(self, tmp_path, sample_pandas_df):
     filepath = str(tmp_path / "data")
     local_csv_data_set = CSVLocalDataSet(filepath=filepath)
     local_csv_data_set.save(sample_pandas_df)
     spark_data_set = SparkDataSet(filepath=filepath,
                                   file_format="csv",
                                   load_args={"header": True})
     spark_df = spark_data_set.load()
     assert spark_df.filter(col("Name") == "Alex").count() == 1
Пример #10
0
    def test_load_exact(self, tmp_path, sample_spark_df):
        ts = generate_timestamp()
        ds_local = SparkDataSet(filepath=str(tmp_path / FILENAME),
                                version=Version(ts, ts))

        ds_local.save(sample_spark_df)
        reloaded = ds_local.load()

        assert reloaded.exceptAll(sample_spark_df).count() == 0
Пример #11
0
def test_load_options_csv(tmpdir):
    temp_path = str(tmpdir.join("data"))
    pandas_df = _get_sample_pandas_data_frame()
    local_csv_data_set = CSVLocalDataSet(filepath=temp_path)
    local_csv_data_set.save(pandas_df)
    spark_data_set = SparkDataSet(filepath=temp_path,
                                  file_format="csv",
                                  load_args={"header": True})
    spark_df = spark_data_set.load()
    assert spark_df.filter(col("Name") == "Alex").count() == 1
Пример #12
0
    def test_repr(self, version):
        versioned_hdfs = SparkDataSet(filepath="hdfs://{}".format(HDFS_PREFIX),
                                      version=version)
        assert "filepath=hdfs://" in str(versioned_hdfs)
        assert "version=Version(load=None, save='{}')".format(
            version.save) in str(versioned_hdfs)

        dataset_hdfs = SparkDataSet(filepath="hdfs://{}".format(HDFS_PREFIX))
        assert "filepath=hdfs://" in str(dataset_hdfs)
        assert "version=" not in str(dataset_hdfs)
Пример #13
0
    def test_save_version_warning(self, tmp_path, sample_spark_df):
        exact_version = Version("2019-01-01T23.59.59.999Z",
                                "2019-01-02T00.00.00.000Z")
        ds_local = SparkDataSet(filepath=str(tmp_path / FILENAME),
                                version=exact_version)

        pattern = (r"Save version `{ev.save}` did not match load version "
                   r"`{ev.load}` for SparkDataSet\(.+\)".format(
                       ev=exact_version))
        with pytest.warns(UserWarning, match=pattern):
            ds_local.save(sample_spark_df)
Пример #14
0
def test_exists_raises_error(monkeypatch):
    # exists should raise all errors except for
    # AnalysisExceptions clearly indicating a missing file
    def faulty_get_spark():
        raise AnalysisException("Other Exception", [])

    spark_data_set = SparkDataSet(filepath="")
    monkeypatch.setattr(spark_data_set, "_get_spark", faulty_get_spark)

    with pytest.raises(DataSetError) as error:
        spark_data_set.exists()
    assert "Other Exception" in str(error.value)
Пример #15
0
    def test_load_exact(self, mocker):
        ts = generate_timestamp()
        versioned_hdfs = SparkDataSet(filepath="hdfs://{}".format(HDFS_PREFIX),
                                      version=Version(ts, None))
        get_spark = mocker.patch.object(versioned_hdfs, "_get_spark")

        versioned_hdfs.load()

        get_spark.return_value.read.load.assert_called_once_with(
            "hdfs://{fn}/{f}/{v}/{f}".format(fn=FOLDER_NAME, f=FILENAME, v=ts),
            "parquet",
        )
Пример #16
0
    def test_exists_raises_error(self, mocker):
        # exists should raise all errors except for
        # AnalysisExceptions clearly indicating a missing file
        spark_data_set = SparkDataSet(filepath="")
        mocker.patch.object(
            spark_data_set,
            "_get_spark",
            side_effect=AnalysisException("Other Exception", []),
        )

        with pytest.raises(DataSetError, match="Other Exception"):
            spark_data_set.exists()
Пример #17
0
    def test_no_version(self, mocker, version):
        hdfs_walk = mocker.patch(
            "kedro.contrib.io.pyspark.spark_data_set.InsecureClient.walk")
        hdfs_walk.return_value = []

        versioned_hdfs = SparkDataSet(filepath="hdfs://{}".format(HDFS_PREFIX),
                                      version=version)

        pattern = r"Did not find any versions for SparkDataSet\(.+\)"
        with pytest.raises(DataSetError, match=pattern):
            versioned_hdfs.load()

        hdfs_walk.assert_called_once_with(HDFS_PREFIX)
Пример #18
0
    def test_load_exact(self, mocker):
        ts = generate_timestamp()
        ds_s3 = SparkDataSet(
            filepath="s3a://{}/{}".format(BUCKET_NAME, FILENAME),
            version=Version(ts, None),
            credentials=AWS_CREDENTIALS,
        )
        get_spark = mocker.patch.object(ds_s3, "_get_spark")

        ds_s3.load()

        get_spark.return_value.read.load.assert_called_once_with(
            "s3a://{b}/{f}/{v}/{f}".format(b=BUCKET_NAME, f=FILENAME, v=ts),
            "parquet")
Пример #19
0
 def test_hdfs_warning(self, version):
     pattern = (
         "HDFS filesystem support for versioned SparkDataSet is in beta "
         "and uses `hdfs.client.InsecureClient`, please use with caution")
     with pytest.warns(UserWarning, match=pattern):
         SparkDataSet(filepath="hdfs://{}".format(HDFS_PREFIX),
                      version=version)
Пример #20
0
    def test_save_partition(self, tmp_path, sample_spark_df):
        # To verify partitioning this test will partition the data by one
        # of the columns and then check whether partitioned column is added
        # to the save path

        filepath = Path(str(tmp_path / "test_data"))
        spark_data_set = SparkDataSet(
            filepath=str(filepath),
            save_args={"mode": "overwrite", "partitionBy": ["name"]},
        )

        spark_data_set.save(sample_spark_df)

        expected_path = filepath / "name=Alex"

        assert expected_path.exists()
Пример #21
0
    def test_repr(self, versioned_dataset_local, tmp_path, version):
        assert "version=Version(load=None, save='{}')".format(version.save) in str(
            versioned_dataset_local
        )

        dataset_local = SparkDataSet(filepath=str(tmp_path / FILENAME))
        assert "version=" not in str(dataset_local)
def customer_dimension(df_customer: SparkDataSet, df_region: SparkDataSet,
                       df_nation: SparkDataSet) -> SparkDataSet:
    """
    Args:
        **df: Source data frames

    Returns:
        Spark data frame
    """

    column_name = [
        "Customer_Key", "Customer_Name", "Customer_Address", "Customer_Phone",
        "Customer_Account_Balance", "Customer_Marketing_Segment",
        "Customer_Country_Name", "Customer_Region_Name"
    ]
    df_customer_dimension = df_customer.join(df_nation, df_customer.C_NATIONKEY == df_nation.N_NATIONKEY) \
        .join(df_region, df_region.R_REGIONKEY == df_nation.N_REGIONKEY) \
        .drop("C_NATIONKEY",
              "N_NATIONKEY",
              "N_REGIONKEY",
              "R_REGIONKEY",
              "C_COMMENT",
              "N_COMMENT",
              "R_COMMENT") \
        .toDF(*column_name) \
        .withColumn("Customer_Account_Balance_Group",
                    f.when(f.col("Customer_Account_Balance") < 4000, "Less than 3000") \
                    .when(f.col("Customer_Account_Balance") < 8000, "Between 3000 and 8000") \
                    .otherwise("More than 8000"))

    return df_customer_dimension
def supplier_dimension(df_supplier: SparkDataSet, df_region: SparkDataSet,
                       df_nation: SparkDataSet) -> SparkDataSet:
    """
    Args:
        **df: Source data frames

    Returns:
        Spark data frame
    """

    column_name = [
        "Supplier_Key", "Supplier_Name", "Supplier_Address", "Supplier_Phone",
        "Supplier_Account_Balance", "Supplier_Country_Name",
        "Supplier_Region_Name"
    ]
    df_supplier_dimension = df_supplier.join(df_nation, df_supplier.S_NATIONKEY == df_nation.N_NATIONKEY) \
        .join(df_region, df_region.R_REGIONKEY == df_nation.N_REGIONKEY) \
        .drop("S_COMMENT",
              "N_COMMENT",
              "R_COMMENT",
              "N_NATIONKEY",
              "N_REGIONKEY",
              "S_NATIONKEY",
              "R_REGIONKEY"
              ) \
        .toDF(*column_name)

    return df_supplier_dimension
Пример #24
0
def test_str_representation():
    with tempfile.NamedTemporaryFile() as temp_data_file:
        spark_data_set = SparkDataSet(filepath=temp_data_file.name,
                                      file_format="csv",
                                      load_args={"header": True})
        assert "SparkDataSet" in str(spark_data_set)
        assert "filepath={}".format(temp_data_file.name) in str(spark_data_set)
Пример #25
0
    def test_repr(self, versioned_dataset_s3, version):
        assert "filepath=s3a://" in str(versioned_dataset_s3)
        assert "version=Version(load=None, save='{}')".format(
            version.save) in str(versioned_dataset_s3)

        dataset_s3 = SparkDataSet(
            filepath="s3a://{}/{}".format(BUCKET_NAME, FILENAME))
        assert "filepath=s3a://" in str(dataset_s3)
        assert "version=" not in str(dataset_s3)
Пример #26
0
 def test_s3n_warning(self, version):
     pattern = (
         "`s3n` filesystem has now been deprecated by Spark, "
         "please consider switching to `s3a`"
     )
     with pytest.warns(DeprecationWarning, match=pattern):
         SparkDataSet(
             filepath="s3n://{}/{}".format(BUCKET_NAME, FILENAME), version=version
         )
Пример #27
0
def test_save_partition():
    # To verify partitioning this test with partition data
    # and checked whether paritioned column is added to the
    # path

    with tempfile.TemporaryDirectory() as temp_dir:
        temp_path = join(temp_dir, "test_data")
        spark_data_set = SparkDataSet(filepath=temp_path,
                                      save_args={
                                          "mode": "overwrite",
                                          "partitionBy": ["name"]
                                      })

        spark_df = _get_sample_spark_data_frame()
        spark_data_set.save(spark_df)

        expected_path = join(temp_path, "name=Alex")

        assert exists(expected_path)
Пример #28
0
    def test_save_parquet(self, tmp_path, sample_spark_df):
        # To cross check the correct Spark save operation we save to
        # a single spark partition and retrieve it with Kedro
        # ParquetLocalDataSet
        temp_dir = Path(str(tmp_path / "test_data"))
        spark_data_set = SparkDataSet(
            filepath=str(temp_dir), save_args={"compression": "none"}
        )
        spark_df = sample_spark_df.coalesce(1)
        spark_data_set.save(spark_df)

        single_parquet = [
            f for f in temp_dir.iterdir() if f.is_file() and f.name.startswith("part")
        ][0]

        local_parquet_data_set = ParquetLocalDataSet(filepath=str(single_parquet))

        pandas_df = local_parquet_data_set.load()

        assert pandas_df[pandas_df["name"] == "Bob"]["age"].iloc[0] == 12
Пример #29
0
    def test_save_version_warning(self, mocker):
        exact_version = Version("2019-01-01T23.59.59.999Z", "2019-01-02T00.00.00.000Z")
        ds_s3 = SparkDataSet(
            filepath="s3a://{}/{}".format(BUCKET_NAME, FILENAME),
            version=exact_version,
            credentials=AWS_CREDENTIALS,
        )
        mocked_spark_df = mocker.Mock()

        pattern = (
            r"Save version `{ev.save}` did not match load version "
            r"`{ev.load}` for SparkDataSet\(.+\)".format(ev=exact_version)
        )
        with pytest.warns(UserWarning, match=pattern):
            ds_s3.save(mocked_spark_df)
        mocked_spark_df.write.save.assert_called_once_with(
            "s3a://{b}/{f}/{v}/{f}".format(
                b=BUCKET_NAME, f=FILENAME, v=exact_version.save
            ),
            "parquet",
        )
Пример #30
0
def test_save_parquet():
    # To cross check the correct Spark save operation we save to
    # a single spark partition and retrieve it with Kedro
    # ParquetLocalDataSet
    with tempfile.TemporaryDirectory() as temp_dir:
        temp_path = join(temp_dir, "test_data")
        spark_data_set = SparkDataSet(filepath=temp_path,
                                      save_args={"compression": "none"})
        spark_df = _get_sample_spark_data_frame().coalesce(1)
        spark_data_set.save(spark_df)

        single_parquet = [
            join(temp_path, f) for f in listdir(temp_path)
            if f.startswith("part")
        ][0]

        local_parquet_data_set = ParquetLocalDataSet(filepath=single_parquet)

        pandas_df = local_parquet_data_set.load()

        assert pandas_df[pandas_df["name"] == "Bob"]["age"].iloc[0] == 12