예제 #1
0
def test_alias__long_bigint():
    schema = StructType.fromDDL('i1: long, i2: bigint')
    assert schema == StructType([
        StructField('i1', LongType(), True),
        StructField('i2', LongType(), True),
    ])
    assert str(schema) == 'StructType(List(StructField(i1,LongType,true),StructField(i2,LongType,true)))'
예제 #2
0
def test_double_string():
    schema = StructType.fromDDL("a DOUBLE, b STRING")
    assert schema == StructType([
        StructField('a', DoubleType(), True),
        StructField('b', StringType(), True),
    ])
    assert str(schema) == 'StructType(List(StructField(a,DoubleType,true),StructField(b,StringType,true)))'
예제 #3
0
def test_byte_decimal():
    schema = StructType.fromDDL("a: byte, b: decimal(  16 , 8   ) ")
    assert schema == StructType([
        StructField('a', ByteType(), True),
        StructField('b', DecimalType(16, 8), True),
    ])
    assert str(schema) == 'StructType(List(StructField(a,ByteType,true),StructField(b,DecimalType(16,8),true)))'
예제 #4
0
def test_nested_array():
    schema = StructType.fromDDL('some_str: string, arr: array<array<string>>')
    assert schema == StructType([
        StructField('some_str', StringType(), True),
        StructField('arr', ArrayType(ArrayType(StringType())), True),
    ])
    assert str(schema) == 'StructType(List(' \
                          'StructField(some_str,StringType,true),' \
                          'StructField(arr,ArrayType(ArrayType(StringType,true),true),true)' \
                          '))'
예제 #5
0
def test_basic_entries():
    schema = StructType.fromDDL('some_str: string, some_int: integer, some_date: date')
    assert schema == StructType([
        StructField('some_str', StringType(), True),
        StructField('some_int', IntegerType(), True),
        StructField('some_date', DateType(), True),
    ])
    assert str(schema) == (
        'StructType(List('
        'StructField(some_str,StringType,true),'
        'StructField(some_int,IntegerType,true),'
        'StructField(some_date,DateType,true)'
        '))'
    )
예제 #6
0
    def test_session_create_data_frame_from_pandas_data_frame(self):
        try:
            # Pandas is an optional dependency
            # pylint: disable=import-outside-toplevel
            import pandas as pd
        except ImportError as e:
            raise ImportError("pandas is not importable") from e

        pdf = pd.DataFrame([(1, "one"), (2, "two"), (3, "three")])

        df = self.spark.createDataFrame(pdf)

        self.assertEqual(df.count(), 3)
        self.assertListEqual(df.collect(), [
            Row(**{
                "0": 1,
                "1": 'one'
            }),
            Row(**{
                "0": 2,
                "1": 'two'
            }),
            Row(**{
                "0": 3,
                "2": 'three'
            })
        ])
        self.assertEqual(
            df.schema,
            StructType([
                StructField("0", LongType(), True),
                StructField("1", StringType(), True)
            ]))
예제 #7
0
 def test_session_create_data_frame_from_list_with_schema(self):
     schema = StructType(
         [StructField("map", MapType(StringType(), IntegerType()), True)])
     df = self.spark.createDataFrame([({'a': 1}, )], schema=schema)
     self.assertEqual(df.count(), 1)
     self.assertListEqual(df.collect(), [Row(map={'a': 1})])
     self.assertEqual(df.schema, schema)
예제 #8
0
 def test_csv_read_with_given_schema(self):
     schema = StructType([
         StructField("permalink", StringType()),
         StructField("company", StringType()),
         StructField("numEmps", IntegerType()),
         StructField("category", StringType()),
         StructField("city", StringType()),
         StructField("state", StringType()),
         StructField("fundedDate", DateType()),
         StructField("raisedAmt", IntegerType()),
         StructField("raisedCurrency", StringType()),
         StructField("round", StringType())
     ])
     df = spark.read.schema(schema).csv(os.path.join(
         os.path.dirname(os.path.realpath(__file__)), "data/fundings/"),
                                        header=True)
     self.assertEqual([Row(**r.asDict()) for r in df.collect()], [
         Row(permalink='mycityfaces',
             company='MyCityFaces',
             numEmps=7,
             category='web',
             city='Scottsdale',
             state='AZ',
             fundedDate=datetime.date(2008, 1, 1),
             raisedAmt=50000,
             raisedCurrency='USD',
             round='seed'),
         Row(permalink='flypaper',
             company='Flypaper',
             numEmps=None,
             category='web',
             city='Phoenix',
             state='AZ',
             fundedDate=datetime.date(2008, 2, 1),
             raisedAmt=3000000,
             raisedCurrency='USD',
             round='a'),
         Row(permalink='chosenlist-com',
             company='ChosenList.com',
             numEmps=5,
             category='web',
             city='Scottsdale',
             state='AZ',
             fundedDate=datetime.date(2008, 1, 25),
             raisedAmt=233750,
             raisedCurrency='USD',
             round='angel'),
         Row(permalink='digg',
             company='Digg',
             numEmps=60,
             category='web',
             city='San Francisco',
             state='CA',
             fundedDate=datetime.date(2006, 12, 1),
             raisedAmt=8500000,
             raisedCurrency='USD',
             round='b')
     ])
예제 #9
0
    def test_session_create_data_frame_from_list_with_col_names(self):
        df = self.spark.createDataFrame([(0.0, [1.0, 0.8]), (1.0, [0.0, 0.0]),
                                         (2.0, [0.5, 0.5])],
                                        ["label", "features"])
        self.assertEqual(df.count(), 3)
        self.assertListEqual(df.collect(), [
            row_from_keyed_values([("label", 0.0), ("features", [1.0, 0.8])]),
            row_from_keyed_values([("label", 1.0), ("features", [0.0, 0.0])]),
            row_from_keyed_values([("label", 2.0), ("features", [0.5, 0.5])]),
        ])

        self.assertEqual(
            df.schema,
            StructType([
                StructField("label", DoubleType(), True),
                StructField("features", ArrayType(DoubleType(), True), True)
            ]))
예제 #10
0
 def test_session_create_data_frame_from_list(self):
     df = self.spark.createDataFrame([
         (1, "one"),
         (2, "two"),
         (3, "three"),
     ])
     self.assertEqual(df.count(), 3)
     self.assertListEqual(
         df.collect(),
         [Row(_1=1, _2='one'),
          Row(_1=2, _2='two'),
          Row(_1=3, _2='three')])
     self.assertEqual(
         df.schema,
         StructType([
             StructField("_1", LongType(), True),
             StructField("_2", StringType(), True)
         ]))
예제 #11
0
def test_column_stat_helper():
    """
    Expected quantile values come from use of org.apache.spark.sql.catalyst.util.QuantileSummaries
    """
    schema = StructType([StructField("value", IntegerType())])
    helper = ColumnStatHelper(col("value"))
    for i in range(1, 100001):
        helper.merge(Row(value=i), schema)
    helper.finalize()
    assert helper.count == 100000
    assert helper.min == 1
    assert helper.max == 100000
    assert helper.mean == 50000.5
    assert helper.stddev == 28867.65779668774  # sample standard deviation
    assert helper.get_quantile(0) == 1
    assert helper.get_quantile(0.25) == 24998
    assert helper.get_quantile(0.5) == 50000
    assert helper.get_quantile(0.75) == 74993
    assert helper.get_quantile(1) == 100000
예제 #12
0
 DecimalType(10, 0),
 "Array<string>":
 ArrayType(StringType()),
 "Array<doublE>":
 ArrayType(DoubleType()),
 "Array<int>":
 ArrayType(IntegerType()),
 "Array<map<int, tinYint>>":
 ArrayType(MapType(IntegerType(), ByteType())),
 "Map<string, int>":
 MapType(StringType(), IntegerType()),
 "Map < integer, String >":
 MapType(IntegerType(), StringType()),
 "Struct<name: string, age: int>":
 StructType([
     StructField(name="name", dataType=StringType()),
     StructField(name="age", dataType=IntegerType()),
 ]),
 "array<struct<tinYint:tinyint>>":
 ArrayType(StructType([StructField('tinYint', ByteType())]), ),
 "MAp<int, ARRAY<double>>":
 MapType(IntegerType(), ArrayType(DoubleType())),
 "MAP<int, struct<varchar:string>>":
 MapType(IntegerType(), StructType([StructField("varchar", StringType())])),
 "struct<intType: int, ts:timestamp>":
 StructType([
     StructField("intType", IntegerType()),
     StructField("ts", TimestampType())
 ]),
 "Struct<int: int, timestamp:timestamp>":
 StructType([
     StructField("int", IntegerType()),
예제 #13
0
def test_too_much_closed_map():
    with pytest.raises(ParseException):
        StructType.fromDDL("map<int, boolean>>")
예제 #14
0
def test_wrong_type():
    with pytest.raises(ParseException):
        StructType.fromDDL("blabla")
예제 #15
0
def test_array_short():
    schema = StructType.fromDDL("a: array< short>")
    assert schema == StructType([
        StructField('a', ArrayType(ShortType()), True),
    ])
    assert str(schema) == 'StructType(List(StructField(a,ArrayType(ShortType,true),true)))'
예제 #16
0
def test_comma_at_end():
    with pytest.raises(ParseException):
        print(StructType.fromDDL("a: int,"))
예제 #17
0
 def test_csv_read_without_schema(self):
     df = spark.read.csv(os.path.join(
         os.path.dirname(os.path.realpath(__file__)), "data/fundings/"),
                         header=True)
     self.assertEqual(df.count(), 4)
     self.assertEqual(
         df.schema,
         StructType([
             StructField("permalink", StringType()),
             StructField("company", StringType()),
             StructField("numEmps", StringType()),
             StructField("category", StringType()),
             StructField("city", StringType()),
             StructField("state", StringType()),
             StructField("fundedDate", StringType()),
             StructField("raisedAmt", StringType()),
             StructField("raisedCurrency", StringType()),
             StructField("round", StringType())
         ]))
     self.assertListEqual([Row(**r.asDict()) for r in df.collect()], [
         Row(permalink='mycityfaces',
             company='MyCityFaces',
             numEmps='7',
             category='web',
             city='Scottsdale',
             state='AZ',
             fundedDate='2008-01-01',
             raisedAmt='50000',
             raisedCurrency='USD',
             round='seed'),
         Row(permalink='flypaper',
             company='Flypaper',
             numEmps=None,
             category='web',
             city='Phoenix',
             state='AZ',
             fundedDate='2008-02-01',
             raisedAmt='3000000',
             raisedCurrency='USD',
             round='a'),
         Row(permalink='chosenlist-com',
             company='ChosenList.com',
             numEmps='5',
             category='web',
             city='Scottsdale',
             state='AZ',
             fundedDate='2008-01-25',
             raisedAmt='233750',
             raisedCurrency='USD',
             round='angel'),
         Row(permalink='digg',
             company='Digg',
             numEmps='60',
             category='web',
             city='San Francisco',
             state='CA',
             fundedDate='2006-12-01',
             raisedAmt='8500000',
             raisedCurrency='USD',
             round='b')
     ])
예제 #18
0
import pytest

from gelanis import Row
from gelanis.sql._ast.ast_to_python import parse_expression
from gelanis.sql.types import IntegerType, StructField, StructType

ROW = Row(a=1, b=2, c=3)
SCHEMA = StructType([
    StructField("a", IntegerType()),
    StructField("b", IntegerType()),
])

SCENARIOS = {
    'Least(-1,0,1)': ('least', 'least(-1, 0, 1)', -1),
    'GREATEST(-1,0,1)': ('greatest', 'greatest(-1, 0, 1)', 1),
    'shiftRight ( 42, 1 )': ('shiftright', 'shiftright(42, 1)', 21),
    'ShiftLeft ( 42, 1 )': ('shiftleft', 'shiftleft(42, 1)', 84),
    "concat_ws('/', a, b )": ('concat_ws', 'concat_ws(/, a, b)', "1/2"),
    'instr(a, a)': ('instr', 'instr(a, a)', 1),  # rely on columns
    'instr(a, b)': ('instr', 'instr(a, b)', 0),  # rely on columns
    "instr('abc', 'c')": ('instr', 'instr(abc, c)', 3),  # rely on lit
}


@pytest.mark.parametrize('string, expected', SCENARIOS.items())
def test_functions(string, expected):
    operator, expected_parsed, expected_result = expected
    actual_parsed = parse_expression(string, True)

    assert expected_parsed == str(actual_parsed)
    assert operator == actual_parsed.pretty_name