예제 #1
0
def test_median_metric_spark(spark_session):
    engine = build_spark_engine(
        spark=spark_session,
        df=pd.DataFrame({"a": [1, 2, 3]}, ),
        batch_id="my_id",
    )

    desired_metric = MetricConfiguration(
        metric_name="table.row_count.aggregate_fn",
        metric_domain_kwargs={"column": "a"},
        metric_value_kwargs=dict(),
    )
    metrics = engine.resolve_metrics(metrics_to_resolve=(desired_metric, ))

    row_count = MetricConfiguration(
        metric_name="table.row_count",
        metric_domain_kwargs={},
        metric_value_kwargs=dict(),
        metric_dependencies={"metric_partial_fn": desired_metric},
    )
    metrics = engine.resolve_metrics(metrics_to_resolve=(row_count, ),
                                     metrics=metrics)

    desired_metric = MetricConfiguration(
        metric_name="column.median",
        metric_domain_kwargs={"column": "a"},
        metric_value_kwargs=dict(),
        metric_dependencies={"table.row_count": row_count},
    )
    results = engine.resolve_metrics(metrics_to_resolve=(desired_metric, ),
                                     metrics=metrics)
    assert results == {desired_metric.id: 2}
예제 #2
0
def test_get_compute_domain_with_nonexistent_condition_parser(
    spark_session, basic_spark_df_execution_engine
):
    engine = build_spark_engine(
        spark=spark_session,
        df=pd.DataFrame(
            {"a": [1, 2, 3, 4], "b": [2, 3, 4, None]},
        ),
        batch_id="1234",
    )
    df = engine.dataframe

    # Loading batch data
    engine.load_batch_data(batch_data=df, batch_id="1234")

    # Expect GreatExpectationsError because parser doesn't exist
    with pytest.raises(GreatExpectationsError):
        # noinspection PyUnusedLocal
        data, compute_kwargs, accessor_kwargs = engine.get_compute_domain(
            domain_kwargs={
                "row_condition": "b > 24",
                "condition_parser": "nonexistent",
            },
            domain_type=MetricDomainTypes.IDENTITY,
        )
예제 #3
0
def test_get_compute_domain_with_row_condition_alt(
    spark_session, basic_spark_df_execution_engine
):
    engine = build_spark_engine(
        spark=spark_session,
        df=pd.DataFrame(
            {"a": [1, 2, 3, 4], "b": [2, 3, 4, None]},
        ),
        batch_id="1234",
    )
    df = engine.dataframe
    expected_df = df.where("b > 2")

    # Loading batch data
    engine.load_batch_data(batch_data=df, batch_id="1234")

    data, compute_kwargs, accessor_kwargs = engine.get_compute_domain(
        domain_kwargs={"row_condition": "b > 2", "condition_parser": "spark"},
        domain_type="identity",
    )

    # Ensuring data has been properly queried
    assert dataframes_equal(
        data, expected_df
    ), "Data does not match after getting compute domain"

    # Ensuring compute kwargs have not been modified
    assert (
        "row_condition" in compute_kwargs.keys()
    ), "Row condition should be located within compute kwargs"
    assert accessor_kwargs == {}, "Accessor kwargs have been modified"
예제 #4
0
def test_distinct_metric_spark(spark_session):
    engine = build_spark_engine(
        spark=spark_session,
        df=pd.DataFrame({"a": [1, 2, 1, 2, 3, 3, None]}, ),
        batch_id="my_id",
    )

    desired_metric = MetricConfiguration(
        metric_name="column.value_counts",
        metric_domain_kwargs={"column": "a"},
        metric_value_kwargs={
            "sort": "value",
            "collate": None
        },
    )

    metrics = engine.resolve_metrics(metrics_to_resolve=(desired_metric, ))
    assert pd.Series(index=[1, 2, 3],
                     data=[2, 2, 2]).equals(metrics[desired_metric.id])

    desired_metric = MetricConfiguration(
        metric_name="column.distinct_values",
        metric_domain_kwargs={"column": "a"},
        metric_value_kwargs=dict(),
        metric_dependencies={"column.value_counts": desired_metric},
    )

    results = engine.resolve_metrics(metrics_to_resolve=(desired_metric, ),
                                     metrics=metrics)
    assert results == {desired_metric.id: {1, 2, 3}}
예제 #5
0
def test_get_compute_domain_with_ge_experimental_condition_parser(
    spark_session, basic_spark_df_execution_engine
):
    engine = build_spark_engine(
        spark=spark_session,
        df=pd.DataFrame(
            {"a": [1, 2, 3, 4], "b": [2, 3, 4, None]},
        ),
        batch_id="1234",
    )
    df = engine.dataframe

    # Filtering expected data based on row condition
    expected_df = df.where("b == 2")

    # Loading batch data
    engine.load_batch_data(batch_data=df, batch_id="1234")

    # Obtaining data from computation
    data, compute_kwargs, accessor_kwargs = engine.get_compute_domain(
        domain_kwargs={
            "column": "b",
            "row_condition": 'col("b") == 2',
            "condition_parser": "great_expectations__experimental__",
        },
        domain_type="column",
    )
    # Ensuring data has been properly queried
    assert dataframes_equal(
        data, expected_df
    ), "Data does not match after getting compute domain"

    # Ensuring compute kwargs have not been modified
    assert (
        "row_condition" in compute_kwargs.keys()
    ), "Row condition should be located within compute kwargs"
    assert accessor_kwargs == {"column": "b"}, "Accessor kwargs have been modified"

    # Should react differently for domain type identity
    data, compute_kwargs, accessor_kwargs = engine.get_compute_domain(
        domain_kwargs={
            "column": "b",
            "row_condition": 'col("b") == 2',
            "condition_parser": "great_expectations__experimental__",
        },
        domain_type="identity",
    )
    # Ensuring data has been properly queried
    assert dataframes_equal(
        data, expected_df.select("b")
    ), "Data does not match after getting compute domain"

    # Ensuring compute kwargs have not been modified
    assert (
        "row_condition" in compute_kwargs.keys()
    ), "Row condition should be located within compute kwargs"
    assert accessor_kwargs == {}, "Accessor kwargs have been modified"
예제 #6
0
def test_dataframe_property_given_loaded_batch(
    spark_session, basic_spark_df_execution_engine
):
    engine = build_spark_engine(
        spark=spark_session,
        df=pd.DataFrame(
            {"a": [1, 5, 22, 3, 5, 10]},
        ),
        batch_id="1234",
    )
    df = engine.dataframe

    # Ensuring Data not distorted
    assert engine.dataframe == df
예제 #7
0
def test_get_compute_domain_with_unmeetable_row_condition_alt(
    spark_session, basic_spark_df_execution_engine
):
    engine = build_spark_engine(
        spark=spark_session,
        df=pd.DataFrame(
            {"a": [1, 2, 3, 4], "b": [2, 3, 4, None]},
        ),
        batch_id="1234",
    )
    df = engine.dataframe
    expected_df = df.where("b > 24")

    # Loading batch data
    engine.load_batch_data(batch_data=df, batch_id="1234")

    data, compute_kwargs, accessor_kwargs = engine.get_compute_domain(
        domain_kwargs={"row_condition": "b > 24", "condition_parser": "spark"},
        domain_type="identity",
    )
    # Ensuring data has been properly queried
    assert dataframes_equal(
        data, expected_df
    ), "Data does not match after getting compute domain"

    # Ensuring compute kwargs have not been modified
    assert (
        "row_condition" in compute_kwargs.keys()
    ), "Row condition should be located within compute kwargs"
    assert accessor_kwargs == {}, "Accessor kwargs have been modified"

    # Ensuring errors for column and column_ pair domains are caught
    with pytest.raises(GreatExpectationsError):
        # noinspection PyUnusedLocal
        data, compute_kwargs, accessor_kwargs = engine.get_compute_domain(
            domain_kwargs={
                "row_condition": "b > 24",
                "condition_parser": "spark",
            },
            domain_type="column",
        )
    with pytest.raises(GreatExpectationsError) as g:
        # noinspection PyUnusedLocal
        data, compute_kwargs, accessor_kwargs = engine.get_compute_domain(
            domain_kwargs={
                "row_condition": "b > 24",
                "condition_parser": "spark",
            },
            domain_type="column_pair",
        )
예제 #8
0
def test_resolve_metric_bundle_with_nonexistent_metric(
    spark_session, basic_spark_df_execution_engine
):
    engine = build_spark_engine(
        spark=spark_session,
        df=pd.DataFrame(
            {"a": [1, 2, 1, 2, 3, 3], "b": [4, 4, 4, 4, 4, 4]},
        ),
        batch_id="1234",
    )

    desired_metric_1 = MetricConfiguration(
        metric_name="column_values.unique",
        metric_domain_kwargs={"column": "a"},
        metric_value_kwargs=dict(),
    )
    desired_metric_2 = MetricConfiguration(
        metric_name="column.min",
        metric_domain_kwargs={"column": "a"},
        metric_value_kwargs=dict(),
    )
    desired_metric_3 = MetricConfiguration(
        metric_name="column.max",
        metric_domain_kwargs={"column": "b"},
        metric_value_kwargs=dict(),
    )
    desired_metric_4 = MetricConfiguration(
        metric_name="column.does_not_exist",
        metric_domain_kwargs={"column": "b"},
        metric_value_kwargs=dict(),
    )

    # Ensuring a metric provider error is raised if metric does not exist
    with pytest.raises(MetricProviderError) as e:
        # noinspection PyUnusedLocal
        res = engine.resolve_metrics(
            metrics_to_resolve=(
                desired_metric_1,
                desired_metric_2,
                desired_metric_3,
                desired_metric_4,
            )
        )
        print(e)
예제 #9
0
def test_max_metric_spark_column_does_not_exist(spark_session):
    engine = build_spark_engine(
        spark=spark_session,
        df=pd.DataFrame({"a": [1, 2, 1]}),
        batch_id="my_id",
    )

    partial_metric = MetricConfiguration(
        metric_name="column.max.aggregate_fn",
        metric_domain_kwargs={"column": "non_existent_column"},
        metric_value_kwargs=dict(),
    )

    with pytest.raises(ge_exceptions.ExecutionEngineError) as eee:
        # noinspection PyUnusedLocal
        metrics = engine.resolve_metrics(metrics_to_resolve=(partial_metric, ))
    assert (
        str(eee.value) ==
        'Error: The column "non_existent_column" in BatchData does not exist.')
예제 #10
0
def test_get_compute_domain_with_column_pair(
    spark_session, basic_spark_df_execution_engine
):
    engine = build_spark_engine(
        spark=spark_session,
        df=pd.DataFrame(
            {"a": [1, 2, 3, 4], "b": [2, 3, 4, None]},
        ),
        batch_id="1234",
    )
    df = engine.dataframe

    data, compute_kwargs, accessor_kwargs = engine.get_compute_domain(
        domain_kwargs={"column_A": "a", "column_B": "b"}, domain_type="column_pair"
    )

    # Ensuring that with no domain nothing happens to the data itself
    assert dataframes_equal(
        data, df
    ), "Data does not match after getting compute domain"
    assert compute_kwargs == {}, "Compute domain kwargs should be existent"
    assert accessor_kwargs == {
        "column_A": "a",
        "column_B": "b",
    }, "Accessor kwargs have been modified"

    data, compute_kwargs, accessor_kwargs = engine.get_compute_domain(
        domain_kwargs={"column_A": "a", "column_B": "b"}, domain_type="identity"
    )

    # Ensuring that with no domain nothing happens to the data itself
    assert dataframes_equal(
        data, df
    ), "Data does not match after getting compute domain"
    assert compute_kwargs == {
        "column_A": "a",
        "column_B": "b",
    }, "Compute domain kwargs should not be modified"
    assert accessor_kwargs == {}, "Accessor kwargs have been modified"
예제 #11
0
def test_get_compute_domain_with_multicolumn(
    spark_session, basic_spark_df_execution_engine
):
    engine = build_spark_engine(
        spark=spark_session,
        df=pd.DataFrame(
            {"a": [1, 2, 3, 4], "b": [2, 3, 4, None], "c": [1, 2, 3, None]},
        ),
        batch_id="1234",
    )
    df = engine.dataframe

    data, compute_kwargs, accessor_kwargs = engine.get_compute_domain(
        domain_kwargs={"columns": ["a", "b", "c"]}, domain_type="multicolumn"
    )

    # Ensuring that with no domain nothing happens to the data itself
    assert dataframes_equal(
        data, df
    ), "Data does not match after getting compute domain"
    assert compute_kwargs == {}, "Compute domain kwargs should be empty"
    assert accessor_kwargs == {
        "columns": ["a", "b", "c"]
    }, "Accessor kwargs have been modified"

    # Checking for identity
    engine.load_batch_data(batch_data=df, batch_id="1234")
    data, compute_kwargs, accessor_kwargs = engine.get_compute_domain(
        domain_kwargs={"columns": ["a", "b", "c"]}, domain_type="identity"
    )

    # Ensuring that with no domain nothing happens to the data itself
    assert dataframes_equal(
        data, df
    ), "Data does not match after getting compute domain"
    assert compute_kwargs == {
        "columns": ["a", "b", "c"]
    }, "Compute domain kwargs should not change for identity domain"
    assert accessor_kwargs == {}, "Accessor kwargs have been modified"
예제 #12
0
def test_get_compute_domain_with_column_domain_alt(
    spark_session, basic_spark_df_execution_engine
):
    engine = build_spark_engine(
        spark=spark_session,
        df=pd.DataFrame(
            {"a": [1, 2, 3, 4], "b": [2, 3, 4, None]},
        ),
        batch_id="1234",
    )
    df = engine.dataframe

    data, compute_kwargs, accessor_kwargs = engine.get_compute_domain(
        domain_kwargs={"column": "a"}, domain_type="column"
    )

    # Ensuring that column domain is now an accessor kwarg, and data remains unmodified
    assert dataframes_equal(
        data, df
    ), "Data does not match after getting compute domain"
    assert compute_kwargs == {}, "Compute domain kwargs should be empty"
    assert accessor_kwargs == {"column": "a"}, "Accessor kwargs have been modified"
예제 #13
0
def test_map_unique_spark_column_does_not_exist(spark_session):
    engine = build_spark_engine(
        spark=spark_session,
        df=pd.DataFrame({
            "a": [1, 2, 3, 3, 4, None],
            "b": [None, "foo", "bar", "baz", "qux", "fish"],
        }),
        batch_id="my_id",
    )

    condition_metric = MetricConfiguration(
        metric_name="column_values.unique.condition",
        metric_domain_kwargs={"column": "non_existent_column"},
        metric_value_kwargs=dict(),
    )

    with pytest.raises(ge_exceptions.ExecutionEngineError) as eee:
        # noinspection PyUnusedLocal
        metrics = engine.resolve_metrics(
            metrics_to_resolve=(condition_metric, ))
    assert (
        str(eee.value) ==
        'Error: The column "non_existent_column" in BatchData does not exist.')
예제 #14
0
def test_max_metric_spark_column_exists(spark_session):
    engine = build_spark_engine(
        spark=spark_session,
        df=pd.DataFrame({"a": [1, 2, 1]}),
        batch_id="my_id",
    )
    partial_metric = MetricConfiguration(
        metric_name="column.max.aggregate_fn",
        metric_domain_kwargs={"column": "a"},
        metric_value_kwargs=dict(),
    )

    metrics = engine.resolve_metrics(metrics_to_resolve=(partial_metric, ))
    desired_metric = MetricConfiguration(
        metric_name="column.max",
        metric_domain_kwargs={"column": "a"},
        metric_value_kwargs=dict(),
        metric_dependencies={"metric_partial_fn": partial_metric},
    )

    results = engine.resolve_metrics(metrics_to_resolve=(desired_metric, ),
                                     metrics=metrics)
    assert results == {desired_metric.id: 2}
예제 #15
0
def test_sparkdf_batch_aggregate_metrics(caplog, spark_session):
    import datetime

    engine = build_spark_engine(
        spark=spark_session,
        df=pd.DataFrame({
            "a": [1, 2, 1, 2, 3, 3],
            "b": [4, 4, 4, 4, 4, 4]
        }, ),
        batch_id="my_id",
    )

    desired_metric_1 = MetricConfiguration(
        metric_name="column.max.aggregate_fn",
        metric_domain_kwargs={"column": "a"},
        metric_value_kwargs=dict(),
    )
    desired_metric_2 = MetricConfiguration(
        metric_name="column.min.aggregate_fn",
        metric_domain_kwargs={"column": "a"},
        metric_value_kwargs=dict(),
    )
    desired_metric_3 = MetricConfiguration(
        metric_name="column.max.aggregate_fn",
        metric_domain_kwargs={"column": "b"},
        metric_value_kwargs=dict(),
    )
    desired_metric_4 = MetricConfiguration(
        metric_name="column.min.aggregate_fn",
        metric_domain_kwargs={"column": "b"},
        metric_value_kwargs=dict(),
    )
    metrics = engine.resolve_metrics(metrics_to_resolve=(
        desired_metric_1,
        desired_metric_2,
        desired_metric_3,
        desired_metric_4,
    ))
    desired_metric_1 = MetricConfiguration(
        metric_name="column.max",
        metric_domain_kwargs={"column": "a"},
        metric_value_kwargs=dict(),
        metric_dependencies={"metric_partial_fn": desired_metric_1},
    )
    desired_metric_2 = MetricConfiguration(
        metric_name="column.min",
        metric_domain_kwargs={"column": "a"},
        metric_value_kwargs=dict(),
        metric_dependencies={"metric_partial_fn": desired_metric_2},
    )
    desired_metric_3 = MetricConfiguration(
        metric_name="column.max",
        metric_domain_kwargs={"column": "b"},
        metric_value_kwargs=dict(),
        metric_dependencies={"metric_partial_fn": desired_metric_3},
    )
    desired_metric_4 = MetricConfiguration(
        metric_name="column.min",
        metric_domain_kwargs={"column": "b"},
        metric_value_kwargs=dict(),
        metric_dependencies={"metric_partial_fn": desired_metric_4},
    )
    start = datetime.datetime.now()
    caplog.clear()
    caplog.set_level(logging.DEBUG, logger="great_expectations")
    res = engine.resolve_metrics(
        metrics_to_resolve=(
            desired_metric_1,
            desired_metric_2,
            desired_metric_3,
            desired_metric_4,
        ),
        metrics=metrics,
    )
    end = datetime.datetime.now()
    print(end - start)
    assert res[desired_metric_1.id] == 3
    assert res[desired_metric_2.id] == 1
    assert res[desired_metric_3.id] == 4
    assert res[desired_metric_4.id] == 4

    # Check that all four of these metrics were computed on a single domain
    found_message = False
    for record in caplog.records:
        if (record.message ==
                "SparkDFExecutionEngine computed 4 metrics on domain_id ()"):
            found_message = True
    assert found_message
예제 #16
0
def test_z_score_under_threshold_spark(spark_session):
    engine = build_spark_engine(
        spark=spark_session,
        df=pd.DataFrame({"a": [1, 2, 3, 3, None]}, ),
        batch_id="my_id",
    )

    mean = MetricConfiguration(
        metric_name="column.mean.aggregate_fn",
        metric_domain_kwargs={"column": "a"},
        metric_value_kwargs=dict(),
    )
    stdev = MetricConfiguration(
        metric_name="column.standard_deviation.aggregate_fn",
        metric_domain_kwargs={"column": "a"},
        metric_value_kwargs=dict(),
    )
    desired_metrics = (mean, stdev)
    metrics = engine.resolve_metrics(metrics_to_resolve=desired_metrics)

    mean = MetricConfiguration(
        metric_name="column.mean",
        metric_domain_kwargs={"column": "a"},
        metric_value_kwargs=dict(),
        metric_dependencies={"metric_partial_fn": mean},
    )
    stdev = MetricConfiguration(
        metric_name="column.standard_deviation",
        metric_domain_kwargs={"column": "a"},
        metric_value_kwargs=dict(),
        metric_dependencies={"metric_partial_fn": stdev},
    )
    desired_metrics = (mean, stdev)
    metrics = engine.resolve_metrics(metrics_to_resolve=desired_metrics,
                                     metrics=metrics)

    desired_metric = MetricConfiguration(
        metric_name="column_values.z_score.map",
        metric_domain_kwargs={"column": "a"},
        metric_value_kwargs=dict(),
        metric_dependencies={
            "column.standard_deviation": stdev,
            "column.mean": mean
        },
    )
    results = engine.resolve_metrics(metrics_to_resolve=(desired_metric, ),
                                     metrics=metrics)
    metrics.update(results)
    desired_metric = MetricConfiguration(
        metric_name="column_values.z_score.under_threshold.condition",
        metric_domain_kwargs={"column": "a"},
        metric_value_kwargs={
            "double_sided": True,
            "threshold": 2
        },
        metric_dependencies={"column_values.z_score.map": desired_metric},
    )
    results = engine.resolve_metrics(metrics_to_resolve=(desired_metric, ),
                                     metrics=metrics)
    metrics.update(results)

    desired_metric = MetricConfiguration(
        metric_name=
        "column_values.z_score.under_threshold.unexpected_count.aggregate_fn",
        metric_domain_kwargs={"column": "a"},
        metric_value_kwargs={
            "double_sided": True,
            "threshold": 2
        },
        metric_dependencies={"unexpected_condition": desired_metric},
    )
    results = engine.resolve_metrics(metrics_to_resolve=(desired_metric, ),
                                     metrics=metrics)
    metrics.update(results)

    desired_metric = MetricConfiguration(
        metric_name="column_values.z_score.under_threshold.unexpected_count",
        metric_domain_kwargs={"column": "a"},
        metric_value_kwargs={
            "double_sided": True,
            "threshold": 2
        },
        metric_dependencies={"metric_partial_fn": desired_metric},
    )
    results = engine.resolve_metrics(metrics_to_resolve=(desired_metric, ),
                                     metrics=metrics)
    assert results[desired_metric.id] == 0
예제 #17
0
def test_map_unique_spark_column_exists(spark_session):
    engine = build_spark_engine(
        spark=spark_session,
        df=pd.DataFrame({
            "a": [1, 2, 3, 3, 4, None],
            "b": [None, "foo", "bar", "baz", "qux", "fish"],
        }),
        batch_id="my_id",
    )

    condition_metric = MetricConfiguration(
        metric_name="column_values.unique.condition",
        metric_domain_kwargs={"column": "a"},
        metric_value_kwargs=dict(),
    )
    metrics = engine.resolve_metrics(metrics_to_resolve=(condition_metric, ))

    # unique is a *window* function so does not use the aggregate_fn version of unexpected count
    desired_metric = MetricConfiguration(
        metric_name="column_values.unique.unexpected_count",
        metric_domain_kwargs={"column": "a"},
        metric_value_kwargs=dict(),
        metric_dependencies={"unexpected_condition": condition_metric},
    )
    results = engine.resolve_metrics(metrics_to_resolve=(desired_metric, ),
                                     metrics=metrics)
    assert results[desired_metric.id] == 2

    desired_metric = MetricConfiguration(
        metric_name="column_values.unique.unexpected_values",
        metric_domain_kwargs={"column": "a"},
        metric_value_kwargs={
            "result_format": {
                "result_format": "BASIC",
                "partial_unexpected_count": 20
            }
        },
        metric_dependencies={"unexpected_condition": condition_metric},
    )
    results = engine.resolve_metrics(metrics_to_resolve=(desired_metric, ),
                                     metrics=metrics)
    assert results[desired_metric.id] == [3, 3]

    desired_metric = MetricConfiguration(
        metric_name="column_values.unique.unexpected_value_counts",
        metric_domain_kwargs={"column": "a"},
        metric_value_kwargs={
            "result_format": {
                "result_format": "BASIC",
                "partial_unexpected_count": 20
            }
        },
        metric_dependencies={"unexpected_condition": condition_metric},
    )
    results = engine.resolve_metrics(metrics_to_resolve=(desired_metric, ),
                                     metrics=metrics)
    assert results[desired_metric.id] == [(3, 2)]

    desired_metric = MetricConfiguration(
        metric_name="column_values.unique.unexpected_rows",
        metric_domain_kwargs={"column": "a"},
        metric_value_kwargs={
            "result_format": {
                "result_format": "BASIC",
                "partial_unexpected_count": 20
            }
        },
        metric_dependencies={"unexpected_condition": condition_metric},
    )
    results = engine.resolve_metrics(metrics_to_resolve=(desired_metric, ),
                                     metrics=metrics)
    assert results[desired_metric.id] == [(3, "bar"), (3, "baz")]
예제 #18
0
def test_map_value_set_spark(spark_session, basic_spark_df_execution_engine):
    engine = build_spark_engine(
        spark=spark_session,
        df=pd.DataFrame({"a": [1, 2, 3, 3, None]}, ),
        batch_id="my_id",
    )

    condition_metric = MetricConfiguration(
        metric_name="column_values.in_set.condition",
        metric_domain_kwargs={"column": "a"},
        metric_value_kwargs={"value_set": [1, 2, 3]},
    )
    metrics = engine.resolve_metrics(metrics_to_resolve=(condition_metric, ))

    # Note: metric_dependencies is optional here in the config when called from a validator.
    aggregate_partial = MetricConfiguration(
        metric_name="column_values.in_set.unexpected_count.aggregate_fn",
        metric_domain_kwargs={"column": "a"},
        metric_value_kwargs={"value_set": [1, 2, 3]},
        metric_dependencies={"unexpected_condition": condition_metric},
    )
    metrics = engine.resolve_metrics(metrics_to_resolve=(aggregate_partial, ),
                                     metrics=metrics)
    desired_metric = MetricConfiguration(
        metric_name="column_values.in_set.unexpected_count",
        metric_domain_kwargs={"column": "a"},
        metric_value_kwargs={"value_set": [1, 2, 3]},
        metric_dependencies={"metric_partial_fn": aggregate_partial},
    )

    results = engine.resolve_metrics(metrics_to_resolve=(desired_metric, ),
                                     metrics=metrics)
    assert results == {desired_metric.id: 0}

    # We run the same computation again, this time with None being replaced by nan instead of NULL
    # to demonstrate this behavior
    df = pd.DataFrame({"a": [1, 2, 3, 3, None]})
    df = spark_session.createDataFrame(df)
    engine = basic_spark_df_execution_engine
    engine.load_batch_data(batch_id="my_id", batch_data=df)

    condition_metric = MetricConfiguration(
        metric_name="column_values.in_set.condition",
        metric_domain_kwargs={"column": "a"},
        metric_value_kwargs={"value_set": [1, 2, 3]},
    )
    metrics = engine.resolve_metrics(metrics_to_resolve=(condition_metric, ))

    # Note: metric_dependencies is optional here in the config when called from a validator.
    aggregate_partial = MetricConfiguration(
        metric_name="column_values.in_set.unexpected_count.aggregate_fn",
        metric_domain_kwargs={"column": "a"},
        metric_value_kwargs={"value_set": [1, 2, 3]},
        metric_dependencies={"unexpected_condition": condition_metric},
    )
    metrics = engine.resolve_metrics(metrics_to_resolve=(aggregate_partial, ),
                                     metrics=metrics)
    desired_metric = MetricConfiguration(
        metric_name="column_values.in_set.unexpected_count",
        metric_domain_kwargs={"column": "a"},
        metric_value_kwargs={"value_set": [1, 2, 3]},
        metric_dependencies={"metric_partial_fn": aggregate_partial},
    )

    results = engine.resolve_metrics(metrics_to_resolve=(desired_metric, ),
                                     metrics=metrics)
    assert results == {desired_metric.id: 1}