def test_cast_row_to_string(self): self.assertEqual( cast_to_string( Row( a=collections.OrderedDict([('value', None), ('b', { 'c': 7 })]), b=None, c=True, d=5.2, ), StructType([ StructField( 'a', MapType( StringType(), MapType(StringType(), LongType(), True), True, ), True, ), StructField('b', LongType(), True), StructField('c', BooleanType(), True), StructField('d', DoubleType(), True), ]), options=BASE_OPTIONS, ), '[[value ->, b -> [c -> 7]],, true, 5.2]', )
def agg(self, stats): grouping_schema = StructType([ field for col in self.grouping_cols for field in col.find_fields_in_schema(self.jdf.bound_schema) ]) aggregated_stats = self.jdf.aggregate( GroupedStats( self.grouping_cols, stats, pivot_col=self.pivot_col, pivot_values=self.pivot_values, ), lambda grouped_stats, row: grouped_stats.merge( row, self.jdf.bound_schema), lambda grouped_stats_1, grouped_stats_2: grouped_stats_1. mergeStats(grouped_stats_2, self.jdf.bound_schema), ) data = [] all_stats = self.add_subtotals(aggregated_stats) for group_key in all_stats.group_keys: key = [(str(key), None if value is GROUPED else value) for key, value in zip(self.grouping_cols, group_key)] grouping = tuple(value is GROUPED for value in group_key) key_as_row = row_from_keyed_values(key).set_grouping(grouping) data.append( row_from_keyed_values(key + [( str(stat), stat.with_pre_evaluation_schema(self.jdf.bound_schema). eval(key_as_row, grouping_schema), ) for pivot_value in all_stats.pivot_values for stat in get_pivoted_stats( all_stats.groups[group_key] [pivot_value], pivot_value)])) if self.pivot_col is not None: if len(stats) == 1: new_schema = StructType(grouping_schema.fields + [ StructField(str(pivot_value), DataType(), True) for pivot_value in self.pivot_values ]) else: new_schema = StructType(grouping_schema.fields + [ StructField('{0}_{1}'.format(pivot_value, stat), DataType(), True) for pivot_value in self.pivot_values for stat in stats ]) else: new_schema = StructType( grouping_schema.fields + [StructField(str(stat), DataType(), True) for stat in stats]) # noinspection PyProtectedMember return self.jdf._with_rdd(self.jdf._sc.parallelize(data), schema=new_schema)
def test_session_create_data_frame_from_list(self): df = self.spark.createDataFrame([(1, 'one'), (2, 'two'), (3, 'three')]) self.assertEqual(df.count(), 3) self.assertListEqual( df.collect(), [Row(_1=1, _2='one'), Row(_1=2, _2='two'), Row(_1=3, _2='three')], ) self.assertEqual( df.schema, StructType([StructField('_1', LongType(), True), StructField('_2', StringType(), True)]), )
def crosstab(self, df, col1, col2): table_name = '_'.join((col1, col2)) counts = df.groupBy(col1, col2).agg(count('*')).take(1e6) if len(counts) == 1e6: warnings.warn( 'The maximum limit of 1e6 pairs have been collected, ' 'which may not be all of the pairs. Please try reducing ' 'the amount of distinct items in your columns.') def clean_element(element): return str(element) if element is not None else 'null' distinct_col2 = counts.map(lambda row: clean_element(row[ col2])).distinct().sorted().zipWithIndex().toMap() column_size = len(distinct_col2) if column_size < 1e4: raise ValueError( "The number of distinct values for {0} can't exceed 1e4. " 'Currently {1}'.format(col2, column_size)) def create_counts_row(col1Item, rows): counts_row = [None] * (column_size + 1) def parse_row(row): column_index = distinct_col2[clean_element(row[1])] counts_row[int(column_index + 1)] = int(row[2]) rows.foreach(parse_row) # the value of col1 is the first value, the rest are the counts counts_row[0] = clean_element(col1Item) return Row(counts_row) table = counts.groupBy(lambda r: r[col1]).map(create_counts_row).toSeq # Back ticks can't exist in DataFrame column names, # therefore drop them. To be able to accept special keywords and `.`, # wrap the column names in ``. def clean_column_name(name): return name.replace('`', '') # In the map, the column names (._1) are not ordered by the index (._2). # We need to explicitly sort by the column index and assign the column names. header_names = distinct_col2.toSeq.sortBy(lambda r: r[2]).map( lambda r: StructField(clean_column_name(str(r[1])), LongType)) schema = StructType([StructField(table_name, StringType)] + header_names) return schema, table
def merge_schemas(left_schema, right_schema, how, on=None): if on is None: on = [] left_on_fields, right_on_fields = get_on_fields(left_schema, right_schema, on) other_left_fields = [ field for field in left_schema.fields if field not in left_on_fields ] other_right_fields = [ field for field in right_schema.fields if field not in right_on_fields ] if how in (INNER_JOIN, CROSS_JOIN, LEFT_JOIN, LEFT_ANTI_JOIN, LEFT_SEMI_JOIN): on_fields = left_on_fields elif how == RIGHT_JOIN: on_fields = right_on_fields elif how == FULL_JOIN: on_fields = [ StructField(field.name, field.dataType, nullable=True) for field in left_on_fields ] else: raise IllegalArgumentException( 'Invalid how argument in join: {0}'.format(how)) return StructType(fields=on_fields + other_left_fields + other_right_fields)
def output_fields(self, schema): if isinstance(self.expr, Expression): return self.expr.output_fields(schema) return [ StructField(name=self.col_name, dataType=self.data_type, nullable=self.is_nullable) ]
def test_session_create_data_frame_from_pandas_data_frame(self): try: # Pandas is an optional dependency # pylint: disable=import-outside-toplevel import pandas as pd except ImportError: raise Exception('pandas is not importable') pdf = pd.DataFrame([(1, 'one'), (2, 'two'), (3, 'three')]) df = self.spark.createDataFrame(pdf) self.assertEqual(df.count(), 3) self.assertListEqual( df.collect(), [Row(**{'0': 1, '1': 'one'}), Row(**{'0': 2, '1': 'two'}), Row(**{'0': 3, '2': 'three'})], ) self.assertEqual( df.schema, StructType([StructField('0', LongType(), True), StructField('1', StringType(), True)]), )
def withColumnRenamed(self, existing, new): def mapper(row): keyed_values = [(new, row[col]) if col == existing else (col, row[col]) for col in row.__fields__] return row_from_keyed_values(keyed_values) new_schema = StructType([ field if field.name != existing else StructField( new, field.dataType, field.nullable) for field in self.bound_schema.fields ]) return self._with_rdd(self._rdd.map(mapper), schema=new_schema)
def toDF(self, new_names): def mapper(row): keyed_values = [(new_name, row[old]) for new_name, old in zip(new_names, row.__fields__) ] return row_from_keyed_values(keyed_values) new_schema = StructType([ StructField(new_name, field.dataType, field.nullable) for new_name, field in zip(new_names, self.bound_schema.fields) ]) return self._with_rdd(self._rdd.map(mapper), schema=new_schema)
def test_session_create_data_frame_from_list_with_col_names(self): df = self.spark.createDataFrame( [(0.0, [1.0, 0.8]), (1.0, [0.0, 0.0]), (2.0, [0.5, 0.5])], ['label', 'features'], ) self.assertEqual(df.count(), 3) self.assertListEqual( df.collect(), [ row_from_keyed_values([('label', 0.0), ('features', [1.0, 0.8])]), row_from_keyed_values([('label', 1.0), ('features', [0.0, 0.0])]), row_from_keyed_values([('label', 2.0), ('features', [0.5, 0.5])]), ], ) self.assertEqual( df.schema, StructType( [ StructField('label', DoubleType(), True), StructField('features', ArrayType(DoubleType(), True), True), ] ), )
def guess_schema_from_strings(schema_fields, data, options): field_values = [(field, [row[field] for row in data]) for field in schema_fields] field_types_and_values = [ (field, guess_type_from_values_as_string(values, options)) for field, values in field_values ] schema = StructType(fields=[ StructField(field, field_type) for field, field_type in field_types_and_values ]) return schema
def test_cast_to_struct(self): self.assertEqual( cast_to_struct( Row(character='Alice', day='28', month='8', year='2019'), from_type=StructType(fields=[ StructField('character', StringType()), StructField('day', StringType()), StructField('month', StringType()), StructField('year', StringType()), ]), to_type=StructType(fields=[ StructField('character', StringType()), StructField('day', IntegerType()), StructField('month', IntegerType()), StructField('year', IntegerType()), ]), options=BASE_OPTIONS, ), Row(character='Alice', day=28, month=8, year=2019), )
def test_column_stat_helper(): """ Expected quantile values come from use of org.apache.spark.sql.catalyst.util.QuantileSummaries """ schema = StructType([StructField('value', IntegerType())]) helper = ColumnStatHelper(col('value')) for i in range(1, 100001): helper.merge(Row(value=i), schema) helper.finalize() assert helper.count == 100000 assert helper.min == 1 assert helper.max == 100000 assert helper.mean == 50000.5 assert helper.stddev == 28867.65779668774 # sample standard deviation assert helper.get_quantile(0) == 1 assert helper.get_quantile(0.25) == 24998 assert helper.get_quantile(0.5) == 50000 assert helper.get_quantile(0.75) == 74993 assert helper.get_quantile(1) == 100000
def read(self): sc = self.spark._sc paths = self.paths partitions, partition_schema = resolve_partitions(paths) rdd_filenames = sc.parallelize(sorted(partitions.keys()), len(partitions)) rdd = rdd_filenames.flatMap( partial(parse_csv_file, partitions, partition_schema, self.schema, self.options)) if self.schema is not None: schema = self.schema elif self.options.inferSchema: fields = rdd.take(1)[0].__fields__ schema = guess_schema_from_strings(fields, rdd.collect(), options=self.options) else: schema = infer_schema_from_rdd(rdd) schema_with_string = StructType(fields=[ StructField(field.name, StringType()) for field in schema.fields ]) if partition_schema: partitions_fields = partition_schema.fields full_schema = StructType(schema.fields[:-len(partitions_fields)] + partitions_fields) else: full_schema = schema cast_row = get_caster(from_type=schema_with_string, to_type=full_schema, options=self.options) casted_rdd = rdd.map(cast_row) casted_rdd._name = paths return DataFrameInternal(sc, casted_rdd, schema=full_schema)
def get_summary_schema(self, exprs): return StructType([StructField('summary', StringType(), True)] + [ StructField(field.name, StringType(), True) for field in get_schema_from_cols(exprs, self.bound_schema).fields ])
def __init__(self, spark, paths, schema, options): self.spark = spark self.paths = paths self.schema = schema or StructType( [StructField('value', StringType())]) self.options = Options(self.default_options, options)
def output_fields(self, schema): return [ StructField(name=str(self), dataType=self.data_type, nullable=self.is_nullable) ]
def test_csv_read_with_inferred_schema(self): df = spark.read.option('inferSchema', True).csv( os.path.join(os.path.dirname(os.path.realpath(__file__)), 'data/fundings/'), header=True, ) self.assertEqual(df.count(), 4) self.assertEqual( df.schema, StructType([ StructField('permalink', StringType()), StructField('company', StringType()), StructField('numEmps', IntegerType()), StructField('category', StringType()), StructField('city', StringType()), StructField('state', StringType()), StructField('fundedDate', TimestampType()), StructField('raisedAmt', IntegerType()), StructField('raisedCurrency', StringType()), StructField('round', StringType()), ]), ) self.assertEqual( [Row(**r.asDict()) for r in df.collect()], [ Row( permalink='mycityfaces', company='MyCityFaces', numEmps=7, category='web', city='Scottsdale', state='AZ', fundedDate=datetime.datetime(2008, 1, 1, 0, 0), raisedAmt=50000, raisedCurrency='USD', round='seed', ), Row( permalink='flypaper', company='Flypaper', numEmps=None, category='web', city='Phoenix', state='AZ', fundedDate=datetime.datetime(2008, 2, 1, 0, 0), raisedAmt=3000000, raisedCurrency='USD', round='a', ), Row( permalink='chosenlist-com', company='ChosenList.com', numEmps=5, category='web', city='Scottsdale', state='AZ', fundedDate=datetime.datetime(2008, 1, 25, 0, 0), raisedAmt=233750, raisedCurrency='USD', round='angel', ), Row( permalink='digg', company='Digg', numEmps=60, category='web', city='San Francisco', state='CA', fundedDate=datetime.datetime(2006, 12, 1, 0, 0), raisedAmt=8500000, raisedCurrency='USD', round='b', ), ], )
def test_csv_read_without_schema(self): df = spark.read.csv( os.path.join(os.path.dirname(os.path.realpath(__file__)), 'data/fundings/'), header=True, ) self.assertEqual(df.count(), 4) self.assertEqual( df.schema, StructType([ StructField('permalink', StringType()), StructField('company', StringType()), StructField('numEmps', StringType()), StructField('category', StringType()), StructField('city', StringType()), StructField('state', StringType()), StructField('fundedDate', StringType()), StructField('raisedAmt', StringType()), StructField('raisedCurrency', StringType()), StructField('round', StringType()), ]), ) self.assertListEqual( [Row(**r.asDict()) for r in df.collect()], [ Row( permalink='mycityfaces', company='MyCityFaces', numEmps='7', category='web', city='Scottsdale', state='AZ', fundedDate='2008-01-01', raisedAmt='50000', raisedCurrency='USD', round='seed', ), Row( permalink='flypaper', company='Flypaper', numEmps=None, category='web', city='Phoenix', state='AZ', fundedDate='2008-02-01', raisedAmt='3000000', raisedCurrency='USD', round='a', ), Row( permalink='chosenlist-com', company='ChosenList.com', numEmps='5', category='web', city='Scottsdale', state='AZ', fundedDate='2008-01-25', raisedAmt='233750', raisedCurrency='USD', round='angel', ), Row( permalink='digg', company='Digg', numEmps='60', category='web', city='San Francisco', state='CA', fundedDate='2006-12-01', raisedAmt='8500000', raisedCurrency='USD', round='b', ), ], )
def output_fields(self, schema): return [ StructField('pos', IntegerType(), False), StructField('col', DataType(), False), ]
def test_session_create_data_frame_from_list_with_schema(self): schema = StructType([StructField('map', MapType(StringType(), IntegerType()), True)]) df = self.spark.createDataFrame([({'a': 1},)], schema=schema) self.assertEqual(df.count(), 1) self.assertListEqual(df.collect(), [Row(map={'a': 1})]) self.assertEqual(df.schema, schema)