コード例 #1
0
def test_can_load_non_standard_delimited_csv(
        spark_session: SparkSession) -> None:
    # Arrange
    clean_spark_session(spark_session)

    data_dir: Path = Path(__file__).parent.joinpath("./")
    test_file_path: str = f"{data_dir.joinpath('test.psv')}"

    schema = StructType([])

    df: DataFrame = spark_session.createDataFrame(
        spark_session.sparkContext.emptyRDD(), schema)

    # Act
    loader = FrameworkCsvLoader(view="my_view",
                                filepath=test_file_path,
                                delimiter="|")
    loader.transform(df)

    # noinspection SqlDialectInspection
    result: DataFrame = spark_session.sql("SELECT * FROM my_view")

    result.show()

    # Assert
    assert loader.getDelimiter() == "|"
    assert_results(result)
コード例 #2
0
def test_framework_fill_na_transformer(spark_session: SparkSession) -> None:
    # create a dataframe with the test data
    data_dir: Path = Path(__file__).parent.joinpath("./")

    df: DataFrame = create_empty_dataframe(spark_session=spark_session)

    view: str = "primary_care_protocol"
    FrameworkCsvLoader(
        view=view,
        filepath=data_dir.joinpath("primary_care_protocol.csv"),
        clean_column_names=False,
    ).transform(df)

    # ensure we have all the rows even the ones we want to drop
    result_df: DataFrame = spark_session.table(view)
    result_df = result_df.withColumn("Minimum Age",
                                     result_df["Minimum Age"].cast("float"))
    result_df.createOrReplaceTempView(view)
    assert 7 == result_df.count()

    # drop the rows with null NPI or null Last Name
    FrameworkFillNaTransformer(view=view,
                               column_mapping={
                                   "Minimum Age": 1.0,
                                   "Maximum Age": "No Limit"
                               }).transform(df)

    # assert we get only the rows with a populated NPI
    result_df = spark_session.table(view)
    assert 7 == result_df.count()
    assert "No Limit" == result_df.select(
        "Maximum Age").collect()[1].__getitem__("Maximum Age")
    assert 24.0 == result_df.agg({"Minimum Age": "sum"}).collect()[0][0]
コード例 #3
0
def test_can_keep_columns(spark_session: SparkSession) -> None:
    # Arrange
    clean_spark_session(spark_session)

    data_dir: Path = Path(__file__).parent.joinpath("./")
    test_file_path: str = f"{data_dir.joinpath('test.csv')}"

    schema = StructType([])

    df: DataFrame = spark_session.createDataFrame(
        spark_session.sparkContext.emptyRDD(), schema)

    # Act
    FrameworkCsvLoader(view="my_view", filepath=test_file_path,
                       delimiter=",").transform(df)

    FrameworkSelectColumnsTransformer(view="my_view",
                                      keep_columns=["Column2"]).transform(df)

    # noinspection SqlDialectInspection
    result: DataFrame = spark_session.sql("SELECT * FROM my_view")

    result.show()

    # Assert
    assert len(result.columns) == 1

    assert result.count() == 3

    assert result.collect()[1][0] == "bar"
コード例 #4
0
def test_validation_recurses_query_dir(spark_session: SparkSession) -> None:
    clean_spark_session(spark_session)
    query_dir: Path = Path(__file__).parent.joinpath("./queries")
    more_queries_dir: str = "more_queries"
    data_dir: Path = Path(__file__).parent.joinpath("./data")
    test_data_file: str = f"{data_dir.joinpath('test.csv')}"
    validation_query_file: str = "validate.sql"

    schema = StructType([])

    df: DataFrame = spark_session.createDataFrame(
        spark_session.sparkContext.emptyRDD(), schema)

    FrameworkCsvLoader(view="my_view", filepath=test_data_file).transform(df)

    FrameworkValidationTransformer(
        validation_source_path=str(query_dir),
        validation_queries=[validation_query_file, more_queries_dir],
    ).transform(df)

    df_validation = df.sql_ctx.table("pipeline_validation")
    df_validation.show(truncate=False)
    assert 3 == df_validation.count(
    ), "Expected 3 total rows in pipeline_validation"
    assert (1 == df_validation.filter("is_failed == 1").count()
            ), "Expected one failing row in the validation table"
コード例 #5
0
def test_simple_csv_and_sql_pipeline(spark_session: SparkSession) -> None:
    # Arrange
    data_dir: Path = Path(__file__).parent.joinpath("./")
    flights_path: str = f"file://{data_dir.joinpath('flights.csv')}"

    schema = StructType([])

    df: DataFrame = spark_session.createDataFrame(
        spark_session.sparkContext.emptyRDD(), schema)

    spark_session.sql("DROP TABLE IF EXISTS default.flights")

    # Act
    parameters: Dict[str, Any] = {}

    stages: List[Transformer] = create_steps([
        FrameworkCsvLoader(view="flights", filepath=flights_path),
        FeaturesCarriersV1(parameters=parameters),
    ])

    pipeline: Pipeline = Pipeline(stages=stages)  # type: ignore
    transformer = pipeline.fit(df)
    transformer.transform(df)

    # Assert
    result_df: DataFrame = spark_session.sql("SELECT * FROM flights2")
    result_df.show()

    assert result_df.count() > 0
コード例 #6
0
def test_correctly_loads_csv_with_clean_flag_on(
        spark_session: SparkSession) -> None:
    # Arrange
    clean_spark_session(spark_session)

    data_dir: Path = Path(__file__).parent.joinpath("./")
    test_file_path: str = f"{data_dir.joinpath('column_name_test.csv')}"

    schema = StructType([])

    df: DataFrame = spark_session.createDataFrame(
        spark_session.sparkContext.emptyRDD(), schema)

    # Act
    FrameworkCsvLoader(
        view="my_view",
        filepath=test_file_path,
        delimiter=",",
        clean_column_names=True,
    ).transform(df)

    # noinspection SqlDialectInspection
    result: DataFrame = spark_session.sql("SELECT * FROM my_view")

    # Assert
    assert_results(result)
    assert result.collect()[1][0] == "2"
    assert (result.columns[2] ==
            "Ugly_column_with_chars_that_parquet_does_not_like_much_-")
コード例 #7
0
 def __init__(self, parameters: Dict[str, Any],
              progress_logger: ProgressLogger):
     super(MyFailFastValidatedPipeline, self).__init__(
         parameters=parameters,
         progress_logger=progress_logger,
         run_id="12345678",
         validation_output_path=parameters["validation_output_path"],
     )
     self.transformers = self.create_steps([
         FrameworkCsvLoader(
             view="flights",
             filepath=parameters["flights_path"],
             parameters=parameters,
             progress_logger=progress_logger,
         ),
         FrameworkValidationTransformer(
             validation_source_path=parameters["validation_source_path"],
             validation_queries=["validate.sql"],
             fail_on_validation=True,
         ),
         FeaturesCarriersV1(parameters=parameters,
                            progress_logger=progress_logger),
         FeaturesCarriersPythonV1(parameters=parameters,
                                  progress_logger=progress_logger),
     ])
コード例 #8
0
def test_framework_drop_duplicates_transformer(
        spark_session: SparkSession) -> None:
    # create a dataframe with the test data
    data_dir: Path = Path(__file__).parent.joinpath("./")

    df: DataFrame = create_empty_dataframe(spark_session=spark_session)

    view: str = "primary_care_protocol"
    FrameworkCsvLoader(
        view=view,
        filepath=data_dir.joinpath("primary_care_protocol.csv"),
        clean_column_names=False,
    ).transform(df)

    # ensure we have all the rows even the ones we want to drop
    result_df: DataFrame = spark_session.table(view)
    assert 3 == result_df.count()

    # drop the rows with null NPI or null Last Name
    FrameworkDropDuplicatesTransformer(columns=["NPI"],
                                       view=view).transform(df)

    # assert we get only the rows with a populated NPI
    result_df = spark_session.table(view)
    assert 2 == result_df.count()
コード例 #9
0
def test_framework_drop_rows_with_null_transformer(
        spark_session: SparkSession) -> None:
    # create a dataframe with the test data
    data_dir: Path = Path(__file__).parent.joinpath("./")

    df: DataFrame = create_empty_dataframe(spark_session=spark_session)

    view: str = "primary_care_protocol"
    FrameworkCsvLoader(
        view=view,
        filepath=data_dir.joinpath("primary_care_protocol.csv"),
        clean_column_names=False,
    ).transform(df)

    # ensure we have all the rows even the ones we want to drop
    result_df: DataFrame = spark_session.table(view)
    assert 7 == result_df.count()

    # drop the rows with null NPI or null Last Name
    FrameworkDropRowsWithNullTransformer(columns_to_check=["NPI", "Last Name"],
                                         view=view).transform(df)

    # assert we get only the rows with a populated NPI
    result_df = spark_session.table(view)
    assert 1 == result_df.count()

    # ensure that no rows are dropped when there are no null values
    FrameworkDropRowsWithNullTransformer(columns_to_check=["NPI", "Last Name"],
                                         view=view).transform(result_df)
    assert 1 == result_df.count()
コード例 #10
0
def test_simple_csv_loader_pipeline(spark_session: SparkSession) -> None:
    # Arrange
    data_dir: Path = Path(__file__).parent.joinpath("./")
    flights_path: str = f"file://{data_dir.joinpath('flights.csv')}"

    schema = StructType([])

    df: DataFrame = spark_session.createDataFrame(
        spark_session.sparkContext.emptyRDD(), schema)

    # noinspection SqlDialectInspection,SqlNoDataSourceInspection
    spark_session.sql("DROP TABLE IF EXISTS default.flights")

    # Act
    # parameters = Dict[str, Any]({
    # })

    stages: List[Union[Estimator[Any], Transformer]] = [
        FrameworkCsvLoader(view="flights", filepath=flights_path),
        SQLTransformer(statement="SELECT * FROM flights"),
    ]

    pipeline: Pipeline = Pipeline(stages=stages)

    transformer = pipeline.fit(df)
    result_df: DataFrame = transformer.transform(df)

    # Assert
    result_df.show()

    assert result_df.count() > 0
コード例 #11
0
def test_can_save_csv(spark_session: SparkSession) -> None:
    # Arrange
    SparkTestHelper.clear_tables(spark_session)

    data_dir: Path = Path(__file__).parent.joinpath("./")
    test_file_path: str = f"{data_dir.joinpath('test.csv')}"

    if path.isdir(data_dir.joinpath("temp")):
        shutil.rmtree(data_dir.joinpath("temp"))

    schema = StructType([])

    df: DataFrame = spark_session.createDataFrame(
        spark_session.sparkContext.emptyRDD(), schema
    )

    FrameworkCsvLoader(
        view="my_view", filepath=test_file_path, delimiter=","
    ).transform(df)

    csv_file_path: str = f"file://{data_dir.joinpath('temp/').joinpath(f'test.csv')}"

    # Act
    FrameworkCsvExporter(
        view="my_view", file_path=csv_file_path, header=True, delimiter=","
    ).transform(df)

    # Assert
    FrameworkCsvLoader(
        view="my_view2", filepath=csv_file_path, delimiter=","
    ).transform(df)

    # noinspection SqlDialectInspection
    result: DataFrame = spark_session.sql("SELECT * FROM my_view2")

    result.show()

    assert result.count() == 3

    assert result.collect()[1][0] == "2"
    assert result.collect()[1][1] == "bar"
    assert result.collect()[1][2] == "bar2"
コード例 #12
0
 def __init__(self, parameters: Dict[str, Any],
              progress_logger: ProgressLogger):
     super(MyPipeline, self).__init__(parameters=parameters,
                                      progress_logger=progress_logger)
     self.transformers = self.create_steps([
         FrameworkCsvLoader(
             view="flights",
             filepath=parameters["flights_path"],
             parameters=parameters,
             progress_logger=progress_logger,
         ),
         FeaturesCarriersV1(parameters=parameters,
                            progress_logger=progress_logger),
         FeaturesCarriersPythonV1(parameters=parameters,
                                  progress_logger=progress_logger),
     ])
コード例 #13
0
def test_can_load_multiline_csv(spark_session: SparkSession) -> None:
    # Arrange
    clean_spark_session(spark_session)

    data_dir: Path = Path(__file__).parent.joinpath("./")
    test_file_path: str = f"{data_dir.joinpath('multiline_row.csv')}"

    schema = StructType([])

    df: DataFrame = spark_session.createDataFrame(
        spark_session.sparkContext.emptyRDD(), schema)

    # Act
    FrameworkCsvLoader(view="my_view",
                       filepath=test_file_path,
                       delimiter=",",
                       multiline=True).transform(df)

    # noinspection SqlDialectInspection
    result: DataFrame = spark_session.sql("SELECT * FROM my_view")
    assert 1 == result.count()
コード例 #14
0
def test_validation_throws_error(spark_session: SparkSession) -> None:
    with pytest.raises(AssertionError):
        clean_spark_session(spark_session)
        query_dir: Path = Path(__file__).parent.joinpath("./queries")
        data_dir: Path = Path(__file__).parent.joinpath("./data")
        test_data_file: str = f"{data_dir.joinpath('test.csv')}"
        validation_query_file: str = "validate.sql"

        schema = StructType([])

        df: DataFrame = spark_session.createDataFrame(
            spark_session.sparkContext.emptyRDD(), schema)

        FrameworkCsvLoader(view="my_view",
                           filepath=test_data_file).transform(df)

        FrameworkValidationTransformer(
            validation_source_path=str(query_dir),
            validation_queries=[validation_query_file],
            fail_on_validation=True,
        ).transform(df)

        df.sql_ctx.table("pipeline_validation").show(truncate=False)
コード例 #15
0
def test_can_load_csv_without_header(spark_session: SparkSession) -> None:
    # Arrange
    clean_spark_session(spark_session)

    data_dir: Path = Path(__file__).parent.joinpath("./")
    test_file_path: str = f"{data_dir.joinpath('no_header.csv')}"

    schema = StructType([])

    df: DataFrame = spark_session.createDataFrame(
        spark_session.sparkContext.emptyRDD(), schema)

    # Act
    FrameworkCsvLoader(view="another_view",
                       filepath=test_file_path,
                       delimiter=",",
                       has_header=False).transform(df)

    # noinspection SqlDialectInspection
    result: DataFrame = spark_session.sql("SELECT * FROM another_view")

    # Assert
    assert_results(result)
コード例 #16
0
    def __init__(
        self,
        parameters: Dict[str, Any],
        location: Union[str, Path],
        progress_logger: Optional[ProgressLogger] = None,
        verify_count_remains_same: bool = False,
    ) -> None:
        super().__init__(
            name=self.__class__.__name__,
            parameters=parameters,
            progress_logger=progress_logger,
        )
        self.verify_count_remains_same: bool = verify_count_remains_same
        self.location: str = str(location)
        self.my_transformers: List[Transformer] = []
        self.logger = get_logger(__name__)

        assert self.location
        # Iterate over files to create transformers
        files: List[str] = sorted(listdir(self.location))
        index_of_module: int = self.location.rfind("/library/")
        module_ = index_of_module + 1
        module_name: str = self.location[module_:].replace("/", ".")

        # noinspection Mypy
        self.setParams(
            name=self.__class__.__name__,
            parameters=parameters,
            progress_logger=progress_logger,
        )

        for file in files:
            if file.endswith(".csv"):
                file_name = file.replace(".csv", "")
                self.my_transformers.append(
                    FrameworkCsvLoader(
                        view=file_name,
                        filepath=path.join(self.location, file),
                        delimiter=parameters.get("delimiter", ","),
                        has_header=parameters.get("has_header", True),
                        mapping_file_name=file,
                    ))
            elif file.endswith(".sql"):
                feature_sql: str = self.read_file_as_string(
                    path.join(self.location,
                              file)).format(parameters=parameters)
                self.my_transformers.append(
                    FrameworkSqlTransformer(
                        sql=feature_sql,
                        name=module_name,
                        progress_logger=progress_logger,
                        log_sql=parameters.get("debug_log_sql", False),
                        view=file.replace(".sql", ""),
                        verify_count_remains_same=verify_count_remains_same,
                        mapping_file_name=file,
                    ))
            elif file.endswith("mapping.py"):
                file_name_only: str = os.path.basename(file)
                # strip off .py to get the module name
                import_module_name: str = file_name_only.replace(".py", "")
                self.my_transformers.append(
                    self.get_python_mapping_transformer(
                        "." + import_module_name, file))
            elif file.endswith("calculate.py") or file.endswith("pipeline.py"):
                file_name_only = file.replace(".py", "")
                self.my_transformers.append(
                    self.get_python_transformer(f".{file_name_only}", file))

        assert len(self.my_transformers) > 0, (
            f"No transformer files found in {self.location}."
            "  There should be one or more .sql, .csv, *mapping.py, *calculate.py or *pipeline.py files"
        )