Пример #1
0
def test_from_db_column(spark: SparkSession, df_from_db_left: DataFrame,
                        df_from_db_right: DataFrame) -> None:
    """[Compare two dataframe sourced from database. One side has NULLS/BLANKs in database. Expect no differences
    One side has a column not on right side and vice versa]

    Args:
        spark (SparkSession): [Spark session]
        df_from_db_left (DataFrame): [Spark dataframe source from database.]
        df_from_db_right (DataFrame): [Spark dataframe source from database]
    """
    from pyspark.sql.functions import lit
    df_from_db_left_mut = df_from_db_left.withColumn("LEFT_COL",
                                                     lit("left_col_value"))
    df_from_db_right_mut = df_from_db_right.withColumn("RIGHT_COL",
                                                       lit("right_col_value"))
    dfResult = dfc.compareDfs(
        spark,
        df_from_db_left_mut,
        df_from_db_right_mut,
        tolerance=0.1,
        keysLeft="bsr",
        keysRight="bsr",
        colExcludeList=[
            "n1", "n2", "n3", "n4", "n5", "tx", "this_col_does_not_exist"
        ],
        joinType="full_outer",
    )
    pass_count = dfResult.filter("PASS == True").count()
    overall_count = dfResult.count()
    assert pass_count == overall_count
def null_out_values_array(df: DataFrame, array_colname: str, values_to_null: list):
    """Null out a user defined list of undesirable values in a column that contains an array of values
    Useful for columns that mostly contain valid data but occasionally
    contain other values such as 'unknown'
   
    Args:
        df (DataFrame): The dataframe to clean
        colname (string): The name of the column to clean
        values_to_null (list): A list of values to be nulled.
    Returns:
        DataFrame: The cleaned dataframe with column containing array that has values in values_to_null nulled
    """
    if len(values_to_null) > 0:

        if str((dict(df.dtypes)[array_colname])).startswith("array"):

            array_args = [f.lit(v) for v in values_to_null]
            df = df.withColumn("vals_to_remove", f.array(*array_args))
            df = df.withColumn(
                array_colname, f.expr(f"array_except({array_colname}, vals_to_remove)")
            )
            df = df.drop("vals_to_remove")

        else:
            # if column is not an array fire up a warning
            warnings.warn(
                f""" column {array_colname} is not an array. Please use function null_out_values instead  """
            )

    return df
Пример #3
0
def standardise_names(df: DataFrame, name_cols: list, drop_orig: bool = True):
    """Take a one or more name columns in a list and standardise the names
    so one name appears in each column consistently

    Args:
        df (DataFrame): Spark DataFrame
        name_cols (list): A list of columns that contain names, in order from first name to last name
        drop_orig (bool, optional): Drop the original columns after standardisation. Defaults to True.

    Returns:
        DataFrame: A Spark DataFrame with standardised name columns
    """

    name_col_joined = ", ".join(name_cols)
    surname_col_name = name_cols[-1]
    df = df.withColumn('name_concat',
                       expr(f"concat_ws(' ', {name_col_joined})"))
    df = df.withColumn('name_concat', expr('lower(name_concat)'))
    df = df.withColumn('name_concat',
                       expr("regexp_replace(name_concat, '[\\-\\.]', ' ')"))
    df = df.withColumn('name_arr', expr("split(name_concat, ' ')"))
    df = df.withColumn(
        'surname_std',
        expr(
            f"case when {surname_col_name} is not null then element_at(name_arr,-1) else null end"
        ))
    df = df.withColumn(
        'forename1_std',
        expr(
            "case when size(name_arr) > 1 then element_at(name_arr,1) else null end"
        ))
    df = df.withColumn(
        'forename2_std',
        expr(
            "case when size(name_arr) > 2 then element_at(name_arr,2) else null end"
        ))
    df = df.withColumn(
        'forename3_std',
        expr(
            "case when size(name_arr) > 3 then element_at(name_arr,3) else null end"
        ))
    df = df.withColumn(
        'forename4_std',
        expr(
            "case when size(name_arr) > 4 then element_at(name_arr,4) else null end"
        ))
    df = df.withColumn(
        'forename5_std',
        expr(
            "case when size(name_arr) > 5 then element_at(name_arr,5) else null end"
        ))
    df = df.drop("name_arr", "name_concat")
    if drop_orig:
        for n in name_cols:
            df = df.drop(n)
    return df
Пример #4
0
def convert_types_for_kafka(df: DataFrame) -> DataFrame:
    to_array = udf(lambda v: v.toArray().tolist(), ArrayType(FloatType()))
    to_value = udf(
        lambda radiant_lineup, dire_lineup, radiant_win_prediction, probability_arr:
        f'{{"radiant_lineup": {radiant_lineup}, "dire_lineup": {dire_lineup}, "radiant_win_prediction": {"true" if radiant_win_prediction else "false"}, "probability": {probability_arr}}}',
        StringType())

    df = df.withColumn("radiant_win_prediction", df.prediction.cast(BooleanType())) \
             .withColumn("probability_arr", to_array(df.probability))
    return df.withColumn("value", to_value(df.radiant_lineup, df.dire_lineup, df.radiant_win_prediction, df.probability_arr))
 def _create_partitions(self, dataframe: DataFrame) -> DataFrame:
     # create year partition column
     dataframe = dataframe.withColumn(
         columns.PARTITION_YEAR, year(dataframe[columns.TIMESTAMP_COLUMN]))
     # create month partition column
     dataframe = dataframe.withColumn(
         columns.PARTITION_MONTH,
         month(dataframe[columns.TIMESTAMP_COLUMN]))
     # create day partition column
     dataframe = dataframe.withColumn(
         columns.PARTITION_DAY,
         dayofmonth(dataframe[columns.TIMESTAMP_COLUMN]))
     return repartition_df(dataframe, self.PARTITION_BY,
                           self.num_partitions)
Пример #6
0
def validateXpath(df: DataFrame) -> DataFrame:
    def isXpathValid(xpath) -> str:
        try:
            lxml.etree.XPath(str(xpath).strip())
            return 'true'
        except Exception as ex:
            return 'false'

    validate_xpath_udf = udf(isXpathValid, StringType())
    df = df.withColumn('xpathvalidation', validate_xpath_udf('XPATH'))
    df = df.withColumn(
        'xpathleafnode',
        split(col('xpath'), '/')[size(split(col('xpath'), '/')) - 1])
    return df
Пример #7
0
def prepare_model(df: DataFrame) -> tuple:
    """[prepare pyspark model for logistic regression]

    Args:
        df ([DataFrame]): [spark dataframe]

    Returns:
        [tuple]: [tuple of vector assembler and lr model for predictions]
    """
    #label_indexer = StringIndexer(inputCol='class', outputCol='class_index')
    df = df.withColumn(
        'class_index',
        when(col('class') == 'Iris-setosa',
             0).when(col('class') == 'Iris-versicolor', 1).otherwise(2))
    vector_assembler = VectorAssembler(inputCols=[
        'sepal_length', 'sepal_width', 'petal_length', 'petal_width'
    ],
                                       outputCol='features')
    # stages = [vector_assembler]
    # pipeline = Pipeline(stages=stages)
    # pipelineModel = pipeline.fit(df)
    #transformed_df = pipelineModel.transform(df)
    transformed_df = vector_assembler.transform(df)
    selectedCols = ['features', 'class_index']
    final_df = transformed_df.select(selectedCols)
    log_reg = LogisticRegression(featuresCol='features',
                                 labelCol='class_index',
                                 regParam=100000)
    lrModel = log_reg.fit(final_df)
    return (vector_assembler, lrModel)
Пример #8
0
    def _struct_df(self, df: DataFrame) -> DataFrame:
        """Struct the output dataframe generated by the reader.

        Under the default "value" column coming from Kafka there are the custom
        fields created by some producer. This function will struct the dataframe as
        to get all desired fields from "value" and insert all Kafka default columns,
        including "value", under "kafka_metadata" column. It is important to notice
        that the declared value_schema suffer from the same effects described in
        explode_json_column method in pre_processing module.

        Args:
            df: direct dataframe output from from KafkaReader.

        Returns:
            Structured dataframe with kafka value fields as columns.
                All other default fields from Kafka will be stored under
                "kafka_metadata" column.

        """
        df = df.withColumn("kafka_metadata", struct(*self.KAFKA_COLUMNS))
        df = explode_json_column(df,
                                 column="value",
                                 json_schema=self.value_schema)
        return df.select([field.name
                          for field in self.value_schema] + ["kafka_metadata"])
Пример #9
0
 def _transform(self, dataset: DataFrame) -> DataFrame:
     self.transformSchema(dataset.schema)
     transformUDF = udf(self.createTransformFunc(), self.outputDataType())
     transformedDataset = dataset.withColumn(
         self.getOutputCol(), transformUDF(dataset[self.getInputCol()])
     )
     return transformedDataset
Пример #10
0
def clean_immigration(df: SparkDataFrame) -> SparkDataFrame:
    """Clean immigration data

    :param df: immigration data frame to be cleaned.
    :return: cleaned immigration data frame
    """

    drop_cols = [
        'visapost', 'occup', 'entdepu', 'insnum', 'count', 'entdepa',
        'entdepd', 'matflag', 'dtaddto', 'biryear', 'admnum'
    ]
    int_cols = [
        'cicid', 'i94yr', 'i94mon', 'i94cit', 'i94res', 'i94mode', 'i94bir',
        'i94visa', 'dtadfile'
    ]
    date_cols = ['arrdate', 'depdate']
    date_udf = udf(lambda x: x and (timedelta(days=int(x)) + datetime(
        1960, 1, 1)).strftime('%Y-%m-%d'))

    df = df.drop(*drop_cols)
    df = convert_column_type(df, 'integer', int_cols)
    for col in date_cols:
        df = df.withColumn(col, date_udf(df[col]))

    # Remove the row if the data in any of fk column is lost
    fk_columns = ['i94cit', 'i94port', 'i94addr']
    df = reduce(lambda df, idx: df.filter(df[fk_columns[idx]].isNotNull()),
                range(len(fk_columns)), df)

    return df
Пример #11
0
def clean_countries(df: SparkDataFrame) -> SparkDataFrame:
    """Clean countries data

    :param df: countries data frame to be cleaned.
    :return: cleaned countries data frame
    """

    df = convert_column_type(df, 'integer', ['code'])

    # change the name to match the names in demographics for further operations.
    name_to_change = [
        ('MEXICO Air Sea, and Not Reported (I-94, no land arrivals)',
         'MEXICO'), ('BOSNIA-HERZEGOVINA', 'BOSNIA AND HERZEGOVINA'),
        ('INVALID: CANADA', 'CANADA'), ('CHINA, PRC', 'CHINA'),
        ('GUINEA-BISSAU', 'GUINEA BISSAU'),
        ('INVALID: PUERTO RICO', 'PUERTO RICO'),
        ('INVALID: UNITED STATES', 'UNITED STATES')
    ]
    df = reduce(
        lambda df, idx: df.withColumn(
            'name',
            when(df['name'] == name_to_change[idx][0], name_to_change[idx][1]).
            otherwise(df['name'])), range(len(name_to_change)), df)

    return df
Пример #12
0
 def _unpack_struct(self, df: DataFrame, col_name):
     sub_df = df.select(col_name + '.*')
     for subcol_name in sub_df.columns:
         df = df.withColumn(f'{col_name}_{subcol_name}',
                            df[col_name][subcol_name])
     df = df.drop(col_name)
     return self.unpack_nested(df)
def null_out_values(df: DataFrame, colname: str, values_to_null):
    """Null out a list of undesirable values in a column
    Useful for columns that mostly contain valid data but occasionally
    contain other values such as 'unknown'
    Args:
        df (DataFrame): The dataframe to clean
        colname (string): The name of the column to clean
        values_to_null: A list of values to be nulled.

    Returns:
        DataFrame: The cleaned dataframe with incoming column overwritten
    """

    if len(values_to_null) == 0:
        return df

    values_to_null_string = [f'"{v}"' for v in values_to_null]
    values_to_null_joined = ", ".join(values_to_null_string)

    case_statement = f"""
    CASE
    WHEN {colname} in ({values_to_null_joined}) THEN NULL
    ELSE {colname}
    END
    """

    df = df.withColumn(colname, f.expr(case_statement))

    return df
    def process_batch(self, df: DataFrame, batch_id):
        window = Window.partitionBy(FieldsName.USER_ID).orderBy(
            FieldsName.ACCESS_TIME)

        new_df = df.withColumn("before_time",
                               f.lag(FieldsName.ACCESS_TIME, 1).over(window))

        items_to_time = new_df.select(
            FieldsName.USER_ID, FieldsName.ITEMS,
            (f.col(FieldsName.ACCESS_TIME) -
             f.col("before_time")).alias("browsing_time"))

        window2 = Window.partitionBy(FieldsName.USER_ID).orderBy(
            f.col("browsing_time").desc())

        final_result = items_to_time.withColumn(
            "rank",
            f.rank().over(window2)).filter("rank <= 2")

        final_result = final_result.select(FieldsName.USER_ID,
                                           FieldsName.ITEMS, "browsing_time")

        for r in final_result.collect():
            self.send(topic="python_test1",
                      key=r[FieldsName.USER_ID],
                      value={
                          "items": r[FieldsName.ITEMS],
                          "browsing_time": r["browsing_time"]
                      })
Пример #15
0
    def __generate_target_fill(self, df: DataFrame, partition_cols: List[str],
                               ts_col: str, target_col: str) -> DataFrame:
        """
        Create columns for previous and next value for a specific target column

        :param df: input DataFrame
        :param partition_cols: partition column names
        :param ts_col: timestamp column name
        :param target_col: target column name
        """
        return (df.withColumn(
            f"previous_{target_col}",
            last(df[target_col], ignorenulls=True).over(
                Window.partitionBy(
                    *partition_cols).orderBy(ts_col).rowsBetween(
                        Window.unboundedPreceding, 0)),
        )
                # Handle if subsequent value is null
                .withColumn(
                    f"next_null_{target_col}",
                    last(df[target_col], ignorenulls=True).over(
                        Window.partitionBy(*partition_cols).orderBy(
                            col(ts_col).desc()).rowsBetween(
                                Window.unboundedPreceding, 0)),
                ).withColumn(
                    f"next_{target_col}",
                    lead(df[target_col]).over(
                        Window.partitionBy(*partition_cols).orderBy(ts_col)),
                ))
Пример #16
0
def transform_read_centerline_data(df: DataFrame) -> DataFrame:
    """Transforming centerline data to make it joinable, below are the things steps in high level

    1. Converted ST_LABEL & FULL_STREE to upper case
    2. Converted L_LOW_HN & L_HIGH_HN  separated by '-' for odd house number
    3. Converted R_LOW_HN & R_HIGH_HN  separated by '-' for even house number
    4. Removed any data having no house number in L_LOW_HN and R_LOW_HN
    """
    df = (df.select("PHYSICALID", "BOROCODE", "FULL_STREE", "ST_NAME",
                    "L_LOW_HN", "L_HIGH_HN", "R_LOW_HN", "R_HIGH_HN").orderBy(
                        "PHYSICALID", "BOROCODE", "FULL_STREE", "ST_NAME",
                        "L_LOW_HN", "L_HIGH_HN", "R_LOW_HN",
                        "R_HIGH_HN").coalesce(200).withColumn(
                            "ST_NAME", F.upper(F.col("ST_NAME"))).withColumn(
                                "FULL_STREE",
                                F.upper(F.col("FULL_STREE"))).filter(
                                    (F.col("L_LOW_HN").isNotNull())
                                    | (F.col("R_LOW_HN").isNotNull())))
    df = df.withColumn("L_TEMP_ODD", F.split("L_LOW_HN", "-")).withColumn(
        "L_LOW_HN",
        F.col("L_TEMP_ODD").getItem(0).cast("int") +
        F.when(F.col("L_TEMP_ODD").getItem(1).isNull(), "0").otherwise(
            F.col("L_TEMP_ODD").getItem(1)).cast("int") / 1000,
    )

    df = df.withColumn("L_TEMP_ODD", F.split("L_HIGH_HN", "-")).withColumn(
        "L_HIGH_HN",
        F.col("L_TEMP_ODD").getItem(0).cast("int") +
        F.when(F.col("L_TEMP_ODD").getItem(1).isNull(), "0").otherwise(
            F.col("L_TEMP_ODD").getItem(1)).cast("int") / 1000,
    )

    df = df.withColumn("L_TEMP_ODD", F.split("R_LOW_HN", "-")).withColumn(
        "R_LOW_HN",
        F.col("L_TEMP_ODD").getItem(0).cast("int") +
        F.when(F.col("L_TEMP_ODD").getItem(1).isNull(), "0").otherwise(
            F.col("L_TEMP_ODD").getItem(1)).cast("int") / 1000,
    )

    df = df.withColumn("L_TEMP_ODD", F.split("R_HIGH_HN", "-")).withColumn(
        "R_HIGH_HN",
        F.col("L_TEMP_ODD").getItem(0).cast("int") +
        F.when(F.col("L_TEMP_ODD").getItem(1).isNull(), "0").otherwise(
            F.col("L_TEMP_ODD").getItem(1)).cast("int") / 1000,
    )

    return df
Пример #17
0
def postcode_to_inward_outward(df: DataFrame,
                               pc_field: str,
                               drop_orig: bool = True):
    """Given a field containing a postcode, creates new columns in the dataframe
    called outward_postcode_std and inward_postcode_std

    Original postcode can have spaces or not and be in any case

    Args:
        df (DataFrame): Spark Dataframe
        pc_field (str): Name of field containing postcode
    """

    sql = f"upper(replace({pc_field}, ' ', ''))"
    df = df.withColumn("pc_nospace_temp__", expr(sql))

    # If the postcode is long enough, parse out inner outer
    # If it's too short, assume we only have the outer part

    sql = """
    case 
    when length(pc_nospace_temp__) >= 5 then left(pc_nospace_temp__, length(pc_nospace_temp__) - 3)
    else left(pc_nospace_temp__, 4)
    end
    """

    # sql = f"""left(pc_nospace_temp__, length(pc_nospace_temp__) - 3)"""
    df = df.withColumn("outward_postcode_std", expr(sql))

    sql = f"""right(pc_nospace_temp__, 3)"""

    sql = """
    case 
    when length(pc_nospace_temp__) >= 5 then right(pc_nospace_temp__, 3)
    else null 
    end
    """

    df = df.withColumn("inward_postcode_std", expr(sql))

    df = df.drop("pc_nospace_temp__")

    if drop_orig:
        df = df.drop(pc_field)

    return df
Пример #18
0
def generate_idx_for_df(df: DataFrame, col_name: str, col_schema):
    idx_udf = udf(lambda x: udf_array_to_map(x),
                  MapType(IntegerType(), col_schema, True))
    df = df.withColumn("map", idx_udf(col(col_name)))
    df = df.select("problem_type", "user_id", "oms_protected", "problem_id",
                   "create_at",
                   explode("map").alias("item_id", "answer"))
    return df
Пример #19
0
 def remove_illegal_chars(self, dataframe: DataFrame, source_column: str,
                          target_column: str):
     remover_udf = udf(lambda data: data.translate(
         {ord(i): self.replacement
          for i in self.chars}))
     return dataframe.withColumn(
         target_column,
         remover_udf(getattr(dataframe, source_column))).drop(source_column)
Пример #20
0
def convert_columns(data: DataFrame, column_names: [],
                    convert_func: F.udf) -> DataFrame:
    for col in column_names:
        data = data.withColumn(col, convert_func(col))

    print_data_info(data,
                    f'[DF] converted {convert_func.__name__}',
                    isDetailed=False)
    return data
Пример #21
0
 def remove_illegal_chars(self, dataframe: DataFrame, source_column: str,
                          target_column: str):
     df = dataframe.withColumn(
         target_column,
         regexp_replace(dataframe[source_column], self.pattern,
                        self.replacement),
     )
     df = df.drop(source_column)
     return df
Пример #22
0
def convert_heroes_to_lineup(df: DataFrame) -> DataFrame:

    def onehot(heroes: ArrayType):
        lineup = tuple(heroes_dict[hero] for hero in heroes)
        return Vectors.dense([1 if hero_slot in lineup else 0 for hero_slot in range(len(heroes_dict))])

    heros_to_lineup_udf = udf(onehot, VectorUDT())
    return df.withColumn("dire_lineup_vec", heros_to_lineup_udf(df.dire_lineup))\
             .withColumn("radiant_lineup_vec", heros_to_lineup_udf(df.radiant_lineup))
Пример #23
0
def transform_countries(df: SparkDataFrame) -> SparkDataFrame:
    """Transform countries data
    Add a column of lower case country name for joining purpose

    :param df: the countries data frame to be transformed
    :return: the transformed countries frame
    """

    df = df.withColumn('lower_name', lower(df['name']))
    return df
Пример #24
0
def cast(df: DataFrame, schemas: dict) -> DataFrame:
    """
    更改对应列的数据类型
    :param df: 输入的数据表
    :param schemas: 列名作为键,想要更改为的数据类型作为值,如{'sales':'int'};
                    数据类型包括'int','float','string','date','bool'等
    :return: 更改后的表
    """
    for col_name, d_type in schemas.items():
        df = df.withColumn(col_name, df[col_name].cast(d_type))
    return df
Пример #25
0
def replace(dataframe: DataFrame, column: str,
            replace_dict: Dict[str, str]) -> DataFrame:
    """Replace values of a string column in the dataframe using a dict.

    Example:

    >>> from butterfree.extract.pre_processing import replace
    ... from butterfree.testing.dataframe import (
    ...     assert_dataframe_equality,
    ...     create_df_from_collection,
    ... )
    >>> from pyspark import SparkContext
    >>> from pyspark.sql import session
    >>> spark_context = SparkContext.getOrCreate()
    >>> spark_session = session.SparkSession(spark_context)
    >>> input_data = [
    ...     {"id":1, "type": "a"}, {"id":2, "type": "b"}, {"id":3, "type": "c"}
    ... ]
    >>> input_df = create_df_from_collection(input_data, spark_context, spark_session)
    >>> input_df.collect()

    [Row(id=1, type='a'), Row(id=2, type='b'), Row(id=3, type='c')]

    >>> replace_dict = {"a": "type_a", "b": "type_b"}
    >>> replace(input_df, "type", replace_dict).collect()

    [Row(id=1, type='type_a'), Row(id=2, type='type_b'), Row(id=3, type='c')]

    Args:
        dataframe: data to be transformed.
        column: string column on the dataframe where to apply the replace.
        replace_dict: dict with values to be replaced.
            All mapped values must be string.

    Returns:
        Dataframe with column values replaced.

    """
    if not isinstance(dataframe, DataFrame):
        raise ValueError("dataframe needs to be a Pyspark DataFrame type")
    if (column not in dict(
            dataframe.dtypes)) or (dict(dataframe.dtypes)[column] != "string"):
        raise ValueError(
            "column needs to be the name of an string column in dataframe")
    if (not isinstance(replace_dict, dict)) or (not all(
            isinstance(value, str) for value in chain(*replace_dict.items()))):
        raise ValueError("replace_dict needs to be a Python dict with "
                         "all keys and values as string values")

    mapping = create_map(
        [lit(value) for value in chain(*replace_dict.items())]  # type: ignore
    )
    return dataframe.withColumn(column,
                                coalesce(mapping[col(column)], col(column)))
Пример #26
0
 def assign_sk(self, df: DataFrame, orderByCol: str):
     now = datetime.now()  # current date and time
     fmt = '%y%m%d%H'
     yymmddhh = now.strftime(fmt)
     df_with_row_num = df.withColumn(
         "row_num",
         row_number().over(Window.orderBy(col(orderByCol))))
     sk_df = df_with_row_num.select(
         concat(lit(yymmddhh), lpad(col("row_num"), 5,
                                    "0")).cast("long").alias("sys_sk"),
         col("*")).drop(col("row_num"))
     return sk_df
Пример #27
0
def cleanTransactions(df:DataFrame) -> DataFrame:
    """ custom function - flatten nested columns and cast column types"""
    if isinstance(df, DataFrame):
        df1 = df.withColumn("basket_explode", explode(col("basket"))).drop("basket")
        df2 = df1.select(col("customer_id"), \
            col("date_of_purchase"), \
            col("basket_explode.*") \
        ) \
        .withColumn("date", col("date_of_purchase").cast("Date")) \
        .withColumn("price", col("price").cast("Integer"))
        showMySchema(df2, "transactions") 
        return df2
Пример #28
0
def convert_column_type(df: SparkDataFrame, data_type: str,
                        cols: List[str]) -> SparkDataFrame:
    """Convert given columns to given data type.

    :param df: the spark data frame to be converted
    :param data_type: the data type to be used.
    :param cols: the column list to be converted
    :return: Converted spark data frame
    """

    for col in [col for col in cols if col in df.columns]:
        df = df.withColumn(col, df[col].cast(data_type))
    return df
Пример #29
0
    def remove_illegal_chars(self, dataframe: DataFrame, source_column: str,
                             target_column: str):

        remove_str = ''
        for char in self.chars:
            if re.match('[A-Za-z0-9]', char) is not None:
                remove_str = remove_str + char
            else:
                remove_str = remove_str + '\\' + char
        df = dataframe.withColumn(target_column, regexp_replace(source_column,
                                  '[' + remove_str + ']', self.replacement))\
                      .drop('string')
        return df
Пример #30
0
def null_out_entries_with_freq_above_n(df: DataFrame, colname: str, n: int,
                                       spark: SparkSession):
    """Null out values above a certain frequency threshold

    Useful for columns that mostly contain valid data but occasionally
    contain other values such as 'unknown'

    Args:
        df (DataFrame): The dataframe to clean
        colname (string): The name of the column to clean
        n (int): The maximum frequency allowed.  Any values with a frequency higher than n will be nulled out
        spark (SparkSession): The spark session

    Returns:
        DataFrame: The cleaned dataframe with incoming column overwritten
    """

    # Possible that a window function would be better than the following approach
    # But I think both require a shuffle so possibly doesn't make much difference

    df.createOrReplaceTempView("df")

    sql = f"""
    select {colname} as count
    from df
    group by {colname}
    having count(*) > {n}
    """

    df_groups = spark.sql(sql)

    collected = df_groups.collect()

    values_to_null = [row["count"] for row in collected]

    if len(values_to_null) == 0:
        return df

    values_to_null = [f'"{v}"' for v in values_to_null]
    values_to_null_joined = ", ".join(values_to_null)

    case_statement = f"""
    CASE
    WHEN {colname} in ({values_to_null_joined}) THEN NULL
    ELSE {colname}
    END
    """

    df = df.withColumn(colname, f.expr(case_statement))

    return df