def test_higher_order_function_failures(self):
        from pyspark.sql.functions import col, transform

        # Should fail with varargs
        with self.assertRaises(ValueError):
            transform(col("foo"), lambda *x: lit(1))

        # Should fail with kwargs
        with self.assertRaises(ValueError):
            transform(col("foo"), lambda **x: lit(1))

        # Should fail with nullary function
        with self.assertRaises(ValueError):
            transform(col("foo"), lambda: lit(1))

        # Should fail with quaternary function
        with self.assertRaises(ValueError):
            transform(col("foo"), lambda x1, x2, x3, x4: lit(1))

        # Should fail if function doesn't return Column
        with self.assertRaises(ValueError):
            transform(col("foo"), lambda x: 1)
Esempio n. 2
0
    def get_column_spec(
        self, source_df: Optional[DataFrame], current_column: Optional[Column]
    ) -> Column:
        column_spec: Column = self.column.get_column_spec(
            source_df=source_df, current_column=current_column
        )

        def get_column_spec_for_column(x: Column) -> Column:
            value_get_column_spec: Column = self.value.get_column_spec(
                source_df=source_df, current_column=x
            )
            return value_get_column_spec

        return transform(column_spec, get_column_spec_for_column)
def test_automapper_select_one(spark_session: SparkSession) -> None:
    clean_spark_session(spark_session)
    data_dir: Path = Path(__file__).parent.joinpath("./")

    data_json_file: Path = data_dir.joinpath("data.json")

    source_df: DataFrame = spark_session.read.json(str(data_json_file),
                                                   multiLine=True)

    source_df.createOrReplaceTempView("patients")

    source_df.show(truncate=False)

    # Act
    mapper = AutoMapper(
        view="members",
        source_view="patients").columns(age=A.column("identifier").filter(
            lambda x: x["system"] == "http://hl7.org/fhir/sid/us-npi").
                                        select_one(A.field("_.value")))

    assert isinstance(mapper, AutoMapper)
    sql_expressions: Dict[str, Column] = mapper.get_column_specs(
        source_df=source_df)
    for column_name, sql_expression in sql_expressions.items():
        print(f"{column_name}: {sql_expression}")

    assert_compare_expressions(
        sql_expressions["age"],
        transform(
            filter(
                "b.identifier",
                lambda x: x["system"] == lit("http://hl7.org/fhir/sid/us-npi"),
            ),
            lambda x: x["value"],
        )[0].alias("age"),
    )
    result_df: DataFrame = mapper.transform(df=source_df)

    result_df.show(truncate=False)

    assert result_df.select("age").collect()[0][0] == "1730325416"
    assert result_df.select("age").collect()[1][0] == "1467734301"
Esempio n. 4
0
def test_automapper_transform(spark_session: SparkSession) -> None:
    clean_spark_session(spark_session)
    data_dir: Path = Path(__file__).parent.joinpath("./")

    data_json_file: Path = data_dir.joinpath("data.json")

    source_df: DataFrame = spark_session.read.json(str(data_json_file),
                                                   multiLine=True)

    source_df.createOrReplaceTempView("patients")

    source_df.show(truncate=False)

    # Act
    mapper = AutoMapper(view="members", source_view="patients").complex(
        MyObject(age=A.transform(
            A.column("identifier"),
            A.complex(bar=A.field("value"), bar2=A.field("system")),
        )))

    assert isinstance(mapper, AutoMapper)
    sql_expressions: Dict[str, Column] = mapper.get_column_specs(
        source_df=source_df)
    for column_name, sql_expression in sql_expressions.items():
        print(f"{column_name}: {sql_expression}")

    assert_compare_expressions(
        sql_expressions["age"],
        transform(
            "b.identifier",
            lambda x: struct(
                col("x[value]").alias("bar"),
                col("x[system]").alias("bar2")),
        ).alias("age"),
    )
    result_df: DataFrame = mapper.transform(df=source_df)

    result_df.show(truncate=False)

    assert result_df.select("age").collect()[0][0][0][0] == "123"
Esempio n. 5
0
def test_automapper_nested_array_filter_with_parent_column(
    spark_session: SparkSession,
) -> None:
    schema = StructType(
        [
            StructField("row_id", dataType=IntegerType(), nullable=False),
            StructField(
                "location",
                dataType=ArrayType(
                    StructType(
                        [
                            StructField("name", StringType(), True),
                        ]
                    )
                ),
            ),
            StructField(
                "schedule",
                dataType=ArrayType(
                    StructType(
                        [
                            StructField("name", StringType(), True),
                            StructField(
                                "actor",
                                ArrayType(
                                    StructType(
                                        [StructField("reference", StringType(), True)]
                                    ),
                                    True,
                                ),
                            ),
                        ]
                    )
                ),
            ),
            StructField(
                "single_level",
                dataType=ArrayType(
                    StructType(
                        [
                            StructField("reference", StringType(), True),
                        ]
                    )
                ),
            ),
        ]
    )
    spark_session.createDataFrame(
        [
            (
                1,
                [{"name": "location-100"}, {"name": "location-200"}],
                [
                    {
                        "name": "schedule-1",
                        "actor": [
                            {"reference": "location-100"},
                            {"reference": "practitioner-role-100"},
                        ],
                    },
                    {
                        "name": "schedule-2",
                        "actor": [
                            {"reference": "location-200"},
                            {"reference": "practitioner-role-200"},
                        ],
                    },
                ],
                [{"reference": "location-100"}, {"reference": "location-200"}],
            )
        ],
        schema,
    ).createOrReplaceTempView("patients")

    source_df: DataFrame = spark_session.table("patients")

    mapper = AutoMapper(
        view="schedule", source_view="patients", keys=["row_id"]
    ).columns(
        location=A.column("location").select(
            AutoMapperElasticSearchLocation(
                name=A.field("name"),
                scheduling=A.nested_array_filter(
                    array_field=A.column("schedule"),
                    inner_array_field=A.field("actor"),
                    match_property="reference",
                    match_value=A.field("{parent}.name"),
                ).select_one(AutoMapperElasticSearchSchedule(name=A.field("name"))),
            )
        )
    )

    assert isinstance(mapper, AutoMapper)
    sql_expressions: Dict[str, Column] = mapper.get_column_specs(source_df=source_df)
    print("------COLUMN SPECS------")
    for column_name, sql_expression in sql_expressions.items():
        print(f"{column_name}: {sql_expression}")
    assert_compare_expressions(
        sql_expressions["location"],
        transform(
            col("b.location"),
            lambda l: (
                struct(
                    l["name"].alias("name"),
                    transform(
                        filter(
                            col("b.schedule"),
                            lambda s: exists(
                                s["actor"],
                                lambda a: a["reference"] == l["name"],  # type: ignore
                            ),
                        ),
                        lambda s: struct(s["name"].alias("name")),
                    )[0].alias("scheduling"),
                )
            ),
        ).alias("___location"),
    )
    result_df: DataFrame = mapper.transform(df=source_df)

    # Assert
    # result_df.printSchema()
    # result_df.show(truncate=False)
    location_row = result_df.collect()[0].location
    for index, location in enumerate(location_row):
        location_name = location.name
        location_scheduling = location.scheduling
        assert location_name == f"location-{index + 1}00"
        assert len(location_scheduling) == 1
        assert location_scheduling.name == f"schedule-{index + 1}"