def add_elapsed(df: pyspark.sql.DataFrame,
                cols: List[str]) -> pyspark.sql.DataFrame:
    def add_elapsed_column(col, asc):
        def fn(rows):
            last_store, last_date = None, None
            for r in rows:
                if last_store != r.Store:
                    last_store = r.Store
                    last_date = r.Date
                if r[col]:
                    last_date = r.Date
                fields = r.asDict().copy()
                fields[("After" if asc else "Before") + col] = (r.Date -
                                                                last_date).days
                yield Row(**fields)

        return fn

    # repartition: rearrange the rows in the DataFrame based on the partitioning expression
    # sortWithinPartitions: sort every partition in the DataFrame based on specific columns
    # mapPartitions: apply the 'add_elapsed_column' method to each partition in the dataset, and convert the partitions into a DataFrame
    df = df.repartition(df.Store)
    for asc in [False, True]:
        sort_col = df.Date.asc() if asc else df.Date.desc()
        rdd = df.sortWithinPartitions(df.Store.asc(), sort_col).rdd
        for col in cols:
            rdd = rdd.mapPartitions(add_elapsed_column(col, asc))
        df = rdd.toDF()
    return df
示例#2
0
def extract_datepart(df: pyspark.sql.DataFrame,
                     dt_col: str,
                     to_extract: str,
                     drop: bool = False) -> pyspark.sql.DataFrame:
    """
    Base function for extracting dateparts. Used in less abstracted functions.

    Parameters
    ----------
    df : pyspark.sql.DataFrame
        Base dataframe which contains ``dt_col`` column for extracting ``to_extract``.
    dt_col : str
        Name of date column to extract ``to_extract`` from.
    to_extract : str
        TODO
    drop : bool
        Whether or not to drop dt_col after extraction (default is False).

    Returns
    -------
    df : pyspark.sql.DataFrame
        df with ``to_extract`` column, optionally without original ``dt_col`` column.
    """
    df = df.withColumn(to_extract, getattr(F, to_extract)(F.col(dt_col)))
    if drop:
        df = df.drop(dt_col)
    return df
示例#3
0
def check_total_rows(left_df: pyspark.sql.DataFrame,
                     right_df: pyspark.sql.DataFrame) -> None:
    left_df_count = left_df.count()
    right_df_count = right_df.count()
    assert left_df_count == right_df_count, \
        f"Number of rows are not same.\n\n" \
        f"Actual Rows (left_df): {left_df_count}\n" \
        f"Expected Rows (right_df): {right_df_count}\n"
示例#4
0
def check_df_content(left_df: pyspark.sql.DataFrame,
                     right_df: pyspark.sql.DataFrame) -> None:
    logging.info('Executing: left_df - right_df')
    df_diff_oracle_raw = left_df.subtract(right_df)
    logging.info(df_diff_oracle_raw.count())

    logging.info('Executing: right_df - left_df')
    df_diff_raw_oracle = right_df.subtract(left_df)
    logging.info(df_diff_raw_oracle.count())

    assert left_df.subtract(right_df).count() == right_df.subtract(
        left_df).count() == 0
示例#5
0
def group_mean(data: pyspark.sql.DataFrame, groups, response, features):
    means = scipy.zeros((len(groups), len(features)))
    for i, target in enumerate(groups):
        df_t = data.filter("{} == {}".format(response, target))
        X_t = df_t.select(features).rdd.map(numpy.array)
        means[i, :] = column_means(X_t)
    return means
示例#6
0
def test_read_write_parquet(
    test_parquet_in_asset: PySparkDataAsset,
    iris_spark: pyspark.sql.DataFrame,
    fake_airflow_context: Any,
    spark_session: pyspark.sql.SparkSession,
) -> None:
    p = path.abspath(
        path.join(
            test_parquet_in_asset.staging_pickedup_path(fake_airflow_context)))
    os.makedirs(path.dirname(p), exist_ok=True)
    iris_spark.write.mode("overwrite").parquet(p)

    count_before = iris_spark.count()
    columns_before = len(iris_spark.columns)

    with pytest.raises(expected_exception=ValueError):
        PySparkDataAssetIO.read_data_asset(test_parquet_in_asset,
                                           source_files=[p])

    x = PySparkDataAssetIO.read_data_asset(test_parquet_in_asset,
                                           source_files=[p],
                                           spark_session=spark_session)

    assert count_before == x.count()
    assert columns_before == len(x.columns)

    # try with additional kwargs:
    x = PySparkDataAssetIO.read_data_asset(
        asset=test_parquet_in_asset,
        source_files=[p],
        spark_session=spark_session,
        mergeSchema=True,
    )

    assert count_before == x.count()
示例#7
0
def _write_dataframe_to_s3(config, logger, df: pyspark.sql.DataFrame, df_name: str) -> None:
    """
    Converts a PySpark DataFrame to Pandas, before writing out to a CSV file stored in Amazon
    S3, in the given bucket pulled from the config object.
    
    """
    logger.warn(f'About to write dataframe: {df_name} as CSV to S3')
    
    # Convert Pyspark dataframe to Pandas
    pd_df = df.toPandas()
    
    # Get S3 details
    s3 = boto3.resource('s3',
                        aws_access_key_id=config['AWS']['AWS_ACCESS_KEY_ID'],
                        aws_secret_access_key=config['AWS']['AWS_SECRET_ACCESS_KEY'])
    
    #Write Pandas df to CSV stored locally
    csv_buff = StringIO()
    
    pd_df.to_csv(csv_buff, sep=',', index = False)
    
    # Write to S3
    s3.Object(config['S3']['BUCKET_NAME'], f'{df_name}.csv').put(Body=csv_buff.getvalue())
    
    logger.warn(f'Finished writing dataframe: {df_name} as CSV to S3')
示例#8
0
def assert_test_dfs_equal(expected_df: pyspark.sql.DataFrame,
                          generated_df: pyspark.sql.DataFrame) -> None:
    """
    Used to compare two dataframes (typically, in a unit test).
    Better than the direct df1.equals(df2) method, as this function
    allows for tolerances in the floating point columns, and is
    also more descriptive with which parts of the two dataframes
    are in disagreement.
    :param expected_df: First dataframe to compare
    :param generated_df: Second dataframe to compare
    """

    row_limit = 10000

    e_count = expected_df.count()
    g_count = generated_df.count()

    if (e_count > row_limit) or (g_count > row_limit):
        raise Exception(
            f"One or both of the dataframes passed has too many rows (>{row_limit})."
            f"Please limit your test sizes to be lower than this number.")

    assert e_count == g_count, "The dataframes have a different number of rows."

    expected_pdf = expected_df.toPandas()
    generated_pdf = generated_df.toPandas()

    assert list(expected_pdf.columns) == list(generated_pdf.columns), \
        "The two dataframes have different columns."

    for col in expected_pdf.columns:
        error_msg = f"The columns with name: `{col}` were not equal."
        if expected_pdf[col].dtype.type == np.object_:
            assert expected_pdf[[col]].equals(generated_pdf[[col]]), error_msg
        else:
            # Numpy will not equate nulls on both sides. Filter them out.
            expected_pdf = expected_pdf[expected_pdf[col].notnull()]
            generated_pdf = generated_pdf[generated_pdf[col].notnull()]
            try:
                is_close = np.allclose(expected_pdf[col].values,
                                       generated_pdf[col].values)
            except ValueError:
                logging.error(
                    f"Problem encountered while equating column '{col}'.")
                raise
            assert is_close, error_msg
def build_vocabulary(df: pyspark.sql.DataFrame) -> Dict[str, List[Any]]:
    vocab = {}
    for col in CATEGORICAL_COLS:
        values = [r[0] for r in df.select(col).distinct().collect()]
        col_type = type([x for x in values if x is not None][0])
        default_value = col_type()
        vocab[col] = sorted(values, key=lambda x: x or default_value)
    return vocab
def _sample_dfs(t_df: pyspark.sql.DataFrame, t_fracs: pd.DataFrame,
                c_can_df: pyspark.sql.DataFrame, c_fracs: pd.DataFrame,
                match_col: str) -> Tuple[DataFrame, DataFrame]:
    r"""given treatment and control pops and their stratified sample
    fracs, return balanced pops

    Parameters
    ----------
    t_df : pyspark.DataFrame
        treatment pop
    t_fracs: pd.DataFrame
        with columns `match_col` and 'treatment_scaled_sample_fraction'
    c_can_df : pyspark.DataFrame
        control can pop
    c_fracs : pd.DataFrame
        with columns `match_col` and control_scaled_sample_fraction

    Returns
    -------
    t_out : pyspark.sql.DataFrame
    c_out : pyspark.sql.DataFrame

    Raises
    ------
    UncaughtExceptions

    """
    _persist_if_unpersisted(t_df)
    _persist_if_unpersisted(c_can_df)

    t_fracs = t_fracs.set_index(
        match_col).treatment_scaled_sample_fraction.to_dict()
    t_dict = {}
    for key, value in t_fracs.items():
        t_dict[int(key)] = min(float(value), 1)
    t_out = t_df.sampleBy(col=match_col, fractions=t_dict, seed=42)

    c_fracs = c_fracs.set_index(
        match_col).control_scaled_sample_fraction.to_dict()
    c_dict = {}
    for key, value in c_fracs.items():
        c_dict[int(key)] = float(value)
    c_out = c_can_df.sampleBy(col=match_col, fractions=c_dict, seed=42)

    return t_out, c_out
示例#11
0
def within_group_scatter(data: pyspark.sql.DataFrame, features, response,
                         targets):
    p = len(features)
    sw = numpy.zeros((p, p))
    for target in targets:
        df_t = data.filter("{} == '{}'".format(response, target))
        X_t = RowMatrix(df_t.select(features).rdd.map(numpy.array))
        sw += X_t.computeCovariance().toArray() * (df_t.count() - 1)
    return sw
示例#12
0
def perc_weather_cancellations_per_week(spark: sk.sql.SparkSession,
                                        data: sk.sql.DataFrame) -> sk.RDD:
    onlycancelled = data.select(data['Cancelled'] == 1)
    codeperweek = data.rdd.map(lambda row: (week_from_row(row), (1, 1 if (str(
        row['CancellationCode']).strip() == 'B') else 0)))
    fractioncancelled = codeperweek.reduceByKey(lambda l, r:
                                                (l[0] + r[0], l[1] + r[1]))
    return fractioncancelled.mapValues(
        lambda v: v[1] / v[0] * 100.0).sortByKey()
示例#13
0
def normalise_fields_names(df: pyspark.sql.DataFrame,
                           fieldname_normaliser=__normalise_fieldname__):
    return df.select([
        f.col("`{}`".format(field.name)).cast(
            __rename_nested_field__(field.dataType,
                                    fieldname_normaliser)).alias(
                                        fieldname_normaliser(field.name))
        for field in df.schema.fields
    ])
def lookup_columns(df: pyspark.sql.DataFrame,
                   vocab: Dict[str, List[Any]]) -> pyspark.sql.DataFrame:
    def lookup(mapping):
        def fn(v):
            return mapping.index(v)

        return F.udf(fn, returnType=T.IntegerType())

    for col, mapping in vocab.items():
        df = df.withColumn(col, lookup(mapping)(df[col]))
    return df
示例#15
0
    def save_to_ray(self, df: pyspark.sql.DataFrame) -> SharedDataset:
        return_type = StructType()
        return_type.add(StructField("node_label", StringType(), True))
        return_type.add(StructField("fetch_index", IntegerType(), False))
        return_type.add(StructField("size", LongType(), False))

        @pandas_udf(return_type)
        def save(batch_iter: Iterator[pd.DataFrame]) -> Iterator[pd.DataFrame]:
            if not ray.is_initialized():
                redis_config = {}
                broadcasted = _global_broadcasted["redis"].value
                redis_config["address"] = broadcasted["address"]
                redis_config["password"] = broadcasted["password"]

                local_address = get_node_address()

                ray.init(address=redis_config["address"],
                         node_ip_address=local_address,
                         redis_password=redis_config["password"])

            node_label = f"node:{ray.services.get_node_ip_address()}"
            block_holder = _global_block_holder[node_label]
            object_ids = []
            sizes = []
            for pdf in batch_iter:
                obj = ray.put(pickle.dumps(pdf))
                # TODO: register object in batch
                object_ids.append(block_holder.register_object_id.remote([obj
                                                                          ]))
                sizes.append(len(pdf))

            indexes = ray.get(object_ids)
            result_dfs = []
            for index, block_size in zip(indexes, sizes):
                result_dfs.append(
                    pd.DataFrame({
                        "node_label": [node_label],
                        "fetch_index": [index],
                        "size": [block_size]
                    }))
            return iter(result_dfs)

        results = df.mapInPandas(save).collect()
        fetch_indexes = []
        block_holder_mapping = {}
        block_sizes = []
        for row in results:
            fetch_indexes.append((row["node_label"], row["fetch_index"]))
            block_sizes.append(row["size"])
            if row["node_label"] not in block_holder_mapping:
                block_holder_mapping[row["node_label"]] = _global_block_holder[
                    row["node_label"]]
        return StandaloneClusterSharedDataset(fetch_indexes, block_sizes,
                                              block_holder_mapping)
示例#16
0
def add_increment(
    current_df: pyspark.sql.DataFrame,
    increment_df: pyspark.sql.DataFrame,
) -> pyspark.sql.DataFrame:
    union_df = current_df.union(increment_df)
    return (union_df.withColumn(
        '_row_number',
        F.row_number().over(
            Window.partitionBy(union_df['link']).orderBy([
                'scraped_at'
            ]))).where(F.col('_row_number') == 1).drop('_row_number'))
示例#17
0
def displayHead(df: pyspark.sql.DataFrame, nrows: int = 5):
    """
    returns the first `nrow` lines of a Spark dataframe as Pandas dataframe

    Args:
        df: The spark dataframe
        nrows: number of rows

    Returns: Pandas dataframe
    """
    return df.limit(nrows).toPandas()
示例#18
0
    def process_inquiries(self, review: pyspark.sql.DataFrame,
                          metadata: pyspark.sql.DataFrame) -> None:
        logging.info("Start pipeline")

        logging.info("Processing")
        review_transform_date = review.select(
            'asin', 'overall',
            'unixReviewTime').withColumn("unixReviewTime",
                                         from_unixtime("unixReviewTime"))
        review_date_decompose = review_transform_date.withColumn(
            "month",
            month("unixReviewTime")).withColumn("year", year("unixReviewTime"))
        metadata_flatten_categories = metadata.select(
            'asin', explode('categories')).select('asin', explode('col'))
        join_review_metadata = review_date_decompose.join(
            metadata_flatten_categories, on=['asin'], how='inner')
        groupby_review_metadata = join_review_metadata.groupBy(
            "year", "month", "col").count().orderBy('year',
                                                    'month',
                                                    'count',
                                                    ascending=False).cache()
        patrions = groupby_review_metadata.withColumn(
            "rank",
            row_number().over(self.get_partitions())).cache()
        filter_patrions = patrions.filter(self.patrions.rank <= 5).cache()
        groupby_review_metadata.unpersist()
        result_inner = join_review_metadata.join(filter_patrions,
                                                 on=['year', 'month', 'col'],
                                                 how='inner')
        patrions.unpersist()
        filter_patrions.unpersist()
        result_groupby = result_inner.groupBy(
            'year', 'month',
            'col').avg('overall').alias('rating').orderBy('year',
                                                          'month',
                                                          ascending=True)
        result_groupby.show()
        logging.info("Finished")
        self.upsert_database(result_groupby, 'mydb', 'myset')
示例#19
0
def show_df(df: pyspark.sql.DataFrame,
            columns: list,
            rows: int = 10,
            sample=False,
            truncate=True):
    """
    Prints out number of rows in pyspark df

    :param df:  pyspark dataframe
    :param columns: list of columns to print
    :param rows: how many rows to print - default 10
    :param sample: should we sample - default False
    :param truncate: truncate output - default True
    :return:
    """
    if sample:
        sample_percent = min(rows / df.count(), 1.0)
        log.info(f'sampling percentage: {sample_percent}')
        df.select(columns).sample(False, sample_percent,
                                  seed=1).show(rows, truncate=truncate)
    else:
        df.select(columns).show(rows, truncate=truncate)
示例#20
0
def assert_pyspark_df_equal(left_df: pyspark.sql.DataFrame,
                            right_df: pyspark.sql.DataFrame,
                            check_data_type: bool,
                            check_cols_in_order: bool = True,
                            order_by: Optional[str] = None) -> None:
    """
    left_df: destination layer, ex: raw
    right_df: origin layer, ex: oracle
    """
    # Check data types
    if check_data_type:
        check_schema(check_cols_in_order, left_df, right_df)

    # Check total rows
    check_total_rows(left_df, right_df)

    # Sort df
    if order_by:
        left_df = left_df.orderBy(order_by)
        right_df = right_df.orderBy(order_by)

    # Check dataframe content
    check_df_content(left_df, right_df)
示例#21
0
    def reload_df(self,
                  df: pyspark.sql.DataFrame,
                  name: str,
                  num_partitions: int = None,
                  partition_cols: List[str] = None,
                  pre_final: bool = False) -> pyspark.sql.DataFrame:
        """Saves a DataFrame as parquet and reloads it.

        Args:
            df (pyspark.sql.DataFrame):
            name (str):
            num_partitions (int):
            partition_cols:
            pre_final (bool):
        """

        self.save_to_parquet(df=df,
                             name=name,
                             num_partitions=num_partitions,
                             partition_cols=partition_cols,
                             pre_final=pre_final)
        df = self.load_from_parquet(name=name, pre_final=pre_final)
        df.persist(StorageLevel.MEMORY_AND_DISK)
        return df
def prepare_google_trend(
    google_trend_csv: pyspark.sql.DataFrame, ) -> pyspark.sql.DataFrame:
    google_trend_all = google_trend_csv.withColumn(
        "Date",
        F.regexp_extract(google_trend_csv.week, "(.*?) -", 1)).withColumn(
            "State",
            F.regexp_extract(google_trend_csv.file, "Rossmann_DE_(.*)", 1))

    # map state NI -> HB,NI to align with other data sources
    google_trend_all = google_trend_all.withColumn(
        "State",
        F.when(google_trend_all.State == "NI",
               "HB,NI").otherwise(google_trend_all.State),
    )

    # expand dates
    return expand_date(google_trend_all)
示例#23
0
def flatten(df: pyspark.sql.DataFrame,
            fieldname_normaliser=__normalise_fieldname__):
    cols = []
    for child in __get_fields_info__(df.schema):
        if len(child) > 2:
            ex = "x.{}".format(child[-1])
            for seg in child[-2:0:-1]:
                if seg != '``':
                    ex = "transform(x.{outer}, x -> {inner})".format(outer=seg,
                                                                     inner=ex)
            ex = "transform({outer}, x -> {inner})".format(outer=child[0],
                                                           inner=ex)
        else:
            ex = ".".join(child)
        cols.append(
            f.expr(ex).alias(
                fieldname_normaliser("_".join(child).replace('`', ''))))
    return df.select(cols)
示例#24
0
    def save_to_parquet(self,
                        df: pyspark.sql.DataFrame,
                        name: str,
                        mode: str = "overwrite",
                        num_partitions: int = None,
                        partition_cols: List[str] = None,
                        pre_final: bool = False):
        """Saves a DataFrame into a parquet file.

        Args:
            df (pyspark.sql.DataFrame):
            name (str):
            mode (str):
            num_partitions (int):
            partition_cols (list):
            pre_final (bool):
        """

        logger.debug(
            "Saving %s to parquet.." %
            name if not pre_final else "Saving %s.pre_final to parquet.." %
            name)
        path = os.path.join(self.df_data_folder, name, str(self.loop_counter))
        if not os.path.exists(path):
            os.makedirs(path)
        if pre_final:
            parquet_name = os.path.join(path, name + ".pre_final.parquet")
        else:
            parquet_name = os.path.join(path, name + ".parquet")

        if partition_cols and num_partitions:
            df.repartition(
                num_partitions,
                *partition_cols).write.mode(mode).parquet(parquet_name)
        elif num_partitions and not partition_cols:
            df.repartition(num_partitions).write.mode(mode).parquet(
                parquet_name)
        elif partition_cols and not num_partitions:
            df.repartition(
                *partition_cols).write.mode(mode).parquet(parquet_name)
        else:
            df.repartition(1).write.mode(mode).parquet(parquet_name)
示例#25
0
    def save_to_ray(self, df: pyspark.sql.DataFrame,
                    num_shards: int) -> PandasDataset:
        # call java function from python
        df = df.repartition(num_shards)
        sql_context = df.sql_ctx
        jvm = sql_context.sparkSession.sparkContext._jvm
        jdf = df._jdf
        object_store_writer = jvm.org.apache.spark.sql.raydp.ObjectStoreWriter(
            jdf)
        records = object_store_writer.save()

        worker = ray.worker.global_worker

        blocks: List[ray.ObjectRef] = []
        block_sizes: List[int] = []
        for record in records:
            owner_address = record.ownerAddress()
            object_id = ray.ObjectID(record.objectId())
            num_records = record.numRecords()
            # Register the ownership of the ObjectRef
            worker.core_worker.deserialize_and_register_object_ref(
                object_id.binary(), ray.ObjectRef.nil(), owner_address)

            blocks.append(object_id)
            block_sizes.append(num_records)

        divided_blocks = divide_blocks(block_sizes, num_shards)
        record_batch_set: List[RecordBatch] = []
        for i in range(num_shards):
            indexes = divided_blocks[i]
            object_ids = [blocks[index] for index in indexes]
            record_batch_set.append(RecordBatch(object_ids))

        # TODO: we should specify the resource spec for each shard
        ds = parallel_dataset.from_iterators(generators=record_batch_set,
                                             name="spark_df")

        def resolve_fn(it: "Iterable[RecordBatch]") -> "Iterator[RecordBatch]":
            for item in it:
                item.resolve()
                yield item

        return ds.transform(resolve_fn,
                            ".RecordBatch#resolve()").flatten().to_pandas(None)
示例#26
0
def test_read_write_csv(
    test_csv_asset: PySparkDataAsset,
    iris_spark: pyspark.sql.DataFrame,
    spark_session: pyspark.sql.SparkSession,
) -> None:
    # try without any extra kwargs:
    PySparkDataAssetIO.write_data_asset(asset=test_csv_asset, data=iris_spark)

    # try with additional kwargs:
    PySparkDataAssetIO.write_data_asset(asset=test_csv_asset,
                                        data=iris_spark,
                                        header=True)

    # test mode; default is overwrite, switch to error (if exists) should raise:
    with pytest.raises(AnalysisException):
        PySparkDataAssetIO.write_data_asset(asset=test_csv_asset,
                                            data=iris_spark,
                                            header=True,
                                            mode="error")

    # test retrieval
    # before we can retrieve, we need to move the data from 'staging' to 'ready'
    os.makedirs(test_csv_asset.ready_path, exist_ok=True)

    # load the prepared data
    shutil.rmtree(test_csv_asset.ready_path)
    shutil.move(test_csv_asset.staging_ready_path, test_csv_asset.ready_path)

    retrieved = PySparkDataAssetIO.retrieve_data_asset(
        test_csv_asset,
        spark_session=spark_session,
        inferSchema=True,
        header=True)
    assert retrieved.count() == iris_spark.count()

    # Test check for missing 'spark_session' kwarg
    with pytest.raises(ValueError):
        PySparkDataAssetIO.retrieve_data_asset(test_csv_asset)

    # Test check for invalid 'spark_session' kwarg
    with pytest.raises(TypeError):
        PySparkDataAssetIO.retrieve_data_asset(test_csv_asset,
                                               spark_session=42)
def _calc_var(df: pyspark.sql.DataFrame, label_col: str) -> pd.DataFrame:
    r"""calculate variance for each column that isnt the label_col

     Parameters
    ----------
    df : pyspark.sql.DataFrame
        df where rows are observations, all columns except `label_col` are
        predictors.
    label_col : str

    Returns
    -------
    bias_df : pd.DataFrame
        pandas dataframe where predictors are index and only column is
        variance

    Raises
    ------
    UncaughtExceptions

    Notes
    -----


    """
    pred_cols = [x for x in df.columns if x != label_col]
    s_var_df = df.groupby(label_col).agg({x: 'variance'
                                          for x in pred_cols
                                          }).toPandas().transpose()
    s_var_df = s_var_df.reset_index()
    s_var_df['index'] = s_var_df['index'].str.replace(r')', '').str.replace(
        r'variance\(', '')
    s_var_df = s_var_df.set_index('index')
    s_var_df.columns = ["var_{0}".format(x) for x in s_var_df.columns]
    s_var_df = s_var_df.loc[s_var_df.index != 'label', :]
    return s_var_df
def prepare_df(
    df: pyspark.sql.DataFrame,
    store_csv: pyspark.sql.DataFrame,
    store_states_csv: pyspark.sql.DataFrame,
    state_names_csv: pyspark.sql.DataFrame,
    google_trend_csv: pyspark.sql.DataFrame,
    weather_csv: pyspark.sql.DataFrame,
) -> pyspark.sql.DataFrame:
    num_rows = df.count()

    # expand dates
    df = expand_date(df)

    # create new columns in the DataFrame by filtering out special events(promo/holiday where sales was zero or store was closed).
    df = (df.withColumn("Open", df.Open != "0").withColumn(
        "Promo",
        df.Promo != "0").withColumn("StateHoliday",
                                    df.StateHoliday != "0").withColumn(
                                        "SchoolHoliday",
                                        df.SchoolHoliday != "0"))

    # merge store information
    store = store_csv.join(store_states_csv, "Store")
    df = df.join(store, "Store")

    # merge Google Trend information
    google_trend_all = prepare_google_trend(google_trend_csv)
    df = df.join(google_trend_all,
                 ["State", "Year", "Week"]).select(df["*"],
                                                   google_trend_all.trend)

    # merge in Google Trend for whole Germany
    google_trend_de = google_trend_all[google_trend_all.file ==
                                       "Rossmann_DE"].withColumnRenamed(
                                           "trend", "trend_de")
    df = df.join(google_trend_de,
                 ["Year", "Week"]).select(df["*"], google_trend_de.trend_de)

    # merge weather
    weather = weather_csv.join(state_names_csv,
                               weather_csv.file == state_names_csv.StateName)
    df = df.join(weather, ["State", "Date"])

    # fix null values
    df = (df.withColumn(
        "CompetitionOpenSinceYear",
        F.coalesce(df.CompetitionOpenSinceYear, F.lit(1900)),
    ).withColumn(
        "CompetitionOpenSinceMonth",
        F.coalesce(df.CompetitionOpenSinceMonth, F.lit(1)),
    ).withColumn("Promo2SinceYear",
                 F.coalesce(df.Promo2SinceYear, F.lit(1900))).withColumn(
                     "Promo2SinceWeek", F.coalesce(df.Promo2SinceWeek,
                                                   F.lit(1))))

    # days and months since the competition has been open, cap it to 2 years
    df = df.withColumn(
        "CompetitionOpenSince",
        F.to_date(
            F.format_string("%s-%s-15", df.CompetitionOpenSinceYear,
                            df.CompetitionOpenSinceMonth)),
    )
    df = df.withColumn(
        "CompetitionDaysOpen",
        F.when(
            df.CompetitionOpenSinceYear > 1900,
            F.greatest(
                F.lit(0),
                F.least(F.lit(360 * 2),
                        F.datediff(df.Date, df.CompetitionOpenSince)),
            ),
        ).otherwise(0),
    )
    df = df.withColumn("CompetitionMonthsOpen",
                       (df.CompetitionDaysOpen / 30).cast(T.IntegerType()))

    # days and weeks of promotion, cap it to 25 weeks
    df = df.withColumn(
        "Promo2Since",
        F.expr(
            'date_add(format_string("%s-01-01", Promo2SinceYear), (cast(Promo2SinceWeek as int) - 1) * 7)'
        ),
    )
    df = df.withColumn(
        "Promo2Days",
        F.when(
            df.Promo2SinceYear > 1900,
            F.greatest(
                F.lit(0),
                F.least(F.lit(25 * 7), F.datediff(df.Date, df.Promo2Since))),
        ).otherwise(0),
    )
    df = df.withColumn("Promo2Weeks",
                       (df.Promo2Days / 7).cast(T.IntegerType()))

    # ensure that no row was lost through inner joins
    assert num_rows == df.count(), "lost rows in joins"
    return df
示例#29
0
def cast(df: pyspark.sql.DataFrame, col_name: str,
         dtype: str) -> pyspark.sql.DataFrame:
    return df.withColumn(col_name, F.col(col_name).cast(dtype))
def cast_columns(df: pyspark.sql.DataFrame,
                 cols: List[str]) -> pyspark.sql.DataFrame:
    for col in cols:
        df = df.withColumn(col,
                           F.coalesce(df[col].cast(T.FloatType()), F.lit(0.0)))
    return df