def _save_mat(partition): ctx = TaskContext() data = [] for row in partition: fpath, idx, mbytes = row for b in mbytes: assert isinstance(b, int) row_data = mbytes + [get_label(fpath)] data.append(row_data) mat = np.array(data, dtype=np.dtype('B')) out = fp(outfile) + fp("part_%05d.mat" % ctx.partitionId()) savemat(out.path(), {"packets": mat})
def func_with_open_process_close(partition_id: Any, iterator: Iterator) -> Iterator: epoch_id = cast(TaskContext, TaskContext.get()).getLocalProperty( "streaming.sql.batchId") if epoch_id: int_epoch_id = int(epoch_id) else: raise RuntimeError( "Could not get batch id from TaskContext") # Check if the data should be processed should_process = True if open_exists: should_process = f.open( partition_id, int_epoch_id) # type: ignore[union-attr] error = None try: if should_process: for x in iterator: cast("SupportsProcess", f).process(x) except Exception as ex: error = ex finally: if close_exists: f.close(error) # type: ignore[union-attr] if error: raise error return iter([])
def ff(iter): partitionId = TaskContext.get().partitionId() with open(partDataPath, "w") as fp: for i in iter: fp.write(str(i) + "\n") #we need barrier here #sleep 1s for now to ensure all nodes datafile generated. time.sleep(1) if partitionId == 0: with open(hostFilePath, "w") as fp: fp.write(hosts) # NOTE: MPI require every node process run in the same working directory, # so I add `cd /tmp/` so every process will run in `/tmp` # without this, the default directory may not exist on other nodes and cause error. mpiCmd = "cd /tmp/;" + mpiRunPath + " -n " + str(numTasks) + " -f " +\ hostFilePath + " " + mpiProgPath + " " +\ partDataPath + " " + outputDataPath prc = Popen(mpiCmd, stdout=PIPE, stderr=PIPE, shell=True) stdout, stderr = prc.communicate() if prc.returncode != 0: raise Exception, "cmd:\n" + mpiCmd + "\ncmd ouput:\n" + stdout + "\ncmd err\n: " + stderr # I still read data from stdout, # later I will change to read from local file # but we need to address the issue that ensure mapping process-0 to the mpirun node. yield stdout
def maybe_create_eggroll_client(): """ a tricky way to set eggroll client which may be used by spark tasks. WARM: This may be removed or adjusted in future! """ import pickle from pyspark.taskcontext import TaskContext mode, eggroll_session = pickle.loads( bytes.fromhex(TaskContext.get().getLocalProperty(_EGGROLL_CLIENT))) build_eggroll_runtime(WorkMode(mode), eggroll_session)
def launchHorovodMPI(featureArrayFile, labelsFile): # later I will pass the two filepath args to the mpi cmd partitionId = TaskContext.get().partitionId() if partitionId == 0: # NOTE: MPI require every node process run in the same working directory, # so I add `cd /tmp/` so every process will run in `/tmp` # without this, the default directory may not exist on other nodes and cause error. mpiCmd = "cd /tmp/;mpirun -np 4 -H localhost:4 -bind-to none -map-by slot python hvd_run_mnist_training" prc = Popen(mpiCmd, stdout=PIPE, stderr=PIPE, shell=True) stdout, stderr = prc.communicate() if prc.returncode != 0: raise Exception, "cmd:\n" + mpiCmd + "\ncmd ouput:\n" + stdout + "\ncmd err\n: " + stderr # I still read data from stdout, # later I will change to read from local file # but we need to address the issue that ensure mapping process-0 to the mpirun node. return stdout
def maybe_create_eggroll_client(): """ a tricky way to set eggroll client which may be used by spark tasks. WARM: This may be removed or adjusted in future! """ import pickle from pyspark.taskcontext import TaskContext mode, eggroll_session = pickle.loads(bytes.fromhex(TaskContext.get().getLocalProperty(_EGGROLL_CLIENT))) if mode == 1: from eggroll.api.cluster.eggroll import _EggRoll if _EggRoll.instance is None: from eggroll.api import ComputingEngine from eggroll.api.cluster.eggroll import _EggRoll eggroll_runtime = _EggRoll(eggroll_session=eggroll_session) eggroll_session.set_runtime(ComputingEngine.EGGROLL_DTABLE, eggroll_runtime) else: from eggroll.api.standalone.eggroll import Standalone Standalone(eggroll_session)
def test_spark_worker(monkeypatch, sentry_init, capture_events, capture_exceptions): import pyspark.worker as original_worker import pyspark.daemon as original_daemon from pyspark.taskcontext import TaskContext task_context = TaskContext._getOrCreate() def mock_main(): task_context._stageId = 0 task_context._attemptNumber = 1 task_context._partitionId = 2 task_context._taskAttemptId = 3 try: raise ZeroDivisionError except ZeroDivisionError: sys.exit(-1) monkeypatch.setattr(original_worker, "main", mock_main) sentry_init(integrations=[SparkWorkerIntegration()]) events = capture_events() exceptions = capture_exceptions() original_daemon.worker_main() # SystemExit called, but not recorded as part of event assert type(exceptions.pop()) == SystemExit assert len(events[0]["exception"]["values"]) == 1 assert events[0]["exception"]["values"][0]["type"] == "ZeroDivisionError" assert events[0]["tags"] == { "stageId": 0, "attemptNumber": 1, "partitionId": 2, "taskAttemptId": 3, }
def process_event(event, hint): # type: (Event, Hint) -> Optional[Event] with capture_internal_exceptions(): integration = Hub.current.get_integration( SparkWorkerIntegration) task_context = TaskContext.get() if integration is None or task_context is None: return event event.setdefault("tags", {}).setdefault("stageId", task_context.stageId()) event["tags"].setdefault("partitionId", task_context.partitionId()) event["tags"].setdefault("attemptNumber", task_context.attemptNumber()) event["tags"].setdefault("taskAttemptId", task_context.taskAttemptId()) if task_context._localProperties: if "sentry_app_name" in task_context._localProperties: event["tags"].setdefault( "app_name", task_context._localProperties["sentry_app_name"]) event["tags"].setdefault( "application_id", task_context. _localProperties["sentry_application_id"], ) if "callSite.short" in task_context._localProperties: event.setdefault("extra", {}).setdefault( "callSite", task_context._localProperties["callSite.short"]) return event
def main(infile, outfile): try: boot_time = time.time() split_index = read_int(infile) if split_index == -1: # for unit tests sys.exit(-1) version = utf8_deserializer.loads(infile) if version != "%d.%d" % sys.version_info[:2]: raise Exception(("Python in worker has different version %s than that in " + "driver %s, PySpark cannot run with different minor versions." + "Please check environment variables PYSPARK_PYTHON and " + "PYSPARK_DRIVER_PYTHON are correctly set.") % ("%d.%d" % sys.version_info[:2], version)) # read inputs only for a barrier task isBarrier = read_bool(infile) boundPort = read_int(infile) secret = UTF8Deserializer().loads(infile) # set up memory limits memory_limit_mb = int(os.environ.get('PYSPARK_EXECUTOR_MEMORY_MB', "-1")) if memory_limit_mb > 0 and has_resource_module: total_memory = resource.RLIMIT_AS try: (soft_limit, hard_limit) = resource.getrlimit(total_memory) msg = "Current mem limits: {0} of max {1}\n".format(soft_limit, hard_limit) print(msg, file=sys.stderr) # convert to bytes new_limit = memory_limit_mb * 1024 * 1024 if soft_limit == resource.RLIM_INFINITY or new_limit < soft_limit: msg = "Setting mem limits to {0} of max {1}\n".format(new_limit, new_limit) print(msg, file=sys.stderr) resource.setrlimit(total_memory, (new_limit, new_limit)) except (resource.error, OSError, ValueError) as e: # not all systems support resource limits, so warn instead of failing print("WARN: Failed to set memory limit: {0}\n".format(e), file=sys.stderr) # initialize global state taskContext = None if isBarrier: taskContext = BarrierTaskContext._getOrCreate() BarrierTaskContext._initialize(boundPort, secret) else: taskContext = TaskContext._getOrCreate() # read inputs for TaskContext info taskContext._stageId = read_int(infile) taskContext._partitionId = read_int(infile) taskContext._attemptNumber = read_int(infile) taskContext._taskAttemptId = read_long(infile) taskContext._localProperties = dict() for i in range(read_int(infile)): k = utf8_deserializer.loads(infile) v = utf8_deserializer.loads(infile) taskContext._localProperties[k] = v shuffle.MemoryBytesSpilled = 0 shuffle.DiskBytesSpilled = 0 _accumulatorRegistry.clear() # fetch name of workdir spark_files_dir = utf8_deserializer.loads(infile) SparkFiles._root_directory = spark_files_dir SparkFiles._is_running_on_worker = True # fetch names of includes (*.zip and *.egg files) and construct PYTHONPATH add_path(spark_files_dir) # *.py files that were added will be copied here num_python_includes = read_int(infile) for _ in range(num_python_includes): filename = utf8_deserializer.loads(infile) add_path(os.path.join(spark_files_dir, filename)) if sys.version > '3': import importlib importlib.invalidate_caches() # fetch names and values of broadcast variables needs_broadcast_decryption_server = read_bool(infile) num_broadcast_variables = read_int(infile) if needs_broadcast_decryption_server: # read the decrypted data from a server in the jvm port = read_int(infile) auth_secret = utf8_deserializer.loads(infile) (broadcast_sock_file, _) = local_connect_and_auth(port, auth_secret) for _ in range(num_broadcast_variables): bid = read_long(infile) if bid >= 0: if needs_broadcast_decryption_server: read_bid = read_long(broadcast_sock_file) assert(read_bid == bid) _broadcastRegistry[bid] = \ Broadcast(sock_file=broadcast_sock_file) else: path = utf8_deserializer.loads(infile) _broadcastRegistry[bid] = Broadcast(path=path) else: bid = - bid - 1 _broadcastRegistry.pop(bid) if needs_broadcast_decryption_server: broadcast_sock_file.write(b'1') broadcast_sock_file.close() _accumulatorRegistry.clear() eval_type = read_int(infile) if eval_type == PythonEvalType.NON_UDF: func, profiler, deserializer, serializer = read_command(pickleSer, infile) else: func, profiler, deserializer, serializer = read_udfs(pickleSer, infile, eval_type) init_time = time.time() def process(): iterator = deserializer.load_stream(infile) serializer.dump_stream(func(split_index, iterator), outfile) if profiler: profiler.profile(process) else: process() except Exception: try: write_int(SpecialLengths.PYTHON_EXCEPTION_THROWN, outfile) write_with_length(traceback.format_exc().encode("utf-8"), outfile) except IOError: # JVM close the socket pass except Exception: # Write the error to stderr if it happened while serializing print("PySpark worker failed with exception:", file=sys.stderr) print(traceback.format_exc(), file=sys.stderr) sys.exit(-1) finish_time = time.time() report_times(outfile, boot_time, init_time, finish_time) write_long(shuffle.MemoryBytesSpilled, outfile) write_long(shuffle.DiskBytesSpilled, outfile) # Mark the beginning of the accumulators section of the output write_int(SpecialLengths.END_OF_DATA_SECTION, outfile) write_int(len(_accumulatorRegistry), outfile) for (aid, accum) in _accumulatorRegistry.items(): pickleSer._write_with_length((aid, accum._value), outfile) # check end of stream if read_int(infile) == SpecialLengths.END_OF_STREAM: write_int(SpecialLengths.END_OF_STREAM, outfile) else: # write a different value to tell JVM to not reuse this worker write_int(SpecialLengths.END_OF_DATA_SECTION, outfile) sys.exit(-1)
def main(infile, outfile): try: boot_time = time.time() split_index = read_int(infile) if split_index == -1: # for unit tests sys.exit(-1) version = utf8_deserializer.loads(infile) if version != "%d.%d" % sys.version_info[:2]: raise Exception( ("Python in worker has different version %s than that in " + "driver %s, PySpark cannot run with different minor versions." + "Please check environment variables PYSPARK_PYTHON and " + "PYSPARK_DRIVER_PYTHON are correctly set.") % ("%d.%d" % sys.version_info[:2], version)) # initialize global state taskContext = TaskContext._getOrCreate() taskContext._stageId = read_int(infile) taskContext._partitionId = read_int(infile) taskContext._attemptNumber = read_int(infile) taskContext._taskAttemptId = read_long(infile) shuffle.MemoryBytesSpilled = 0 shuffle.DiskBytesSpilled = 0 _accumulatorRegistry.clear() # fetch name of workdir spark_files_dir = utf8_deserializer.loads(infile) SparkFiles._root_directory = spark_files_dir SparkFiles._is_running_on_worker = True # fetch names of includes (*.zip and *.egg files) and construct PYTHONPATH add_path( spark_files_dir) # *.py files that were added will be copied here num_python_includes = read_int(infile) for _ in range(num_python_includes): filename = utf8_deserializer.loads(infile) add_path(os.path.join(spark_files_dir, filename)) if sys.version > '3': import importlib importlib.invalidate_caches() # fetch names and values of broadcast variables num_broadcast_variables = read_int(infile) for _ in range(num_broadcast_variables): bid = read_long(infile) if bid >= 0: path = utf8_deserializer.loads(infile) _broadcastRegistry[bid] = Broadcast(path=path) else: bid = -bid - 1 _broadcastRegistry.pop(bid) _accumulatorRegistry.clear() eval_type = read_int(infile) if eval_type == PythonEvalType.NON_UDF: func, profiler, deserializer, serializer = read_command( pickleSer, infile) else: func, profiler, deserializer, serializer = read_udfs( pickleSer, infile, eval_type) init_time = time.time() def process(): iterator = deserializer.load_stream(infile) serializer.dump_stream(func(split_index, iterator), outfile) if profiler: profiler.profile(process) else: process() except Exception: try: write_int(SpecialLengths.PYTHON_EXCEPTION_THROWN, outfile) write_with_length(traceback.format_exc().encode("utf-8"), outfile) except IOError: # JVM close the socket pass except Exception: # Write the error to stderr if it happened while serializing print("PySpark worker failed with exception:", file=sys.stderr) print(traceback.format_exc(), file=sys.stderr) sys.exit(-1) finish_time = time.time() report_times(outfile, boot_time, init_time, finish_time) write_long(shuffle.MemoryBytesSpilled, outfile) write_long(shuffle.DiskBytesSpilled, outfile) # Mark the beginning of the accumulators section of the output write_int(SpecialLengths.END_OF_DATA_SECTION, outfile) write_int(len(_accumulatorRegistry), outfile) for (aid, accum) in _accumulatorRegistry.items(): pickleSer._write_with_length((aid, accum._value), outfile) # check end of stream if read_int(infile) == SpecialLengths.END_OF_STREAM: write_int(SpecialLengths.END_OF_STREAM, outfile) else: # write a different value to tell JVM to not reuse this worker write_int(SpecialLengths.END_OF_DATA_SECTION, outfile) sys.exit(-1)
def main(infile, outfile): faulthandler_log_path = os.environ.get("PYTHON_FAULTHANDLER_DIR", None) try: if faulthandler_log_path: faulthandler_log_path = os.path.join(faulthandler_log_path, str(os.getpid())) faulthandler_log_file = open(faulthandler_log_path, "w") faulthandler.enable(file=faulthandler_log_file) boot_time = time.time() split_index = read_int(infile) if split_index == -1: # for unit tests sys.exit(-1) version = utf8_deserializer.loads(infile) if version != "%d.%d" % sys.version_info[:2]: raise RuntimeError( ( "Python in worker has different version %s than that in " + "driver %s, PySpark cannot run with different minor versions. " + "Please check environment variables PYSPARK_PYTHON and " + "PYSPARK_DRIVER_PYTHON are correctly set." ) % ("%d.%d" % sys.version_info[:2], version) ) # read inputs only for a barrier task isBarrier = read_bool(infile) boundPort = read_int(infile) secret = UTF8Deserializer().loads(infile) # set up memory limits memory_limit_mb = int(os.environ.get("PYSPARK_EXECUTOR_MEMORY_MB", "-1")) if memory_limit_mb > 0 and has_resource_module: total_memory = resource.RLIMIT_AS try: (soft_limit, hard_limit) = resource.getrlimit(total_memory) msg = "Current mem limits: {0} of max {1}\n".format(soft_limit, hard_limit) print(msg, file=sys.stderr) # convert to bytes new_limit = memory_limit_mb * 1024 * 1024 if soft_limit == resource.RLIM_INFINITY or new_limit < soft_limit: msg = "Setting mem limits to {0} of max {1}\n".format(new_limit, new_limit) print(msg, file=sys.stderr) resource.setrlimit(total_memory, (new_limit, new_limit)) except (resource.error, OSError, ValueError) as e: # not all systems support resource limits, so warn instead of failing lineno = ( getframeinfo(currentframe()).lineno + 1 if currentframe() is not None else 0 ) print( warnings.formatwarning( "Failed to set memory limit: {0}".format(e), ResourceWarning, __file__, lineno, ), file=sys.stderr, ) # initialize global state taskContext = None if isBarrier: taskContext = BarrierTaskContext._getOrCreate() BarrierTaskContext._initialize(boundPort, secret) # Set the task context instance here, so we can get it by TaskContext.get for # both TaskContext and BarrierTaskContext TaskContext._setTaskContext(taskContext) else: taskContext = TaskContext._getOrCreate() # read inputs for TaskContext info taskContext._stageId = read_int(infile) taskContext._partitionId = read_int(infile) taskContext._attemptNumber = read_int(infile) taskContext._taskAttemptId = read_long(infile) taskContext._cpus = read_int(infile) taskContext._resources = {} for r in range(read_int(infile)): key = utf8_deserializer.loads(infile) name = utf8_deserializer.loads(infile) addresses = [] taskContext._resources = {} for a in range(read_int(infile)): addresses.append(utf8_deserializer.loads(infile)) taskContext._resources[key] = ResourceInformation(name, addresses) taskContext._localProperties = dict() for i in range(read_int(infile)): k = utf8_deserializer.loads(infile) v = utf8_deserializer.loads(infile) taskContext._localProperties[k] = v shuffle.MemoryBytesSpilled = 0 shuffle.DiskBytesSpilled = 0 _accumulatorRegistry.clear() # fetch name of workdir spark_files_dir = utf8_deserializer.loads(infile) SparkFiles._root_directory = spark_files_dir SparkFiles._is_running_on_worker = True # fetch names of includes (*.zip and *.egg files) and construct PYTHONPATH add_path(spark_files_dir) # *.py files that were added will be copied here num_python_includes = read_int(infile) for _ in range(num_python_includes): filename = utf8_deserializer.loads(infile) add_path(os.path.join(spark_files_dir, filename)) importlib.invalidate_caches() # fetch names and values of broadcast variables needs_broadcast_decryption_server = read_bool(infile) num_broadcast_variables = read_int(infile) if needs_broadcast_decryption_server: # read the decrypted data from a server in the jvm port = read_int(infile) auth_secret = utf8_deserializer.loads(infile) (broadcast_sock_file, _) = local_connect_and_auth(port, auth_secret) for _ in range(num_broadcast_variables): bid = read_long(infile) if bid >= 0: if needs_broadcast_decryption_server: read_bid = read_long(broadcast_sock_file) assert read_bid == bid _broadcastRegistry[bid] = Broadcast(sock_file=broadcast_sock_file) else: path = utf8_deserializer.loads(infile) _broadcastRegistry[bid] = Broadcast(path=path) else: bid = -bid - 1 _broadcastRegistry.pop(bid) if needs_broadcast_decryption_server: broadcast_sock_file.write(b"1") broadcast_sock_file.close() _accumulatorRegistry.clear() eval_type = read_int(infile) if eval_type == PythonEvalType.NON_UDF: func, profiler, deserializer, serializer = read_command(pickleSer, infile) else: func, profiler, deserializer, serializer = read_udfs(pickleSer, infile, eval_type) init_time = time.time() def process(): iterator = deserializer.load_stream(infile) out_iter = func(split_index, iterator) try: serializer.dump_stream(out_iter, outfile) finally: if hasattr(out_iter, "close"): out_iter.close() if profiler: profiler.profile(process) else: process() # Reset task context to None. This is a guard code to avoid residual context when worker # reuse. TaskContext._setTaskContext(None) BarrierTaskContext._setTaskContext(None) except BaseException as e: try: exc_info = None if os.environ.get("SPARK_SIMPLIFIED_TRACEBACK", False): tb = try_simplify_traceback(sys.exc_info()[-1]) if tb is not None: e.__cause__ = None exc_info = "".join(traceback.format_exception(type(e), e, tb)) if exc_info is None: exc_info = traceback.format_exc() write_int(SpecialLengths.PYTHON_EXCEPTION_THROWN, outfile) write_with_length(exc_info.encode("utf-8"), outfile) except IOError: # JVM close the socket pass except BaseException: # Write the error to stderr if it happened while serializing print("PySpark worker failed with exception:", file=sys.stderr) print(traceback.format_exc(), file=sys.stderr) sys.exit(-1) finally: if faulthandler_log_path: faulthandler.disable() faulthandler_log_file.close() os.remove(faulthandler_log_path) finish_time = time.time() report_times(outfile, boot_time, init_time, finish_time) write_long(shuffle.MemoryBytesSpilled, outfile) write_long(shuffle.DiskBytesSpilled, outfile) # Mark the beginning of the accumulators section of the output write_int(SpecialLengths.END_OF_DATA_SECTION, outfile) write_int(len(_accumulatorRegistry), outfile) for (aid, accum) in _accumulatorRegistry.items(): pickleSer._write_with_length((aid, accum._value), outfile) # check end of stream if read_int(infile) == SpecialLengths.END_OF_STREAM: write_int(SpecialLengths.END_OF_STREAM, outfile) else: # write a different value to tell JVM to not reuse this worker write_int(SpecialLengths.END_OF_DATA_SECTION, outfile) sys.exit(-1)
def main(infile, outfile): try: boot_time = time.time() split_index = read_int(infile) if split_index == -1: # for unit tests exit(-1) version = utf8_deserializer.loads(infile) if version != "%d.%d" % sys.version_info[:2]: raise Exception(("Python in worker has different version %s than that in " + "driver %s, PySpark cannot run with different minor versions." + "Please check environment variables PYSPARK_PYTHON and " + "PYSPARK_DRIVER_PYTHON are correctly set.") % ("%d.%d" % sys.version_info[:2], version)) # initialize global state taskContext = TaskContext._getOrCreate() taskContext._stageId = read_int(infile) taskContext._partitionId = read_int(infile) taskContext._attemptNumber = read_int(infile) taskContext._taskAttemptId = read_long(infile) shuffle.MemoryBytesSpilled = 0 shuffle.DiskBytesSpilled = 0 _accumulatorRegistry.clear() # fetch name of workdir spark_files_dir = utf8_deserializer.loads(infile) SparkFiles._root_directory = spark_files_dir SparkFiles._is_running_on_worker = True # fetch names of includes (*.zip and *.egg files) and construct PYTHONPATH add_path(spark_files_dir) # *.py files that were added will be copied here num_python_includes = read_int(infile) for _ in range(num_python_includes): filename = utf8_deserializer.loads(infile) add_path(os.path.join(spark_files_dir, filename)) if sys.version > '3': import importlib importlib.invalidate_caches() # fetch names and values of broadcast variables num_broadcast_variables = read_int(infile) for _ in range(num_broadcast_variables): bid = read_long(infile) if bid >= 0: path = utf8_deserializer.loads(infile) _broadcastRegistry[bid] = Broadcast(path=path) else: bid = - bid - 1 _broadcastRegistry.pop(bid) _accumulatorRegistry.clear() eval_type = read_int(infile) if eval_type == PythonEvalType.NON_UDF: func, profiler, deserializer, serializer = read_command(pickleSer, infile) else: func, profiler, deserializer, serializer = read_udfs(pickleSer, infile, eval_type) init_time = time.time() def process(): iterator = deserializer.load_stream(infile) serializer.dump_stream(func(split_index, iterator), outfile) if profiler: profiler.profile(process) else: process() except Exception: try: write_int(SpecialLengths.PYTHON_EXCEPTION_THROWN, outfile) write_with_length(traceback.format_exc().encode("utf-8"), outfile) except IOError: # JVM close the socket pass except Exception: # Write the error to stderr if it happened while serializing print("PySpark worker failed with exception:", file=sys.stderr) print(traceback.format_exc(), file=sys.stderr) exit(-1) finish_time = time.time() report_times(outfile, boot_time, init_time, finish_time) write_long(shuffle.MemoryBytesSpilled, outfile) write_long(shuffle.DiskBytesSpilled, outfile) # Mark the beginning of the accumulators section of the output write_int(SpecialLengths.END_OF_DATA_SECTION, outfile) write_int(len(_accumulatorRegistry), outfile) for (aid, accum) in _accumulatorRegistry.items(): pickleSer._write_with_length((aid, accum._value), outfile) # check end of stream if read_int(infile) == SpecialLengths.END_OF_STREAM: write_int(SpecialLengths.END_OF_STREAM, outfile) else: # write a different value to tell JVM to not reuse this worker write_int(SpecialLengths.END_OF_DATA_SECTION, outfile) exit(-1)
def runHorovodMPI(iter): taskCtx = TaskContext.get() # assume only one element in the iterator. # so I fix the file name for now dataFilePath = "/tmp/mpiInputData" modelExportDir = "/tmp/modelExportDir_" + str(random.randint(0, 2 << 30)) # Note: # change this to be a dbfs path destModelDir = "/tmp/model_" + str(random.randint(0, 2 << 30)) for pdf in iter: table = pa.Table.from_pandas(pdf) # later will directly get pyarrow table from RDD. save_pyarrow_table(table, dataFilePath) taskCtx.barrier() partitionID = taskCtx.partitionId() if partitionID == 0: hostsList = [i.split(":")[0] for i in taskCtx.hosts()] localHost = hostsList[0] # need a new API numProc = len(hostsList) # move local host to be first one. for i in range(0, numProc): if localHost == hostsList[i]: temp = hostsList[0] hostsList[0] = localHost hostsList[i] = temp break # do not generate host file, use simpler -H param instead. hostsListParam = ",".join(hostsList) # generate rank file rankFilePath = "/tmp/rankfile" with open(rankFilePath, "w") as rf: for i in range(0, numProc): rf.write("rank %d=%s slot=0-4" % (i, hostsList[i])) # NOTE: # Remember to change to real path mpiProgPath = "/tmp/hvd_run_mnist_training.py" # NOTE: specify mpi working dir "/tmp". # and note the horovod estimator will generate checkpoint dir # `mnist_convnet_model_${RANDOM_NUMBER}` # in the working dir. # NOTE: # Remember to add `sudo -u ubuntu` when run on databricks cluster # and change python path mpiCmd = "mpirun --wdir %s -np %d -H %s python %s %s %s" % ( "/tmp", numProc, hostsListParam, #rankFilePath, mpiProgPath, dataFilePath, modelExportDir) prc = Popen(mpiCmd, stdout=PIPE, stderr=PIPE, shell=True) stdout, stderr = prc.communicate() if prc.returncode != 0: raise Exception, "cmd:\n" + mpiCmd + "\ncmd ouput:\n" + stdout + "\ncmd err\n: " + stderr # from tensorflow.contrib import predictor # predictor.from_saved_model(modelExportDir) # get the inner dir. modelDir = modelExportDir + os.listdir(modelExportDir)[0] copyModelCmd = "cp -r %s %s" % (modelDir, destModelDir) prc = Popen(copyModelCmd, stdout=PIPE, stderr=PIPE, shell=True) stdout, stderr = prc.communicate() if prc.returncode != 0: raise Exception, "cmd:\n" + mpiCmd + "\ncmd ouput:\n" + stdout + "\ncmd err\n: " + stderr taskCtx.barrier() return [destModelDir] else: taskCtx.barrier() return []