Ejemplo n.º 1
0
    def test_udf_with_string_return_type(self):
        add_one = UserDefinedFunction(lambda x: x + 1, "integer")
        make_pair = UserDefinedFunction(lambda x: (-x, x), "struct<x:integer,y:integer>")
        make_array = UserDefinedFunction(
            lambda x: [float(x) for x in range(x, x + 3)], "array<double>")

        expected = (2, Row(x=-1, y=1), [1.0, 2.0, 3.0])
        actual = (self.spark.range(1, 2).toDF("x")
                  .select(add_one("x"), make_pair("x"), make_array("x"))
                  .first())

        self.assertTupleEqual(expected, actual)
Ejemplo n.º 2
0
    def registerFunction(self, name, f, returnType=StringType()):
        """Registers a Python function (including lambda function) or a :class:`UserDefinedFunction`
        as a UDF. The registered UDF can be used in SQL statement.

        In addition to a name and the function itself, the return type can be optionally specified.
        When the return type is not given it default to a string and conversion will automatically
        be done.  For any other return type, the produced object must match the specified type.

        :param name: name of the UDF
        :param f: a Python function, or a wrapped/native UserDefinedFunction
        :param returnType: a :class:`pyspark.sql.types.DataType` object
        :return: a wrapped :class:`UserDefinedFunction`

        >>> strlen = spark.catalog.registerFunction("stringLengthString", len)
        >>> spark.sql("SELECT stringLengthString('test')").collect()
        [Row(stringLengthString(test)=u'4')]

        >>> spark.sql("SELECT 'foo' AS text").select(strlen("text")).collect()
        [Row(stringLengthString(text)=u'3')]

        >>> from pyspark.sql.types import IntegerType
        >>> _ = spark.catalog.registerFunction("stringLengthInt", len, IntegerType())
        >>> spark.sql("SELECT stringLengthInt('test')").collect()
        [Row(stringLengthInt(test)=4)]

        >>> from pyspark.sql.types import IntegerType
        >>> _ = spark.udf.register("stringLengthInt", len, IntegerType())
        >>> spark.sql("SELECT stringLengthInt('test')").collect()
        [Row(stringLengthInt(test)=4)]

        >>> import random
        >>> from pyspark.sql.functions import udf
        >>> from pyspark.sql.types import IntegerType, StringType
        >>> random_udf = udf(lambda: random.randint(0, 100), IntegerType()).asNondeterministic()
        >>> newRandom_udf = spark.catalog.registerFunction("random_udf", random_udf, StringType())
        >>> spark.sql("SELECT random_udf()").collect()  # doctest: +SKIP
        [Row(random_udf()=u'82')]
        >>> spark.range(1).select(newRandom_udf()).collect()  # doctest: +SKIP
        [Row(random_udf()=u'62')]
        """

        # This is to check whether the input function is a wrapped/native UserDefinedFunction
        if hasattr(f, 'asNondeterministic'):
            udf = UserDefinedFunction(f.func, returnType=returnType, name=name,
                                      evalType=PythonEvalType.SQL_BATCHED_UDF,
                                      deterministic=f.deterministic)
        else:
            udf = UserDefinedFunction(f, returnType=returnType, name=name,
                                      evalType=PythonEvalType.SQL_BATCHED_UDF)
        self._jsparkSession.udf().registerPython(name, udf._judf)
        return udf._wrapped()
Ejemplo n.º 3
0
 def test_udf_registration_return_type_not_none(self):
     with QuietTest(self.sc):
         with self.assertRaisesRegex(TypeError, "Invalid return type"):
             self.spark.catalog.registerFunction(
                 "f",
                 UserDefinedFunction(lambda x, y: len(x) + y, StringType()),
                 StringType())
Ejemplo n.º 4
0
 def test_udf_registration_return_type_none(self):
     two_args = self.spark.catalog.registerFunction(
         "twoArgs", UserDefinedFunction(lambda x, y: len(x) + y, "integer"),
         None)
     self.assertEqual(two_args.deterministic, True)
     [row] = self.spark.sql("SELECT twoArgs('test', 1)").collect()
     self.assertEqual(row[0], 5)
Ejemplo n.º 5
0
    def test_udf_with_udt(self):
        row = Row(label=1.0, point=ExamplePoint(1.0, 2.0))
        df = self.spark.createDataFrame([row])
        self.assertEqual(1.0, df.rdd.map(lambda r: r.point.x).first())
        udf = UserDefinedFunction(lambda p: p.y, DoubleType())
        self.assertEqual(2.0, df.select(udf(df.point)).first()[0])
        udf2 = UserDefinedFunction(lambda p: ExamplePoint(p.x + 1, p.y + 1), ExamplePointUDT())
        self.assertEqual(ExamplePoint(2.0, 3.0), df.select(udf2(df.point)).first()[0])

        row = Row(label=1.0, point=PythonOnlyPoint(1.0, 2.0))
        df = self.spark.createDataFrame([row])
        self.assertEqual(1.0, df.rdd.map(lambda r: r.point.x).first())
        udf = UserDefinedFunction(lambda p: p.y, DoubleType())
        self.assertEqual(2.0, df.select(udf(df.point)).first()[0])
        udf2 = UserDefinedFunction(lambda p: PythonOnlyPoint(p.x + 1, p.y + 1), PythonOnlyUDT())
        self.assertEqual(PythonOnlyPoint(2.0, 3.0), df.select(udf2(df.point)).first()[0])
Ejemplo n.º 6
0
 def test_udf3(self):
     two_args = self.spark.catalog.registerFunction(
         "twoArgs", UserDefinedFunction(lambda x, y: len(x) + y)
     )
     self.assertEqual(two_args.deterministic, True)
     [row] = self.spark.sql("SELECT twoArgs('test', 1)").collect()
     self.assertEqual(row[0], "5")
Ejemplo n.º 7
0
    def test_udf_init_should_not_initialize_context(self):
        UserDefinedFunction(lambda x: x, StringType())

        self.assertIsNone(
            SparkContext._active_spark_context,
            "SparkContext shouldn't be initialized when UserDefinedFunction is created."
        )
        self.assertIsNone(
            SparkSession._instantiatedSession,
            "SparkSession shouldn't be initialized when UserDefinedFunction is created."
        )
Ejemplo n.º 8
0
    def test_udf_with_partial_function(self):
        d = [Row(number=i, squared=i**2) for i in range(10)]
        rdd = self.sc.parallelize(d)
        data = self.spark.createDataFrame(rdd)

        def some_func(col, param):
            if col is not None:
                return col + param

        pfunc = functools.partial(some_func, param=4)
        pudf = UserDefinedFunction(pfunc, LongType())
        res = data.select(pudf(data['number']).alias('plus_four'))
        self.assertEqual(res.agg({'plus_four': 'sum'}).collect()[0][0], 85)
Ejemplo n.º 9
0
    def test_udf_with_callable(self):
        d = [Row(number=i, squared=i**2) for i in range(10)]
        rdd = self.sc.parallelize(d)
        data = self.spark.createDataFrame(rdd)

        class PlusFour:
            def __call__(self, col):
                if col is not None:
                    return col + 4

        call = PlusFour()
        pudf = UserDefinedFunction(call, LongType())
        res = data.select(pudf(data['number']).alias('plus_four'))
        self.assertEqual(res.agg({'plus_four': 'sum'}).collect()[0][0], 85)
Ejemplo n.º 10
0
    def registerFunction(self, name, f, returnType=StringType()):
        """Registers a python function (including lambda function) as a UDF
        so it can be used in SQL statements.

        In addition to a name and the function itself, the return type can be optionally specified.
        When the return type is not given it default to a string and conversion will automatically
        be done.  For any other return type, the produced object must match the specified type.

        :param name: name of the UDF
        :param f: python function
        :param returnType: a :class:`pyspark.sql.types.DataType` object
        :return: a wrapped :class:`UserDefinedFunction`

        >>> strlen = spark.catalog.registerFunction("stringLengthString", len)
        >>> spark.sql("SELECT stringLengthString('test')").collect()
        [Row(stringLengthString(test)=u'4')]

        >>> spark.sql("SELECT 'foo' AS text").select(strlen("text")).collect()
        [Row(stringLengthString(text)=u'3')]

        >>> from pyspark.sql.types import IntegerType
        >>> _ = spark.catalog.registerFunction("stringLengthInt", len, IntegerType())
        >>> spark.sql("SELECT stringLengthInt('test')").collect()
        [Row(stringLengthInt(test)=4)]

        >>> from pyspark.sql.types import IntegerType
        >>> _ = spark.udf.register("stringLengthInt", len, IntegerType())
        >>> spark.sql("SELECT stringLengthInt('test')").collect()
        [Row(stringLengthInt(test)=4)]
        """
        udf = UserDefinedFunction(f,
                                  returnType=returnType,
                                  name=name,
                                  evalType=PythonEvalType.SQL_BATCHED_UDF)
        self._jsparkSession.udf().registerPython(name, udf._judf)
        return udf._wrapped()
Ejemplo n.º 11
0
    def test_udf_defers_judf_initialization(self):
        # This is separate of  UDFInitializationTests
        # to avoid context initialization
        # when udf is called
        f = UserDefinedFunction(lambda x: x, StringType())

        self.assertIsNone(
            f._judf_placeholder, "judf should not be initialized before the first call."
        )

        self.assertIsInstance(f("foo"), Column, "UDF call should return a Column.")

        self.assertIsNotNone(
            f._judf_placeholder, "judf should be initialized after UDF has been called."
        )
Ejemplo n.º 12
0
    def registerFunction(self, name, f, returnType=None):
        """Registers a Python function (including lambda function) or a :class:`UserDefinedFunction`
        as a UDF. The registered UDF can be used in SQL statements.

        :func:`spark.udf.register` is an alias for :func:`spark.catalog.registerFunction`.

        In addition to a name and the function itself, `returnType` can be optionally specified.
        1) When f is a Python function, `returnType` defaults to a string. The produced object must
        match the specified type. 2) When f is a :class:`UserDefinedFunction`, Spark uses the return
        type of the given UDF as the return type of the registered UDF. The input parameter
        `returnType` is None by default. If given by users, the value must be None.

        :param name: name of the UDF in SQL statements.
        :param f: a Python function, or a wrapped/native UserDefinedFunction. The UDF can be either
            row-at-a-time or vectorized.
        :param returnType: the return type of the registered UDF.
        :return: a wrapped/native :class:`UserDefinedFunction`

        >>> strlen = spark.catalog.registerFunction("stringLengthString", len)
        >>> spark.sql("SELECT stringLengthString('test')").collect()
        [Row(stringLengthString(test)=u'4')]

        >>> spark.sql("SELECT 'foo' AS text").select(strlen("text")).collect()
        [Row(stringLengthString(text)=u'3')]

        >>> from pyspark.sql.types import IntegerType
        >>> _ = spark.catalog.registerFunction("stringLengthInt", len, IntegerType())
        >>> spark.sql("SELECT stringLengthInt('test')").collect()
        [Row(stringLengthInt(test)=4)]

        >>> from pyspark.sql.types import IntegerType
        >>> _ = spark.udf.register("stringLengthInt", len, IntegerType())
        >>> spark.sql("SELECT stringLengthInt('test')").collect()
        [Row(stringLengthInt(test)=4)]

        >>> from pyspark.sql.types import IntegerType
        >>> from pyspark.sql.functions import udf
        >>> slen = udf(lambda s: len(s), IntegerType())
        >>> _ = spark.udf.register("slen", slen)
        >>> spark.sql("SELECT slen('test')").collect()
        [Row(slen(test)=4)]

        >>> import random
        >>> from pyspark.sql.functions import udf
        >>> from pyspark.sql.types import IntegerType
        >>> random_udf = udf(lambda: random.randint(0, 100), IntegerType()).asNondeterministic()
        >>> new_random_udf = spark.catalog.registerFunction("random_udf", random_udf)
        >>> spark.sql("SELECT random_udf()").collect()  # doctest: +SKIP
        [Row(random_udf()=82)]
        >>> spark.range(1).select(new_random_udf()).collect()  # doctest: +SKIP
        [Row(<lambda>()=26)]

        >>> from pyspark.sql.functions import pandas_udf, PandasUDFType
        >>> @pandas_udf("integer", PandasUDFType.SCALAR)  # doctest: +SKIP
        ... def add_one(x):
        ...     return x + 1
        ...
        >>> _ = spark.udf.register("add_one", add_one)  # doctest: +SKIP
        >>> spark.sql("SELECT add_one(id) FROM range(3)").collect()  # doctest: +SKIP
        [Row(add_one(id)=1), Row(add_one(id)=2), Row(add_one(id)=3)]
        """

        # This is to check whether the input function is a wrapped/native UserDefinedFunction
        if hasattr(f, 'asNondeterministic'):
            if returnType is not None:
                raise TypeError(
                    "Invalid returnType: None is expected when f is a UserDefinedFunction, "
                    "but got %s." % returnType)
            if f.evalType not in [
                    PythonEvalType.SQL_BATCHED_UDF,
                    PythonEvalType.SQL_PANDAS_SCALAR_UDF
            ]:
                raise ValueError(
                    "Invalid f: f must be either SQL_BATCHED_UDF or SQL_PANDAS_SCALAR_UDF"
                )
            register_udf = UserDefinedFunction(f.func,
                                               returnType=f.returnType,
                                               name=name,
                                               evalType=f.evalType,
                                               deterministic=f.deterministic)
            return_udf = f
        else:
            if returnType is None:
                returnType = StringType()
            register_udf = UserDefinedFunction(
                f,
                returnType=returnType,
                name=name,
                evalType=PythonEvalType.SQL_BATCHED_UDF)
            return_udf = register_udf._wrapped()
        self._jsparkSession.udf().registerPython(name, register_udf._judf)
        return return_udf