def test_global_lambda_calc(spark):

    # Imagine 10% of matches match in the cartesian product
    # Further imagine amongst records which are matches, 90% of the time surname matches
    # and that amongst records which are non-matches, 1% of the time surname matches
    # So:
    # 9% of records are matches with surname match
    # 0.9% are non-matches with surname match
    # So the proportion of mathes with a surnae match is 9/9.9 = 90.909090%

    settings_1 = {
        "link_type":
        "dedupe_only",
        "proportion_of_matches":
        0.90909090909090,
        "blocking_rules": ["l.surname = r.surname"],
        "comparison_columns": [{
            "col_name": "first_name",
            "m_probabilities": [0.3, 0.7],
            "u_probabilities": [0.8, 0.2],
        }],
    }

    settings_2 = {
        "link_type":
        "dedupe_only",
        "proportion_of_matches":
        0.3,
        "blocking_rules": ["l.first_name = r.first_name"],
        "comparison_columns": [{
            "col_name": "surname",
            "m_probabilities": [0.1, 0.9],
            "u_probabilities": [0.99, 0.01],
        }],
    }

    model_1 = Model(settings_1, spark)
    model_2 = Model(settings_2, spark)

    surname_cc = model_2.current_settings_obj.get_comparison_column("surname")
    dict1 = {
        "name": "dob and first name block",
        "model": model_1,
        "comparison_columns_for_global_lambda": [surname_cc],
    }

    mc = ModelCombiner([dict1])
    settings = mc.get_combined_settings_dict()
    assert settings["proportion_of_matches"] == pytest.approx(0.10)
Exemplo n.º 2
0
def model_example(spark):

    case_expr = """
    case
    when email_l is null or email_r is null then -1
    when email_l = email_r then 1
    else 0
    end
    as gamma_my_custom
    """

    settings = {
        "link_type": "dedupe_only",
        "proportion_of_matches": 0.2,
        "comparison_columns": [
            {"col_name": "fname"},
            {"col_name": "sname", "num_levels": 3},
            {
                "custom_name": "my_custom",
                "custom_columns_used": ["email", "city"],
                "case_expression": case_expr,
                "num_levels": 2,
            },
        ],
        "blocking_rules": [],
    }

    model = Model(settings, spark=spark)

    yield model
Exemplo n.º 3
0
    def __init__(
        self,
        settings: dict,
        df_or_dfs: Union[DataFrame, List[DataFrame]],
        spark: SparkSession,
        save_state_fn: Callable = None,
        break_lineage_blocked_comparisons:
        Callable = default_break_lineage_blocked_comparisons,
        break_lineage_scored_comparisons:
        Callable = default_break_lineage_scored_comparisons,
    ):
        """Splink data linker

        Provides easy access to the core user-facing functinoality of splink

        Args:
            settings (dict): splink settings dictionary
            df_or_dfs (Union[DataFrame, List[DataFrame]]): Either a single Spark dataframe to dedupe, or a list of Spark dataframe to link and or dedupe. Where `link_type` is `dedupe_only`, should be a single dataframe to dedupe. Where `link_type` is `link_only` or `link_and_dedupe`, show be a list of dfs.  Requires conformant dataframes (i.e. they must have same columns)
            spark (SparkSession): SparkSession object
            save_state_fn (function, optional):  A function provided by the user that takes one arguments, model (i.e. a Model from splink.model), and is executed each iteration.  This is a hook that allows the user to save the state between iterations, which is mostly useful for very large jobs which may need to be restarted from where they left off if they fail.
            break_lineage_blocked_comparisons (function, optional): Large jobs will likely run into memory errors unless the lineage is broken after blocking.  This is a user-provided function that takes one argument - df - and allows the user to break lineage.  For example, the function might save df to the AWS s3 file system, and then reload it from the saved files.
            break_lineage_scored_comparisons (function, optional): Large jobs will likely run into memory errors unless the lineage is broken after comparisons are scored and before term frequency adjustments.  This is a user-provided function that takes one argument - df - and allows the user to break lineage.  For example, the function might save df to the AWS s3 file system, and then reload it from the saved files.
        """

        self.spark = spark
        self.break_lineage_blocked_comparisons = break_lineage_blocked_comparisons
        self.break_lineage_scored_comparisons = break_lineage_scored_comparisons
        _check_jaro_registered(spark)

        validate_settings_against_schema(settings)
        validate_link_type(df_or_dfs, settings)

        self.model = Model(settings, spark)
        self.settings_dict = self.model.current_settings_obj.settings_dict
        self.settings_dict = normalise_probabilities(self.settings_dict)
        validate_probabilities(self.settings_dict)
        # dfs is a list of dfs irrespective of whether input was a df or list of dfs
        if type(df_or_dfs) == DataFrame:
            dfs = [df_or_dfs]
        else:
            dfs = df_or_dfs

        self.df = vertically_concatenate_datasets(dfs)
        validate_input_datasets(self.df, self.model.current_settings_obj)
        self.save_state_fn = save_state_fn
def estimate(df_gammas: DataFrame, settings: dict, spark: SparkSession):
    """Take pandas datafrae of gammas and estimate splink model

    Args:
        df_gammas (DataFrame): Pandas dataframe of df_gammas
        settings (dict): Splink settings dictionary
        spark (SparkSession): SparkSession object
    """

    settings["retain_matching_columns"] = False

    df = spark.createDataFrame(df_gammas)

    model = Model(settings, spark)

    df_e = iterate(df, model, spark)

    return df_e, model
Exemplo n.º 5
0
class Splink:
    def __init__(
        self,
        settings: dict,
        df_or_dfs: Union[DataFrame, List[DataFrame]],
        spark: SparkSession,
        save_state_fn: Callable = None,
        break_lineage_blocked_comparisons:
        Callable = default_break_lineage_blocked_comparisons,
        break_lineage_scored_comparisons:
        Callable = default_break_lineage_scored_comparisons,
    ):
        """Splink data linker

        Provides easy access to the core user-facing functinoality of splink

        Args:
            settings (dict): splink settings dictionary
            df_or_dfs (Union[DataFrame, List[DataFrame]]): Either a single Spark dataframe to dedupe, or a list of Spark dataframe to link and or dedupe. Where `link_type` is `dedupe_only`, should be a single dataframe to dedupe. Where `link_type` is `link_only` or `link_and_dedupe`, show be a list of dfs.  Requires conformant dataframes (i.e. they must have same columns)
            spark (SparkSession): SparkSession object
            save_state_fn (function, optional):  A function provided by the user that takes one arguments, model (i.e. a Model from splink.model), and is executed each iteration.  This is a hook that allows the user to save the state between iterations, which is mostly useful for very large jobs which may need to be restarted from where they left off if they fail.
            break_lineage_blocked_comparisons (function, optional): Large jobs will likely run into memory errors unless the lineage is broken after blocking.  This is a user-provided function that takes one argument - df - and allows the user to break lineage.  For example, the function might save df to the AWS s3 file system, and then reload it from the saved files.
            break_lineage_scored_comparisons (function, optional): Large jobs will likely run into memory errors unless the lineage is broken after comparisons are scored and before term frequency adjustments.  This is a user-provided function that takes one argument - df - and allows the user to break lineage.  For example, the function might save df to the AWS s3 file system, and then reload it from the saved files.
        """

        self.spark = spark
        self.break_lineage_blocked_comparisons = break_lineage_blocked_comparisons
        self.break_lineage_scored_comparisons = break_lineage_scored_comparisons
        _check_jaro_registered(spark)

        validate_settings(settings)
        validate_link_type(df_or_dfs, settings)
        self.model = Model(settings, spark)
        self.settings_dict = self.model.current_settings_obj.settings_dict

        # dfs is a list of dfs irrespective of whether input was a df or list of dfs
        if type(df_or_dfs) == DataFrame:
            dfs = [df_or_dfs]
        else:
            dfs = df_or_dfs

        self.df = vertically_concatenate_datasets(dfs)
        validate_input_datasets(self.df, self.model.current_settings_obj)
        self.save_state_fn = save_state_fn

    def manually_apply_fellegi_sunter_weights(self):
        """Compute match probabilities from m and u probabilities specified in the splink settings object

        Returns:
            DataFrame: A spark dataframe including a match probability column
        """
        df_comparison = block_using_rules(self.settings_dict, self.df,
                                          self.spark)
        df_gammas = add_gammas(df_comparison, self.settings_dict, self.spark)
        return run_expectation_step(df_gammas, self.model, self.spark)

    def get_scored_comparisons(self, compute_ll=False):
        """Use the EM algorithm to estimate model parameters and return match probabilities.

        Returns:
            DataFrame: A spark dataframe including a match probability column
        """

        df_comparison = block_using_rules(self.settings_dict, self.df,
                                          self.spark)

        df_gammas = add_gammas(df_comparison, self.settings_dict, self.spark)

        df_gammas = self.break_lineage_blocked_comparisons(
            df_gammas, self.spark)

        df_e = iterate(
            df_gammas,
            self.model,
            self.spark,
            compute_ll=compute_ll,
            save_state_fn=self.save_state_fn,
        )

        # In case the user's break lineage function has persisted it
        df_gammas.unpersist()

        df_e = self.break_lineage_scored_comparisons(df_e, self.spark)

        df_e_adj = self.make_term_frequency_adjustments(df_e)

        df_e.unpersist()

        return df_e_adj

    def make_term_frequency_adjustments(self, df_e: DataFrame):
        """Take the outputs of 'get_scored_comparisons' and make term frequency adjustments on designated columns in the settings dictionary

        Args:
            df_e (DataFrame): A dataframe produced by the get_scored_comparisons method

        Returns:
            DataFrame: A spark dataframe including a column with term frequency adjusted match probabilities
        """

        return make_adjustment_for_term_frequencies(
            df_e,
            self.model,
            retain_adjustment_columns=True,
            spark=self.spark,
        )

    def save_model_as_json(self, path: str, overwrite=False):
        """Save model (settings, parameters and parameter history) as a json file so it can later be re-loaded using load_from_json

        Args:
            path (str): Path to the json file.
            overwrite (bool): Whether to overwrite the file if it exsits
        """
        self.model.save_model_to_json_file(path, overwrite=overwrite)
def test_average_calc_m_u(spark):
    settings_1 = {
        "link_type":
        "link_and_dedupe",
        "blocking_rules": ["l.forename = r.forename"],
        "comparison_columns": [
            {
                "col_name": "surname",
                "num_levels": 3,
                "m_probabilities": [0.1, 0.4, 0.5],
                "u_probabilities": [0.8, 0.1, 0.1],
            },
            {
                "col_name": "email",
                "num_levels": 2,
                "m_probabilities": [0.1, 0.9],
                "u_probabilities": [0.9, 0.1],
            },
        ],
    }

    settings_2 = {
        "link_type":
        "link_and_dedupe",
        "blocking_rules": ["l.surname = r.surname"],
        "comparison_columns": [
            {
                "col_name": "forename",
                "num_levels": 2,
                "m_probabilities": [0.1, 0.9],
                "u_probabilities": [0.9, 0.1],
            },
            {
                "col_name": "email",
                "num_levels": 2,
                "m_probabilities": [0.1, 0.9],
                "u_probabilities": [0.85, 0.15],
            },
        ],
    }

    settings_3 = {
        "link_type":
        "link_and_dedupe",
        "blocking_rules": ["l.dob = r.dob"],
        "comparison_columns": [
            {
                "col_name": "forename",
                "num_levels": 3,
                "m_probabilities": [0.1, 0.4, 0.5],
                "u_probabilities": [0.8, 0.1, 0.1],
            },
            {
                "col_name": "surname",
                "num_levels": 3,
                "m_probabilities": [0.2, 0.4, 0.4],
                "u_probabilities": [0.8, 0.1, 0.1],
            },
            {
                "col_name": "email",
                "num_levels": 2,
                "m_probabilities": [0.1, 0.9],
                "u_probabilities": [0.7, 0.3],
            },
        ],
    }

    model_1 = Model(settings_1, spark)
    model_2 = Model(settings_2, spark)
    model_3 = Model(settings_3, spark)

    dict1 = {
        "name": "first name block",
        "model": model_1,
        # "comparison_columns_for_global_lambda": [first_name_cc],
    }

    dict2 = {
        "name": "surname block",
        "model": model_2,
        # "comparison_columns_for_global_lambda": [surname_cc],
    }

    dict3 = {
        "name": "dob block",
        "model": model_3,
        # "comparison_columns_for_global_lambda": [dob_cc],
    }

    mc = ModelCombiner([dict1, dict2, dict3])

    settings_dict = mc.get_combined_settings_dict(median)

    settings = Settings(settings_dict)
    email = settings.get_comparison_column("email")
    actual = email["u_probabilities"][0]
    expected = median([0.9, 0.85, 0.7])
    assert actual == pytest.approx(expected)

    surname = settings.get_comparison_column("surname")
    actual = surname["m_probabilities"][2]
    expected = median([0.4, 0.5])
    assert actual == pytest.approx(expected)

    assert len(settings_dict["blocking_rules"]) == 3

    settings_4_with_nulls = {
        "link_type":
        "link_and_dedupe",
        "blocking_rules": ["l.email = r.email"],
        "comparison_columns": [
            {
                "col_name": "forename",
                "num_levels": 3,
                "m_probabilities": [None, 0.4, 0.5],
                "u_probabilities": [0.8, 0.1, 0.1],
            },
            {
                "col_name": "surname",
                "num_levels": 3,
                "m_probabilities": [0.1, 0.4, 0.5],
                "u_probabilities": [0.8, 0.1, 0.1],
            },
        ],
    }

    model_4 = Model(settings_4_with_nulls, spark)

    dict4 = {
        "name": "email block",
        "model": model_4,
        # "comparison_columns_for_global_lambda": [dob_cc],
    }

    mc = ModelCombiner([dict4])

    with pytest.warns(UserWarning):
        settings_dict = mc.get_combined_settings_dict(median)

    mc = ModelCombiner([dict1, dict2, dict3, dict4])
    settings = Settings(settings_dict)
    forename = settings.get_comparison_column("forename")
    actual = forename["m_probabilities"][0]
    assert actual is None
def test_expectation_and_maximisation(spark):
    settings = {
        "link_type":
        "dedupe_only",
        "proportion_of_matches":
        0.4,
        "comparison_columns": [
            {
                "col_name": "mob",
                "num_levels": 2,
                "m_probabilities": [0.1, 0.9],
                "u_probabilities": [0.8, 0.2],
            },
            {
                "custom_name": "surname",
                "custom_columns_used": ["surname"],
                "num_levels": 3,
                "case_expression": """
                    case
                    when surname_l is null or surname_r is null then -1
                    when surname_l = surname_r then 2
                    when substr(surname_l,1, 3) =  substr(surname_r, 1, 3) then 1
                    else 0
                    end
                    as gamma_surname
                    """,
                "m_probabilities": [0.1, 0.2, 0.7],
                "u_probabilities": [0.5, 0.25, 0.25],
            },
        ],
        "blocking_rules": [
            "l.mob = r.mob",
            "l.surname = r.surname",
        ],
        "retain_intermediate_calculation_columns":
        True,
    }

    rows = [
        {
            "unique_id": 1,
            "mob": 10,
            "surname": "Linacre"
        },
        {
            "unique_id": 2,
            "mob": 10,
            "surname": "Linacre"
        },
        {
            "unique_id": 3,
            "mob": 10,
            "surname": "Linacer"
        },
        {
            "unique_id": 4,
            "mob": 7,
            "surname": "Smith"
        },
        {
            "unique_id": 5,
            "mob": 8,
            "surname": "Smith"
        },
        {
            "unique_id": 6,
            "mob": 8,
            "surname": "Smith"
        },
        {
            "unique_id": 7,
            "mob": 8,
            "surname": "Jones"
        },
    ]

    df_input = spark.createDataFrame(Row(**x) for x in rows)
    df_input.persist()
    params = Model(settings, spark)

    df_comparison = block_using_rules(
        params.current_settings_obj.settings_dict, df_input, spark)
    df_gammas = add_gammas(df_comparison,
                           params.current_settings_obj.settings_dict, spark)
    df_gammas.persist()
    df_e = run_expectation_step(df_gammas, params, spark)
    df_e = df_e.sort("unique_id_l", "unique_id_r")

    df_e.persist()

    ################################################
    # Test probabilities correctly assigned
    ################################################

    df = df_e.toPandas()
    cols_to_keep = [
        "prob_gamma_mob_match",
        "prob_gamma_mob_non_match",
        "prob_gamma_surname_match",
        "prob_gamma_surname_non_match",
    ]
    pd_df_result = df[cols_to_keep][:4]

    df_correct = [
        {
            "prob_gamma_mob_match": 0.9,
            "prob_gamma_mob_non_match": 0.2,
            "prob_gamma_surname_match": 0.7,
            "prob_gamma_surname_non_match": 0.25,
        },
        {
            "prob_gamma_mob_match": 0.9,
            "prob_gamma_mob_non_match": 0.2,
            "prob_gamma_surname_match": 0.2,
            "prob_gamma_surname_non_match": 0.25,
        },
        {
            "prob_gamma_mob_match": 0.9,
            "prob_gamma_mob_non_match": 0.2,
            "prob_gamma_surname_match": 0.2,
            "prob_gamma_surname_non_match": 0.25,
        },
        {
            "prob_gamma_mob_match": 0.1,
            "prob_gamma_mob_non_match": 0.8,
            "prob_gamma_surname_match": 0.7,
            "prob_gamma_surname_non_match": 0.25,
        },
    ]

    pd_df_correct = pd.DataFrame(df_correct)

    assert_frame_equal(pd_df_correct, pd_df_result)

    ################################################
    # Test match probabilities correctly calculated
    ################################################

    result_list = list(df["match_probability"])
    # See https://github.com/moj-analytical-services/splink/blob/master/tests/expectation_maximisation_test_answers.xlsx
    # for derivation of these numbers
    correct_list = [
        0.893617021,
        0.705882353,
        0.705882353,
        0.189189189,
        0.189189189,
        0.893617021,
        0.375,
        0.375,
    ]
    assert result_list == pytest.approx(correct_list)

    ################################################
    # Test new probabilities correctly calculated
    ################################################

    run_maximisation_step(df_e, params, spark)

    new_lambda = params.current_settings_obj["proportion_of_matches"]

    # See https://github.com/moj-analytical-services/splink/blob/master/tests/expectation_maximisation_test_answers.xlsx
    # for derivation of these numbers
    assert new_lambda == pytest.approx(0.540922141)

    rows = [
        ["mob", 0, 0.087438272, 0.441543191],
        ["mob", 1, 0.912561728, 0.558456809],
        ["surname", 0, 0.173315146, 0.340356209],
        ["surname", 1, 0.326240275, 0.160167628],
        ["surname", 2, 0.500444578, 0.499476163],
    ]

    settings_obj = params.current_settings_obj

    for r in rows:
        cc = settings_obj.get_comparison_column(r[0])
        level_dict = cc.level_as_dict(r[1])
        assert level_dict["m_probability"] == pytest.approx(r[2])
        assert level_dict["u_probability"] == pytest.approx(r[3])

    ################################################
    # Test revised probabilities correctly used
    ################################################

    df_e = run_expectation_step(df_gammas, params, spark)
    df_e = df_e.sort("unique_id_l", "unique_id_r")
    result_list = list(df_e.toPandas()["match_probability"])

    correct_list = [
        0.658602114,
        0.796821727,
        0.796821727,
        0.189486495,
        0.189486495,
        0.658602114,
        0.495063367,
        0.495063367,
    ]
    assert result_list == pytest.approx(correct_list)

    run_maximisation_step(df_e, params, spark)
    new_lambda = params.current_settings_obj["proportion_of_matches"]
    assert new_lambda == pytest.approx(0.534993426)

    rows = [
        ["mob", 0, 0.088546179, 0.435753788],
        ["mob", 1, 0.911453821, 0.564246212],
        ["surname", 0, 0.231340865, 0.27146747],
        ["surname", 1, 0.372351177, 0.109234086],
        ["surname", 2, 0.396307958, 0.619298443],
    ]

    settings_obj = params.current_settings_obj

    for r in rows:
        cc = settings_obj.get_comparison_column(r[0])
        level_dict = cc.level_as_dict(r[1])
        assert level_dict["m_probability"] == pytest.approx(r[2])
        assert level_dict["u_probability"] == pytest.approx(r[3])

    ################################################
    # Test whether saving and loading params works
    # (If we load params, does the expectation step yield same answer)
    ################################################
    import tempfile

    dir = tempfile.TemporaryDirectory()
    fname = os.path.join(dir.name, "params.json")

    df_e = run_expectation_step(df_gammas, params, spark)
    params.save_model_to_json_file(fname)

    from splink.model import load_model_from_json

    p = load_model_from_json(fname)

    df_e_2 = run_expectation_step(df_gammas, p, spark)

    assert_frame_equal(df_e.toPandas(), df_e_2.toPandas())
def test_term_frequency_adjustments(spark):

    # The strategy is going to be to create a fake dataframe
    # where we have different levels to model frequency imbalance
    # gamma=3 is where name matches and name is robin (unusual name)
    # gamma=2 is where name matches and name is matt (normal name)
    # gamma=1 is where name matches and name is john (v common name)

    # We simulate the term frequency imbalance
    # by pooling this together, setting all gamma >0
    # to equal 1

    # We then expect that
    # term frequency adjustments should adjust up the
    # robins but adjust down the johns

    # We also expect that the tf adjusted match probability should be more accurate

    forename_probs = _probabilities_from_freqs([3, 2, 1])
    surname_probs = _probabilities_from_freqs([10, 5, 1])

    settings_true = {
        "link_type":
        "dedupe_only",
        "proportion_of_matches":
        0.5,
        "comparison_columns": [
            {
                "col_name": "forename",
                "term_frequency_adjustments": True,
                "m_probabilities": forename_probs["m_probabilities"],
                "u_probabilities": forename_probs["u_probabilities"],
                "num_levels": 4,
            },
            {
                "col_name": "surname",
                "term_frequency_adjustments": True,
                "m_probabilities": surname_probs["m_probabilities"],
                "u_probabilities": surname_probs["u_probabilities"],
                "num_levels": 4,
            },
            {
                "col_name": "cat_20",
                "m_probabilities": [0.2, 0.8],
                "u_probabilities": [19 / 20, 1 / 20],
            },
        ],
    }

    settings_true = complete_settings_dict(settings_true, spark)

    df = generate_df_gammas_random(10000, settings_true)

    # Create new binary columns that binarise the more granular gammas to 0 and 1
    df["gamma_forename_binary"] = df["gamma_forename"].where(
        df["gamma_forename"] == 0, 1)

    df["gamma_surname_binary"] = df["gamma_surname"].where(
        df["gamma_surname"] == 0, 1)

    # Populate non matches with random value
    # Then assign left and right values ased on the gamma values
    df["forename_binary_l"] = df["unique_id_l"]
    df["forename_binary_r"] = df["unique_id_r"]

    f1 = df["gamma_forename"] == 3
    df.loc[f1, "forename_binary_l"] = "Robin"
    df.loc[f1, "forename_binary_r"] = "Robin"

    f1 = df["gamma_forename"] == 2
    df.loc[f1, "forename_binary_l"] = "Matt"
    df.loc[f1, "forename_binary_r"] = "Matt"

    f1 = df["gamma_forename"] == 1
    df.loc[f1, "forename_binary_l"] = "John"
    df.loc[f1, "forename_binary_r"] = "John"

    # Populate non matches with random value
    df["surname_binary_l"] = df["unique_id_l"]
    df["surname_binary_r"] = df["unique_id_r"]

    f1 = df["gamma_surname"] == 3
    df.loc[f1, "surname_binary_l"] = "Linacre"
    df.loc[f1, "surname_binary_r"] = "Linacre"

    f1 = df["gamma_surname"] == 2
    df.loc[f1, "surname_binary_l"] = "Hughes"
    df.loc[f1, "surname_binary_r"] = "Hughes"

    f1 = df["gamma_surname"] == 1
    df.loc[f1, "surname_binary_l"] = "Smith"
    df.loc[f1, "surname_binary_r"] = "Smith"

    # cat20
    df["cat_20_l"] = df["unique_id_l"]
    df["cat_20_r"] = df["unique_id_r"]

    f1 = df["gamma_cat_20"] == 1
    df.loc[f1, "cat_20_l"] = "a"
    df.loc[f1, "cat_20_r"] = "a"

    df = add_match_prob(df, settings_true)
    df["match_probability"] = df["true_match_probability_l"]

    df_e = spark.createDataFrame(df)

    def four_to_two(probs):
        return [probs[0], sum(probs[1:])]

    settings_binary = {
        "link_type":
        "dedupe_only",
        "proportion_of_matches":
        0.5,
        "comparison_columns": [
            {
                "col_name": "forename_binary",
                "term_frequency_adjustments": True,
                "num_levels": 2,
                "m_probabilities":
                four_to_two(forename_probs["m_probabilities"]),
                "u_probabilities":
                four_to_two(forename_probs["u_probabilities"]),
            },
            {
                "col_name": "surname_binary",
                "term_frequency_adjustments": True,
                "num_levels": 2,
                "m_probabilities":
                four_to_two(surname_probs["m_probabilities"]),
                "u_probabilities":
                four_to_two(surname_probs["u_probabilities"]),
            },
            {
                "col_name": "cat_20",
                "m_probabilities": [0.2, 0.8],
                "u_probabilities": [19 / 20, 1 / 20],
            },
        ],
        "retain_intermediate_calculation_columns":
        True,
        "max_iterations":
        0,
        "additional_columns_to_retain": ["true_match_probability"],
    }

    # Can't use linker = Splink() because we have df_gammas, not df
    settings_binary = complete_settings_dict(settings_binary, spark)
    model = Model(settings_binary, spark)
    df_e = iterate(df_e, model, spark)

    df_e = make_adjustment_for_term_frequencies(df_e,
                                                model,
                                                spark,
                                                retain_adjustment_columns=True)

    df = df_e.toPandas()

    #########
    # Tests start here
    #########

    # Test that overall square error is better for tf adjusted match prob
    df["e1"] = (df["match_probability"] - df["true_match_probability_l"])**2
    df["e2"] = (df["tf_adjusted_match_prob"] -
                df["true_match_probability_l"])**2
    assert df["e1"].sum() > df["e2"].sum()

    # We expect Johns to be adjusted down...
    f1 = df["forename_binary_l"] == "John"
    df_filtered = df[f1]
    adj = df_filtered["forename_binary_tf_adj"].mean()
    assert adj < 0.5

    # And Robins to be adjusted up
    f1 = df["forename_binary_l"] == "Robin"
    df_filtered = df[f1]
    adj = df_filtered["forename_binary_tf_adj"].mean()
    assert adj > 0.5

    # We expect Smiths to be adjusted down...
    f1 = df["surname_binary_l"] == "Smith"
    df_filtered = df[f1]
    adj = df_filtered["surname_binary_tf_adj"].mean()
    assert adj < 0.5

    # And Linacres to be adjusted up
    f1 = df["surname_binary_l"] == "Linacre"
    df_filtered = df[f1]
    adj = df_filtered["surname_binary_tf_adj"].mean()
    assert adj > 0.5

    # Check adjustments are applied correctly

    f1 = df["forename_binary_l"] == "Robin"
    f2 = df["surname_binary_l"] == "Linacre"
    df_filtered = df[f1 & f2]
    row = df_filtered.head(1).to_dict(orient="records")[0]

    prior = row["match_probability"]
    posterior = row["tf_adjusted_match_prob"]

    b1 = row["forename_binary_tf_adj"]
    b2 = row["surname_binary_tf_adj"]

    expected_post = (prior * b1 * b2 / (prior * b1 * b2 + (1 - prior) *
                                        (1 - b1) * (1 - b2)))
    assert posterior == pytest.approx(expected_post)

    #  We expect match probability to be equal to tf_adjusted match probability in cases where surname and forename don't match
    f1 = df["surname_binary_l"] != df["surname_binary_r"]
    f2 = df["forename_binary_l"] != df["forename_binary_r"]

    df_filtered = df[f1 & f2]
    sum_difference = (df_filtered["tf_adjusted_match_prob"] -
                      df_filtered["match_probability"]).sum()

    assert 0 == pytest.approx(sum_difference)