def read_single_udf(pickleSer, infile, eval_type, runner_conf, udf_index): num_arg = read_int(infile) arg_offsets = [read_int(infile) for i in range(num_arg)] row_func = None for i in range(read_int(infile)): f, return_type = read_command(pickleSer, infile) if row_func is None: row_func = f else: row_func = chain(row_func, f) # make sure StopIteration's raised in the user code are not ignored # when they are processed in a for loop, raise them as RuntimeError's instead func = fail_on_stopiteration(row_func) # the last returnType will be the return type of UDF if eval_type == PythonEvalType.SQL_SCALAR_PANDAS_UDF: return arg_offsets, wrap_scalar_pandas_udf(func, return_type) elif eval_type == PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF: argspec = _get_argspec(row_func) # signature was lost when wrapping it return arg_offsets, wrap_grouped_map_pandas_udf(func, return_type, argspec, runner_conf) elif eval_type == PythonEvalType.SQL_GROUPED_AGG_PANDAS_UDF: return arg_offsets, wrap_grouped_agg_pandas_udf(func, return_type) elif eval_type == PythonEvalType.SQL_WINDOW_AGG_PANDAS_UDF: return arg_offsets, wrap_window_agg_pandas_udf(func, return_type, runner_conf, udf_index) elif eval_type == PythonEvalType.SQL_BATCHED_UDF: return arg_offsets, wrap_udf(func, return_type) else: raise ValueError("Unknown eval type: {}".format(eval_type))
def read_udfs(pickleSer, infile, eval_type): runner_conf = {} if eval_type in (PythonEvalType.SQL_SCALAR_PANDAS_UDF, PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF, PythonEvalType.SQL_GROUPED_AGG_PANDAS_UDF, PythonEvalType.SQL_WINDOW_AGG_PANDAS_UDF): # Load conf used for pandas_udf evaluation num_conf = read_int(infile) for i in range(num_conf): k = utf8_deserializer.loads(infile) v = utf8_deserializer.loads(infile) runner_conf[k] = v # NOTE: if timezone is set here, that implies respectSessionTimeZone is True timezone = runner_conf.get("spark.sql.session.timeZone", None) ser = ArrowStreamPandasSerializer(timezone) else: ser = BatchedSerializer(PickleSerializer(), 100) num_udfs = read_int(infile) udfs = {} call_udf = [] mapper_str = "" if eval_type == PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF: # Create function like this: # lambda a: f([a[0]], [a[0], a[1]]) # We assume there is only one UDF here because grouped map doesn't # support combining multiple UDFs. assert num_udfs == 1 # See FlatMapGroupsInPandasExec for how arg_offsets are used to # distinguish between grouping attributes and data attributes arg_offsets, udf = read_single_udf( pickleSer, infile, eval_type, runner_conf, udf_index=0) udfs['f'] = udf split_offset = arg_offsets[0] + 1 arg0 = ["a[%d]" % o for o in arg_offsets[1: split_offset]] arg1 = ["a[%d]" % o for o in arg_offsets[split_offset:]] mapper_str = "lambda a: f([%s], [%s])" % (", ".join(arg0), ", ".join(arg1)) else: # Create function like this: # lambda a: (f0(a[0]), f1(a[1], a[2]), f2(a[3])) # In the special case of a single UDF this will return a single result rather # than a tuple of results; this is the format that the JVM side expects. for i in range(num_udfs): arg_offsets, udf = read_single_udf( pickleSer, infile, eval_type, runner_conf, udf_index=i) udfs['f%d' % i] = udf args = ["a[%d]" % o for o in arg_offsets] call_udf.append("f%d(%s)" % (i, ", ".join(args))) mapper_str = "lambda a: (%s)" % (", ".join(call_udf)) mapper = eval(mapper_str, udfs) func = lambda _, it: map(mapper, it) # profiling is not supported for UDF return func, None, ser, ser
def main(): split_index = read_int(sys.stdin) spark_files_dir = load_pickle(read_with_length(sys.stdin)) SparkFiles._root_directory = spark_files_dir SparkFiles._is_running_on_worker = True sys.path.append(spark_files_dir) num_broadcast_variables = read_int(sys.stdin) for _ in range(num_broadcast_variables): bid = read_long(sys.stdin) value = read_with_length(sys.stdin) _broadcastRegistry[bid] = Broadcast(bid, load_pickle(value)) func = load_obj() bypassSerializer = load_obj() if bypassSerializer: dumps = lambda x: x else: dumps = dump_pickle iterator = read_from_pickle_file(sys.stdin) try: for obj in func(split_index, iterator): write_with_length(dumps(obj), old_stdout) except Exception as e: write_int(-2, old_stdout) write_with_length(traceback.format_exc(), old_stdout) sys.exit(-1) # Mark the beginning of the accumulators section of the output write_int(-1, old_stdout) for aid, accum in _accumulatorRegistry.items(): write_with_length(dump_pickle((aid, accum._value)), old_stdout)
def main(infile, outfile): try: boot_time = time.time() split_index = read_int(infile) if split_index == -1: # for unit tests return # fetch name of workdir spark_files_dir = utf8_deserializer.loads(infile) SparkFiles._root_directory = spark_files_dir SparkFiles._is_running_on_worker = True # fetch names of includes (*.zip and *.egg files) and construct PYTHONPATH sys.path.append(spark_files_dir) # *.py files that were added will be copied here num_python_includes = read_int(infile) for _ in range(num_python_includes): filename = utf8_deserializer.loads(infile) sys.path.append(os.path.join(spark_files_dir, filename)) # fetch names and values of broadcast variables num_broadcast_variables = read_int(infile) ser = CompressedSerializer(pickleSer) for _ in range(num_broadcast_variables): bid = read_long(infile) if bid >= 0: value = ser._read_with_length(infile) _broadcastRegistry[bid] = Broadcast(bid, value) else: bid = - bid - 1 _broadcastRegistry.remove(bid) _accumulatorRegistry.clear() command = pickleSer._read_with_length(infile) (func, deserializer, serializer) = command init_time = time.time() iterator = deserializer.load_stream(infile) serializer.dump_stream(func(split_index, iterator), outfile) except Exception: try: write_int(SpecialLengths.PYTHON_EXCEPTION_THROWN, outfile) write_with_length(traceback.format_exc(), outfile) outfile.flush() except IOError: # JVM close the socket pass except Exception: # Write the error to stderr if it happened while serializing print >> sys.stderr, "PySpark worker failed with exception:" print >> sys.stderr, traceback.format_exc() exit(-1) finish_time = time.time() report_times(outfile, boot_time, init_time, finish_time) # Mark the beginning of the accumulators section of the output write_int(SpecialLengths.END_OF_DATA_SECTION, outfile) write_int(len(_accumulatorRegistry), outfile) for (aid, accum) in _accumulatorRegistry.items(): pickleSer._write_with_length((aid, accum._value), outfile)
def read_single_udf(pickleSer, infile): num_arg = read_int(infile) arg_offsets = [read_int(infile) for i in range(num_arg)] row_func = None for i in range(read_int(infile)): f, return_type = read_command(pickleSer, infile) if row_func is None: row_func = f else: row_func = chain(row_func, f) # the last returnType will be the return type of UDF return arg_offsets, wrap_udf(row_func, return_type)
def read_single_udf(pickleSer, infile, eval_type): num_arg = read_int(infile) arg_offsets = [read_int(infile) for i in range(num_arg)] row_func = None for i in range(read_int(infile)): f, return_type = read_command(pickleSer, infile) if row_func is None: row_func = f else: row_func = chain(row_func, f) # the last returnType will be the return type of UDF if eval_type == PythonEvalType.SQL_PANDAS_UDF: return arg_offsets, wrap_pandas_udf(row_func, return_type) else: return arg_offsets, wrap_udf(row_func, return_type)
def do_termination_test(self, terminator): from subprocess import Popen, PIPE from errno import ECONNREFUSED # start daemon daemon_path = os.path.join(os.path.dirname(__file__), "daemon.py") daemon = Popen([sys.executable, daemon_path], stdin=PIPE, stdout=PIPE) # read the port number port = read_int(daemon.stdout) # daemon should accept connections self.assertTrue(self.connect(port)) # request shutdown terminator(daemon) time.sleep(1) # daemon should no longer accept connections try: self.connect(port) except EnvironmentError as exception: self.assertEqual(exception.errno, ECONNREFUSED) else: self.fail("Expected EnvironmentError to be raised")
def main(): split_index = read_int(sys.stdin) num_broadcast_variables = read_int(sys.stdin) for _ in range(num_broadcast_variables): bid = read_long(sys.stdin) value = read_with_length(sys.stdin) _broadcastRegistry[bid] = Broadcast(bid, load_pickle(value)) func = load_obj() bypassSerializer = load_obj() if bypassSerializer: dumps = lambda x: x else: dumps = dump_pickle iterator = read_from_pickle_file(sys.stdin) for obj in func(split_index, iterator): write_with_length(dumps(obj), old_stdout)
def read_udfs(pickleSer, infile, eval_type): num_udfs = read_int(infile) udfs = {} call_udf = [] for i in range(num_udfs): arg_offsets, udf = read_single_udf(pickleSer, infile, eval_type) udfs['f%d' % i] = udf args = ["a[%d]" % o for o in arg_offsets] call_udf.append("f%d(%s)" % (i, ", ".join(args))) # Create function like this: # lambda a: (f0(a0), f1(a1, a2), f2(a3)) # In the special case of a single UDF this will return a single result rather # than a tuple of results; this is the format that the JVM side expects. mapper_str = "lambda a: (%s)" % (", ".join(call_udf)) mapper = eval(mapper_str, udfs) func = lambda _, it: map(mapper, it) if eval_type in (PythonEvalType.SQL_SCALAR_PANDAS_UDF, PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF, PythonEvalType.SQL_GROUPED_AGG_PANDAS_UDF): timezone = utf8_deserializer.loads(infile) ser = ArrowStreamPandasSerializer(timezone) else: ser = BatchedSerializer(PickleSerializer(), 100) # profiling is not supported for UDF return func, None, ser, ser
def main(infile, outfile): boot_time = time.time() split_index = read_int(infile) if split_index == -1: # for unit tests return # fetch name of workdir spark_files_dir = load_pickle(read_with_length(infile)) SparkFiles._root_directory = spark_files_dir SparkFiles._is_running_on_worker = True # fetch names and values of broadcast variables num_broadcast_variables = read_int(infile) for _ in range(num_broadcast_variables): bid = read_long(infile) value = read_with_length(infile) _broadcastRegistry[bid] = Broadcast(bid, load_pickle(value)) # fetch names of includes (*.zip and *.egg files) and construct PYTHONPATH sys.path.append(spark_files_dir) # *.py files that were added will be copied here num_python_includes = read_int(infile) for _ in range(num_python_includes): sys.path.append(os.path.join(spark_files_dir, load_pickle(read_with_length(infile)))) # now load function func = load_obj(infile) bypassSerializer = load_obj(infile) if bypassSerializer: dumps = lambda x: x else: dumps = dump_pickle init_time = time.time() iterator = read_from_pickle_file(infile) try: for obj in func(split_index, iterator): write_with_length(dumps(obj), outfile) except Exception as e: write_int(-2, outfile) write_with_length(traceback.format_exc(), outfile) sys.exit(-1) finish_time = time.time() report_times(outfile, boot_time, init_time, finish_time) # Mark the beginning of the accumulators section of the output write_int(-1, outfile) for aid, accum in _accumulatorRegistry.items(): write_with_length(dump_pickle((aid, accum._value)), outfile) write_int(-1, outfile)
def _read_vec(stream): vector_length = read_int(stream) vector = np.empty(vector_length) # TODO: maybe some optimized way to read this all at once? for i in xrange(vector_length): vector[i] = struct.unpack('!d', stream.read(8))[0] return vector
def handle(self): from pyspark.accumulators import _accumulatorRegistry num_updates = read_int(self.rfile) for _ in range(num_updates): (aid, update) = load_pickle(read_with_length(self.rfile)) _accumulatorRegistry[aid] += update # Write a byte in acknowledgement self.wfile.write(struct.pack("!b", 1))
def accum_updates(): num_updates = read_int(self.rfile) for _ in range(num_updates): (aid, update) = pickleSer._read_with_length(self.rfile) _accumulatorRegistry[aid] += update # Write a byte in acknowledgement self.wfile.write(struct.pack("!b", 1)) return False
def func(event): headersBytes = BytesIO(event[0]) if sys.version >= "3" else StringIO(event[0]) headers = {} strSer = UTF8Deserializer() for i in range(0, read_int(headersBytes)): key = strSer.loads(headersBytes) value = strSer.loads(headersBytes) headers[key] = value body = bodyDecoder(event[1]) return (headers, body)
def main(infile, outfile): boot_time = time.time() split_index = read_int(infile) if split_index == -1: # for unit tests return # fetch name of workdir spark_files_dir = mutf8_deserializer.loads(infile) SparkFiles._root_directory = spark_files_dir SparkFiles._is_running_on_worker = True # fetch names and values of broadcast variables num_broadcast_variables = read_int(infile) for _ in range(num_broadcast_variables): bid = read_long(infile) value = pickleSer._read_with_length(infile) _broadcastRegistry[bid] = Broadcast(bid, value) # fetch names of includes (*.zip and *.egg files) and construct PYTHONPATH sys.path.append(spark_files_dir) # *.py files that were added will be copied here num_python_includes = read_int(infile) for _ in range(num_python_includes): filename = mutf8_deserializer.loads(infile) sys.path.append(os.path.join(spark_files_dir, filename)) command = pickleSer._read_with_length(infile) (func, deserializer, serializer) = command init_time = time.time() try: iterator = deserializer.load_stream(infile) serializer.dump_stream(func(split_index, iterator), outfile) except Exception as e: write_int(SpecialLengths.PYTHON_EXCEPTION_THROWN, outfile) write_with_length(traceback.format_exc(), outfile) sys.exit(-1) finish_time = time.time() report_times(outfile, boot_time, init_time, finish_time) # Mark the beginning of the accumulators section of the output write_int(SpecialLengths.END_OF_DATA_SECTION, outfile) write_int(len(_accumulatorRegistry), outfile) for (aid, accum) in _accumulatorRegistry.items(): pickleSer._write_with_length((aid, accum._value), outfile)
def load_stream(self, stream): """ Load a stream of un-ordered Arrow RecordBatches, where the last iteration yields a list of indices that can be used to put the RecordBatches in the correct order. """ # load the batches for batch in self.serializer.load_stream(stream): yield batch # load the batch order indices or propagate any error that occurred in the JVM num = read_int(stream) if num == -1: error_msg = UTF8Deserializer().loads(stream) raise RuntimeError( "An error occurred while calling " "ArrowCollectSerializer.load_stream: {}".format(error_msg)) batch_order = [] for i in range(num): index = read_int(stream) batch_order.append(index) yield batch_order
def handle(self): from pyspark.accumulators import _accumulatorRegistry while not self.server.server_shutdown: # Poll every 1 second for new data -- don't block in case of shutdown. r, _, _ = select.select([self.rfile], [], [], 1) if self.rfile in r: num_updates = read_int(self.rfile) for _ in range(num_updates): (aid, update) = pickleSer._read_with_length(self.rfile) _accumulatorRegistry[aid] += update # Write a byte in acknowledgement self.wfile.write(struct.pack("!b", 1))
def read_single_udf(pickleSer, infile, eval_type, runner_conf, udf_index): num_arg = read_int(infile) arg_offsets = [read_int(infile) for i in range(num_arg)] chained_func = None profiler = None for i in range(read_int(infile)): f, return_type, profiler = read_command(pickleSer, infile) if chained_func is None: chained_func = f else: chained_func = chain(chained_func, f) if eval_type == PythonEvalType.SQL_SCALAR_PANDAS_ITER_UDF: func = chained_func else: # make sure StopIteration's raised in the user code are not ignored # when they are processed in a for loop, raise them as RuntimeError's instead func = fail_on_stopiteration(chained_func) # the last returnType will be the return type of UDF if eval_type == PythonEvalType.SQL_SCALAR_PANDAS_UDF: return arg_offsets, wrap_scalar_pandas_udf(func, return_type), profiler elif eval_type == PythonEvalType.SQL_SCALAR_PANDAS_ITER_UDF: return arg_offsets, wrap_pandas_iter_udf(func, return_type), profiler elif eval_type == PythonEvalType.SQL_MAP_PANDAS_ITER_UDF: return arg_offsets, wrap_pandas_iter_udf(func, return_type), profiler elif eval_type == PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF: argspec = _get_argspec(chained_func) # signature was lost when wrapping it return arg_offsets, wrap_grouped_map_pandas_udf(func, return_type, argspec), profiler elif eval_type == PythonEvalType.SQL_COGROUPED_MAP_PANDAS_UDF: argspec = _get_argspec(chained_func) # signature was lost when wrapping it return arg_offsets, wrap_cogrouped_map_pandas_udf(func, return_type, argspec), profiler elif eval_type == PythonEvalType.SQL_GROUPED_AGG_PANDAS_UDF: return arg_offsets, wrap_grouped_agg_pandas_udf(func, return_type), profiler elif eval_type == PythonEvalType.SQL_WINDOW_AGG_PANDAS_UDF: return arg_offsets, wrap_window_agg_pandas_udf(func, return_type, runner_conf, udf_index), profiler elif eval_type == PythonEvalType.SQL_BATCHED_UDF: return arg_offsets, wrap_udf(func, return_type), profiler else: raise ValueError("Unknown eval type: {}".format(eval_type))
def read_single_udf(pickleSer, infile, eval_type): num_arg = read_int(infile) arg_offsets = [read_int(infile) for i in range(num_arg)] row_func = None for i in range(read_int(infile)): f, return_type = read_command(pickleSer, infile) if row_func is None: row_func = f else: row_func = chain(row_func, f) # make sure StopIteration's raised in the user code are not ignored # when they are processed in a for loop, raise them as RuntimeError's instead func = fail_on_stopiteration(row_func) # the last returnType will be the return type of UDF if eval_type == PythonEvalType.SQL_SCALAR_PANDAS_UDF: return arg_offsets, wrap_scalar_pandas_udf(func, return_type) elif eval_type == PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF: return arg_offsets, wrap_grouped_map_pandas_udf(func, return_type) else: return arg_offsets, wrap_udf(func, return_type)
def read_single_udf(pickleSer, infile, eval_type): num_arg = read_int(infile) arg_offsets = [read_int(infile) for i in range(num_arg)] row_func = None for i in range(read_int(infile)): f, return_type = read_command(pickleSer, infile) if row_func is None: row_func = f else: row_func = chain(row_func, f) # the last returnType will be the return type of UDF if eval_type == PythonEvalType.SQL_SCALAR_PANDAS_UDF: return arg_offsets, wrap_scalar_pandas_udf(row_func, return_type) elif eval_type == PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF: return arg_offsets, wrap_grouped_map_pandas_udf(row_func, return_type) elif eval_type == PythonEvalType.SQL_GROUPED_AGG_PANDAS_UDF: return arg_offsets, wrap_grouped_agg_pandas_udf(row_func, return_type) elif eval_type == PythonEvalType.SQL_BATCHED_UDF: return arg_offsets, wrap_udf(row_func, return_type) else: raise ValueError("Unknown eval type: {}".format(eval_type))
def main(infile, outfile): boot_time = time.time() split_index = read_int(infile) if split_index == -1: # for unit tests return spark_files_dir = load_pickle(read_with_length(infile)) SparkFiles._root_directory = spark_files_dir SparkFiles._is_running_on_worker = True sys.path.append(spark_files_dir) num_broadcast_variables = read_int(infile) for _ in range(num_broadcast_variables): bid = read_long(infile) value = read_with_length(infile) _broadcastRegistry[bid] = Broadcast(bid, load_pickle(value)) func = load_obj(infile) bypassSerializer = load_obj(infile) if bypassSerializer: dumps = lambda x: x else: dumps = dump_pickle init_time = time.time() iterator = read_from_pickle_file(infile) try: for obj in func(split_index, iterator): write_with_length(dumps(obj), outfile) except Exception as e: write_int(-2, outfile) write_with_length(traceback.format_exc(), outfile) sys.exit(-1) finish_time = time.time() report_times(outfile, boot_time, init_time, finish_time) # Mark the beginning of the accumulators section of the output write_int(-1, outfile) for aid, accum in _accumulatorRegistry.items(): write_with_length(dump_pickle((aid, accum._value)), outfile) write_int(-1, outfile)
def read_udfs(pickleSer, infile, eval_type): num_udfs = read_int(infile) udfs = {} call_udf = [] mapper_str = "" if eval_type == PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF: # Create function like this: # lambda a: f([a[0]], [a[0], a[1]]) # We assume there is only one UDF here because grouped map doesn't # support combining multiple UDFs. assert num_udfs == 1 # See FlatMapGroupsInPandasExec for how arg_offsets are used to # distinguish between grouping attributes and data attributes arg_offsets, udf = read_single_udf(pickleSer, infile, eval_type) udfs['f'] = udf split_offset = arg_offsets[0] + 1 arg0 = ["a[%d]" % o for o in arg_offsets[1:split_offset]] arg1 = ["a[%d]" % o for o in arg_offsets[split_offset:]] mapper_str = "lambda a: f([%s], [%s])" % (", ".join(arg0), ", ".join(arg1)) else: # Create function like this: # lambda a: (f0(a[0]), f1(a[1], a[2]), f2(a[3])) # In the special case of a single UDF this will return a single result rather # than a tuple of results; this is the format that the JVM side expects. for i in range(num_udfs): arg_offsets, udf = read_single_udf(pickleSer, infile, eval_type) udfs['f%d' % i] = udf args = ["a[%d]" % o for o in arg_offsets] call_udf.append("f%d(%s)" % (i, ", ".join(args))) mapper_str = "lambda a: (%s)" % (", ".join(call_udf)) mapper = eval(mapper_str, udfs) func = lambda _, it: map(mapper, it) if eval_type in (PythonEvalType.SQL_SCALAR_PANDAS_UDF, PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF, PythonEvalType.SQL_GROUPED_AGG_PANDAS_UDF, PythonEvalType.SQL_WINDOW_AGG_PANDAS_UDF): timezone = utf8_deserializer.loads(infile) ser = ArrowStreamPandasSerializer(timezone) else: ser = BatchedSerializer(PickleSerializer(), 100) # profiling is not supported for UDF return func, None, ser, ser
def read_udfs(pickleSer, infile): num_udfs = read_int(infile) udfs = {} call_udf = [] for i in range(num_udfs): arg_offsets, udf = read_single_udf(pickleSer, infile) udfs['f%d' % i] = udf args = ["a[%d]" % o for o in arg_offsets] call_udf.append("f%d(%s)" % (i, ", ".join(args))) # Create function like this: # lambda a: (f0(a0), f1(a1, a2), f2(a3)) # In the special case of a single UDF this will return a single result rather # than a tuple of results; this is the format that the JVM side expects. mapper_str = "lambda a: (%s)" % (", ".join(call_udf)) mapper = eval(mapper_str, udfs) func = lambda _, it: map(mapper, it) ser = BatchedSerializer(PickleSerializer(), 100) # profiling is not supported for UDF return func, None, ser, ser
def do_termination_test(self, terminator): from subprocess import Popen, PIPE from errno import ECONNREFUSED # start daemon daemon_path = os.path.join(os.path.dirname(__file__), "daemon.py") daemon = Popen([sys.executable, daemon_path], stdin=PIPE, stdout=PIPE) # read the port number port = read_int(daemon.stdout) # daemon should accept connections self.assertTrue(self.connect(port)) # request shutdown terminator(daemon) time.sleep(1) # daemon should no longer accept connections with self.assertRaises(EnvironmentError) as trap: self.connect(port) self.assertEqual(trap.exception.errno, ECONNREFUSED)
def load_stream(self, stream): """ Deserialize Cogrouped ArrowRecordBatches to a tuple of Arrow tables and yield as two lists of pandas.Series. """ import pyarrow as pa dataframes_in_group = None while dataframes_in_group is None or dataframes_in_group > 0: dataframes_in_group = read_int(stream) if dataframes_in_group == 2: batch1 = [batch for batch in ArrowStreamSerializer.load_stream(self, stream)] batch2 = [batch for batch in ArrowStreamSerializer.load_stream(self, stream)] yield ( [self.arrow_to_pandas(c) for c in pa.Table.from_batches(batch1).itercolumns()], [self.arrow_to_pandas(c) for c in pa.Table.from_batches(batch2).itercolumns()] ) elif dataframes_in_group != 0: raise ValueError( 'Invalid number of pandas.DataFrames in group {0}'.format(dataframes_in_group))
def _load_from_socket( port: Optional[Union[str, int]], auth_secret: str, function: int, all_gather_message: Optional[str] = None, ) -> List[str]: """ Load data from a given socket, this is a blocking method thus only return when the socket connection has been closed. """ (sockfile, sock) = local_connect_and_auth(port, auth_secret) # The call may block forever, so no timeout sock.settimeout(None) if function == BARRIER_FUNCTION: # Make a barrier() function call. write_int(function, sockfile) elif function == ALL_GATHER_FUNCTION: # Make a all_gather() function call. write_int(function, sockfile) write_with_length( cast(str, all_gather_message).encode("utf-8"), sockfile) else: raise ValueError("Unrecognized function type") sockfile.flush() # Collect result. len = read_int(sockfile) res = [] for i in range(len): res.append(UTF8Deserializer().loads(sockfile)) # Release resources. sockfile.close() sock.close() return res
def read_udfs(pickleSer, infile): num_udfs = read_int(infile) if num_udfs == 1: # fast path for single UDF _, udf = read_single_udf(pickleSer, infile) mapper = lambda a: udf(*a) else: udfs = {} call_udf = [] for i in range(num_udfs): arg_offsets, udf = read_single_udf(pickleSer, infile) udfs['f%d' % i] = udf args = ["a[%d]" % o for o in arg_offsets] call_udf.append("f%d(%s)" % (i, ", ".join(args))) # Create function like this: # lambda a: (f0(a0), f1(a1, a2), f2(a3)) mapper_str = "lambda a: (%s)" % (", ".join(call_udf)) mapper = eval(mapper_str, udfs) func = lambda _, it: map(mapper, it) ser = BatchedSerializer(PickleSerializer(), 100) # profiling is not supported for UDF return func, None, ser, ser
def launch_gateway(conf=None, popen_kwargs=None): """ launch jvm gateway Parameters ---------- conf : :py:class:`pyspark.SparkConf` spark configuration passed to spark-submit popen_kwargs : dict Dictionary of kwargs to pass to Popen when spawning the py4j JVM. This is a developer feature intended for use in customizing how pyspark interacts with the py4j JVM (e.g., capturing stdout/stderr). Returns ------- ClientServer or JavaGateway """ if "PYSPARK_GATEWAY_PORT" in os.environ: gateway_port = int(os.environ["PYSPARK_GATEWAY_PORT"]) gateway_secret = os.environ["PYSPARK_GATEWAY_SECRET"] # Process already exists proc = None else: SPARK_HOME = _find_spark_home() # Launch the Py4j gateway using Spark's run command so that we pick up the # proper classpath and settings from spark-env.sh on_windows = platform.system() == "Windows" script = "./bin/spark-submit.cmd" if on_windows else "./bin/spark-submit" command = [os.path.join(SPARK_HOME, script)] if conf: for k, v in conf.getAll(): command += ["--conf", "%s=%s" % (k, v)] submit_args = os.environ.get("PYSPARK_SUBMIT_ARGS", "pyspark-shell") if os.environ.get("SPARK_TESTING"): submit_args = " ".join( ["--conf spark.ui.enabled=false", submit_args]) command = command + shlex.split(submit_args) # Create a temporary directory where the gateway server should write the connection # information. conn_info_dir = tempfile.mkdtemp() try: fd, conn_info_file = tempfile.mkstemp(dir=conn_info_dir) os.close(fd) os.unlink(conn_info_file) env = dict(os.environ) env["_PYSPARK_DRIVER_CONN_INFO_PATH"] = conn_info_file # Launch the Java gateway. popen_kwargs = {} if popen_kwargs is None else popen_kwargs # We open a pipe to stdin so that the Java gateway can die when the pipe is broken popen_kwargs["stdin"] = PIPE # We always set the necessary environment variables. popen_kwargs["env"] = env if not on_windows: # Don't send ctrl-c / SIGINT to the Java gateway: def preexec_func(): signal.signal(signal.SIGINT, signal.SIG_IGN) popen_kwargs["preexec_fn"] = preexec_func proc = Popen(command, **popen_kwargs) else: # preexec_fn not supported on Windows proc = Popen(command, **popen_kwargs) # Wait for the file to appear, or for the process to exit, whichever happens first. while not proc.poll() and not os.path.isfile(conn_info_file): time.sleep(0.1) if not os.path.isfile(conn_info_file): raise RuntimeError( "Java gateway process exited before sending its port number" ) with open(conn_info_file, "rb") as info: gateway_port = read_int(info) gateway_secret = UTF8Deserializer().loads(info) finally: shutil.rmtree(conn_info_dir) # In Windows, ensure the Java child processes do not linger after Python has exited. # In UNIX-based systems, the child process can kill itself on broken pipe (i.e. when # the parent process' stdin sends an EOF). In Windows, however, this is not possible # because java.lang.Process reads directly from the parent process' stdin, contending # with any opportunity to read an EOF from the parent. Note that this is only best # effort and will not take effect if the python process is violently terminated. if on_windows: # In Windows, the child process here is "spark-submit.cmd", not the JVM itself # (because the UNIX "exec" command is not available). This means we cannot simply # call proc.kill(), which kills only the "spark-submit.cmd" process but not the # JVMs. Instead, we use "taskkill" with the tree-kill option "/t" to terminate all # child processes in the tree (http://technet.microsoft.com/en-us/library/bb491009.aspx) def killChild(): Popen([ "cmd", "/c", "taskkill", "/f", "/t", "/pid", str(proc.pid) ]) atexit.register(killChild) # Connect to the gateway (or client server to pin the thread between JVM and Python) if os.environ.get("PYSPARK_PIN_THREAD", "true").lower() == "true": gateway = ClientServer( java_parameters=JavaParameters(port=gateway_port, auth_token=gateway_secret, auto_convert=True), python_parameters=PythonParameters(port=0, eager_load=False), ) else: gateway = JavaGateway(gateway_parameters=GatewayParameters( port=gateway_port, auth_token=gateway_secret, auto_convert=True)) # Store a reference to the Popen object for use by the caller (e.g., in reading stdout/stderr) gateway.proc = proc # Import the classes used by PySpark java_import(gateway.jvm, "org.apache.spark.SparkConf") java_import(gateway.jvm, "org.apache.spark.api.java.*") java_import(gateway.jvm, "org.apache.spark.api.python.*") java_import(gateway.jvm, "org.apache.spark.ml.python.*") java_import(gateway.jvm, "org.apache.spark.mllib.api.python.*") java_import(gateway.jvm, "org.apache.spark.resource.*") # TODO(davies): move into sql java_import(gateway.jvm, "org.apache.spark.sql.*") java_import(gateway.jvm, "org.apache.spark.sql.api.python.*") java_import(gateway.jvm, "org.apache.spark.sql.hive.*") java_import(gateway.jvm, "scala.Tuple2") return gateway
def main(infile, outfile): try: boot_time = time.time() split_index = read_int(infile) if split_index == -1: # for unit tests sys.exit(-1) version = utf8_deserializer.loads(infile) if version != "%d.%d" % sys.version_info[:2]: raise Exception( ("Python in worker has different version %s than that in " + "driver %s, PySpark cannot run with different minor versions." + "Please check environment variables PYSPARK_PYTHON and " + "PYSPARK_DRIVER_PYTHON are correctly set.") % ("%d.%d" % sys.version_info[:2], version)) # initialize global state taskContext = TaskContext._getOrCreate() taskContext._stageId = read_int(infile) taskContext._partitionId = read_int(infile) taskContext._attemptNumber = read_int(infile) taskContext._taskAttemptId = read_long(infile) shuffle.MemoryBytesSpilled = 0 shuffle.DiskBytesSpilled = 0 _accumulatorRegistry.clear() # fetch name of workdir spark_files_dir = utf8_deserializer.loads(infile) SparkFiles._root_directory = spark_files_dir SparkFiles._is_running_on_worker = True # fetch names of includes (*.zip and *.egg files) and construct PYTHONPATH add_path( spark_files_dir) # *.py files that were added will be copied here num_python_includes = read_int(infile) for _ in range(num_python_includes): filename = utf8_deserializer.loads(infile) add_path(os.path.join(spark_files_dir, filename)) if sys.version > '3': import importlib importlib.invalidate_caches() # fetch names and values of broadcast variables num_broadcast_variables = read_int(infile) for _ in range(num_broadcast_variables): bid = read_long(infile) if bid >= 0: path = utf8_deserializer.loads(infile) _broadcastRegistry[bid] = Broadcast(path=path) else: bid = -bid - 1 _broadcastRegistry.pop(bid) _accumulatorRegistry.clear() eval_type = read_int(infile) if eval_type == PythonEvalType.NON_UDF: func, profiler, deserializer, serializer = read_command( pickleSer, infile) else: func, profiler, deserializer, serializer = read_udfs( pickleSer, infile, eval_type) init_time = time.time() def process(): iterator = deserializer.load_stream(infile) serializer.dump_stream(func(split_index, iterator), outfile) if profiler: profiler.profile(process) else: process() except Exception: try: write_int(SpecialLengths.PYTHON_EXCEPTION_THROWN, outfile) write_with_length(traceback.format_exc().encode("utf-8"), outfile) except IOError: # JVM close the socket pass except Exception: # Write the error to stderr if it happened while serializing print("PySpark worker failed with exception:", file=sys.stderr) print(traceback.format_exc(), file=sys.stderr) sys.exit(-1) finish_time = time.time() report_times(outfile, boot_time, init_time, finish_time) write_long(shuffle.MemoryBytesSpilled, outfile) write_long(shuffle.DiskBytesSpilled, outfile) # Mark the beginning of the accumulators section of the output write_int(SpecialLengths.END_OF_DATA_SECTION, outfile) write_int(len(_accumulatorRegistry), outfile) for (aid, accum) in _accumulatorRegistry.items(): pickleSer._write_with_length((aid, accum._value), outfile) # check end of stream if read_int(infile) == SpecialLengths.END_OF_STREAM: write_int(SpecialLengths.END_OF_STREAM, outfile) else: # write a different value to tell JVM to not reuse this worker write_int(SpecialLengths.END_OF_DATA_SECTION, outfile) sys.exit(-1)
def main(infile, outfile): try: boot_time = time.time() split_index = read_int(infile) if split_index == -1: # for unit tests sys.exit(-1) version = utf8_deserializer.loads(infile) if version != "%d.%d" % sys.version_info[:2]: raise Exception(("Python in worker has different version %s than that in " + "driver %s, PySpark cannot run with different minor versions." + "Please check environment variables PYSPARK_PYTHON and " + "PYSPARK_DRIVER_PYTHON are correctly set.") % ("%d.%d" % sys.version_info[:2], version)) # read inputs only for a barrier task isBarrier = read_bool(infile) boundPort = read_int(infile) secret = UTF8Deserializer().loads(infile) # set up memory limits memory_limit_mb = int(os.environ.get('PYSPARK_EXECUTOR_MEMORY_MB', "-1")) if memory_limit_mb > 0 and has_resource_module: total_memory = resource.RLIMIT_AS try: (soft_limit, hard_limit) = resource.getrlimit(total_memory) msg = "Current mem limits: {0} of max {1}\n".format(soft_limit, hard_limit) print(msg, file=sys.stderr) # convert to bytes new_limit = memory_limit_mb * 1024 * 1024 if soft_limit == resource.RLIM_INFINITY or new_limit < soft_limit: msg = "Setting mem limits to {0} of max {1}\n".format(new_limit, new_limit) print(msg, file=sys.stderr) resource.setrlimit(total_memory, (new_limit, new_limit)) except (resource.error, OSError, ValueError) as e: # not all systems support resource limits, so warn instead of failing print("WARN: Failed to set memory limit: {0}\n".format(e), file=sys.stderr) # initialize global state taskContext = None if isBarrier: taskContext = BarrierTaskContext._getOrCreate() BarrierTaskContext._initialize(boundPort, secret) else: taskContext = TaskContext._getOrCreate() # read inputs for TaskContext info taskContext._stageId = read_int(infile) taskContext._partitionId = read_int(infile) taskContext._attemptNumber = read_int(infile) taskContext._taskAttemptId = read_long(infile) taskContext._localProperties = dict() for i in range(read_int(infile)): k = utf8_deserializer.loads(infile) v = utf8_deserializer.loads(infile) taskContext._localProperties[k] = v shuffle.MemoryBytesSpilled = 0 shuffle.DiskBytesSpilled = 0 _accumulatorRegistry.clear() # fetch name of workdir spark_files_dir = utf8_deserializer.loads(infile) SparkFiles._root_directory = spark_files_dir SparkFiles._is_running_on_worker = True # fetch names of includes (*.zip and *.egg files) and construct PYTHONPATH add_path(spark_files_dir) # *.py files that were added will be copied here num_python_includes = read_int(infile) for _ in range(num_python_includes): filename = utf8_deserializer.loads(infile) add_path(os.path.join(spark_files_dir, filename)) if sys.version > '3': import importlib importlib.invalidate_caches() # fetch names and values of broadcast variables needs_broadcast_decryption_server = read_bool(infile) num_broadcast_variables = read_int(infile) if needs_broadcast_decryption_server: # read the decrypted data from a server in the jvm port = read_int(infile) auth_secret = utf8_deserializer.loads(infile) (broadcast_sock_file, _) = local_connect_and_auth(port, auth_secret) for _ in range(num_broadcast_variables): bid = read_long(infile) if bid >= 0: if needs_broadcast_decryption_server: read_bid = read_long(broadcast_sock_file) assert(read_bid == bid) _broadcastRegistry[bid] = \ Broadcast(sock_file=broadcast_sock_file) else: path = utf8_deserializer.loads(infile) _broadcastRegistry[bid] = Broadcast(path=path) else: bid = - bid - 1 _broadcastRegistry.pop(bid) if needs_broadcast_decryption_server: broadcast_sock_file.write(b'1') broadcast_sock_file.close() _accumulatorRegistry.clear() eval_type = read_int(infile) if eval_type == PythonEvalType.NON_UDF: func, profiler, deserializer, serializer = read_command(pickleSer, infile) else: func, profiler, deserializer, serializer = read_udfs(pickleSer, infile, eval_type) init_time = time.time() def process(): iterator = deserializer.load_stream(infile) serializer.dump_stream(func(split_index, iterator), outfile) if profiler: profiler.profile(process) else: process() except Exception: try: write_int(SpecialLengths.PYTHON_EXCEPTION_THROWN, outfile) write_with_length(traceback.format_exc().encode("utf-8"), outfile) except IOError: # JVM close the socket pass except Exception: # Write the error to stderr if it happened while serializing print("PySpark worker failed with exception:", file=sys.stderr) print(traceback.format_exc(), file=sys.stderr) sys.exit(-1) finish_time = time.time() report_times(outfile, boot_time, init_time, finish_time) write_long(shuffle.MemoryBytesSpilled, outfile) write_long(shuffle.DiskBytesSpilled, outfile) # Mark the beginning of the accumulators section of the output write_int(SpecialLengths.END_OF_DATA_SECTION, outfile) write_int(len(_accumulatorRegistry), outfile) for (aid, accum) in _accumulatorRegistry.items(): pickleSer._write_with_length((aid, accum._value), outfile) # check end of stream if read_int(infile) == SpecialLengths.END_OF_STREAM: write_int(SpecialLengths.END_OF_STREAM, outfile) else: # write a different value to tell JVM to not reuse this worker write_int(SpecialLengths.END_OF_DATA_SECTION, outfile) sys.exit(-1)
def main(infile, outfile): try: boot_time = time.time() split_index = read_int(infile) if split_index == -1: # for unit tests return # initialize global state shuffle.MemoryBytesSpilled = 0 shuffle.DiskBytesSpilled = 0 _accumulatorRegistry.clear() # fetch name of workdir spark_files_dir = utf8_deserializer.loads(infile) SparkFiles._root_directory = spark_files_dir SparkFiles._is_running_on_worker = True # fetch names of includes (*.zip and *.egg files) and construct PYTHONPATH sys.path.append( spark_files_dir) # *.py files that were added will be copied here num_python_includes = read_int(infile) for _ in range(num_python_includes): filename = utf8_deserializer.loads(infile) sys.path.append(os.path.join(spark_files_dir, filename)) # fetch names and values of broadcast variables num_broadcast_variables = read_int(infile) ser = CompressedSerializer(pickleSer) for _ in range(num_broadcast_variables): bid = read_long(infile) if bid >= 0: value = ser._read_with_length(infile) _broadcastRegistry[bid] = Broadcast(bid, value) else: bid = -bid - 1 _broadcastRegistry.pop(bid) _accumulatorRegistry.clear() command = pickleSer._read_with_length(infile) if isinstance(command, Broadcast): command = pickleSer.loads(command.value) (func, deserializer, serializer) = command init_time = time.time() iterator = deserializer.load_stream(infile) serializer.dump_stream(func(split_index, iterator), outfile) except Exception: try: write_int(SpecialLengths.PYTHON_EXCEPTION_THROWN, outfile) write_with_length(traceback.format_exc(), outfile) outfile.flush() except IOError: # JVM close the socket pass except Exception: # Write the error to stderr if it happened while serializing print >> sys.stderr, "PySpark worker failed with exception:" print >> sys.stderr, traceback.format_exc() exit(-1) finish_time = time.time() report_times(outfile, boot_time, init_time, finish_time) write_long(shuffle.MemoryBytesSpilled, outfile) write_long(shuffle.DiskBytesSpilled, outfile) # Mark the beginning of the accumulators section of the output write_int(SpecialLengths.END_OF_DATA_SECTION, outfile) write_int(len(_accumulatorRegistry), outfile) for (aid, accum) in _accumulatorRegistry.items(): pickleSer._write_with_length((aid, accum._value), outfile)
def main(infile, outfile): try: boot_time = time.time() split_index = read_int(infile) if split_index == -1: # for unit tests exit(-1) # initialize global state shuffle.MemoryBytesSpilled = 0 shuffle.DiskBytesSpilled = 0 _accumulatorRegistry.clear() # fetch name of workdir spark_files_dir = utf8_deserializer.loads(infile) SparkFiles._root_directory = spark_files_dir SparkFiles._is_running_on_worker = True # fetch names of includes (*.zip and *.egg files) and construct PYTHONPATH add_path( spark_files_dir) # *.py files that were added will be copied here num_python_includes = read_int(infile) for _ in range(num_python_includes): filename = utf8_deserializer.loads(infile) add_path(os.path.join(spark_files_dir, filename)) # fetch names and values of broadcast variables num_broadcast_variables = read_int(infile) bser = LargeObjectSerializer() for _ in range(num_broadcast_variables): bid = read_long(infile) if bid >= 0: size = read_long(infile) s = SizeLimitedStream(infile, size) value = list( (bser.load_stream(s)))[0] # read out all the bytes _broadcastRegistry[bid] = Broadcast(bid, value) else: bid = -bid - 1 _broadcastRegistry.pop(bid) _accumulatorRegistry.clear() command = pickleSer._read_with_length(infile) if isinstance(command, Broadcast): command = pickleSer.loads(command.value) (func, stats, deserializer, serializer) = command init_time = time.time() def process(): iterator = deserializer.load_stream(infile) serializer.dump_stream(func(split_index, iterator), outfile) if stats: p = cProfile.Profile() p.runcall(process) st = pstats.Stats(p) st.stream = None # make it picklable stats.add(st.strip_dirs()) else: process() except Exception: try: write_int(SpecialLengths.PYTHON_EXCEPTION_THROWN, outfile) write_with_length(traceback.format_exc(), outfile) except IOError: # JVM close the socket pass except Exception: # Write the error to stderr if it happened while serializing print >> sys.stderr, "PySpark worker failed with exception:" print >> sys.stderr, traceback.format_exc() exit(-1) finish_time = time.time() report_times(outfile, boot_time, init_time, finish_time) write_long(shuffle.MemoryBytesSpilled, outfile) write_long(shuffle.DiskBytesSpilled, outfile) # Mark the beginning of the accumulators section of the output write_int(SpecialLengths.END_OF_DATA_SECTION, outfile) write_int(len(_accumulatorRegistry), outfile) for (aid, accum) in _accumulatorRegistry.items(): pickleSer._write_with_length((aid, accum._value), outfile) # check end of stream if read_int(infile) == SpecialLengths.END_OF_STREAM: write_int(SpecialLengths.END_OF_STREAM, outfile) else: # write a different value to tell JVM to not reuse this worker write_int(SpecialLengths.END_OF_DATA_SECTION, outfile) exit(-1)
def main(infile, outfile): try: boot_time = time.time() split_index = read_int(infile) if split_index == -1: # for unit tests exit(-1) # initialize global state shuffle.MemoryBytesSpilled = 0 shuffle.DiskBytesSpilled = 0 _accumulatorRegistry.clear() # fetch name of workdir spark_files_dir = utf8_deserializer.loads(infile) SparkFiles._root_directory = spark_files_dir SparkFiles._is_running_on_worker = True # fetch names of includes (*.zip and *.egg files) and construct PYTHONPATH add_path(spark_files_dir) # *.py files that were added will be copied here num_python_includes = read_int(infile) for _ in range(num_python_includes): filename = utf8_deserializer.loads(infile) add_path(os.path.join(spark_files_dir, filename)) # fetch names and values of broadcast variables num_broadcast_variables = read_int(infile) for _ in range(num_broadcast_variables): bid = read_long(infile) if bid >= 0: path = utf8_deserializer.loads(infile) _broadcastRegistry[bid] = Broadcast(path=path) else: bid = - bid - 1 _broadcastRegistry.pop(bid) _accumulatorRegistry.clear() command = pickleSer._read_with_length(infile) if isinstance(command, Broadcast): command = pickleSer.loads(command.value) (func, profiler, deserializer, serializer), version = command if version != sys.version_info[:2]: raise Exception(("Python in worker has different version %s than that in " + "driver %s, PySpark cannot run with different minor versions") % (sys.version_info[:2], version)) init_time = time.time() def process(): iterator = deserializer.load_stream(infile) serializer.dump_stream(func(split_index, iterator), outfile) if profiler: profiler.profile(process) else: process() except Exception: try: write_int(SpecialLengths.PYTHON_EXCEPTION_THROWN, outfile) write_with_length(traceback.format_exc(), outfile) except IOError: # JVM close the socket pass except Exception: # Write the error to stderr if it happened while serializing print >> sys.stderr, "PySpark worker failed with exception:" print >> sys.stderr, traceback.format_exc() exit(-1) finish_time = time.time() report_times(outfile, boot_time, init_time, finish_time) write_long(shuffle.MemoryBytesSpilled, outfile) write_long(shuffle.DiskBytesSpilled, outfile) # Mark the beginning of the accumulators section of the output write_int(SpecialLengths.END_OF_DATA_SECTION, outfile) write_int(len(_accumulatorRegistry), outfile) for (aid, accum) in _accumulatorRegistry.items(): pickleSer._write_with_length((aid, accum._value), outfile) # check end of stream if read_int(infile) == SpecialLengths.END_OF_STREAM: write_int(SpecialLengths.END_OF_STREAM, outfile) else: # write a different value to tell JVM to not reuse this worker write_int(SpecialLengths.END_OF_DATA_SECTION, outfile) exit(-1)
def manager(): # Create a new process group to corral our children os.setpgid(0, 0) # Create a listening socket on the AF_INET loopback interface listen_sock = socket.socket(AF_INET, SOCK_STREAM) listen_sock.bind(('127.0.0.1', 0)) listen_sock.listen(max(1024, SOMAXCONN)) listen_host, listen_port = listen_sock.getsockname() write_int(listen_port, sys.stdout) sys.stdout.flush() def shutdown(code): signal.signal(SIGTERM, SIG_DFL) # Send SIGHUP to notify workers of shutdown os.kill(0, SIGHUP) exit(code) def handle_sigterm(*args): shutdown(1) signal.signal(SIGTERM, handle_sigterm) # Gracefully exit on SIGTERM signal.signal(SIGHUP, SIG_IGN) # Don't die on SIGHUP reuse = os.environ.get("SPARK_REUSE_WORKER") # Initialization complete try: while True: try: ready_fds = select.select([0, listen_sock], [], [], 1)[0] except select.error as ex: if ex[0] == EINTR: continue else: raise # cleanup in signal handler will cause deadlock cleanup_dead_children() if 0 in ready_fds: try: worker_pid = read_int(sys.stdin) except EOFError: # Spark told us to exit by closing stdin shutdown(0) try: os.kill(worker_pid, signal.SIGKILL) except OSError: pass # process already died if listen_sock in ready_fds: try: sock, _ = listen_sock.accept() except OSError as e: if e.errno == EINTR: continue raise # Launch a worker process try: pid = os.fork() except OSError as e: if e.errno in (EAGAIN, EINTR): time.sleep(1) pid = os.fork() # error here will shutdown daemon else: outfile = sock.makefile('w') write_int(e.errno, outfile) # Signal that the fork failed outfile.flush() outfile.close() sock.close() continue if pid == 0: # in child process listen_sock.close() try: # Acknowledge that the fork was successful outfile = sock.makefile("w") write_int(os.getpid(), outfile) outfile.flush() outfile.close() while True: code = worker(sock) if not reuse or code: # wait for closing try: while sock.recv(1024): pass except Exception: pass break gc.collect() except: traceback.print_exc() os._exit(1) else: os._exit(0) else: sock.close() finally: shutdown(1)
def main(infile, outfile): try: boot_time = time.time() split_index = read_int(infile) if split_index == -1: # for unit tests sys.exit(-1) version = utf8_deserializer.loads(infile) if version != "%d.%d" % sys.version_info[:2]: raise Exception(("Python in worker has different version %s than that in " + "driver %s, PySpark cannot run with different minor versions." + "Please check environment variables PYSPARK_PYTHON and " + "PYSPARK_DRIVER_PYTHON are correctly set.") % ("%d.%d" % sys.version_info[:2], version)) # read inputs only for a barrier task isBarrier = read_bool(infile) boundPort = read_int(infile) secret = UTF8Deserializer().loads(infile) # initialize global state taskContext = None if isBarrier: taskContext = BarrierTaskContext._getOrCreate() BarrierTaskContext._initialize(boundPort, secret) else: taskContext = TaskContext._getOrCreate() # read inputs for TaskContext info taskContext._stageId = read_int(infile) taskContext._partitionId = read_int(infile) taskContext._attemptNumber = read_int(infile) taskContext._taskAttemptId = read_long(infile) taskContext._localProperties = dict() for i in range(read_int(infile)): k = utf8_deserializer.loads(infile) v = utf8_deserializer.loads(infile) taskContext._localProperties[k] = v shuffle.MemoryBytesSpilled = 0 shuffle.DiskBytesSpilled = 0 _accumulatorRegistry.clear() # fetch name of workdir spark_files_dir = utf8_deserializer.loads(infile) SparkFiles._root_directory = spark_files_dir SparkFiles._is_running_on_worker = True # fetch names of includes (*.zip and *.egg files) and construct PYTHONPATH add_path(spark_files_dir) # *.py files that were added will be copied here num_python_includes = read_int(infile) for _ in range(num_python_includes): filename = utf8_deserializer.loads(infile) add_path(os.path.join(spark_files_dir, filename)) if sys.version > '3': import importlib importlib.invalidate_caches() # fetch names and values of broadcast variables num_broadcast_variables = read_int(infile) for _ in range(num_broadcast_variables): bid = read_long(infile) if bid >= 0: path = utf8_deserializer.loads(infile) _broadcastRegistry[bid] = Broadcast(path=path) else: bid = - bid - 1 _broadcastRegistry.pop(bid) _accumulatorRegistry.clear() eval_type = read_int(infile) if eval_type == PythonEvalType.NON_UDF: func, profiler, deserializer, serializer = read_command(pickleSer, infile) else: func, profiler, deserializer, serializer = read_udfs(pickleSer, infile, eval_type) init_time = time.time() def process(): iterator = deserializer.load_stream(infile) serializer.dump_stream(func(split_index, iterator), outfile) if profiler: profiler.profile(process) else: process() except Exception: try: write_int(SpecialLengths.PYTHON_EXCEPTION_THROWN, outfile) write_with_length(traceback.format_exc().encode("utf-8"), outfile) except IOError: # JVM close the socket pass except Exception: # Write the error to stderr if it happened while serializing print("PySpark worker failed with exception:", file=sys.stderr) print(traceback.format_exc(), file=sys.stderr) sys.exit(-1) finish_time = time.time() report_times(outfile, boot_time, init_time, finish_time) write_long(shuffle.MemoryBytesSpilled, outfile) write_long(shuffle.DiskBytesSpilled, outfile) # Mark the beginning of the accumulators section of the output write_int(SpecialLengths.END_OF_DATA_SECTION, outfile) write_int(len(_accumulatorRegistry), outfile) for (aid, accum) in _accumulatorRegistry.items(): pickleSer._write_with_length((aid, accum._value), outfile) # check end of stream if read_int(infile) == SpecialLengths.END_OF_STREAM: write_int(SpecialLengths.END_OF_STREAM, outfile) else: # write a different value to tell JVM to not reuse this worker write_int(SpecialLengths.END_OF_DATA_SECTION, outfile) sys.exit(-1)
def main(infile, outfile): try: boot_time = time.time() split_index = read_int(infile) if split_index == -1: # for unit tests exit(-1) # initialize global state shuffle.MemoryBytesSpilled = 0 shuffle.DiskBytesSpilled = 0 _accumulatorRegistry.clear() # fetch name of workdir spark_files_dir = utf8_deserializer.loads(infile) SparkFiles._root_directory = spark_files_dir SparkFiles._is_running_on_worker = True # fetch names of includes (*.zip and *.egg files) and construct PYTHONPATH add_path(spark_files_dir) # *.py files that were added will be copied here num_python_includes = read_int(infile) for _ in range(num_python_includes): filename = utf8_deserializer.loads(infile) add_path(os.path.join(spark_files_dir, filename)) if sys.version > '3': import importlib importlib.invalidate_caches() # fetch names and values of broadcast variables num_broadcast_variables = read_int(infile) for _ in range(num_broadcast_variables): bid = read_long(infile) if bid >= 0: path = utf8_deserializer.loads(infile) _broadcastRegistry[bid] = Broadcast(path=path) else: bid = - bid - 1 _broadcastRegistry.pop(bid) _accumulatorRegistry.clear() command = pickleSer._read_with_length(infile) if isinstance(command, Broadcast): command = pickleSer.loads(command.value) (func, profiler, deserializer, serializer), version = command if version != sys.version_info[:2]: raise Exception(("Python in worker has different version %s than that in " + "driver %s, PySpark cannot run with different minor versions") % (sys.version_info[:2], version)) init_time = time.time() def process(): iterator = deserializer.load_stream(infile) serializer.dump_stream(func(split_index, iterator), outfile) if profiler: profiler.profile(process) else: process() except Exception: try: write_int(SpecialLengths.PYTHON_EXCEPTION_THROWN, outfile) write_with_length(traceback.format_exc().encode("utf-8"), outfile) except IOError: # JVM close the socket pass except Exception: # Write the error to stderr if it happened while serializing print("PySpark worker failed with exception:", file=sys.stderr) print(traceback.format_exc(), file=sys.stderr) exit(-1) finish_time = time.time() report_times(outfile, boot_time, init_time, finish_time) write_long(shuffle.MemoryBytesSpilled, outfile) write_long(shuffle.DiskBytesSpilled, outfile) # Mark the beginning of the accumulators section of the output write_int(SpecialLengths.END_OF_DATA_SECTION, outfile) write_int(len(_accumulatorRegistry), outfile) for (aid, accum) in _accumulatorRegistry.items(): pickleSer._write_with_length((aid, accum._value), outfile) # check end of stream if read_int(infile) == SpecialLengths.END_OF_STREAM: write_int(SpecialLengths.END_OF_STREAM, outfile) else: # write a different value to tell JVM to not reuse this worker write_int(SpecialLengths.END_OF_DATA_SECTION, outfile) exit(-1)
def launch_gateway(conf=None): """ launch jvm gateway :param conf: spark configuration passed to spark-submit :return: """ if "PYSPARK_GATEWAY_PORT" in os.environ: gateway_port = int(os.environ["PYSPARK_GATEWAY_PORT"]) gateway_secret = os.environ["PYSPARK_GATEWAY_SECRET"] else: SPARK_HOME = _find_spark_home() # Launch the Py4j gateway using Spark's run command so that we pick up the # proper classpath and settings from spark-env.sh on_windows = platform.system() == "Windows" script = "./bin/spark-submit.cmd" if on_windows else "./bin/spark-submit" command = [os.path.join(SPARK_HOME, script)] if conf: for k, v in conf.getAll(): command += ['--conf', '%s=%s' % (k, v)] submit_args = os.environ.get("PYSPARK_SUBMIT_ARGS", "pyspark-shell") if os.environ.get("SPARK_TESTING"): submit_args = ' '.join([ "--conf spark.ui.enabled=false", submit_args ]) command = command + shlex.split(submit_args) # Create a temporary directory where the gateway server should write the connection # information. conn_info_dir = tempfile.mkdtemp() try: fd, conn_info_file = tempfile.mkstemp(dir=conn_info_dir) os.close(fd) os.unlink(conn_info_file) env = dict(os.environ) env["_PYSPARK_DRIVER_CONN_INFO_PATH"] = conn_info_file # Launch the Java gateway. # We open a pipe to stdin so that the Java gateway can die when the pipe is broken if not on_windows: # Don't send ctrl-c / SIGINT to the Java gateway: def preexec_func(): signal.signal(signal.SIGINT, signal.SIG_IGN) proc = Popen(command, stdin=PIPE, preexec_fn=preexec_func, env=env) else: # preexec_fn not supported on Windows proc = Popen(command, stdin=PIPE, env=env) # Wait for the file to appear, or for the process to exit, whichever happens first. while not proc.poll() and not os.path.isfile(conn_info_file): time.sleep(0.1) if not os.path.isfile(conn_info_file): raise Exception("Java gateway process exited before sending its port number") with open(conn_info_file, "rb") as info: gateway_port = read_int(info) gateway_secret = UTF8Deserializer().loads(info) finally: shutil.rmtree(conn_info_dir) # In Windows, ensure the Java child processes do not linger after Python has exited. # In UNIX-based systems, the child process can kill itself on broken pipe (i.e. when # the parent process' stdin sends an EOF). In Windows, however, this is not possible # because java.lang.Process reads directly from the parent process' stdin, contending # with any opportunity to read an EOF from the parent. Note that this is only best # effort and will not take effect if the python process is violently terminated. if on_windows: # In Windows, the child process here is "spark-submit.cmd", not the JVM itself # (because the UNIX "exec" command is not available). This means we cannot simply # call proc.kill(), which kills only the "spark-submit.cmd" process but not the # JVMs. Instead, we use "taskkill" with the tree-kill option "/t" to terminate all # child processes in the tree (http://technet.microsoft.com/en-us/library/bb491009.aspx) def killChild(): Popen(["cmd", "/c", "taskkill", "/f", "/t", "/pid", str(proc.pid)]) atexit.register(killChild) # Connect to the gateway gateway = JavaGateway( gateway_parameters=GatewayParameters(port=gateway_port, auth_token=gateway_secret, auto_convert=True)) # Import the classes used by PySpark java_import(gateway.jvm, "org.apache.spark.SparkConf") java_import(gateway.jvm, "org.apache.spark.api.java.*") java_import(gateway.jvm, "org.apache.spark.api.python.*") java_import(gateway.jvm, "org.apache.spark.ml.python.*") java_import(gateway.jvm, "org.apache.spark.mllib.api.python.*") # TODO(davies): move into sql java_import(gateway.jvm, "org.apache.spark.sql.*") java_import(gateway.jvm, "org.apache.spark.sql.api.python.*") java_import(gateway.jvm, "org.apache.spark.sql.hive.*") java_import(gateway.jvm, "scala.Tuple2") return gateway
def main(infile, outfile): try: boot_time = time.time() split_index = read_int(infile) if split_index == -1: # for unit tests exit(-1) # initialize global state shuffle.MemoryBytesSpilled = 0 shuffle.DiskBytesSpilled = 0 _accumulatorRegistry.clear() # fetch name of workdir spark_files_dir = utf8_deserializer.loads(infile) SparkFiles._root_directory = spark_files_dir SparkFiles._is_running_on_worker = True # fetch names of includes (*.zip and *.egg files) and construct PYTHONPATH add_path(spark_files_dir) # *.py files that were added will be copied here num_python_includes = read_int(infile) for _ in range(num_python_includes): filename = utf8_deserializer.loads(infile) add_path(os.path.join(spark_files_dir, filename)) # fetch names and values of broadcast variables num_broadcast_variables = read_int(infile) bser = LargeObjectSerializer() for _ in range(num_broadcast_variables): bid = read_long(infile) if bid >= 0: size = read_long(infile) s = SizeLimitedStream(infile, size) value = list((bser.load_stream(s)))[0] # read out all the bytes _broadcastRegistry[bid] = Broadcast(bid, value) else: bid = - bid - 1 _broadcastRegistry.pop(bid) _accumulatorRegistry.clear() command = pickleSer._read_with_length(infile) if isinstance(command, Broadcast): command = pickleSer.loads(command.value) (func, stats, deserializer, serializer) = command init_time = time.time() def process(): iterator = deserializer.load_stream(infile) serializer.dump_stream(func(split_index, iterator), outfile) if stats: p = cProfile.Profile() p.runcall(process) st = pstats.Stats(p) st.stream = None # make it picklable stats.add(st.strip_dirs()) else: process() except Exception: try: write_int(SpecialLengths.PYTHON_EXCEPTION_THROWN, outfile) write_with_length(traceback.format_exc(), outfile) except IOError: # JVM close the socket pass except Exception: # Write the error to stderr if it happened while serializing print >> sys.stderr, "PySpark worker failed with exception:" print >> sys.stderr, traceback.format_exc() exit(-1) finish_time = time.time() report_times(outfile, boot_time, init_time, finish_time) write_long(shuffle.MemoryBytesSpilled, outfile) write_long(shuffle.DiskBytesSpilled, outfile) # Mark the beginning of the accumulators section of the output write_int(SpecialLengths.END_OF_DATA_SECTION, outfile) write_int(len(_accumulatorRegistry), outfile) for (aid, accum) in _accumulatorRegistry.items(): pickleSer._write_with_length((aid, accum._value), outfile) # check end of stream if read_int(infile) == SpecialLengths.END_OF_STREAM: write_int(SpecialLengths.END_OF_STREAM, outfile) else: # write a different value to tell JVM to not reuse this worker write_int(SpecialLengths.END_OF_DATA_SECTION, outfile) exit(-1)
def loads(self, obj): stream = BytesIO(obj) key_length = read_int(stream) key = stream.read(key_length).decode('utf-8') return (key, _read_vec(stream))
def read_udfs(pickleSer, infile, eval_type): runner_conf = {} if eval_type in (PythonEvalType.SQL_SCALAR_PANDAS_UDF, PythonEvalType.SQL_SCALAR_PANDAS_ITER_UDF, PythonEvalType.SQL_MAP_PANDAS_ITER_UDF, PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF, PythonEvalType.SQL_GROUPED_AGG_PANDAS_UDF, PythonEvalType.SQL_WINDOW_AGG_PANDAS_UDF): # Load conf used for pandas_udf evaluation num_conf = read_int(infile) for i in range(num_conf): k = utf8_deserializer.loads(infile) v = utf8_deserializer.loads(infile) runner_conf[k] = v # NOTE: if timezone is set here, that implies respectSessionTimeZone is True timezone = runner_conf.get("spark.sql.session.timeZone", None) safecheck = runner_conf.get( "spark.sql.execution.pandas.arrowSafeTypeConversion", "false").lower() == 'true' # Used by SQL_GROUPED_MAP_PANDAS_UDF and SQL_SCALAR_PANDAS_UDF when returning StructType assign_cols_by_name = runner_conf.get( "spark.sql.legacy.execution.pandas.groupedMap.assignColumnsByName", "true")\ .lower() == "true" # Scalar Pandas UDF handles struct type arguments as pandas DataFrames instead of # pandas Series. See SPARK-27240. df_for_struct = ( eval_type == PythonEvalType.SQL_SCALAR_PANDAS_UDF or eval_type == PythonEvalType.SQL_SCALAR_PANDAS_ITER_UDF or eval_type == PythonEvalType.SQL_MAP_PANDAS_ITER_UDF) ser = ArrowStreamPandasUDFSerializer(timezone, safecheck, assign_cols_by_name, df_for_struct) else: ser = BatchedSerializer(PickleSerializer(), 100) num_udfs = read_int(infile) is_scalar_iter = eval_type == PythonEvalType.SQL_SCALAR_PANDAS_ITER_UDF is_map_iter = eval_type == PythonEvalType.SQL_MAP_PANDAS_ITER_UDF if is_scalar_iter or is_map_iter: if is_scalar_iter: assert num_udfs == 1, "One SCALAR_ITER UDF expected here." if is_map_iter: assert num_udfs == 1, "One MAP_ITER UDF expected here." arg_offsets, udf = read_single_udf(pickleSer, infile, eval_type, runner_conf, udf_index=0) def func(_, iterator): num_input_rows = [0] def map_batch(batch): udf_args = [batch[offset] for offset in arg_offsets] num_input_rows[0] += len(udf_args[0]) if len(udf_args) == 1: return udf_args[0] else: return tuple(udf_args) iterator = map(map_batch, iterator) result_iter = udf(iterator) num_output_rows = 0 for result_batch, result_type in result_iter: num_output_rows += len(result_batch) assert is_map_iter or num_output_rows <= num_input_rows[0], \ "Pandas MAP_ITER UDF outputted more rows than input rows." yield (result_batch, result_type) if is_scalar_iter: try: next(iterator) except StopIteration: pass else: raise RuntimeError( "SQL_SCALAR_PANDAS_ITER_UDF should exhaust the input " "iterator.") if is_scalar_iter and num_output_rows != num_input_rows[0]: raise RuntimeError( "The number of output rows of pandas iterator UDF should be " "the same with input rows. The input rows number is %d but the " "output rows number is %d." % (num_input_rows[0], num_output_rows)) # profiling is not supported for UDF return func, None, ser, ser udfs = {} call_udf = [] mapper_str = "" if eval_type == PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF: # Create function like this: # lambda a: f([a[0]], [a[0], a[1]]) # We assume there is only one UDF here because grouped map doesn't # support combining multiple UDFs. assert num_udfs == 1 # See FlatMapGroupsInPandasExec for how arg_offsets are used to # distinguish between grouping attributes and data attributes arg_offsets, udf = read_single_udf(pickleSer, infile, eval_type, runner_conf, udf_index=0) udfs['f'] = udf split_offset = arg_offsets[0] + 1 arg0 = ["a[%d]" % o for o in arg_offsets[1:split_offset]] arg1 = ["a[%d]" % o for o in arg_offsets[split_offset:]] mapper_str = "lambda a: f([%s], [%s])" % (", ".join(arg0), ", ".join(arg1)) else: # Create function like this: # lambda a: (f0(a[0]), f1(a[1], a[2]), f2(a[3])) # In the special case of a single UDF this will return a single result rather # than a tuple of results; this is the format that the JVM side expects. for i in range(num_udfs): arg_offsets, udf = read_single_udf(pickleSer, infile, eval_type, runner_conf, udf_index=i) udfs['f%d' % i] = udf args = ["a[%d]" % o for o in arg_offsets] call_udf.append("f%d(%s)" % (i, ", ".join(args))) mapper_str = "lambda a: (%s)" % (", ".join(call_udf)) mapper = eval(mapper_str, udfs) func = lambda _, it: map(mapper, it) # profiling is not supported for UDF return func, None, ser, ser
def manager(): # Create a new process group to corral our children os.setpgid(0, 0) # Create a listening socket on the AF_INET loopback interface listen_sock = socket.socket(AF_INET, SOCK_STREAM) listen_sock.bind(('127.0.0.1', 0)) listen_sock.listen(max(1024, SOMAXCONN)) listen_host, listen_port = listen_sock.getsockname() write_int(listen_port, sys.stdout) def shutdown(code): signal.signal(SIGTERM, SIG_DFL) # Send SIGHUP to notify workers of shutdown os.kill(0, SIGHUP) exit(code) def handle_sigterm(*args): shutdown(1) signal.signal(SIGTERM, handle_sigterm) # Gracefully exit on SIGTERM signal.signal(SIGHUP, SIG_IGN) # Don't die on SIGHUP # Cleanup zombie children def handle_sigchld(*args): try: pid, status = os.waitpid(0, os.WNOHANG) if status != 0: msg = "worker %s crashed abruptly with exit status %s" % (pid, status) print >> sys.stderr, msg except EnvironmentError as err: if err.errno not in (ECHILD, EINTR): raise signal.signal(SIGCHLD, handle_sigchld) # Initialization complete sys.stdout.close() try: while True: try: ready_fds = select.select([0, listen_sock], [], [])[0] except select.error as ex: if ex[0] == EINTR: continue else: raise if 0 in ready_fds: try: worker_pid = read_int(sys.stdin) except EOFError: # Spark told us to exit by closing stdin shutdown(0) try: os.kill(worker_pid, signal.SIGKILL) except OSError: pass # process already died if listen_sock in ready_fds: sock, addr = listen_sock.accept() # Launch a worker process try: pid = os.fork() if pid == 0: listen_sock.close() try: worker(sock) except: traceback.print_exc() os._exit(1) else: os._exit(0) else: sock.close() except OSError as e: print >> sys.stderr, "Daemon failed to fork PySpark worker: %s" % e outfile = os.fdopen(os.dup(sock.fileno()), "a+", 65536) write_int(-1, outfile) # Signal that the fork failed outfile.flush() outfile.close() sock.close() finally: shutdown(1)
def launch_gateway(conf=None): """ launch jvm gateway :param conf: spark configuration passed to spark-submit :return: """ if "PYSPARK_GATEWAY_PORT" in os.environ: gateway_port = int(os.environ["PYSPARK_GATEWAY_PORT"]) else: SPARK_HOME = _find_spark_home() # Launch the Py4j gateway using Spark's run command so that we pick up the # proper classpath and settings from spark-env.sh on_windows = platform.system() == "Windows" script = "./bin/spark-submit.cmd" if on_windows else "./bin/spark-submit" command = [os.path.join(SPARK_HOME, script)] if conf: for k, v in conf.getAll(): command += ['--conf', '%s=%s' % (k, v)] submit_args = os.environ.get("PYSPARK_SUBMIT_ARGS", "pyspark-shell") if os.environ.get("SPARK_TESTING"): submit_args = ' '.join( ["--conf spark.ui.enabled=false", submit_args]) command = command + shlex.split(submit_args) # Start a socket that will be used by PythonGatewayServer to communicate its port to us callback_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM) callback_socket.bind(('127.0.0.1', 0)) callback_socket.listen(1) callback_host, callback_port = callback_socket.getsockname() env = dict(os.environ) env['_PYSPARK_DRIVER_CALLBACK_HOST'] = callback_host env['_PYSPARK_DRIVER_CALLBACK_PORT'] = str(callback_port) # Launch the Java gateway. # We open a pipe to stdin so that the Java gateway can die when the pipe is broken if not on_windows: # Don't send ctrl-c / SIGINT to the Java gateway: def preexec_func(): signal.signal(signal.SIGINT, signal.SIG_IGN) proc = Popen(command, stdin=PIPE, preexec_fn=preexec_func, env=env) else: # preexec_fn not supported on Windows proc = Popen(command, stdin=PIPE, env=env) gateway_port = None # We use select() here in order to avoid blocking indefinitely if the subprocess dies # before connecting while gateway_port is None and proc.poll() is None: timeout = 1 # (seconds) readable, _, _ = select.select([callback_socket], [], [], timeout) if callback_socket in readable: gateway_connection = callback_socket.accept()[0] # Determine which ephemeral port the server started on: gateway_port = read_int(gateway_connection.makefile(mode="rb")) gateway_connection.close() callback_socket.close() if gateway_port is None: raise Exception( "Java gateway process exited before sending the driver its port number" ) # In Windows, ensure the Java child processes do not linger after Python has exited. # In UNIX-based systems, the child process can kill itself on broken pipe (i.e. when # the parent process' stdin sends an EOF). In Windows, however, this is not possible # because java.lang.Process reads directly from the parent process' stdin, contending # with any opportunity to read an EOF from the parent. Note that this is only best # effort and will not take effect if the python process is violently terminated. if on_windows: # In Windows, the child process here is "spark-submit.cmd", not the JVM itself # (because the UNIX "exec" command is not available). This means we cannot simply # call proc.kill(), which kills only the "spark-submit.cmd" process but not the # JVMs. Instead, we use "taskkill" with the tree-kill option "/t" to terminate all # child processes in the tree (http://technet.microsoft.com/en-us/library/bb491009.aspx) def killChild(): Popen([ "cmd", "/c", "taskkill", "/f", "/t", "/pid", str(proc.pid) ]) atexit.register(killChild) # Connect to the gateway gateway = JavaGateway(GatewayClient(port=gateway_port), auto_convert=True) # Import the classes used by PySpark java_import(gateway.jvm, "org.apache.spark.SparkConf") java_import(gateway.jvm, "org.apache.spark.api.java.*") java_import(gateway.jvm, "org.apache.spark.api.python.*") java_import(gateway.jvm, "org.apache.spark.ml.python.*") java_import(gateway.jvm, "org.apache.spark.mllib.api.python.*") # TODO(davies): move into sql java_import(gateway.jvm, "org.apache.spark.sql.*") java_import(gateway.jvm, "org.apache.spark.sql.hive.*") java_import(gateway.jvm, "scala.Tuple2") return gateway
def main(infile, outfile): faulthandler_log_path = os.environ.get("PYTHON_FAULTHANDLER_DIR", None) try: if faulthandler_log_path: faulthandler_log_path = os.path.join(faulthandler_log_path, str(os.getpid())) faulthandler_log_file = open(faulthandler_log_path, "w") faulthandler.enable(file=faulthandler_log_file) boot_time = time.time() split_index = read_int(infile) if split_index == -1: # for unit tests sys.exit(-1) version = utf8_deserializer.loads(infile) if version != "%d.%d" % sys.version_info[:2]: raise RuntimeError( ( "Python in worker has different version %s than that in " + "driver %s, PySpark cannot run with different minor versions. " + "Please check environment variables PYSPARK_PYTHON and " + "PYSPARK_DRIVER_PYTHON are correctly set." ) % ("%d.%d" % sys.version_info[:2], version) ) # read inputs only for a barrier task isBarrier = read_bool(infile) boundPort = read_int(infile) secret = UTF8Deserializer().loads(infile) # set up memory limits memory_limit_mb = int(os.environ.get("PYSPARK_EXECUTOR_MEMORY_MB", "-1")) if memory_limit_mb > 0 and has_resource_module: total_memory = resource.RLIMIT_AS try: (soft_limit, hard_limit) = resource.getrlimit(total_memory) msg = "Current mem limits: {0} of max {1}\n".format(soft_limit, hard_limit) print(msg, file=sys.stderr) # convert to bytes new_limit = memory_limit_mb * 1024 * 1024 if soft_limit == resource.RLIM_INFINITY or new_limit < soft_limit: msg = "Setting mem limits to {0} of max {1}\n".format(new_limit, new_limit) print(msg, file=sys.stderr) resource.setrlimit(total_memory, (new_limit, new_limit)) except (resource.error, OSError, ValueError) as e: # not all systems support resource limits, so warn instead of failing lineno = ( getframeinfo(currentframe()).lineno + 1 if currentframe() is not None else 0 ) print( warnings.formatwarning( "Failed to set memory limit: {0}".format(e), ResourceWarning, __file__, lineno, ), file=sys.stderr, ) # initialize global state taskContext = None if isBarrier: taskContext = BarrierTaskContext._getOrCreate() BarrierTaskContext._initialize(boundPort, secret) # Set the task context instance here, so we can get it by TaskContext.get for # both TaskContext and BarrierTaskContext TaskContext._setTaskContext(taskContext) else: taskContext = TaskContext._getOrCreate() # read inputs for TaskContext info taskContext._stageId = read_int(infile) taskContext._partitionId = read_int(infile) taskContext._attemptNumber = read_int(infile) taskContext._taskAttemptId = read_long(infile) taskContext._cpus = read_int(infile) taskContext._resources = {} for r in range(read_int(infile)): key = utf8_deserializer.loads(infile) name = utf8_deserializer.loads(infile) addresses = [] taskContext._resources = {} for a in range(read_int(infile)): addresses.append(utf8_deserializer.loads(infile)) taskContext._resources[key] = ResourceInformation(name, addresses) taskContext._localProperties = dict() for i in range(read_int(infile)): k = utf8_deserializer.loads(infile) v = utf8_deserializer.loads(infile) taskContext._localProperties[k] = v shuffle.MemoryBytesSpilled = 0 shuffle.DiskBytesSpilled = 0 _accumulatorRegistry.clear() # fetch name of workdir spark_files_dir = utf8_deserializer.loads(infile) SparkFiles._root_directory = spark_files_dir SparkFiles._is_running_on_worker = True # fetch names of includes (*.zip and *.egg files) and construct PYTHONPATH add_path(spark_files_dir) # *.py files that were added will be copied here num_python_includes = read_int(infile) for _ in range(num_python_includes): filename = utf8_deserializer.loads(infile) add_path(os.path.join(spark_files_dir, filename)) importlib.invalidate_caches() # fetch names and values of broadcast variables needs_broadcast_decryption_server = read_bool(infile) num_broadcast_variables = read_int(infile) if needs_broadcast_decryption_server: # read the decrypted data from a server in the jvm port = read_int(infile) auth_secret = utf8_deserializer.loads(infile) (broadcast_sock_file, _) = local_connect_and_auth(port, auth_secret) for _ in range(num_broadcast_variables): bid = read_long(infile) if bid >= 0: if needs_broadcast_decryption_server: read_bid = read_long(broadcast_sock_file) assert read_bid == bid _broadcastRegistry[bid] = Broadcast(sock_file=broadcast_sock_file) else: path = utf8_deserializer.loads(infile) _broadcastRegistry[bid] = Broadcast(path=path) else: bid = -bid - 1 _broadcastRegistry.pop(bid) if needs_broadcast_decryption_server: broadcast_sock_file.write(b"1") broadcast_sock_file.close() _accumulatorRegistry.clear() eval_type = read_int(infile) if eval_type == PythonEvalType.NON_UDF: func, profiler, deserializer, serializer = read_command(pickleSer, infile) else: func, profiler, deserializer, serializer = read_udfs(pickleSer, infile, eval_type) init_time = time.time() def process(): iterator = deserializer.load_stream(infile) out_iter = func(split_index, iterator) try: serializer.dump_stream(out_iter, outfile) finally: if hasattr(out_iter, "close"): out_iter.close() if profiler: profiler.profile(process) else: process() # Reset task context to None. This is a guard code to avoid residual context when worker # reuse. TaskContext._setTaskContext(None) BarrierTaskContext._setTaskContext(None) except BaseException as e: try: exc_info = None if os.environ.get("SPARK_SIMPLIFIED_TRACEBACK", False): tb = try_simplify_traceback(sys.exc_info()[-1]) if tb is not None: e.__cause__ = None exc_info = "".join(traceback.format_exception(type(e), e, tb)) if exc_info is None: exc_info = traceback.format_exc() write_int(SpecialLengths.PYTHON_EXCEPTION_THROWN, outfile) write_with_length(exc_info.encode("utf-8"), outfile) except IOError: # JVM close the socket pass except BaseException: # Write the error to stderr if it happened while serializing print("PySpark worker failed with exception:", file=sys.stderr) print(traceback.format_exc(), file=sys.stderr) sys.exit(-1) finally: if faulthandler_log_path: faulthandler.disable() faulthandler_log_file.close() os.remove(faulthandler_log_path) finish_time = time.time() report_times(outfile, boot_time, init_time, finish_time) write_long(shuffle.MemoryBytesSpilled, outfile) write_long(shuffle.DiskBytesSpilled, outfile) # Mark the beginning of the accumulators section of the output write_int(SpecialLengths.END_OF_DATA_SECTION, outfile) write_int(len(_accumulatorRegistry), outfile) for (aid, accum) in _accumulatorRegistry.items(): pickleSer._write_with_length((aid, accum._value), outfile) # check end of stream if read_int(infile) == SpecialLengths.END_OF_STREAM: write_int(SpecialLengths.END_OF_STREAM, outfile) else: # write a different value to tell JVM to not reuse this worker write_int(SpecialLengths.END_OF_DATA_SECTION, outfile) sys.exit(-1)
def read_udfs(pickleSer, infile, eval_type): runner_conf = {} if eval_type in (PythonEvalType.SQL_SCALAR_PANDAS_UDF, PythonEvalType.SQL_COGROUPED_MAP_PANDAS_UDF, PythonEvalType.SQL_SCALAR_PANDAS_ITER_UDF, PythonEvalType.SQL_MAP_PANDAS_ITER_UDF, PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF, PythonEvalType.SQL_GROUPED_AGG_PANDAS_UDF, PythonEvalType.SQL_WINDOW_AGG_PANDAS_UDF): # Load conf used for pandas_udf evaluation num_conf = read_int(infile) for i in range(num_conf): k = utf8_deserializer.loads(infile) v = utf8_deserializer.loads(infile) runner_conf[k] = v # NOTE: if timezone is set here, that implies respectSessionTimeZone is True timezone = runner_conf.get("spark.sql.session.timeZone", None) safecheck = runner_conf.get( "spark.sql.execution.pandas.convertToArrowArraySafely", "false").lower() == 'true' # Used by SQL_GROUPED_MAP_PANDAS_UDF and SQL_SCALAR_PANDAS_UDF when returning StructType assign_cols_by_name = runner_conf.get( "spark.sql.legacy.execution.pandas.groupedMap.assignColumnsByName", "true")\ .lower() == "true" if eval_type == PythonEvalType.SQL_COGROUPED_MAP_PANDAS_UDF: ser = CogroupUDFSerializer(timezone, safecheck, assign_cols_by_name) else: # Scalar Pandas UDF handles struct type arguments as pandas DataFrames instead of # pandas Series. See SPARK-27240. df_for_struct = ( eval_type == PythonEvalType.SQL_SCALAR_PANDAS_UDF or eval_type == PythonEvalType.SQL_SCALAR_PANDAS_ITER_UDF or eval_type == PythonEvalType.SQL_MAP_PANDAS_ITER_UDF) ser = ArrowStreamPandasUDFSerializer(timezone, safecheck, assign_cols_by_name, df_for_struct) else: ser = BatchedSerializer(PickleSerializer(), 100) num_udfs = read_int(infile) is_scalar_iter = eval_type == PythonEvalType.SQL_SCALAR_PANDAS_ITER_UDF is_map_iter = eval_type == PythonEvalType.SQL_MAP_PANDAS_ITER_UDF if is_scalar_iter or is_map_iter: if is_scalar_iter: assert num_udfs == 1, "One SCALAR_ITER UDF expected here." if is_map_iter: assert num_udfs == 1, "One MAP_ITER UDF expected here." arg_offsets, udf = read_single_udf(pickleSer, infile, eval_type, runner_conf, udf_index=0) def func(_, iterator): num_input_rows = 0 def map_batch(batch): nonlocal num_input_rows udf_args = [batch[offset] for offset in arg_offsets] num_input_rows += len(udf_args[0]) if len(udf_args) == 1: return udf_args[0] else: return tuple(udf_args) iterator = map(map_batch, iterator) result_iter = udf(iterator) num_output_rows = 0 for result_batch, result_type in result_iter: num_output_rows += len(result_batch) # This assert is for Scalar Iterator UDF to fail fast. # The length of the entire input can only be explicitly known # by consuming the input iterator in user side. Therefore, # it's very unlikely the output length is higher than # input length. assert is_map_iter or num_output_rows <= num_input_rows, \ "Pandas SCALAR_ITER UDF outputted more rows than input rows." yield (result_batch, result_type) if is_scalar_iter: try: next(iterator) except StopIteration: pass else: raise RuntimeError( "pandas iterator UDF should exhaust the input " "iterator.") if num_output_rows != num_input_rows: raise RuntimeError( "The length of output in Scalar iterator pandas UDF should be " "the same with the input's; however, the length of output was %d and the " "length of input was %d." % (num_output_rows, num_input_rows)) # profiling is not supported for UDF return func, None, ser, ser def extract_key_value_indexes(grouped_arg_offsets): """ Helper function to extract the key and value indexes from arg_offsets for the grouped and cogrouped pandas udfs. See BasePandasGroupExec.resolveArgOffsets for equivalent scala code. Parameters ---------- grouped_arg_offsets: list List containing the key and value indexes of columns of the DataFrames to be passed to the udf. It consists of n repeating groups where n is the number of DataFrames. Each group has the following format: group[0]: length of group group[1]: length of key indexes group[2.. group[1] +2]: key attributes group[group[1] +3 group[0]]: value attributes """ parsed = [] idx = 0 while idx < len(grouped_arg_offsets): offsets_len = grouped_arg_offsets[idx] idx += 1 offsets = grouped_arg_offsets[idx:idx + offsets_len] split_index = offsets[0] + 1 offset_keys = offsets[1:split_index] offset_values = offsets[split_index:] parsed.append([offset_keys, offset_values]) idx += offsets_len return parsed if eval_type == PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF: # We assume there is only one UDF here because grouped map doesn't # support combining multiple UDFs. assert num_udfs == 1 # See FlatMapGroupsInPandasExec for how arg_offsets are used to # distinguish between grouping attributes and data attributes arg_offsets, f = read_single_udf(pickleSer, infile, eval_type, runner_conf, udf_index=0) parsed_offsets = extract_key_value_indexes(arg_offsets) # Create function like this: # mapper a: f([a[0]], [a[0], a[1]]) def mapper(a): keys = [a[o] for o in parsed_offsets[0][0]] vals = [a[o] for o in parsed_offsets[0][1]] return f(keys, vals) elif eval_type == PythonEvalType.SQL_COGROUPED_MAP_PANDAS_UDF: # We assume there is only one UDF here because cogrouped map doesn't # support combining multiple UDFs. assert num_udfs == 1 arg_offsets, f = read_single_udf(pickleSer, infile, eval_type, runner_conf, udf_index=0) parsed_offsets = extract_key_value_indexes(arg_offsets) def mapper(a): df1_keys = [a[0][o] for o in parsed_offsets[0][0]] df1_vals = [a[0][o] for o in parsed_offsets[0][1]] df2_keys = [a[1][o] for o in parsed_offsets[1][0]] df2_vals = [a[1][o] for o in parsed_offsets[1][1]] return f(df1_keys, df1_vals, df2_keys, df2_vals) else: udfs = [] for i in range(num_udfs): udfs.append( read_single_udf(pickleSer, infile, eval_type, runner_conf, udf_index=i)) def mapper(a): result = tuple( f(*[a[o] for o in arg_offsets]) for (arg_offsets, f) in udfs) # In the special case of a single UDF this will return a single result rather # than a tuple of results; this is the format that the JVM side expects. if len(result) == 1: return result[0] else: return result func = lambda _, it: map(mapper, it) # profiling is not supported for UDF return func, None, ser, ser
def launch_gateway(): SPARK_HOME = os.environ["SPARK_HOME"] if "PYSPARK_GATEWAY_PORT" in os.environ: gateway_port = int(os.environ["PYSPARK_GATEWAY_PORT"]) else: # Launch the Py4j gateway using Spark's run command so that we pick up the # proper classpath and settings from spark-env.sh on_windows = platform.system() == "Windows" script = "./bin/spark-submit.cmd" if on_windows else "./bin/spark-submit" submit_args = os.environ.get("PYSPARK_SUBMIT_ARGS", "pyspark-shell") command = [os.path.join(SPARK_HOME, script)] + shlex.split(submit_args) # Start a socket that will be used by PythonGatewayServer to communicate its port to us callback_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM) callback_socket.bind(('127.0.0.1', 0)) callback_socket.listen(1) callback_host, callback_port = callback_socket.getsockname() env = dict(os.environ) env['_PYSPARK_DRIVER_CALLBACK_HOST'] = callback_host env['_PYSPARK_DRIVER_CALLBACK_PORT'] = str(callback_port) # Launch the Java gateway. # We open a pipe to stdin so that the Java gateway can die when the pipe is broken if not on_windows: # Don't send ctrl-c / SIGINT to the Java gateway: def preexec_func(): signal.signal(signal.SIGINT, signal.SIG_IGN) proc = Popen(command, stdin=PIPE, preexec_fn=preexec_func, env=env) else: # preexec_fn not supported on Windows proc = Popen(command, stdin=PIPE, env=env) gateway_port = None # We use select() here in order to avoid blocking indefinitely if the subprocess dies # before connecting while gateway_port is None and proc.poll() is None: timeout = 1 # (seconds) readable, _, _ = select.select([callback_socket], [], [], timeout) if callback_socket in readable: gateway_connection = callback_socket.accept()[0] # Determine which ephemeral port the server started on: gateway_port = read_int(gateway_connection.makefile()) gateway_connection.close() callback_socket.close() if gateway_port is None: raise Exception("Java gateway process exited before sending the driver its port number") # In Windows, ensure the Java child processes do not linger after Python has exited. # In UNIX-based systems, the child process can kill itself on broken pipe (i.e. when # the parent process' stdin sends an EOF). In Windows, however, this is not possible # because java.lang.Process reads directly from the parent process' stdin, contending # with any opportunity to read an EOF from the parent. Note that this is only best # effort and will not take effect if the python process is violently terminated. if on_windows: # In Windows, the child process here is "spark-submit.cmd", not the JVM itself # (because the UNIX "exec" command is not available). This means we cannot simply # call proc.kill(), which kills only the "spark-submit.cmd" process but not the # JVMs. Instead, we use "taskkill" with the tree-kill option "/t" to terminate all # child processes in the tree (http://technet.microsoft.com/en-us/library/bb491009.aspx) def killChild(): Popen(["cmd", "/c", "taskkill", "/f", "/t", "/pid", str(proc.pid)]) atexit.register(killChild) # Connect to the gateway gateway = JavaGateway(GatewayClient(port=gateway_port), auto_convert=False) # Import the classes used by PySpark java_import(gateway.jvm, "org.apache.spark.SparkConf") java_import(gateway.jvm, "org.apache.spark.api.java.*") java_import(gateway.jvm, "org.apache.spark.api.python.*") java_import(gateway.jvm, "org.apache.spark.mllib.api.python.*") # TODO(davies): move into sql java_import(gateway.jvm, "org.apache.spark.sql.*") java_import(gateway.jvm, "org.apache.spark.sql.hive.*") java_import(gateway.jvm, "scala.Tuple2") return gateway
def manager(): # Create a new process group to corral our children os.setpgid(0, 0) # Create a listening socket on the AF_INET loopback interface listen_sock = socket.socket(AF_INET, SOCK_STREAM) listen_sock.bind(('127.0.0.1', 0)) listen_sock.listen(max(1024, SOMAXCONN)) listen_host, listen_port = listen_sock.getsockname() write_int(listen_port, sys.stdout) def shutdown(code): signal.signal(SIGTERM, SIG_DFL) # Send SIGHUP to notify workers of shutdown os.kill(0, SIGHUP) exit(code) def handle_sigterm(*args): shutdown(1) signal.signal(SIGTERM, handle_sigterm) # Gracefully exit on SIGTERM signal.signal(SIGHUP, SIG_IGN) # Don't die on SIGHUP # Cleanup zombie children def handle_sigchld(*args): try: pid, status = os.waitpid(0, os.WNOHANG) if status != 0: msg = "worker %s crashed abruptly with exit status %s" % ( pid, status) print >> sys.stderr, msg except EnvironmentError as err: if err.errno not in (ECHILD, EINTR): raise signal.signal(SIGCHLD, handle_sigchld) # Initialization complete sys.stdout.close() try: while True: try: ready_fds = select.select([0, listen_sock], [], [])[0] except select.error as ex: if ex[0] == EINTR: continue else: raise if 0 in ready_fds: try: worker_pid = read_int(sys.stdin) except EOFError: # Spark told us to exit by closing stdin shutdown(0) try: os.kill(worker_pid, signal.SIGKILL) except OSError: pass # process already died if listen_sock in ready_fds: sock, addr = listen_sock.accept() # Launch a worker process try: pid = os.fork() if pid == 0: listen_sock.close() try: worker(sock) except: traceback.print_exc() os._exit(1) else: os._exit(0) else: sock.close() except OSError as e: print >> sys.stderr, "Daemon failed to fork PySpark worker: %s" % e outfile = os.fdopen(os.dup(sock.fileno()), "a+", 65536) write_int(-1, outfile) # Signal that the fork failed outfile.flush() outfile.close() sock.close() finally: shutdown(1)
def read_udfs(pickleSer, infile, eval_type): runner_conf = {} if eval_type in (PythonEvalType.SQL_SCALAR_PANDAS_UDF, PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF, PythonEvalType.SQL_GROUPED_AGG_PANDAS_UDF, PythonEvalType.SQL_WINDOW_AGG_PANDAS_UDF): # Load conf used for pandas_udf evaluation num_conf = read_int(infile) for i in range(num_conf): k = utf8_deserializer.loads(infile) v = utf8_deserializer.loads(infile) runner_conf[k] = v # NOTE: if timezone is set here, that implies respectSessionTimeZone is True timezone = runner_conf.get("spark.sql.session.timeZone", None) safecheck = runner_conf.get("spark.sql.execution.pandas.arrowSafeTypeConversion", "false").lower() == 'true' # NOTE: this is duplicated from wrap_grouped_map_pandas_udf assign_cols_by_name = runner_conf.get( "spark.sql.legacy.execution.pandas.groupedMap.assignColumnsByName", "true")\ .lower() == "true" ser = ArrowStreamPandasSerializer(timezone, safecheck, assign_cols_by_name) else: ser = BatchedSerializer(PickleSerializer(), 100) num_udfs = read_int(infile) udfs = {} call_udf = [] mapper_str = "" if eval_type == PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF: # Create function like this: # lambda a: f([a[0]], [a[0], a[1]]) # We assume there is only one UDF here because grouped map doesn't # support combining multiple UDFs. assert num_udfs == 1 # See FlatMapGroupsInPandasExec for how arg_offsets are used to # distinguish between grouping attributes and data attributes arg_offsets, udf = read_single_udf( pickleSer, infile, eval_type, runner_conf, udf_index=0) udfs['f'] = udf split_offset = arg_offsets[0] + 1 arg0 = ["a[%d]" % o for o in arg_offsets[1: split_offset]] arg1 = ["a[%d]" % o for o in arg_offsets[split_offset:]] mapper_str = "lambda a: f([%s], [%s])" % (", ".join(arg0), ", ".join(arg1)) else: # Create function like this: # lambda a: (f0(a[0]), f1(a[1], a[2]), f2(a[3])) # In the special case of a single UDF this will return a single result rather # than a tuple of results; this is the format that the JVM side expects. for i in range(num_udfs): arg_offsets, udf = read_single_udf( pickleSer, infile, eval_type, runner_conf, udf_index=i) udfs['f%d' % i] = udf args = ["a[%d]" % o for o in arg_offsets] call_udf.append("f%d(%s)" % (i, ", ".join(args))) mapper_str = "lambda a: (%s)" % (", ".join(call_udf)) mapper = eval(mapper_str, udfs) func = lambda _, it: map(mapper, it) # profiling is not supported for UDF return func, None, ser, ser
def read_udfs(pickleSer, infile, eval_type): runner_conf = {} if eval_type in (PythonEvalType.SQL_SCALAR_PANDAS_UDF, PythonEvalType.SQL_COGROUPED_MAP_PANDAS_UDF, PythonEvalType.SQL_SCALAR_PANDAS_ITER_UDF, PythonEvalType.SQL_MAP_PANDAS_ITER_UDF, PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF, PythonEvalType.SQL_GROUPED_AGG_PANDAS_UDF, PythonEvalType.SQL_WINDOW_AGG_PANDAS_UDF): # Load conf used for pandas_udf evaluation num_conf = read_int(infile) for i in range(num_conf): k = utf8_deserializer.loads(infile) v = utf8_deserializer.loads(infile) runner_conf[k] = v # NOTE: if timezone is set here, that implies respectSessionTimeZone is True timezone = runner_conf.get("spark.sql.session.timeZone", None) safecheck = runner_conf.get( "spark.sql.execution.pandas.arrowSafeTypeConversion", "false").lower() == 'true' # Used by SQL_GROUPED_MAP_PANDAS_UDF and SQL_SCALAR_PANDAS_UDF when returning StructType assign_cols_by_name = runner_conf.get( "spark.sql.legacy.execution.pandas.groupedMap.assignColumnsByName", "true")\ .lower() == "true" if eval_type == PythonEvalType.SQL_COGROUPED_MAP_PANDAS_UDF: ser = CogroupUDFSerializer(timezone, safecheck, assign_cols_by_name) else: # Scalar Pandas UDF handles struct type arguments as pandas DataFrames instead of # pandas Series. See SPARK-27240. df_for_struct = ( eval_type == PythonEvalType.SQL_SCALAR_PANDAS_UDF or eval_type == PythonEvalType.SQL_SCALAR_PANDAS_ITER_UDF or eval_type == PythonEvalType.SQL_MAP_PANDAS_ITER_UDF) ser = ArrowStreamPandasUDFSerializer(timezone, safecheck, assign_cols_by_name, df_for_struct) else: ser = BatchedSerializer(PickleSerializer(), 100) num_udfs = read_int(infile) is_scalar_iter = eval_type == PythonEvalType.SQL_SCALAR_PANDAS_ITER_UDF is_map_iter = eval_type == PythonEvalType.SQL_MAP_PANDAS_ITER_UDF if is_scalar_iter or is_map_iter: if is_scalar_iter: assert num_udfs == 1, "One SCALAR_ITER UDF expected here." if is_map_iter: assert num_udfs == 1, "One MAP_ITER UDF expected here." arg_offsets, udf = read_single_udf(pickleSer, infile, eval_type, runner_conf, udf_index=0) def func(_, iterator): num_input_rows = [0] def map_batch(batch): udf_args = [batch[offset] for offset in arg_offsets] num_input_rows[0] += len(udf_args[0]) if len(udf_args) == 1: return udf_args[0] else: return tuple(udf_args) iterator = map(map_batch, iterator) result_iter = udf(iterator) num_output_rows = 0 for result_batch, result_type in result_iter: num_output_rows += len(result_batch) assert is_map_iter or num_output_rows <= num_input_rows[0], \ "Pandas MAP_ITER UDF outputted more rows than input rows." yield (result_batch, result_type) if is_scalar_iter: try: next(iterator) except StopIteration: pass else: raise RuntimeError( "SQL_SCALAR_PANDAS_ITER_UDF should exhaust the input " "iterator.") if is_scalar_iter and num_output_rows != num_input_rows[0]: raise RuntimeError( "The number of output rows of pandas iterator UDF should be " "the same with input rows. The input rows number is %d but the " "output rows number is %d." % (num_input_rows[0], num_output_rows)) # profiling is not supported for UDF return func, None, ser, ser def extract_key_value_indexes(grouped_arg_offsets): """ Helper function to extract the key and value indexes from arg_offsets for the grouped and cogrouped pandas udfs. See BasePandasGroupExec.resolveArgOffsets for equivalent scala code. :param grouped_arg_offsets: List containing the key and value indexes of columns of the DataFrames to be passed to the udf. It consists of n repeating groups where n is the number of DataFrames. Each group has the following format: group[0]: length of group group[1]: length of key indexes group[2.. group[1] +2]: key attributes group[group[1] +3 group[0]]: value attributes """ parsed = [] idx = 0 while idx < len(grouped_arg_offsets): offsets_len = grouped_arg_offsets[idx] idx += 1 offsets = grouped_arg_offsets[idx:idx + offsets_len] split_index = offsets[0] + 1 offset_keys = offsets[1:split_index] offset_values = offsets[split_index:] parsed.append([offset_keys, offset_values]) idx += offsets_len return parsed udfs = {} call_udf = [] mapper_str = "" if eval_type == PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF: # Create function like this: # lambda a: f([a[0]], [a[0], a[1]]) # We assume there is only one UDF here because grouped map doesn't # support combining multiple UDFs. assert num_udfs == 1 # See FlatMapGroupsInPandasExec for how arg_offsets are used to # distinguish between grouping attributes and data attributes arg_offsets, udf = read_single_udf(pickleSer, infile, eval_type, runner_conf, udf_index=0) udfs['f'] = udf parsed_offsets = extract_key_value_indexes(arg_offsets) keys = ["a[%d]" % (o, ) for o in parsed_offsets[0][0]] vals = ["a[%d]" % (o, ) for o in parsed_offsets[0][1]] mapper_str = "lambda a: f([%s], [%s])" % (", ".join(keys), ", ".join(vals)) elif eval_type == PythonEvalType.SQL_COGROUPED_MAP_PANDAS_UDF: # We assume there is only one UDF here because cogrouped map doesn't # support combining multiple UDFs. assert num_udfs == 1 arg_offsets, udf = read_single_udf(pickleSer, infile, eval_type, runner_conf, udf_index=0) udfs['f'] = udf parsed_offsets = extract_key_value_indexes(arg_offsets) df1_keys = ["a[0][%d]" % (o, ) for o in parsed_offsets[0][0]] df1_vals = ["a[0][%d]" % (o, ) for o in parsed_offsets[0][1]] df2_keys = ["a[1][%d]" % (o, ) for o in parsed_offsets[1][0]] df2_vals = ["a[1][%d]" % (o, ) for o in parsed_offsets[1][1]] mapper_str = "lambda a: f([%s], [%s], [%s], [%s])" % ( ", ".join(df1_keys), ", ".join(df1_vals), ", ".join(df2_keys), ", ".join(df2_vals)) else: # Create function like this: # lambda a: (f0(a[0]), f1(a[1], a[2]), f2(a[3])) # In the special case of a single UDF this will return a single result rather # than a tuple of results; this is the format that the JVM side expects. for i in range(num_udfs): arg_offsets, udf = read_single_udf(pickleSer, infile, eval_type, runner_conf, udf_index=i) udfs['f%d' % i] = udf args = ["a[%d]" % o for o in arg_offsets] call_udf.append("f%d(%s)" % (i, ", ".join(args))) mapper_str = "lambda a: (%s)" % (", ".join(call_udf)) mapper = eval(mapper_str, udfs) func = lambda _, it: map(mapper, it) # profiling is not supported for UDF return func, None, ser, ser
def manager(): # Create a new process group to corral our children os.setpgid(0, 0) # Create a listening socket on the AF_INET loopback interface listen_sock = socket.socket(AF_INET, SOCK_STREAM) listen_sock.bind(("127.0.0.1", 0)) listen_sock.listen(max(1024, SOMAXCONN)) listen_host, listen_port = listen_sock.getsockname() # re-open stdin/stdout in 'wb' mode stdin_bin = os.fdopen(sys.stdin.fileno(), "rb", 4) stdout_bin = os.fdopen(sys.stdout.fileno(), "wb", 4) write_int(listen_port, stdout_bin) stdout_bin.flush() def shutdown(code): signal.signal(SIGTERM, SIG_DFL) # Send SIGHUP to notify workers of shutdown os.kill(0, SIGHUP) sys.exit(code) def handle_sigterm(*args): shutdown(1) signal.signal(SIGTERM, handle_sigterm) # Gracefully exit on SIGTERM signal.signal(SIGHUP, SIG_IGN) # Don't die on SIGHUP signal.signal(SIGCHLD, SIG_IGN) reuse = os.environ.get("SPARK_REUSE_WORKER") # Initialization complete try: while True: try: ready_fds = select.select([0, listen_sock], [], [], 1)[0] except select.error as ex: if ex[0] == EINTR: continue else: raise if 0 in ready_fds: try: worker_pid = read_int(stdin_bin) except EOFError: # Spark told us to exit by closing stdin shutdown(0) try: os.kill(worker_pid, signal.SIGKILL) except OSError: pass # process already died if listen_sock in ready_fds: try: sock, _ = listen_sock.accept() except OSError as e: if e.errno == EINTR: continue raise # Launch a worker process try: pid = os.fork() except OSError as e: if e.errno in (EAGAIN, EINTR): time.sleep(1) pid = os.fork() # error here will shutdown daemon else: outfile = sock.makefile(mode="wb") write_int(e.errno, outfile) # Signal that the fork failed outfile.flush() outfile.close() sock.close() continue if pid == 0: # in child process listen_sock.close() # It should close the standard input in the child process so that # Python native function executions stay intact. # # Note that if we just close the standard input (file descriptor 0), # the lowest file descriptor (file descriptor 0) will be allocated, # later when other file descriptors should happen to open. # # Therefore, here we redirects it to '/dev/null' by duplicating # another file descriptor for '/dev/null' to the standard input (0). # See SPARK-26175. devnull = open(os.devnull, "r") os.dup2(devnull.fileno(), 0) devnull.close() try: # Acknowledge that the fork was successful outfile = sock.makefile(mode="wb") write_int(os.getpid(), outfile) outfile.flush() outfile.close() authenticated = False while True: code = worker(sock, authenticated) if code == 0: authenticated = True if not reuse or code: # wait for closing try: while sock.recv(1024): pass except Exception: pass break gc.collect() except: traceback.print_exc() os._exit(1) else: os._exit(0) else: sock.close() finally: shutdown(1)
def main(infile, outfile): try: boot_time = time.time() split_index = read_int(infile) if split_index == -1: # for unit tests exit(-1) version = utf8_deserializer.loads(infile) if version != "%d.%d" % sys.version_info[:2]: raise Exception(("Python in worker has different version %s than that in " + "driver %s, PySpark cannot run with different minor versions") % ("%d.%d" % sys.version_info[:2], version)) # initialize global state shuffle.MemoryBytesSpilled = 0 shuffle.DiskBytesSpilled = 0 _accumulatorRegistry.clear() # fetch name of workdir spark_files_dir = utf8_deserializer.loads(infile) SparkFiles._root_directory = spark_files_dir SparkFiles._is_running_on_worker = True # fetch names of includes (*.zip and *.egg files) and construct PYTHONPATH add_path(spark_files_dir) # *.py files that were added will be copied here num_python_includes = read_int(infile) for _ in range(num_python_includes): filename = utf8_deserializer.loads(infile) add_path(os.path.join(spark_files_dir, filename)) if sys.version > '3': import importlib importlib.invalidate_caches() # fetch names and values of broadcast variables num_broadcast_variables = read_int(infile) for _ in range(num_broadcast_variables): bid = read_long(infile) if bid >= 0: path = utf8_deserializer.loads(infile) _broadcastRegistry[bid] = Broadcast(path=path) else: bid = - bid - 1 _broadcastRegistry.pop(bid) _accumulatorRegistry.clear() row_based = read_int(infile) num_commands = read_int(infile) if row_based: profiler = None # profiling is not supported for UDF row_func = None for i in range(num_commands): f, returnType, deserializer = read_command(pickleSer, infile) if row_func is None: row_func = f else: row_func = chain(row_func, f) serializer = deserializer func = lambda _, it: map(lambda x: returnType.toInternal(row_func(*x)), it) else: assert num_commands == 1 func, profiler, deserializer, serializer = read_command(pickleSer, infile) init_time = time.time() def process(): iterator = deserializer.load_stream(infile) serializer.dump_stream(func(split_index, iterator), outfile) if profiler: profiler.profile(process) else: process() except Exception: try: write_int(SpecialLengths.PYTHON_EXCEPTION_THROWN, outfile) write_with_length(traceback.format_exc().encode("utf-8"), outfile) except IOError: # JVM close the socket pass except Exception: # Write the error to stderr if it happened while serializing print("PySpark worker failed with exception:", file=sys.stderr) print(traceback.format_exc(), file=sys.stderr) exit(-1) finish_time = time.time() report_times(outfile, boot_time, init_time, finish_time) write_long(shuffle.MemoryBytesSpilled, outfile) write_long(shuffle.DiskBytesSpilled, outfile) # Mark the beginning of the accumulators section of the output write_int(SpecialLengths.END_OF_DATA_SECTION, outfile) write_int(len(_accumulatorRegistry), outfile) for (aid, accum) in _accumulatorRegistry.items(): pickleSer._write_with_length((aid, accum._value), outfile) # check end of stream if read_int(infile) == SpecialLengths.END_OF_STREAM: write_int(SpecialLengths.END_OF_STREAM, outfile) else: # write a different value to tell JVM to not reuse this worker write_int(SpecialLengths.END_OF_DATA_SECTION, outfile) exit(-1)