def _task_fn(index, driver_addresses, settings): task = task_service.SparkTaskService(index, settings.key) try: driver_client = driver_service.SparkDriverClient(driver_addresses, settings.key, settings.verbose) driver_client.register_task(index, task.addresses(), host_hash.host_hash()) task.wait_for_initial_registration(settings.timeout) # Tasks ping each other in a circular fashion to determine interfaces reachable within # the cluster. next_task_index = (index + 1) % settings.num_proc next_task_addresses = driver_client.all_task_addresses(next_task_index) # We request interface matching to weed out all the NAT'ed interfaces. next_task_client = \ task_service.SparkTaskClient(next_task_index, next_task_addresses, settings.key, settings.verbose, match_intf=True) driver_client.register_task_to_task_addresses(next_task_index, next_task_client.addresses()) task_indices_on_this_host = driver_client.task_host_hash_indices( host_hash.host_hash()) if task_indices_on_this_host[0] == index: # Task with first index will execute orted that will run mpirun_exec_fn for all tasks. task.wait_for_command_start(settings.timeout) task.wait_for_command_termination() else: # The rest of tasks need to wait for the first task to finish. first_task_addresses = driver_client.all_task_addresses(task_indices_on_this_host[0]) first_task_client = \ task_service.SparkTaskClient(task_indices_on_this_host[0], first_task_addresses, settings.key, settings.verbose) first_task_client.wait_for_command_termination() return task.fn_result() finally: task.shutdown()
def _task_fn(index, driver_addresses, key, settings, use_gloo): # deserialized on Spark workers, settings do not contain the key, so it is given here explicitly # Spark RPC communicates the key and supports encryption # for convenience, we put it back into settings settings.key = key task = task_service.SparkTaskService(index, settings.key, settings.nics, settings.verbose) try: driver_client = driver_service.SparkDriverClient( driver_addresses, settings.key, settings.verbose) driver_client.register_task(index, task.addresses(), host_hash.host_hash()) task.wait_for_initial_registration(settings.timeout) task_indices_on_this_host = driver_client.task_host_hash_indices( host_hash.host_hash()) # With Gloo all tasks wait for the command # With MPI task with first index executes orted which will run mpirun_exec_fn for all tasks. if use_gloo or task_indices_on_this_host[0] == index: task.wait_for_command_start(settings.timeout) task.wait_for_command_termination() else: # The rest of tasks need to wait for the first task to finish. first_task_addresses = driver_client.all_task_addresses( task_indices_on_this_host[0]) first_task_client = \ task_service.SparkTaskClient(task_indices_on_this_host[0], first_task_addresses, settings.key, settings.verbose) first_task_client.wait_for_command_termination() return task.fn_result() finally: task.shutdown()
def _task_fn(index, driver_addresses, key, settings, use_gloo): # deserialized on Spark workers, settings do not contain the key, so it is given here explicitly # Spark RPC communicates the key and supports encryption # for convenience, we put it back into settings settings.key = key task = task_service.SparkTaskService(index, settings.key, settings.nics, settings.verbose) try: driver_client = driver_service.SparkDriverClient( driver_addresses, settings.key, settings.verbose) driver_client.register_task(index, task.addresses(), host_hash.host_hash()) task.wait_for_initial_registration(settings.timeout) task_indices_on_this_host = driver_client.task_host_hash_indices( host_hash.host_hash()) # With Gloo all tasks wait for the command # With MPI task with first index executes orted which will run mpirun_exec_fn for all tasks. minimum_lifetime_after_start = None if use_gloo or task_indices_on_this_host[0] == index: task.wait_for_command_start(settings.timeout) minimum_lifetime_after_start = timeout.Timeout( MINIMUM_COMMAND_LIFETIME_S, message='Just measuring runtime') task.wait_for_command_termination() else: # The rest of tasks need to wait for the first task to finish. first_task_addresses = driver_client.all_task_addresses( task_indices_on_this_host[0]) first_task_client = \ task_service.SparkTaskClient(task_indices_on_this_host[0], first_task_addresses, settings.key, settings.verbose) first_task_client.wait_for_command_termination() # command terminated, make sure this task service does not shutdown too quickly after # the client started the command as it needs some time to connect again # to wait for the result after starting the command (see horovod.spark.driver.rsh). if minimum_lifetime_after_start is not None: time.sleep(minimum_lifetime_after_start.remaining()) return task.fn_result() finally: # this has to block on running requests (wait_for_command_exit_code) # so they can finish serving the exit code # shutdown does block with network.BasicService._server._block_on_close = True task.shutdown()
def _task_fn(index, driver_addresses, key, settings, use_gloo, is_elastic): # deserialized on Spark workers, settings do not contain the key, so it is given here explicitly # Spark RPC communicates the key and supports encryption # for convenience, we put it back into settings settings.key = key # to simplify things, each task is an individual host in Elastic Horovod on Spark # further, each attempt (instance) of a task is an individual host in Elastic Horovod on Spark # hides availability of shared memory among executors on the same Spark node hosthash = host_hash( salt='{}-{}'.format(index, time.time()) if is_elastic else None) # provide host hash to mpirun_exec_fn.py via task service # gloo_exec_fn.py will get this env var set in request env as well os.environ['HOROVOD_HOSTNAME'] = hosthash task = task_service.SparkTaskService( index, settings.key, settings.nics, MINIMUM_COMMAND_LIFETIME_S if is_elastic or use_gloo else None, settings.verbose) try: driver_client = driver_service.SparkDriverClient( driver_addresses, settings.key, settings.verbose) driver_client.register_task(index, task.addresses(), hosthash) if not is_elastic: task.wait_for_initial_registration(settings.start_timeout) task_indices_on_this_host = driver_client.task_host_hash_indices( hosthash) local_rank_zero_index = task_indices_on_this_host[0] else: local_rank_zero_index = None # In elastic all tasks wait for task shutdown signal from driver. # With Gloo all tasks wait for the command to start and terminate. # With MPI task with first index executes orted which will run mpirun_exec_fn for all tasks. if is_elastic: # either terminate on task shutdown or command termination shutdown_thread = in_thread(driver_client.wait_for_task_shutdown) while shutdown_thread.is_alive(): # Once the command started we wait for its termination if task.check_for_command_start( WAIT_FOR_COMMAND_START_DELAY_SECONDS): task.wait_for_command_termination() if task.command_exit_code() != 0: raise Exception( 'Command failed, making Spark task fail to restart the task' ) break # While no command started, we can shutdown any time shutdown_thread.join(WAIT_FOR_SHUTDOWN_DELAY_SECONDS) elif use_gloo or index == local_rank_zero_index: # Either Gloo or first task with MPI. task.wait_for_command_start(settings.start_timeout) task.wait_for_command_termination() else: # The other tasks with MPI need to wait for the first task to finish. first_task_addresses = driver_client.all_task_addresses( local_rank_zero_index) first_task_client = \ task_service.SparkTaskClient(local_rank_zero_index, first_task_addresses, settings.key, settings.verbose) first_task_client.wait_for_command_termination() return task.fn_result() finally: # we must not call into shutdown too quickly, task clients run a command # and want to wait on the result, we have told task service not to return # from wait_for_command_termination too quickly, so we are safe here to shutdown # clients have had enough time to connect to the service already # # the shutdown has to block on running requests (wait_for_command_exit_code) # so they can finish serving the exit code # shutdown does block with network.BasicService._server._block_on_close = True task.shutdown()