def test_can_run_validated_framework_pipeline(
        spark_session: SparkSession) -> None:
    with pytest.raises(AssertionError):
        # Arrange
        clean_spark_session(spark_session)
        data_dir: Path = Path(__file__).parent.joinpath("./")
        flights_path: str = f"file://{data_dir.joinpath('flights.csv')}"
        output_path: str = f"file://{data_dir.joinpath('temp').joinpath('validation.csv')}"

        if path.isdir(data_dir.joinpath("temp")):
            shutil.rmtree(data_dir.joinpath("temp"))

        schema = StructType([])

        df: DataFrame = spark_session.createDataFrame(
            spark_session.sparkContext.emptyRDD(), schema)

        spark_session.sql("DROP TABLE IF EXISTS default.flights")

        # Act
        parameters = {
            "flights_path": flights_path,
            "validation_source_path": str(data_dir),
            "validation_output_path": output_path,
        }

        with ProgressLogger() as progress_logger:
            pipeline: MyValidatedPipeline = MyValidatedPipeline(
                parameters=parameters, progress_logger=progress_logger)
            transformer = pipeline.fit(df)
            transformer.transform(df)
Пример #2
0
def test_simple_csv_and_sql_pipeline(spark_session: SparkSession) -> None:
    # Arrange
    data_dir: Path = Path(__file__).parent.joinpath("./")
    flights_path: str = f"file://{data_dir.joinpath('flights.csv')}"

    schema = StructType([])

    df: DataFrame = spark_session.createDataFrame(
        spark_session.sparkContext.emptyRDD(), schema)

    spark_session.sql("DROP TABLE IF EXISTS default.flights")

    # Act
    parameters: Dict[str, Any] = {}

    stages: List[Transformer] = create_steps([
        FrameworkCsvLoader(view="flights", filepath=flights_path),
        FeaturesCarriersV1(parameters=parameters),
    ])

    pipeline: Pipeline = Pipeline(stages=stages)  # type: ignore
    transformer = pipeline.fit(df)
    transformer.transform(df)

    # Assert
    result_df: DataFrame = spark_session.sql("SELECT * FROM flights2")
    result_df.show()

    assert result_df.count() > 0
Пример #3
0
def test_can_run_framework_pipeline(spark_session: SparkSession) -> None:
    # Arrange
    data_dir: Path = Path(__file__).parent.joinpath("./")
    flights_path: str = f"file://{data_dir.joinpath('flights.csv')}"

    schema = StructType([])

    df: DataFrame = spark_session.createDataFrame(
        spark_session.sparkContext.emptyRDD(), schema)

    spark_session.sql("DROP TABLE IF EXISTS default.flights")

    # Act
    parameters = {"flights_path": flights_path}

    with ProgressLogger() as progress_logger:
        pipeline: MyPipeline = MyPipeline(parameters=parameters,
                                          progress_logger=progress_logger)
        transformer = pipeline.fit(df)
        transformer.transform(df)

    # Assert
    result_df: DataFrame = spark_session.sql("SELECT * FROM flights2")
    result_df.show()

    assert result_df.count() > 0
def test_fail_fast_validated_framework_pipeline_writes_results(
    spark_session: SparkSession, ) -> None:
    # Arrange
    clean_spark_session(spark_session)
    data_dir: Path = Path(__file__).parent.joinpath("./")
    flights_path: str = f"file://{data_dir.joinpath('flights.csv')}"
    output_path: str = f"file://{data_dir.joinpath('temp').joinpath('validation.csv')}"

    if path.isdir(data_dir.joinpath("temp")):
        shutil.rmtree(data_dir.joinpath("temp"))

    schema = StructType([])

    df: DataFrame = spark_session.createDataFrame(
        spark_session.sparkContext.emptyRDD(), schema)

    spark_session.sql("DROP TABLE IF EXISTS default.flights")

    # Act
    parameters = {
        "flights_path": flights_path,
        "validation_source_path": str(data_dir),
        "validation_output_path": output_path,
    }

    try:
        with ProgressLogger() as progress_logger:
            pipeline: MyFailFastValidatedPipeline = MyFailFastValidatedPipeline(
                parameters=parameters, progress_logger=progress_logger)
            transformer = pipeline.fit(df)
            transformer.transform(df)
    except AssertionError:
        validation_df = df.sql_ctx.read.csv(output_path, header=True)
        validation_df.show(truncate=False)
        assert validation_df.count() == 1
Пример #5
0
def test_simple_csv_loader_pipeline(spark_session: SparkSession) -> None:
    # Arrange
    data_dir: Path = Path(__file__).parent.joinpath('./')
    flights_path: str = f"file://{data_dir.joinpath('flights.csv')}"

    schema = StructType([])

    df: DataFrame = spark_session.createDataFrame(
        spark_session.sparkContext.emptyRDD(), schema)

    # noinspection SqlDialectInspection,SqlNoDataSourceInspection
    spark_session.sql("DROP TABLE IF EXISTS default.flights")

    # Act
    # parameters = Dict[str, Any]({
    # })

    stages: List[Union[Estimator, Transformer]] = [
        FrameworkCsvLoader(
            view="flights",
            path_to_csv=flights_path
        ),
        SQLTransformer(statement="SELECT * FROM flights"),
    ]

    pipeline: Pipeline = Pipeline(stages=stages)

    transformer = pipeline.fit(df)
    result_df: DataFrame = transformer.transform(df)

    # Assert
    result_df.show()

    assert result_df.count() > 0
Пример #6
0
def make_adjustment_for_term_frequencies(
    df_e: DataFrame,
    model: Model,
    spark: SparkSession,
    retain_adjustment_columns: bool = False,
):

    # Running a maximisation step will eliminate errors cause by global parameters
    # being used in blocked jobs

    settings = model.current_settings_obj.settings_dict

    term_freq_column_list = [
        cc.name for cc in model.current_settings_obj.comparison_columns_list
        if cc["term_frequency_adjustments"] is True
    ]

    if len(term_freq_column_list) == 0:
        return df_e

    retain_source_dataset_col = _retain_source_dataset_column(settings, df_e)
    df_e.createOrReplaceTempView("df_e")

    old_settings = deepcopy(model.current_settings_obj.settings_dict)

    for cc in model.current_settings_obj.comparison_columns_list:
        cc.column_dict["fix_m_probabilities"] = False
        cc.column_dict["fix_u_probabilities"] = False

    run_maximisation_step(df_e, model, spark)

    # Generate a lookup table for each column with 'term specific' lambdas.
    for c in term_freq_column_list:
        sql = sql_gen_generate_adjusted_lambda(c, model)
        logger.debug(_format_sql(sql))
        lookup = spark.sql(sql)
        lookup.persist()
        lookup.createOrReplaceTempView(f"{c}_lookup")

    # Merge these lookup tables into main table
    sql = sql_gen_add_adjumentments_to_df_e(term_freq_column_list)
    logger.debug(_format_sql(sql))
    df_e_adj = spark.sql(sql)
    df_e_adj.createOrReplaceTempView("df_e_adj")

    sql = sql_gen_compute_final_group_membership_prob_from_adjustments(
        term_freq_column_list, settings, retain_source_dataset_col)
    logger.debug(_format_sql(sql))
    df = spark.sql(sql)
    if not retain_adjustment_columns:
        for c in term_freq_column_list:
            df = df.drop(c + "_tf_adj")

    # Restore original settings
    model.current_settings_obj.settings_dict = old_settings

    return df
Пример #7
0
def test_model_codegen_registered(spark: SparkSession):
    init(spark, True)

    spark.sql(
        """CREATE MODEL foo_dynamic OPTIONS (foo="str",bar=True,max_score=1.23)
         USING 'test://model/a/b/c'"""
    ).count()

    init(spark, False)

    spark.sql(
        """CREATE MODEL foo_static OPTIONS (foo="str",bar=True,max_score=1.23)
         USING 'test://model/a/b/c'"""
    ).count()
Пример #8
0
def test_can_load_non_standard_delimited_csv(
        spark_session: SparkSession) -> None:
    # Arrange
    clean_spark_session(spark_session)

    data_dir: Path = Path(__file__).parent.joinpath("./")
    test_file_path: str = f"{data_dir.joinpath('test.psv')}"

    schema = StructType([])

    df: DataFrame = spark_session.createDataFrame(
        spark_session.sparkContext.emptyRDD(), schema)

    # Act
    loader = FrameworkCsvLoader(view="my_view",
                                filepath=test_file_path,
                                delimiter="|")
    loader.transform(df)

    # noinspection SqlDialectInspection
    result: DataFrame = spark_session.sql("SELECT * FROM my_view")

    result.show()

    # Assert
    assert loader.getDelimiter() == "|"
    assert_results(result)
Пример #9
0
def test_correctly_loads_csv_with_clean_flag_on(
        spark_session: SparkSession) -> None:
    # Arrange
    clean_spark_session(spark_session)

    data_dir: Path = Path(__file__).parent.joinpath("./")
    test_file_path: str = f"{data_dir.joinpath('column_name_test.csv')}"

    schema = StructType([])

    df: DataFrame = spark_session.createDataFrame(
        spark_session.sparkContext.emptyRDD(), schema)

    # Act
    FrameworkCsvLoader(
        view="my_view",
        filepath=test_file_path,
        delimiter=",",
        clean_column_names=True,
    ).transform(df)

    # noinspection SqlDialectInspection
    result: DataFrame = spark_session.sql("SELECT * FROM my_view")

    # Assert
    assert_results(result)
    assert result.collect()[1][0] == "2"
    assert (result.columns[2] ==
            "Ugly_column_with_chars_that_parquet_does_not_like_much_-")
Пример #10
0
def test_can_keep_columns(spark_session: SparkSession) -> None:
    # Arrange
    clean_spark_session(spark_session)

    data_dir: Path = Path(__file__).parent.joinpath("./")
    test_file_path: str = f"{data_dir.joinpath('test.csv')}"

    schema = StructType([])

    df: DataFrame = spark_session.createDataFrame(
        spark_session.sparkContext.emptyRDD(), schema)

    # Act
    FrameworkCsvLoader(view="my_view", filepath=test_file_path,
                       delimiter=",").transform(df)

    FrameworkSelectColumnsTransformer(view="my_view",
                                      keep_columns=["Column2"]).transform(df)

    # noinspection SqlDialectInspection
    result: DataFrame = spark_session.sql("SELECT * FROM my_view")

    result.show()

    # Assert
    assert len(result.columns) == 1

    assert result.count() == 3

    assert result.collect()[1][0] == "bar"
def block_using_rules(settings: dict, df: DataFrame, spark: SparkSession):
    """Apply a series of blocking rules to create a dataframe of record comparisons. If no blocking rules provided, performs a cartesian join.

    Args:
        settings (dict): A splink settings dictionary
        df (DataFrame): Spark dataframe to block - if linking multiple datasets, assumes dataframes have already been vertically concatenated
        spark (SparkSession): The pyspark.sql.session.SparkSession

    Returns:
        pyspark.sql.dataframe.DataFrame: A dataframe of each record comparison
    """
    df.createOrReplaceTempView("df")
    columns_to_retain = _get_columns_to_retain_blocking(settings, df)
    unique_id_col = settings["unique_id_column_name"]
    if settings["link_type"] == "dedupe_only":
        source_dataset_col = None
    else:
        source_dataset_col = settings["source_dataset_column_name"]
    link_type = settings["link_type"]

    if "blocking_rules" not in settings or len(
            settings["blocking_rules"]) == 0:
        sql = _sql_gen_cartesian_block(link_type, columns_to_retain,
                                       unique_id_col, source_dataset_col)
    else:
        rules = settings["blocking_rules"]
        sql = _sql_gen_block_using_rules(link_type, columns_to_retain, rules,
                                         unique_id_col, source_dataset_col)

    logger.debug(_format_sql(sql))

    df_comparison = spark.sql(sql)

    return df_comparison
def run_maximisation_step(df_e: DataFrame, model: Model, spark: SparkSession):
    """Compute new parameters and save them in the model object

    Note that the model object will be updated in-place by this function

    Args:
        df_e (DataFrame): the result of the expectation step
        model (Model): splink Model object
        spark (SparkSession): The spark session
    """

    sql = _sql_gen_intermediate_pi_aggregate(model)

    df_e.createOrReplaceTempView("df_e")
    df_intermediate = spark.sql(sql)
    logger.debug(_format_sql(sql))
    df_intermediate.createOrReplaceTempView("df_intermediate")
    df_intermediate.persist()

    new_lambda = _get_new_lambda(df_intermediate, spark)
    pi_df_collected = _get_new_pi_df(df_intermediate, spark, model)

    model._populate_model_from_maximisation_step(new_lambda, pi_df_collected)
    model.iteration += 1
    df_intermediate.unpersist()
Пример #13
0
def test_can_load_xml_file_with_schema(spark_session: SparkSession) -> None:
    # Arrange
    clean_spark_session(spark_session)

    data_dir: Path = Path(__file__).parent.joinpath("./")
    test_file_path: str = f"{data_dir.joinpath('test.xml')}"

    schema = StructType([])

    df: DataFrame = spark_session.createDataFrame(
        spark_session.sparkContext.emptyRDD(), schema)

    xml_shema = StructType([
        StructField("_id", StringType(), True),
        StructField("author", StringType(), True),
        StructField("description", StringType(), True),
        StructField("genre", StringType(), True),
        StructField("price", DoubleType(), True),
        StructField("publish_date", StringType(), True),
        StructField("title", StringType(), True),
    ])
    # Act
    FrameworkXmlLoader(view="my_view",
                       filepath=test_file_path,
                       row_tag="book",
                       schema=xml_shema).transform(df)

    result: DataFrame = spark_session.sql("SELECT * FROM my_view")
    result.show()
    assert result.count() == 12
    assert len(result.columns) == 7
Пример #14
0
def add_gammas(
    df_comparison: DataFrame,
    settings_dict: dict,
    spark:SparkSession,
    unique_id_col: str = "unique_id",
):
    """ Compute the comparison vectors and add them to the dataframe.  See
    https://imai.fas.harvard.edu/research/files/linkage.pdf for more details of what is meant by comparison vectors

    Args:
        df_comparison (spark dataframe): A Spark dataframe containing record comparisons, with records compared using the convention col_name_l, col_name_r
        settings_dict (dict): The `splink` settings dictionary
        spark (Spark session): The Spark session object
        unique_id_col (str, optional): Name of the unique id column. Defaults to "unique_id".

    Returns:
        Spark dataframe: A dataframe containing new columns representing the gammas of the model
    """


    settings_dict = complete_settings_dict(settings_dict, spark)

    sql = _sql_gen_add_gammas(
        settings_dict,
        unique_id_col=unique_id_col,
    )

    logger.debug(_format_sql(sql))
    df_comparison.createOrReplaceTempView("df_comparison")
    df_gammas = spark.sql(sql)

    return df_gammas
Пример #15
0
def test_can_load_parquet(spark_session: SparkSession):
    # Arrange
    SparkTestHelper.clear_tables(spark_session)

    data_dir: Path = Path(__file__).parent.joinpath('./')
    test_file_path: str = f"{data_dir.joinpath('test.csv')}"

    if path.isdir(data_dir.joinpath('temp')):
        shutil.rmtree(data_dir.joinpath('temp'))

    schema = StructType([])

    df: DataFrame = spark_session.createDataFrame(
        spark_session.sparkContext.emptyRDD(), schema)

    parquet_file_path: str = ParquetHelper.create_parquet_from_csv(
        spark_session=spark_session, file_path=test_file_path)

    # Act
    FrameworkParquetLoader(view="my_view",
                           file_path=parquet_file_path).transform(df)

    # noinspection SqlDialectInspection
    result: DataFrame = spark_session.sql("SELECT * FROM my_view")

    result.show()

    # Assert
    assert result.count() == 3

    assert result.collect()[1][0] == 2
    assert result.collect()[1][1] == "bar"
    assert result.collect()[1][2] == "bar2"
Пример #16
0
def run_maximisation_step(df_e: DataFrame, params: Params,
                          spark: SparkSession):
    """Compute new parameters and save them in the params object

    Note that the params object will be updated in-place by this function

    Args:
        df_e (DataFrame): the result of the expectation step
        params (Params): splink Params object
        spark (SparkSession): The spark session
    """

    sql = _sql_gen_intermediate_pi_aggregate(params)

    df_e.createOrReplaceTempView("df_e")
    df_intermediate = spark.sql(sql)
    logger.debug(_format_sql(sql))
    df_intermediate.createOrReplaceTempView("df_intermediate")
    df_intermediate.persist()

    new_lambda = _get_new_lambda(df_intermediate, spark)
    pi_df_collected = _get_new_pi_df(df_intermediate, spark, params)

    params._update_params(new_lambda, pi_df_collected)
    df_intermediate.unpersist()
Пример #17
0
def make_adjustment_for_term_frequencies(
    df_e: DataFrame,
    params: Params,
    settings: dict,
    spark: SparkSession,
    retain_adjustment_columns: bool = False
):

    df_e.createOrReplaceTempView("df_e")

    term_freq_column_list = [
        c["col_name"]
        for c in settings["comparison_columns"]
        if c["term_frequency_adjustments"] == True
    ]

    if len(term_freq_column_list) == 0:
        warnings.warn(
            "No term frequency adjustment columns are specified in your settings object.  Returning original df"
        )
        return df_e

    # Generate a lookup table for each column with 'term specific' lambdas.
    for c in term_freq_column_list:
        sql = sql_gen_generate_adjusted_lambda(c, params)
        logger.debug(_format_sql(sql))
        lookup = spark.sql(sql)
        lookup.persist()
        lookup.createOrReplaceTempView(f"{c}_lookup")

    # Merge these lookup tables into main table
    sql = sql_gen_add_adjumentments_to_df_e(term_freq_column_list)
    logger.debug(_format_sql(sql))
    df_e_adj = spark.sql(sql)
    df_e_adj.createOrReplaceTempView("df_e_adj")

    sql = sql_gen_compute_final_group_membership_prob_from_adjustments(
        term_freq_column_list, settings
    )
    logger.debug(_format_sql(sql))
    df = spark.sql(sql)
    if not retain_adjustment_columns:
        for c in term_freq_column_list:
            df = df.drop(c + "_adj")

    return df
Пример #18
0
def run_expectation_step(
    df_with_gamma: DataFrame,
    model: Model,
    spark: SparkSession,
    compute_ll=False,
):
    """Run the expectation step of the EM algorithm described in the fastlink paper:
    http://imai.fas.harvard.edu/research/files/linkage.pdf

      Args:
          df_with_gamma (DataFrame): Spark dataframe with comparison vectors already populated
          model (Model): splink Model object
          spark (SparkSession): SparkSession
          compute_ll (bool, optional): Whether to compute the log likelihood. Degrades performance. Defaults to False.

      Returns:
          DataFrame: Spark dataframe with a match_probability column
    """

    retain_source_dataset = _retain_source_dataset_column(
        model.current_settings_obj.settings_dict, df_with_gamma)

    sql = _sql_gen_gamma_prob_columns(model, retain_source_dataset)

    df_with_gamma.createOrReplaceTempView("df_with_gamma")
    logger.debug(_format_sql(sql))
    df_with_gamma_probs = spark.sql(sql)

    # This is optional because is slows down execution
    if compute_ll:
        ll = get_overall_log_likelihood(df_with_gamma_probs, model, spark)
        message = f"Log likelihood for iteration {model.iteration-1}:  {ll}"
        logger.info(message)
        model.current_settings_obj["log_likelihood"] = ll

    sql = _sql_gen_expected_match_prob(model, retain_source_dataset)

    logger.debug(_format_sql(sql))
    df_with_gamma_probs.createOrReplaceTempView("df_with_gamma_probs")
    df_e = spark.sql(sql)

    df_e.createOrReplaceTempView("df_e")

    model.save_settings_to_iteration_history()

    return df_e
Пример #19
0
def test_can_load_fixed_width(spark_session: SparkSession) -> None:
    # Arrange
    clean_spark_session(spark_session)

    data_dir: Path = Path(__file__).parent.joinpath("./")
    test_file_path: str = f"{data_dir.joinpath('test.txt')}"

    schema = StructType([])

    df: DataFrame = spark_session.createDataFrame(
        spark_session.sparkContext.emptyRDD(), schema)

    # Act
    FrameworkFixedWidthLoader(
        view="my_view",
        filepath=test_file_path,
        columns=[
            ColumnSpec(column_name="id",
                       start_pos=1,
                       length=3,
                       data_type=StringType()),
            ColumnSpec(column_name="some_date",
                       start_pos=4,
                       length=8,
                       data_type=StringType()),
            ColumnSpec(
                column_name="some_string",
                start_pos=12,
                length=3,
                data_type=StringType(),
            ),
            ColumnSpec(
                column_name="some_integer",
                start_pos=15,
                length=4,
                data_type=IntegerType(),
            ),
        ],
    ).transform(df)

    # noinspection SqlDialectInspection
    result: DataFrame = spark_session.sql("SELECT * FROM my_view")

    result.show()

    # Assert
    assert result.count() == 2
    assert result.collect()[0][0] == "001"
    assert result.collect()[1][0] == "002"
    assert result.collect()[0][1] == "01292017"
    assert result.collect()[1][1] == "01302017"
    assert result.collect()[0][2] == "you"
    assert result.collect()[1][2] == "me"
    assert result.collect()[0][3] == 1234
    assert result.collect()[1][3] == 5678
Пример #20
0
class TweetsSpark:
    def __init__(self):
        self.tweets = None
        self.tweets_table = None
        sc = SparkContext.getOrCreate()
        self.spark = SparkSession(sc)

    # Метод для записи в переменную, выборочные данные из файла.
    def set_tweets_data(self, path):
        # Чтение данных и разбиение на колонки.
        self.tweets = self.spark.read.csv(path, header=True, escape='\"')
        self.tweets.createOrReplaceTempView("tweets")

    def set_tweets_foreign(self):
        sql = """\
        CREATE TEMP VIEW tweets_foreign AS
          SELECT userid, 
                 CAST(reply_count AS INT), 
                 account_language
            FROM tweets
            WHERE account_language != 'ru'
        """
        self.spark.sql(sql)

    def get_userid_max_rc(self):
        sql = """\
        SELECT userid FROM tweets_foreign
          WHERE reply_count = (
            SELECT MAX(reply_count)
              FROM tweets_foreign
            )
          LIMIT 1
        """
        # Result userid 4224729994
        return self.spark.sql(sql).show()

    def get_sorted_data(self):
        sql = """\
         SELECT * FROM tweeets_foreign
           ORDER BY reply_count DESC
         """
        return self.spark.sql(sql).show()
Пример #21
0
def block_using_rules(
    settings: dict,
    spark: SparkSession,
    df_l: DataFrame=None,
    df_r: DataFrame=None,
    df: DataFrame=None
):
    """Apply a series of blocking rules to create a dataframe of record comparisons. If no blocking rules provided, performs a cartesian join.

    Args:
        settings (dict): A splink settings dictionary
        spark (SparkSession): The pyspark.sql.session.SparkSession
        df_l (DataFrame, optional): Where `link_type` is `link_only` or `link_and_dedupe`, one of the two dataframes to link. Should be ommitted `link_type` is `dedupe_only`.
        df_r (DataFrame, optional): Where `link_type` is `link_only` or `link_and_dedupe`, one of the two dataframes to link. Should be ommitted `link_type` is `dedupe_only`.
        df (DataFrame, optional): Where `link_type` is `dedupe_only`, the dataframe to dedupe. Should be ommitted `link_type` is `link_only` or `link_and_dedupe`.

    Returns:
        pyspark.sql.dataframe.DataFrame: A dataframe of each record comparison
    """

    if "blocking_rules" not in settings or len(settings["blocking_rules"])==0:
        return cartesian_block(settings, spark, df_l, df_r, df)

    link_type = settings["link_type"]

    columns_to_retain = _get_columns_to_retain_blocking(settings)
    unique_id_col = settings["unique_id_column_name"]

    if link_type == "dedupe_only":
        df.createOrReplaceTempView("df")

    if link_type == "link_only":
        df_l.createOrReplaceTempView("df_l")
        df_r.createOrReplaceTempView("df_r")

    if link_type == "link_and_dedupe":
        df_concat = _vertically_concatenate_datasets(df_l, df_r, settings, spark=spark)
        columns_to_retain.append("_source_table")
        df_concat.createOrReplaceTempView("df")
        df_concat.persist()

    rules = settings["blocking_rules"]

    sql = _sql_gen_block_using_rules(link_type, columns_to_retain, rules, unique_id_col)

    logger.debug(_format_sql(sql))

    df_comparison = spark.sql(sql)

    if link_type == "link_and_dedupe":
        df_concat.unpersist()


    return df_comparison
Пример #22
0
def run_expectation_step(df_with_gamma: DataFrame,
                         params: Params,
                         settings: dict,
                         spark: SparkSession,
                         compute_ll=False):
    """Run the expectation step of the EM algorithm described in the fastlink paper:
    http://imai.fas.harvard.edu/research/files/linkage.pdf

      Args:
          df_with_gamma (DataFrame): Spark dataframe with comparison vectors already populated
          params (Params): splink params object
          settings (dict): splink settings dictionary
          spark (SparkSession): SparkSession
          compute_ll (bool, optional): Whether to compute the log likelihood. Degrades performance. Defaults to False.

      Returns:
          DataFrame: Spark dataframe with a match_probability column
      """


    sql = _sql_gen_gamma_prob_columns(params, settings)

    df_with_gamma.createOrReplaceTempView("df_with_gamma")
    logger.debug(_format_sql(sql))
    df_with_gamma_probs = spark.sql(sql)
    
    # This is optional because is slows down execution
    if compute_ll:
        ll = get_overall_log_likelihood(df_with_gamma_probs, params, spark)
        message = f"Log likelihood for iteration {params.iteration-1}:  {ll}"
        logger.info(message)
        params.params["log_likelihood"] = ll

    sql = _sql_gen_expected_match_prob(params, settings)

    logger.debug(_format_sql(sql))
    df_with_gamma_probs.createOrReplaceTempView("df_with_gamma_probs")
    df_e = spark.sql(sql)

    df_e.createOrReplaceTempView("df_e")
    return df_e
def df_e_with_truth_categories(
    df_labels_with_splink_scores,
    threshold_pred,
    spark: SparkSession,
    threshold_actual: float = 0.5,
    score_colname: str = None,
):
    """Join Splink's predictions to clerically labelled data and categorise
    rows by truth category (false positive, true positive etc.)

    Note that df_labels

    Args:
        df_labels_with_splink_scores (DataFrame): A dataframe of labels and associated splink scores
            usually the output of the truth.labels_with_splink_scores function
        threshold_pred (float): Threshold to use in categorising Splink predictions into
            match or no match
        spark (SparkSession): SparkSession object
        threshold_actual (float, optional): Threshold to use in categorising clerical match
            scores into match or no match. Defaults to 0.5.
        score_colname (float, optional): Allows user to explicitly state the column name
            in the Splink dataset containing the Splink score.  If none will be inferred

    Returns:
        DataFrame: Dataframe of labels associated with truth category
    """

    df_labels_with_splink_scores.createOrReplaceTempView(
        "df_labels_with_splink_scores")

    score_colname = _get_score_colname(df_labels_with_splink_scores)

    pred = f"({score_colname} >= {threshold_pred})"

    actual = f"(clerical_match_score >= {threshold_actual})"

    sql = f"""
    select
    *,
    cast ({threshold_pred} as float) as truth_threshold,
    {actual} = 1.0 as P,
    {actual} = 0.0 as N,
    {pred} = 1.0 and {actual} = 1.0 as TP,
    {pred} = 0.0 and {actual} = 0.0 as TN,
    {pred} = 1.0 and {actual} = 0.0 as FP,
    {pred} = 0.0 and {actual} = 1.0 as FN

    from
    df_labels_with_splink_scores

    """

    return spark.sql(sql)
Пример #24
0
def null_out_entries_with_freq_above_n(df: DataFrame, colname: str, n: int,
                                       spark: SparkSession):
    """Null out values above a certain frequency threshold

    Useful for columns that mostly contain valid data but occasionally
    contain other values such as 'unknown'

    Args:
        df (DataFrame): The dataframe to clean
        colname (string): The name of the column to clean
        n (int): The maximum frequency allowed.  Any values with a frequency higher than n will be nulled out
        spark (SparkSession): The spark session

    Returns:
        DataFrame: The cleaned dataframe with incoming column overwritten
    """

    # Possible that a window function would be better than the following approach
    # But I think both require a shuffle so possibly doesn't make much difference

    df.createOrReplaceTempView("df")

    sql = f"""
    select {colname} as count
    from df
    group by {colname}
    having count(*) > {n}
    """

    df_groups = spark.sql(sql)

    collected = df_groups.collect()

    values_to_null = [row["count"] for row in collected]

    if len(values_to_null) == 0:
        return df

    values_to_null = [f'"{v}"' for v in values_to_null]
    values_to_null_joined = ", ".join(values_to_null)

    case_statement = f"""
    CASE
    WHEN {colname} in ({values_to_null_joined}) THEN NULL
    ELSE {colname}
    END
    """

    df = df.withColumn(colname, f.expr(case_statement))

    return df
def execute(spark: SparkSession, log: logging, config: dict):
    log.info("extract")
    params = config['params']
    ps_conf = config['postgres']

    ts: datetime.datetime = params['ts']
    in_path = ts.strftime(params['in_path'])
    ts_from = config['ts_from']
    ts_to = config['ts_to']
    df = spark.read.csv(in_path, header=True, sep=';')
    df.select(
        F.col('FROM_PHONE_NUMBER'), F.col('TO_PHONE_NUMBER'),
        F.to_timestamp(df['START_TIME'], 'dd/MM/yyyy HH:mm:ss').alias('START_TIME'),
        F.col('CALL_DURATION').cast('long'), F.col('IMEI'), F.col('LOCATION')
    ).withColumn("TS", F.date_format(F.date_trunc("hour", "START_TIME"), "yyyy-MM-dd-HH"))
    df.write.partitionBy("TS").mode('append').format('hive').saveAsTable('task_02')
    df = spark.sql("select * from task_02 where TS >= {} AND TS < {}".format(ts_from, ts_to)).drop_duplicates()
    df.cache()
    ts = df.select("TS").rdd.map(lambda x: x[0]).first()
    # Number of call, total call duration.
    num_call = df.count()
    total_call_duration = list(df.select(F.sum(df['CALL_DURATION'])).first().asDict().values())[0]

    # Number of call in working hour (8am to 5pm)
    num_call_working_hour = df.filter("hour(START_TIME) >= 8 AND hour(START_TIME) <= 17").count()

    # Find the IMEI which make most call.
    imei_most = df.groupBy('IMEI').count().sort(F.col("count").desc()).first().asDict()

    # Find top 2 locations which make most call.
    locations = list(map(lambda x: x.asDict(), df.groupBy('LOCATION').count().sort(F.col("count").desc()).head(2)))

    rs = (ts, num_call, total_call_duration, num_call_working_hour, imei_most, locations)
    with get_postgres_cli(ps_conf) as ps_cli:
        with ps_cli.cursor() as cur:
            sql = """
            INSERT INTO metric_hour(
                ts, num_call, total_call_duration, 
                num_call_working_hour, imei_most, locations
            ) VALUES(%s, %s, %s, %s, %s, %s) 
            ON CONFLICT (ts) 
            DO UPDATE SET(
                num_call, total_call_duration, num_call_working_hour, imei_most, locations) = 
                (EXCLUDED.num_call, EXCLUDED.total_call_duration, EXCLUDED.num_call_working_hour
                 EXCLUDED.imei_most, EXCLUDED.locations)
            """
            cur.execute(sql, rs)
def test_can_save_csv(spark_session: SparkSession) -> None:
    # Arrange
    SparkTestHelper.clear_tables(spark_session)

    data_dir: Path = Path(__file__).parent.joinpath("./")
    test_file_path: str = f"{data_dir.joinpath('test.csv')}"

    if path.isdir(data_dir.joinpath("temp")):
        shutil.rmtree(data_dir.joinpath("temp"))

    schema = StructType([])

    df: DataFrame = spark_session.createDataFrame(
        spark_session.sparkContext.emptyRDD(), schema
    )

    FrameworkCsvLoader(
        view="my_view", filepath=test_file_path, delimiter=","
    ).transform(df)

    csv_file_path: str = f"file://{data_dir.joinpath('temp/').joinpath(f'test.csv')}"

    # Act
    FrameworkCsvExporter(
        view="my_view", file_path=csv_file_path, header=True, delimiter=","
    ).transform(df)

    # Assert
    FrameworkCsvLoader(
        view="my_view2", filepath=csv_file_path, delimiter=","
    ).transform(df)

    # noinspection SqlDialectInspection
    result: DataFrame = spark_session.sql("SELECT * FROM my_view2")

    result.show()

    assert result.count() == 3

    assert result.collect()[1][0] == "2"
    assert result.collect()[1][1] == "bar"
    assert result.collect()[1][2] == "bar2"
Пример #27
0
def test_can_load_xml_file(spark_session: SparkSession) -> None:
    # Arrange
    clean_spark_session(spark_session)

    data_dir: Path = Path(__file__).parent.joinpath("./")
    test_file_path: str = f"{data_dir.joinpath('test.xml')}"

    schema = StructType([])

    df: DataFrame = spark_session.createDataFrame(
        spark_session.sparkContext.emptyRDD(), schema)

    # Act
    FrameworkXmlLoader(view="my_view", filepath=test_file_path,
                       row_tag="book").transform(df)

    result: DataFrame = spark_session.sql("SELECT * FROM my_view")
    result.show()
    assert result.count() == 12
    assert len(result.columns) == 7
Пример #28
0
def test_can_load_multiline_csv(spark_session: SparkSession) -> None:
    # Arrange
    clean_spark_session(spark_session)

    data_dir: Path = Path(__file__).parent.joinpath("./")
    test_file_path: str = f"{data_dir.joinpath('multiline_row.csv')}"

    schema = StructType([])

    df: DataFrame = spark_session.createDataFrame(
        spark_session.sparkContext.emptyRDD(), schema)

    # Act
    FrameworkCsvLoader(view="my_view",
                       filepath=test_file_path,
                       delimiter=",",
                       multiline=True).transform(df)

    # noinspection SqlDialectInspection
    result: DataFrame = spark_session.sql("SELECT * FROM my_view")
    assert 1 == result.count()
Пример #29
0
    def clear_tables(spark_session: SparkSession) -> None:
        """
        :param spark_session:
        """

        # spark_session.sql("SET -v").show(n=200, truncate=False)

        tables = spark_session.catalog.listTables("default")

        for table in tables:
            print(f"clear_tables() dropping table/view: {table.name}")
            spark_session.sql(f"DROP TABLE IF EXISTS default.{table.name}")
            spark_session.sql(f"DROP VIEW IF EXISTS default.{table.name}")
            spark_session.sql(f"DROP VIEW IF EXISTS {table.name}")

        spark_session.catalog.clearCache()
Пример #30
0
def test_can_load_csv_without_header(spark_session: SparkSession) -> None:
    # Arrange
    clean_spark_session(spark_session)

    data_dir: Path = Path(__file__).parent.joinpath("./")
    test_file_path: str = f"{data_dir.joinpath('no_header.csv')}"

    schema = StructType([])

    df: DataFrame = spark_session.createDataFrame(
        spark_session.sparkContext.emptyRDD(), schema)

    # Act
    FrameworkCsvLoader(view="another_view",
                       filepath=test_file_path,
                       delimiter=",",
                       has_header=False).transform(df)

    # noinspection SqlDialectInspection
    result: DataFrame = spark_session.sql("SELECT * FROM another_view")

    # Assert
    assert_results(result)