示例#1
0
def add_gammas(
    df_comparison: DataFrame,
    settings_dict: dict,
    spark:SparkSession,
    unique_id_col: str = "unique_id",
):
    """ Compute the comparison vectors and add them to the dataframe.  See
    https://imai.fas.harvard.edu/research/files/linkage.pdf for more details of what is meant by comparison vectors

    Args:
        df_comparison (spark dataframe): A Spark dataframe containing record comparisons, with records compared using the convention col_name_l, col_name_r
        settings_dict (dict): The `splink` settings dictionary
        spark (Spark session): The Spark session object
        unique_id_col (str, optional): Name of the unique id column. Defaults to "unique_id".

    Returns:
        Spark dataframe: A dataframe containing new columns representing the gammas of the model
    """


    settings_dict = complete_settings_dict(settings_dict, spark)

    sql = _sql_gen_add_gammas(
        settings_dict,
        unique_id_col=unique_id_col,
    )

    logger.debug(_format_sql(sql))
    df_comparison.createOrReplaceTempView("df_comparison")
    df_gammas = spark.sql(sql)

    return df_gammas
示例#2
0
def run_maximisation_step(df_e: DataFrame, params: Params,
                          spark: SparkSession):
    """Compute new parameters and save them in the params object

    Note that the params object will be updated in-place by this function

    Args:
        df_e (DataFrame): the result of the expectation step
        params (Params): splink Params object
        spark (SparkSession): The spark session
    """

    sql = _sql_gen_intermediate_pi_aggregate(params)

    df_e.createOrReplaceTempView("df_e")
    df_intermediate = spark.sql(sql)
    logger.debug(_format_sql(sql))
    df_intermediate.createOrReplaceTempView("df_intermediate")
    df_intermediate.persist()

    new_lambda = _get_new_lambda(df_intermediate, spark)
    pi_df_collected = _get_new_pi_df(df_intermediate, spark, params)

    params._update_params(new_lambda, pi_df_collected)
    df_intermediate.unpersist()
示例#3
0
    def _transform(self, df: DataFrame) -> DataFrame:
        sql_text: Optional[str] = self.getSql()
        name: Optional[str] = self.getName()
        view: Optional[str] = self.getView()
        progress_logger: Optional[ProgressLogger] = self.getProgressLogger()

        assert sql_text
        with ProgressLogMetric(name=name or view or "",
                               progress_logger=progress_logger):
            if progress_logger and name:
                # mlflow opens .txt files inline so we use that extension
                progress_logger.log_artifact(key=f"{name}.sql.txt",
                                             contents=sql_text)
                progress_logger.write_to_log(name=name, message=sql_text)
            try:
                df = df.sql_ctx.sql(sql_text)
            except Exception:
                self.logger.info(f"Error in {name}")
                self.logger.info(sql_text)
                raise

            if view:
                df.createOrReplaceTempView(view)
            self.logger.info(
                f"GenericSqlTransformer [{name}] finished running SQL")

        return df
def run_maximisation_step(df_e: DataFrame, model: Model, spark: SparkSession):
    """Compute new parameters and save them in the model object

    Note that the model object will be updated in-place by this function

    Args:
        df_e (DataFrame): the result of the expectation step
        model (Model): splink Model object
        spark (SparkSession): The spark session
    """

    sql = _sql_gen_intermediate_pi_aggregate(model)

    df_e.createOrReplaceTempView("df_e")
    df_intermediate = spark.sql(sql)
    logger.debug(_format_sql(sql))
    df_intermediate.createOrReplaceTempView("df_intermediate")
    df_intermediate.persist()

    new_lambda = _get_new_lambda(df_intermediate, spark)
    pi_df_collected = _get_new_pi_df(df_intermediate, spark, model)

    model._populate_model_from_maximisation_step(new_lambda, pi_df_collected)
    model.iteration += 1
    df_intermediate.unpersist()
def block_using_rules(settings: dict, df: DataFrame, spark: SparkSession):
    """Apply a series of blocking rules to create a dataframe of record comparisons. If no blocking rules provided, performs a cartesian join.

    Args:
        settings (dict): A splink settings dictionary
        df (DataFrame): Spark dataframe to block - if linking multiple datasets, assumes dataframes have already been vertically concatenated
        spark (SparkSession): The pyspark.sql.session.SparkSession

    Returns:
        pyspark.sql.dataframe.DataFrame: A dataframe of each record comparison
    """
    df.createOrReplaceTempView("df")
    columns_to_retain = _get_columns_to_retain_blocking(settings, df)
    unique_id_col = settings["unique_id_column_name"]
    if settings["link_type"] == "dedupe_only":
        source_dataset_col = None
    else:
        source_dataset_col = settings["source_dataset_column_name"]
    link_type = settings["link_type"]

    if "blocking_rules" not in settings or len(
            settings["blocking_rules"]) == 0:
        sql = _sql_gen_cartesian_block(link_type, columns_to_retain,
                                       unique_id_col, source_dataset_col)
    else:
        rules = settings["blocking_rules"]
        sql = _sql_gen_block_using_rules(link_type, columns_to_retain, rules,
                                         unique_id_col, source_dataset_col)

    logger.debug(_format_sql(sql))

    df_comparison = spark.sql(sql)

    return df_comparison
示例#6
0
def make_adjustment_for_term_frequencies(
    df_e: DataFrame,
    model: Model,
    spark: SparkSession,
    retain_adjustment_columns: bool = False,
):

    # Running a maximisation step will eliminate errors cause by global parameters
    # being used in blocked jobs

    settings = model.current_settings_obj.settings_dict

    term_freq_column_list = [
        cc.name for cc in model.current_settings_obj.comparison_columns_list
        if cc["term_frequency_adjustments"] is True
    ]

    if len(term_freq_column_list) == 0:
        return df_e

    retain_source_dataset_col = _retain_source_dataset_column(settings, df_e)
    df_e.createOrReplaceTempView("df_e")

    old_settings = deepcopy(model.current_settings_obj.settings_dict)

    for cc in model.current_settings_obj.comparison_columns_list:
        cc.column_dict["fix_m_probabilities"] = False
        cc.column_dict["fix_u_probabilities"] = False

    run_maximisation_step(df_e, model, spark)

    # Generate a lookup table for each column with 'term specific' lambdas.
    for c in term_freq_column_list:
        sql = sql_gen_generate_adjusted_lambda(c, model)
        logger.debug(_format_sql(sql))
        lookup = spark.sql(sql)
        lookup.persist()
        lookup.createOrReplaceTempView(f"{c}_lookup")

    # Merge these lookup tables into main table
    sql = sql_gen_add_adjumentments_to_df_e(term_freq_column_list)
    logger.debug(_format_sql(sql))
    df_e_adj = spark.sql(sql)
    df_e_adj.createOrReplaceTempView("df_e_adj")

    sql = sql_gen_compute_final_group_membership_prob_from_adjustments(
        term_freq_column_list, settings, retain_source_dataset_col)
    logger.debug(_format_sql(sql))
    df = spark.sql(sql)
    if not retain_adjustment_columns:
        for c in term_freq_column_list:
            df = df.drop(c + "_tf_adj")

    # Restore original settings
    model.current_settings_obj.settings_dict = old_settings

    return df
示例#7
0
    def registerDataFrameAsTable(self, df: DataFrame, tableName: str) -> None:
        """Registers the given :class:`DataFrame` as a temporary table in the catalog.

        Temporary tables exist only during the lifetime of this instance of :class:`SQLContext`.

        .. versionadded:: 1.3.0

        Examples
        --------
        >>> sqlContext.registerDataFrameAsTable(df, "table1")
        """
        df.createOrReplaceTempView(tableName)
示例#8
0
def null_out_entries_with_freq_above_n(df: DataFrame, colname: str, n: int,
                                       spark: SparkSession):
    """Null out values above a certain frequency threshold

    Useful for columns that mostly contain valid data but occasionally
    contain other values such as 'unknown'

    Args:
        df (DataFrame): The dataframe to clean
        colname (string): The name of the column to clean
        n (int): The maximum frequency allowed.  Any values with a frequency higher than n will be nulled out
        spark (SparkSession): The spark session

    Returns:
        DataFrame: The cleaned dataframe with incoming column overwritten
    """

    # Possible that a window function would be better than the following approach
    # But I think both require a shuffle so possibly doesn't make much difference

    df.createOrReplaceTempView("df")

    sql = f"""
    select {colname} as count
    from df
    group by {colname}
    having count(*) > {n}
    """

    df_groups = spark.sql(sql)

    collected = df_groups.collect()

    values_to_null = [row["count"] for row in collected]

    if len(values_to_null) == 0:
        return df

    values_to_null = [f'"{v}"' for v in values_to_null]
    values_to_null_joined = ", ".join(values_to_null)

    case_statement = f"""
    CASE
    WHEN {colname} in ({values_to_null_joined}) THEN NULL
    ELSE {colname}
    END
    """

    df = df.withColumn(colname, f.expr(case_statement))

    return df
    def _transform(self, df: DataFrame) -> DataFrame:
        # pushdown_query = "(select * from employees where emp_no < 10008) emp_alias"
        query: str = self.getQuery()
        jdbc_url: str = self.getJdbcUrl()
        driver: str = self.getDriver()
        view: Optional[str] = self.getView()
        df = (
            # this execution requires an Option either 'dbtable' or 'query' parameter
            df.sql_ctx.read.format("jdbc").option("url", jdbc_url).option(
                "dbtable", query).option("driver", driver).load())

        if view:
            df.createOrReplaceTempView(view)

        return df
示例#10
0
    def upsert(self, df: DataFrame, full_table_name: str, schema: StructType, primary_key: list):
        temp_source_table = (
            f"upsert_{full_table_name.replace('.', '__')}_{''.join(random.choice(string.ascii_lowercase) for _ in range(6))}"
        )

        df.createOrReplaceTempView(temp_source_table)

        upsert_sql_statement = self.__upsert_query_creator.create(full_table_name, schema, primary_key, temp_source_table)

        try:
            self.__spark.sql(upsert_sql_statement)

        except BaseException:  # pylint: disable = broad-except, try-except-raise
            raise

        finally:
            self.__spark.catalog.dropTempView(temp_source_table)
示例#11
0
def run_expectation_step(
    df_with_gamma: DataFrame,
    model: Model,
    spark: SparkSession,
    compute_ll=False,
):
    """Run the expectation step of the EM algorithm described in the fastlink paper:
    http://imai.fas.harvard.edu/research/files/linkage.pdf

      Args:
          df_with_gamma (DataFrame): Spark dataframe with comparison vectors already populated
          model (Model): splink Model object
          spark (SparkSession): SparkSession
          compute_ll (bool, optional): Whether to compute the log likelihood. Degrades performance. Defaults to False.

      Returns:
          DataFrame: Spark dataframe with a match_probability column
    """

    retain_source_dataset = _retain_source_dataset_column(
        model.current_settings_obj.settings_dict, df_with_gamma)

    sql = _sql_gen_gamma_prob_columns(model, retain_source_dataset)

    df_with_gamma.createOrReplaceTempView("df_with_gamma")
    logger.debug(_format_sql(sql))
    df_with_gamma_probs = spark.sql(sql)

    # This is optional because is slows down execution
    if compute_ll:
        ll = get_overall_log_likelihood(df_with_gamma_probs, model, spark)
        message = f"Log likelihood for iteration {model.iteration-1}:  {ll}"
        logger.info(message)
        model.current_settings_obj["log_likelihood"] = ll

    sql = _sql_gen_expected_match_prob(model, retain_source_dataset)

    logger.debug(_format_sql(sql))
    df_with_gamma_probs.createOrReplaceTempView("df_with_gamma_probs")
    df_e = spark.sql(sql)

    df_e.createOrReplaceTempView("df_e")

    model.save_settings_to_iteration_history()

    return df_e
示例#12
0
def make_adjustment_for_term_frequencies(
    df_e: DataFrame,
    params: Params,
    settings: dict,
    spark: SparkSession,
    retain_adjustment_columns: bool = False
):

    df_e.createOrReplaceTempView("df_e")

    term_freq_column_list = [
        c["col_name"]
        for c in settings["comparison_columns"]
        if c["term_frequency_adjustments"] == True
    ]

    if len(term_freq_column_list) == 0:
        warnings.warn(
            "No term frequency adjustment columns are specified in your settings object.  Returning original df"
        )
        return df_e

    # Generate a lookup table for each column with 'term specific' lambdas.
    for c in term_freq_column_list:
        sql = sql_gen_generate_adjusted_lambda(c, params)
        logger.debug(_format_sql(sql))
        lookup = spark.sql(sql)
        lookup.persist()
        lookup.createOrReplaceTempView(f"{c}_lookup")

    # Merge these lookup tables into main table
    sql = sql_gen_add_adjumentments_to_df_e(term_freq_column_list)
    logger.debug(_format_sql(sql))
    df_e_adj = spark.sql(sql)
    df_e_adj.createOrReplaceTempView("df_e_adj")

    sql = sql_gen_compute_final_group_membership_prob_from_adjustments(
        term_freq_column_list, settings
    )
    logger.debug(_format_sql(sql))
    df = spark.sql(sql)
    if not retain_adjustment_columns:
        for c in term_freq_column_list:
            df = df.drop(c + "_adj")

    return df
示例#13
0
def run_expectation_step(df_with_gamma: DataFrame,
                         params: Params,
                         settings: dict,
                         spark: SparkSession,
                         compute_ll=False):
    """Run the expectation step of the EM algorithm described in the fastlink paper:
    http://imai.fas.harvard.edu/research/files/linkage.pdf

      Args:
          df_with_gamma (DataFrame): Spark dataframe with comparison vectors already populated
          params (Params): splink params object
          settings (dict): splink settings dictionary
          spark (SparkSession): SparkSession
          compute_ll (bool, optional): Whether to compute the log likelihood. Degrades performance. Defaults to False.

      Returns:
          DataFrame: Spark dataframe with a match_probability column
      """


    sql = _sql_gen_gamma_prob_columns(params, settings)

    df_with_gamma.createOrReplaceTempView("df_with_gamma")
    logger.debug(_format_sql(sql))
    df_with_gamma_probs = spark.sql(sql)
    
    # This is optional because is slows down execution
    if compute_ll:
        ll = get_overall_log_likelihood(df_with_gamma_probs, params, spark)
        message = f"Log likelihood for iteration {params.iteration-1}:  {ll}"
        logger.info(message)
        params.params["log_likelihood"] = ll

    sql = _sql_gen_expected_match_prob(params, settings)

    logger.debug(_format_sql(sql))
    df_with_gamma_probs.createOrReplaceTempView("df_with_gamma_probs")
    df_e = spark.sql(sql)

    df_e.createOrReplaceTempView("df_e")
    return df_e
示例#14
0
    def synthesize_data(self,
                        stats_nom: DataFrame,
                        record_layout: DataFrame,
                        qty: int = 100) -> pd.DataFrame:
        spark = self.__spark

        create_object(self.t_item_view, self.concepts_script, spark)
        create_object(self.txform_view, self.txform_script, spark)

        stats_nom.createOrReplaceTempView(self.agg_view)

        entity = spark.createDataFrame([(ix, ) for ix in range(0, qty)],
                                       ['case_index'])
        entity.createOrReplaceTempView(self.entity_view)
        # simulated_entity.limit(5).toPandas()

        for view in self.views:
            create_object(view, self.script, spark)
        spark.catalog.cacheTable(self.views[-1])

        # ISSUE: SQL goes in .sql files
        sim_records_nom = spark.sql('''
        select data.case_index, data.xmlId, data.value
        from simulated_naaccr_nom data
        join record_layout rl on rl.xmlId = data.xmlId
        join section on rl.section = section.section
        order by case_index, rl.start
        ''').toPandas()
        sim_records_nom = sim_records_nom.pivot(index='case_index',
                                                columns='xmlId',
                                                values='value')
        for col in sim_records_nom.columns:
            sim_records_nom[col] = sim_records_nom[col].astype('category')

        col_start = {row.xmlId: row.start for row in record_layout.collect()}
        sim_records_nom = sim_records_nom[sorted(
            sim_records_nom.columns, key=lambda xid: col_start[xid])]
        return sim_records_nom
示例#15
0
def block_using_rules(
    settings: dict,
    spark: SparkSession,
    df_l: DataFrame=None,
    df_r: DataFrame=None,
    df: DataFrame=None
):
    """Apply a series of blocking rules to create a dataframe of record comparisons. If no blocking rules provided, performs a cartesian join.

    Args:
        settings (dict): A splink settings dictionary
        spark (SparkSession): The pyspark.sql.session.SparkSession
        df_l (DataFrame, optional): Where `link_type` is `link_only` or `link_and_dedupe`, one of the two dataframes to link. Should be ommitted `link_type` is `dedupe_only`.
        df_r (DataFrame, optional): Where `link_type` is `link_only` or `link_and_dedupe`, one of the two dataframes to link. Should be ommitted `link_type` is `dedupe_only`.
        df (DataFrame, optional): Where `link_type` is `dedupe_only`, the dataframe to dedupe. Should be ommitted `link_type` is `link_only` or `link_and_dedupe`.

    Returns:
        pyspark.sql.dataframe.DataFrame: A dataframe of each record comparison
    """

    if "blocking_rules" not in settings or len(settings["blocking_rules"])==0:
        return cartesian_block(settings, spark, df_l, df_r, df)

    link_type = settings["link_type"]

    columns_to_retain = _get_columns_to_retain_blocking(settings)
    unique_id_col = settings["unique_id_column_name"]

    if link_type == "dedupe_only":
        df.createOrReplaceTempView("df")

    if link_type == "link_only":
        df_l.createOrReplaceTempView("df_l")
        df_r.createOrReplaceTempView("df_r")

    if link_type == "link_and_dedupe":
        df_concat = _vertically_concatenate_datasets(df_l, df_r, settings, spark=spark)
        columns_to_retain.append("_source_table")
        df_concat.createOrReplaceTempView("df")
        df_concat.persist()

    rules = settings["blocking_rules"]

    sql = _sql_gen_block_using_rules(link_type, columns_to_retain, rules, unique_id_col)

    logger.debug(_format_sql(sql))

    df_comparison = spark.sql(sql)

    if link_type == "link_and_dedupe":
        df_concat.unpersist()


    return df_comparison
def truth_space_table(
    df_labels_with_splink_scores: DataFrame,
    spark: SparkSession,
    threshold_actual: float = 0.5,
    score_colname: str = None,
):
    """Create a table of the ROC space i.e. truth table statistics
    for each discrimination threshold

    Args:
        df_labels_with_splink_scores (DataFrame): A dataframe of labels and associated splink scores
            usually the output of the truth.labels_with_splink_scores function
        threshold_actual (float, optional): Threshold to use in categorising clerical match
            scores into match or no match. Defaults to 0.5.
        score_colname (float, optional): Allows user to explicitly state the column name
            in the Splink dataset containing the Splink score.  If none will be inferred

    Returns:
        DataFrame: Table of 'truth space' i.e. truth categories for each threshold level
    """

    # At a truth threshold of 1.0, we say a splink score of 1.0 is a positive in ROC space. i.e it's inclusive, so if there are splink scores of exactly 1.0 it's not possible to have zero positives in the truth table.
    # This means that at a truth threshold of 0.0 we say a splink score of 0.0 positive.  so it's possible to have zero negatives in the truth table.

    # This code provides an efficient way to compute the truth space
    # It's more complex than the previous code, but executes much faster because only a single SQL query/PySpark Action is needed
    # The previous implementation, which is easier to understand, is [here](https://github.com/moj-analytical-services/splink/blob/b4f601e6d180c6abfd64ab40775dca3e3513c0b5/splink/truth.py#L396)

    # We start with df_labels_with_splink_scores
    # This is a table of each pair of clerically labelled records accompanied by the Splink match score.
    # It is sorted in order to clerical_match_score, low to high

    # This means for any row, if we picked a threshold equal to clerical_match_score, all rows _above_ (in the table) are categoried by splink as non-matches.

    # For instance, if a clerical_match_score is 0.25, then any records above this in the table have a score of <0.25, and are therefore negative.  We categorise a score of exactly 0.25 as positive.

    # In addition, we can categorise any indiviual row as containing a false positive or false negative _at_ the clerical match score for the row.

    # This allows us to say things like:  Of the records above this row, we have _classified_ them all as negative, but we have _seen_ to true (clerically labelled) positives.  Thefore these must be false negatives.

    # In particular, the calculations are as follows:
    # False positives:  The cumulative total of positive labels in records BELOW this row, INCLUSIVE (because this one is being counted as positive)
    # True positives:  The total number of positives minus false positives

    # False negatives:  The total number of negatives, minus negatives seen above this row
    # True negatives:  The cumulative total of negative labels in records aboev this row

    # We want percentiles of score to compute
    score_colname = _get_score_colname(df_labels_with_splink_scores,
                                       score_colname)

    df_labels_with_splink_scores.createOrReplaceTempView(
        "df_labels_with_splink_scores")
    sql = f"""
    select
    *,
    {score_colname} as truth_threshold,
    case when clerical_match_score >= {threshold_actual} then 1
    else 0
    end
    as c_P,
    case when clerical_match_score >= {threshold_actual} then 0
    else 1
    end
    as c_N
    from df_labels_with_splink_scores
    order by {score_colname}
    """
    df_with_labels = spark.sql(sql)
    df_with_labels.createOrReplaceTempView("df_with_labels")

    sql = """
    select truth_threshold, count(*) as num_records_in_row, sum(c_P) as c_P, sum(c_N) as c_N
    from
    df_with_labels
    group by truth_threshold
    order by truth_threshold
    """
    df_with_labels_grouped = spark.sql(sql)
    df_with_labels_grouped.createOrReplaceTempView("df_with_labels_grouped")

    sql = """
    select
    truth_threshold,

    (sum(c_P) over (order by truth_threshold desc))  as cum_clerical_P,
    (sum(c_N) over (order by truth_threshold)) - c_N as cum_clerical_N,

    (select sum(c_P) from df_with_labels_grouped) as total_clerical_P,
    (select sum(c_N) from df_with_labels_grouped) as total_clerical_N,
    (select sum(num_records_in_row) from df_with_labels_grouped) as row_count,

    -num_records_in_row + sum(num_records_in_row) over (order by truth_threshold) as N_labels,
    sum(num_records_in_row) over (order by truth_threshold desc) as P_labels
    from df_with_labels_grouped
    order by  truth_threshold
    """
    df_with_cumulative_labels = spark.sql(sql)
    df_with_cumulative_labels.createOrReplaceTempView(
        "df_with_cumulative_labels")

    sql = """
    select
    truth_threshold,
    row_count,
    total_clerical_P as P,
    total_clerical_N as N,

    P_labels - cum_clerical_P as FP,
    cum_clerical_P as TP,

    N_labels - cum_clerical_N as FN,
    cum_clerical_N as TN

    from df_with_cumulative_labels
    """
    df_with_truth_cats = spark.sql(sql)
    df_with_truth_cats.createOrReplaceTempView("df_with_truth_cats")
    df_with_truth_cats.toPandas()

    sql = """
    select
    truth_threshold,
    row_count,
    P,
    N,
    TP,
    TN,
    FP,
    FN,
    P/row_count as P_rate,
        N/row_count as N_rate,
        TP/P as TP_rate,
        TN/N as TN_rate,
        FP/N as FP_rate,
        FN/P as FN_rate,
        TP/(TP+FP) as precision,
        TP/(TP+FN) as recall
    from df_with_truth_cats
    """
    df_truth_space = spark.sql(sql)

    return df_truth_space