def test_can_load_non_standard_delimited_csv( spark_session: SparkSession) -> None: # Arrange clean_spark_session(spark_session) data_dir: Path = Path(__file__).parent.joinpath("./") test_file_path: str = f"{data_dir.joinpath('test.psv')}" schema = StructType([]) df: DataFrame = spark_session.createDataFrame( spark_session.sparkContext.emptyRDD(), schema) # Act loader = FrameworkCsvLoader(view="my_view", filepath=test_file_path, delimiter="|") loader.transform(df) # noinspection SqlDialectInspection result: DataFrame = spark_session.sql("SELECT * FROM my_view") result.show() # Assert assert loader.getDelimiter() == "|" assert_results(result)
def test_framework_fill_na_transformer(spark_session: SparkSession) -> None: # create a dataframe with the test data data_dir: Path = Path(__file__).parent.joinpath("./") df: DataFrame = create_empty_dataframe(spark_session=spark_session) view: str = "primary_care_protocol" FrameworkCsvLoader( view=view, filepath=data_dir.joinpath("primary_care_protocol.csv"), clean_column_names=False, ).transform(df) # ensure we have all the rows even the ones we want to drop result_df: DataFrame = spark_session.table(view) result_df = result_df.withColumn("Minimum Age", result_df["Minimum Age"].cast("float")) result_df.createOrReplaceTempView(view) assert 7 == result_df.count() # drop the rows with null NPI or null Last Name FrameworkFillNaTransformer(view=view, column_mapping={ "Minimum Age": 1.0, "Maximum Age": "No Limit" }).transform(df) # assert we get only the rows with a populated NPI result_df = spark_session.table(view) assert 7 == result_df.count() assert "No Limit" == result_df.select( "Maximum Age").collect()[1].__getitem__("Maximum Age") assert 24.0 == result_df.agg({"Minimum Age": "sum"}).collect()[0][0]
def test_can_keep_columns(spark_session: SparkSession) -> None: # Arrange clean_spark_session(spark_session) data_dir: Path = Path(__file__).parent.joinpath("./") test_file_path: str = f"{data_dir.joinpath('test.csv')}" schema = StructType([]) df: DataFrame = spark_session.createDataFrame( spark_session.sparkContext.emptyRDD(), schema) # Act FrameworkCsvLoader(view="my_view", filepath=test_file_path, delimiter=",").transform(df) FrameworkSelectColumnsTransformer(view="my_view", keep_columns=["Column2"]).transform(df) # noinspection SqlDialectInspection result: DataFrame = spark_session.sql("SELECT * FROM my_view") result.show() # Assert assert len(result.columns) == 1 assert result.count() == 3 assert result.collect()[1][0] == "bar"
def test_validation_recurses_query_dir(spark_session: SparkSession) -> None: clean_spark_session(spark_session) query_dir: Path = Path(__file__).parent.joinpath("./queries") more_queries_dir: str = "more_queries" data_dir: Path = Path(__file__).parent.joinpath("./data") test_data_file: str = f"{data_dir.joinpath('test.csv')}" validation_query_file: str = "validate.sql" schema = StructType([]) df: DataFrame = spark_session.createDataFrame( spark_session.sparkContext.emptyRDD(), schema) FrameworkCsvLoader(view="my_view", filepath=test_data_file).transform(df) FrameworkValidationTransformer( validation_source_path=str(query_dir), validation_queries=[validation_query_file, more_queries_dir], ).transform(df) df_validation = df.sql_ctx.table("pipeline_validation") df_validation.show(truncate=False) assert 3 == df_validation.count( ), "Expected 3 total rows in pipeline_validation" assert (1 == df_validation.filter("is_failed == 1").count() ), "Expected one failing row in the validation table"
def test_simple_csv_and_sql_pipeline(spark_session: SparkSession) -> None: # Arrange data_dir: Path = Path(__file__).parent.joinpath("./") flights_path: str = f"file://{data_dir.joinpath('flights.csv')}" schema = StructType([]) df: DataFrame = spark_session.createDataFrame( spark_session.sparkContext.emptyRDD(), schema) spark_session.sql("DROP TABLE IF EXISTS default.flights") # Act parameters: Dict[str, Any] = {} stages: List[Transformer] = create_steps([ FrameworkCsvLoader(view="flights", filepath=flights_path), FeaturesCarriersV1(parameters=parameters), ]) pipeline: Pipeline = Pipeline(stages=stages) # type: ignore transformer = pipeline.fit(df) transformer.transform(df) # Assert result_df: DataFrame = spark_session.sql("SELECT * FROM flights2") result_df.show() assert result_df.count() > 0
def test_correctly_loads_csv_with_clean_flag_on( spark_session: SparkSession) -> None: # Arrange clean_spark_session(spark_session) data_dir: Path = Path(__file__).parent.joinpath("./") test_file_path: str = f"{data_dir.joinpath('column_name_test.csv')}" schema = StructType([]) df: DataFrame = spark_session.createDataFrame( spark_session.sparkContext.emptyRDD(), schema) # Act FrameworkCsvLoader( view="my_view", filepath=test_file_path, delimiter=",", clean_column_names=True, ).transform(df) # noinspection SqlDialectInspection result: DataFrame = spark_session.sql("SELECT * FROM my_view") # Assert assert_results(result) assert result.collect()[1][0] == "2" assert (result.columns[2] == "Ugly_column_with_chars_that_parquet_does_not_like_much_-")
def __init__(self, parameters: Dict[str, Any], progress_logger: ProgressLogger): super(MyFailFastValidatedPipeline, self).__init__( parameters=parameters, progress_logger=progress_logger, run_id="12345678", validation_output_path=parameters["validation_output_path"], ) self.transformers = self.create_steps([ FrameworkCsvLoader( view="flights", filepath=parameters["flights_path"], parameters=parameters, progress_logger=progress_logger, ), FrameworkValidationTransformer( validation_source_path=parameters["validation_source_path"], validation_queries=["validate.sql"], fail_on_validation=True, ), FeaturesCarriersV1(parameters=parameters, progress_logger=progress_logger), FeaturesCarriersPythonV1(parameters=parameters, progress_logger=progress_logger), ])
def test_framework_drop_duplicates_transformer( spark_session: SparkSession) -> None: # create a dataframe with the test data data_dir: Path = Path(__file__).parent.joinpath("./") df: DataFrame = create_empty_dataframe(spark_session=spark_session) view: str = "primary_care_protocol" FrameworkCsvLoader( view=view, filepath=data_dir.joinpath("primary_care_protocol.csv"), clean_column_names=False, ).transform(df) # ensure we have all the rows even the ones we want to drop result_df: DataFrame = spark_session.table(view) assert 3 == result_df.count() # drop the rows with null NPI or null Last Name FrameworkDropDuplicatesTransformer(columns=["NPI"], view=view).transform(df) # assert we get only the rows with a populated NPI result_df = spark_session.table(view) assert 2 == result_df.count()
def test_framework_drop_rows_with_null_transformer( spark_session: SparkSession) -> None: # create a dataframe with the test data data_dir: Path = Path(__file__).parent.joinpath("./") df: DataFrame = create_empty_dataframe(spark_session=spark_session) view: str = "primary_care_protocol" FrameworkCsvLoader( view=view, filepath=data_dir.joinpath("primary_care_protocol.csv"), clean_column_names=False, ).transform(df) # ensure we have all the rows even the ones we want to drop result_df: DataFrame = spark_session.table(view) assert 7 == result_df.count() # drop the rows with null NPI or null Last Name FrameworkDropRowsWithNullTransformer(columns_to_check=["NPI", "Last Name"], view=view).transform(df) # assert we get only the rows with a populated NPI result_df = spark_session.table(view) assert 1 == result_df.count() # ensure that no rows are dropped when there are no null values FrameworkDropRowsWithNullTransformer(columns_to_check=["NPI", "Last Name"], view=view).transform(result_df) assert 1 == result_df.count()
def test_simple_csv_loader_pipeline(spark_session: SparkSession) -> None: # Arrange data_dir: Path = Path(__file__).parent.joinpath("./") flights_path: str = f"file://{data_dir.joinpath('flights.csv')}" schema = StructType([]) df: DataFrame = spark_session.createDataFrame( spark_session.sparkContext.emptyRDD(), schema) # noinspection SqlDialectInspection,SqlNoDataSourceInspection spark_session.sql("DROP TABLE IF EXISTS default.flights") # Act # parameters = Dict[str, Any]({ # }) stages: List[Union[Estimator[Any], Transformer]] = [ FrameworkCsvLoader(view="flights", filepath=flights_path), SQLTransformer(statement="SELECT * FROM flights"), ] pipeline: Pipeline = Pipeline(stages=stages) transformer = pipeline.fit(df) result_df: DataFrame = transformer.transform(df) # Assert result_df.show() assert result_df.count() > 0
def test_can_save_csv(spark_session: SparkSession) -> None: # Arrange SparkTestHelper.clear_tables(spark_session) data_dir: Path = Path(__file__).parent.joinpath("./") test_file_path: str = f"{data_dir.joinpath('test.csv')}" if path.isdir(data_dir.joinpath("temp")): shutil.rmtree(data_dir.joinpath("temp")) schema = StructType([]) df: DataFrame = spark_session.createDataFrame( spark_session.sparkContext.emptyRDD(), schema ) FrameworkCsvLoader( view="my_view", filepath=test_file_path, delimiter="," ).transform(df) csv_file_path: str = f"file://{data_dir.joinpath('temp/').joinpath(f'test.csv')}" # Act FrameworkCsvExporter( view="my_view", file_path=csv_file_path, header=True, delimiter="," ).transform(df) # Assert FrameworkCsvLoader( view="my_view2", filepath=csv_file_path, delimiter="," ).transform(df) # noinspection SqlDialectInspection result: DataFrame = spark_session.sql("SELECT * FROM my_view2") result.show() assert result.count() == 3 assert result.collect()[1][0] == "2" assert result.collect()[1][1] == "bar" assert result.collect()[1][2] == "bar2"
def __init__(self, parameters: Dict[str, Any], progress_logger: ProgressLogger): super(MyPipeline, self).__init__(parameters=parameters, progress_logger=progress_logger) self.transformers = self.create_steps([ FrameworkCsvLoader( view="flights", filepath=parameters["flights_path"], parameters=parameters, progress_logger=progress_logger, ), FeaturesCarriersV1(parameters=parameters, progress_logger=progress_logger), FeaturesCarriersPythonV1(parameters=parameters, progress_logger=progress_logger), ])
def test_can_load_multiline_csv(spark_session: SparkSession) -> None: # Arrange clean_spark_session(spark_session) data_dir: Path = Path(__file__).parent.joinpath("./") test_file_path: str = f"{data_dir.joinpath('multiline_row.csv')}" schema = StructType([]) df: DataFrame = spark_session.createDataFrame( spark_session.sparkContext.emptyRDD(), schema) # Act FrameworkCsvLoader(view="my_view", filepath=test_file_path, delimiter=",", multiline=True).transform(df) # noinspection SqlDialectInspection result: DataFrame = spark_session.sql("SELECT * FROM my_view") assert 1 == result.count()
def test_validation_throws_error(spark_session: SparkSession) -> None: with pytest.raises(AssertionError): clean_spark_session(spark_session) query_dir: Path = Path(__file__).parent.joinpath("./queries") data_dir: Path = Path(__file__).parent.joinpath("./data") test_data_file: str = f"{data_dir.joinpath('test.csv')}" validation_query_file: str = "validate.sql" schema = StructType([]) df: DataFrame = spark_session.createDataFrame( spark_session.sparkContext.emptyRDD(), schema) FrameworkCsvLoader(view="my_view", filepath=test_data_file).transform(df) FrameworkValidationTransformer( validation_source_path=str(query_dir), validation_queries=[validation_query_file], fail_on_validation=True, ).transform(df) df.sql_ctx.table("pipeline_validation").show(truncate=False)
def test_can_load_csv_without_header(spark_session: SparkSession) -> None: # Arrange clean_spark_session(spark_session) data_dir: Path = Path(__file__).parent.joinpath("./") test_file_path: str = f"{data_dir.joinpath('no_header.csv')}" schema = StructType([]) df: DataFrame = spark_session.createDataFrame( spark_session.sparkContext.emptyRDD(), schema) # Act FrameworkCsvLoader(view="another_view", filepath=test_file_path, delimiter=",", has_header=False).transform(df) # noinspection SqlDialectInspection result: DataFrame = spark_session.sql("SELECT * FROM another_view") # Assert assert_results(result)
def __init__( self, parameters: Dict[str, Any], location: Union[str, Path], progress_logger: Optional[ProgressLogger] = None, verify_count_remains_same: bool = False, ) -> None: super().__init__( name=self.__class__.__name__, parameters=parameters, progress_logger=progress_logger, ) self.verify_count_remains_same: bool = verify_count_remains_same self.location: str = str(location) self.my_transformers: List[Transformer] = [] self.logger = get_logger(__name__) assert self.location # Iterate over files to create transformers files: List[str] = sorted(listdir(self.location)) index_of_module: int = self.location.rfind("/library/") module_ = index_of_module + 1 module_name: str = self.location[module_:].replace("/", ".") # noinspection Mypy self.setParams( name=self.__class__.__name__, parameters=parameters, progress_logger=progress_logger, ) for file in files: if file.endswith(".csv"): file_name = file.replace(".csv", "") self.my_transformers.append( FrameworkCsvLoader( view=file_name, filepath=path.join(self.location, file), delimiter=parameters.get("delimiter", ","), has_header=parameters.get("has_header", True), mapping_file_name=file, )) elif file.endswith(".sql"): feature_sql: str = self.read_file_as_string( path.join(self.location, file)).format(parameters=parameters) self.my_transformers.append( FrameworkSqlTransformer( sql=feature_sql, name=module_name, progress_logger=progress_logger, log_sql=parameters.get("debug_log_sql", False), view=file.replace(".sql", ""), verify_count_remains_same=verify_count_remains_same, mapping_file_name=file, )) elif file.endswith("mapping.py"): file_name_only: str = os.path.basename(file) # strip off .py to get the module name import_module_name: str = file_name_only.replace(".py", "") self.my_transformers.append( self.get_python_mapping_transformer( "." + import_module_name, file)) elif file.endswith("calculate.py") or file.endswith("pipeline.py"): file_name_only = file.replace(".py", "") self.my_transformers.append( self.get_python_transformer(f".{file_name_only}", file)) assert len(self.my_transformers) > 0, ( f"No transformer files found in {self.location}." " There should be one or more .sql, .csv, *mapping.py, *calculate.py or *pipeline.py files" )