Пример #1
0
def infer_spark_type(typeclass) -> t.DataType:
    if typeclass in (None, NoneType):
        return t.NullType()
    elif typeclass is str:
        return t.StringType()
    elif typeclass in {bytes, bytearray}:
        return t.BinaryType()
    elif typeclass is bool:
        return t.BooleanType()
    elif typeclass is date:
        return t.DateType()
    elif typeclass is datetime:
        return t.TimestampType()
    elif typeclass is Decimal:
        return t.DecimalType(precision=36, scale=6)
    elif isinstance(typeclass, type) and issubclass(typeclass, BoundDecimal):
        (precision, scale) = typeclass.__constraints__
        return t.DecimalType(precision=precision, scale=scale)
    elif typeclass is float:
        return t.DoubleType()
    elif typeclass is int:
        return t.IntegerType()
    elif typeclass is long:
        return t.LongType()
    elif typeclass is short:
        return t.ShortType()
    elif typeclass is byte:
        return t.ByteType()
    elif getattr(typeclass, "__origin__", None) is not None:
        return infer_complex_spark_type(typeclass)
    elif is_pyspark_class(typeclass):
        return transform(typeclass)
    else:
        raise TypeError(f"Don't know how to represent {typeclass} in Spark")
Пример #2
0
def build_song_schema():
    """Build and return a schema to use for the song data.
    Returns schema: StructType object, a representation of schema and defined fields
    """
    schema = T.StructType(
        [
            T.StructField('artist_id', T.StringType(), True),
            T.StructField('artist_latitude', T.DecimalType(), True),
            T.StructField('artist_longitude', T.DecimalType(), True),
            T.StructField('artist_location', T.StringType(), True),
            T.StructField('artist_name', T.StringType(), True),
            T.StructField('duration', T.DecimalType(), True),
            T.StructField('num_songs', T.IntegerType(), True),
            T.StructField('song_id', T.StringType(), True),
            T.StructField('title', T.StringType(), True),
            T.StructField('year', T.IntegerType(), True)
        ]
    )
    return schema
Пример #3
0
def load_table_from_csv(spark, filename, ignore_cols: List[str] = []):
    tablename = os.path.basename(filename).split(".")[0]
    schemaFile: str = os.path.join(os.path.dirname(filename),
                                   f"{tablename}.schema.json")
    if os.path.exists(schemaFile):
        # use schemafile
        with open(schemaFile, 'r') as schemaFileHandler:
            schema_fields = json.load(schemaFileHandler)
            df_schema = {'fields': [], 'type': 'struct'}
            for field in schema_fields["fields"]:
                df_schema["fields"].append({  # type:ignore
                    "name":
                    field["name"],
                    "type":
                    field["type"],
                    "nullable":
                    field.get("nullable", True),
                    "metadata":
                    field.get("metadata", {})
                })
        df = spark.read.option("header",True).\
            option("inferSchema",False).\
            option("ignoreTrailingWhiteSpace",True).\
            option("quote", "\"").\
            option("escape", "\"").\
            option("multiLine", "true").\
            csv(filename)
        df = df.select(*[
            F.col(details["name"]).cast(details["type"]).alias(
                details["name"])  # type: ignore
            for details in df_schema["fields"]
        ])
    else:
        # for performance, we load only a subset of rows to inferSchema, then use that schema when loading full file.
        # otherwise, spark tries to read in the whole file to inferSchema
        sample_rows = spark.read.text(filename).limit(500)
        file_schema = spark.read.option("header", True).option(
            "inferSchema",
            True).csv(sample_rows.rdd.map(lambda x: x[0])).schema
        df = spark.read.option("header",True).\
            option("inferSchema",False).\
            option("ignoreTrailingWhiteSpace",True).\
            option("quote", "\"").\
            option("escape", "\"").\
            option("multiLine", "true").\
            csv(filename,schema=file_schema)
        # convert all double to decimal
        df = df.select(*[
            F.col(field.name).cast(T.DecimalType(38, 18)) if field.dataType ==
            T.DoubleType() else F.col(field.name) for field in df.schema
        ])
    # make column names lowercase
    df = df.toDF(*[c.lower() for c in df.columns])
    df = df.drop(*ignore_cols)
    return df
Пример #4
0
def as_spark_type(tpe) -> types.DataType:
    """
    Given a Python type, returns the equivalent spark type.
    Accepts:
    - the built-in types in Python
    - the built-in types in numpy
    - list of pairs of (field_name, type)
    - dictionaries of field_name -> type
    - Python3's typing system
    """
    # TODO: Add "boolean" and "string" types.
    # ArrayType
    if tpe in (np.ndarray,):
        return types.ArrayType(types.StringType())
    elif hasattr(tpe, "__origin__") and issubclass(tpe.__origin__, list):
        return types.ArrayType(as_spark_type(tpe.__args__[0]))
    # BinaryType
    elif tpe in (bytes, np.character, np.bytes_, np.string_):
        return types.BinaryType()
    # BooleanType
    elif tpe in (bool, np.bool, "bool", "?"):
        return types.BooleanType()
    # DateType
    elif tpe in (datetime.date,):
        return types.DateType()
    # NumericType
    elif tpe in (np.int8, np.byte, "int8", "byte", "b"):
        return types.ByteType()
    elif tpe in (decimal.Decimal,):
        # TODO: considering about the precision & scale for decimal type.
        return types.DecimalType(38, 18)
    elif tpe in (float, np.float, np.float64, "float", "float64", "double"):
        return types.DoubleType()
    elif tpe in (np.float32, "float32", "f"):
        return types.FloatType()
    elif tpe in (np.int32, "int32", "i"):
        return types.IntegerType()
    elif tpe in (int, np.int, np.int64, "int", "int64", "long", "bigint"):
        return types.LongType()
    elif tpe in (np.int16, "int16", "short"):
        return types.ShortType()
    # StringType
    elif tpe in (str, np.unicode_, "str", "U"):
        return types.StringType()
    # TimestampType
    elif tpe in (datetime.datetime, np.datetime64, "datetime64[ns]", "M"):
        return types.TimestampType()
    else:
        raise TypeError("Type %s was not understood." % tpe)
Пример #5
0
def analysis(df):
    monthly_sales_df = df.select(['year', 'month', 'sales']).groupBy(['year', 'month']).sum() \
        .withColumn('monthNumber', F.date_format(F.to_date(F.col('month'), 'MMM'), 'MM').cast('int')) \
        .select('year', 'monthNumber', 'month', 'sum(sales)').sort('year', 'monthNumber')

    partition = Window.partitionBy("year") \
        .orderBy("year", "monthNumber") \
        .rowsBetween(Window.unboundedPreceding, Window.currentRow)

    partition_2 = Window.partitionBy().orderBy("year", "monthNumber")

    monthly_sales_df = monthly_sales_df.withColumn("sales_yearly_cum_sum", F.sum(F.col('sum(sales)')).over(partition)) \
        .withColumn('monthly_perc_change',
                    ((F.col('sum(sales)') / (F.lag(F.col('sum(sales)')).over(partition_2)) - 1) * 100)
                    .cast(st.DecimalType(10, 2)))
    return monthly_sales_df.drop('monthNumber')
Пример #6
0
def as_spark_type(tpe) -> types.DataType:
    """
    Given a python type, returns the equivalent spark type.
    Accepts:
    - the built-in types in python
    - the built-in types in numpy
    - list of pairs of (field_name, type)
    - dictionaries of field_name -> type
    - python3's typing system
    """
    if tpe in (str, "str", "string"):
        return types.StringType()
    elif tpe in (bytes, ):
        return types.BinaryType()
    elif tpe in (np.int8, "int8", "byte"):
        return types.ByteType()
    elif tpe in (np.int16, "int16", "short"):
        return types.ShortType()
    elif tpe in (int, "int", np.int, np.int32):
        return types.IntegerType()
    elif tpe in (np.int64, "int64", "long", "bigint"):
        return types.LongType()
    elif tpe in (float, "float", np.float):
        return types.FloatType()
    elif tpe in (np.float64, "float64", "double"):
        return types.DoubleType()
    elif tpe in (decimal.Decimal, ):
        return types.DecimalType(38, 18)
    elif tpe in (datetime.datetime, np.datetime64):
        return types.TimestampType()
    elif tpe in (datetime.date, ):
        return types.DateType()
    elif tpe in (bool, "boolean", "bool", np.bool):
        return types.BooleanType()
    elif tpe in (np.ndarray, ):
        # TODO: support other child types
        return types.ArrayType(types.StringType())
    else:
        raise TypeError("Type %s was not understood." % tpe)
Пример #7
0
def create_demography_table(spark, df_clean_demograph, output_parquet):
    # create demography table per ethnic per state
    try:
        dim_demography = df_clean_demograph \
                        .filter('state_id !=""') \
                        .groupBy( 'state_id', 'ethnic') \
                        .agg((F.avg('ethnic_count').cast(T.DecimalType(22, 2))).alias('avg_ethnic')) \
                        .dropDuplicates() \
                        .orderBy("state_id")

        #dim_demography.printSchema()
        #print(dim_airport_us.count())
        #dim_demography.show(2)
        dim_demography.collect()
        #dim_demography.toPandas().to_csv(output_parquet+"demograph_table.csv", header=True)
        parquet_path = output_parquet + 'demograph_table'
        write_parquet(dim_demography, parquet_path)
        check_parquet(spark, parquet_path)
        return (dim_demography)
    except Exception as e:
        print("Unexpected error: %s" % e)
        sys.exit()
Пример #8
0
def ibis_decimal_dtype_to_spark_dtype(ibis_dtype_obj):
    precision = ibis_dtype_obj.precision
    scale = ibis_dtype_obj.scale
    return pt.DecimalType(precision, scale)
Пример #9
0
spark.conf.set("spark.sql.sources.partitionOverwriteMode", "dynamic")

spark.sparkContext.addPyFile(custom_lib_path)
spark.sparkContext.addFile(file_path)
with open(SparkFiles.get("file")) as f:
    pass
spark.catalog.listDatabases()
spark.catalog.listTables("database_name")
spark.catalog.listColumns("table_name", dbName="database_name")
spark.sql("SHOW PARTITIONS database_name.table_name").collect()

schema = T.StructType([
    T.StructField("column_name_string", T.StringType(), True),
    T.StructField("column_name_bigint", T.IntegerType(), True),
    T.StructField("column_name_decimal", T.DecimalType(), True)
])


def example_function(x):
    return x


example_function_udf = F.udf(example_function, T.IntegerType())

df = spark.sql("SELECT * FROM database_name.table_name")
df = spark.table("database_name.table_name")

df.select("column_name")
df.first()
df.take(5)
Пример #10
0
     ('import_date', T.TimestampType()),
     ('sourcesystem_cd', T.StringType()),
     ('upload_id', T.IntegerType()),
 ]),
 'observation_fact':
 OrderedDict([
     ('encounter_num', T.StringType()),
     ('concept_cd', T.StringType()),
     ('provider_id', T.StringType()),
     ('start_date', T.TimestampType()),
     ('patient_num', T.StringType()),
     ('modifier_cd', T.StringType()),
     ('instance_num', T.StringType()),
     ('valtype_cd', T.StringType()),
     ('tval_char', T.StringType()),
     ('nval_num', T.DecimalType(18, 5)),
     ('valueflag_cd', T.StringType()),
     ('quantity_num', T.DecimalType(18, 5)),
     ('units_cd', T.StringType()),
     ('end_date', T.TimestampType()),
     ('location_cd', T.StringType()),
     ('observation_blob', T.StringType()),
     ('confidence_num', T.DecimalType(18, 5)),
     ('update_date', T.TimestampType()),
     ('download_date', T.TimestampType()),
     ('import_date', T.TimestampType()),
     ('sourcesystem_cd', T.StringType()),
     ('upload_id', T.IntegerType()),
 ]),
 'patient_dimension':
 OrderedDict([
Пример #11
0
def test_custom_decimal():
    check(decimal(18, 11), t.DecimalType(precision=18, scale=11))
Пример #12
0
def test_decimal():
    check(Decimal, t.DecimalType(precision=36, scale=6))

# Forming table for GDP prediction

IN_PATH = "s3://mysparks/OUTPUT-Folder/"
OUT_PATH = "s3://mysparks/OUTPUT-Folder/"
#IN_PATH = "/home/at/project/output/"
#OUT_PATH = "/home/at/project/output/out/"


os.makedirs(OUT_PATH, exist_ok=True)
#exp_schema = json.load(open("../schema/statcan/" + exp_id + ".json"))
hours_schema = types.StructType([
	types.StructField('YEAR', types.StringType()),
	types.StructField('NAICS', types.StringType()),
	types.StructField('HOURS_WORKED', types.DecimalType()),
	types.StructField('GDP_VALUE', types.DecimalType()),
	types.StructField('labour_productivity', types.DecimalType()),
])

cpi_schema = types.StructType([
	types.StructField('YEAR', types.StringType()),
	types.StructField('Alberta', types.DecimalType()),
	types.StructField('British Columbia', types.DecimalType()),
	types.StructField('Canada', types.DecimalType()),
	types.StructField('Manitoba', types.DecimalType()),
	types.StructField('New Brunswick', types.DecimalType()),
	types.StructField('Newfoundland and Labrador', types.DecimalType()),
	types.StructField('Nova Scotia', types.DecimalType()),
	types.StructField('Ontario', types.DecimalType()),
	types.StructField('Prince Edward Island', types.DecimalType()),
Пример #14
0
    ([('select', [['name', 'age', 'happy']])], 3),
    ([('select', [['*']])], 3)
])
def test_select_returns_df_successfully(args, col_count, spark_session):
    """Select returns DF successfully."""
    # Act
    result = transform(create_simple_df(spark_session), args)

    # Assert
    assert len(result.columns) == col_count


@pytest.mark.spark
@pytest.mark.parametrize('args', [
    ([('cast_column', ["age", T.StringType()])]),
    ([('cast_column', ["name", T.DecimalType()])]),
    ([('cast_column', ["happy", T.DateType()])]),
])
def test_cast_returns_df_successfully(args, spark_session):
    """cast_column returns DF successfully after type cast even when uncastable."""
    # Act
    result = transform(create_simple_df(spark_session), args)

    # Assert
    assert isinstance(result, Column)


@pytest.mark.spark
@pytest.mark.parametrize('args,col', [
    ([('cast_column', ["no_age", T.DateType()])], "no_age"),
    ([('cast_column', ["no_name", T.StringType()])], "no_name"),
    def dataB_process_otm(df_s: DataFrame, df_cost: DataFrame,
                          df_sr: DataFrame, df_sstat: DataFrame,
                          df_ss: DataFrame, df_ss_remark: DataFrame,
                          df_inv: DataFrame) -> DataFrame:
        """
        This function takes OTM source data frames and processes Freght Costs for the
        invoices in the dataB database that correspond to the Recycle Mills.
        :param df_s: Shipment table dataframe
        :param df_cost: Shipment_Cost table dataframe
        :param df_sr: Shipment_Refnum table dataframe
        :param df_sstat: Shipment_Status table dataframe
        :param df_ss: Shipment_Stop table dataframe
        :param df_ss_remark: Shipment_Stop_Remark table dataframe
        :param df_inv: dataB Recycle Invoices dataframe to match with OTM data
        :return: df: OTM processed dataframe
        """

        # Filter the Shipment_Refnum dataframe to only include records
        # with domain_name of SSCC and shipment_refnum_qual_gid's that equal MB or SSCC.MB.
        # Including a distinct because of duplicate values in the source.
        df_sr = df_sr.where(df_sr.domain_name == 'SSCC')\
            .where(df_sr.shipment_refnum_qual_gid.isin(['MB', 'SSCC.MB']))\
            .distinct()

        # Join on the shipment_stop_remark table to include sub-BOL's
        df_sr_sub_bol = df_sr.join(df_ss_remark, [df_sr.shipment_gid == df_ss_remark.shipment_gid], 'inner')\
            .select(df_sr.shipment_gid,
                    df_ss_remark.remark_text,
                    df_ss_remark.insert_date)
        df_sr_sub_bol = df_sr_sub_bol.withColumnRenamed(
            'remark_text', 'shipment_refnum_value')
        df_sr = df_sr.select(df_sr.shipment_gid, df_sr.shipment_refnum_value,
                             df_sr.insert_date)
        df_sr = df_sr.union(df_sr_sub_bol)

        # Filter down the Shipment_Refnum dataset to only include the records
        # that match with invoices so that the later joins and calculations
        # are more focused.  Since the bol's are unique in OTM, we only
        # need unique bol numbers here to filter down.
        df_inv = df_inv.where(df_inv.lbs >= 0)\
            .groupBy(df_inv.bol_number)\
            .agg((F.sum(df_inv.lbs)/2000).alias("total_bol_tons_from_inv"))

        df_sr = df_sr.groupBy(df_sr.shipment_gid,
                              df_sr.shipment_refnum_value)\
            .agg(F.max(df_sr.insert_date).alias("max_insert_date"))
        df_sr = df_sr.orderBy(df_sr.shipment_gid,
                              df_sr.shipment_refnum_value,
                              df_sr.max_insert_date.desc())\
            .dropDuplicates(["shipment_gid",
                             "shipment_refnum_value"])

        df_sr = df_inv.join(df_sr, [df_inv.bol_number == df_sr.shipment_refnum_value], 'inner')\
            .select(df_sr.shipment_gid,
                    df_sr.shipment_refnum_value,
                    df_inv.total_bol_tons_from_inv)

        df_sr_join_back = df_sr.select(df_sr.shipment_gid,
                                       df_sr.shipment_refnum_value)\
            .withColumnRenamed('shipment_gid', 'df_sr_join_back_shipment_gid')

        df_sr = df_sr.select(df_sr.shipment_gid,
                             df_sr.total_bol_tons_from_inv)\
            .groupBy(df_sr.shipment_gid)\
            .agg(F.sum(df_inv.total_bol_tons_from_inv).alias("total_shipment_gid_tons_from_inv"))

        df_sr = df_sr.join(df_sr_join_back, [df_sr.shipment_gid == df_sr_join_back.df_sr_join_back_shipment_gid], 'inner')\
            .select(df_sr.shipment_gid,
                    df_sr.total_shipment_gid_tons_from_inv,
                    df_sr_join_back.shipment_refnum_value)\
            .withColumnRenamed('shipment_gid', 'sr_shipment_gid')

        df_sstat = df_sstat.groupBy(df_sstat.domain_name,
                                    df_sstat.shipment_gid,
                                    df_sstat.status_type_gid,
                                    df_sstat.status_value_gid)\
            .agg(F.max(df_sstat.insert_date).alias("max_insert_date"))
        df_sstat = df_sstat.orderBy(df_sstat.shipment_gid,
                                    df_sstat.status_type_gid,
                                    df_sstat.status_value_gid,
                                    df_sstat.max_insert_date.desc())\
            .dropDuplicates(["shipment_gid",
                             "status_type_gid",
                             "status_value_gid"])

        # Filter the Shipment Refnum dataframe to only include records where the Shipment Status
        # matches what we are looking for with domain_name of SSCC and specific status_value_id's
        # At least three of the status_value_gid's need to match in order to put the confidence
        # level high enough to signify a match.
        df_sr = df_sr.join(df_sstat, [df_sr.sr_shipment_gid == df_sstat.shipment_gid], 'inner')\
            .where((df_sstat.domain_name == 'SSCC') &
                   (df_sstat.status_value_gid.isin({
                       'SSCC.BOL_ACTUALS_ENTERED_TRANSMISSION',
                       'SSCC.BOL DELETED_NO',
                       'SSCC.SECURE RESOURCES_ACCEPTED',
                       'SSCC.SECURE RESOURCES_PICKUP NOTIFICATION'})))\
            .groupBy(df_sr.sr_shipment_gid,
                     df_sr.shipment_refnum_value,
                     df_sr.total_shipment_gid_tons_from_inv).count()\
            .where(F.col('count') > 2) \
            .select(df_sr.sr_shipment_gid,
                    df_sr.shipment_refnum_value,
                    df_sr.total_shipment_gid_tons_from_inv)

        df_ss = df_ss.groupBy(df_ss.domain_name,
                              df_ss.shipment_gid,
                              df_ss.stop_num,
                              df_ss.dist_from_prev_stop_base)\
            .agg(F.max(df_ss.insert_date).alias("max_insert_date"))
        df_ss = df_ss.select(df_ss.domain_name,
                             df_ss.shipment_gid,
                             df_ss.stop_num,
                             df_ss.dist_from_prev_stop_base,
                             df_ss.max_insert_date)\
            .orderBy(df_ss.shipment_gid,
                     df_ss.stop_num,
                     df_ss.dist_from_prev_stop_base,
                     df_ss.max_insert_date.desc())\
            .dropDuplicates(["shipment_gid",
                             "stop_num"])

        # Filter the Shipment_Stop dataframe to only include records
        # with domain_name of SSCC and then add up the dist_from_prev_stop_base values
        # to determine the mileage
        df_ss = df_ss.where(df_ss.domain_name == 'SSCC')\
            .groupBy(df_ss.shipment_gid)\
            .agg(F.sum('dist_from_prev_stop_base').alias('mileage'))
        df_ss = df_ss.withColumnRenamed('shipment_gid', 'ss_shipment_gid')

        df_cost = df_cost.groupBy(df_cost.cost_type,
                                  df_cost.cost_base,
                                  df_cost.accessorial_code_gid,
                                  df_cost.is_weighted,
                                  df_cost.domain_name,
                                  df_cost.shipment_gid)\
            .agg(F.max(df_cost.insert_date).alias("max_insert_date"))

        # Drop the extra entries based on the max insert date
        df_cost = df_cost.orderBy(df_cost.shipment_gid,
                                  df_cost.max_insert_date.desc())\
            .dropDuplicates(["shipment_gid", "cost_base"])\
            .select(df_cost.shipment_gid,
                    df_cost.accessorial_code_gid,
                    df_cost.cost_type,
                    df_cost.cost_base,
                    df_cost.is_weighted,
                    df_cost.domain_name)

        # Filter the Shipment_Cost dataframe to only include records
        # with domain_name of SSCC
        df_cost = df_cost.where(df_cost.domain_name == 'SSCC')

        # Create a dataframe from Shipment_Cost that includes the detention costs
        df_cost_det = df_cost.where(df_cost.cost_type == 'A')\
            .where(df_cost.accessorial_code_gid.isin({
                'SSCC.DETENTION',
                'SSCC.DETENTION_DESTINATION',
                'SSCC.DTL LOADING',
                'SSCC.STORAGE'}))\
            .groupBy(df_cost.shipment_gid)\
            .agg(F.sum('cost_base').alias('det_cost_base_sum'))\
            .select(df_cost.shipment_gid, 'det_cost_base_sum')
        df_cost_det = df_cost_det.withColumnRenamed('shipment_gid',
                                                    'det_shipment_gid')
        df_cost_det = df_cost_det\
            .withColumn('det_cost_base_sum',
                        F.when(df_cost_det.det_cost_base_sum.isNotNull(), df_cost_det.det_cost_base_sum)
                        .otherwise(0))

        # Create a dataframe from Shipment_Cost that includes the accessorial costs
        df_cost_acc = df_cost.where(df_cost.cost_type.isin({'A', 'S', 'O'}))\
            .where(df_cost.accessorial_code_gid.isin({
                'SSCC.DETENTION',
                'SSCC.DETENTION_DESTINATION',
                'SSCC.DTL LOADING',
                'SSCC.STORAGE'}) == False)\
            .where(F.split(df_cost.accessorial_code_gid, '.').getItem(1).contains('FSC') == False)\
            .where(df_cost.is_weighted == 'N')\
            .groupBy(df_cost.shipment_gid)\
            .agg(F.sum('cost_base').alias('acc_cost_base_sum'))\
            .select(df_cost.shipment_gid, 'acc_cost_base_sum')
        df_cost_acc = df_cost_acc.withColumnRenamed('shipment_gid',
                                                    'acc_shipment_gid')
        df_cost_acc = df_cost_acc\
            .withColumn('acc_cost_base_sum',
                        F.when(df_cost_acc.acc_cost_base_sum.isNotNull(), df_cost_acc.acc_cost_base_sum)
                        .otherwise(0))

        # Create a dataframe from Shipment_Cost that includes the accessorial costs
        df_cost_fsrchg = df_cost.where(df_cost.cost_type == 'A')\
            .where(F.split(df_cost.accessorial_code_gid, '.').getItem(1).contains('FSC'))\
            .groupBy(df_cost.shipment_gid)\
            .agg(F.sum('cost_base').alias('fsrchg_cost_base_sum'))\
            .select(df_cost.shipment_gid, 'fsrchg_cost_base_sum')
        df_cost_fsrchg = df_cost_fsrchg.withColumnRenamed(
            'shipment_gid', 'fsrchg_shipment_gid')
        df_cost_fsrchg = df_cost_fsrchg\
            .withColumn('fsrchg_cost_base_sum',
                        F.when(df_cost_fsrchg.fsrchg_cost_base_sum.isNotNull(), df_cost_fsrchg.fsrchg_cost_base_sum)
                        .otherwise(0))

        # Create a dataframe from Shipment_Cost that includes the base rate costs
        # for LTL based shipments.  This will be later joined to the Shipment
        # table to apply a Freight cost value if the transport mode is LTL
        df_cost_ltl_base = df_cost.where(df_cost.cost_type.isin({'B', 'D'}))\
            .groupBy(df_cost.shipment_gid)\
            .agg(F.sum('cost_base').alias('ltl_base_cost_base_sum'))\
            .select(df_cost.shipment_gid, 'ltl_base_cost_base_sum')
        df_cost_ltl_base = df_cost_ltl_base.withColumnRenamed(
            'shipment_gid', 'ltl_base_shipment_gid')
        df_cost_ltl_base = df_cost_ltl_base\
            .withColumn('ltl_base_cost_base_sum',
                        F.when(df_cost_ltl_base.ltl_base_cost_base_sum.isNotNull(),
                               df_cost_ltl_base.ltl_base_cost_base_sum)
                        .otherwise(0))

        # Create a dataframe from Shipment_Cost that includes the base rate costs
        # for non-LTL based shipments.  This will be later joined to the Shipment
        # table to apply a Freight cost value if the transport mode is not LTL
        df_cost_nonltl_base = df_cost.where(df_cost.cost_type == 'B')\
            .groupBy(df_cost.shipment_gid)\
            .agg(F.sum('cost_base').alias('nonltl_base_cost_base_sum')) \
            .select(df_cost.shipment_gid, 'nonltl_base_cost_base_sum')
        df_cost_nonltl_base = df_cost_nonltl_base.withColumnRenamed(
            'shipment_gid', 'nonltl_base_shipment_gid')
        df_cost_nonltl_base = df_cost_nonltl_base\
            .withColumn('nonltl_base_cost_base_sum',
                        F.when(df_cost_nonltl_base.nonltl_base_cost_base_sum.isNotNull(),
                               df_cost_nonltl_base.nonltl_base_cost_base_sum)
                        .otherwise(0))

        df_s = df_s.groupBy(df_s.shipment_gid,
                            df_s.transport_mode_gid,
                            df_s.total_weight_base) \
            .agg(F.max(df_s.insert_date).alias("max_insert_date"))
        df_s = df_s.orderBy(df_s.shipment_gid, df_s.max_insert_date.desc())\
            .dropDuplicates(["shipment_gid"])

        # This join filters down the ref_nums to only shipments with good statuses that are
        # relevant to our invoices.
        df = df_sr.join(df_s, [df_sr.sr_shipment_gid == df_s.shipment_gid],
                        'left_outer')
        df = df.join(df_ss, [df_ss.ss_shipment_gid == df.sr_shipment_gid],
                     'left_outer')
        df = df.join(df_cost_det,
                     [df_cost_det.det_shipment_gid == df.sr_shipment_gid],
                     'left_outer')
        df = df.join(df_cost_acc,
                     [df_cost_acc.acc_shipment_gid == df.sr_shipment_gid],
                     'left_outer')
        df = df.join(
            df_cost_fsrchg,
            [df_cost_fsrchg.fsrchg_shipment_gid == df.sr_shipment_gid],
            'left_outer')
        df = df.join(
            df_cost_ltl_base,
            [df_cost_ltl_base.ltl_base_shipment_gid == df.sr_shipment_gid],
            'left_outer')
        df = df.join(df_cost_nonltl_base, [
            df_cost_nonltl_base.nonltl_base_shipment_gid == df.sr_shipment_gid
        ], 'left_outer')
        df = df.select(df.shipment_gid, df.transport_mode_gid,
                       df.total_shipment_gid_tons_from_inv,
                       df.shipment_refnum_value, df.mileage,
                       df.det_cost_base_sum, df.acc_cost_base_sum,
                       df.fsrchg_cost_base_sum, df.ltl_base_cost_base_sum,
                       df.nonltl_base_cost_base_sum)
        df = df.withColumn('det_cost_base_sum',
                           F.when(df.det_cost_base_sum.isNotNull(), df.det_cost_base_sum)
                           .otherwise(0))\
            .withColumn('acc_cost_base_sum',
                        F.when(df.acc_cost_base_sum.isNotNull(), df.acc_cost_base_sum)
                        .otherwise(0))\
            .withColumn('fsrchg_cost_base_sum',
                        F.when(df.fsrchg_cost_base_sum.isNotNull(), df.fsrchg_cost_base_sum)
                        .otherwise(0))\
            .withColumn('ltl_base_cost_base_sum',
                        F.when(df.ltl_base_cost_base_sum.isNotNull(), df.ltl_base_cost_base_sum)
                        .otherwise(0))\
            .withColumn('nonltl_base_cost_base_sum',
                        F.when(df.nonltl_base_cost_base_sum.isNotNull(), df.nonltl_base_cost_base_sum)
                        .otherwise(0))

        # Calculate the individual costs based on all of the joined tables.
        df = df.withColumn('tons',
                           F.when(df.total_shipment_gid_tons_from_inv.isNull(), 0)
                           .otherwise(df.total_shipment_gid_tons_from_inv).cast(T.DecimalType(38, 18)))\
            .withColumn('detention',
                        F.when(df.det_cost_base_sum.isNull(), 0)
                        .otherwise(df.det_cost_base_sum).cast(T.DecimalType(38, 18)))\
            .withColumn('accessorials',
                        F.when(df.acc_cost_base_sum.isNull(), 0)
                        .otherwise(df.acc_cost_base_sum).cast(T.DecimalType(38, 18)))\
            .withColumn('fuel_surcharge',
                        F.when(df.fsrchg_cost_base_sum.isNull(), 0)
                        .otherwise(df.fsrchg_cost_base_sum).cast(T.DecimalType(38, 18)))\
            .withColumn('base_rate',
                        F.when((df.ltl_base_cost_base_sum.isNull() & df.nonltl_base_cost_base_sum.isNull()), 0)
                        .when(F.trim(F.upper(df.transport_mode_gid)) == 'LTL', df.ltl_base_cost_base_sum)
                        .otherwise(df.nonltl_base_cost_base_sum).cast(T.DecimalType(38, 18)))
        # Calculate the freight rate per ton
        df = df.withColumn('freight_rate_per_ton',
                           F.when(df.tons > 0,
                                  (df.detention + df.accessorials + df.fuel_surcharge + df.base_rate) / df.tons)
                           .otherwise(0))\
            .withColumnRenamed('shipment_refnum_value', 'bol_number_join')

        # It is possible for multiple shipment_gid to match to the same bol_number and have the same cost.
        # This distinct removes those cases so as not to introduce duplicates when joining with the invoice table.
        df = (df.select(df.freight_rate_per_ton,
                        df.bol_number_join).distinct())

        return df
from pyspark.sql import SparkSession
import pyspark.sql.functions as F
import pyspark.sql.types as T

# writeLegacyFormat is to make a spark decimal type works with hive decimal type.
spark = SparkSession.builder\
.config("spark.sql.parquet.writeLegacyFormat",True)\
.enableHiveSupport().getOrCreate()

df = spark.read.csv('hdfs://hive-namenode:8020/user/sqoop/restaurant_detail/part-m-00000', header=False)
rename = {
    '_c0' : 'id',
    '_c1' : 'restaurant_name',
    '_c2' : 'category',
    '_c3' : 'estimated_cooking_time',
    '_c4' : 'latitude',
    '_c5' : 'longitude',
}
df = df.toDF(*[rename[c] for c in df.columns])
df = df.withColumn('estimated_cooking_time', F.col('estimated_cooking_time').cast(T.FloatType()))
df = df.withColumn('latitude', F.col('latitude').cast(T.DecimalType(11,8)))
df = df.withColumn('longitude', F.col('longitude').cast(T.DecimalType(11,8)))
df = df.withColumn('dt', F.lit("latest"))
df.write.parquet('hdfs://hive-namenode:8020/user/spark/transformed_restaurant_detail', partitionBy='dt', mode='overwrite')
Пример #17
0
db_path = os.environ.get("DATABASE_PATH")
output_path = os.environ.get("OUTPUT_PATH")

RAW_DATA_PATH = f"{db_path}/teinvento_inc/ventas_reportadas_mercado_tamales/mx/20200801/"

today = date.today().strftime("%Y%m%d")

# TeInvento Inc data
FACT_DATA_PATH = f"{RAW_DATA_PATH}fact_table/"
PRODUCT_DIM_DATA_PATH = f"{RAW_DATA_PATH}product_dim/"
REGION_DIM_DATA_PATH = f"{RAW_DATA_PATH}region_dim/"

fact_data_schema = st.StructType([
    st.StructField('year', st.StringType(), True),
    st.StructField('month', st.StringType(), True),
    st.StructField('sales', st.DecimalType(10, 2), True),
    st.StructField('region_id', st.StringType(), True),
    st.StructField('product_id', st.StringType(), True)
])

region_data_schema = st.StructType([
    st.StructField('region_id', st.StringType(), True),
    st.StructField('country', st.StringType(), True),
    st.StructField('location', st.StringType(), True)
])

product_data_schema = st.StructType([
    st.StructField('product_id', st.StringType(), True),
    st.StructField('type', st.StringType(), True),
    st.StructField('vendor', st.StringType(), True),
    st.StructField('flavor', st.StringType(), True),
Пример #18
0
def process_song_data(spark, input_data, output_data):
    """
    Description: This function can be used to process song-data files from the
    given input path and transform the data from json files into songs and artists
    spark tables and writing these tables to the given output path as parquet tables.

    Arguments:
        spark: SparkSession object.
        input_data: Path to the input JSON files.
        output_data: Path to the output directory that stores output parquet tables.

    Returns:
        df: Song data dataframe.
    """
    # get filepath to song data file
    song_data = input_data + 'song-data/A/B/*/*.json'

    # define schema for song data file
    song_schema = t.StructType([
        t.StructField("artist_id", t.StringType(), True),
        t.StructField("artist_latitude", t.DecimalType(11, 7), True),
        t.StructField("artist_location", t.StringType(), True),
        t.StructField("artist_longitude", t.DecimalType(11, 7), True),
        t.StructField("artist_name", t.StringType(), True),
        t.StructField("duration", t.DecimalType(11, 7), True),
        t.StructField("num_songs", t.IntegerType(), True),
        t.StructField("song_id", t.StringType(), True),
        t.StructField("title", t.StringType(), True),
        t.StructField("year", t.ShortType(), True)
    ])

    # read song data file using schema
    df = spark \
        .read \
        .format("json") \
        .schema(song_schema) \
        .load(song_data)

    # extract columns to create songs table
    songs_table = df \
        .select(['song_id', 'title', 'artist_id', 'year', 'duration']) \
        .dropDuplicates()

    # write songs table to parquet files partitioned by year and artist
    songs_output = output_data + 'songs'

    songs_table \
        .write \
        .partitionBy('year', 'artist_id') \
        .option("path", songs_output) \
        .saveAsTable('songs', format='parquet')

    # extract columns to create artists table
    artists_table = df \
        .select(['artist_id', 'artist_name', 'artist_location', 'artist_longitude', 'artist_latitude']) \
        .dropDuplicates()

    # write artists table to parquet files
    artists_output = output_data + 'artists'

    artists_table \
        .write \
        .option("path", artists_output) \
        .saveAsTable('artists', format='parquet')

    return df
Пример #19
0
def as_spark_type(tpe: Union[str, type, Dtype],
                  *,
                  raise_error: bool = True,
                  prefer_timestamp_ntz: bool = False) -> types.DataType:
    """
    Given a Python type, returns the equivalent spark type.
    Accepts:
    - the built-in types in Python
    - the built-in types in numpy
    - list of pairs of (field_name, type)
    - dictionaries of field_name -> type
    - Python3's typing system
    """
    # For NumPy typing, NumPy version should be 1.21+ and Python version should be 3.8+
    if sys.version_info >= (3, 8) and LooseVersion(
            np.__version__) >= LooseVersion("1.21"):
        if (hasattr(tpe, "__origin__")
                and tpe.__origin__ is np.ndarray  # type: ignore[union-attr]
                and hasattr(tpe, "__args__")
                and len(tpe.__args__) > 1  # type: ignore[union-attr]
            ):
            # numpy.typing.NDArray
            return types.ArrayType(
                as_spark_type(
                    tpe.__args__[1].__args__[0],
                    raise_error=raise_error  # type: ignore[union-attr]
                ))

    if isinstance(tpe, np.dtype) and tpe == np.dtype("object"):
        pass
    # ArrayType
    elif tpe in (np.ndarray, ):
        return types.ArrayType(types.StringType())
    elif hasattr(tpe, "__origin__") and issubclass(
            tpe.__origin__,
            list  # type: ignore[union-attr]
    ):
        element_type = as_spark_type(
            tpe.__args__[0],
            raise_error=raise_error  # type: ignore[union-attr]
        )
        if element_type is None:
            return None
        return types.ArrayType(element_type)
    # BinaryType
    elif tpe in (bytes, np.character, np.bytes_, np.string_):
        return types.BinaryType()
    # BooleanType
    elif tpe in (bool, np.bool_, "bool", "?"):
        return types.BooleanType()
    # DateType
    elif tpe in (datetime.date, ):
        return types.DateType()
    # NumericType
    elif tpe in (np.int8, np.byte, "int8", "byte", "b"):
        return types.ByteType()
    elif tpe in (decimal.Decimal, ):
        # TODO: considering about the precision & scale for decimal type.
        return types.DecimalType(38, 18)
    elif tpe in (float, np.float_, np.float64, "float", "float64", "double"):
        return types.DoubleType()
    elif tpe in (np.float32, "float32", "f"):
        return types.FloatType()
    elif tpe in (np.int32, "int32", "i"):
        return types.IntegerType()
    elif tpe in (int, np.int64, "int", "int64", "long"):
        return types.LongType()
    elif tpe in (np.int16, "int16", "short"):
        return types.ShortType()
    # StringType
    elif tpe in (str, np.unicode_, "str", "U"):
        return types.StringType()
    # TimestampType or TimestampNTZType if timezone is not specified.
    elif tpe in (datetime.datetime, np.datetime64, "datetime64[ns]", "M"):
        return types.TimestampNTZType(
        ) if prefer_timestamp_ntz else types.TimestampType()

    # categorical types
    elif isinstance(tpe, CategoricalDtype) or (isinstance(tpe, str)
                                               and type == "category"):
        return types.LongType()

    # extension types
    elif extension_dtypes_available:
        # IntegralType
        if isinstance(tpe, Int8Dtype) or (isinstance(tpe, str)
                                          and tpe == "Int8"):
            return types.ByteType()
        elif isinstance(tpe, Int16Dtype) or (isinstance(tpe, str)
                                             and tpe == "Int16"):
            return types.ShortType()
        elif isinstance(tpe, Int32Dtype) or (isinstance(tpe, str)
                                             and tpe == "Int32"):
            return types.IntegerType()
        elif isinstance(tpe, Int64Dtype) or (isinstance(tpe, str)
                                             and tpe == "Int64"):
            return types.LongType()

        if extension_object_dtypes_available:
            # BooleanType
            if isinstance(tpe, BooleanDtype) or (isinstance(tpe, str)
                                                 and tpe == "boolean"):
                return types.BooleanType()
            # StringType
            elif isinstance(tpe, StringDtype) or (isinstance(tpe, str)
                                                  and tpe == "string"):
                return types.StringType()

        if extension_float_dtypes_available:
            # FractionalType
            if isinstance(tpe, Float32Dtype) or (isinstance(tpe, str)
                                                 and tpe == "Float32"):
                return types.FloatType()
            elif isinstance(tpe, Float64Dtype) or (isinstance(tpe, str)
                                                   and tpe == "Float64"):
                return types.DoubleType()

    if raise_error:
        raise TypeError("Type %s was not understood." % tpe)
    else:
        return None
Пример #20
0
def process_log_data(spark, input_data, output_data, song_df):
    """
    Description: This function can be used to process log-data files from the
    given input path and transform the data from json files into users, time and songplays
    spark tables and writing these tables to the given output path as parquet tables.

    Arguments:
        spark: SparkSession object.
        input_data: Path to the input JSON files.
        output_data: Path to the output directory that stores output parquet tables.
        song_df: Song data dataframe.

    Returns:
        None.
    """
    # get filepath to log data file
    log_data = input_data + 'log-data/2018/11'

    # define schema for log data file
    log_schema = t.StructType([
        t.StructField("artist", t.StringType(), True),
        t.StructField("auth", t.StringType(), True),
        t.StructField("firstName", t.StringType(), True),
        t.StructField("gender", t.StringType(), True),
        t.StructField("itemInSession", t.IntegerType(), True),
        t.StructField("lastName", t.StringType(), True),
        t.StructField("length", t.DecimalType(12, 7), True),
        t.StructField("level", t.StringType(), True),
        t.StructField("location", t.StringType(), True),
        t.StructField("method", t.StringType(), True),
        t.StructField("page", t.StringType(), True),
        t.StructField("registration", t.DecimalType(16, 2), True),
        t.StructField("sessionId", t.IntegerType(), True),
        t.StructField("song", t.StringType(), True),
        t.StructField("status", t.IntegerType(), True),
        t.StructField("ts", t.LongType(), True),
        t.StructField("userAgent", t.StringType(), True),
        t.StructField("userId", t.StringType(), True)
    ])

    # read log data file using schema
    df = spark \
        .read \
        .format("json") \
        .schema(log_schema) \
        .load(log_data)

    # filter by actions for song plays
    df = df \
        .filter('page = "NextSong"')

    # group by userId for unique users
    users_list = df \
        .groupBy('userId') \
        .agg(f.max('ts').alias('ts'))

    # extract columns to create users table
    users_table = df \
        .join(users_list, ['userId', 'ts'], 'inner') \
        .select([df.userId.cast(t.IntegerType()).alias('user_id'), col('firstName').alias('first_name'), col('lastName').alias('last_name'), 'gender', 'level']) \
        .dropDuplicates()

    # write users table to parquet files
    users_output = output_data + 'users'

    users_table \
        .write \
        .option("path", users_output) \
        .saveAsTable('users', format='parquet')

    # create timestamp column from original timestamp column
    df = df \
        .withColumn('timestamp', f.from_utc_timestamp((df.ts/1000.0).cast('timestamp'), 'UTC'))

    # create datetime column from original timestamp column
    get_datetime = udf(lambda ts: datetime.fromtimestamp(ts / 1000.0),
                       t.TimestampType())
    df = df.withColumn('datetime', get_datetime('ts'))

    # extract columns to create time table
    time_table = df \
        .select([col('datetime').alias('start_time'), dayofmonth(col('datetime')).alias('day'), weekofyear(col('datetime')).alias('week'), month(col('datetime')).alias('month'), year(col('datetime')).alias('year'), dayofweek(col('datetime')).alias('weekday')]) \
        .dropDuplicates()

    # write time table to parquet files partitioned by year and month
    time_output = output_data + 'time'

    time_table \
        .write \
        .partitionBy('year', 'month') \
        .option("path", time_output) \
        .saveAsTable('time', format='parquet')

    # join and extract columns from song and log datasets to create songplays table
    cond = [
        df.artist == song_df.artist_name, df.song == song_df.title,
        df.length == song_df.duration
    ]
    songplays_df = df.join(song_df, cond, 'left')

    songplays_df = songplays_df \
        .select(df.datetime.alias('start_time'), df.userId.alias('user_id'), df.level.alias('level'), song_df.song_id.alias('song_id'), song_df.artist_id.alias('artist_id'), df.sessionId.alias('session_id'), df.location.alias('location'), df.userAgent.alias('user_agent'), year(df.datetime).alias('year'), month(df.datetime).alias('month'))
    w = Window().orderBy(f.lit('A'))
    songplays_table = songplays_df.withColumn('songplay_id',
                                              f.row_number().over(w))

    # write songplays table to parquet files partitioned by year and month
    songplays_output = output_data + 'songplays'

    songplays_table \
        .select(['songplay_id', 'start_time', 'user_id', 'level', 'song_id', 'artist_id', 'session_id', 'location', 'user_agent', 'year', 'month'])\
        .write \
        .partitionBy('year', 'month') \
        .option("path", songplays_output) \
        .saveAsTable('songplays', format='parquet')
Пример #21
0
def as_spark_type(tpe: typing.Union[str, type, Dtype],
                  *,
                  raise_error: bool = True) -> types.DataType:
    """
    Given a Python type, returns the equivalent spark type.
    Accepts:
    - the built-in types in Python
    - the built-in types in numpy
    - list of pairs of (field_name, type)
    - dictionaries of field_name -> type
    - Python3's typing system
    """
    # TODO: Add "boolean" and "string" types.
    # ArrayType
    if tpe in (np.ndarray, ):
        return types.ArrayType(types.StringType())
    elif hasattr(tpe, "__origin__") and issubclass(tpe.__origin__,
                                                   list):  # type: ignore
        element_type = as_spark_type(tpe.__args__[0],
                                     raise_error=raise_error)  # type: ignore
        if element_type is None:
            return None
        return types.ArrayType(element_type)
    # BinaryType
    elif tpe in (bytes, np.character, np.bytes_, np.string_):
        return types.BinaryType()
    # BooleanType
    elif tpe in (bool, np.bool, "bool", "?"):
        return types.BooleanType()
    # DateType
    elif tpe in (datetime.date, ):
        return types.DateType()
    # NumericType
    elif tpe in (np.int8, np.byte, "int8", "byte", "b"):
        return types.ByteType()
    elif tpe in (decimal.Decimal, ):
        # TODO: considering about the precision & scale for decimal type.
        return types.DecimalType(38, 18)
    elif tpe in (float, np.float, np.float64, "float", "float64", "double"):
        return types.DoubleType()
    elif tpe in (np.float32, "float32", "f"):
        return types.FloatType()
    elif tpe in (np.int32, "int32", "i"):
        return types.IntegerType()
    elif tpe in (int, np.int, np.int64, "int", "int64", "long"):
        return types.LongType()
    elif tpe in (np.int16, "int16", "short"):
        return types.ShortType()
    # StringType
    elif tpe in (str, np.unicode_, "str", "U"):
        return types.StringType()
    # TimestampType
    elif tpe in (datetime.datetime, np.datetime64, "datetime64[ns]", "M"):
        return types.TimestampType()

    # categorical types
    elif isinstance(tpe, CategoricalDtype) or (isinstance(tpe, str)
                                               and type == "category"):
        return types.LongType()

    # extension types
    elif extension_dtypes_available:
        # IntegralType
        if isinstance(tpe, Int8Dtype) or (isinstance(tpe, str)
                                          and tpe == "Int8"):
            return types.ByteType()
        elif isinstance(tpe, Int16Dtype) or (isinstance(tpe, str)
                                             and tpe == "Int16"):
            return types.ShortType()
        elif isinstance(tpe, Int32Dtype) or (isinstance(tpe, str)
                                             and tpe == "Int32"):
            return types.IntegerType()
        elif isinstance(tpe, Int64Dtype) or (isinstance(tpe, str)
                                             and tpe == "Int64"):
            return types.LongType()

        if extension_object_dtypes_available:
            # BooleanType
            if isinstance(tpe, BooleanDtype) or (isinstance(tpe, str)
                                                 and tpe == "boolean"):
                return types.BooleanType()
            # StringType
            elif isinstance(tpe, StringDtype) or (isinstance(tpe, str)
                                                  and tpe == "string"):
                return types.StringType()

        if extension_float_dtypes_available:
            # FractionalType
            if isinstance(tpe, Float32Dtype) or (isinstance(tpe, str)
                                                 and tpe == "Float32"):
                return types.FloatType()
            elif isinstance(tpe, Float64Dtype) or (isinstance(tpe, str)
                                                   and tpe == "Float64"):
                return types.DoubleType()

    if raise_error:
        raise TypeError("Type %s was not understood." % tpe)
    else:
        return None
Пример #22
0
def parse(path_to_dir):
    if 'DAS5' in os.environ:  # If we want to execute it on the DAS-5 super computer
        print("We are on DAS5, {0} is master.".format(os.environ['HOSTNAME'] + ".ib.cluster"))
        spark = SparkSession.builder \
            .master("spark://" + os.environ['HOSTNAME'] + ".ib.cluster:7077") \
            .appName("WTA parser") \
            .config("spark.executor.memory", "28G") \
            .config("spark.executor.cores", "8") \
            .config("spark.executor.instances", "10") \
            .config("spark.driver.memory", "40G") \
            .getOrCreate()
    else:
        findspark.init(spark_home="<path_to_spark>")
        spark = SparkSession.builder \
            .master("local[8]") \
            .appName("WTA parser") \
            .config("spark.executor.memory", "20G") \
            .config("spark.driver.memory", "8G") \
            .getOrCreate()

    # Convert times which are in microseconds and do not fit in a long to milliseconds
    convert_micro_to_milliseconds = F.udf(lambda x: x / 1000)

    if not os.path.exists(os.path.join(TARGET_DIR, TaskState.output_path())):
        print("######\n Start parsing TaskState\n ######")
        task_usage_df = spark.read.format('com.databricks.spark.csv').options(mode="FAILFAST", inferschema="true").load(
            os.path.join(path_to_dir, 'task_usage', '*.csv'))
        # task_usage_df = spark.read.format('com.databricks.spark.csv').options(mode="FAILFAST", inferschema="true").load(
        #     'fake_task_usage.csv')
        oldColumns = task_usage_df.schema.names
        newColumns = ["ts_start",
                      "ts_end",
                      "workflow_id",
                      "id",
                      "resource_id",
                      "cpu_rate",
                      "memory_consumption",
                      "assigned_memory_usage",
                      "unmapped_page_cache",
                      "total_page_cache",
                      "max_memory_usage",
                      "mean_disk_io_time",
                      "mean_local_disk_space_usage",
                      "max_cpu_rate",
                      "max_disk_io_time",
                      "cycles_per_instruction",
                      "memory_accesses_per_instruction",
                      "sample_portion",
                      "aggregation_type",
                      "sampled_cpu_usage", ]

        task_usage_df = reduce(lambda data, idx: data.withColumnRenamed(oldColumns[idx], newColumns[idx]),
                               range(len(oldColumns)), task_usage_df)

        # Drop columns with too low level details
        task_usage_df = task_usage_df.drop('memory_accesses_per_instruction')
        task_usage_df = task_usage_df.drop('cycles_per_instruction')
        task_usage_df = task_usage_df.drop('unmapped_page_cache')
        task_usage_df = task_usage_df.drop('total_page_cache')

        # Conver the timestamps from micro to milliseconds and cast them to long.
        task_usage_df = task_usage_df.withColumn('ts_start', convert_micro_to_milliseconds(F.col('ts_start')))
        task_usage_df = task_usage_df.withColumn('ts_start', F.col('ts_start').cast(T.LongType()))
        task_usage_df = task_usage_df.withColumn('ts_end', convert_micro_to_milliseconds(F.col('ts_end')))
        task_usage_df = task_usage_df.withColumn('ts_end', F.col('ts_end').cast(T.LongType()))

        # Some fields have weird symbols in them, clean those.
        truncate_at_lt_symbol_udf = F.udf(lambda x: re.sub('[^0-9.eE\-+]', '', str(x)) if x is not None else x)
        task_usage_df = task_usage_df.withColumn('workflow_id', truncate_at_lt_symbol_udf(F.col('workflow_id')))
        task_usage_df = task_usage_df.withColumn('max_cpu_rate', truncate_at_lt_symbol_udf(F.col('max_cpu_rate')))

        # Now that the columns have been sanitized, cast them to the right type
        task_usage_df = task_usage_df.withColumn('workflow_id', F.col('workflow_id').cast(T.LongType()))
        task_usage_df = task_usage_df.withColumn('max_cpu_rate', F.col('max_cpu_rate').cast(T.FloatType()))

        task_usage_df.write.parquet(os.path.join(TARGET_DIR, TaskState.output_path()), mode="overwrite",
                                    compression="snappy")
        print("######\n Done parsing TaskState\n ######")

    if not os.path.exists(os.path.join(TARGET_DIR, Task.output_path())):

        if 'task_usage_df' not in locals():
            task_usage_df = spark.read.parquet(os.path.join(TARGET_DIR, TaskState.output_path()))

        print("######\n Start parsing Tasks\n ######")
        task_df = spark.read.format('com.databricks.spark.csv').options(inferschema="true", mode="FAILFAST",
                                                                        parserLib="univocity").load(
            os.path.join(path_to_dir, 'task_events', '*.csv'))

        oldColumns = task_df.schema.names
        newColumns = ["ts_submit",
                      "missing_info",
                      "workflow_id",
                      "id",
                      "resource_id",
                      "event_type",
                      "user_id",
                      "scheduler",
                      "nfrs",
                      "resources_requested",
                      "memory_requested",
                      "disk_space_request",
                      "machine_restrictions", ]

        task_df = reduce(lambda data, idx: data.withColumnRenamed(oldColumns[idx], newColumns[idx]),
                         range(len(oldColumns)), task_df)

        task_df = task_df.withColumn('ts_submit', convert_micro_to_milliseconds(F.col('ts_submit')))
        task_df = task_df.withColumn('ts_submit', F.col('ts_submit').cast(T.LongType()))

        # Filter tasks that never reached completion
        task_df.createOrReplaceTempView("task_table")
        task_df = spark.sql("""WITH filtered_tasks AS (
        SELECT DISTINCT t1.workflow_id AS workflow_id, t1.id AS id
            FROM task_table t1
            WHERE t1.event_type IN(0, 1, 4)
            group by t1.workflow_id, t1.id
            having count(distinct event_type) = 3
        )
    SELECT t.*
    FROM task_table t INNER JOIN filtered_tasks f
    ON t.id = f.id AND t.workflow_id = f.workflow_id""")

        task_aggregation_structtype = T.StructType([
            T.StructField("workflow_id", T.LongType(), True),
            T.StructField("id", T.LongType(), True),
            T.StructField("type", T.StringType(), True),
            T.StructField("ts_submit", T.LongType(), True),
            T.StructField("submission_site", T.LongType(), True),
            T.StructField("runtime", T.LongType(), True),
            T.StructField("resource_type", T.StringType(), True),
            T.StructField("resource_amount_requested", T.DoubleType(), True),
            T.StructField("parents", T.ArrayType(T.LongType()), True),
            T.StructField("children", T.ArrayType(T.LongType()), True),
            T.StructField("user_id", T.LongType(), True),
            T.StructField("group_id", T.LongType(), True),
            T.StructField("nfrs", T.StringType(), True),
            T.StructField("wait_time", T.LongType(), True),
            T.StructField("params", T.StringType(), True),
            T.StructField("memory_requested", T.DoubleType(), True),
            T.StructField("network_io_time", T.DoubleType(), True),
            T.StructField("disk_space_requested", T.DoubleType(), True),
            T.StructField("energy_consumption", T.DoubleType(), True),
            T.StructField("resource_used", T.StringType(), True),
        ])

        # Compute based on the event type
        @F.pandas_udf(returnType=task_aggregation_structtype, functionType=F.PandasUDFType.GROUPED_MAP)
        def compute_aggregated_task_usage_metrics(df):
            def get_first_non_value_in_column(column_name):
                s = df[column_name]
                idx = s.first_valid_index()
                return s.loc[idx] if idx is not None else None

            task_workflow_id = get_first_non_value_in_column("workflow_id")
            task_id = get_first_non_value_in_column("id")

            task_submit_time = df[df['event_type'] == 0]['ts_submit'].min(skipna=True)
            task_start_time = df[df['event_type'] == 1]['ts_submit'].min(skipna=True)
            task_finish_time = df[df['event_type'] == 4]['ts_submit'].max(skipna=True)

            if None in [task_start_time, task_submit_time, task_finish_time]:
                return None

            task_resource_request = df['resources_requested'].max(skipna=True)
            task_memory_request = df['memory_requested'].max(skipna=True)
            task_priority = df['nfrs'].max(skipna=True)
            task_disk_space_requested = df['disk_space_request'].max(skipna=True)

            task_machine_id_list = df.resource_id.unique()

            task_waittime = int(task_start_time) - int(task_submit_time)
            task_runtime = int(task_finish_time) - int(task_start_time)

            def default(o):
                if isinstance(o, np.int64):
                    return int(o)

            data_dict = {
                "workflow_id": task_workflow_id,
                "id": task_id,
                "type": "",  # Unknown
                "ts_submit": task_submit_time,
                "submission_site": -1,  # Unknown
                "runtime": task_runtime,
                "resource_type": "core",  # Fields are called CPU, but they are core count (see Google documentation)
                "resource_amount_requested": task_resource_request,
                "parents": [],
                "children": [],
                "user_id": mmh3.hash64(get_first_non_value_in_column("user_id"))[0],
                "group_id": -1,
                "nfrs": json.dumps({"priority": task_priority}, default=default),
                "wait_time": task_waittime,
                "params": "{}",
                "memory_requested": task_memory_request,
                "network_io_time": -1,  # Unknown
                "disk_space_requested": task_disk_space_requested,
                "energy_consumption": -1,  # Unknown
                "resource_used": json.dumps(task_machine_id_list, default=default),
            }

            return pd.DataFrame(data_dict, index=[0])

        task_df = task_df.groupBy(["workflow_id", "id"]).apply(compute_aggregated_task_usage_metrics)
        task_df.explain(True)

        # Now add disk IO time - This cannot be done in the previous Pandas UDF function as
        # accessing another dataframe in the apply function is not allowed
        disk_io_structtype = T.StructType([
            T.StructField("workflow_id", T.LongType(), True),
            T.StructField("id", T.LongType(), True),
            T.StructField("disk_io_time", T.DoubleType(), True),
        ])

        @F.pandas_udf(returnType=disk_io_structtype, functionType=F.PandasUDFType.GROUPED_MAP)
        def compute_disk_io_time(df):
            def get_first_non_value_in_column(column_name):
                s = df[column_name]
                idx = s.first_valid_index()
                return s.loc[idx] if idx is not None else None

            task_workflow_id = get_first_non_value_in_column("workflow_id")
            task_id = get_first_non_value_in_column("id")

            disk_io_time = ((df['ts_end'] - df['ts_start']) * df['mean_disk_io_time']).sum(skipna=True) / 1000
            data_dict = {
                "workflow_id": task_workflow_id,
                "id": task_id,
                "disk_io_time": disk_io_time
            }

            return pd.DataFrame(data_dict, index=[0])

        disk_io_df = task_usage_df.select(['workflow_id', 'id', 'mean_disk_io_time', 'ts_end', 'ts_start']).groupBy(
            ["workflow_id", "id"]).apply(compute_disk_io_time)
        disk_io_df.explain(True)

        join_condition = (task_df.workflow_id == disk_io_df.workflow_id) & (task_df.id == disk_io_df.id)
        task_df = task_df.join(disk_io_df, ["workflow_id", "id"])

        task_df.write.parquet(os.path.join(TARGET_DIR, Task.output_path()), mode="overwrite", compression="snappy")
        print("######\n Done parsing Tasks\n ######")
    else:
        task_df = spark.read.parquet(os.path.join(TARGET_DIR, Task.output_path()))

    if not os.path.exists(os.path.join(TARGET_DIR, Resource.output_path())):
        print("######\n Start parsing Resource\n ######")
        # Parse the machine information in the traces, these should match with the resource_ids in task_usage
        resources_structtype = T.StructType([  # Using StringTypes as we drop those columns
            T.StructField("time", T.StringType(), False),
            T.StructField("id", T.LongType(), False),
            T.StructField("attribute_name", T.StringType(), False),
            T.StructField("attribute_value", T.StringType(), False),
            T.StructField("attribute_deleted", T.StringType(), False),
        ])

        resource_df = spark.read.format('com.databricks.spark.csv').schema(resources_structtype).options(
            mode="FAILFAST").load(os.path.join(path_to_dir, 'machine_attributes', '*.csv'))

        resource_df = resource_df.select(["id"])  # Only keep the ID, the rest we do not need.

        # Since the information in the traces is completely opaque, we use the educated guess from Amvrosiadis et al.
        # in their ATC 2018 article.
        resource_df = resource_df.withColumn('type', F.lit("core"))
        resource_df = resource_df.withColumn('num_resources', F.lit(8))
        resource_df = resource_df.withColumn('proc_model', F.lit("AMD Opteron Barcelona"))
        resource_df = resource_df.withColumn('memory', F.lit(-1))
        resource_df = resource_df.withColumn('disk_space', F.lit(-1))
        resource_df = resource_df.withColumn('network', F.lit(-1))
        resource_df = resource_df.withColumn('os', F.lit(""))
        resource_df = resource_df.withColumn('details', F.lit("{}"))

        # Write the resource_df to the specified location
        resource_df.write.parquet(os.path.join(TARGET_DIR, Resource.output_path()), mode="overwrite",
                                  compression="snappy")
        print("######\n Done parsing Resource\n ######")

    if not os.path.exists(os.path.join(TARGET_DIR, ResourceState.output_path())):
        print("######\n Start parsing ResourceState\n ######")
        resource_events_structtype = T.StructType([
            T.StructField("timestamp", T.DecimalType(20, 0), False),
            T.StructField("machine_id", T.LongType(), False),
            T.StructField("event_type", T.IntegerType(), False),
            T.StructField("platform_id", T.StringType(), False),
            T.StructField("available_resources", T.FloatType(), False),
            T.StructField("available_memory", T.FloatType(), False),
        ])

        resource_event_df = spark.read.format('com.databricks.spark.csv').schema(resource_events_structtype).options(
            mode="FAILFAST").load(os.path.join(path_to_dir, 'machine_events', '*.csv'))

        resource_event_df = resource_event_df.withColumn('timestamp', convert_micro_to_milliseconds(F.col('timestamp')))
        resource_event_df = resource_event_df.withColumn('timestamp', F.col('timestamp').cast(T.LongType()))

        resource_event_df = resource_event_df.withColumn('available_disk_space', F.lit(-1))
        resource_event_df = resource_event_df.withColumn('available_disk_io_bandwidth', F.lit(-1))
        resource_event_df = resource_event_df.withColumn('available_network_bandwidth', F.lit(-1))
        resource_event_df = resource_event_df.withColumn('average_load_1_minute', F.lit(-1))
        resource_event_df = resource_event_df.withColumn('average_load_5_minute', F.lit(-1))
        resource_event_df = resource_event_df.withColumn('average_load_15_minute', F.lit(-1))

        # Write the resource_df to the specified location
        resource_event_df.write.parquet(os.path.join(TARGET_DIR, ResourceState.output_path()), mode="overwrite",
                                        compression="snappy")
        print("######\n Done parsing ResourceState\n ######")

    if not os.path.exists(os.path.join(TARGET_DIR, Workflow.output_path())):
        print("######\n Start parsing Workflows\n ######")
        workflow_structype = T.StructType([
            T.StructField("id", T.LongType(), False),
            T.StructField("ts_submit", T.LongType(), False),
            T.StructField("task_count", T.IntegerType(), False),
            T.StructField("critical_path_length", T.LongType(), False),
            T.StructField("critical_path_task_count", T.IntegerType(), False),
            T.StructField("approx_max_concurrent_tasks", T.IntegerType(), False),
            T.StructField("nfrs", T.StringType(), False),
            T.StructField("scheduler", T.StringType(), False),
            T.StructField("total_resources", T.DoubleType(), False),
            T.StructField("total_memory_usage", T.DoubleType(), False),
            T.StructField("total_network_usage", T.LongType(), False),
            T.StructField("total_disk_space_usage", T.LongType(), False),
            T.StructField("total_energy_consumption", T.LongType(), False),
        ])

        @F.pandas_udf(returnType=workflow_structype, functionType=F.PandasUDFType.GROUPED_MAP)
        def compute_workflow_stats(df):
            id = df['workflow_id'].iloc[0]
            ts_submit = df['ts_submit'].min()
            task_count = len(df)
            critical_path_length = -1  # We do not know the task dependencies, so -1
            critical_path_task_count = -1
            approx_max_concurrent_tasks = -1
            nfrs = "{}"
            scheduler = ""
            total_resources = df['resource_amount_requested'].sum()  # TODO or assigned?
            total_memory_usage = df['memory_requested'].sum()  # TODO or consumption, or assigned?
            total_network_usage = -1
            total_disk_space_usage = -1
            total_energy_consumption = -1

            data_dict = {
                "id": id, "ts_submit": ts_submit, 'task_count': task_count,
                'critical_path_length': critical_path_length,
                'critical_path_task_count': critical_path_task_count,
                'approx_max_concurrent_tasks': approx_max_concurrent_tasks, 'nfrs': nfrs, 'scheduler': scheduler,
                'total_resources': total_resources, 'total_memory_usage': total_memory_usage,
                'total_network_usage': total_network_usage, 'total_disk_space_usage': total_disk_space_usage,
                'total_energy_consumption': total_energy_consumption
            }

            return pd.DataFrame(data_dict, index=[0])

        # Create and write the workflow dataframe
        workflow_df = task_df.groupBy('workflow_id').apply(compute_workflow_stats)

        workflow_df.write.parquet(os.path.join(TARGET_DIR, Workflow.output_path()), mode="overwrite",
                                  compression="snappy")
        print("######\n Done parsing Workflows\n ######")

    print("######\n Start parsing Workload\n ######")
    json_dict = Workload.get_json_dict_from_spark_task_dataframe(task_df,
                                                                 domain="Industrial",
                                                                 start_date="2011-05-01",
                                                                 end_date="2011-05-30",
                                                                 authors=["Google"])

    os.makedirs(os.path.join(TARGET_DIR, Workload.output_path()), exist_ok=True)
    with open(os.path.join(TARGET_DIR, Workload.output_path(), "generic_information.json"), "w") as file:
        # Need this on 32-bit python.
        def default(o):
            if isinstance(o, np.int64):
                return int(o)

        file.write(json.dumps(json_dict, default=default))
    print("######\n Done parsing Workload\n ######")