Пример #1
0
    def test_udt(self):
        from pyspark.sql.types import _parse_datatype_json_string, _infer_type, _verify_type
        from pyspark.sql.tests import ExamplePointUDT, ExamplePoint

        def check_datatype(datatype):
            pickled = pickle.loads(pickle.dumps(datatype))
            assert datatype == pickled
            scala_datatype = self.sqlCtx._ssql_ctx.parseDataType(datatype.json())
            python_datatype = _parse_datatype_json_string(scala_datatype.json())
            assert datatype == python_datatype

        check_datatype(ExamplePointUDT())
        structtype_with_udt = StructType([StructField("label", DoubleType(), False),
                                          StructField("point", ExamplePointUDT(), False)])
        check_datatype(structtype_with_udt)
        p = ExamplePoint(1.0, 2.0)
        self.assertEqual(_infer_type(p), ExamplePointUDT())
        _verify_type(ExamplePoint(1.0, 2.0), ExamplePointUDT())
        self.assertRaises(ValueError, lambda: _verify_type([1.0, 2.0], ExamplePointUDT()))

        check_datatype(PythonOnlyUDT())
        structtype_with_udt = StructType([StructField("label", DoubleType(), False),
                                          StructField("point", PythonOnlyUDT(), False)])
        check_datatype(structtype_with_udt)
        p = PythonOnlyPoint(1.0, 2.0)
        self.assertEqual(_infer_type(p), PythonOnlyUDT())
        _verify_type(PythonOnlyPoint(1.0, 2.0), PythonOnlyUDT())
        self.assertRaises(ValueError, lambda: _verify_type([1.0, 2.0], PythonOnlyUDT()))
Пример #2
0
 def test_apply_schema_with_udt(self):
     from pyspark.sql.tests import ExamplePoint, ExamplePointUDT
     row = (1.0, ExamplePoint(1.0, 2.0))
     rdd = self.sc.parallelize([row])
     schema = StructType([StructField("label", DoubleType(), False),
                          StructField("point", ExamplePointUDT(), False)])
     df = rdd.toDF(schema)
     point = df.head().point
     self.assertEquals(point, ExamplePoint(1.0, 2.0))
Пример #3
0
 def test_udf_with_udt(self):
     from pyspark.sql.tests import ExamplePoint, ExamplePointUDT
     row = Row(label=1.0, point=ExamplePoint(1.0, 2.0))
     df = self.sc.parallelize([row]).toDF()
     self.assertEqual(1.0, df.map(lambda r: r.point.x).first())
     udf = UserDefinedFunction(lambda p: p.y, DoubleType())
     self.assertEqual(2.0, df.select(udf(df.point)).first()[0])
     udf2 = UserDefinedFunction(lambda p: ExamplePoint(p.x + 1, p.y + 1),
                                ExamplePointUDT())
     self.assertEqual(ExamplePoint(2.0, 3.0),
                      df.select(udf2(df.point)).first()[0])
Пример #4
0
    def test_gapply_universal_udt_val(self):
        def pandasAggFunction(series):
            x = float(series.apply(lambda pt: int(pt.x) + int(pt.y)).sum())
            return ExamplePoint(
                x, x)  # still deterministic, can have exact equivalence test

        dataType = ExamplePointUDT()
        dataGen = lambda: ExamplePoint(
            float(random.randrange(GapplyTests.NVALS)),
            float(random.randrange(GapplyTests.NVALS)))
        self.checkGapplyEquivalentToPandas(pandasAggFunction, dataType,
                                           dataGen)
Пример #5
0
    def test_apply_schema_with_udt(self):
        from pyspark.sql.tests import ExamplePoint, ExamplePointUDT
        row = (1.0, ExamplePoint(1.0, 2.0))
        schema = StructType([StructField("label", DoubleType(), False),
                             StructField("point", ExamplePointUDT(), False)])
        df = self.sqlCtx.createDataFrame([row], schema)
        point = df.head().point
        self.assertEquals(point, ExamplePoint(1.0, 2.0))

        row = (1.0, PythonOnlyPoint(1.0, 2.0))
        schema = StructType([StructField("label", DoubleType(), False),
                             StructField("point", PythonOnlyUDT(), False)])
        df = self.sqlCtx.createDataFrame([row], schema)
        point = df.head().point
        self.assertEquals(point, PythonOnlyPoint(1.0, 2.0))
Пример #6
0
    def runTest(self):
        from pyspark.sql.tests import ExamplePoint, ExamplePointUDT, DoubleType, StructType, StructField, Row
        row1 = (1.0, ExamplePoint(1.0, 2.0))
        row2 = (2.0, ExamplePoint(3.0, 4.0))
        schema = StructType([
            StructField("label", DoubleType(), False),
            StructField("point", ExamplePointUDT(), False)
        ])
        df1 = self.spark.createDataFrame([row1], schema)
        df2 = self.spark.createDataFrame([row2], schema)

        result = df1.union(df2).orderBy("label").collect()
        self.assertEqual(result, [
            Row(label=1.0, point=ExamplePoint(1.0, 2.0)),
            Row(label=2.0, point=ExamplePoint(3.0, 4.0))
        ])
Пример #7
0
    def test_udf_with_udt(self):
        from pyspark.sql.tests import ExamplePoint, ExamplePointUDT
        row = Row(label=1.0, point=ExamplePoint(1.0, 2.0))
        df = self.sqlCtx.createDataFrame([row])
        self.assertEqual(1.0, df.map(lambda r: r.point.x).first())
        udf = UserDefinedFunction(lambda p: p.y, DoubleType())
        self.assertEqual(2.0, df.select(udf(df.point)).first()[0])
        udf2 = UserDefinedFunction(lambda p: ExamplePoint(p.x + 1, p.y + 1), ExamplePointUDT())
        self.assertEqual(ExamplePoint(2.0, 3.0), df.select(udf2(df.point)).first()[0])

        row = Row(label=1.0, point=PythonOnlyPoint(1.0, 2.0))
        df = self.sqlCtx.createDataFrame([row])
        self.assertEqual(1.0, df.map(lambda r: r.point.x).first())
        udf = UserDefinedFunction(lambda p: p.y, DoubleType())
        self.assertEqual(2.0, df.select(udf(df.point)).first()[0])
        udf2 = UserDefinedFunction(lambda p: PythonOnlyPoint(p.x + 1, p.y + 1), PythonOnlyUDT())
        self.assertEqual(PythonOnlyPoint(2.0, 3.0), df.select(udf2(df.point)).first()[0])
Пример #8
0
class ExamplePoint:
    """
    An example class to demonstrate UDT in Scala, Java, and Python.
    """

    __UDT__ = ExamplePointUDT()

    def __init__(self, x, y):
        self.x = x
        self.y = y

    def __repr__(self):
        return "ExamplePoint(%s,%s)" % (self.x, self.y)

    def __str__(self):
        return "(%s,%s)" % (self.x, self.y)

    def __eq__(self, other):
        return isinstance(other, ExamplePoint) and \
            other.x == self.x and other.y == self.y