Ejemplo n.º 1
0
    def test_udt(self):
        from pyspark.sql.types import _parse_datatype_json_string, _infer_type, _make_type_verifier

        def check_datatype(datatype):
            pickled = pickle.loads(pickle.dumps(datatype))
            assert datatype == pickled
            scala_datatype = self.spark._jsparkSession.parseDataType(datatype.json())
            python_datatype = _parse_datatype_json_string(scala_datatype.json())
            assert datatype == python_datatype

        check_datatype(ExamplePointUDT())
        structtype_with_udt = StructType([StructField("label", DoubleType(), False),
                                          StructField("point", ExamplePointUDT(), False)])
        check_datatype(structtype_with_udt)
        p = ExamplePoint(1.0, 2.0)
        self.assertEqual(_infer_type(p), ExamplePointUDT())
        _make_type_verifier(ExamplePointUDT())(ExamplePoint(1.0, 2.0))
        self.assertRaises(ValueError, lambda: _make_type_verifier(ExamplePointUDT())([1.0, 2.0]))

        check_datatype(PythonOnlyUDT())
        structtype_with_udt = StructType([StructField("label", DoubleType(), False),
                                          StructField("point", PythonOnlyUDT(), False)])
        check_datatype(structtype_with_udt)
        p = PythonOnlyPoint(1.0, 2.0)
        self.assertEqual(_infer_type(p), PythonOnlyUDT())
        _make_type_verifier(PythonOnlyUDT())(PythonOnlyPoint(1.0, 2.0))
        self.assertRaises(
            ValueError,
            lambda: _make_type_verifier(PythonOnlyUDT())([1.0, 2.0]))
Ejemplo n.º 2
0
    def test_nested_udt_in_df(self):
        schema = StructType().add("key", LongType()).add("val", ArrayType(PythonOnlyUDT()))
        df = self.spark.createDataFrame(
            [(i % 3, [PythonOnlyPoint(float(i), float(i))]) for i in range(10)],
            schema=schema)
        df.collect()

        schema = StructType().add("key", LongType()).add("val",
                                                         MapType(LongType(), PythonOnlyUDT()))
        df = self.spark.createDataFrame(
            [(i % 3, {i % 3: PythonOnlyPoint(float(i + 1), float(i + 1))}) for i in range(10)],
            schema=schema)
        df.collect()
Ejemplo n.º 3
0
 def test_cast_to_udt_with_udt(self):
     from pyspark.sql.functions import col
     row = Row(point=ExamplePoint(1.0, 2.0), python_only_point=PythonOnlyPoint(1.0, 2.0))
     df = self.spark.createDataFrame([row])
     self.assertRaises(AnalysisException, lambda: df.select(col("point").cast(PythonOnlyUDT())))
     self.assertRaises(AnalysisException,
                       lambda: df.select(col("python_only_point").cast(ExamplePointUDT())))
Ejemplo n.º 4
0
 def test_simple_udt_in_df(self):
     schema = StructType().add("key",
                               LongType()).add("val", PythonOnlyUDT())
     df = self.spark.createDataFrame(
         [(i % 3, PythonOnlyPoint(float(i), float(i))) for i in range(10)],
         schema=schema)
     df.collect()
Ejemplo n.º 5
0
    def test_cast_to_string_with_udt(self):
        from pyspark.sql.functions import col
        row = (ExamplePoint(1.0, 2.0), PythonOnlyPoint(3.0, 4.0))
        schema = StructType([StructField("point", ExamplePointUDT(), False),
                             StructField("pypoint", PythonOnlyUDT(), False)])
        df = self.spark.createDataFrame([row], schema)

        result = df.select(col('point').cast('string'), col('pypoint').cast('string')).head()
        self.assertEqual(result, Row(point=u'(1.0, 2.0)', pypoint=u'[3.0, 4.0]'))
Ejemplo n.º 6
0
    def test_udt_with_none(self):
        df = self.spark.range(0, 10, 1, 1)

        def myudf(x):
            if x > 0:
                return PythonOnlyPoint(float(x), float(x))

        self.spark.catalog.registerFunction("udf", myudf, PythonOnlyUDT())
        rows = [r[0] for r in df.selectExpr("udf(id)").take(2)]
        self.assertEqual(rows, [None, PythonOnlyPoint(1, 1)])
Ejemplo n.º 7
0
    def test_complex_nested_udt_in_df(self):
        from pyspark.sql.functions import udf

        schema = StructType().add("key", LongType()).add("val", PythonOnlyUDT())
        df = self.spark.createDataFrame(
            [(i % 3, PythonOnlyPoint(float(i), float(i))) for i in range(10)],
            schema=schema)
        df.collect()

        gd = df.groupby("key").agg({"val": "collect_list"})
        gd.collect()
        udf = udf(lambda k, v: [(k, v[0])], ArrayType(df.schema))
        gd.select(udf(*gd)).collect()
Ejemplo n.º 8
0
    def test_apply_schema_with_udt(self):
        row = (1.0, ExamplePoint(1.0, 2.0))
        schema = StructType([StructField("label", DoubleType(), False),
                             StructField("point", ExamplePointUDT(), False)])
        df = self.spark.createDataFrame([row], schema)
        point = df.head().point
        self.assertEqual(point, ExamplePoint(1.0, 2.0))

        row = (1.0, PythonOnlyPoint(1.0, 2.0))
        schema = StructType([StructField("label", DoubleType(), False),
                             StructField("point", PythonOnlyUDT(), False)])
        df = self.spark.createDataFrame([row], schema)
        point = df.head().point
        self.assertEqual(point, PythonOnlyPoint(1.0, 2.0))
Ejemplo n.º 9
0
    def test_udf_with_udt(self):
        row = Row(label=1.0, point=ExamplePoint(1.0, 2.0))
        df = self.spark.createDataFrame([row])
        self.assertEqual(1.0, df.rdd.map(lambda r: r.point.x).first())
        udf = UserDefinedFunction(lambda p: p.y, DoubleType())
        self.assertEqual(2.0, df.select(udf(df.point)).first()[0])
        udf2 = UserDefinedFunction(lambda p: ExamplePoint(p.x + 1, p.y + 1), ExamplePointUDT())
        self.assertEqual(ExamplePoint(2.0, 3.0), df.select(udf2(df.point)).first()[0])

        row = Row(label=1.0, point=PythonOnlyPoint(1.0, 2.0))
        df = self.spark.createDataFrame([row])
        self.assertEqual(1.0, df.rdd.map(lambda r: r.point.x).first())
        udf = UserDefinedFunction(lambda p: p.y, DoubleType())
        self.assertEqual(2.0, df.select(udf(df.point)).first()[0])
        udf2 = UserDefinedFunction(lambda p: PythonOnlyPoint(p.x + 1, p.y + 1), PythonOnlyUDT())
        self.assertEqual(PythonOnlyPoint(2.0, 3.0), df.select(udf2(df.point)).first()[0])
Ejemplo n.º 10
0
    def test_verify_type_not_nullable(self):
        import array
        import datetime
        import decimal

        schema = StructType([
            StructField('s', StringType(), nullable=False),
            StructField('i', IntegerType(), nullable=True)])

        class MyObj:
            def __init__(self, **kwargs):
                for k, v in kwargs.items():
                    setattr(self, k, v)

        # obj, data_type
        success_spec = [
            # String
            ("", StringType()),
            (u"", StringType()),
            (1, StringType()),
            (1.0, StringType()),
            ([], StringType()),
            ({}, StringType()),

            # UDT
            (ExamplePoint(1.0, 2.0), ExamplePointUDT()),

            # Boolean
            (True, BooleanType()),

            # Byte
            (-(2**7), ByteType()),
            (2**7 - 1, ByteType()),

            # Short
            (-(2**15), ShortType()),
            (2**15 - 1, ShortType()),

            # Integer
            (-(2**31), IntegerType()),
            (2**31 - 1, IntegerType()),

            # Long
            (-(2**63), LongType()),
            (2**63 - 1, LongType()),

            # Float & Double
            (1.0, FloatType()),
            (1.0, DoubleType()),

            # Decimal
            (decimal.Decimal("1.0"), DecimalType()),

            # Binary
            (bytearray([1, 2]), BinaryType()),

            # Date/Timestamp
            (datetime.date(2000, 1, 2), DateType()),
            (datetime.datetime(2000, 1, 2, 3, 4), DateType()),
            (datetime.datetime(2000, 1, 2, 3, 4), TimestampType()),

            # Array
            ([], ArrayType(IntegerType())),
            (["1", None], ArrayType(StringType(), containsNull=True)),
            ([1, 2], ArrayType(IntegerType())),
            ((1, 2), ArrayType(IntegerType())),
            (array.array('h', [1, 2]), ArrayType(IntegerType())),

            # Map
            ({}, MapType(StringType(), IntegerType())),
            ({"a": 1}, MapType(StringType(), IntegerType())),
            ({"a": None}, MapType(StringType(), IntegerType(), valueContainsNull=True)),

            # Struct
            ({"s": "a", "i": 1}, schema),
            ({"s": "a", "i": None}, schema),
            ({"s": "a"}, schema),
            ({"s": "a", "f": 1.0}, schema),
            (Row(s="a", i=1), schema),
            (Row(s="a", i=None), schema),
            (Row(s="a", i=1, f=1.0), schema),
            (["a", 1], schema),
            (["a", None], schema),
            (("a", 1), schema),
            (MyObj(s="a", i=1), schema),
            (MyObj(s="a", i=None), schema),
            (MyObj(s="a"), schema),
        ]

        # obj, data_type, exception class
        failure_spec = [
            # String (match anything but None)
            (None, StringType(), ValueError),

            # UDT
            (ExamplePoint(1.0, 2.0), PythonOnlyUDT(), ValueError),

            # Boolean
            (1, BooleanType(), TypeError),
            ("True", BooleanType(), TypeError),
            ([1], BooleanType(), TypeError),

            # Byte
            (-(2**7) - 1, ByteType(), ValueError),
            (2**7, ByteType(), ValueError),
            ("1", ByteType(), TypeError),
            (1.0, ByteType(), TypeError),

            # Short
            (-(2**15) - 1, ShortType(), ValueError),
            (2**15, ShortType(), ValueError),

            # Integer
            (-(2**31) - 1, IntegerType(), ValueError),
            (2**31, IntegerType(), ValueError),

            # Float & Double
            (1, FloatType(), TypeError),
            (1, DoubleType(), TypeError),

            # Decimal
            (1.0, DecimalType(), TypeError),
            (1, DecimalType(), TypeError),
            ("1.0", DecimalType(), TypeError),

            # Binary
            (1, BinaryType(), TypeError),

            # Date/Timestamp
            ("2000-01-02", DateType(), TypeError),
            (946811040, TimestampType(), TypeError),

            # Array
            (["1", None], ArrayType(StringType(), containsNull=False), ValueError),
            ([1, "2"], ArrayType(IntegerType()), TypeError),

            # Map
            ({"a": 1}, MapType(IntegerType(), IntegerType()), TypeError),
            ({"a": "1"}, MapType(StringType(), IntegerType()), TypeError),
            ({"a": None}, MapType(StringType(), IntegerType(), valueContainsNull=False),
             ValueError),

            # Struct
            ({"s": "a", "i": "1"}, schema, TypeError),
            (Row(s="a"), schema, ValueError),     # Row can't have missing field
            (Row(s="a", i="1"), schema, TypeError),
            (["a"], schema, ValueError),
            (["a", "1"], schema, TypeError),
            (MyObj(s="a", i="1"), schema, TypeError),
            (MyObj(s=None, i="1"), schema, ValueError),
        ]

        # Check success cases
        for obj, data_type in success_spec:
            try:
                _make_type_verifier(data_type, nullable=False)(obj)
            except Exception:
                self.fail("verify_type(%s, %s, nullable=False)" % (obj, data_type))

        # Check failure cases
        for obj, data_type, exp in failure_spec:
            msg = "verify_type(%s, %s, nullable=False) == %s" % (obj, data_type, exp)
            with self.assertRaises(exp, msg=msg):
                _make_type_verifier(data_type, nullable=False)(obj)