def test_cast_row_to_string(self):
     self.assertEqual(
         cast_to_string(
             Row(
                 a=collections.OrderedDict([('value', None), ('b', {
                     'c': 7
                 })]),
                 b=None,
                 c=True,
                 d=5.2,
             ),
             StructType([
                 StructField(
                     'a',
                     MapType(
                         StringType(),
                         MapType(StringType(), LongType(), True),
                         True,
                     ),
                     True,
                 ),
                 StructField('b', LongType(), True),
                 StructField('c', BooleanType(), True),
                 StructField('d', DoubleType(), True),
             ]),
             options=BASE_OPTIONS,
         ),
         '[[value ->, b -> [c -> 7]],, true, 5.2]',
     )
    def agg(self, stats):
        grouping_schema = StructType([
            field for col in self.grouping_cols
            for field in col.find_fields_in_schema(self.jdf.bound_schema)
        ])

        aggregated_stats = self.jdf.aggregate(
            GroupedStats(
                self.grouping_cols,
                stats,
                pivot_col=self.pivot_col,
                pivot_values=self.pivot_values,
            ),
            lambda grouped_stats, row: grouped_stats.merge(
                row, self.jdf.bound_schema),
            lambda grouped_stats_1, grouped_stats_2: grouped_stats_1.
            mergeStats(grouped_stats_2, self.jdf.bound_schema),
        )

        data = []
        all_stats = self.add_subtotals(aggregated_stats)
        for group_key in all_stats.group_keys:
            key = [(str(key), None if value is GROUPED else value)
                   for key, value in zip(self.grouping_cols, group_key)]
            grouping = tuple(value is GROUPED for value in group_key)

            key_as_row = row_from_keyed_values(key).set_grouping(grouping)
            data.append(
                row_from_keyed_values(key + [(
                    str(stat),
                    stat.with_pre_evaluation_schema(self.jdf.bound_schema).
                    eval(key_as_row, grouping_schema),
                ) for pivot_value in all_stats.pivot_values
                                             for stat in get_pivoted_stats(
                                                 all_stats.groups[group_key]
                                                 [pivot_value], pivot_value)]))

        if self.pivot_col is not None:
            if len(stats) == 1:
                new_schema = StructType(grouping_schema.fields + [
                    StructField(str(pivot_value), DataType(), True)
                    for pivot_value in self.pivot_values
                ])
            else:
                new_schema = StructType(grouping_schema.fields + [
                    StructField('{0}_{1}'.format(pivot_value, stat),
                                DataType(), True)
                    for pivot_value in self.pivot_values for stat in stats
                ])
        else:
            new_schema = StructType(
                grouping_schema.fields +
                [StructField(str(stat), DataType(), True) for stat in stats])

        # noinspection PyProtectedMember
        return self.jdf._with_rdd(self.jdf._sc.parallelize(data),
                                  schema=new_schema)
Example #3
0
 def test_session_create_data_frame_from_list(self):
     df = self.spark.createDataFrame([(1, 'one'), (2, 'two'), (3, 'three')])
     self.assertEqual(df.count(), 3)
     self.assertListEqual(
         df.collect(), [Row(_1=1, _2='one'), Row(_1=2, _2='two'), Row(_1=3, _2='three')],
     )
     self.assertEqual(
         df.schema, StructType([StructField('_1', LongType(), True), StructField('_2', StringType(), True)]),
     )
    def crosstab(self, df, col1, col2):
        table_name = '_'.join((col1, col2))
        counts = df.groupBy(col1, col2).agg(count('*')).take(1e6)
        if len(counts) == 1e6:
            warnings.warn(
                'The maximum limit of 1e6 pairs have been collected, '
                'which may not be all of the pairs. Please try reducing '
                'the amount of distinct items in your columns.')

        def clean_element(element):
            return str(element) if element is not None else 'null'

        distinct_col2 = counts.map(lambda row: clean_element(row[
            col2])).distinct().sorted().zipWithIndex().toMap()
        column_size = len(distinct_col2)
        if column_size < 1e4:
            raise ValueError(
                "The number of distinct values for {0} can't exceed 1e4. "
                'Currently {1}'.format(col2, column_size))

        def create_counts_row(col1Item, rows):
            counts_row = [None] * (column_size + 1)

            def parse_row(row):
                column_index = distinct_col2[clean_element(row[1])]
                counts_row[int(column_index + 1)] = int(row[2])

            rows.foreach(parse_row)
            # the value of col1 is the first value, the rest are the counts
            counts_row[0] = clean_element(col1Item)
            return Row(counts_row)

        table = counts.groupBy(lambda r: r[col1]).map(create_counts_row).toSeq

        # Back ticks can't exist in DataFrame column names,
        # therefore drop them. To be able to accept special keywords and `.`,
        # wrap the column names in ``.
        def clean_column_name(name):
            return name.replace('`', '')

        # In the map, the column names (._1) are not ordered by the index (._2).
        # We need to explicitly sort by the column index and assign the column names.
        header_names = distinct_col2.toSeq.sortBy(lambda r: r[2]).map(
            lambda r: StructField(clean_column_name(str(r[1])), LongType))
        schema = StructType([StructField(table_name, StringType)] +
                            header_names)

        return schema, table
def merge_schemas(left_schema, right_schema, how, on=None):
    if on is None:
        on = []

    left_on_fields, right_on_fields = get_on_fields(left_schema, right_schema,
                                                    on)
    other_left_fields = [
        field for field in left_schema.fields if field not in left_on_fields
    ]
    other_right_fields = [
        field for field in right_schema.fields if field not in right_on_fields
    ]

    if how in (INNER_JOIN, CROSS_JOIN, LEFT_JOIN, LEFT_ANTI_JOIN,
               LEFT_SEMI_JOIN):
        on_fields = left_on_fields
    elif how == RIGHT_JOIN:
        on_fields = right_on_fields
    elif how == FULL_JOIN:
        on_fields = [
            StructField(field.name, field.dataType, nullable=True)
            for field in left_on_fields
        ]
    else:
        raise IllegalArgumentException(
            'Invalid how argument in join: {0}'.format(how))

    return StructType(fields=on_fields + other_left_fields +
                      other_right_fields)
 def output_fields(self, schema):
     if isinstance(self.expr, Expression):
         return self.expr.output_fields(schema)
     return [
         StructField(name=self.col_name,
                     dataType=self.data_type,
                     nullable=self.is_nullable)
     ]
Example #7
0
    def test_session_create_data_frame_from_pandas_data_frame(self):
        try:
            # Pandas is an optional dependency
            # pylint: disable=import-outside-toplevel
            import pandas as pd
        except ImportError:
            raise Exception('pandas is not importable')

        pdf = pd.DataFrame([(1, 'one'), (2, 'two'), (3, 'three')])

        df = self.spark.createDataFrame(pdf)

        self.assertEqual(df.count(), 3)
        self.assertListEqual(
            df.collect(), [Row(**{'0': 1, '1': 'one'}), Row(**{'0': 2, '1': 'two'}), Row(**{'0': 3, '2': 'three'})],
        )
        self.assertEqual(
            df.schema, StructType([StructField('0', LongType(), True), StructField('1', StringType(), True)]),
        )
    def withColumnRenamed(self, existing, new):
        def mapper(row):
            keyed_values = [(new, row[col]) if col == existing else
                            (col, row[col]) for col in row.__fields__]
            return row_from_keyed_values(keyed_values)

        new_schema = StructType([
            field if field.name != existing else StructField(
                new, field.dataType, field.nullable)
            for field in self.bound_schema.fields
        ])

        return self._with_rdd(self._rdd.map(mapper), schema=new_schema)
    def toDF(self, new_names):
        def mapper(row):
            keyed_values = [(new_name, row[old])
                            for new_name, old in zip(new_names, row.__fields__)
                            ]
            return row_from_keyed_values(keyed_values)

        new_schema = StructType([
            StructField(new_name, field.dataType, field.nullable)
            for new_name, field in zip(new_names, self.bound_schema.fields)
        ])

        return self._with_rdd(self._rdd.map(mapper), schema=new_schema)
Example #10
0
    def test_session_create_data_frame_from_list_with_col_names(self):
        df = self.spark.createDataFrame(
            [(0.0, [1.0, 0.8]), (1.0, [0.0, 0.0]), (2.0, [0.5, 0.5])], ['label', 'features'],
        )
        self.assertEqual(df.count(), 3)
        self.assertListEqual(
            df.collect(),
            [
                row_from_keyed_values([('label', 0.0), ('features', [1.0, 0.8])]),
                row_from_keyed_values([('label', 1.0), ('features', [0.0, 0.0])]),
                row_from_keyed_values([('label', 2.0), ('features', [0.5, 0.5])]),
            ],
        )

        self.assertEqual(
            df.schema,
            StructType(
                [
                    StructField('label', DoubleType(), True),
                    StructField('features', ArrayType(DoubleType(), True), True),
                ]
            ),
        )
Example #11
0
def guess_schema_from_strings(schema_fields, data, options):
    field_values = [(field, [row[field] for row in data])
                    for field in schema_fields]

    field_types_and_values = [
        (field, guess_type_from_values_as_string(values, options))
        for field, values in field_values
    ]

    schema = StructType(fields=[
        StructField(field, field_type)
        for field, field_type in field_types_and_values
    ])

    return schema
 def test_cast_to_struct(self):
     self.assertEqual(
         cast_to_struct(
             Row(character='Alice', day='28', month='8', year='2019'),
             from_type=StructType(fields=[
                 StructField('character', StringType()),
                 StructField('day', StringType()),
                 StructField('month', StringType()),
                 StructField('year', StringType()),
             ]),
             to_type=StructType(fields=[
                 StructField('character', StringType()),
                 StructField('day', IntegerType()),
                 StructField('month', IntegerType()),
                 StructField('year', IntegerType()),
             ]),
             options=BASE_OPTIONS,
         ),
         Row(character='Alice', day=28, month=8, year=2019),
     )
def test_column_stat_helper():
    """
    Expected quantile values come from use of org.apache.spark.sql.catalyst.util.QuantileSummaries
    """
    schema = StructType([StructField('value', IntegerType())])
    helper = ColumnStatHelper(col('value'))
    for i in range(1, 100001):
        helper.merge(Row(value=i), schema)
    helper.finalize()
    assert helper.count == 100000
    assert helper.min == 1
    assert helper.max == 100000
    assert helper.mean == 50000.5
    assert helper.stddev == 28867.65779668774  # sample standard deviation
    assert helper.get_quantile(0) == 1
    assert helper.get_quantile(0.25) == 24998
    assert helper.get_quantile(0.5) == 50000
    assert helper.get_quantile(0.75) == 74993
    assert helper.get_quantile(1) == 100000
    def read(self):
        sc = self.spark._sc
        paths = self.paths

        partitions, partition_schema = resolve_partitions(paths)

        rdd_filenames = sc.parallelize(sorted(partitions.keys()),
                                       len(partitions))
        rdd = rdd_filenames.flatMap(
            partial(parse_csv_file, partitions, partition_schema, self.schema,
                    self.options))

        if self.schema is not None:
            schema = self.schema
        elif self.options.inferSchema:
            fields = rdd.take(1)[0].__fields__
            schema = guess_schema_from_strings(fields,
                                               rdd.collect(),
                                               options=self.options)
        else:
            schema = infer_schema_from_rdd(rdd)

        schema_with_string = StructType(fields=[
            StructField(field.name, StringType()) for field in schema.fields
        ])

        if partition_schema:
            partitions_fields = partition_schema.fields
            full_schema = StructType(schema.fields[:-len(partitions_fields)] +
                                     partitions_fields)
        else:
            full_schema = schema

        cast_row = get_caster(from_type=schema_with_string,
                              to_type=full_schema,
                              options=self.options)
        casted_rdd = rdd.map(cast_row)
        casted_rdd._name = paths

        return DataFrameInternal(sc, casted_rdd, schema=full_schema)
 def get_summary_schema(self, exprs):
     return StructType([StructField('summary', StringType(), True)] + [
         StructField(field.name, StringType(), True)
         for field in get_schema_from_cols(exprs, self.bound_schema).fields
     ])
 def __init__(self, spark, paths, schema, options):
     self.spark = spark
     self.paths = paths
     self.schema = schema or StructType(
         [StructField('value', StringType())])
     self.options = Options(self.default_options, options)
 def output_fields(self, schema):
     return [
         StructField(name=str(self),
                     dataType=self.data_type,
                     nullable=self.is_nullable)
     ]
Example #18
0
 def test_csv_read_with_inferred_schema(self):
     df = spark.read.option('inferSchema', True).csv(
         os.path.join(os.path.dirname(os.path.realpath(__file__)),
                      'data/fundings/'),
         header=True,
     )
     self.assertEqual(df.count(), 4)
     self.assertEqual(
         df.schema,
         StructType([
             StructField('permalink', StringType()),
             StructField('company', StringType()),
             StructField('numEmps', IntegerType()),
             StructField('category', StringType()),
             StructField('city', StringType()),
             StructField('state', StringType()),
             StructField('fundedDate', TimestampType()),
             StructField('raisedAmt', IntegerType()),
             StructField('raisedCurrency', StringType()),
             StructField('round', StringType()),
         ]),
     )
     self.assertEqual(
         [Row(**r.asDict()) for r in df.collect()],
         [
             Row(
                 permalink='mycityfaces',
                 company='MyCityFaces',
                 numEmps=7,
                 category='web',
                 city='Scottsdale',
                 state='AZ',
                 fundedDate=datetime.datetime(2008, 1, 1, 0, 0),
                 raisedAmt=50000,
                 raisedCurrency='USD',
                 round='seed',
             ),
             Row(
                 permalink='flypaper',
                 company='Flypaper',
                 numEmps=None,
                 category='web',
                 city='Phoenix',
                 state='AZ',
                 fundedDate=datetime.datetime(2008, 2, 1, 0, 0),
                 raisedAmt=3000000,
                 raisedCurrency='USD',
                 round='a',
             ),
             Row(
                 permalink='chosenlist-com',
                 company='ChosenList.com',
                 numEmps=5,
                 category='web',
                 city='Scottsdale',
                 state='AZ',
                 fundedDate=datetime.datetime(2008, 1, 25, 0, 0),
                 raisedAmt=233750,
                 raisedCurrency='USD',
                 round='angel',
             ),
             Row(
                 permalink='digg',
                 company='Digg',
                 numEmps=60,
                 category='web',
                 city='San Francisco',
                 state='CA',
                 fundedDate=datetime.datetime(2006, 12, 1, 0, 0),
                 raisedAmt=8500000,
                 raisedCurrency='USD',
                 round='b',
             ),
         ],
     )
Example #19
0
 def test_csv_read_without_schema(self):
     df = spark.read.csv(
         os.path.join(os.path.dirname(os.path.realpath(__file__)),
                      'data/fundings/'),
         header=True,
     )
     self.assertEqual(df.count(), 4)
     self.assertEqual(
         df.schema,
         StructType([
             StructField('permalink', StringType()),
             StructField('company', StringType()),
             StructField('numEmps', StringType()),
             StructField('category', StringType()),
             StructField('city', StringType()),
             StructField('state', StringType()),
             StructField('fundedDate', StringType()),
             StructField('raisedAmt', StringType()),
             StructField('raisedCurrency', StringType()),
             StructField('round', StringType()),
         ]),
     )
     self.assertListEqual(
         [Row(**r.asDict()) for r in df.collect()],
         [
             Row(
                 permalink='mycityfaces',
                 company='MyCityFaces',
                 numEmps='7',
                 category='web',
                 city='Scottsdale',
                 state='AZ',
                 fundedDate='2008-01-01',
                 raisedAmt='50000',
                 raisedCurrency='USD',
                 round='seed',
             ),
             Row(
                 permalink='flypaper',
                 company='Flypaper',
                 numEmps=None,
                 category='web',
                 city='Phoenix',
                 state='AZ',
                 fundedDate='2008-02-01',
                 raisedAmt='3000000',
                 raisedCurrency='USD',
                 round='a',
             ),
             Row(
                 permalink='chosenlist-com',
                 company='ChosenList.com',
                 numEmps='5',
                 category='web',
                 city='Scottsdale',
                 state='AZ',
                 fundedDate='2008-01-25',
                 raisedAmt='233750',
                 raisedCurrency='USD',
                 round='angel',
             ),
             Row(
                 permalink='digg',
                 company='Digg',
                 numEmps='60',
                 category='web',
                 city='San Francisco',
                 state='CA',
                 fundedDate='2006-12-01',
                 raisedAmt='8500000',
                 raisedCurrency='USD',
                 round='b',
             ),
         ],
     )
 def output_fields(self, schema):
     return [
         StructField('pos', IntegerType(), False),
         StructField('col', DataType(), False),
     ]
Example #21
0
 def test_session_create_data_frame_from_list_with_schema(self):
     schema = StructType([StructField('map', MapType(StringType(), IntegerType()), True)])
     df = self.spark.createDataFrame([({'a': 1},)], schema=schema)
     self.assertEqual(df.count(), 1)
     self.assertListEqual(df.collect(), [Row(map={'a': 1})])
     self.assertEqual(df.schema, schema)