def test_automapper_nested_array_filter_simple_with_array(
    spark_session: SparkSession, ) -> None:
    clean_spark_session(spark_session)
    data_dir: Path = Path(__file__).parent.joinpath("./")

    environ["LOGLEVEL"] = "DEBUG"

    data_json_file: Path = data_dir.joinpath("data.json")

    source_df: DataFrame = spark_session.read.json(str(data_json_file),
                                                   multiLine=True)

    source_df.createOrReplaceTempView("patients")

    source_df.show(truncate=False)

    # Act
    mapper = AutoMapper(
        view="members",
        source_view="patients").columns(age=A.nested_array_filter(
            array_field=A.column("array1"),
            inner_array_field=A.field("array2"),
            match_property="reference",
            match_value=A.text("bar"),
        ))

    assert isinstance(mapper, AutoMapper)
    sql_expressions: Dict[str, Column] = mapper.get_column_specs(
        source_df=source_df)
    for column_name, sql_expression in sql_expressions.items():
        print(f"{column_name}: {sql_expression}")

    assert_compare_expressions(
        sql_expressions["age"],
        filter(
            col("b.array1"),
            lambda y: exists(
                y["array2"], lambda x: x["reference"] == lit("bar").cast(
                    "string")),
        ).alias("age"),
    )
    result_df: DataFrame = mapper.transform(df=source_df)

    result_df.printSchema()
    result_df.show(truncate=False)

    assert result_df.count() == 2
    assert result_df.select("age").collect()[0][0] == []
    assert result_df.select(
        "age").collect()[1][0][0]["array2"][0]["reference"] == "bar"
Esempio n. 2
0
def test_automapper_nested_array_filter_with_parent_column(
    spark_session: SparkSession,
) -> None:
    schema = StructType(
        [
            StructField("row_id", dataType=IntegerType(), nullable=False),
            StructField(
                "location",
                dataType=ArrayType(
                    StructType(
                        [
                            StructField("name", StringType(), True),
                        ]
                    )
                ),
            ),
            StructField(
                "schedule",
                dataType=ArrayType(
                    StructType(
                        [
                            StructField("name", StringType(), True),
                            StructField(
                                "actor",
                                ArrayType(
                                    StructType(
                                        [StructField("reference", StringType(), True)]
                                    ),
                                    True,
                                ),
                            ),
                        ]
                    )
                ),
            ),
            StructField(
                "single_level",
                dataType=ArrayType(
                    StructType(
                        [
                            StructField("reference", StringType(), True),
                        ]
                    )
                ),
            ),
        ]
    )
    spark_session.createDataFrame(
        [
            (
                1,
                [{"name": "location-100"}, {"name": "location-200"}],
                [
                    {
                        "name": "schedule-1",
                        "actor": [
                            {"reference": "location-100"},
                            {"reference": "practitioner-role-100"},
                        ],
                    },
                    {
                        "name": "schedule-2",
                        "actor": [
                            {"reference": "location-200"},
                            {"reference": "practitioner-role-200"},
                        ],
                    },
                ],
                [{"reference": "location-100"}, {"reference": "location-200"}],
            )
        ],
        schema,
    ).createOrReplaceTempView("patients")

    source_df: DataFrame = spark_session.table("patients")

    mapper = AutoMapper(
        view="schedule", source_view="patients", keys=["row_id"]
    ).columns(
        location=A.column("location").select(
            AutoMapperElasticSearchLocation(
                name=A.field("name"),
                scheduling=A.nested_array_filter(
                    array_field=A.column("schedule"),
                    inner_array_field=A.field("actor"),
                    match_property="reference",
                    match_value=A.field("{parent}.name"),
                ).select_one(AutoMapperElasticSearchSchedule(name=A.field("name"))),
            )
        )
    )

    assert isinstance(mapper, AutoMapper)
    sql_expressions: Dict[str, Column] = mapper.get_column_specs(source_df=source_df)
    print("------COLUMN SPECS------")
    for column_name, sql_expression in sql_expressions.items():
        print(f"{column_name}: {sql_expression}")
    assert_compare_expressions(
        sql_expressions["location"],
        transform(
            col("b.location"),
            lambda l: (
                struct(
                    l["name"].alias("name"),
                    transform(
                        filter(
                            col("b.schedule"),
                            lambda s: exists(
                                s["actor"],
                                lambda a: a["reference"] == l["name"],  # type: ignore
                            ),
                        ),
                        lambda s: struct(s["name"].alias("name")),
                    )[0].alias("scheduling"),
                )
            ),
        ).alias("___location"),
    )
    result_df: DataFrame = mapper.transform(df=source_df)

    # Assert
    # result_df.printSchema()
    # result_df.show(truncate=False)
    location_row = result_df.collect()[0].location
    for index, location in enumerate(location_row):
        location_name = location.name
        location_scheduling = location.scheduling
        assert location_name == f"location-{index + 1}00"
        assert len(location_scheduling) == 1
        assert location_scheduling.name == f"schedule-{index + 1}"