def create_city_demographics_dim_table(spark: SparkSession, us_demographics_df: DataFrame, ports_df: DataFrame) -> DataFrame: """ Demographics dataset contains multiple entries for a city. This function aggregates those rows and creates city demographics dimensional table by combining ports and aggregated demographics datasets. """ us_demographics_df.createOrReplaceTempView('staging_us_demographics') ports_df.createOrReplaceTempView('staging_ports') aggregated_df = spark.sql(""" SELECT sud.city, sud.state_code, SUM(sud.male_population) AS male_population, SUM(sud.female_population) AS female_population, SUM(sud.total_population) AS total_population, SUM(sud.number_of_veterans) AS number_of_veterans, SUM(sud.foreign_born) AS num_foreign_born FROM staging_us_demographics sud GROUP BY sud.city, sud.state_code """) aggregated_df.createOrReplaceTempView('combined_demographics') return spark.sql(""" SELECT sp.port_code AS port_code, cd.* FROM staging_ports sp JOIN combined_demographics cd ON lower(cd.city) = lower(sp.city) AND cd.state_code = sp.state_code """)
def group_transactions(spark: SparkSession): spark.sql("""SELECT raw_transactions.customer_id, raw_transactions.basket FROM raw_transactions"""). \ select("*", F.explode("basket").alias("exploded_data")). \ select("customer_id", "exploded_data.product_id"). \ groupby("customer_id", "product_id").count().sort("customer_id").createOrReplaceTempView("transactions_grouped") logging.info("Grouped transactions.")
def process_log_data(spark: SparkSession, input_data: str, output_data: str): """Process events Args: spark (SparkSession): Spark Session input_data (str): input data path output_data (str): output data path """ events_path = os.path.join(input_data, "log_data/2018/11/*events.json") log_data = spark.read.json(events_path).dropDuplicates() log_data.createOrReplaceTempView("events_raw") events = spark.sql( "SELECT * FROM events_raw WHERE page = 'NextSong';").cache() events.createOrReplaceTempView("events_raw") user_table = spark.sql("""SELECT DISTINCT(INT(userId)) as user_id, STRING(firstName) as first_name, STRING(lastName) as last_name, STRING(gender) AS gender, STRING(level) AS level FROM events_raw;""") songplays_table = spark.sql("""SELECT INT(uuid()) AS songplay_id, TIMESTAMP(from_unixtime(ts / 1000)) AS start_time, INT(userId) AS user_id, STRING(level) AS level, INT(songs.song_id) AS song_id, INT(artists.artist_id) AS artist_id, INT(sessionId) as session_id, STRING(events_raw.location) AS location, STRING(userAgent) as user_agent, SMALLINT(year(from_unixtime(ts / 1000))) as year, SMALLINT(month(from_unixtime(ts / 1000))) as month FROM events_raw JOIN artists on events_raw.artist = artists.name JOIN songs on events_raw.song = songs.title""") songplays_table.createOrReplaceTempView("songplays") time_table = spark.sql("""SELECT DISTINCT(TIMESTAMP(start_time)), SMALLINT(hour(start_time)) as hour, SMALLINT(day(start_time)) as day, SMALLINT(weekofyear(start_time)) as week, SMALLINT(month) AS month, SMALLINT(year) AS year, SMALLINT(dayofweek(start_time)) as weekday FROM songplays;""") # Leaving write for last in case anything goes wrong in the conversion step time_output = os.path.join(output_data, "time_parquet") time_table.write.partitionBy("year", "month").parquet(time_output, "overwrite") user_output = os.path.join(output_data, "users_parquet") user_table.write.parquet(user_output, "overwrite") songplays_ouput = os.path.join(output_data, "songplays_parquet") songplays_table.write.partitionBy("year", "month").parquet(songplays_ouput, "overwrite")
def main( spark: SparkSession, inputfile: str, ): logger.info(inputfile) flights_df = (spark.read.format("parquet").load(inputfile)) spark.sql("CREATE DATABASE IF NOT EXISTS AIRLINE_DB") spark.catalog.setCurrentDatabase("AIRLINE_DB") # Write DataFrame table with Partitions flights_df.write \ .mode('overwrite') \ .partitionBy("OP_CARRIER", "ORIGIN") \ .saveAsTable("flight_data") flights_df.write \ .mode('overwrite') \ .bucketBy(5, "OP_CARRIER", "ORIGIN") \ .sortBy("OP_CARRIER", "ORIGIN") \ .saveAsTable("flight_data_bucket") logger.info(spark.catalog.listTables("AIRLINE_DB")) logger.info("done")
def spark_start(config: Dict = {}) -> SparkSession: pro_home = arcpy.GetInstallInfo()["InstallDir"] # pro_lib_dir = os.path.join(pro_home, "Java", "lib") pro_runtime_dir = os.path.join(pro_home, "Java", "runtime") os.environ["HADOOP_HOME"] = os.path.join(pro_runtime_dir, "hadoop") conf = SparkConf() for k, v in config.items(): conf.set(k, v) # # these need to be reset on every run or pyspark will think the Java gateway is still up and running os.environ.unsetenv("PYSPARK_GATEWAY_PORT") os.environ.unsetenv("PYSPARK_GATEWAY_SECRET") SparkContext._jvm = None SparkContext._gateway = None popen_kwargs = { 'stdout': subprocess.DEVNULL, # need to redirect stdout & stderr when running in Pro or JVM fails immediately 'stderr': subprocess.DEVNULL, 'shell': True # keeps the command-line window from showing } # we have to manage the py4j gateway ourselves so that we can control the JVM process gateway = launch_gateway(conf=conf, popen_kwargs=popen_kwargs) sc = SparkContext(gateway=gateway) spark = SparkSession(sc) # Kick start the spark engine. spark.sql("select 1").collect() return spark
def generic_file_source_options_example(spark: SparkSession) -> None: # $example on:ignore_corrupt_files$ # enable ignore corrupt files spark.sql("set spark.sql.files.ignoreCorruptFiles=true") # dir1/file3.json is corrupt from parquet's view test_corrupt_df = spark.read.parquet( "examples/src/main/resources/dir1/", "examples/src/main/resources/dir1/dir2/") test_corrupt_df.show() # +-------------+ # | file| # +-------------+ # |file1.parquet| # |file2.parquet| # +-------------+ # $example off:ignore_corrupt_files$ # $example on:recursive_file_lookup$ recursive_loaded_df = spark.read.format("parquet")\ .option("recursiveFileLookup", "true")\ .load("examples/src/main/resources/dir1") recursive_loaded_df.show() # +-------------+ # | file| # +-------------+ # |file1.parquet| # |file2.parquet| # +-------------+ # $example off:recursive_file_lookup$ spark.sql("set spark.sql.files.ignoreCorruptFiles=false") # $example on:load_with_path_glob_filter$ df = spark.read.load("examples/src/main/resources/dir1", format="parquet", pathGlobFilter="*.parquet") df.show() # +-------------+ # | file| # +-------------+ # |file1.parquet| # +-------------+ # $example off:load_with_path_glob_filter$ # $example on:load_with_modified_time_filter$ # Only load files modified before 07/1/2050 @ 08:30:00 df = spark.read.load("examples/src/main/resources/dir1", format="parquet", modifiedBefore="2050-07-01T08:30:00") df.show() # +-------------+ # | file| # +-------------+ # |file1.parquet| # +-------------+ # Only load files modified after 06/01/2050 @ 08:30:00 df = spark.read.load("examples/src/main/resources/dir1", format="parquet", modifiedAfter="2050-06-01T08:30:00") df.show()
def test_mlflow_model_from_model_version(spark: SparkSession, mlflow_client): # peg to a particular version of a model spark.sql("CREATE MODEL resnet_m_fizz USING 'mlflow://rikai-test/1'") check_ml_predict(spark, "resnet_m_fizz") # use the latest version in a given stage (omitted means none) spark.sql("CREATE MODEL resnet_m_buzz USING 'mlflow://rikai-test/'") check_ml_predict(spark, "resnet_m_buzz")
def spark_session_test(sc): spark = SparkSession(sc) df = spark.read.json("people.json") df.printSchema() df.createOrReplaceTempView("people") spark.sql("SELECT name FROM people WHERE AGE > 10").show()
def generate_poi_vocabulary(spark: SparkSession, args): sql = config.POI_VOCABULARY_SQL_TPL.format( START_DATE=args.start, END_DATE=args.end ) print(sql) save_path = config.SAMPLE_SAVE_DIR + "/poi_vocabulary/" spark.sql(sql).repartition(1).write.csv(save_path, mode="overwrite", sep="\t", header="true")
def create(self, title, cscore=True, start_date=None, end_date=None): assert type(cscore) is bool, 'Error. A bool value is expected for cscore signalling, if data file contains ' \ 'cscore.' self.set_selector_dates(start_date, end_date) # Create a SparkSession # Note: In case its run on Windows and generates errors use (tmp Folder mus exist): # spark = SparkSession.builder.config("spark.sql.warehouse.dir", "file:///C:/temp").appName("Postprocessing").getOrCreate() conf = SparkConf().setMaster("local[*]").setAppName("Events") sc = SparkContext(conf=conf) spark = SparkSession(sc).builder.appName("GtEvents").getOrCreate() # Create results path and filename results_path = os.path.join(self.graph.curr_data_path, str(title)) #print(results_path) self.check_results_path(results_path) events_results_file = str(title) + '_' + self.graph.source_events[0] # Register events dataframe for i in range(len(self.graph.source_events)): events_source = spark.sparkContext.textFile( os.path.join(self.graph.source_events_location, self.graph.source_events[i])) events = events_source.map(self.mapper_events) events_df = spark.createDataFrame(events).cache() if i == 0: all_events_df = events_df else: all_events_df = all_events_df.union(events_df) # Register dataframe for wiki_id to gt_id mapping id_map_source = spark.sparkContext.textFile(os.path.join(self.gt_wiki_id_map_path, self.gt_wiki_id_map_file)) id_map = id_map_source.map(self.mapper_ids) id_map_df = spark.createDataFrame(id_map).cache() # Resolve wiki_ids to gt_id all_events_df.createOrReplaceTempView("events") id_map_df.createOrReplaceTempView("id_map") events_resolved_df = spark.sql('SELECT e.revision, i.gt_id as gt_source, e.target, e.event, e.cscore ' 'FROM events e LEFT OUTER JOIN id_map i ON e.source = i.wiki_id') events_resolved_df.createOrReplaceTempView("events_resolved") events_resolved_df = spark.sql('SELECT e.revision, e.gt_source, i.gt_id as gt_target, e.event, e.cscore ' 'FROM events_resolved e LEFT OUTER JOIN id_map i ON e.target = i.wiki_id') # Collect results and write to file events_resolved = events_resolved_df.rdd.collect() self.write_list(os.path.join(results_path, events_results_file), events_resolved) self.results['files'] = [events_results_file] self.results['type'] = 'gt_events' self.results['start'] = str(datetime.fromtimestamp(self.start_date)) self.results['end'] = str(datetime.fromtimestamp(self.end_date)) self.data[self.graph_id][title] = self.results self.graph.update_graph_data(self.data) sc.stop()
def basic_datasource_example(spark: SparkSession) -> None: # $example on:generic_load_save_functions$ df = spark.read.load("examples/src/main/resources/users.parquet") df.select("name", "favorite_color").write.save("namesAndFavColors.parquet") # $example off:generic_load_save_functions$ # $example on:write_partitioning$ df.write.partitionBy("favorite_color").format("parquet").save("namesPartByColor.parquet") # $example off:write_partitioning$ # $example on:write_partition_and_bucket$ df = spark.read.parquet("examples/src/main/resources/users.parquet") (df .write .partitionBy("favorite_color") .bucketBy(42, "name") .saveAsTable("users_partitioned_bucketed")) # $example off:write_partition_and_bucket$ # $example on:manual_load_options$ df = spark.read.load("examples/src/main/resources/people.json", format="json") df.select("name", "age").write.save("namesAndAges.parquet", format="parquet") # $example off:manual_load_options$ # $example on:manual_load_options_csv$ df = spark.read.load("examples/src/main/resources/people.csv", format="csv", sep=";", inferSchema="true", header="true") # $example off:manual_load_options_csv$ # $example on:manual_save_options_orc$ df = spark.read.orc("examples/src/main/resources/users.orc") (df.write.format("orc") .option("orc.bloom.filter.columns", "favorite_color") .option("orc.dictionary.key.threshold", "1.0") .option("orc.column.encoding.direct", "name") .save("users_with_options.orc")) # $example off:manual_save_options_orc$ # $example on:manual_save_options_parquet$ df = spark.read.parquet("examples/src/main/resources/users.parquet") (df.write.format("parquet") .option("parquet.bloom.filter.enabled#favorite_color", "true") .option("parquet.bloom.filter.expected.ndv#favorite_color", "1000000") .option("parquet.enable.dictionary", "true") .option("parquet.page.write-checksum.enabled", "false") .save("users_with_options.parquet")) # $example off:manual_save_options_parquet$ # $example on:write_sorting_and_bucketing$ df.write.bucketBy(42, "name").sortBy("age").saveAsTable("people_bucketed") # $example off:write_sorting_and_bucketing$ # $example on:direct_sql$ df = spark.sql("SELECT * FROM parquet.`examples/src/main/resources/users.parquet`") # $example off:direct_sql$ spark.sql("DROP TABLE IF EXISTS people_bucketed") spark.sql("DROP TABLE IF EXISTS users_partitioned_bucketed")
def registerAll(cls, spark: SparkSession) -> bool: """ This is the core of whole package, It uses py4j to run wrapper which takes existing SparkSession and register all User Defined Functions by GeoSpark developers, for this SparkSession. :param spark: pyspark.sql.SparkSession, spark session instance :return: bool, True if registration was correct. """ spark.sql("SELECT 1 as geom").count() cls.register(spark) return True
def simple_profile(spark: SparkSession): """ :param spark: Spark session :return: DataFrame """ social_graph_rollup = spark.sql( "select first_user_id, collect_set(second_user_id) friends from socialgraph group by 1" ) social_graph_rollup.createOrReplaceTempView("social_graph_rollup") checkins_rollup = spark.sql("""select c.user_id, collect_list(struct(c.id, c.venue_id, c.latitude, c.longitude, c.c_latitude, c.c_longitude, c.created_at, r.rating, c.total_chk, c.checkin_seq, c.total_venue_chk, c.dst ) ) chks, count(1) total_chk, count(distinct c.venue_id) distinct_chk, approx_percentile(dst, 0.50) median_dst, approx_percentile(r.rating, 0.50) median_rating, approx_percentile(r.rating, 0.95) p95_rating from checkins_agg c left join ( select user_id, venue_id, avg(rating) rating from ratings group by 1, 2 ) r on r.user_id = c.user_id and r.venue_id = c.venue_id group by 1 """) checkins_rollup.createOrReplaceTempView("checkins_rollup") profiles = spark.sql( """select u.id, u.latitude u_latitude, u.longitude u_longitude, encode_geohash(u.latitude, u.longitude) u_gh4, c.*, g.friends from users u inner join checkins_rollup c on u.id = c.user_id left join social_graph_rollup g on u.id = g.first_user_id order by total_chk desc """) profiles.createOrReplaceTempView("profiles") profiles.write.format("orc").mode("overwrite").save("data/profiles") return profiles
def test_mlflow_model_from_runid( spark: SparkSession, mlflow_client: MlflowClient ): run_id = mlflow_client.search_model_versions("name='rikai-test'")[0].run_id spark.sql( "CREATE MODEL resnet_m_foo USING 'mlflow://{}/model'".format(run_id) ) check_ml_predict(spark, "resnet_m_foo") # if no path is given but only one artifact exists then use it by default spark.sql("CREATE MODEL resnet_m_bar USING 'mlflow://{}'".format(run_id)) check_ml_predict(spark, "resnet_m_bar")
def run_ddl(self, spark: SparkSession) -> None: ''' the function runs ddl operations from the text file with a list of ddl operations according to the template with the current time stamp and the current project directory according to the configuration ''' with Path(self.mode_directory, 'schwacke_hive_tables_ddl.txt').open('r') as ddl: _LOGGER.debug( '''text file with a list of ddl operations is available''') for operation in ddl.read().split('\n\n'): spark.sql(f'''{operation}''')
def recommendation_als_based(): print('ALS Model') sc = SparkContext() spark = SparkSession(sc) spark.read.format("jdbc").option("url", jdbc_url)\ .option("user", username).option("password", password)\ .option("dbtable", "game_steam_user_inventory")\ .load().createOrReplaceTempView('user_inventory') spark.read.format("jdbc").option("url", jdbc_url)\ .option("user", username).option("password", password)\ .option("dbtable", "game_steam_app")\ .load().createOrReplaceTempView('game_steam_app') df_user_playtime = spark.sql( 'SELECT DENSE_RANK() OVER (ORDER BY user_id) AS user, user_id, app_id AS item, playtime_forever AS rating FROM user_inventory WHERE playtime_forever > 0' ) df_valid_games = spark.sql( 'SELECT app_id FROM game_steam_app WHERE short_description IS NOT NULL AND type = "game" AND initial_price IS NOT NULL' ) df_user_inventory = df_user_playtime.join( df_valid_games, df_user_playtime['item'] == df_valid_games['app_id'], 'inner').select('user', 'user_id', 'item', 'rating') dic_real_user_id = df_user_inventory.select( 'user', 'user_id').toPandas().set_index('user')['user_id'].to_dict() als = ALS(rank=5) model = als.fit(df_user_inventory) recommended_games = model.recommendForAllUsers(10) dic_recomended_als_based = {} for user, lst_recommended_games in recommended_games.select( 'user', 'recommendations.item').toPandas().set_index( 'user')['item'].to_dict().items(): user_id = dic_real_user_id.get(user) dic_recomended_als_based[user_id] = {} for i, app_id in enumerate(lst_recommended_games): dic_recomended_als_based[user_id].update({i: app_id}) df_als_based_result = pd.DataFrame.from_dict(dic_recomended_als_based, 'index') df_als_based_result.index.name = 'user_id' df_als_based_result.reset_index(inplace=True) df_als_based_result.to_sql('recommended_games_als_based', engine, if_exists='replace', chunksize=1000, index=False)
def _get_data(spark: SparkSession, args): sql = config.SAMPLE_SQL_TPL.format(cols=','.join( col for col in config.COLUMNS), START_DATE=args.start, END_DATE=args.end) print("[INFO] get data sql: {0}".format(sql)) return spark.sql(sql).na.fill(-999.0)
def drop_dup_keep_latest(spark: SparkSession, df: DataFrame, id_col: str, date_col: str, keep_date_null: bool) -> DataFrame: df.createOrReplaceTempView('dataset') add_cond = '' if keep_date_null: add_cond = 'OR (T2.{0} IS NULL AND T1.{0} IS NULL)'.format(date_col) query = ''' SELECT T1.* FROM dataset AS T1 INNER JOIN ( SELECT {id_col}, MAX({date_col}) AS {date_col} FROM dataset GROUP BY {id_col} ) AS T2 ON T2.{id_col} = T1.{id_col} AND (T2.{date_col} = T1.{date_col} {add_cond}) ORDER BY {id_col} '''.format(id_col=id_col, date_col=date_col, add_cond=add_cond) print(query) df = spark.sql(query) spark.catalog.dropTempView('dataset') return df
def main(): sc = SparkContext.getOrCreate() spark = SparkSession(sc) print("What the heck") print("Hello World!") peopleDF = spark.read.json( "/opt/spark/examples/src/main/resources/people.json") # DataFrames can be saved as Parquet files, maintaining the schema information. peopleDF.write.parquet("people.parquet") # Read in the Parquet file created above. # Parquet files are self-describing so the schema is preserved. # The result of loading a parquet file is also a DataFrame. parquetFile = spark.read.parquet("people.parquet") # Parquet files can also be used to create a temporary view and then used in SQL statements. parquetFile.createOrReplaceTempView("parquetFile") teenagers = spark.sql( "SELECT name FROM parquetFile WHERE age >= 13 AND age <= 19") teenagers.show() # +------+ # | name| # +------+ # |Justin| # +------+ sc.stop()
def load_delay_data(df_reader: DataFrameReader, spark: SparkSession, mes_part_opplan, mes_part_info_dispatch, mes_part_info): df_reader.option( "dbtable", mes_part_opplan).load().createOrReplaceTempView("mes_part_opplan") df_reader.option("dbtable", mes_part_info_dispatch).load().createOrReplaceTempView( "mes_part_info_dispatch") df_reader.option( "dbtable", mes_part_info).load().createOrReplaceTempView("mes_part_info") df = spark.sql(''' select c.MES_PART_INFO_ID,b.MES_PART_INFO_DISPATCH_ID,a.MES_PART_OPPLAN_ID, c.PART_NO,c.LOT_NO,c.FOPLAN_ID,c.TASK_QTY,c.MANUFACTURER, b.SCHEDULED_OPERATOR_TYPE,b.SCHEDULED_OPERATOR_NO,a.SCHEDULED_START_DATE,a.SCHEDULED_COMPLETION_DATE, a.START_DATE,a.END_DATE,a.OP_NO,a.OP_NAME,a.OP_DESCRIPTION, a.OPER_DEPART,a.ACTUAL_MADE_BY,b.COMPLETION_RECORD_CREATOR from mes_part_opplan a join mes_part_info_dispatch b on a.MES_PART_INFO_ID=b.MES_PART_INFO_ID and a.MES_PART_OPPLAN_ID=b.MES_PART_OPPLAN_ID join mes_part_info c on a.MES_PART_INFO_ID=c.MES_PART_INFO_ID ''') return df.filter(F.col("START_DATE").isNotNull())
def simple_als(spark: SparkSession): """ :param spark: Spark session :return: DataFrame """ df = spark.sql("""select c.user_id, c.venue_id, c.v_gh4, max(c.total_venue_chk) total_venue_chk, avg(r.rating) rating from checkins_agg c left join ( select user_id, venue_id, avg(rating) rating from ratings group by 1, 2 ) r on r.user_id = c.user_id and r.venue_id = c.venue_id group by 1, 2, 3""") df.write.format("orc").mode("overwrite").save("data/checkins")
def checkin_aggregation(spark: SparkSession): """ :param spark: Spark session :return: DataFrame """ df = spark.sql("""select c.id, c.user_id, venue_id, v.latitude, v.longitude, c.latitude c_latitude, c.longitude c_longitude, approx_dist(v.latitude, v.longitude , c.latitude, c.longitude) dst, encode_geohash(v.latitude, v.longitude) v_gh4, timestamp(created_at) as created_at, count(1) over(partition by user_id) total_chk, row_number() over(partition by user_id order by timestamp(created_at)) checkin_seq, count(1) over(partition by user_id, venue_id) total_venue_chk from checkins c left join venues v on v.id = c.venue_id """) df.createOrReplaceTempView(DerivedTables.checkins_agg.value) return df
def _write_users(output_data: str, spark: SparkSession): """Save users data in the parquet format. Keeps only the latest records to reflect the current user level (free/paid). """ users_table = spark.sql(select_users) users_table.write.parquet(os.path.join(output_data, 'users.parquet'))
def find_references_drug_clinical_trial(spark: SparkSession, datadir: Path): dfdrugs = spark.read.parquet(str(datadir / labels.parquet_drug)) dfdrugs.createOrReplaceTempView('drug') dfclinicaltrial = spark.read.parquet( str(datadir / labels.parquet_clinical_trial)) dfclinicaltrial.createOrReplaceTempView('clinical_trial') cross = spark.sql( 'select drug.atccode as drug_atccode, drug.drug as drug_name,clinical_trial.id as clinical_trial_id, \ clinical_trial.scientific_title as clinical_trial_scientific_title,\ clinical_trial.date as clinical_trial_date from drug cross join clinical_trial' ) cross = cross.rdd.filter(lambda item: item.drug_name.upper() in item. clinical_trial_scientific_title.upper()) schema = StructType([ StructField('drug_atccode', StringType(), False), StructField('drug_name', StringType(), True), StructField('clinical_trial_id', StringType(), True), StructField('clinical_trial_scientific_title', StringType(), True), StructField('clinical_trial_date', DateType(), True), ]) cross.toDF(schema).write.parquet(path=str( datadir / labels.parquet_drug_clinical_trial), mode='overwrite')
def test_correctly_loads_csv_with_clean_flag_off(spark_session: SparkSession) -> None: # Arrange clean_spark_session(spark_session) data_dir: Path = Path(__file__).parent.joinpath("./") test_file_path: str = f"{data_dir.joinpath('column_name_test.json')}" schema = StructType([]) df: DataFrame = spark_session.createDataFrame( spark_session.sparkContext.emptyRDD(), schema ) # Act FrameworkJsonLoader( view="books", filepath=test_file_path, clean_column_names=False ).transform(df) # noinspection SqlDialectInspection result: DataFrame = spark_session.sql("SELECT * FROM books") # Assert assert result.count() == 2 assert result.collect()[1]["title"] == "Essentials of Programming Languages" assert len(result.collect()[1]["authors"]) == 2 assert result.collect()[1]["authors"][0]["surname"] == "Friedman" assert ( result.collect()[1]["Ugly column,with;chars{that}parquet(does)not like=much_-"] == 3 )
def _gen_cat_feature_vocabulary(data: DataFrame, orig_sparse_features: list, transform_sparse_features_dict: dict, bucket_feature_prefix, args, spark: SparkSession, sc: SparkContext): """ 离散变量生成词表 :param data: :param orig_sparse_features: :param transform_sparse_features_dict :param bucket_feature_prefix :param args: :param sc: :return: """ tmp_table = "tmp_table" data.createOrReplaceTempView(tmp_table) if args.is_train: cat_feature_vocabulary_dict = dict() for col in orig_sparse_features: vocabulary_df = spark.sql( "select distinct {col} as {col} from {table}".format( col=col, table=tmp_table)).toPandas() cat_feature_vocabulary_dict[col] = vocabulary_df[col].tolist() for k, v in transform_sparse_features_dict.items(): cat_feature_vocabulary_dict[bucket_feature_prefix + k] = [i for i in range(0, len(v))] utils.write_to_hdfs(sc, config.VOCABULARY_DICT, json.dumps(cat_feature_vocabulary_dict), overwrite=True) print("[INFO] VOCABULARY_DICT write success: {0}".format( config.VOCABULARY_DICT)) return config.VOCABULARY_DICT
def _get_entity_df_event_timestamp_range( entity_df: Union[pd.DataFrame, str], entity_df_event_timestamp_col: str, spark_session: SparkSession, ) -> Tuple[datetime, datetime]: if isinstance(entity_df, pd.DataFrame): entity_df_event_timestamp = entity_df.loc[:, entity_df_event_timestamp_col].infer_objects( ) if pd.api.types.is_string_dtype(entity_df_event_timestamp): entity_df_event_timestamp = pd.to_datetime( entity_df_event_timestamp, utc=True) entity_df_event_timestamp_range = ( entity_df_event_timestamp.min().to_pydatetime(), entity_df_event_timestamp.max().to_pydatetime(), ) elif isinstance(entity_df, str): # If the entity_df is a string (SQL query), determine range # from table df = spark_session.sql(entity_df).select(entity_df_event_timestamp_col) # TODO(kzhang132): need utc conversion here. entity_df_event_timestamp_range = ( df.agg({ entity_df_event_timestamp_col: "max" }).collect()[0][0], df.agg({ entity_df_event_timestamp_col: "min" }).collect()[0][0], ) else: raise InvalidEntityType(type(entity_df)) return entity_df_event_timestamp_range
def transformData(spark:SparkSession, transactionsDf:DataFrame, customersDf:DataFrame, productsDf:DataFrame, path:str) -> DataFrame: """ call your custom functions to tranform your data """ # wrap the exportResult function around the transform function, which saves the transform result to a delta table # after exporting the result we're now done with the ETL exportResult([ \ (spark, cleanTransactions(transactionsDf), {"format":"delta", "path":f"{path}/transactions", "key":"customer_id"}), \ (spark, cleanCustomers(customersDf), {"format":"delta", "path":f"{path}/customers", "key":"customer_id"}), \ (spark, cleanProducts(productsDf), {"format":"delta", "path":f"{path}/products", "key":"product_id"}) \ ]) # this final step needn't be in jobs/sales.py and would be ideally placed in a new jobs application # i'm including it as indication that we could perform further queries on the delta table # because we might load historic data from our existing tables or other data from a different table # here I'll use an example of the delta tables we've already just saved l = loadDeltaTables([ \ (spark, f"{path}/transactions", "delta"), \ (spark, f"{path}/customers", "delta"), \ (spark, f"{path}/products", "delta") \ ]) # if you prefer sql to directly quering dataframes then zip the list of 'dataframes' with a list of 'table names' listOfDf = list(zip(l, ["transactions", "customers", "products"])) # then create temp tables that we can perform sql queries on createTempTables(spark, listOfDf) # from here, include functions to do more stuff # print delta tables to terminal df = spark.sql("SELECT * FROM transactions") df.show()
def find_references_drug_pubmed(spark: SparkSession, datadir: Path): dfdrugs = spark.read.parquet(str(datadir / labels.parquet_drug)) dfdrugs.createOrReplaceTempView('drug') dfpubmed = spark.read.parquet(str(datadir / labels.parquet_pubmed)) dfpubmed.createOrReplaceTempView('pubmed') cross = spark.sql( 'select drug.atccode as drug_atccode, drug.drug as drug_name,pubmed.id as pubmed_id,pubmed.title \ as pubmed_title,pubmed.date as pubmed_date from drug cross join pubmed' ) def filter_cross(row): return row.drug_name.upper() in row.pubmed_title.upper() cross = cross.rdd.filter(filter_cross) schema = StructType([ StructField('drug_atccode', StringType(), False), StructField('drug_name', StringType(), True), StructField('pubmed_id', IntegerType(), True), StructField('pubmed_title', StringType(), True), StructField('pubmed_date', DateType(), True), ]) cross.toDF(schema).write.parquet(path=str(datadir / labels.parquet_drug_pubmed), mode='overwrite')
def linked_profiles(spark: SparkSession): profiles_complete = spark.sql(""" select p.*, g.friends_profile from profiles p left join ( select first_user_id user_id, collect_list(struct(p.user_id, p.chks)) friends_profile from socialgraph g inner join profiles p on g.second_user_id = p.id group by 1 ) g on g.user_id = p.user_id """) profiles_complete.write.format("orc").mode("overwrite").save( "data/profiles_complete") return profiles_complete