def test_udt(self): from pyspark.sql.types import _parse_datatype_json_string, _infer_type, _make_type_verifier def check_datatype(datatype): pickled = pickle.loads(pickle.dumps(datatype)) assert datatype == pickled scala_datatype = self.spark._jsparkSession.parseDataType(datatype.json()) python_datatype = _parse_datatype_json_string(scala_datatype.json()) assert datatype == python_datatype check_datatype(ExamplePointUDT()) structtype_with_udt = StructType([StructField("label", DoubleType(), False), StructField("point", ExamplePointUDT(), False)]) check_datatype(structtype_with_udt) p = ExamplePoint(1.0, 2.0) self.assertEqual(_infer_type(p), ExamplePointUDT()) _make_type_verifier(ExamplePointUDT())(ExamplePoint(1.0, 2.0)) self.assertRaises(ValueError, lambda: _make_type_verifier(ExamplePointUDT())([1.0, 2.0])) check_datatype(PythonOnlyUDT()) structtype_with_udt = StructType([StructField("label", DoubleType(), False), StructField("point", PythonOnlyUDT(), False)]) check_datatype(structtype_with_udt) p = PythonOnlyPoint(1.0, 2.0) self.assertEqual(_infer_type(p), PythonOnlyUDT()) _make_type_verifier(PythonOnlyUDT())(PythonOnlyPoint(1.0, 2.0)) self.assertRaises( ValueError, lambda: _make_type_verifier(PythonOnlyUDT())([1.0, 2.0]))
def test_nested_udt_in_df(self): schema = StructType().add("key", LongType()).add("val", ArrayType(PythonOnlyUDT())) df = self.spark.createDataFrame( [(i % 3, [PythonOnlyPoint(float(i), float(i))]) for i in range(10)], schema=schema) df.collect() schema = StructType().add("key", LongType()).add("val", MapType(LongType(), PythonOnlyUDT())) df = self.spark.createDataFrame( [(i % 3, {i % 3: PythonOnlyPoint(float(i + 1), float(i + 1))}) for i in range(10)], schema=schema) df.collect()
def test_apply_schema_with_udt(self): row = (1.0, ExamplePoint(1.0, 2.0)) schema = StructType([StructField("label", DoubleType(), False), StructField("point", ExamplePointUDT(), False)]) df = self.spark.createDataFrame([row], schema) point = df.head().point self.assertEqual(point, ExamplePoint(1.0, 2.0)) row = (1.0, PythonOnlyPoint(1.0, 2.0)) schema = StructType([StructField("label", DoubleType(), False), StructField("point", PythonOnlyUDT(), False)]) df = self.spark.createDataFrame([row], schema) point = df.head().point self.assertEqual(point, PythonOnlyPoint(1.0, 2.0))
def test_parquet_with_udt(self): row = Row(label=1.0, point=ExamplePoint(1.0, 2.0)) df0 = self.spark.createDataFrame([row]) output_dir = os.path.join(self.tempdir.name, "labeled_point") df0.write.parquet(output_dir) df1 = self.spark.read.parquet(output_dir) point = df1.head().point self.assertEqual(point, ExamplePoint(1.0, 2.0)) row = Row(label=1.0, point=PythonOnlyPoint(1.0, 2.0)) df0 = self.spark.createDataFrame([row]) df0.write.parquet(output_dir, mode='overwrite') df1 = self.spark.read.parquet(output_dir) point = df1.head().point self.assertEqual(point, PythonOnlyPoint(1.0, 2.0))
def test_cast_to_udt_with_udt(self): from pyspark.sql.functions import col row = Row(point=ExamplePoint(1.0, 2.0), python_only_point=PythonOnlyPoint(1.0, 2.0)) df = self.spark.createDataFrame([row]) self.assertRaises(AnalysisException, lambda: df.select(col("point").cast(PythonOnlyUDT()))) self.assertRaises(AnalysisException, lambda: df.select(col("python_only_point").cast(ExamplePointUDT())))
def test_udf_with_udt(self): row = Row(label=1.0, point=ExamplePoint(1.0, 2.0)) df = self.spark.createDataFrame([row]) self.assertEqual(1.0, df.rdd.map(lambda r: r.point.x).first()) udf = UserDefinedFunction(lambda p: p.y, DoubleType()) self.assertEqual(2.0, df.select(udf(df.point)).first()[0]) udf2 = UserDefinedFunction(lambda p: ExamplePoint(p.x + 1, p.y + 1), ExamplePointUDT()) self.assertEqual(ExamplePoint(2.0, 3.0), df.select(udf2(df.point)).first()[0]) row = Row(label=1.0, point=PythonOnlyPoint(1.0, 2.0)) df = self.spark.createDataFrame([row]) self.assertEqual(1.0, df.rdd.map(lambda r: r.point.x).first()) udf = UserDefinedFunction(lambda p: p.y, DoubleType()) self.assertEqual(2.0, df.select(udf(df.point)).first()[0]) udf2 = UserDefinedFunction(lambda p: PythonOnlyPoint(p.x + 1, p.y + 1), PythonOnlyUDT()) self.assertEqual(PythonOnlyPoint(2.0, 3.0), df.select(udf2(df.point)).first()[0])
def test_simple_udt_in_df(self): schema = StructType().add("key", LongType()).add("val", PythonOnlyUDT()) df = self.spark.createDataFrame( [(i % 3, PythonOnlyPoint(float(i), float(i))) for i in range(10)], schema=schema) df.collect()
def test_cast_to_string_with_udt(self): from pyspark.sql.functions import col row = (ExamplePoint(1.0, 2.0), PythonOnlyPoint(3.0, 4.0)) schema = StructType([StructField("point", ExamplePointUDT(), False), StructField("pypoint", PythonOnlyUDT(), False)]) df = self.spark.createDataFrame([row], schema) result = df.select(col('point').cast('string'), col('pypoint').cast('string')).head() self.assertEqual(result, Row(point=u'(1.0, 2.0)', pypoint=u'[3.0, 4.0]'))
def test_udt_with_none(self): df = self.spark.range(0, 10, 1, 1) def myudf(x): if x > 0: return PythonOnlyPoint(float(x), float(x)) self.spark.catalog.registerFunction("udf", myudf, PythonOnlyUDT()) rows = [r[0] for r in df.selectExpr("udf(id)").take(2)] self.assertEqual(rows, [None, PythonOnlyPoint(1, 1)])
def test_infer_schema_with_udt(self): row = Row(label=1.0, point=ExamplePoint(1.0, 2.0)) df = self.spark.createDataFrame([row]) schema = df.schema field = [f for f in schema.fields if f.name == "point"][0] self.assertEqual(type(field.dataType), ExamplePointUDT) with self.tempView("labeled_point"): df.createOrReplaceTempView("labeled_point") point = self.spark.sql("SELECT point FROM labeled_point").head().point self.assertEqual(point, ExamplePoint(1.0, 2.0)) row = Row(label=1.0, point=PythonOnlyPoint(1.0, 2.0)) df = self.spark.createDataFrame([row]) schema = df.schema field = [f for f in schema.fields if f.name == "point"][0] self.assertEqual(type(field.dataType), PythonOnlyUDT) with self.tempView("labeled_point"): df.createOrReplaceTempView("labeled_point") point = self.spark.sql("SELECT point FROM labeled_point").head().point self.assertEqual(point, PythonOnlyPoint(1.0, 2.0))
def test_complex_nested_udt_in_df(self): from pyspark.sql.functions import udf schema = StructType().add("key", LongType()).add("val", PythonOnlyUDT()) df = self.spark.createDataFrame( [(i % 3, PythonOnlyPoint(float(i), float(i))) for i in range(10)], schema=schema) df.collect() gd = df.groupby("key").agg({"val": "collect_list"}) gd.collect() udf = udf(lambda k, v: [(k, v[0])], ArrayType(df.schema)) gd.select(udf(*gd)).collect()
def myudf(x): if x > 0: return PythonOnlyPoint(float(x), float(x))