示例#1
0
def _load_from_socket(port, auth_secret, function, all_gather_message=None):
    """
    Load data from a given socket, this is a blocking method thus only return when the socket
    connection has been closed.
    """
    (sockfile, sock) = local_connect_and_auth(port, auth_secret)

    # The call may block forever, so no timeout
    sock.settimeout(None)

    if function == BARRIER_FUNCTION:
        # Make a barrier() function call.
        write_int(function, sockfile)
    elif function == ALL_GATHER_FUNCTION:
        # Make a all_gather() function call.
        write_int(function, sockfile)
        write_with_length(all_gather_message.encode("utf-8"), sockfile)
    else:
        raise ValueError("Unrecognized function type")
    sockfile.flush()

    # Collect result.
    res = UTF8Deserializer().loads(sockfile)

    # Release resources.
    sockfile.close()
    sock.close()

    return res
示例#2
0
def main(infile, outfile):
    try:
        boot_time = time.time()
        split_index = read_int(infile)
        if split_index == -1:  # for unit tests
            return

        # fetch name of workdir
        spark_files_dir = utf8_deserializer.loads(infile)
        SparkFiles._root_directory = spark_files_dir
        SparkFiles._is_running_on_worker = True

        # fetch names of includes (*.zip and *.egg files) and construct PYTHONPATH
        sys.path.append(spark_files_dir)  # *.py files that were added will be copied here
        num_python_includes = read_int(infile)
        for _ in range(num_python_includes):
            filename = utf8_deserializer.loads(infile)
            sys.path.append(os.path.join(spark_files_dir, filename))

        # fetch names and values of broadcast variables
        num_broadcast_variables = read_int(infile)
        ser = CompressedSerializer(pickleSer)
        for _ in range(num_broadcast_variables):
            bid = read_long(infile)
            if bid >= 0:
                value = ser._read_with_length(infile)
                _broadcastRegistry[bid] = Broadcast(bid, value)
            else:
                bid = - bid - 1
                _broadcastRegistry.remove(bid)

        _accumulatorRegistry.clear()
        command = pickleSer._read_with_length(infile)
        (func, deserializer, serializer) = command
        init_time = time.time()
        iterator = deserializer.load_stream(infile)
        serializer.dump_stream(func(split_index, iterator), outfile)
    except Exception:
        try:
            write_int(SpecialLengths.PYTHON_EXCEPTION_THROWN, outfile)
            write_with_length(traceback.format_exc(), outfile)
            outfile.flush()
        except IOError:
            # JVM close the socket
            pass
        except Exception:
            # Write the error to stderr if it happened while serializing
            print >> sys.stderr, "PySpark worker failed with exception:"
            print >> sys.stderr, traceback.format_exc()
        exit(-1)
    finish_time = time.time()
    report_times(outfile, boot_time, init_time, finish_time)
    # Mark the beginning of the accumulators section of the output
    write_int(SpecialLengths.END_OF_DATA_SECTION, outfile)
    write_int(len(_accumulatorRegistry), outfile)
    for (aid, accum) in _accumulatorRegistry.items():
        pickleSer._write_with_length((aid, accum._value), outfile)
示例#3
0
def do_server_auth(conn, auth_secret):
    """
    Performs the authentication protocol defined by the SocketAuthHelper class on the given
    file-like object 'conn'.
    """
    write_with_length(auth_secret.encode("utf-8"), conn)
    conn.flush()
    reply = UTF8Deserializer().loads(conn)
    if reply != "ok":
        conn.close()
        raise Exception("Unexpected reply from iterator server.")
示例#4
0
def _do_server_auth(conn, auth_secret):
    """
    Performs the authentication protocol defined by the SocketAuthHelper class on the given
    file-like object 'conn'.
    """
    write_with_length(auth_secret.encode("utf-8"), conn)
    conn.flush()
    reply = UTF8Deserializer().loads(conn)
    if reply != "ok":
        conn.close()
        raise RuntimeError("Unexpected reply from iterator server.")
示例#5
0
def main(infile, outfile):
    try:
        boot_time = time.time()
        split_index = read_int(infile)
        if split_index == -1:  # for unit tests
            return

        # fetch name of workdir
        spark_files_dir = utf8_deserializer.loads(infile)
        SparkFiles._root_directory = spark_files_dir
        SparkFiles._is_running_on_worker = True

        # fetch names of includes (*.zip and *.egg files) and construct PYTHONPATH
        sys.path.append(spark_files_dir)  # *.py files that were added will be copied here
        num_python_includes = read_int(infile)
        for _ in range(num_python_includes):
            filename = utf8_deserializer.loads(infile)
            sys.path.append(os.path.join(spark_files_dir, filename))

        # fetch names and values of broadcast variables
        num_broadcast_variables = read_int(infile)
        ser = CompressedSerializer(pickleSer)
        for _ in range(num_broadcast_variables):
            bid = read_long(infile)
            value = ser._read_with_length(infile)
            _broadcastRegistry[bid] = Broadcast(bid, value)

        command = pickleSer._read_with_length(infile)
        (func, deserializer, serializer) = command
        init_time = time.time()
        iterator = deserializer.load_stream(infile)
        serializer.dump_stream(func(split_index, iterator), outfile)
    except Exception:
        try:
            write_int(SpecialLengths.PYTHON_EXCEPTION_THROWN, outfile)
            write_with_length(traceback.format_exc(), outfile)
            outfile.flush()
        except IOError:
            # JVM close the socket
            pass
        except Exception:
            # Write the error to stderr if it happened while serializing
            print >> sys.stderr, "PySpark worker failed with exception:"
            print >> sys.stderr, traceback.format_exc()
        exit(-1)
    finish_time = time.time()
    report_times(outfile, boot_time, init_time, finish_time)
    # Mark the beginning of the accumulators section of the output
    write_int(SpecialLengths.END_OF_DATA_SECTION, outfile)
    write_int(len(_accumulatorRegistry), outfile)
    for (aid, accum) in _accumulatorRegistry.items():
        pickleSer._write_with_length((aid, accum._value), outfile)
示例#6
0
文件: worker.py 项目: fernand/spark
def main():
    split_index = read_int(sys.stdin)
    num_broadcast_variables = read_int(sys.stdin)
    for _ in range(num_broadcast_variables):
        bid = read_long(sys.stdin)
        value = read_with_length(sys.stdin)
        _broadcastRegistry[bid] = Broadcast(bid, load_pickle(value))
    func = load_obj()
    bypassSerializer = load_obj()
    if bypassSerializer:
        dumps = lambda x: x
    else:
        dumps = dump_pickle
    iterator = read_from_pickle_file(sys.stdin)
    for obj in func(split_index, iterator):
        write_with_length(dumps(obj), old_stdout)
示例#7
0
文件: context.py 项目: fernand/spark
 def parallelize(self, c, numSlices=None):
     """
     Distribute a local Python collection to form an RDD.
     """
     numSlices = numSlices or self.defaultParallelism
     # Calling the Java parallelize() method with an ArrayList is too slow,
     # because it sends O(n) Py4J commands.  As an alternative, serialized
     # objects are written to a file and loaded through textFile().
     tempFile = NamedTemporaryFile(delete=False)
     atexit.register(lambda: os.unlink(tempFile.name))
     if self.batchSize != 1:
         c = batched(c, self.batchSize)
     for x in c:
         write_with_length(dump_pickle(x), tempFile)
     tempFile.close()
     jrdd = self._readRDDFromPickleFile(self._jsc, tempFile.name, numSlices)
     return RDD(jrdd, self)
示例#8
0
 def parallelize(self, c, numSlices=None):
     """
     Distribute a local Python collection to form an RDD.
     """
     numSlices = numSlices or self.defaultParallelism
     # Calling the Java parallelize() method with an ArrayList is too slow,
     # because it sends O(n) Py4J commands.  As an alternative, serialized
     # objects are written to a file and loaded through textFile().
     tempFile = NamedTemporaryFile(delete=False, dir=self._temp_dir)
     if self.batchSize != 1:
         c = batched(c, self.batchSize)
     for x in c:
         write_with_length(dump_pickle(x), tempFile)
     tempFile.close()
     readRDDFromPickleFile = self._jvm.PythonRDD.readRDDFromPickleFile
     jrdd = readRDDFromPickleFile(self._jsc, tempFile.name, numSlices)
     return RDD(jrdd, self)
示例#9
0
def main(infile, outfile):
    boot_time = time.time()
    split_index = read_int(infile)
    if split_index == -1:  # for unit tests
        return

    # fetch name of workdir
    spark_files_dir = mutf8_deserializer.loads(infile)
    SparkFiles._root_directory = spark_files_dir
    SparkFiles._is_running_on_worker = True

    # fetch names and values of broadcast variables
    num_broadcast_variables = read_int(infile)
    for _ in range(num_broadcast_variables):
        bid = read_long(infile)
        value = pickleSer._read_with_length(infile)
        _broadcastRegistry[bid] = Broadcast(bid, value)

    # fetch names of includes (*.zip and *.egg files) and construct PYTHONPATH
    sys.path.append(spark_files_dir) # *.py files that were added will be copied here
    num_python_includes =  read_int(infile)
    for _ in range(num_python_includes):
        filename = mutf8_deserializer.loads(infile)
        sys.path.append(os.path.join(spark_files_dir, filename))

    command = pickleSer._read_with_length(infile)
    (func, deserializer, serializer) = command
    init_time = time.time()
    try:
        iterator = deserializer.load_stream(infile)
        serializer.dump_stream(func(split_index, iterator), outfile)
    except Exception as e:
        write_int(SpecialLengths.PYTHON_EXCEPTION_THROWN, outfile)
        write_with_length(traceback.format_exc(), outfile)
        sys.exit(-1)
    finish_time = time.time()
    report_times(outfile, boot_time, init_time, finish_time)
    # Mark the beginning of the accumulators section of the output
    write_int(SpecialLengths.END_OF_DATA_SECTION, outfile)
    write_int(len(_accumulatorRegistry), outfile)
    for (aid, accum) in _accumulatorRegistry.items():
        pickleSer._write_with_length((aid, accum._value), outfile)
示例#10
0
文件: worker.py 项目: zhngc3170/spark
def main():
    split_index = read_int(sys.stdin)
    spark_files_dir = load_pickle(read_with_length(sys.stdin))
    SparkFiles._root_directory = spark_files_dir
    SparkFiles._is_running_on_worker = True
    sys.path.append(spark_files_dir)
    num_broadcast_variables = read_int(sys.stdin)
    for _ in range(num_broadcast_variables):
        bid = read_long(sys.stdin)
        value = read_with_length(sys.stdin)
        _broadcastRegistry[bid] = Broadcast(bid, load_pickle(value))
    func = load_obj()
    bypassSerializer = load_obj()
    if bypassSerializer:
        dumps = lambda x: x
    else:
        dumps = dump_pickle
    iterator = read_from_pickle_file(sys.stdin)
    try:
        for obj in func(split_index, iterator):
            write_with_length(dumps(obj), old_stdout)
    except Exception as e:
        write_int(-2, old_stdout)
        write_with_length(traceback.format_exc(), old_stdout)
        sys.exit(-1)
    # Mark the beginning of the accumulators section of the output
    write_int(-1, old_stdout)
    for aid, accum in _accumulatorRegistry.items():
        write_with_length(dump_pickle((aid, accum._value)), old_stdout)
示例#11
0
文件: worker.py 项目: Alienfeel/spark
def main():
    split_index = read_int(sys.stdin)
    spark_files_dir = load_pickle(read_with_length(sys.stdin))
    SparkFiles._root_directory = spark_files_dir
    SparkFiles._is_running_on_worker = True
    sys.path.append(spark_files_dir)
    num_broadcast_variables = read_int(sys.stdin)
    for _ in range(num_broadcast_variables):
        bid = read_long(sys.stdin)
        value = read_with_length(sys.stdin)
        _broadcastRegistry[bid] = Broadcast(bid, load_pickle(value))
    func = load_obj()
    bypassSerializer = load_obj()
    if bypassSerializer:
        dumps = lambda x: x
    else:
        dumps = dump_pickle
    iterator = read_from_pickle_file(sys.stdin)
    try:
        for obj in func(split_index, iterator):
           write_with_length(dumps(obj), old_stdout)
    except Exception as e:
        write_int(-2, old_stdout)
        write_with_length(traceback.format_exc(), old_stdout)
        sys.exit(-1)
    # Mark the beginning of the accumulators section of the output
    write_int(-1, old_stdout)
    for aid, accum in _accumulatorRegistry.items():
        write_with_length(dump_pickle((aid, accum._value)), old_stdout)
示例#12
0
文件: daemon.py 项目: zoelin7/spark
def worker(sock, authenticated):
    """
    Called by a worker process after the fork().
    """
    signal.signal(SIGHUP, SIG_DFL)
    signal.signal(SIGCHLD, SIG_DFL)
    signal.signal(SIGTERM, SIG_DFL)
    # restore the handler for SIGINT,
    # it's useful for debugging (show the stacktrace before exit)
    signal.signal(SIGINT, signal.default_int_handler)

    # Read the socket using fdopen instead of socket.makefile() because the latter
    # seems to be very slow; note that we need to dup() the file descriptor because
    # otherwise writes also cause a seek that makes us miss data on the read side.
    buffer_size = int(os.environ.get("SPARK_BUFFER_SIZE", 65536))
    infile = os.fdopen(os.dup(sock.fileno()), "rb", buffer_size)
    outfile = os.fdopen(os.dup(sock.fileno()), "wb", buffer_size)

    if not authenticated:
        client_secret = UTF8Deserializer().loads(infile)
        if os.environ["PYTHON_WORKER_FACTORY_SECRET"] == client_secret:
            write_with_length("ok".encode("utf-8"), outfile)
            outfile.flush()
        else:
            write_with_length("err".encode("utf-8"), outfile)
            outfile.flush()
            sock.close()
            return 1

    exit_code = 0
    try:
        worker_main(infile, outfile)
    except SystemExit as exc:
        exit_code = compute_real_exit_code(exc.code)
    finally:
        try:
            outfile.flush()
        except Exception:
            pass
    return exit_code
示例#13
0
文件: daemon.py 项目: BaiBenny/spark
def worker(sock, authenticated):
    """
    Called by a worker process after the fork().
    """
    signal.signal(SIGHUP, SIG_DFL)
    signal.signal(SIGCHLD, SIG_DFL)
    signal.signal(SIGTERM, SIG_DFL)
    # restore the handler for SIGINT,
    # it's useful for debugging (show the stacktrace before exit)
    signal.signal(SIGINT, signal.default_int_handler)

    # Read the socket using fdopen instead of socket.makefile() because the latter
    # seems to be very slow; note that we need to dup() the file descriptor because
    # otherwise writes also cause a seek that makes us miss data on the read side.
    infile = os.fdopen(os.dup(sock.fileno()), "rb", 65536)
    outfile = os.fdopen(os.dup(sock.fileno()), "wb", 65536)

    if not authenticated:
        client_secret = UTF8Deserializer().loads(infile)
        if os.environ["PYTHON_WORKER_FACTORY_SECRET"] == client_secret:
            write_with_length("ok".encode("utf-8"), outfile)
            outfile.flush()
        else:
            write_with_length("err".encode("utf-8"), outfile)
            outfile.flush()
            sock.close()
            return 1

    exit_code = 0
    try:
        worker_main(infile, outfile)
    except SystemExit as exc:
        exit_code = compute_real_exit_code(exc.code)
    finally:
        try:
            outfile.flush()
        except Exception:
            pass
    return exit_code
示例#14
0
def _load_from_socket(
    port: Optional[Union[str, int]],
    auth_secret: str,
    function: int,
    all_gather_message: Optional[str] = None,
) -> List[str]:
    """
    Load data from a given socket, this is a blocking method thus only return when the socket
    connection has been closed.
    """
    (sockfile, sock) = local_connect_and_auth(port, auth_secret)

    # The call may block forever, so no timeout
    sock.settimeout(None)

    if function == BARRIER_FUNCTION:
        # Make a barrier() function call.
        write_int(function, sockfile)
    elif function == ALL_GATHER_FUNCTION:
        # Make a all_gather() function call.
        write_int(function, sockfile)
        write_with_length(
            cast(str, all_gather_message).encode("utf-8"), sockfile)
    else:
        raise ValueError("Unrecognized function type")
    sockfile.flush()

    # Collect result.
    len = read_int(sockfile)
    res = []
    for i in range(len):
        res.append(UTF8Deserializer().loads(sockfile))

    # Release resources.
    sockfile.close()
    sock.close()

    return res
示例#15
0
    def parallelize(self, c, numSlices=None):
        """
        Distribute a local Python collection to form an RDD.

        >>> sc.parallelize(range(5), 5).glom().collect()
        [[0], [1], [2], [3], [4]]
        """
        numSlices = numSlices or self.defaultParallelism
        # Calling the Java parallelize() method with an ArrayList is too slow,
        # because it sends O(n) Py4J commands.  As an alternative, serialized
        # objects are written to a file and loaded through textFile().
        tempFile = NamedTemporaryFile(delete=False, dir=self._temp_dir)
        # Make sure we distribute data evenly if it's smaller than self.batchSize
        if "__len__" not in dir(c):
            c = list(c)    # Make it a list so we can compute its length
        batchSize = min(len(c) // numSlices, self.batchSize)
        if batchSize > 1:
            c = batched(c, batchSize)
        for x in c:
            write_with_length(dump_pickle(x), tempFile)
        tempFile.close()
        readRDDFromPickleFile = self._jvm.PythonRDD.readRDDFromPickleFile
        jrdd = readRDDFromPickleFile(self._jsc, tempFile.name, numSlices)
        return RDD(jrdd, self)
示例#16
0
    def parallelize(self, c, numSlices=None):
        """
        Distribute a local Python collection to form an RDD.

        >>> sc.parallelize(range(5), 5).glom().collect()
        [[0], [1], [2], [3], [4]]
        """
        numSlices = numSlices or self.defaultParallelism
        # Calling the Java parallelize() method with an ArrayList is too slow,
        # because it sends O(n) Py4J commands.  As an alternative, serialized
        # objects are written to a file and loaded through textFile().
        tempFile = NamedTemporaryFile(delete=False, dir=self._temp_dir)
        # Make sure we distribute data evenly if it's smaller than self.batchSize
        if "__len__" not in dir(c):
            c = list(c)  # Make it a list so we can compute its length
        batchSize = min(len(c) // numSlices, self.batchSize)
        if batchSize > 1:
            c = batched(c, batchSize)
        for x in c:
            write_with_length(dump_pickle(x), tempFile)
        tempFile.close()
        readRDDFromPickleFile = self._jvm.PythonRDD.readRDDFromPickleFile
        jrdd = readRDDFromPickleFile(self._jsc, tempFile.name, numSlices)
        return RDD(jrdd, self)
示例#17
0
文件: worker.py 项目: ljzzju/spark-1
def main(infile, outfile):
    boot_time = time.time()
    split_index = read_int(infile)
    if split_index == -1:  # for unit tests
        return

    # fetch name of workdir
    spark_files_dir = load_pickle(read_with_length(infile))
    SparkFiles._root_directory = spark_files_dir
    SparkFiles._is_running_on_worker = True

    # fetch names and values of broadcast variables
    num_broadcast_variables = read_int(infile)
    for _ in range(num_broadcast_variables):
        bid = read_long(infile)
        value = read_with_length(infile)
        _broadcastRegistry[bid] = Broadcast(bid, load_pickle(value))

    # fetch names of includes (*.zip and *.egg files) and construct PYTHONPATH
    sys.path.append(
        spark_files_dir)  # *.py files that were added will be copied here
    num_python_includes = read_int(infile)
    for _ in range(num_python_includes):
        sys.path.append(
            os.path.join(spark_files_dir,
                         load_pickle(read_with_length(infile))))

    # now load function
    func = load_obj(infile)
    bypassSerializer = load_obj(infile)
    if bypassSerializer:
        dumps = lambda x: x
    else:
        dumps = dump_pickle
    init_time = time.time()
    iterator = read_from_pickle_file(infile)
    try:
        for obj in func(split_index, iterator):
            write_with_length(dumps(obj), outfile)
    except Exception as e:
        write_int(-2, outfile)
        write_with_length(traceback.format_exc(), outfile)
        sys.exit(-1)
    finish_time = time.time()
    report_times(outfile, boot_time, init_time, finish_time)
    # Mark the beginning of the accumulators section of the output
    write_int(-1, outfile)
    for aid, accum in _accumulatorRegistry.items():
        write_with_length(dump_pickle((aid, accum._value)), outfile)
    write_int(-1, outfile)
示例#18
0
def main(infile, outfile):
    boot_time = time.time()
    split_index = read_int(infile)
    if split_index == -1:  # for unit tests
        return

    # fetch name of workdir
    spark_files_dir = load_pickle(read_with_length(infile))
    SparkFiles._root_directory = spark_files_dir
    SparkFiles._is_running_on_worker = True

    # fetch names and values of broadcast variables
    num_broadcast_variables = read_int(infile)
    for _ in range(num_broadcast_variables):
        bid = read_long(infile)
        value = read_with_length(infile)
        _broadcastRegistry[bid] = Broadcast(bid, load_pickle(value))

    # fetch names of includes (*.zip and *.egg files) and construct PYTHONPATH
    sys.path.append(spark_files_dir) # *.py files that were added will be copied here
    num_python_includes =  read_int(infile)
    for _ in range(num_python_includes):
        sys.path.append(os.path.join(spark_files_dir, load_pickle(read_with_length(infile))))

    # now load function
    func = load_obj(infile)
    bypassSerializer = load_obj(infile)
    if bypassSerializer:
        dumps = lambda x: x
    else:
        dumps = dump_pickle
    init_time = time.time()
    iterator = read_from_pickle_file(infile)
    try:
        for obj in func(split_index, iterator):
            write_with_length(dumps(obj), outfile)
    except Exception as e:
        write_int(-2, outfile)
        write_with_length(traceback.format_exc(), outfile)
        sys.exit(-1)
    finish_time = time.time()
    report_times(outfile, boot_time, init_time, finish_time)
    # Mark the beginning of the accumulators section of the output
    write_int(-1, outfile)
    for aid, accum in _accumulatorRegistry.items():
        write_with_length(dump_pickle((aid, accum._value)), outfile)
    write_int(-1, outfile)
示例#19
0
def main(infile, outfile):
    boot_time = time.time()
    split_index = read_int(infile)
    if split_index == -1:  # for unit tests
        return
    spark_files_dir = load_pickle(read_with_length(infile))
    SparkFiles._root_directory = spark_files_dir
    SparkFiles._is_running_on_worker = True
    sys.path.append(spark_files_dir)
    num_broadcast_variables = read_int(infile)
    for _ in range(num_broadcast_variables):
        bid = read_long(infile)
        value = read_with_length(infile)
        _broadcastRegistry[bid] = Broadcast(bid, load_pickle(value))
    func = load_obj(infile)
    bypassSerializer = load_obj(infile)
    if bypassSerializer:
        dumps = lambda x: x
    else:
        dumps = dump_pickle
    init_time = time.time()
    iterator = read_from_pickle_file(infile)
    try:
        for obj in func(split_index, iterator):
            write_with_length(dumps(obj), outfile)
    except Exception as e:
        write_int(-2, outfile)
        write_with_length(traceback.format_exc(), outfile)
        sys.exit(-1)
    finish_time = time.time()
    report_times(outfile, boot_time, init_time, finish_time)
    # Mark the beginning of the accumulators section of the output
    write_int(-1, outfile)
    for aid, accum in _accumulatorRegistry.items():
        write_with_length(dump_pickle((aid, accum._value)), outfile)
    write_int(-1, outfile)
示例#20
0
def main(infile, outfile):
    boot_time = time.time()
    split_index = read_int(infile)
    if split_index == -1:  # for unit tests
        return
    spark_files_dir = load_pickle(read_with_length(infile))
    SparkFiles._root_directory = spark_files_dir
    SparkFiles._is_running_on_worker = True
    sys.path.append(spark_files_dir)
    num_broadcast_variables = read_int(infile)
    for _ in range(num_broadcast_variables):
        bid = read_long(infile)
        value = read_with_length(infile)
        _broadcastRegistry[bid] = Broadcast(bid, load_pickle(value))
    func = load_obj(infile)
    bypassSerializer = load_obj(infile)
    if bypassSerializer:
        dumps = lambda x: x
    else:
        dumps = dump_pickle
    init_time = time.time()
    iterator = read_from_pickle_file(infile)
    try:
        for obj in func(split_index, iterator):
            write_with_length(dumps(obj), outfile)
    except Exception as e:
        write_int(-2, outfile)
        write_with_length(traceback.format_exc(), outfile)
        sys.exit(-1)
    finish_time = time.time()
    report_times(outfile, boot_time, init_time, finish_time)
    # Mark the beginning of the accumulators section of the output
    write_int(-1, outfile)
    for aid, accum in _accumulatorRegistry.items():
        write_with_length(dump_pickle((aid, accum._value)), outfile)
    write_int(-1, outfile)
示例#21
0
def main(infile, outfile):
    try:
        boot_time = time.time()
        split_index = read_int(infile)
        if split_index == -1:  # for unit tests
            exit(-1)

        version = utf8_deserializer.loads(infile)
        if version != "%d.%d" % sys.version_info[:2]:
            raise Exception(("Python in worker has different version %s than that in " +
                             "driver %s, PySpark cannot run with different minor versions") %
                            ("%d.%d" % sys.version_info[:2], version))

        # initialize global state
        shuffle.MemoryBytesSpilled = 0
        shuffle.DiskBytesSpilled = 0
        _accumulatorRegistry.clear()

        # fetch name of workdir
        spark_files_dir = utf8_deserializer.loads(infile)
        SparkFiles._root_directory = spark_files_dir
        SparkFiles._is_running_on_worker = True

        # fetch names of includes (*.zip and *.egg files) and construct PYTHONPATH
        add_path(spark_files_dir)  # *.py files that were added will be copied here
        num_python_includes = read_int(infile)
        for _ in range(num_python_includes):
            filename = utf8_deserializer.loads(infile)
            add_path(os.path.join(spark_files_dir, filename))
        if sys.version > '3':
            import importlib
            importlib.invalidate_caches()

        # fetch names and values of broadcast variables
        num_broadcast_variables = read_int(infile)
        for _ in range(num_broadcast_variables):
            bid = read_long(infile)
            if bid >= 0:
                path = utf8_deserializer.loads(infile)
                _broadcastRegistry[bid] = Broadcast(path=path)
            else:
                bid = - bid - 1
                _broadcastRegistry.pop(bid)

        _accumulatorRegistry.clear()
        row_based = read_int(infile)
        num_commands = read_int(infile)
        if row_based:
            profiler = None  # profiling is not supported for UDF
            row_func = None
            for i in range(num_commands):
                f, returnType, deserializer = read_command(pickleSer, infile)
                if row_func is None:
                    row_func = f
                else:
                    row_func = chain(row_func, f)
            serializer = deserializer
            func = lambda _, it: map(lambda x: returnType.toInternal(row_func(*x)), it)
        else:
            assert num_commands == 1
            func, profiler, deserializer, serializer = read_command(pickleSer, infile)

        init_time = time.time()

        def process():
            iterator = deserializer.load_stream(infile)
            serializer.dump_stream(func(split_index, iterator), outfile)

        if profiler:
            profiler.profile(process)
        else:
            process()
    except Exception:
        try:
            write_int(SpecialLengths.PYTHON_EXCEPTION_THROWN, outfile)
            write_with_length(traceback.format_exc().encode("utf-8"), outfile)
        except IOError:
            # JVM close the socket
            pass
        except Exception:
            # Write the error to stderr if it happened while serializing
            print("PySpark worker failed with exception:", file=sys.stderr)
            print(traceback.format_exc(), file=sys.stderr)
        exit(-1)
    finish_time = time.time()
    report_times(outfile, boot_time, init_time, finish_time)
    write_long(shuffle.MemoryBytesSpilled, outfile)
    write_long(shuffle.DiskBytesSpilled, outfile)

    # Mark the beginning of the accumulators section of the output
    write_int(SpecialLengths.END_OF_DATA_SECTION, outfile)
    write_int(len(_accumulatorRegistry), outfile)
    for (aid, accum) in _accumulatorRegistry.items():
        pickleSer._write_with_length((aid, accum._value), outfile)

    # check end of stream
    if read_int(infile) == SpecialLengths.END_OF_STREAM:
        write_int(SpecialLengths.END_OF_STREAM, outfile)
    else:
        # write a different value to tell JVM to not reuse this worker
        write_int(SpecialLengths.END_OF_DATA_SECTION, outfile)
        exit(-1)
示例#22
0
def main(infile, outfile):
    try:
        boot_time = time.time()
        split_index = read_int(infile)
        if split_index == -1:  # for unit tests
            sys.exit(-1)

        version = utf8_deserializer.loads(infile)
        if version != "%d.%d" % sys.version_info[:2]:
            raise Exception(("Python in worker has different version %s than that in " +
                             "driver %s, PySpark cannot run with different minor versions." +
                             "Please check environment variables PYSPARK_PYTHON and " +
                             "PYSPARK_DRIVER_PYTHON are correctly set.") %
                            ("%d.%d" % sys.version_info[:2], version))

        # read inputs only for a barrier task
        isBarrier = read_bool(infile)
        boundPort = read_int(infile)
        secret = UTF8Deserializer().loads(infile)

        # set up memory limits
        memory_limit_mb = int(os.environ.get('PYSPARK_EXECUTOR_MEMORY_MB', "-1"))
        if memory_limit_mb > 0 and has_resource_module:
            total_memory = resource.RLIMIT_AS
            try:
                (soft_limit, hard_limit) = resource.getrlimit(total_memory)
                msg = "Current mem limits: {0} of max {1}\n".format(soft_limit, hard_limit)
                print(msg, file=sys.stderr)

                # convert to bytes
                new_limit = memory_limit_mb * 1024 * 1024

                if soft_limit == resource.RLIM_INFINITY or new_limit < soft_limit:
                    msg = "Setting mem limits to {0} of max {1}\n".format(new_limit, new_limit)
                    print(msg, file=sys.stderr)
                    resource.setrlimit(total_memory, (new_limit, new_limit))

            except (resource.error, OSError, ValueError) as e:
                # not all systems support resource limits, so warn instead of failing
                print("WARN: Failed to set memory limit: {0}\n".format(e), file=sys.stderr)

        # initialize global state
        taskContext = None
        if isBarrier:
            taskContext = BarrierTaskContext._getOrCreate()
            BarrierTaskContext._initialize(boundPort, secret)
        else:
            taskContext = TaskContext._getOrCreate()
        # read inputs for TaskContext info
        taskContext._stageId = read_int(infile)
        taskContext._partitionId = read_int(infile)
        taskContext._attemptNumber = read_int(infile)
        taskContext._taskAttemptId = read_long(infile)
        taskContext._localProperties = dict()
        for i in range(read_int(infile)):
            k = utf8_deserializer.loads(infile)
            v = utf8_deserializer.loads(infile)
            taskContext._localProperties[k] = v

        shuffle.MemoryBytesSpilled = 0
        shuffle.DiskBytesSpilled = 0
        _accumulatorRegistry.clear()

        # fetch name of workdir
        spark_files_dir = utf8_deserializer.loads(infile)
        SparkFiles._root_directory = spark_files_dir
        SparkFiles._is_running_on_worker = True

        # fetch names of includes (*.zip and *.egg files) and construct PYTHONPATH
        add_path(spark_files_dir)  # *.py files that were added will be copied here
        num_python_includes = read_int(infile)
        for _ in range(num_python_includes):
            filename = utf8_deserializer.loads(infile)
            add_path(os.path.join(spark_files_dir, filename))
        if sys.version > '3':
            import importlib
            importlib.invalidate_caches()

        # fetch names and values of broadcast variables
        needs_broadcast_decryption_server = read_bool(infile)
        num_broadcast_variables = read_int(infile)
        if needs_broadcast_decryption_server:
            # read the decrypted data from a server in the jvm
            port = read_int(infile)
            auth_secret = utf8_deserializer.loads(infile)
            (broadcast_sock_file, _) = local_connect_and_auth(port, auth_secret)

        for _ in range(num_broadcast_variables):
            bid = read_long(infile)
            if bid >= 0:
                if needs_broadcast_decryption_server:
                    read_bid = read_long(broadcast_sock_file)
                    assert(read_bid == bid)
                    _broadcastRegistry[bid] = \
                        Broadcast(sock_file=broadcast_sock_file)
                else:
                    path = utf8_deserializer.loads(infile)
                    _broadcastRegistry[bid] = Broadcast(path=path)

            else:
                bid = - bid - 1
                _broadcastRegistry.pop(bid)

        if needs_broadcast_decryption_server:
            broadcast_sock_file.write(b'1')
            broadcast_sock_file.close()

        _accumulatorRegistry.clear()
        eval_type = read_int(infile)
        if eval_type == PythonEvalType.NON_UDF:
            func, profiler, deserializer, serializer = read_command(pickleSer, infile)
        else:
            func, profiler, deserializer, serializer = read_udfs(pickleSer, infile, eval_type)

        init_time = time.time()

        def process():
            iterator = deserializer.load_stream(infile)
            serializer.dump_stream(func(split_index, iterator), outfile)

        if profiler:
            profiler.profile(process)
        else:
            process()
    except Exception:
        try:
            write_int(SpecialLengths.PYTHON_EXCEPTION_THROWN, outfile)
            write_with_length(traceback.format_exc().encode("utf-8"), outfile)
        except IOError:
            # JVM close the socket
            pass
        except Exception:
            # Write the error to stderr if it happened while serializing
            print("PySpark worker failed with exception:", file=sys.stderr)
            print(traceback.format_exc(), file=sys.stderr)
        sys.exit(-1)
    finish_time = time.time()
    report_times(outfile, boot_time, init_time, finish_time)
    write_long(shuffle.MemoryBytesSpilled, outfile)
    write_long(shuffle.DiskBytesSpilled, outfile)

    # Mark the beginning of the accumulators section of the output
    write_int(SpecialLengths.END_OF_DATA_SECTION, outfile)
    write_int(len(_accumulatorRegistry), outfile)
    for (aid, accum) in _accumulatorRegistry.items():
        pickleSer._write_with_length((aid, accum._value), outfile)

    # check end of stream
    if read_int(infile) == SpecialLengths.END_OF_STREAM:
        write_int(SpecialLengths.END_OF_STREAM, outfile)
    else:
        # write a different value to tell JVM to not reuse this worker
        write_int(SpecialLengths.END_OF_DATA_SECTION, outfile)
        sys.exit(-1)
示例#23
0
文件: worker.py 项目: mgyucht/spark-1
def main(infile, outfile):
    try:
        boot_time = time.time()
        split_index = read_int(infile)
        if split_index == -1:  # for unit tests
            sys.exit(-1)

        version = utf8_deserializer.loads(infile)
        if version != "%d.%d" % sys.version_info[:2]:
            raise Exception(
                ("Python in worker has different version %s than that in " +
                 "driver %s, PySpark cannot run with different minor versions."
                 + "Please check environment variables PYSPARK_PYTHON and " +
                 "PYSPARK_DRIVER_PYTHON are correctly set.") %
                ("%d.%d" % sys.version_info[:2], version))

        # initialize global state
        taskContext = TaskContext._getOrCreate()
        taskContext._stageId = read_int(infile)
        taskContext._partitionId = read_int(infile)
        taskContext._attemptNumber = read_int(infile)
        taskContext._taskAttemptId = read_long(infile)
        shuffle.MemoryBytesSpilled = 0
        shuffle.DiskBytesSpilled = 0
        _accumulatorRegistry.clear()

        # fetch name of workdir
        spark_files_dir = utf8_deserializer.loads(infile)
        SparkFiles._root_directory = spark_files_dir
        SparkFiles._is_running_on_worker = True

        # fetch names of includes (*.zip and *.egg files) and construct PYTHONPATH
        add_path(
            spark_files_dir)  # *.py files that were added will be copied here
        num_python_includes = read_int(infile)
        for _ in range(num_python_includes):
            filename = utf8_deserializer.loads(infile)
            add_path(os.path.join(spark_files_dir, filename))
        if sys.version > '3':
            import importlib
            importlib.invalidate_caches()

        # fetch names and values of broadcast variables
        num_broadcast_variables = read_int(infile)
        for _ in range(num_broadcast_variables):
            bid = read_long(infile)
            if bid >= 0:
                path = utf8_deserializer.loads(infile)
                _broadcastRegistry[bid] = Broadcast(path=path)
            else:
                bid = -bid - 1
                _broadcastRegistry.pop(bid)

        _accumulatorRegistry.clear()
        eval_type = read_int(infile)
        if eval_type == PythonEvalType.NON_UDF:
            func, profiler, deserializer, serializer = read_command(
                pickleSer, infile)
        else:
            func, profiler, deserializer, serializer = read_udfs(
                pickleSer, infile, eval_type)

        init_time = time.time()

        def process():
            iterator = deserializer.load_stream(infile)
            serializer.dump_stream(func(split_index, iterator), outfile)

        if profiler:
            profiler.profile(process)
        else:
            process()
    except Exception:
        try:
            write_int(SpecialLengths.PYTHON_EXCEPTION_THROWN, outfile)
            write_with_length(traceback.format_exc().encode("utf-8"), outfile)
        except IOError:
            # JVM close the socket
            pass
        except Exception:
            # Write the error to stderr if it happened while serializing
            print("PySpark worker failed with exception:", file=sys.stderr)
            print(traceback.format_exc(), file=sys.stderr)
        sys.exit(-1)
    finish_time = time.time()
    report_times(outfile, boot_time, init_time, finish_time)
    write_long(shuffle.MemoryBytesSpilled, outfile)
    write_long(shuffle.DiskBytesSpilled, outfile)

    # Mark the beginning of the accumulators section of the output
    write_int(SpecialLengths.END_OF_DATA_SECTION, outfile)
    write_int(len(_accumulatorRegistry), outfile)
    for (aid, accum) in _accumulatorRegistry.items():
        pickleSer._write_with_length((aid, accum._value), outfile)

    # check end of stream
    if read_int(infile) == SpecialLengths.END_OF_STREAM:
        write_int(SpecialLengths.END_OF_STREAM, outfile)
    else:
        # write a different value to tell JVM to not reuse this worker
        write_int(SpecialLengths.END_OF_DATA_SECTION, outfile)
        sys.exit(-1)
示例#24
0
文件: worker.py 项目: 0asa/spark
def main(infile, outfile):
    try:
        boot_time = time.time()
        split_index = read_int(infile)
        if split_index == -1:  # for unit tests
            exit(-1)

        # initialize global state
        shuffle.MemoryBytesSpilled = 0
        shuffle.DiskBytesSpilled = 0
        _accumulatorRegistry.clear()

        # fetch name of workdir
        spark_files_dir = utf8_deserializer.loads(infile)
        SparkFiles._root_directory = spark_files_dir
        SparkFiles._is_running_on_worker = True

        # fetch names of includes (*.zip and *.egg files) and construct PYTHONPATH
        add_path(spark_files_dir)  # *.py files that were added will be copied here
        num_python_includes = read_int(infile)
        for _ in range(num_python_includes):
            filename = utf8_deserializer.loads(infile)
            add_path(os.path.join(spark_files_dir, filename))

        # fetch names and values of broadcast variables
        num_broadcast_variables = read_int(infile)
        bser = LargeObjectSerializer()
        for _ in range(num_broadcast_variables):
            bid = read_long(infile)
            if bid >= 0:
                size = read_long(infile)
                s = SizeLimitedStream(infile, size)
                value = list((bser.load_stream(s)))[0]  # read out all the bytes
                _broadcastRegistry[bid] = Broadcast(bid, value)
            else:
                bid = - bid - 1
                _broadcastRegistry.pop(bid)

        _accumulatorRegistry.clear()
        command = pickleSer._read_with_length(infile)
        if isinstance(command, Broadcast):
            command = pickleSer.loads(command.value)
        (func, stats, deserializer, serializer) = command
        init_time = time.time()

        def process():
            iterator = deserializer.load_stream(infile)
            serializer.dump_stream(func(split_index, iterator), outfile)

        if stats:
            p = cProfile.Profile()
            p.runcall(process)
            st = pstats.Stats(p)
            st.stream = None  # make it picklable
            stats.add(st.strip_dirs())
        else:
            process()
    except Exception:
        try:
            write_int(SpecialLengths.PYTHON_EXCEPTION_THROWN, outfile)
            write_with_length(traceback.format_exc(), outfile)
        except IOError:
            # JVM close the socket
            pass
        except Exception:
            # Write the error to stderr if it happened while serializing
            print >> sys.stderr, "PySpark worker failed with exception:"
            print >> sys.stderr, traceback.format_exc()
        exit(-1)
    finish_time = time.time()
    report_times(outfile, boot_time, init_time, finish_time)
    write_long(shuffle.MemoryBytesSpilled, outfile)
    write_long(shuffle.DiskBytesSpilled, outfile)

    # Mark the beginning of the accumulators section of the output
    write_int(SpecialLengths.END_OF_DATA_SECTION, outfile)
    write_int(len(_accumulatorRegistry), outfile)
    for (aid, accum) in _accumulatorRegistry.items():
        pickleSer._write_with_length((aid, accum._value), outfile)

    # check end of stream
    if read_int(infile) == SpecialLengths.END_OF_STREAM:
        write_int(SpecialLengths.END_OF_STREAM, outfile)
    else:
        # write a different value to tell JVM to not reuse this worker
        write_int(SpecialLengths.END_OF_DATA_SECTION, outfile)
        exit(-1)
示例#25
0
def main(infile, outfile):
    faulthandler_log_path = os.environ.get("PYTHON_FAULTHANDLER_DIR", None)
    try:
        if faulthandler_log_path:
            faulthandler_log_path = os.path.join(faulthandler_log_path, str(os.getpid()))
            faulthandler_log_file = open(faulthandler_log_path, "w")
            faulthandler.enable(file=faulthandler_log_file)

        boot_time = time.time()
        split_index = read_int(infile)
        if split_index == -1:  # for unit tests
            sys.exit(-1)

        version = utf8_deserializer.loads(infile)
        if version != "%d.%d" % sys.version_info[:2]:
            raise RuntimeError(
                (
                    "Python in worker has different version %s than that in "
                    + "driver %s, PySpark cannot run with different minor versions. "
                    + "Please check environment variables PYSPARK_PYTHON and "
                    + "PYSPARK_DRIVER_PYTHON are correctly set."
                )
                % ("%d.%d" % sys.version_info[:2], version)
            )

        # read inputs only for a barrier task
        isBarrier = read_bool(infile)
        boundPort = read_int(infile)
        secret = UTF8Deserializer().loads(infile)

        # set up memory limits
        memory_limit_mb = int(os.environ.get("PYSPARK_EXECUTOR_MEMORY_MB", "-1"))
        if memory_limit_mb > 0 and has_resource_module:
            total_memory = resource.RLIMIT_AS
            try:
                (soft_limit, hard_limit) = resource.getrlimit(total_memory)
                msg = "Current mem limits: {0} of max {1}\n".format(soft_limit, hard_limit)
                print(msg, file=sys.stderr)

                # convert to bytes
                new_limit = memory_limit_mb * 1024 * 1024

                if soft_limit == resource.RLIM_INFINITY or new_limit < soft_limit:
                    msg = "Setting mem limits to {0} of max {1}\n".format(new_limit, new_limit)
                    print(msg, file=sys.stderr)
                    resource.setrlimit(total_memory, (new_limit, new_limit))

            except (resource.error, OSError, ValueError) as e:
                # not all systems support resource limits, so warn instead of failing
                lineno = (
                    getframeinfo(currentframe()).lineno + 1 if currentframe() is not None else 0
                )
                print(
                    warnings.formatwarning(
                        "Failed to set memory limit: {0}".format(e),
                        ResourceWarning,
                        __file__,
                        lineno,
                    ),
                    file=sys.stderr,
                )

        # initialize global state
        taskContext = None
        if isBarrier:
            taskContext = BarrierTaskContext._getOrCreate()
            BarrierTaskContext._initialize(boundPort, secret)
            # Set the task context instance here, so we can get it by TaskContext.get for
            # both TaskContext and BarrierTaskContext
            TaskContext._setTaskContext(taskContext)
        else:
            taskContext = TaskContext._getOrCreate()
        # read inputs for TaskContext info
        taskContext._stageId = read_int(infile)
        taskContext._partitionId = read_int(infile)
        taskContext._attemptNumber = read_int(infile)
        taskContext._taskAttemptId = read_long(infile)
        taskContext._cpus = read_int(infile)
        taskContext._resources = {}
        for r in range(read_int(infile)):
            key = utf8_deserializer.loads(infile)
            name = utf8_deserializer.loads(infile)
            addresses = []
            taskContext._resources = {}
            for a in range(read_int(infile)):
                addresses.append(utf8_deserializer.loads(infile))
            taskContext._resources[key] = ResourceInformation(name, addresses)

        taskContext._localProperties = dict()
        for i in range(read_int(infile)):
            k = utf8_deserializer.loads(infile)
            v = utf8_deserializer.loads(infile)
            taskContext._localProperties[k] = v

        shuffle.MemoryBytesSpilled = 0
        shuffle.DiskBytesSpilled = 0
        _accumulatorRegistry.clear()

        # fetch name of workdir
        spark_files_dir = utf8_deserializer.loads(infile)
        SparkFiles._root_directory = spark_files_dir
        SparkFiles._is_running_on_worker = True

        # fetch names of includes (*.zip and *.egg files) and construct PYTHONPATH
        add_path(spark_files_dir)  # *.py files that were added will be copied here
        num_python_includes = read_int(infile)
        for _ in range(num_python_includes):
            filename = utf8_deserializer.loads(infile)
            add_path(os.path.join(spark_files_dir, filename))

        importlib.invalidate_caches()

        # fetch names and values of broadcast variables
        needs_broadcast_decryption_server = read_bool(infile)
        num_broadcast_variables = read_int(infile)
        if needs_broadcast_decryption_server:
            # read the decrypted data from a server in the jvm
            port = read_int(infile)
            auth_secret = utf8_deserializer.loads(infile)
            (broadcast_sock_file, _) = local_connect_and_auth(port, auth_secret)

        for _ in range(num_broadcast_variables):
            bid = read_long(infile)
            if bid >= 0:
                if needs_broadcast_decryption_server:
                    read_bid = read_long(broadcast_sock_file)
                    assert read_bid == bid
                    _broadcastRegistry[bid] = Broadcast(sock_file=broadcast_sock_file)
                else:
                    path = utf8_deserializer.loads(infile)
                    _broadcastRegistry[bid] = Broadcast(path=path)

            else:
                bid = -bid - 1
                _broadcastRegistry.pop(bid)

        if needs_broadcast_decryption_server:
            broadcast_sock_file.write(b"1")
            broadcast_sock_file.close()

        _accumulatorRegistry.clear()
        eval_type = read_int(infile)
        if eval_type == PythonEvalType.NON_UDF:
            func, profiler, deserializer, serializer = read_command(pickleSer, infile)
        else:
            func, profiler, deserializer, serializer = read_udfs(pickleSer, infile, eval_type)

        init_time = time.time()

        def process():
            iterator = deserializer.load_stream(infile)
            out_iter = func(split_index, iterator)
            try:
                serializer.dump_stream(out_iter, outfile)
            finally:
                if hasattr(out_iter, "close"):
                    out_iter.close()

        if profiler:
            profiler.profile(process)
        else:
            process()

        # Reset task context to None. This is a guard code to avoid residual context when worker
        # reuse.
        TaskContext._setTaskContext(None)
        BarrierTaskContext._setTaskContext(None)
    except BaseException as e:
        try:
            exc_info = None
            if os.environ.get("SPARK_SIMPLIFIED_TRACEBACK", False):
                tb = try_simplify_traceback(sys.exc_info()[-1])
                if tb is not None:
                    e.__cause__ = None
                    exc_info = "".join(traceback.format_exception(type(e), e, tb))
            if exc_info is None:
                exc_info = traceback.format_exc()

            write_int(SpecialLengths.PYTHON_EXCEPTION_THROWN, outfile)
            write_with_length(exc_info.encode("utf-8"), outfile)
        except IOError:
            # JVM close the socket
            pass
        except BaseException:
            # Write the error to stderr if it happened while serializing
            print("PySpark worker failed with exception:", file=sys.stderr)
            print(traceback.format_exc(), file=sys.stderr)
        sys.exit(-1)
    finally:
        if faulthandler_log_path:
            faulthandler.disable()
            faulthandler_log_file.close()
            os.remove(faulthandler_log_path)
    finish_time = time.time()
    report_times(outfile, boot_time, init_time, finish_time)
    write_long(shuffle.MemoryBytesSpilled, outfile)
    write_long(shuffle.DiskBytesSpilled, outfile)

    # Mark the beginning of the accumulators section of the output
    write_int(SpecialLengths.END_OF_DATA_SECTION, outfile)
    write_int(len(_accumulatorRegistry), outfile)
    for (aid, accum) in _accumulatorRegistry.items():
        pickleSer._write_with_length((aid, accum._value), outfile)

    # check end of stream
    if read_int(infile) == SpecialLengths.END_OF_STREAM:
        write_int(SpecialLengths.END_OF_STREAM, outfile)
    else:
        # write a different value to tell JVM to not reuse this worker
        write_int(SpecialLengths.END_OF_DATA_SECTION, outfile)
        sys.exit(-1)
示例#26
0
文件: worker.py 项目: 1ambda/spark
def main(infile, outfile):
    try:
        boot_time = time.time()
        split_index = read_int(infile)
        if split_index == -1:  # for unit tests
            exit(-1)

        # initialize global state
        shuffle.MemoryBytesSpilled = 0
        shuffle.DiskBytesSpilled = 0
        _accumulatorRegistry.clear()

        # fetch name of workdir
        spark_files_dir = utf8_deserializer.loads(infile)
        SparkFiles._root_directory = spark_files_dir
        SparkFiles._is_running_on_worker = True

        # fetch names of includes (*.zip and *.egg files) and construct PYTHONPATH
        add_path(spark_files_dir)  # *.py files that were added will be copied here
        num_python_includes = read_int(infile)
        for _ in range(num_python_includes):
            filename = utf8_deserializer.loads(infile)
            add_path(os.path.join(spark_files_dir, filename))
        if sys.version > '3':
            import importlib
            importlib.invalidate_caches()

        # fetch names and values of broadcast variables
        num_broadcast_variables = read_int(infile)
        for _ in range(num_broadcast_variables):
            bid = read_long(infile)
            if bid >= 0:
                path = utf8_deserializer.loads(infile)
                _broadcastRegistry[bid] = Broadcast(path=path)
            else:
                bid = - bid - 1
                _broadcastRegistry.pop(bid)

        _accumulatorRegistry.clear()
        command = pickleSer._read_with_length(infile)
        if isinstance(command, Broadcast):
            command = pickleSer.loads(command.value)
        (func, profiler, deserializer, serializer), version = command
        if version != sys.version_info[:2]:
            raise Exception(("Python in worker has different version %s than that in " +
                            "driver %s, PySpark cannot run with different minor versions") %
                            (sys.version_info[:2], version))
        init_time = time.time()

        def process():
            iterator = deserializer.load_stream(infile)
            serializer.dump_stream(func(split_index, iterator), outfile)

        if profiler:
            profiler.profile(process)
        else:
            process()
    except Exception:
        try:
            write_int(SpecialLengths.PYTHON_EXCEPTION_THROWN, outfile)
            write_with_length(traceback.format_exc().encode("utf-8"), outfile)
        except IOError:
            # JVM close the socket
            pass
        except Exception:
            # Write the error to stderr if it happened while serializing
            print("PySpark worker failed with exception:", file=sys.stderr)
            print(traceback.format_exc(), file=sys.stderr)
        exit(-1)
    finish_time = time.time()
    report_times(outfile, boot_time, init_time, finish_time)
    write_long(shuffle.MemoryBytesSpilled, outfile)
    write_long(shuffle.DiskBytesSpilled, outfile)

    # Mark the beginning of the accumulators section of the output
    write_int(SpecialLengths.END_OF_DATA_SECTION, outfile)
    write_int(len(_accumulatorRegistry), outfile)
    for (aid, accum) in _accumulatorRegistry.items():
        pickleSer._write_with_length((aid, accum._value), outfile)

    # check end of stream
    if read_int(infile) == SpecialLengths.END_OF_STREAM:
        write_int(SpecialLengths.END_OF_STREAM, outfile)
    else:
        # write a different value to tell JVM to not reuse this worker
        write_int(SpecialLengths.END_OF_DATA_SECTION, outfile)
        exit(-1)
示例#27
0
文件: worker.py 项目: lshoo/spark
def main(infile, outfile):
    try:
        boot_time = time.time()
        split_index = read_int(infile)
        if split_index == -1:  # for unit tests
            sys.exit(-1)

        version = utf8_deserializer.loads(infile)
        if version != "%d.%d" % sys.version_info[:2]:
            raise Exception(("Python in worker has different version %s than that in " +
                             "driver %s, PySpark cannot run with different minor versions." +
                             "Please check environment variables PYSPARK_PYTHON and " +
                             "PYSPARK_DRIVER_PYTHON are correctly set.") %
                            ("%d.%d" % sys.version_info[:2], version))

        # read inputs only for a barrier task
        isBarrier = read_bool(infile)
        boundPort = read_int(infile)
        secret = UTF8Deserializer().loads(infile)
        # initialize global state
        taskContext = None
        if isBarrier:
            taskContext = BarrierTaskContext._getOrCreate()
            BarrierTaskContext._initialize(boundPort, secret)
        else:
            taskContext = TaskContext._getOrCreate()
        # read inputs for TaskContext info
        taskContext._stageId = read_int(infile)
        taskContext._partitionId = read_int(infile)
        taskContext._attemptNumber = read_int(infile)
        taskContext._taskAttemptId = read_long(infile)
        taskContext._localProperties = dict()
        for i in range(read_int(infile)):
            k = utf8_deserializer.loads(infile)
            v = utf8_deserializer.loads(infile)
            taskContext._localProperties[k] = v

        shuffle.MemoryBytesSpilled = 0
        shuffle.DiskBytesSpilled = 0
        _accumulatorRegistry.clear()

        # fetch name of workdir
        spark_files_dir = utf8_deserializer.loads(infile)
        SparkFiles._root_directory = spark_files_dir
        SparkFiles._is_running_on_worker = True

        # fetch names of includes (*.zip and *.egg files) and construct PYTHONPATH
        add_path(spark_files_dir)  # *.py files that were added will be copied here
        num_python_includes = read_int(infile)
        for _ in range(num_python_includes):
            filename = utf8_deserializer.loads(infile)
            add_path(os.path.join(spark_files_dir, filename))
        if sys.version > '3':
            import importlib
            importlib.invalidate_caches()

        # fetch names and values of broadcast variables
        num_broadcast_variables = read_int(infile)
        for _ in range(num_broadcast_variables):
            bid = read_long(infile)
            if bid >= 0:
                path = utf8_deserializer.loads(infile)
                _broadcastRegistry[bid] = Broadcast(path=path)
            else:
                bid = - bid - 1
                _broadcastRegistry.pop(bid)

        _accumulatorRegistry.clear()
        eval_type = read_int(infile)
        if eval_type == PythonEvalType.NON_UDF:
            func, profiler, deserializer, serializer = read_command(pickleSer, infile)
        else:
            func, profiler, deserializer, serializer = read_udfs(pickleSer, infile, eval_type)

        init_time = time.time()

        def process():
            iterator = deserializer.load_stream(infile)
            serializer.dump_stream(func(split_index, iterator), outfile)

        if profiler:
            profiler.profile(process)
        else:
            process()
    except Exception:
        try:
            write_int(SpecialLengths.PYTHON_EXCEPTION_THROWN, outfile)
            write_with_length(traceback.format_exc().encode("utf-8"), outfile)
        except IOError:
            # JVM close the socket
            pass
        except Exception:
            # Write the error to stderr if it happened while serializing
            print("PySpark worker failed with exception:", file=sys.stderr)
            print(traceback.format_exc(), file=sys.stderr)
        sys.exit(-1)
    finish_time = time.time()
    report_times(outfile, boot_time, init_time, finish_time)
    write_long(shuffle.MemoryBytesSpilled, outfile)
    write_long(shuffle.DiskBytesSpilled, outfile)

    # Mark the beginning of the accumulators section of the output
    write_int(SpecialLengths.END_OF_DATA_SECTION, outfile)
    write_int(len(_accumulatorRegistry), outfile)
    for (aid, accum) in _accumulatorRegistry.items():
        pickleSer._write_with_length((aid, accum._value), outfile)

    # check end of stream
    if read_int(infile) == SpecialLengths.END_OF_STREAM:
        write_int(SpecialLengths.END_OF_STREAM, outfile)
    else:
        # write a different value to tell JVM to not reuse this worker
        write_int(SpecialLengths.END_OF_DATA_SECTION, outfile)
        sys.exit(-1)
示例#28
0
def main(infile, outfile):
    try:
        boot_time = time.time()
        split_index = read_int(infile)
        if split_index == -1:  # for unit tests
            exit(-1)

        # initialize global state
        shuffle.MemoryBytesSpilled = 0
        shuffle.DiskBytesSpilled = 0
        _accumulatorRegistry.clear()

        # fetch name of workdir
        spark_files_dir = utf8_deserializer.loads(infile)
        SparkFiles._root_directory = spark_files_dir
        SparkFiles._is_running_on_worker = True

        # fetch names of includes (*.zip and *.egg files) and construct PYTHONPATH
        add_path(
            spark_files_dir)  # *.py files that were added will be copied here
        num_python_includes = read_int(infile)
        for _ in range(num_python_includes):
            filename = utf8_deserializer.loads(infile)
            add_path(os.path.join(spark_files_dir, filename))

        # fetch names and values of broadcast variables
        num_broadcast_variables = read_int(infile)
        ser = CompressedSerializer(pickleSer)
        for _ in range(num_broadcast_variables):
            bid = read_long(infile)
            if bid >= 0:
                value = ser._read_with_length(infile)
                _broadcastRegistry[bid] = Broadcast(bid, value)
            else:
                bid = -bid - 1
                _broadcastRegistry.pop(bid)

        _accumulatorRegistry.clear()
        command = pickleSer._read_with_length(infile)
        if isinstance(command, Broadcast):
            command = pickleSer.loads(command.value)
        (func, stats, deserializer, serializer) = command
        init_time = time.time()

        def process():
            iterator = deserializer.load_stream(infile)
            serializer.dump_stream(func(split_index, iterator), outfile)

        if stats:
            p = cProfile.Profile()
            p.runcall(process)
            st = pstats.Stats(p)
            st.stream = None  # make it picklable
            stats.add(st.strip_dirs())
        else:
            process()
    except Exception:
        try:
            write_int(SpecialLengths.PYTHON_EXCEPTION_THROWN, outfile)
            write_with_length(traceback.format_exc(), outfile)
        except IOError:
            # JVM close the socket
            pass
        except Exception:
            # Write the error to stderr if it happened while serializing
            print >> sys.stderr, "PySpark worker failed with exception:"
            print >> sys.stderr, traceback.format_exc()
        exit(-1)
    finish_time = time.time()
    report_times(outfile, boot_time, init_time, finish_time)
    write_long(shuffle.MemoryBytesSpilled, outfile)
    write_long(shuffle.DiskBytesSpilled, outfile)

    # Mark the beginning of the accumulators section of the output
    write_int(SpecialLengths.END_OF_DATA_SECTION, outfile)
    write_int(len(_accumulatorRegistry), outfile)
    for (aid, accum) in _accumulatorRegistry.items():
        pickleSer._write_with_length((aid, accum._value), outfile)

    # check end of stream
    if read_int(infile) == SpecialLengths.END_OF_STREAM:
        write_int(SpecialLengths.END_OF_STREAM, outfile)
    else:
        # write a different value to tell JVM to not reuse this worker
        write_int(SpecialLengths.END_OF_DATA_SECTION, outfile)
        exit(-1)