def test_alias__long_bigint(): schema = StructType.fromDDL('i1: long, i2: bigint') assert schema == StructType([ StructField('i1', LongType(), True), StructField('i2', LongType(), True), ]) assert str(schema) == 'StructType(List(StructField(i1,LongType,true),StructField(i2,LongType,true)))'
def test_double_string(): schema = StructType.fromDDL("a DOUBLE, b STRING") assert schema == StructType([ StructField('a', DoubleType(), True), StructField('b', StringType(), True), ]) assert str(schema) == 'StructType(List(StructField(a,DoubleType,true),StructField(b,StringType,true)))'
def test_byte_decimal(): schema = StructType.fromDDL("a: byte, b: decimal( 16 , 8 ) ") assert schema == StructType([ StructField('a', ByteType(), True), StructField('b', DecimalType(16, 8), True), ]) assert str(schema) == 'StructType(List(StructField(a,ByteType,true),StructField(b,DecimalType(16,8),true)))'
def test_nested_array(): schema = StructType.fromDDL('some_str: string, arr: array<array<string>>') assert schema == StructType([ StructField('some_str', StringType(), True), StructField('arr', ArrayType(ArrayType(StringType())), True), ]) assert str(schema) == 'StructType(List(' \ 'StructField(some_str,StringType,true),' \ 'StructField(arr,ArrayType(ArrayType(StringType,true),true),true)' \ '))'
def test_basic_entries(): schema = StructType.fromDDL('some_str: string, some_int: integer, some_date: date') assert schema == StructType([ StructField('some_str', StringType(), True), StructField('some_int', IntegerType(), True), StructField('some_date', DateType(), True), ]) assert str(schema) == ( 'StructType(List(' 'StructField(some_str,StringType,true),' 'StructField(some_int,IntegerType,true),' 'StructField(some_date,DateType,true)' '))' )
def test_session_create_data_frame_from_pandas_data_frame(self): try: # Pandas is an optional dependency # pylint: disable=import-outside-toplevel import pandas as pd except ImportError as e: raise ImportError("pandas is not importable") from e pdf = pd.DataFrame([(1, "one"), (2, "two"), (3, "three")]) df = self.spark.createDataFrame(pdf) self.assertEqual(df.count(), 3) self.assertListEqual(df.collect(), [ Row(**{ "0": 1, "1": 'one' }), Row(**{ "0": 2, "1": 'two' }), Row(**{ "0": 3, "2": 'three' }) ]) self.assertEqual( df.schema, StructType([ StructField("0", LongType(), True), StructField("1", StringType(), True) ]))
def test_session_create_data_frame_from_list_with_schema(self): schema = StructType( [StructField("map", MapType(StringType(), IntegerType()), True)]) df = self.spark.createDataFrame([({'a': 1}, )], schema=schema) self.assertEqual(df.count(), 1) self.assertListEqual(df.collect(), [Row(map={'a': 1})]) self.assertEqual(df.schema, schema)
def test_csv_read_with_given_schema(self): schema = StructType([ StructField("permalink", StringType()), StructField("company", StringType()), StructField("numEmps", IntegerType()), StructField("category", StringType()), StructField("city", StringType()), StructField("state", StringType()), StructField("fundedDate", DateType()), StructField("raisedAmt", IntegerType()), StructField("raisedCurrency", StringType()), StructField("round", StringType()) ]) df = spark.read.schema(schema).csv(os.path.join( os.path.dirname(os.path.realpath(__file__)), "data/fundings/"), header=True) self.assertEqual([Row(**r.asDict()) for r in df.collect()], [ Row(permalink='mycityfaces', company='MyCityFaces', numEmps=7, category='web', city='Scottsdale', state='AZ', fundedDate=datetime.date(2008, 1, 1), raisedAmt=50000, raisedCurrency='USD', round='seed'), Row(permalink='flypaper', company='Flypaper', numEmps=None, category='web', city='Phoenix', state='AZ', fundedDate=datetime.date(2008, 2, 1), raisedAmt=3000000, raisedCurrency='USD', round='a'), Row(permalink='chosenlist-com', company='ChosenList.com', numEmps=5, category='web', city='Scottsdale', state='AZ', fundedDate=datetime.date(2008, 1, 25), raisedAmt=233750, raisedCurrency='USD', round='angel'), Row(permalink='digg', company='Digg', numEmps=60, category='web', city='San Francisco', state='CA', fundedDate=datetime.date(2006, 12, 1), raisedAmt=8500000, raisedCurrency='USD', round='b') ])
def test_session_create_data_frame_from_list_with_col_names(self): df = self.spark.createDataFrame([(0.0, [1.0, 0.8]), (1.0, [0.0, 0.0]), (2.0, [0.5, 0.5])], ["label", "features"]) self.assertEqual(df.count(), 3) self.assertListEqual(df.collect(), [ row_from_keyed_values([("label", 0.0), ("features", [1.0, 0.8])]), row_from_keyed_values([("label", 1.0), ("features", [0.0, 0.0])]), row_from_keyed_values([("label", 2.0), ("features", [0.5, 0.5])]), ]) self.assertEqual( df.schema, StructType([ StructField("label", DoubleType(), True), StructField("features", ArrayType(DoubleType(), True), True) ]))
def test_session_create_data_frame_from_list(self): df = self.spark.createDataFrame([ (1, "one"), (2, "two"), (3, "three"), ]) self.assertEqual(df.count(), 3) self.assertListEqual( df.collect(), [Row(_1=1, _2='one'), Row(_1=2, _2='two'), Row(_1=3, _2='three')]) self.assertEqual( df.schema, StructType([ StructField("_1", LongType(), True), StructField("_2", StringType(), True) ]))
def test_column_stat_helper(): """ Expected quantile values come from use of org.apache.spark.sql.catalyst.util.QuantileSummaries """ schema = StructType([StructField("value", IntegerType())]) helper = ColumnStatHelper(col("value")) for i in range(1, 100001): helper.merge(Row(value=i), schema) helper.finalize() assert helper.count == 100000 assert helper.min == 1 assert helper.max == 100000 assert helper.mean == 50000.5 assert helper.stddev == 28867.65779668774 # sample standard deviation assert helper.get_quantile(0) == 1 assert helper.get_quantile(0.25) == 24998 assert helper.get_quantile(0.5) == 50000 assert helper.get_quantile(0.75) == 74993 assert helper.get_quantile(1) == 100000
DecimalType(10, 0), "Array<string>": ArrayType(StringType()), "Array<doublE>": ArrayType(DoubleType()), "Array<int>": ArrayType(IntegerType()), "Array<map<int, tinYint>>": ArrayType(MapType(IntegerType(), ByteType())), "Map<string, int>": MapType(StringType(), IntegerType()), "Map < integer, String >": MapType(IntegerType(), StringType()), "Struct<name: string, age: int>": StructType([ StructField(name="name", dataType=StringType()), StructField(name="age", dataType=IntegerType()), ]), "array<struct<tinYint:tinyint>>": ArrayType(StructType([StructField('tinYint', ByteType())]), ), "MAp<int, ARRAY<double>>": MapType(IntegerType(), ArrayType(DoubleType())), "MAP<int, struct<varchar:string>>": MapType(IntegerType(), StructType([StructField("varchar", StringType())])), "struct<intType: int, ts:timestamp>": StructType([ StructField("intType", IntegerType()), StructField("ts", TimestampType()) ]), "Struct<int: int, timestamp:timestamp>": StructType([ StructField("int", IntegerType()),
def test_too_much_closed_map(): with pytest.raises(ParseException): StructType.fromDDL("map<int, boolean>>")
def test_wrong_type(): with pytest.raises(ParseException): StructType.fromDDL("blabla")
def test_array_short(): schema = StructType.fromDDL("a: array< short>") assert schema == StructType([ StructField('a', ArrayType(ShortType()), True), ]) assert str(schema) == 'StructType(List(StructField(a,ArrayType(ShortType,true),true)))'
def test_comma_at_end(): with pytest.raises(ParseException): print(StructType.fromDDL("a: int,"))
def test_csv_read_without_schema(self): df = spark.read.csv(os.path.join( os.path.dirname(os.path.realpath(__file__)), "data/fundings/"), header=True) self.assertEqual(df.count(), 4) self.assertEqual( df.schema, StructType([ StructField("permalink", StringType()), StructField("company", StringType()), StructField("numEmps", StringType()), StructField("category", StringType()), StructField("city", StringType()), StructField("state", StringType()), StructField("fundedDate", StringType()), StructField("raisedAmt", StringType()), StructField("raisedCurrency", StringType()), StructField("round", StringType()) ])) self.assertListEqual([Row(**r.asDict()) for r in df.collect()], [ Row(permalink='mycityfaces', company='MyCityFaces', numEmps='7', category='web', city='Scottsdale', state='AZ', fundedDate='2008-01-01', raisedAmt='50000', raisedCurrency='USD', round='seed'), Row(permalink='flypaper', company='Flypaper', numEmps=None, category='web', city='Phoenix', state='AZ', fundedDate='2008-02-01', raisedAmt='3000000', raisedCurrency='USD', round='a'), Row(permalink='chosenlist-com', company='ChosenList.com', numEmps='5', category='web', city='Scottsdale', state='AZ', fundedDate='2008-01-25', raisedAmt='233750', raisedCurrency='USD', round='angel'), Row(permalink='digg', company='Digg', numEmps='60', category='web', city='San Francisco', state='CA', fundedDate='2006-12-01', raisedAmt='8500000', raisedCurrency='USD', round='b') ])
import pytest from gelanis import Row from gelanis.sql._ast.ast_to_python import parse_expression from gelanis.sql.types import IntegerType, StructField, StructType ROW = Row(a=1, b=2, c=3) SCHEMA = StructType([ StructField("a", IntegerType()), StructField("b", IntegerType()), ]) SCENARIOS = { 'Least(-1,0,1)': ('least', 'least(-1, 0, 1)', -1), 'GREATEST(-1,0,1)': ('greatest', 'greatest(-1, 0, 1)', 1), 'shiftRight ( 42, 1 )': ('shiftright', 'shiftright(42, 1)', 21), 'ShiftLeft ( 42, 1 )': ('shiftleft', 'shiftleft(42, 1)', 84), "concat_ws('/', a, b )": ('concat_ws', 'concat_ws(/, a, b)', "1/2"), 'instr(a, a)': ('instr', 'instr(a, a)', 1), # rely on columns 'instr(a, b)': ('instr', 'instr(a, b)', 0), # rely on columns "instr('abc', 'c')": ('instr', 'instr(abc, c)', 3), # rely on lit } @pytest.mark.parametrize('string, expected', SCENARIOS.items()) def test_functions(string, expected): operator, expected_parsed, expected_result = expected actual_parsed = parse_expression(string, True) assert expected_parsed == str(actual_parsed) assert operator == actual_parsed.pretty_name