def _ensure_initialized(cls): SparkContext._ensure_initialized() gw = SparkContext._gateway java_import(gw.jvm, "org.apache.spark.streaming.*") java_import(gw.jvm, "org.apache.spark.streaming.api.*") java_import(gw.jvm, "org.apache.spark.streaming.api.java.*") java_import(gw.jvm, "org.apache.spark.streaming.api.python.*") # start callback server # getattr will fallback to JVM, so we cannot test by hasattr() if "_callback_server" not in gw.__dict__ or gw._callback_server is None: gw.callback_server_parameters.eager_load = True gw.callback_server_parameters.daemonize = True gw.callback_server_parameters.daemonize_connections = True gw.callback_server_parameters.port = 0 gw.start_callback_server(gw.callback_server_parameters) cbport = gw._callback_server.server_socket.getsockname()[1] gw._callback_server.port = cbport # gateway with real port gw._python_proxy_port = gw._callback_server.port # get the GatewayServer object in JVM by ID jgws = JavaObject("GATEWAY_SERVER", gw._gateway_client) # update the port of CallbackClient with real port gw.jvm.PythonDStream.updatePythonGatewayPort( jgws, gw._python_proxy_port) _py4j_cleaner = Py4jCallbackConnectionCleaner(gw) _py4j_cleaner.start() # register serializer for TransformFunction # it happens before creating SparkContext when loading from checkpointing if cls._transformerSerializer is None: transformer_serializer = TransformFunctionSerializer() transformer_serializer.init(SparkContext._active_spark_context, CloudPickleSerializer(), gw) # SPARK-12511 streaming driver with checkpointing unable to finalize leading to OOM # There is an issue that Py4J's PythonProxyHandler.finalize blocks forever. # (https://github.com/bartdag/py4j/pull/184) # # Py4j will create a PythonProxyHandler in Java for "transformer_serializer" when # calling "registerSerializer". If we call "registerSerializer" twice, the second # PythonProxyHandler will override the first one, then the first one will be GCed and # trigger "PythonProxyHandler.finalize". To avoid that, we should not call # "registerSerializer" more than once, so that "PythonProxyHandler" in Java side won't # be GCed. # # TODO Once Py4J fixes this issue, we should upgrade Py4j to the latest version. transformer_serializer.gateway.jvm.PythonDStream.registerSerializer( transformer_serializer) cls._transformerSerializer = transformer_serializer else: cls._transformerSerializer.init(SparkContext._active_spark_context, CloudPickleSerializer(), gw)
def jvertex_rdd(self): if self.jvrdd_val: return self.jvrdd_val if self.bypass_serializer: self.jvertex_rdd_deserializer = NoOpSerializer() # enable_profile = self.ctx._conf.get("spark.python.profile", "false") == "true" # profileStats = self.ctx.accumulator(None, PStatsParam) if enable_profile else None command = (self.func, profileStats, self.prev_jvertex_rdd_deserializer, self.jvertex_rdd_deserializer) # the serialized command will be compressed by broadcast ser = CloudPickleSerializer() pickled_command = ser.dumps(command) if len(pickled_command) > (1 << 20): # 1M self.broadcast = self.ctx.broadcast(pickled_command) pickled_command = ser.dumps(self.broadcast) broadcast_vars = ListConverter().convert( [x._jbroadcast for x in self.ctx._pickled_broadcast_vars], self.ctx._gateway._gateway_client) self.ctx._pickled_broadcast_vars.clear() env = MapConverter().convert(self.ctx.environment, self.ctx._gateway._gateway_client) includes = ListConverter().convert(self.ctx._python_includes, self.ctx._gateway._gateway_client) java_storage_level = self.ctx._getJavaStorageLevel( StorageLevel.MEMORY_ONLY) python_rdd = self.ctx._jvm.PythonVertexRDD( self.prev_jvertex_rdd, bytearray(pickled_command), env, includes, self.preservesPartitioning, self.ctx.pythonExec, broadcast_vars, self.ctx._javaAccumulator, java_storage_level) self.jvrdd_val = python_rdd.asJavaVertexRDD() if enable_profile: self.id = self.jvrdd_val.id() self.ctx._add_profile(self.id, profileStats) return self.jvrdd_val
def _ensure_initialized(cls): SparkContext._ensure_initialized() gw = SparkContext._gateway java_import(gw.jvm, "org.apache.spark.streaming.*") java_import(gw.jvm, "org.apache.spark.streaming.api.java.*") java_import(gw.jvm, "org.apache.spark.streaming.api.python.*") # start callback server # getattr will fallback to JVM, so we cannot test by hasattr() if "_callback_server" not in gw.__dict__ or gw._callback_server is None: gw.callback_server_parameters.eager_load = True gw.callback_server_parameters.daemonize = True gw.callback_server_parameters.daemonize_connections = True gw.callback_server_parameters.port = 0 gw.start_callback_server(gw.callback_server_parameters) cbport = gw._callback_server.server_socket.getsockname()[1] gw._callback_server.port = cbport # gateway with real port gw._python_proxy_port = gw._callback_server.port # get the GatewayServer object in JVM by ID jgws = JavaObject("GATEWAY_SERVER", gw._gateway_client) # update the port of CallbackClient with real port gw.jvm.PythonDStream.updatePythonGatewayPort( jgws, gw._python_proxy_port) # register serializer for TransformFunction # it happens before creating SparkContext when loading from checkpointing cls._transformerSerializer = TransformFunctionSerializer( SparkContext._active_spark_context, CloudPickleSerializer(), gw)
def _jdstream(self): if self._jdstream_val: return self._jdstream_val if self._bypass_serializer: self.jrdd_deserializer = NoOpSerializer() command = (self.func, self._prev_jrdd_deserializer, self._jrdd_deserializer) # the serialized command will be compressed by broadcast ser = CloudPickleSerializer() pickled_command = ser.dumps(command) if pickled_command > (1 << 20): # 1M broadcast = self.ctx.broadcast(pickled_command) pickled_command = ser.dumps(broadcast) broadcast_vars = ListConverter().convert( [x._jbroadcast for x in self.ctx._pickled_broadcast_vars], self.ctx._gateway._gateway_client) self.ctx._pickled_broadcast_vars.clear() class_tag = self._prev_jdstream.classTag() env = MapConverter().convert(self.ctx.environment, self.ctx._gateway._gateway_client) includes = ListConverter().convert(self.ctx._python_includes, self.ctx._gateway._gateway_client) python_dstream = self.ctx._jvm.PythonDStream( self._prev_jdstream.dstream(), bytearray(pickled_command), env, includes, self.preservesPartitioning, self.ctx.pythonExec, broadcast_vars, self.ctx._javaAccumulator, class_tag) self._jdstream_val = python_dstream.asJavaDStream() return self._jdstream_val
def __init__(self, master=None, appName=None, sparkHome=None, pyFiles=None, environment=None, batchSize=0, serializer=PickleSerializer(), conf=None, gateway=None, jsc=None, profiler_cls=BasicProfiler, function_serializer=CloudPickleSerializer()): """ Create a new SparkContext. At least the master and app name should be set, either through the named parameters here or through C{conf}. :param master: Cluster URL to connect to (e.g. mesos://host:port, spark://host:port, local[4]). :param appName: A name for your job, to display on the cluster web UI. :param sparkHome: Location where Spark is installed on cluster nodes. :param pyFiles: Collection of .zip or .py files to send to the cluster and add to PYTHONPATH. These can be paths on the local file system or HDFS, HTTP, HTTPS, or FTP URLs. :param environment: A dictionary of environment variables to set on worker nodes. :param batchSize: The number of Python objects represented as a single Java object. Set 1 to disable batching, 0 to automatically choose the batch size based on object sizes, or -1 to use an unlimited batch size :param serializer: The serializer for RDDs. :param conf: A L{SparkConf} object setting Spark properties. :param gateway: Use an existing gateway and JVM, otherwise a new JVM will be instantiated. :param jsc: The JavaSparkContext instance (optional). :param profiler_cls: A class of custom Profiler used to do profiling (default is pyspark.profiler.BasicProfiler). :param function_serializer: The serializer for functions used in RDD transformations. >>> from pyspark.context import SparkContext >>> sc = SparkContext('local', 'test') >>> sc2 = SparkContext('local', 'test2') # doctest: +IGNORE_EXCEPTION_DETAIL Traceback (most recent call last): ... ValueError:... """ self._callsite = first_spark_call() or CallSite(None, None, None) SparkContext._ensure_initialized(self, gateway=gateway) try: self._do_init(master, appName, sparkHome, pyFiles, environment, batchSize, serializer, conf, jsc, profiler_cls, function_serializer) except: # If an error occurs, clean up in order to allow future SparkContext creation: self.stop() raise
def test_function_module_name(self): ser = CloudPickleSerializer() def func(x): return x func2 = ser.loads(ser.dumps(func)) self.assertEqual(func.__module__, func2.__module__)
def test_pickling_file_handles(self): # to be corrected with SPARK-11160 try: import xmlrunner # type: ignore[import] # noqa: F401 except ImportError: ser = CloudPickleSerializer() out1 = sys.stderr out2 = ser.loads(ser.dumps(out1)) self.assertEqual(out1, out2)
def test_itemgetter(self): from operator import itemgetter ser = CloudPickleSerializer() d = range(10) getter = itemgetter(1) getter2 = ser.loads(ser.dumps(getter)) self.assertEqual(getter(d), getter2(d)) getter = itemgetter(0, 3) getter2 = ser.loads(ser.dumps(getter)) self.assertEqual(getter(d), getter2(d))
def fromJson(cls, json): pyUDT = str(json["pyClass"]) # convert unicode to str split = pyUDT.rfind(".") pyModule = pyUDT[:split] pyClass = pyUDT[split + 1:] m = __import__(pyModule, globals(), locals(), [pyClass]) if not hasattr(m, pyClass): s = base64.b64decode(json['serializedClass'].encode('utf-8')) UDT = CloudPickleSerializer().loads(s) else: UDT = getattr(m, pyClass) return UDT()
def registerFunction(self, name, f, returnType="string"): def func(split, iterator): return imap(f, iterator) command = (func, self._sc.serializer, self._sc.serializer) env = MapConverter().convert(self._sc.environment, self._sc._gateway._gateway_client) includes = ListConverter().convert(self._sc._python_includes, self._sc._gateway._gateway_client) self._ssql_ctx.registerPython( name, bytearray(CloudPickleSerializer().dumps(command)), env, includes, self._sc.pythonExec, self._sc._javaAccumulator, returnType)
def _ensure_initialized(cls): SparkContext._ensure_initialized() gw = SparkContext._gateway java_import(gw.jvm, "org.apache.spark.streaming.*") java_import(gw.jvm, "org.apache.spark.streaming.api.java.*") java_import(gw.jvm, "org.apache.spark.streaming.api.python.*") from pyspark.java_gateway import ensure_callback_server_started ensure_callback_server_started(gw) # register serializer for TransformFunction # it happens before creating SparkContext when loading from checkpointing cls._transformerSerializer = TransformFunctionSerializer( SparkContext._active_spark_context, CloudPickleSerializer(), gw)
def test_func_globals(self): class Unpicklable: def __reduce__(self): raise RuntimeError("not picklable") global exit exit = Unpicklable() ser = CloudPickleSerializer() self.assertRaises(Exception, lambda: ser.dumps(exit)) def foo(): sys.exit(0) self.assertTrue("exit" in foo.__code__.co_names) ser.dumps(foo)
def jsonValue(self): if self.scalaUDT(): assert self.module() != '__main__', 'UDT in __main__ cannot work with ScalaUDT' schema = { "type": "udt", "class": self.scalaUDT(), "pyClass": "%s.%s" % (self.module(), type(self).__name__), "sqlType": self.sqlType().jsonValue() } else: ser = CloudPickleSerializer() b = ser.dumps(type(self)) schema = { "type": "udt", "pyClass": "%s.%s" % (self.module(), type(self).__name__), "serializedClass": base64.b64encode(b).decode('utf8'), "sqlType": self.sqlType().jsonValue() } return schema
def tensorflowGridSearch(): """ Wrap tensorflow so that it can be used in sklearn GridsearchCV :return: """ dataX, dataY = getMnist() dataX = dataX.astype(np.float32) dataY = np.argmax(dataY, axis=1).astype(np.int32) tuned_parameters = [{'lr': [1e-1, 1e-2], 'iters': [10, 20]}] scores = ['precision', 'recall'] model = DisLRModel(400, 10, 0.01, 10) clf = GridSearchCV(model, param_grid=tuned_parameters, cv=2, scoring='%s_macro' % "precision") clf.fit(dataX, dataY) # test whether the model could be serialized cp = CloudPickleSerializer() cp.dumps(model)
def test_attrgetter(self): from operator import attrgetter ser = CloudPickleSerializer() class C(object): def __getattr__(self, item): return item d = C() getter = attrgetter("a") getter2 = ser.loads(ser.dumps(getter)) self.assertEqual(getter(d), getter2(d)) getter = attrgetter("a", "b") getter2 = ser.loads(ser.dumps(getter)) self.assertEqual(getter(d), getter2(d)) d.e = C() getter = attrgetter("e.a") getter2 = ser.loads(ser.dumps(getter)) self.assertEqual(getter(d), getter2(d)) getter = attrgetter("e.a", "e.b") getter2 = ser.loads(ser.dumps(getter)) self.assertEqual(getter(d), getter2(d))
def _jrdd(self): if self._jrdd_val: return self._jrdd_val if self._bypass_serializer: serializer = NoOpSerializer() else: serializer = self.ctx.serializer command = (self.func, self._prev_jrdd_deserializer, serializer) pickled_command = CloudPickleSerializer().dumps(command) broadcast_vars = ListConverter().convert( [x._jbroadcast for x in self.ctx._pickled_broadcast_vars], self.ctx._gateway._gateway_client) self.ctx._pickled_broadcast_vars.clear() class_tag = self._prev_jrdd.classTag() env = MapConverter().convert(self.ctx.environment, self.ctx._gateway._gateway_client) includes = ListConverter().convert(self.ctx._python_includes, self.ctx._gateway._gateway_client) python_rdd = self.ctx._jvm.PythonRDD(self._prev_jrdd.rdd(), bytearray(pickled_command), env, includes, self.preservesPartitioning, self.ctx.pythonExec, broadcast_vars, self.ctx._javaAccumulator, class_tag) self._jrdd_val = python_rdd.asJavaRDD() return self._jrdd_val
def main(args=None): config = load_configurations(args) sc = SparkContext('local[*]', 'app', serializer=CloudPickleSerializer()) MapCreationProgram(config=config, sc=sc)
def pickle_command(command): # the serialized command will be compressed by broadcast ser = CloudPickleSerializer() pickled_command = ser.dumps(command) return pickled_command
np.int32]).astype(np.int32) elif type(elem_type) == LongType: result = result.select_dtypes( [np.byte, np.ubyte, np.short, np.ushort, np.int, np.long]) elif type(elem_type) == FloatType: result = result.select_dtypes(include=(np.number, )).astype(np.float32) elif type(elem_type) == DoubleType: result = result.select_dtypes(include=(np.number, )).astype(np.float64) if len(result.columns) == 0: raise MlflowException( message= "The the model did not produce any values compatible with the requested " "type '{}'. Consider requesting udf with StringType or " "Arraytype(StringType).".format(str(elem_type)), error_code=INVALID_PARAMETER_VALUE) if type(elem_type) == StringType: result = result.applymap(str) if type(return_type) == ArrayType: return pandas.Series([row[1].values for row in result.iterrows()]) else: return result[result.columns[0]] f.write(CloudPickleSerializer().dumps((predict, return_type)))
def _create_judf(self, name): func = self.func from pyspark.sql import SQLContext sc = SparkContext.getOrCreate() # Empty strings allow the Scala code to recognize no data and skip adding the Jython # code to handle vars or imports if not needed. serialized_vars = "" serialized_imports = "" if isinstance(func, string_types): src = func else: try: import dill except ImportError: raise ImportError( "Failed to import dill, magic Jython function serialization " + "depends on dill on the driver machine. You may wish to pass " + "your function in as a string instead.") try: src = dill.source.getsource(func) except: print( "Failed to get the source code associated with provided function. " + "You may wish to try and assign you lambda to a variable or pass in as a " + "string.") raise # Extract the globals, classes, etc. needed for this function file = StringIO() cp = cloudpickle.CloudPickler(file) code, f_globals, defaults, closure, dct, base_globals = cp.extract_func_data( func) closure_dct = {} if func.__closure__: if sys.version < "3": closure_dct = dict( zip(func.func_code.co_freevars, (c.cell_contents for c in func.func_closure))) else: closure_dct = dict( zip(func.__code__.co_freevars, (c.cell_contents for c in func.__closure__))) req = dict(base_globals) req.update(f_globals) req.update(closure_dct) # Serialize the "extras" and drop PySpark imports ser = CloudPickleSerializer() def isClass(v): return isinstance(v, (type, types.ClassType)) def isInternal(v): return v.__module__.startswith("pyspark") # Sort out PySpark and non PySpark requirements req_vars = dict((k, v) for k, v in req.items() if not isClass(v) or not isInternal(v)) req_imports = dict( (k, v) for k, v in req.items() if isClass(v) and isInternal(v)) if req_vars: serialized_vars = b64encode( ser.dumps(req_vars)).decode("utf-8") if req_imports: formatted_imports = list((v.__module__, v.__name__, k) for k, v in req_imports.items()) serialized_imports = b64encode( ser.dumps(formatted_imports)).decode("utf-8") from pyspark.sql import SparkSession spark = SparkSession.builder.getOrCreate() jdt = spark._jsparkSession.parseDataType(self.returnType.json()) if name is None: f = self.func name = f.__name__ if hasattr(f, '__name__') else f.__class__.__name__ # Create a Java representation wrapped_jython_func = _wrap_jython_func(sc, src, serialized_vars, serialized_imports, self.setupCode) judf = sc._jvm.org.apache.spark.sql.jython.UserDefinedJythonFunction( name, wrapped_jython_func, jdt) return judf
logger = logging.getLogger(__name__) class SAP(AccumulatorParam): def zero(self, initialValue): s = set() s.add(initialValue) return s def addInPlace(self, v1, v2): return v1.union(v2) if __name__ == '__main__': sc = SparkContext('local[1]', 'app', serializer=CloudPickleSerializer()) spark = SparkSession(sc) from pyspark.sql import Row l = [('Ankit', 25), ('Jalfaizy', 22), ('saurabh', 20), ('Bala', 26)] rdd = sc.parallelize(l) people = rdd.map(lambda x: Row(name=x[0], age=int(x[1]))) schemaPeople = spark.createDataFrame(people) schemaPeople = schemaPeople.repartition(20) schema = StructType([ StructField('name', StringType()), StructField('age', IntegerType()), ]) schemaPeople.show() def foo(pd):
def do_pickle(f, sc): command = (f, None, sc.serializer, sc.serializer) ser = CloudPickleSerializer() ser.dumps(command)
def test_function_module_name(self): ser = CloudPickleSerializer() func = lambda x: x func2 = ser.loads(ser.dumps(func)) self.assertEqual(func.__module__, func2.__module__)