def test_cast_to_string_with_udt(self): from pyspark.sql.functions import col row = (ExamplePoint(1.0, 2.0), PythonOnlyPoint(3.0, 4.0)) schema = StructType([ StructField("point", ExamplePointUDT(), False), StructField("pypoint", PythonOnlyUDT(), False) ]) df = self.spark.createDataFrame([row], schema) result = df.select( col('point').cast('string'), col('pypoint').cast('string')).head() self.assertEqual(result, Row(point=u'(1.0, 2.0)', pypoint=u'[3.0, 4.0]'))
def test_complex_nested_udt_in_df(self): from pyspark.sql.functions import udf schema = StructType().add("key", LongType()).add("val", PythonOnlyUDT()) df = self.spark.createDataFrame( [(i % 3, PythonOnlyPoint(float(i), float(i))) for i in range(10)], schema=schema) df.collect() gd = df.groupby("key").agg({"val": "collect_list"}) gd.collect() udf = udf(lambda k, v: [(k, v[0])], ArrayType(df.schema)) gd.select(udf(*gd)).collect()
def test_udt(self): from pyspark.sql.types import _parse_datatype_json_string, _infer_type, _make_type_verifier def check_datatype(datatype): pickled = pickle.loads(pickle.dumps(datatype)) assert datatype == pickled scala_datatype = self.spark._jsparkSession.parseDataType( datatype.json()) python_datatype = _parse_datatype_json_string( scala_datatype.json()) assert datatype == python_datatype check_datatype(ExamplePointUDT()) structtype_with_udt = StructType([ StructField("label", DoubleType(), False), StructField("point", ExamplePointUDT(), False) ]) check_datatype(structtype_with_udt) p = ExamplePoint(1.0, 2.0) self.assertEqual(_infer_type(p), ExamplePointUDT()) _make_type_verifier(ExamplePointUDT())(ExamplePoint(1.0, 2.0)) self.assertRaises( ValueError, lambda: _make_type_verifier(ExamplePointUDT()) ([1.0, 2.0])) check_datatype(PythonOnlyUDT()) structtype_with_udt = StructType([ StructField("label", DoubleType(), False), StructField("point", PythonOnlyUDT(), False) ]) check_datatype(structtype_with_udt) p = PythonOnlyPoint(1.0, 2.0) self.assertEqual(_infer_type(p), PythonOnlyUDT()) _make_type_verifier(PythonOnlyUDT())(PythonOnlyPoint(1.0, 2.0)) self.assertRaises( ValueError, lambda: _make_type_verifier(PythonOnlyUDT()) ([1.0, 2.0]))
def test_infer_schema_with_udt(self): row = Row(label=1.0, point=ExamplePoint(1.0, 2.0)) df = self.spark.createDataFrame([row]) schema = df.schema field = [f for f in schema.fields if f.name == "point"][0] self.assertEqual(type(field.dataType), ExamplePointUDT) with self.tempView("labeled_point"): df.createOrReplaceTempView("labeled_point") point = self.spark.sql( "SELECT point FROM labeled_point").head().point self.assertEqual(point, ExamplePoint(1.0, 2.0)) row = Row(label=1.0, point=PythonOnlyPoint(1.0, 2.0)) df = self.spark.createDataFrame([row]) schema = df.schema field = [f for f in schema.fields if f.name == "point"][0] self.assertEqual(type(field.dataType), PythonOnlyUDT) with self.tempView("labeled_point"): df.createOrReplaceTempView("labeled_point") point = self.spark.sql( "SELECT point FROM labeled_point").head().point self.assertEqual(point, PythonOnlyPoint(1.0, 2.0))
def myudf(x): if x > 0: return PythonOnlyPoint(float(x), float(x))
def test_simple_udt_in_df(self): schema = StructType().add("key", LongType()).add("val", PythonOnlyUDT()) df = self.spark.createDataFrame( [(i % 3, PythonOnlyPoint(float(i), float(i))) for i in range(10)], schema=schema) df.collect()