def add_elapsed(df: pyspark.sql.DataFrame, cols: List[str]) -> pyspark.sql.DataFrame: def add_elapsed_column(col, asc): def fn(rows): last_store, last_date = None, None for r in rows: if last_store != r.Store: last_store = r.Store last_date = r.Date if r[col]: last_date = r.Date fields = r.asDict().copy() fields[("After" if asc else "Before") + col] = (r.Date - last_date).days yield Row(**fields) return fn # repartition: rearrange the rows in the DataFrame based on the partitioning expression # sortWithinPartitions: sort every partition in the DataFrame based on specific columns # mapPartitions: apply the 'add_elapsed_column' method to each partition in the dataset, and convert the partitions into a DataFrame df = df.repartition(df.Store) for asc in [False, True]: sort_col = df.Date.asc() if asc else df.Date.desc() rdd = df.sortWithinPartitions(df.Store.asc(), sort_col).rdd for col in cols: rdd = rdd.mapPartitions(add_elapsed_column(col, asc)) df = rdd.toDF() return df
def extract_datepart(df: pyspark.sql.DataFrame, dt_col: str, to_extract: str, drop: bool = False) -> pyspark.sql.DataFrame: """ Base function for extracting dateparts. Used in less abstracted functions. Parameters ---------- df : pyspark.sql.DataFrame Base dataframe which contains ``dt_col`` column for extracting ``to_extract``. dt_col : str Name of date column to extract ``to_extract`` from. to_extract : str TODO drop : bool Whether or not to drop dt_col after extraction (default is False). Returns ------- df : pyspark.sql.DataFrame df with ``to_extract`` column, optionally without original ``dt_col`` column. """ df = df.withColumn(to_extract, getattr(F, to_extract)(F.col(dt_col))) if drop: df = df.drop(dt_col) return df
def check_total_rows(left_df: pyspark.sql.DataFrame, right_df: pyspark.sql.DataFrame) -> None: left_df_count = left_df.count() right_df_count = right_df.count() assert left_df_count == right_df_count, \ f"Number of rows are not same.\n\n" \ f"Actual Rows (left_df): {left_df_count}\n" \ f"Expected Rows (right_df): {right_df_count}\n"
def check_df_content(left_df: pyspark.sql.DataFrame, right_df: pyspark.sql.DataFrame) -> None: logging.info('Executing: left_df - right_df') df_diff_oracle_raw = left_df.subtract(right_df) logging.info(df_diff_oracle_raw.count()) logging.info('Executing: right_df - left_df') df_diff_raw_oracle = right_df.subtract(left_df) logging.info(df_diff_raw_oracle.count()) assert left_df.subtract(right_df).count() == right_df.subtract( left_df).count() == 0
def group_mean(data: pyspark.sql.DataFrame, groups, response, features): means = scipy.zeros((len(groups), len(features))) for i, target in enumerate(groups): df_t = data.filter("{} == {}".format(response, target)) X_t = df_t.select(features).rdd.map(numpy.array) means[i, :] = column_means(X_t) return means
def test_read_write_parquet( test_parquet_in_asset: PySparkDataAsset, iris_spark: pyspark.sql.DataFrame, fake_airflow_context: Any, spark_session: pyspark.sql.SparkSession, ) -> None: p = path.abspath( path.join( test_parquet_in_asset.staging_pickedup_path(fake_airflow_context))) os.makedirs(path.dirname(p), exist_ok=True) iris_spark.write.mode("overwrite").parquet(p) count_before = iris_spark.count() columns_before = len(iris_spark.columns) with pytest.raises(expected_exception=ValueError): PySparkDataAssetIO.read_data_asset(test_parquet_in_asset, source_files=[p]) x = PySparkDataAssetIO.read_data_asset(test_parquet_in_asset, source_files=[p], spark_session=spark_session) assert count_before == x.count() assert columns_before == len(x.columns) # try with additional kwargs: x = PySparkDataAssetIO.read_data_asset( asset=test_parquet_in_asset, source_files=[p], spark_session=spark_session, mergeSchema=True, ) assert count_before == x.count()
def _write_dataframe_to_s3(config, logger, df: pyspark.sql.DataFrame, df_name: str) -> None: """ Converts a PySpark DataFrame to Pandas, before writing out to a CSV file stored in Amazon S3, in the given bucket pulled from the config object. """ logger.warn(f'About to write dataframe: {df_name} as CSV to S3') # Convert Pyspark dataframe to Pandas pd_df = df.toPandas() # Get S3 details s3 = boto3.resource('s3', aws_access_key_id=config['AWS']['AWS_ACCESS_KEY_ID'], aws_secret_access_key=config['AWS']['AWS_SECRET_ACCESS_KEY']) #Write Pandas df to CSV stored locally csv_buff = StringIO() pd_df.to_csv(csv_buff, sep=',', index = False) # Write to S3 s3.Object(config['S3']['BUCKET_NAME'], f'{df_name}.csv').put(Body=csv_buff.getvalue()) logger.warn(f'Finished writing dataframe: {df_name} as CSV to S3')
def assert_test_dfs_equal(expected_df: pyspark.sql.DataFrame, generated_df: pyspark.sql.DataFrame) -> None: """ Used to compare two dataframes (typically, in a unit test). Better than the direct df1.equals(df2) method, as this function allows for tolerances in the floating point columns, and is also more descriptive with which parts of the two dataframes are in disagreement. :param expected_df: First dataframe to compare :param generated_df: Second dataframe to compare """ row_limit = 10000 e_count = expected_df.count() g_count = generated_df.count() if (e_count > row_limit) or (g_count > row_limit): raise Exception( f"One or both of the dataframes passed has too many rows (>{row_limit})." f"Please limit your test sizes to be lower than this number.") assert e_count == g_count, "The dataframes have a different number of rows." expected_pdf = expected_df.toPandas() generated_pdf = generated_df.toPandas() assert list(expected_pdf.columns) == list(generated_pdf.columns), \ "The two dataframes have different columns." for col in expected_pdf.columns: error_msg = f"The columns with name: `{col}` were not equal." if expected_pdf[col].dtype.type == np.object_: assert expected_pdf[[col]].equals(generated_pdf[[col]]), error_msg else: # Numpy will not equate nulls on both sides. Filter them out. expected_pdf = expected_pdf[expected_pdf[col].notnull()] generated_pdf = generated_pdf[generated_pdf[col].notnull()] try: is_close = np.allclose(expected_pdf[col].values, generated_pdf[col].values) except ValueError: logging.error( f"Problem encountered while equating column '{col}'.") raise assert is_close, error_msg
def build_vocabulary(df: pyspark.sql.DataFrame) -> Dict[str, List[Any]]: vocab = {} for col in CATEGORICAL_COLS: values = [r[0] for r in df.select(col).distinct().collect()] col_type = type([x for x in values if x is not None][0]) default_value = col_type() vocab[col] = sorted(values, key=lambda x: x or default_value) return vocab
def _sample_dfs(t_df: pyspark.sql.DataFrame, t_fracs: pd.DataFrame, c_can_df: pyspark.sql.DataFrame, c_fracs: pd.DataFrame, match_col: str) -> Tuple[DataFrame, DataFrame]: r"""given treatment and control pops and their stratified sample fracs, return balanced pops Parameters ---------- t_df : pyspark.DataFrame treatment pop t_fracs: pd.DataFrame with columns `match_col` and 'treatment_scaled_sample_fraction' c_can_df : pyspark.DataFrame control can pop c_fracs : pd.DataFrame with columns `match_col` and control_scaled_sample_fraction Returns ------- t_out : pyspark.sql.DataFrame c_out : pyspark.sql.DataFrame Raises ------ UncaughtExceptions """ _persist_if_unpersisted(t_df) _persist_if_unpersisted(c_can_df) t_fracs = t_fracs.set_index( match_col).treatment_scaled_sample_fraction.to_dict() t_dict = {} for key, value in t_fracs.items(): t_dict[int(key)] = min(float(value), 1) t_out = t_df.sampleBy(col=match_col, fractions=t_dict, seed=42) c_fracs = c_fracs.set_index( match_col).control_scaled_sample_fraction.to_dict() c_dict = {} for key, value in c_fracs.items(): c_dict[int(key)] = float(value) c_out = c_can_df.sampleBy(col=match_col, fractions=c_dict, seed=42) return t_out, c_out
def within_group_scatter(data: pyspark.sql.DataFrame, features, response, targets): p = len(features) sw = numpy.zeros((p, p)) for target in targets: df_t = data.filter("{} == '{}'".format(response, target)) X_t = RowMatrix(df_t.select(features).rdd.map(numpy.array)) sw += X_t.computeCovariance().toArray() * (df_t.count() - 1) return sw
def perc_weather_cancellations_per_week(spark: sk.sql.SparkSession, data: sk.sql.DataFrame) -> sk.RDD: onlycancelled = data.select(data['Cancelled'] == 1) codeperweek = data.rdd.map(lambda row: (week_from_row(row), (1, 1 if (str( row['CancellationCode']).strip() == 'B') else 0))) fractioncancelled = codeperweek.reduceByKey(lambda l, r: (l[0] + r[0], l[1] + r[1])) return fractioncancelled.mapValues( lambda v: v[1] / v[0] * 100.0).sortByKey()
def normalise_fields_names(df: pyspark.sql.DataFrame, fieldname_normaliser=__normalise_fieldname__): return df.select([ f.col("`{}`".format(field.name)).cast( __rename_nested_field__(field.dataType, fieldname_normaliser)).alias( fieldname_normaliser(field.name)) for field in df.schema.fields ])
def lookup_columns(df: pyspark.sql.DataFrame, vocab: Dict[str, List[Any]]) -> pyspark.sql.DataFrame: def lookup(mapping): def fn(v): return mapping.index(v) return F.udf(fn, returnType=T.IntegerType()) for col, mapping in vocab.items(): df = df.withColumn(col, lookup(mapping)(df[col])) return df
def save_to_ray(self, df: pyspark.sql.DataFrame) -> SharedDataset: return_type = StructType() return_type.add(StructField("node_label", StringType(), True)) return_type.add(StructField("fetch_index", IntegerType(), False)) return_type.add(StructField("size", LongType(), False)) @pandas_udf(return_type) def save(batch_iter: Iterator[pd.DataFrame]) -> Iterator[pd.DataFrame]: if not ray.is_initialized(): redis_config = {} broadcasted = _global_broadcasted["redis"].value redis_config["address"] = broadcasted["address"] redis_config["password"] = broadcasted["password"] local_address = get_node_address() ray.init(address=redis_config["address"], node_ip_address=local_address, redis_password=redis_config["password"]) node_label = f"node:{ray.services.get_node_ip_address()}" block_holder = _global_block_holder[node_label] object_ids = [] sizes = [] for pdf in batch_iter: obj = ray.put(pickle.dumps(pdf)) # TODO: register object in batch object_ids.append(block_holder.register_object_id.remote([obj ])) sizes.append(len(pdf)) indexes = ray.get(object_ids) result_dfs = [] for index, block_size in zip(indexes, sizes): result_dfs.append( pd.DataFrame({ "node_label": [node_label], "fetch_index": [index], "size": [block_size] })) return iter(result_dfs) results = df.mapInPandas(save).collect() fetch_indexes = [] block_holder_mapping = {} block_sizes = [] for row in results: fetch_indexes.append((row["node_label"], row["fetch_index"])) block_sizes.append(row["size"]) if row["node_label"] not in block_holder_mapping: block_holder_mapping[row["node_label"]] = _global_block_holder[ row["node_label"]] return StandaloneClusterSharedDataset(fetch_indexes, block_sizes, block_holder_mapping)
def add_increment( current_df: pyspark.sql.DataFrame, increment_df: pyspark.sql.DataFrame, ) -> pyspark.sql.DataFrame: union_df = current_df.union(increment_df) return (union_df.withColumn( '_row_number', F.row_number().over( Window.partitionBy(union_df['link']).orderBy([ 'scraped_at' ]))).where(F.col('_row_number') == 1).drop('_row_number'))
def displayHead(df: pyspark.sql.DataFrame, nrows: int = 5): """ returns the first `nrow` lines of a Spark dataframe as Pandas dataframe Args: df: The spark dataframe nrows: number of rows Returns: Pandas dataframe """ return df.limit(nrows).toPandas()
def process_inquiries(self, review: pyspark.sql.DataFrame, metadata: pyspark.sql.DataFrame) -> None: logging.info("Start pipeline") logging.info("Processing") review_transform_date = review.select( 'asin', 'overall', 'unixReviewTime').withColumn("unixReviewTime", from_unixtime("unixReviewTime")) review_date_decompose = review_transform_date.withColumn( "month", month("unixReviewTime")).withColumn("year", year("unixReviewTime")) metadata_flatten_categories = metadata.select( 'asin', explode('categories')).select('asin', explode('col')) join_review_metadata = review_date_decompose.join( metadata_flatten_categories, on=['asin'], how='inner') groupby_review_metadata = join_review_metadata.groupBy( "year", "month", "col").count().orderBy('year', 'month', 'count', ascending=False).cache() patrions = groupby_review_metadata.withColumn( "rank", row_number().over(self.get_partitions())).cache() filter_patrions = patrions.filter(self.patrions.rank <= 5).cache() groupby_review_metadata.unpersist() result_inner = join_review_metadata.join(filter_patrions, on=['year', 'month', 'col'], how='inner') patrions.unpersist() filter_patrions.unpersist() result_groupby = result_inner.groupBy( 'year', 'month', 'col').avg('overall').alias('rating').orderBy('year', 'month', ascending=True) result_groupby.show() logging.info("Finished") self.upsert_database(result_groupby, 'mydb', 'myset')
def show_df(df: pyspark.sql.DataFrame, columns: list, rows: int = 10, sample=False, truncate=True): """ Prints out number of rows in pyspark df :param df: pyspark dataframe :param columns: list of columns to print :param rows: how many rows to print - default 10 :param sample: should we sample - default False :param truncate: truncate output - default True :return: """ if sample: sample_percent = min(rows / df.count(), 1.0) log.info(f'sampling percentage: {sample_percent}') df.select(columns).sample(False, sample_percent, seed=1).show(rows, truncate=truncate) else: df.select(columns).show(rows, truncate=truncate)
def assert_pyspark_df_equal(left_df: pyspark.sql.DataFrame, right_df: pyspark.sql.DataFrame, check_data_type: bool, check_cols_in_order: bool = True, order_by: Optional[str] = None) -> None: """ left_df: destination layer, ex: raw right_df: origin layer, ex: oracle """ # Check data types if check_data_type: check_schema(check_cols_in_order, left_df, right_df) # Check total rows check_total_rows(left_df, right_df) # Sort df if order_by: left_df = left_df.orderBy(order_by) right_df = right_df.orderBy(order_by) # Check dataframe content check_df_content(left_df, right_df)
def reload_df(self, df: pyspark.sql.DataFrame, name: str, num_partitions: int = None, partition_cols: List[str] = None, pre_final: bool = False) -> pyspark.sql.DataFrame: """Saves a DataFrame as parquet and reloads it. Args: df (pyspark.sql.DataFrame): name (str): num_partitions (int): partition_cols: pre_final (bool): """ self.save_to_parquet(df=df, name=name, num_partitions=num_partitions, partition_cols=partition_cols, pre_final=pre_final) df = self.load_from_parquet(name=name, pre_final=pre_final) df.persist(StorageLevel.MEMORY_AND_DISK) return df
def prepare_google_trend( google_trend_csv: pyspark.sql.DataFrame, ) -> pyspark.sql.DataFrame: google_trend_all = google_trend_csv.withColumn( "Date", F.regexp_extract(google_trend_csv.week, "(.*?) -", 1)).withColumn( "State", F.regexp_extract(google_trend_csv.file, "Rossmann_DE_(.*)", 1)) # map state NI -> HB,NI to align with other data sources google_trend_all = google_trend_all.withColumn( "State", F.when(google_trend_all.State == "NI", "HB,NI").otherwise(google_trend_all.State), ) # expand dates return expand_date(google_trend_all)
def flatten(df: pyspark.sql.DataFrame, fieldname_normaliser=__normalise_fieldname__): cols = [] for child in __get_fields_info__(df.schema): if len(child) > 2: ex = "x.{}".format(child[-1]) for seg in child[-2:0:-1]: if seg != '``': ex = "transform(x.{outer}, x -> {inner})".format(outer=seg, inner=ex) ex = "transform({outer}, x -> {inner})".format(outer=child[0], inner=ex) else: ex = ".".join(child) cols.append( f.expr(ex).alias( fieldname_normaliser("_".join(child).replace('`', '')))) return df.select(cols)
def save_to_parquet(self, df: pyspark.sql.DataFrame, name: str, mode: str = "overwrite", num_partitions: int = None, partition_cols: List[str] = None, pre_final: bool = False): """Saves a DataFrame into a parquet file. Args: df (pyspark.sql.DataFrame): name (str): mode (str): num_partitions (int): partition_cols (list): pre_final (bool): """ logger.debug( "Saving %s to parquet.." % name if not pre_final else "Saving %s.pre_final to parquet.." % name) path = os.path.join(self.df_data_folder, name, str(self.loop_counter)) if not os.path.exists(path): os.makedirs(path) if pre_final: parquet_name = os.path.join(path, name + ".pre_final.parquet") else: parquet_name = os.path.join(path, name + ".parquet") if partition_cols and num_partitions: df.repartition( num_partitions, *partition_cols).write.mode(mode).parquet(parquet_name) elif num_partitions and not partition_cols: df.repartition(num_partitions).write.mode(mode).parquet( parquet_name) elif partition_cols and not num_partitions: df.repartition( *partition_cols).write.mode(mode).parquet(parquet_name) else: df.repartition(1).write.mode(mode).parquet(parquet_name)
def save_to_ray(self, df: pyspark.sql.DataFrame, num_shards: int) -> PandasDataset: # call java function from python df = df.repartition(num_shards) sql_context = df.sql_ctx jvm = sql_context.sparkSession.sparkContext._jvm jdf = df._jdf object_store_writer = jvm.org.apache.spark.sql.raydp.ObjectStoreWriter( jdf) records = object_store_writer.save() worker = ray.worker.global_worker blocks: List[ray.ObjectRef] = [] block_sizes: List[int] = [] for record in records: owner_address = record.ownerAddress() object_id = ray.ObjectID(record.objectId()) num_records = record.numRecords() # Register the ownership of the ObjectRef worker.core_worker.deserialize_and_register_object_ref( object_id.binary(), ray.ObjectRef.nil(), owner_address) blocks.append(object_id) block_sizes.append(num_records) divided_blocks = divide_blocks(block_sizes, num_shards) record_batch_set: List[RecordBatch] = [] for i in range(num_shards): indexes = divided_blocks[i] object_ids = [blocks[index] for index in indexes] record_batch_set.append(RecordBatch(object_ids)) # TODO: we should specify the resource spec for each shard ds = parallel_dataset.from_iterators(generators=record_batch_set, name="spark_df") def resolve_fn(it: "Iterable[RecordBatch]") -> "Iterator[RecordBatch]": for item in it: item.resolve() yield item return ds.transform(resolve_fn, ".RecordBatch#resolve()").flatten().to_pandas(None)
def test_read_write_csv( test_csv_asset: PySparkDataAsset, iris_spark: pyspark.sql.DataFrame, spark_session: pyspark.sql.SparkSession, ) -> None: # try without any extra kwargs: PySparkDataAssetIO.write_data_asset(asset=test_csv_asset, data=iris_spark) # try with additional kwargs: PySparkDataAssetIO.write_data_asset(asset=test_csv_asset, data=iris_spark, header=True) # test mode; default is overwrite, switch to error (if exists) should raise: with pytest.raises(AnalysisException): PySparkDataAssetIO.write_data_asset(asset=test_csv_asset, data=iris_spark, header=True, mode="error") # test retrieval # before we can retrieve, we need to move the data from 'staging' to 'ready' os.makedirs(test_csv_asset.ready_path, exist_ok=True) # load the prepared data shutil.rmtree(test_csv_asset.ready_path) shutil.move(test_csv_asset.staging_ready_path, test_csv_asset.ready_path) retrieved = PySparkDataAssetIO.retrieve_data_asset( test_csv_asset, spark_session=spark_session, inferSchema=True, header=True) assert retrieved.count() == iris_spark.count() # Test check for missing 'spark_session' kwarg with pytest.raises(ValueError): PySparkDataAssetIO.retrieve_data_asset(test_csv_asset) # Test check for invalid 'spark_session' kwarg with pytest.raises(TypeError): PySparkDataAssetIO.retrieve_data_asset(test_csv_asset, spark_session=42)
def _calc_var(df: pyspark.sql.DataFrame, label_col: str) -> pd.DataFrame: r"""calculate variance for each column that isnt the label_col Parameters ---------- df : pyspark.sql.DataFrame df where rows are observations, all columns except `label_col` are predictors. label_col : str Returns ------- bias_df : pd.DataFrame pandas dataframe where predictors are index and only column is variance Raises ------ UncaughtExceptions Notes ----- """ pred_cols = [x for x in df.columns if x != label_col] s_var_df = df.groupby(label_col).agg({x: 'variance' for x in pred_cols }).toPandas().transpose() s_var_df = s_var_df.reset_index() s_var_df['index'] = s_var_df['index'].str.replace(r')', '').str.replace( r'variance\(', '') s_var_df = s_var_df.set_index('index') s_var_df.columns = ["var_{0}".format(x) for x in s_var_df.columns] s_var_df = s_var_df.loc[s_var_df.index != 'label', :] return s_var_df
def prepare_df( df: pyspark.sql.DataFrame, store_csv: pyspark.sql.DataFrame, store_states_csv: pyspark.sql.DataFrame, state_names_csv: pyspark.sql.DataFrame, google_trend_csv: pyspark.sql.DataFrame, weather_csv: pyspark.sql.DataFrame, ) -> pyspark.sql.DataFrame: num_rows = df.count() # expand dates df = expand_date(df) # create new columns in the DataFrame by filtering out special events(promo/holiday where sales was zero or store was closed). df = (df.withColumn("Open", df.Open != "0").withColumn( "Promo", df.Promo != "0").withColumn("StateHoliday", df.StateHoliday != "0").withColumn( "SchoolHoliday", df.SchoolHoliday != "0")) # merge store information store = store_csv.join(store_states_csv, "Store") df = df.join(store, "Store") # merge Google Trend information google_trend_all = prepare_google_trend(google_trend_csv) df = df.join(google_trend_all, ["State", "Year", "Week"]).select(df["*"], google_trend_all.trend) # merge in Google Trend for whole Germany google_trend_de = google_trend_all[google_trend_all.file == "Rossmann_DE"].withColumnRenamed( "trend", "trend_de") df = df.join(google_trend_de, ["Year", "Week"]).select(df["*"], google_trend_de.trend_de) # merge weather weather = weather_csv.join(state_names_csv, weather_csv.file == state_names_csv.StateName) df = df.join(weather, ["State", "Date"]) # fix null values df = (df.withColumn( "CompetitionOpenSinceYear", F.coalesce(df.CompetitionOpenSinceYear, F.lit(1900)), ).withColumn( "CompetitionOpenSinceMonth", F.coalesce(df.CompetitionOpenSinceMonth, F.lit(1)), ).withColumn("Promo2SinceYear", F.coalesce(df.Promo2SinceYear, F.lit(1900))).withColumn( "Promo2SinceWeek", F.coalesce(df.Promo2SinceWeek, F.lit(1)))) # days and months since the competition has been open, cap it to 2 years df = df.withColumn( "CompetitionOpenSince", F.to_date( F.format_string("%s-%s-15", df.CompetitionOpenSinceYear, df.CompetitionOpenSinceMonth)), ) df = df.withColumn( "CompetitionDaysOpen", F.when( df.CompetitionOpenSinceYear > 1900, F.greatest( F.lit(0), F.least(F.lit(360 * 2), F.datediff(df.Date, df.CompetitionOpenSince)), ), ).otherwise(0), ) df = df.withColumn("CompetitionMonthsOpen", (df.CompetitionDaysOpen / 30).cast(T.IntegerType())) # days and weeks of promotion, cap it to 25 weeks df = df.withColumn( "Promo2Since", F.expr( 'date_add(format_string("%s-01-01", Promo2SinceYear), (cast(Promo2SinceWeek as int) - 1) * 7)' ), ) df = df.withColumn( "Promo2Days", F.when( df.Promo2SinceYear > 1900, F.greatest( F.lit(0), F.least(F.lit(25 * 7), F.datediff(df.Date, df.Promo2Since))), ).otherwise(0), ) df = df.withColumn("Promo2Weeks", (df.Promo2Days / 7).cast(T.IntegerType())) # ensure that no row was lost through inner joins assert num_rows == df.count(), "lost rows in joins" return df
def cast(df: pyspark.sql.DataFrame, col_name: str, dtype: str) -> pyspark.sql.DataFrame: return df.withColumn(col_name, F.col(col_name).cast(dtype))
def cast_columns(df: pyspark.sql.DataFrame, cols: List[str]) -> pyspark.sql.DataFrame: for col in cols: df = df.withColumn(col, F.coalesce(df[col].cast(T.FloatType()), F.lit(0.0))) return df