def test_read_write_parquet( test_parquet_in_asset: PySparkDataAsset, iris_spark: pyspark.sql.DataFrame, fake_airflow_context: Any, spark_session: pyspark.sql.SparkSession, ) -> None: p = path.abspath( path.join( test_parquet_in_asset.staging_pickedup_path(fake_airflow_context))) os.makedirs(path.dirname(p), exist_ok=True) iris_spark.write.mode("overwrite").parquet(p) count_before = iris_spark.count() columns_before = len(iris_spark.columns) with pytest.raises(expected_exception=ValueError): PySparkDataAssetIO.read_data_asset(test_parquet_in_asset, source_files=[p]) x = PySparkDataAssetIO.read_data_asset(test_parquet_in_asset, source_files=[p], spark_session=spark_session) assert count_before == x.count() assert columns_before == len(x.columns) # try with additional kwargs: x = PySparkDataAssetIO.read_data_asset( asset=test_parquet_in_asset, source_files=[p], spark_session=spark_session, mergeSchema=True, ) assert count_before == x.count()
def test_read_write_unsupported( test_xlsx_in_asset: PySparkDataAsset, spark_session: pyspark.sql.SparkSession) -> None: with pytest.raises(expected_exception=ValueError): PySparkDataAssetIO.read_data_asset(test_xlsx_in_asset, source_files=[], spark_session=spark_session)
def rebuild_for_store(asset: PySparkDataAsset, airflow_context): spark_session = pyspark.sql.SparkSession.builder.getOrCreate() enrollment_data = PySparkDataAssetIO.read_data_asset( asset=asset, source_files=asset.pickedup_files(airflow_context), spark_session=spark_session, header=True, inferSchema=True, ) PySparkDataAssetIO.write_data_asset(asset=asset, data=enrollment_data) spark_session.stop()
def rebuild_for_store(asset: PySparkDataAsset, airflow_context): spark_session = pyspark.sql.SparkSession.builder.getOrCreate() programme_data = PySparkDataAssetIO.read_data_asset( asset=asset, source_files=asset.pickedup_files(airflow_context), spark_session=spark_session, header=True, inferSchema=True, ) programme_data = programme_data.drop_duplicates( subset=asset.declarations.key_columns) PySparkDataAssetIO.write_data_asset(asset=asset, data=programme_data) spark_session.stop()
def rebuild_for_store(asset: PySparkDataAsset, airflow_context): spark_session = pyspark.sql.SparkSession.builder.getOrCreate() student = PySparkDataAsset(name="student_pyspark") programme = PySparkDataAsset(name="programme_pyspark") enrollment = PySparkDataAsset(name="enrollment_pyspark") student_df = student.retrieve_from_store( airflow_context=airflow_context, consuming_asset=asset, spark_session=spark_session, ) programme_df = programme.retrieve_from_store( airflow_context=airflow_context, consuming_asset=asset, spark_session=spark_session, ) enrollment_df = enrollment.retrieve_from_store( airflow_context=airflow_context, consuming_asset=asset, spark_session=spark_session, ) enrollment_summary: pyspark.sql.DataFrame = enrollment_df.join( other=student_df, on=student.declarations.key_columns ).join(other=programme_df, on=programme.declarations.key_columns) enrollment_summary = ( enrollment_summary.select(["student_major", "programme_name", "student_id"]) .groupby(["student_major", "programme_name"]) .agg(f.count("*").alias("count")) ) PySparkDataAssetIO.write_data_asset(asset=asset, data=enrollment_summary) spark_session.stop()
def test_read_write_csv( test_csv_asset: PySparkDataAsset, iris_spark: pyspark.sql.DataFrame, spark_session: pyspark.sql.SparkSession, ) -> None: # try without any extra kwargs: PySparkDataAssetIO.write_data_asset(asset=test_csv_asset, data=iris_spark) # try with additional kwargs: PySparkDataAssetIO.write_data_asset(asset=test_csv_asset, data=iris_spark, header=True) # test mode; default is overwrite, switch to error (if exists) should raise: with pytest.raises(AnalysisException): PySparkDataAssetIO.write_data_asset(asset=test_csv_asset, data=iris_spark, header=True, mode="error") # test retrieval # before we can retrieve, we need to move the data from 'staging' to 'ready' os.makedirs(test_csv_asset.ready_path, exist_ok=True) # load the prepared data shutil.rmtree(test_csv_asset.ready_path) shutil.move(test_csv_asset.staging_ready_path, test_csv_asset.ready_path) retrieved = PySparkDataAssetIO.retrieve_data_asset( test_csv_asset, spark_session=spark_session, inferSchema=True, header=True) assert retrieved.count() == iris_spark.count() # Test check for missing 'spark_session' kwarg with pytest.raises(ValueError): PySparkDataAssetIO.retrieve_data_asset(test_csv_asset) # Test check for invalid 'spark_session' kwarg with pytest.raises(TypeError): PySparkDataAssetIO.retrieve_data_asset(test_csv_asset, spark_session=42)