def _load_from_socket(port, auth_secret, function, all_gather_message=None): """ Load data from a given socket, this is a blocking method thus only return when the socket connection has been closed. """ (sockfile, sock) = local_connect_and_auth(port, auth_secret) # The call may block forever, so no timeout sock.settimeout(None) if function == BARRIER_FUNCTION: # Make a barrier() function call. write_int(function, sockfile) elif function == ALL_GATHER_FUNCTION: # Make a all_gather() function call. write_int(function, sockfile) write_with_length(all_gather_message.encode("utf-8"), sockfile) else: raise ValueError("Unrecognized function type") sockfile.flush() # Collect result. res = UTF8Deserializer().loads(sockfile) # Release resources. sockfile.close() sock.close() return res
def _serialize_to_jvm(self, data, parallelism, serializer): """ Using py4j to send a large dataset to the jvm is really slow, so we use either a file or a socket if we have encryption enabled. """ if self._encryption_enabled: # with encryption, we open a server in java and send the data directly server = self._jvm.PythonParallelizeServer(self._jsc.sc(), parallelism) (sock_file, _) = local_connect_and_auth(server.port(), server.secret()) chunked_out = ChunkedStream(sock_file, 8192) serializer.dump_stream(data, chunked_out) chunked_out.close() # this call will block until the server has read all the data and processed it (or # throws an exception) return server.getResult() else: # without encryption, we serialize to a file, and we read the file in java and # parallelize from there. tempFile = NamedTemporaryFile(delete=False, dir=self._temp_dir) try: serializer.dump_stream(data, tempFile) tempFile.close() readRDDFromFile = self._jvm.PythonRDD.readRDDFromFile return readRDDFromFile(self._jsc, tempFile.name, parallelism) finally: # we eagerly read the file so we can delete right after. os.unlink(tempFile.name)
def _serialize_to_jvm(self, data, serializer, reader_func, createRDDServer): """ Using py4j to send a large dataset to the jvm is really slow, so we use either a file or a socket if we have encryption enabled. :param data: :param serializer: :param reader_func: A function which takes a filename and reads in the data in the jvm and returns a JavaRDD. Only used when encryption is disabled. :param createRDDServer: A function which creates a PythonRDDServer in the jvm to accept the serialized data, for use when encryption is enabled. :return: """ if self._encryption_enabled: # with encryption, we open a server in java and send the data directly server = createRDDServer() (sock_file, _) = local_connect_and_auth(server.port(), server.secret()) chunked_out = ChunkedStream(sock_file, 8192) serializer.dump_stream(data, chunked_out) chunked_out.close() # this call will block until the server has read all the data and processed it (or # throws an exception) r = server.getResult() return r else: # without encryption, we serialize to a file, and we read the file in java and # parallelize from there. tempFile = NamedTemporaryFile(delete=False, dir=self._temp_dir) try: serializer.dump_stream(data, tempFile) tempFile.close() return reader_func(tempFile.name) finally: # we eagerily reads the file so we can delete right after. os.unlink(tempFile.name)
def value(self): """Return the broadcasted value""" if not hasattr(self, "_value") and self._path is not None: # we only need to decrypt it here when encryption is enabled and # if its on the driver, since executor decryption is handled already if self._sc is not None and self._sc._encryption_enabled: port, auth_secret = self._python_broadcast.setupDecryptionServer() (decrypted_sock_file, _) = local_connect_and_auth(port, auth_secret) self._python_broadcast.waitTillBroadcastDataSent() return self.load(decrypted_sock_file) else: self._value = self.load_from_path(self._path) return self._value
def __init__( self, sc: Optional["SparkContext"] = None, value: Optional[T] = None, pickle_registry: Optional["BroadcastPickleRegistry"] = None, path: Optional[str] = None, sock_file: Optional[BinaryIO] = None, ): """ Should not be called directly by users -- use :meth:`SparkContext.broadcast` instead. """ if sc is not None: # we're on the driver. We want the pickled data to end up in a file (maybe encrypted) f = NamedTemporaryFile(delete=False, dir=sc._temp_dir) self._path = f.name self._sc: Optional["SparkContext"] = sc assert sc._jvm is not None self._python_broadcast = sc._jvm.PythonRDD.setupBroadcast( self._path) broadcast_out: Union[ChunkedStream, IO[bytes]] if sc._encryption_enabled: # with encryption, we ask the jvm to do the encryption for us, we send it data # over a socket port, auth_secret = self._python_broadcast.setupEncryptionServer( ) (encryption_sock_file, _) = local_connect_and_auth(port, auth_secret) broadcast_out = ChunkedStream(encryption_sock_file, 8192) else: # no encryption, we can just write pickled data directly to the file from python broadcast_out = f self.dump(value, broadcast_out) # type: ignore[arg-type] if sc._encryption_enabled: self._python_broadcast.waitTillDataReceived() self._jbroadcast = sc._jsc.broadcast(self._python_broadcast) self._pickle_registry = pickle_registry else: # we're on an executor self._jbroadcast = None self._sc = None self._python_broadcast = None if sock_file is not None: # the jvm is doing decryption for us. Read the value # immediately from the sock_file self._value = self.load(sock_file) else: # the jvm just dumps the pickled data in path -- we'll unpickle lazily when # the value is requested assert path is not None self._path = path
def value(self): """ Return the broadcasted value """ if not hasattr(self, "_value") and self._path is not None: # we only need to decrypt it here when encryption is enabled and # if its on the driver, since executor decryption is handled already if self._sc is not None and self._sc._encryption_enabled: port, auth_secret = self._python_broadcast.setupDecryptionServer() (decrypted_sock_file, _) = local_connect_and_auth(port, auth_secret) self._python_broadcast.waitTillBroadcastDataSent() return self.load(decrypted_sock_file) else: self._value = self.load_from_path(self._path) return self._value
def _load_from_socket(port, auth_secret): """ Load data from a given socket, this is a blocking method thus only return when the socket connection has been closed. """ (sockfile, sock) = local_connect_and_auth(port, auth_secret) # The barrier() call may block forever, so no timeout sock.settimeout(None) # Make a barrier() function call. write_int(BARRIER_FUNCTION, sockfile) sockfile.flush() # Collect result. res = UTF8Deserializer().loads(sockfile) # Release resources. sockfile.close() sock.close() return res
def _load_from_socket(port, auth_secret): """ Load data from a given socket, this is a blocking method thus only return when the socket connection has been closed. """ (sockfile, sock) = local_connect_and_auth(port, auth_secret) # The barrier() call may block forever, so no timeout sock.settimeout(None) # Make a barrier() function call. write_int(BARRIER_FUNCTION, sockfile) sockfile.flush() # Collect result. res = UTF8Deserializer().loads(sockfile) # Release resources. sockfile.close() sock.close() return res
def __init__(self, sc=None, value=None, pickle_registry=None, path=None, sock_file=None): """ Should not be called directly by users -- use L{SparkContext.broadcast()} instead. """ if sc is not None: # we're on the driver. We want the pickled data to end up in a file (maybe encrypted) f = NamedTemporaryFile(delete=False, dir=sc._temp_dir) self._path = f.name self._sc = sc self._python_broadcast = sc._jvm.PythonRDD.setupBroadcast(self._path) if sc._encryption_enabled: # with encryption, we ask the jvm to do the encryption for us, we send it data # over a socket port, auth_secret = self._python_broadcast.setupEncryptionServer() (encryption_sock_file, _) = local_connect_and_auth(port, auth_secret) broadcast_out = ChunkedStream(encryption_sock_file, 8192) else: # no encryption, we can just write pickled data directly to the file from python broadcast_out = f self.dump(value, broadcast_out) if sc._encryption_enabled: self._python_broadcast.waitTillDataReceived() self._jbroadcast = sc._jsc.broadcast(self._python_broadcast) self._pickle_registry = pickle_registry else: # we're on an executor self._jbroadcast = None self._sc = None self._python_broadcast = None if sock_file is not None: # the jvm is doing decryption for us. Read the value # immediately from the sock_file self._value = self.load(sock_file) else: # the jvm just dumps the pickled data in path -- we'll unpickle lazily when # the value is requested assert(path is not None) self._path = path
def _load_from_socket( port: Optional[Union[str, int]], auth_secret: str, function: int, all_gather_message: Optional[str] = None, ) -> List[str]: """ Load data from a given socket, this is a blocking method thus only return when the socket connection has been closed. """ (sockfile, sock) = local_connect_and_auth(port, auth_secret) # The call may block forever, so no timeout sock.settimeout(None) if function == BARRIER_FUNCTION: # Make a barrier() function call. write_int(function, sockfile) elif function == ALL_GATHER_FUNCTION: # Make a all_gather() function call. write_int(function, sockfile) write_with_length( cast(str, all_gather_message).encode("utf-8"), sockfile) else: raise ValueError("Unrecognized function type") sockfile.flush() # Collect result. len = read_int(sockfile) res = [] for i in range(len): res.append(UTF8Deserializer().loads(sockfile)) # Release resources. sockfile.close() sock.close() return res
def _serialize_to_jvm(self, data, serializer, reader_func, createRDDServer): """ Using py4j to send a large dataset to the jvm is really slow, so we use either a file or a socket if we have encryption enabled. :param data: :param serializer: :param reader_func: A function which takes a filename and reads in the data in the jvm and returns a JavaRDD. Only used when encryption is disabled. :param createRDDServer: A function which creates a PythonRDDServer in the jvm to accept the serialized data, for use when encryption is enabled. :return: """ if self._encryption_enabled: # with encryption, we open a server in java and send the data directly server = createRDDServer() (sock_file, _) = local_connect_and_auth(server.port(), server.secret()) chunked_out = ChunkedStream(sock_file, 8192) serializer.dump_stream(data, chunked_out) chunked_out.close() # this call will block until the server has read all the data and processed it (or # throws an exception) r = server.getResult() return r else: # without encryption, we serialize to a file, and we read the file in java and # parallelize from there. tempFile = NamedTemporaryFile(delete=False, dir=self._temp_dir) try: try: serializer.dump_stream(data, tempFile) finally: tempFile.close() return reader_func(tempFile.name) finally: # we eagerily reads the file so we can delete right after. os.unlink(tempFile.name)
def main(infile, outfile): try: boot_time = time.time() split_index = read_int(infile) if split_index == -1: # for unit tests sys.exit(-1) version = utf8_deserializer.loads(infile) if version != "%d.%d" % sys.version_info[:2]: raise Exception(("Python in worker has different version %s than that in " + "driver %s, PySpark cannot run with different minor versions." + "Please check environment variables PYSPARK_PYTHON and " + "PYSPARK_DRIVER_PYTHON are correctly set.") % ("%d.%d" % sys.version_info[:2], version)) # read inputs only for a barrier task isBarrier = read_bool(infile) boundPort = read_int(infile) secret = UTF8Deserializer().loads(infile) # set up memory limits memory_limit_mb = int(os.environ.get('PYSPARK_EXECUTOR_MEMORY_MB', "-1")) if memory_limit_mb > 0 and has_resource_module: total_memory = resource.RLIMIT_AS try: (soft_limit, hard_limit) = resource.getrlimit(total_memory) msg = "Current mem limits: {0} of max {1}\n".format(soft_limit, hard_limit) print(msg, file=sys.stderr) # convert to bytes new_limit = memory_limit_mb * 1024 * 1024 if soft_limit == resource.RLIM_INFINITY or new_limit < soft_limit: msg = "Setting mem limits to {0} of max {1}\n".format(new_limit, new_limit) print(msg, file=sys.stderr) resource.setrlimit(total_memory, (new_limit, new_limit)) except (resource.error, OSError, ValueError) as e: # not all systems support resource limits, so warn instead of failing print("WARN: Failed to set memory limit: {0}\n".format(e), file=sys.stderr) # initialize global state taskContext = None if isBarrier: taskContext = BarrierTaskContext._getOrCreate() BarrierTaskContext._initialize(boundPort, secret) else: taskContext = TaskContext._getOrCreate() # read inputs for TaskContext info taskContext._stageId = read_int(infile) taskContext._partitionId = read_int(infile) taskContext._attemptNumber = read_int(infile) taskContext._taskAttemptId = read_long(infile) taskContext._localProperties = dict() for i in range(read_int(infile)): k = utf8_deserializer.loads(infile) v = utf8_deserializer.loads(infile) taskContext._localProperties[k] = v shuffle.MemoryBytesSpilled = 0 shuffle.DiskBytesSpilled = 0 _accumulatorRegistry.clear() # fetch name of workdir spark_files_dir = utf8_deserializer.loads(infile) SparkFiles._root_directory = spark_files_dir SparkFiles._is_running_on_worker = True # fetch names of includes (*.zip and *.egg files) and construct PYTHONPATH add_path(spark_files_dir) # *.py files that were added will be copied here num_python_includes = read_int(infile) for _ in range(num_python_includes): filename = utf8_deserializer.loads(infile) add_path(os.path.join(spark_files_dir, filename)) if sys.version > '3': import importlib importlib.invalidate_caches() # fetch names and values of broadcast variables needs_broadcast_decryption_server = read_bool(infile) num_broadcast_variables = read_int(infile) if needs_broadcast_decryption_server: # read the decrypted data from a server in the jvm port = read_int(infile) auth_secret = utf8_deserializer.loads(infile) (broadcast_sock_file, _) = local_connect_and_auth(port, auth_secret) for _ in range(num_broadcast_variables): bid = read_long(infile) if bid >= 0: if needs_broadcast_decryption_server: read_bid = read_long(broadcast_sock_file) assert(read_bid == bid) _broadcastRegistry[bid] = \ Broadcast(sock_file=broadcast_sock_file) else: path = utf8_deserializer.loads(infile) _broadcastRegistry[bid] = Broadcast(path=path) else: bid = - bid - 1 _broadcastRegistry.pop(bid) if needs_broadcast_decryption_server: broadcast_sock_file.write(b'1') broadcast_sock_file.close() _accumulatorRegistry.clear() eval_type = read_int(infile) if eval_type == PythonEvalType.NON_UDF: func, profiler, deserializer, serializer = read_command(pickleSer, infile) else: func, profiler, deserializer, serializer = read_udfs(pickleSer, infile, eval_type) init_time = time.time() def process(): iterator = deserializer.load_stream(infile) serializer.dump_stream(func(split_index, iterator), outfile) if profiler: profiler.profile(process) else: process() except Exception: try: write_int(SpecialLengths.PYTHON_EXCEPTION_THROWN, outfile) write_with_length(traceback.format_exc().encode("utf-8"), outfile) except IOError: # JVM close the socket pass except Exception: # Write the error to stderr if it happened while serializing print("PySpark worker failed with exception:", file=sys.stderr) print(traceback.format_exc(), file=sys.stderr) sys.exit(-1) finish_time = time.time() report_times(outfile, boot_time, init_time, finish_time) write_long(shuffle.MemoryBytesSpilled, outfile) write_long(shuffle.DiskBytesSpilled, outfile) # Mark the beginning of the accumulators section of the output write_int(SpecialLengths.END_OF_DATA_SECTION, outfile) write_int(len(_accumulatorRegistry), outfile) for (aid, accum) in _accumulatorRegistry.items(): pickleSer._write_with_length((aid, accum._value), outfile) # check end of stream if read_int(infile) == SpecialLengths.END_OF_STREAM: write_int(SpecialLengths.END_OF_STREAM, outfile) else: # write a different value to tell JVM to not reuse this worker write_int(SpecialLengths.END_OF_DATA_SECTION, outfile) sys.exit(-1)
# Write the error to stderr if it happened while serializing print("PySpark worker failed with exception:", file=sys.stderr) print(traceback.format_exc(), file=sys.stderr) sys.exit(-1) finish_time = time.time() report_times(outfile, boot_time, init_time, finish_time) write_long(shuffle.MemoryBytesSpilled, outfile) write_long(shuffle.DiskBytesSpilled, outfile) # Mark the beginning of the accumulators section of the output write_int(SpecialLengths.END_OF_DATA_SECTION, outfile) write_int(len(_accumulatorRegistry), outfile) for (aid, accum) in _accumulatorRegistry.items(): pickleSer._write_with_length((aid, accum._value), outfile) # check end of stream if read_int(infile) == SpecialLengths.END_OF_STREAM: write_int(SpecialLengths.END_OF_STREAM, outfile) else: # write a different value to tell JVM to not reuse this worker write_int(SpecialLengths.END_OF_DATA_SECTION, outfile) sys.exit(-1) if __name__ == '__main__': # Read information about how to connect back to the JVM from the environment. java_port = int(os.environ["PYTHON_WORKER_FACTORY_PORT"]) auth_secret = os.environ["PYTHON_WORKER_FACTORY_SECRET"] (sock_file, _) = local_connect_and_auth(java_port, auth_secret) main(sock_file, sock_file)
def main(infile, outfile): faulthandler_log_path = os.environ.get("PYTHON_FAULTHANDLER_DIR", None) try: if faulthandler_log_path: faulthandler_log_path = os.path.join(faulthandler_log_path, str(os.getpid())) faulthandler_log_file = open(faulthandler_log_path, "w") faulthandler.enable(file=faulthandler_log_file) boot_time = time.time() split_index = read_int(infile) if split_index == -1: # for unit tests sys.exit(-1) version = utf8_deserializer.loads(infile) if version != "%d.%d" % sys.version_info[:2]: raise RuntimeError( ( "Python in worker has different version %s than that in " + "driver %s, PySpark cannot run with different minor versions. " + "Please check environment variables PYSPARK_PYTHON and " + "PYSPARK_DRIVER_PYTHON are correctly set." ) % ("%d.%d" % sys.version_info[:2], version) ) # read inputs only for a barrier task isBarrier = read_bool(infile) boundPort = read_int(infile) secret = UTF8Deserializer().loads(infile) # set up memory limits memory_limit_mb = int(os.environ.get("PYSPARK_EXECUTOR_MEMORY_MB", "-1")) if memory_limit_mb > 0 and has_resource_module: total_memory = resource.RLIMIT_AS try: (soft_limit, hard_limit) = resource.getrlimit(total_memory) msg = "Current mem limits: {0} of max {1}\n".format(soft_limit, hard_limit) print(msg, file=sys.stderr) # convert to bytes new_limit = memory_limit_mb * 1024 * 1024 if soft_limit == resource.RLIM_INFINITY or new_limit < soft_limit: msg = "Setting mem limits to {0} of max {1}\n".format(new_limit, new_limit) print(msg, file=sys.stderr) resource.setrlimit(total_memory, (new_limit, new_limit)) except (resource.error, OSError, ValueError) as e: # not all systems support resource limits, so warn instead of failing lineno = ( getframeinfo(currentframe()).lineno + 1 if currentframe() is not None else 0 ) print( warnings.formatwarning( "Failed to set memory limit: {0}".format(e), ResourceWarning, __file__, lineno, ), file=sys.stderr, ) # initialize global state taskContext = None if isBarrier: taskContext = BarrierTaskContext._getOrCreate() BarrierTaskContext._initialize(boundPort, secret) # Set the task context instance here, so we can get it by TaskContext.get for # both TaskContext and BarrierTaskContext TaskContext._setTaskContext(taskContext) else: taskContext = TaskContext._getOrCreate() # read inputs for TaskContext info taskContext._stageId = read_int(infile) taskContext._partitionId = read_int(infile) taskContext._attemptNumber = read_int(infile) taskContext._taskAttemptId = read_long(infile) taskContext._cpus = read_int(infile) taskContext._resources = {} for r in range(read_int(infile)): key = utf8_deserializer.loads(infile) name = utf8_deserializer.loads(infile) addresses = [] taskContext._resources = {} for a in range(read_int(infile)): addresses.append(utf8_deserializer.loads(infile)) taskContext._resources[key] = ResourceInformation(name, addresses) taskContext._localProperties = dict() for i in range(read_int(infile)): k = utf8_deserializer.loads(infile) v = utf8_deserializer.loads(infile) taskContext._localProperties[k] = v shuffle.MemoryBytesSpilled = 0 shuffle.DiskBytesSpilled = 0 _accumulatorRegistry.clear() # fetch name of workdir spark_files_dir = utf8_deserializer.loads(infile) SparkFiles._root_directory = spark_files_dir SparkFiles._is_running_on_worker = True # fetch names of includes (*.zip and *.egg files) and construct PYTHONPATH add_path(spark_files_dir) # *.py files that were added will be copied here num_python_includes = read_int(infile) for _ in range(num_python_includes): filename = utf8_deserializer.loads(infile) add_path(os.path.join(spark_files_dir, filename)) importlib.invalidate_caches() # fetch names and values of broadcast variables needs_broadcast_decryption_server = read_bool(infile) num_broadcast_variables = read_int(infile) if needs_broadcast_decryption_server: # read the decrypted data from a server in the jvm port = read_int(infile) auth_secret = utf8_deserializer.loads(infile) (broadcast_sock_file, _) = local_connect_and_auth(port, auth_secret) for _ in range(num_broadcast_variables): bid = read_long(infile) if bid >= 0: if needs_broadcast_decryption_server: read_bid = read_long(broadcast_sock_file) assert read_bid == bid _broadcastRegistry[bid] = Broadcast(sock_file=broadcast_sock_file) else: path = utf8_deserializer.loads(infile) _broadcastRegistry[bid] = Broadcast(path=path) else: bid = -bid - 1 _broadcastRegistry.pop(bid) if needs_broadcast_decryption_server: broadcast_sock_file.write(b"1") broadcast_sock_file.close() _accumulatorRegistry.clear() eval_type = read_int(infile) if eval_type == PythonEvalType.NON_UDF: func, profiler, deserializer, serializer = read_command(pickleSer, infile) else: func, profiler, deserializer, serializer = read_udfs(pickleSer, infile, eval_type) init_time = time.time() def process(): iterator = deserializer.load_stream(infile) out_iter = func(split_index, iterator) try: serializer.dump_stream(out_iter, outfile) finally: if hasattr(out_iter, "close"): out_iter.close() if profiler: profiler.profile(process) else: process() # Reset task context to None. This is a guard code to avoid residual context when worker # reuse. TaskContext._setTaskContext(None) BarrierTaskContext._setTaskContext(None) except BaseException as e: try: exc_info = None if os.environ.get("SPARK_SIMPLIFIED_TRACEBACK", False): tb = try_simplify_traceback(sys.exc_info()[-1]) if tb is not None: e.__cause__ = None exc_info = "".join(traceback.format_exception(type(e), e, tb)) if exc_info is None: exc_info = traceback.format_exc() write_int(SpecialLengths.PYTHON_EXCEPTION_THROWN, outfile) write_with_length(exc_info.encode("utf-8"), outfile) except IOError: # JVM close the socket pass except BaseException: # Write the error to stderr if it happened while serializing print("PySpark worker failed with exception:", file=sys.stderr) print(traceback.format_exc(), file=sys.stderr) sys.exit(-1) finally: if faulthandler_log_path: faulthandler.disable() faulthandler_log_file.close() os.remove(faulthandler_log_path) finish_time = time.time() report_times(outfile, boot_time, init_time, finish_time) write_long(shuffle.MemoryBytesSpilled, outfile) write_long(shuffle.DiskBytesSpilled, outfile) # Mark the beginning of the accumulators section of the output write_int(SpecialLengths.END_OF_DATA_SECTION, outfile) write_int(len(_accumulatorRegistry), outfile) for (aid, accum) in _accumulatorRegistry.items(): pickleSer._write_with_length((aid, accum._value), outfile) # check end of stream if read_int(infile) == SpecialLengths.END_OF_STREAM: write_int(SpecialLengths.END_OF_STREAM, outfile) else: # write a different value to tell JVM to not reuse this worker write_int(SpecialLengths.END_OF_DATA_SECTION, outfile) sys.exit(-1)
def main(infile, outfile): try: boot_time = time.time() split_index = read_int(infile) if split_index == -1: # for unit tests exit(-1) version = utf8_deserializer.loads(infile) if version != "%d.%d" % sys.version_info[:2]: raise Exception( ("Python in worker has different version %s than that in " + "driver %s, PySpark cannot run with different minor versions." + "Please check environment variables PYSPARK_PYTHON and " + "PYSPARK_DRIVER_PYTHON are correctly set.") % ("%d.%d" % sys.version_info[:2], version)) # initialize global state taskContext = TaskContext._getOrCreate() taskContext._stageId = read_int(infile) taskContext._partitionId = read_int(infile) taskContext._attemptNumber = read_int(infile) taskContext._taskAttemptId = read_long(infile) shuffle.MemoryBytesSpilled = 0 shuffle.DiskBytesSpilled = 0 _accumulatorRegistry.clear() # fetch name of workdir spark_files_dir = utf8_deserializer.loads(infile) SparkFiles._root_directory = spark_files_dir SparkFiles._is_running_on_worker = True # fetch names of includes (*.zip and *.egg files) and construct PYTHONPATH add_path( spark_files_dir) # *.py files that were added will be copied here num_python_includes = read_int(infile) for _ in range(num_python_includes): filename = utf8_deserializer.loads(infile) add_path(os.path.join(spark_files_dir, filename)) if sys.version > '3': import importlib importlib.invalidate_caches() # fetch names and values of broadcast variables needs_broadcast_decryption_server = read_bool(infile) num_broadcast_variables = read_int(infile) if needs_broadcast_decryption_server: # read the decrypted data from a server in the jvm port = read_int(infile) auth_secret = utf8_deserializer.loads(infile) (broadcast_sock_file, _) = local_connect_and_auth(port, auth_secret) for _ in range(num_broadcast_variables): bid = read_long(infile) if bid >= 0: if needs_broadcast_decryption_server: read_bid = read_long(broadcast_sock_file) assert (read_bid == bid) _broadcastRegistry[bid] = \ Broadcast(sock_file=broadcast_sock_file) else: path = utf8_deserializer.loads(infile) _broadcastRegistry[bid] = Broadcast(path=path) else: bid = -bid - 1 _broadcastRegistry.pop(bid) if needs_broadcast_decryption_server: broadcast_sock_file.write(b'1') broadcast_sock_file.close() _accumulatorRegistry.clear() is_sql_udf = read_int(infile) if is_sql_udf: func, profiler, deserializer, serializer = read_udfs( pickleSer, infile) else: func, profiler, deserializer, serializer = read_command( pickleSer, infile) init_time = time.time() def process(): iterator = deserializer.load_stream(infile) serializer.dump_stream(func(split_index, iterator), outfile) if profiler: profiler.profile(process) else: process() except Exception: try: write_int(SpecialLengths.PYTHON_EXCEPTION_THROWN, outfile) write_with_length(traceback.format_exc().encode("utf-8"), outfile) except IOError: # JVM close the socket pass except Exception: # Write the error to stderr if it happened while serializing print("PySpark worker failed with exception:", file=sys.stderr) print(traceback.format_exc(), file=sys.stderr) exit(-1) finish_time = time.time() report_times(outfile, boot_time, init_time, finish_time) write_long(shuffle.MemoryBytesSpilled, outfile) write_long(shuffle.DiskBytesSpilled, outfile) # Mark the beginning of the accumulators section of the output write_int(SpecialLengths.END_OF_DATA_SECTION, outfile) write_int(len(_accumulatorRegistry), outfile) for (aid, accum) in _accumulatorRegistry.items(): pickleSer._write_with_length((aid, accum._value), outfile) # check end of stream if read_int(infile) == SpecialLengths.END_OF_STREAM: write_int(SpecialLengths.END_OF_STREAM, outfile) else: # write a different value to tell JVM to not reuse this worker write_int(SpecialLengths.END_OF_DATA_SECTION, outfile) exit(-1)