Exemplo n.º 1
0
    def _get_df_comparison(self):

        if self.settings["link_type"] == "dedupe_only":
            return block_using_rules(self.settings, self.spark, df=self.df)

        if self.settings["link_type"] in ("link_only", "link_and_dedupe"):
            return block_using_rules(self.settings,
                                     self.spark,
                                     df_l=self.df_l,
                                     df_r=self.df_r)
Exemplo n.º 2
0
def test_link_dedupe(spark, link_dedupe_data, link_dedupe_data_repeat_ids):

    settings = {
        "link_type": "link_and_dedupe",
        "comparison_columns": [{"col_name": "first_name"}, {"col_name": "surname"}],
        "blocking_rules": ["l.first_name = r.first_name", "l.surname = r.surname"],
    }
    settings = complete_settings_dict(settings, spark=spark)
    df_l = link_dedupe_data["df_l"]
    df_r = link_dedupe_data["df_r"]
    df = vertically_concatenate_datasets([df_l, df_r])
    df_comparison = block_using_rules(settings, df, spark)
    df = df_comparison.toPandas()
    df = df.sort_values(["unique_id_l", "unique_id_r"])

    assert list(df["unique_id_l"]) == [1, 1, 2, 2, 7, 8]
    assert list(df["unique_id_r"]) == [7, 9, 8, 9, 9, 9]

    df_l = link_dedupe_data_repeat_ids["df_l"]
    df_r = link_dedupe_data_repeat_ids["df_r"]
    df = vertically_concatenate_datasets([df_l, df_r])
    df = block_using_rules(settings, df, spark)
    df = df.toPandas()
    df["u_l"] = df["unique_id_l"].astype(str) + df["source_dataset_l"].str.slice(0, 1)
    df["u_r"] = df["unique_id_r"].astype(str) + df["source_dataset_r"].str.slice(0, 1)

    df = df.sort_values(
        ["source_dataset_l", "source_dataset_r", "unique_id_l", "unique_id_r"]
    )

    assert list(df["u_l"]) == ["2l", "1l", "1l", "2l", "2l", "3l", "3l", "1r", "2r"]
    assert list(df["u_r"]) == ["3l", "1r", "3r", "2r", "3r", "2r", "3r", "3r", "3r"]

    settings = {
        "link_type": "link_and_dedupe",
        "comparison_columns": [{"col_name": "first_name"}, {"col_name": "surname"}],
        "blocking_rules": [],
    }
    settings = complete_settings_dict(settings, spark=spark)

    df_l = link_dedupe_data_repeat_ids["df_l"]
    df_r = link_dedupe_data_repeat_ids["df_r"]
    df = vertically_concatenate_datasets([df_l, df_r])
    df = block_using_rules(settings, df, spark)
    df = df.toPandas()

    df["u_l"] = df["unique_id_l"].astype(str) + df["source_dataset_l"].str.slice(0, 1)
    df["u_r"] = df["unique_id_r"].astype(str) + df["source_dataset_r"].str.slice(0, 1)
    df = df.sort_values(
        ["source_dataset_l", "unique_id_l", "source_dataset_r", "unique_id_r"]
    )
    # fmt: off
    assert list(df["u_l"]) == ["1l", "1l", "1l", "1l", "1l", "2l", "2l", "2l", "2l", "3l", "3l", "3l", "1r", "1r", "2r"]
    assert list(df["u_r"]) == ["2l", "3l", "1r", "2r", "3r", "3l", "1r", "2r", "3r", "1r", "2r", "3r", "2r", "3r", "3r"]
Exemplo n.º 3
0
def test_expectation(spark, sqlite_con_1, params_1, gamma_settings_1):
    dfpd = pd.read_sql("select * from test1", sqlite_con_1)
    df = spark.createDataFrame(dfpd)

    gamma_settings_1["blocking_rules"] = [
        "l.mob = r.mob",
        "l.surname = r.surname",
    ]

    df_comparison = block_using_rules(gamma_settings_1, df=df, spark=spark)

    df_gammas = add_gammas(df_comparison, gamma_settings_1, spark)

    # df_e = iterate(df_gammas, spark, params_1, num_iterations=1)
    df_e = run_expectation_step(df_gammas, params_1, gamma_settings_1, spark)

    df_e_pd = df_e.toPandas()
    df_e_pd = df_e_pd.sort_values(["unique_id_l", "unique_id_r"])

    correct_list = [
        0.893617021,
        0.705882353,
        0.705882353,
        0.189189189,
        0.189189189,
        0.893617021,
        0.375,
        0.375,
    ]
    result_list = list(df_e_pd["match_probability"].astype(float))

    for i in zip(result_list, correct_list):
        assert i[0] == pytest.approx(i[1])
Exemplo n.º 4
0
    def get_scored_comparisons(self, compute_ll=False):
        """Use the EM algorithm to estimate model parameters and return match probabilities.

        Returns:
            DataFrame: A spark dataframe including a match probability column
        """

        df_comparison = block_using_rules(self.settings_dict, self.df,
                                          self.spark)

        df_gammas = add_gammas(df_comparison, self.settings_dict, self.spark)

        df_gammas = self.break_lineage_blocked_comparisons(
            df_gammas, self.spark)

        df_e = iterate(
            df_gammas,
            self.model,
            self.spark,
            compute_ll=compute_ll,
            save_state_fn=self.save_state_fn,
        )

        # In case the user's break lineage function has persisted it
        df_gammas.unpersist()

        df_e = self.break_lineage_scored_comparisons(df_e, self.spark)

        df_e_adj = self.make_term_frequency_adjustments(df_e)

        df_e.unpersist()

        return df_e_adj
Exemplo n.º 5
0
def test_tiny_numbers(spark, sqlite_con_1):

    # Regression test, see https://github.com/moj-analytical-services/splink/issues/48

    dfpd = pd.read_sql("select * from test1", sqlite_con_1)
    df = spark.createDataFrame(dfpd)

    settings = {
        "link_type": "dedupe_only",
        "proportion_of_matches": 0.4,
        "comparison_columns": [
            {
                "col_name": "mob",
                "num_levels": 2,
                "m_probabilities": [5.9380419956766985e-25, 1 - 5.9380419956766985e-25],
                "u_probabilities": [0.8, 0.2],
            },
            {"col_name": "surname", "num_levels": 2,},
        ],
        "blocking_rules": ["l.mob = r.mob", "l.surname = r.surname",],
    }

    settings = complete_settings_dict(settings, spark=None)

    df_comparison = block_using_rules(settings, df=df, spark=spark)

    df_gammas = add_gammas(df_comparison, settings, spark)
    params = Params(settings, spark="supress_warnings")

    df_e = run_expectation_step(df_gammas, params, settings, spark)
Exemplo n.º 6
0
def test_link_option_link(spark, link_dedupe_data_repeat_ids):
    settings = {
        "link_type": "link_only",
        "comparison_columns": [{
            "col_name": "first_name"
        }, {
            "col_name": "surname"
        }],
        "blocking_rules":
        ["l.first_name = r.first_name", "l.surname = r.surname"]
    }
    settings = complete_settings_dict(settings, spark=None)
    dfpd_l = pd.read_sql("select * from df_l", link_dedupe_data_repeat_ids)
    df_l = spark.createDataFrame(dfpd_l)
    dfpd_r = pd.read_sql("select * from df_r", link_dedupe_data_repeat_ids)
    df_r = spark.createDataFrame(dfpd_r)
    df = block_using_rules(settings, spark, df_l=df_l, df_r=df_r)
    df = df.toPandas()

    df = df.sort_values(["unique_id_l", "unique_id_r"])

    assert list(df["unique_id_l"]) == [1, 1, 2, 2, 3, 3]
    assert list(df["unique_id_r"]) == [1, 3, 2, 3, 2, 3]

    # Test cartesian version

    settings = {
        "link_type": "link_only",
        "comparison_columns": [{
            "col_name": "first_name"
        }, {
            "col_name": "surname"
        }],
        "blocking_rules": []
    }
    settings = complete_settings_dict(settings, spark=None)
    dfpd_l = pd.read_sql("select * from df_l", link_dedupe_data_repeat_ids)
    df_l = spark.createDataFrame(dfpd_l)
    dfpd_r = pd.read_sql("select * from df_r", link_dedupe_data_repeat_ids)
    df_r = spark.createDataFrame(dfpd_r)
    df = block_using_rules(settings, spark, df_l=df_l, df_r=df_r)
    df = df.toPandas()

    df = df.sort_values(["unique_id_l", "unique_id_r"])

    assert list(df["unique_id_l"]) == [1, 1, 1, 2, 2, 2, 3, 3, 3]
    assert list(df["unique_id_r"]) == [1, 2, 3, 1, 2, 3, 1, 2, 3]
Exemplo n.º 7
0
    def manually_apply_fellegi_sunter_weights(self):
        """Compute match probabilities from m and u probabilities specified in the splink settings object

        Returns:
            DataFrame: A spark dataframe including a match probability column
        """
        df_comparison = block_using_rules(self.settings_dict, self.df,
                                          self.spark)
        df_gammas = add_gammas(df_comparison, self.settings_dict, self.spark)
        return run_expectation_step(df_gammas, self.model, self.spark)
Exemplo n.º 8
0
def test_link_only(spark, link_dedupe_data, link_dedupe_data_repeat_ids):

    settings = {
        "link_type": "link_only",
        "comparison_columns": [{"col_name": "first_name"}, {"col_name": "surname"}],
        "blocking_rules": ["l.first_name = r.first_name", "l.surname = r.surname"],
    }
    settings = complete_settings_dict(settings, spark)
    df_l = link_dedupe_data["df_l"]
    df_r = link_dedupe_data["df_r"]
    df = vertically_concatenate_datasets([df_l, df_r])
    df_comparison = block_using_rules(settings, df, spark)
    df = df_comparison.toPandas()
    df = df.sort_values(["unique_id_l", "unique_id_r"])

    assert list(df["unique_id_l"]) == [1, 1, 2, 2]
    assert list(df["unique_id_r"]) == [7, 9, 8, 9]

    df_l = link_dedupe_data_repeat_ids["df_l"]
    df_r = link_dedupe_data_repeat_ids["df_r"]
    df = vertically_concatenate_datasets([df_l, df_r])
    df_comparison = block_using_rules(settings, df, spark)
    df = df_comparison.toPandas()
    df = df.sort_values(["unique_id_l", "unique_id_r"])

    assert list(df["unique_id_l"]) == [1, 1, 2, 2, 3, 3]
    assert list(df["unique_id_r"]) == [1, 3, 2, 3, 2, 3]

    settings = {
        "link_type": "link_only",
        "comparison_columns": [{"col_name": "first_name"}, {"col_name": "surname"}],
        "blocking_rules": [],
    }
    settings = complete_settings_dict(settings, spark)
    df = vertically_concatenate_datasets([df_l, df_r])
    df_comparison = block_using_rules(settings, df, spark)
    df = df_comparison.toPandas()
    df = df.sort_values(["unique_id_l", "unique_id_r"])

    assert list(df["unique_id_l"]) == [1, 1, 1, 2, 2, 2, 3, 3, 3]
    assert list(df["unique_id_r"]) == [1, 2, 3, 1, 2, 3, 1, 2, 3]
Exemplo n.º 9
0
def test_dedupe(spark, link_dedupe_data_repeat_ids):
    # This tests checks that we only get one result when a comparison is hit by multiple blocking rules
    settings = {
        "link_type": "dedupe_only",
        "comparison_columns": [{"col_name": "first_name"}, {"col_name": "surname"}],
        "blocking_rules": ["l.first_name = r.first_name", "l.surname = r.surname"],
    }
    settings = complete_settings_dict(settings, spark=None)
    df_l = link_dedupe_data_repeat_ids["df_l"]
    df = block_using_rules(settings, df_l, spark)
    df = df.toPandas()

    df = df.sort_values(["unique_id_l", "unique_id_r"])

    assert list(df["unique_id_l"]) == [2]
    assert list(df["unique_id_r"]) == [3]

    # Is the source dataset column retained if it exists?
    assert "source_dataset_l" in list(df.columns)

    df_l = link_dedupe_data_repeat_ids["df_l"]
    df_l = df_l.drop("source_dataset")

    df = block_using_rules(settings, df_l, spark)

    # Is the source dataset column excluded if it doesn't exist?
    assert "source_dataset_l" not in list(df.columns)

    # Is the source datasetcolumn included if it has a different name?
    df_l = link_dedupe_data_repeat_ids["df_l"]
    df_l = df_l.withColumnRenamed("source_dataset", "source_ds")

    settings["source_dataset_column_name"] = "source_ds"

    df = block_using_rules(settings, df_l, spark)

    # Is the source dataset column excluded if it doesn't exist?
    assert "source_ds_l" in list(df.columns)
    assert "source_dataset_l" not in list(df.columns)
Exemplo n.º 10
0
    def manually_apply_fellegi_sunter_weights(self):
        """Compute match probabilities from m and u probabilities specified in the splink settings object

        Returns:
            DataFrame: A spark dataframe including a match probability column
        """
        df_comparison = block_using_rules(self.settings_dict, self.df,
                                          self.spark)
        df_gammas = add_gammas(df_comparison, self.settings_dict, self.spark)
        # see https://github.com/moj-analytical-services/splink/issues/187
        df_gammas = self.break_lineage_blocked_comparisons(
            df_gammas, self.spark)
        return run_expectation_step(df_gammas, self.model, self.spark)
Exemplo n.º 11
0
def test_no_blocking(spark, link_dedupe_data):
    settings = {
        "link_type": "link_only",
        "comparison_columns": [{"col_name": "first_name"},
                            {"col_name": "surname"}],
        "blocking_rules": []
    }
    settings = complete_settings_dict(settings, spark=None)
    dfpd_l = pd.read_sql("select * from df_l", link_dedupe_data)
    dfpd_r = pd.read_sql("select * from df_r", link_dedupe_data)
    df_l = spark.createDataFrame(dfpd_l)
    df_r = spark.createDataFrame(dfpd_r)


    df_comparison = block_using_rules(settings, spark, df_l=df_l, df_r=df_r)
    df = df_comparison.toPandas()
    df = df.sort_values(["unique_id_l", "unique_id_r"])

    assert list(df["unique_id_l"]) == [1,1,1,2,2,2]
    assert list(df["unique_id_r"]) == [7,8,9,7,8,9]
Exemplo n.º 12
0
def test_link_option_dedupe_only(spark, link_dedupe_data_repeat_ids):
    settings = {
        "link_type": "dedupe_only",
        "comparison_columns": [{"col_name": "first_name"},
                            {"col_name": "surname"}],
        "blocking_rules": [
            "l.first_name = r.first_name",
            "l.surname = r.surname"
        ]
    }
    settings = complete_settings_dict(settings, spark=None)
    dfpd = pd.read_sql("select * from df_l", link_dedupe_data_repeat_ids)
    df = spark.createDataFrame(dfpd)

    df = block_using_rules(settings, spark, df=df)
    df = df.toPandas()

    df = df.sort_values(["unique_id_l", "unique_id_r"])

    assert list(df["unique_id_l"]) == [2]
    assert list(df["unique_id_r"]) == [3]
Exemplo n.º 13
0
def test_link_option_link_dedupe(spark, link_dedupe_data_repeat_ids):
    settings = {
        "link_type": "link_and_dedupe",
        "comparison_columns": [{
            "col_name": "first_name"
        }, {
            "col_name": "surname"
        }],
        "blocking_rules":
        ["l.first_name = r.first_name", "l.surname = r.surname"]
    }
    settings = complete_settings_dict(settings, spark=None)
    dfpd_l = pd.read_sql("select * from df_l", link_dedupe_data_repeat_ids)
    df_l = spark.createDataFrame(dfpd_l)
    dfpd_r = pd.read_sql("select * from df_r", link_dedupe_data_repeat_ids)
    df_r = spark.createDataFrame(dfpd_r)
    df = block_using_rules(settings, spark, df_l=df_l, df_r=df_r)
    df = df.toPandas()
    df["u_l"] = df["unique_id_l"].astype(
        str) + df["_source_table_l"].str.slice(0, 1)
    df["u_r"] = df["unique_id_r"].astype(
        str) + df["_source_table_r"].str.slice(0, 1)
    df = df.sort_values(
        ["_source_table_l", "_source_table_r", "unique_id_l", "unique_id_r"])

    assert list(
        df["u_l"]) == ['2l', '1l', '1l', '2l', '2l', '3l', '3l', '1r', '2r']
    assert list(
        df["u_r"]) == ['3l', '1r', '3r', '2r', '3r', '2r', '3r', '3r', '3r']

    # Same for no blocking rules = cartesian product

    settings = {
        "link_type": "link_and_dedupe",
        "comparison_columns": [{
            "col_name": "first_name"
        }, {
            "col_name": "surname"
        }],
        "blocking_rules": []
    }
    settings = complete_settings_dict(settings, spark=None)
    dfpd_l = pd.read_sql("select * from df_l", link_dedupe_data_repeat_ids)
    df_l = spark.createDataFrame(dfpd_l)
    dfpd_r = pd.read_sql("select * from df_r", link_dedupe_data_repeat_ids)
    df_r = spark.createDataFrame(dfpd_r)
    df = block_using_rules(settings, spark, df_l=df_l, df_r=df_r)
    df = df.toPandas()

    df["u_l"] = df["unique_id_l"].astype(
        str) + df["_source_table_l"].str.slice(0, 1)
    df["u_r"] = df["unique_id_r"].astype(
        str) + df["_source_table_r"].str.slice(0, 1)
    df = df.sort_values(
        ["_source_table_l", "unique_id_l", "_source_table_r", "unique_id_r"])

    assert list(df["u_l"]) == [
        '1l', '1l', '1l', '1l', '1l', '2l', '2l', '2l', '2l', '3l', '3l', '3l',
        '1r', '1r', '2r'
    ]
    assert list(df["u_r"]) == [
        '2l', '3l', '1r', '2r', '3r', '3l', '1r', '2r', '3r', '1r', '2r', '3r',
        '2r', '3r', '3r'
    ]

    # Same for cartesian product

    settings = {
        "link_type": "link_and_dedupe",
        "comparison_columns": [{
            "col_name": "first_name"
        }, {
            "col_name": "surname"
        }]
    }
    settings = complete_settings_dict(settings, spark=None)
    dfpd_l = pd.read_sql("select * from df_l", link_dedupe_data_repeat_ids)
    df_l = spark.createDataFrame(dfpd_l)
    dfpd_r = pd.read_sql("select * from df_r", link_dedupe_data_repeat_ids)
    df_r = spark.createDataFrame(dfpd_r)
    df = block_using_rules(settings, spark, df_l=df_l, df_r=df_r)
    df = df.toPandas()
    df["u_l"] = df["unique_id_l"].astype(
        str) + df["_source_table_l"].str.slice(0, 1)
    df["u_r"] = df["unique_id_r"].astype(
        str) + df["_source_table_r"].str.slice(0, 1)
    df = df.sort_values(
        ["_source_table_l", "unique_id_l", "_source_table_r", "unique_id_r"])

    assert list(df["u_l"]) == [
        '1l', '1l', '1l', '1l', '1l', '2l', '2l', '2l', '2l', '3l', '3l', '3l',
        '1r', '1r', '2r'
    ]
    assert list(df["u_r"]) == [
        '2l', '3l', '1r', '2r', '3r', '3l', '1r', '2r', '3r', '1r', '2r', '3r',
        '2r', '3r', '3r'
    ]
Exemplo n.º 14
0
def test_iterate(spark, sqlite_con_1, params_1, gamma_settings_1):

    original_params = copy.deepcopy(params_1.params)
    dfpd = pd.read_sql("select * from test1", sqlite_con_1)
    df = spark.createDataFrame(dfpd)

    rules = [
        "l.mob = r.mob",
        "l.surname = r.surname",
    ]

    gamma_settings_1["blocking_rules"] = rules

    df_comparison = block_using_rules(gamma_settings_1, df=df, spark=spark)

    df_gammas = add_gammas(df_comparison, gamma_settings_1, spark)

    gamma_settings_1["max_iterations"] = 1
    df_e = iterate(df_gammas, params_1, gamma_settings_1, spark)

    assert params_1.params["λ"] == pytest.approx(0.540922141)

    assert params_1.params["π"]["gamma_mob"]["prob_dist_match"]["level_0"][
        "probability"] == pytest.approx(0.087438272, abs=0.0001)
    assert params_1.params["π"]["gamma_surname"]["prob_dist_non_match"][
        "level_1"]["probability"] == pytest.approx(0.160167628, abs=0.0001)

    first_it_params = copy.deepcopy(params_1.params)

    df_e_pd = df_e.toPandas()
    df_e_pd = df_e_pd.sort_values(["unique_id_l", "unique_id_r"])

    correct_list = [
        0.658602114,
        0.796821727,
        0.796821727,
        0.189486495,
        0.189486495,
        0.658602114,
        0.495063367,
        0.495063367,
    ]
    result_list = list(df_e_pd["match_probability"].astype(float))

    for i in zip(result_list, correct_list):
        assert i[0] == pytest.approx(i[1], abs=0.0001)

    # Does it still work with another iteration?
    gamma_settings_1["max_iterations"] = 1
    df_e = iterate(df_gammas, params_1, gamma_settings_1, spark)
    assert params_1.params["λ"] == pytest.approx(0.534993426, abs=0.0001)

    assert params_1.params["π"]["gamma_mob"]["prob_dist_match"]["level_0"][
        "probability"] == pytest.approx(0.088546179, abs=0.0001)
    assert params_1.params["π"]["gamma_surname"]["prob_dist_non_match"][
        "level_1"]["probability"] == pytest.approx(0.109234086, abs=0.0001)

    ## Test whether the params object is correctly storing the iteration history

    assert params_1.param_history[0] == original_params
    assert params_1.param_history[1] == first_it_params

    ## Now test whether, when we

    data = params_1._convert_params_dict_to_dataframe(original_params)
    val1 = {
        "gamma": "gamma_mob",
        "match": 0,
        "value_of_gamma": "level_0",
        "probability": 0.8,
        "value": 0,
        "column": "mob",
    }
    val2 = {
        "gamma": "gamma_surname",
        "match": 1,
        "value_of_gamma": "level_1",
        "probability": 0.2,
        "value": 1,
        "column": "surname",
    }

    assert val1 in data
    assert val2 in data

    correct_list = [{
        "iteration": 0,
        "λ": 0.4
    }, {
        "iteration": 1,
        "λ": 0.540922141
    }]

    result_list = params_1._iteration_history_df_lambdas()

    for i in zip(result_list, correct_list):
        assert i[0]["iteration"] == i[1]["iteration"]
        assert i[0]["λ"] == pytest.approx(i[1]["λ"])

    result_list = params_1._iteration_history_df_gammas()

    val1 = {
        "iteration": 0,
        "gamma": "gamma_mob",
        "match": 0,
        "value_of_gamma": "level_0",
        "probability": 0.8,
        "value": 0,
        "column": "mob",
    }
    assert val1 in result_list

    val2 = {
        "iteration": 1,
        "gamma": "gamma_surname",
        "match": 0,
        "value_of_gamma": "level_1",
        "probability": 0.160167628,
        "value": 1,
        "column": "surname",
    }

    for r in result_list:
        if r["iteration"] == 1:
            if r["gamma"] == "gamma_surname":
                if r["match"] == 0:
                    if r["value"] == 1:
                        record = r

    for k, v in record.items():
        expected_value = val2[k]
        if k == "probability":
            assert v == pytest.approx(expected_value, abs=0.0001)
        else:
            assert v == expected_value

    # Test whether saving and loading parameters works
    import tempfile

    dir = tempfile.TemporaryDirectory()
    fname = os.path.join(dir.name, "params.json")

    # print(params_1.params)
    # import json
    # print(json.dumps(params_1.to_dict(), indent=4))

    params_1.save_params_to_json_file(fname)

    from splink.params import load_params_from_json

    p = load_params_from_json(fname)
    assert p.params["λ"] == pytest.approx(params_1.params["λ"])
def test_expectation_and_maximisation(spark):
    settings = {
        "link_type":
        "dedupe_only",
        "proportion_of_matches":
        0.4,
        "comparison_columns": [
            {
                "col_name": "mob",
                "num_levels": 2,
                "m_probabilities": [0.1, 0.9],
                "u_probabilities": [0.8, 0.2],
            },
            {
                "custom_name": "surname",
                "custom_columns_used": ["surname"],
                "num_levels": 3,
                "case_expression": """
                    case
                    when surname_l is null or surname_r is null then -1
                    when surname_l = surname_r then 2
                    when substr(surname_l,1, 3) =  substr(surname_r, 1, 3) then 1
                    else 0
                    end
                    as gamma_surname
                    """,
                "m_probabilities": [0.1, 0.2, 0.7],
                "u_probabilities": [0.5, 0.25, 0.25],
            },
        ],
        "blocking_rules": [
            "l.mob = r.mob",
            "l.surname = r.surname",
        ],
        "retain_intermediate_calculation_columns":
        True,
    }

    rows = [
        {
            "unique_id": 1,
            "mob": 10,
            "surname": "Linacre"
        },
        {
            "unique_id": 2,
            "mob": 10,
            "surname": "Linacre"
        },
        {
            "unique_id": 3,
            "mob": 10,
            "surname": "Linacer"
        },
        {
            "unique_id": 4,
            "mob": 7,
            "surname": "Smith"
        },
        {
            "unique_id": 5,
            "mob": 8,
            "surname": "Smith"
        },
        {
            "unique_id": 6,
            "mob": 8,
            "surname": "Smith"
        },
        {
            "unique_id": 7,
            "mob": 8,
            "surname": "Jones"
        },
    ]

    df_input = spark.createDataFrame(Row(**x) for x in rows)
    df_input.persist()
    params = Model(settings, spark)

    df_comparison = block_using_rules(
        params.current_settings_obj.settings_dict, df_input, spark)
    df_gammas = add_gammas(df_comparison,
                           params.current_settings_obj.settings_dict, spark)
    df_gammas.persist()
    df_e = run_expectation_step(df_gammas, params, spark)
    df_e = df_e.sort("unique_id_l", "unique_id_r")

    df_e.persist()

    ################################################
    # Test probabilities correctly assigned
    ################################################

    df = df_e.toPandas()
    cols_to_keep = [
        "prob_gamma_mob_match",
        "prob_gamma_mob_non_match",
        "prob_gamma_surname_match",
        "prob_gamma_surname_non_match",
    ]
    pd_df_result = df[cols_to_keep][:4]

    df_correct = [
        {
            "prob_gamma_mob_match": 0.9,
            "prob_gamma_mob_non_match": 0.2,
            "prob_gamma_surname_match": 0.7,
            "prob_gamma_surname_non_match": 0.25,
        },
        {
            "prob_gamma_mob_match": 0.9,
            "prob_gamma_mob_non_match": 0.2,
            "prob_gamma_surname_match": 0.2,
            "prob_gamma_surname_non_match": 0.25,
        },
        {
            "prob_gamma_mob_match": 0.9,
            "prob_gamma_mob_non_match": 0.2,
            "prob_gamma_surname_match": 0.2,
            "prob_gamma_surname_non_match": 0.25,
        },
        {
            "prob_gamma_mob_match": 0.1,
            "prob_gamma_mob_non_match": 0.8,
            "prob_gamma_surname_match": 0.7,
            "prob_gamma_surname_non_match": 0.25,
        },
    ]

    pd_df_correct = pd.DataFrame(df_correct)

    assert_frame_equal(pd_df_correct, pd_df_result)

    ################################################
    # Test match probabilities correctly calculated
    ################################################

    result_list = list(df["match_probability"])
    # See https://github.com/moj-analytical-services/splink/blob/master/tests/expectation_maximisation_test_answers.xlsx
    # for derivation of these numbers
    correct_list = [
        0.893617021,
        0.705882353,
        0.705882353,
        0.189189189,
        0.189189189,
        0.893617021,
        0.375,
        0.375,
    ]
    assert result_list == pytest.approx(correct_list)

    ################################################
    # Test new probabilities correctly calculated
    ################################################

    run_maximisation_step(df_e, params, spark)

    new_lambda = params.current_settings_obj["proportion_of_matches"]

    # See https://github.com/moj-analytical-services/splink/blob/master/tests/expectation_maximisation_test_answers.xlsx
    # for derivation of these numbers
    assert new_lambda == pytest.approx(0.540922141)

    rows = [
        ["mob", 0, 0.087438272, 0.441543191],
        ["mob", 1, 0.912561728, 0.558456809],
        ["surname", 0, 0.173315146, 0.340356209],
        ["surname", 1, 0.326240275, 0.160167628],
        ["surname", 2, 0.500444578, 0.499476163],
    ]

    settings_obj = params.current_settings_obj

    for r in rows:
        cc = settings_obj.get_comparison_column(r[0])
        level_dict = cc.level_as_dict(r[1])
        assert level_dict["m_probability"] == pytest.approx(r[2])
        assert level_dict["u_probability"] == pytest.approx(r[3])

    ################################################
    # Test revised probabilities correctly used
    ################################################

    df_e = run_expectation_step(df_gammas, params, spark)
    df_e = df_e.sort("unique_id_l", "unique_id_r")
    result_list = list(df_e.toPandas()["match_probability"])

    correct_list = [
        0.658602114,
        0.796821727,
        0.796821727,
        0.189486495,
        0.189486495,
        0.658602114,
        0.495063367,
        0.495063367,
    ]
    assert result_list == pytest.approx(correct_list)

    run_maximisation_step(df_e, params, spark)
    new_lambda = params.current_settings_obj["proportion_of_matches"]
    assert new_lambda == pytest.approx(0.534993426)

    rows = [
        ["mob", 0, 0.088546179, 0.435753788],
        ["mob", 1, 0.911453821, 0.564246212],
        ["surname", 0, 0.231340865, 0.27146747],
        ["surname", 1, 0.372351177, 0.109234086],
        ["surname", 2, 0.396307958, 0.619298443],
    ]

    settings_obj = params.current_settings_obj

    for r in rows:
        cc = settings_obj.get_comparison_column(r[0])
        level_dict = cc.level_as_dict(r[1])
        assert level_dict["m_probability"] == pytest.approx(r[2])
        assert level_dict["u_probability"] == pytest.approx(r[3])

    ################################################
    # Test whether saving and loading params works
    # (If we load params, does the expectation step yield same answer)
    ################################################
    import tempfile

    dir = tempfile.TemporaryDirectory()
    fname = os.path.join(dir.name, "params.json")

    df_e = run_expectation_step(df_gammas, params, spark)
    params.save_model_to_json_file(fname)

    from splink.model import load_model_from_json

    p = load_model_from_json(fname)

    df_e_2 = run_expectation_step(df_gammas, p, spark)

    assert_frame_equal(df_e.toPandas(), df_e_2.toPandas())