コード例 #1
0
ファイル: test_types.py プロジェクト: tchenthilkumar/spark
    def test_udt(self):
        from pyspark.sql.types import _parse_datatype_json_string, _infer_type, _make_type_verifier

        def check_datatype(datatype):
            pickled = pickle.loads(pickle.dumps(datatype))
            assert datatype == pickled
            scala_datatype = self.spark._jsparkSession.parseDataType(datatype.json())
            python_datatype = _parse_datatype_json_string(scala_datatype.json())
            assert datatype == python_datatype

        check_datatype(ExamplePointUDT())
        structtype_with_udt = StructType([StructField("label", DoubleType(), False),
                                          StructField("point", ExamplePointUDT(), False)])
        check_datatype(structtype_with_udt)
        p = ExamplePoint(1.0, 2.0)
        self.assertEqual(_infer_type(p), ExamplePointUDT())
        _make_type_verifier(ExamplePointUDT())(ExamplePoint(1.0, 2.0))
        self.assertRaises(ValueError, lambda: _make_type_verifier(ExamplePointUDT())([1.0, 2.0]))

        check_datatype(PythonOnlyUDT())
        structtype_with_udt = StructType([StructField("label", DoubleType(), False),
                                          StructField("point", PythonOnlyUDT(), False)])
        check_datatype(structtype_with_udt)
        p = PythonOnlyPoint(1.0, 2.0)
        self.assertEqual(_infer_type(p), PythonOnlyUDT())
        _make_type_verifier(PythonOnlyUDT())(PythonOnlyPoint(1.0, 2.0))
        self.assertRaises(
            ValueError,
            lambda: _make_type_verifier(PythonOnlyUDT())([1.0, 2.0]))
コード例 #2
0
ファイル: test_types.py プロジェクト: tchenthilkumar/spark
    def test_nested_udt_in_df(self):
        schema = StructType().add("key", LongType()).add("val", ArrayType(PythonOnlyUDT()))
        df = self.spark.createDataFrame(
            [(i % 3, [PythonOnlyPoint(float(i), float(i))]) for i in range(10)],
            schema=schema)
        df.collect()

        schema = StructType().add("key", LongType()).add("val",
                                                         MapType(LongType(), PythonOnlyUDT()))
        df = self.spark.createDataFrame(
            [(i % 3, {i % 3: PythonOnlyPoint(float(i + 1), float(i + 1))}) for i in range(10)],
            schema=schema)
        df.collect()
コード例 #3
0
ファイル: test_types.py プロジェクト: tchenthilkumar/spark
    def test_apply_schema_with_udt(self):
        row = (1.0, ExamplePoint(1.0, 2.0))
        schema = StructType([StructField("label", DoubleType(), False),
                             StructField("point", ExamplePointUDT(), False)])
        df = self.spark.createDataFrame([row], schema)
        point = df.head().point
        self.assertEqual(point, ExamplePoint(1.0, 2.0))

        row = (1.0, PythonOnlyPoint(1.0, 2.0))
        schema = StructType([StructField("label", DoubleType(), False),
                             StructField("point", PythonOnlyUDT(), False)])
        df = self.spark.createDataFrame([row], schema)
        point = df.head().point
        self.assertEqual(point, PythonOnlyPoint(1.0, 2.0))
コード例 #4
0
ファイル: test_types.py プロジェクト: tchenthilkumar/spark
    def test_parquet_with_udt(self):
        row = Row(label=1.0, point=ExamplePoint(1.0, 2.0))
        df0 = self.spark.createDataFrame([row])
        output_dir = os.path.join(self.tempdir.name, "labeled_point")
        df0.write.parquet(output_dir)
        df1 = self.spark.read.parquet(output_dir)
        point = df1.head().point
        self.assertEqual(point, ExamplePoint(1.0, 2.0))

        row = Row(label=1.0, point=PythonOnlyPoint(1.0, 2.0))
        df0 = self.spark.createDataFrame([row])
        df0.write.parquet(output_dir, mode='overwrite')
        df1 = self.spark.read.parquet(output_dir)
        point = df1.head().point
        self.assertEqual(point, PythonOnlyPoint(1.0, 2.0))
コード例 #5
0
ファイル: test_types.py プロジェクト: tchenthilkumar/spark
 def test_cast_to_udt_with_udt(self):
     from pyspark.sql.functions import col
     row = Row(point=ExamplePoint(1.0, 2.0), python_only_point=PythonOnlyPoint(1.0, 2.0))
     df = self.spark.createDataFrame([row])
     self.assertRaises(AnalysisException, lambda: df.select(col("point").cast(PythonOnlyUDT())))
     self.assertRaises(AnalysisException,
                       lambda: df.select(col("python_only_point").cast(ExamplePointUDT())))
コード例 #6
0
ファイル: test_types.py プロジェクト: tchenthilkumar/spark
    def test_udf_with_udt(self):
        row = Row(label=1.0, point=ExamplePoint(1.0, 2.0))
        df = self.spark.createDataFrame([row])
        self.assertEqual(1.0, df.rdd.map(lambda r: r.point.x).first())
        udf = UserDefinedFunction(lambda p: p.y, DoubleType())
        self.assertEqual(2.0, df.select(udf(df.point)).first()[0])
        udf2 = UserDefinedFunction(lambda p: ExamplePoint(p.x + 1, p.y + 1), ExamplePointUDT())
        self.assertEqual(ExamplePoint(2.0, 3.0), df.select(udf2(df.point)).first()[0])

        row = Row(label=1.0, point=PythonOnlyPoint(1.0, 2.0))
        df = self.spark.createDataFrame([row])
        self.assertEqual(1.0, df.rdd.map(lambda r: r.point.x).first())
        udf = UserDefinedFunction(lambda p: p.y, DoubleType())
        self.assertEqual(2.0, df.select(udf(df.point)).first()[0])
        udf2 = UserDefinedFunction(lambda p: PythonOnlyPoint(p.x + 1, p.y + 1), PythonOnlyUDT())
        self.assertEqual(PythonOnlyPoint(2.0, 3.0), df.select(udf2(df.point)).first()[0])
コード例 #7
0
 def test_simple_udt_in_df(self):
     schema = StructType().add("key",
                               LongType()).add("val", PythonOnlyUDT())
     df = self.spark.createDataFrame(
         [(i % 3, PythonOnlyPoint(float(i), float(i))) for i in range(10)],
         schema=schema)
     df.collect()
コード例 #8
0
ファイル: test_types.py プロジェクト: tchenthilkumar/spark
    def test_cast_to_string_with_udt(self):
        from pyspark.sql.functions import col
        row = (ExamplePoint(1.0, 2.0), PythonOnlyPoint(3.0, 4.0))
        schema = StructType([StructField("point", ExamplePointUDT(), False),
                             StructField("pypoint", PythonOnlyUDT(), False)])
        df = self.spark.createDataFrame([row], schema)

        result = df.select(col('point').cast('string'), col('pypoint').cast('string')).head()
        self.assertEqual(result, Row(point=u'(1.0, 2.0)', pypoint=u'[3.0, 4.0]'))
コード例 #9
0
ファイル: test_types.py プロジェクト: tchenthilkumar/spark
    def test_udt_with_none(self):
        df = self.spark.range(0, 10, 1, 1)

        def myudf(x):
            if x > 0:
                return PythonOnlyPoint(float(x), float(x))

        self.spark.catalog.registerFunction("udf", myudf, PythonOnlyUDT())
        rows = [r[0] for r in df.selectExpr("udf(id)").take(2)]
        self.assertEqual(rows, [None, PythonOnlyPoint(1, 1)])
コード例 #10
0
ファイル: test_types.py プロジェクト: tchenthilkumar/spark
    def test_infer_schema_with_udt(self):
        row = Row(label=1.0, point=ExamplePoint(1.0, 2.0))
        df = self.spark.createDataFrame([row])
        schema = df.schema
        field = [f for f in schema.fields if f.name == "point"][0]
        self.assertEqual(type(field.dataType), ExamplePointUDT)

        with self.tempView("labeled_point"):
            df.createOrReplaceTempView("labeled_point")
            point = self.spark.sql("SELECT point FROM labeled_point").head().point
            self.assertEqual(point, ExamplePoint(1.0, 2.0))

        row = Row(label=1.0, point=PythonOnlyPoint(1.0, 2.0))
        df = self.spark.createDataFrame([row])
        schema = df.schema
        field = [f for f in schema.fields if f.name == "point"][0]
        self.assertEqual(type(field.dataType), PythonOnlyUDT)

        with self.tempView("labeled_point"):
            df.createOrReplaceTempView("labeled_point")
            point = self.spark.sql("SELECT point FROM labeled_point").head().point
            self.assertEqual(point, PythonOnlyPoint(1.0, 2.0))
コード例 #11
0
ファイル: test_types.py プロジェクト: tchenthilkumar/spark
    def test_complex_nested_udt_in_df(self):
        from pyspark.sql.functions import udf

        schema = StructType().add("key", LongType()).add("val", PythonOnlyUDT())
        df = self.spark.createDataFrame(
            [(i % 3, PythonOnlyPoint(float(i), float(i))) for i in range(10)],
            schema=schema)
        df.collect()

        gd = df.groupby("key").agg({"val": "collect_list"})
        gd.collect()
        udf = udf(lambda k, v: [(k, v[0])], ArrayType(df.schema))
        gd.select(udf(*gd)).collect()
コード例 #12
0
ファイル: test_types.py プロジェクト: tchenthilkumar/spark
 def myudf(x):
     if x > 0:
         return PythonOnlyPoint(float(x), float(x))