コード例 #1
0
def params_4(gamma_settings_4):

    # Probability columns
    params = Params(gamma_settings_4, spark="supress_warnings")

    params._generate_param_dict()
    yield params
コード例 #2
0
def test_tiny_numbers(spark, sqlite_con_1):

    # Regression test, see https://github.com/moj-analytical-services/splink/issues/48

    dfpd = pd.read_sql("select * from test1", sqlite_con_1)
    df = spark.createDataFrame(dfpd)

    settings = {
        "link_type": "dedupe_only",
        "proportion_of_matches": 0.4,
        "comparison_columns": [
            {
                "col_name": "mob",
                "num_levels": 2,
                "m_probabilities": [5.9380419956766985e-25, 1 - 5.9380419956766985e-25],
                "u_probabilities": [0.8, 0.2],
            },
            {"col_name": "surname", "num_levels": 2,},
        ],
        "blocking_rules": ["l.mob = r.mob", "l.surname = r.surname",],
    }

    settings = complete_settings_dict(settings, spark=None)

    df_comparison = block_using_rules(settings, df=df, spark=spark)

    df_gammas = add_gammas(df_comparison, settings, spark)
    params = Params(settings, spark="supress_warnings")

    df_e = run_expectation_step(df_gammas, params, settings, spark)
コード例 #3
0
    def __init__(self,
                 settings: dict,
                 spark: SparkSession,
                 df_l: DataFrame = None,
                 df_r: DataFrame = None,
                 df: DataFrame = None,
                 save_state_fn: Callable = None,
                 break_lineage_blocked_comparisons:
                 Callable = default_break_lineage_blocked_comparisons,
                 break_lineage_scored_comparisons:
                 Callable = default_break_lineage_scored_comparisons):
        """splink data linker

        Provides easy access to the core user-facing functinoality of splink

        Args:
            settings (dict): splink settings dictionary
            spark (SparkSession): SparkSession object
            df_l (DataFrame, optional): A dataframe to link/dedupe. Where `link_type` is `link_only` or `link_and_dedupe`, one of the two dataframes to link. Should be ommitted `link_type` is `dedupe_only`.
            df_r (DataFrame, optional): A dataframe to link/dedupe. Where `link_type` is `link_only` or `link_and_dedupe`, one of the two dataframes to link. Should be ommitted `link_type` is `dedupe_only`.
            df (DataFrame, optional): The dataframe to dedupe. Where `link_type` is `dedupe_only`, the dataframe to dedupe. Should be ommitted `link_type` is `link_only` or `link_and_dedupe`.
            save_state_fn (function, optional):  A function provided by the user that takes two arguments, params and settings, and is executed each iteration.  This is a hook that allows the user to save the state between iterations, which is mostly useful for very large jobs which may need to be restarted from where they left off if they fail.
            break_lineage_blocked_comparisons (function, optional): Large jobs will likely run into memory errors unless the lineage is broken after blocking.  This is a user-provided function that takes one argument - df - and allows the user to break lineage.  For example, the function might save df to the AWS s3 file system, and then reload it from the saved files.
            break_lineage_scored_comparisons (function, optional): Large jobs will likely run into memory errors unless the lineage is broken after comparisons are scored and before term frequency adjustments.  This is a user-provided function that takes one argument - df - and allows the user to break lineage.  For example, the function might save df to the AWS s3 file system, and then reload it from the saved files.
        """

        self.spark = spark
        self.break_lineage_blocked_comparisons = break_lineage_blocked_comparisons
        self.break_lineage_scored_comparisons = break_lineage_scored_comparisons
        _check_jaro_registered(spark)

        settings = complete_settings_dict(settings, spark)
        validate_settings(settings)
        self.settings = settings

        self.params = Params(settings, spark)

        self.df_r = df_r
        self.df_l = df_l
        self.df = df
        self.save_state_fn = save_state_fn
        self._check_args()
コード例 #4
0
def test_update_settings():

    old_settings = {
        "link_type":
        "dedupe_only",
        "proportion_of_matches":
        0.2,
        "comparison_columns": [{
            "col_name": "fname"
        }, {
            "col_name": "sname",
            "num_levels": 3
        }],
        "blocking_rules": []
    }

    params = Params(old_settings, spark="supress_warnings")

    new_settings = {
        "link_type":
        "dedupe_only",
        "blocking_rules": [],
        "comparison_columns": [{
            "col_name": "fname",
            "num_levels": 3,
            "m_probabilities": [0.02, 0.03, 0.95],
            "u_probabilities": [0.92, 0.05, 0.03]
        }, {
            "custom_name": "sname",
            "custom_columns_used": ["fname", "sname"],
            "num_levels": 3,
            "case_expression": """
                    case when concat(fname_l, sname_l) = concat(fname_r, sname_r) then 1
                    else 0 end
                """,
            "m_probabilities": [0.01, 0.02, 0.97],
            "u_probabilities": [0.9, 0.05, 0.05]
        }, {
            "col_name": "dob"
        }]
    }

    update = get_or_update_settings(params, new_settings)

    # new settings used due to num_levels mismatch
    assert update["comparison_columns"][0]["m_probabilities"] == new_settings[
        "comparison_columns"][0]["m_probabilities"]
    # new settings updated with old settings
    assert update["comparison_columns"][1]["u_probabilities"] == pytest.approx(
        params.settings["comparison_columns"][1]["u_probabilities"])
コード例 #5
0
    def __init__(
        self,
        settings: dict,
        spark: SparkSession,
        df_l: DataFrame = None,
        df_r: DataFrame = None,
        df: DataFrame = None,
        save_state_fn: Callable = None,
    ):
        """splink data linker

        Provides easy access to the core user-facing functinoality of splink

        Args:
            settings (dict): splink settings dictionary
            spark (SparkSession): SparkSession object
            df_l (DataFrame, optional): A dataframe to link/dedupe. Where `link_type` is `link_only` or `link_and_dedupe`, one of the two dataframes to link. Should be ommitted `link_type` is `dedupe_only`.
            df_r (DataFrame, optional): A dataframe to link/dedupe. Where `link_type` is `link_only` or `link_and_dedupe`, one of the two dataframes to link. Should be ommitted `link_type` is `dedupe_only`.
            df (DataFrame, optional): The dataframe to dedupe. Where `link_type` is `dedupe_only`, the dataframe to dedupe. Should be ommitted `link_type` is `link_only` or `link_and_dedupe`.
            save_state_fn (function, optional):  A function provided by the user that takes two arguments, params and settings, and is executed each iteration.  This is a hook that allows the user to save the state between iterations, which is mostly useful for very large jobs which may need to be restarted from where they left off if they fail.

        """

        self.spark = spark
        _check_jaro_registered(spark)

        settings = complete_settings_dict(settings, spark)
        validate_settings(settings)
        self.settings = settings

        self.params = Params(settings, spark)

        self.df_r = df_r
        self.df_l = df_l
        self.df = df
        self.save_state_fn = save_state_fn
        self._check_args()
コード例 #6
0
def param_example():
    gamma_settings = {
                    "link_type": "dedupe_only",
                      "proportion_of_matches": 0.2,
                     "comparison_columns": [
                        {"col_name": "fname"},
                        {"col_name": "sname",
                        "num_levels": 3}
                    ],
                    "blocking_rules": []
                    }

    params = Params(gamma_settings, spark="supress_warnings")

    yield params
コード例 #7
0
def test_probability_columns(sqlite_con_1, gamma_settings_1):

    params = Params(gamma_settings_1, spark="supress_warnings")

    sql = _sql_gen_gamma_prob_columns(params, gamma_settings_1, "df_gammas1")
    df = pd.read_sql(sql, sqlite_con_1)

    cols_to_keep = [
        "prob_gamma_mob_match", "prob_gamma_mob_non_match",
        "prob_gamma_surname_match", "prob_gamma_surname_non_match"
    ]
    pd_df_result = df[cols_to_keep][:4]

    df_correct = [{
        "prob_gamma_mob_match": 0.9,
        "prob_gamma_mob_non_match": 0.2,
        "prob_gamma_surname_match": 0.7,
        "prob_gamma_surname_non_match": 0.25
    }, {
        "prob_gamma_mob_match": 0.9,
        "prob_gamma_mob_non_match": 0.2,
        "prob_gamma_surname_match": 0.2,
        "prob_gamma_surname_non_match": 0.25
    }, {
        "prob_gamma_mob_match": 0.9,
        "prob_gamma_mob_non_match": 0.2,
        "prob_gamma_surname_match": 0.2,
        "prob_gamma_surname_non_match": 0.25
    }, {
        "prob_gamma_mob_match": 0.1,
        "prob_gamma_mob_non_match": 0.8,
        "prob_gamma_surname_match": 0.7,
        "prob_gamma_surname_non_match": 0.25
    }]

    pd_df_correct = pd.DataFrame(df_correct)

    assert_frame_equal(pd_df_correct, pd_df_result)
コード例 #8
0
ファイル: test_spark.py プロジェクト: sdmurff/splink
def test_term_frequency_adjustments(spark):

    settings = {
        "link_type":
        "dedupe_only",
        "proportion_of_matches":
        0.1,
        "comparison_columns": [
            {
                "col_name":
                "name",
                "term_frequency_adjustments":
                True,
                "m_probabilities": [
                    0.1,  # Amonst matches, 10% are have typose
                    0.9  # The reamining 90% have a match
                ],
                "u_probabilities": [
                    4 /
                    5,  # Among non matches, 80% of the time there's no match
                    1 /
                    5  # But 20% of the time names 'collide'  WE WANT THESE U PROBABILITIES TO BE DEPENDENT ON NAME.  
                ],
            },
            {
                "col_name": "cat_12",
                "m_probabilities": [0.05, 0.95],
                "u_probabilities": [11 / 12, 1 / 12],
            },
            {
                "col_name": "cat_20",
                "m_probabilities": [0.2, 0.8],
                "u_probabilities": [19 / 20, 1 / 20],
            }
        ],
        "em_convergence":
        0.001
    }

    from string import ascii_letters
    import statistics
    import random
    from splink.settings import complete_settings_dict
    settings = complete_settings_dict(settings, spark="supress_warnings")

    def is_match(settings):
        p = settings["proportion_of_matches"]
        return random.choices([0, 1], [1 - p, p])[0]

    def get_row_portion(match, comparison_col, skew="auto"):
        # Problem is that at the moment we're guaranteeing that a match on john is just as likely to be a match as a match on james

        # What we want is to generate more 'collisions' for john than robin i.e. if it's a non match, we want more gamma = 1 on name for john

        if match:
            gamma_pdist = comparison_col["m_probabilities"]
        else:
            gamma_pdist = comparison_col["u_probabilities"]

        # To decide whether gamma = 0 or 1 in the case of skew, we first need to decide on what value the left hand value column will take (well, what probability it has of selection)

        # How many distinct values should be choose?
        num_values = int(round(1 / comparison_col["u_probabilities"][1]))

        if skew == "auto":
            skew = comparison_col["term_frequency_adjustments"]

        if skew:

            prob_dist = range(
                1, num_values +
                1)[::-1]  # a most freqent, last value least frequent
            # Normalise
            prob_dist = [p / sum(prob_dist) for p in prob_dist]

            index_of_value = random.choices(range(num_values), prob_dist)[0]
            if not match:  # If it's a u probability
                this_prob = prob_dist[index_of_value]
                gamma_pdist = [1 - this_prob, this_prob]

        else:
            prob_dist = [1 / num_values] * num_values
            index_of_value = random.choices(range(num_values), prob_dist)[0]

        levels = comparison_col["num_levels"]
        gamma = random.choices(range(levels), gamma_pdist)[0]

        values = ascii_letters[:26]
        if num_values > 26:
            values = [
                a + b for a in ascii_letters[:26] for b in ascii_letters[:26]
            ]  #aa, ab etc

        values = values[:num_values]

        if gamma == 1:
            value_1 = values[index_of_value]
            value_2 = value_1

        if gamma == 0:
            value_1 = values[index_of_value]
            same_value = True
            while same_value:
                value_2 = random.choices(values, prob_dist)[0]
                if value_1 != value_2:
                    same_value = False

        cname = comparison_col["col_name"]
        return {
            f"{cname}_l": value_1,
            f"{cname}_r": value_2,
            f"gamma_{cname}": gamma
        }

    import uuid
    rows = []
    for uid in range(100000):
        m = is_match(settings)
        row = {
            "unique_id_l": str(uuid.uuid4()),
            "unique_id_r": str(uuid.uuid4()),
            "match": m
        }
        for cc in settings["comparison_columns"]:
            row_portion = get_row_portion(m, cc)
            row = {**row, **row_portion}
        rows.append(row)

    all_rows = pd.DataFrame(rows)
    df_gammas = spark.createDataFrame(all_rows)

    settings["comparison_columns"][1]["term_frequency_adjustments"] = True

    from splink import Splink
    from splink.params import Params
    from splink.iterate import iterate
    from splink.term_frequencies import make_adjustment_for_term_frequencies

    # We have table of gammas - need to work from there within splink
    params = Params(settings, spark)

    df_e = iterate(df_gammas, params, settings, spark, compute_ll=False)

    df_e_adj = make_adjustment_for_term_frequencies(
        df_e, params, settings, retain_adjustment_columns=True, spark=spark)

    df_e_adj.createOrReplaceTempView("df_e_adj")
    sql = """
    select name_l, name_tf_adj,  count(*)
    from df_e_adj
    where name_l = name_r
    group by name_l, name_tf_adj
    order by name_l
    """
    df = spark.sql(sql).toPandas()
    df = df.set_index("name_l")
    df_dict = df.to_dict(orient='index')
    assert df_dict['a']["name_tf_adj"] < 0.5

    assert df_dict['e']["name_tf_adj"] > 0.5
    assert df_dict['e'][
        "name_tf_adj"] > 0.6  #Arbitrary numbers, but we do expect a big uplift here
    assert df_dict['e'][
        "name_tf_adj"] < 0.95  #Arbitrary numbers, but we do expect a big uplift here

    df_e_adj.createOrReplaceTempView("df_e_adj")
    sql = """
    select cat_12_l, cat_12_tf_adj,  count(*) as count
    from df_e_adj
    where cat_12_l = cat_12_r
    group by cat_12_l, cat_12_tf_adj
    order by cat_12_l
    """
    spark.sql(sql).toPandas()
    df = spark.sql(sql).toPandas()
    assert df["cat_12_tf_adj"].max(
    ) < 0.55  # Keep these loose because when generating random data anything can happen!
    assert df["cat_12_tf_adj"].min() > 0.45

    # Test adjustments applied coorrectly when there is one
    df_e_adj.createOrReplaceTempView("df_e_adj")
    sql = """
    select *
    from df_e_adj
    where name_l = name_r and cat_12_l != cat_12_r
    limit 1
    """
    df = spark.sql(sql).toPandas()
    df_dict = df.loc[0, :].to_dict()

    def bayes(p1, p2):
        return p1 * p2 / (p1 * p2 + (1 - p1) * (1 - p2))

    assert df_dict["tf_adjusted_match_prob"] == pytest.approx(
        bayes(df_dict["match_probability"], df_dict["name_tf_adj"]))

    # Test adjustments applied coorrectly when there are multiple
    df_e_adj.createOrReplaceTempView("df_e_adj")
    sql = """
    select *
    from df_e_adj
    where name_l = name_r and cat_12_l = cat_12_r
    limit 1
    """
    df = spark.sql(sql).toPandas()
    df_dict = df.loc[0, :].to_dict()

    double_b = bayes(
        bayes(df_dict["match_probability"], df_dict["name_tf_adj"]),
        df_dict["cat_12_tf_adj"])

    assert df_dict["tf_adjusted_match_prob"] == pytest.approx(double_b)
コード例 #9
0
class Splink:
    @check_types
    def __init__(self,
                 settings: dict,
                 spark: SparkSession,
                 df_l: DataFrame = None,
                 df_r: DataFrame = None,
                 df: DataFrame = None,
                 save_state_fn: Callable = None,
                 break_lineage_blocked_comparisons:
                 Callable = default_break_lineage_blocked_comparisons,
                 break_lineage_scored_comparisons:
                 Callable = default_break_lineage_scored_comparisons):
        """splink data linker

        Provides easy access to the core user-facing functinoality of splink

        Args:
            settings (dict): splink settings dictionary
            spark (SparkSession): SparkSession object
            df_l (DataFrame, optional): A dataframe to link/dedupe. Where `link_type` is `link_only` or `link_and_dedupe`, one of the two dataframes to link. Should be ommitted `link_type` is `dedupe_only`.
            df_r (DataFrame, optional): A dataframe to link/dedupe. Where `link_type` is `link_only` or `link_and_dedupe`, one of the two dataframes to link. Should be ommitted `link_type` is `dedupe_only`.
            df (DataFrame, optional): The dataframe to dedupe. Where `link_type` is `dedupe_only`, the dataframe to dedupe. Should be ommitted `link_type` is `link_only` or `link_and_dedupe`.
            save_state_fn (function, optional):  A function provided by the user that takes two arguments, params and settings, and is executed each iteration.  This is a hook that allows the user to save the state between iterations, which is mostly useful for very large jobs which may need to be restarted from where they left off if they fail.
            break_lineage_blocked_comparisons (function, optional): Large jobs will likely run into memory errors unless the lineage is broken after blocking.  This is a user-provided function that takes one argument - df - and allows the user to break lineage.  For example, the function might save df to the AWS s3 file system, and then reload it from the saved files.
            break_lineage_scored_comparisons (function, optional): Large jobs will likely run into memory errors unless the lineage is broken after comparisons are scored and before term frequency adjustments.  This is a user-provided function that takes one argument - df - and allows the user to break lineage.  For example, the function might save df to the AWS s3 file system, and then reload it from the saved files.
        """

        self.spark = spark
        self.break_lineage_blocked_comparisons = break_lineage_blocked_comparisons
        self.break_lineage_scored_comparisons = break_lineage_scored_comparisons
        _check_jaro_registered(spark)

        settings = complete_settings_dict(settings, spark)
        validate_settings(settings)
        self.settings = settings

        self.params = Params(settings, spark)

        self.df_r = df_r
        self.df_l = df_l
        self.df = df
        self.save_state_fn = save_state_fn
        self._check_args()

    def _check_args(self):

        link_type = self.settings["link_type"]

        if link_type == "dedupe_only":
            check_1 = self.df_r is None
            check_2 = self.df_l is None
            check_3 = isinstance(self.df, DataFrame)

            if not all([check_1, check_2, check_3]):
                raise ValueError(
                    "For link_type = 'dedupe_only', you must pass a single Spark dataframe to Splink using the df argument. "
                    "The df_l and df_r arguments should be omitted or set to None. "
                    "e.g. linker = Splink(settings, spark, df=my_df)")

        if link_type in ["link_only", "link_and_dedupe"]:
            check_1 = isinstance(self.df_l, DataFrame)
            check_2 = isinstance(self.df_r, DataFrame)
            check_3 = self.df is None

            if not all([check_1, check_2, check_3]):
                raise ValueError(
                    f"For link_type = '{link_type}', you must pass two Spark dataframes to Splink using the df_l and df_r argument. "
                    "The df argument should be omitted or set to None. "
                    "e.g. linker = Splink(settings, spark, df_l=my_first_df, df_r=df_to_link_to_first_one)"
                )

    def _get_df_comparison(self):

        if self.settings["link_type"] == "dedupe_only":
            return block_using_rules(self.settings, self.spark, df=self.df)

        if self.settings["link_type"] in ("link_only", "link_and_dedupe"):
            return block_using_rules(self.settings,
                                     self.spark,
                                     df_l=self.df_l,
                                     df_r=self.df_r)

    def manually_apply_fellegi_sunter_weights(self):
        """Compute match probabilities from m and u probabilities specified in the splink settings object

        Returns:
            DataFrame: A spark dataframe including a match probability column
        """
        df_comparison = self._get_df_comparison()
        df_gammas = add_gammas(df_comparison, self.settings, self.spark)
        return run_expectation_step(df_gammas, self.params, self.settings,
                                    self.spark)

    def get_scored_comparisons(self):
        """Use the EM algorithm to estimate model parameters and return match probabilities.

        Note: Does not compute term frequency adjustments.

        Returns:
            DataFrame: A spark dataframe including a match probability column
        """

        df_comparison = self._get_df_comparison()

        df_gammas = add_gammas(df_comparison, self.settings, self.spark)

        df_gammas = self.break_lineage_blocked_comparisons(
            df_gammas, self.spark)

        df_e = iterate(
            df_gammas,
            self.params,
            self.settings,
            self.spark,
            compute_ll=False,
            save_state_fn=self.save_state_fn,
        )

        # In case the user's break lineage function has persisted it
        df_gammas.unpersist()

        df_e = self.break_lineage_scored_comparisons(df_e, self.spark)

        df_e_adj = self._make_term_frequency_adjustments(df_e)

        df_e.unpersist()

        return df_e_adj

    def _make_term_frequency_adjustments(self, df_e: DataFrame):
        """Take the outputs of 'get_scored_comparisons' and make term frequency adjustments on designated columns in the settings dictionary

        Args:
            df_e (DataFrame): A dataframe produced by the get_scored_comparisons method

        Returns:
            DataFrame: A spark dataframe including a column with term frequency adjusted match probabilities
        """

        return make_adjustment_for_term_frequencies(
            df_e,
            self.params,
            self.settings,
            retain_adjustment_columns=True,
            spark=self.spark,
        )

    def save_model_as_json(self, path: str, overwrite=False):
        """Save model (settings, parameters and parameter history) as a json file so it can later be re-loaded using load_from_json

        Args:
            path (str): Path to the json file.
            overwrite (bool): Whether to overwrite the file if it exsits
        """
        self.params.save_params_to_json_file(path, overwrite=overwrite)
コード例 #10
0
def sqlite_con_1(gamma_settings_1, params_1):

    # Create the database and the database table
    con = sqlite3.connect(":memory:")
    con.row_factory = sqlite3.Row
    cur = con.cursor()
    cur.execute("create table test1 (unique_id, mob, surname)")
    cur.execute("insert into test1 values (?, ?, ?)", (1, 10, "Linacre"))
    cur.execute("insert into test1 values (?, ?, ?)", (2, 10, "Linacre"))
    cur.execute("insert into test1 values (?, ?, ?)", (3, 10, "Linacer"))
    cur.execute("insert into test1 values (?, ?, ?)", (4, 7, "Smith"))
    cur.execute("insert into test1 values (?, ?, ?)", (5, 8, "Smith"))
    cur.execute("insert into test1 values (?, ?, ?)", (6, 8, "Smith"))
    cur.execute("insert into test1 values (?, ?, ?)", (7, 8, "Jones"))

    # Create comparison table
    rules = [
        "l.mob = r.mob",
        "l.surname = r.surname",
    ]

    sql = "select * from test1 limit 1"
    cur.execute(sql)
    one = cur.fetchone()
    columns = one.keys()

    sql = _sql_gen_block_using_rules("dedupe_only",
                                     columns,
                                     rules,
                                     table_name_dedupe="test1")
    df = pd.read_sql(sql, con)
    df = df.drop_duplicates(["unique_id_l", "unique_id_r"])
    df = df.sort_values(["unique_id_l", "unique_id_r"])
    df.to_sql("df_comparison1", con, index=False)

    sql = _sql_gen_add_gammas(gamma_settings_1, table_name="df_comparison1")

    df = pd.read_sql(sql, con)
    df.to_sql("df_gammas1", con, index=False)

    sql = _sql_gen_gamma_prob_columns(params_1, gamma_settings_1, "df_gammas1")
    df = pd.read_sql(sql, con)
    df.to_sql("df_with_gamma_probs1", con, index=False)

    sql = _sql_gen_expected_match_prob(params_1, gamma_settings_1,
                                       "df_with_gamma_probs1")
    df = pd.read_sql(sql, con)
    df.to_sql("df_with_match_probability1", con, index=False)

    sql = _sql_gen_intermediate_pi_aggregate(
        params_1, table_name="df_with_match_probability1")
    df = pd.read_sql(sql, con)
    df.to_sql("df_intermediate1", con, index=False)

    sql = _sql_gen_pi_df(params_1, "df_intermediate1")

    df = pd.read_sql(sql, con)
    df.to_sql("df_pi1", con, index=False)

    # Create a new parameters object and run everything again for a second iteration
    # Probability columns
    gamma_settings_it_2 = copy.deepcopy(gamma_settings_1)
    gamma_settings_it_2["proportion_of_matches"] = 0.540922141
    gamma_settings_it_2["comparison_columns"][0]["m_probabilities"] = [
        0.087438272, 0.912561728
    ]
    gamma_settings_it_2["comparison_columns"][0]["u_probabilities"] = [
        0.441543191, 0.558456809
    ]
    gamma_settings_it_2["comparison_columns"][1]["m_probabilities"] = [
        0.173315146,
        0.326240275,
        0.500444578,
    ]
    gamma_settings_it_2["comparison_columns"][1]["u_probabilities"] = [
        0.340356209,
        0.160167628,
        0.499476163,
    ]

    params2 = Params(gamma_settings_it_2, spark="supress_warnings")

    params2._generate_param_dict()

    sql = _sql_gen_gamma_prob_columns(params2, gamma_settings_it_2,
                                      "df_gammas1")
    df = pd.read_sql(sql, con)
    df.to_sql("df_with_gamma_probs1_it2", con, index=False)

    sql = _sql_gen_expected_match_prob(params2, gamma_settings_it_2,
                                       "df_with_gamma_probs1_it2")
    df = pd.read_sql(sql, con)
    df.to_sql("df_with_match_probability1_it2", con, index=False)

    sql = _sql_gen_intermediate_pi_aggregate(
        params2, table_name="df_with_match_probability1_it2")
    df = pd.read_sql(sql, con)
    df.to_sql("df_intermediate1_it2", con, index=False)

    sql = _sql_gen_pi_df(params2, "df_intermediate1_it2")

    df = pd.read_sql(sql, con)
    df.to_sql("df_pi1_it2", con, index=False)

    yield con
コード例 #11
0
def params_1(gamma_settings_1):

    # Probability columns
    params = Params(gamma_settings_1, spark="supress_warnings")
    yield params