def test_udt(self): from pyspark.sql.types import _parse_datatype_json_string, _infer_type, _make_type_verifier def check_datatype(datatype): pickled = pickle.loads(pickle.dumps(datatype)) assert datatype == pickled scala_datatype = self.spark._jsparkSession.parseDataType(datatype.json()) python_datatype = _parse_datatype_json_string(scala_datatype.json()) assert datatype == python_datatype check_datatype(ExamplePointUDT()) structtype_with_udt = StructType([StructField("label", DoubleType(), False), StructField("point", ExamplePointUDT(), False)]) check_datatype(structtype_with_udt) p = ExamplePoint(1.0, 2.0) self.assertEqual(_infer_type(p), ExamplePointUDT()) _make_type_verifier(ExamplePointUDT())(ExamplePoint(1.0, 2.0)) self.assertRaises(ValueError, lambda: _make_type_verifier(ExamplePointUDT())([1.0, 2.0])) check_datatype(PythonOnlyUDT()) structtype_with_udt = StructType([StructField("label", DoubleType(), False), StructField("point", PythonOnlyUDT(), False)]) check_datatype(structtype_with_udt) p = PythonOnlyPoint(1.0, 2.0) self.assertEqual(_infer_type(p), PythonOnlyUDT()) _make_type_verifier(PythonOnlyUDT())(PythonOnlyPoint(1.0, 2.0)) self.assertRaises( ValueError, lambda: _make_type_verifier(PythonOnlyUDT())([1.0, 2.0]))
def test_cast_to_udt_with_udt(self): from pyspark.sql.functions import col row = Row(point=ExamplePoint(1.0, 2.0), python_only_point=PythonOnlyPoint(1.0, 2.0)) df = self.spark.createDataFrame([row]) self.assertRaises(AnalysisException, lambda: df.select(col("point").cast(PythonOnlyUDT()))) self.assertRaises(AnalysisException, lambda: df.select(col("python_only_point").cast(ExamplePointUDT())))
def test_cast_to_string_with_udt(self): from pyspark.sql.functions import col row = (ExamplePoint(1.0, 2.0), PythonOnlyPoint(3.0, 4.0)) schema = StructType([StructField("point", ExamplePointUDT(), False), StructField("pypoint", PythonOnlyUDT(), False)]) df = self.spark.createDataFrame([row], schema) result = df.select(col('point').cast('string'), col('pypoint').cast('string')).head() self.assertEqual(result, Row(point=u'(1.0, 2.0)', pypoint=u'[3.0, 4.0]'))
def test_apply_schema_with_udt(self): row = (1.0, ExamplePoint(1.0, 2.0)) schema = StructType([StructField("label", DoubleType(), False), StructField("point", ExamplePointUDT(), False)]) df = self.spark.createDataFrame([row], schema) point = df.head().point self.assertEqual(point, ExamplePoint(1.0, 2.0)) row = (1.0, PythonOnlyPoint(1.0, 2.0)) schema = StructType([StructField("label", DoubleType(), False), StructField("point", PythonOnlyUDT(), False)]) df = self.spark.createDataFrame([row], schema) point = df.head().point self.assertEqual(point, PythonOnlyPoint(1.0, 2.0))
def test_union_with_udt(self): row1 = (1.0, ExamplePoint(1.0, 2.0)) row2 = (2.0, ExamplePoint(3.0, 4.0)) schema = StructType([ StructField("label", DoubleType(), False), StructField("point", ExamplePointUDT(), False) ]) df1 = self.spark.createDataFrame([row1], schema) df2 = self.spark.createDataFrame([row2], schema) result = df1.union(df2).orderBy("label").collect() self.assertEqual(result, [ Row(label=1.0, point=ExamplePoint(1.0, 2.0)), Row(label=2.0, point=ExamplePoint(3.0, 4.0)) ])
def test_udf_with_udt(self): row = Row(label=1.0, point=ExamplePoint(1.0, 2.0)) df = self.spark.createDataFrame([row]) self.assertEqual(1.0, df.rdd.map(lambda r: r.point.x).first()) udf = UserDefinedFunction(lambda p: p.y, DoubleType()) self.assertEqual(2.0, df.select(udf(df.point)).first()[0]) udf2 = UserDefinedFunction(lambda p: ExamplePoint(p.x + 1, p.y + 1), ExamplePointUDT()) self.assertEqual(ExamplePoint(2.0, 3.0), df.select(udf2(df.point)).first()[0]) row = Row(label=1.0, point=PythonOnlyPoint(1.0, 2.0)) df = self.spark.createDataFrame([row]) self.assertEqual(1.0, df.rdd.map(lambda r: r.point.x).first()) udf = UserDefinedFunction(lambda p: p.y, DoubleType()) self.assertEqual(2.0, df.select(udf(df.point)).first()[0]) udf2 = UserDefinedFunction(lambda p: PythonOnlyPoint(p.x + 1, p.y + 1), PythonOnlyUDT()) self.assertEqual(PythonOnlyPoint(2.0, 3.0), df.select(udf2(df.point)).first()[0])
def test_verify_type_not_nullable(self): import array import datetime import decimal schema = StructType([ StructField('s', StringType(), nullable=False), StructField('i', IntegerType(), nullable=True)]) class MyObj: def __init__(self, **kwargs): for k, v in kwargs.items(): setattr(self, k, v) # obj, data_type success_spec = [ # String ("", StringType()), (u"", StringType()), (1, StringType()), (1.0, StringType()), ([], StringType()), ({}, StringType()), # UDT (ExamplePoint(1.0, 2.0), ExamplePointUDT()), # Boolean (True, BooleanType()), # Byte (-(2**7), ByteType()), (2**7 - 1, ByteType()), # Short (-(2**15), ShortType()), (2**15 - 1, ShortType()), # Integer (-(2**31), IntegerType()), (2**31 - 1, IntegerType()), # Long (-(2**63), LongType()), (2**63 - 1, LongType()), # Float & Double (1.0, FloatType()), (1.0, DoubleType()), # Decimal (decimal.Decimal("1.0"), DecimalType()), # Binary (bytearray([1, 2]), BinaryType()), # Date/Timestamp (datetime.date(2000, 1, 2), DateType()), (datetime.datetime(2000, 1, 2, 3, 4), DateType()), (datetime.datetime(2000, 1, 2, 3, 4), TimestampType()), # Array ([], ArrayType(IntegerType())), (["1", None], ArrayType(StringType(), containsNull=True)), ([1, 2], ArrayType(IntegerType())), ((1, 2), ArrayType(IntegerType())), (array.array('h', [1, 2]), ArrayType(IntegerType())), # Map ({}, MapType(StringType(), IntegerType())), ({"a": 1}, MapType(StringType(), IntegerType())), ({"a": None}, MapType(StringType(), IntegerType(), valueContainsNull=True)), # Struct ({"s": "a", "i": 1}, schema), ({"s": "a", "i": None}, schema), ({"s": "a"}, schema), ({"s": "a", "f": 1.0}, schema), (Row(s="a", i=1), schema), (Row(s="a", i=None), schema), (Row(s="a", i=1, f=1.0), schema), (["a", 1], schema), (["a", None], schema), (("a", 1), schema), (MyObj(s="a", i=1), schema), (MyObj(s="a", i=None), schema), (MyObj(s="a"), schema), ] # obj, data_type, exception class failure_spec = [ # String (match anything but None) (None, StringType(), ValueError), # UDT (ExamplePoint(1.0, 2.0), PythonOnlyUDT(), ValueError), # Boolean (1, BooleanType(), TypeError), ("True", BooleanType(), TypeError), ([1], BooleanType(), TypeError), # Byte (-(2**7) - 1, ByteType(), ValueError), (2**7, ByteType(), ValueError), ("1", ByteType(), TypeError), (1.0, ByteType(), TypeError), # Short (-(2**15) - 1, ShortType(), ValueError), (2**15, ShortType(), ValueError), # Integer (-(2**31) - 1, IntegerType(), ValueError), (2**31, IntegerType(), ValueError), # Float & Double (1, FloatType(), TypeError), (1, DoubleType(), TypeError), # Decimal (1.0, DecimalType(), TypeError), (1, DecimalType(), TypeError), ("1.0", DecimalType(), TypeError), # Binary (1, BinaryType(), TypeError), # Date/Timestamp ("2000-01-02", DateType(), TypeError), (946811040, TimestampType(), TypeError), # Array (["1", None], ArrayType(StringType(), containsNull=False), ValueError), ([1, "2"], ArrayType(IntegerType()), TypeError), # Map ({"a": 1}, MapType(IntegerType(), IntegerType()), TypeError), ({"a": "1"}, MapType(StringType(), IntegerType()), TypeError), ({"a": None}, MapType(StringType(), IntegerType(), valueContainsNull=False), ValueError), # Struct ({"s": "a", "i": "1"}, schema, TypeError), (Row(s="a"), schema, ValueError), # Row can't have missing field (Row(s="a", i="1"), schema, TypeError), (["a"], schema, ValueError), (["a", "1"], schema, TypeError), (MyObj(s="a", i="1"), schema, TypeError), (MyObj(s=None, i="1"), schema, ValueError), ] # Check success cases for obj, data_type in success_spec: try: _make_type_verifier(data_type, nullable=False)(obj) except Exception: self.fail("verify_type(%s, %s, nullable=False)" % (obj, data_type)) # Check failure cases for obj, data_type, exp in failure_spec: msg = "verify_type(%s, %s, nullable=False) == %s" % (obj, data_type, exp) with self.assertRaises(exp, msg=msg): _make_type_verifier(data_type, nullable=False)(obj)