示例#1
0
    def baseline_model(self, df_processed):
        cond_dvs = F.when(col("dvs_p30days") > 400,
                          400).otherwise(col("dvs_p30days"))


        stats=df_processed\
        .withColumn("dvs_p30days",cond_dvs)\
        .filter(col("recent_event_date")==1)\
        .filter(col("event_date")>=F.date_add(F.current_date(),-15))\
        .groupby(['grt_l2_cat_name'])\
        .agg(F.avg(col("dvs_p30days")).alias("avg_dvs_p30days")
            ,F.stddev(col("dvs_p30days")).alias("std_deal_view")
            ,F.round(F.avg(col("dvs_p30days")).cast('integer')).alias("avg_deal_view")
            ,F.max("dvs_p30days").alias("max_deal_view"))

        w = Window.partitionBy(F.col('consumer_id')).orderBy(
            F.col('normalized_dvs_p30days').desc())

        df_final=df_processed\
        .filter(col("recent_event_date")==1)\
        .filter(col("event_date")>=F.date_add(F.current_date(),-15))\
        .join(stats,on='grt_l2_cat_name')\
        .withColumn('normalized_dvs_p30days',(col('dvs_p30days')-col('avg_deal_view'))/col('std_deal_view'))\
        .withColumn('normalized_dvs_p30days_rank',F.row_number().over(w))


        df_micro=df_final\
        .filter(col('normalized_dvs_p30days')>=0)\
        .filter(col('normalized_dvs_p30days_rank')==1)\
        .filter(col('grt_l2_purchaser_14d')==0)

        return df_micro
示例#2
0
def metricRetention(data,
                    needed_dimension_variables,
                    feature_col,
                    sampling_multiplier,
                    activated=False):
    activity_data = data.filter(col(feature_col) > 0).select(
        ["id", "date", feature_col]).distinct()

    pcd_table = data.select(["date", "id", "bucket"] +
                            needed_dimension_variables)
    windowSpec = Window.partitionBy([pcd_table.id] +
                                    needed_dimension_variables).orderBy(
                                        pcd_table.date).rowsBetween(0, 13)
    for v in needed_dimension_variables:
        pcd_table = pcd_table.withColumn(v, F.last(v, True).over(windowSpec))
    pcd_table = pcd_table.filter(col("new_profile") == 1)

    if activated:
        pcd_table = pcd_table.alias("pcd_t").join(
            activity_data.alias("i_t"), (col('pcd_t.id') == col('i_t.id')) &
            (col('i_t.date') >= F.date_add(col('pcd_t.date'), 1)) &
            (col('i_t.date') <= F.date_add(col('pcd_t.date'), 6)),
            "inner").filter(col("i_t." + feature_col) > 0).dropDuplicates([
                'id'
            ]).select([
                col('pcd_t.{}'.format(c))
                for c in ['id', 'bucket', "date"] + needed_dimension_variables
            ])

    intermediate_table3 = pcd_table.alias("pcd_t").join(
        activity_data.alias("i_t"), (col('pcd_t.id') == col('i_t.id')) &
        (col('i_t.date') >= F.date_add(col('pcd_t.date'), 7)) &
        (col('i_t.date') <= F.date_add(col('pcd_t.date'), 13)),
        "outer").select([
            'pcd_t.{}'.format(c)
            for c in ['id', 'date', 'bucket'] + needed_dimension_variables
        ] + [feature_col]).fillna(0, [feature_col]).groupBy([
            'pcd_t.{}'.format(c)
            for c in ['id', 'date', 'bucket'] + needed_dimension_variables
        ], ).agg(F.max(col(feature_col))).drop(feature_col).withColumnRenamed(
            "MAX({})".format(feature_col), feature_col).select([
                col("pcd_t.{}".format(c)).alias(c)
                for c in ['id', 'bucket', 'date'] + needed_dimension_variables
            ] + [feature_col])

    intermediate_table4 = intermediate_table3.groupBy(
        ["date", "bucket"] +
        needed_dimension_variables).mean(feature_col).withColumnRenamed(
            'avg({})'.format(feature_col), feature_col)
    intermediate_table4_allbucket = intermediate_table3.groupBy(
        ["date"] +
        needed_dimension_variables).mean(feature_col).withColumnRenamed(
            'avg({})'.format(feature_col),
            feature_col).withColumn('bucket', lit("ALL"))

    joined_intermediate = intermediate_table4.unionByName(
        intermediate_table4_allbucket)
    return joined_intermediate
示例#3
0
def createDataFile(start_date, end_date, spark_instance, jackknife_buckets,
                   sample_percent, output_path):
    feature_data_phase1 = spark_instance.table(_TABLE_SOURCE).select([
        _COL_ID.alias("id"),
        _DATE_PARSED.alias("date"),
        # TODO: Use MD5 instead of CRC32
        ((F.crc32(_COL_ID) / 100) % jackknife_buckets).alias("bucket"),
        lit(1).alias("is_active"),
        F.when(_COL_URI_COUNT >= _NUM_ADAU_THRESHOLD, 1).otherwise(0).alias(
            "is_active_active"),
        F.to_date(_COL_PC_DATE).alias("profile_creation_date")
    ] + _MAP_NATURAL_DIMENSIONS.keys()).filter(
        (_DATE_PARSED.between(start_date, end_date))
        & (_COL_SAMPLE < sample_percent)).withColumn(
            "young_profile",
            F.when(
                col("date") < F.date_add(col("profile_creation_date"), 14),
                "TRUE").otherwise("FALSE"))

    new_profile_window = Window.partitionBy(col("id")).orderBy(col("date"))
    new_profile_data = feature_data_phase1.filter(
        (col("date") >= col("profile_creation_date"))
        & (col("date") <= F.date_add(col("profile_creation_date"), 6))).select(
            "*",
            F.rank().over(new_profile_window).alias('rank')).filter(
                col('rank') == 1).withColumn("new_profile",
                                             lit(1)).drop("date").withColumn(
                                                 "date",
                                                 col("profile_creation_date"))

    feature_data = feature_data_phase1.alias("fd").join(
        new_profile_data.alias("np"),
        (col("fd.id") == col("np.id")) & (col("fd.date") == col("np.date")),
        how='full',
    ).select(
        [F.coalesce(col("np.new_profile"), lit(0)).alias("new_profile")] +
        [F.coalesce(col("fd.is_active"), lit(0)).alias("is_active")] + [
            F.coalesce(col("fd.is_active_active"), lit(0)).alias(
                "is_active_active")
        ] + [
            F.coalesce(col("fd.{}".format(c)), col("np.{}".format(c))).alias(c)
            for c in feature_data_phase1.columns
            if c not in ["is_active", "is_active_active"]
        ])

    once_ever_profiles = feature_data.filter(
        col("is_active") == 1).groupBy("id").count().filter(
            col("count") == 1).select("id").withColumn("single_day_profile",
                                                       lit("1"))

    feature_data = feature_data.alias("fd").join(
        once_ever_profiles.alias("oep"), "id",
        "outer").fillna({"single_day_profile": "0"})

    feature_data.write.partitionBy("date").mode('overwrite').parquet(
        output_path)
示例#4
0
    def test_date_add_function(self):
        dt = datetime.date(2021, 12, 27)

        # Note; number var in Python gets converted to LongType column;
        # this is not supported by the function, so cast to Integer explicitly
        df = self.spark.createDataFrame([Row(date=dt, add=2)],
                                        "date date, add integer")

        self.assertTrue(
            all(
                df.select(
                    date_add(df.date, df.add) == datetime.date(2021, 12, 29),
                    date_add(df.date, "add") == datetime.date(2021, 12, 29),
                    date_add(df.date, 3) == datetime.date(2021, 12, 30),
                ).first()))
示例#5
0
    def process_present_df(self, present_df, table_name):
        """
        Process function for generate result dataframe that contains date,
        number of create events and growth rate of a day compare to last week
        """
        df_columns = present_df.columns
        df_first_record = present_df.first()
        keyword = 'object' if 'object' in df_first_record[
            'payload'] else 'ref_type'

        num_create_events_df = \
            present_df \
            .filter(col('payload')[keyword] == 'repository') \
            .filter((col('type') == 'CreateEvent') | (col('type') == 'Event'))

        num_create_events_by_date_df = \
            num_create_events_df \
            .groupby(to_date(present_df.created_at).alias('date_created_at')) \
            .count()

        return num_create_events_by_date_df.withColumn(
            'weekly_increase_rate',
            self.get_num_created_repo(
                table_name,
                date_add(num_create_events_by_date_df.date_created_at, -7)))
示例#6
0
    def main(self, sc: SparkContext, *args):
        os.makedirs(DATASET_DIR, exist_ok=True)

        spark = SparkSession(sc)
        df = spark.read.json(self.get_path_dataset())

        if not 'item_bought' in df.columns:
            df = df.withColumn('item_bought', lit(0))

        df = df.withColumn("session_id", F.monotonically_increasing_id())

        df = df.withColumn("event", explode(df.user_history))

        df = df.withColumn('event_info', col("event").getItem("event_info"))\
                .withColumn('event_timestamp', col("event").getItem("event_timestamp"))\
                .withColumn('event_type', col("event").getItem("event_type"))

        df_view = df.select("session_id", "event_timestamp", "event_info",
                            "event_type")

        df_buy = df.groupBy("session_id").agg(
            max(df.event_timestamp).alias("event_timestamp"),
            max(df.item_bought).alias("event_info"))
        df_buy = df_buy.withColumn('event_type', lit("buy"))
        df_buy = df_buy.withColumn('event_timestamp',
                                   F.date_add(df_buy['event_timestamp'], 1))

        df = df_view.union(df_buy)
        df = df.withColumn('event_timestamp2',
                           parse_date(col('event_timestamp')))

        df.orderBy(col('event_timestamp2')).toPandas().to_csv(
            self.output().path, index=False)
示例#7
0
def county_reality_supply():
    # 各区县各档位该品规上周投放量
    try:
        print(f"{str(dt.now())} 各区县各档位该品规上周投放量")
        plm_item = get_plm_item(spark).select("item_id", "item_name")

        co_cust = get_co_cust(spark).select("cust_id", "sale_center_id",
                                            "cust_seg")

        area = get_area(spark)
        # com_id与city的映射关系
        city = area.dropDuplicates(["com_id"]).select("com_id", "city")
        # sale_center_id与区(list)的映射关系
        county = area.groupBy("sale_center_id") \
            .agg(f.collect_list("county").alias("county")) \
            .select("sale_center_id", "county")

        # 获取上周实际投放量
        # cust_item_spw = spark.sql(
        #     "select com_id,cust_id,item_id,qty_allocco,begin_date,end_date from DB2_DB2INST1_SGP_CUST_ITEM_SPW") \
        #     .withColumn("begin_date", f.to_date(col("begin_date"), "yyyyMMdd")) \
        #     .withColumn("end_date", f.to_date(col("end_date"), "yyyyMMdd")) \
        #     .withColumn("last_mon", f.date_sub(f.date_trunc("week", f.current_date()), 7)) \
        #     .withColumn("last_sun", f.date_add(col("last_mon"), 6)) \
        #     .where((col("begin_date") == col("last_mon")) & (col("end_date") == col("last_sun")))\
        #     .join(co_cust,"cust_id")

        cust_item_spw = spark.sql(
            "select com_id,cust_id,item_id,qty_allocco,begin_date,end_date from DB2_DB2INST1_SGP_CUST_ITEM_SPW") \
            .withColumn("begin_date", f.to_date(col("begin_date"), "yyyyMMdd")) \
            .withColumn("end_date", f.to_date(col("end_date"), "yyyyMMdd")) \
            .withColumn("last_mon", f.date_sub(f.date_trunc("week", f.current_date()), 7 * 4)) \
            .withColumn("last_sun", f.date_add(col("last_mon"), 6 + 7 * 3)) \
            .where((col("begin_date") >= col("last_mon")) & (col("end_date") <= col("last_sun")))\
            .join(co_cust,"cust_id")

        #需要计算的值的列名
        colName = "county_gauge_week_volume"
        result = cust_item_spw.groupBy("com_id","sale_center_id","cust_seg", "item_id") \
                                .agg(f.sum("qty_allocco").alias(colName))

        columns = [
            "com_id", "city", "sale_center_id", "county", "gears", "gauge_id",
            "gauge_name", "city", "gears_data_marker", colName
        ]
        result.withColumn("row", f.concat_ws("_", col("sale_center_id"),col("cust_seg"), col("item_id"))) \
            .withColumn("gears_data_marker", f.lit("4")) \
            .join(plm_item, "item_id") \
            .join(city, "com_id") \
            .join(county,"sale_center_id")\
            .withColumnRenamed("item_id","gauge_id")\
            .withColumnRenamed("item_name","gauge_name")\
            .withColumnRenamed("cust_seg","gears")\
            .foreachPartition(lambda x: write_hbase1(x, columns, hbase))
    except Exception:
        tb.print_exc()
示例#8
0
def week_start_date(col, week_start_day="Sun"):
    _raise_if_invalid_day(week_start_day)
    # the "standard week" in Spark is from Sunday to Saturday
    mapping = {
        "Sun": "Sat",
        "Mon": "Sun",
        "Tue": "Mon",
        "Wed": "Tue",
        "Thu": "Wed",
        "Fri": "Thu",
        "Sat": "Fri",
    }
    end = week_end_date(col, mapping[week_start_day])
    return F.date_add(end, -6)
示例#9
0
def week_start_date(col, week_start_day='Sun'):
    _raise_if_invalid_day(week_start_day)
    # the "standard week" in Spark is from Sunday to Saturday
    mapping = {
        'Sun': 'Sat',
        'Mon': 'Sun',
        'Tue': 'Mon',
        'Wed': 'Tue',
        'Thu': 'Wed',
        'Fri': 'Thu',
        'Sat': 'Fri'
    }
    end = week_end_date(col, mapping[week_start_day])
    return F.date_add(end, -6)
示例#10
0
def to_friday(df, dt):
    '''Convert all days to Fridays

    Note that this is not forward mapping as Saturday and Sunday are mapped backward.
    This is fine since all the data are supposed to be in "business days"
    To be consistent over the weekends, use the timestamps'''
    cols = df.columns
    # Convert all days to Friday of the week
    df = df.withColumn('friday', F.date_add(F.date_trunc('week', dt), 4))
    # Keep only the last record in each week
    w = Window.partitionBy('friday').orderBy(F.col(dt).desc())
    df = df.withColumn('rn', F.row_number().over(w)).where(F.col('rn') == 1)
    df = df.drop(dt).withColumnRenamed('friday', dt)
    return df.select(*cols)
示例#11
0
  def in_days_of_week(self, days: List[Text], input_time_column=None) -> "Ezlink":
    """Return an Ezlink object that is representative for the selected days of week."""
    if self.meta.get("in_days_of_week"):
      raise RuntimeError("Ezlink is already in_days_of_week="
                         .format(self.meta.get("in_days_of_week")))
    # if input time is before 3am, the ride date should be 1 day earlier
    # Ride on Monday 2:30am = Sunday night ride
    input_time_column = input_time_column if input_time_column else self.tap_in_time
    dataframe = (self.dataframe
                 .withColumn('dayofweek',
                             F.when(F.date_format(input_time_column, "HH") < 3,
                                    F.date_format(F.date_add(input_time_column, -1), "EEEE"))
                             .otherwise(F.date_format(input_time_column, "EEEE"))))

    dataframe = (dataframe
                 .filter(F.col('dayofweek').isin(days))
                 .drop('dayofweek'))
    return (Ezlink(dataframe, **self.columns)
            .annotate("in_days_of_week", days))
def get_data_last_ndays(events_df, ndays=20, page_filter=None):
    """
    Get the last N days of the log events
    Args:
        events_df: Pyspark dataframe
        page_filter: page to filter
    """
    if page_filter:
        usage_days_df = events_df \
            .where(events_df.page == "NextSong")
    else:
        usage_days_df = events_df

    usage_days_df = usage_days_df \
        .select('userId', 'date', 'churn') \
        .groupBy('userId') \
        .agg(F.max(events_df.date), F.min(events_df.date)) \
        .withColumnRenamed('max(date)', 'last_day') \
        .withColumnRenamed('min(date)', 'first_day') \
        .withColumn("{}_days_before".format(ndays), F.date_add(col("last_day"), - ndays + 1))\
        .filter(account_age_in_days(col("last_day"), col("first_day")) >= ndays)
    return usage_days_df
示例#13
0
    def run_preprocessor(self):
        if type(self) == data_preprocessor:
            raise NotImplementedError(
                "Method need to be called in sub-class but currently called in base class"
            )

        from pyspark.sql.functions import struct, col, split, date_add
        try:
            return self.spark.read.parquet(self.out_file_name)
        except Exception as ex:
            template = "An exception of type {0} occurred. Arguments:\n{1!r}"
            message = template.format(type(ex).__name__, ex.args)
            self.logger.info(message)
            self.logger.info("PROCESS")
            self.logger.info(self.out_file_name)

        try:
            return self.spark.read.parquet(self.out_file_name)
        except Exception as ex:
            template = "An exception of type {0} occurred. Arguments:\n{1!r}"
            message = template.format(type(ex).__name__, ex.args)
            self.logger.info(message)
            self.logger.info("PROCESS_PREPROCESS")
            self.logger.info(self.out_file_name)


        self.cur_obs_bf_dropna = self.obs_df.select("ID", "ITEMID", "VALUE", "TIME_OBS",
                                          split("TIME_OBS", "\ ").getItem(0).alias("DATE_OBS")) \
            .withColumn("TIME_SPAN", struct(col("DATE_OBS").cast("timestamp").alias("TIME_FROM") \
                                            , date_add("DATE_OBS", 1).cast("timestamp").alias(
                "TIME_TO")))

        self.cur_obs = self.cur_obs_bf_dropna.dropna()

        self.run_remaining()
        return self.spark.read.parquet(self.out_file_name)
示例#14
0
class preprocessor_gen():
    def __init__(self):
        cur_spark_and_logger = spark_and_logger()
        spark_inited = False
        self.spark = cur_spark_and_logger.spark
        self.logger = cur_spark_and_logger.logger
        if spark_inited:
            self.logger.info("SPARK_INITED")

    def num_cat_tagger(self,
                       data_frame,
                       inputCol="VALUE",
                       outputCol="IS_CAT",
                       labelCol="ITEMID",
                       REPARTITION_CONST=None,
                       nominal_th=2):
        #DL0411: Output should be list of features with corresponding num/cat instances
        from pyspark.sql.functions import size, collect_set
        if REPARTITION_CONST is None:
            get_nominal_var = data_frame.repartition(labelCol).groupBy(labelCol).agg(size(collect_set(inputCol)).alias("value_cnt"))\
                         .where("value_cnt<={0}".format(nominal_th)).select(labelCol).rdd.flatMap(list).collect()
        else:
            get_nominal_var = data_frame.repartition(REPARTITION_CONST).groupBy(labelCol).agg(size(collect_set(inputCol)).alias("value_cnt"))\
                         .where("value_cnt<={0}".format(nominal_th)).select(labelCol).rdd.flatMap(list).collect()
        self.logger.debug(get_nominal_var)
        data_frame.show()

        self.logger.debug("[PREPROCESSOR_GEN]GET_BINARY_VAR")
        if REPARTITION_CONST is None:
            ret_data_frame = data_frame.withColumn(outputCol,
                when((col(inputCol).rlike('^(?!-0?(\.0+)?(E|$))-?(0|[1-9]\d*)?(\.\d+)?(?<=\d)(E-?(0|[1-9]\d*))?$'))&
                     (~col(labelCol).isin(get_nominal_var)),lit("0"))\
                .otherwise(lit("1")))
        else:
            ret_data_frame = data_frame.withColumn(outputCol,
                when((col(inputCol).rlike('^(?!-0?(\.0+)?(E|$))-?(0|[1-9]\d*)?(\.\d+)?(?<=\d)(E-?(0|[1-9]\d*))?$'))&
                     (~col(labelCol).isin(get_nominal_var)),lit("0"))\
                .otherwise(lit("1"))).repartition(REPARTITION_CONST)
        self.logger.debug("[PREPROCESSOR_GEN]RET_DATA_FRAME")
        return ret_data_frame

    def count_instance(self, raw_feature1, raw_feature2=None):
        from pyspark.sql.functions import collect_set, size
        raw1_distinct_instance = raw_feature1\
            .select("ID","TIME_SPAN").withColumn("T1",col("TIME_SPAN").TIME_FROM.cast("timestamp"))\
            .withColumn("T2",col("TIME_SPAN").TIME_TO.cast("timestamp")).select("ID","T1","T2")
        if raw_feature2 is not None:
            raw2_distinct_instance = raw_feature2\
                .select("ID","TIME_SPAN").withColumn("T1",col("TIME_SPAN").TIME_FROM.cast("timestamp"))\
                .withColumn("T2",col("TIME_SPAN").TIME_TO.cast("timestamp")).select("ID","T1","T2")
        else:
            raw2_distinct_instance = None
        if raw2_distinct_instance is not None:
            all_inst = raw1_distinct_instance.unionAll(raw2_distinct_instance)
        else:
            all_inst = raw1_distinct_instance

        return all_inst.distinct().groupBy("ID").agg(
            collect_set("T1").alias("collected_set")).select(
                "ID",
                size("collected_set").alias("list_size")).rdd.map(
                    lambda x: x.list_size).reduce(lambda a, b: a + b)

    def num_iqr_filter(self,
                       data_frame,
                       inputCol="VALUE",
                       labelCol="ITEMID",
                       REPARTITION_CONST=None,
                       sc=None):
        from pyspark.sql.window import Window
        from pyspark.sql.functions import abs, percent_rank, row_number, collect_list, udf, struct, count, avg, stddev_pop
        from pyspark.sql.types import MapType, StringType, DoubleType
        self.logger.debug("[NUM_IQR_FILTER] IN")
        self.logger.debug("[NUM_IQR_FILTER] BEFORE QUANTILE TAGGED")
        if REPARTITION_CONST is None:
            value_order = Window.partitionBy(labelCol).orderBy(
                col(inputCol).cast("float"))
            Q1_percentile = Window.partitionBy(labelCol).orderBy(
                abs(0.25 - col("percentile")))
            Q2_percentile = Window.partitionBy(labelCol).orderBy(
                abs(0.5 - col("percentile")))
            Q3_percentile = Window.partitionBy(labelCol).orderBy(
                abs(0.75 - col("percentile")))
            percent_data_frame = data_frame.select(
                labelCol, inputCol,
                percent_rank().over(value_order).alias("percentile"))
            Q1_data_frame = percent_data_frame.withColumn("Q1_rn",row_number().over(Q1_percentile)).where("Q1_rn == 1")\
                            .select(labelCol,inputCol,lit("Q1").alias("quantile"))
            Q2_data_frame = percent_data_frame.withColumn("Q2_rn",row_number().over(Q2_percentile)).where("Q2_rn == 1")\
                            .select(labelCol,inputCol,lit("Q2").alias("quantile"))
            Q3_data_frame = percent_data_frame.withColumn("Q3_rn",row_number().over(Q3_percentile)).where("Q3_rn == 1")\
                            .select(labelCol,inputCol,lit("Q3").alias("quantile"))
            self.logger.debug(
                "[NUM_IQR_FILTER] REPARTITION_CONST Not Asserted")
            merge_all = Q1_data_frame.unionAll(Q2_data_frame).unionAll(
                Q3_data_frame).persist()

            self.logger.debug("[NUM_IQR_FILTER] Qs Merged")

            udf_parse_list_to_map = udf(
                lambda maps: dict(list(tuple(x) for x in maps)),
                MapType(StringType(), StringType()))

            self.logger.debug("[NUM_IQR_FILTER] Before merge quantiles")
            aggregate_quantiles = merge_all.groupBy(labelCol).agg(
                collect_list(struct("quantile", inputCol)).alias("quant_info"))
            self.logger.debug("AGG ONLY")
            aggregate_quantiles = aggregate_quantiles.select(
                labelCol,
                udf_parse_list_to_map("quant_info").alias("quant_info"))
            self.logger.debug("TRANSFORM")
            iqr_data_frame = aggregate_quantiles.withColumn("Q1",col("quant_info").getItem("Q1").cast("float"))\
                .withColumn("Q2",col("quant_info").getItem("Q2").cast("float"))\
                .withColumn("Q3",col("quant_info").getItem("Q3").cast("float"))
            self.logger.debug("QUANTILE_EXTRACTION")
        else:
            cur_label_list = data_frame.select(
                labelCol).distinct().rdd.flatMap(list).collect()
            cur_iqr_list = list()
            cnt = -1
            for cur_item in cur_label_list:
                cnt = cnt + 1
                data_frame.where(
                    col(labelCol) == cur_item).registerTempTable("cur_table")
                self.logger.debug("{0}/{1},::{2}".format(
                    cnt, len(cur_label_list),
                    sc.sql(
                        "select {0} from cur_table".format(labelCol)).count()))
                cur_iqr = sc.sql(
                    "select {0}, percentile_approx({1},0.25) as Q1, percentile_approx({2},0.5) as Q2, percentile_approx({3},0.75) as Q3 from cur_table group by {4}"
                    .format(labelCol, inputCol, inputCol, inputCol,
                            labelCol)).first().asDict()
                cur_iqr_list.append(cur_iqr)
                sc.catalog.dropTempView("cur_table")
                #percent_data_frame = data_frame.select(labelCol, inputCol, percent_rank().over(value_order).alias("percentile")).repartition(REPARTITION_CONST).cache().checkpoint()
            iqr_data_frame = sc.createDataFrame(cur_iqr_list).repartition(
                REPARTITION_CONST)

        if REPARTITION_CONST is None:
            iqr_data_frame = iqr_data_frame.withColumn("IQR",col("Q3")-col("Q1"))\
                                       .withColumn("LB",col("Q1")-1.5*col("IQR"))\
                                       .withColumn("UB",col("Q3")+1.5*col("IQR"))\
                                       .select(labelCol,"LB","UB")
        else:
            iqr_data_frame = iqr_data_frame.withColumn("IQR",col("Q3")-col("Q1"))\
                                       .withColumn("LB",col("Q1")-1.5*col("IQR"))\
                                       .withColumn("UB",col("Q3")+1.5*col("IQR"))\
                                       .select(labelCol,"LB","UB").repartition(REPARTITION_CONST).persist()

            self.logger.debug("CUR_ITEMID_ALL_COUNT:{0}".format(
                iqr_data_frame.count()))

        self.logger.debug("[NUM_IQR_FILTER] iqr_data_frame merged")
        if REPARTITION_CONST is None:
            self.logger.debug(
                "[NUM_IQR_FILTER] RETURN_PREP, REPARTITION_CONST NOT ASSERTED")
            ret_data_frame = data_frame.repartition(labelCol).join(iqr_data_frame,labelCol).where((col("LB").cast("float") <= col(inputCol).cast("float")) & (col("UB").cast("float")>=col(inputCol).cast("float")))\
                                                                 .drop("LB").drop("UB").persist()
            ref_df = ret_data_frame.repartition(labelCol).groupBy(labelCol)\
                               .agg(count(inputCol).alias("ref_count"),avg(inputCol).alias("ref_avg"),stddev_pop(inputCol).alias("ref_std")).persist()
            self.logger.debug("CHECK DF")
            self.logger.debug(ref_df.count())
            self.logger.debug(ret_data_frame.count())

            return (ret_data_frame, ref_df)
        else:
            self.logger.debug(
                "[NUM_IQR_FILTER] RETURN_PREP, REPARTITION_CONST ASSERTED: {0}"
                .format(REPARTITION_CONST))
            ret_data_frame = data_frame.join(iqr_data_frame,labelCol).where((col("LB").cast("float") <= col(inputCol).cast("float")) & (col("UB").cast("float")>=col(inputCol).cast("float")))\
                                                                 .drop("LB").drop("UB").repartition(REPARTITION_CONST)
            ref_df = ret_data_frame.groupBy(labelCol)\
                               .agg(count(inputCol).alias("ref_count"),avg(inputCol).alias("ref_avg"),stddev_pop(inputCol).alias("ref_std")).repartition(REPARTITION_CONST)

            return (ret_data_frame, ref_df)

    def cat_frequency_filter(self,
                             data_frame,
                             threshold_lb=0,
                             threshold_ub=1,
                             inputCol="VALUE",
                             labelCol="ITEMID",
                             REPARTITION_CONST=None):
        from pyspark.sql.functions import count, monotonically_increasing_id
        label_count = data_frame.groupBy(labelCol).agg(
            count("*").alias("label_count"))
        if REPARTITION_CONST is not None:
            label_count = label_count.repartition(REPARTITION_CONST).persist()
        self.logger.debug("[CAT_FREQUENCY_FILTER] label_count done]")
        if REPARTITION_CONST is None:
            cur_freq = data_frame.groupBy(labelCol,inputCol).agg(count("*").alias("indiv_count")).join(label_count,labelCol)\
                                .withColumn("cat_freq",col("indiv_count")/col("label_count"))
        else:
            cur_freq = data_frame.groupBy(labelCol,inputCol).agg(count("*").alias("indiv_count")).repartition(REPARTITION_CONST).join(label_count,labelCol)\
                                .withColumn("cat_freq",col("indiv_count")/col("label_count")).repartition(REPARTITION_CONST).checkpoint()
            self.logger.debug(
                "[CAT_FREQUENCY_FILTER] CUR_FREQ_CHECKPOINTED:{0}".format(
                    cur_freq.count()))
        self.logger.debug("[CAT_FREQUENCY_FILTER] frequency calc done")
        cur_freq = cur_freq.where((col("cat_freq") >= threshold_lb) & (
            col("cat_freq") <= threshold_ub)).drop("cat_freq")
        #cur_freq.orderBy(col(labelCol),-1*col("cat_freq")).show(500)

        ret_data_frame = data_frame.join(
            cur_freq,
            [inputCol, labelCol]).drop("indiv_count").drop("label_count")
        self.logger.debug("[CAT_FREQUENCY_FILTER] ret_df prepared")
        if REPARTITION_CONST is not None:
            ret_data_frame = ret_data_frame.repartition(REPARTITION_CONST)
        ret_voca = ret_data_frame.select("ITEMID", "VALUE").distinct()
        ret_voca = ret_voca.rdd.map(lambda x: (x.ITEMID, x.VALUE)
                                    ).zipWithUniqueId().map(lambda x: {
                                        "idx": x[1],
                                        "ITEMID": x[0][0],
                                        "VALUE": x[0][1]
                                    }).toDF()
        if REPARTITION_CONST is not None:
            ret_voca = ret_voca.repartition(REPARTITION_CONST)
        self.logger.debug("[CAT_FREQUENCY_FILTER] ALL DONE")
        return (ret_data_frame, ret_voca)

    @staticmethod
    def calc_summary_stat(x, labelCol):
        import numpy as np
        cur_array = np.array(x, dtype=float)
        ret_dict = dict()
        ret_dict["N_{0}_avg".format(labelCol)] = float(np.average(cur_array))
        ret_dict["N_{0}_min".format(labelCol)] = float(np.min(cur_array))
        ret_dict["N_{0}_max".format(labelCol)] = float(np.max(cur_array))
        ret_dict["N_{0}_std".format(labelCol)] = float(np.std(cur_array))
        ret_dict["N_{0}_count".format(labelCol)] = float(
            np.shape(cur_array)[0])
        return ret_dict

    @staticmethod
    def merge_dict_all(x):  #will not be used outside of the package
        ret_dict = dict()
        for cur_dict in x:
            ret_dict.update(cur_dict)
        return ret_dict

    @staticmethod
    def sustainment_quantifier(x, cur_label, ref_avg, ref_std, ref_count):
        from scipy.stats import ttest_ind_from_stats
        import numpy as np
        ret_dict = dict()
        statistic, p_val = ttest_ind_from_stats(
            x["N_{0}_avg".format(cur_label)],
            x["N_{0}_std".format(cur_label)],
            x["N_{0}_count".format(cur_label)],
            ref_avg,
            ref_std,
            ref_count,
            equal_var=True)
        if not np.isnan(statistic):
            if statistic > 0:
                one_tailed_pval = 1.0 - p_val / 2.0
            else:
                one_tailed_pval = p_val / 2.0
            ret_dict["N_{0}_TT".format(cur_label)] = float(p_val)
            ret_dict["N_{0}_LT".format(cur_label)] = float(one_tailed_pval)
        return ret_dict

    def num_featurizer(self,
                       data_frame,
                       ref_df=None,
                       featurize_process=["summary_stat", "sustainment_q"],
                       inputCol="VALUE",
                       labelCol="ITEMID",
                       outputCol="num_features",
                       REPARTITION_CONST=None):
        from pyspark.sql.functions import udf, array
        from pyspark.sql.types import StringType, DoubleType, MapType
        if not data_frame:
            return
        ret_data_frame = self.value_aggregator(data_frame)
        if REPARTITION_CONST is not None:
            ret_data_frame = ret_data_frame.checkpoint()
            self.logger.debug(
                "[NUM_FEATURIZER] ret_dataframe checkpointed:{0}".format(
                    ret_data_frame.count()))
        if "summary_stat" in featurize_process:
            udf_summary_stat = udf(preprocessor_gen.calc_summary_stat,
                                   MapType(StringType(), DoubleType()))
            ret_data_frame = ret_data_frame.withColumn(
                "summary_stat", udf_summary_stat(inputCol + "_LIST", labelCol))
            if REPARTITION_CONST is not None:
                ret_data_frame = ret_data_frame.checkpoint()
                self.logger.debug(
                    "[NUM_FEATURIZER] summary_stat, ret_dataframe checkpointed:{0}"
                    .format(ret_data_frame.count()))

        if "sustainment_q" in featurize_process:
            udf_sustainment_quant = udf(
                preprocessor_gen.sustainment_quantifier,
                MapType(StringType(), DoubleType()))
            ret_data_frame = ret_data_frame.join(ref_df, labelCol).withColumn(
                "sustainment_q",
                udf_sustainment_quant("summary_stat", labelCol, "ref_avg",
                                      "ref_std",
                                      "ref_count")).drop("ref_avg").drop(
                                          "ref_std").drop("ref_count")
            if REPARTITION_CONST is not None:
                ret_data_frame = ret_data_frame.checkpoint()
                self.logger.debug(
                    "[NUM_FEATURIZER] sustainment_q, ret_dataframe checkpointed:{0}"
                    .format(ret_data_frame.count()))
        udf_merge_dict_all = udf(preprocessor_gen.merge_dict_all,
                                 MapType(StringType(), DoubleType()))
        ret_data_frame = ret_data_frame.withColumn(
            outputCol, udf_merge_dict_all(array(featurize_process)))
        return ret_data_frame

    def cat_featurizer(self,
                       data_frame,
                       voca_df,
                       inputCol="VALUE",
                       labelCol="ITEMID",
                       outputCol="cat_features",
                       REPARTITION_CONST=None):
        def prep_cat_dict(avail, pos):  # internal
            ret_dict_key = list(map(lambda x: "C_" + str(x), avail))
            ret_dict = dict(zip(ret_dict_key, [0.0] * len(ret_dict_key)))
            pos_dict_key = list(map(lambda x: "C_" + str(x), pos))
            update_dict = dict(zip(pos_dict_key, [1.0] * len(pos_dict_key)))
            ret_dict.update(update_dict)
            return ret_dict

        from pyspark.sql.functions import udf, collect_set
        from pyspark.sql.types import MapType, StringType, DoubleType
        if not data_frame:
            return

        all_var = voca_df.groupBy(labelCol).agg(
            collect_set("idx").alias("AVAIL_LIST"))
        ret_data_frame = data_frame.join(
            voca_df, [inputCol, labelCol]).drop("VALUE").withColumnRenamed(
                "idx", "VALUE")
        if REPARTITION_CONST is not None:
            ret_data_frame = ret_data_frame.repartition(
                REPARTITION_CONST).checkpoint()
            self.logger.debug(
                "[CAT_FEATURIZER] VOCA_JOINED_CHECKPOINTED:{0}".format(
                    ret_data_frame.count()))
        ret_data_frame = self.value_aggregator(ret_data_frame).join(
            all_var, "ITEMID")
        self.logger.debug(ret_data_frame.select(labelCol).distinct().count())
        self.logger.debug("[CAT_FEATURIZER] VALUE_AGGREGATOR OUT")
        udf_prep_cat_dict = udf(prep_cat_dict,
                                MapType(StringType(), DoubleType()))
        ret_data_frame = ret_data_frame.withColumn(
            "cat_features", udf_prep_cat_dict("AVAIL_LIST", "VALUE_LIST"))
        return ret_data_frame

    def availability_filter(self,
                            data_frame,
                            n_inst=None,
                            availability_th=0.80,
                            labelCol="ITEMID",
                            idCol="ID",
                            timeCol="TIME_SPAN",
                            REPARTITION_CONST=None):
        from pyspark.sql.functions import count
        if not n_inst:
            total_cnt = data_frame.select(idCol, timeCol).distinct().count()
            self.logger.debug(total_cnt)
        else:
            total_cnt = n_inst
        target_label_set = data_frame.select(idCol, timeCol,
                                             labelCol).distinct()

        if REPARTITION_CONST is None:
            target_label_set = target_label_set.groupBy(labelCol).agg(
                (count("*") / float(total_cnt)).alias("freq"))
        else:
            target_label_set = target_label_set.repartition(REPARTITION_CONST)\
                .groupBy(labelCol).agg((count("*")/float(total_cnt)).alias("freq"))

        target_label_set.orderBy(col("freq").desc()).show()

        if REPARTITION_CONST is not None:
            target_label_set = target_label_set.repartition(
                REPARTITION_CONST).checkpoint()
            self.logger.info(
                "[AVAILABILITY_FILTER] target_label_Set checkpointed:{0}".
                format(target_label_set.count()))
        self.logger.info(target_label_set.rdd.toDebugString())
        target_label_set = target_label_set.where(
            col("freq") >= availability_th).select(labelCol).rdd.flatMap(
                list).collect()
        self.logger.info(target_label_set)
        if len(target_label_set) == 0:
            return
        ret_data_frame = data_frame.where(col(labelCol).isin(target_label_set))

        self.logger.debug(target_label_set)
        self.logger.debug(len(target_label_set))
        return ret_data_frame

    @staticmethod
    def check_key_in_dict(target_dict, target_key):
        return target_dict.has_key(target_key)

    def flattener_df_prep(self,
                          data_frame,
                          descCol=["ID", "TIME_SPAN"],
                          inputCol="feature_aggregated",
                          drop_cnt=True):
        from pyspark.sql import Row
        from pyspark.sql.functions import col, udf, lit
        from pyspark.sql.types import StructType, StructField, StringType, DoubleType, BooleanType
        data_frame.show()
        desc_schema = data_frame.select(descCol).schema
        all_feature_column = list(
            data_frame.select(inputCol).rdd.map(lambda x: set(x[inputCol].keys(
            ))).reduce(lambda a, b: a.union(b)))
        ret_df = data_frame.rdd.map(lambda x: x.asDict()).map(lambda cur_item: [dict((cur_col,cur_item[cur_col]) for cur_col in descCol)]\
                                                                                + [cur_item[inputCol]]).map(lambda x: preprocessor_gen.merge_dict_all(x))
        inst_count = data_frame.count()
        ret_schema = desc_schema
        ret_feature_col = list()
        key_checker = udf(preprocessor_gen.check_key_in_dict, BooleanType())
        for cur_col in all_feature_column:
            #        print ("{0}//{1}//{2}".format(data_frame.where(key_checker(inputCol,lit(cur_col))).count(),inst_count,cur_col))
            #        if (data_frame.where(key_checker(inputCol,lit(cur_col))).count() == inst_count):
            #            continue
            if drop_cnt:
                if cur_col.find("count") != -1:
                    continue
            ret_feature_col.append(cur_col)
            ret_schema = ret_schema.add(
                StructField(cur_col, DoubleType(), True))
        return (ret_df, ret_schema, ret_feature_col)

    def value_aggregator(self,
                         data_frame,
                         aggregateCols=["ID", "TIME_SPAN", "ITEMID"],
                         catmarkerCol="IS_CAT",
                         inputCol="VALUE",
                         outputCol="VALUE_LIST"):
        from pyspark.sql.functions import collect_set, collect_list
        cat_data_agg_frame = data_frame.where(col(catmarkerCol) == 1).groupBy(aggregateCols+[catmarkerCol])\
                                                                     .agg(collect_set(inputCol).alias(outputCol))
        num_data_agg_frame = data_frame.where(col(catmarkerCol) == 0).groupBy(aggregateCols+[catmarkerCol])\
                                                                     .agg(collect_list(inputCol).alias(outputCol))
        return cat_data_agg_frame.unionAll(num_data_agg_frame)

    def feature_aggregator(self,
                           num_features,
                           cat_features,
                           catinputCol="cat_features",
                           numinputCol="num_features",
                           aggregatorCol=["ID", "TIME_SPAN"],
                           outputCol="feature_aggregated",
                           idCol="ID",
                           REPARTITION_CONST=None):
        from pyspark.sql.functions import col, udf, collect_list, rand
        from pyspark.sql.types import MapType, StringType, DoubleType

        if not num_features:
            ret_data_frame = cat_features.withColumnRenamed(
                catinputCol, "features")
            self.logger.debug("CAT_ONLY")
        elif not cat_features:
            ret_data_frame = num_features.withColumnRenamed(
                numinputCol, "features")
            self.logger.debug("NUM_ONLY")
        else:
            ret_data_frame = num_features.select(aggregatorCol+[numinputCol]).withColumnRenamed(numinputCol,"features")\
                                     .unionAll(cat_features.select(aggregatorCol+[catinputCol]).withColumnRenamed(catinputCol,"features"))
            self.logger.debug("BOTH")
        if REPARTITION_CONST is not None:
            ret_data_frame = ret_data_frame.repartition(
                REPARTITION_CONST).checkpoint()
            self.logger.debug(
                "[FEATURE_AGGREGATOR] ret_data_Frame checkpointed before groupby:{0}"
                .format(ret_data_frame.count()))
        udf_merge_dict_all = udf(preprocessor_gen.merge_dict_all,
                                 MapType(StringType(), DoubleType()))
        self.logger.debug("[FEATURE_AGGREGATOR] before groupby")
        # FROM HERE
        ret_data_frame = ret_data_frame.groupBy(aggregatorCol)
        if REPARTITION_CONST is not None:
            ret_data_frame = ret_data_frame.agg(
                collect_list("features").alias("features")).repartition(
                    REPARTITION_CONST).checkpoint()
            self.logger.debug(
                "[FEATURE_AGGREGATOR] ret_data_frame chkpointed:{0}".format(
                    ret_data_frame.count()))
        else:
            ret_data_frame = ret_data_frame.agg(
                collect_list("features").alias("features"))
        #ret_data_frame.orderBy(rand()).show(truncate=False)
        ret_data_frame = ret_data_frame.withColumn(
            outputCol, udf_merge_dict_all("features"))
        return ret_data_frame
        # TILL HERE. Maybe something triggers this weird spark burst

    def post_processor(data_frame,
                       algo="None",
                       outputCol="features_postprocessed"):
        ret_data_frame = data_frame
        return ret_data_frame

    def normalizer(data_frame,
                   inputCol="features",
                   outputCol="scaled_features"):
        ret_data_frame = data_frame
        return ret_data_frame

    def prep_TR_TE(merged_df, per_instance=False, tr_prop=0.9, targetCol="ID"):
        from pyspark.sql.functions import col
        if per_instance:
            tr_inst, te_inst = merged_df.randomSplit([tr_prop, 1 - tr_prop])
        else:
            tr_id, te_id = merged_df.select(targetCol).distinct().randomSplit(
                [tr_prop, 1 - tr_prop])
            tr_id = tr_id.rdd.flatMap(list).collect()
            te_id = te_id.rdd.flatMap(list).collect()
            tr_inst = merged_df.where(col(targetCol).isin(tr_id))
            te_inst = merged_df.where(col(targetCol).isin(te_id))
        return (tr_inst, te_inst)

    # NEED TO GET RID OF THIS PART. NEVER USE FOR THE MAIN PROCESSING!#
    # Will get this rid of as soon as dev done
    if __name__ == "__main__":
        # TODO remove all this information
        cur_dir = "/Users/dhlee4/mimic3_data/CHARTEVENTS.csv_parquet"
        cur_home = "/Users/dhlee4/mimic3_data/"

        spark = SparkSession.builder.master("local[*]")\
                                     .appName("PreProcessorGen_local_test")\
                                     .getOrCreate()
        spark.sparkContext.setLogLevel("WARN")
        cur_df = spark.read.parquet(cur_dir + "_test_pts")

        prep_df = cur_df.select(col("HADM_ID").alias("ID"), col("CHARTTIME").alias("TIME_OBS")\
                              ,"ITEMID",col("VALUE"))
        from pyspark.sql.functions import struct, split, date_add
        prep_df = prep_df.withColumn("TIME_SPAN",struct(split("TIME_OBS","\ ").getItem(0).cast("timestamp").alias("FROM_TIME")\
                                                        ,date_add(split("TIME_OBS","\ ").getItem(0),1).cast("timestamp").alias("TO_TIME")))
        prep_df.show()
        self.logger.debug(prep_df.count())
        zz = num_cat_tagger(prep_df)
        zz.show(truncate=False)
        cat_raw_filtered, voca_list = cat_frequency_filter(
            zz.where("IS_CAT == 1"))
        num_raw_filtered, num_ref_list = num_iqr_filter(
            zz.where("IS_CAT == 0"))
        cur_id = cat_raw_filtered.select("ID", "TIME_SPAN").unionAll(
            num_raw_filtered.select("ID", "TIME_SPAN")).distinct()
        num_filtered = availability_filter(num_raw_filtered, cur_id.count())
        num_filtered.show()
        cat_filtered = availability_filter(cat_raw_filtered, cur_id.count())
        cat_filtered.show()

        cat_featurized = cat_featurizer(cat_filtered, voca_df=voca_list)
        num_featurized = num_featurizer(num_filtered, ref_df=num_ref_list)

        merged_all = feature_aggregator(num_featurized, cat_featurized)

        target_rdd, target_schema, feature_column = flattener_df_prep(
            merged_all)

        cur_df = spark.createDataFrame(target_rdd, target_schema)
        cur_df.show(50)

        imputed_df, feature_col = imputer(cur_df, inputCols=feature_column)
        imputed_df.show(truncate=False)
        self.logger.debug(feature_col)
        imputed_df.write.save(cur_home + "test_obs", mode="overwrite")
示例#15
0
    
    patient_timeline_pd = patient_event \
        .join(ra_patient, 'person_id') \
        .where(F.col('index_date').between(F.col('lower_bound'), F.col('upper_bound')))\
        .withColumn('date_concept_id', F.struct(F.col('index_date'), patient_event['standard_concept_id']))\
        .groupBy('person_id').agg(join_collection_udf(F.collect_list('date_concept_id')).alias('sequence'), 
                                  F.size(F.collect_list('date_concept_id')).alias('size')) \
        .where(F.col('size') > 1) \
        .select('person_id', 'sequence').toPandas()
    
    return patient_timeline_pd


patient_event = spark.read.parquet('/data/research_ops/omops/omop_2019q4_embeddings/visit/patient_event/')
patient_event = patient_event \
    .withColumn("lower_bound", F.unix_timestamp(F.date_add(F.from_unixtime(F.col('date'), 'yyyy-MM-dd'), -30), 'yyyy-MM-dd')) \
    .withColumn("upper_bound", F.unix_timestamp(F.date_add(F.from_unixtime(F.col('date'), 'yyyy-MM-dd'), 30), 'yyyy-MM-dd'))
patient_event.cache()

data = visualize_time_lines(patient_event, 80809, 30)

sequences = data['sequence'].to_list()

sequence_1 = [int(c) for c in sequences[20].split(' ')]
sequence_2 = [int(c) for c in sequences[1].split(' ')]

sequence_1 = [1000560, 1126658, 1125315]
sequence_2 = [1125315, 1124957, 1125315, 1125315, 1112807]

patient_similarity = PatientSimilarity(max_cost=0, is_similarity=True)
示例#16
0
               F.rand(seed=10).alias("temp_raw"),
               F.randn(seed=27).alias("pressure_raw"),
               F.rand(seed=45).alias("duration_raw"),
               F.randn(seed=54).alias("temp_n"),
               F.randn(seed=78).alias("pressure_n"),
               F.randn(seed=96).alias("duration_n"),
               F.round(F.rand() * 7.5 * 60, 0).alias("timestamp_n"))
df = df.withColumn('pid', (100000 + df["id"]))
df = (df.withColumn("temp_raw", (10.0 * df["temp_raw"]) + 350).withColumn(
    "pressure_raw", (2.0 * df["pressure_raw"]) + 12).withColumn(
        "duration_raw", (4.0 * df["duration_raw"]) + 28.5).withColumn(
            "timestamp", ((df["id"] * 7.5 * 60) + 1561939200 +
                          df["timestamp_n"]).cast('timestamp')))
df = df.withColumn("process_time", df["timestamp"])
df = df.withColumn("qualitycheck_time",
                   F.date_trunc("day", F.date_add(df["timestamp"], 2)))

# Assign good vs bad for quality
df = df.withColumn(
    'quality',
    F.when(F.col('id').between(1,3400) & F.col('temp_raw').between(351,359) & F.col('pressure_raw').between(8,15) & F.col('duration_raw').between(28,32), 1)\
    .when(F.col('id').between(3401,6500) & F.col('temp_raw').between(354,359) & F.col('pressure_raw').between(8,15) & F.col('duration_raw').between(28,32), 1)\
    .when(F.col('id').between(6501,8000) & F.col('temp_raw').between(354,359) & F.col('pressure_raw').between(12,15) & F.col('duration_raw').between(28,32), 1)\
    .otherwise(0)
)
# Add some noise
df = (df.withColumn(
    "temp",
    df["temp_raw"] + (dg_noise["temp_noise"] * df["temp_n"])).withColumn(
        "pressure", df["pressure_raw"] +
        (dg_noise["pressure_noise"] * df["pressure_n"])).withColumn(
示例#17
0
    def process_history_df(self, df):
        """
        Process function for history data, generate result dataframe
        that contains date, number of create events and
        growth rate of a day compare to last week
        """
        # There are two versions of API for CreateEvent of repository:
        # - One is        col("payload")['object'] == 'repository'
        # - Another is    col("payload")['ref_type'] == 'repository'
        df_columns = df.columns
        df_first_record = df.first()

        num_create_events_df = \
            df \
            .filter((col('payload')['ref_type'] == 'repository') | (col('payload')['object'] == 'repository')) \
            .filter((col('type') == 'CreateEvent') | (col('type') == 'Event'))

        # count the number of create events happened in one day (group by date)
        num_create_events_by_date_df = num_create_events_df.groupby(
            to_date(df.created_at).alias('date_created_at')).count()

        # calculate the grawth rate of that day compare to last week
        # dulicated two dataframes, for each day in the first dataframe
        # find the number fo create events in the second dataframe
        # of a day that is 7 days before the day in the first dataframe
        # [df1] 2015-01-07 -> [df2] 2015-01-01 (7 days)
        num_create_events_by_date_df_1 = num_create_events_by_date_df.alias(
            'num_create_events_by_date_df_1')

        num_create_events_by_date_df_1 = \
            num_create_events_by_date_df_1 \
            .select(
                col('date_created_at').alias('date_created_at_1'),
                col('count').alias('count_1'))

        num_create_events_by_date_df_2 = num_create_events_by_date_df.alias(
            'num_create_events_by_date_df_2')

        num_create_events_by_date_df_2 = \
            num_create_events_by_date_df_2 \
            .select(
                col('date_created_at').alias('date_created_at_2'),
                col('count').alias('count_2'))

        joined_num_create_events_df = \
            num_create_events_by_date_df_1 \
            .withColumn(
                'last_week_date_created_at',
                date_add(num_create_events_by_date_df_1.date_created_at_1, -7)) \
            .join(
                num_create_events_by_date_df_2,
                col('last_week_date_created_at')
                == col('date_created_at_2'),
                how='left_outer')

        joined_num_create_events_df = joined_num_create_events_df.withColumn(
            'count_2', coalesce('count_2', 'count_1'))

        num_create_events_with_growth_rate_df = \
            joined_num_create_events_df \
            .withColumn(
                'weekly_increase_rate',
                ((joined_num_create_events_df.count_1 - joined_num_create_events_df.count_2) / joined_num_create_events_df.count_2)
            ) \
            .select(
                'date_created_at_1',
                'count_1',
                'weekly_increase_rate')

        num_create_events_with_growth_rate_df.show()

        return num_create_events_with_growth_rate_df
示例#18
0
from pyspark.sql import SparkSession
from pyspark.sql.functions import col, lit, to_date, current_date, \
    current_timestamp, date_add, date_sub, datediff, months_between, to_timestamp, hour

if __name__ == '__main__':
    spark = SparkSession.builder.appName("learning").master(
        "local").getOrCreate()

    spark.range(5).withColumn('date', to_date(lit('2019-01-01'))).show()

    spark.read.jdbc

    spark.range(5)\
         .select(current_date().alias('date'), current_timestamp().alias('timestamp'))\
         .select(date_add(col('date'), 1), date_sub(col('timestamp'), 1)).show()

    spark.range(5).select(to_date(lit('2019-01-01')).alias('date1'),
                          to_date(lit('2019-01-05')).alias('date2'))\
                  .select(datediff(col('date2'), col('date1')),
                          months_between(col('date2'), col('date1'))).show()

    spark.range(5).withColumn('date', to_date(
        lit('2019-XX-XX'))).show()  #No emite excepcion

    spark.range(5).withColumn('date_comp1', to_date(lit('2019-01-01')) > to_date(lit('2019-01-02'))) \
                  .withColumn('date_comp2', to_date(lit('2019-01-01')) > to_timestamp(lit('2019-01-02'))) \
        .withColumn('date_comp3', to_date(lit('2019-01-01')) > "2019-01-02") \
                  .withColumn('date_comp3', to_date(lit('2019-01-01')) > "'2019-01-02'").show()

    spark.range(5).select(current_timestamp().alias("timestamp")).select(
        hour(col('date')))
  .select("Description").show(3, False)


# COMMAND ----------

from pyspark.sql.functions import current_date, current_timestamp
dateDF = spark.range(10)\
  .withColumn("today", current_date())\
  .withColumn("now", current_timestamp())
dateDF.createOrReplaceTempView("dateTable")


# COMMAND ----------

from pyspark.sql.functions import date_add, date_sub
dateDF.select(date_sub(col("today"), 5), date_add(col("today"), 5)).show(1)


# COMMAND ----------

from pyspark.sql.functions import datediff, months_between, to_date
dateDF.withColumn("week_ago", date_sub(col("today"), 7))\
  .select(datediff(col("week_ago"), col("today"))).show(1)

dateDF.select(
    to_date(lit("2016-01-01")).alias("start"),
    to_date(lit("2017-05-22")).alias("end"))\
  .select(months_between(col("start"), col("end"))).show(1)


# COMMAND ----------
示例#20
0
]

# In[10]:

## construct the df_test
## construct the df_test based on the lastest week's data
df_sales_date_max = processed_time_d[1] - timedelta(7)
df_test = df_sales.filter(
    df_sales.week_start == df_sales_date_max).withColumnRenamed(
        'week_start', 'week_start_origin')
## how many weeks to predict ahead: num_weeks_ahead
## note: if using with the current simulator, it is a must to keep num_weeks_ahead=0, because there are some dependencies with the current version of simulator
## note: if working with the real data, it is fine to change this parameter, to provide suggested price several weeks ahead the next week
num_weeks_ahead = 0
df_test = df_test.withColumn(
    'week_start', date_add(col('week_start_origin'),
                           7 * (num_weeks_ahead + 1)))
## only do optimization for stores in the treatment group
df_test = df_test.filter(col('group_val') == 'treatment').drop('group_val')
df_test = df_test.select([
    'store_id', 'product_id', 'department_id', 'brand_id', 'msrp', 'cost',
    'avg_hhi', 'avg_traffic', 'week_start'
]).distinct()

# In[21]:

################################################## 3: load the model built last time
dirfilename_load = spark.read.parquet(modelDir + "model_name").filter(
    '_1 == "model_name"').rdd.map(lambda p: p[1]).collect()
dirfilename_load = dirfilename_load[0]
rfModel = RandomForestModel.load(sc, dirfilename_load)
示例#21
0
# ## Save model in HDFS
stockRfrModel.write().overwrite().save('stockRfr.mdl')

##### TO DO !!!! #####


# # Model 2: Predict long and short positions based on historic analyst ratings
# ## Load in historical analyst ratings from Yahoo Finance!
ratings = RM.loadData(spark, 'finance/data/analystRatings.json')
ratings.printSchema()
ratings.show(5)

# ## Transform the data
modelData = ratings.join(priceChanges,
                         (F.date_add(ratings.Date,1) == priceChanges.Date) & \
                          (ratings.Symbol == priceChanges.Symbol))\
  .select(priceChanges.Date, 
          priceChanges.Symbol,
          priceChanges.PriceChange,
          'Action',
          'Research Firm',
          'From',
          'To')
modelData.persist(StorageLevel.MEMORY_ONLY)
modelData.show(5)

# ## Build a pipeline that does all feature transformations and model generation
stages = RM.engineerFeatures()
ratingsPipeline = Pipeline(stages=stages)
(trainingData, testData) = modelData.randomSplit([0.7, 0.3])
示例#22
0
import pyspark.sql.functions as F

# COMMAND ----------

# For the following dataframe, select "dt" as well as a new column "next_date", that adds one day to the date:
df = spark.createDataFrame([('2015-04-08', )], ['dt'])
# Expected:
# +----------+----------+
# |        dt| next_date|
# +----------+----------+
# |2015-04-08|2015-04-09|
# +----------+----------+

# Answer
df.select("dt", F.date_add(df.dt, 1).alias('next_date')).show()

# COMMAND ----------

# For the following dataframe, convert format to MM/dd/yyyy and alias it as "date"
df = spark.createDataFrame([('2015-04-08', )], ['dt'])
# Expected
# +----------+
# |      date|
# +----------+
# |04/08/2015|
# +----------+

# Answer
df.select(F.date_format('dt', 'MM/dd/yyyy').alias('date')).show()
示例#23
0
#maintSensorDailyPD = maintTypes.filter('year(date)>=2016')\
#  .select(maintTypes.date.alias('day'), 'maintenanceType')\
#  .join(dailyRawMeasurements, 'day')\
#  .select('maintenanceType', 'sensor_name', 'value')\
#  .toPandas()
#sb.swarmplot(y='sensor_name', x='value', hue='maintenanceType', data=maintSensorDailyPD)

# Let's plot the maintenance type against sensor readings right before we have an outage
progPrevMaint = maintTypes.filter('maintenanceType!="Corrective"')\
  .select('date', 'maintenanceType')\
  .join(dailyRawMeasurements, maintTypes.date==dailyRawMeasurements.day)\
  .select('date', 'maintenanceType', 'tag_id', 'value', 'sensor_name')

corrMaint = maintTypes.filter('maintenanceType=="Corrective"')\
  .select('date', 'maintenanceType')\
  .join(dailyRawMeasurements, maintTypes.date==F.date_add(dailyRawMeasurements.day, 1))    .select('date', 'maintenanceType', 'tag_id', 'value', 'sensor_name')

rawSensorsByMaint = progPrevMaint.union(corrMaint)

sb.stripplot(y='sensor_name', x='value', hue='maintenanceType', jitter=True,
             data=rawSensorsByMaint.filter('year(date)>=2016').select('sensor_name', 'value', 'maintenanceType').toPandas())

# ### Spark Machine Learning Capabilities
# The analysis and visualizations above indicate that there is some relationship between
# sensor readings and the type of maintenance that occurs. Thus we should be able to
# build a model to predict this.
# Spark has an extensive list of built-in machine learning algorithms, including:
# * Classification algorithms
#  * Binomial and multi-nomial logistic regression
#  * Decision tree and random forest classifiers
#  * Gradient-boosted tree classifiers
示例#24
0

'''Now we drop year,month,day,hour,minute,date,time columns as we will again try to create these from timestamp column that we created'''
df_nycflights = df_nycflights. \
                drop('year'). \
                drop('month'). \
                drop('day'). \
                drop('hour'). \
                drop('minute'). \
                drop('date'). \
                drop('time')

df_nycflights.show() 

'''Now we extract the fields back'''
df_nycflights = df_nycflights. \
                withColumn('year',year(df_nycflights.timestamp)). \
                withColumn('month',month(df_nycflights.timestamp)). \
                withColumn('day',dayofmonth(df_nycflights.timestamp)). \
                withColumn('hour',hour(df_nycflights.timestamp)). \
                withColumn('minute',minute(df_nycflights.timestamp))  

df_nycflights.show()

'''Now few operations on timestamp '''
df_nycflights = df_nycflights.\
                withColumn('date_sub',date_sub(df_nycflights.timestamp ,10)). \
                withColumn('date_add',date_add(df_nycflights.timestamp ,10)). \
                withColumn('months_between',months_between(df_nycflights.timestamp,df_nycflights.timestamp))

df_nycflights.show()                 
示例#25
0
I have answered a similar question on stackoverflow: https://stackoverflow.com/questions/60426113/how-to-add-delimiters-to-a-csv-file/60428023#60428023. The reasoning to use an expression was a little different as the length of the string was not changing, it was rather out of laziness(I did not want to count the length of the string).
Example 2:
This example is actually straight out of a stackoverflow question I have answered: https://stackoverflow.com/questions/60494549/how-to-filter-a-column-in-a-data-frame-by-the-regex-value-of-another-column-in-s/60494657#60494657
Suppose you have a DataFrame with a column(query) of StringType that you have to apply a regexp_extract function to, and you have another column(regex_patt) which has all the patterns for that regex, row by row. If you didn’t know how to make your regexp_extract function dynamic for each row, you would build a UDF taking the two columns as input, and computing the regular expression for each row.(which will be very slow and cost inefficient).
Image for post
The question basically wants to filter out rows that do not match a given pattern.
The PySpark api has an inbuilt regexp_extract:
pyspark.sql.functions.regexp_extract(str, pattern, idx)
However, it only takes the str as a column, not the pattern. The pattern has to be specified as a static string value in the function. Therefore, we can use an expression to send a column to the pattern part of the function:
from pyspark.sql import functions as F
df.withColumn("query1", F.expr("""regexp_extract(query, regex_patt)""")).filter(F.col("query1")!='').drop("query1").show(truncate=False)
Image for post
The expression as shown in bold, allows us to apply the regex row by row and filter out the non matching row, hence row 2 was removed using the filter.
As stated above, if you try to put regex_patt as a column in your usual pyspark regexp_replace function syntax, you will get this error:
TypeError: Column is not iterable
Example 3:
Suppose you have a DataFrame shown below with a loan_date column(DateType) and days_to_make_payment column(IntegerType). You would like to compute the last date for payment, which would basically be adding the days column to the date column to get the new date.
Image for post
You can do this using the spark in-built date_add function:
pyspark.sql.functions.date_add(start, days)
It Returns the date that is days days after start. However, using this syntax, it only allows us to put the start as a column, and the days as a static integer value. Hence, we can use an expression to send the days_to_make_payment column as days into our function.
from pyspark.sql import functions as F
df.withColumn(“last_date_for_payment”, F.expr(“””date_add(Loan_date,days_to_make_payment)”””)).show()
Image for post
I would just like to reiterate for the last time that if you had used the usual pyspark syntax to put the days_to_make_payment to days like this:
from pyspark.sql import functions as F
df.withColumn("last_date_for_payment", F.date_add(F.col("Loan_date"),F.col("days_to_make_payment"))).show()
You would get this error:
TypeError: Column is not iterable
Conclusion:
Spark is the gold standard of big data processing engines and it has a vast open source community contributing to it all the time. It has a plethora of functions that can allow you to perform transformations at petabyte scale. With that said, one should be well aware of its limitations when it comes to UDFs(require moving data from the executor’s JVM to a Python interpreter) and Joins(shuffles data across partitions/cores), and one should always to try to push its in-built functions to their limits as they are highly optimized and scalable for big data tasks.
示例#26
0
def get_item_historical_sales():
    #市/区每款卷烟上四周各周的销量
    try:

        # 烟id,烟名称
        plm_item = get_plm_item(spark).select("item_id", "item_name")

        area = get_area(spark)
        # com_id与city的映射关系
        city = area.dropDuplicates(["com_id"]).select("com_id", "city")
        # sale_center_id与区(list)的映射关系
        county = area.groupBy("sale_center_id") \
            .agg(f.collect_list("county").alias("county")) \
            .select("sale_center_id", "county")

        # 标识列的值
        markers = ["1", "3"]
        # 按照 市或区统计
        groups = ["com_id", "sale_center_id"]
        joins = [city, county]
        # 除需要计算的值,其他的数据
        cols_comm = [["city", "gauge_id", "gauge_name", "ciga_data_marker"],
                     [
                         "county", "sale_center_id", "gauge_id", "gauge_name",
                         "ciga_data_marker"
                     ]]
        # 需要计算的值的列名
        cols = ["gauge_city_sales_history", "gauge_county_sales_history"]

        # 获取上四周订单行表数据
        # date为订单所在周的星期五的日期 给前端展示
        co_co_line = get_co_co_line(spark, scope=[1, 4], filter="week") \
            .select("item_id", "com_id","sale_center_id", "qty_ord", "born_date") \
            .withColumn("date", f.date_add(f.date_trunc("week", col("born_date")), 4))
        co_co_line.cache()
        for i in range(len(groups)):
            group = groups[i]
            join = joins[i]
            c = cols[i]
            marker = markers[i]
            columns = cols_comm[i]
            print(f"{str(dt.now())} {group} 每款卷烟上四周各周的销量")
            try:
                json_udf = f.udf(lambda x, y: json.dumps({
                    "date": str(x),
                    "value": y
                }))
                #计算每个市/区 每款烟前四周各周的销量
                #将结果拼成 [{"date":"2019-06-21","value":12354},{"date":"2019-06-14","value":14331}....]
                result=co_co_line.groupBy(group,"item_id","date")\
                          .agg(f.sum(col("qty_ord")).alias("qty_ord"))\
                          .withColumn("json",json_udf(col("date"),col("qty_ord")))\
                          .groupBy(group,"item_id")\
                          .agg(f.collect_list(col("json")).alias(c))

                columns.append(c)
                result.withColumn("row", f.concat_ws("_", col(group), col("item_id"))) \
                    .join(plm_item, "item_id") \
                    .join(join, group) \
                    .withColumnRenamed("item_id", "gauge_id") \
                    .withColumnRenamed("item_name", "gauge_name") \
                    .withColumn("ciga_data_marker", f.lit(marker)) \
                    .foreachPartition(lambda x: write_hbase1(x, columns, hbase))
            except Exception:
                tb.print_exc()
        co_co_line.unpersist()
    except Exception:
        tb.print_exc()
示例#27
0
                       flat.product_net_cost - flat.product_net_revenue)
flat = flat.withColumn('percentage_margin',
                       flat.absolute_margin / flat.product_net_cost)
flat = flat.withColumn('average_price_per_item',
                       flat.product_net_revenue / flat.item_sold)

flat = flat.select(
    '*',
    F.unix_timestamp(flat['order_date'].cast('string'),
                     'yyyyMMdd').cast('timestamp').alias('order_date_iso'))

avg_day = flat.groupBy('order_date_iso').agg(F.avg('average_price_per_item'))

avg_day = avg_day.withColumn(
    'next_day',
    F.date_add(avg_day.order_date_iso, 1).cast('timestamp')).cache()
avg_day = avg_day.withColumnRenamed('avg(average_price_per_item)',
                                    'average_price_per_item_yesterday')

flat = flat.join(avg_day, flat.order_date_iso == avg_day.next_day).drop(
    avg_day.next_day).drop(flat.order_date_iso)
# to do: Transform this in a left join and take care of missing values

flat = flat.withColumn(
    'var_average_price_per_item',
    (flat.average_price_per_item / flat.average_price_per_item_yesterday) - 1)
"""
avg_day.show()
+-------------------+--------------------------------+-------------------+      
|     order_date_iso|average_price_per_item_yesterday|           next_day|
+-------------------+--------------------------------+-------------------+
示例#28
0
string_with_space = "     hello     "

df.select(ltrim(lit(string_with_space)), rtrim(lit(string_with_space)),
          trim(lit(string_with_space))).show()

#regular expressions

#working with dates, timestamps
from pyspark.sql.functions import current_date, current_timestamp, date_add, date_sub, datediff, months_between, to_date, to_timestamp

dateDF = spark.range(10).withColumn("today", current_date()).withColumn(
    "now", current_timestamp())
dateDF.show()

dateDF.select(
    date_add(col("today"), 5).alias("today+5"),
    date_sub(col("today"), 5).alias("today-5")).show()

#convert string to date, default format is 'YYYY-MM-DD'
spark.range(1).select(
    to_date(lit("2019-02-01")).alias("start_date"),
    to_date(lit("2019-03-06")).alias("end_date")).select(
        datediff(col("start_date"), col("end_date"))).show()

cleanDateDF = spark.range(1).withColumn("date1", current_date())
date_format = 'YYYY-MM-DD'
cleanDateDF.select(to_timestamp(col("date1"), date_format)).show()

#working with nulls in data

#coalesce, returns first not null value from a set of columns
示例#29
0
def clear_spark_data(data):
    drop_pool = [
        'lm_stu_cancel_cnt', 'near_stu_cancel_distance', 'gift_gr13_hours',
        'gift_gr15_hours', 'first_new_distance', 'first_subjoin_distance'
    ]
    for col in drop_pool:
        data = data.drop(col)
    data = data.filter(data['age'].between(2, 18))
    data = data.filter(data['left_hours'] >= 0)
    data = data.filter(data['left_hours'].isNotNull())
    data = data.filter(data['left_hours'] != np.nan)

    data = data.filter(data['level_sequence'] >= 0)
    data = data.filter(data['level_sequence'].isNotNull())
    data = data.filter(data['level_sequence'] != np.nan)

    data = data.filter(data['register_distance'] > 0)
    data = data.filter(data['register_distance'].isNotNull())
    data = data.filter(data['register_distance'] != np.nan)

    data = data.filter(data['entry_distance'] > 0)
    data = data.filter(data['entry_distance'].isNotNull())
    data = data.filter(data['entry_distance'] != np.nan)

    data = data.filter(data['paid_success_cnt'] > 0)
    data = data.filter(data['paid_success_cnt'].isNotNull())
    data = data.filter(data['paid_success_cnt'] != np.nan)

    data = data.fillna({
        'city_level':
        8,
        'info_length':
        0,
        'l6m_hw_avg_score':
        0,
        'l6m_hw_avg_time':
        0,
        'l6m_avg_bad_rating':
        data.select('l6m_avg_bad_rating').toPandas().median()[0]
    })

    for x in data.columns:
        if 'cnt' in x or 'hours' in x:
            data = data.fillna({x: 0})
        elif 'distance' in x:
            data = data.fillna({x: 2000})
        else:
            pass

    data = data.withColumn(
        'all_paid_hours', data['paid_new_hours'] + data['paid_subjoin_hours'] +
        data['paid_renew_hours'])
    data = data.filter(data['all_paid_hours'] > 0)
    data = data.filter(data['all_paid_hours'].isNotNull())
    data = data.filter(data['all_paid_hours'] != np.nan)

    data = data.withColumn('now_during', fn.when(fn.dayofmonth('point_36_date')<=10,1)\
                           .when(fn.dayofmonth('point_36_date')<=20,2)\
                           .otherwise(3))
    data = data.withColumn('will_renew_date',
                           fn.date_add(data['point_36_date'], 11))
    data = data.withColumn('renew_during', fn.when(fn.dayofmonth('will_renew_date')<=10,1)\
                           .when(fn.dayofmonth('will_renew_date')<=20,2)\
                           .otherwise(3))#avg(datediff(point_renew_date,point_36_date))
    data = data.drop('will_renew_date')

    data = data.withColumn(
        'register2new_distance',
        data['register_distance'] - data['last_new_distance'])
    data = data.withColumn('new2entry_distance',
                           data['last_new_distance'] - data['entry_distance'])
    # data = data.withColumn('new2subjoin_distance', data['last_new_distance'] - data['last_subjoin_distance'])#TODO
    # data = data.withColumn('subjoin2renew_distance', data['last_subjoin_distance'] - data['first_renew_distance'])#TODO
    data = data.drop('last_subjoin_distance')

    data = data.withColumn(
        'f2l_paid_distance',
        data['first_paid_distance'] - data['last_paid_distance'])
    data = data.withColumn(
        'f2l_paid_avg_distance',
        data['f2l_paid_distance'] / data['paid_success_cnt'])

    data = data.withColumn(
        'all_gift_hours',
        data['gift_positive_hours'] + data['gift_negative_hours'])
    data = data.withColumn(
        'gift_positive_percent',
        data['gift_positive_hours'] / (data['all_gift_hours'] + 1))
    data = data.withColumn(
        'gift_negative_percent',
        data['gift_negative_hours'] / (data['all_gift_hours'] + 1))
    data = data.withColumn('gift_PN_rate', (data['gift_positive_hours'] + 1) /
                           (data['gift_negative_hours'] + 1))
    data = data.withColumn('left_hours_percent_paid',
                           data['left_hours'] / data['all_paid_hours'])
    data = data.withColumn(
        'left_hours_percent_all',
        data['left_hours'] / (data['all_paid_hours'] + data['all_gift_hours']))
    data = data.withColumn('left_hours_percent_last_paid',
                           data['left_hours'] / data['last_paid_hours'])
    data = data.withColumn(
        'lpd2n_noconsume_class_cnt',
        (data['lastpaid2now_class2_hours'] - data['lastpaid2now_class1_hours'])
        / data['last_paid_distance'])
    data = data.withColumn(
        'lpd2n_abs_diff_left_hours',
        fn.abs(data['last_left_hours'] - data['left_hours']))
    data = data.withColumn(
        'lpd2n_isless_left_hours',
        fn.when(data['last_left_hours'] > data['left_hours'], 1).otherwise(0))
    data = data.withColumn(
        'lpd2n_day_class_hours',
        (data['last_paid_hours'] + data['last_left_hours']) /
        data['last_paid_distance'])
    data = data.withColumn(
        'will_zero_hours_distance', data['left_hours'] /
        ((data['lastpaid2now_class1_hours'] + 1) / data['last_paid_distance']))
    data = data.withColumn('gift_gr1_hours_percent_entry_distance',
                           data['gift_gr1_hours'] / data['entry_distance'])
    data = data.withColumn('student_type', fn.when(data['student_type']=='NORMAL',1).when(data['student_type']=='VIP',2)\
                           .when(data['student_type']=='KOL',3)\
                           .otherwise(0))

    data = data.drop('lastpaid2now_class2_hours')
    data = data.drop('point_36_date')
    return data
def data_loader_base(spark_ctx):
    hive_ctx = spark_ctx.hive_ctx()

    demo_cols = spark_ctx.parallelize(demographic_respondent.split('\n')).map(
        lambda x: Row(dat=x.split(','))).map(lambda x: Row(
            demographic_key=x.dat[2], is_lookup=int(x.dat[4]))).collect()

    is_lookup_keys = []
    for key in [x.demographic_key for x in demo_cols if x.is_lookup == 1]:
        is_lookup_keys.append(key)

    value_mapper = spark_ctx.parallelize(
        csv.reader(demographic_respondent_value.splitlines(),
                   quotechar='"',
                   delimiter=',',
                   quoting=csv.QUOTE_ALL,
                   skipinitialspace=True)
    ).map(lambda x: Row(dat=x)).map(
        lambda x: Row(demographic_value_key='_'.join([
            str(x.dat[0]),
            str(int(x.dat[1])) if re.match('\d', str(x.dat[1])) else x.dat[1]
        ]),
                      demographic_key=str(x.dat[0]),
                      demographic_value_sequence=x.dat[3])).map(lambda x: {
                          x.demographic_value_key:
                          [x.demographic_key, x.demographic_value_sequence]
                      }).collect()

    base_mapper = {}
    for m in value_mapper:
        demographic_value_key = list(m.keys())[0]
        demographic_key = list(m.values())[0][0]
        demographic_value_sequence = list(m.values())[0][1]
        if demographic_key in is_lookup_keys:
            base_mapper[demographic_value_key] = demographic_value_sequence

    tv_resp_df = hive_ctx.table('.'.join(
        [dwh_name, dwh_tv_respondent.table.name])).coalesce(1)

    person_df = hive_ctx.table('.'.join(
        [dwh_name, dwh_person_deduplicated.table.name]))

    hh_df = hive_ctx.table('.'.join(
        [dwh_name, dwh_household_deduplicated.table.name]))

    columns_altered = [
        'age', 'agerange', 'agerangedar', 'agegenderbuildingblockcode',
        'visitorstatuscode', 'workinghours',
        'relationshiptoheadofhouseholdcode', 'internetusagehome',
        'internetusagework', 'num_of_years_spent_in_the_usa',
        'ladyofhouseoccupationcode', 'numberoftvsets', 'numberoftvsetswithpay',
        'numberoftvsetswithwiredcable', 'numberoftvsetswithwiredcableandpay',
        'numberofvcrs',
        'headofhouseholdhispanicspecificethnicityeffectivejune282010',
        'numberofcars', 'householdincomeamount',
        'relationshiptoheadofhouseholdcode_af_20180101', 'broadcastonly',
        'zerotv'
    ]

    columns_not_altered = [
        x.demographic_key.lower() for x in demo_cols
        if x.demographic_key.lower() not in columns_altered
    ]

    demographicdata = tv_resp_df.alias('t').join(
        person_df.alias('p'),
        on=[
            f.col('t.household_id') == f.col('p.household_id'),
            f.col('t.person_id') == f.col('p.person_id')
        ]).join(
            hh_df.alias('h'),
            on=[
                f.col('p.household_id') == f.col('h.household_id'),
                f.col('h.start_date') <= f.col('p.end_date'),
                f.col('h.end_date') >= f.col('p.start_date')
            ]).withColumn(
                'startdate',
                f.when(
                    f.col('h.start_date') > f.col('p.start_date'),
                    f.col('h.start_date')).otherwise(f.col('p.start_date'))
            ).withColumn(
                'enddate',
                f.when(
                    f.col('h.end_date') < f.col('p.end_date'),
                    f.col('h.end_date')).otherwise(f.col('p.end_date'))
            ).select(
                f.col('tv_respondent_id'),
                f.col('startdate').alias('start_date'),
                f.col('enddate').alias('end_date'),
                f.when(f.col('p.age') == '999',
                       f.lit(None).astype('string')).otherwise(
                           f.col('p.age')).alias('age'),
                f.when(f.col('p.age').astype('int').between(
                    2, 5), f.lit(19)).when(
                        f.col('p.age').astype('int').between(6,
                                                             8),
                        f.lit(20)).when(
                            f.col('p.age').astype('int').between(9,
                                                                 11),
                            f.lit(21)).when(
                                f.col('p.age').astype('int').between(12, 14),
                                f.lit(22)).when(
                                    f.col('p.age').astype('int').between(
                                        15, 17),
                                    f.lit(23)).when(
                                        f.col('p.age').astype('int') == 18,
                                        f.lit(2)).when(
                                            f.col('p.age').astype('int') == 19,
                                            f.lit(3)).
                when(f.col('p.age').astype('int') == 20, f.lit(4)).when(
                    f.col('p.age').astype('int') == 21, f.lit(5)).when(
                        f.col('p.age').astype('int').between(22,
                                                             24),
                        f.lit(6)).when(
                            f.col('p.age').astype('int').between(25,
                                                                 29),
                            f.lit(7)).when(
                                f.col('p.age').astype('int').between(30, 34),
                                f.lit(8)).
                when(f.col('p.age').astype('int').between(35, 39),
                     f.lit(9)).when(
                         f.col('p.age').astype('int').between(40,
                                                              44),
                         f.lit(10)).when(
                             f.col('p.age').astype('int').between(45,
                                                                  49),
                             f.lit(11)).when(
                                 f.col('p.age').astype('int').between(50, 54),
                                 f.lit(12)).
                when(f.col('p.age').astype('int').between(55, 59),
                     f.lit(13)).when(
                         f.col('p.age').astype('int').between(60,
                                                              64),
                         f.lit(14)).when(
                             f.col('p.age').astype('int').between(65,
                                                                  69),
                             f.lit(15)).when(
                                 f.col('p.age').astype('int').between(70, 74),
                                 f.lit(16)).when(
                                     f.col('p.age').astype('int').between(
                                         75, 85), f.lit(17)).when(
                                             f.col('p.age').astype('int') > 85,
                                             f.lit(18)).alias('agerange'),
                f.when(f.col('p.age').astype('int').between(2, 17),
                       f.lit(1)).when(
                           f.col('p.age').astype('int').between(18, 20),
                           f.lit(2)).when(
                               f.col('p.age').astype('int').between(21, 24),
                               f.lit(3)).when(
                                   f.col('p.age').astype('int').between(
                                       25, 29), f.lit(4)).
                when(f.col('p.age').astype('int').between(30, 34),
                     f.lit(5)).when(
                         f.col('p.age').astype('int').between(35, 39),
                         f.lit(6)).when(
                             f.col('p.age').astype('int').between(40, 44),
                             f.lit(7)).when(
                                 f.col('p.age').astype('int').between(45, 49),
                                 f.lit(8)).when(
                                     f.col('p.age').astype('int').between(
                                         50, 54),
                                     f.lit(9)).when(
                                         f.col('p.age').astype('int').between(
                                             55, 64),
                                         f.lit(10)).when(
                                             f.col('p.age').astype('int') > 64,
                                             f.lit(11)).alias('agerangedar'),
                f.lit(9999).alias('agegenderbuildingblockcode'),
                f.lit(9999).alias('visitorstatuscode'),
                f.col('p.workinghours').astype('int').alias('workinghours'),
                f.when(
                    f.col('startdate') < f.to_date(f.lit('2018-01-01')),
                    f.col('p.relationshiptoheadofhouseholdcode')).otherwise(
                        f.lit(None).astype('string')).alias(
                            'relationshiptoheadofhouseholdcode'),
                f.col('p.internetusagehome').astype('int').alias(
                    'internetusagehome'),
                f.col('p.internetusagework').astype('int').alias(
                    'internetusagework'),
                f.col('p.number_of_years_spent_in_the_united_states').alias(
                    'num_of_years_spent_in_the_usa'),
                f.col('h.ladyofhouseoccupationcode').astype('int').alias(
                    'ladyofhouseoccupationcode'),
                f.col('h.numberoftvsets').astype('int').alias(
                    'numberoftvsets'),
                f.col('h.numberoftvsetswithpay').astype('int').alias(
                    'numberoftvsetswithpay'),
                f.col('h.numberoftvsetswithwiredcable').astype('int').alias(
                    'numberoftvsetswithwiredcable'),
                f.col('h.numberoftvsetswithwiredcableandpay').astype(
                    'int').alias('numberoftvsetswithwiredcableandpay'),
                f.col('h.numberofvcrs').astype('int').alias('numberofvcrs'),
                f.col('h.headofhouseholdhispanicspecificethnicity').alias(
                    'headofhouseholdhispanicspecificethnicityeffectivejune282010'
                ),
                f.col('h.numberofcars').astype('int').alias('numberofcars'),
                f.col('h.householdincomeamount').astype('int').alias(
                    'householdincomeamount'),
                f.when(
                    f.col('startdate') >= f.to_date(f.lit('2018-01-01')),
                    f.col('p.relationshiptoheadofhouseholdcode')).otherwise(
                        f.lit(None).astype('string')).alias(
                            'RelationshipToHeadOfHouseholdCode_af_20180101'),
                f.col('h.broadcast_only').alias('broadcastonly'),
                f.col('t.zero_tv').alias('zerotv'),
                *columns_not_altered).withColumn(
                    'rn',
                    f.row_number().over(
                        w.partitionBy('tv_respondent_id',
                                      'start_date').orderBy(
                                          f.col('end_date').desc()))).filter(
                                              f.col('rn') == 1)

    result = demographicdata.select(
        f.col('tv_respondent_id'), f.col('start_date'),
        f.date_add(
            f.lead(f.col('start_date'), 1, '9999-12-31').over(
                w.partitionBy('tv_respondent_id').orderBy('start_date')),
            -1).alias('end_date'),
        f.array(*create_values(demo_cols)).alias('values'))

    max_weight_date = hive_ctx.table('.'.join([
        dwh_name, dwh_tv_respondent_weight.table.name
    ])).groupBy('tv_respondent_id').agg(
        f.max(f.col('date')).astype('date').alias('date'))

    return max_weight_date.alias('md').join(
        result.alias('r'), on=['tv_respondent_id']).select(
            f.col('tv_respondent_id'),
            f.when(
                f.to_date(f.lit('2013-12-30')).between(f.col('r.start_date'),
                                                       f.col('r.end_date')),
                f.to_date(f.lit('2013-12-30'))).otherwise(
                    f.col('r.start_date')).alias('start_date'),
            f.when(
                f.col('md.date').between(f.col('r.start_date'),
                                         f.col('r.end_date')),
                f.col('md.date')).otherwise(
                    f.col('r.end_date')).alias('end_date'), f.col('values')
        ).filter(f.col('end_date') >= f.to_date(f.lit('2013-12-30'))).rdd.map(
            lambda x: Row(tv_respondent_id=int(x.tv_respondent_id),
                          end_date=x.end_date,
                          values=list(
                              map(
                                  lambda v: base_mapper[v] if v in base_mapper
                                  else v, x.values)),
                          start_date=x.start_date)).toDF().repartition(100)