Example #1
0
    def test_cast_to_string_with_udt(self):
        from pyspark.sql.functions import col
        row = (ExamplePoint(1.0, 2.0), PythonOnlyPoint(3.0, 4.0))
        schema = StructType([
            StructField("point", ExamplePointUDT(), False),
            StructField("pypoint", PythonOnlyUDT(), False)
        ])
        df = self.spark.createDataFrame([row], schema)

        result = df.select(
            col('point').cast('string'),
            col('pypoint').cast('string')).head()
        self.assertEqual(result, Row(point=u'(1.0, 2.0)',
                                     pypoint=u'[3.0, 4.0]'))
Example #2
0
    def test_complex_nested_udt_in_df(self):
        from pyspark.sql.functions import udf

        schema = StructType().add("key",
                                  LongType()).add("val", PythonOnlyUDT())
        df = self.spark.createDataFrame(
            [(i % 3, PythonOnlyPoint(float(i), float(i))) for i in range(10)],
            schema=schema)
        df.collect()

        gd = df.groupby("key").agg({"val": "collect_list"})
        gd.collect()
        udf = udf(lambda k, v: [(k, v[0])], ArrayType(df.schema))
        gd.select(udf(*gd)).collect()
Example #3
0
    def test_udt(self):
        from pyspark.sql.types import _parse_datatype_json_string, _infer_type, _make_type_verifier

        def check_datatype(datatype):
            pickled = pickle.loads(pickle.dumps(datatype))
            assert datatype == pickled
            scala_datatype = self.spark._jsparkSession.parseDataType(
                datatype.json())
            python_datatype = _parse_datatype_json_string(
                scala_datatype.json())
            assert datatype == python_datatype

        check_datatype(ExamplePointUDT())
        structtype_with_udt = StructType([
            StructField("label", DoubleType(), False),
            StructField("point", ExamplePointUDT(), False)
        ])
        check_datatype(structtype_with_udt)
        p = ExamplePoint(1.0, 2.0)
        self.assertEqual(_infer_type(p), ExamplePointUDT())
        _make_type_verifier(ExamplePointUDT())(ExamplePoint(1.0, 2.0))
        self.assertRaises(
            ValueError, lambda: _make_type_verifier(ExamplePointUDT())
            ([1.0, 2.0]))

        check_datatype(PythonOnlyUDT())
        structtype_with_udt = StructType([
            StructField("label", DoubleType(), False),
            StructField("point", PythonOnlyUDT(), False)
        ])
        check_datatype(structtype_with_udt)
        p = PythonOnlyPoint(1.0, 2.0)
        self.assertEqual(_infer_type(p), PythonOnlyUDT())
        _make_type_verifier(PythonOnlyUDT())(PythonOnlyPoint(1.0, 2.0))
        self.assertRaises(
            ValueError, lambda: _make_type_verifier(PythonOnlyUDT())
            ([1.0, 2.0]))
Example #4
0
    def test_infer_schema_with_udt(self):
        row = Row(label=1.0, point=ExamplePoint(1.0, 2.0))
        df = self.spark.createDataFrame([row])
        schema = df.schema
        field = [f for f in schema.fields if f.name == "point"][0]
        self.assertEqual(type(field.dataType), ExamplePointUDT)

        with self.tempView("labeled_point"):
            df.createOrReplaceTempView("labeled_point")
            point = self.spark.sql(
                "SELECT point FROM labeled_point").head().point
            self.assertEqual(point, ExamplePoint(1.0, 2.0))

        row = Row(label=1.0, point=PythonOnlyPoint(1.0, 2.0))
        df = self.spark.createDataFrame([row])
        schema = df.schema
        field = [f for f in schema.fields if f.name == "point"][0]
        self.assertEqual(type(field.dataType), PythonOnlyUDT)

        with self.tempView("labeled_point"):
            df.createOrReplaceTempView("labeled_point")
            point = self.spark.sql(
                "SELECT point FROM labeled_point").head().point
            self.assertEqual(point, PythonOnlyPoint(1.0, 2.0))
Example #5
0
 def myudf(x):
     if x > 0:
         return PythonOnlyPoint(float(x), float(x))
Example #6
0
 def test_simple_udt_in_df(self):
     schema = StructType().add("key", LongType()).add("val", PythonOnlyUDT())
     df = self.spark.createDataFrame(
         [(i % 3, PythonOnlyPoint(float(i), float(i))) for i in range(10)],
         schema=schema)
     df.collect()