Beispiel #1
0
    def run(self,
            model_file_path: str,
            predict_data_dir_path: str,
            user_data_file_path: str,
            item_data_file_path: str,
            processed_training_data_file_path: str,
            data_limit: int = -1) -> bool:
        """execute."""
        # check parameter
        self.__param_check(model_file_path, predict_data_dir_path,
                           user_data_file_path, item_data_file_path,
                           processed_training_data_file_path, data_limit)

        # make spark context
        spark = SparkSession\
            .builder\
            .appName('create_predicted_score')\
            .config('spark.sql.crossJoin.enabled', 'true')\
            .config('spark.debug.maxToStringFields', 500)\
            .getOrCreate()
        sqlContext = SQLContext(sparkContext=spark.sparkContext,
                                sparkSession=spark)

        # load user data
        users_df = sqlContext\
            .read\
            .format('csv')\
            .options(header='false')\
            .load(user_data_file_path)
        users_id_rdd = users_df.rdd.map(lambda l: Row(user_id=l[0]))
        users_id_df = sqlContext.createDataFrame(users_id_rdd)

        # load item data
        items_df = sqlContext\
            .read\
            .format('csv')\
            .options(header='false')\
            .load(item_data_file_path)
        items_id_rdd = items_df.rdd.map(lambda l: Row(item_id=l[0]))
        items_id_df = sqlContext.createDataFrame(items_id_rdd)

        # cross join user_id and item_id
        joined_df = users_id_df.join(items_id_df)
        joined_df.cache()

        # delete unnecessary variables
        del (users_df)
        del (users_id_rdd)
        del (users_id_df)
        del (items_df)
        del (items_id_rdd)
        del (items_id_df)

        # load training data
        custom_schema = StructType([
            StructField('user', StringType(), True),
            StructField('item', StringType(), True),
            StructField('rating', FloatType(), True),
            StructField('unique_user_id', IntegerType(), True),
            StructField('unique_item_id', IntegerType(), True),
        ])
        training_df = sqlContext\
            .read\
            .format('csv')\
            .options(header='true')\
            .load(processed_training_data_file_path, schema=custom_schema)
        # users
        unique_users_rdd = training_df.rdd.map(lambda l: [l[0], l[3]])
        unique_users_df = sqlContext.createDataFrame(
            unique_users_rdd, ('user', 'unique_user_id')).dropDuplicates()
        unique_users_df.cache()
        # items
        unique_items_rdd = training_df.rdd.map(lambda l: [l[1], l[4]])
        unique_items_df = sqlContext.createDataFrame(
            unique_items_rdd, ('item', 'unique_item_id')).dropDuplicates()
        unique_items_df.cache()

        # delete unnecessary variables
        del (training_df)
        del (unique_users_rdd)
        del (unique_items_rdd)

        # add unique user id
        joined_df = joined_df.join(
            unique_users_df, joined_df['user_id'] == unique_users_df['user'],
            'inner').drop(unique_users_df['user'])

        # add unique item id
        joined_df = joined_df.join(
            unique_items_df, joined_df['item_id'] == unique_items_df['item'],
            'inner').drop(unique_items_df['item'])

        # load model
        model = ALSModel.load(model_file_path)

        # predict score
        predictions = model.transform(joined_df)
        all_predict_data = predictions\
            .select('user_id', 'item_id', 'prediction')\
            .filter('prediction > 0')

        # save
        ymd = datetime.today().strftime('%Y%m%d')
        # all score
        saved_data_file_path = predict_data_dir_path \
            + 'als_predict_data_all_%s.csv' % (ymd)
        all_predict_data.write\
            .format('csv')\
            .mode('overwrite')\
            .options(header='true')\
            .save(saved_data_file_path)
        copied_data_file_path = predict_data_dir_path + 'als_predict_data_all.csv'
        all_predict_data.write\
            .format('csv')\
            .mode('overwrite')\
            .options(header='true')\
            .save(copied_data_file_path)

        # limited score
        data_limit = int(data_limit)
        if data_limit > 0:
            all_predict_data.registerTempTable('predictions')
            sql = 'SELECT user_id, item_id, prediction ' \
                + 'FROM ( ' \
                + '  SELECT user_id, item_id, prediction, dense_rank() ' \
                + '  OVER (PARTITION BY user_id ORDER BY prediction DESC) AS rank ' \
                + '  FROM predictions ' \
                + ') tmp WHERE rank <= %d' % (data_limit)
            limited_predict_data = sqlContext.sql(sql)
        else:
            limited_predict_data = all_predict_data

        saved_data_file_path = predict_data_dir_path \
            + 'als_predict_data_limit_%s.csv' % (ymd)
        limited_predict_data.write\
            .format('csv')\
            .mode('overwrite')\
            .options(header='true')\
            .save(saved_data_file_path)
        copied_data_file_path = predict_data_dir_path + 'als_predict_data_limit.csv'
        limited_predict_data.write\
            .format('csv')\
            .mode('overwrite')\
            .options(header='true')\
            .save(copied_data_file_path)

        return True
Beispiel #2
0
        .load(absolute_file_path_xml)
count_xml = df_xml.count()

df_parquet = spark.read.format("parquet") \
        .option("header", "true") \
        .option("multiline", True) \
        .option("sep", ";") \
        .option("quote", "*") \
        .option("mode", "DROPMALFORMED") \
        .option("dateFormat", "MM/dd/yyyy") \
        .option("inferSchema", True) \
        .load(absolute_file_path_parquet)
count_parquet = df_parquet.count()

schema = StructType([
    StructField('WeatherStation', StringType(), False),
    StructField('WBAN', StringType(), True),
    StructField('ObservationDate', StringType(), False),
    StructField('ObservatioonHous', IntegerType(), True),
    StructField('Latitude', StringType(), True),
    StructField('Longitude', StringType(), True),
    StructField('Elevation', IntegerType(), True),
    StructField('WinDirection', IntegerType(), True),
    StructField('WDQualityCode', IntegerType(), True),
    StructField('SkyCeilingHeight', IntegerType(), True),
    StructField('SCQualityCode', IntegerType(), True),
    StructField('VisibilityDistance', IntegerType(), True),
    StructField('VDQualityCode', IntegerType(), True),
    StructField('AirTemperature', FloatType(), True),
    StructField('ATQualityCode', IntegerType(), True),
    StructField('DewPoint', FloatType(), True),
#Data Exploration :
#Data Visualisation: Databricks cloud dataviz  (or Matplotlib in Jupyter)
#Machine Learning : Spark Mllib ( ML packages)

# COMMAND ----------

display(dataset.limit(3))

# COMMAND ----------

from pyspark.sql.types import StructType, StructField, IntegerType, StringType, BooleanType

# COMMAND ----------

#Create a schema for a Pyspark Dataframe Api
fireSchema = StructType([StructField('CallNumber', IntegerType(), True),
                     StructField('UnitID', StringType(), True),
                     StructField('IncidentNumber', IntegerType(), True),
                     StructField('CallType', StringType(), True),                  
                     StructField('CallDate', StringType(), True),       
                     StructField('WatchDate', StringType(), True),       
                     StructField('ReceivedDtTm', StringType(), True),       
                     StructField('EntryDtTm', StringType(), True),       
                     StructField('DispatchDtTm', StringType(), True),       
                     StructField('ResponseDtTm', StringType(), True),       
                     StructField('OnSceneDtTm', StringType(), True),       
                     StructField('TransportDtTm', StringType(), True),                  
                     StructField('HospitalDtTm', StringType(), True),       
                     StructField('CallFinalDisposition', StringType(), True),       
                     StructField('AvailableDtTm', StringType(), True),       
                     StructField('Address', StringType(), True),       
conf = SparkConf().setAppName("RAPIDS_Accelerator_Spark_XGBoost_test")
conf.set("spark.executor.instances", "1")
conf.set("spark.executor.cores", "1")
conf.set("spark.task.cpus", "1")
conf.set("spark.executor.memory", "2g")
conf.set("spark.task.resource.gpu.amount", "1")
conf.set("spark.plugins", "com.nvidia.spark.SQLPlugin")
conf.set("spark.rapids.memory.gpu.pooling.enabled", "false")
spark = SparkSession.builder \
                    .config(conf=conf) \
                    .getOrCreate()

label = 'l'
schema = StructType([
    StructField('c0', FloatType()),
    StructField('c1', FloatType()),
    StructField(label, IntegerType()),
])

features = [x.name for x in schema if x.name != label]
df = spark.createDataFrame(
    [
        (1.05, 9.05, 0),  # create your data here, be consistent in the types.
        (2.95, 1.95, 1),
    ],
    schema)

params = {
    'missing': 0.0,
    'treeMethod': 'gpu_hist',
    # Using format()
    str += '{0}'
    list = [str.format(i) for i in list]
    return list


start_date = '2017,1,4'
end_date = '2017,1,5'
deployment = 'cbr'
filter_list = list_all_file(deployment, start_date, end_date)
filter_list = prepend(
    filter_list,
    "/Users/datami/anshuman/sd-sync-process/s3-data/" + deployment + "/accmi/")

schema = StructType([
    StructField("crId", StringType(), True),
    StructField("op", StringType(), True),
    StructField("trace", StringType(), True),
    StructField("res", StringType(), True),
    StructField("ts", StringType(), True),
    StructField("uri", StringType(), True),
    StructField("version", StringType(), True),
    StructField("req.aStore", StringType(), True)

    # StructField("req.aStore", StringType(), True),
    # StructField("req.apikey", StringType(), True),
    # StructField("req.appId", StringType(), True),
    # StructField("req.appVer", StringType(), True),
    # StructField("req.callState", StringType(), True),
    # StructField("req.clientIpAddress", StringType(), True),
    # StructField("req.correlationId", StringType(), True),
Beispiel #6
0
                                       "io.delta:delta-core_2.12:0.7.0,"
                                       "org.postgresql:postgresql:42.2.19") \
        .config("spark.driver.extraJavaOptions", "-Dlog4j.configuration=file:log4j.properties "
                                                 "-Dspark.yarn.app.container.log.dir=app-logs "
                                                 "-Dlogfile.name=hello-spark") \
        .config("spark.sql.extensions", "io.delta.sql.DeltaSparkSessionExtension")\
        .config("spark.sql.catalog.spark_catalog", "spark.sql.catalog.spark_catalog")\
        .getOrCreate()

    # conf_out = spark.sparkContext.getConf()
    # print(conf_out.toDebugString())

    logger = Log4j(spark)

    schema = StructType([
        StructField("InvoiceNumber", StringType()),
        StructField("CreatedTime", LongType()),
        StructField("StoreID", StringType()),
        StructField("PosID", StringType()),
        StructField("CashierID", StringType()),
        StructField("CustomerType", StringType()),
        StructField("CustomerCardNo", StringType()),
        StructField("TotalAmount", DoubleType()),
        StructField("NumberOfItems", IntegerType()),
        StructField("PaymentMethod", StringType()),
        StructField("CGST", DoubleType()),
        StructField("SGST", DoubleType()),
        StructField("CESS", DoubleType()),
        StructField("DeliveryType", StringType()),
        StructField(
            "DeliveryAddress",
Beispiel #7
0
spark.sparkContext.setLogLevel("WARN")

stediEventsRawDF = spark \
    .readStream \
    .format("kafka") \
    .option("kafka.bootstrap.servers", KAFKA_HOST) \
    .option("subscribe", "stedi-events") \
    .option("startingOffsets", "earliest") \
    .load()

# parse the JSON from the single column "value" with a json object in it, like this:

stediEventsDF = stediEventsRawDF.selectExpr("cast(value as string) as value")

stediEventsSchema = StructType([
    StructField("customer", StringType()),
    StructField("score", FloatType()),
    StructField("riskDate", DateType())
])

# storing them in a temporary view called CustomerRisk

stediEventsDF \
    .withColumn("value", F.from_json("value", stediEventsSchema)) \
    .select(F.col("value.customer"), F.col("value.score"), F.col("value.riskDate")) \
    .createOrReplaceTempView("CustomerRisk")

# Execute a sql statement against a temporary view, selecting the customer and the
# score from the temporary view, creating a dataframe called customerRiskStreamingDF

customerRiskStreamingDF = spark.sql("select customer, score from CustomerRisk")
Beispiel #8
0
# 输出方向:数据库
conn_param = {}
conn_param['user'] = '******'
conn_param['password'] = '******'
conn_param['driver'] = "com.mysql.jdbc.Driver"

spark = SparkSession \
    .builder \
    .appName("Python Spark SQL basic example") \
    .config("spark.some.config.option", "some-value") \
    .getOrCreate()

sc = spark.sparkContext

# 构建数据的属性schema
field1 = StructField("director", StringType(), nullable=True)
field2 = StructField("box", FloatType(), nullable=True)
fields = [field1, field2]
schema = StructType(fields)

# 获取hdfs上的数据
loan_info = sc.textFile(readFileName)

# 将RDD转换成DF
loan_info_row = loan_info.map(lambda line: line.split('@'))
loan_info_row_rdd = loan_info_row.map(
    lambda attributes: Row(attributes[3], float(eval(attributes[7]))))
loan_info_df = spark.createDataFrame(loan_info_row_rdd, schema)

# 数据处理
result_df = loan_info_df.groupBy("director").sum()
Beispiel #9
0
                         "io.delta.sql.DeltaSparkSessionExtension")\
          .getOrCreate()

meteoDataFrame  = spark.read.format('csv')\
    .option('sep',';')\
    .option('header','true')\
    .option('nullValue','mq')\
    .option('inferSchema', 'true')\
    .load('donnees/meteo')\
    .cache()

meteoDataFrame.columns
meteoDataFrame.printSchema()

schema = StructType([
    StructField('Id', StringType(), True),
    StructField('ville', StringType(), True),
    StructField('latitude', FloatType(), True),
    StructField('longitude', FloatType(), True),
    StructField('altitude', IntegerType(), True)
])

villes  = spark.read.format('csv')   \
      .option('sep',';')                \
      .option('mergeSchema', 'true')    \
      .option('header','true')          \
      .schema(schema)                   \
      .load('/user/spark/donnees/postesSynop.csv')  \
      .cache()

from pyspark.sql.types import StructType, StructField, IntegerType, StringType

nreco = 12

rating_schema = StructType([
	StructField("userId", IntegerType(), True),
	StructField("movieId", IntegerType(), True),
	StructField("rating", IntegerType(), True),
	StructField("timestamp", IntegerType(), True)]
)

movie_schema = StructType([
	StructField("movieId", IntegerType(), True),
	StructField("title", StringType(), True),
	StructField("release_date", StringType(), True),
	StructField("video release date", StringType(), True),
	StructField("IMDb URL", StringType(), True),
	StructField("unknown", IntegerType(), True),
	StructField("Action", IntegerType(), True),
	StructField("Adventure", IntegerType(), True),
	StructField("Animation", IntegerType(), True),
	StructField("Children's", IntegerType(), True),
	StructField("Comedy", IntegerType(), True),
	StructField("Crime", IntegerType(), True),
	StructField("Documentary", IntegerType(), True),
	StructField("Drama", IntegerType(), True),
	StructField("Fantasy", IntegerType(), True),
	StructField("Film-Noir", IntegerType(), True),
	StructField("Horror", IntegerType(), True),
	StructField("Musical", IntegerType(), True),
	StructField("Mystery", IntegerType(), True),
Beispiel #11
0
def compute_churn_week(df, week_start):
    """Compute the churn data for this week. Note that it takes 10 days
    from the end of this period for all the activity to arrive. This data
    should be from Sunday through Saturday.

    df: DataFrame of the dataset relevant to computing the churn
    week_start: datestring of this time period"""

    week_start_date = datetime.strptime(week_start, "%Y%m%d")
    week_end_date = week_start_date + timedelta(6)
    week_start = fmt(week_start_date)
    week_end = fmt(week_end_date)

    # Verify that the start date is a Sunday
    if week_start_date.weekday() != 6:
        msg = "Week start date {} is not a Sunday".format(week_start)
        raise RuntimeError(msg)

    # If the data for this week can still be coming, don't try to compute the
    # churn.
    week_end_slop = fmt(week_end_date + timedelta(10))
    today = fmt(datetime.utcnow())
    if week_end_slop >= today:
        msg = ("Skipping week of {} to {} - Data is still arriving until {}.".
               format(week_start, week_end, week_end_slop))
        raise RuntimeError(msg)

    logger.info("Starting week from {} to {}".format(week_start, week_end))

    # the subsession_start_date field has a different form than
    # submission_date_s3, so needs to be formatted with hyphens.
    week_end_excl = fmt(week_end_date + timedelta(1), date_format="%Y-%m-%d")
    week_start_hyphenated = fmt(week_start_date, date_format="%Y-%m-%d")

    current_week = (df.filter(df['submission_date_s3'] >= week_start).filter(
        df['submission_date_s3'] <= week_end_slop).filter(
            df['subsession_start_date'] >= week_start_hyphenated).filter(
                df['subsession_start_date'] < week_end_excl))

    # take a subset and rename the app_version field
    current_week = (current_week.select(source_columns).withColumnRenamed(
        "scalar_parent_browser_engagement_total_uri_count",
        "total_uri_count").withColumnRenamed(
            "scalar_parent_browser_engagement_unique_domains_count",
            "unique_domains_count").withColumnRenamed("app_version",
                                                      "version"))

    # clean some of the aggregate fields
    current_week = current_week.na.fill(
        0, ["total_uri_count", "unique_domains_count"])

    # Clamp broken subsession values in the [0, MAX_SUBSESSION_LENGTH] range.
    clamped_subsession_subquery = (F.when(
        F.col('subsession_length') > MAX_SUBSESSION_LENGTH,
        MAX_SUBSESSION_LENGTH).otherwise(
            F.when(F.col('subsession_length') < 0,
                   0).otherwise(F.col('subsession_length'))))

    # Compute per client aggregates lost during newest client computation
    per_client_aggregates = (current_week.select(
        'client_id', 'total_uri_count', 'unique_domains_count',
        clamped_subsession_subquery.alias('subsession_length')).groupby(
            'client_id').agg(
                F.sum('subsession_length').alias('usage_seconds'),
                F.sum('total_uri_count').alias('total_uri_count_per_client'),
                F.avg('unique_domains_count').alias(
                    'average_unique_domains_count_per_client')))

    # Get the newest ping per client and append to original dataframe
    newest_per_client = get_newest_per_client(current_week)
    newest_with_usage = newest_per_client.join(per_client_aggregates,
                                               'client_id', 'inner')

    # Build the "effective version" cache:
    d2v = make_d2v(get_release_info())

    converted = newest_with_usage.rdd.map(
        lambda x: convert(d2v, week_start, x))
    """
    - channel (appUpdateChannel)
    - geo (bucketed into top 30 countries + "rest of world")
    - is_funnelcake (contains "-cck-"?)
    - acquisition_period (cohort_week)
    - start_version (effective version on profile creation date)
    - sync_usage ("no", "single" or "multiple" devices)
    - current_version (current appVersion)
    - current_week (week)
    - source (associated attribution)
    - medium (associated with attribution)
    - campaign (associated with attribution)
    - content (associated with attribution)
    - distribution_id (funnelcake associated with profile)
    - default_search_engine
    - locale
    - is_active (were the client_ids active this week or not)
    - n_profiles (count of matching client_ids)
    - usage_hours (sum of the per-client subsession lengths,
            clamped in the [0, MAX_SUBSESSION_LENGTH] range)
    - sum_squared_usage_hours (the sum of squares of the usage hours)
    - total_uri_count (sum of per-client uri counts)
    - unique_domains_count_per_profile (average of the average unique
             domains per-client)
    """
    churn_schema = StructType([
        StructField('channel', StringType(), True),
        StructField('geo', StringType(), True),
        StructField('is_funnelcake', StringType(), True),
        StructField('acquisition_period', StringType(), True),
        StructField('start_version', StringType(), True),
        StructField('sync_usage', StringType(), True),
        StructField('current_version', StringType(), True),
        StructField('current_week', LongType(), True),
        StructField('source', StringType(), True),
        StructField('medium', StringType(), True),
        StructField('campaign', StringType(), True),
        StructField('content', StringType(), True),
        StructField('distribution_id', StringType(), True),
        StructField('default_search_engine', StringType(), True),
        StructField('locale', StringType(), True),
        StructField('is_active', StringType(), True),
        StructField('n_profiles', LongType(), True),
        StructField('usage_hours', DoubleType(), True),
        StructField('sum_squared_usage_hours', DoubleType(), True),
        StructField('total_uri_count', LongType(), True),
        StructField('unique_domains_count', DoubleType(), True)
    ])

    # Don't bother to filter out non-good records - they will appear
    # as 'unknown' in the output.
    countable = converted.map(lambda x: (
        (
            # attributes unique to a client
            x.get('channel', 'unknown'),
            x.get('geo', 'unknown'),
            "yes" if x.get('is_funnelcake', False) else "no",
            datetime.strftime(x.get('acquisition_period', date(2000, 1, 1)),
                              "%Y-%m-%d"),
            x.get('start_version', 'unknown'),
            x.get('sync_usage', 'unknown'),
            x.get('current_version', 'unknown'),
            x.get('current_week', -1),
            x.get('source', 'unknown'),
            x.get('medium', 'unknown'),
            x.get('campaign', 'unknown'),
            x.get('content', 'unknown'),
            x.get('distribution_id', 'unknown'),
            x.get('default_search_engine', 'unknown'),
            x.get('locale', 'unknown'),
            x.get('is_active', 'unknown')),
        (
            1,  # active users
            x.get('usage_hours', 0.0),
            x.get('squared_usage_hours', 0.0),
            x.get('total_uri_count', 0),
            x.get('unique_domains_count', 0.0))))

    def reduce_func(x, y):
        return tuple(map(sum, zip(x, y)))

    aggregated = countable.reduceByKey(reduce_func)

    records_df = aggregated.map(lambda x: x[0] + x[1]).toDF(churn_schema)

    # Apply some post-processing for other aggregates
    # (i.e. unique_domains_count). This needs to be done when you want
    # something other than just a simple sum
    def average(total, n):
        if not n:
            return 0.0
        return float(total) / n

    average_udf = F.udf(average, DoubleType())

    # Create new derived columns and drop any unnecessary ones
    records_df = (
        records_df
        # The total number of unique domains divided by the number of profiles
        # over a set of dimensions. This should be aggregated using a weighted
        # mean, i.e. sum(unique_domains_count_per_profile * n_profiles)
        .withColumn('unique_domains_count_per_profile',
                    average_udf(F.col('unique_domains_count'),
                                F.col('n_profiles')))
        # This value is meaningless because of overlapping domains between
        # profiles
        .drop('unique_domains_count')
    )

    return records_df
Beispiel #12
0
from pyspark.sql.types import StructType, StructField, StringType, ArrayType

from sparkql import merge_schemas

schema_a = StructType([
    StructField("message", StringType()),
    StructField("author",
                ArrayType(StructType([StructField("name", StringType())])))
])

schema_b = StructType([
    StructField("author",
                ArrayType(StructType([StructField("address", StringType())])))
])

merged_schema = merge_schemas(schema_a, schema_b)

pretty_merged_schema = """
StructType(List(
    StructField(message,StringType,true),
    StructField(author,
        ArrayType(StructType(List(
            StructField(name,StringType,true),
            StructField(address,StringType,true))),true),
        true)))
"""
Beispiel #13
0
    def get_associated_body_weight(
        self,
        specimen_level_experiment_df: DataFrame,
        mouse_specimen_df: DataFrame,
        impress_df: DataFrame,
    ) -> DataFrame:
        """
        Takes in DataFrame with Experimental data, one with Mouse Specimens and one with Impress information,
        and applies the algorithm to select the associated BW to any given experiment
        and calculate the age of experiment for the selected BW measurement.
        """
        # Explode the nested experiment DF structure so every row represents an observation
        weight_observations: DataFrame = specimen_level_experiment_df.withColumn(
            "simpleParameter", explode_outer("simpleParameter"))

        # Select the parameter relevant pieces from the IMPReSS DF
        parameters = impress_df.select(
            "pipelineKey",
            "procedure.procedureKey",
            "parameter.parameterKey",
            "parameter.analysisWithBodyweight",
        ).distinct()

        # Filter the IMPReSS using the analysisWithBodyweight flag
        weight_parameters = parameters.where(
            col("analysisWithBodyweight").isin(
                ["is_body_weight", "is_fasted_body_weight"]))

        # Join both the  observations DF and the BW parameters DF to obtain the observations that are BW
        weight_observations = weight_observations.join(
            weight_parameters,
            ((weight_observations["_pipeline"]
              == weight_parameters["pipelineKey"])
             & (weight_observations["_procedureID"]
                == weight_parameters["procedureKey"])
             & (weight_observations["simpleParameter._parameterID"]
                == weight_parameters["parameterKey"])),
        )
        # Create a boolean flag for fasted BW procedures
        weight_observations = weight_observations.withColumn(
            "weightFasted",
            col("analysisWithBodyweight") == "is_fasted_body_weight")

        weight_observations = weight_observations.select(
            "specimenID",
            "_centreID",
            col("unique_id").alias("sourceExperimentId"),
            col("_dateOfExperiment").alias("weightDate"),
            col("simpleParameter._parameterID").alias("weightParameterID"),
            col("simpleParameter.value").alias("weightValue"),
            "weightFasted",
        )
        weight_observations = weight_observations.where(
            col("weightValue").isNotNull())

        # Join the body weight observations so we can determine the  age of the specimen for any BW measurement
        weight_observations = weight_observations.join(
            mouse_specimen_df,
            (weight_observations["specimenID"]
             == mouse_specimen_df["_specimenID"])
            & (weight_observations["_centreID"]
               == mouse_specimen_df["_centreID"]),
        )
        weight_observations = weight_observations.withColumn(
            "weightDaysOld", datediff("weightDate", "_DOB"))

        #  Group the weight observations by Specimen
        weight_observations = weight_observations.groupBy("specimenID").agg(
            collect_set(
                struct(
                    "sourceExperimentId",
                    "weightDate",
                    "weightParameterID",
                    "weightValue",
                    "weightDaysOld",
                    "weightFasted",
                )).alias("weight_observations"))

        # Create a temporary "procedureGroup" column to be used in the  BW selection
        specimen_level_experiment_df = specimen_level_experiment_df.withColumn(
            "procedureGroup",
            udf(lambda prod_id: prod_id[:prod_id.rfind("_")],
                StringType())(col("_procedureID")),
        )

        # Join all the observations with the BW observations grouped by specimen
        specimen_level_experiment_df = specimen_level_experiment_df.join(
            weight_observations, "specimenID", "left_outer")
        # Schema for the struct that is going to group all the associated BW data
        output_weight_schema = StructType([
            StructField("sourceExperimentId", StringType()),
            StructField("weightDate", DateType()),
            StructField("weightParameterID", StringType()),
            StructField("weightValue", StringType()),
            StructField("weightDaysOld", IntegerType()),
            StructField("error", ArrayType(StringType())),
        ])

        # Alias both the experiment and the specimen df so is easier to join and manipulate
        experiment_df_a = specimen_level_experiment_df.alias("exp")
        mice_df_a = mouse_specimen_df.alias("mice")

        specimen_level_experiment_df = experiment_df_a.join(
            mice_df_a,
            (specimen_level_experiment_df["specimenID"]
             == mouse_specimen_df["_specimenID"])
            & (specimen_level_experiment_df["_centreID"]
               == mouse_specimen_df["_centreID"]),
            "left_outer",
        )

        # Add special dates to the experiment x specimen dataframe
        # for some experiments the date of sacrifice or date of blood collection
        # has to be used as reference for age and  BW calculations
        specimen_level_experiment_df = self._add_special_dates(
            specimen_level_experiment_df)
        get_associated_body_weight_udf = udf(self._get_closest_weight,
                                             output_weight_schema)
        specimen_level_experiment_df = specimen_level_experiment_df.withColumn(
            "weight",
            get_associated_body_weight_udf(
                when(
                    col("_dateOfBloodCollection").isNotNull(),
                    col("_dateOfBloodCollection"),
                ).when(
                    col("_dateOfSacrifice").isNotNull(),
                    col("_dateOfSacrifice")).otherwise(
                        col("_dateOfExperiment")),
                col("procedureGroup"),
                col("weight_observations"),
            ),
        )
        specimen_level_experiment_df = specimen_level_experiment_df.select(
            "exp.*", "weight")
        return specimen_level_experiment_df
from pyspark.sql import SparkSession
from pyspark.sql.functions import from_json, col, unbase64, base64, split
from pyspark.sql.types import StructField, StructType, StringType, BooleanType, ArrayType, DateType, FloatType

# TO-DO: using the spark application object, read a streaming dataframe from the Kafka topic stedi-events as the source
# Be sure to specify the option that reads all the events from the topic including those that were published before you started the spark stream
stediEventsSchema = StructType([
    StructField('customer', StringType()),
    StructField('score', FloatType()),
    StructField('riskData', DateType())
])

spark = SparkSession.builder.appName('stedi-app').getOrCreate()
spark.sparkContext.setLogLevel('WARN')

stediEventsRawStreamingDF = spark \
    .readStream \
    .format('kafka') \
    .option('kafka.bootstrap.servers', 'localhost:9092') \
    .option('subscribe', 'stedi-events') \
    .option('startingOffsets', 'earliest') \
    .load()

# TO-DO: cast the value column in the streaming dataframe as a STRING
stediEventsStreamingDF = stediEventsRawStreamingDF.selectExpr(
    'cast(key as string) key', 'cast(value as string) value')

# TO-DO: parse the JSON from the single column "value" with a json object in it, like this:
# +------------+
# | value      |
# +------------+
Beispiel #15
0
#######################################################################################
##### 9. Create StructType and StructFields
# Struct type allows you to create desired data type or structure the data type of your columns
# Struct fields are like column in your datastructure

from pyspark.sql.types import StructType, StructField, StringType, IntegerType

data = [("James", "", "Smith", "36636", "M", 3000),
        ("Michael", "Rose", "", "40288", "M", 4000),
        ("Robert", "", "Williams", "42114", "M", 4000),
        ("Maria", "Anne", "Jones", "39192", "F", 4000),
        ("Jen", "Mary", "Brown", "", "F", -1)]

schema = StructType([
    StructField("firstname", StringType(), True),
    StructField("middlename", StringType(), True),
    StructField("lastname", StringType(), True),
    StructField("id", StringType(), True),
    StructField("gender", StringType(), True),
    StructField("salary", IntegerType(), True)
])
df = spark.createDataFrame(data=data, schema=schema)
df.printSchema()
df.show(truncate=False)

##################################################################
################## Nested Schema
structureData = [(("James", "", "Smith"), "36636", "M", 3100),
                 (("Michael", "Rose", ""), "40288", "M", 4300),
                 (("Robert", "", "Williams"), "42114", "M", 1400),
Beispiel #16
0
def make_df_solid(context):
    schema = StructType([StructField("name", StringType()), StructField("age", IntegerType())])
    rows = [Row(name="John", age=19), Row(name="Jennifer", age=29), Row(name="Henry", age=50)]
    return context.resources.pyspark.spark_session.createDataFrame(rows, schema)
# ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
# SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
# ============

from pyspark import SparkContext, SparkConf
from pyspark.sql import SQLContext
from pyspark.sql.types import StructField, StructType, StringType, IntegerType, BooleanType, DoubleType, ByteType

sparkConf = SparkConf().setAppName("DataFrameExample").setMaster("local[2]")

sc = SparkContext(conf=sparkConf)
sqlContext = SQLContext(sc)

schema = StructType([
    StructField("Arrest", BooleanType()),
    StructField("Beat", StringType()),
    StructField("Block", StringType()),
    StructField("CaseNumber", IntegerType()),
    StructField("CommunityArea", IntegerType()),
    StructField("Date", StringType()),
    StructField("Description", StringType()),
    StructField("District", IntegerType()),
    StructField("Domestic", BooleanType()),
    StructField("FBICode", IntegerType()),
    StructField("ID", StringType()),
    StructField("IUCR", IntegerType()),
    StructField("Latitude", DoubleType()),
    StructField("Location", StringType()),
    StructField("LocationDescription", StringType()),
    StructField("Longitude", DoubleType()),
def ecg_autosense_data_quality(ecg,
                               Fs=64,
                               sensor_name='autosense',
                               outlier_threshold_high=4000,
                               outlier_threshold_low=20,
                               slope_threshold=100,
                               range_threshold=50,
                               eck_threshold_band_loose=400,
                               window_size=3,
                               acceptable_outlier_percent=34):
    """
    Some desc..

    Args:
        ecg (DataStream):
        Fs (int):
        sensor_name (str):
        outlier_threshold_high (int):
        outlier_threshold_low (int):
        slope_threshold (int):
        range_threshold (int):
        eck_threshold_band_loose (int):
        window_size (int):
        acceptable_outlier_percent (int):

    Returns:
        DataStream - structure [timestamp, localtime, version.....]
    """
    data_quality_band_loose = 'loose/improper attachment'
    data_quality_not_worn = 'sensor off body'
    data_quality_band_off = 'battery down/disconnected'
    data_quality_missing = 'intermittent data loss'
    data_quality_good = 'acceptable'
    stream_name = 'org.md2k.autosense.ecg.quality'

    def get_metadata():
        stream_metadata = Metadata()
        stream_metadata.set_name(stream_name).set_description("Chest ECG quality 3 seconds") \
            .add_input_stream(ecg.metadata.get_name()) \
            .add_dataDescriptor(DataDescriptor().set_name("timestamp").set_type("datetime")) \
            .add_dataDescriptor(DataDescriptor().set_name("localtime").set_type("datetime")) \
            .add_dataDescriptor(DataDescriptor().set_name("version").set_type("int")) \
            .add_dataDescriptor(DataDescriptor().set_name("user").set_type("string")) \
            .add_dataDescriptor(
            DataDescriptor().set_name("quality").set_type("string") \
                .set_attribute("description", "ECG data quality") \
                .set_attribute('Loose/Improper Attachment','Electrode Displacement') \
                .set_attribute('Sensor off Body', 'Autosense not worn') \
                .set_attribute('Battery down/Disconnected', 'No data is present - Can be due to battery down or sensor disconnection') \
                .set_attribute('Intermittent Data Loss','Not enough samples are present') \
                .set_attribute('Acceptable','Good Quality')) \
            .add_dataDescriptor(
            DataDescriptor().set_name("ecg").set_type("double").set_attribute("description", \
                                                                              "ecg sample value")) \
            .add_module(
            ModuleMetadata().set_name("ecg data quality").set_attribute("url", "http://md2k.org/").set_author(
                "Md Azim Ullah", "*****@*****.**"))
        return stream_metadata

    def get_quality_autosense(data):
        """

        Args:
            data:

        Returns:

        """
        minimum_expected_samples = window_size * acceptable_outlier_percent * Fs / 100

        if (len(data) == 0):
            return data_quality_band_off
        if (len(data) <= minimum_expected_samples):
            return data_quality_missing
        range_data = max(data) - min(data)
        if range_data <= range_threshold:
            return data_quality_not_worn
        if range_data <= eck_threshold_band_loose:
            return data_quality_band_loose
        outlier_counts = 0
        for i in range(0, len(data)):
            im, ip = i, i
            if i == 0:
                im = len(data) - 1
            else:
                im = i - 1
            if i == len(data) - 1:
                ip = 0
            else:
                ip = ip + 1
            stuck = ((data[i] == data[im]) and (data[i] == data[ip]))
            flip = ((abs(data[i] - data[im]) > ((int(outlier_threshold_high))))
                    or (abs(data[i] - data[ip]) >
                        ((int(outlier_threshold_high)))))
            disc = ((abs(data[i] - data[im]) > ((int(slope_threshold))))
                    and (abs(data[i] - data[ip]) > ((int(slope_threshold)))))
            if disc:
                outlier_counts += 1
            elif stuck:
                outlier_counts += 1
            elif flip:
                outlier_counts += 1
            elif data[i] >= outlier_threshold_high:
                outlier_counts += 1
            elif data[i] <= outlier_threshold_low:
                outlier_counts += 1
        if (100 * outlier_counts > acceptable_outlier_percent * len(data)):
            return data_quality_band_loose
        return data_quality_good

    schema = StructType([
        StructField("timestamp", TimestampType()),
        StructField("localtime", TimestampType()),
        StructField("version", IntegerType()),
        StructField("user", StringType()),
        StructField("quality", StringType()),
        StructField("ecg", DoubleType())
    ])

    @pandas_udf(schema, PandasUDFType.GROUPED_MAP)
    @CC_MProvAgg('ecg--org.md2k.autosense--autosense_chest--chest',
                 'ecg_autosense_data_quality', stream_name,
                 ['user', 'timestamp'], ['user', 'timestamp'])
    def data_quality(data):
        """

        Args:
            data:

        Returns:

        """
        data['quality'] = ''
        if data.shape[0] > 0:
            data = data.sort_values('timestamp')
            if sensor_name in ['autosense']:
                data['quality'] = get_quality_autosense(list(data['ecg']))
        return data

    ecg_quality_stream = ecg.compute(data_quality,
                                     windowDuration=3,
                                     startTime='0 seconds')
    data = ecg_quality_stream._data
    ds = DataStream(data=data, metadata=get_metadata())
    return ds
    movieNames = {}
    # CHANGE THIS TO THE PATH TO YOUR u.ITEM FILE:
    with codecs.open("/opt/bitnami/spark/spark-data/ml-100k/u.item",
                     "r",
                     encoding='ISO-8859-1',
                     errors='ignore') as f:
        for line in f:
            fields = line.split('|')
            movieNames[int(fields[0])] = fields[1]
    return movieNames


spark = get_spark_session('ALSExample')

moviesSchema = StructType([ \
                     StructField("userID", IntegerType(), True), \
                     StructField("movieID", IntegerType(), True), \
                     StructField("rating", IntegerType(), True), \
                     StructField("timestamp", LongType(), True)])

names = loadMovieNames()

ratings = spark.read.option("sep", "\t").schema(moviesSchema) \
    .csv(f"{SPARK_DATA_PATH}/ml-100k/u.data")

print("Training recommendation model...")

als = ALS().setMaxIter(5).setRegParam(0.01).setUserCol("userID").setItemCol("movieID") \
    .setRatingCol("rating")

model = als.fit(ratings)
def read_pigeon_obs(pigeon_s3):
    pigeon_schema = StructType([
        StructField("eventid", StringType(), False),
        StructField("visible", StringType(), False),
        StructField("timestamp", TimestampType(), False),
        StructField("longitude", FloatType(), False),
        StructField("latitude", FloatType(), False),
        StructField("gps", IntegerType(), True),
        StructField("ground_speed", FloatType(), True),
        StructField("height_above_sealevel", FloatType(), True),
        StructField("outlier_flag", StringType(), True),
        StructField("sensor_type", StringType(), True),
        StructField("taxon_name", StringType(), True),
        StructField("tag_local_identifier", StringType(), True),
        StructField("individual_local_identifier", StringType(), True),
        StructField("study_name", StringType(), True)
    ])
    pigeon_obs = spark.read.csv(pigeon_s3,
                                schema=pigeon_schema,
                                timestampFormat='yyyy-MM-dd HH:mm:ss.SSS',
                                header=True)
    return pigeon_obs
        .config("spark.some.config.option", "some-value") \
        .getOrCreate()

    sqlContext = SQLContext(spark)

    # get command-line arguments
    inFile = sys.argv[1]
    supp = sys.argv[2]
    conf = sys.argv[3]
    prot = sys.argv[4]

    print("Executing HW2SQL with input from " + inFile + ", support=" + supp +
          ", confidence=" + conf + ", protection=" + prot)

    pp_schema = StructType([
        StructField("uid", IntegerType(), True),
        StructField("attr", StringType(), True),
        StructField("val", IntegerType(), True)
    ])

    Pro_Publica = sqlContext.read.format('csv').options(
        header=False).schema(pp_schema).load(inFile)
    Pro_Publica.createOrReplaceTempView("Pro_Publica")
    sqlContext.cacheTable("Pro_Publica")
    spark.sql("select count(*) from Pro_Publica").show()

    # compute frequent itemsets of size 1, store in F1(attr, val)
    query = "select attr, val, count(*) as supp \
               from Pro_Publica \
              group by attr, val \
             having count(*) >= " + str(supp)
Beispiel #22
0
}), ('Washington', {
    'hair': 'grey',
    'eye': 'grey'
}), ('Jefferson', {
    'hair': 'brown',
    'eye': ''
})]

df = spark.createDataFrame(data=dataDictionary, schema=['name', 'properties'])
df.printSchema()
df.show(truncate=False)

# Using StructType schema
from pyspark.sql.types import StructField, StructType, StringType, MapType, IntegerType
schema = StructType([
    StructField('Name', StringType(), True),
    StructField('properties', MapType(StringType(), StringType()), True)
])
df2 = spark.createDataFrame(data=dataDictionary, schema=schema)
df2.printSchema()
df2.show(truncate=False)

df3=df.rdd.map(lambda x: \
    (x.name,x.properties["hair"],x.properties["eye"])) \
    .toDF(["name","hair","eye"])
df3.printSchema()
df3.show()

df.withColumn("hair",df.properties.getItem("hair")) \
  .withColumn("eye",df.properties.getItem("eye")) \
  .drop("properties") \
Beispiel #23
0
#! /usr/bin/env python3.6
#import findspark
#findspark.init('/usr/lib/spark-current')

import pyspark
from pyspark.sql import SparkSession
ss = SparkSession.builder.appName("Hehe Final Python Spark with ML").getOrCreate()

#### 1 数据预处理
#### 1.1 读取数据
from pyspark.sql.types import StructField, StructType, StringType,  IntegerType, DoubleType

schema_sdf = StructType([
        StructField('ID', StringType(), True),
        StructField('Source', StringType(), True),
        StructField('TMC', StringType(), True), # 交通消息通道(TMC)代码
        StructField('Severity', StringType(), True), # 事故的严重程度1234 -定序
        StructField('Start_Time', StringType(), True),
        StructField('End_Time', StringType(), True),
        StructField('Start_Lat', DoubleType(), True), # 纬度
        StructField('Start_Lng', DoubleType(), True), # 经度
        StructField('End_Lat', DoubleType(), True),
        StructField('End_Lng', DoubleType(), True),
        StructField('Distance_mi', DoubleType(), True), #受事故影响的道路范围的长度
        StructField('Description', StringType(), True),
        StructField('Number', StringType(), True), # 街道号码
        StructField('Street', StringType(), True), # 街道名称
        StructField('Side', StringType(), True),   # 左侧右侧
        StructField('City', StringType(), True),  # 城市
        StructField('County', StringType(), True), # 县
        StructField('State', StringType(), True),  # 州
Beispiel #24
0
 def with_idx(sdf):
     new_schema = StructType(sdf.schema.fields + [StructField("idx", LongType(), False), ])
     return sdf.rdd.zipWithIndex().map(lambda row: row[0] + (row[1],)).toDF(
         schema=new_schema)
if __name__ == "__main__":

# Create spark session and set configurations, here we set yarn, because we want to run spark application on hadoop environment
    spark = SparkSession \
        .builder \
        .appName("Sliding Window Steram") \
        .master("yarn") \
        .config("spark.streaming.stopGracefullyOnShutdown", "true") \
        .config("spark.sql.shuffle.partitions", 1) \
        .getOrCreate()

# Define schema for kafka data
# Schema will be used during deserialization of kafka data
    stock_schema = StructType([
        StructField("CreatedTime", StringType()),
        StructField("Reading", DoubleType())
    ])

# Read data from kafka topic
    kafka_source_df = spark.readStream \
        .format("kafka") \
        .option("kafka.bootstrap.servers", "localhost:9092") \
        .option("subscribe", "sensor") \
        .option("startingOffsets", "earliest") \
        .load()

# Deserialization
    value_df = kafka_source_df.select(col("key").cast("string").alias("SensorID"),
                                      from_json(col("value").cast("string"), stock_schema).alias("value"))
Beispiel #26
0
from pyspark.sql.types import (
    ArrayType,
    BinaryType,
    DoubleType,
    StructType,
    StructField,
    StringType,
    IntegerType,
    LongType,
)

ARCHIVE_ORG_SCHEMA = StructType(
    [
        StructField("created", LongType(), True),
        StructField("d1", StringType(), True),
        StructField("d2", StringType(), True),
        StructField("dir", StringType(), True),
        StructField(
            "files",
            ArrayType(
                StructType(
                    [
                        StructField("bitrate", StringType(), True),
                        StructField("btih", StringType(), True),
                        StructField("crc32", StringType(), True),
                        StructField("format", StringType(), True),
                        StructField("height", StringType(), True),
                        StructField("length", StringType(), True),
                        StructField("license", StringType(), True),
                        StructField("md5", StringType(), True),
                        StructField("mtime", StringType(), True),
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
plt.style.use(['dark_background'])
import os
import datetime
import 

os.environ["PYSPARK_PYTHON"]="/usr/bin/python3"
os.environ["PYSPARK_DRIVER_PYTHON"]="/usr/bin/python3"


app = Flask(__name__)
map_count=0
spark = SparkSession.builder.appName("Chicago_crime_analysis").getOrCreate()
crimes_schema = StructType([StructField("_c0", StringType(), True),
                            StructField("ID", StringType(), True),
                            StructField("CaseNumber", StringType(), True),
                            StructField("Date", StringType(), True ),
                            StructField("Block", StringType(), True),
                            StructField("IUCR", StringType(), True),
                            StructField("PrimaryType", StringType(), True  ),
                            StructField("Description", StringType(), True ),
                            StructField("LocationDescription", StringType(), True ),
                            StructField("Arrest", BooleanType(), True),
                            StructField("Domestic", BooleanType(), True),
                            StructField("Beat", StringType(), True),
                            StructField("District", StringType(), True),
                            StructField("Ward", StringType(), True),
                            StructField("CommunityArea", StringType(), True),
                            StructField("FBICode", StringType(), True ),
from pyspark.sql import SparkSession
from pyspark.sql.types import StringType, StructField, StructType
from pyspark.sql.functions import udf
from datetime import datetime
from time import time

if __name__ == "__main__":

    spark = SparkSession.builder.appName(
        'Stream data groupby timestamp').getOrCreate()

    spark.sparkContext.setLogLevel("ERROR")
    # declare schema defination
    schema = StructType([
        StructField("lsoc_code", StringType(), True),
        StructField("borough", StringType(), True),
        StructField("major_category", StringType(), True),
        StructField("minor_category", StringType(), True),
        StructField("value", StringType(), True),
        StructField("year", StringType(), True),
        StructField("month", StringType(), True)
    ])

    # read the streaming
    fileStreamDF = spark.readStream\
        .option("header","true")\
        .schema(schema)\
        .csv("../datasets/droplocation")

    # Custom function to create the timestamp
    def get_timestamp():
def run(argv: sys.argv):
    parser = argparse.ArgumentParser()
    parser.add_argument('--project', dest='project', type=str)
    parser.add_argument('--bucket', dest='bucket', type=str)
    known_args, pipeline_args = parser.parse_known_args(argv)

    # Known args
    # Project - Google Cloud project
    # Bucket table - Bucket for temporal table creation
    project = known_args.project
    bucket = known_args.bucket

    spark = SparkSession\
        .builder\
        .appName(f'spark-etl_retentions_remittances')\
        .getOrCreate()
    spark.conf.set('temporaryGcsBucket', bucket)

    # Read source transactions Dataframe
    transactions = spark.read.format('bigquery')\
        .option('table', f'{project}:remittances.remittances_movements')\
        .load()\
        .where(F.col('created_at').isNotNull())
    transactions.createOrReplaceTempView('remittances_movements')

    # Get starting date of users
    users_date = transactions\
        .select('created_by_user', 'created_at')\
        .groupBy('created_by_user')\
        .agg({'created_at': 'min'})\
        .withColumnRenamed('min(created_at)', 'min_date')\
        .withColumn('min_date', F.to_date(F.col('min_date')))\
        .orderBy('created_by_user', ascending=True)

    # Get transactions dates
    transactions_dates = transactions\
        .select('created_at')\
        .withColumnRenamed('created_at', 'date')\
        .withColumn('date', F.to_date(F.col('date')))\
        .distinct()\
        .orderBy('date', ascending=True)

    # Get user by day
    day_users = transactions\
        .select('created_by_user', 'created_at')\
        .withColumn('created_at', F.to_date(F.col('created_at')))\
        .groupBy('created_at')\
        .count()\
        .withColumnRenamed('created_at', 'date')\
        .withColumnRenamed('count', 'total_users')\
        .orderBy('date', ascending=False)

    # Get new users of transactions days
    transactions_new = transactions_dates.alias('d')\
        .join(users_date.alias('u'), F.col('d.date') == F.col('u.min_date'), how='left')\
        .select('date', 'created_by_user')\
        .groupBy('date') \
        .count() \
        .withColumnRenamed('count', 'new_users')

    # Join total and new users count by day
    transactions_dates = transactions_dates.alias('c')\
        .join(transactions_new.alias('n'), 'date')\
        .join(day_users.alias('t'), 'date')\
        .orderBy('date', ascending=False)

    # Calculate retention by date
    retention_window = Window.orderBy('date').rowsBetween(-1, -1)
    retentions_day = transactions_dates\
        .withColumn('prev_total_users', F.lag('total_users', 1, 0).over(retention_window))\
        .withColumn(
            'retention_rate',
            (F.col('total_users') - F.col('new_users')) / F.lag('total_users', 1, 0).over(retention_window)
        )\
        .withColumn('retention_rate', F.col('retention_rate').cast('decimal(38, 9)'))\
        .filter(F.col('date').isNotNull())\
        .filter(F.col('new_users').isNotNull())\
        .filter(F.col('total_users').isNotNull())\
        .filter(F.col('retention_rate').isNotNull())

    # Set required fields for schema - Preparing to write in BigQuery table
    retention_schema = [
        StructField('date', DateType(), False),
        StructField('new_users', IntegerType(), False),
        StructField('total_users', IntegerType(), False),
        StructField('prev_total_users', IntegerType(), False),
        StructField('retention_rate', DecimalType(38, 9), False)
    ]
    retention_df = spark.createDataFrame(retentions_day.rdd,
                                         StructType(retention_schema))

    # Write results in BigQuery
    retention_df.write\
        .format('bigquery')\
        .option('table', f'{project}:results.retentions_remittances_d')\
        .mode("append")\
        .save()
    movie_names = {}
    # CHANGE THIS TO THE PATH TO YOUR u.ITEM FILE:
    with codecs.open('datasets/u.item', 'r',
                     encoding='ISO-8859-1', errors='ignore') as f:
        for line in f:
            fields = line.split('|')
            movie_names[int(fields[0])] = fields[1]
    return movie_names

spark = SparkSession.builder.appName('PopularMovies').getOrCreate()

# Broadcast a dictionary
name_dict = spark.sparkContext.broadcast(loadMovieNames())

schema = StructType([
    StructField("user_id", IntegerType(), True),
    StructField("movie_id", IntegerType(), True),
    StructField("rating", IntegerType(), True),
    StructField("timestamp", IntegerType(), True)
])
ratings_df = spark.read.csv('datasets/u.data',
                            sep='\t',
                            schema=schema)
ratings_df = ratings_df.select(['movie_id', 'rating'])

ratings_df.show(5)
# +--------+------+
# |movie_id|rating|
# +--------+------+
# |     242|     3|
# |     302|     3|