def test_automapper_filter_and_transform(spark_session: SparkSession) -> None:
    clean_spark_session(spark_session)
    data_dir: Path = Path(__file__).parent.joinpath("./")

    data_json_file: Path = data_dir.joinpath("data.json")

    source_df: DataFrame = spark_session.read.json(str(data_json_file),
                                                   multiLine=True)

    source_df.createOrReplaceTempView("patients")

    source_df.show(truncate=False)

    # Act
    mapper = AutoMapper(view="members", source_view="patients").complex(
        MyObject(age=A.transform(
            A.filter(column=A.column("identifier"),
                     func=lambda x: x["use"] == lit("usual")),
            A.complex(bar=A.field("value"), bar2=A.field("system")))))

    assert isinstance(mapper, AutoMapper)
    sql_expressions: Dict[str, Column] = mapper.get_column_specs(
        source_df=source_df)
    for column_name, sql_expression in sql_expressions.items():
        print(f"{column_name}: {sql_expression}")

    assert str(sql_expressions["age"]) == str(
        transform(
            filter("b.identifier", lambda x: x["use"] == lit("usual")),
            lambda x: struct(x["value"].alias("bar"), x["system"].alias("bar2")
                             )).alias("age"))
    result_df: DataFrame = mapper.transform(df=source_df)

    result_df.show(truncate=False)
예제 #2
0
    def get_column_spec(self, source_df: Optional[DataFrame],
                        current_column: Optional[Column]) -> Column:
        if isinstance(
                self.value, str
        ):  # if the src column is just string then consider it a sql expression
            return array(lit(self.value))

        if isinstance(self.value,
                      list):  # if the src column is a list then iterate
            return filter(array(
                [
                    self.get_value(item, source_df=source_df, current_column=current_column)
                    for item in self.value
                ]
            ), lambda x: x.isNotNull()) \
                if self.remove_nulls \
                else array(
                [
                    self.get_value(item, source_df=source_df, current_column=current_column)
                    for item in self.value
                ]
            )

        # if value is an AutoMapper then ask it for its column spec
        if isinstance(self.value, AutoMapperDataTypeBase):
            child: AutoMapperDataTypeBase = self.value
            return child.get_column_spec(source_df=source_df,
                                         current_column=current_column)

        raise ValueError(f"value: {self.value} is neither str nor AutoMapper")
def test_auto_mapper_array_multiple_items_with_null(
        spark_session: SparkSession) -> None:
    # Arrange
    spark_session.createDataFrame([
        (1, 'Qureshi', 'Imran'),
        (2, 'Vidal', 'Michael'),
    ], ['member_id', 'last_name', 'first_name'
        ]).createOrReplaceTempView("patients")

    source_df: DataFrame = spark_session.table("patients")

    df: DataFrame = source_df.select("member_id")
    df.createOrReplaceTempView("members")

    # Act
    mapper = AutoMapper(
        view="members",
        source_view="patients",
        keys=["member_id"],
        drop_key_columns=False).columns(
            dst2=AutoMapperList(["address1", "address2", None]))

    assert isinstance(mapper, AutoMapper)
    sql_expressions: Dict[str, Column] = mapper.get_column_specs(
        source_df=source_df)
    for column_name, sql_expression in sql_expressions.items():
        print(f"{column_name}: {sql_expression}")

    assert str(sql_expressions["dst2"]) == str(
        filter(array(lit("address1"), lit("address2"), lit(None)),
               lambda x: x.isNotNull()).alias("dst2"))

    result_df: DataFrame = mapper.transform(df=df)

    # Assert
    result_df.printSchema()
    result_df.show()

    assert result_df.where("member_id == 1").select(
        "dst2").collect()[0][0][0] == "address1"
    assert result_df.where("member_id == 1").select(
        "dst2").collect()[0][0][1] == "address2"
    assert result_df.where("member_id == 2").select(
        "dst2").collect()[0][0][0] == "address1"
    assert result_df.where("member_id == 2").select(
        "dst2").collect()[0][0][1] == "address2"
예제 #4
0
def test_automapper_select_one(spark_session: SparkSession) -> None:
    clean_spark_session(spark_session)
    data_dir: Path = Path(__file__).parent.joinpath("./")

    data_json_file: Path = data_dir.joinpath("data.json")

    source_df: DataFrame = spark_session.read.json(
        str(data_json_file), multiLine=True
    )

    source_df.createOrReplaceTempView("patients")

    source_df.show(truncate=False)

    # Act
    mapper = AutoMapper(view="members", source_view="patients").columns(
        age=A.column("identifier").filter(
            lambda x: x["system"] == "http://hl7.org/fhir/sid/us-npi"
        ).select_one(A.field("_.value"))
    )

    assert isinstance(mapper, AutoMapper)
    sql_expressions: Dict[str, Column] = mapper.get_column_specs(
        source_df=source_df
    )
    for column_name, sql_expression in sql_expressions.items():
        print(f"{column_name}: {sql_expression}")

    assert str(sql_expressions["age"]) == str(
        transform(
            filter(
                "b.identifier",
                lambda x: x["system"] == lit("http://hl7.org/fhir/sid/us-npi")
            ), lambda x: x["value"]
        )[0].alias("age")
    )
    result_df: DataFrame = mapper.transform(df=source_df)

    result_df.show(truncate=False)

    assert result_df.select("age").collect()[0][0] == "1730325416"
    assert result_df.select("age").collect()[1][0] == "1467734301"
예제 #5
0
 def get_column_spec(self, source_df: Optional[DataFrame],
                     current_column: Optional[Column]) -> Column:
     return filter(
         self.column.get_column_spec(source_df=source_df,
                                     current_column=current_column),
         self.func)
def test_auto_mapper_columns(spark_session: SparkSession) -> None:
    # Arrange
    spark_session.createDataFrame([
        (1, 'Qureshi', 'Imran'),
        (2, 'Vidal', 'Michael'),
    ], ['member_id', 'last_name', 'first_name'
        ]).createOrReplaceTempView("patients")

    source_df: DataFrame = spark_session.table("patients")

    df = source_df.select("member_id")
    df.createOrReplaceTempView("members")

    # Act
    mapper = AutoMapper(view="members",
                        source_view="patients",
                        keys=["member_id"],
                        drop_key_columns=False).columns(
                            dst1="src1",
                            dst2=AutoMapperList(["address1"]),
                            dst3=AutoMapperList(["address1", "address2"]),
                            dst4=AutoMapperList([
                                A.complex(use="usual",
                                          family=A.column("last_name"))
                            ]))

    assert isinstance(mapper, AutoMapper)
    sql_expressions: Dict[str, Column] = mapper.get_column_specs(
        source_df=source_df)
    for column_name, sql_expression in sql_expressions.items():
        print(f"{column_name}: {sql_expression}")

    # Assert
    assert len(sql_expressions) == 4
    assert str(sql_expressions["dst1"]) == str(lit("src1").alias("dst1"))
    assert str(sql_expressions["dst2"]) == str(
        filter(array(lit("address1")), lambda x: x.isNotNull()).alias("dst2"))
    assert str(sql_expressions["dst3"]) == str(
        filter(array(lit("address1"), lit("address2")),
               lambda x: x.isNotNull()).alias("dst3"))
    assert str(sql_expressions["dst4"]) == str(
        filter(
            array(
                struct(
                    lit("usual").alias("use"),
                    col("b.last_name").alias("family"))),
            lambda x: x.isNotNull()).alias("dst4"))

    result_df: DataFrame = mapper.transform(df=df)

    # Assert
    result_df.printSchema()
    result_df.show()

    assert len(result_df.columns) == 5
    assert result_df.where("member_id == 1").select(
        "dst1").collect()[0][0] == "src1"
    assert result_df.where("member_id == 1").select(
        "dst2").collect()[0][0][0] == "address1"

    assert result_df.where("member_id == 1").select(
        "dst3").collect()[0][0][0] == "address1"
    assert result_df.where("member_id == 1").select(
        "dst3").collect()[0][0][1] == "address2"

    assert result_df.where("member_id == 1").select(
        "dst4").collect()[0][0][0][0] == "usual"
    assert result_df.where("member_id == 1").select(
        "dst4").collect()[0][0][0][1] == "Qureshi"