def agg(self, stats): grouping_schema = StructType([ field for col in self.grouping_cols for field in col.find_fields_in_schema(self.jdf.bound_schema) ]) aggregated_stats = self.jdf.aggregate( GroupedStats(self.grouping_cols, stats, pivot_col=self.pivot_col, pivot_values=self.pivot_values), lambda grouped_stats, row: grouped_stats.merge( row, self.jdf.bound_schema), lambda grouped_stats_1, grouped_stats_2: grouped_stats_1. mergeStats(grouped_stats_2, self.jdf.bound_schema)) data = [] all_stats = self.add_subtotals(aggregated_stats) for group_key in all_stats.group_keys: key = [(str(key), None if value is GROUPED else value) for key, value in zip(self.grouping_cols, group_key)] grouping = tuple(value is GROUPED for value in group_key) key_as_row = row_from_keyed_values(key).set_grouping(grouping) data.append( row_from_keyed_values(key + [ (str(stat), stat.with_pre_evaluation_schema(self.jdf.bound_schema). eval(key_as_row, grouping_schema)) for pivot_value in all_stats.pivot_values for stat in get_pivoted_stats( all_stats.groups[group_key][pivot_value], pivot_value) ])) if self.pivot_col is not None: if len(stats) == 1: new_schema = StructType(grouping_schema.fields + [ StructField(str(pivot_value), DataType(), True) for pivot_value in self.pivot_values ]) else: new_schema = StructType(grouping_schema.fields + [ StructField("{0}_{1}".format(pivot_value, stat), DataType(), True) for pivot_value in self.pivot_values for stat in stats ]) else: new_schema = StructType( grouping_schema.fields + [StructField(str(stat), DataType(), True) for stat in stats]) # noinspection PyProtectedMember return self.jdf._with_rdd(self.jdf._sc.parallelize(data), schema=new_schema)
def read(self): sc = self.spark._sc paths = self.paths partitions, partition_schema = resolve_partitions(paths) rdd_filenames = sc.parallelize(sorted(partitions.keys()), len(partitions)) rdd = rdd_filenames.flatMap( partial(parse_json_file, partitions, partition_schema, self.schema, self.options)) inferred_schema = infer_schema_from_rdd(rdd) schema = self.schema if self.schema is not None else inferred_schema schema_fields = {field.name: field for field in schema.fields} # Field order is defined by fields in the record, not by the given schema # Field type is defined by the given schema or inferred full_schema = StructType(fields=[ schema_fields.get(field.name, field) for field in inferred_schema.fields ]) cast_row = get_struct_caster(inferred_schema, full_schema, options=self.options) casted_rdd = rdd.map(cast_row) casted_rdd._name = paths return DataFrameInternal(sc, casted_rdd, schema=full_schema)
def read(self): sc = self.spark._sc paths = self.paths partitions, partition_schema = resolve_partitions(paths) rdd_filenames = sc.parallelize(sorted(partitions.keys()), len(partitions)) rdd = rdd_filenames.flatMap(partial( parse_text_file, partitions, partition_schema, self.schema, self.options )) if partition_schema: partitions_fields = partition_schema.fields full_schema = StructType(self.schema.fields + partitions_fields) else: full_schema = self.schema rdd._name = paths return DataFrameInternal( sc, rdd, schema=full_schema )
def merge_schemas(left_schema, right_schema, how, on=None): if on is None: on = [] left_on_fields, right_on_fields = get_on_fields(left_schema, right_schema, on) other_left_fields = [ field for field in left_schema.fields if field not in left_on_fields ] other_right_fields = [ field for field in right_schema.fields if field not in right_on_fields ] if how in (INNER_JOIN, CROSS_JOIN, LEFT_JOIN, LEFT_ANTI_JOIN, LEFT_SEMI_JOIN): on_fields = left_on_fields elif how == RIGHT_JOIN: on_fields = right_on_fields elif how == FULL_JOIN: on_fields = [ StructField(field.name, field.dataType, nullable=True) for field in left_on_fields ] else: raise IllegalArgumentException( "Invalid how argument in join: {0}".format(how)) return StructType(fields=on_fields + other_left_fields + other_right_fields)
def test_session_create_data_frame_from_list_with_schema(self): schema = StructType( [StructField("map", MapType(StringType(), IntegerType()), True)]) df = self.spark.createDataFrame([({'a': 1}, )], schema=schema) self.assertEqual(df.count(), 1) self.assertListEqual(df.collect(), [Row(map={'a': 1})]) self.assertEqual(df.schema, schema)
def test_session_create_data_frame_from_pandas_data_frame(self): try: # Pandas is an optional dependency # pylint: disable=import-outside-toplevel import pandas as pd except ImportError as e: raise ImportError("pandas is not importable") from e pdf = pd.DataFrame([(1, "one"), (2, "two"), (3, "three")]) df = self.spark.createDataFrame(pdf) self.assertEqual(df.count(), 3) self.assertListEqual(df.collect(), [ Row(**{ "0": 1, "1": 'one' }), Row(**{ "0": 2, "1": 'two' }), Row(**{ "0": 3, "2": 'three' }) ]) self.assertEqual( df.schema, StructType([ StructField("0", LongType(), True), StructField("1", StringType(), True) ]))
def test_csv_read_with_given_schema(self): schema = StructType([ StructField("permalink", StringType()), StructField("company", StringType()), StructField("numEmps", IntegerType()), StructField("category", StringType()), StructField("city", StringType()), StructField("state", StringType()), StructField("fundedDate", DateType()), StructField("raisedAmt", IntegerType()), StructField("raisedCurrency", StringType()), StructField("round", StringType()) ]) df = spark.read.schema(schema).csv(os.path.join( os.path.dirname(os.path.realpath(__file__)), "data/fundings/"), header=True) self.assertEqual([Row(**r.asDict()) for r in df.collect()], [ Row(permalink='mycityfaces', company='MyCityFaces', numEmps=7, category='web', city='Scottsdale', state='AZ', fundedDate=datetime.date(2008, 1, 1), raisedAmt=50000, raisedCurrency='USD', round='seed'), Row(permalink='flypaper', company='Flypaper', numEmps=None, category='web', city='Phoenix', state='AZ', fundedDate=datetime.date(2008, 2, 1), raisedAmt=3000000, raisedCurrency='USD', round='a'), Row(permalink='chosenlist-com', company='ChosenList.com', numEmps=5, category='web', city='Scottsdale', state='AZ', fundedDate=datetime.date(2008, 1, 25), raisedAmt=233750, raisedCurrency='USD', round='angel'), Row(permalink='digg', company='Digg', numEmps=60, category='web', city='San Francisco', state='CA', fundedDate=datetime.date(2006, 12, 1), raisedAmt=8500000, raisedCurrency='USD', round='b') ])
def test_cast_to_struct(self): self.assertEqual( cast_to_struct(Row(character='Alice', day='28', month='8', year='2019'), from_type=StructType(fields=[ StructField("character", StringType()), StructField("day", StringType()), StructField("month", StringType()), StructField("year", StringType()), ]), to_type=StructType(fields=[ StructField("character", StringType()), StructField("day", IntegerType()), StructField("month", IntegerType()), StructField("year", IntegerType()), ]), options=BASE_OPTIONS), Row(character='Alice', day=28, month=8, year=2019), )
def read(self): sc = self.spark._sc paths = self.paths partitions, partition_schema = resolve_partitions(paths) rdd_filenames = sc.parallelize(sorted(partitions.keys()), len(partitions)) rdd = rdd_filenames.flatMap( partial(parse_csv_file, partitions, partition_schema, self.schema, self.options)) if self.schema is not None: schema = self.schema elif self.options.inferSchema: fields = rdd.take(1)[0].__fields__ schema = guess_schema_from_strings(fields, rdd.collect(), options=self.options) else: schema = infer_schema_from_rdd(rdd) schema_with_string = StructType(fields=[ StructField(field.name, StringType()) for field in schema.fields ]) if partition_schema: partitions_fields = partition_schema.fields full_schema = StructType(schema.fields[:-len(partitions_fields)] + partitions_fields) else: full_schema = schema cast_row = get_caster(from_type=schema_with_string, to_type=full_schema, options=self.options) casted_rdd = rdd.map(cast_row) casted_rdd._name = paths return DataFrameInternal(sc, casted_rdd, schema=full_schema)
def toDF(self, new_names): def mapper(row): keyed_values = [(new_name, row[old]) for new_name, old in zip(new_names, row.__fields__) ] return row_from_keyed_values(keyed_values) new_schema = StructType([ StructField(new_name, field.dataType, field.nullable) for new_name, field in zip(new_names, self.bound_schema.fields) ]) return self._with_rdd(self._rdd.map(mapper), schema=new_schema)
def withColumnRenamed(self, existing, new): def mapper(row): keyed_values = [(new, row[col]) if col == existing else (col, row[col]) for col in row.__fields__] return row_from_keyed_values(keyed_values) new_schema = StructType([ field if field.name != existing else StructField( new, field.dataType, field.nullable) for field in self.bound_schema.fields ]) return self._with_rdd(self._rdd.map(mapper), schema=new_schema)
def crosstab(self, df, col1, col2): table_name = "_".join((col1, col2)) counts = df.groupBy(col1, col2).agg(count("*")).take(1e6) if len(counts) == 1e6: warnings.warn( "The maximum limit of 1e6 pairs have been collected, " "which may not be all of the pairs. Please try reducing " "the amount of distinct items in your columns.") def clean_element(element): return str(element) if element is not None else "null" distinct_col2 = (counts.map(lambda row: clean_element(row[col2])). distinct().sorted().zipWithIndex().toMap()) column_size = len(distinct_col2) if column_size < 1e4: raise ValueError( "The number of distinct values for {0} can't exceed 1e4. " "Currently {1}".format(col2, column_size)) def create_counts_row(col1Item, rows): counts_row = [None] * (column_size + 1) def parse_row(row): column_index = distinct_col2[clean_element(row[1])] counts_row[int(column_index + 1)] = int(row[2]) rows.foreach(parse_row) # the value of col1 is the first value, the rest are the counts counts_row[0] = clean_element(col1Item) return Row(counts_row) table = counts.groupBy(lambda r: r[col1]).map(create_counts_row).toSeq # Back ticks can't exist in DataFrame column names, # therefore drop them. To be able to accept special keywords and `.`, # wrap the column names in ``. def clean_column_name(name): return name.replace("`", "") # In the map, the column names (._1) are not ordered by the index (._2). # We need to explicitly sort by the column index and assign the column names. header_names = distinct_col2.toSeq.sortBy(lambda r: r[2]).map( lambda r: StructField(clean_column_name(str(r[1])), LongType)) schema = StructType([StructField(table_name, StringType)] + header_names) return schema, table
def test_session_create_data_frame_from_list(self): df = self.spark.createDataFrame([ (1, "one"), (2, "two"), (3, "three"), ]) self.assertEqual(df.count(), 3) self.assertListEqual( df.collect(), [Row(_1=1, _2='one'), Row(_1=2, _2='two'), Row(_1=3, _2='three')]) self.assertEqual( df.schema, StructType([StructField("_1", LongType(), True), StructField("_2", StringType(), True)]) )
def test_session_create_data_frame_from_list_with_col_names(self): df = self.spark.createDataFrame([(0.0, [1.0, 0.8]), (1.0, [0.0, 0.0]), (2.0, [0.5, 0.5])], ["label", "features"]) self.assertEqual(df.count(), 3) self.assertListEqual(df.collect(), [ row_from_keyed_values([("label", 0.0), ("features", [1.0, 0.8])]), row_from_keyed_values([("label", 1.0), ("features", [0.0, 0.0])]), row_from_keyed_values([("label", 2.0), ("features", [0.5, 0.5])]), ]) self.assertEqual( df.schema, StructType([ StructField("label", DoubleType(), True), StructField("features", ArrayType(DoubleType(), True), True) ]))
def guess_schema_from_strings(schema_fields, data, options): field_values = [ (field, [row[field] for row in data]) for field in schema_fields ] field_types_and_values = [ (field, guess_type_from_values_as_string(values, options)) for field, values in field_values ] schema = StructType(fields=[ StructField(field, field_type) for field, field_type in field_types_and_values ]) return schema
def test_column_stat_helper(): """ Expected quantile values come from use of org.apache.spark.sql.catalyst.util.QuantileSummaries """ schema = StructType([StructField("value", IntegerType())]) helper = ColumnStatHelper(col("value")) for i in range(1, 100001): helper.merge(Row(value=i), schema) helper.finalize() assert helper.count == 100000 assert helper.min == 1 assert helper.max == 100000 assert helper.mean == 50000.5 assert helper.stddev == 28867.65779668774 # sample standard deviation assert helper.get_quantile(0) == 1 assert helper.get_quantile(0.25) == 24998 assert helper.get_quantile(0.5) == 50000 assert helper.get_quantile(0.75) == 74993 assert helper.get_quantile(1) == 100000
def drop(self, cols): positions_to_drop = [] for col in cols: if isinstance(col, str): if col == "*": continue col = parse(col) try: positions_to_drop.append( col.find_position_in_schema(self.bound_schema)) except ValueError: pass new_schema = StructType([ field for i, field in enumerate(self.bound_schema.fields) if i not in positions_to_drop ]) return self._with_rdd( self.rdd().map(lambda row: row_from_keyed_values( [(field, row[i]) for i, field in enumerate(row.__fields__) if i not in positions_to_drop])), new_schema)
def test_cast_row_to_string(self): self.assertEqual( cast_to_string(Row(a=collections.OrderedDict([("value", None), ("b", { "c": 7 })]), b=None, c=True, d=5.2), StructType([ StructField( "a", MapType( StringType(), MapType(StringType(), LongType(), True), True), True), StructField("b", LongType(), True), StructField("c", BooleanType(), True), StructField("d", DoubleType(), True) ]), options=BASE_OPTIONS), "[[value ->, b -> [c -> 7]],, true, 5.2]")
def createDataFrame(self, data, schema=None, samplingRatio=None, verifySchema=True): SparkSession._activeSession = self if isinstance(data, DataFrame): raise TypeError("data is already a DataFrame") if isinstance(schema, basestring): schema = StructType.fromDDL(schema) elif isinstance(schema, (list, tuple)): # Must re-encode any unicode strings to be consistent with StructField names schema = [ x.encode('utf-8') if not isinstance(x, str) else x for x in schema ] try: # pandas is an optional dependency # pylint: disable=import-outside-toplevel has_pandas = True import pandas except ImportError: has_pandas = False if has_pandas and isinstance(data, pandas.DataFrame): data, schema = self.parse_pandas_dataframe(data, schema) no_check = lambda _: True if isinstance(schema, StructType): verify_func = _make_type_verifier( schema) if verifySchema else no_check def prepare(obj): verify_func(obj) return obj elif isinstance(schema, DataType): dataType = schema schema = StructType().add("value", schema) verify_func = _make_type_verifier( dataType, name="field value") if verifySchema else no_check def prepare(obj): verify_func(obj) return tuple([obj]) else: def prepare(obj): return obj if isinstance(data, RDD): rdd, schema = self._createFromRDD(data.map(prepare), schema, samplingRatio) else: rdd, schema = self._createFromLocal(map(prepare, data), schema) cols = [ col_type.name if hasattr(col_type, "name") else "_" + str(i) for i, col_type in enumerate(schema) ] df = DataFrame(DataFrameInternal(self._sc, rdd, cols, True, schema), self._wrapped) return df
def get_schema_from_cols(cols, current_schema): new_schema = StructType(fields=[ field for col in cols for field in col.find_fields_in_schema(current_schema) ]) return new_schema
def test_csv_read_without_schema(self): df = spark.read.csv(os.path.join( os.path.dirname(os.path.realpath(__file__)), "data/fundings/"), header=True) self.assertEqual(df.count(), 4) self.assertEqual( df.schema, StructType([ StructField("permalink", StringType()), StructField("company", StringType()), StructField("numEmps", StringType()), StructField("category", StringType()), StructField("city", StringType()), StructField("state", StringType()), StructField("fundedDate", StringType()), StructField("raisedAmt", StringType()), StructField("raisedCurrency", StringType()), StructField("round", StringType()) ])) self.assertListEqual([Row(**r.asDict()) for r in df.collect()], [ Row(permalink='mycityfaces', company='MyCityFaces', numEmps='7', category='web', city='Scottsdale', state='AZ', fundedDate='2008-01-01', raisedAmt='50000', raisedCurrency='USD', round='seed'), Row(permalink='flypaper', company='Flypaper', numEmps=None, category='web', city='Phoenix', state='AZ', fundedDate='2008-02-01', raisedAmt='3000000', raisedCurrency='USD', round='a'), Row(permalink='chosenlist-com', company='ChosenList.com', numEmps='5', category='web', city='Scottsdale', state='AZ', fundedDate='2008-01-25', raisedAmt='233750', raisedCurrency='USD', round='angel'), Row(permalink='digg', company='Digg', numEmps='60', category='web', city='San Francisco', state='CA', fundedDate='2006-12-01', raisedAmt='8500000', raisedCurrency='USD', round='b') ])
def __init__(self, spark, paths, schema, options): self.spark = spark self.paths = paths self.schema = schema or StructType([StructField("value", StringType())]) self.options = Options(self.default_options, options)
def get_summary_schema(self, exprs): return StructType([StructField("summary", StringType(), True)] + [ StructField(field.name, StringType(), True) for field in get_schema_from_cols(exprs, self.bound_schema).fields ])