コード例 #1
0
ファイル: dstream.py プロジェクト: giworld/spark
 def _jdstream(self):
     if self._jdstream_val:
         return self._jdstream_val
     if self._bypass_serializer:
         self.jrdd_deserializer = NoOpSerializer()
     command = (self.func, self._prev_jrdd_deserializer,
                self._jrdd_deserializer)
     # the serialized command will be compressed by broadcast
     ser = CloudPickleSerializer()
     pickled_command = ser.dumps(command)
     if pickled_command > (1 << 20):  # 1M
         broadcast = self.ctx.broadcast(pickled_command)
         pickled_command = ser.dumps(broadcast)
     broadcast_vars = ListConverter().convert(
         [x._jbroadcast for x in self.ctx._pickled_broadcast_vars],
         self.ctx._gateway._gateway_client)
     self.ctx._pickled_broadcast_vars.clear()
     class_tag = self._prev_jdstream.classTag()
     env = MapConverter().convert(self.ctx.environment,
                                  self.ctx._gateway._gateway_client)
     includes = ListConverter().convert(self.ctx._python_includes,
                                        self.ctx._gateway._gateway_client)
     python_dstream = self.ctx._jvm.PythonDStream(self._prev_jdstream.dstream(),
                                                  bytearray(pickled_command),
                                                  env, includes, self.preservesPartitioning,
                                                  self.ctx.pythonExec,
                                                  broadcast_vars, self.ctx._javaAccumulator,
                                                  class_tag)
     self._jdstream_val = python_dstream.asJavaDStream()
     return self._jdstream_val
コード例 #2
0
ファイル: test_serializers.py プロジェクト: Brett-A/spark
 def test_pickling_file_handles(self):
     # to be corrected with SPARK-11160
     try:
         import xmlrunner
     except ImportError:
         ser = CloudPickleSerializer()
         out1 = sys.stderr
         out2 = ser.loads(ser.dumps(out1))
         self.assertEqual(out1, out2)
コード例 #3
0
ファイル: test_serializers.py プロジェクト: Brett-A/spark
    def test_itemgetter(self):
        from operator import itemgetter
        ser = CloudPickleSerializer()
        d = range(10)
        getter = itemgetter(1)
        getter2 = ser.loads(ser.dumps(getter))
        self.assertEqual(getter(d), getter2(d))

        getter = itemgetter(0, 3)
        getter2 = ser.loads(ser.dumps(getter))
        self.assertEqual(getter(d), getter2(d))
コード例 #4
0
ファイル: test_serializers.py プロジェクト: Brett-A/spark
    def test_func_globals(self):

        class Unpicklable(object):
            def __reduce__(self):
                raise Exception("not picklable")

        global exit
        exit = Unpicklable()

        ser = CloudPickleSerializer()
        self.assertRaises(Exception, lambda: ser.dumps(exit))

        def foo():
            sys.exit(0)

        self.assertTrue("exit" in foo.__code__.co_names)
        ser.dumps(foo)
コード例 #5
0
ファイル: types.py プロジェクト: chenc10/Spark-PAF-INFOCOM18
 def jsonValue(self):
     if self.scalaUDT():
         assert self.module() != '__main__', 'UDT in __main__ cannot work with ScalaUDT'
         schema = {
             "type": "udt",
             "class": self.scalaUDT(),
             "pyClass": "%s.%s" % (self.module(), type(self).__name__),
             "sqlType": self.sqlType().jsonValue()
         }
     else:
         ser = CloudPickleSerializer()
         b = ser.dumps(type(self))
         schema = {
             "type": "udt",
             "pyClass": "%s.%s" % (self.module(), type(self).__name__),
             "serializedClass": base64.b64encode(b).decode('utf8'),
             "sqlType": self.sqlType().jsonValue()
         }
     return schema
コード例 #6
0
 def tf_dataset(cls, func, total_size, bigdl_type="float"):
     """
     :param func: a function return a tensorflow dataset
     :param total_size: total size of this dataset
     :param bigdl_type: numeric type
     :return: A feature set
     """
     func = CloudPickleSerializer.dumps(CloudPickleSerializer, func)
     jvalue = callZooFunc(bigdl_type, "createFeatureSetFromTfDataset", func, total_size)
     return cls(jvalue=jvalue)
コード例 #7
0
ファイル: types.py プロジェクト: CatKyo/SparkInfoSystem
 def jsonValue(self):
     if self.scalaUDT():
         assert self.module() != "__main__", "UDT in __main__ cannot work with ScalaUDT"
         schema = {
             "type": "udt",
             "class": self.scalaUDT(),
             "pyClass": "%s.%s" % (self.module(), type(self).__name__),
             "sqlType": self.sqlType().jsonValue(),
         }
     else:
         ser = CloudPickleSerializer()
         b = ser.dumps(type(self))
         schema = {
             "type": "udt",
             "pyClass": "%s.%s" % (self.module(), type(self).__name__),
             "serializedClass": base64.b64encode(b).decode("utf8"),
             "sqlType": self.sqlType().jsonValue(),
         }
     return schema
コード例 #8
0
ファイル: types.py プロジェクト: sgaviner/spark-1
 def fromJson(cls, json):
     pyUDT = str(json["pyClass"])  # convert unicode to str
     split = pyUDT.rfind(".")
     pyModule = pyUDT[:split]
     pyClass = pyUDT[split + 1:]
     m = __import__(pyModule, globals(), locals(), [pyClass])
     if not hasattr(m, pyClass):
         s = base64.b64decode(json['serializedClass'].encode('utf-8'))
         UDT = CloudPickleSerializer().loads(s)
     else:
         UDT = getattr(m, pyClass)
     return UDT()
コード例 #9
0
    def registerFunction(self, name, f, returnType="string"):
        def func(split, iterator):
            return imap(f, iterator)

        command = (func, self._sc.serializer, self._sc.serializer)
        env = MapConverter().convert(self._sc.environment,
                                     self._sc._gateway._gateway_client)
        includes = ListConverter().convert(self._sc._python_includes,
                                           self._sc._gateway._gateway_client)
        self._ssql_ctx.registerPython(
            name, bytearray(CloudPickleSerializer().dumps(command)), env,
            includes, self._sc.pythonExec, self._sc._javaAccumulator,
            returnType)
コード例 #10
0
 def to_pytorch(self):
     """
     Convert to pytorch model
     :return: a pytorch model
     """
     new_weight = self.get_weights()
     assert (len(new_weight) == 1,
             "TorchModel's weights should be one tensor")
     m = CloudPickleSerializer.loads(CloudPickleSerializer,
                                     self.module_bytes)
     w = torch.Tensor(new_weight[0])
     torch.nn.utils.vector_to_parameters(w, m.parameters())
     return m
コード例 #11
0
    def pytorch_dataloader(cls,
                           dataloader,
                           features="_data[0]",
                           labels="_data[1]",
                           bigdl_type="float"):
        """
        Create FeatureSet from pytorch dataloader
        :param dataloader: a pytorch dataloader, or a function return pytorch dataloader.
        :param features: features in _data, _data is get from dataloader.
        :param labels: lables in _data, _data is get from dataloader.
        :param bigdl_type: numeric type
        :return: A feature set
        """
        import torch
        if isinstance(dataloader, torch.utils.data.DataLoader):
            node_num, core_num = get_node_and_core_number()
            if dataloader.batch_size % node_num != 0:
                true_bs = math.ceil(
                    dataloader.batch_size / node_num) * node_num
                warning_msg = "Detect dataloader's batch_size is not divisible by node number(" + \
                              str(node_num) + "), will adjust batch_size to " + str(true_bs) + \
                              " automatically"
                warnings.warn(warning_msg)

            bys = CloudPickleSerializer.dumps(CloudPickleSerializer,
                                              dataloader)
            jvalue = callZooFunc(bigdl_type, "createFeatureSetFromPyTorch",
                                 bys, False, features, labels)
            return cls(jvalue=jvalue)
        elif callable(dataloader):
            bys = CloudPickleSerializer.dumps(CloudPickleSerializer,
                                              dataloader)
            jvalue = callZooFunc(bigdl_type, "createFeatureSetFromPyTorch",
                                 bys, True, features, labels)
            return cls(jvalue=jvalue)
        else:
            raise ValueError(
                "Unsupported dataloader type, please pass pytorch dataloader" +
                " or a function to create pytorch dataloader.")
コード例 #12
0
 def from_pytorch(model):
     """
     Create a TorchNet directly from PyTorch model, e.g. model in torchvision.models.
     :param model: a PyTorch model
     """
     weights = []
     for param in trainable_param(model):
         weights.append(param.view(-1))
     flatten_weight = torch.nn.utils.parameters_to_vector(
         weights).data.numpy()
     bys = CloudPickleSerializer.dumps(CloudPickleSerializer, model)
     net = TorchModel(bys, flatten_weight)
     return net
コード例 #13
0
def tensorflowGridSearch():
    """
        Wrap tensorflow so that it can be used in sklearn GridsearchCV
    :return:
    """
    dataX, dataY = getMnist()
    dataX = dataX.astype(np.float32)
    dataY = np.argmax(dataY, axis=1).astype(np.int32)

    tuned_parameters = [{'lr': [1e-1, 1e-2], 'iters': [10, 20]}]
    scores = ['precision', 'recall']

    model = DisLRModel(400, 10, 0.01, 10)
    clf = GridSearchCV(model,
                       param_grid=tuned_parameters,
                       cv=2,
                       scoring='%s_macro' % "precision")
    clf.fit(dataX, dataY)

    # test whether the model could be serialized
    cp = CloudPickleSerializer()
    cp.dumps(model)
コード例 #14
0
ファイル: vertex.py プロジェクト: calhank/reddiculous
    def getJavaVertexRDD(self, rdd, rdd_deserializer):
        if self.bypass_serializer:
            self.jvertex_rdd_deserializer = NoOpSerializer()
            rdd_deserializer = NoOpSerializer()
        # enable_profile = self.ctx._conf.get("spark.python.profile", "false") == "true"
        # profileStats = self.ctx.accumulator(None, PStatsParam) if enable_profile else None
        def f(index, iterator):
            return iterator
        command = (f, rdd_deserializer, rdd_deserializer)
        # the serialized command will be compressed by broadcast
        ser = CloudPickleSerializer()
        pickled_command = ser.dumps(command)
        if len(pickled_command) > (1 << 20):  # 1M
            self.broadcast = self.ctx.broadcast(pickled_command)
            pickled_command = ser.dumps(self.broadcast)

        # the serialized command will be compressed by broadcast
        broadcast_vars = ListConverter().convert(
            [x._jbroadcast for x in self.ctx._pickled_broadcast_vars],
            self.ctx._gateway._gateway_client)
        self.ctx._pickled_broadcast_vars.clear()
        env = MapConverter().convert(self.ctx.environment,
                                     self.ctx._gateway._gateway_client)
        includes = ListConverter().convert(self.ctx._python_includes,
                                           self.ctx._gateway._gateway_client)
        target_storage_level = StorageLevel.MEMORY_ONLY
        java_storage_level = self.ctx._getJavaStorageLevel(target_storage_level)
        prdd = self.ctx._jvm.PythonVertexRDD(rdd._jrdd,
                                                   bytearray(pickled_command),
                                                   env, includes, self.preserve_partitioning,
                                                   self.ctx.pythonExec,
                                                   broadcast_vars, self.ctx._javaAccumulator,
                                                   java_storage_level)
        self.jvertex_rdd = prdd.asJavaVertexRDD()
        # if enable_profile:
        #     self.id = self.jvertex_rdd.id()
        #     self.ctx._add_profile(self.id, profileStats)
        return self.jvertex_rdd
コード例 #15
0
    def getJavaVertexRDD(self, rdd, rdd_deserializer):
        if self.bypass_serializer:
            self.jvertex_rdd_deserializer = NoOpSerializer()
            rdd_deserializer = NoOpSerializer()
        # enable_profile = self.ctx._conf.get("spark.python.profile", "false") == "true"
        # profileStats = self.ctx.accumulator(None, PStatsParam) if enable_profile else None
        def f(index, iterator):
            return iterator

        command = (f, rdd_deserializer, rdd_deserializer)
        # the serialized command will be compressed by broadcast
        ser = CloudPickleSerializer()
        pickled_command = ser.dumps(command)
        if len(pickled_command) > (1 << 20):  # 1M
            self.broadcast = self.ctx.broadcast(pickled_command)
            pickled_command = ser.dumps(self.broadcast)

        # the serialized command will be compressed by broadcast
        broadcast_vars = ListConverter().convert(
            [x._jbroadcast for x in self.ctx._pickled_broadcast_vars],
            self.ctx._gateway._gateway_client)
        self.ctx._pickled_broadcast_vars.clear()
        env = MapConverter().convert(self.ctx.environment,
                                     self.ctx._gateway._gateway_client)
        includes = ListConverter().convert(self.ctx._python_includes,
                                           self.ctx._gateway._gateway_client)
        target_storage_level = StorageLevel.MEMORY_ONLY
        java_storage_level = self.ctx._getJavaStorageLevel(
            target_storage_level)
        prdd = self.ctx._jvm.PythonVertexRDD(
            rdd._jrdd, bytearray(pickled_command), env, includes,
            self.preserve_partitioning, self.ctx.pythonExec, broadcast_vars,
            self.ctx._javaAccumulator, java_storage_level)
        self.jvertex_rdd = prdd.asJavaVertexRDD()
        # if enable_profile:
        #     self.id = self.jvertex_rdd.id()
        #     self.ctx._add_profile(self.id, profileStats)
        return self.jvertex_rdd
コード例 #16
0
    def test_attrgetter(self):
        from operator import attrgetter
        ser = CloudPickleSerializer()

        class C(object):
            def __getattr__(self, item):
                return item

        d = C()
        getter = attrgetter("a")
        getter2 = ser.loads(ser.dumps(getter))
        self.assertEqual(getter(d), getter2(d))
        getter = attrgetter("a", "b")
        getter2 = ser.loads(ser.dumps(getter))
        self.assertEqual(getter(d), getter2(d))

        d.e = C()
        getter = attrgetter("e.a")
        getter2 = ser.loads(ser.dumps(getter))
        self.assertEqual(getter(d), getter2(d))
        getter = attrgetter("e.a", "e.b")
        getter2 = ser.loads(ser.dumps(getter))
        self.assertEqual(getter(d), getter2(d))
コード例 #17
0
    def jedge_rdd(self):
        if self.jerdd_val:
            return self.jerdd_val
        if self.bypass_serializer:
            self.jedge_rdd_deserializer = NoOpSerializer()
        enable_profile = self.ctx._conf.get("spark.python.profile", "false") == "true"
        profileStats = self.ctx.accumulator(None, PStatsParam) if enable_profile else None
        command = (self.func, profileStats, self.prev_jedge_rdd_deserializer,
                   self.jedge_rdd_deserializer)
        # the serialized command will be compressed by broadcast
        ser = CloudPickleSerializer()
        pickled_command = ser.dumps(command)
        if len(pickled_command) > (1 << 20):  # 1M
            self.broadcast = self.ctx.broadcast(pickled_command)
            pickled_command = ser.dumps(self.broadcast)
        broadcast_vars = ListConverter().convert(
            [x._jbroadcast for x in self.ctx._pickled_broadcast_vars],
            self.ctx._gateway._gateway_client)
        self.ctx._pickled_broadcast_vars.clear()
        env = MapConverter().convert(self.ctx.environment,
                                     self.ctx._gateway._gateway_client)
        includes = ListConverter().convert(self.ctx._python_includes,
                                           self.ctx._gateway._gateway_client)
        java_storage_level = self.ctx._getJavaStorageLevel(StorageLevel.MEMORY_ONLY)
        python_rdd = self.ctx._jvm.PythonEdgeRDD(self.prev_jedge_rdd,
                                                   bytearray(pickled_command),
                                                   env, includes, self.preservesPartitioning,
                                                   self.ctx.pythonExec,
                                                   broadcast_vars, self.ctx._javaAccumulator,
                                                   java_storage_level)
        self.jerdd_val = python_rdd.asJavaEdgeRDD()

        if enable_profile:
            self.id = self.jerdd_val.id()
            self.ctx._add_profile(self.id, profileStats)
        return self.jerdd_val
コード例 #18
0
ファイル: context.py プロジェクト: skambha/spark-ri
    def _ensure_initialized(cls):
        SparkContext._ensure_initialized()
        gw = SparkContext._gateway

        java_import(gw.jvm, "org.apache.spark.streaming.*")
        java_import(gw.jvm, "org.apache.spark.streaming.api.java.*")
        java_import(gw.jvm, "org.apache.spark.streaming.api.python.*")

        from pyspark.java_gateway import ensure_callback_server_started
        ensure_callback_server_started(gw)

        # register serializer for TransformFunction
        # it happens before creating SparkContext when loading from checkpointing
        cls._transformerSerializer = TransformFunctionSerializer(
            SparkContext._active_spark_context, CloudPickleSerializer(), gw)
コード例 #19
0
ファイル: test_serializers.py プロジェクト: Brett-A/spark
    def test_attrgetter(self):
        from operator import attrgetter
        ser = CloudPickleSerializer()

        class C(object):
            def __getattr__(self, item):
                return item
        d = C()
        getter = attrgetter("a")
        getter2 = ser.loads(ser.dumps(getter))
        self.assertEqual(getter(d), getter2(d))
        getter = attrgetter("a", "b")
        getter2 = ser.loads(ser.dumps(getter))
        self.assertEqual(getter(d), getter2(d))

        d.e = C()
        getter = attrgetter("e.a")
        getter2 = ser.loads(ser.dumps(getter))
        self.assertEqual(getter(d), getter2(d))
        getter = attrgetter("e.a", "e.b")
        getter2 = ser.loads(ser.dumps(getter))
        self.assertEqual(getter(d), getter2(d))
コード例 #20
0
    def test_itemgetter(self):
        from operator import itemgetter
        ser = CloudPickleSerializer()
        d = range(10)
        getter = itemgetter(1)
        getter2 = ser.loads(ser.dumps(getter))
        self.assertEqual(getter(d), getter2(d))

        getter = itemgetter(0, 3)
        getter2 = ser.loads(ser.dumps(getter))
        self.assertEqual(getter(d), getter2(d))
コード例 #21
0
ファイル: common.py プロジェクト: yanwei-ji/analytics-zoo
    def pytorch_dataloader(cls, dataloader, bigdl_type="float"):
        """
        Create FeatureSet from pytorch dataloader
        :param dataloader: a pytorch dataloader
        :param bigdl_type: numeric type
        :return: A feature set
        """
        node_num, core_num = get_node_and_core_number()
        if dataloader.batch_size % node_num != 0:
            true_bs = math.ceil(dataloader.batch_size / node_num) * node_num
            warning_msg = "Detect dataloader's batch_size is not divisible by node number(" + \
                          node_num + "), will adjust batch_size to " + true_bs + " automatically"
            warnings.warn(warning_msg)

        bys = CloudPickleSerializer.dumps(CloudPickleSerializer, dataloader)
        jvalue = callZooFunc(bigdl_type, "createFeatureSetFromPyTorch", bys)
        return cls(jvalue=jvalue)
コード例 #22
0
ファイル: rdd.py プロジェクト: xoltar/spark
 def _jrdd(self):
     if self._jrdd_val:
         return self._jrdd_val
     if self._bypass_serializer:
         serializer = NoOpSerializer()
     else:
         serializer = self.ctx.serializer
     command = (self.func, self._prev_jrdd_deserializer, serializer)
     pickled_command = CloudPickleSerializer().dumps(command)
     broadcast_vars = ListConverter().convert(
         [x._jbroadcast for x in self.ctx._pickled_broadcast_vars],
         self.ctx._gateway._gateway_client)
     self.ctx._pickled_broadcast_vars.clear()
     class_tag = self._prev_jrdd.classTag()
     env = MapConverter().convert(self.ctx.environment,
                                  self.ctx._gateway._gateway_client)
     includes = ListConverter().convert(self.ctx._python_includes,
                                  self.ctx._gateway._gateway_client)
     python_rdd = self.ctx._jvm.PythonRDD(self._prev_jrdd.rdd(),
         bytearray(pickled_command), env, includes, self.preservesPartitioning,
         self.ctx.pythonExec, broadcast_vars, self.ctx._javaAccumulator,
         class_tag)
     self._jrdd_val = python_rdd.asJavaRDD()
     return self._jrdd_val
コード例 #23
0
ファイル: test_rdd.py プロジェクト: zhining-lu/spark
 def do_pickle(f, sc):
     command = (f, None, sc.serializer, sc.serializer)
     ser = CloudPickleSerializer()
     ser.dumps(command)
コード例 #24
0
ファイル: python_fun.py プロジェクト: zhuan77241/streamingpro
 def pickle_command(command):
     # the serialized command will be compressed by broadcast
     ser = CloudPickleSerializer()
     pickled_command = ser.dumps(command)
     return pickled_command
コード例 #25
0
ファイル: test_rdd.py プロジェクト: apache/spark
 def do_pickle(f, sc):
     command = (f, None, sc.serializer, sc.serializer)
     ser = CloudPickleSerializer()
     ser.dumps(command)
コード例 #26
0
             np.int32]).astype(np.int32)

    elif type(elem_type) == LongType:
        result = result.select_dtypes(
            [np.byte, np.ubyte, np.short, np.ushort, np.int, np.long])

    elif type(elem_type) == FloatType:
        result = result.select_dtypes(include=(np.number, )).astype(np.float32)

    elif type(elem_type) == DoubleType:
        result = result.select_dtypes(include=(np.number, )).astype(np.float64)

    if len(result.columns) == 0:
        raise MlflowException(
            message=
            "The the model did not produce any values compatible with the requested "
            "type '{}'. Consider requesting udf with StringType or "
            "Arraytype(StringType).".format(str(elem_type)),
            error_code=INVALID_PARAMETER_VALUE)

    if type(elem_type) == StringType:
        result = result.applymap(str)

    if type(return_type) == ArrayType:
        return pandas.Series([row[1].values for row in result.iterrows()])
    else:
        return result[result.columns[0]]


f.write(CloudPickleSerializer().dumps((predict, return_type)))
コード例 #27
0
 def test_function_module_name(self):
     ser = CloudPickleSerializer()
     func = lambda x: x
     func2 = ser.loads(ser.dumps(func))
     self.assertEqual(func.__module__, func2.__module__)
コード例 #28
0
 def from_pytorch(criterion):
     bys = CloudPickleSerializer.dumps(CloudPickleSerializer, criterion)
     net = TorchLoss(bys)
     return net
コード例 #29
0
def main(args=None):
    config = load_configurations(args)
    sc = SparkContext('local[*]', 'app', serializer=CloudPickleSerializer())
    MapCreationProgram(config=config, sc=sc)
コード例 #30
0
ファイル: test_serializers.py プロジェクト: Brett-A/spark
 def test_function_module_name(self):
     ser = CloudPickleSerializer()
     func = lambda x: x
     func2 = ser.loads(ser.dumps(func))
     self.assertEqual(func.__module__, func2.__module__)
logger = logging.getLogger(__name__)


class SAP(AccumulatorParam):
    def zero(self, initialValue):
        s = set()
        s.add(initialValue)
        return s

    def addInPlace(self, v1, v2):

        return v1.union(v2)


if __name__ == '__main__':
    sc = SparkContext('local[1]', 'app', serializer=CloudPickleSerializer())
    spark = SparkSession(sc)
    from pyspark.sql import Row

    l = [('Ankit', 25), ('Jalfaizy', 22), ('saurabh', 20), ('Bala', 26)]
    rdd = sc.parallelize(l)
    people = rdd.map(lambda x: Row(name=x[0], age=int(x[1])))
    schemaPeople = spark.createDataFrame(people)
    schemaPeople = schemaPeople.repartition(20)
    schema = StructType([
        StructField('name', StringType()),
        StructField('age', IntegerType()),
    ])
    schemaPeople.show()

    def foo(pd):
コード例 #32
0
    def _create_judf(self, name):
        func = self.func
        from pyspark.sql import SQLContext
        sc = SparkContext.getOrCreate()
        # Empty strings allow the Scala code to recognize no data and skip adding the Jython
        # code to handle vars or imports if not needed.
        serialized_vars = ""
        serialized_imports = ""
        if isinstance(func, string_types):
            src = func
        else:
            try:
                import dill
            except ImportError:
                raise ImportError(
                    "Failed to import dill, magic Jython function serialization "
                    +
                    "depends on dill on the driver machine. You may wish to pass "
                    + "your function in as a string instead.")
            try:
                src = dill.source.getsource(func)
            except:
                print(
                    "Failed to get the source code associated with provided function. "
                    +
                    "You may wish to try and assign you lambda to a variable or pass in as a "
                    + "string.")
                raise
            # Extract the globals, classes, etc. needed for this function
            file = StringIO()
            cp = cloudpickle.CloudPickler(file)
            code, f_globals, defaults, closure, dct, base_globals = cp.extract_func_data(
                func)
            closure_dct = {}
            if func.__closure__:
                if sys.version < "3":
                    closure_dct = dict(
                        zip(func.func_code.co_freevars,
                            (c.cell_contents for c in func.func_closure)))
                else:
                    closure_dct = dict(
                        zip(func.__code__.co_freevars,
                            (c.cell_contents for c in func.__closure__)))

            req = dict(base_globals)
            req.update(f_globals)
            req.update(closure_dct)
            # Serialize the "extras" and drop PySpark imports
            ser = CloudPickleSerializer()

            def isClass(v):
                return isinstance(v, (type, types.ClassType))

            def isInternal(v):
                return v.__module__.startswith("pyspark")

            # Sort out PySpark and non PySpark requirements
            req_vars = dict((k, v) for k, v in req.items()
                            if not isClass(v) or not isInternal(v))
            req_imports = dict(
                (k, v) for k, v in req.items() if isClass(v) and isInternal(v))
            if req_vars:
                serialized_vars = b64encode(
                    ser.dumps(req_vars)).decode("utf-8")
            if req_imports:
                formatted_imports = list((v.__module__, v.__name__, k)
                                         for k, v in req_imports.items())
                serialized_imports = b64encode(
                    ser.dumps(formatted_imports)).decode("utf-8")

        from pyspark.sql import SparkSession
        spark = SparkSession.builder.getOrCreate()
        jdt = spark._jsparkSession.parseDataType(self.returnType.json())
        if name is None:
            f = self.func
            name = f.__name__ if hasattr(f,
                                         '__name__') else f.__class__.__name__
        # Create a Java representation
        wrapped_jython_func = _wrap_jython_func(sc, src, serialized_vars,
                                                serialized_imports,
                                                self.setupCode)
        judf = sc._jvm.org.apache.spark.sql.jython.UserDefinedJythonFunction(
            name, wrapped_jython_func, jdt)
        return judf