def _task_fn(index, driver_addresses, num_proc, tmout, key): task = task_service.TaskService(index, key) try: driver_client = driver_service.DriverClient(driver_addresses, key) driver_client.register_task(index, task.addresses(), host_hash.host_hash()) task.wait_for_initial_registration(tmout) # Tasks ping each other in a circular fashion to determine interfaces reachable within # the cluster. next_task_index = (index + 1) % num_proc next_task_addresses = driver_client.all_task_addresses(next_task_index) # We request interface matching to weed out all the NAT'ed interfaces. next_task_client = task_service.TaskClient(next_task_index, next_task_addresses, key, match_intf=True) driver_client.register_task_to_task_addresses( next_task_index, next_task_client.addresses()) task_indices_on_this_host = driver_client.task_host_hash_indices( host_hash.host_hash()) if task_indices_on_this_host[0] == index: # Task with first index will execute orted that will run mpirun_exec_fn for all tasks. task.wait_for_command_start(tmout) task.wait_for_command_termination() else: # The rest of tasks need to wait for the first task to finish. first_task_addresses = driver_client.all_task_addresses( task_indices_on_this_host[0]) first_task_client = task_service.TaskClient( task_indices_on_this_host[0], first_task_addresses, key) first_task_client.wait_for_command_termination() return task.fn_result() finally: task.shutdown()
def main(driver_addresses, host_hash, command): key = codec.loads_base64(os.environ[secret.HOROVOD_SECRET_KEY]) driver_client = driver_service.DriverClient(driver_addresses, key) task_indices = driver_client.task_host_hash_indices(host_hash) # Since tasks with the same host hash have shared memory, we will run only # one ORTED process on the first task. first_task_index = task_indices[0] task_addresses = driver_client.all_task_addresses(first_task_index) task_client = task_service.TaskClient(first_task_index, task_addresses, key) task_client.run_command(command, os.environ)
def main(driver_addresses, host_hash, command): if ':' in host_hash: raise Exception( 'Illegal host hash provided. Are you using Open MPI 4.0.0+?') key = codec.loads_base64(os.environ[secret.HOROVOD_SECRET_KEY]) driver_client = driver_service.DriverClient(driver_addresses, key) task_indices = driver_client.task_host_hash_indices(host_hash) # Since tasks with the same host hash have shared memory, we will run only # one ORTED process on the first task. first_task_index = task_indices[0] task_addresses = driver_client.all_task_addresses(first_task_index) task_client = task_service.TaskClient(first_task_index, task_addresses, key) task_client.run_command(command, os.environ)
def main(driver_addresses): # Die if parent process terminates bg = threading.Thread(target=parent_process_monitor, args=(os.getppid(), )) bg.daemon = True bg.start() key = codec.loads_base64(os.environ[secret.HOROVOD_SECRET_KEY]) rank = int(os.environ['OMPI_COMM_WORLD_RANK']) driver_client = driver_service.DriverClient(driver_addresses, key) task_index = driver_client.task_index_by_rank(rank) task_addresses = driver_client.all_task_addresses(task_index) task_client = task_service.TaskClient(task_index, task_addresses, key) fn, args, kwargs = driver_client.code() result = fn(*args, **kwargs) task_client.register_code_result(result)
def run(fn, args=(), kwargs={}, num_proc=None, start_timeout=None, env=None, stdout=None, stderr=None, verbose=1): """ Runs Horovod in Spark. Runs `num_proc` processes executing `fn` using the same amount of Spark tasks. Args: fn: Function to run. args: Arguments to pass to `fn`. kwargs: Keyword arguments to pass to `fn`. num_proc: Number of Horovod processes. Defaults to `spark.default.parallelism`. start_timeout: Timeout for Spark tasks to spawn, register and start running the code, in seconds. If not set, falls back to `HOROVOD_SPARK_START_TIMEOUT` environment variable value. If it is not set as well, defaults to 600 seconds. env: Environment dictionary to use in Horovod run. Defaults to `os.environ`. stdout: Horovod stdout is redirected to this stream. Defaults to sys.stdout. stderr: Horovod stderr is redirected to this stream. Defaults to sys.stderr. verbose: Debug output verbosity (0-2). Defaults to 1. Returns: List of results returned by running `fn` on each rank. """ spark_context = pyspark.SparkContext._active_spark_context if spark_context is None: raise Exception( 'Could not find an active SparkContext, are you running in a PySpark session?' ) if num_proc is None: num_proc = spark_context.defaultParallelism if verbose >= 1: print( 'Running %d processes (inferred from spark.default.parallelism)...' % num_proc) else: if verbose >= 1: print('Running %d processes...' % num_proc) if start_timeout is None: # Lookup default timeout from the environment variable. start_timeout = int(os.getenv('HOROVOD_SPARK_START_TIMEOUT', '600')) result_queue = queue.Queue(1) tmout = timeout.Timeout(start_timeout) key = secret.make_secret_key() spark_job_group = 'horovod.spark.run.%d' % job_id.next_job_id() driver = driver_service.DriverService(num_proc, fn, args, kwargs, key) spark_thread = _make_spark_thread(spark_context, spark_job_group, num_proc, driver, tmout, key, result_queue) try: driver.wait_for_initial_registration(tmout) if verbose >= 2: print('Initial Spark task registration is complete.') task_clients = [ task_service.TaskClient(index, driver.task_addresses_for_driver(index), key) for index in range(num_proc) ] for task_client in task_clients: task_client.notify_initial_registration_complete() driver.wait_for_task_to_task_address_updates(tmout) if verbose >= 2: print('Spark task-to-task address registration is complete.') # Determine a set of common interfaces for task-to-task communication. common_intfs = set(driver.task_addresses_for_tasks(0).keys()) for index in range(1, num_proc): common_intfs.intersection_update( driver.task_addresses_for_tasks(index).keys()) if not common_intfs: raise Exception( 'Unable to find a set of common task-to-task communication interfaces: %s' % [(index, driver.task_addresses_for_tasks(index)) for index in range(num_proc)]) # Determine the index grouping based on host hashes. # Barrel shift until index 0 is in the first host. host_hashes = list(driver.task_host_hash_indices().keys()) host_hashes.sort() while 0 not in driver.task_host_hash_indices()[host_hashes[0]]: host_hashes = host_hashes[1:] + host_hashes[:1] ranks_to_indices = [] for host_hash in host_hashes: ranks_to_indices += driver.task_host_hash_indices()[host_hash] driver.set_ranks_to_indices(ranks_to_indices) if env is None: env = os.environ.copy() # Pass secret key through the environment variables. env[secret.HOROVOD_SECRET_KEY] = codec.dumps_base64(key) mpirun_command = ( 'mpirun --allow-run-as-root --tag-output ' '-np {num_proc} -H {hosts} ' '-bind-to none -map-by slot ' '-mca pml ob1 -mca btl ^openib -mca btl_tcp_if_include {common_intfs} ' '-x NCCL_DEBUG=INFO -x NCCL_SOCKET_IFNAME={common_intfs} ' '{env} ' # expect a lot of environment variables '-mca plm_rsh_agent "{python} -m horovod.spark.driver.mpirun_rsh {encoded_driver_addresses}" ' '{python} -m horovod.spark.task.mpirun_exec_fn {encoded_driver_addresses} ' .format( num_proc=num_proc, hosts=','.join( '%s:%d' % (host_hash, len(driver.task_host_hash_indices()[host_hash])) for host_hash in host_hashes), common_intfs=','.join(common_intfs), env=' '.join('-x %s' % key for key in env.keys()), python=sys.executable, encoded_driver_addresses=codec.dumps_base64( driver.addresses()))) if verbose >= 2: print('+ %s' % mpirun_command) exit_code = safe_shell_exec.execute(mpirun_command, env, stdout, stderr) if exit_code != 0: raise Exception( 'mpirun exited with code %d, see the error above.' % exit_code) except: # Terminate Spark job. spark_context.cancelJobGroup(spark_job_group) # Re-raise exception. raise finally: spark_thread.join() driver.shutdown() # Make sure Spark Job did not fail. driver.check_for_spark_job_failure() # If there's no exception, execution results are in this queue. results = result_queue.get_nowait() return [results[index] for index in ranks_to_indices]