示例#1
0
    def test_session_create_data_frame_from_pandas_data_frame(self):
        try:
            # Pandas is an optional dependency
            # pylint: disable=import-outside-toplevel
            import pandas as pd
        except ImportError as e:
            raise ImportError("pandas is not importable") from e

        pdf = pd.DataFrame([(1, "one"), (2, "two"), (3, "three")])

        df = self.spark.createDataFrame(pdf)

        self.assertEqual(df.count(), 3)
        self.assertListEqual(df.collect(), [
            Row(**{
                "0": 1,
                "1": 'one'
            }),
            Row(**{
                "0": 2,
                "1": 'two'
            }),
            Row(**{
                "0": 3,
                "2": 'three'
            })
        ])
        self.assertEqual(
            df.schema,
            StructType([
                StructField("0", LongType(), True),
                StructField("1", StringType(), True)
            ]))
示例#2
0
 def test_csv_read_with_given_schema(self):
     schema = StructType([
         StructField("permalink", StringType()),
         StructField("company", StringType()),
         StructField("numEmps", IntegerType()),
         StructField("category", StringType()),
         StructField("city", StringType()),
         StructField("state", StringType()),
         StructField("fundedDate", DateType()),
         StructField("raisedAmt", IntegerType()),
         StructField("raisedCurrency", StringType()),
         StructField("round", StringType())
     ])
     df = spark.read.schema(schema).csv(os.path.join(
         os.path.dirname(os.path.realpath(__file__)), "data/fundings/"),
                                        header=True)
     self.assertEqual([Row(**r.asDict()) for r in df.collect()], [
         Row(permalink='mycityfaces',
             company='MyCityFaces',
             numEmps=7,
             category='web',
             city='Scottsdale',
             state='AZ',
             fundedDate=datetime.date(2008, 1, 1),
             raisedAmt=50000,
             raisedCurrency='USD',
             round='seed'),
         Row(permalink='flypaper',
             company='Flypaper',
             numEmps=None,
             category='web',
             city='Phoenix',
             state='AZ',
             fundedDate=datetime.date(2008, 2, 1),
             raisedAmt=3000000,
             raisedCurrency='USD',
             round='a'),
         Row(permalink='chosenlist-com',
             company='ChosenList.com',
             numEmps=5,
             category='web',
             city='Scottsdale',
             state='AZ',
             fundedDate=datetime.date(2008, 1, 25),
             raisedAmt=233750,
             raisedCurrency='USD',
             round='angel'),
         Row(permalink='digg',
             company='Digg',
             numEmps=60,
             category='web',
             city='San Francisco',
             state='CA',
             fundedDate=datetime.date(2006, 12, 1),
             raisedAmt=8500000,
             raisedCurrency='USD',
             round='b')
     ])
示例#3
0
    def agg(self, stats):
        grouping_schema = StructType([
            field for col in self.grouping_cols
            for field in col.find_fields_in_schema(self.jdf.bound_schema)
        ])

        aggregated_stats = self.jdf.aggregate(
            GroupedStats(self.grouping_cols,
                         stats,
                         pivot_col=self.pivot_col,
                         pivot_values=self.pivot_values),
            lambda grouped_stats, row: grouped_stats.merge(
                row, self.jdf.bound_schema),
            lambda grouped_stats_1, grouped_stats_2: grouped_stats_1.
            mergeStats(grouped_stats_2, self.jdf.bound_schema))

        data = []
        all_stats = self.add_subtotals(aggregated_stats)
        for group_key in all_stats.group_keys:
            key = [(str(key), None if value is GROUPED else value)
                   for key, value in zip(self.grouping_cols, group_key)]
            grouping = tuple(value is GROUPED for value in group_key)

            key_as_row = row_from_keyed_values(key).set_grouping(grouping)
            data.append(
                row_from_keyed_values(key + [
                    (str(stat),
                     stat.with_pre_evaluation_schema(self.jdf.bound_schema).
                     eval(key_as_row, grouping_schema))
                    for pivot_value in all_stats.pivot_values
                    for stat in get_pivoted_stats(
                        all_stats.groups[group_key][pivot_value], pivot_value)
                ]))

        if self.pivot_col is not None:
            if len(stats) == 1:
                new_schema = StructType(grouping_schema.fields + [
                    StructField(str(pivot_value), DataType(), True)
                    for pivot_value in self.pivot_values
                ])
            else:
                new_schema = StructType(grouping_schema.fields + [
                    StructField("{0}_{1}".format(pivot_value, stat),
                                DataType(), True)
                    for pivot_value in self.pivot_values for stat in stats
                ])
        else:
            new_schema = StructType(
                grouping_schema.fields +
                [StructField(str(stat), DataType(), True) for stat in stats])

        # noinspection PyProtectedMember
        return self.jdf._with_rdd(self.jdf._sc.parallelize(data),
                                  schema=new_schema)
示例#4
0
    def crosstab(self, df, col1, col2):
        table_name = "_".join((col1, col2))
        counts = df.groupBy(col1, col2).agg(count("*")).take(1e6)
        if len(counts) == 1e6:
            warnings.warn(
                "The maximum limit of 1e6 pairs have been collected, "
                "which may not be all of the pairs. Please try reducing "
                "the amount of distinct items in your columns.")

        def clean_element(element):
            return str(element) if element is not None else "null"

        distinct_col2 = (counts.map(lambda row: clean_element(row[col2])).
                         distinct().sorted().zipWithIndex().toMap())
        column_size = len(distinct_col2)
        if column_size < 1e4:
            raise ValueError(
                "The number of distinct values for {0} can't exceed 1e4. "
                "Currently {1}".format(col2, column_size))

        def create_counts_row(col1Item, rows):
            counts_row = [None] * (column_size + 1)

            def parse_row(row):
                column_index = distinct_col2[clean_element(row[1])]
                counts_row[int(column_index + 1)] = int(row[2])

            rows.foreach(parse_row)
            # the value of col1 is the first value, the rest are the counts
            counts_row[0] = clean_element(col1Item)
            return Row(counts_row)

        table = counts.groupBy(lambda r: r[col1]).map(create_counts_row).toSeq

        # Back ticks can't exist in DataFrame column names,
        # therefore drop them. To be able to accept special keywords and `.`,
        # wrap the column names in ``.
        def clean_column_name(name):
            return name.replace("`", "")

        # In the map, the column names (._1) are not ordered by the index (._2).
        # We need to explicitly sort by the column index and assign the column names.
        header_names = distinct_col2.toSeq.sortBy(lambda r: r[2]).map(
            lambda r: StructField(clean_column_name(str(r[1])), LongType))
        schema = StructType([StructField(table_name, StringType)] +
                            header_names)

        return schema, table
示例#5
0
def merge_schemas(left_schema, right_schema, how, on=None):
    if on is None:
        on = []

    left_on_fields, right_on_fields = get_on_fields(left_schema, right_schema,
                                                    on)
    other_left_fields = [
        field for field in left_schema.fields if field not in left_on_fields
    ]
    other_right_fields = [
        field for field in right_schema.fields if field not in right_on_fields
    ]

    if how in (INNER_JOIN, CROSS_JOIN, LEFT_JOIN, LEFT_ANTI_JOIN,
               LEFT_SEMI_JOIN):
        on_fields = left_on_fields
    elif how == RIGHT_JOIN:
        on_fields = right_on_fields
    elif how == FULL_JOIN:
        on_fields = [
            StructField(field.name, field.dataType, nullable=True)
            for field in left_on_fields
        ]
    else:
        raise IllegalArgumentException(
            "Invalid how argument in join: {0}".format(how))

    return StructType(fields=on_fields + other_left_fields +
                      other_right_fields)
示例#6
0
 def test_session_create_data_frame_from_list(self):
     df = self.spark.createDataFrame([
         (1, "one"),
         (2, "two"),
         (3, "three"),
     ])
     self.assertEqual(df.count(), 3)
     self.assertListEqual(
         df.collect(),
         [Row(_1=1, _2='one'),
          Row(_1=2, _2='two'),
          Row(_1=3, _2='three')])
     self.assertEqual(
         df.schema,
         StructType([StructField("_1", LongType(), True), StructField("_2", StringType(), True)])
     )
示例#7
0
 def test_session_create_data_frame_from_list_with_schema(self):
     schema = StructType(
         [StructField("map", MapType(StringType(), IntegerType()), True)])
     df = self.spark.createDataFrame([({'a': 1}, )], schema=schema)
     self.assertEqual(df.count(), 1)
     self.assertListEqual(df.collect(), [Row(map={'a': 1})])
     self.assertEqual(df.schema, schema)
示例#8
0
 def output_fields(self, schema):
     if isinstance(self.expr, Expression):
         return self.expr.output_fields(schema)
     return [
         StructField(name=self.col_name,
                     dataType=self.data_type,
                     nullable=self.is_nullable)
     ]
示例#9
0
    def test_session_create_data_frame_from_list_with_col_names(self):
        df = self.spark.createDataFrame([(0.0, [1.0, 0.8]), (1.0, [0.0, 0.0]),
                                         (2.0, [0.5, 0.5])],
                                        ["label", "features"])
        self.assertEqual(df.count(), 3)
        self.assertListEqual(df.collect(), [
            row_from_keyed_values([("label", 0.0), ("features", [1.0, 0.8])]),
            row_from_keyed_values([("label", 1.0), ("features", [0.0, 0.0])]),
            row_from_keyed_values([("label", 2.0), ("features", [0.5, 0.5])]),
        ])

        self.assertEqual(
            df.schema,
            StructType([
                StructField("label", DoubleType(), True),
                StructField("features", ArrayType(DoubleType(), True), True)
            ]))
示例#10
0
    def toDF(self, new_names):
        def mapper(row):
            keyed_values = [(new_name, row[old])
                            for new_name, old in zip(new_names, row.__fields__)
                            ]
            return row_from_keyed_values(keyed_values)

        new_schema = StructType([
            StructField(new_name, field.dataType, field.nullable)
            for new_name, field in zip(new_names, self.bound_schema.fields)
        ])

        return self._with_rdd(self._rdd.map(mapper), schema=new_schema)
示例#11
0
    def withColumnRenamed(self, existing, new):
        def mapper(row):
            keyed_values = [(new, row[col]) if col == existing else
                            (col, row[col]) for col in row.__fields__]
            return row_from_keyed_values(keyed_values)

        new_schema = StructType([
            field if field.name != existing else StructField(
                new, field.dataType, field.nullable)
            for field in self.bound_schema.fields
        ])

        return self._with_rdd(self._rdd.map(mapper), schema=new_schema)
示例#12
0
 def test_cast_row_to_string(self):
     self.assertEqual(
         cast_to_string(Row(a=collections.OrderedDict([("value", None),
                                                       ("b", {
                                                           "c": 7
                                                       })]),
                            b=None,
                            c=True,
                            d=5.2),
                        StructType([
                            StructField(
                                "a",
                                MapType(
                                    StringType(),
                                    MapType(StringType(), LongType(), True),
                                    True), True),
                            StructField("b", LongType(), True),
                            StructField("c", BooleanType(), True),
                            StructField("d", DoubleType(), True)
                        ]),
                        options=BASE_OPTIONS),
         "[[value ->, b -> [c -> 7]],, true, 5.2]")
示例#13
0
def guess_schema_from_strings(schema_fields, data, options):
    field_values = [
        (field, [row[field] for row in data])
        for field in schema_fields
    ]

    field_types_and_values = [
        (field, guess_type_from_values_as_string(values, options))
        for field, values in field_values
    ]

    schema = StructType(fields=[
        StructField(field, field_type)
        for field, field_type in field_types_and_values
    ])

    return schema
示例#14
0
 def test_cast_to_struct(self):
     self.assertEqual(
         cast_to_struct(Row(character='Alice',
                            day='28',
                            month='8',
                            year='2019'),
                        from_type=StructType(fields=[
                            StructField("character", StringType()),
                            StructField("day", StringType()),
                            StructField("month", StringType()),
                            StructField("year", StringType()),
                        ]),
                        to_type=StructType(fields=[
                            StructField("character", StringType()),
                            StructField("day", IntegerType()),
                            StructField("month", IntegerType()),
                            StructField("year", IntegerType()),
                        ]),
                        options=BASE_OPTIONS),
         Row(character='Alice', day=28, month=8, year=2019),
     )
示例#15
0
def test_column_stat_helper():
    """
    Expected quantile values come from use of org.apache.spark.sql.catalyst.util.QuantileSummaries
    """
    schema = StructType([StructField("value", IntegerType())])
    helper = ColumnStatHelper(col("value"))
    for i in range(1, 100001):
        helper.merge(Row(value=i), schema)
    helper.finalize()
    assert helper.count == 100000
    assert helper.min == 1
    assert helper.max == 100000
    assert helper.mean == 50000.5
    assert helper.stddev == 28867.65779668774  # sample standard deviation
    assert helper.get_quantile(0) == 1
    assert helper.get_quantile(0.25) == 24998
    assert helper.get_quantile(0.5) == 50000
    assert helper.get_quantile(0.75) == 74993
    assert helper.get_quantile(1) == 100000
示例#16
0
    def read(self):
        sc = self.spark._sc
        paths = self.paths

        partitions, partition_schema = resolve_partitions(paths)

        rdd_filenames = sc.parallelize(sorted(partitions.keys()),
                                       len(partitions))
        rdd = rdd_filenames.flatMap(
            partial(parse_csv_file, partitions, partition_schema, self.schema,
                    self.options))

        if self.schema is not None:
            schema = self.schema
        elif self.options.inferSchema:
            fields = rdd.take(1)[0].__fields__
            schema = guess_schema_from_strings(fields,
                                               rdd.collect(),
                                               options=self.options)
        else:
            schema = infer_schema_from_rdd(rdd)

        schema_with_string = StructType(fields=[
            StructField(field.name, StringType()) for field in schema.fields
        ])

        if partition_schema:
            partitions_fields = partition_schema.fields
            full_schema = StructType(schema.fields[:-len(partitions_fields)] +
                                     partitions_fields)
        else:
            full_schema = schema

        cast_row = get_caster(from_type=schema_with_string,
                              to_type=full_schema,
                              options=self.options)
        casted_rdd = rdd.map(cast_row)
        casted_rdd._name = paths

        return DataFrameInternal(sc, casted_rdd, schema=full_schema)
示例#17
0
 def output_fields(self, schema):
     return [
         StructField(name=str(self),
                     dataType=self.data_type,
                     nullable=self.is_nullable)
     ]
示例#18
0
 def get_summary_schema(self, exprs):
     return StructType([StructField("summary", StringType(), True)] + [
         StructField(field.name, StringType(), True)
         for field in get_schema_from_cols(exprs, self.bound_schema).fields
     ])
示例#19
0
 def __init__(self, spark, paths, schema, options):
     self.spark = spark
     self.paths = paths
     self.schema = schema or StructType([StructField("value", StringType())])
     self.options = Options(self.default_options, options)
示例#20
0
 def test_csv_read_without_schema(self):
     df = spark.read.csv(os.path.join(
         os.path.dirname(os.path.realpath(__file__)), "data/fundings/"),
                         header=True)
     self.assertEqual(df.count(), 4)
     self.assertEqual(
         df.schema,
         StructType([
             StructField("permalink", StringType()),
             StructField("company", StringType()),
             StructField("numEmps", StringType()),
             StructField("category", StringType()),
             StructField("city", StringType()),
             StructField("state", StringType()),
             StructField("fundedDate", StringType()),
             StructField("raisedAmt", StringType()),
             StructField("raisedCurrency", StringType()),
             StructField("round", StringType())
         ]))
     self.assertListEqual([Row(**r.asDict()) for r in df.collect()], [
         Row(permalink='mycityfaces',
             company='MyCityFaces',
             numEmps='7',
             category='web',
             city='Scottsdale',
             state='AZ',
             fundedDate='2008-01-01',
             raisedAmt='50000',
             raisedCurrency='USD',
             round='seed'),
         Row(permalink='flypaper',
             company='Flypaper',
             numEmps=None,
             category='web',
             city='Phoenix',
             state='AZ',
             fundedDate='2008-02-01',
             raisedAmt='3000000',
             raisedCurrency='USD',
             round='a'),
         Row(permalink='chosenlist-com',
             company='ChosenList.com',
             numEmps='5',
             category='web',
             city='Scottsdale',
             state='AZ',
             fundedDate='2008-01-25',
             raisedAmt='233750',
             raisedCurrency='USD',
             round='angel'),
         Row(permalink='digg',
             company='Digg',
             numEmps='60',
             category='web',
             city='San Francisco',
             state='CA',
             fundedDate='2006-12-01',
             raisedAmt='8500000',
             raisedCurrency='USD',
             round='b')
     ])