Пример #1
0
    def test_udt(self):
        from pyspark.sql.types import _parse_datatype_json_string, _infer_type, _verify_type
        from pyspark.sql.tests import ExamplePointUDT, ExamplePoint

        def check_datatype(datatype):
            pickled = pickle.loads(pickle.dumps(datatype))
            assert datatype == pickled
            scala_datatype = self.sqlCtx._ssql_ctx.parseDataType(datatype.json())
            python_datatype = _parse_datatype_json_string(scala_datatype.json())
            assert datatype == python_datatype

        check_datatype(ExamplePointUDT())
        structtype_with_udt = StructType([StructField("label", DoubleType(), False),
                                          StructField("point", ExamplePointUDT(), False)])
        check_datatype(structtype_with_udt)
        p = ExamplePoint(1.0, 2.0)
        self.assertEqual(_infer_type(p), ExamplePointUDT())
        _verify_type(ExamplePoint(1.0, 2.0), ExamplePointUDT())
        self.assertRaises(ValueError, lambda: _verify_type([1.0, 2.0], ExamplePointUDT()))

        check_datatype(PythonOnlyUDT())
        structtype_with_udt = StructType([StructField("label", DoubleType(), False),
                                          StructField("point", PythonOnlyUDT(), False)])
        check_datatype(structtype_with_udt)
        p = PythonOnlyPoint(1.0, 2.0)
        self.assertEqual(_infer_type(p), PythonOnlyUDT())
        _verify_type(PythonOnlyPoint(1.0, 2.0), PythonOnlyUDT())
        self.assertRaises(ValueError, lambda: _verify_type([1.0, 2.0], PythonOnlyUDT()))
Пример #2
0
 def test_parquet_with_udt(self):
     from pyspark.sql.tests import ExamplePoint
     row = Row(label=1.0, point=ExamplePoint(1.0, 2.0))
     df0 = self.sc.parallelize([row]).toDF()
     output_dir = os.path.join(self.tempdir.name, "labeled_point")
     df0.saveAsParquetFile(output_dir)
     df1 = self.sqlCtx.parquetFile(output_dir)
     point = df1.head().point
     self.assertEquals(point, ExamplePoint(1.0, 2.0))
Пример #3
0
 def test_apply_schema_with_udt(self):
     from pyspark.sql.tests import ExamplePoint, ExamplePointUDT
     row = (1.0, ExamplePoint(1.0, 2.0))
     rdd = self.sc.parallelize([row])
     schema = StructType([StructField("label", DoubleType(), False),
                          StructField("point", ExamplePointUDT(), False)])
     df = rdd.toDF(schema)
     point = df.head().point
     self.assertEquals(point, ExamplePoint(1.0, 2.0))
Пример #4
0
 def test_infer_schema_with_udt(self):
     from pyspark.sql.tests import ExamplePoint, ExamplePointUDT
     row = Row(label=1.0, point=ExamplePoint(1.0, 2.0))
     df = self.sc.parallelize([row]).toDF()
     schema = df.schema
     field = [f for f in schema.fields if f.name == "point"][0]
     self.assertEqual(type(field.dataType), ExamplePointUDT)
     df.registerTempTable("labeled_point")
     point = self.sqlCtx.sql("SELECT point FROM labeled_point").head().point
     self.assertEqual(point, ExamplePoint(1.0, 2.0))
Пример #5
0
 def test_udf_with_udt(self):
     from pyspark.sql.tests import ExamplePoint, ExamplePointUDT
     row = Row(label=1.0, point=ExamplePoint(1.0, 2.0))
     df = self.sc.parallelize([row]).toDF()
     self.assertEqual(1.0, df.map(lambda r: r.point.x).first())
     udf = UserDefinedFunction(lambda p: p.y, DoubleType())
     self.assertEqual(2.0, df.select(udf(df.point)).first()[0])
     udf2 = UserDefinedFunction(lambda p: ExamplePoint(p.x + 1, p.y + 1),
                                ExamplePointUDT())
     self.assertEqual(ExamplePoint(2.0, 3.0),
                      df.select(udf2(df.point)).first()[0])
Пример #6
0
    def test_apply_schema_with_udt(self):
        from pyspark.sql.tests import ExamplePoint, ExamplePointUDT
        row = (1.0, ExamplePoint(1.0, 2.0))
        schema = StructType([StructField("label", DoubleType(), False),
                             StructField("point", ExamplePointUDT(), False)])
        df = self.sqlCtx.createDataFrame([row], schema)
        point = df.head().point
        self.assertEquals(point, ExamplePoint(1.0, 2.0))

        row = (1.0, PythonOnlyPoint(1.0, 2.0))
        schema = StructType([StructField("label", DoubleType(), False),
                             StructField("point", PythonOnlyUDT(), False)])
        df = self.sqlCtx.createDataFrame([row], schema)
        point = df.head().point
        self.assertEquals(point, PythonOnlyPoint(1.0, 2.0))
Пример #7
0
    def runTest(self):
        from pyspark.sql.tests import ExamplePoint, ExamplePointUDT, DoubleType, StructType, StructField, Row
        row1 = (1.0, ExamplePoint(1.0, 2.0))
        row2 = (2.0, ExamplePoint(3.0, 4.0))
        schema = StructType([
            StructField("label", DoubleType(), False),
            StructField("point", ExamplePointUDT(), False)
        ])
        df1 = self.spark.createDataFrame([row1], schema)
        df2 = self.spark.createDataFrame([row2], schema)

        result = df1.union(df2).orderBy("label").collect()
        self.assertEqual(result, [
            Row(label=1.0, point=ExamplePoint(1.0, 2.0)),
            Row(label=2.0, point=ExamplePoint(3.0, 4.0))
        ])
Пример #8
0
    def test_parquet_with_udt(self):
        from pyspark.sql.tests import ExamplePoint, ExamplePointUDT
        row = Row(label=1.0, point=ExamplePoint(1.0, 2.0))
        df0 = self.sqlCtx.createDataFrame([row])
        output_dir = os.path.join(self.tempdir.name, "labeled_point")
        df0.write.parquet(output_dir)
        df1 = self.sqlCtx.parquetFile(output_dir)
        point = df1.head().point
        self.assertEquals(point, ExamplePoint(1.0, 2.0))

        row = Row(label=1.0, point=PythonOnlyPoint(1.0, 2.0))
        df0 = self.sqlCtx.createDataFrame([row])
        df0.write.parquet(output_dir, mode='overwrite')
        df1 = self.sqlCtx.parquetFile(output_dir)
        point = df1.head().point
        self.assertEquals(point, PythonOnlyPoint(1.0, 2.0))
Пример #9
0
    def test_udf_with_udt(self):
        from pyspark.sql.tests import ExamplePoint, ExamplePointUDT
        row = Row(label=1.0, point=ExamplePoint(1.0, 2.0))
        df = self.sqlCtx.createDataFrame([row])
        self.assertEqual(1.0, df.map(lambda r: r.point.x).first())
        udf = UserDefinedFunction(lambda p: p.y, DoubleType())
        self.assertEqual(2.0, df.select(udf(df.point)).first()[0])
        udf2 = UserDefinedFunction(lambda p: ExamplePoint(p.x + 1, p.y + 1), ExamplePointUDT())
        self.assertEqual(ExamplePoint(2.0, 3.0), df.select(udf2(df.point)).first()[0])

        row = Row(label=1.0, point=PythonOnlyPoint(1.0, 2.0))
        df = self.sqlCtx.createDataFrame([row])
        self.assertEqual(1.0, df.map(lambda r: r.point.x).first())
        udf = UserDefinedFunction(lambda p: p.y, DoubleType())
        self.assertEqual(2.0, df.select(udf(df.point)).first()[0])
        udf2 = UserDefinedFunction(lambda p: PythonOnlyPoint(p.x + 1, p.y + 1), PythonOnlyUDT())
        self.assertEqual(PythonOnlyPoint(2.0, 3.0), df.select(udf2(df.point)).first()[0])
Пример #10
0
    def test_gapply_universal_udt_val(self):
        def pandasAggFunction(series):
            x = float(series.apply(lambda pt: int(pt.x) + int(pt.y)).sum())
            return ExamplePoint(
                x, x)  # still deterministic, can have exact equivalence test

        dataType = ExamplePointUDT()
        dataGen = lambda: ExamplePoint(
            float(random.randrange(GapplyTests.NVALS)),
            float(random.randrange(GapplyTests.NVALS)))
        self.checkGapplyEquivalentToPandas(pandasAggFunction, dataType,
                                           dataGen)
Пример #11
0
 def deserialize(self, datum):
     return ExamplePoint(datum[0], datum[1])
Пример #12
0
 def pandasAggFunction(series):
     x = float(series.apply(lambda pt: int(pt.x) + int(pt.y)).sum())
     return ExamplePoint(
         x, x)  # still deterministic, can have exact equivalence test
Пример #13
0
 def dataGen():
     return ExamplePoint(float(random.randrange(GapplyTests.NVALS)),
                         float(random.randrange(GapplyTests.NVALS)))