def test_can_run_validated_framework_pipeline( spark_session: SparkSession) -> None: with pytest.raises(AssertionError): # Arrange clean_spark_session(spark_session) data_dir: Path = Path(__file__).parent.joinpath("./") flights_path: str = f"file://{data_dir.joinpath('flights.csv')}" output_path: str = f"file://{data_dir.joinpath('temp').joinpath('validation.csv')}" if path.isdir(data_dir.joinpath("temp")): shutil.rmtree(data_dir.joinpath("temp")) schema = StructType([]) df: DataFrame = spark_session.createDataFrame( spark_session.sparkContext.emptyRDD(), schema) spark_session.sql("DROP TABLE IF EXISTS default.flights") # Act parameters = { "flights_path": flights_path, "validation_source_path": str(data_dir), "validation_output_path": output_path, } with ProgressLogger() as progress_logger: pipeline: MyValidatedPipeline = MyValidatedPipeline( parameters=parameters, progress_logger=progress_logger) transformer = pipeline.fit(df) transformer.transform(df)
def test_simple_csv_and_sql_pipeline(spark_session: SparkSession) -> None: # Arrange data_dir: Path = Path(__file__).parent.joinpath("./") flights_path: str = f"file://{data_dir.joinpath('flights.csv')}" schema = StructType([]) df: DataFrame = spark_session.createDataFrame( spark_session.sparkContext.emptyRDD(), schema) spark_session.sql("DROP TABLE IF EXISTS default.flights") # Act parameters: Dict[str, Any] = {} stages: List[Transformer] = create_steps([ FrameworkCsvLoader(view="flights", filepath=flights_path), FeaturesCarriersV1(parameters=parameters), ]) pipeline: Pipeline = Pipeline(stages=stages) # type: ignore transformer = pipeline.fit(df) transformer.transform(df) # Assert result_df: DataFrame = spark_session.sql("SELECT * FROM flights2") result_df.show() assert result_df.count() > 0
def test_can_run_framework_pipeline(spark_session: SparkSession) -> None: # Arrange data_dir: Path = Path(__file__).parent.joinpath("./") flights_path: str = f"file://{data_dir.joinpath('flights.csv')}" schema = StructType([]) df: DataFrame = spark_session.createDataFrame( spark_session.sparkContext.emptyRDD(), schema) spark_session.sql("DROP TABLE IF EXISTS default.flights") # Act parameters = {"flights_path": flights_path} with ProgressLogger() as progress_logger: pipeline: MyPipeline = MyPipeline(parameters=parameters, progress_logger=progress_logger) transformer = pipeline.fit(df) transformer.transform(df) # Assert result_df: DataFrame = spark_session.sql("SELECT * FROM flights2") result_df.show() assert result_df.count() > 0
def test_fail_fast_validated_framework_pipeline_writes_results( spark_session: SparkSession, ) -> None: # Arrange clean_spark_session(spark_session) data_dir: Path = Path(__file__).parent.joinpath("./") flights_path: str = f"file://{data_dir.joinpath('flights.csv')}" output_path: str = f"file://{data_dir.joinpath('temp').joinpath('validation.csv')}" if path.isdir(data_dir.joinpath("temp")): shutil.rmtree(data_dir.joinpath("temp")) schema = StructType([]) df: DataFrame = spark_session.createDataFrame( spark_session.sparkContext.emptyRDD(), schema) spark_session.sql("DROP TABLE IF EXISTS default.flights") # Act parameters = { "flights_path": flights_path, "validation_source_path": str(data_dir), "validation_output_path": output_path, } try: with ProgressLogger() as progress_logger: pipeline: MyFailFastValidatedPipeline = MyFailFastValidatedPipeline( parameters=parameters, progress_logger=progress_logger) transformer = pipeline.fit(df) transformer.transform(df) except AssertionError: validation_df = df.sql_ctx.read.csv(output_path, header=True) validation_df.show(truncate=False) assert validation_df.count() == 1
def test_simple_csv_loader_pipeline(spark_session: SparkSession) -> None: # Arrange data_dir: Path = Path(__file__).parent.joinpath('./') flights_path: str = f"file://{data_dir.joinpath('flights.csv')}" schema = StructType([]) df: DataFrame = spark_session.createDataFrame( spark_session.sparkContext.emptyRDD(), schema) # noinspection SqlDialectInspection,SqlNoDataSourceInspection spark_session.sql("DROP TABLE IF EXISTS default.flights") # Act # parameters = Dict[str, Any]({ # }) stages: List[Union[Estimator, Transformer]] = [ FrameworkCsvLoader( view="flights", path_to_csv=flights_path ), SQLTransformer(statement="SELECT * FROM flights"), ] pipeline: Pipeline = Pipeline(stages=stages) transformer = pipeline.fit(df) result_df: DataFrame = transformer.transform(df) # Assert result_df.show() assert result_df.count() > 0
def make_adjustment_for_term_frequencies( df_e: DataFrame, model: Model, spark: SparkSession, retain_adjustment_columns: bool = False, ): # Running a maximisation step will eliminate errors cause by global parameters # being used in blocked jobs settings = model.current_settings_obj.settings_dict term_freq_column_list = [ cc.name for cc in model.current_settings_obj.comparison_columns_list if cc["term_frequency_adjustments"] is True ] if len(term_freq_column_list) == 0: return df_e retain_source_dataset_col = _retain_source_dataset_column(settings, df_e) df_e.createOrReplaceTempView("df_e") old_settings = deepcopy(model.current_settings_obj.settings_dict) for cc in model.current_settings_obj.comparison_columns_list: cc.column_dict["fix_m_probabilities"] = False cc.column_dict["fix_u_probabilities"] = False run_maximisation_step(df_e, model, spark) # Generate a lookup table for each column with 'term specific' lambdas. for c in term_freq_column_list: sql = sql_gen_generate_adjusted_lambda(c, model) logger.debug(_format_sql(sql)) lookup = spark.sql(sql) lookup.persist() lookup.createOrReplaceTempView(f"{c}_lookup") # Merge these lookup tables into main table sql = sql_gen_add_adjumentments_to_df_e(term_freq_column_list) logger.debug(_format_sql(sql)) df_e_adj = spark.sql(sql) df_e_adj.createOrReplaceTempView("df_e_adj") sql = sql_gen_compute_final_group_membership_prob_from_adjustments( term_freq_column_list, settings, retain_source_dataset_col) logger.debug(_format_sql(sql)) df = spark.sql(sql) if not retain_adjustment_columns: for c in term_freq_column_list: df = df.drop(c + "_tf_adj") # Restore original settings model.current_settings_obj.settings_dict = old_settings return df
def test_model_codegen_registered(spark: SparkSession): init(spark, True) spark.sql( """CREATE MODEL foo_dynamic OPTIONS (foo="str",bar=True,max_score=1.23) USING 'test://model/a/b/c'""" ).count() init(spark, False) spark.sql( """CREATE MODEL foo_static OPTIONS (foo="str",bar=True,max_score=1.23) USING 'test://model/a/b/c'""" ).count()
def test_can_load_non_standard_delimited_csv( spark_session: SparkSession) -> None: # Arrange clean_spark_session(spark_session) data_dir: Path = Path(__file__).parent.joinpath("./") test_file_path: str = f"{data_dir.joinpath('test.psv')}" schema = StructType([]) df: DataFrame = spark_session.createDataFrame( spark_session.sparkContext.emptyRDD(), schema) # Act loader = FrameworkCsvLoader(view="my_view", filepath=test_file_path, delimiter="|") loader.transform(df) # noinspection SqlDialectInspection result: DataFrame = spark_session.sql("SELECT * FROM my_view") result.show() # Assert assert loader.getDelimiter() == "|" assert_results(result)
def test_correctly_loads_csv_with_clean_flag_on( spark_session: SparkSession) -> None: # Arrange clean_spark_session(spark_session) data_dir: Path = Path(__file__).parent.joinpath("./") test_file_path: str = f"{data_dir.joinpath('column_name_test.csv')}" schema = StructType([]) df: DataFrame = spark_session.createDataFrame( spark_session.sparkContext.emptyRDD(), schema) # Act FrameworkCsvLoader( view="my_view", filepath=test_file_path, delimiter=",", clean_column_names=True, ).transform(df) # noinspection SqlDialectInspection result: DataFrame = spark_session.sql("SELECT * FROM my_view") # Assert assert_results(result) assert result.collect()[1][0] == "2" assert (result.columns[2] == "Ugly_column_with_chars_that_parquet_does_not_like_much_-")
def test_can_keep_columns(spark_session: SparkSession) -> None: # Arrange clean_spark_session(spark_session) data_dir: Path = Path(__file__).parent.joinpath("./") test_file_path: str = f"{data_dir.joinpath('test.csv')}" schema = StructType([]) df: DataFrame = spark_session.createDataFrame( spark_session.sparkContext.emptyRDD(), schema) # Act FrameworkCsvLoader(view="my_view", filepath=test_file_path, delimiter=",").transform(df) FrameworkSelectColumnsTransformer(view="my_view", keep_columns=["Column2"]).transform(df) # noinspection SqlDialectInspection result: DataFrame = spark_session.sql("SELECT * FROM my_view") result.show() # Assert assert len(result.columns) == 1 assert result.count() == 3 assert result.collect()[1][0] == "bar"
def block_using_rules(settings: dict, df: DataFrame, spark: SparkSession): """Apply a series of blocking rules to create a dataframe of record comparisons. If no blocking rules provided, performs a cartesian join. Args: settings (dict): A splink settings dictionary df (DataFrame): Spark dataframe to block - if linking multiple datasets, assumes dataframes have already been vertically concatenated spark (SparkSession): The pyspark.sql.session.SparkSession Returns: pyspark.sql.dataframe.DataFrame: A dataframe of each record comparison """ df.createOrReplaceTempView("df") columns_to_retain = _get_columns_to_retain_blocking(settings, df) unique_id_col = settings["unique_id_column_name"] if settings["link_type"] == "dedupe_only": source_dataset_col = None else: source_dataset_col = settings["source_dataset_column_name"] link_type = settings["link_type"] if "blocking_rules" not in settings or len( settings["blocking_rules"]) == 0: sql = _sql_gen_cartesian_block(link_type, columns_to_retain, unique_id_col, source_dataset_col) else: rules = settings["blocking_rules"] sql = _sql_gen_block_using_rules(link_type, columns_to_retain, rules, unique_id_col, source_dataset_col) logger.debug(_format_sql(sql)) df_comparison = spark.sql(sql) return df_comparison
def run_maximisation_step(df_e: DataFrame, model: Model, spark: SparkSession): """Compute new parameters and save them in the model object Note that the model object will be updated in-place by this function Args: df_e (DataFrame): the result of the expectation step model (Model): splink Model object spark (SparkSession): The spark session """ sql = _sql_gen_intermediate_pi_aggregate(model) df_e.createOrReplaceTempView("df_e") df_intermediate = spark.sql(sql) logger.debug(_format_sql(sql)) df_intermediate.createOrReplaceTempView("df_intermediate") df_intermediate.persist() new_lambda = _get_new_lambda(df_intermediate, spark) pi_df_collected = _get_new_pi_df(df_intermediate, spark, model) model._populate_model_from_maximisation_step(new_lambda, pi_df_collected) model.iteration += 1 df_intermediate.unpersist()
def test_can_load_xml_file_with_schema(spark_session: SparkSession) -> None: # Arrange clean_spark_session(spark_session) data_dir: Path = Path(__file__).parent.joinpath("./") test_file_path: str = f"{data_dir.joinpath('test.xml')}" schema = StructType([]) df: DataFrame = spark_session.createDataFrame( spark_session.sparkContext.emptyRDD(), schema) xml_shema = StructType([ StructField("_id", StringType(), True), StructField("author", StringType(), True), StructField("description", StringType(), True), StructField("genre", StringType(), True), StructField("price", DoubleType(), True), StructField("publish_date", StringType(), True), StructField("title", StringType(), True), ]) # Act FrameworkXmlLoader(view="my_view", filepath=test_file_path, row_tag="book", schema=xml_shema).transform(df) result: DataFrame = spark_session.sql("SELECT * FROM my_view") result.show() assert result.count() == 12 assert len(result.columns) == 7
def add_gammas( df_comparison: DataFrame, settings_dict: dict, spark:SparkSession, unique_id_col: str = "unique_id", ): """ Compute the comparison vectors and add them to the dataframe. See https://imai.fas.harvard.edu/research/files/linkage.pdf for more details of what is meant by comparison vectors Args: df_comparison (spark dataframe): A Spark dataframe containing record comparisons, with records compared using the convention col_name_l, col_name_r settings_dict (dict): The `splink` settings dictionary spark (Spark session): The Spark session object unique_id_col (str, optional): Name of the unique id column. Defaults to "unique_id". Returns: Spark dataframe: A dataframe containing new columns representing the gammas of the model """ settings_dict = complete_settings_dict(settings_dict, spark) sql = _sql_gen_add_gammas( settings_dict, unique_id_col=unique_id_col, ) logger.debug(_format_sql(sql)) df_comparison.createOrReplaceTempView("df_comparison") df_gammas = spark.sql(sql) return df_gammas
def test_can_load_parquet(spark_session: SparkSession): # Arrange SparkTestHelper.clear_tables(spark_session) data_dir: Path = Path(__file__).parent.joinpath('./') test_file_path: str = f"{data_dir.joinpath('test.csv')}" if path.isdir(data_dir.joinpath('temp')): shutil.rmtree(data_dir.joinpath('temp')) schema = StructType([]) df: DataFrame = spark_session.createDataFrame( spark_session.sparkContext.emptyRDD(), schema) parquet_file_path: str = ParquetHelper.create_parquet_from_csv( spark_session=spark_session, file_path=test_file_path) # Act FrameworkParquetLoader(view="my_view", file_path=parquet_file_path).transform(df) # noinspection SqlDialectInspection result: DataFrame = spark_session.sql("SELECT * FROM my_view") result.show() # Assert assert result.count() == 3 assert result.collect()[1][0] == 2 assert result.collect()[1][1] == "bar" assert result.collect()[1][2] == "bar2"
def run_maximisation_step(df_e: DataFrame, params: Params, spark: SparkSession): """Compute new parameters and save them in the params object Note that the params object will be updated in-place by this function Args: df_e (DataFrame): the result of the expectation step params (Params): splink Params object spark (SparkSession): The spark session """ sql = _sql_gen_intermediate_pi_aggregate(params) df_e.createOrReplaceTempView("df_e") df_intermediate = spark.sql(sql) logger.debug(_format_sql(sql)) df_intermediate.createOrReplaceTempView("df_intermediate") df_intermediate.persist() new_lambda = _get_new_lambda(df_intermediate, spark) pi_df_collected = _get_new_pi_df(df_intermediate, spark, params) params._update_params(new_lambda, pi_df_collected) df_intermediate.unpersist()
def make_adjustment_for_term_frequencies( df_e: DataFrame, params: Params, settings: dict, spark: SparkSession, retain_adjustment_columns: bool = False ): df_e.createOrReplaceTempView("df_e") term_freq_column_list = [ c["col_name"] for c in settings["comparison_columns"] if c["term_frequency_adjustments"] == True ] if len(term_freq_column_list) == 0: warnings.warn( "No term frequency adjustment columns are specified in your settings object. Returning original df" ) return df_e # Generate a lookup table for each column with 'term specific' lambdas. for c in term_freq_column_list: sql = sql_gen_generate_adjusted_lambda(c, params) logger.debug(_format_sql(sql)) lookup = spark.sql(sql) lookup.persist() lookup.createOrReplaceTempView(f"{c}_lookup") # Merge these lookup tables into main table sql = sql_gen_add_adjumentments_to_df_e(term_freq_column_list) logger.debug(_format_sql(sql)) df_e_adj = spark.sql(sql) df_e_adj.createOrReplaceTempView("df_e_adj") sql = sql_gen_compute_final_group_membership_prob_from_adjustments( term_freq_column_list, settings ) logger.debug(_format_sql(sql)) df = spark.sql(sql) if not retain_adjustment_columns: for c in term_freq_column_list: df = df.drop(c + "_adj") return df
def run_expectation_step( df_with_gamma: DataFrame, model: Model, spark: SparkSession, compute_ll=False, ): """Run the expectation step of the EM algorithm described in the fastlink paper: http://imai.fas.harvard.edu/research/files/linkage.pdf Args: df_with_gamma (DataFrame): Spark dataframe with comparison vectors already populated model (Model): splink Model object spark (SparkSession): SparkSession compute_ll (bool, optional): Whether to compute the log likelihood. Degrades performance. Defaults to False. Returns: DataFrame: Spark dataframe with a match_probability column """ retain_source_dataset = _retain_source_dataset_column( model.current_settings_obj.settings_dict, df_with_gamma) sql = _sql_gen_gamma_prob_columns(model, retain_source_dataset) df_with_gamma.createOrReplaceTempView("df_with_gamma") logger.debug(_format_sql(sql)) df_with_gamma_probs = spark.sql(sql) # This is optional because is slows down execution if compute_ll: ll = get_overall_log_likelihood(df_with_gamma_probs, model, spark) message = f"Log likelihood for iteration {model.iteration-1}: {ll}" logger.info(message) model.current_settings_obj["log_likelihood"] = ll sql = _sql_gen_expected_match_prob(model, retain_source_dataset) logger.debug(_format_sql(sql)) df_with_gamma_probs.createOrReplaceTempView("df_with_gamma_probs") df_e = spark.sql(sql) df_e.createOrReplaceTempView("df_e") model.save_settings_to_iteration_history() return df_e
def test_can_load_fixed_width(spark_session: SparkSession) -> None: # Arrange clean_spark_session(spark_session) data_dir: Path = Path(__file__).parent.joinpath("./") test_file_path: str = f"{data_dir.joinpath('test.txt')}" schema = StructType([]) df: DataFrame = spark_session.createDataFrame( spark_session.sparkContext.emptyRDD(), schema) # Act FrameworkFixedWidthLoader( view="my_view", filepath=test_file_path, columns=[ ColumnSpec(column_name="id", start_pos=1, length=3, data_type=StringType()), ColumnSpec(column_name="some_date", start_pos=4, length=8, data_type=StringType()), ColumnSpec( column_name="some_string", start_pos=12, length=3, data_type=StringType(), ), ColumnSpec( column_name="some_integer", start_pos=15, length=4, data_type=IntegerType(), ), ], ).transform(df) # noinspection SqlDialectInspection result: DataFrame = spark_session.sql("SELECT * FROM my_view") result.show() # Assert assert result.count() == 2 assert result.collect()[0][0] == "001" assert result.collect()[1][0] == "002" assert result.collect()[0][1] == "01292017" assert result.collect()[1][1] == "01302017" assert result.collect()[0][2] == "you" assert result.collect()[1][2] == "me" assert result.collect()[0][3] == 1234 assert result.collect()[1][3] == 5678
class TweetsSpark: def __init__(self): self.tweets = None self.tweets_table = None sc = SparkContext.getOrCreate() self.spark = SparkSession(sc) # Метод для записи в переменную, выборочные данные из файла. def set_tweets_data(self, path): # Чтение данных и разбиение на колонки. self.tweets = self.spark.read.csv(path, header=True, escape='\"') self.tweets.createOrReplaceTempView("tweets") def set_tweets_foreign(self): sql = """\ CREATE TEMP VIEW tweets_foreign AS SELECT userid, CAST(reply_count AS INT), account_language FROM tweets WHERE account_language != 'ru' """ self.spark.sql(sql) def get_userid_max_rc(self): sql = """\ SELECT userid FROM tweets_foreign WHERE reply_count = ( SELECT MAX(reply_count) FROM tweets_foreign ) LIMIT 1 """ # Result userid 4224729994 return self.spark.sql(sql).show() def get_sorted_data(self): sql = """\ SELECT * FROM tweeets_foreign ORDER BY reply_count DESC """ return self.spark.sql(sql).show()
def block_using_rules( settings: dict, spark: SparkSession, df_l: DataFrame=None, df_r: DataFrame=None, df: DataFrame=None ): """Apply a series of blocking rules to create a dataframe of record comparisons. If no blocking rules provided, performs a cartesian join. Args: settings (dict): A splink settings dictionary spark (SparkSession): The pyspark.sql.session.SparkSession df_l (DataFrame, optional): Where `link_type` is `link_only` or `link_and_dedupe`, one of the two dataframes to link. Should be ommitted `link_type` is `dedupe_only`. df_r (DataFrame, optional): Where `link_type` is `link_only` or `link_and_dedupe`, one of the two dataframes to link. Should be ommitted `link_type` is `dedupe_only`. df (DataFrame, optional): Where `link_type` is `dedupe_only`, the dataframe to dedupe. Should be ommitted `link_type` is `link_only` or `link_and_dedupe`. Returns: pyspark.sql.dataframe.DataFrame: A dataframe of each record comparison """ if "blocking_rules" not in settings or len(settings["blocking_rules"])==0: return cartesian_block(settings, spark, df_l, df_r, df) link_type = settings["link_type"] columns_to_retain = _get_columns_to_retain_blocking(settings) unique_id_col = settings["unique_id_column_name"] if link_type == "dedupe_only": df.createOrReplaceTempView("df") if link_type == "link_only": df_l.createOrReplaceTempView("df_l") df_r.createOrReplaceTempView("df_r") if link_type == "link_and_dedupe": df_concat = _vertically_concatenate_datasets(df_l, df_r, settings, spark=spark) columns_to_retain.append("_source_table") df_concat.createOrReplaceTempView("df") df_concat.persist() rules = settings["blocking_rules"] sql = _sql_gen_block_using_rules(link_type, columns_to_retain, rules, unique_id_col) logger.debug(_format_sql(sql)) df_comparison = spark.sql(sql) if link_type == "link_and_dedupe": df_concat.unpersist() return df_comparison
def run_expectation_step(df_with_gamma: DataFrame, params: Params, settings: dict, spark: SparkSession, compute_ll=False): """Run the expectation step of the EM algorithm described in the fastlink paper: http://imai.fas.harvard.edu/research/files/linkage.pdf Args: df_with_gamma (DataFrame): Spark dataframe with comparison vectors already populated params (Params): splink params object settings (dict): splink settings dictionary spark (SparkSession): SparkSession compute_ll (bool, optional): Whether to compute the log likelihood. Degrades performance. Defaults to False. Returns: DataFrame: Spark dataframe with a match_probability column """ sql = _sql_gen_gamma_prob_columns(params, settings) df_with_gamma.createOrReplaceTempView("df_with_gamma") logger.debug(_format_sql(sql)) df_with_gamma_probs = spark.sql(sql) # This is optional because is slows down execution if compute_ll: ll = get_overall_log_likelihood(df_with_gamma_probs, params, spark) message = f"Log likelihood for iteration {params.iteration-1}: {ll}" logger.info(message) params.params["log_likelihood"] = ll sql = _sql_gen_expected_match_prob(params, settings) logger.debug(_format_sql(sql)) df_with_gamma_probs.createOrReplaceTempView("df_with_gamma_probs") df_e = spark.sql(sql) df_e.createOrReplaceTempView("df_e") return df_e
def df_e_with_truth_categories( df_labels_with_splink_scores, threshold_pred, spark: SparkSession, threshold_actual: float = 0.5, score_colname: str = None, ): """Join Splink's predictions to clerically labelled data and categorise rows by truth category (false positive, true positive etc.) Note that df_labels Args: df_labels_with_splink_scores (DataFrame): A dataframe of labels and associated splink scores usually the output of the truth.labels_with_splink_scores function threshold_pred (float): Threshold to use in categorising Splink predictions into match or no match spark (SparkSession): SparkSession object threshold_actual (float, optional): Threshold to use in categorising clerical match scores into match or no match. Defaults to 0.5. score_colname (float, optional): Allows user to explicitly state the column name in the Splink dataset containing the Splink score. If none will be inferred Returns: DataFrame: Dataframe of labels associated with truth category """ df_labels_with_splink_scores.createOrReplaceTempView( "df_labels_with_splink_scores") score_colname = _get_score_colname(df_labels_with_splink_scores) pred = f"({score_colname} >= {threshold_pred})" actual = f"(clerical_match_score >= {threshold_actual})" sql = f""" select *, cast ({threshold_pred} as float) as truth_threshold, {actual} = 1.0 as P, {actual} = 0.0 as N, {pred} = 1.0 and {actual} = 1.0 as TP, {pred} = 0.0 and {actual} = 0.0 as TN, {pred} = 1.0 and {actual} = 0.0 as FP, {pred} = 0.0 and {actual} = 1.0 as FN from df_labels_with_splink_scores """ return spark.sql(sql)
def null_out_entries_with_freq_above_n(df: DataFrame, colname: str, n: int, spark: SparkSession): """Null out values above a certain frequency threshold Useful for columns that mostly contain valid data but occasionally contain other values such as 'unknown' Args: df (DataFrame): The dataframe to clean colname (string): The name of the column to clean n (int): The maximum frequency allowed. Any values with a frequency higher than n will be nulled out spark (SparkSession): The spark session Returns: DataFrame: The cleaned dataframe with incoming column overwritten """ # Possible that a window function would be better than the following approach # But I think both require a shuffle so possibly doesn't make much difference df.createOrReplaceTempView("df") sql = f""" select {colname} as count from df group by {colname} having count(*) > {n} """ df_groups = spark.sql(sql) collected = df_groups.collect() values_to_null = [row["count"] for row in collected] if len(values_to_null) == 0: return df values_to_null = [f'"{v}"' for v in values_to_null] values_to_null_joined = ", ".join(values_to_null) case_statement = f""" CASE WHEN {colname} in ({values_to_null_joined}) THEN NULL ELSE {colname} END """ df = df.withColumn(colname, f.expr(case_statement)) return df
def execute(spark: SparkSession, log: logging, config: dict): log.info("extract") params = config['params'] ps_conf = config['postgres'] ts: datetime.datetime = params['ts'] in_path = ts.strftime(params['in_path']) ts_from = config['ts_from'] ts_to = config['ts_to'] df = spark.read.csv(in_path, header=True, sep=';') df.select( F.col('FROM_PHONE_NUMBER'), F.col('TO_PHONE_NUMBER'), F.to_timestamp(df['START_TIME'], 'dd/MM/yyyy HH:mm:ss').alias('START_TIME'), F.col('CALL_DURATION').cast('long'), F.col('IMEI'), F.col('LOCATION') ).withColumn("TS", F.date_format(F.date_trunc("hour", "START_TIME"), "yyyy-MM-dd-HH")) df.write.partitionBy("TS").mode('append').format('hive').saveAsTable('task_02') df = spark.sql("select * from task_02 where TS >= {} AND TS < {}".format(ts_from, ts_to)).drop_duplicates() df.cache() ts = df.select("TS").rdd.map(lambda x: x[0]).first() # Number of call, total call duration. num_call = df.count() total_call_duration = list(df.select(F.sum(df['CALL_DURATION'])).first().asDict().values())[0] # Number of call in working hour (8am to 5pm) num_call_working_hour = df.filter("hour(START_TIME) >= 8 AND hour(START_TIME) <= 17").count() # Find the IMEI which make most call. imei_most = df.groupBy('IMEI').count().sort(F.col("count").desc()).first().asDict() # Find top 2 locations which make most call. locations = list(map(lambda x: x.asDict(), df.groupBy('LOCATION').count().sort(F.col("count").desc()).head(2))) rs = (ts, num_call, total_call_duration, num_call_working_hour, imei_most, locations) with get_postgres_cli(ps_conf) as ps_cli: with ps_cli.cursor() as cur: sql = """ INSERT INTO metric_hour( ts, num_call, total_call_duration, num_call_working_hour, imei_most, locations ) VALUES(%s, %s, %s, %s, %s, %s) ON CONFLICT (ts) DO UPDATE SET( num_call, total_call_duration, num_call_working_hour, imei_most, locations) = (EXCLUDED.num_call, EXCLUDED.total_call_duration, EXCLUDED.num_call_working_hour EXCLUDED.imei_most, EXCLUDED.locations) """ cur.execute(sql, rs)
def test_can_save_csv(spark_session: SparkSession) -> None: # Arrange SparkTestHelper.clear_tables(spark_session) data_dir: Path = Path(__file__).parent.joinpath("./") test_file_path: str = f"{data_dir.joinpath('test.csv')}" if path.isdir(data_dir.joinpath("temp")): shutil.rmtree(data_dir.joinpath("temp")) schema = StructType([]) df: DataFrame = spark_session.createDataFrame( spark_session.sparkContext.emptyRDD(), schema ) FrameworkCsvLoader( view="my_view", filepath=test_file_path, delimiter="," ).transform(df) csv_file_path: str = f"file://{data_dir.joinpath('temp/').joinpath(f'test.csv')}" # Act FrameworkCsvExporter( view="my_view", file_path=csv_file_path, header=True, delimiter="," ).transform(df) # Assert FrameworkCsvLoader( view="my_view2", filepath=csv_file_path, delimiter="," ).transform(df) # noinspection SqlDialectInspection result: DataFrame = spark_session.sql("SELECT * FROM my_view2") result.show() assert result.count() == 3 assert result.collect()[1][0] == "2" assert result.collect()[1][1] == "bar" assert result.collect()[1][2] == "bar2"
def test_can_load_xml_file(spark_session: SparkSession) -> None: # Arrange clean_spark_session(spark_session) data_dir: Path = Path(__file__).parent.joinpath("./") test_file_path: str = f"{data_dir.joinpath('test.xml')}" schema = StructType([]) df: DataFrame = spark_session.createDataFrame( spark_session.sparkContext.emptyRDD(), schema) # Act FrameworkXmlLoader(view="my_view", filepath=test_file_path, row_tag="book").transform(df) result: DataFrame = spark_session.sql("SELECT * FROM my_view") result.show() assert result.count() == 12 assert len(result.columns) == 7
def test_can_load_multiline_csv(spark_session: SparkSession) -> None: # Arrange clean_spark_session(spark_session) data_dir: Path = Path(__file__).parent.joinpath("./") test_file_path: str = f"{data_dir.joinpath('multiline_row.csv')}" schema = StructType([]) df: DataFrame = spark_session.createDataFrame( spark_session.sparkContext.emptyRDD(), schema) # Act FrameworkCsvLoader(view="my_view", filepath=test_file_path, delimiter=",", multiline=True).transform(df) # noinspection SqlDialectInspection result: DataFrame = spark_session.sql("SELECT * FROM my_view") assert 1 == result.count()
def clear_tables(spark_session: SparkSession) -> None: """ :param spark_session: """ # spark_session.sql("SET -v").show(n=200, truncate=False) tables = spark_session.catalog.listTables("default") for table in tables: print(f"clear_tables() dropping table/view: {table.name}") spark_session.sql(f"DROP TABLE IF EXISTS default.{table.name}") spark_session.sql(f"DROP VIEW IF EXISTS default.{table.name}") spark_session.sql(f"DROP VIEW IF EXISTS {table.name}") spark_session.catalog.clearCache()
def test_can_load_csv_without_header(spark_session: SparkSession) -> None: # Arrange clean_spark_session(spark_session) data_dir: Path = Path(__file__).parent.joinpath("./") test_file_path: str = f"{data_dir.joinpath('no_header.csv')}" schema = StructType([]) df: DataFrame = spark_session.createDataFrame( spark_session.sparkContext.emptyRDD(), schema) # Act FrameworkCsvLoader(view="another_view", filepath=test_file_path, delimiter=",", has_header=False).transform(df) # noinspection SqlDialectInspection result: DataFrame = spark_session.sql("SELECT * FROM another_view") # Assert assert_results(result)