Esempio n. 1
0
def read_single_udf(pickleSer, infile, eval_type, runner_conf, udf_index):
    num_arg = read_int(infile)
    arg_offsets = [read_int(infile) for i in range(num_arg)]
    row_func = None
    for i in range(read_int(infile)):
        f, return_type = read_command(pickleSer, infile)
        if row_func is None:
            row_func = f
        else:
            row_func = chain(row_func, f)

    # make sure StopIteration's raised in the user code are not ignored
    # when they are processed in a for loop, raise them as RuntimeError's instead
    func = fail_on_stopiteration(row_func)

    # the last returnType will be the return type of UDF
    if eval_type == PythonEvalType.SQL_SCALAR_PANDAS_UDF:
        return arg_offsets, wrap_scalar_pandas_udf(func, return_type)
    elif eval_type == PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF:
        argspec = _get_argspec(row_func)  # signature was lost when wrapping it
        return arg_offsets, wrap_grouped_map_pandas_udf(func, return_type, argspec, runner_conf)
    elif eval_type == PythonEvalType.SQL_GROUPED_AGG_PANDAS_UDF:
        return arg_offsets, wrap_grouped_agg_pandas_udf(func, return_type)
    elif eval_type == PythonEvalType.SQL_WINDOW_AGG_PANDAS_UDF:
        return arg_offsets, wrap_window_agg_pandas_udf(func, return_type, runner_conf, udf_index)
    elif eval_type == PythonEvalType.SQL_BATCHED_UDF:
        return arg_offsets, wrap_udf(func, return_type)
    else:
        raise ValueError("Unknown eval type: {}".format(eval_type))
Esempio n. 2
0
def read_udfs(pickleSer, infile, eval_type):
    runner_conf = {}

    if eval_type in (PythonEvalType.SQL_SCALAR_PANDAS_UDF,
                     PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF,
                     PythonEvalType.SQL_GROUPED_AGG_PANDAS_UDF,
                     PythonEvalType.SQL_WINDOW_AGG_PANDAS_UDF):

        # Load conf used for pandas_udf evaluation
        num_conf = read_int(infile)
        for i in range(num_conf):
            k = utf8_deserializer.loads(infile)
            v = utf8_deserializer.loads(infile)
            runner_conf[k] = v

        # NOTE: if timezone is set here, that implies respectSessionTimeZone is True
        timezone = runner_conf.get("spark.sql.session.timeZone", None)
        ser = ArrowStreamPandasSerializer(timezone)
    else:
        ser = BatchedSerializer(PickleSerializer(), 100)

    num_udfs = read_int(infile)
    udfs = {}
    call_udf = []
    mapper_str = ""
    if eval_type == PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF:
        # Create function like this:
        #   lambda a: f([a[0]], [a[0], a[1]])

        # We assume there is only one UDF here because grouped map doesn't
        # support combining multiple UDFs.
        assert num_udfs == 1

        # See FlatMapGroupsInPandasExec for how arg_offsets are used to
        # distinguish between grouping attributes and data attributes
        arg_offsets, udf = read_single_udf(
            pickleSer, infile, eval_type, runner_conf, udf_index=0)
        udfs['f'] = udf
        split_offset = arg_offsets[0] + 1
        arg0 = ["a[%d]" % o for o in arg_offsets[1: split_offset]]
        arg1 = ["a[%d]" % o for o in arg_offsets[split_offset:]]
        mapper_str = "lambda a: f([%s], [%s])" % (", ".join(arg0), ", ".join(arg1))
    else:
        # Create function like this:
        #   lambda a: (f0(a[0]), f1(a[1], a[2]), f2(a[3]))
        # In the special case of a single UDF this will return a single result rather
        # than a tuple of results; this is the format that the JVM side expects.
        for i in range(num_udfs):
            arg_offsets, udf = read_single_udf(
                pickleSer, infile, eval_type, runner_conf, udf_index=i)
            udfs['f%d' % i] = udf
            args = ["a[%d]" % o for o in arg_offsets]
            call_udf.append("f%d(%s)" % (i, ", ".join(args)))
        mapper_str = "lambda a: (%s)" % (", ".join(call_udf))

    mapper = eval(mapper_str, udfs)
    func = lambda _, it: map(mapper, it)

    # profiling is not supported for UDF
    return func, None, ser, ser
Esempio n. 3
0
def main():
    split_index = read_int(sys.stdin)
    spark_files_dir = load_pickle(read_with_length(sys.stdin))
    SparkFiles._root_directory = spark_files_dir
    SparkFiles._is_running_on_worker = True
    sys.path.append(spark_files_dir)
    num_broadcast_variables = read_int(sys.stdin)
    for _ in range(num_broadcast_variables):
        bid = read_long(sys.stdin)
        value = read_with_length(sys.stdin)
        _broadcastRegistry[bid] = Broadcast(bid, load_pickle(value))
    func = load_obj()
    bypassSerializer = load_obj()
    if bypassSerializer:
        dumps = lambda x: x
    else:
        dumps = dump_pickle
    iterator = read_from_pickle_file(sys.stdin)
    try:
        for obj in func(split_index, iterator):
           write_with_length(dumps(obj), old_stdout)
    except Exception as e:
        write_int(-2, old_stdout)
        write_with_length(traceback.format_exc(), old_stdout)
        sys.exit(-1)
    # Mark the beginning of the accumulators section of the output
    write_int(-1, old_stdout)
    for aid, accum in _accumulatorRegistry.items():
        write_with_length(dump_pickle((aid, accum._value)), old_stdout)
Esempio n. 4
0
def main(infile, outfile):
    try:
        boot_time = time.time()
        split_index = read_int(infile)
        if split_index == -1:  # for unit tests
            return

        # fetch name of workdir
        spark_files_dir = utf8_deserializer.loads(infile)
        SparkFiles._root_directory = spark_files_dir
        SparkFiles._is_running_on_worker = True

        # fetch names of includes (*.zip and *.egg files) and construct PYTHONPATH
        sys.path.append(spark_files_dir)  # *.py files that were added will be copied here
        num_python_includes = read_int(infile)
        for _ in range(num_python_includes):
            filename = utf8_deserializer.loads(infile)
            sys.path.append(os.path.join(spark_files_dir, filename))

        # fetch names and values of broadcast variables
        num_broadcast_variables = read_int(infile)
        ser = CompressedSerializer(pickleSer)
        for _ in range(num_broadcast_variables):
            bid = read_long(infile)
            if bid >= 0:
                value = ser._read_with_length(infile)
                _broadcastRegistry[bid] = Broadcast(bid, value)
            else:
                bid = - bid - 1
                _broadcastRegistry.remove(bid)

        _accumulatorRegistry.clear()
        command = pickleSer._read_with_length(infile)
        (func, deserializer, serializer) = command
        init_time = time.time()
        iterator = deserializer.load_stream(infile)
        serializer.dump_stream(func(split_index, iterator), outfile)
    except Exception:
        try:
            write_int(SpecialLengths.PYTHON_EXCEPTION_THROWN, outfile)
            write_with_length(traceback.format_exc(), outfile)
            outfile.flush()
        except IOError:
            # JVM close the socket
            pass
        except Exception:
            # Write the error to stderr if it happened while serializing
            print >> sys.stderr, "PySpark worker failed with exception:"
            print >> sys.stderr, traceback.format_exc()
        exit(-1)
    finish_time = time.time()
    report_times(outfile, boot_time, init_time, finish_time)
    # Mark the beginning of the accumulators section of the output
    write_int(SpecialLengths.END_OF_DATA_SECTION, outfile)
    write_int(len(_accumulatorRegistry), outfile)
    for (aid, accum) in _accumulatorRegistry.items():
        pickleSer._write_with_length((aid, accum._value), outfile)
Esempio n. 5
0
def read_single_udf(pickleSer, infile):
    num_arg = read_int(infile)
    arg_offsets = [read_int(infile) for i in range(num_arg)]
    row_func = None
    for i in range(read_int(infile)):
        f, return_type = read_command(pickleSer, infile)
        if row_func is None:
            row_func = f
        else:
            row_func = chain(row_func, f)
    # the last returnType will be the return type of UDF
    return arg_offsets, wrap_udf(row_func, return_type)
Esempio n. 6
0
def read_single_udf(pickleSer, infile, eval_type):
    num_arg = read_int(infile)
    arg_offsets = [read_int(infile) for i in range(num_arg)]
    row_func = None
    for i in range(read_int(infile)):
        f, return_type = read_command(pickleSer, infile)
        if row_func is None:
            row_func = f
        else:
            row_func = chain(row_func, f)
    # the last returnType will be the return type of UDF
    if eval_type == PythonEvalType.SQL_PANDAS_UDF:
        return arg_offsets, wrap_pandas_udf(row_func, return_type)
    else:
        return arg_offsets, wrap_udf(row_func, return_type)
Esempio n. 7
0
    def do_termination_test(self, terminator):
        from subprocess import Popen, PIPE
        from errno import ECONNREFUSED

        # start daemon
        daemon_path = os.path.join(os.path.dirname(__file__), "daemon.py")
        daemon = Popen([sys.executable, daemon_path], stdin=PIPE, stdout=PIPE)

        # read the port number
        port = read_int(daemon.stdout)

        # daemon should accept connections
        self.assertTrue(self.connect(port))

        # request shutdown
        terminator(daemon)
        time.sleep(1)

        # daemon should no longer accept connections
        try:
            self.connect(port)
        except EnvironmentError as exception:
            self.assertEqual(exception.errno, ECONNREFUSED)
        else:
            self.fail("Expected EnvironmentError to be raised")
Esempio n. 8
0
def main():
    split_index = read_int(sys.stdin)
    num_broadcast_variables = read_int(sys.stdin)
    for _ in range(num_broadcast_variables):
        bid = read_long(sys.stdin)
        value = read_with_length(sys.stdin)
        _broadcastRegistry[bid] = Broadcast(bid, load_pickle(value))
    func = load_obj()
    bypassSerializer = load_obj()
    if bypassSerializer:
        dumps = lambda x: x
    else:
        dumps = dump_pickle
    iterator = read_from_pickle_file(sys.stdin)
    for obj in func(split_index, iterator):
        write_with_length(dumps(obj), old_stdout)
Esempio n. 9
0
def read_udfs(pickleSer, infile, eval_type):
    num_udfs = read_int(infile)
    udfs = {}
    call_udf = []
    for i in range(num_udfs):
        arg_offsets, udf = read_single_udf(pickleSer, infile, eval_type)
        udfs['f%d' % i] = udf
        args = ["a[%d]" % o for o in arg_offsets]
        call_udf.append("f%d(%s)" % (i, ", ".join(args)))
    # Create function like this:
    #   lambda a: (f0(a0), f1(a1, a2), f2(a3))
    # In the special case of a single UDF this will return a single result rather
    # than a tuple of results; this is the format that the JVM side expects.
    mapper_str = "lambda a: (%s)" % (", ".join(call_udf))
    mapper = eval(mapper_str, udfs)

    func = lambda _, it: map(mapper, it)

    if eval_type in (PythonEvalType.SQL_SCALAR_PANDAS_UDF,
                     PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF,
                     PythonEvalType.SQL_GROUPED_AGG_PANDAS_UDF):
        timezone = utf8_deserializer.loads(infile)
        ser = ArrowStreamPandasSerializer(timezone)
    else:
        ser = BatchedSerializer(PickleSerializer(), 100)

    # profiling is not supported for UDF
    return func, None, ser, ser
Esempio n. 10
0
def main(infile, outfile):
    boot_time = time.time()
    split_index = read_int(infile)
    if split_index == -1:  # for unit tests
        return

    # fetch name of workdir
    spark_files_dir = load_pickle(read_with_length(infile))
    SparkFiles._root_directory = spark_files_dir
    SparkFiles._is_running_on_worker = True

    # fetch names and values of broadcast variables
    num_broadcast_variables = read_int(infile)
    for _ in range(num_broadcast_variables):
        bid = read_long(infile)
        value = read_with_length(infile)
        _broadcastRegistry[bid] = Broadcast(bid, load_pickle(value))

    # fetch names of includes (*.zip and *.egg files) and construct PYTHONPATH
    sys.path.append(spark_files_dir) # *.py files that were added will be copied here
    num_python_includes =  read_int(infile)
    for _ in range(num_python_includes):
        sys.path.append(os.path.join(spark_files_dir, load_pickle(read_with_length(infile))))

    # now load function
    func = load_obj(infile)
    bypassSerializer = load_obj(infile)
    if bypassSerializer:
        dumps = lambda x: x
    else:
        dumps = dump_pickle
    init_time = time.time()
    iterator = read_from_pickle_file(infile)
    try:
        for obj in func(split_index, iterator):
            write_with_length(dumps(obj), outfile)
    except Exception as e:
        write_int(-2, outfile)
        write_with_length(traceback.format_exc(), outfile)
        sys.exit(-1)
    finish_time = time.time()
    report_times(outfile, boot_time, init_time, finish_time)
    # Mark the beginning of the accumulators section of the output
    write_int(-1, outfile)
    for aid, accum in _accumulatorRegistry.items():
        write_with_length(dump_pickle((aid, accum._value)), outfile)
    write_int(-1, outfile)
def _read_vec(stream):
    vector_length = read_int(stream)
    vector = np.empty(vector_length)
    # TODO: maybe some optimized way to read this all at once?
    for i in xrange(vector_length):
        vector[i] = struct.unpack('!d', stream.read(8))[0]
    
    return vector
Esempio n. 12
0
 def handle(self):
     from pyspark.accumulators import _accumulatorRegistry
     num_updates = read_int(self.rfile)
     for _ in range(num_updates):
         (aid, update) = load_pickle(read_with_length(self.rfile))
         _accumulatorRegistry[aid] += update
     # Write a byte in acknowledgement
     self.wfile.write(struct.pack("!b", 1))
Esempio n. 13
0
 def accum_updates():
     num_updates = read_int(self.rfile)
     for _ in range(num_updates):
         (aid, update) = pickleSer._read_with_length(self.rfile)
         _accumulatorRegistry[aid] += update
     # Write a byte in acknowledgement
     self.wfile.write(struct.pack("!b", 1))
     return False
Esempio n. 14
0
 def func(event):
     headersBytes = BytesIO(event[0]) if sys.version >= "3" else StringIO(event[0])
     headers = {}
     strSer = UTF8Deserializer()
     for i in range(0, read_int(headersBytes)):
         key = strSer.loads(headersBytes)
         value = strSer.loads(headersBytes)
         headers[key] = value
     body = bodyDecoder(event[1])
     return (headers, body)
Esempio n. 15
0
def main(infile, outfile):
    boot_time = time.time()
    split_index = read_int(infile)
    if split_index == -1:  # for unit tests
        return

    # fetch name of workdir
    spark_files_dir = mutf8_deserializer.loads(infile)
    SparkFiles._root_directory = spark_files_dir
    SparkFiles._is_running_on_worker = True

    # fetch names and values of broadcast variables
    num_broadcast_variables = read_int(infile)
    for _ in range(num_broadcast_variables):
        bid = read_long(infile)
        value = pickleSer._read_with_length(infile)
        _broadcastRegistry[bid] = Broadcast(bid, value)

    # fetch names of includes (*.zip and *.egg files) and construct PYTHONPATH
    sys.path.append(spark_files_dir) # *.py files that were added will be copied here
    num_python_includes =  read_int(infile)
    for _ in range(num_python_includes):
        filename = mutf8_deserializer.loads(infile)
        sys.path.append(os.path.join(spark_files_dir, filename))

    command = pickleSer._read_with_length(infile)
    (func, deserializer, serializer) = command
    init_time = time.time()
    try:
        iterator = deserializer.load_stream(infile)
        serializer.dump_stream(func(split_index, iterator), outfile)
    except Exception as e:
        write_int(SpecialLengths.PYTHON_EXCEPTION_THROWN, outfile)
        write_with_length(traceback.format_exc(), outfile)
        sys.exit(-1)
    finish_time = time.time()
    report_times(outfile, boot_time, init_time, finish_time)
    # Mark the beginning of the accumulators section of the output
    write_int(SpecialLengths.END_OF_DATA_SECTION, outfile)
    write_int(len(_accumulatorRegistry), outfile)
    for (aid, accum) in _accumulatorRegistry.items():
        pickleSer._write_with_length((aid, accum._value), outfile)
Esempio n. 16
0
    def load_stream(self, stream):
        """
        Load a stream of un-ordered Arrow RecordBatches, where the last iteration yields
        a list of indices that can be used to put the RecordBatches in the correct order.
        """
        # load the batches
        for batch in self.serializer.load_stream(stream):
            yield batch

        # load the batch order indices or propagate any error that occurred in the JVM
        num = read_int(stream)
        if num == -1:
            error_msg = UTF8Deserializer().loads(stream)
            raise RuntimeError(
                "An error occurred while calling "
                "ArrowCollectSerializer.load_stream: {}".format(error_msg))
        batch_order = []
        for i in range(num):
            index = read_int(stream)
            batch_order.append(index)
        yield batch_order
Esempio n. 17
0
 def handle(self):
     from pyspark.accumulators import _accumulatorRegistry
     while not self.server.server_shutdown:
         # Poll every 1 second for new data -- don't block in case of shutdown.
         r, _, _ = select.select([self.rfile], [], [], 1)
         if self.rfile in r:
             num_updates = read_int(self.rfile)
             for _ in range(num_updates):
                 (aid, update) = pickleSer._read_with_length(self.rfile)
                 _accumulatorRegistry[aid] += update
             # Write a byte in acknowledgement
             self.wfile.write(struct.pack("!b", 1))
Esempio n. 18
0
 def handle(self):
     from pyspark.accumulators import _accumulatorRegistry
     while not self.server.server_shutdown:
         # Poll every 1 second for new data -- don't block in case of shutdown.
         r, _, _ = select.select([self.rfile], [], [], 1)
         if self.rfile in r:
             num_updates = read_int(self.rfile)
             for _ in range(num_updates):
                 (aid, update) = pickleSer._read_with_length(self.rfile)
                 _accumulatorRegistry[aid] += update
             # Write a byte in acknowledgement
             self.wfile.write(struct.pack("!b", 1))
Esempio n. 19
0
def read_single_udf(pickleSer, infile, eval_type, runner_conf, udf_index):
    num_arg = read_int(infile)
    arg_offsets = [read_int(infile) for i in range(num_arg)]
    chained_func = None
    profiler = None
    for i in range(read_int(infile)):
        f, return_type, profiler = read_command(pickleSer, infile)
        if chained_func is None:
            chained_func = f
        else:
            chained_func = chain(chained_func, f)

    if eval_type == PythonEvalType.SQL_SCALAR_PANDAS_ITER_UDF:
        func = chained_func
    else:
        # make sure StopIteration's raised in the user code are not ignored
        # when they are processed in a for loop, raise them as RuntimeError's instead
        func = fail_on_stopiteration(chained_func)

    # the last returnType will be the return type of UDF
    if eval_type == PythonEvalType.SQL_SCALAR_PANDAS_UDF:
        return arg_offsets, wrap_scalar_pandas_udf(func, return_type), profiler
    elif eval_type == PythonEvalType.SQL_SCALAR_PANDAS_ITER_UDF:
        return arg_offsets, wrap_pandas_iter_udf(func, return_type), profiler
    elif eval_type == PythonEvalType.SQL_MAP_PANDAS_ITER_UDF:
        return arg_offsets, wrap_pandas_iter_udf(func, return_type), profiler
    elif eval_type == PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF:
        argspec = _get_argspec(chained_func)  # signature was lost when wrapping it
        return arg_offsets, wrap_grouped_map_pandas_udf(func, return_type, argspec), profiler
    elif eval_type == PythonEvalType.SQL_COGROUPED_MAP_PANDAS_UDF:
        argspec = _get_argspec(chained_func)  # signature was lost when wrapping it
        return arg_offsets, wrap_cogrouped_map_pandas_udf(func, return_type, argspec), profiler
    elif eval_type == PythonEvalType.SQL_GROUPED_AGG_PANDAS_UDF:
        return arg_offsets, wrap_grouped_agg_pandas_udf(func, return_type), profiler
    elif eval_type == PythonEvalType.SQL_WINDOW_AGG_PANDAS_UDF:
        return arg_offsets, wrap_window_agg_pandas_udf(func, return_type, runner_conf, udf_index), profiler
    elif eval_type == PythonEvalType.SQL_BATCHED_UDF:
        return arg_offsets, wrap_udf(func, return_type), profiler
    else:
        raise ValueError("Unknown eval type: {}".format(eval_type))
Esempio n. 20
0
def read_single_udf(pickleSer, infile, eval_type):
    num_arg = read_int(infile)
    arg_offsets = [read_int(infile) for i in range(num_arg)]
    row_func = None
    for i in range(read_int(infile)):
        f, return_type = read_command(pickleSer, infile)
        if row_func is None:
            row_func = f
        else:
            row_func = chain(row_func, f)

    # make sure StopIteration's raised in the user code are not ignored
    # when they are processed in a for loop, raise them as RuntimeError's instead
    func = fail_on_stopiteration(row_func)

    # the last returnType will be the return type of UDF
    if eval_type == PythonEvalType.SQL_SCALAR_PANDAS_UDF:
        return arg_offsets, wrap_scalar_pandas_udf(func, return_type)
    elif eval_type == PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF:
        return arg_offsets, wrap_grouped_map_pandas_udf(func, return_type)
    else:
        return arg_offsets, wrap_udf(func, return_type)
Esempio n. 21
0
def read_single_udf(pickleSer, infile, eval_type):
    num_arg = read_int(infile)
    arg_offsets = [read_int(infile) for i in range(num_arg)]
    row_func = None
    for i in range(read_int(infile)):
        f, return_type = read_command(pickleSer, infile)
        if row_func is None:
            row_func = f
        else:
            row_func = chain(row_func, f)

    # the last returnType will be the return type of UDF
    if eval_type == PythonEvalType.SQL_SCALAR_PANDAS_UDF:
        return arg_offsets, wrap_scalar_pandas_udf(row_func, return_type)
    elif eval_type == PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF:
        return arg_offsets, wrap_grouped_map_pandas_udf(row_func, return_type)
    elif eval_type == PythonEvalType.SQL_GROUPED_AGG_PANDAS_UDF:
        return arg_offsets, wrap_grouped_agg_pandas_udf(row_func, return_type)
    elif eval_type == PythonEvalType.SQL_BATCHED_UDF:
        return arg_offsets, wrap_udf(row_func, return_type)
    else:
        raise ValueError("Unknown eval type: {}".format(eval_type))
Esempio n. 22
0
def read_single_udf(pickleSer, infile, eval_type):
    num_arg = read_int(infile)
    arg_offsets = [read_int(infile) for i in range(num_arg)]
    row_func = None
    for i in range(read_int(infile)):
        f, return_type = read_command(pickleSer, infile)
        if row_func is None:
            row_func = f
        else:
            row_func = chain(row_func, f)

    # the last returnType will be the return type of UDF
    if eval_type == PythonEvalType.SQL_SCALAR_PANDAS_UDF:
        return arg_offsets, wrap_scalar_pandas_udf(row_func, return_type)
    elif eval_type == PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF:
        return arg_offsets, wrap_grouped_map_pandas_udf(row_func, return_type)
    elif eval_type == PythonEvalType.SQL_GROUPED_AGG_PANDAS_UDF:
        return arg_offsets, wrap_grouped_agg_pandas_udf(row_func, return_type)
    elif eval_type == PythonEvalType.SQL_BATCHED_UDF:
        return arg_offsets, wrap_udf(row_func, return_type)
    else:
        raise ValueError("Unknown eval type: {}".format(eval_type))
Esempio n. 23
0
def main(infile, outfile):
    boot_time = time.time()
    split_index = read_int(infile)
    if split_index == -1:  # for unit tests
        return
    spark_files_dir = load_pickle(read_with_length(infile))
    SparkFiles._root_directory = spark_files_dir
    SparkFiles._is_running_on_worker = True
    sys.path.append(spark_files_dir)
    num_broadcast_variables = read_int(infile)
    for _ in range(num_broadcast_variables):
        bid = read_long(infile)
        value = read_with_length(infile)
        _broadcastRegistry[bid] = Broadcast(bid, load_pickle(value))
    func = load_obj(infile)
    bypassSerializer = load_obj(infile)
    if bypassSerializer:
        dumps = lambda x: x
    else:
        dumps = dump_pickle
    init_time = time.time()
    iterator = read_from_pickle_file(infile)
    try:
        for obj in func(split_index, iterator):
            write_with_length(dumps(obj), outfile)
    except Exception as e:
        write_int(-2, outfile)
        write_with_length(traceback.format_exc(), outfile)
        sys.exit(-1)
    finish_time = time.time()
    report_times(outfile, boot_time, init_time, finish_time)
    # Mark the beginning of the accumulators section of the output
    write_int(-1, outfile)
    for aid, accum in _accumulatorRegistry.items():
        write_with_length(dump_pickle((aid, accum._value)), outfile)
    write_int(-1, outfile)
Esempio n. 24
0
def main(infile, outfile):
    boot_time = time.time()
    split_index = read_int(infile)
    if split_index == -1:  # for unit tests
        return
    spark_files_dir = load_pickle(read_with_length(infile))
    SparkFiles._root_directory = spark_files_dir
    SparkFiles._is_running_on_worker = True
    sys.path.append(spark_files_dir)
    num_broadcast_variables = read_int(infile)
    for _ in range(num_broadcast_variables):
        bid = read_long(infile)
        value = read_with_length(infile)
        _broadcastRegistry[bid] = Broadcast(bid, load_pickle(value))
    func = load_obj(infile)
    bypassSerializer = load_obj(infile)
    if bypassSerializer:
        dumps = lambda x: x
    else:
        dumps = dump_pickle
    init_time = time.time()
    iterator = read_from_pickle_file(infile)
    try:
        for obj in func(split_index, iterator):
            write_with_length(dumps(obj), outfile)
    except Exception as e:
        write_int(-2, outfile)
        write_with_length(traceback.format_exc(), outfile)
        sys.exit(-1)
    finish_time = time.time()
    report_times(outfile, boot_time, init_time, finish_time)
    # Mark the beginning of the accumulators section of the output
    write_int(-1, outfile)
    for aid, accum in _accumulatorRegistry.items():
        write_with_length(dump_pickle((aid, accum._value)), outfile)
    write_int(-1, outfile)
Esempio n. 25
0
def read_udfs(pickleSer, infile, eval_type):
    num_udfs = read_int(infile)
    udfs = {}
    call_udf = []
    mapper_str = ""
    if eval_type == PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF:
        # Create function like this:
        #   lambda a: f([a[0]], [a[0], a[1]])

        # We assume there is only one UDF here because grouped map doesn't
        # support combining multiple UDFs.
        assert num_udfs == 1

        # See FlatMapGroupsInPandasExec for how arg_offsets are used to
        # distinguish between grouping attributes and data attributes
        arg_offsets, udf = read_single_udf(pickleSer, infile, eval_type)
        udfs['f'] = udf
        split_offset = arg_offsets[0] + 1
        arg0 = ["a[%d]" % o for o in arg_offsets[1:split_offset]]
        arg1 = ["a[%d]" % o for o in arg_offsets[split_offset:]]
        mapper_str = "lambda a: f([%s], [%s])" % (", ".join(arg0),
                                                  ", ".join(arg1))
    else:
        # Create function like this:
        #   lambda a: (f0(a[0]), f1(a[1], a[2]), f2(a[3]))
        # In the special case of a single UDF this will return a single result rather
        # than a tuple of results; this is the format that the JVM side expects.
        for i in range(num_udfs):
            arg_offsets, udf = read_single_udf(pickleSer, infile, eval_type)
            udfs['f%d' % i] = udf
            args = ["a[%d]" % o for o in arg_offsets]
            call_udf.append("f%d(%s)" % (i, ", ".join(args)))
        mapper_str = "lambda a: (%s)" % (", ".join(call_udf))

    mapper = eval(mapper_str, udfs)
    func = lambda _, it: map(mapper, it)

    if eval_type in (PythonEvalType.SQL_SCALAR_PANDAS_UDF,
                     PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF,
                     PythonEvalType.SQL_GROUPED_AGG_PANDAS_UDF,
                     PythonEvalType.SQL_WINDOW_AGG_PANDAS_UDF):
        timezone = utf8_deserializer.loads(infile)
        ser = ArrowStreamPandasSerializer(timezone)
    else:
        ser = BatchedSerializer(PickleSerializer(), 100)

    # profiling is not supported for UDF
    return func, None, ser, ser
Esempio n. 26
0
def read_udfs(pickleSer, infile):
    num_udfs = read_int(infile)
    udfs = {}
    call_udf = []
    for i in range(num_udfs):
        arg_offsets, udf = read_single_udf(pickleSer, infile)
        udfs['f%d' % i] = udf
        args = ["a[%d]" % o for o in arg_offsets]
        call_udf.append("f%d(%s)" % (i, ", ".join(args)))
    # Create function like this:
    #   lambda a: (f0(a0), f1(a1, a2), f2(a3))
    # In the special case of a single UDF this will return a single result rather
    # than a tuple of results; this is the format that the JVM side expects.
    mapper_str = "lambda a: (%s)" % (", ".join(call_udf))
    mapper = eval(mapper_str, udfs)

    func = lambda _, it: map(mapper, it)
    ser = BatchedSerializer(PickleSerializer(), 100)
    # profiling is not supported for UDF
    return func, None, ser, ser
Esempio n. 27
0
def read_udfs(pickleSer, infile):
    num_udfs = read_int(infile)
    udfs = {}
    call_udf = []
    for i in range(num_udfs):
        arg_offsets, udf = read_single_udf(pickleSer, infile)
        udfs['f%d' % i] = udf
        args = ["a[%d]" % o for o in arg_offsets]
        call_udf.append("f%d(%s)" % (i, ", ".join(args)))
    # Create function like this:
    #   lambda a: (f0(a0), f1(a1, a2), f2(a3))
    # In the special case of a single UDF this will return a single result rather
    # than a tuple of results; this is the format that the JVM side expects.
    mapper_str = "lambda a: (%s)" % (", ".join(call_udf))
    mapper = eval(mapper_str, udfs)

    func = lambda _, it: map(mapper, it)
    ser = BatchedSerializer(PickleSerializer(), 100)
    # profiling is not supported for UDF
    return func, None, ser, ser
Esempio n. 28
0
    def do_termination_test(self, terminator):
        from subprocess import Popen, PIPE
        from errno import ECONNREFUSED

        # start daemon
        daemon_path = os.path.join(os.path.dirname(__file__), "daemon.py")
        daemon = Popen([sys.executable, daemon_path], stdin=PIPE, stdout=PIPE)

        # read the port number
        port = read_int(daemon.stdout)

        # daemon should accept connections
        self.assertTrue(self.connect(port))

        # request shutdown
        terminator(daemon)
        time.sleep(1)

        # daemon should no longer accept connections
        with self.assertRaises(EnvironmentError) as trap:
            self.connect(port)
        self.assertEqual(trap.exception.errno, ECONNREFUSED)
Esempio n. 29
0
    def load_stream(self, stream):
        """
        Deserialize Cogrouped ArrowRecordBatches to a tuple of Arrow tables and yield as two
        lists of pandas.Series.
        """
        import pyarrow as pa
        dataframes_in_group = None

        while dataframes_in_group is None or dataframes_in_group > 0:
            dataframes_in_group = read_int(stream)

            if dataframes_in_group == 2:
                batch1 = [batch for batch in ArrowStreamSerializer.load_stream(self, stream)]
                batch2 = [batch for batch in ArrowStreamSerializer.load_stream(self, stream)]
                yield (
                    [self.arrow_to_pandas(c) for c in pa.Table.from_batches(batch1).itercolumns()],
                    [self.arrow_to_pandas(c) for c in pa.Table.from_batches(batch2).itercolumns()]
                )

            elif dataframes_in_group != 0:
                raise ValueError(
                    'Invalid number of pandas.DataFrames in group {0}'.format(dataframes_in_group))
Esempio n. 30
0
def _load_from_socket(
    port: Optional[Union[str, int]],
    auth_secret: str,
    function: int,
    all_gather_message: Optional[str] = None,
) -> List[str]:
    """
    Load data from a given socket, this is a blocking method thus only return when the socket
    connection has been closed.
    """
    (sockfile, sock) = local_connect_and_auth(port, auth_secret)

    # The call may block forever, so no timeout
    sock.settimeout(None)

    if function == BARRIER_FUNCTION:
        # Make a barrier() function call.
        write_int(function, sockfile)
    elif function == ALL_GATHER_FUNCTION:
        # Make a all_gather() function call.
        write_int(function, sockfile)
        write_with_length(
            cast(str, all_gather_message).encode("utf-8"), sockfile)
    else:
        raise ValueError("Unrecognized function type")
    sockfile.flush()

    # Collect result.
    len = read_int(sockfile)
    res = []
    for i in range(len):
        res.append(UTF8Deserializer().loads(sockfile))

    # Release resources.
    sockfile.close()
    sock.close()

    return res
Esempio n. 31
0
def read_udfs(pickleSer, infile):
    num_udfs = read_int(infile)
    if num_udfs == 1:
        # fast path for single UDF
        _, udf = read_single_udf(pickleSer, infile)
        mapper = lambda a: udf(*a)
    else:
        udfs = {}
        call_udf = []
        for i in range(num_udfs):
            arg_offsets, udf = read_single_udf(pickleSer, infile)
            udfs['f%d' % i] = udf
            args = ["a[%d]" % o for o in arg_offsets]
            call_udf.append("f%d(%s)" % (i, ", ".join(args)))
        # Create function like this:
        #   lambda a: (f0(a0), f1(a1, a2), f2(a3))
        mapper_str = "lambda a: (%s)" % (", ".join(call_udf))
        mapper = eval(mapper_str, udfs)

    func = lambda _, it: map(mapper, it)
    ser = BatchedSerializer(PickleSerializer(), 100)
    # profiling is not supported for UDF
    return func, None, ser, ser
Esempio n. 32
0
def read_udfs(pickleSer, infile):
    num_udfs = read_int(infile)
    if num_udfs == 1:
        # fast path for single UDF
        _, udf = read_single_udf(pickleSer, infile)
        mapper = lambda a: udf(*a)
    else:
        udfs = {}
        call_udf = []
        for i in range(num_udfs):
            arg_offsets, udf = read_single_udf(pickleSer, infile)
            udfs['f%d' % i] = udf
            args = ["a[%d]" % o for o in arg_offsets]
            call_udf.append("f%d(%s)" % (i, ", ".join(args)))
        # Create function like this:
        #   lambda a: (f0(a0), f1(a1, a2), f2(a3))
        mapper_str = "lambda a: (%s)" % (", ".join(call_udf))
        mapper = eval(mapper_str, udfs)

    func = lambda _, it: map(mapper, it)
    ser = BatchedSerializer(PickleSerializer(), 100)
    # profiling is not supported for UDF
    return func, None, ser, ser
Esempio n. 33
0
def launch_gateway(conf=None, popen_kwargs=None):
    """
    launch jvm gateway

    Parameters
    ----------
    conf : :py:class:`pyspark.SparkConf`
        spark configuration passed to spark-submit
    popen_kwargs : dict
        Dictionary of kwargs to pass to Popen when spawning
        the py4j JVM. This is a developer feature intended for use in
        customizing how pyspark interacts with the py4j JVM (e.g., capturing
        stdout/stderr).

    Returns
    -------
    ClientServer or JavaGateway
    """
    if "PYSPARK_GATEWAY_PORT" in os.environ:
        gateway_port = int(os.environ["PYSPARK_GATEWAY_PORT"])
        gateway_secret = os.environ["PYSPARK_GATEWAY_SECRET"]
        # Process already exists
        proc = None
    else:
        SPARK_HOME = _find_spark_home()
        # Launch the Py4j gateway using Spark's run command so that we pick up the
        # proper classpath and settings from spark-env.sh
        on_windows = platform.system() == "Windows"
        script = "./bin/spark-submit.cmd" if on_windows else "./bin/spark-submit"
        command = [os.path.join(SPARK_HOME, script)]
        if conf:
            for k, v in conf.getAll():
                command += ["--conf", "%s=%s" % (k, v)]
        submit_args = os.environ.get("PYSPARK_SUBMIT_ARGS", "pyspark-shell")
        if os.environ.get("SPARK_TESTING"):
            submit_args = " ".join(
                ["--conf spark.ui.enabled=false", submit_args])
        command = command + shlex.split(submit_args)

        # Create a temporary directory where the gateway server should write the connection
        # information.
        conn_info_dir = tempfile.mkdtemp()
        try:
            fd, conn_info_file = tempfile.mkstemp(dir=conn_info_dir)
            os.close(fd)
            os.unlink(conn_info_file)

            env = dict(os.environ)
            env["_PYSPARK_DRIVER_CONN_INFO_PATH"] = conn_info_file

            # Launch the Java gateway.
            popen_kwargs = {} if popen_kwargs is None else popen_kwargs
            # We open a pipe to stdin so that the Java gateway can die when the pipe is broken
            popen_kwargs["stdin"] = PIPE
            # We always set the necessary environment variables.
            popen_kwargs["env"] = env
            if not on_windows:
                # Don't send ctrl-c / SIGINT to the Java gateway:
                def preexec_func():
                    signal.signal(signal.SIGINT, signal.SIG_IGN)

                popen_kwargs["preexec_fn"] = preexec_func
                proc = Popen(command, **popen_kwargs)
            else:
                # preexec_fn not supported on Windows
                proc = Popen(command, **popen_kwargs)

            # Wait for the file to appear, or for the process to exit, whichever happens first.
            while not proc.poll() and not os.path.isfile(conn_info_file):
                time.sleep(0.1)

            if not os.path.isfile(conn_info_file):
                raise RuntimeError(
                    "Java gateway process exited before sending its port number"
                )

            with open(conn_info_file, "rb") as info:
                gateway_port = read_int(info)
                gateway_secret = UTF8Deserializer().loads(info)
        finally:
            shutil.rmtree(conn_info_dir)

        # In Windows, ensure the Java child processes do not linger after Python has exited.
        # In UNIX-based systems, the child process can kill itself on broken pipe (i.e. when
        # the parent process' stdin sends an EOF). In Windows, however, this is not possible
        # because java.lang.Process reads directly from the parent process' stdin, contending
        # with any opportunity to read an EOF from the parent. Note that this is only best
        # effort and will not take effect if the python process is violently terminated.
        if on_windows:
            # In Windows, the child process here is "spark-submit.cmd", not the JVM itself
            # (because the UNIX "exec" command is not available). This means we cannot simply
            # call proc.kill(), which kills only the "spark-submit.cmd" process but not the
            # JVMs. Instead, we use "taskkill" with the tree-kill option "/t" to terminate all
            # child processes in the tree (http://technet.microsoft.com/en-us/library/bb491009.aspx)
            def killChild():
                Popen([
                    "cmd", "/c", "taskkill", "/f", "/t", "/pid",
                    str(proc.pid)
                ])

            atexit.register(killChild)

    # Connect to the gateway (or client server to pin the thread between JVM and Python)
    if os.environ.get("PYSPARK_PIN_THREAD", "true").lower() == "true":
        gateway = ClientServer(
            java_parameters=JavaParameters(port=gateway_port,
                                           auth_token=gateway_secret,
                                           auto_convert=True),
            python_parameters=PythonParameters(port=0, eager_load=False),
        )
    else:
        gateway = JavaGateway(gateway_parameters=GatewayParameters(
            port=gateway_port, auth_token=gateway_secret, auto_convert=True))

    # Store a reference to the Popen object for use by the caller (e.g., in reading stdout/stderr)
    gateway.proc = proc

    # Import the classes used by PySpark
    java_import(gateway.jvm, "org.apache.spark.SparkConf")
    java_import(gateway.jvm, "org.apache.spark.api.java.*")
    java_import(gateway.jvm, "org.apache.spark.api.python.*")
    java_import(gateway.jvm, "org.apache.spark.ml.python.*")
    java_import(gateway.jvm, "org.apache.spark.mllib.api.python.*")
    java_import(gateway.jvm, "org.apache.spark.resource.*")
    # TODO(davies): move into sql
    java_import(gateway.jvm, "org.apache.spark.sql.*")
    java_import(gateway.jvm, "org.apache.spark.sql.api.python.*")
    java_import(gateway.jvm, "org.apache.spark.sql.hive.*")
    java_import(gateway.jvm, "scala.Tuple2")

    return gateway
Esempio n. 34
0
def main(infile, outfile):
    try:
        boot_time = time.time()
        split_index = read_int(infile)
        if split_index == -1:  # for unit tests
            sys.exit(-1)

        version = utf8_deserializer.loads(infile)
        if version != "%d.%d" % sys.version_info[:2]:
            raise Exception(
                ("Python in worker has different version %s than that in " +
                 "driver %s, PySpark cannot run with different minor versions."
                 + "Please check environment variables PYSPARK_PYTHON and " +
                 "PYSPARK_DRIVER_PYTHON are correctly set.") %
                ("%d.%d" % sys.version_info[:2], version))

        # initialize global state
        taskContext = TaskContext._getOrCreate()
        taskContext._stageId = read_int(infile)
        taskContext._partitionId = read_int(infile)
        taskContext._attemptNumber = read_int(infile)
        taskContext._taskAttemptId = read_long(infile)
        shuffle.MemoryBytesSpilled = 0
        shuffle.DiskBytesSpilled = 0
        _accumulatorRegistry.clear()

        # fetch name of workdir
        spark_files_dir = utf8_deserializer.loads(infile)
        SparkFiles._root_directory = spark_files_dir
        SparkFiles._is_running_on_worker = True

        # fetch names of includes (*.zip and *.egg files) and construct PYTHONPATH
        add_path(
            spark_files_dir)  # *.py files that were added will be copied here
        num_python_includes = read_int(infile)
        for _ in range(num_python_includes):
            filename = utf8_deserializer.loads(infile)
            add_path(os.path.join(spark_files_dir, filename))
        if sys.version > '3':
            import importlib
            importlib.invalidate_caches()

        # fetch names and values of broadcast variables
        num_broadcast_variables = read_int(infile)
        for _ in range(num_broadcast_variables):
            bid = read_long(infile)
            if bid >= 0:
                path = utf8_deserializer.loads(infile)
                _broadcastRegistry[bid] = Broadcast(path=path)
            else:
                bid = -bid - 1
                _broadcastRegistry.pop(bid)

        _accumulatorRegistry.clear()
        eval_type = read_int(infile)
        if eval_type == PythonEvalType.NON_UDF:
            func, profiler, deserializer, serializer = read_command(
                pickleSer, infile)
        else:
            func, profiler, deserializer, serializer = read_udfs(
                pickleSer, infile, eval_type)

        init_time = time.time()

        def process():
            iterator = deserializer.load_stream(infile)
            serializer.dump_stream(func(split_index, iterator), outfile)

        if profiler:
            profiler.profile(process)
        else:
            process()
    except Exception:
        try:
            write_int(SpecialLengths.PYTHON_EXCEPTION_THROWN, outfile)
            write_with_length(traceback.format_exc().encode("utf-8"), outfile)
        except IOError:
            # JVM close the socket
            pass
        except Exception:
            # Write the error to stderr if it happened while serializing
            print("PySpark worker failed with exception:", file=sys.stderr)
            print(traceback.format_exc(), file=sys.stderr)
        sys.exit(-1)
    finish_time = time.time()
    report_times(outfile, boot_time, init_time, finish_time)
    write_long(shuffle.MemoryBytesSpilled, outfile)
    write_long(shuffle.DiskBytesSpilled, outfile)

    # Mark the beginning of the accumulators section of the output
    write_int(SpecialLengths.END_OF_DATA_SECTION, outfile)
    write_int(len(_accumulatorRegistry), outfile)
    for (aid, accum) in _accumulatorRegistry.items():
        pickleSer._write_with_length((aid, accum._value), outfile)

    # check end of stream
    if read_int(infile) == SpecialLengths.END_OF_STREAM:
        write_int(SpecialLengths.END_OF_STREAM, outfile)
    else:
        # write a different value to tell JVM to not reuse this worker
        write_int(SpecialLengths.END_OF_DATA_SECTION, outfile)
        sys.exit(-1)
Esempio n. 35
0
def main(infile, outfile):
    try:
        boot_time = time.time()
        split_index = read_int(infile)
        if split_index == -1:  # for unit tests
            sys.exit(-1)

        version = utf8_deserializer.loads(infile)
        if version != "%d.%d" % sys.version_info[:2]:
            raise Exception(("Python in worker has different version %s than that in " +
                             "driver %s, PySpark cannot run with different minor versions." +
                             "Please check environment variables PYSPARK_PYTHON and " +
                             "PYSPARK_DRIVER_PYTHON are correctly set.") %
                            ("%d.%d" % sys.version_info[:2], version))

        # read inputs only for a barrier task
        isBarrier = read_bool(infile)
        boundPort = read_int(infile)
        secret = UTF8Deserializer().loads(infile)

        # set up memory limits
        memory_limit_mb = int(os.environ.get('PYSPARK_EXECUTOR_MEMORY_MB', "-1"))
        if memory_limit_mb > 0 and has_resource_module:
            total_memory = resource.RLIMIT_AS
            try:
                (soft_limit, hard_limit) = resource.getrlimit(total_memory)
                msg = "Current mem limits: {0} of max {1}\n".format(soft_limit, hard_limit)
                print(msg, file=sys.stderr)

                # convert to bytes
                new_limit = memory_limit_mb * 1024 * 1024

                if soft_limit == resource.RLIM_INFINITY or new_limit < soft_limit:
                    msg = "Setting mem limits to {0} of max {1}\n".format(new_limit, new_limit)
                    print(msg, file=sys.stderr)
                    resource.setrlimit(total_memory, (new_limit, new_limit))

            except (resource.error, OSError, ValueError) as e:
                # not all systems support resource limits, so warn instead of failing
                print("WARN: Failed to set memory limit: {0}\n".format(e), file=sys.stderr)

        # initialize global state
        taskContext = None
        if isBarrier:
            taskContext = BarrierTaskContext._getOrCreate()
            BarrierTaskContext._initialize(boundPort, secret)
        else:
            taskContext = TaskContext._getOrCreate()
        # read inputs for TaskContext info
        taskContext._stageId = read_int(infile)
        taskContext._partitionId = read_int(infile)
        taskContext._attemptNumber = read_int(infile)
        taskContext._taskAttemptId = read_long(infile)
        taskContext._localProperties = dict()
        for i in range(read_int(infile)):
            k = utf8_deserializer.loads(infile)
            v = utf8_deserializer.loads(infile)
            taskContext._localProperties[k] = v

        shuffle.MemoryBytesSpilled = 0
        shuffle.DiskBytesSpilled = 0
        _accumulatorRegistry.clear()

        # fetch name of workdir
        spark_files_dir = utf8_deserializer.loads(infile)
        SparkFiles._root_directory = spark_files_dir
        SparkFiles._is_running_on_worker = True

        # fetch names of includes (*.zip and *.egg files) and construct PYTHONPATH
        add_path(spark_files_dir)  # *.py files that were added will be copied here
        num_python_includes = read_int(infile)
        for _ in range(num_python_includes):
            filename = utf8_deserializer.loads(infile)
            add_path(os.path.join(spark_files_dir, filename))
        if sys.version > '3':
            import importlib
            importlib.invalidate_caches()

        # fetch names and values of broadcast variables
        needs_broadcast_decryption_server = read_bool(infile)
        num_broadcast_variables = read_int(infile)
        if needs_broadcast_decryption_server:
            # read the decrypted data from a server in the jvm
            port = read_int(infile)
            auth_secret = utf8_deserializer.loads(infile)
            (broadcast_sock_file, _) = local_connect_and_auth(port, auth_secret)

        for _ in range(num_broadcast_variables):
            bid = read_long(infile)
            if bid >= 0:
                if needs_broadcast_decryption_server:
                    read_bid = read_long(broadcast_sock_file)
                    assert(read_bid == bid)
                    _broadcastRegistry[bid] = \
                        Broadcast(sock_file=broadcast_sock_file)
                else:
                    path = utf8_deserializer.loads(infile)
                    _broadcastRegistry[bid] = Broadcast(path=path)

            else:
                bid = - bid - 1
                _broadcastRegistry.pop(bid)

        if needs_broadcast_decryption_server:
            broadcast_sock_file.write(b'1')
            broadcast_sock_file.close()

        _accumulatorRegistry.clear()
        eval_type = read_int(infile)
        if eval_type == PythonEvalType.NON_UDF:
            func, profiler, deserializer, serializer = read_command(pickleSer, infile)
        else:
            func, profiler, deserializer, serializer = read_udfs(pickleSer, infile, eval_type)

        init_time = time.time()

        def process():
            iterator = deserializer.load_stream(infile)
            serializer.dump_stream(func(split_index, iterator), outfile)

        if profiler:
            profiler.profile(process)
        else:
            process()
    except Exception:
        try:
            write_int(SpecialLengths.PYTHON_EXCEPTION_THROWN, outfile)
            write_with_length(traceback.format_exc().encode("utf-8"), outfile)
        except IOError:
            # JVM close the socket
            pass
        except Exception:
            # Write the error to stderr if it happened while serializing
            print("PySpark worker failed with exception:", file=sys.stderr)
            print(traceback.format_exc(), file=sys.stderr)
        sys.exit(-1)
    finish_time = time.time()
    report_times(outfile, boot_time, init_time, finish_time)
    write_long(shuffle.MemoryBytesSpilled, outfile)
    write_long(shuffle.DiskBytesSpilled, outfile)

    # Mark the beginning of the accumulators section of the output
    write_int(SpecialLengths.END_OF_DATA_SECTION, outfile)
    write_int(len(_accumulatorRegistry), outfile)
    for (aid, accum) in _accumulatorRegistry.items():
        pickleSer._write_with_length((aid, accum._value), outfile)

    # check end of stream
    if read_int(infile) == SpecialLengths.END_OF_STREAM:
        write_int(SpecialLengths.END_OF_STREAM, outfile)
    else:
        # write a different value to tell JVM to not reuse this worker
        write_int(SpecialLengths.END_OF_DATA_SECTION, outfile)
        sys.exit(-1)
Esempio n. 36
0
def main(infile, outfile):
    try:
        boot_time = time.time()
        split_index = read_int(infile)
        if split_index == -1:  # for unit tests
            return

        # initialize global state
        shuffle.MemoryBytesSpilled = 0
        shuffle.DiskBytesSpilled = 0
        _accumulatorRegistry.clear()

        # fetch name of workdir
        spark_files_dir = utf8_deserializer.loads(infile)
        SparkFiles._root_directory = spark_files_dir
        SparkFiles._is_running_on_worker = True

        # fetch names of includes (*.zip and *.egg files) and construct PYTHONPATH
        sys.path.append(
            spark_files_dir)  # *.py files that were added will be copied here
        num_python_includes = read_int(infile)
        for _ in range(num_python_includes):
            filename = utf8_deserializer.loads(infile)
            sys.path.append(os.path.join(spark_files_dir, filename))

        # fetch names and values of broadcast variables
        num_broadcast_variables = read_int(infile)
        ser = CompressedSerializer(pickleSer)
        for _ in range(num_broadcast_variables):
            bid = read_long(infile)
            if bid >= 0:
                value = ser._read_with_length(infile)
                _broadcastRegistry[bid] = Broadcast(bid, value)
            else:
                bid = -bid - 1
                _broadcastRegistry.pop(bid)

        _accumulatorRegistry.clear()
        command = pickleSer._read_with_length(infile)
        if isinstance(command, Broadcast):
            command = pickleSer.loads(command.value)
        (func, deserializer, serializer) = command
        init_time = time.time()
        iterator = deserializer.load_stream(infile)
        serializer.dump_stream(func(split_index, iterator), outfile)
    except Exception:
        try:
            write_int(SpecialLengths.PYTHON_EXCEPTION_THROWN, outfile)
            write_with_length(traceback.format_exc(), outfile)
            outfile.flush()
        except IOError:
            # JVM close the socket
            pass
        except Exception:
            # Write the error to stderr if it happened while serializing
            print >> sys.stderr, "PySpark worker failed with exception:"
            print >> sys.stderr, traceback.format_exc()
        exit(-1)
    finish_time = time.time()
    report_times(outfile, boot_time, init_time, finish_time)
    write_long(shuffle.MemoryBytesSpilled, outfile)
    write_long(shuffle.DiskBytesSpilled, outfile)

    # Mark the beginning of the accumulators section of the output
    write_int(SpecialLengths.END_OF_DATA_SECTION, outfile)
    write_int(len(_accumulatorRegistry), outfile)
    for (aid, accum) in _accumulatorRegistry.items():
        pickleSer._write_with_length((aid, accum._value), outfile)
Esempio n. 37
0
def main(infile, outfile):
    try:
        boot_time = time.time()
        split_index = read_int(infile)
        if split_index == -1:  # for unit tests
            exit(-1)

        # initialize global state
        shuffle.MemoryBytesSpilled = 0
        shuffle.DiskBytesSpilled = 0
        _accumulatorRegistry.clear()

        # fetch name of workdir
        spark_files_dir = utf8_deserializer.loads(infile)
        SparkFiles._root_directory = spark_files_dir
        SparkFiles._is_running_on_worker = True

        # fetch names of includes (*.zip and *.egg files) and construct PYTHONPATH
        add_path(
            spark_files_dir)  # *.py files that were added will be copied here
        num_python_includes = read_int(infile)
        for _ in range(num_python_includes):
            filename = utf8_deserializer.loads(infile)
            add_path(os.path.join(spark_files_dir, filename))

        # fetch names and values of broadcast variables
        num_broadcast_variables = read_int(infile)
        bser = LargeObjectSerializer()
        for _ in range(num_broadcast_variables):
            bid = read_long(infile)
            if bid >= 0:
                size = read_long(infile)
                s = SizeLimitedStream(infile, size)
                value = list(
                    (bser.load_stream(s)))[0]  # read out all the bytes
                _broadcastRegistry[bid] = Broadcast(bid, value)
            else:
                bid = -bid - 1
                _broadcastRegistry.pop(bid)

        _accumulatorRegistry.clear()
        command = pickleSer._read_with_length(infile)
        if isinstance(command, Broadcast):
            command = pickleSer.loads(command.value)
        (func, stats, deserializer, serializer) = command
        init_time = time.time()

        def process():
            iterator = deserializer.load_stream(infile)
            serializer.dump_stream(func(split_index, iterator), outfile)

        if stats:
            p = cProfile.Profile()
            p.runcall(process)
            st = pstats.Stats(p)
            st.stream = None  # make it picklable
            stats.add(st.strip_dirs())
        else:
            process()
    except Exception:
        try:
            write_int(SpecialLengths.PYTHON_EXCEPTION_THROWN, outfile)
            write_with_length(traceback.format_exc(), outfile)
        except IOError:
            # JVM close the socket
            pass
        except Exception:
            # Write the error to stderr if it happened while serializing
            print >> sys.stderr, "PySpark worker failed with exception:"
            print >> sys.stderr, traceback.format_exc()
        exit(-1)
    finish_time = time.time()
    report_times(outfile, boot_time, init_time, finish_time)
    write_long(shuffle.MemoryBytesSpilled, outfile)
    write_long(shuffle.DiskBytesSpilled, outfile)

    # Mark the beginning of the accumulators section of the output
    write_int(SpecialLengths.END_OF_DATA_SECTION, outfile)
    write_int(len(_accumulatorRegistry), outfile)
    for (aid, accum) in _accumulatorRegistry.items():
        pickleSer._write_with_length((aid, accum._value), outfile)

    # check end of stream
    if read_int(infile) == SpecialLengths.END_OF_STREAM:
        write_int(SpecialLengths.END_OF_STREAM, outfile)
    else:
        # write a different value to tell JVM to not reuse this worker
        write_int(SpecialLengths.END_OF_DATA_SECTION, outfile)
        exit(-1)
Esempio n. 38
0
def main(infile, outfile):
    try:
        boot_time = time.time()
        split_index = read_int(infile)
        if split_index == -1:  # for unit tests
            exit(-1)

        # initialize global state
        shuffle.MemoryBytesSpilled = 0
        shuffle.DiskBytesSpilled = 0
        _accumulatorRegistry.clear()

        # fetch name of workdir
        spark_files_dir = utf8_deserializer.loads(infile)
        SparkFiles._root_directory = spark_files_dir
        SparkFiles._is_running_on_worker = True

        # fetch names of includes (*.zip and *.egg files) and construct PYTHONPATH
        add_path(spark_files_dir)  # *.py files that were added will be copied here
        num_python_includes = read_int(infile)
        for _ in range(num_python_includes):
            filename = utf8_deserializer.loads(infile)
            add_path(os.path.join(spark_files_dir, filename))

        # fetch names and values of broadcast variables
        num_broadcast_variables = read_int(infile)
        for _ in range(num_broadcast_variables):
            bid = read_long(infile)
            if bid >= 0:
                path = utf8_deserializer.loads(infile)
                _broadcastRegistry[bid] = Broadcast(path=path)
            else:
                bid = - bid - 1
                _broadcastRegistry.pop(bid)

        _accumulatorRegistry.clear()
        command = pickleSer._read_with_length(infile)
        if isinstance(command, Broadcast):
            command = pickleSer.loads(command.value)
        (func, profiler, deserializer, serializer), version = command
        if version != sys.version_info[:2]:
            raise Exception(("Python in worker has different version %s than that in " +
                            "driver %s, PySpark cannot run with different minor versions") %
                            (sys.version_info[:2], version))
        init_time = time.time()

        def process():
            iterator = deserializer.load_stream(infile)
            serializer.dump_stream(func(split_index, iterator), outfile)

        if profiler:
            profiler.profile(process)
        else:
            process()
    except Exception:
        try:
            write_int(SpecialLengths.PYTHON_EXCEPTION_THROWN, outfile)
            write_with_length(traceback.format_exc(), outfile)
        except IOError:
            # JVM close the socket
            pass
        except Exception:
            # Write the error to stderr if it happened while serializing
            print >> sys.stderr, "PySpark worker failed with exception:"
            print >> sys.stderr, traceback.format_exc()
        exit(-1)
    finish_time = time.time()
    report_times(outfile, boot_time, init_time, finish_time)
    write_long(shuffle.MemoryBytesSpilled, outfile)
    write_long(shuffle.DiskBytesSpilled, outfile)

    # Mark the beginning of the accumulators section of the output
    write_int(SpecialLengths.END_OF_DATA_SECTION, outfile)
    write_int(len(_accumulatorRegistry), outfile)
    for (aid, accum) in _accumulatorRegistry.items():
        pickleSer._write_with_length((aid, accum._value), outfile)

    # check end of stream
    if read_int(infile) == SpecialLengths.END_OF_STREAM:
        write_int(SpecialLengths.END_OF_STREAM, outfile)
    else:
        # write a different value to tell JVM to not reuse this worker
        write_int(SpecialLengths.END_OF_DATA_SECTION, outfile)
        exit(-1)
Esempio n. 39
0
def manager():
    # Create a new process group to corral our children
    os.setpgid(0, 0)

    # Create a listening socket on the AF_INET loopback interface
    listen_sock = socket.socket(AF_INET, SOCK_STREAM)
    listen_sock.bind(('127.0.0.1', 0))
    listen_sock.listen(max(1024, SOMAXCONN))
    listen_host, listen_port = listen_sock.getsockname()
    write_int(listen_port, sys.stdout)
    sys.stdout.flush()

    def shutdown(code):
        signal.signal(SIGTERM, SIG_DFL)
        # Send SIGHUP to notify workers of shutdown
        os.kill(0, SIGHUP)
        exit(code)

    def handle_sigterm(*args):
        shutdown(1)
    signal.signal(SIGTERM, handle_sigterm)  # Gracefully exit on SIGTERM
    signal.signal(SIGHUP, SIG_IGN)  # Don't die on SIGHUP

    reuse = os.environ.get("SPARK_REUSE_WORKER")

    # Initialization complete
    try:
        while True:
            try:
                ready_fds = select.select([0, listen_sock], [], [], 1)[0]
            except select.error as ex:
                if ex[0] == EINTR:
                    continue
                else:
                    raise

            # cleanup in signal handler will cause deadlock
            cleanup_dead_children()

            if 0 in ready_fds:
                try:
                    worker_pid = read_int(sys.stdin)
                except EOFError:
                    # Spark told us to exit by closing stdin
                    shutdown(0)
                try:
                    os.kill(worker_pid, signal.SIGKILL)
                except OSError:
                    pass  # process already died

            if listen_sock in ready_fds:
                try:
                    sock, _ = listen_sock.accept()
                except OSError as e:
                    if e.errno == EINTR:
                        continue
                    raise

                # Launch a worker process
                try:
                    pid = os.fork()
                except OSError as e:
                    if e.errno in (EAGAIN, EINTR):
                        time.sleep(1)
                        pid = os.fork()  # error here will shutdown daemon
                    else:
                        outfile = sock.makefile('w')
                        write_int(e.errno, outfile)  # Signal that the fork failed
                        outfile.flush()
                        outfile.close()
                        sock.close()
                        continue

                if pid == 0:
                    # in child process
                    listen_sock.close()
                    try:
                        # Acknowledge that the fork was successful
                        outfile = sock.makefile("w")
                        write_int(os.getpid(), outfile)
                        outfile.flush()
                        outfile.close()
                        while True:
                            code = worker(sock)
                            if not reuse or code:
                                # wait for closing
                                try:
                                    while sock.recv(1024):
                                        pass
                                except Exception:
                                    pass
                                break
                            gc.collect()
                    except:
                        traceback.print_exc()
                        os._exit(1)
                    else:
                        os._exit(0)
                else:
                    sock.close()

    finally:
        shutdown(1)
Esempio n. 40
0
def main(infile, outfile):
    try:
        boot_time = time.time()
        split_index = read_int(infile)
        if split_index == -1:  # for unit tests
            sys.exit(-1)

        version = utf8_deserializer.loads(infile)
        if version != "%d.%d" % sys.version_info[:2]:
            raise Exception(("Python in worker has different version %s than that in " +
                             "driver %s, PySpark cannot run with different minor versions." +
                             "Please check environment variables PYSPARK_PYTHON and " +
                             "PYSPARK_DRIVER_PYTHON are correctly set.") %
                            ("%d.%d" % sys.version_info[:2], version))

        # read inputs only for a barrier task
        isBarrier = read_bool(infile)
        boundPort = read_int(infile)
        secret = UTF8Deserializer().loads(infile)
        # initialize global state
        taskContext = None
        if isBarrier:
            taskContext = BarrierTaskContext._getOrCreate()
            BarrierTaskContext._initialize(boundPort, secret)
        else:
            taskContext = TaskContext._getOrCreate()
        # read inputs for TaskContext info
        taskContext._stageId = read_int(infile)
        taskContext._partitionId = read_int(infile)
        taskContext._attemptNumber = read_int(infile)
        taskContext._taskAttemptId = read_long(infile)
        taskContext._localProperties = dict()
        for i in range(read_int(infile)):
            k = utf8_deserializer.loads(infile)
            v = utf8_deserializer.loads(infile)
            taskContext._localProperties[k] = v

        shuffle.MemoryBytesSpilled = 0
        shuffle.DiskBytesSpilled = 0
        _accumulatorRegistry.clear()

        # fetch name of workdir
        spark_files_dir = utf8_deserializer.loads(infile)
        SparkFiles._root_directory = spark_files_dir
        SparkFiles._is_running_on_worker = True

        # fetch names of includes (*.zip and *.egg files) and construct PYTHONPATH
        add_path(spark_files_dir)  # *.py files that were added will be copied here
        num_python_includes = read_int(infile)
        for _ in range(num_python_includes):
            filename = utf8_deserializer.loads(infile)
            add_path(os.path.join(spark_files_dir, filename))
        if sys.version > '3':
            import importlib
            importlib.invalidate_caches()

        # fetch names and values of broadcast variables
        num_broadcast_variables = read_int(infile)
        for _ in range(num_broadcast_variables):
            bid = read_long(infile)
            if bid >= 0:
                path = utf8_deserializer.loads(infile)
                _broadcastRegistry[bid] = Broadcast(path=path)
            else:
                bid = - bid - 1
                _broadcastRegistry.pop(bid)

        _accumulatorRegistry.clear()
        eval_type = read_int(infile)
        if eval_type == PythonEvalType.NON_UDF:
            func, profiler, deserializer, serializer = read_command(pickleSer, infile)
        else:
            func, profiler, deserializer, serializer = read_udfs(pickleSer, infile, eval_type)

        init_time = time.time()

        def process():
            iterator = deserializer.load_stream(infile)
            serializer.dump_stream(func(split_index, iterator), outfile)

        if profiler:
            profiler.profile(process)
        else:
            process()
    except Exception:
        try:
            write_int(SpecialLengths.PYTHON_EXCEPTION_THROWN, outfile)
            write_with_length(traceback.format_exc().encode("utf-8"), outfile)
        except IOError:
            # JVM close the socket
            pass
        except Exception:
            # Write the error to stderr if it happened while serializing
            print("PySpark worker failed with exception:", file=sys.stderr)
            print(traceback.format_exc(), file=sys.stderr)
        sys.exit(-1)
    finish_time = time.time()
    report_times(outfile, boot_time, init_time, finish_time)
    write_long(shuffle.MemoryBytesSpilled, outfile)
    write_long(shuffle.DiskBytesSpilled, outfile)

    # Mark the beginning of the accumulators section of the output
    write_int(SpecialLengths.END_OF_DATA_SECTION, outfile)
    write_int(len(_accumulatorRegistry), outfile)
    for (aid, accum) in _accumulatorRegistry.items():
        pickleSer._write_with_length((aid, accum._value), outfile)

    # check end of stream
    if read_int(infile) == SpecialLengths.END_OF_STREAM:
        write_int(SpecialLengths.END_OF_STREAM, outfile)
    else:
        # write a different value to tell JVM to not reuse this worker
        write_int(SpecialLengths.END_OF_DATA_SECTION, outfile)
        sys.exit(-1)
Esempio n. 41
0
def main(infile, outfile):
    try:
        boot_time = time.time()
        split_index = read_int(infile)
        if split_index == -1:  # for unit tests
            exit(-1)

        # initialize global state
        shuffle.MemoryBytesSpilled = 0
        shuffle.DiskBytesSpilled = 0
        _accumulatorRegistry.clear()

        # fetch name of workdir
        spark_files_dir = utf8_deserializer.loads(infile)
        SparkFiles._root_directory = spark_files_dir
        SparkFiles._is_running_on_worker = True

        # fetch names of includes (*.zip and *.egg files) and construct PYTHONPATH
        add_path(spark_files_dir)  # *.py files that were added will be copied here
        num_python_includes = read_int(infile)
        for _ in range(num_python_includes):
            filename = utf8_deserializer.loads(infile)
            add_path(os.path.join(spark_files_dir, filename))
        if sys.version > '3':
            import importlib
            importlib.invalidate_caches()

        # fetch names and values of broadcast variables
        num_broadcast_variables = read_int(infile)
        for _ in range(num_broadcast_variables):
            bid = read_long(infile)
            if bid >= 0:
                path = utf8_deserializer.loads(infile)
                _broadcastRegistry[bid] = Broadcast(path=path)
            else:
                bid = - bid - 1
                _broadcastRegistry.pop(bid)

        _accumulatorRegistry.clear()
        command = pickleSer._read_with_length(infile)
        if isinstance(command, Broadcast):
            command = pickleSer.loads(command.value)
        (func, profiler, deserializer, serializer), version = command
        if version != sys.version_info[:2]:
            raise Exception(("Python in worker has different version %s than that in " +
                            "driver %s, PySpark cannot run with different minor versions") %
                            (sys.version_info[:2], version))
        init_time = time.time()

        def process():
            iterator = deserializer.load_stream(infile)
            serializer.dump_stream(func(split_index, iterator), outfile)

        if profiler:
            profiler.profile(process)
        else:
            process()
    except Exception:
        try:
            write_int(SpecialLengths.PYTHON_EXCEPTION_THROWN, outfile)
            write_with_length(traceback.format_exc().encode("utf-8"), outfile)
        except IOError:
            # JVM close the socket
            pass
        except Exception:
            # Write the error to stderr if it happened while serializing
            print("PySpark worker failed with exception:", file=sys.stderr)
            print(traceback.format_exc(), file=sys.stderr)
        exit(-1)
    finish_time = time.time()
    report_times(outfile, boot_time, init_time, finish_time)
    write_long(shuffle.MemoryBytesSpilled, outfile)
    write_long(shuffle.DiskBytesSpilled, outfile)

    # Mark the beginning of the accumulators section of the output
    write_int(SpecialLengths.END_OF_DATA_SECTION, outfile)
    write_int(len(_accumulatorRegistry), outfile)
    for (aid, accum) in _accumulatorRegistry.items():
        pickleSer._write_with_length((aid, accum._value), outfile)

    # check end of stream
    if read_int(infile) == SpecialLengths.END_OF_STREAM:
        write_int(SpecialLengths.END_OF_STREAM, outfile)
    else:
        # write a different value to tell JVM to not reuse this worker
        write_int(SpecialLengths.END_OF_DATA_SECTION, outfile)
        exit(-1)
Esempio n. 42
0
def launch_gateway(conf=None):
    """
    launch jvm gateway
    :param conf: spark configuration passed to spark-submit
    :return:
    """
    if "PYSPARK_GATEWAY_PORT" in os.environ:
        gateway_port = int(os.environ["PYSPARK_GATEWAY_PORT"])
        gateway_secret = os.environ["PYSPARK_GATEWAY_SECRET"]
    else:
        SPARK_HOME = _find_spark_home()
        # Launch the Py4j gateway using Spark's run command so that we pick up the
        # proper classpath and settings from spark-env.sh
        on_windows = platform.system() == "Windows"
        script = "./bin/spark-submit.cmd" if on_windows else "./bin/spark-submit"
        command = [os.path.join(SPARK_HOME, script)]
        if conf:
            for k, v in conf.getAll():
                command += ['--conf', '%s=%s' % (k, v)]
        submit_args = os.environ.get("PYSPARK_SUBMIT_ARGS", "pyspark-shell")
        if os.environ.get("SPARK_TESTING"):
            submit_args = ' '.join([
                "--conf spark.ui.enabled=false",
                submit_args
            ])
        command = command + shlex.split(submit_args)

        # Create a temporary directory where the gateway server should write the connection
        # information.
        conn_info_dir = tempfile.mkdtemp()
        try:
            fd, conn_info_file = tempfile.mkstemp(dir=conn_info_dir)
            os.close(fd)
            os.unlink(conn_info_file)

            env = dict(os.environ)
            env["_PYSPARK_DRIVER_CONN_INFO_PATH"] = conn_info_file

            # Launch the Java gateway.
            # We open a pipe to stdin so that the Java gateway can die when the pipe is broken
            if not on_windows:
                # Don't send ctrl-c / SIGINT to the Java gateway:
                def preexec_func():
                    signal.signal(signal.SIGINT, signal.SIG_IGN)
                proc = Popen(command, stdin=PIPE, preexec_fn=preexec_func, env=env)
            else:
                # preexec_fn not supported on Windows
                proc = Popen(command, stdin=PIPE, env=env)

            # Wait for the file to appear, or for the process to exit, whichever happens first.
            while not proc.poll() and not os.path.isfile(conn_info_file):
                time.sleep(0.1)

            if not os.path.isfile(conn_info_file):
                raise Exception("Java gateway process exited before sending its port number")

            with open(conn_info_file, "rb") as info:
                gateway_port = read_int(info)
                gateway_secret = UTF8Deserializer().loads(info)
        finally:
            shutil.rmtree(conn_info_dir)

        # In Windows, ensure the Java child processes do not linger after Python has exited.
        # In UNIX-based systems, the child process can kill itself on broken pipe (i.e. when
        # the parent process' stdin sends an EOF). In Windows, however, this is not possible
        # because java.lang.Process reads directly from the parent process' stdin, contending
        # with any opportunity to read an EOF from the parent. Note that this is only best
        # effort and will not take effect if the python process is violently terminated.
        if on_windows:
            # In Windows, the child process here is "spark-submit.cmd", not the JVM itself
            # (because the UNIX "exec" command is not available). This means we cannot simply
            # call proc.kill(), which kills only the "spark-submit.cmd" process but not the
            # JVMs. Instead, we use "taskkill" with the tree-kill option "/t" to terminate all
            # child processes in the tree (http://technet.microsoft.com/en-us/library/bb491009.aspx)
            def killChild():
                Popen(["cmd", "/c", "taskkill", "/f", "/t", "/pid", str(proc.pid)])
            atexit.register(killChild)

    # Connect to the gateway
    gateway = JavaGateway(
        gateway_parameters=GatewayParameters(port=gateway_port, auth_token=gateway_secret,
                                             auto_convert=True))

    # Import the classes used by PySpark
    java_import(gateway.jvm, "org.apache.spark.SparkConf")
    java_import(gateway.jvm, "org.apache.spark.api.java.*")
    java_import(gateway.jvm, "org.apache.spark.api.python.*")
    java_import(gateway.jvm, "org.apache.spark.ml.python.*")
    java_import(gateway.jvm, "org.apache.spark.mllib.api.python.*")
    # TODO(davies): move into sql
    java_import(gateway.jvm, "org.apache.spark.sql.*")
    java_import(gateway.jvm, "org.apache.spark.sql.api.python.*")
    java_import(gateway.jvm, "org.apache.spark.sql.hive.*")
    java_import(gateway.jvm, "scala.Tuple2")

    return gateway
Esempio n. 43
0
File: worker.py Progetto: 0asa/spark
def main(infile, outfile):
    try:
        boot_time = time.time()
        split_index = read_int(infile)
        if split_index == -1:  # for unit tests
            exit(-1)

        # initialize global state
        shuffle.MemoryBytesSpilled = 0
        shuffle.DiskBytesSpilled = 0
        _accumulatorRegistry.clear()

        # fetch name of workdir
        spark_files_dir = utf8_deserializer.loads(infile)
        SparkFiles._root_directory = spark_files_dir
        SparkFiles._is_running_on_worker = True

        # fetch names of includes (*.zip and *.egg files) and construct PYTHONPATH
        add_path(spark_files_dir)  # *.py files that were added will be copied here
        num_python_includes = read_int(infile)
        for _ in range(num_python_includes):
            filename = utf8_deserializer.loads(infile)
            add_path(os.path.join(spark_files_dir, filename))

        # fetch names and values of broadcast variables
        num_broadcast_variables = read_int(infile)
        bser = LargeObjectSerializer()
        for _ in range(num_broadcast_variables):
            bid = read_long(infile)
            if bid >= 0:
                size = read_long(infile)
                s = SizeLimitedStream(infile, size)
                value = list((bser.load_stream(s)))[0]  # read out all the bytes
                _broadcastRegistry[bid] = Broadcast(bid, value)
            else:
                bid = - bid - 1
                _broadcastRegistry.pop(bid)

        _accumulatorRegistry.clear()
        command = pickleSer._read_with_length(infile)
        if isinstance(command, Broadcast):
            command = pickleSer.loads(command.value)
        (func, stats, deserializer, serializer) = command
        init_time = time.time()

        def process():
            iterator = deserializer.load_stream(infile)
            serializer.dump_stream(func(split_index, iterator), outfile)

        if stats:
            p = cProfile.Profile()
            p.runcall(process)
            st = pstats.Stats(p)
            st.stream = None  # make it picklable
            stats.add(st.strip_dirs())
        else:
            process()
    except Exception:
        try:
            write_int(SpecialLengths.PYTHON_EXCEPTION_THROWN, outfile)
            write_with_length(traceback.format_exc(), outfile)
        except IOError:
            # JVM close the socket
            pass
        except Exception:
            # Write the error to stderr if it happened while serializing
            print >> sys.stderr, "PySpark worker failed with exception:"
            print >> sys.stderr, traceback.format_exc()
        exit(-1)
    finish_time = time.time()
    report_times(outfile, boot_time, init_time, finish_time)
    write_long(shuffle.MemoryBytesSpilled, outfile)
    write_long(shuffle.DiskBytesSpilled, outfile)

    # Mark the beginning of the accumulators section of the output
    write_int(SpecialLengths.END_OF_DATA_SECTION, outfile)
    write_int(len(_accumulatorRegistry), outfile)
    for (aid, accum) in _accumulatorRegistry.items():
        pickleSer._write_with_length((aid, accum._value), outfile)

    # check end of stream
    if read_int(infile) == SpecialLengths.END_OF_STREAM:
        write_int(SpecialLengths.END_OF_STREAM, outfile)
    else:
        # write a different value to tell JVM to not reuse this worker
        write_int(SpecialLengths.END_OF_DATA_SECTION, outfile)
        exit(-1)
    def loads(self, obj):
        stream = BytesIO(obj)
        key_length = read_int(stream)
        key = stream.read(key_length).decode('utf-8')

        return (key, _read_vec(stream))
Esempio n. 45
0
def launch_gateway(conf=None):
    """
    launch jvm gateway
    :param conf: spark configuration passed to spark-submit
    :return:
    """
    if "PYSPARK_GATEWAY_PORT" in os.environ:
        gateway_port = int(os.environ["PYSPARK_GATEWAY_PORT"])
        gateway_secret = os.environ["PYSPARK_GATEWAY_SECRET"]
    else:
        SPARK_HOME = _find_spark_home()
        # Launch the Py4j gateway using Spark's run command so that we pick up the
        # proper classpath and settings from spark-env.sh
        on_windows = platform.system() == "Windows"
        script = "./bin/spark-submit.cmd" if on_windows else "./bin/spark-submit"
        command = [os.path.join(SPARK_HOME, script)]
        if conf:
            for k, v in conf.getAll():
                command += ['--conf', '%s=%s' % (k, v)]
        submit_args = os.environ.get("PYSPARK_SUBMIT_ARGS", "pyspark-shell")
        if os.environ.get("SPARK_TESTING"):
            submit_args = ' '.join([
                "--conf spark.ui.enabled=false",
                submit_args
            ])
        command = command + shlex.split(submit_args)

        # Create a temporary directory where the gateway server should write the connection
        # information.
        conn_info_dir = tempfile.mkdtemp()
        try:
            fd, conn_info_file = tempfile.mkstemp(dir=conn_info_dir)
            os.close(fd)
            os.unlink(conn_info_file)

            env = dict(os.environ)
            env["_PYSPARK_DRIVER_CONN_INFO_PATH"] = conn_info_file

            # Launch the Java gateway.
            # We open a pipe to stdin so that the Java gateway can die when the pipe is broken
            if not on_windows:
                # Don't send ctrl-c / SIGINT to the Java gateway:
                def preexec_func():
                    signal.signal(signal.SIGINT, signal.SIG_IGN)
                proc = Popen(command, stdin=PIPE, preexec_fn=preexec_func, env=env)
            else:
                # preexec_fn not supported on Windows
                proc = Popen(command, stdin=PIPE, env=env)

            # Wait for the file to appear, or for the process to exit, whichever happens first.
            while not proc.poll() and not os.path.isfile(conn_info_file):
                time.sleep(0.1)

            if not os.path.isfile(conn_info_file):
                raise Exception("Java gateway process exited before sending its port number")

            with open(conn_info_file, "rb") as info:
                gateway_port = read_int(info)
                gateway_secret = UTF8Deserializer().loads(info)
        finally:
            shutil.rmtree(conn_info_dir)

        # In Windows, ensure the Java child processes do not linger after Python has exited.
        # In UNIX-based systems, the child process can kill itself on broken pipe (i.e. when
        # the parent process' stdin sends an EOF). In Windows, however, this is not possible
        # because java.lang.Process reads directly from the parent process' stdin, contending
        # with any opportunity to read an EOF from the parent. Note that this is only best
        # effort and will not take effect if the python process is violently terminated.
        if on_windows:
            # In Windows, the child process here is "spark-submit.cmd", not the JVM itself
            # (because the UNIX "exec" command is not available). This means we cannot simply
            # call proc.kill(), which kills only the "spark-submit.cmd" process but not the
            # JVMs. Instead, we use "taskkill" with the tree-kill option "/t" to terminate all
            # child processes in the tree (http://technet.microsoft.com/en-us/library/bb491009.aspx)
            def killChild():
                Popen(["cmd", "/c", "taskkill", "/f", "/t", "/pid", str(proc.pid)])
            atexit.register(killChild)

    # Connect to the gateway
    gateway = JavaGateway(
        gateway_parameters=GatewayParameters(port=gateway_port, auth_token=gateway_secret,
                                             auto_convert=True))

    # Import the classes used by PySpark
    java_import(gateway.jvm, "org.apache.spark.SparkConf")
    java_import(gateway.jvm, "org.apache.spark.api.java.*")
    java_import(gateway.jvm, "org.apache.spark.api.python.*")
    java_import(gateway.jvm, "org.apache.spark.ml.python.*")
    java_import(gateway.jvm, "org.apache.spark.mllib.api.python.*")
    # TODO(davies): move into sql
    java_import(gateway.jvm, "org.apache.spark.sql.*")
    java_import(gateway.jvm, "org.apache.spark.sql.api.python.*")
    java_import(gateway.jvm, "org.apache.spark.sql.hive.*")
    java_import(gateway.jvm, "scala.Tuple2")

    return gateway
Esempio n. 46
0
def read_udfs(pickleSer, infile, eval_type):
    runner_conf = {}

    if eval_type in (PythonEvalType.SQL_SCALAR_PANDAS_UDF,
                     PythonEvalType.SQL_SCALAR_PANDAS_ITER_UDF,
                     PythonEvalType.SQL_MAP_PANDAS_ITER_UDF,
                     PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF,
                     PythonEvalType.SQL_GROUPED_AGG_PANDAS_UDF,
                     PythonEvalType.SQL_WINDOW_AGG_PANDAS_UDF):

        # Load conf used for pandas_udf evaluation
        num_conf = read_int(infile)
        for i in range(num_conf):
            k = utf8_deserializer.loads(infile)
            v = utf8_deserializer.loads(infile)
            runner_conf[k] = v

        # NOTE: if timezone is set here, that implies respectSessionTimeZone is True
        timezone = runner_conf.get("spark.sql.session.timeZone", None)
        safecheck = runner_conf.get(
            "spark.sql.execution.pandas.arrowSafeTypeConversion",
            "false").lower() == 'true'
        # Used by SQL_GROUPED_MAP_PANDAS_UDF and SQL_SCALAR_PANDAS_UDF when returning StructType
        assign_cols_by_name = runner_conf.get(
            "spark.sql.legacy.execution.pandas.groupedMap.assignColumnsByName", "true")\
            .lower() == "true"

        # Scalar Pandas UDF handles struct type arguments as pandas DataFrames instead of
        # pandas Series. See SPARK-27240.
        df_for_struct = (
            eval_type == PythonEvalType.SQL_SCALAR_PANDAS_UDF
            or eval_type == PythonEvalType.SQL_SCALAR_PANDAS_ITER_UDF
            or eval_type == PythonEvalType.SQL_MAP_PANDAS_ITER_UDF)
        ser = ArrowStreamPandasUDFSerializer(timezone, safecheck,
                                             assign_cols_by_name,
                                             df_for_struct)
    else:
        ser = BatchedSerializer(PickleSerializer(), 100)

    num_udfs = read_int(infile)

    is_scalar_iter = eval_type == PythonEvalType.SQL_SCALAR_PANDAS_ITER_UDF
    is_map_iter = eval_type == PythonEvalType.SQL_MAP_PANDAS_ITER_UDF

    if is_scalar_iter or is_map_iter:
        if is_scalar_iter:
            assert num_udfs == 1, "One SCALAR_ITER UDF expected here."
        if is_map_iter:
            assert num_udfs == 1, "One MAP_ITER UDF expected here."

        arg_offsets, udf = read_single_udf(pickleSer,
                                           infile,
                                           eval_type,
                                           runner_conf,
                                           udf_index=0)

        def func(_, iterator):
            num_input_rows = [0]

            def map_batch(batch):
                udf_args = [batch[offset] for offset in arg_offsets]
                num_input_rows[0] += len(udf_args[0])
                if len(udf_args) == 1:
                    return udf_args[0]
                else:
                    return tuple(udf_args)

            iterator = map(map_batch, iterator)
            result_iter = udf(iterator)

            num_output_rows = 0
            for result_batch, result_type in result_iter:
                num_output_rows += len(result_batch)
                assert is_map_iter or num_output_rows <= num_input_rows[0], \
                    "Pandas MAP_ITER UDF outputted more rows than input rows."
                yield (result_batch, result_type)

            if is_scalar_iter:
                try:
                    next(iterator)
                except StopIteration:
                    pass
                else:
                    raise RuntimeError(
                        "SQL_SCALAR_PANDAS_ITER_UDF should exhaust the input "
                        "iterator.")

            if is_scalar_iter and num_output_rows != num_input_rows[0]:
                raise RuntimeError(
                    "The number of output rows of pandas iterator UDF should be "
                    "the same with input rows. The input rows number is %d but the "
                    "output rows number is %d." %
                    (num_input_rows[0], num_output_rows))

        # profiling is not supported for UDF
        return func, None, ser, ser

    udfs = {}
    call_udf = []
    mapper_str = ""
    if eval_type == PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF:
        # Create function like this:
        #   lambda a: f([a[0]], [a[0], a[1]])

        # We assume there is only one UDF here because grouped map doesn't
        # support combining multiple UDFs.
        assert num_udfs == 1

        # See FlatMapGroupsInPandasExec for how arg_offsets are used to
        # distinguish between grouping attributes and data attributes
        arg_offsets, udf = read_single_udf(pickleSer,
                                           infile,
                                           eval_type,
                                           runner_conf,
                                           udf_index=0)
        udfs['f'] = udf
        split_offset = arg_offsets[0] + 1
        arg0 = ["a[%d]" % o for o in arg_offsets[1:split_offset]]
        arg1 = ["a[%d]" % o for o in arg_offsets[split_offset:]]
        mapper_str = "lambda a: f([%s], [%s])" % (", ".join(arg0),
                                                  ", ".join(arg1))
    else:
        # Create function like this:
        #   lambda a: (f0(a[0]), f1(a[1], a[2]), f2(a[3]))
        # In the special case of a single UDF this will return a single result rather
        # than a tuple of results; this is the format that the JVM side expects.
        for i in range(num_udfs):
            arg_offsets, udf = read_single_udf(pickleSer,
                                               infile,
                                               eval_type,
                                               runner_conf,
                                               udf_index=i)
            udfs['f%d' % i] = udf
            args = ["a[%d]" % o for o in arg_offsets]
            call_udf.append("f%d(%s)" % (i, ", ".join(args)))
        mapper_str = "lambda a: (%s)" % (", ".join(call_udf))

    mapper = eval(mapper_str, udfs)
    func = lambda _, it: map(mapper, it)

    # profiling is not supported for UDF
    return func, None, ser, ser
Esempio n. 47
0
def manager():
    # Create a new process group to corral our children
    os.setpgid(0, 0)

    # Create a listening socket on the AF_INET loopback interface
    listen_sock = socket.socket(AF_INET, SOCK_STREAM)
    listen_sock.bind(('127.0.0.1', 0))
    listen_sock.listen(max(1024, SOMAXCONN))
    listen_host, listen_port = listen_sock.getsockname()
    write_int(listen_port, sys.stdout)

    def shutdown(code):
        signal.signal(SIGTERM, SIG_DFL)
        # Send SIGHUP to notify workers of shutdown
        os.kill(0, SIGHUP)
        exit(code)

    def handle_sigterm(*args):
        shutdown(1)
    signal.signal(SIGTERM, handle_sigterm)  # Gracefully exit on SIGTERM
    signal.signal(SIGHUP, SIG_IGN)  # Don't die on SIGHUP

    # Cleanup zombie children
    def handle_sigchld(*args):
        try:
            pid, status = os.waitpid(0, os.WNOHANG)
            if status != 0:
                msg = "worker %s crashed abruptly with exit status %s" % (pid, status)
                print >> sys.stderr, msg
        except EnvironmentError as err:
            if err.errno not in (ECHILD, EINTR):
                raise
    signal.signal(SIGCHLD, handle_sigchld)

    # Initialization complete
    sys.stdout.close()
    try:
        while True:
            try:
                ready_fds = select.select([0, listen_sock], [], [])[0]
            except select.error as ex:
                if ex[0] == EINTR:
                    continue
                else:
                    raise
            if 0 in ready_fds:
                try:
                    worker_pid = read_int(sys.stdin)
                except EOFError:
                    # Spark told us to exit by closing stdin
                    shutdown(0)
                try:
                    os.kill(worker_pid, signal.SIGKILL)
                except OSError:
                    pass  # process already died

            if listen_sock in ready_fds:
                sock, addr = listen_sock.accept()
                # Launch a worker process
                try:
                    pid = os.fork()
                    if pid == 0:
                        listen_sock.close()
                        try:
                            worker(sock)
                        except:
                            traceback.print_exc()
                            os._exit(1)
                        else:
                            os._exit(0)
                    else:
                        sock.close()

                except OSError as e:
                    print >> sys.stderr, "Daemon failed to fork PySpark worker: %s" % e
                    outfile = os.fdopen(os.dup(sock.fileno()), "a+", 65536)
                    write_int(-1, outfile)  # Signal that the fork failed
                    outfile.flush()
                    outfile.close()
                    sock.close()
    finally:
        shutdown(1)
Esempio n. 48
0
def launch_gateway(conf=None):
    """
    launch jvm gateway
    :param conf: spark configuration passed to spark-submit
    :return:
    """
    if "PYSPARK_GATEWAY_PORT" in os.environ:
        gateway_port = int(os.environ["PYSPARK_GATEWAY_PORT"])
    else:
        SPARK_HOME = _find_spark_home()
        # Launch the Py4j gateway using Spark's run command so that we pick up the
        # proper classpath and settings from spark-env.sh
        on_windows = platform.system() == "Windows"
        script = "./bin/spark-submit.cmd" if on_windows else "./bin/spark-submit"
        command = [os.path.join(SPARK_HOME, script)]
        if conf:
            for k, v in conf.getAll():
                command += ['--conf', '%s=%s' % (k, v)]
        submit_args = os.environ.get("PYSPARK_SUBMIT_ARGS", "pyspark-shell")
        if os.environ.get("SPARK_TESTING"):
            submit_args = ' '.join(
                ["--conf spark.ui.enabled=false", submit_args])
        command = command + shlex.split(submit_args)

        # Start a socket that will be used by PythonGatewayServer to communicate its port to us
        callback_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
        callback_socket.bind(('127.0.0.1', 0))
        callback_socket.listen(1)
        callback_host, callback_port = callback_socket.getsockname()
        env = dict(os.environ)
        env['_PYSPARK_DRIVER_CALLBACK_HOST'] = callback_host
        env['_PYSPARK_DRIVER_CALLBACK_PORT'] = str(callback_port)

        # Launch the Java gateway.
        # We open a pipe to stdin so that the Java gateway can die when the pipe is broken
        if not on_windows:
            # Don't send ctrl-c / SIGINT to the Java gateway:
            def preexec_func():
                signal.signal(signal.SIGINT, signal.SIG_IGN)

            proc = Popen(command, stdin=PIPE, preexec_fn=preexec_func, env=env)
        else:
            # preexec_fn not supported on Windows
            proc = Popen(command, stdin=PIPE, env=env)

        gateway_port = None
        # We use select() here in order to avoid blocking indefinitely if the subprocess dies
        # before connecting
        while gateway_port is None and proc.poll() is None:
            timeout = 1  # (seconds)
            readable, _, _ = select.select([callback_socket], [], [], timeout)
            if callback_socket in readable:
                gateway_connection = callback_socket.accept()[0]
                # Determine which ephemeral port the server started on:
                gateway_port = read_int(gateway_connection.makefile(mode="rb"))
                gateway_connection.close()
                callback_socket.close()
        if gateway_port is None:
            raise Exception(
                "Java gateway process exited before sending the driver its port number"
            )

        # In Windows, ensure the Java child processes do not linger after Python has exited.
        # In UNIX-based systems, the child process can kill itself on broken pipe (i.e. when
        # the parent process' stdin sends an EOF). In Windows, however, this is not possible
        # because java.lang.Process reads directly from the parent process' stdin, contending
        # with any opportunity to read an EOF from the parent. Note that this is only best
        # effort and will not take effect if the python process is violently terminated.
        if on_windows:
            # In Windows, the child process here is "spark-submit.cmd", not the JVM itself
            # (because the UNIX "exec" command is not available). This means we cannot simply
            # call proc.kill(), which kills only the "spark-submit.cmd" process but not the
            # JVMs. Instead, we use "taskkill" with the tree-kill option "/t" to terminate all
            # child processes in the tree (http://technet.microsoft.com/en-us/library/bb491009.aspx)
            def killChild():
                Popen([
                    "cmd", "/c", "taskkill", "/f", "/t", "/pid",
                    str(proc.pid)
                ])

            atexit.register(killChild)

    # Connect to the gateway
    gateway = JavaGateway(GatewayClient(port=gateway_port), auto_convert=True)

    # Import the classes used by PySpark
    java_import(gateway.jvm, "org.apache.spark.SparkConf")
    java_import(gateway.jvm, "org.apache.spark.api.java.*")
    java_import(gateway.jvm, "org.apache.spark.api.python.*")
    java_import(gateway.jvm, "org.apache.spark.ml.python.*")
    java_import(gateway.jvm, "org.apache.spark.mllib.api.python.*")
    # TODO(davies): move into sql
    java_import(gateway.jvm, "org.apache.spark.sql.*")
    java_import(gateway.jvm, "org.apache.spark.sql.hive.*")
    java_import(gateway.jvm, "scala.Tuple2")

    return gateway
Esempio n. 49
0
def main(infile, outfile):
    faulthandler_log_path = os.environ.get("PYTHON_FAULTHANDLER_DIR", None)
    try:
        if faulthandler_log_path:
            faulthandler_log_path = os.path.join(faulthandler_log_path, str(os.getpid()))
            faulthandler_log_file = open(faulthandler_log_path, "w")
            faulthandler.enable(file=faulthandler_log_file)

        boot_time = time.time()
        split_index = read_int(infile)
        if split_index == -1:  # for unit tests
            sys.exit(-1)

        version = utf8_deserializer.loads(infile)
        if version != "%d.%d" % sys.version_info[:2]:
            raise RuntimeError(
                (
                    "Python in worker has different version %s than that in "
                    + "driver %s, PySpark cannot run with different minor versions. "
                    + "Please check environment variables PYSPARK_PYTHON and "
                    + "PYSPARK_DRIVER_PYTHON are correctly set."
                )
                % ("%d.%d" % sys.version_info[:2], version)
            )

        # read inputs only for a barrier task
        isBarrier = read_bool(infile)
        boundPort = read_int(infile)
        secret = UTF8Deserializer().loads(infile)

        # set up memory limits
        memory_limit_mb = int(os.environ.get("PYSPARK_EXECUTOR_MEMORY_MB", "-1"))
        if memory_limit_mb > 0 and has_resource_module:
            total_memory = resource.RLIMIT_AS
            try:
                (soft_limit, hard_limit) = resource.getrlimit(total_memory)
                msg = "Current mem limits: {0} of max {1}\n".format(soft_limit, hard_limit)
                print(msg, file=sys.stderr)

                # convert to bytes
                new_limit = memory_limit_mb * 1024 * 1024

                if soft_limit == resource.RLIM_INFINITY or new_limit < soft_limit:
                    msg = "Setting mem limits to {0} of max {1}\n".format(new_limit, new_limit)
                    print(msg, file=sys.stderr)
                    resource.setrlimit(total_memory, (new_limit, new_limit))

            except (resource.error, OSError, ValueError) as e:
                # not all systems support resource limits, so warn instead of failing
                lineno = (
                    getframeinfo(currentframe()).lineno + 1 if currentframe() is not None else 0
                )
                print(
                    warnings.formatwarning(
                        "Failed to set memory limit: {0}".format(e),
                        ResourceWarning,
                        __file__,
                        lineno,
                    ),
                    file=sys.stderr,
                )

        # initialize global state
        taskContext = None
        if isBarrier:
            taskContext = BarrierTaskContext._getOrCreate()
            BarrierTaskContext._initialize(boundPort, secret)
            # Set the task context instance here, so we can get it by TaskContext.get for
            # both TaskContext and BarrierTaskContext
            TaskContext._setTaskContext(taskContext)
        else:
            taskContext = TaskContext._getOrCreate()
        # read inputs for TaskContext info
        taskContext._stageId = read_int(infile)
        taskContext._partitionId = read_int(infile)
        taskContext._attemptNumber = read_int(infile)
        taskContext._taskAttemptId = read_long(infile)
        taskContext._cpus = read_int(infile)
        taskContext._resources = {}
        for r in range(read_int(infile)):
            key = utf8_deserializer.loads(infile)
            name = utf8_deserializer.loads(infile)
            addresses = []
            taskContext._resources = {}
            for a in range(read_int(infile)):
                addresses.append(utf8_deserializer.loads(infile))
            taskContext._resources[key] = ResourceInformation(name, addresses)

        taskContext._localProperties = dict()
        for i in range(read_int(infile)):
            k = utf8_deserializer.loads(infile)
            v = utf8_deserializer.loads(infile)
            taskContext._localProperties[k] = v

        shuffle.MemoryBytesSpilled = 0
        shuffle.DiskBytesSpilled = 0
        _accumulatorRegistry.clear()

        # fetch name of workdir
        spark_files_dir = utf8_deserializer.loads(infile)
        SparkFiles._root_directory = spark_files_dir
        SparkFiles._is_running_on_worker = True

        # fetch names of includes (*.zip and *.egg files) and construct PYTHONPATH
        add_path(spark_files_dir)  # *.py files that were added will be copied here
        num_python_includes = read_int(infile)
        for _ in range(num_python_includes):
            filename = utf8_deserializer.loads(infile)
            add_path(os.path.join(spark_files_dir, filename))

        importlib.invalidate_caches()

        # fetch names and values of broadcast variables
        needs_broadcast_decryption_server = read_bool(infile)
        num_broadcast_variables = read_int(infile)
        if needs_broadcast_decryption_server:
            # read the decrypted data from a server in the jvm
            port = read_int(infile)
            auth_secret = utf8_deserializer.loads(infile)
            (broadcast_sock_file, _) = local_connect_and_auth(port, auth_secret)

        for _ in range(num_broadcast_variables):
            bid = read_long(infile)
            if bid >= 0:
                if needs_broadcast_decryption_server:
                    read_bid = read_long(broadcast_sock_file)
                    assert read_bid == bid
                    _broadcastRegistry[bid] = Broadcast(sock_file=broadcast_sock_file)
                else:
                    path = utf8_deserializer.loads(infile)
                    _broadcastRegistry[bid] = Broadcast(path=path)

            else:
                bid = -bid - 1
                _broadcastRegistry.pop(bid)

        if needs_broadcast_decryption_server:
            broadcast_sock_file.write(b"1")
            broadcast_sock_file.close()

        _accumulatorRegistry.clear()
        eval_type = read_int(infile)
        if eval_type == PythonEvalType.NON_UDF:
            func, profiler, deserializer, serializer = read_command(pickleSer, infile)
        else:
            func, profiler, deserializer, serializer = read_udfs(pickleSer, infile, eval_type)

        init_time = time.time()

        def process():
            iterator = deserializer.load_stream(infile)
            out_iter = func(split_index, iterator)
            try:
                serializer.dump_stream(out_iter, outfile)
            finally:
                if hasattr(out_iter, "close"):
                    out_iter.close()

        if profiler:
            profiler.profile(process)
        else:
            process()

        # Reset task context to None. This is a guard code to avoid residual context when worker
        # reuse.
        TaskContext._setTaskContext(None)
        BarrierTaskContext._setTaskContext(None)
    except BaseException as e:
        try:
            exc_info = None
            if os.environ.get("SPARK_SIMPLIFIED_TRACEBACK", False):
                tb = try_simplify_traceback(sys.exc_info()[-1])
                if tb is not None:
                    e.__cause__ = None
                    exc_info = "".join(traceback.format_exception(type(e), e, tb))
            if exc_info is None:
                exc_info = traceback.format_exc()

            write_int(SpecialLengths.PYTHON_EXCEPTION_THROWN, outfile)
            write_with_length(exc_info.encode("utf-8"), outfile)
        except IOError:
            # JVM close the socket
            pass
        except BaseException:
            # Write the error to stderr if it happened while serializing
            print("PySpark worker failed with exception:", file=sys.stderr)
            print(traceback.format_exc(), file=sys.stderr)
        sys.exit(-1)
    finally:
        if faulthandler_log_path:
            faulthandler.disable()
            faulthandler_log_file.close()
            os.remove(faulthandler_log_path)
    finish_time = time.time()
    report_times(outfile, boot_time, init_time, finish_time)
    write_long(shuffle.MemoryBytesSpilled, outfile)
    write_long(shuffle.DiskBytesSpilled, outfile)

    # Mark the beginning of the accumulators section of the output
    write_int(SpecialLengths.END_OF_DATA_SECTION, outfile)
    write_int(len(_accumulatorRegistry), outfile)
    for (aid, accum) in _accumulatorRegistry.items():
        pickleSer._write_with_length((aid, accum._value), outfile)

    # check end of stream
    if read_int(infile) == SpecialLengths.END_OF_STREAM:
        write_int(SpecialLengths.END_OF_STREAM, outfile)
    else:
        # write a different value to tell JVM to not reuse this worker
        write_int(SpecialLengths.END_OF_DATA_SECTION, outfile)
        sys.exit(-1)
Esempio n. 50
0
def read_udfs(pickleSer, infile, eval_type):
    runner_conf = {}

    if eval_type in (PythonEvalType.SQL_SCALAR_PANDAS_UDF,
                     PythonEvalType.SQL_COGROUPED_MAP_PANDAS_UDF,
                     PythonEvalType.SQL_SCALAR_PANDAS_ITER_UDF,
                     PythonEvalType.SQL_MAP_PANDAS_ITER_UDF,
                     PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF,
                     PythonEvalType.SQL_GROUPED_AGG_PANDAS_UDF,
                     PythonEvalType.SQL_WINDOW_AGG_PANDAS_UDF):

        # Load conf used for pandas_udf evaluation
        num_conf = read_int(infile)
        for i in range(num_conf):
            k = utf8_deserializer.loads(infile)
            v = utf8_deserializer.loads(infile)
            runner_conf[k] = v

        # NOTE: if timezone is set here, that implies respectSessionTimeZone is True
        timezone = runner_conf.get("spark.sql.session.timeZone", None)
        safecheck = runner_conf.get(
            "spark.sql.execution.pandas.convertToArrowArraySafely",
            "false").lower() == 'true'
        # Used by SQL_GROUPED_MAP_PANDAS_UDF and SQL_SCALAR_PANDAS_UDF when returning StructType
        assign_cols_by_name = runner_conf.get(
            "spark.sql.legacy.execution.pandas.groupedMap.assignColumnsByName", "true")\
            .lower() == "true"

        if eval_type == PythonEvalType.SQL_COGROUPED_MAP_PANDAS_UDF:
            ser = CogroupUDFSerializer(timezone, safecheck,
                                       assign_cols_by_name)
        else:
            # Scalar Pandas UDF handles struct type arguments as pandas DataFrames instead of
            # pandas Series. See SPARK-27240.
            df_for_struct = (
                eval_type == PythonEvalType.SQL_SCALAR_PANDAS_UDF
                or eval_type == PythonEvalType.SQL_SCALAR_PANDAS_ITER_UDF
                or eval_type == PythonEvalType.SQL_MAP_PANDAS_ITER_UDF)
            ser = ArrowStreamPandasUDFSerializer(timezone, safecheck,
                                                 assign_cols_by_name,
                                                 df_for_struct)
    else:
        ser = BatchedSerializer(PickleSerializer(), 100)

    num_udfs = read_int(infile)

    is_scalar_iter = eval_type == PythonEvalType.SQL_SCALAR_PANDAS_ITER_UDF
    is_map_iter = eval_type == PythonEvalType.SQL_MAP_PANDAS_ITER_UDF

    if is_scalar_iter or is_map_iter:
        if is_scalar_iter:
            assert num_udfs == 1, "One SCALAR_ITER UDF expected here."
        if is_map_iter:
            assert num_udfs == 1, "One MAP_ITER UDF expected here."

        arg_offsets, udf = read_single_udf(pickleSer,
                                           infile,
                                           eval_type,
                                           runner_conf,
                                           udf_index=0)

        def func(_, iterator):
            num_input_rows = 0

            def map_batch(batch):
                nonlocal num_input_rows

                udf_args = [batch[offset] for offset in arg_offsets]
                num_input_rows += len(udf_args[0])
                if len(udf_args) == 1:
                    return udf_args[0]
                else:
                    return tuple(udf_args)

            iterator = map(map_batch, iterator)
            result_iter = udf(iterator)

            num_output_rows = 0
            for result_batch, result_type in result_iter:
                num_output_rows += len(result_batch)
                # This assert is for Scalar Iterator UDF to fail fast.
                # The length of the entire input can only be explicitly known
                # by consuming the input iterator in user side. Therefore,
                # it's very unlikely the output length is higher than
                # input length.
                assert is_map_iter or num_output_rows <= num_input_rows, \
                    "Pandas SCALAR_ITER UDF outputted more rows than input rows."
                yield (result_batch, result_type)

            if is_scalar_iter:
                try:
                    next(iterator)
                except StopIteration:
                    pass
                else:
                    raise RuntimeError(
                        "pandas iterator UDF should exhaust the input "
                        "iterator.")

                if num_output_rows != num_input_rows:
                    raise RuntimeError(
                        "The length of output in Scalar iterator pandas UDF should be "
                        "the same with the input's; however, the length of output was %d and the "
                        "length of input was %d." %
                        (num_output_rows, num_input_rows))

        # profiling is not supported for UDF
        return func, None, ser, ser

    def extract_key_value_indexes(grouped_arg_offsets):
        """
        Helper function to extract the key and value indexes from arg_offsets for the grouped and
        cogrouped pandas udfs. See BasePandasGroupExec.resolveArgOffsets for equivalent scala code.

        Parameters
        ----------
        grouped_arg_offsets:  list
            List containing the key and value indexes of columns of the
            DataFrames to be passed to the udf. It consists of n repeating groups where n is the
            number of DataFrames.  Each group has the following format:
                group[0]: length of group
                group[1]: length of key indexes
                group[2.. group[1] +2]: key attributes
                group[group[1] +3 group[0]]: value attributes
        """
        parsed = []
        idx = 0
        while idx < len(grouped_arg_offsets):
            offsets_len = grouped_arg_offsets[idx]
            idx += 1
            offsets = grouped_arg_offsets[idx:idx + offsets_len]
            split_index = offsets[0] + 1
            offset_keys = offsets[1:split_index]
            offset_values = offsets[split_index:]
            parsed.append([offset_keys, offset_values])
            idx += offsets_len
        return parsed

    if eval_type == PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF:
        # We assume there is only one UDF here because grouped map doesn't
        # support combining multiple UDFs.
        assert num_udfs == 1

        # See FlatMapGroupsInPandasExec for how arg_offsets are used to
        # distinguish between grouping attributes and data attributes
        arg_offsets, f = read_single_udf(pickleSer,
                                         infile,
                                         eval_type,
                                         runner_conf,
                                         udf_index=0)
        parsed_offsets = extract_key_value_indexes(arg_offsets)

        # Create function like this:
        #   mapper a: f([a[0]], [a[0], a[1]])
        def mapper(a):
            keys = [a[o] for o in parsed_offsets[0][0]]
            vals = [a[o] for o in parsed_offsets[0][1]]
            return f(keys, vals)
    elif eval_type == PythonEvalType.SQL_COGROUPED_MAP_PANDAS_UDF:
        # We assume there is only one UDF here because cogrouped map doesn't
        # support combining multiple UDFs.
        assert num_udfs == 1
        arg_offsets, f = read_single_udf(pickleSer,
                                         infile,
                                         eval_type,
                                         runner_conf,
                                         udf_index=0)

        parsed_offsets = extract_key_value_indexes(arg_offsets)

        def mapper(a):
            df1_keys = [a[0][o] for o in parsed_offsets[0][0]]
            df1_vals = [a[0][o] for o in parsed_offsets[0][1]]
            df2_keys = [a[1][o] for o in parsed_offsets[1][0]]
            df2_vals = [a[1][o] for o in parsed_offsets[1][1]]
            return f(df1_keys, df1_vals, df2_keys, df2_vals)
    else:
        udfs = []
        for i in range(num_udfs):
            udfs.append(
                read_single_udf(pickleSer,
                                infile,
                                eval_type,
                                runner_conf,
                                udf_index=i))

        def mapper(a):
            result = tuple(
                f(*[a[o] for o in arg_offsets]) for (arg_offsets, f) in udfs)
            # In the special case of a single UDF this will return a single result rather
            # than a tuple of results; this is the format that the JVM side expects.
            if len(result) == 1:
                return result[0]
            else:
                return result

    func = lambda _, it: map(mapper, it)

    # profiling is not supported for UDF
    return func, None, ser, ser
Esempio n. 51
0
def launch_gateway():
    SPARK_HOME = os.environ["SPARK_HOME"]

    if "PYSPARK_GATEWAY_PORT" in os.environ:
        gateway_port = int(os.environ["PYSPARK_GATEWAY_PORT"])
    else:
        # Launch the Py4j gateway using Spark's run command so that we pick up the
        # proper classpath and settings from spark-env.sh
        on_windows = platform.system() == "Windows"
        script = "./bin/spark-submit.cmd" if on_windows else "./bin/spark-submit"
        submit_args = os.environ.get("PYSPARK_SUBMIT_ARGS", "pyspark-shell")
        command = [os.path.join(SPARK_HOME, script)] + shlex.split(submit_args)

        # Start a socket that will be used by PythonGatewayServer to communicate its port to us
        callback_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
        callback_socket.bind(('127.0.0.1', 0))
        callback_socket.listen(1)
        callback_host, callback_port = callback_socket.getsockname()
        env = dict(os.environ)
        env['_PYSPARK_DRIVER_CALLBACK_HOST'] = callback_host
        env['_PYSPARK_DRIVER_CALLBACK_PORT'] = str(callback_port)

        # Launch the Java gateway.
        # We open a pipe to stdin so that the Java gateway can die when the pipe is broken
        if not on_windows:
            # Don't send ctrl-c / SIGINT to the Java gateway:
            def preexec_func():
                signal.signal(signal.SIGINT, signal.SIG_IGN)
            proc = Popen(command, stdin=PIPE, preexec_fn=preexec_func, env=env)
        else:
            # preexec_fn not supported on Windows
            proc = Popen(command, stdin=PIPE, env=env)

        gateway_port = None
        # We use select() here in order to avoid blocking indefinitely if the subprocess dies
        # before connecting
        while gateway_port is None and proc.poll() is None:
            timeout = 1  # (seconds)
            readable, _, _ = select.select([callback_socket], [], [], timeout)
            if callback_socket in readable:
                gateway_connection = callback_socket.accept()[0]
                # Determine which ephemeral port the server started on:
                gateway_port = read_int(gateway_connection.makefile())
                gateway_connection.close()
                callback_socket.close()
        if gateway_port is None:
            raise Exception("Java gateway process exited before sending the driver its port number")

        # In Windows, ensure the Java child processes do not linger after Python has exited.
        # In UNIX-based systems, the child process can kill itself on broken pipe (i.e. when
        # the parent process' stdin sends an EOF). In Windows, however, this is not possible
        # because java.lang.Process reads directly from the parent process' stdin, contending
        # with any opportunity to read an EOF from the parent. Note that this is only best
        # effort and will not take effect if the python process is violently terminated.
        if on_windows:
            # In Windows, the child process here is "spark-submit.cmd", not the JVM itself
            # (because the UNIX "exec" command is not available). This means we cannot simply
            # call proc.kill(), which kills only the "spark-submit.cmd" process but not the
            # JVMs. Instead, we use "taskkill" with the tree-kill option "/t" to terminate all
            # child processes in the tree (http://technet.microsoft.com/en-us/library/bb491009.aspx)
            def killChild():
                Popen(["cmd", "/c", "taskkill", "/f", "/t", "/pid", str(proc.pid)])
            atexit.register(killChild)

    # Connect to the gateway
    gateway = JavaGateway(GatewayClient(port=gateway_port), auto_convert=False)

    # Import the classes used by PySpark
    java_import(gateway.jvm, "org.apache.spark.SparkConf")
    java_import(gateway.jvm, "org.apache.spark.api.java.*")
    java_import(gateway.jvm, "org.apache.spark.api.python.*")
    java_import(gateway.jvm, "org.apache.spark.mllib.api.python.*")
    # TODO(davies): move into sql
    java_import(gateway.jvm, "org.apache.spark.sql.*")
    java_import(gateway.jvm, "org.apache.spark.sql.hive.*")
    java_import(gateway.jvm, "scala.Tuple2")

    return gateway
Esempio n. 52
0
    def loads(self, obj):
        stream = BytesIO(obj)
        key_length = read_int(stream)
        key = stream.read(key_length).decode('utf-8')

        return (key, _read_vec(stream))
Esempio n. 53
0
def manager():
    # Create a new process group to corral our children
    os.setpgid(0, 0)

    # Create a listening socket on the AF_INET loopback interface
    listen_sock = socket.socket(AF_INET, SOCK_STREAM)
    listen_sock.bind(('127.0.0.1', 0))
    listen_sock.listen(max(1024, SOMAXCONN))
    listen_host, listen_port = listen_sock.getsockname()
    write_int(listen_port, sys.stdout)

    def shutdown(code):
        signal.signal(SIGTERM, SIG_DFL)
        # Send SIGHUP to notify workers of shutdown
        os.kill(0, SIGHUP)
        exit(code)

    def handle_sigterm(*args):
        shutdown(1)

    signal.signal(SIGTERM, handle_sigterm)  # Gracefully exit on SIGTERM
    signal.signal(SIGHUP, SIG_IGN)  # Don't die on SIGHUP

    # Cleanup zombie children
    def handle_sigchld(*args):
        try:
            pid, status = os.waitpid(0, os.WNOHANG)
            if status != 0:
                msg = "worker %s crashed abruptly with exit status %s" % (
                    pid, status)
                print >> sys.stderr, msg
        except EnvironmentError as err:
            if err.errno not in (ECHILD, EINTR):
                raise

    signal.signal(SIGCHLD, handle_sigchld)

    # Initialization complete
    sys.stdout.close()
    try:
        while True:
            try:
                ready_fds = select.select([0, listen_sock], [], [])[0]
            except select.error as ex:
                if ex[0] == EINTR:
                    continue
                else:
                    raise
            if 0 in ready_fds:
                try:
                    worker_pid = read_int(sys.stdin)
                except EOFError:
                    # Spark told us to exit by closing stdin
                    shutdown(0)
                try:
                    os.kill(worker_pid, signal.SIGKILL)
                except OSError:
                    pass  # process already died

            if listen_sock in ready_fds:
                sock, addr = listen_sock.accept()
                # Launch a worker process
                try:
                    pid = os.fork()
                    if pid == 0:
                        listen_sock.close()
                        try:
                            worker(sock)
                        except:
                            traceback.print_exc()
                            os._exit(1)
                        else:
                            os._exit(0)
                    else:
                        sock.close()

                except OSError as e:
                    print >> sys.stderr, "Daemon failed to fork PySpark worker: %s" % e
                    outfile = os.fdopen(os.dup(sock.fileno()), "a+", 65536)
                    write_int(-1, outfile)  # Signal that the fork failed
                    outfile.flush()
                    outfile.close()
                    sock.close()
    finally:
        shutdown(1)
Esempio n. 54
0
def read_udfs(pickleSer, infile, eval_type):
    runner_conf = {}

    if eval_type in (PythonEvalType.SQL_SCALAR_PANDAS_UDF,
                     PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF,
                     PythonEvalType.SQL_GROUPED_AGG_PANDAS_UDF,
                     PythonEvalType.SQL_WINDOW_AGG_PANDAS_UDF):

        # Load conf used for pandas_udf evaluation
        num_conf = read_int(infile)
        for i in range(num_conf):
            k = utf8_deserializer.loads(infile)
            v = utf8_deserializer.loads(infile)
            runner_conf[k] = v

        # NOTE: if timezone is set here, that implies respectSessionTimeZone is True
        timezone = runner_conf.get("spark.sql.session.timeZone", None)
        safecheck = runner_conf.get("spark.sql.execution.pandas.arrowSafeTypeConversion",
                                    "false").lower() == 'true'
        # NOTE: this is duplicated from wrap_grouped_map_pandas_udf
        assign_cols_by_name = runner_conf.get(
            "spark.sql.legacy.execution.pandas.groupedMap.assignColumnsByName", "true")\
            .lower() == "true"

        ser = ArrowStreamPandasSerializer(timezone, safecheck, assign_cols_by_name)
    else:
        ser = BatchedSerializer(PickleSerializer(), 100)

    num_udfs = read_int(infile)
    udfs = {}
    call_udf = []
    mapper_str = ""
    if eval_type == PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF:
        # Create function like this:
        #   lambda a: f([a[0]], [a[0], a[1]])

        # We assume there is only one UDF here because grouped map doesn't
        # support combining multiple UDFs.
        assert num_udfs == 1

        # See FlatMapGroupsInPandasExec for how arg_offsets are used to
        # distinguish between grouping attributes and data attributes
        arg_offsets, udf = read_single_udf(
            pickleSer, infile, eval_type, runner_conf, udf_index=0)
        udfs['f'] = udf
        split_offset = arg_offsets[0] + 1
        arg0 = ["a[%d]" % o for o in arg_offsets[1: split_offset]]
        arg1 = ["a[%d]" % o for o in arg_offsets[split_offset:]]
        mapper_str = "lambda a: f([%s], [%s])" % (", ".join(arg0), ", ".join(arg1))
    else:
        # Create function like this:
        #   lambda a: (f0(a[0]), f1(a[1], a[2]), f2(a[3]))
        # In the special case of a single UDF this will return a single result rather
        # than a tuple of results; this is the format that the JVM side expects.
        for i in range(num_udfs):
            arg_offsets, udf = read_single_udf(
                pickleSer, infile, eval_type, runner_conf, udf_index=i)
            udfs['f%d' % i] = udf
            args = ["a[%d]" % o for o in arg_offsets]
            call_udf.append("f%d(%s)" % (i, ", ".join(args)))
        mapper_str = "lambda a: (%s)" % (", ".join(call_udf))

    mapper = eval(mapper_str, udfs)
    func = lambda _, it: map(mapper, it)

    # profiling is not supported for UDF
    return func, None, ser, ser
Esempio n. 55
0
def read_udfs(pickleSer, infile, eval_type):
    runner_conf = {}

    if eval_type in (PythonEvalType.SQL_SCALAR_PANDAS_UDF,
                     PythonEvalType.SQL_COGROUPED_MAP_PANDAS_UDF,
                     PythonEvalType.SQL_SCALAR_PANDAS_ITER_UDF,
                     PythonEvalType.SQL_MAP_PANDAS_ITER_UDF,
                     PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF,
                     PythonEvalType.SQL_GROUPED_AGG_PANDAS_UDF,
                     PythonEvalType.SQL_WINDOW_AGG_PANDAS_UDF):

        # Load conf used for pandas_udf evaluation
        num_conf = read_int(infile)
        for i in range(num_conf):
            k = utf8_deserializer.loads(infile)
            v = utf8_deserializer.loads(infile)
            runner_conf[k] = v

        # NOTE: if timezone is set here, that implies respectSessionTimeZone is True
        timezone = runner_conf.get("spark.sql.session.timeZone", None)
        safecheck = runner_conf.get(
            "spark.sql.execution.pandas.arrowSafeTypeConversion",
            "false").lower() == 'true'
        # Used by SQL_GROUPED_MAP_PANDAS_UDF and SQL_SCALAR_PANDAS_UDF when returning StructType
        assign_cols_by_name = runner_conf.get(
            "spark.sql.legacy.execution.pandas.groupedMap.assignColumnsByName", "true")\
            .lower() == "true"

        if eval_type == PythonEvalType.SQL_COGROUPED_MAP_PANDAS_UDF:
            ser = CogroupUDFSerializer(timezone, safecheck,
                                       assign_cols_by_name)
        else:
            # Scalar Pandas UDF handles struct type arguments as pandas DataFrames instead of
            # pandas Series. See SPARK-27240.
            df_for_struct = (
                eval_type == PythonEvalType.SQL_SCALAR_PANDAS_UDF
                or eval_type == PythonEvalType.SQL_SCALAR_PANDAS_ITER_UDF
                or eval_type == PythonEvalType.SQL_MAP_PANDAS_ITER_UDF)
            ser = ArrowStreamPandasUDFSerializer(timezone, safecheck,
                                                 assign_cols_by_name,
                                                 df_for_struct)
    else:
        ser = BatchedSerializer(PickleSerializer(), 100)

    num_udfs = read_int(infile)

    is_scalar_iter = eval_type == PythonEvalType.SQL_SCALAR_PANDAS_ITER_UDF
    is_map_iter = eval_type == PythonEvalType.SQL_MAP_PANDAS_ITER_UDF

    if is_scalar_iter or is_map_iter:
        if is_scalar_iter:
            assert num_udfs == 1, "One SCALAR_ITER UDF expected here."
        if is_map_iter:
            assert num_udfs == 1, "One MAP_ITER UDF expected here."

        arg_offsets, udf = read_single_udf(pickleSer,
                                           infile,
                                           eval_type,
                                           runner_conf,
                                           udf_index=0)

        def func(_, iterator):
            num_input_rows = [0]

            def map_batch(batch):
                udf_args = [batch[offset] for offset in arg_offsets]
                num_input_rows[0] += len(udf_args[0])
                if len(udf_args) == 1:
                    return udf_args[0]
                else:
                    return tuple(udf_args)

            iterator = map(map_batch, iterator)
            result_iter = udf(iterator)

            num_output_rows = 0
            for result_batch, result_type in result_iter:
                num_output_rows += len(result_batch)
                assert is_map_iter or num_output_rows <= num_input_rows[0], \
                    "Pandas MAP_ITER UDF outputted more rows than input rows."
                yield (result_batch, result_type)

            if is_scalar_iter:
                try:
                    next(iterator)
                except StopIteration:
                    pass
                else:
                    raise RuntimeError(
                        "SQL_SCALAR_PANDAS_ITER_UDF should exhaust the input "
                        "iterator.")

            if is_scalar_iter and num_output_rows != num_input_rows[0]:
                raise RuntimeError(
                    "The number of output rows of pandas iterator UDF should be "
                    "the same with input rows. The input rows number is %d but the "
                    "output rows number is %d." %
                    (num_input_rows[0], num_output_rows))

        # profiling is not supported for UDF
        return func, None, ser, ser

    def extract_key_value_indexes(grouped_arg_offsets):
        """
        Helper function to extract the key and value indexes from arg_offsets for the grouped and
        cogrouped pandas udfs. See BasePandasGroupExec.resolveArgOffsets for equivalent scala code.

        :param grouped_arg_offsets:  List containing the key and value indexes of columns of the
            DataFrames to be passed to the udf. It consists of n repeating groups where n is the
            number of DataFrames.  Each group has the following format:
                group[0]: length of group
                group[1]: length of key indexes
                group[2.. group[1] +2]: key attributes
                group[group[1] +3 group[0]]: value attributes
        """
        parsed = []
        idx = 0
        while idx < len(grouped_arg_offsets):
            offsets_len = grouped_arg_offsets[idx]
            idx += 1
            offsets = grouped_arg_offsets[idx:idx + offsets_len]
            split_index = offsets[0] + 1
            offset_keys = offsets[1:split_index]
            offset_values = offsets[split_index:]
            parsed.append([offset_keys, offset_values])
            idx += offsets_len
        return parsed

    udfs = {}
    call_udf = []
    mapper_str = ""
    if eval_type == PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF:
        # Create function like this:
        #   lambda a: f([a[0]], [a[0], a[1]])

        # We assume there is only one UDF here because grouped map doesn't
        # support combining multiple UDFs.
        assert num_udfs == 1

        # See FlatMapGroupsInPandasExec for how arg_offsets are used to
        # distinguish between grouping attributes and data attributes
        arg_offsets, udf = read_single_udf(pickleSer,
                                           infile,
                                           eval_type,
                                           runner_conf,
                                           udf_index=0)
        udfs['f'] = udf
        parsed_offsets = extract_key_value_indexes(arg_offsets)
        keys = ["a[%d]" % (o, ) for o in parsed_offsets[0][0]]
        vals = ["a[%d]" % (o, ) for o in parsed_offsets[0][1]]
        mapper_str = "lambda a: f([%s], [%s])" % (", ".join(keys),
                                                  ", ".join(vals))
    elif eval_type == PythonEvalType.SQL_COGROUPED_MAP_PANDAS_UDF:
        # We assume there is only one UDF here because cogrouped map doesn't
        # support combining multiple UDFs.
        assert num_udfs == 1
        arg_offsets, udf = read_single_udf(pickleSer,
                                           infile,
                                           eval_type,
                                           runner_conf,
                                           udf_index=0)
        udfs['f'] = udf
        parsed_offsets = extract_key_value_indexes(arg_offsets)
        df1_keys = ["a[0][%d]" % (o, ) for o in parsed_offsets[0][0]]
        df1_vals = ["a[0][%d]" % (o, ) for o in parsed_offsets[0][1]]
        df2_keys = ["a[1][%d]" % (o, ) for o in parsed_offsets[1][0]]
        df2_vals = ["a[1][%d]" % (o, ) for o in parsed_offsets[1][1]]
        mapper_str = "lambda a: f([%s], [%s], [%s], [%s])" % (
            ", ".join(df1_keys), ", ".join(df1_vals), ", ".join(df2_keys),
            ", ".join(df2_vals))
    else:
        # Create function like this:
        #   lambda a: (f0(a[0]), f1(a[1], a[2]), f2(a[3]))
        # In the special case of a single UDF this will return a single result rather
        # than a tuple of results; this is the format that the JVM side expects.
        for i in range(num_udfs):
            arg_offsets, udf = read_single_udf(pickleSer,
                                               infile,
                                               eval_type,
                                               runner_conf,
                                               udf_index=i)
            udfs['f%d' % i] = udf
            args = ["a[%d]" % o for o in arg_offsets]
            call_udf.append("f%d(%s)" % (i, ", ".join(args)))
        mapper_str = "lambda a: (%s)" % (", ".join(call_udf))

    mapper = eval(mapper_str, udfs)
    func = lambda _, it: map(mapper, it)

    # profiling is not supported for UDF
    return func, None, ser, ser
Esempio n. 56
0
def manager():
    # Create a new process group to corral our children
    os.setpgid(0, 0)

    # Create a listening socket on the AF_INET loopback interface
    listen_sock = socket.socket(AF_INET, SOCK_STREAM)
    listen_sock.bind(("127.0.0.1", 0))
    listen_sock.listen(max(1024, SOMAXCONN))
    listen_host, listen_port = listen_sock.getsockname()

    # re-open stdin/stdout in 'wb' mode
    stdin_bin = os.fdopen(sys.stdin.fileno(), "rb", 4)
    stdout_bin = os.fdopen(sys.stdout.fileno(), "wb", 4)
    write_int(listen_port, stdout_bin)
    stdout_bin.flush()

    def shutdown(code):
        signal.signal(SIGTERM, SIG_DFL)
        # Send SIGHUP to notify workers of shutdown
        os.kill(0, SIGHUP)
        sys.exit(code)

    def handle_sigterm(*args):
        shutdown(1)

    signal.signal(SIGTERM, handle_sigterm)  # Gracefully exit on SIGTERM
    signal.signal(SIGHUP, SIG_IGN)  # Don't die on SIGHUP
    signal.signal(SIGCHLD, SIG_IGN)

    reuse = os.environ.get("SPARK_REUSE_WORKER")

    # Initialization complete
    try:
        while True:
            try:
                ready_fds = select.select([0, listen_sock], [], [], 1)[0]
            except select.error as ex:
                if ex[0] == EINTR:
                    continue
                else:
                    raise

            if 0 in ready_fds:
                try:
                    worker_pid = read_int(stdin_bin)
                except EOFError:
                    # Spark told us to exit by closing stdin
                    shutdown(0)
                try:
                    os.kill(worker_pid, signal.SIGKILL)
                except OSError:
                    pass  # process already died

            if listen_sock in ready_fds:
                try:
                    sock, _ = listen_sock.accept()
                except OSError as e:
                    if e.errno == EINTR:
                        continue
                    raise

                # Launch a worker process
                try:
                    pid = os.fork()
                except OSError as e:
                    if e.errno in (EAGAIN, EINTR):
                        time.sleep(1)
                        pid = os.fork()  # error here will shutdown daemon
                    else:
                        outfile = sock.makefile(mode="wb")
                        write_int(e.errno,
                                  outfile)  # Signal that the fork failed
                        outfile.flush()
                        outfile.close()
                        sock.close()
                        continue

                if pid == 0:
                    # in child process
                    listen_sock.close()

                    # It should close the standard input in the child process so that
                    # Python native function executions stay intact.
                    #
                    # Note that if we just close the standard input (file descriptor 0),
                    # the lowest file descriptor (file descriptor 0) will be allocated,
                    # later when other file descriptors should happen to open.
                    #
                    # Therefore, here we redirects it to '/dev/null' by duplicating
                    # another file descriptor for '/dev/null' to the standard input (0).
                    # See SPARK-26175.
                    devnull = open(os.devnull, "r")
                    os.dup2(devnull.fileno(), 0)
                    devnull.close()

                    try:
                        # Acknowledge that the fork was successful
                        outfile = sock.makefile(mode="wb")
                        write_int(os.getpid(), outfile)
                        outfile.flush()
                        outfile.close()
                        authenticated = False
                        while True:
                            code = worker(sock, authenticated)
                            if code == 0:
                                authenticated = True
                            if not reuse or code:
                                # wait for closing
                                try:
                                    while sock.recv(1024):
                                        pass
                                except Exception:
                                    pass
                                break
                            gc.collect()
                    except:
                        traceback.print_exc()
                        os._exit(1)
                    else:
                        os._exit(0)
                else:
                    sock.close()

    finally:
        shutdown(1)
Esempio n. 57
0
def main(infile, outfile):
    try:
        boot_time = time.time()
        split_index = read_int(infile)
        if split_index == -1:  # for unit tests
            exit(-1)

        version = utf8_deserializer.loads(infile)
        if version != "%d.%d" % sys.version_info[:2]:
            raise Exception(("Python in worker has different version %s than that in " +
                             "driver %s, PySpark cannot run with different minor versions") %
                            ("%d.%d" % sys.version_info[:2], version))

        # initialize global state
        shuffle.MemoryBytesSpilled = 0
        shuffle.DiskBytesSpilled = 0
        _accumulatorRegistry.clear()

        # fetch name of workdir
        spark_files_dir = utf8_deserializer.loads(infile)
        SparkFiles._root_directory = spark_files_dir
        SparkFiles._is_running_on_worker = True

        # fetch names of includes (*.zip and *.egg files) and construct PYTHONPATH
        add_path(spark_files_dir)  # *.py files that were added will be copied here
        num_python_includes = read_int(infile)
        for _ in range(num_python_includes):
            filename = utf8_deserializer.loads(infile)
            add_path(os.path.join(spark_files_dir, filename))
        if sys.version > '3':
            import importlib
            importlib.invalidate_caches()

        # fetch names and values of broadcast variables
        num_broadcast_variables = read_int(infile)
        for _ in range(num_broadcast_variables):
            bid = read_long(infile)
            if bid >= 0:
                path = utf8_deserializer.loads(infile)
                _broadcastRegistry[bid] = Broadcast(path=path)
            else:
                bid = - bid - 1
                _broadcastRegistry.pop(bid)

        _accumulatorRegistry.clear()
        row_based = read_int(infile)
        num_commands = read_int(infile)
        if row_based:
            profiler = None  # profiling is not supported for UDF
            row_func = None
            for i in range(num_commands):
                f, returnType, deserializer = read_command(pickleSer, infile)
                if row_func is None:
                    row_func = f
                else:
                    row_func = chain(row_func, f)
            serializer = deserializer
            func = lambda _, it: map(lambda x: returnType.toInternal(row_func(*x)), it)
        else:
            assert num_commands == 1
            func, profiler, deserializer, serializer = read_command(pickleSer, infile)

        init_time = time.time()

        def process():
            iterator = deserializer.load_stream(infile)
            serializer.dump_stream(func(split_index, iterator), outfile)

        if profiler:
            profiler.profile(process)
        else:
            process()
    except Exception:
        try:
            write_int(SpecialLengths.PYTHON_EXCEPTION_THROWN, outfile)
            write_with_length(traceback.format_exc().encode("utf-8"), outfile)
        except IOError:
            # JVM close the socket
            pass
        except Exception:
            # Write the error to stderr if it happened while serializing
            print("PySpark worker failed with exception:", file=sys.stderr)
            print(traceback.format_exc(), file=sys.stderr)
        exit(-1)
    finish_time = time.time()
    report_times(outfile, boot_time, init_time, finish_time)
    write_long(shuffle.MemoryBytesSpilled, outfile)
    write_long(shuffle.DiskBytesSpilled, outfile)

    # Mark the beginning of the accumulators section of the output
    write_int(SpecialLengths.END_OF_DATA_SECTION, outfile)
    write_int(len(_accumulatorRegistry), outfile)
    for (aid, accum) in _accumulatorRegistry.items():
        pickleSer._write_with_length((aid, accum._value), outfile)

    # check end of stream
    if read_int(infile) == SpecialLengths.END_OF_STREAM:
        write_int(SpecialLengths.END_OF_STREAM, outfile)
    else:
        # write a different value to tell JVM to not reuse this worker
        write_int(SpecialLengths.END_OF_DATA_SECTION, outfile)
        exit(-1)