예제 #1
0
def guess_type_from_values_as_string(values, options):
    # Reproduces inferences available in Spark
    # PartitioningUtils.inferPartitionColumnValue()
    # located in org.apache.spark.sql.execution.datasources
    tested_types = (
        IntegerType(),
        LongType(),
        DecimalType(),
        DoubleType(),
        TimestampType(),
        StringType(),
    )
    string_type = StringType()
    for tested_type in tested_types:
        type_caster = get_caster(from_type=string_type,
                                 to_type=tested_type,
                                 options=options)
        try:
            for value in values:
                casted_value = type_caster(value)
                if casted_value is None and value not in ('null', None):
                    raise ValueError
            return tested_type
        except ValueError:
            pass
    # Should never happen
    raise AnalysisException(
        'Unable to find a matching type for some fields, even StringType did not work'
    )
 def test_cast_array_to_string(self):
     self.assertEqual(
         cast_to_string(
             [[[1, None, 2], []]],
             ArrayType(ArrayType(ArrayType(IntegerType()))),
             options=BASE_OPTIONS,
         ),
         '[[[1,, 2], []]]',
     )
 def test_cast_to_struct(self):
     self.assertEqual(
         cast_to_struct(
             Row(character='Alice', day='28', month='8', year='2019'),
             from_type=StructType(fields=[
                 StructField('character', StringType()),
                 StructField('day', StringType()),
                 StructField('month', StringType()),
                 StructField('year', StringType()),
             ]),
             to_type=StructType(fields=[
                 StructField('character', StringType()),
                 StructField('day', IntegerType()),
                 StructField('month', IntegerType()),
                 StructField('year', IntegerType()),
             ]),
             options=BASE_OPTIONS,
         ),
         Row(character='Alice', day=28, month=8, year=2019),
     )
 def test_cast_map_to_string(self):
     self.assertEqual(
         cast_to_string(
             {
                 True:
                 collections.OrderedDict([('one', 1), ('nothing', None),
                                          ('three', 3)])
             },
             MapType(BooleanType(), MapType(StringType(), IntegerType())),
             options=BASE_OPTIONS,
         ),
         '[true -> [one -> 1, nothing ->, three -> 3]]',
     )
def test_column_stat_helper():
    """
    Expected quantile values come from use of org.apache.spark.sql.catalyst.util.QuantileSummaries
    """
    schema = StructType([StructField('value', IntegerType())])
    helper = ColumnStatHelper(col('value'))
    for i in range(1, 100001):
        helper.merge(Row(value=i), schema)
    helper.finalize()
    assert helper.count == 100000
    assert helper.min == 1
    assert helper.max == 100000
    assert helper.mean == 50000.5
    assert helper.stddev == 28867.65779668774  # sample standard deviation
    assert helper.get_quantile(0) == 1
    assert helper.get_quantile(0.25) == 24998
    assert helper.get_quantile(0.5) == 50000
    assert helper.get_quantile(0.75) == 74993
    assert helper.get_quantile(1) == 100000
 def test_cast_int_to_timestamp(self):
     self.assertEqual(
         cast_to_timestamp(86400 * 365, IntegerType(),
                           options=BASE_OPTIONS),
         datetime.datetime(1971, 1, 1, 1, 0, 0, 0),
     )
 def test_cast_truish_to_boolean(self):
     self.assertEqual(
         cast_to_boolean(-1, IntegerType(), options=BASE_OPTIONS), True)
 def test_cast_falsish_to_boolean(self):
     self.assertEqual(
         cast_to_boolean(0, IntegerType(), options=BASE_OPTIONS), False)
예제 #9
0
 def test_csv_read_with_inferred_schema(self):
     df = spark.read.option('inferSchema', True).csv(
         os.path.join(os.path.dirname(os.path.realpath(__file__)),
                      'data/fundings/'),
         header=True,
     )
     self.assertEqual(df.count(), 4)
     self.assertEqual(
         df.schema,
         StructType([
             StructField('permalink', StringType()),
             StructField('company', StringType()),
             StructField('numEmps', IntegerType()),
             StructField('category', StringType()),
             StructField('city', StringType()),
             StructField('state', StringType()),
             StructField('fundedDate', TimestampType()),
             StructField('raisedAmt', IntegerType()),
             StructField('raisedCurrency', StringType()),
             StructField('round', StringType()),
         ]),
     )
     self.assertEqual(
         [Row(**r.asDict()) for r in df.collect()],
         [
             Row(
                 permalink='mycityfaces',
                 company='MyCityFaces',
                 numEmps=7,
                 category='web',
                 city='Scottsdale',
                 state='AZ',
                 fundedDate=datetime.datetime(2008, 1, 1, 0, 0),
                 raisedAmt=50000,
                 raisedCurrency='USD',
                 round='seed',
             ),
             Row(
                 permalink='flypaper',
                 company='Flypaper',
                 numEmps=None,
                 category='web',
                 city='Phoenix',
                 state='AZ',
                 fundedDate=datetime.datetime(2008, 2, 1, 0, 0),
                 raisedAmt=3000000,
                 raisedCurrency='USD',
                 round='a',
             ),
             Row(
                 permalink='chosenlist-com',
                 company='ChosenList.com',
                 numEmps=5,
                 category='web',
                 city='Scottsdale',
                 state='AZ',
                 fundedDate=datetime.datetime(2008, 1, 25, 0, 0),
                 raisedAmt=233750,
                 raisedCurrency='USD',
                 round='angel',
             ),
             Row(
                 permalink='digg',
                 company='Digg',
                 numEmps=60,
                 category='web',
                 city='San Francisco',
                 state='CA',
                 fundedDate=datetime.datetime(2006, 12, 1, 0, 0),
                 raisedAmt=8500000,
                 raisedCurrency='USD',
                 round='b',
             ),
         ],
     )
예제 #10
0
 def output_fields(self, schema):
     return [
         StructField('pos', IntegerType(), False),
         StructField('col', DataType(), False),
     ]
예제 #11
0
 def test_session_create_data_frame_from_list_with_schema(self):
     schema = StructType([StructField('map', MapType(StringType(), IntegerType()), True)])
     df = self.spark.createDataFrame([({'a': 1},)], schema=schema)
     self.assertEqual(df.count(), 1)
     self.assertListEqual(df.collect(), [Row(map={'a': 1})])
     self.assertEqual(df.schema, schema)